summaryrefslogtreecommitdiff
path: root/src/net/dnsclient_unix_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'src/net/dnsclient_unix_test.go')
-rw-r--r--src/net/dnsclient_unix_test.go77
1 files changed, 55 insertions, 22 deletions
diff --git a/src/net/dnsclient_unix_test.go b/src/net/dnsclient_unix_test.go
index a95b2fe645..bb014b903a 100644
--- a/src/net/dnsclient_unix_test.go
+++ b/src/net/dnsclient_unix_test.go
@@ -113,7 +113,7 @@ var specialDomainNameTests = []struct {
}
func TestSpecialDomainName(t *testing.T) {
- fake := fakeDNSServer{func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ fake := fakeDNSServer{rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
r := dnsmessage.Message{
Header: dnsmessage.Header{
ID: q.ID,
@@ -189,7 +189,7 @@ func TestAvoidDNSName(t *testing.T) {
}
}
-var fakeDNSServerSuccessful = fakeDNSServer{func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+var fakeDNSServerSuccessful = fakeDNSServer{rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
r := dnsmessage.Message{
Header: dnsmessage.Header{
ID: q.ID,
@@ -473,7 +473,7 @@ var goLookupIPWithResolverConfigTests = []struct {
func TestGoLookupIPWithResolverConfig(t *testing.T) {
defer dnsWaitGroup.Wait()
- fake := fakeDNSServer{func(n, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ fake := fakeDNSServer{rh: func(n, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
switch s {
case "[2001:4860:4860::8888]:53", "8.8.8.8:53":
break
@@ -571,7 +571,7 @@ func TestGoLookupIPWithResolverConfig(t *testing.T) {
func TestGoLookupIPOrderFallbackToFile(t *testing.T) {
defer dnsWaitGroup.Wait()
- fake := fakeDNSServer{func(n, s string, q dnsmessage.Message, tm time.Time) (dnsmessage.Message, error) {
+ fake := fakeDNSServer{rh: func(n, s string, q dnsmessage.Message, tm time.Time) (dnsmessage.Message, error) {
r := dnsmessage.Message{
Header: dnsmessage.Header{
ID: q.ID,
@@ -641,7 +641,7 @@ func TestErrorForOriginalNameWhenSearching(t *testing.T) {
t.Fatal(err)
}
- fake := fakeDNSServer{func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ fake := fakeDNSServer{rh: func(_, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
r := dnsmessage.Message{
Header: dnsmessage.Header{
ID: q.ID,
@@ -696,7 +696,7 @@ func TestIgnoreLameReferrals(t *testing.T) {
t.Fatal(err)
}
- fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
t.Log(s, q)
r := dnsmessage.Message{
Header: dnsmessage.Header{
@@ -788,12 +788,15 @@ func BenchmarkGoLookupIPWithBrokenNameServer(b *testing.B) {
}
type fakeDNSServer struct {
- rh func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error)
+ rh func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error)
+ alwaysTCP bool
}
func (server *fakeDNSServer) DialContext(_ context.Context, n, s string) (Conn, error) {
- tcp := n == "tcp" || n == "tcp4" || n == "tcp6"
- return &fakeDNSConn{tcp: tcp, server: server, n: n, s: s}, nil
+ if server.alwaysTCP || n == "tcp" || n == "tcp4" || n == "tcp6" {
+ return &fakeDNSConn{tcp: true, server: server, n: n, s: s}, nil
+ }
+ return &fakeDNSPacketConn{fakeDNSConn: fakeDNSConn{tcp: false, server: server, n: n, s: s}}, nil
}
type fakeDNSConn struct {
@@ -846,10 +849,6 @@ func (f *fakeDNSConn) Read(b []byte) (int, error) {
return len(bb), nil
}
-func (f *fakeDNSConn) ReadFrom(b []byte) (int, Addr, error) {
- return 0, nil, nil
-}
-
func (f *fakeDNSConn) Write(b []byte) (int, error) {
if f.tcp && len(b) >= 2 {
b = b[2:]
@@ -860,15 +859,24 @@ func (f *fakeDNSConn) Write(b []byte) (int, error) {
return len(b), nil
}
-func (f *fakeDNSConn) WriteTo(b []byte, addr Addr) (int, error) {
- return 0, nil
-}
-
func (f *fakeDNSConn) SetDeadline(t time.Time) error {
f.t = t
return nil
}
+type fakeDNSPacketConn struct {
+ PacketConn
+ fakeDNSConn
+}
+
+func (f *fakeDNSPacketConn) SetDeadline(t time.Time) error {
+ return f.fakeDNSConn.SetDeadline(t)
+}
+
+func (f *fakeDNSPacketConn) Close() error {
+ return f.fakeDNSConn.Close()
+}
+
// UDP round-tripper algorithm should ignore invalid DNS responses (issue 13281).
func TestIgnoreDNSForgeries(t *testing.T) {
c, s := Pipe()
@@ -973,7 +981,7 @@ func TestRetryTimeout(t *testing.T) {
var deadline0 time.Time
- fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
+ fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
t.Log(s, q, deadline)
if deadline.IsZero() {
@@ -1034,7 +1042,7 @@ func testRotate(t *testing.T, rotate bool, nameservers, wantServers []string) {
}
var usedServers []string
- fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
+ fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
usedServers = append(usedServers, s)
return mockTXTResponse(q), nil
}}
@@ -1218,7 +1226,7 @@ func TestStrictErrorsLookupIP(t *testing.T) {
}
for i, tt := range cases {
- fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
+ fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
t.Log(s, q)
switch tt.resolveWhich(q.Questions[0]) {
@@ -1356,7 +1364,7 @@ func TestStrictErrorsLookupTXT(t *testing.T) {
const searchY = "test.y.golang.org."
const txt = "Hello World"
- fake := fakeDNSServer{func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
+ fake := fakeDNSServer{rh: func(_, s string, q dnsmessage.Message, deadline time.Time) (dnsmessage.Message, error) {
t.Log(s, q)
switch q.Questions[0].Name.String() {
@@ -1402,7 +1410,7 @@ func TestStrictErrorsLookupTXT(t *testing.T) {
func TestDNSGoroutineRace(t *testing.T) {
defer dnsWaitGroup.Wait()
- fake := fakeDNSServer{func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error) {
+ fake := fakeDNSServer{rh: func(n, s string, q dnsmessage.Message, t time.Time) (dnsmessage.Message, error) {
time.Sleep(10 * time.Microsecond)
return dnsmessage.Message{}, poll.ErrTimeout
}}
@@ -1502,3 +1510,28 @@ func TestIssue12778(t *testing.T) {
t.Fatalf("Err = %#v; wanted %q", de.Err, errNoSuchHost.Error())
}
}
+
+// Issue 26573: verify that Conns that don't implement PacketConn are treated
+// as streams even when udp was requested.
+func TestDNSDialTCP(t *testing.T) {
+ fake := fakeDNSServer{
+ rh: func(n, _ string, q dnsmessage.Message, _ time.Time) (dnsmessage.Message, error) {
+ r := dnsmessage.Message{
+ Header: dnsmessage.Header{
+ ID: q.Header.ID,
+ Response: true,
+ RCode: dnsmessage.RCodeSuccess,
+ },
+ Questions: q.Questions,
+ }
+ return r, nil
+ },
+ alwaysTCP: true,
+ }
+ r := Resolver{PreferGo: true, Dial: fake.DialContext}
+ ctx := context.Background()
+ _, _, err := r.exchange(ctx, "0.0.0.0", mustQuestion("com.", dnsmessage.TypeALL, dnsmessage.ClassINET), time.Second)
+ if err != nil {
+ t.Fatal("exhange failed:", err)
+ }
+}