1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
|
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
"""asyncio library query support"""
import socket
import asyncio
import sys
import dns._asyncbackend
import dns.exception
_is_win32 = sys.platform == 'win32'
def _get_running_loop():
try:
return asyncio.get_running_loop()
except AttributeError: # pragma: no cover
return asyncio.get_event_loop()
class _DatagramProtocol:
def __init__(self):
self.transport = None
self.recvfrom = None
def connection_made(self, transport):
self.transport = transport
def datagram_received(self, data, addr):
if self.recvfrom:
self.recvfrom.set_result((data, addr))
self.recvfrom = None
def error_received(self, exc): # pragma: no cover
if self.recvfrom and not self.recvfrom.done():
self.recvfrom.set_exception(exc)
def connection_lost(self, exc):
if self.recvfrom and not self.recvfrom.done():
self.recvfrom.set_exception(exc)
def close(self):
self.transport.close()
async def _maybe_wait_for(awaitable, timeout):
if timeout:
try:
return await asyncio.wait_for(awaitable, timeout)
except asyncio.TimeoutError:
raise dns.exception.Timeout(timeout=timeout)
else:
return await awaitable
class DatagramSocket(dns._asyncbackend.DatagramSocket):
def __init__(self, family, transport, protocol):
self.family = family
self.transport = transport
self.protocol = protocol
async def sendto(self, what, destination, timeout): # pragma: no cover
# no timeout for asyncio sendto
self.transport.sendto(what, destination)
async def recvfrom(self, size, timeout):
# ignore size as there's no way I know to tell protocol about it
done = _get_running_loop().create_future()
assert self.protocol.recvfrom is None
self.protocol.recvfrom = done
await _maybe_wait_for(done, timeout)
return done.result()
async def close(self):
self.protocol.close()
async def getpeername(self):
return self.transport.get_extra_info('peername')
async def getsockname(self):
return self.transport.get_extra_info('sockname')
class StreamSocket(dns._asyncbackend.StreamSocket):
def __init__(self, af, reader, writer):
self.family = af
self.reader = reader
self.writer = writer
async def sendall(self, what, timeout):
self.writer.write(what)
return await _maybe_wait_for(self.writer.drain(), timeout)
async def recv(self, count, timeout):
return await _maybe_wait_for(self.reader.read(count),
timeout)
async def close(self):
self.writer.close()
try:
await self.writer.wait_closed()
except AttributeError: # pragma: no cover
pass
async def getpeername(self):
return self.writer.get_extra_info('peername')
async def getsockname(self):
return self.writer.get_extra_info('sockname')
class Backend(dns._asyncbackend.Backend):
def name(self):
return 'asyncio'
async def make_socket(self, af, socktype, proto=0,
source=None, destination=None, timeout=None,
ssl_context=None, server_hostname=None):
if destination is None and socktype == socket.SOCK_DGRAM and \
_is_win32:
raise NotImplementedError('destinationless datagram sockets '
'are not supported by asyncio '
'on Windows')
loop = _get_running_loop()
if socktype == socket.SOCK_DGRAM:
transport, protocol = await loop.create_datagram_endpoint(
_DatagramProtocol, source, family=af,
proto=proto, remote_addr=destination)
return DatagramSocket(af, transport, protocol)
elif socktype == socket.SOCK_STREAM:
(r, w) = await _maybe_wait_for(
asyncio.open_connection(destination[0],
destination[1],
ssl=ssl_context,
family=af,
proto=proto,
local_addr=source,
server_hostname=server_hostname),
timeout)
return StreamSocket(af, r, w)
raise NotImplementedError('unsupported socket ' +
f'type {socktype}') # pragma: no cover
async def sleep(self, interval):
await asyncio.sleep(interval)
def datagram_connection_required(self):
return _is_win32
|