summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBob Halley <halley@dnspython.org>2023-03-02 07:51:50 -0800
committerGitHub <noreply@github.com>2023-03-02 07:51:50 -0800
commit3dd5ed5a8889ca2453a3796f5c77412313d76758 (patch)
tree512553a37b323b49dec8b9154817755385c5887b
parente37866650cbba14ffc1947430494660ec3a405fd (diff)
downloaddnspython-3dd5ed5a8889ca2453a3796f5c77412313d76758.tar.gz
Fix hangs when QUIC connection fails [#899]. (#900)
This also fixes problems with computing the wait_for() timeout for the sync and asyncio ports, and fixes delivery of the timeout for the sync port.
-rw-r--r--dns/quic/_asyncio.py12
-rw-r--r--dns/quic/_sync.py48
-rw-r--r--dns/quic/_trio.py37
3 files changed, 62 insertions, 35 deletions
diff --git a/dns/quic/_asyncio.py b/dns/quic/_asyncio.py
index bcce048..80f244d 100644
--- a/dns/quic/_asyncio.py
+++ b/dns/quic/_asyncio.py
@@ -17,6 +17,7 @@ from dns.quic._common import (
AsyncQuicConnection,
AsyncQuicManager,
QUIC_MAX_DATAGRAM,
+ UnexpectedEOF,
)
@@ -30,8 +31,8 @@ class AsyncioQuicStream(BaseQuicStream):
await self._wake_up.wait()
async def wait_for(self, amount, expiration):
- timeout = self._timeout_from_expiration(expiration)
while True:
+ timeout = self._timeout_from_expiration(expiration)
if self._buffer.have(amount):
return
self._expecting = amount
@@ -106,6 +107,11 @@ class AsyncioQuicConnection(AsyncQuicConnection):
self._wake_timer.notify_all()
except Exception:
pass
+ finally:
+ self._done = True
+ async with self._wake_timer:
+ self._wake_timer.notify_all()
+ self._handshake_complete.set()
async def _wait_for_wake_timer(self):
async with self._wake_timer:
@@ -115,7 +121,7 @@ class AsyncioQuicConnection(AsyncQuicConnection):
await self._socket_created.wait()
while not self._done:
datagrams = self._connection.datagrams_to_send(time.time())
- for (datagram, address) in datagrams:
+ for datagram, address in datagrams:
assert address == self._peer[0]
await self._socket.sendto(datagram, self._peer, None)
(expiration, interval) = self._get_timer_values()
@@ -162,6 +168,8 @@ class AsyncioQuicConnection(AsyncQuicConnection):
async def make_stream(self):
await self._handshake_complete.wait()
+ if self._done:
+ raise UnexpectedEOF
stream_id = self._connection.get_next_available_stream_id(False)
stream = AsyncioQuicStream(self, stream_id)
self._streams[stream_id] = stream
diff --git a/dns/quic/_sync.py b/dns/quic/_sync.py
index 8cc606a..bc034fa 100644
--- a/dns/quic/_sync.py
+++ b/dns/quic/_sync.py
@@ -17,6 +17,7 @@ from dns.quic._common import (
BaseQuicConnection,
BaseQuicManager,
QUIC_MAX_DATAGRAM,
+ UnexpectedEOF,
)
# Avoid circularity with dns.query
@@ -33,14 +34,15 @@ class SyncQuicStream(BaseQuicStream):
self._lock = threading.Lock()
def wait_for(self, amount, expiration):
- timeout = self._timeout_from_expiration(expiration)
while True:
+ timeout = self._timeout_from_expiration(expiration)
with self._lock:
if self._buffer.have(amount):
return
self._expecting = amount
with self._wake_up:
- self._wake_up.wait(timeout)
+ if not self._wake_up.wait(timeout):
+ raise TimeoutError
self._expecting = 0
def receive(self, timeout=None):
@@ -114,24 +116,30 @@ class SyncQuicConnection(BaseQuicConnection):
return
def _worker(self):
- sel = _selector_class()
- sel.register(self._socket, selectors.EVENT_READ, self._read)
- sel.register(self._receive_wakeup, selectors.EVENT_READ, self._drain_wakeup)
- while not self._done:
- (expiration, interval) = self._get_timer_values(False)
- items = sel.select(interval)
- for (key, _) in items:
- key.data()
+ try:
+ sel = _selector_class()
+ sel.register(self._socket, selectors.EVENT_READ, self._read)
+ sel.register(self._receive_wakeup, selectors.EVENT_READ, self._drain_wakeup)
+ while not self._done:
+ (expiration, interval) = self._get_timer_values(False)
+ items = sel.select(interval)
+ for key, _ in items:
+ key.data()
+ with self._lock:
+ self._handle_timer(expiration)
+ datagrams = self._connection.datagrams_to_send(time.time())
+ for datagram, _ in datagrams:
+ try:
+ self._socket.send(datagram)
+ except BlockingIOError:
+ # we let QUIC handle any lossage
+ pass
+ self._handle_events()
+ finally:
with self._lock:
- self._handle_timer(expiration)
- datagrams = self._connection.datagrams_to_send(time.time())
- for (datagram, _) in datagrams:
- try:
- self._socket.send(datagram)
- except BlockingIOError:
- # we let QUIC handle any lossage
- pass
- self._handle_events()
+ self._done = True
+ # Ensure anyone waiting for this gets woken up.
+ self._handshake_complete.set()
def _handle_events(self):
while True:
@@ -166,6 +174,8 @@ class SyncQuicConnection(BaseQuicConnection):
def make_stream(self):
self._handshake_complete.wait()
with self._lock:
+ if self._done:
+ raise UnexpectedEOF
stream_id = self._connection.get_next_available_stream_id(False)
stream = SyncQuicStream(self, stream_id)
self._streams[stream_id] = stream
diff --git a/dns/quic/_trio.py b/dns/quic/_trio.py
index 543e3cb..7f81061 100644
--- a/dns/quic/_trio.py
+++ b/dns/quic/_trio.py
@@ -17,6 +17,7 @@ from dns.quic._common import (
AsyncQuicConnection,
AsyncQuicManager,
QUIC_MAX_DATAGRAM,
+ UnexpectedEOF,
)
@@ -80,20 +81,26 @@ class TrioQuicConnection(AsyncQuicConnection):
self._worker_scope = None
async def _worker(self):
- await self._socket.connect(self._peer)
- while not self._done:
- (expiration, interval) = self._get_timer_values(False)
- with trio.CancelScope(
- deadline=trio.current_time() + interval
- ) as self._worker_scope:
- datagram = await self._socket.recv(QUIC_MAX_DATAGRAM)
- self._connection.receive_datagram(datagram, self._peer[0], time.time())
- self._worker_scope = None
- self._handle_timer(expiration)
- datagrams = self._connection.datagrams_to_send(time.time())
- for (datagram, _) in datagrams:
- await self._socket.send(datagram)
- await self._handle_events()
+ try:
+ await self._socket.connect(self._peer)
+ while not self._done:
+ (expiration, interval) = self._get_timer_values(False)
+ with trio.CancelScope(
+ deadline=trio.current_time() + interval
+ ) as self._worker_scope:
+ datagram = await self._socket.recv(QUIC_MAX_DATAGRAM)
+ self._connection.receive_datagram(
+ datagram, self._peer[0], time.time()
+ )
+ self._worker_scope = None
+ self._handle_timer(expiration)
+ datagrams = self._connection.datagrams_to_send(time.time())
+ for datagram, _ in datagrams:
+ await self._socket.send(datagram)
+ await self._handle_events()
+ finally:
+ self._done = True
+ self._handshake_complete.set()
async def _handle_events(self):
count = 0
@@ -132,6 +139,8 @@ class TrioQuicConnection(AsyncQuicConnection):
async def make_stream(self):
await self._handshake_complete.wait()
+ if self._done:
+ raise UnexpectedEOF
stream_id = self._connection.get_next_available_stream_id(False)
stream = TrioQuicStream(self, stream_id)
self._streams[stream_id] = stream