summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBob Halley <halley@dnspython.org>2021-02-25 08:55:15 -0800
committerBob Halley <halley@dnspython.org>2021-02-25 08:55:15 -0800
commit5969ce9d22277f907ece5bddc4f02054fedd8734 (patch)
tree41fd2145847332f252ee3b62e5f0a174fb0fddc7
parentc249ca426bc43a5cd634387f8b0d09848fa9a169 (diff)
downloaddnspython-windows_asyncio_fix.tar.gz
asyncio on Windows requries connected sockets. [Issue #637]windows_asyncio_fix
-rw-r--r--dns/_asyncbackend.py3
-rw-r--r--dns/_asyncio_backend.py14
-rw-r--r--dns/asyncquery.py7
-rw-r--r--tests/test_async.py14
4 files changed, 35 insertions, 3 deletions
diff --git a/dns/_asyncbackend.py b/dns/_asyncbackend.py
index 0ce316b..69411df 100644
--- a/dns/_asyncbackend.py
+++ b/dns/_asyncbackend.py
@@ -64,3 +64,6 @@ class Backend: # pragma: no cover
source=None, destination=None, timeout=None,
ssl_context=None, server_hostname=None):
raise NotImplementedError
+
+ def datagram_connection_required(self):
+ return False
diff --git a/dns/_asyncio_backend.py b/dns/_asyncio_backend.py
index 17bd0f7..80c31dc 100644
--- a/dns/_asyncio_backend.py
+++ b/dns/_asyncio_backend.py
@@ -4,11 +4,14 @@
import socket
import asyncio
+import sys
import dns._asyncbackend
import dns.exception
+_is_win32 = sys.platform == 'win32'
+
def _get_running_loop():
try:
return asyncio.get_running_loop()
@@ -114,11 +117,16 @@ class Backend(dns._asyncbackend.Backend):
async def make_socket(self, af, socktype, proto=0,
source=None, destination=None, timeout=None,
ssl_context=None, server_hostname=None):
+ if destination is None and socktype == socket.SOCK_DGRAM and \
+ _is_win32:
+ raise NotImplementedError('destinationless datagram sockets '
+ 'are not supported by asyncio '
+ 'on Windows')
loop = _get_running_loop()
if socktype == socket.SOCK_DGRAM:
transport, protocol = await loop.create_datagram_endpoint(
_DatagramProtocol, source, family=af,
- proto=proto)
+ proto=proto, remote_addr=destination)
return DatagramSocket(af, transport, protocol)
elif socktype == socket.SOCK_STREAM:
(r, w) = await _maybe_wait_for(
@@ -136,3 +144,7 @@ class Backend(dns._asyncbackend.Backend):
async def sleep(self, interval):
await asyncio.sleep(interval)
+
+ def datagram_connection_required(self):
+ return _is_win32
+
diff --git a/dns/asyncquery.py b/dns/asyncquery.py
index 89c2622..0e353e8 100644
--- a/dns/asyncquery.py
+++ b/dns/asyncquery.py
@@ -142,7 +142,12 @@ async def udp(q, where, timeout=None, port=53, source=None, source_port=0,
if not backend:
backend = dns.asyncbackend.get_default_backend()
stuple = _source_tuple(af, source, source_port)
- s = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple)
+ if backend.datagram_connection_required():
+ dtuple = (where, port)
+ else:
+ dtuple = None
+ s = await backend.make_socket(af, socket.SOCK_DGRAM, 0, stuple,
+ dtuple)
await send_udp(s, wire, destination, expiration)
(r, received_time, _) = await receive_udp(s, destination, expiration,
ignore_unexpected,
diff --git a/tests/test_async.py b/tests/test_async.py
index e9a26bb..0252f22 100644
--- a/tests/test_async.py
+++ b/tests/test_async.py
@@ -17,6 +17,7 @@
import asyncio
import socket
+import sys
import time
import unittest
@@ -152,6 +153,7 @@ class MiscQuery(unittest.TestCase):
@unittest.skipIf(not _network_available, "Internet not reachable")
class AsyncTests(unittest.TestCase):
+ connect_udp = sys.platform == 'win32'
def setUp(self):
self.backend = dns.asyncbackend.set_default_backend('asyncio')
@@ -261,9 +263,13 @@ class AsyncTests(unittest.TestCase):
for address in query_addresses:
qname = dns.name.from_text('dns.google.')
async def run():
+ if self.connect_udp:
+ dtuple=(address, 53)
+ else:
+ dtuple=None
async with await self.backend.make_socket(
dns.inet.af_for_address(address),
- socket.SOCK_DGRAM) as s:
+ socket.SOCK_DGRAM, 0, None, dtuple) as s:
q = dns.message.make_query(qname, dns.rdatatype.A)
return await dns.asyncquery.udp(q, address, sock=s,
timeout=2)
@@ -373,6 +379,8 @@ class AsyncTests(unittest.TestCase):
self.assertFalse(tcp)
def testUDPReceiveQuery(self):
+ if self.connect_udp:
+ self.skipTest('test needs connectionless sockets')
async def run():
async with await self.backend.make_socket(
socket.AF_INET, socket.SOCK_DGRAM,
@@ -392,6 +400,8 @@ class AsyncTests(unittest.TestCase):
self.assertEqual(sender_address, recv_address)
def testUDPReceiveTimeout(self):
+ if self.connect_udp:
+ self.skipTest('test needs connectionless sockets')
async def arun():
async with await self.backend.make_socket(socket.AF_INET,
socket.SOCK_DGRAM, 0,
@@ -430,6 +440,7 @@ try:
return trio.run(afunc)
class TrioAsyncTests(AsyncTests):
+ connect_udp = False
def setUp(self):
self.backend = dns.asyncbackend.set_default_backend('trio')
@@ -453,6 +464,7 @@ try:
return curio.run(afunc)
class CurioAsyncTests(AsyncTests):
+ connect_udp = False
def setUp(self):
self.backend = dns.asyncbackend.set_default_backend('curio')