summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBob Halley <halley@dnspython.org>2020-06-02 07:22:42 -0700
committerBob Halley <halley@dnspython.org>2020-06-02 07:22:42 -0700
commit842dc3a0c1466126d547a05b32abd29150ecc1d2 (patch)
tree3f6ad05194597b0c0d7e926e7f5dc8e8dbd72fa4
parent856317e9d43b83bc8b14c85dd6a8a967da4b49cd (diff)
downloaddnspython-842dc3a0c1466126d547a05b32abd29150ecc1d2.tar.gz
pass peer and connection type to nanoserver handle()
-rw-r--r--tests/nanonameserver.py25
-rw-r--r--tests/test_resolver.py2
2 files changed, 19 insertions, 8 deletions
diff --git a/tests/nanonameserver.py b/tests/nanonameserver.py
index 1b4a434..a14d925 100644
--- a/tests/nanonameserver.py
+++ b/tests/nanonameserver.py
@@ -1,6 +1,7 @@
# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license
import contextlib
+import enum
import functools
import socket
import struct
@@ -11,6 +12,9 @@ import dns.message
import dns.rcode
import dns.trio.query
+class ConnectionType(enum.IntEnum):
+ UDP = 1
+ TCP = 2
class Server(threading.Thread):
@@ -95,7 +99,7 @@ class Server(threading.Thread):
finally:
raise EOFError
- def handle(self, message):
+ def handle(self, message, peer, connection_type):
#
# Handle message 'message'. Override this method to change
# how the server behaves.
@@ -113,7 +117,7 @@ class Server(threading.Thread):
except Exception:
return None
- def handle_wire(self, wire):
+ def handle_wire(self, wire, peer, connection_type):
#
# This is the common code to parse wire format, call handle() on
# the message, and then generate resposne wire format (if handle()
@@ -123,6 +127,12 @@ class Server(threading.Thread):
#
# Returns a wire format message to send, or None indicating there
# is nothing to do.
+ #
+ # XXXRTH It might be nice to have a "debug mode" in the server
+ # where we'd print something in all the places we're eating
+ # exceptions. That way bugs in handle() would be easier to
+ # find.
+ #
r = None
try:
q = dns.message.from_wire(wire)
@@ -142,7 +152,7 @@ class Server(threading.Thread):
# r might have been set above, so skip handle() if we
# already have a response.
if r is None:
- r = self.handle(q)
+ r = self.handle(q, peer, connection_type)
except Exception:
# Exceptions from handle get a SERVFAIL response.
r = dns.message.make_response(q)
@@ -158,20 +168,21 @@ class Server(threading.Thread):
self.udp = None # we own cleanup
while True:
try:
- (wire, from_address) = await sock.recvfrom(65535)
- wire = self.handle_wire(wire)
+ (wire, peer) = await sock.recvfrom(65535)
+ wire = self.handle_wire(wire, peer, ConnectionType.UDP)
if wire is not None:
- await sock.sendto(wire, from_address)
+ await sock.sendto(wire, peer)
except Exception:
pass
async def serve_tcp(self, stream):
try:
+ peer = stream.socket.getpeername()
while True:
ldata = await dns.trio.query.read_exactly(stream, 2)
(l,) = struct.unpack("!H", ldata)
wire = await dns.trio.query.read_exactly(stream, l)
- wire = self.handle_wire(wire)
+ wire = self.handle_wire(wire, peer, ConnectionType.TCP)
if wire is not None:
l = len(wire)
stream_message = struct.pack("!H", l) + wire
diff --git a/tests/test_resolver.py b/tests/test_resolver.py
index 87aebaa..309a89d 100644
--- a/tests/test_resolver.py
+++ b/tests/test_resolver.py
@@ -595,7 +595,7 @@ class ResolverNameserverValidTypeTestCase(unittest.TestCase):
class NaptrNanoNameserver(Server):
- def handle(self, message):
+ def handle(self, message, peer, connection_type):
response = dns.message.make_response(message)
response.set_rcode(dns.rcode.REFUSED)
response.flags |= dns.flags.RA