summaryrefslogtreecommitdiff
path: root/python3
diff options
context:
space:
mode:
authorJoe Gregorio <jcgregorio@google.com>2011-06-07 15:44:51 -0400
committerJoe Gregorio <jcgregorio@google.com>2011-06-07 15:44:51 -0400
commitb53de9b961c5e28d4976a0ae2793846a267bc3c0 (patch)
tree91d09180b99e60f62a6f5b8140d6ba7400e232f2 /python3
parent327f360d32d3e366c4d344a02f9ebb568d8d8022 (diff)
downloadhttplib2-b53de9b961c5e28d4976a0ae2793846a267bc3c0.tar.gz
Add certificate validation. Work initially started by Christoph Kern.
Diffstat (limited to 'python3')
-rw-r--r--python3/httplib2/__init__.py130
-rwxr-xr-xpython3/httplib2test.py42
2 files changed, 78 insertions, 94 deletions
diff --git a/python3/httplib2/__init__.py b/python3/httplib2/__init__.py
index 90ec4d9..af2c3ee 100644
--- a/python3/httplib2/__init__.py
+++ b/python3/httplib2/__init__.py
@@ -92,6 +92,7 @@ class UnimplementedHmacDigestAuthOptionError(HttpLib2ErrorWithResponse): pass
class MalformedHeader(HttpLib2Error): pass
class RelativeURIError(HttpLib2Error): pass
class ServerNotFoundError(HttpLib2Error): pass
+class CertificateValidationUnsupportedInPython31(HttpLib2Error): pass
# Open Items:
# -----------
@@ -118,6 +119,10 @@ DEFAULT_MAX_REDIRECTS = 5
# Which headers are hop-by-hop headers by default
HOP_BY_HOP = ['connection', 'keep-alive', 'proxy-authenticate', 'proxy-authorization', 'te', 'trailers', 'transfer-encoding', 'upgrade']
+# Default CA certificates file bundled with httplib2.
+CA_CERTS = os.path.join(
+ os.path.dirname(os.path.abspath(__file__ )), "cacerts.txt")
+
def _get_end2end_headers(response):
hopbyhop = list(HOP_BY_HOP)
hopbyhop.extend([x.strip() for x in response.get('connection', '').split(',')])
@@ -219,10 +224,10 @@ def _parse_www_authenticate(headers, headername='www-authenticate'):
while authenticate:
# Break off the scheme at the beginning of the line
if headername == 'authentication-info':
- (auth_scheme, the_rest) = ('digest', authenticate)
+ (auth_scheme, the_rest) = ('digest', authenticate)
else:
(auth_scheme, the_rest) = authenticate.split(" ", 1)
- # Now loop over all the key value pairs that come after the scheme,
+ # Now loop over all the key value pairs that come after the scheme,
# being careful not to roll into the next scheme
match = www_auth.search(the_rest)
auth_params = {}
@@ -712,43 +717,11 @@ class HTTPConnectionWithTimeout(http.client.HTTPConnection):
http://docs.python.org/library/socket.html#socket.setdefaulttimeout
"""
- def __init__(self, host, port=None, strict=None, timeout=None, proxy_info=None):
- http.client.HTTPConnection.__init__(self, host, port, strict, timeout)
+ def __init__(self, host, port=None, timeout=None, proxy_info=None):
+ http.client.HTTPConnection.__init__(self, host, port=port,
+ timeout=timeout)
self.proxy_info = proxy_info
- def connect(self):
- """Connect to the host and port specified in __init__."""
- self.sock = socket.create_connection((self.host,self.port),
- self.timeout)
- # Mostly verbatim from httplib.py.
- msg = "getaddrinfo returns an empty list"
- for res in socket.getaddrinfo(self.host, self.port, 0,
- socket.SOCK_STREAM):
- af, socktype, proto, canonname, sa = res
- try:
- if self.proxy_info and self.proxy_info.isgood():
- self.sock = socks.socksocket(af, socktype, proto)
- self.sock.setproxy(*self.proxy_info.astuple())
- else:
- self.sock = socket.socket(af, socktype, proto)
- self.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
- # Different from httplib: support timeouts.
- if has_timeout(self.timeout):
- self.sock.settimeout(self.timeout)
- # End of difference from httplib.
- if self.debuglevel > 0:
- print("connect: (%s, %s)" % (self.host, self.port))
- self.sock.connect(sa)
- except socket.error as msg:
- if self.debuglevel > 0:
- print('connect fail:', (self.host, self.port))
- if self.sock:
- self.sock.close()
- self.sock = None
- continue
- break
- if not self.sock:
- raise socket.error(msg)
class HTTPSConnectionWithTimeout(http.client.HTTPSConnection):
"""
@@ -761,43 +734,25 @@ class HTTPSConnectionWithTimeout(http.client.HTTPSConnection):
"""
def __init__(self, host, port=None, key_file=None, cert_file=None,
- strict=None, timeout=None, proxy_info=None):
+ timeout=None, proxy_info=None,
+ ca_certs=None, disable_ssl_certificate_validation=False):
self.proxy_info = proxy_info
+ context = None
+ if ca_certs is None:
+ ca_certs = CA_CERTS
+ if (cert_file or ca_certs) and not disable_ssl_certificate_validation:
+ if not hasattr(ssl, 'SSLContext'):
+ raise CertificateValidationUnsupportedInPython31()
+ context = ssl.SSLContext(ssl.PROTOCOL_TLSv1)
+ context.verify_mode = ssl.CERT_REQUIRED
+ if cert_file:
+ context.load_cert_chain(cert_file, key_file)
+ if ca_certs:
+ context.load_verify_locations(ca_certs)
http.client.HTTPSConnection.__init__(self, host, port=port, key_file=key_file,
- cert_file=cert_file, strict=strict, timeout=timeout)
-
- def connect(self):
- "Connect to a host on a given (SSL) port."
+ cert_file=cert_file, timeout=timeout, context=context,
+ check_hostname=True)
- msg = "getaddrinfo returns an empty list"
- self.sock = None
- for family, socktype, proto, canonname, sockaddr in socket.getaddrinfo(
- self.host, self.port, 0, socket.SOCK_STREAM):
- try:
- if self.proxy_info and self.proxy_info.isgood():
- sock = socks.socksocket(family, socktype, proto)
- sock.setproxy(*self.proxy_info.astuple())
- else:
- sock = socket.socket(family, socktype, proto)
- sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
-
- if has_timeout(self.timeout):
- sock.settimeout(self.timeout)
- sock.connect((self.host, self.port))
- self.sock =_ssl_wrap_socket(sock, self.key_file, self.cert_file)
- if self.debuglevel > 0:
- print("connect: (%s, %s)" % (self.host, self.port))
- except socket.error as err:
- if self.debuglevel > 0:
- print('connect fail:', (self.host, self.port))
- if self.sock:
- self.sock.close()
- self.sock = None
- msg = err
- continue
- break
- if self.sock is None:
- raise socket.error(msg)
class Http(object):
@@ -813,13 +768,17 @@ class Http(object):
and more.
"""
- def __init__(self, cache=None, timeout=None, proxy_info=None):
+ def __init__(self, cache=None, timeout=None, proxy_info=None,
+ ca_certs=None, disable_ssl_certificate_validation=False):
"""The value of proxy_info is a ProxyInfo instance.
If 'cache' is a string then it is used as a directory name
for a disk cache. Otherwise it must be an object that supports
the same interface as FileCache."""
self.proxy_info = proxy_info
+ self.ca_certs = ca_certs
+ self.disable_ssl_certificate_validation = \
+ disable_ssl_certificate_validation
# Map domain name to an httplib connection
self.connections = {}
# The location of the cache, for now a directory
@@ -884,8 +843,11 @@ the same interface as FileCache."""
def _conn_request(self, conn, request_uri, method, body, headers):
for i in range(2):
try:
+ if conn.sock is None:
+ conn.connect()
conn.request(method, request_uri, body, headers)
except socket.timeout:
+ conn.close()
raise
except socket.gaierror:
conn.close()
@@ -913,6 +875,7 @@ the same interface as FileCache."""
try:
response = conn.getresponse()
except (socket.error, http.client.HTTPException):
+ conn.close()
if i == 0:
conn.close()
conn.connect()
@@ -1054,11 +1017,26 @@ a string that contains the response entity body.
if not connection_type:
connection_type = (scheme == 'https') and HTTPSConnectionWithTimeout or HTTPConnectionWithTimeout
certs = list(self.certificates.iter(authority))
- if scheme == 'https' and certs:
- conn = self.connections[conn_key] = connection_type(authority, key_file=certs[0][0],
- cert_file=certs[0][1], timeout=self.timeout, proxy_info=self.proxy_info)
+ if issubclass(connection_type, HTTPSConnectionWithTimeout):
+ if certs:
+ conn = self.connections[conn_key] = connection_type(
+ authority, key_file=certs[0][0],
+ cert_file=certs[0][1], timeout=self.timeout,
+ proxy_info=self.proxy_info,
+ ca_certs=self.ca_certs,
+ disable_ssl_certificate_validation=
+ self.disable_ssl_certificate_validation)
+ else:
+ conn = self.connections[conn_key] = connection_type(
+ authority, timeout=self.timeout,
+ proxy_info=self.proxy_info,
+ ca_certs=self.ca_certs,
+ disable_ssl_certificate_validation=
+ self.disable_ssl_certificate_validation)
else:
- conn = self.connections[conn_key] = connection_type(authority, timeout=self.timeout, proxy_info=self.proxy_info)
+ conn = self.connections[conn_key] = connection_type(
+ authority, timeout=self.timeout,
+ proxy_info=self.proxy_info)
conn.set_debuglevel(debuglevel)
if 'range' not in headers and 'accept-encoding' not in headers:
diff --git a/python3/httplib2test.py b/python3/httplib2test.py
index 4ef3a78..40e087c 100755
--- a/python3/httplib2test.py
+++ b/python3/httplib2test.py
@@ -20,6 +20,7 @@ import httplib2
import io
import os
import socket
+import ssl
import sys
import time
import unittest
@@ -117,6 +118,7 @@ class _MyHTTPConnection(object):
self.port = port
self.timeout = timeout
self.log = ""
+ self.sock = None
def set_debuglevel(self, level):
pass
@@ -473,8 +475,26 @@ class HttpTest(unittest.TestCase):
# Skip on 3.2
pass
-
-
+ def testSslCertValidation(self):
+ # Test that we get an ssl.SSLError when specifying a non-existent CA
+ # certs file.
+ http = httplib2.Http(ca_certs='/nosuchfile')
+ self.assertRaises(IOError,
+ http.request, "https://www.google.com/", "GET")
+
+ # Test that we get a SSLHandshakeError if we try to access
+ # https://www.google.com, using a CA cert file that doesn't contain
+ # the CA Gogole uses (i.e., simulating a cert that's not signed by a
+ # trusted CA).
+ other_ca_certs = os.path.join(
+ os.path.dirname(os.path.abspath(httplib2.__file__ )),
+ "test", "other_cacerts.txt")
+ http = httplib2.Http(ca_certs=other_ca_certs)
+ self.assertRaises(ssl.SSLError,
+ http.request,"https://www.google.com/", "GET")
+
+ def testSniHostnameValidation(self):
+ self.http.request("https://google.com/", method="GET")
def testGet303(self):
# Do a follow-up GET on a Location: header
@@ -736,20 +756,6 @@ class HttpTest(unittest.TestCase):
self.assertEqual(response.status, 500)
self.assertTrue(response.reason.startswith("Content purported"))
- def testTimeout(self):
- self.http.force_exception_to_status_code = True
- uri = urllib.parse.urljoin(base, "timeout/timeout.cgi")
- try:
- import socket
- socket.setdefaulttimeout(1)
- except:
- # Don't run the test if we can't set the timeout
- return
- (response, content) = self.http.request(uri)
- self.assertEqual(response.status, 408)
- self.assertTrue(response.reason.startswith("Request Timeout"))
- self.assertTrue(content.startswith(b"Request Timeout"))
-
def testIndividualTimeout(self):
uri = urllib.parse.urljoin(base, "timeout/timeout.cgi")
http = httplib2.Http(timeout=1)
@@ -1469,11 +1475,11 @@ class HttpPrivateTest(unittest.TestCase):
# Degenerate case of no headers
response = {}
end2end = httplib2._get_end2end_headers(response)
- self.assertEquals(0, len(end2end))
+ self.assertEqual(0, len(end2end))
# Degenerate case of connection referrring to a header not passed in
response = {'connection': 'content-type'}
end2end = httplib2._get_end2end_headers(response)
- self.assertEquals(0, len(end2end))
+ self.assertEqual(0, len(end2end))
unittest.main()