diff options
| author | Bob Halley <halley@nominum.com> | 2012-04-08 13:25:36 +0100 |
|---|---|---|
| committer | Bob Halley <halley@nominum.com> | 2012-04-08 13:25:36 +0100 |
| commit | 4d26315313c188502f44de17914f5905744f8f15 (patch) | |
| tree | be053ff0ff8c93fe8a4c4e40b7161dee36ff6387 | |
| parent | e1e32a32396c660746b5094330b70e4ad5f477e9 (diff) | |
| download | dnspython-4d26315313c188502f44de17914f5905744f8f15.tar.gz | |
Add source_port support to resolver; fix source_port in query code
| -rw-r--r-- | ChangeLog | 7 | ||||
| -rw-r--r-- | README | 6 | ||||
| -rw-r--r-- | dns/query.py | 73 | ||||
| -rw-r--r-- | dns/resolver.py | 21 |
4 files changed, 58 insertions, 49 deletions
@@ -1,5 +1,12 @@ 2012-04-08 Bob Halley <halley@dnspython.org> + * dns/query.py: Specifying source_port had no effect if source was + not specified. We now use the appropriate wildcard source in + that case. + + * dns/resolver.py (Resolver.query): source_port may now be + specified. + * dns/resolver.py (Resolver.query): Switch to TCP when a UDP response is truncated. Handle nameservers that serve on UDP but not TCP. @@ -47,6 +47,8 @@ New since 1.9.4: Trailing junk checking can be disabled. + A source port can be specified when creating a resolver query. + Bugs fixed since 1.9.4: IPv4 and IPv6 address processing is now stricter. @@ -56,6 +58,10 @@ Bugs fixed since 1.9.4: expected) now raise dns.exception.FormError rather than IndexError. + Specifying a source port without specifying source used to + have no effect, but now uses the wildcard address and the + specified port. + New since 1.9.3: Nothing. diff --git a/dns/query.py b/dns/query.py index 7bba352..0e6eb92 100644 --- a/dns/query.py +++ b/dns/query.py @@ -144,6 +144,28 @@ def _addresses_equal(af, a1, a2): n2 = dns.inet.inet_pton(af, a2[0]) return n1 == n2 and a1[1:] == a2[1:] +def _destination_and_source(af, where, port, source, source_port): + # Apply defaults and compute destination and source tuples + # suitable for use in connect(), sendto(), or bind(). + if af is None: + try: + af = dns.inet.af_for_address(where) + except: + af = dns.inet.AF_INET + if af == dns.inet.AF_INET: + destination = (where, port) + if source is not None or source_port != 0: + if source is None: + source = '0.0.0.0' + source = (source, source_port) + elif af == dns.inet.AF_INET6: + destination = (where, port, 0, 0) + if source is not None or source_port != 0: + if source is None: + source = '::' + source = (source, source_port, 0, 0) + return (af, destination, source) + def udp(q, where, timeout=None, port=53, af=None, source=None, source_port=0, ignore_unexpected=False, one_rr_per_rrset=False): """Return the response obtained after sending a query via UDP. @@ -162,7 +184,7 @@ def udp(q, where, timeout=None, port=53, af=None, source=None, source_port=0, If the inference attempt fails, AF_INET is used. @type af: int @rtype: dns.message.Message object - @param source: source address. The default is the IPv4 wildcard address. + @param source: source address. The default is the wildcard address. @type source: string @param source_port: The port from which to send the message. The default is 0. @@ -175,19 +197,8 @@ def udp(q, where, timeout=None, port=53, af=None, source=None, source_port=0, """ wire = q.to_wire() - if af is None: - try: - af = dns.inet.af_for_address(where) - except: - af = dns.inet.AF_INET - if af == dns.inet.AF_INET: - destination = (where, port) - if source is not None: - source = (source, source_port) - elif af == dns.inet.AF_INET6: - destination = (where, port, 0, 0) - if source is not None: - source = (source, source_port, 0, 0) + (af, destination, source) = _destination_and_source(af, where, port, source, + source_port) s = socket.socket(af, socket.SOCK_DGRAM, 0) try: expiration = _compute_expiration(timeout) @@ -270,7 +281,7 @@ def tcp(q, where, timeout=None, port=53, af=None, source=None, source_port=0, If the inference attempt fails, AF_INET is used. @type af: int @rtype: dns.message.Message object - @param source: source address. The default is the IPv4 wildcard address. + @param source: source address. The default is the wildcard address. @type source: string @param source_port: The port from which to send the message. The default is 0. @@ -280,19 +291,8 @@ def tcp(q, where, timeout=None, port=53, af=None, source=None, source_port=0, """ wire = q.to_wire() - if af is None: - try: - af = dns.inet.af_for_address(where) - except: - af = dns.inet.AF_INET - if af == dns.inet.AF_INET: - destination = (where, port) - if source is not None: - source = (source, source_port) - elif af == dns.inet.AF_INET6: - destination = (where, port, 0, 0) - if source is not None: - source = (source, source_port, 0, 0) + (af, destination, source) = _destination_and_source(af, where, port, source, + source_port) s = socket.socket(af, socket.SOCK_STREAM, 0) try: expiration = _compute_expiration(timeout) @@ -357,7 +357,7 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN, take. @type lifetime: float @rtype: generator of dns.message.Message objects. - @param source: source address. The default is the IPv4 wildcard address. + @param source: source address. The default is the wildcard address. @type source: string @param source_port: The port from which to send the message. The default is 0. @@ -384,19 +384,8 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN, if not keyring is None: q.use_tsig(keyring, keyname, algorithm=keyalgorithm) wire = q.to_wire() - if af is None: - try: - af = dns.inet.af_for_address(where) - except: - af = dns.inet.AF_INET - if af == dns.inet.AF_INET: - destination = (where, port) - if source is not None: - source = (source, source_port) - elif af == dns.inet.AF_INET6: - destination = (where, port, 0, 0) - if source is not None: - source = (source, source_port, 0, 0) + (af, destination, source) = _destination_and_source(af, where, port, source, + source_port) if use_udp: if rdtype != dns.rdatatype.IXFR: raise ValueError('cannot do a UDP AXFR') diff --git a/dns/resolver.py b/dns/resolver.py index 9f9b438..4fb13d3 100644 --- a/dns/resolver.py +++ b/dns/resolver.py @@ -688,7 +688,7 @@ class Resolver(object): return min(self.lifetime - duration, self.timeout) def query(self, qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN, - tcp=False, source=None, raise_on_no_answer=True): + tcp=False, source=None, raise_on_no_answer=True, source_port=0): """Query nameservers to find the answer to the question. The I{qname}, I{rdtype}, and I{rdclass} parameters may be objects @@ -709,6 +709,9 @@ class Resolver(object): @param raise_on_no_answer: raise NoAnswer if there's no answer (defaults is True). @type raise_on_no_answer: bool + @param source_port: The port from which to send the message. + The default is 0. + @type source_port: int @rtype: dns.resolver.Answer instance @raises Timeout: no answers could be found in the specified lifetime @raises NXDOMAIN: the query name does not exist @@ -768,17 +771,20 @@ class Resolver(object): if tcp: response = dns.query.tcp(request, nameserver, timeout, self.port, - source=source) + source=source, + source_port=source_port) else: response = dns.query.udp(request, nameserver, timeout, self.port, - source=source) + source=source, + source_port=source_port) if response.flags & dns.flags.TC: # Response truncated; retry with TCP. timeout = self._compute_timeout(start) response = dns.query.tcp(request, nameserver, - timeout, self.port, - source=source) + timeout, self.port, + source=source, + source_port=source_port) except (socket.error, dns.exception.Timeout): # @@ -898,7 +904,8 @@ def get_default_resolver(): return default_resolver def query(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN, - tcp=False, source=None, raise_on_no_answer=True): + tcp=False, source=None, raise_on_no_answer=True, + source_port=0): """Query nameservers to find the answer to the question. This is a convenience function that uses the default resolver @@ -906,7 +913,7 @@ def query(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN, @see: L{dns.resolver.Resolver.query} for more information on the parameters.""" return get_default_resolver().query(qname, rdtype, rdclass, tcp, source, - raise_on_no_answer) + raise_on_no_answer, source_port) def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False, resolver=None): """Find the name of the zone which contains the specified name. |
