diff options
author | Bob Halley <halley@dnspython.org> | 2020-06-02 07:22:42 -0700 |
---|---|---|
committer | Bob Halley <halley@dnspython.org> | 2020-06-02 07:22:42 -0700 |
commit | 842dc3a0c1466126d547a05b32abd29150ecc1d2 (patch) | |
tree | 3f6ad05194597b0c0d7e926e7f5dc8e8dbd72fa4 | |
parent | 856317e9d43b83bc8b14c85dd6a8a967da4b49cd (diff) | |
download | dnspython-842dc3a0c1466126d547a05b32abd29150ecc1d2.tar.gz |
pass peer and connection type to nanoserver handle()
-rw-r--r-- | tests/nanonameserver.py | 25 | ||||
-rw-r--r-- | tests/test_resolver.py | 2 |
2 files changed, 19 insertions, 8 deletions
diff --git a/tests/nanonameserver.py b/tests/nanonameserver.py index 1b4a434..a14d925 100644 --- a/tests/nanonameserver.py +++ b/tests/nanonameserver.py @@ -1,6 +1,7 @@ # Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license import contextlib +import enum import functools import socket import struct @@ -11,6 +12,9 @@ import dns.message import dns.rcode import dns.trio.query +class ConnectionType(enum.IntEnum): + UDP = 1 + TCP = 2 class Server(threading.Thread): @@ -95,7 +99,7 @@ class Server(threading.Thread): finally: raise EOFError - def handle(self, message): + def handle(self, message, peer, connection_type): # # Handle message 'message'. Override this method to change # how the server behaves. @@ -113,7 +117,7 @@ class Server(threading.Thread): except Exception: return None - def handle_wire(self, wire): + def handle_wire(self, wire, peer, connection_type): # # This is the common code to parse wire format, call handle() on # the message, and then generate resposne wire format (if handle() @@ -123,6 +127,12 @@ class Server(threading.Thread): # # Returns a wire format message to send, or None indicating there # is nothing to do. + # + # XXXRTH It might be nice to have a "debug mode" in the server + # where we'd print something in all the places we're eating + # exceptions. That way bugs in handle() would be easier to + # find. + # r = None try: q = dns.message.from_wire(wire) @@ -142,7 +152,7 @@ class Server(threading.Thread): # r might have been set above, so skip handle() if we # already have a response. if r is None: - r = self.handle(q) + r = self.handle(q, peer, connection_type) except Exception: # Exceptions from handle get a SERVFAIL response. r = dns.message.make_response(q) @@ -158,20 +168,21 @@ class Server(threading.Thread): self.udp = None # we own cleanup while True: try: - (wire, from_address) = await sock.recvfrom(65535) - wire = self.handle_wire(wire) + (wire, peer) = await sock.recvfrom(65535) + wire = self.handle_wire(wire, peer, ConnectionType.UDP) if wire is not None: - await sock.sendto(wire, from_address) + await sock.sendto(wire, peer) except Exception: pass async def serve_tcp(self, stream): try: + peer = stream.socket.getpeername() while True: ldata = await dns.trio.query.read_exactly(stream, 2) (l,) = struct.unpack("!H", ldata) wire = await dns.trio.query.read_exactly(stream, l) - wire = self.handle_wire(wire) + wire = self.handle_wire(wire, peer, ConnectionType.TCP) if wire is not None: l = len(wire) stream_message = struct.pack("!H", l) + wire diff --git a/tests/test_resolver.py b/tests/test_resolver.py index 87aebaa..309a89d 100644 --- a/tests/test_resolver.py +++ b/tests/test_resolver.py @@ -595,7 +595,7 @@ class ResolverNameserverValidTypeTestCase(unittest.TestCase): class NaptrNanoNameserver(Server): - def handle(self, message): + def handle(self, message, peer, connection_type): response = dns.message.make_response(message) response.set_rcode(dns.rcode.REFUSED) response.flags |= dns.flags.RA |