summaryrefslogtreecommitdiff
path: root/dns/_asyncio_backend.py
blob: 80c31dcd19e920745a7ea6028455072d6e097445 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license

"""asyncio library query support"""

import socket
import asyncio
import sys

import dns._asyncbackend
import dns.exception


_is_win32 = sys.platform == 'win32'

def _get_running_loop():
    try:
        return asyncio.get_running_loop()
    except AttributeError:  # pragma: no cover
        return asyncio.get_event_loop()


class _DatagramProtocol:
    def __init__(self):
        self.transport = None
        self.recvfrom = None

    def connection_made(self, transport):
        self.transport = transport

    def datagram_received(self, data, addr):
        if self.recvfrom:
            self.recvfrom.set_result((data, addr))
            self.recvfrom = None

    def error_received(self, exc):  # pragma: no cover
        if self.recvfrom and not self.recvfrom.done():
            self.recvfrom.set_exception(exc)

    def connection_lost(self, exc):
        if self.recvfrom and not self.recvfrom.done():
            self.recvfrom.set_exception(exc)

    def close(self):
        self.transport.close()


async def _maybe_wait_for(awaitable, timeout):
    if timeout:
        try:
            return await asyncio.wait_for(awaitable, timeout)
        except asyncio.TimeoutError:
            raise dns.exception.Timeout(timeout=timeout)
    else:
        return await awaitable


class DatagramSocket(dns._asyncbackend.DatagramSocket):
    def __init__(self, family, transport, protocol):
        self.family = family
        self.transport = transport
        self.protocol = protocol

    async def sendto(self, what, destination, timeout):  # pragma: no cover
        # no timeout for asyncio sendto
        self.transport.sendto(what, destination)

    async def recvfrom(self, size, timeout):
        # ignore size as there's no way I know to tell protocol about it
        done = _get_running_loop().create_future()
        assert self.protocol.recvfrom is None
        self.protocol.recvfrom = done
        await _maybe_wait_for(done, timeout)
        return done.result()

    async def close(self):
        self.protocol.close()

    async def getpeername(self):
        return self.transport.get_extra_info('peername')

    async def getsockname(self):
        return self.transport.get_extra_info('sockname')


class StreamSocket(dns._asyncbackend.StreamSocket):
    def __init__(self, af, reader, writer):
        self.family = af
        self.reader = reader
        self.writer = writer

    async def sendall(self, what, timeout):
        self.writer.write(what)
        return await _maybe_wait_for(self.writer.drain(), timeout)

    async def recv(self, count, timeout):
        return await _maybe_wait_for(self.reader.read(count),
                                     timeout)

    async def close(self):
        self.writer.close()
        try:
            await self.writer.wait_closed()
        except AttributeError:  # pragma: no cover
            pass

    async def getpeername(self):
        return self.writer.get_extra_info('peername')

    async def getsockname(self):
        return self.writer.get_extra_info('sockname')


class Backend(dns._asyncbackend.Backend):
    def name(self):
        return 'asyncio'

    async def make_socket(self, af, socktype, proto=0,
                          source=None, destination=None, timeout=None,
                          ssl_context=None, server_hostname=None):
        if destination is None and socktype == socket.SOCK_DGRAM and \
           _is_win32:
            raise NotImplementedError('destinationless datagram sockets '
                                      'are not supported by asyncio '
                                      'on Windows')
        loop = _get_running_loop()
        if socktype == socket.SOCK_DGRAM:
            transport, protocol = await loop.create_datagram_endpoint(
                _DatagramProtocol, source, family=af,
                proto=proto, remote_addr=destination)
            return DatagramSocket(af, transport, protocol)
        elif socktype == socket.SOCK_STREAM:
            (r, w) = await _maybe_wait_for(
                asyncio.open_connection(destination[0],
                                        destination[1],
                                        ssl=ssl_context,
                                        family=af,
                                        proto=proto,
                                        local_addr=source,
                                        server_hostname=server_hostname),
                timeout)
            return StreamSocket(af, r, w)
        raise NotImplementedError('unsupported socket ' +
                                  f'type {socktype}')  # pragma: no cover

    async def sleep(self, interval):
        await asyncio.sleep(interval)

    def datagram_connection_required(self):
        return _is_win32