summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorBob Halley <halley@dnspython.org>2012-04-08 13:25:27 +0100
committerBob Halley <halley@dnspython.org>2012-04-08 13:25:27 +0100
commit0f8559141f589ad112ead7c1ae1432d3b2dbfe4b (patch)
tree0b5c3d9a9b0b3221d0ef26606d82cb77770827bc
parent70a37051ef8486e086e7f914bc493940c70cecd1 (diff)
downloaddnspython-0f8559141f589ad112ead7c1ae1432d3b2dbfe4b.tar.gz
Add source_port support to resolver; fix source_port in query code
-rw-r--r--ChangeLog7
-rw-r--r--README6
-rw-r--r--dns/query.py73
-rw-r--r--dns/resolver.py21
4 files changed, 58 insertions, 49 deletions
diff --git a/ChangeLog b/ChangeLog
index 9b5d4e7..1736999 100644
--- a/ChangeLog
+++ b/ChangeLog
@@ -1,5 +1,12 @@
2012-04-08 Bob Halley <halley@dnspython.org>
+ * dns/query.py: Specifying source_port had no effect if source was
+ not specified. We now use the appropriate wildcard source in
+ that case.
+
+ * dns/resolver.py (Resolver.query): source_port may now be
+ specified.
+
* dns/resolver.py (Resolver.query): Switch to TCP when a UDP
response is truncated. Handle nameservers that serve on UDP
but not TCP.
diff --git a/README b/README
index d5609d1..ff1ff12 100644
--- a/README
+++ b/README
@@ -47,6 +47,8 @@ New since 1.9.4:
Trailing junk checking can be disabled.
+ A source port can be specified when creating a resolver query.
+
Bugs fixed since 1.9.4:
IPv4 and IPv6 address processing is now stricter.
@@ -56,6 +58,10 @@ Bugs fixed since 1.9.4:
expected) now raise dns.exception.FormError rather than
IndexError.
+ Specifying a source port without specifying source used to
+ have no effect, but now uses the wildcard address and the
+ specified port.
+
New since 1.9.3:
Nothing.
diff --git a/dns/query.py b/dns/query.py
index addee4e..2ed0b6a 100644
--- a/dns/query.py
+++ b/dns/query.py
@@ -144,6 +144,28 @@ def _addresses_equal(af, a1, a2):
n2 = dns.inet.inet_pton(af, a2[0])
return n1 == n2 and a1[1:] == a2[1:]
+def _destination_and_source(af, where, port, source, source_port):
+ # Apply defaults and compute destination and source tuples
+ # suitable for use in connect(), sendto(), or bind().
+ if af is None:
+ try:
+ af = dns.inet.af_for_address(where)
+ except:
+ af = dns.inet.AF_INET
+ if af == dns.inet.AF_INET:
+ destination = (where, port)
+ if source is not None or source_port != 0:
+ if source is None:
+ source = '0.0.0.0'
+ source = (source, source_port)
+ elif af == dns.inet.AF_INET6:
+ destination = (where, port, 0, 0)
+ if source is not None or source_port != 0:
+ if source is None:
+ source = '::'
+ source = (source, source_port, 0, 0)
+ return (af, destination, source)
+
def udp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
ignore_unexpected=False, one_rr_per_rrset=False):
"""Return the response obtained after sending a query via UDP.
@@ -162,7 +184,7 @@ def udp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
If the inference attempt fails, AF_INET is used.
@type af: int
@rtype: dns.message.Message object
- @param source: source address. The default is the IPv4 wildcard address.
+ @param source: source address. The default is the wildcard address.
@type source: string
@param source_port: The port from which to send the message.
The default is 0.
@@ -175,19 +197,8 @@ def udp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
"""
wire = q.to_wire()
- if af is None:
- try:
- af = dns.inet.af_for_address(where)
- except:
- af = dns.inet.AF_INET
- if af == dns.inet.AF_INET:
- destination = (where, port)
- if source is not None:
- source = (source, source_port)
- elif af == dns.inet.AF_INET6:
- destination = (where, port, 0, 0)
- if source is not None:
- source = (source, source_port, 0, 0)
+ (af, destination, source) = _destination_and_source(af, where, port, source,
+ source_port)
s = socket.socket(af, socket.SOCK_DGRAM, 0)
try:
expiration = _compute_expiration(timeout)
@@ -270,7 +281,7 @@ def tcp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
If the inference attempt fails, AF_INET is used.
@type af: int
@rtype: dns.message.Message object
- @param source: source address. The default is the IPv4 wildcard address.
+ @param source: source address. The default is the wildcard address.
@type source: string
@param source_port: The port from which to send the message.
The default is 0.
@@ -280,19 +291,8 @@ def tcp(q, where, timeout=None, port=53, af=None, source=None, source_port=0,
"""
wire = q.to_wire()
- if af is None:
- try:
- af = dns.inet.af_for_address(where)
- except:
- af = dns.inet.AF_INET
- if af == dns.inet.AF_INET:
- destination = (where, port)
- if source is not None:
- source = (source, source_port)
- elif af == dns.inet.AF_INET6:
- destination = (where, port, 0, 0)
- if source is not None:
- source = (source, source_port, 0, 0)
+ (af, destination, source) = _destination_and_source(af, where, port, source,
+ source_port)
s = socket.socket(af, socket.SOCK_STREAM, 0)
try:
expiration = _compute_expiration(timeout)
@@ -357,7 +357,7 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN,
take.
@type lifetime: float
@rtype: generator of dns.message.Message objects.
- @param source: source address. The default is the IPv4 wildcard address.
+ @param source: source address. The default is the wildcard address.
@type source: string
@param source_port: The port from which to send the message.
The default is 0.
@@ -384,19 +384,8 @@ def xfr(where, zone, rdtype=dns.rdatatype.AXFR, rdclass=dns.rdataclass.IN,
if not keyring is None:
q.use_tsig(keyring, keyname, algorithm=keyalgorithm)
wire = q.to_wire()
- if af is None:
- try:
- af = dns.inet.af_for_address(where)
- except:
- af = dns.inet.AF_INET
- if af == dns.inet.AF_INET:
- destination = (where, port)
- if source is not None:
- source = (source, source_port)
- elif af == dns.inet.AF_INET6:
- destination = (where, port, 0, 0)
- if source is not None:
- source = (source, source_port, 0, 0)
+ (af, destination, source) = _destination_and_source(af, where, port, source,
+ source_port)
if use_udp:
if rdtype != dns.rdatatype.IXFR:
raise ValueError('cannot do a UDP AXFR')
diff --git a/dns/resolver.py b/dns/resolver.py
index 2f56c7d..437e14e 100644
--- a/dns/resolver.py
+++ b/dns/resolver.py
@@ -694,7 +694,7 @@ class Resolver(object):
return min(self.lifetime - duration, self.timeout)
def query(self, qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN,
- tcp=False, source=None, raise_on_no_answer=True):
+ tcp=False, source=None, raise_on_no_answer=True, source_port=0):
"""Query nameservers to find the answer to the question.
The I{qname}, I{rdtype}, and I{rdclass} parameters may be objects
@@ -715,6 +715,9 @@ class Resolver(object):
@param raise_on_no_answer: raise NoAnswer if there's no answer
(defaults is True).
@type raise_on_no_answer: bool
+ @param source_port: The port from which to send the message.
+ The default is 0.
+ @type source_port: int
@rtype: dns.resolver.Answer instance
@raises Timeout: no answers could be found in the specified lifetime
@raises NXDOMAIN: the query name does not exist
@@ -774,17 +777,20 @@ class Resolver(object):
if tcp:
response = dns.query.tcp(request, nameserver,
timeout, self.port,
- source=source)
+ source=source,
+ source_port=source_port)
else:
response = dns.query.udp(request, nameserver,
timeout, self.port,
- source=source)
+ source=source,
+ source_port=source_port)
if response.flags & dns.flags.TC:
# Response truncated; retry with TCP.
timeout = self._compute_timeout(start)
response = dns.query.tcp(request, nameserver,
- timeout, self.port,
- source=source)
+ timeout, self.port,
+ source=source,
+ source_port=source_port)
except (socket.error, dns.exception.Timeout):
#
# Communication failure or timeout. Go to the
@@ -903,7 +909,8 @@ def get_default_resolver():
return default_resolver
def query(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN,
- tcp=False, source=None, raise_on_no_answer=True):
+ tcp=False, source=None, raise_on_no_answer=True,
+ source_port=0):
"""Query nameservers to find the answer to the question.
This is a convenience function that uses the default resolver
@@ -911,7 +918,7 @@ def query(qname, rdtype=dns.rdatatype.A, rdclass=dns.rdataclass.IN,
@see: L{dns.resolver.Resolver.query} for more information on the
parameters."""
return get_default_resolver().query(qname, rdtype, rdclass, tcp, source,
- raise_on_no_answer)
+ raise_on_no_answer, source_port)
def zone_for_name(name, rdclass=dns.rdataclass.IN, tcp=False, resolver=None):
"""Find the name of the zone which contains the specified name.