summaryrefslogtreecommitdiff
path: root/Lib/ssl.py
diff options
context:
space:
mode:
Diffstat (limited to 'Lib/ssl.py')
-rw-r--r--Lib/ssl.py310
1 files changed, 253 insertions, 57 deletions
diff --git a/Lib/ssl.py b/Lib/ssl.py
index ec42e38d08..ab7a49b576 100644
--- a/Lib/ssl.py
+++ b/Lib/ssl.py
@@ -87,17 +87,18 @@ ALERT_DESCRIPTION_BAD_CERTIFICATE_HASH_VALUE
ALERT_DESCRIPTION_UNKNOWN_PSK_IDENTITY
"""
+import ipaddress
import textwrap
import re
import sys
import os
from collections import namedtuple
-from enum import Enum as _Enum
+from enum import Enum as _Enum, IntEnum as _IntEnum
import _ssl # if we can't import it, let the error propagate
from _ssl import OPENSSL_VERSION_NUMBER, OPENSSL_VERSION_INFO, OPENSSL_VERSION
-from _ssl import _SSLContext
+from _ssl import _SSLContext, MemoryBIO
from _ssl import (
SSLError, SSLZeroReturnError, SSLWantReadError, SSLWantWriteError,
SSLSyscallError, SSLEOFError,
@@ -119,30 +120,23 @@ def _import_symbols(prefix):
_import_symbols('OP_')
_import_symbols('ALERT_DESCRIPTION_')
_import_symbols('SSL_ERROR_')
-_import_symbols('PROTOCOL_')
_import_symbols('VERIFY_')
-from _ssl import HAS_SNI, HAS_ECDH, HAS_NPN
+from _ssl import HAS_SNI, HAS_ECDH, HAS_NPN, HAS_ALPN
from _ssl import _OPENSSL_API_VERSION
+_IntEnum._convert(
+ '_SSLMethod', __name__,
+ lambda name: name.startswith('PROTOCOL_'),
+ source=_ssl)
+
+_PROTOCOL_NAMES = {value: name for name, value in _SSLMethod.__members__.items()}
-_PROTOCOL_NAMES = {value: name for name, value in globals().items() if name.startswith('PROTOCOL_')}
try:
- from _ssl import PROTOCOL_SSLv2
_SSLv2_IF_EXISTS = PROTOCOL_SSLv2
-except ImportError:
+except NameError:
_SSLv2_IF_EXISTS = None
-else:
- _PROTOCOL_NAMES[PROTOCOL_SSLv2] = "SSLv2"
-
-try:
- from _ssl import PROTOCOL_TLSv1_1, PROTOCOL_TLSv1_2
-except ImportError:
- pass
-else:
- _PROTOCOL_NAMES[PROTOCOL_TLSv1_1] = "TLSv1.1"
- _PROTOCOL_NAMES[PROTOCOL_TLSv1_2] = "TLSv1.2"
if sys.platform == "win32":
from _ssl import enum_certificates, enum_crls
@@ -246,6 +240,17 @@ def _dnsname_match(dn, hostname, max_wildcards=1):
return pat.match(hostname)
+def _ipaddress_match(ipname, host_ip):
+ """Exact matching of IP addresses.
+
+ RFC 6125 explicitly doesn't define an algorithm for this
+ (section 1.7.2 - "Out of Scope").
+ """
+ # OpenSSL may add a trailing newline to a subjectAltName's IP address
+ ip = ipaddress.ip_address(ipname.rstrip())
+ return ip == host_ip
+
+
def match_hostname(cert, hostname):
"""Verify that *cert* (in decoded format as returned by
SSLSocket.getpeercert()) matches the *hostname*. RFC 2818 and RFC 6125
@@ -258,11 +263,20 @@ def match_hostname(cert, hostname):
raise ValueError("empty or no certificate, match_hostname needs a "
"SSL socket or SSL context with either "
"CERT_OPTIONAL or CERT_REQUIRED")
+ try:
+ host_ip = ipaddress.ip_address(hostname)
+ except ValueError:
+ # Not an IP address (common case)
+ host_ip = None
dnsnames = []
san = cert.get('subjectAltName', ())
for key, value in san:
if key == 'DNS':
- if _dnsname_match(value, hostname):
+ if host_ip is None and _dnsname_match(value, hostname):
+ return
+ dnsnames.append(value)
+ elif key == 'IP Address':
+ if host_ip is not None and _ipaddress_match(value, host_ip):
return
dnsnames.append(value)
if not dnsnames:
@@ -361,6 +375,12 @@ class SSLContext(_SSLContext):
server_hostname=server_hostname,
_context=self)
+ def wrap_bio(self, incoming, outgoing, server_side=False,
+ server_hostname=None):
+ sslobj = self._wrap_bio(incoming, outgoing, server_side=server_side,
+ server_hostname=server_hostname)
+ return SSLObject(sslobj)
+
def set_npn_protocols(self, npn_protocols):
protos = bytearray()
for protocol in npn_protocols:
@@ -372,6 +392,17 @@ class SSLContext(_SSLContext):
self._set_npn_protocols(protos)
+ def set_alpn_protocols(self, alpn_protocols):
+ protos = bytearray()
+ for protocol in alpn_protocols:
+ b = bytes(protocol, 'ascii')
+ if len(b) == 0 or len(b) > 255:
+ raise SSLError('ALPN protocols must be 1 to 255 in length')
+ protos.append(len(b))
+ protos.extend(b)
+
+ self._set_alpn_protocols(protos)
+
def _load_windows_store_certs(self, storename, purpose):
certs = bytearray()
for cert, encoding, trust in enum_certificates(storename):
@@ -488,6 +519,141 @@ _create_default_https_context = create_default_context
_create_stdlib_context = _create_unverified_context
+class SSLObject:
+ """This class implements an interface on top of a low-level SSL object as
+ implemented by OpenSSL. This object captures the state of an SSL connection
+ but does not provide any network IO itself. IO needs to be performed
+ through separate "BIO" objects which are OpenSSL's IO abstraction layer.
+
+ This class does not have a public constructor. Instances are returned by
+ ``SSLContext.wrap_bio``. This class is typically used by framework authors
+ that want to implement asynchronous IO for SSL through memory buffers.
+
+ When compared to ``SSLSocket``, this object lacks the following features:
+
+ * Any form of network IO incluging methods such as ``recv`` and ``send``.
+ * The ``do_handshake_on_connect`` and ``suppress_ragged_eofs`` machinery.
+ """
+
+ def __init__(self, sslobj, owner=None):
+ self._sslobj = sslobj
+ # Note: _sslobj takes a weak reference to owner
+ self._sslobj.owner = owner or self
+
+ @property
+ def context(self):
+ """The SSLContext that is currently in use."""
+ return self._sslobj.context
+
+ @context.setter
+ def context(self, ctx):
+ self._sslobj.context = ctx
+
+ @property
+ def server_side(self):
+ """Whether this is a server-side socket."""
+ return self._sslobj.server_side
+
+ @property
+ def server_hostname(self):
+ """The currently set server hostname (for SNI), or ``None`` if no
+ server hostame is set."""
+ return self._sslobj.server_hostname
+
+ def read(self, len=0, buffer=None):
+ """Read up to 'len' bytes from the SSL object and return them.
+
+ If 'buffer' is provided, read into this buffer and return the number of
+ bytes read.
+ """
+ if buffer is not None:
+ v = self._sslobj.read(len, buffer)
+ else:
+ v = self._sslobj.read(len or 1024)
+ return v
+
+ def write(self, data):
+ """Write 'data' to the SSL object and return the number of bytes
+ written.
+
+ The 'data' argument must support the buffer interface.
+ """
+ return self._sslobj.write(data)
+
+ def getpeercert(self, binary_form=False):
+ """Returns a formatted version of the data in the certificate provided
+ by the other end of the SSL channel.
+
+ Return None if no certificate was provided, {} if a certificate was
+ provided, but not validated.
+ """
+ return self._sslobj.peer_certificate(binary_form)
+
+ def selected_npn_protocol(self):
+ """Return the currently selected NPN protocol as a string, or ``None``
+ if a next protocol was not negotiated or if NPN is not supported by one
+ of the peers."""
+ if _ssl.HAS_NPN:
+ return self._sslobj.selected_npn_protocol()
+
+ def selected_alpn_protocol(self):
+ """Return the currently selected ALPN protocol as a string, or ``None``
+ if a next protocol was not negotiated or if ALPN is not supported by one
+ of the peers."""
+ if _ssl.HAS_ALPN:
+ return self._sslobj.selected_alpn_protocol()
+
+ def cipher(self):
+ """Return the currently selected cipher as a 3-tuple ``(name,
+ ssl_version, secret_bits)``."""
+ return self._sslobj.cipher()
+
+ def shared_ciphers(self):
+ """Return a list of ciphers shared by the client during the handshake or
+ None if this is not a valid server connection.
+ """
+ return self._sslobj.shared_ciphers()
+
+ def compression(self):
+ """Return the current compression algorithm in use, or ``None`` if
+ compression was not negotiated or not supported by one of the peers."""
+ return self._sslobj.compression()
+
+ def pending(self):
+ """Return the number of bytes that can be read immediately."""
+ return self._sslobj.pending()
+
+ def do_handshake(self):
+ """Start the SSL/TLS handshake."""
+ self._sslobj.do_handshake()
+ if self.context.check_hostname:
+ if not self.server_hostname:
+ raise ValueError("check_hostname needs server_hostname "
+ "argument")
+ match_hostname(self.getpeercert(), self.server_hostname)
+
+ def unwrap(self):
+ """Start the SSL shutdown handshake."""
+ return self._sslobj.shutdown()
+
+ def get_channel_binding(self, cb_type="tls-unique"):
+ """Get channel binding data for current connection. Raise ValueError
+ if the requested `cb_type` is not supported. Return bytes of the data
+ or None if the data is not available (e.g. before the handshake)."""
+ if cb_type not in CHANNEL_BINDING_TYPES:
+ raise ValueError("Unsupported channel binding type")
+ if cb_type != "tls-unique":
+ raise NotImplementedError(
+ "{0} channel binding type not implemented"
+ .format(cb_type))
+ return self._sslobj.tls_unique_cb()
+
+ def version(self):
+ """Return a string identifying the protocol version used by the
+ current SSL channel. """
+ return self._sslobj.version()
+
+
class SSLSocket(socket):
"""This class implements a subtype of socket.socket that wraps
the underlying OS socket in an SSL context when necessary, and
@@ -570,8 +736,9 @@ class SSLSocket(socket):
if connected:
# create the SSL object
try:
- self._sslobj = self._context._wrap_socket(self, server_side,
- server_hostname)
+ sslobj = self._context._wrap_socket(self, server_side,
+ server_hostname)
+ self._sslobj = SSLObject(sslobj, owner=self)
if do_handshake_on_connect:
timeout = self.gettimeout()
if timeout == 0.0:
@@ -616,11 +783,7 @@ class SSLSocket(socket):
if not self._sslobj:
raise ValueError("Read on closed or unwrapped SSL socket.")
try:
- if buffer is not None:
- v = self._sslobj.read(len, buffer)
- else:
- v = self._sslobj.read(len or 1024)
- return v
+ return self._sslobj.read(len, buffer)
except SSLError as x:
if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs:
if buffer is not None:
@@ -647,7 +810,7 @@ class SSLSocket(socket):
self._checkClosed()
self._check_connected()
- return self._sslobj.peer_certificate(binary_form)
+ return self._sslobj.getpeercert(binary_form)
def selected_npn_protocol(self):
self._checkClosed()
@@ -656,6 +819,13 @@ class SSLSocket(socket):
else:
return self._sslobj.selected_npn_protocol()
+ def selected_alpn_protocol(self):
+ self._checkClosed()
+ if not self._sslobj or not _ssl.HAS_ALPN:
+ return None
+ else:
+ return self._sslobj.selected_alpn_protocol()
+
def cipher(self):
self._checkClosed()
if not self._sslobj:
@@ -663,6 +833,12 @@ class SSLSocket(socket):
else:
return self._sslobj.cipher()
+ def shared_ciphers(self):
+ self._checkClosed()
+ if not self._sslobj:
+ return None
+ return self._sslobj.shared_ciphers()
+
def compression(self):
self._checkClosed()
if not self._sslobj:
@@ -677,17 +853,7 @@ class SSLSocket(socket):
raise ValueError(
"non-zero flags not allowed in calls to send() on %s" %
self.__class__)
- try:
- v = self._sslobj.write(data)
- except SSLError as x:
- if x.args[0] == SSL_ERROR_WANT_READ:
- return 0
- elif x.args[0] == SSL_ERROR_WANT_WRITE:
- return 0
- else:
- raise
- else:
- return v
+ return self._sslobj.write(data)
else:
return socket.send(self, data, flags)
@@ -723,6 +889,16 @@ class SSLSocket(socket):
else:
return socket.sendall(self, data, flags)
+ def sendfile(self, file, offset=0, count=None):
+ """Send a file, possibly by using os.sendfile() if this is a
+ clear-text socket. Return the total number of bytes sent.
+ """
+ if self._sslobj is None:
+ # os.sendfile() works with plain sockets only
+ return super().sendfile(file, offset, count)
+ else:
+ return self._sendfile_use_send(file, offset, count)
+
def recv(self, buflen=1024, flags=0):
self._checkClosed()
if self._sslobj:
@@ -787,7 +963,7 @@ class SSLSocket(socket):
def unwrap(self):
if self._sslobj:
- s = self._sslobj.shutdown()
+ s = self._sslobj.unwrap()
self._sslobj = None
return s
else:
@@ -808,12 +984,6 @@ class SSLSocket(socket):
finally:
self.settimeout(timeout)
- if self.context.check_hostname:
- if not self.server_hostname:
- raise ValueError("check_hostname needs server_hostname "
- "argument")
- match_hostname(self.getpeercert(), self.server_hostname)
-
def _real_connect(self, addr, connect_ex):
if self.server_side:
raise ValueError("can't connect in server-side mode")
@@ -821,7 +991,8 @@ class SSLSocket(socket):
# connected at the time of the call. We connect it, then wrap it.
if self._connected:
raise ValueError("attempt to connect already-connected SSLSocket!")
- self._sslobj = self.context._wrap_socket(self, False, self.server_hostname)
+ sslobj = self.context._wrap_socket(self, False, self.server_hostname)
+ self._sslobj = SSLObject(sslobj, owner=self)
try:
if connect_ex:
rc = socket.connect_ex(self, addr)
@@ -864,15 +1035,18 @@ class SSLSocket(socket):
if the requested `cb_type` is not supported. Return bytes of the data
or None if the data is not available (e.g. before the handshake).
"""
- if cb_type not in CHANNEL_BINDING_TYPES:
- raise ValueError("Unsupported channel binding type")
- if cb_type != "tls-unique":
- raise NotImplementedError(
- "{0} channel binding type not implemented"
- .format(cb_type))
if self._sslobj is None:
return None
- return self._sslobj.tls_unique_cb()
+ return self._sslobj.get_channel_binding(cb_type)
+
+ def version(self):
+ """
+ Return a string identifying the protocol version used by the
+ current SSL channel, or None if there is no established channel.
+ """
+ if self._sslobj is None:
+ return None
+ return self._sslobj.version()
def wrap_socket(sock, keyfile=None, certfile=None,
@@ -892,12 +1066,34 @@ def wrap_socket(sock, keyfile=None, certfile=None,
# some utility functions
def cert_time_to_seconds(cert_time):
- """Takes a date-time string in standard ASN1_print form
- ("MON DAY 24HOUR:MINUTE:SEC YEAR TIMEZONE") and return
- a Python time value in seconds past the epoch."""
+ """Return the time in seconds since the Epoch, given the timestring
+ representing the "notBefore" or "notAfter" date from a certificate
+ in ``"%b %d %H:%M:%S %Y %Z"`` strptime format (C locale).
+
+ "notBefore" or "notAfter" dates must use UTC (RFC 5280).
- import time
- return time.mktime(time.strptime(cert_time, "%b %d %H:%M:%S %Y GMT"))
+ Month is one of: Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec
+ UTC should be specified as GMT (see ASN1_TIME_print())
+ """
+ from time import strptime
+ from calendar import timegm
+
+ months = (
+ "Jan","Feb","Mar","Apr","May","Jun",
+ "Jul","Aug","Sep","Oct","Nov","Dec"
+ )
+ time_format = ' %d %H:%M:%S %Y GMT' # NOTE: no month, fixed GMT
+ try:
+ month_number = months.index(cert_time[:3].title()) + 1
+ except ValueError:
+ raise ValueError('time data %r does not match '
+ 'format "%%b%s"' % (cert_time, time_format))
+ else:
+ # found valid month
+ tt = strptime(cert_time[3:], time_format)
+ # return an integer, the previous mktime()-based implementation
+ # returned a float (fractional seconds are always zero here).
+ return timegm((tt[0], month_number) + tt[2:6])
PEM_HEADER = "-----BEGIN CERTIFICATE-----"
PEM_FOOTER = "-----END CERTIFICATE-----"