summaryrefslogtreecommitdiff
path: root/Lib/test/test_asyncio/test_sslproto.py
blob: a7498e85c25ff212e1550ad1e272e15763711c8a (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
"""Tests for asyncio/sslproto.py."""

import logging
import unittest
from unittest import mock
try:
    import ssl
except ImportError:
    ssl = None

import asyncio
from asyncio import log
from asyncio import sslproto
from asyncio import tasks
from test.test_asyncio import utils as test_utils


@unittest.skipIf(ssl is None, 'No ssl module')
class SslProtoHandshakeTests(test_utils.TestCase):

    def setUp(self):
        super().setUp()
        self.loop = asyncio.new_event_loop()
        self.set_event_loop(self.loop)

    def ssl_protocol(self, waiter=None):
        sslcontext = test_utils.dummy_ssl_context()
        app_proto = asyncio.Protocol()
        proto = sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter,
                                     ssl_handshake_timeout=0.1)
        self.assertIs(proto._app_transport.get_protocol(), app_proto)
        self.addCleanup(proto._app_transport.close)
        return proto

    def connection_made(self, ssl_proto, do_handshake=None):
        transport = mock.Mock()
        sslpipe = mock.Mock()
        sslpipe.shutdown.return_value = b''
        if do_handshake:
            sslpipe.do_handshake.side_effect = do_handshake
        else:
            def mock_handshake(callback):
                return []
            sslpipe.do_handshake.side_effect = mock_handshake
        with mock.patch('asyncio.sslproto._SSLPipe', return_value=sslpipe):
            ssl_proto.connection_made(transport)
        return transport

    def test_cancel_handshake(self):
        # Python issue #23197: cancelling a handshake must not raise an
        # exception or log an error, even if the handshake failed
        waiter = asyncio.Future(loop=self.loop)
        ssl_proto = self.ssl_protocol(waiter)
        handshake_fut = asyncio.Future(loop=self.loop)

        def do_handshake(callback):
            exc = Exception()
            callback(exc)
            handshake_fut.set_result(None)
            return []

        waiter.cancel()
        self.connection_made(ssl_proto, do_handshake)

        with test_utils.disable_logger():
            self.loop.run_until_complete(handshake_fut)

    def test_handshake_timeout(self):
        # bpo-29970: Check that a connection is aborted if handshake is not
        # completed in timeout period, instead of remaining open indefinitely
        ssl_proto = self.ssl_protocol()
        transport = self.connection_made(ssl_proto)

        with test_utils.disable_logger():
            self.loop.run_until_complete(tasks.sleep(0.2, loop=self.loop))
        self.assertTrue(transport.abort.called)

    def test_handshake_timeout_zero(self):
        sslcontext = test_utils.dummy_ssl_context()
        app_proto = mock.Mock()
        waiter = mock.Mock()
        with self.assertRaisesRegex(ValueError, 'a positive number'):
            sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter,
                                 ssl_handshake_timeout=0)

    def test_handshake_timeout_negative(self):
        sslcontext = test_utils.dummy_ssl_context()
        app_proto = mock.Mock()
        waiter = mock.Mock()
        with self.assertRaisesRegex(ValueError, 'a positive number'):
            sslproto.SSLProtocol(self.loop, app_proto, sslcontext, waiter,
                                 ssl_handshake_timeout=-10)

    def test_eof_received_waiter(self):
        waiter = asyncio.Future(loop=self.loop)
        ssl_proto = self.ssl_protocol(waiter)
        self.connection_made(ssl_proto)
        ssl_proto.eof_received()
        test_utils.run_briefly(self.loop)
        self.assertIsInstance(waiter.exception(), ConnectionResetError)

    def test_fatal_error_no_name_error(self):
        # From issue #363.
        # _fatal_error() generates a NameError if sslproto.py
        # does not import base_events.
        waiter = asyncio.Future(loop=self.loop)
        ssl_proto = self.ssl_protocol(waiter)
        # Temporarily turn off error logging so as not to spoil test output.
        log_level = log.logger.getEffectiveLevel()
        log.logger.setLevel(logging.FATAL)
        try:
            ssl_proto._fatal_error(None)
        finally:
            # Restore error logging.
            log.logger.setLevel(log_level)

    def test_connection_lost(self):
        # From issue #472.
        # yield from waiter hang if lost_connection was called.
        waiter = asyncio.Future(loop=self.loop)
        ssl_proto = self.ssl_protocol(waiter)
        self.connection_made(ssl_proto)
        ssl_proto.connection_lost(ConnectionAbortedError)
        test_utils.run_briefly(self.loop)
        self.assertIsInstance(waiter.exception(), ConnectionAbortedError)

    def test_close_during_handshake(self):
        # bpo-29743 Closing transport during handshake process leaks socket
        waiter = asyncio.Future(loop=self.loop)
        ssl_proto = self.ssl_protocol(waiter)

        def do_handshake(callback):
            return []

        transport = self.connection_made(ssl_proto)
        test_utils.run_briefly(self.loop)

        ssl_proto._app_transport.close()
        self.assertTrue(transport.abort.called)

    def test_get_extra_info_on_closed_connection(self):
        waiter = asyncio.Future(loop=self.loop)
        ssl_proto = self.ssl_protocol(waiter)
        self.assertIsNone(ssl_proto._get_extra_info('socket'))
        default = object()
        self.assertIs(ssl_proto._get_extra_info('socket', default), default)
        self.connection_made(ssl_proto)
        self.assertIsNotNone(ssl_proto._get_extra_info('socket'))
        ssl_proto.connection_lost(None)
        self.assertIsNone(ssl_proto._get_extra_info('socket'))

    def test_set_new_app_protocol(self):
        waiter = asyncio.Future(loop=self.loop)
        ssl_proto = self.ssl_protocol(waiter)
        new_app_proto = asyncio.Protocol()
        ssl_proto._app_transport.set_protocol(new_app_proto)
        self.assertIs(ssl_proto._app_transport.get_protocol(), new_app_proto)
        self.assertIs(ssl_proto._app_protocol, new_app_proto)


if __name__ == '__main__':
    unittest.main()