summaryrefslogtreecommitdiff
path: root/lib/py/src
diff options
context:
space:
mode:
authorBryan Duxbury <bryanduxbury@apache.org>2011-03-21 17:59:49 +0000
committerBryan Duxbury <bryanduxbury@apache.org>2011-03-21 17:59:49 +0000
commit5040911bfab39b5c9f2a0d715cea0ee9012f7450 (patch)
tree55d5743795503b8df6c7a66576e473ed0743d34a /lib/py/src
parent59d4efda817f73eb195f47ff9f04cb0f4ec47525 (diff)
downloadthrift-5040911bfab39b5c9f2a0d715cea0ee9012f7450.tar.gz
THRIFT-1100. py: python TSSLSocket improvements, including certificate validation
This patch adds a number of features to TSSLSocket and TSSLServerSocket. Patch: Will Pierce git-svn-id: https://svn.apache.org/repos/asf/thrift/trunk@1083880 13f79535-47bb-0310-9956-ffa450edef68
Diffstat (limited to 'lib/py/src')
-rw-r--r--lib/py/src/transport/TSSLSocket.py186
-rw-r--r--lib/py/src/transport/TSocket.py4
2 files changed, 159 insertions, 31 deletions
diff --git a/lib/py/src/transport/TSSLSocket.py b/lib/py/src/transport/TSSLSocket.py
index 8ab91ca61..5eff5e619 100644
--- a/lib/py/src/transport/TSSLSocket.py
+++ b/lib/py/src/transport/TSSLSocket.py
@@ -1,38 +1,166 @@
-import sys
-sys.path.append('/usr/lib/python2.6/site-packages/')
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+#
+import os
+import socket
+import ssl
from thrift.transport import TSocket
-import socket, ssl
+from thrift.transport.TTransport import TTransportException
class TSSLSocket(TSocket.TSocket):
- def open(self):
+ """
+ SSL implementation of client-side TSocket
+
+ This class creates outbound sockets wrapped using the
+ python standard ssl module for encrypted connections.
+
+ The protocol used is set using the class variable
+ SSL_VERSION, which must be one of ssl.PROTOCOL_* and
+ defaults to ssl.PROTOCOL_TLSv1 for greatest security.
+ """
+ SSL_VERSION = ssl.PROTOCOL_TLSv1
+
+ def __init__(self, validate=True, ca_certs=None, *args, **kwargs):
+ """
+ @param validate: Set to False to disable SSL certificate validation entirely.
+ @type validate: bool
+ @param ca_certs: Filename to the Certificate Authority pem file, possibly a
+ file downloaded from: http://curl.haxx.se/ca/cacert.pem This is passed to
+ the ssl_wrap function as the 'ca_certs' parameter.
+ @type ca_certs: str
+
+ Raises an IOError exception if validate is True and the ca_certs file is
+ None, not present or unreadable.
+ """
+ self.validate = validate
+ self.is_valid = False
+ self.peercert = None
+ if not validate:
+ self.cert_reqs = ssl.CERT_NONE
+ else:
+ self.cert_reqs = ssl.CERT_REQUIRED
+ self.ca_certs = ca_certs
+ if validate and ca_certs is not None:
+ if not os.access(ca_certs, os.R_OK):
+ raise IOError('Certificate Authority ca_certs file "%s" is not readable, cannot validate SSL certificates.' % (ca_certs))
+ TSocket.TSocket.__init__(self, *args, **kwargs)
+
+ def open(self):
+ try:
+ res0 = self._resolveAddr()
+ for res in res0:
+ sock_family, sock_type= res[0:2]
+ ip_port = res[4]
+ plain_sock = socket.socket(sock_family, sock_type)
+ self.handle = ssl.wrap_socket(plain_sock, ssl_version=self.SSL_VERSION,
+ do_handshake_on_connect=True, ca_certs=self.ca_certs, cert_reqs=self.cert_reqs)
+ self.handle.settimeout(self._timeout)
try:
- res0 = self._resolveAddr()
- for res in res0:
- plain_sock = socket.socket(res[0], res[1])
- #TODO verify server cert
- self.handle = ssl.wrap_socket(plain_sock, ssl_version=ssl.PROTOCOL_TLSv1)
- self.handle.settimeout(self._timeout)
- try:
- self.handle.connect(res[4])
- except socket.error, e:
- if res is not res0[-1]:
- continue
- else:
- raise e
- break
+ self.handle.connect(ip_port)
except socket.error, e:
- if self._unix_socket:
- message = 'Could not connect to secure socket %s' % self._unix_socket
+ if res is not res0[-1]:
+ continue
else:
- message = 'Could not connect to %s:%d' % (self.host, self.port)
- raise TTransportException(type=TTransportException.NOT_OPEN, message=message)
+ raise e
+ break
+ except socket.error, e:
+ if self._unix_socket:
+ message = 'Could not connect to secure socket %s' % self._unix_socket
+ else:
+ message = 'Could not connect to %s:%d' % (self.host, self.port)
+ raise TTransportException(type=TTransportException.NOT_OPEN, message=message)
+ if self.validate:
+ self._validate_cert()
+
+ def _validate_cert(self):
+ """internal method to validate the peer's SSL certificate, and to check the
+ commonName of the certificate to ensure it matches the hostname we
+ used to make this connection. Does not support subjectAltName records
+ in certificates.
+
+ raises TTransportException if the certificate fails validation."""
+ cert = self.handle.getpeercert()
+ self.peercert = cert
+ if 'subject' not in cert:
+ raise TTransportException(type=TTransportException.NOT_OPEN,
+ message='No SSL certificate found from %s:%s' % (self.host, self.port))
+ fields = cert['subject']
+ for field in fields:
+ # ensure structure we get back is what we expect
+ if not isinstance(field, tuple):
+ continue
+ cert_pair = field[0]
+ if len(cert_pair) < 2:
+ continue
+ cert_key, cert_value = cert_pair[0:2]
+ if cert_key != 'commonName':
+ continue
+ certhost = cert_value
+ if certhost == self.host:
+ # success, cert commonName matches desired hostname
+ self.is_valid = True
+ return
+ else:
+ raise TTransportException(type=TTransportException.UNKNOWN,
+ message='Host name we connected to "%s" doesn\'t match certificate provided commonName "%s"' % (self.host, certhost))
+ raise TTransportException(type=TTransportException.UNKNOWN,
+ message='Could not validate SSL certificate from host "%s". Cert=%s' % (self.host, cert))
class TSSLServerSocket(TSocket.TServerSocket):
- def accept(self):
- plain_client, addr = self.handle.accept()
- result = TSocket.TSocket()
- #TODO take certfile/keyfile as a parameter at setup
- client = ssl.wrap_socket(plain_client, certfile='cert.pem', server_side=True)
- result.setHandle(client)
- return result
+ """
+ SSL implementation of TServerSocket
+
+ This uses the ssl module's wrap_socket() method to provide SSL
+ negotiated encryption.
+ """
+ SSL_VERSION = ssl.PROTOCOL_TLSv1
+
+ def __init__(self, certfile='cert.pem', *args, **kwargs):
+ """Initialize a TSSLServerSocket
+
+ @param certfile: The filename of the server certificate file, defaults to cert.pem
+ @type certfile: str
+ @param host: The hostname or IP to bind the listen socket to, i.e. 'localhost' for only allowing
+ local network connections. Pass None to bind to all interfaces.
+ @type host: str
+ @param port: The port to listen on for inbound connections.
+ @type port: int
+ """
+ self.setCertfile(certfile)
+ TSocket.TServerSocket.__init__(self, *args, **kwargs)
+
+ def setCertfile(self, certfile):
+ """Set or change the server certificate file used to wrap new connections.
+
+ @param certfile: The filename of the server certificate, i.e. '/etc/certs/server.pem'
+ @type certfile: str
+
+ Raises an IOError exception if the certfile is not present or unreadable.
+ """
+ if not os.access(certfile, os.R_OK):
+ raise IOError('No such certfile found: %s' % (certfile))
+ self.certfile = certfile
+
+ def accept(self):
+ plain_client, addr = self.handle.accept()
+ result = TSocket.TSocket()
+ client = ssl.wrap_socket(plain_client, certfile=self.certfile,
+ server_side=True, ssl_version=self.SSL_VERSION)
+ result.setHandle(client)
+ return result
diff --git a/lib/py/src/transport/TSocket.py b/lib/py/src/transport/TSocket.py
index 085a5eef0..be6167802 100644
--- a/lib/py/src/transport/TSocket.py
+++ b/lib/py/src/transport/TSocket.py
@@ -57,7 +57,7 @@ class TSocket(TSocketBase):
self.handle = h
def isOpen(self):
- return self.handle != None
+ return self.handle is not None
def setTimeout(self, ms):
if ms is None:
@@ -65,7 +65,7 @@ class TSocket(TSocketBase):
else:
self._timeout = ms/1000.0
- if (self.handle != None):
+ if self.handle is not None:
self.handle.settimeout(self._timeout)
def open(self):