diff options
author | Bob Halley <halley@dnspython.org> | 2021-02-25 08:55:15 -0800 |
---|---|---|
committer | Bob Halley <halley@dnspython.org> | 2021-02-25 08:55:15 -0800 |
commit | 5969ce9d22277f907ece5bddc4f02054fedd8734 (patch) | |
tree | 41fd2145847332f252ee3b62e5f0a174fb0fddc7 | |
parent | c249ca426bc43a5cd634387f8b0d09848fa9a169 (diff) | |
download | dnspython-windows_asyncio_fix.tar.gz |
asyncio on Windows requries connected sockets. [Issue #637]windows_asyncio_fix
-rw-r--r-- | dns/_asyncbackend.py | 3 | ||||
-rw-r--r-- | dns/_asyncio_backend.py | 14 | ||||
-rw-r--r-- | dns/asyncquery.py | 7 | ||||
-rw-r--r-- | tests/test_async.py | 14 |
4 files changed, 35 insertions, 3 deletions
diff --git a/dns/_asyncbackend.py b/dns/_asyncbackend.py index 0ce316b..69411df 100644 --- a/dns/_asyncbackend.py +++ b/dns/_asyncbackend.py @@ -64,3 +64,6 @@ class Backend: # pragma: no cover source=None, destination=None, timeout=None, ssl_context=None, server_hostname=None): raise NotImplementedError + + def datagram_connection_required(self): + return False diff --git a/dns/_asyncio_backend.py b/dns/_asyncio_backend.py index 17bd0f7..80c31dc 100644 --- a/dns/_asyncio_backend.py +++ b/dns/_asyncio_backend.py @@ -4,11 +4,14 @@ import socket import asyncio +import sys import dns._asyncbackend import dns.exception +_is_win32 = sys.platform == 'win32' + def _get_running_loop(): try: return asyncio.get_running_loop() @@ -114,11 +117,16 @@ class Backend(dns._asyncbackend.Backend): async def make_socket(self, af, socktype, proto=0, source=None, destination=None, timeout=None, ssl_context=None, server_hostname=None): + if destination is None and socktype == socket.SOCK_DGRAM and \ + _is_win32: + raise NotImplementedError('destinationless datagram sockets ' + 'are not supported by asyncio ' + 'on Windows') loop = _get_running_loop() if socktype == socket.SOCK_DGRAM: transport, protocol = await loop.create_datagram_endpoint( _DatagramProtocol, source, family=af, - proto=proto) + proto=proto, remote_addr=destination) return DatagramSocket(af, transport, protocol) elif socktype == socket.SOCK_STREAM: (r, w) = await _maybe_wait_for( @@ -136,3 +144,7 @@ class Backend(dns._asyncbackend.Backend): async def sleep(self, interval): await asyncio.sleep(interval) + + def datagram_connection_required(self): + return _is_win32 + diff --git a/dns/asyncquery.py b/dns/asyncquery.py index 89c2622..0e353e8 100644 --- a/dns/asyncquery.py +++ b/dns/asyncquery.py @@ -142,7 +142,12 @@ async def udp(q, where, timeout=None, port=53, source=None, source_port=0, if not backend: backend = dns.asyncbackend.get_default_backend() stuple = _source_tuple(af, source, source_port) - s = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple) + if backend.datagram_connection_required(): + dtuple = (where, port) + else: + dtuple = None + s = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple, + dtuple) await send_udp(s, wire, destination, expiration) (r, received_time, _) = await receive_udp(s, destination, expiration, ignore_unexpected, diff --git a/tests/test_async.py b/tests/test_async.py index e9a26bb..0252f22 100644 --- a/tests/test_async.py +++ b/tests/test_async.py @@ -17,6 +17,7 @@ import asyncio import socket +import sys import time import unittest @@ -152,6 +153,7 @@ class MiscQuery(unittest.TestCase): @unittest.skipIf(not _network_available, "Internet not reachable") class AsyncTests(unittest.TestCase): + connect_udp = sys.platform == 'win32' def setUp(self): self.backend = dns.asyncbackend.set_default_backend('asyncio') @@ -261,9 +263,13 @@ class AsyncTests(unittest.TestCase): for address in query_addresses: qname = dns.name.from_text('dns.google.') async def run(): + if self.connect_udp: + dtuple=(address, 53) + else: + dtuple=None async with await self.backend.make_socket( dns.inet.af_for_address(address), - socket.SOCK_DGRAM) as s: + socket.SOCK_DGRAM, 0, None, dtuple) as s: q = dns.message.make_query(qname, dns.rdatatype.A) return await dns.asyncquery.udp(q, address, sock=s, timeout=2) @@ -373,6 +379,8 @@ class AsyncTests(unittest.TestCase): self.assertFalse(tcp) def testUDPReceiveQuery(self): + if self.connect_udp: + self.skipTest('test needs connectionless sockets') async def run(): async with await self.backend.make_socket( socket.AF_INET, socket.SOCK_DGRAM, @@ -392,6 +400,8 @@ class AsyncTests(unittest.TestCase): self.assertEqual(sender_address, recv_address) def testUDPReceiveTimeout(self): + if self.connect_udp: + self.skipTest('test needs connectionless sockets') async def arun(): async with await self.backend.make_socket(socket.AF_INET, socket.SOCK_DGRAM, 0, @@ -430,6 +440,7 @@ try: return trio.run(afunc) class TrioAsyncTests(AsyncTests): + connect_udp = False def setUp(self): self.backend = dns.asyncbackend.set_default_backend('trio') @@ -453,6 +464,7 @@ try: return curio.run(afunc) class CurioAsyncTests(AsyncTests): + connect_udp = False def setUp(self): self.backend = dns.asyncbackend.set_default_backend('curio') |