diff options
author | Bob Halley <halley@dnspython.org> | 2012-04-08 13:25:27 +0100 |
---|---|---|
committer | Bob Halley <halley@dnspython.org> | 2012-04-08 13:25:27 +0100 |
commit | 0f8559141f589ad112ead7c1ae1432d3b2dbfe4b (patch) | |
tree | 0b5c3d9a9b0b3221d0ef26606d82cb77770827bc | |
parent | 70a37051ef8486e086e7f914bc493940c70cecd1 (diff) | |
download | dnspython-0f8559141f589ad112ead7c1ae1432d3b2dbfe4b.tar.gz |
Add source_port support to resolver; fix source_port in query code
-rw-r--r-- | ChangeLog | 7 | ||||
-rw-r--r-- | README | 6 | ||||
-rw-r--r-- | dns/query.py | 73 | ||||
-rw-r--r-- | dns/resolver.py | 21 |
4 files changed, 58 insertions, 49 deletions
@@ -1,5 +1,12 @@ 2012-04-08 Bob Halley <halley@dnspython.org> + * dns/query.py: Specifying source_port had no effect if source was + not specified. We now use the appropriate wildcard source in + that case. + + * dns/resolver.py (Resolver.query): source_port may now be + specified. + * dns/resolver.py (Resolver.query): Switch to TCP when a UDP response is truncated. Handle nameservers that serve on UDP but not TCP. @@ -47,6 +47,8 @@ New since 1.9.4: Trailing junk checking can be disabled. + A source port can be specified when creating a resolver query. + Bugs fixed since 1.9.4: IPv4 and IPv6 address processing is now stricter. @@ -56,6 +58,10 @@ Bugs fixed since 1.9.4: expected) now raise dns.exception.FormError rather than IndexError. + Specifying a source port without specifying source used to + have no effect, but now uses the wildcard address and the + specified port. + New since 1.9.3: Nothing. diff --git a/dns/query.py b/dns/query.py index addee4e..2ed0b6a 100644 --- a/dns/query.py +++ b/dns/query.py @@ -144,6 +144,28 @@ def _addresses_equal(af, a1, a2): n2 = dns.inet.inet_pton(af, a2[0]) return n1 == n2 and a1[1:] == a2[1:] +def _destination_and_source(af, where, port, source, source_port): + # Apply defaults and compute destination and source tuples + # suitable for use in connect(), sendto(), or bind(). + if af is None: + try: + af = dns.inet.af_for_address(where) + except: + af = dns.inet.AF_INET + if af == dns.inet.AF_INET: + destination = (where, port) + if source is not None or source_port != 0: + if source is None: + source = '0.0.0.0' + source = (source, source_port) + elif af == dns.inet.AF_INET6: + destination = (where, port, 0, 0) + if source is not None or source_port != 0: + if source is None: + source = '::' + source = (source, source_port, 0, 0) + return (af, destination, source) + def udp(q, where, timeout=None, port=53, af=None, source=None, source_port=0, ignore_unexpected=False, one_rr_per_rrset=False): """Return the response obtained after sending a query via UDP. @@ -162,7 +184,7 @@ def udp(q, where, timeout=None, port=53, af=None, source=None, source_port=0, If the inference attempt fails, AF_INET is used. @type af: int @rtype: dns.message.Message object - @param source: source address. The default is the IPv4 wildcard address. + @param source: source address. The default is the wildcard address. @type source: string @param source_port: The port from which to send the message. The default is 0. @@ -175,19 +197,8 @@ def udp(q, where, timeout=None, port=53, af=None, source=None, source_port=0, """ wire = q.to_wire() - if af is None: - try: - af = dns.inet.af_for_address(where) - except: - af = dns.inet.AF_INET - if af == dns.inet.AF_INET: - destination = (where, port) - if source is not None: - source = (source, source_port) - elif af == dns.inet.AF_INET6: - destination = (where, port, 0, 0) - if source is not None: - source = (source, source_port, 0, 0) + (af, destination, source) = _destination_and_source(af, where, port, source, + source_port) s = socket.socket(af, socket.SOCK_DGRAM, 0) try: expiration = _compute_expiration(timeout) @@ -270,7 +281,7 @@ def tcp(q, where, timeout=None, port=53, af=None, source=None, source_port=0, If the inference attempt fails, AF_INET is used. @type af: int @rtype: dns.message.Message object - @param source: source address. The default is the IPv4 wildcard address. + @param source: source address. The default is the wildcard address. @type source: string @param source_port: The port from which to send the message. The default is 0. @@ -280,19 +291,8 @@ def tcp(q, where, timeout=None, port=53, af=None, source=None, source_port=0, """ wire = q.to_wire() - if af is None: - try: - af = dns.inet.af_for_address(where) - except: - af = dns.inet.AF_INET - if af == dns.inet.AF_INET: - destination = (where, port) - if source is not None: - source = (source, source_port) - elif af == dns.inet.AF_INET6: - destination = (where, port, 0, 0) - if source is not None: - source = (source, source_port, 0, 0) + (af, destination, source) = _destination_and_source(af, where, port, source, + source_port) s = socket.socket(af, socket.SOCK_STREAM, 0) try: expiration = _compute_expiration(timeout) @@ -357,7 +357,7 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN, take. @type lifetime: float @rtype: generator of dns.message.Message objects. - @param source: source address. The default is the IPv4 wildcard address. + @param source: source address. The default is the wildcard address. @type source: string @param source_port: The port from which to send the message. The default is 0. @@ -384,19 +384,8 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN, if not keyring is None: q.use_tsig(keyring, keyname, algorithm=keyalgorithm) wire = q.to_wire() - if af is None: - try: - af = dns.inet.af_for_address(where) - except: - af = dns.inet.AF_INET - if af == dns.inet.AF_INET: - destination = (where, port) - if source is not None: - source = (source, source_port) - elif af == dns.inet.AF_INET6: - destination = (where, port, 0, 0) - if source is not None: - source = (source, source_port, 0, 0) + (af, destination, source) = _destination_and_source(af, where, port, source, + source_port) if use_udp: if rdtype != dns.rdatatype.IXFR: raise ValueError('cannot do a UDP AXFR') diff --git a/dns/resolver.py b/dns/resolver.py index 2f56c7d..437e14e 100644 --- a/dns/resolver.py +++ b/dns/resolver.py @@ -694,7 +694,7 @@ class Resolver(object): return min(self.lifetime - duration, self.timeout) def query(self, qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN, - tcp=False, source=None, raise_on_no_answer=True): + tcp=False, source=None, raise_on_no_answer=True, source_port=0): """Query nameservers to find the answer to the question. The I{qname}, I{rdtype}, and I{rdclass} parameters may be objects @@ -715,6 +715,9 @@ class Resolver(object): @param raise_on_no_answer: raise NoAnswer if there's no answer (defaults is True). @type raise_on_no_answer: bool + @param source_port: The port from which to send the message. + The default is 0. + @type source_port: int @rtype: dns.resolver.Answer instance @raises Timeout: no answers could be found in the specified lifetime @raises NXDOMAIN: the query name does not exist @@ -774,17 +777,20 @@ class Resolver(object): if tcp: response = dns.query.tcp(request, nameserver, timeout, self.port, - source=source) + source=source, + source_port=source_port) else: response = dns.query.udp(request, nameserver, timeout, self.port, - source=source) + source=source, + source_port=source_port) if response.flags & dns.flags.TC: # Response truncated; retry with TCP. timeout = self._compute_timeout(start) response = dns.query.tcp(request, nameserver, - timeout, self.port, - source=source) + timeout, self.port, + source=source, + source_port=source_port) except (socket.error, dns.exception.Timeout): # # Communication failure or timeout. Go to the @@ -903,7 +909,8 @@ def get_default_resolver(): return default_resolver def query(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN, - tcp=False, source=None, raise_on_no_answer=True): + tcp=False, source=None, raise_on_no_answer=True, + source_port=0): """Query nameservers to find the answer to the question. This is a convenience function that uses the default resolver @@ -911,7 +918,7 @@ def query(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN, @see: L{dns.resolver.Resolver.query} for more information on the parameters.""" return get_default_resolver().query(qname, rdtype, rdclass, tcp, source, - raise_on_no_answer) + raise_on_no_answer, source_port) def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False, resolver=None): """Find the name of the zone which contains the specified name. |