1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
|
"""Socket wrappers to go with scheduling.py.
Classes:
- SocketTransport: a transport implementation wrapping a socket.
- SslTransport: a transport implementation wrapping SSL around a socket.
- BufferedReader: a buffer wrapping the read end of a transport.
Functions (all coroutines):
- connect(): connect a socket.
- getaddrinfo(): look up an address.
- create_connection(): look up address and return a connected socket for it.
- create_transport(): look up address and return a connected transport.
TODO:
- Improve transport abstraction.
- Make a nice protocol abstraction.
- Unittests.
- A write() call that isn't a generator (needed so you can substitute it
for sys.stderr, pass it to logging.StreamHandler, etc.).
"""
__author__ = 'Guido van Rossum <guido@python.org>'
# Stdlib imports.
import errno
import re
import socket
import ssl
# Local imports.
import scheduling
class SocketTransport:
"""Transport wrapping a socket.
The socket must already be connected in non-blocking mode.
"""
def __init__(self, sock):
self.sock = sock
def recv(self, n):
"""COROUTINE: Read up to n bytes, blocking at most once."""
assert n >= 0, n
scheduling.block_r(self.sock.fileno())
yield
return self.sock.recv(n)
def send(self, data):
"""COROUTINE; Send data to the socket, blocking until all written."""
while data:
scheduling.block_w(self.sock.fileno())
yield
n = self.sock.send(data)
assert 0 <= n <= len(data), (n, len(data))
if n == len(data):
break
data = data[n:]
def close(self):
"""Close the socket. (Not a coroutine.)"""
self.sock.close()
class SslTransport:
"""Transport wrapping a socket in SSL.
The socket must already be connected at the TCP level in
non-blocking mode.
"""
def __init__(self, rawsock, sslcontext=None):
self.rawsock = rawsock
self.sslcontext = sslcontext or ssl.SSLContext(ssl.PROTOCOL_SSLv23)
self.sslsock = self.sslcontext.wrap_socket(
self.rawsock, do_handshake_on_connect=False)
def do_handshake(self):
"""COROUTINE: Finish the SSL handshake."""
while True:
try:
self.sslsock.do_handshake()
except ssl.SSLWantReadError:
scheduling.block_r(self.sslsock.fileno())
yield
except ssl.SSLWantWriteError:
scheduling.block_w(self.sslsock.fileno())
yield
else:
break
def recv(self, n):
"""COROUTINE: Read up to n bytes.
This blocks until at least one byte is read, or until EOF.
"""
while True:
try:
return self.sslsock.recv(n)
except socket.error as err:
scheduling.block_r(self.sslsock.fileno())
yield
def send(self, data):
"""COROUTINE: Send data to the socket, blocking as needed."""
while data:
try:
n = self.sslsock.send(data)
except socket.error as err:
scheduling.block_w(self.sslsock.fileno())
yield
if n == len(data):
break
data = data[n:]
def close(self):
"""Close the socket. (Not a coroutine.)
This also closes the raw socket.
"""
self.sslsock.close()
# TODO: More SSL-specific methods, e.g. certificate stuff, unwrap(), ...
class BufferedReader:
"""A buffered reader wrapping a transport."""
def __init__(self, trans, limit=8192):
self.trans = trans
self.limit = limit
self.buffer = b''
self.eof = False
def read(self, n):
"""COROUTINE: Read up to n bytes, blocking at most once."""
assert n >= 0, n
if not self.buffer and not self.eof:
yield from self._fillbuffer(max(n, self.limit))
return self._getfrombuffer(n)
def readexactly(self, n):
"""COUROUTINE: Read exactly n bytes, or until EOF."""
blocks = []
count = 0
while n > count:
block = yield from self.read(n - count)
blocks.append(block)
count += len(block)
return b''.join(blocks)
def readline(self):
"""COROUTINE: Read up to newline or limit, whichever comes first."""
end = self.buffer.find(b'\n') + 1 # Point past newline, or 0.
while not end and not self.eof and len(self.buffer) < self.limit:
anchor = len(self.buffer)
yield from self._fillbuffer(self.limit)
end = self.buffer.find(b'\n', anchor) + 1
if not end:
end = len(self.buffer)
if end > self.limit:
end = self.limit
return self._getfrombuffer(end)
def _getfrombuffer(self, n):
"""Read up to n bytes without blocking (not a coroutine)."""
if n >= len(self.buffer):
result, self.buffer = self.buffer, b''
else:
result, self.buffer = self.buffer[:n], self.buffer[n:]
return result
def _fillbuffer(self, n):
"""COROUTINE: Fill buffer with one (up to) n bytes from transport."""
assert not self.eof, '_fillbuffer called at eof'
data = yield from self.trans.recv(n)
if data:
self.buffer += data
else:
self.eof = True
def connect(sock, address):
"""COROUTINE: Connect a socket to an address."""
try:
sock.connect(address)
except socket.error as err:
if err.errno != errno.EINPROGRESS:
raise
scheduling.block_w(sock.fileno())
yield
err = sock.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
if err != 0:
raise IOError(err, 'Connection refused')
def getaddrinfo(host, port, af=0, socktype=0, proto=0):
"""COROUTINE: Look up an address and return a list of infos for it.
Each info is a tuple (af, socktype, protocol, canonname, address).
"""
infos = yield from scheduling.call_in_thread(socket.getaddrinfo,
host, port, af,
socktype, proto)
return infos
def create_connection(host, port, af=0, socktype=socket.SOCK_STREAM):
"""COROUTINE: Look up address and create a socket connected to it."""
match = re.match(r'(\d+)\.(\d+)\.(\d+)\.(\d+)\Z', host)
if match:
d1, d2, d3, d4 = map(int, match.groups())
if not (0 <= d1 <= 255 and 0 <= d2 <= 255 and
0 <= d3 <= 255 and 0 <= d4 <= 255):
match = None
if not match:
infos = yield from getaddrinfo(host, port,
af=af, socktype=socket.SOCK_STREAM)
else:
infos = [(socket.AF_INET, socket.SOCK_STREAM, socket.SOL_TCP, '',
(host, port))]
assert infos, 'No address info for (%r, %r)' % (host, port)
exc = None
for af, socktype, proto, cname, address in infos:
sock = None
try:
sock = socket.socket(af, socktype, proto)
sock.setblocking(False)
yield from connect(sock, address)
break
except socket.error as err:
if sock is not None:
sock.close()
if exc is None:
exc = err
else:
if exc is not None:
raise exc
return sock
def create_transport(host, port, af=0, ssl=None):
"""COROUTINE: Look up address and create a transport connected to it."""
if ssl is None:
ssl = (port == 443)
sock = yield from create_connection(host, port, af=af)
if ssl:
trans = SslTransport(sock)
yield from trans.do_handshake()
else:
trans = SocketTransport(sock)
return trans
|