diff options
Diffstat (limited to 'docker/transport')
-rw-r--r-- | docker/transport/sshconn.py | 111 |
1 files changed, 58 insertions, 53 deletions
diff --git a/docker/transport/sshconn.py b/docker/transport/sshconn.py index cdeeae4..5cdaa27 100644 --- a/docker/transport/sshconn.py +++ b/docker/transport/sshconn.py @@ -1,9 +1,9 @@ -import io import paramiko import requests.adapters import six import logging import os +import signal import socket import subprocess @@ -23,40 +23,6 @@ except ImportError: RecentlyUsedContainer = urllib3._collections.RecentlyUsedContainer -def create_paramiko_client(base_url): - logging.getLogger("paramiko").setLevel(logging.WARNING) - ssh_client = paramiko.SSHClient() - base_url = six.moves.urllib_parse.urlparse(base_url) - ssh_params = { - "hostname": base_url.hostname, - "port": base_url.port, - "username": base_url.username - } - ssh_config_file = os.path.expanduser("~/.ssh/config") - if os.path.exists(ssh_config_file): - conf = paramiko.SSHConfig() - with open(ssh_config_file) as f: - conf.parse(f) - host_config = conf.lookup(base_url.hostname) - ssh_conf = host_config - if 'proxycommand' in host_config: - ssh_params["sock"] = paramiko.ProxyCommand( - ssh_conf['proxycommand'] - ) - if 'hostname' in host_config: - ssh_params['hostname'] = host_config['hostname'] - if 'identityfile' in host_config: - ssh_params['key_filename'] = host_config['identityfile'] - if base_url.port is None and 'port' in host_config: - ssh_params['port'] = ssh_conf['port'] - if base_url.username is None and 'user' in host_config: - ssh_params['username'] = ssh_conf['user'] - - ssh_client.load_system_host_keys() - ssh_client.set_missing_host_key_policy(paramiko.WarningPolicy()) - return ssh_client, ssh_params - - class SSHSocket(socket.socket): def __init__(self, host): super(SSHSocket, self).__init__( @@ -80,7 +46,8 @@ class SSHSocket(socket.socket): ' '.join(args), shell=True, stdout=subprocess.PIPE, - stdin=subprocess.PIPE) + stdin=subprocess.PIPE, + preexec_fn=lambda: signal.signal(signal.SIGINT, signal.SIG_IGN)) def _write(self, data): if not self.proc or self.proc.stdin.closed: @@ -96,17 +63,18 @@ class SSHSocket(socket.socket): def send(self, data): return self._write(data) - def recv(self): + def recv(self, n): if not self.proc: raise Exception('SSH subprocess not initiated.' 'connect() must be called first.') - return self.proc.stdout.read() + return self.proc.stdout.read(n) def makefile(self, mode): - if not self.proc or self.proc.stdout.closed: - buf = io.BytesIO() - buf.write(b'\n\n') - return buf + if not self.proc: + self.connect() + if six.PY3: + self.proc.stdout.channel = self + return self.proc.stdout def close(self): @@ -124,7 +92,7 @@ class SSHConnection(httplib.HTTPConnection, object): ) self.ssh_transport = ssh_transport self.timeout = timeout - self.host = host + self.ssh_host = host def connect(self): if self.ssh_transport: @@ -132,7 +100,7 @@ class SSHConnection(httplib.HTTPConnection, object): sock.settimeout(self.timeout) sock.exec_command('docker system dial-stdio') else: - sock = SSHSocket(self.host) + sock = SSHSocket(self.ssh_host) sock.settimeout(self.timeout) sock.connect() @@ -147,16 +115,16 @@ class SSHConnectionPool(urllib3.connectionpool.HTTPConnectionPool): 'localhost', timeout=timeout, maxsize=maxsize ) self.ssh_transport = None + self.timeout = timeout if ssh_client: self.ssh_transport = ssh_client.get_transport() - self.timeout = timeout - self.host = host - self.port = None + self.ssh_host = host + self.ssh_port = None if ':' in host: - self.host, self.port = host.split(':') + self.ssh_host, self.ssh_port = host.split(':') def _new_conn(self): - return SSHConnection(self.ssh_transport, self.timeout, self.host) + return SSHConnection(self.ssh_transport, self.timeout, self.ssh_host) # When re-using connections, urllib3 calls fileno() on our # SSH channel instance, quickly overloading our fd limit. To avoid this, @@ -193,10 +161,10 @@ class SSHHTTPAdapter(BaseHTTPAdapter): shell_out=True): self.ssh_client = None if not shell_out: - self.ssh_client, self.ssh_params = create_paramiko_client(base_url) + self._create_paramiko_client(base_url) self._connect() - base_url = base_url.lstrip('ssh://') - self.host = base_url + + self.ssh_host = base_url.lstrip('ssh://') self.timeout = timeout self.max_pool_size = max_pool_size self.pools = RecentlyUsedContainer( @@ -204,11 +172,48 @@ class SSHHTTPAdapter(BaseHTTPAdapter): ) super(SSHHTTPAdapter, self).__init__() + def _create_paramiko_client(self, base_url): + logging.getLogger("paramiko").setLevel(logging.WARNING) + self.ssh_client = paramiko.SSHClient() + base_url = six.moves.urllib_parse.urlparse(base_url) + self.ssh_params = { + "hostname": base_url.hostname, + "port": base_url.port, + "username": base_url.username + } + ssh_config_file = os.path.expanduser("~/.ssh/config") + if os.path.exists(ssh_config_file): + conf = paramiko.SSHConfig() + with open(ssh_config_file) as f: + conf.parse(f) + host_config = conf.lookup(base_url.hostname) + self.ssh_conf = host_config + if 'proxycommand' in host_config: + self.ssh_params["sock"] = paramiko.ProxyCommand( + self.ssh_conf['proxycommand'] + ) + if 'hostname' in host_config: + self.ssh_params['hostname'] = host_config['hostname'] + if base_url.port is None and 'port' in host_config: + self.ssh_params['port'] = self.ssh_conf['port'] + if base_url.username is None and 'user' in host_config: + self.ssh_params['username'] = self.ssh_conf['user'] + + self.ssh_client.load_system_host_keys() + self.ssh_client.set_missing_host_key_policy(paramiko.WarningPolicy()) + def _connect(self): if self.ssh_client: self.ssh_client.connect(**self.ssh_params) def get_connection(self, url, proxies=None): + if not self.ssh_client: + return SSHConnectionPool( + ssh_client=self.ssh_client, + timeout=self.timeout, + maxsize=self.max_pool_size, + host=self.ssh_host + ) with self.pools.lock: pool = self.pools.get(url) if pool: @@ -222,7 +227,7 @@ class SSHHTTPAdapter(BaseHTTPAdapter): ssh_client=self.ssh_client, timeout=self.timeout, maxsize=self.max_pool_size, - host=self.host + host=self.ssh_host ) self.pools[url] = pool |