summaryrefslogtreecommitdiff
path: root/trollius/py3_ssl.py
blob: e044f06416316d2daef99ed69e7e3820c9afb347 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
"""
Backport SSL functions and exceptions:
- BACKPORT_SSL_ERRORS (bool)
- SSLWantReadError, SSLWantWriteError, SSLEOFError
- BACKPORT_SSL_CONTEXT (bool)
- SSLContext
- wrap_socket()
- wrap_ssl_error()
"""
import errno
import ssl
import sys
from trollius.py33_exceptions import _wrap_error

__all__ = ["SSLContext", "BACKPORT_SSL_ERRORS", "BACKPORT_SSL_CONTEXT",
           "SSLWantReadError", "SSLWantWriteError", "SSLEOFError",
           ]

try:
    SSLWantReadError = ssl.SSLWantReadError
    SSLWantWriteError = ssl.SSLWantWriteError
    SSLEOFError = ssl.SSLEOFError
    BACKPORT_SSL_ERRORS = False
except AttributeError:
    # Python < 3.3
    BACKPORT_SSL_ERRORS = True

    class SSLWantReadError(ssl.SSLError):
        pass

    class SSLWantWriteError(ssl.SSLError):
        pass

    class SSLEOFError(ssl.SSLError):
        pass


try:
    SSLContext = ssl.SSLContext
    BACKPORT_SSL_CONTEXT = False
    wrap_socket = ssl.wrap_socket
except AttributeError:
    # Python < 3.2
    BACKPORT_SSL_CONTEXT = True

    if (sys.version_info < (2, 6, 6)):
        # SSLSocket constructor has bugs in Python older than 2.6.6:
        #    http://bugs.python.org/issue5103
        #    http://bugs.python.org/issue7943
        from socket import socket, error as socket_error, _delegate_methods
        import _ssl

        class BackportSSLSocket(ssl.SSLSocket):
            # Override SSLSocket.__init__()
            def __init__(self, sock, keyfile=None, certfile=None,
                         server_side=False, cert_reqs=ssl.CERT_NONE,
                         ssl_version=ssl.PROTOCOL_SSLv23, ca_certs=None,
                         do_handshake_on_connect=True,
                         suppress_ragged_eofs=True):
                socket.__init__(self, _sock=sock._sock)
                # The initializer for socket overrides the methods send(), recv(), etc.
                # in the instancce, which we don't need -- but we want to provide the
                # methods defined in SSLSocket.
                for attr in _delegate_methods:
                    try:
                        delattr(self, attr)
                    except AttributeError:
                        pass

                if certfile and not keyfile:
                    keyfile = certfile
                # see if it's connected
                try:
                    socket.getpeername(self)
                except socket_error as e:
                    if e.errno != errno.ENOTCONN:
                        raise
                    # no, no connection yet
                    self._connected = False
                    self._sslobj = None
                else:
                    # yes, create the SSL object
                    self._connected = True
                    self._sslobj = _ssl.sslwrap(self._sock, server_side,
                                                keyfile, certfile,
                                                cert_reqs, ssl_version, ca_certs)
                    if do_handshake_on_connect:
                        self.do_handshake()
                self.keyfile = keyfile
                self.certfile = certfile
                self.cert_reqs = cert_reqs
                self.ssl_version = ssl_version
                self.ca_certs = ca_certs
                self.do_handshake_on_connect = do_handshake_on_connect
                self.suppress_ragged_eofs = suppress_ragged_eofs
                self._makefile_refs = 0

        def wrap_socket(sock, keyfile=None, certfile=None,
                        server_side=False, cert_reqs=ssl.CERT_NONE,
                        ssl_version=ssl.PROTOCOL_SSLv23, ca_certs=None,
                        do_handshake_on_connect=True,
                        suppress_ragged_eofs=True):
            return BackportSSLSocket(sock, keyfile=keyfile, certfile=certfile,
                             server_side=server_side, cert_reqs=cert_reqs,
                             ssl_version=ssl_version, ca_certs=ca_certs,
                             do_handshake_on_connect=do_handshake_on_connect,
                             suppress_ragged_eofs=suppress_ragged_eofs)
    else:
        wrap_socket = ssl.wrap_socket


    class SSLContext(object):
        def __init__(self, protocol=ssl.PROTOCOL_SSLv23):
            self.protocol = protocol
            self.certfile = None
            self.keyfile = None

        def load_cert_chain(self, certfile, keyfile):
            self.certfile = certfile
            self.keyfile = keyfile

        def wrap_socket(self, sock, **kw):
            return wrap_socket(sock,
                                   ssl_version=self.protocol,
                                   certfile=self.certfile,
                                   keyfile=self.keyfile,
                                   **kw)

        @property
        def verify_mode(self):
            return ssl.CERT_NONE


if BACKPORT_SSL_ERRORS:
    _MAP_ERRORS = {
        ssl.SSL_ERROR_WANT_READ: SSLWantReadError,
        ssl.SSL_ERROR_WANT_WRITE: SSLWantWriteError,
        ssl.SSL_ERROR_EOF: SSLEOFError,
    }

    def wrap_ssl_error(func, *args, **kw):
        try:
            return func(*args, **kw)
        except ssl.SSLError as exc:
            if exc.args:
                _wrap_error(exc, _MAP_ERRORS, exc.args[0])
            raise
else:
    def wrap_ssl_error(func, *args, **kw):
        return func(*args, **kw)