diff options
Diffstat (limited to 'python3')
-rw-r--r-- | python3/httplib2/__init__.py | 130 | ||||
-rwxr-xr-x | python3/httplib2test.py | 42 |
2 files changed, 78 insertions, 94 deletions
diff --git a/python3/httplib2/__init__.py b/python3/httplib2/__init__.py index 90ec4d9..af2c3ee 100644 --- a/python3/httplib2/__init__.py +++ b/python3/httplib2/__init__.py @@ -92,6 +92,7 @@ class UnimplementedHmacDigestAuthOptionError(HttpLib2ErrorWithResponse): pass class MalformedHeader(HttpLib2Error): pass class RelativeURIError(HttpLib2Error): pass class ServerNotFoundError(HttpLib2Error): pass +class CertificateValidationUnsupportedInPython31(HttpLib2Error): pass # Open Items: # ----------- @@ -118,6 +119,10 @@ DEFAULT_MAX_REDIRECTS = 5 # Which headers are hop-by-hop headers by default HOP_BY_HOP = ['connection', 'keep-alive', 'proxy-authenticate', 'proxy-authorization', 'te', 'trailers', 'transfer-encoding', 'upgrade'] +# Default CA certificates file bundled with httplib2. +CA_CERTS = os.path.join( + os.path.dirname(os.path.abspath(__file__ )), "cacerts.txt") + def _get_end2end_headers(response): hopbyhop = list(HOP_BY_HOP) hopbyhop.extend([x.strip() for x in response.get('connection', '').split(',')]) @@ -219,10 +224,10 @@ def _parse_www_authenticate(headers, headername='www-authenticate'): while authenticate: # Break off the scheme at the beginning of the line if headername == 'authentication-info': - (auth_scheme, the_rest) = ('digest', authenticate) + (auth_scheme, the_rest) = ('digest', authenticate) else: (auth_scheme, the_rest) = authenticate.split(" ", 1) - # Now loop over all the key value pairs that come after the scheme, + # Now loop over all the key value pairs that come after the scheme, # being careful not to roll into the next scheme match = www_auth.search(the_rest) auth_params = {} @@ -712,43 +717,11 @@ class HTTPConnectionWithTimeout(http.client.HTTPConnection): http://docs.python.org/library/socket.html#socket.setdefaulttimeout """ - def __init__(self, host, port=None, strict=None, timeout=None, proxy_info=None): - http.client.HTTPConnection.__init__(self, host, port, strict, timeout) + def __init__(self, host, port=None, timeout=None, proxy_info=None): + http.client.HTTPConnection.__init__(self, host, port=port, + timeout=timeout) self.proxy_info = proxy_info - def connect(self): - """Connect to the host and port specified in __init__.""" - self.sock = socket.create_connection((self.host,self.port), - self.timeout) - # Mostly verbatim from httplib.py. - msg = "getaddrinfo returns an empty list" - for res in socket.getaddrinfo(self.host, self.port, 0, - socket.SOCK_STREAM): - af, socktype, proto, canonname, sa = res - try: - if self.proxy_info and self.proxy_info.isgood(): - self.sock = socks.socksocket(af, socktype, proto) - self.sock.setproxy(*self.proxy_info.astuple()) - else: - self.sock = socket.socket(af, socktype, proto) - self.sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - # Different from httplib: support timeouts. - if has_timeout(self.timeout): - self.sock.settimeout(self.timeout) - # End of difference from httplib. - if self.debuglevel > 0: - print("connect: (%s, %s)" % (self.host, self.port)) - self.sock.connect(sa) - except socket.error as msg: - if self.debuglevel > 0: - print('connect fail:', (self.host, self.port)) - if self.sock: - self.sock.close() - self.sock = None - continue - break - if not self.sock: - raise socket.error(msg) class HTTPSConnectionWithTimeout(http.client.HTTPSConnection): """ @@ -761,43 +734,25 @@ class HTTPSConnectionWithTimeout(http.client.HTTPSConnection): """ def __init__(self, host, port=None, key_file=None, cert_file=None, - strict=None, timeout=None, proxy_info=None): + timeout=None, proxy_info=None, + ca_certs=None, disable_ssl_certificate_validation=False): self.proxy_info = proxy_info + context = None + if ca_certs is None: + ca_certs = CA_CERTS + if (cert_file or ca_certs) and not disable_ssl_certificate_validation: + if not hasattr(ssl, 'SSLContext'): + raise CertificateValidationUnsupportedInPython31() + context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) + context.verify_mode = ssl.CERT_REQUIRED + if cert_file: + context.load_cert_chain(cert_file, key_file) + if ca_certs: + context.load_verify_locations(ca_certs) http.client.HTTPSConnection.__init__(self, host, port=port, key_file=key_file, - cert_file=cert_file, strict=strict, timeout=timeout) - - def connect(self): - "Connect to a host on a given (SSL) port." + cert_file=cert_file, timeout=timeout, context=context, + check_hostname=True) - msg = "getaddrinfo returns an empty list" - self.sock = None - for family, socktype, proto, canonname, sockaddr in socket.getaddrinfo( - self.host, self.port, 0, socket.SOCK_STREAM): - try: - if self.proxy_info and self.proxy_info.isgood(): - sock = socks.socksocket(family, socktype, proto) - sock.setproxy(*self.proxy_info.astuple()) - else: - sock = socket.socket(family, socktype, proto) - sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - - if has_timeout(self.timeout): - sock.settimeout(self.timeout) - sock.connect((self.host, self.port)) - self.sock =_ssl_wrap_socket(sock, self.key_file, self.cert_file) - if self.debuglevel > 0: - print("connect: (%s, %s)" % (self.host, self.port)) - except socket.error as err: - if self.debuglevel > 0: - print('connect fail:', (self.host, self.port)) - if self.sock: - self.sock.close() - self.sock = None - msg = err - continue - break - if self.sock is None: - raise socket.error(msg) class Http(object): @@ -813,13 +768,17 @@ class Http(object): and more. """ - def __init__(self, cache=None, timeout=None, proxy_info=None): + def __init__(self, cache=None, timeout=None, proxy_info=None, + ca_certs=None, disable_ssl_certificate_validation=False): """The value of proxy_info is a ProxyInfo instance. If 'cache' is a string then it is used as a directory name for a disk cache. Otherwise it must be an object that supports the same interface as FileCache.""" self.proxy_info = proxy_info + self.ca_certs = ca_certs + self.disable_ssl_certificate_validation = \ + disable_ssl_certificate_validation # Map domain name to an httplib connection self.connections = {} # The location of the cache, for now a directory @@ -884,8 +843,11 @@ the same interface as FileCache.""" def _conn_request(self, conn, request_uri, method, body, headers): for i in range(2): try: + if conn.sock is None: + conn.connect() conn.request(method, request_uri, body, headers) except socket.timeout: + conn.close() raise except socket.gaierror: conn.close() @@ -913,6 +875,7 @@ the same interface as FileCache.""" try: response = conn.getresponse() except (socket.error, http.client.HTTPException): + conn.close() if i == 0: conn.close() conn.connect() @@ -1054,11 +1017,26 @@ a string that contains the response entity body. if not connection_type: connection_type = (scheme == 'https') and HTTPSConnectionWithTimeout or HTTPConnectionWithTimeout certs = list(self.certificates.iter(authority)) - if scheme == 'https' and certs: - conn = self.connections[conn_key] = connection_type(authority, key_file=certs[0][0], - cert_file=certs[0][1], timeout=self.timeout, proxy_info=self.proxy_info) + if issubclass(connection_type, HTTPSConnectionWithTimeout): + if certs: + conn = self.connections[conn_key] = connection_type( + authority, key_file=certs[0][0], + cert_file=certs[0][1], timeout=self.timeout, + proxy_info=self.proxy_info, + ca_certs=self.ca_certs, + disable_ssl_certificate_validation= + self.disable_ssl_certificate_validation) + else: + conn = self.connections[conn_key] = connection_type( + authority, timeout=self.timeout, + proxy_info=self.proxy_info, + ca_certs=self.ca_certs, + disable_ssl_certificate_validation= + self.disable_ssl_certificate_validation) else: - conn = self.connections[conn_key] = connection_type(authority, timeout=self.timeout, proxy_info=self.proxy_info) + conn = self.connections[conn_key] = connection_type( + authority, timeout=self.timeout, + proxy_info=self.proxy_info) conn.set_debuglevel(debuglevel) if 'range' not in headers and 'accept-encoding' not in headers: diff --git a/python3/httplib2test.py b/python3/httplib2test.py index 4ef3a78..40e087c 100755 --- a/python3/httplib2test.py +++ b/python3/httplib2test.py @@ -20,6 +20,7 @@ import httplib2 import io
import os
import socket
+import ssl
import sys
import time
import unittest
@@ -117,6 +118,7 @@ class _MyHTTPConnection(object): self.port = port
self.timeout = timeout
self.log = ""
+ self.sock = None
def set_debuglevel(self, level):
pass
@@ -473,8 +475,26 @@ class HttpTest(unittest.TestCase): # Skip on 3.2
pass
-
-
+ def testSslCertValidation(self):
+ # Test that we get an ssl.SSLError when specifying a non-existent CA
+ # certs file.
+ http = httplib2.Http(ca_certs='/nosuchfile')
+ self.assertRaises(IOError,
+ http.request, "https://www.google.com/", "GET")
+
+ # Test that we get a SSLHandshakeError if we try to access
+ # https://www.google.com, using a CA cert file that doesn't contain
+ # the CA Gogole uses (i.e., simulating a cert that's not signed by a
+ # trusted CA).
+ other_ca_certs = os.path.join(
+ os.path.dirname(os.path.abspath(httplib2.__file__ )),
+ "test", "other_cacerts.txt")
+ http = httplib2.Http(ca_certs=other_ca_certs)
+ self.assertRaises(ssl.SSLError,
+ http.request,"https://www.google.com/", "GET")
+
+ def testSniHostnameValidation(self):
+ self.http.request("https://google.com/", method="GET")
def testGet303(self):
# Do a follow-up GET on a Location: header
@@ -736,20 +756,6 @@ class HttpTest(unittest.TestCase): self.assertEqual(response.status, 500)
self.assertTrue(response.reason.startswith("Content purported"))
- def testTimeout(self):
- self.http.force_exception_to_status_code = True
- uri = urllib.parse.urljoin(base, "timeout/timeout.cgi")
- try:
- import socket
- socket.setdefaulttimeout(1)
- except:
- # Don't run the test if we can't set the timeout
- return
- (response, content) = self.http.request(uri)
- self.assertEqual(response.status, 408)
- self.assertTrue(response.reason.startswith("Request Timeout"))
- self.assertTrue(content.startswith(b"Request Timeout"))
-
def testIndividualTimeout(self):
uri = urllib.parse.urljoin(base, "timeout/timeout.cgi")
http = httplib2.Http(timeout=1)
@@ -1469,11 +1475,11 @@ class HttpPrivateTest(unittest.TestCase): # Degenerate case of no headers
response = {}
end2end = httplib2._get_end2end_headers(response)
- self.assertEquals(0, len(end2end))
+ self.assertEqual(0, len(end2end))
# Degenerate case of connection referrring to a header not passed in
response = {'connection': 'content-type'}
end2end = httplib2._get_end2end_headers(response)
- self.assertEquals(0, len(end2end))
+ self.assertEqual(0, len(end2end))
unittest.main()
|