diff options
Diffstat (limited to 'Lib/ssl.py')
-rw-r--r-- | Lib/ssl.py | 99 |
1 files changed, 62 insertions, 37 deletions
diff --git a/Lib/ssl.py b/Lib/ssl.py index 3c0783fadd..f3e5123976 100644 --- a/Lib/ssl.py +++ b/Lib/ssl.py @@ -59,9 +59,9 @@ import textwrap import _ssl # if we can't import it, let the error propagate +from _ssl import OPENSSL_VERSION_NUMBER, OPENSSL_VERSION_INFO, OPENSSL_VERSION from _ssl import SSLError from _ssl import CERT_NONE, CERT_OPTIONAL, CERT_REQUIRED -from _ssl import PROTOCOL_SSLv2, PROTOCOL_SSLv3, PROTOCOL_SSLv23, PROTOCOL_TLSv1 from _ssl import RAND_status, RAND_egd, RAND_add from _ssl import \ SSL_ERROR_ZERO_RETURN, \ @@ -73,9 +73,20 @@ from _ssl import \ SSL_ERROR_WANT_CONNECT, \ SSL_ERROR_EOF, \ SSL_ERROR_INVALID_ERROR_CODE - -from socket import socket, _fileobject, _delegate_methods -from socket import error as socket_error +from _ssl import PROTOCOL_SSLv3, PROTOCOL_SSLv23, PROTOCOL_TLSv1 +_PROTOCOL_NAMES = { + PROTOCOL_TLSv1: "TLSv1", + PROTOCOL_SSLv23: "SSLv23", + PROTOCOL_SSLv3: "SSLv3", +} +try: + from _ssl import PROTOCOL_SSLv2 +except ImportError: + pass +else: + _PROTOCOL_NAMES[PROTOCOL_SSLv2] = "SSLv2" + +from socket import socket, _fileobject, _delegate_methods, error as socket_error from socket import getnameinfo as _getnameinfo import base64 # for DER-to-PEM translation import errno @@ -90,7 +101,7 @@ class SSLSocket(socket): server_side=False, cert_reqs=CERT_NONE, ssl_version=PROTOCOL_SSLv23, ca_certs=None, do_handshake_on_connect=True, - suppress_ragged_eofs=True): + suppress_ragged_eofs=True, ciphers=None): socket.__init__(self, _sock=sock._sock) # The initializer for socket overrides the methods send(), recv(), etc. # in the instancce, which we don't need -- but we want to provide the @@ -110,12 +121,15 @@ class SSLSocket(socket): if e.errno != errno.ENOTCONN: raise # no, no connection yet + self._connected = False self._sslobj = None else: # yes, create the SSL object + self._connected = True self._sslobj = _ssl.sslwrap(self._sock, server_side, keyfile, certfile, - cert_reqs, ssl_version, ca_certs) + cert_reqs, ssl_version, ca_certs, + ciphers) if do_handshake_on_connect: self.do_handshake() self.keyfile = keyfile @@ -123,6 +137,7 @@ class SSLSocket(socket): self.cert_reqs = cert_reqs self.ssl_version = ssl_version self.ca_certs = ca_certs + self.ciphers = ciphers self.do_handshake_on_connect = do_handshake_on_connect self.suppress_ragged_eofs = suppress_ragged_eofs self._makefile_refs = 0 @@ -182,14 +197,16 @@ class SSLSocket(socket): else: return v else: - return socket.send(self, data, flags) + return self._sock.send(data, flags) - def sendto(self, data, addr, flags=0): + def sendto(self, data, flags_or_addr, addr=None): if self._sslobj: raise ValueError("sendto not allowed on instances of %s" % self.__class__) + elif addr is None: + return self._sock.sendto(data, flags_or_addr) else: - return socket.sendto(self, data, addr, flags) + return self._sock.sendto(data, flags_or_addr, addr) def sendall(self, data, flags=0): if self._sslobj: @@ -214,7 +231,7 @@ class SSLSocket(socket): self.__class__) return self.read(buflen) else: - return socket.recv(self, buflen, flags) + return self._sock.recv(buflen, flags) def recv_into(self, buffer, nbytes=None, flags=0): if buffer and (nbytes is None): @@ -231,21 +248,21 @@ class SSLSocket(socket): buffer[:v] = tmp_buffer return v else: - return socket.recv_into(self, buffer, nbytes, flags) + return self._sock.recv_into(buffer, nbytes, flags) - def recvfrom(self, addr, buflen=1024, flags=0): + def recvfrom(self, buflen=1024, flags=0): if self._sslobj: raise ValueError("recvfrom not allowed on instances of %s" % self.__class__) else: - return socket.recvfrom(self, addr, buflen, flags) + return self._sock.recvfrom(buflen, flags) def recvfrom_into(self, buffer, nbytes=None, flags=0): if self._sslobj: raise ValueError("recvfrom_into not allowed on instances of %s" % self.__class__) else: - return socket.recvfrom_into(self, buffer, nbytes, flags) + return self._sock.recvfrom_into(buffer, nbytes, flags) def pending(self): if self._sslobj: @@ -278,21 +295,36 @@ class SSLSocket(socket): self._sslobj.do_handshake() - def connect(self, addr): - - """Connects to remote ADDR, and then wraps the connection in - an SSL channel.""" - + def _real_connect(self, addr, return_errno): # Here we assume that the socket is client-side, and not # connected at the time of the call. We connect it, then wrap it. - if self._sslobj: + if self._connected: raise ValueError("attempt to connect already-connected SSLSocket!") - socket.connect(self, addr) self._sslobj = _ssl.sslwrap(self._sock, False, self.keyfile, self.certfile, self.cert_reqs, self.ssl_version, - self.ca_certs) - if self.do_handshake_on_connect: - self.do_handshake() + self.ca_certs, self.ciphers) + try: + socket.connect(self, addr) + if self.do_handshake_on_connect: + self.do_handshake() + except socket_error as e: + if return_errno: + return e.errno + else: + self._sslobj = None + raise e + self._connected = True + return 0 + + def connect(self, addr): + """Connects to remote ADDR, and then wraps the connection in + an SSL channel.""" + self._real_connect(addr, False) + + def connect_ex(self, addr): + """Connects to remote ADDR, and then wraps the connection in + an SSL channel.""" + return self._real_connect(addr, True) def accept(self): @@ -308,6 +340,7 @@ class SSLSocket(socket): cert_reqs=self.cert_reqs, ssl_version=self.ssl_version, ca_certs=self.ca_certs, + ciphers=self.ciphers, do_handshake_on_connect=self.do_handshake_on_connect, suppress_ragged_eofs=self.suppress_ragged_eofs), addr) @@ -329,13 +362,14 @@ def wrap_socket(sock, keyfile=None, certfile=None, server_side=False, cert_reqs=CERT_NONE, ssl_version=PROTOCOL_SSLv23, ca_certs=None, do_handshake_on_connect=True, - suppress_ragged_eofs=True): + suppress_ragged_eofs=True, ciphers=None): return SSLSocket(sock, keyfile=keyfile, certfile=certfile, server_side=server_side, cert_reqs=cert_reqs, ssl_version=ssl_version, ca_certs=ca_certs, do_handshake_on_connect=do_handshake_on_connect, - suppress_ragged_eofs=suppress_ragged_eofs) + suppress_ragged_eofs=suppress_ragged_eofs, + ciphers=ciphers) # some utility functions @@ -402,16 +436,7 @@ def get_server_certificate(addr, ssl_version=PROTOCOL_SSLv3, ca_certs=None): return DER_cert_to_PEM_cert(dercert) def get_protocol_name(protocol_code): - if protocol_code == PROTOCOL_TLSv1: - return "TLSv1" - elif protocol_code == PROTOCOL_SSLv23: - return "SSLv23" - elif protocol_code == PROTOCOL_SSLv2: - return "SSLv2" - elif protocol_code == PROTOCOL_SSLv3: - return "SSLv3" - else: - return "<unknown>" + return _PROTOCOL_NAMES.get(protocol_code, '<unknown>') # a replacement for the old socket.ssl function @@ -429,7 +454,7 @@ def sslwrap_simple(sock, keyfile=None, certfile=None): PROTOCOL_SSLv23, None) try: sock.getpeername() - except: + except socket_error: # no, no connection yet pass else: |