summaryrefslogtreecommitdiff
path: root/dns/_asyncio_backend.py
blob: f82eb823916363091f3d53e7f206e6a7934494ce (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license

"""asyncio library query support"""

import socket
import asyncio

import dns._asyncbackend
import dns.exception


def _get_running_loop():
    try:
        return asyncio.get_running_loop()
    except AttributeError:
        return asyncio.get_event_loop()


class _DatagramProtocol:
    def __init__(self):
        self.transport = None
        self.recvfrom = None

    def connection_made(self, transport):
        self.transport = transport

    def datagram_received(self, data, addr):
        if self.recvfrom:
            self.recvfrom.set_result((data, addr))
            self.recvfrom = None

    def error_received(self, exc):
        if self.recvfrom:
            self.recvfrom.set_exception(exc)

    def connection_lost(self, exc):
        if self.recvfrom:
            self.recvfrom.set_exception(exc)

    def close(self):
        self.transport.close()


async def _maybe_wait_for(awaitable, timeout):
    if timeout:
        try:
            return await asyncio.wait_for(awaitable, timeout)
        except asyncio.TimeoutError:
            raise dns.exception.Timeout(timeout=timeout)
    else:
        return await awaitable


class DatagramSocket(dns._asyncbackend.DatagramSocket):
    def __init__(self, family, transport, protocol):
        self.family = family
        self.transport = transport
        self.protocol = protocol

    async def sendto(self, what, destination, timeout):
        # no timeout for asyncio sendto
        self.transport.sendto(what, destination)

    async def recvfrom(self, size, timeout):
        # ignore size as there's no way I know to tell protocol about it
        done = _get_running_loop().create_future()
        assert self.protocol.recvfrom is None
        self.protocol.recvfrom = done
        await _maybe_wait_for(done, timeout)
        return done.result()

    async def close(self):
        self.protocol.close()

    async def getpeername(self):
        return self.transport.get_extra_info('peername')


class StreamSocket(dns._asyncbackend.DatagramSocket):
    def __init__(self, af, reader, writer):
        self.family = af
        self.reader = reader
        self.writer = writer

    async def sendall(self, what, timeout):
        self.writer.write(what),
        return await _maybe_wait_for(self.writer.drain(), timeout)
        raise dns.exception.Timeout(timeout=timeout)

    async def recv(self, count, timeout):
        return await _maybe_wait_for(self.reader.read(count),
                                     timeout)
        raise dns.exception.Timeout(timeout=timeout)

    async def close(self):
        self.writer.close()
        try:
            await self.writer.wait_closed()
        except AttributeError:
            pass

    async def getpeername(self):
        return self.writer.get_extra_info('peername')


class Backend(dns._asyncbackend.Backend):
    def name(self):
        return 'asyncio'

    async def make_socket(self, af, socktype, proto=0,
                          source=None, destination=None, timeout=None,
                          ssl_context=None, server_hostname=None):
        loop = _get_running_loop()
        if socktype == socket.SOCK_DGRAM:
            transport, protocol = await loop.create_datagram_endpoint(
                _DatagramProtocol, source, family=af,
                proto=proto)
            return DatagramSocket(af, transport, protocol)
        elif socktype == socket.SOCK_STREAM:
            (r, w) = await _maybe_wait_for(
                asyncio.open_connection(destination[0],
                                        destination[1],
                                        ssl=ssl_context,
                                        family=af,
                                        proto=proto,
                                        local_addr=source,
                                        server_hostname=server_hostname),
                timeout)
            return StreamSocket(af, r, w)
        raise NotImplementedError(f'unsupported socket type {socktype}')

    async def sleep(self, interval):
        await asyncio.sleep(interval)