diff options
Diffstat (limited to 'Lib/test/test_ssl.py')
| -rw-r--r-- | Lib/test/test_ssl.py | 302 | 
1 files changed, 265 insertions, 37 deletions
| diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py index 1b08f2e7dc..c6ce07545c 100644 --- a/Lib/test/test_ssl.py +++ b/Lib/test/test_ssl.py @@ -42,6 +42,9 @@ ONLYCERT = data_file("ssl_cert.pem")  ONLYKEY = data_file("ssl_key.pem")  BYTES_ONLYCERT = os.fsencode(ONLYCERT)  BYTES_ONLYKEY = os.fsencode(ONLYKEY) +CERTFILE_PROTECTED = data_file("keycert.passwd.pem") +ONLYKEY_PROTECTED = data_file("ssl_key.passwd.pem") +KEY_PASSWORD = "somepass"  CAPATH = data_file("capath")  BYTES_CAPATH = os.fsencode(CAPATH) @@ -53,6 +56,8 @@ WRONGCERT = data_file("XXXnonexisting.pem")  BADKEY = data_file("badkey.pem")  NOKIACERT = data_file("nokia.pem") +DHFILE = data_file("dh512.pem") +BYTES_DHFILE = os.fsencode(DHFILE)  def handle_error(prefix):      exc_format = ' '.join(traceback.format_exception(*sys.exc_info())) @@ -95,7 +100,14 @@ class BasicSocketTests(unittest.TestCase):          ssl.CERT_NONE          ssl.CERT_OPTIONAL          ssl.CERT_REQUIRED +        ssl.OP_CIPHER_SERVER_PREFERENCE +        ssl.OP_SINGLE_DH_USE +        if ssl.HAS_ECDH: +            ssl.OP_SINGLE_ECDH_USE +        if ssl.OPENSSL_VERSION_INFO >= (1, 0): +            ssl.OP_NO_COMPRESSION          self.assertIn(ssl.HAS_SNI, {True, False}) +        self.assertIn(ssl.HAS_ECDH, {True, False})      def test_random(self):          v = ssl.RAND_status() @@ -103,6 +115,16 @@ class BasicSocketTests(unittest.TestCase):              sys.stdout.write("\n RAND_status is %d (%s)\n"                               % (v, (v and "sufficient randomness") or                                  "insufficient randomness")) + +        data, is_cryptographic = ssl.RAND_pseudo_bytes(16) +        self.assertEqual(len(data), 16) +        self.assertEqual(is_cryptographic, v == 1) +        if v: +            data = ssl.RAND_bytes(16) +            self.assertEqual(len(data), 16) +        else: +            self.assertRaises(ssl.SSLError, ssl.RAND_bytes, 16) +          try:              ssl.RAND_egd(1)          except TypeError: @@ -337,6 +359,25 @@ class BasicSocketTests(unittest.TestCase):              self.assertRaises(ValueError, ctx.wrap_socket, sock, True,                                server_hostname="some.hostname") +    def test_unknown_channel_binding(self): +        # should raise ValueError for unknown type +        s = socket.socket(socket.AF_INET) +        ss = ssl.wrap_socket(s) +        with self.assertRaises(ValueError): +            ss.get_channel_binding("unknown-type") + +    @unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES, +                         "'tls-unique' channel binding not available") +    def test_tls_unique_channel_binding(self): +        # unconnected should return None for known type +        s = socket.socket(socket.AF_INET) +        ss = ssl.wrap_socket(s) +        self.assertIsNone(ss.get_channel_binding("tls-unique")) +        # the same for server-side +        s = socket.socket(socket.AF_INET) +        ss = ssl.wrap_socket(s, server_side=True, certfile=CERTFILE) +        self.assertIsNone(ss.get_channel_binding("tls-unique")) +  class ContextTests(unittest.TestCase):      @skip_if_broken_ubuntu_ssl @@ -427,6 +468,60 @@ class ContextTests(unittest.TestCase):          ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)          with self.assertRaisesRegex(ssl.SSLError, "key values mismatch"):              ctx.load_cert_chain(SVN_PYTHON_ORG_ROOT_CERT, ONLYKEY) +        # Password protected key and cert +        ctx.load_cert_chain(CERTFILE_PROTECTED, password=KEY_PASSWORD) +        ctx.load_cert_chain(CERTFILE_PROTECTED, password=KEY_PASSWORD.encode()) +        ctx.load_cert_chain(CERTFILE_PROTECTED, +                            password=bytearray(KEY_PASSWORD.encode())) +        ctx.load_cert_chain(ONLYCERT, ONLYKEY_PROTECTED, KEY_PASSWORD) +        ctx.load_cert_chain(ONLYCERT, ONLYKEY_PROTECTED, KEY_PASSWORD.encode()) +        ctx.load_cert_chain(ONLYCERT, ONLYKEY_PROTECTED, +                            bytearray(KEY_PASSWORD.encode())) +        with self.assertRaisesRegex(TypeError, "should be a string"): +            ctx.load_cert_chain(CERTFILE_PROTECTED, password=True) +        with self.assertRaises(ssl.SSLError): +            ctx.load_cert_chain(CERTFILE_PROTECTED, password="badpass") +        with self.assertRaisesRegex(ValueError, "cannot be longer"): +            # openssl has a fixed limit on the password buffer. +            # PEM_BUFSIZE is generally set to 1kb. +            # Return a string larger than this. +            ctx.load_cert_chain(CERTFILE_PROTECTED, password=b'a' * 102400) +        # Password callback +        def getpass_unicode(): +            return KEY_PASSWORD +        def getpass_bytes(): +            return KEY_PASSWORD.encode() +        def getpass_bytearray(): +            return bytearray(KEY_PASSWORD.encode()) +        def getpass_badpass(): +            return "badpass" +        def getpass_huge(): +            return b'a' * (1024 * 1024) +        def getpass_bad_type(): +            return 9 +        def getpass_exception(): +            raise Exception('getpass error') +        class GetPassCallable: +            def __call__(self): +                return KEY_PASSWORD +            def getpass(self): +                return KEY_PASSWORD +        ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_unicode) +        ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_bytes) +        ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_bytearray) +        ctx.load_cert_chain(CERTFILE_PROTECTED, password=GetPassCallable()) +        ctx.load_cert_chain(CERTFILE_PROTECTED, +                            password=GetPassCallable().getpass) +        with self.assertRaises(ssl.SSLError): +            ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_badpass) +        with self.assertRaisesRegex(ValueError, "cannot be longer"): +            ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_huge) +        with self.assertRaisesRegex(TypeError, "must return a string"): +            ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_bad_type) +        with self.assertRaisesRegex(Exception, "getpass error"): +            ctx.load_cert_chain(CERTFILE_PROTECTED, password=getpass_exception) +        # Make sure the password function isn't called if it isn't needed +        ctx.load_cert_chain(CERTFILE, password=getpass_exception)      def test_load_verify_locations(self):          ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) @@ -447,6 +542,19 @@ class ContextTests(unittest.TestCase):          # Issue #10989: crash if the second argument type is invalid          self.assertRaises(TypeError, ctx.load_verify_locations, None, True) +    def test_load_dh_params(self): +        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) +        ctx.load_dh_params(DHFILE) +        if os.name != 'nt': +            ctx.load_dh_params(BYTES_DHFILE) +        self.assertRaises(TypeError, ctx.load_dh_params) +        self.assertRaises(TypeError, ctx.load_dh_params, None) +        with self.assertRaises(FileNotFoundError) as cm: +            ctx.load_dh_params(WRONGCERT) +        self.assertEqual(cm.exception.errno, errno.ENOENT) +        with self.assertRaisesRegex(ssl.SSLError, "PEM routines"): +            ctx.load_dh_params(CERTFILE) +      @skip_if_broken_ubuntu_ssl      def test_session_stats(self):          for proto in PROTOCOLS: @@ -471,6 +579,16 @@ class ContextTests(unittest.TestCase):          ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1)          ctx.set_default_verify_paths() +    @unittest.skipUnless(ssl.HAS_ECDH, "ECDH disabled on this OpenSSL build") +    def test_set_ecdh_curve(self): +        ctx = ssl.SSLContext(ssl.PROTOCOL_TLSv1) +        ctx.set_ecdh_curve("prime256v1") +        ctx.set_ecdh_curve(b"prime256v1") +        self.assertRaises(TypeError, ctx.set_ecdh_curve) +        self.assertRaises(TypeError, ctx.set_ecdh_curve, None) +        self.assertRaises(ValueError, ctx.set_ecdh_curve, "foo") +        self.assertRaises(ValueError, ctx.set_ecdh_curve, b"foo") +  class NetworkedTests(unittest.TestCase): @@ -533,13 +651,10 @@ class NetworkedTests(unittest.TestCase):                      try:                          s.do_handshake()                          break -                    except ssl.SSLError as err: -                        if err.args[0] == ssl.SSL_ERROR_WANT_READ: -                            select.select([s], [], [], 5.0) -                        elif err.args[0] == ssl.SSL_ERROR_WANT_WRITE: -                            select.select([], [s], [], 5.0) -                        else: -                            raise +                    except ssl.SSLWantReadError: +                        select.select([s], [], [], 5.0) +                    except ssl.SSLWantWriteError: +                        select.select([], [s], [], 5.0)                  # SSL established                  self.assertTrue(s.getpeercert())              finally: @@ -659,37 +774,39 @@ class NetworkedTests(unittest.TestCase):                      count += 1                      s.do_handshake()                      break -                except ssl.SSLError as err: -                    if err.args[0] == ssl.SSL_ERROR_WANT_READ: -                        select.select([s], [], []) -                    elif err.args[0] == ssl.SSL_ERROR_WANT_WRITE: -                        select.select([], [s], []) -                    else: -                        raise +                except ssl.SSLWantReadError: +                    select.select([s], [], []) +                except ssl.SSLWantWriteError: +                    select.select([], [s], [])              s.close()              if support.verbose:                  sys.stdout.write("\nNeeded %d calls to do_handshake() to establish session.\n" % count)      def test_get_server_certificate(self): -        with support.transient_internet("svn.python.org"): -            pem = ssl.get_server_certificate(("svn.python.org", 443)) -            if not pem: -                self.fail("No server certificate on svn.python.org:443!") +        def _test_get_server_certificate(host, port, cert=None): +            with support.transient_internet(host): +                pem = ssl.get_server_certificate((host, port)) +                if not pem: +                    self.fail("No server certificate on %s:%s!" % (host, port)) -            try: -                pem = ssl.get_server_certificate(("svn.python.org", 443), ca_certs=CERTFILE) -            except ssl.SSLError as x: -                #should fail +                try: +                    pem = ssl.get_server_certificate((host, port), ca_certs=CERTFILE) +                except ssl.SSLError as x: +                    #should fail +                    if support.verbose: +                        sys.stdout.write("%s\n" % x) +                else: +                    self.fail("Got server certificate %s for %s:%s!" % (pem, host, port)) + +                pem = ssl.get_server_certificate((host, port), ca_certs=cert) +                if not pem: +                    self.fail("No server certificate on %s:%s!" % (host, port))                  if support.verbose: -                    sys.stdout.write("%s\n" % x) -            else: -                self.fail("Got server certificate %s for svn.python.org!" % pem) +                    sys.stdout.write("\nVerified certificate for %s:%s is\n%s\n" % (host, port ,pem)) -            pem = ssl.get_server_certificate(("svn.python.org", 443), ca_certs=SVN_PYTHON_ORG_ROOT_CERT) -            if not pem: -                self.fail("No server certificate on svn.python.org:443!") -            if support.verbose: -                sys.stdout.write("\nVerified certificate for svn.python.org:443 is\n%s\n" % pem) +        _test_get_server_certificate('svn.python.org', 443, SVN_PYTHON_ORG_ROOT_CERT) +        if support.IPV6_ENABLED: +            _test_get_server_certificate('ipv6.google.com', 443)      def test_ciphers(self):          remote = ("svn.python.org", 443) @@ -838,6 +955,11 @@ else:                              self.sslconn = None                              if support.verbose and self.server.connectionchatty:                                  sys.stdout.write(" server: connection is now unencrypted...\n") +                        elif stripped == b'CB tls-unique': +                            if support.verbose and self.server.connectionchatty: +                                sys.stdout.write(" server: read CB tls-unique from client, sending our CB data...\n") +                            data = self.sslconn.get_channel_binding("tls-unique") +                            self.write(repr(data).encode("us-ascii") + b"\n")                          else:                              if (support.verbose and                                  self.server.connectionchatty): @@ -946,12 +1068,11 @@ else:                  def _do_ssl_handshake(self):                      try:                          self.socket.do_handshake() -                    except ssl.SSLError as err: -                        if err.args[0] in (ssl.SSL_ERROR_WANT_READ, -                                           ssl.SSL_ERROR_WANT_WRITE): -                            return -                        elif err.args[0] == ssl.SSL_ERROR_EOF: -                            return self.handle_close() +                    except (ssl.SSLWantReadError, ssl.SSLWantWriteError): +                        return +                    except ssl.SSLEOFError: +                        return self.handle_close() +                    except ssl.SSLError:                          raise                      except socket.error as err:                          if err.args[0] == errno.ECONNABORTED: @@ -1099,7 +1220,12 @@ else:                  if connectionchatty:                      if support.verbose:                          sys.stdout.write(" client:  closing connection.\n") +                stats = { +                    'compression': s.compression(), +                    'cipher': s.cipher(), +                }                  s.close() +                return stats      def try_protocol_combo(server_protocol, client_protocol, expect_success,                             certsreqs=None, server_options=0, client_options=0): @@ -1251,7 +1377,8 @@ else:                  t.join()          @skip_if_broken_ubuntu_ssl -        @unittest.skipUnless(hasattr(ssl, 'PROTOCOL_SSLv2'), "need SSLv2") +        @unittest.skipUnless(hasattr(ssl, 'PROTOCOL_SSLv2'), +                             "OpenSSL is compiled without SSLv2 support")          def test_protocol_sslv2(self):              """Connecting to an SSLv2 server with various client options"""              if support.verbose: @@ -1557,6 +1684,15 @@ else:                              )                          # consume data                          s.read() + +                # Make sure sendmsg et al are disallowed to avoid +                # inadvertent disclosure of data and/or corruption +                # of the encrypted data stream +                self.assertRaises(NotImplementedError, s.sendmsg, [b"data"]) +                self.assertRaises(NotImplementedError, s.recvmsg, 100) +                self.assertRaises(NotImplementedError, +                                  s.recvmsg_into, bytearray(100)) +                  s.write(b"over\n")                  s.close() @@ -1625,6 +1761,98 @@ else:                          s.connect((HOST, server.port))              self.assertIn("no shared cipher", str(server.conn_errors[0])) +        @unittest.skipUnless("tls-unique" in ssl.CHANNEL_BINDING_TYPES, +                             "'tls-unique' channel binding not available") +        def test_tls_unique_channel_binding(self): +            """Test tls-unique channel binding.""" +            if support.verbose: +                sys.stdout.write("\n") + +            server = ThreadedEchoServer(CERTFILE, +                                        certreqs=ssl.CERT_NONE, +                                        ssl_version=ssl.PROTOCOL_TLSv1, +                                        cacerts=CERTFILE, +                                        chatty=True, +                                        connectionchatty=False) +            with server: +                s = ssl.wrap_socket(socket.socket(), +                                    server_side=False, +                                    certfile=CERTFILE, +                                    ca_certs=CERTFILE, +                                    cert_reqs=ssl.CERT_NONE, +                                    ssl_version=ssl.PROTOCOL_TLSv1) +                s.connect((HOST, server.port)) +                # get the data +                cb_data = s.get_channel_binding("tls-unique") +                if support.verbose: +                    sys.stdout.write(" got channel binding data: {0!r}\n" +                                     .format(cb_data)) + +                # check if it is sane +                self.assertIsNotNone(cb_data) +                self.assertEqual(len(cb_data), 12) # True for TLSv1 + +                # and compare with the peers version +                s.write(b"CB tls-unique\n") +                peer_data_repr = s.read().strip() +                self.assertEqual(peer_data_repr, +                                 repr(cb_data).encode("us-ascii")) +                s.close() + +                # now, again +                s = ssl.wrap_socket(socket.socket(), +                                    server_side=False, +                                    certfile=CERTFILE, +                                    ca_certs=CERTFILE, +                                    cert_reqs=ssl.CERT_NONE, +                                    ssl_version=ssl.PROTOCOL_TLSv1) +                s.connect((HOST, server.port)) +                new_cb_data = s.get_channel_binding("tls-unique") +                if support.verbose: +                    sys.stdout.write(" got another channel binding data: {0!r}\n" +                                     .format(new_cb_data)) +                # is it really unique +                self.assertNotEqual(cb_data, new_cb_data) +                self.assertIsNotNone(cb_data) +                self.assertEqual(len(cb_data), 12) # True for TLSv1 +                s.write(b"CB tls-unique\n") +                peer_data_repr = s.read().strip() +                self.assertEqual(peer_data_repr, +                                 repr(new_cb_data).encode("us-ascii")) +                s.close() + +        def test_compression(self): +            context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) +            context.load_cert_chain(CERTFILE) +            stats = server_params_test(context, context, +                                       chatty=True, connectionchatty=True) +            if support.verbose: +                sys.stdout.write(" got compression: {!r}\n".format(stats['compression'])) +            self.assertIn(stats['compression'], { None, 'ZLIB', 'RLE' }) + +        @unittest.skipUnless(hasattr(ssl, 'OP_NO_COMPRESSION'), +                             "ssl.OP_NO_COMPRESSION needed for this test") +        def test_compression_disabled(self): +            context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) +            context.load_cert_chain(CERTFILE) +            context.options |= ssl.OP_NO_COMPRESSION +            stats = server_params_test(context, context, +                                       chatty=True, connectionchatty=True) +            self.assertIs(stats['compression'], None) + +        def test_dh_params(self): +            # Check we can get a connection with ephemeral Diffie-Hellman +            context = ssl.SSLContext(ssl.PROTOCOL_TLSv1) +            context.load_cert_chain(CERTFILE) +            context.load_dh_params(DHFILE) +            context.set_ciphers("kEDH") +            stats = server_params_test(context, context, +                                       chatty=True, connectionchatty=True) +            cipher = stats["cipher"][0] +            parts = cipher.split("-") +            if "ADH" not in parts and "EDH" not in parts and "DHE" not in parts: +                self.fail("Non-DH cipher: " + cipher[0]) +  def test_main(verbose=False):      if support.verbose: | 
