summaryrefslogtreecommitdiff
path: root/libgo/go/net
diff options
context:
space:
mode:
authorian <ian@138bc75d-0d04-0410-961f-82ee72b054a4>2016-02-03 21:58:02 +0000
committerian <ian@138bc75d-0d04-0410-961f-82ee72b054a4>2016-02-03 21:58:02 +0000
commit0694cef2844753fb80be4f71f7d2eb82eb5ba464 (patch)
tree2f8da9862a9c1fe0df138917f997b03439c02773 /libgo/go/net
parent397fecd695789eccab667bf771a354df71d843e8 (diff)
downloadgcc-0694cef2844753fb80be4f71f7d2eb82eb5ba464.tar.gz
libgo: Update to go1.6rc1.
Reviewed-on: https://go-review.googlesource.com/19200 git-svn-id: svn+ssh://gcc.gnu.org/svn/gcc/trunk@233110 138bc75d-0d04-0410-961f-82ee72b054a4
Diffstat (limited to 'libgo/go/net')
-rw-r--r--libgo/go/net/addrselect.go43
-rw-r--r--libgo/go/net/addrselect_test.go115
-rw-r--r--libgo/go/net/cgo_socknew.go4
-rw-r--r--libgo/go/net/cgo_sockold.go4
-rw-r--r--libgo/go/net/cgo_unix.go19
-rw-r--r--libgo/go/net/conf.go3
-rw-r--r--libgo/go/net/dial.go17
-rw-r--r--libgo/go/net/dial_test.go74
-rw-r--r--libgo/go/net/dnsclient.go33
-rw-r--r--libgo/go/net/dnsclient_test.go48
-rw-r--r--libgo/go/net/dnsclient_unix.go41
-rw-r--r--libgo/go/net/dnsclient_unix_test.go133
-rw-r--r--libgo/go/net/dnsmsg.go3
-rw-r--r--libgo/go/net/error_test.go48
-rw-r--r--libgo/go/net/fd_plan9.go8
-rw-r--r--libgo/go/net/fd_unix.go20
-rw-r--r--libgo/go/net/fd_windows.go41
-rw-r--r--libgo/go/net/file_test.go366
-rw-r--r--libgo/go/net/file_unix.go2
-rw-r--r--libgo/go/net/hosts.go95
-rw-r--r--libgo/go/net/hosts_test.go27
-rw-r--r--libgo/go/net/http/cgi/host.go24
-rw-r--r--libgo/go/net/http/cgi/host_test.go27
-rw-r--r--libgo/go/net/http/client.go221
-rw-r--r--libgo/go/net/http/client_test.go209
-rw-r--r--libgo/go/net/http/clientserver_test.go1056
-rw-r--r--libgo/go/net/http/doc.go15
-rw-r--r--libgo/go/net/http/export_test.go106
-rw-r--r--libgo/go/net/http/fcgi/child.go3
-rw-r--r--libgo/go/net/http/fcgi/fcgi_test.go24
-rw-r--r--libgo/go/net/http/fs.go76
-rw-r--r--libgo/go/net/http/fs_test.go130
-rw-r--r--libgo/go/net/http/h2_bundle.go6530
-rw-r--r--libgo/go/net/http/header.go10
-rw-r--r--libgo/go/net/http/httptest/recorder.go45
-rw-r--r--libgo/go/net/http/httptest/recorder_test.go55
-rw-r--r--libgo/go/net/http/httptest/server.go248
-rw-r--r--libgo/go/net/http/httptest/server_test.go57
-rw-r--r--libgo/go/net/http/httputil/dump.go63
-rw-r--r--libgo/go/net/http/httputil/dump_test.go17
-rw-r--r--libgo/go/net/http/httputil/example_test.go125
-rw-r--r--libgo/go/net/http/httputil/reverseproxy.go26
-rw-r--r--libgo/go/net/http/httputil/reverseproxy_test.go116
-rw-r--r--libgo/go/net/http/internal/chunked.go42
-rw-r--r--libgo/go/net/http/internal/chunked_test.go51
-rw-r--r--libgo/go/net/http/internal/testcert.go41
-rw-r--r--libgo/go/net/http/lex.go14
-rw-r--r--libgo/go/net/http/main_test.go18
-rw-r--r--libgo/go/net/http/method.go20
-rw-r--r--libgo/go/net/http/pprof/pprof.go15
-rw-r--r--libgo/go/net/http/request.go179
-rw-r--r--libgo/go/net/http/request_test.go38
-rw-r--r--libgo/go/net/http/response.go22
-rw-r--r--libgo/go/net/http/response_test.go157
-rw-r--r--libgo/go/net/http/serve_test.go950
-rw-r--r--libgo/go/net/http/server.go729
-rw-r--r--libgo/go/net/http/sniff.go29
-rw-r--r--libgo/go/net/http/sniff_test.go46
-rw-r--r--libgo/go/net/http/status.go46
-rw-r--r--libgo/go/net/http/transfer.go87
-rw-r--r--libgo/go/net/http/transport.go737
-rw-r--r--libgo/go/net/http/transport_test.go415
-rw-r--r--libgo/go/net/http/triv.go5
-rw-r--r--libgo/go/net/interface_test.go44
-rw-r--r--libgo/go/net/interface_windows.go330
-rw-r--r--libgo/go/net/interface_windows_test.go132
-rw-r--r--libgo/go/net/internal/socktest/switch.go4
-rw-r--r--libgo/go/net/iprawsock_posix.go4
-rw-r--r--libgo/go/net/ipsock.go2
-rw-r--r--libgo/go/net/ipsock_posix.go15
-rw-r--r--libgo/go/net/lookup.go18
-rw-r--r--libgo/go/net/lookup_plan9.go12
-rw-r--r--libgo/go/net/lookup_test.go201
-rw-r--r--libgo/go/net/lookup_windows.go21
-rw-r--r--libgo/go/net/mac.go11
-rw-r--r--libgo/go/net/mac_test.go24
-rw-r--r--libgo/go/net/mail/message.go23
-rw-r--r--libgo/go/net/mail/message_test.go68
-rw-r--r--libgo/go/net/net.go24
-rw-r--r--libgo/go/net/net_test.go32
-rw-r--r--libgo/go/net/non_unix_test.go15
-rw-r--r--libgo/go/net/parse.go46
-rw-r--r--libgo/go/net/parse_test.go22
-rw-r--r--libgo/go/net/platform_test.go4
-rw-r--r--libgo/go/net/port.go24
-rw-r--r--libgo/go/net/port_test.go57
-rw-r--r--libgo/go/net/race.go31
-rw-r--r--libgo/go/net/race0.go26
-rw-r--r--libgo/go/net/rpc/server.go8
-rw-r--r--libgo/go/net/rpc/server_test.go37
-rw-r--r--libgo/go/net/sendfile_solaris.go2
-rw-r--r--libgo/go/net/server_test.go6
-rw-r--r--libgo/go/net/sock_posix.go8
-rw-r--r--libgo/go/net/tcp_test.go8
-rw-r--r--libgo/go/net/tcpsock_plan9.go5
-rw-r--r--libgo/go/net/tcpsock_posix.go10
-rw-r--r--libgo/go/net/tcpsockopt_plan9.go3
-rw-r--r--libgo/go/net/testdata/case-hosts2
-rw-r--r--libgo/go/net/testdata/hosts3
-rw-r--r--libgo/go/net/textproto/reader.go18
-rw-r--r--libgo/go/net/textproto/reader_test.go55
-rw-r--r--libgo/go/net/timeout_test.go51
-rw-r--r--libgo/go/net/udpsock_posix.go6
-rw-r--r--libgo/go/net/unix_test.go50
-rw-r--r--libgo/go/net/unixsock_posix.go11
-rw-r--r--libgo/go/net/url/url.go119
-rw-r--r--libgo/go/net/url/url_test.go224
107 files changed, 13901 insertions, 1883 deletions
diff --git a/libgo/go/net/addrselect.go b/libgo/go/net/addrselect.go
index e22fbac5cee..58ab7d706c6 100644
--- a/libgo/go/net/addrselect.go
+++ b/libgo/go/net/addrselect.go
@@ -197,6 +197,24 @@ func (s *byRFC6724) Less(i, j int) bool {
if da4 == db4 {
commonA := commonPrefixLen(SourceDA, DA)
commonB := commonPrefixLen(SourceDB, DB)
+
+ // CommonPrefixLen doesn't really make sense for IPv4, and even
+ // causes problems for common load balancing practices
+ // (e.g., https://golang.org/issue/13283). Glibc instead only
+ // uses CommonPrefixLen for IPv4 when the source and destination
+ // addresses are on the same subnet, but that requires extra
+ // work to find the netmask for our source addresses. As a
+ // simpler heuristic, we limit its use to when the source and
+ // destination belong to the same special purpose block.
+ if da4 {
+ if !sameIPv4SpecialPurposeBlock(SourceDA, DA) {
+ commonA = 0
+ }
+ if !sameIPv4SpecialPurposeBlock(SourceDB, DB) {
+ commonB = 0
+ }
+ }
+
if commonA > commonB {
return preferDA
}
@@ -386,3 +404,28 @@ func commonPrefixLen(a, b IP) (cpl int) {
}
return
}
+
+// sameIPv4SpecialPurposeBlock reports whether a and b belong to the same
+// address block reserved by the IANA IPv4 Special-Purpose Address Registry:
+// http://www.iana.org/assignments/iana-ipv4-special-registry/iana-ipv4-special-registry.xhtml
+func sameIPv4SpecialPurposeBlock(a, b IP) bool {
+ a, b = a.To4(), b.To4()
+ if a == nil || b == nil || a[0] != b[0] {
+ return false
+ }
+ // IANA defines more special-purpose blocks, but these are the only
+ // ones likely to be relevant to typical Go systems.
+ switch a[0] {
+ case 10: // 10.0.0.0/8: Private-Use
+ return true
+ case 127: // 127.0.0.0/8: Loopback
+ return true
+ case 169: // 169.254.0.0/16: Link Local
+ return a[1] == 254 && b[1] == 254
+ case 172: // 172.16.0.0/12: Private-Use
+ return a[1]&0xf0 == 16 && b[1]&0xf0 == 16
+ case 192: // 192.168.0.0/16: Private-Use
+ return a[1] == 168 && b[1] == 168
+ }
+ return false
+}
diff --git a/libgo/go/net/addrselect_test.go b/libgo/go/net/addrselect_test.go
index 562022772fa..80aa4eb195b 100644
--- a/libgo/go/net/addrselect_test.go
+++ b/libgo/go/net/addrselect_test.go
@@ -87,6 +87,57 @@ func TestSortByRFC6724(t *testing.T) {
},
reverse: true,
},
+
+ // Issue 13283. Having a 10/8 source address does not
+ // mean we should prefer 23/8 destination addresses.
+ {
+ in: []IPAddr{
+ {IP: ParseIP("54.83.193.112")},
+ {IP: ParseIP("184.72.238.214")},
+ {IP: ParseIP("23.23.172.185")},
+ {IP: ParseIP("75.101.148.21")},
+ {IP: ParseIP("23.23.134.56")},
+ {IP: ParseIP("23.21.50.150")},
+ },
+ srcs: []IP{
+ ParseIP("10.2.3.4"),
+ ParseIP("10.2.3.4"),
+ ParseIP("10.2.3.4"),
+ ParseIP("10.2.3.4"),
+ ParseIP("10.2.3.4"),
+ ParseIP("10.2.3.4"),
+ },
+ want: []IPAddr{
+ {IP: ParseIP("54.83.193.112")},
+ {IP: ParseIP("184.72.238.214")},
+ {IP: ParseIP("23.23.172.185")},
+ {IP: ParseIP("75.101.148.21")},
+ {IP: ParseIP("23.23.134.56")},
+ {IP: ParseIP("23.21.50.150")},
+ },
+ reverse: false,
+ },
+
+ // Prefer longer common prefixes, but only for IPv4 address
+ // pairs in the same special-purpose block.
+ {
+ in: []IPAddr{
+ {IP: ParseIP("1.2.3.4")},
+ {IP: ParseIP("10.55.0.1")},
+ {IP: ParseIP("10.66.0.1")},
+ },
+ srcs: []IP{
+ ParseIP("1.2.3.5"),
+ ParseIP("10.66.1.2"),
+ ParseIP("10.66.1.2"),
+ },
+ want: []IPAddr{
+ {IP: ParseIP("10.66.0.1")},
+ {IP: ParseIP("10.55.0.1")},
+ {IP: ParseIP("1.2.3.4")},
+ },
+ reverse: true,
+ },
}
for i, tt := range tests {
inCopy := make([]IPAddr, len(tt.in))
@@ -217,3 +268,67 @@ func TestRFC6724CommonPrefixLength(t *testing.T) {
}
}
+
+func mustParseCIDRs(t *testing.T, blocks ...string) []*IPNet {
+ res := make([]*IPNet, len(blocks))
+ for i, block := range blocks {
+ var err error
+ _, res[i], err = ParseCIDR(block)
+ if err != nil {
+ t.Fatalf("ParseCIDR(%s) failed: %v", block, err)
+ }
+ }
+ return res
+}
+
+func TestSameIPv4SpecialPurposeBlock(t *testing.T) {
+ blocks := mustParseCIDRs(t,
+ "10.0.0.0/8",
+ "127.0.0.0/8",
+ "169.254.0.0/16",
+ "172.16.0.0/12",
+ "192.168.0.0/16",
+ )
+
+ addrs := []struct {
+ ip IP
+ block int // index or -1
+ }{
+ {IP{1, 2, 3, 4}, -1},
+ {IP{2, 3, 4, 5}, -1},
+ {IP{10, 2, 3, 4}, 0},
+ {IP{10, 6, 7, 8}, 0},
+ {IP{127, 0, 0, 1}, 1},
+ {IP{127, 255, 255, 255}, 1},
+ {IP{169, 254, 77, 99}, 2},
+ {IP{169, 254, 44, 22}, 2},
+ {IP{169, 255, 0, 1}, -1},
+ {IP{172, 15, 5, 6}, -1},
+ {IP{172, 16, 32, 41}, 3},
+ {IP{172, 31, 128, 9}, 3},
+ {IP{172, 32, 88, 100}, -1},
+ {IP{192, 168, 1, 1}, 4},
+ {IP{192, 168, 128, 42}, 4},
+ {IP{192, 169, 1, 1}, -1},
+ }
+
+ for i, addr := range addrs {
+ for j, block := range blocks {
+ got := block.Contains(addr.ip)
+ want := addr.block == j
+ if got != want {
+ t.Errorf("%d/%d. %s.Contains(%s): got %v, want %v", i, j, block, addr.ip, got, want)
+ }
+ }
+ }
+
+ for i, addr1 := range addrs {
+ for j, addr2 := range addrs {
+ got := sameIPv4SpecialPurposeBlock(addr1.ip, addr2.ip)
+ want := addr1.block >= 0 && addr1.block == addr2.block
+ if got != want {
+ t.Errorf("%d/%d. sameIPv4SpecialPurposeBlock(%s, %s): got %v, want %v", i, j, addr1.ip, addr2.ip, got, want)
+ }
+ }
+ }
+}
diff --git a/libgo/go/net/cgo_socknew.go b/libgo/go/net/cgo_socknew.go
index 81816c644f8..9dcd15883e4 100644
--- a/libgo/go/net/cgo_socknew.go
+++ b/libgo/go/net/cgo_socknew.go
@@ -25,8 +25,8 @@ func cgoSockaddrInet4(ip IP) *syscall.RawSockaddr {
return (*syscall.RawSockaddr)(unsafe.Pointer(&sa))
}
-func cgoSockaddrInet6(ip IP) *syscall.RawSockaddr {
- sa := syscall.RawSockaddrInet6{Family: syscall.AF_INET6}
+func cgoSockaddrInet6(ip IP, zone int) *syscall.RawSockaddr {
+ sa := syscall.RawSockaddrInet6{Family: syscall.AF_INET6, Scope_id: uint32(zone)}
copy(sa.Addr[:], ip)
return (*syscall.RawSockaddr)(unsafe.Pointer(&sa))
}
diff --git a/libgo/go/net/cgo_sockold.go b/libgo/go/net/cgo_sockold.go
index e80e03b2b1e..432634b5521 100644
--- a/libgo/go/net/cgo_sockold.go
+++ b/libgo/go/net/cgo_sockold.go
@@ -25,8 +25,8 @@ func cgoSockaddrInet4(ip IP) *syscall.RawSockaddr {
return (*syscall.RawSockaddr)(unsafe.Pointer(&sa))
}
-func cgoSockaddrInet6(ip IP) *syscall.RawSockaddr {
- sa := syscall.RawSockaddrInet6{Len: syscall.SizeofSockaddrInet6, Family: syscall.AF_INET6}
+func cgoSockaddrInet6(ip IP, zone int) *syscall.RawSockaddr {
+ sa := syscall.RawSockaddrInet6{Len: syscall.SizeofSockaddrInet6, Family: syscall.AF_INET6, Scope_id: uint32(zone)}
copy(sa.Addr[:], ip)
return (*syscall.RawSockaddr)(unsafe.Pointer(&sa))
}
diff --git a/libgo/go/net/cgo_unix.go b/libgo/go/net/cgo_unix.go
index 8eafa8cbd44..f634323ed8b 100644
--- a/libgo/go/net/cgo_unix.go
+++ b/libgo/go/net/cgo_unix.go
@@ -211,11 +211,15 @@ func cgoLookupPTR(addr string) ([]string, error, bool) {
acquireThread()
defer releaseThread()
- ip := ParseIP(addr)
+ var zone string
+ ip := parseIPv4(addr)
+ if ip == nil {
+ ip, zone = parseIPv6(addr, true)
+ }
if ip == nil {
return nil, &DNSError{Err: "invalid address", Name: addr}, true
}
- sa, salen := cgoSockaddr(ip)
+ sa, salen := cgoSockaddr(ip, zone)
if sa == nil {
return nil, &DNSError{Err: "invalid address " + ip.String(), Name: addr}, true
}
@@ -247,20 +251,15 @@ func cgoLookupPTR(addr string) ([]string, error, bool) {
break
}
}
- // Add trailing dot to match pure Go reverse resolver
- // and all other lookup routines. See golang.org/issue/12189.
- if len(b) > 0 && b[len(b)-1] != '.' {
- b = append(b, '.')
- }
- return []string{string(b)}, nil, true
+ return []string{absDomainName(b)}, nil, true
}
-func cgoSockaddr(ip IP) (*syscall.RawSockaddr, syscall.Socklen_t) {
+func cgoSockaddr(ip IP, zone string) (*syscall.RawSockaddr, syscall.Socklen_t) {
if ip4 := ip.To4(); ip4 != nil {
return cgoSockaddrInet4(ip4), syscall.Socklen_t(syscall.SizeofSockaddrInet4)
}
if ip6 := ip.To16(); ip6 != nil {
- return cgoSockaddrInet6(ip6), syscall.Socklen_t(syscall.SizeofSockaddrInet6)
+ return cgoSockaddrInet6(ip6, zoneToInt(zone)), syscall.Socklen_t(syscall.SizeofSockaddrInet6)
}
return nil, 0
}
diff --git a/libgo/go/net/conf.go b/libgo/go/net/conf.go
index c92e579d7e6..ddaa978f4f6 100644
--- a/libgo/go/net/conf.go
+++ b/libgo/go/net/conf.go
@@ -9,7 +9,6 @@ package net
import (
"os"
"runtime"
- "strconv"
"sync"
"syscall"
)
@@ -293,7 +292,7 @@ func goDebugNetDNS() (dnsMode string, debugLevel int) {
return
}
if '0' <= s[0] && s[0] <= '9' {
- debugLevel, _ = strconv.Atoi(s)
+ debugLevel, _, _ = dtoi(s, 0)
} else {
dnsMode = s
}
diff --git a/libgo/go/net/dial.go b/libgo/go/net/dial.go
index cb4ec216d53..193776fe413 100644
--- a/libgo/go/net/dial.go
+++ b/libgo/go/net/dial.go
@@ -57,6 +57,11 @@ type Dialer struct {
// If zero, keep-alives are not enabled. Network protocols
// that do not support keep-alives ignore this field.
KeepAlive time.Duration
+
+ // Cancel is an optional channel whose closure indicates that
+ // the dial should be canceled. Not all types of dials support
+ // cancelation.
+ Cancel <-chan struct{}
}
// Return either now+Timeout or Deadline, whichever comes first.
@@ -165,12 +170,14 @@ func resolveAddrList(op, net, addr string, deadline time.Time) (addrList, error)
// in square brackets as in "[::1]:80" or "[ipv6-host%zone]:80".
// The functions JoinHostPort and SplitHostPort manipulate addresses
// in this form.
+// If the host is empty, as in ":80", the local system is assumed.
//
// Examples:
// Dial("tcp", "12.34.56.78:80")
// Dial("tcp", "google.com:http")
// Dial("tcp", "[2001:db8::1]:http")
// Dial("tcp", "[fe80::1%lo0]:80")
+// Dial("tcp", ":80")
//
// For IP networks, the network must be "ip", "ip4" or "ip6" followed
// by a colon and a protocol number or name and the addr must be a
@@ -361,7 +368,7 @@ func dialSingle(ctx *dialContext, ra Addr, deadline time.Time) (c Conn, err erro
switch ra := ra.(type) {
case *TCPAddr:
la, _ := la.(*TCPAddr)
- c, err = testHookDialTCP(ctx.network, la, ra, deadline)
+ c, err = testHookDialTCP(ctx.network, la, ra, deadline, ctx.Cancel)
case *UDPAddr:
la, _ := la.(*UDPAddr)
c, err = dialUDP(ctx.network, la, ra, deadline)
@@ -383,7 +390,10 @@ func dialSingle(ctx *dialContext, ra Addr, deadline time.Time) (c Conn, err erro
// Listen announces on the local network address laddr.
// The network net must be a stream-oriented network: "tcp", "tcp4",
// "tcp6", "unix" or "unixpacket".
-// See Dial for the syntax of laddr.
+// For TCP and UDP, the syntax of laddr is "host:port", like "127.0.0.1:8080".
+// If host is omitted, as in ":8080", Listen listens on all available interfaces
+// instead of just the interface with the given host address.
+// See Dial for more details about address syntax.
func Listen(net, laddr string) (Listener, error) {
addrs, err := resolveAddrList("listen", net, laddr, noDeadline)
if err != nil {
@@ -407,6 +417,9 @@ func Listen(net, laddr string) (Listener, error) {
// ListenPacket announces on the local network address laddr.
// The network net must be a packet-oriented network: "udp", "udp4",
// "udp6", "ip", "ip4", "ip6" or "unixgram".
+// For TCP and UDP, the syntax of laddr is "host:port", like "127.0.0.1:8080".
+// If host is omitted, as in ":8080", ListenPacket listens on all available interfaces
+// instead of just the interface with the given host address.
// See Dial for the syntax of laddr.
func ListenPacket(net, laddr string) (PacketConn, error) {
addrs, err := resolveAddrList("listen", net, laddr, noDeadline)
diff --git a/libgo/go/net/dial_test.go b/libgo/go/net/dial_test.go
index ed6d7cc42f1..2311b108241 100644
--- a/libgo/go/net/dial_test.go
+++ b/libgo/go/net/dial_test.go
@@ -5,6 +5,7 @@
package net
import (
+ "internal/testenv"
"io"
"net/internal/socktest"
"runtime"
@@ -236,8 +237,8 @@ const (
// In some environments, the slow IPs may be explicitly unreachable, and fail
// more quickly than expected. This test hook prevents dialTCP from returning
// before the deadline.
-func slowDialTCP(net string, laddr, raddr *TCPAddr, deadline time.Time) (*TCPConn, error) {
- c, err := dialTCP(net, laddr, raddr, deadline)
+func slowDialTCP(net string, laddr, raddr *TCPAddr, deadline time.Time, cancel <-chan struct{}) (*TCPConn, error) {
+ c, err := dialTCP(net, laddr, raddr, deadline, cancel)
if ParseIP(slowDst4).Equal(raddr.IP) || ParseIP(slowDst6).Equal(raddr.IP) {
time.Sleep(deadline.Sub(time.Now()))
}
@@ -509,6 +510,9 @@ func TestDialerFallbackDelay(t *testing.T) {
}
func TestDialSerialAsyncSpuriousConnection(t *testing.T) {
+ if runtime.GOOS == "plan9" {
+ t.Skip("skipping on plan9; no deadline support, golang.org/issue/11932")
+ }
ln, err := newLocalListener("tcp")
if err != nil {
t.Fatal(err)
@@ -643,7 +647,7 @@ func TestDialerDualStack(t *testing.T) {
}
}
- var timeout = 100*time.Millisecond + closedPortDelay
+ var timeout = 150*time.Millisecond + closedPortDelay
for _, dualstack := range []bool{false, true} {
dss, err := newDualStackServer([]streamListener{
{network: "tcp4", address: "127.0.0.1"},
@@ -713,3 +717,67 @@ func TestDialerKeepAlive(t *testing.T) {
}
}
}
+
+func TestDialCancel(t *testing.T) {
+ if runtime.GOOS == "plan9" || runtime.GOOS == "nacl" {
+ // plan9 is not implemented and nacl doesn't have
+ // external network access.
+ t.Skipf("skipping on %s", runtime.GOOS)
+ }
+ onGoBuildFarm := testenv.Builder() != ""
+ if testing.Short() && !onGoBuildFarm {
+ t.Skip("skipping in short mode")
+ }
+
+ blackholeIPPort := JoinHostPort(slowDst4, "1234")
+ if !supportsIPv4 {
+ blackholeIPPort = JoinHostPort(slowDst6, "1234")
+ }
+
+ ticker := time.NewTicker(10 * time.Millisecond)
+ defer ticker.Stop()
+
+ const cancelTick = 5 // the timer tick we cancel the dial at
+ const timeoutTick = 100
+
+ var d Dialer
+ cancel := make(chan struct{})
+ d.Cancel = cancel
+ errc := make(chan error, 1)
+ connc := make(chan Conn, 1)
+ go func() {
+ if c, err := d.Dial("tcp", blackholeIPPort); err != nil {
+ errc <- err
+ } else {
+ connc <- c
+ }
+ }()
+ ticks := 0
+ for {
+ select {
+ case <-ticker.C:
+ ticks++
+ if ticks == cancelTick {
+ close(cancel)
+ }
+ if ticks == timeoutTick {
+ t.Fatal("timeout waiting for dial to fail")
+ }
+ case c := <-connc:
+ c.Close()
+ t.Fatal("unexpected successful connection")
+ case err := <-errc:
+ if perr := parseDialError(err); perr != nil {
+ t.Error(perr)
+ }
+ if ticks < cancelTick {
+ t.Fatalf("dial error after %d ticks (%d before cancel sent): %v",
+ ticks, cancelTick-ticks, err)
+ }
+ if oe, ok := err.(*OpError); !ok || oe.Err != errCanceled {
+ t.Fatalf("dial error = %v (%T); want OpError with Err == errCanceled", err, err)
+ }
+ return // success.
+ }
+ }
+}
diff --git a/libgo/go/net/dnsclient.go b/libgo/go/net/dnsclient.go
index ce48521bc60..5dc2a0368ce 100644
--- a/libgo/go/net/dnsclient.go
+++ b/libgo/go/net/dnsclient.go
@@ -40,15 +40,20 @@ func reverseaddr(addr string) (arpa string, err error) {
func answer(name, server string, dns *dnsMsg, qtype uint16) (cname string, addrs []dnsRR, err error) {
addrs = make([]dnsRR, 0, len(dns.answer))
- if dns.rcode == dnsRcodeNameError && dns.recursion_available {
+ if dns.rcode == dnsRcodeNameError {
return "", nil, &DNSError{Err: errNoSuchHost.Error(), Name: name, Server: server}
}
if dns.rcode != dnsRcodeSuccess {
// None of the error codes make sense
// for the query we sent. If we didn't get
// a name error and we didn't get success,
- // the server is behaving incorrectly.
- return "", nil, &DNSError{Err: "server misbehaving", Name: name, Server: server}
+ // the server is behaving incorrectly or
+ // having temporary trouble.
+ err := &DNSError{Err: "server misbehaving", Name: name, Server: server}
+ if dns.rcode == dnsRcodeServerFailure {
+ err.IsTemporary = true
+ }
+ return "", nil, err
}
// Look for the name.
@@ -156,6 +161,28 @@ func isDomainName(s string) bool {
return ok
}
+// absDomainName returns an absoulte domain name which ends with a
+// trailing dot to match pure Go reverse resolver and all other lookup
+// routines.
+// See golang.org/issue/12189.
+// But we don't want to add dots for local names from /etc/hosts.
+// It's hard to tell so we settle on the heuristic that names without dots
+// (like "localhost" or "myhost") do not get trailing dots, but any other
+// names do.
+func absDomainName(b []byte) string {
+ hasDots := false
+ for _, x := range b {
+ if x == '.' {
+ hasDots = true
+ break
+ }
+ }
+ if hasDots && b[len(b)-1] != '.' {
+ b = append(b, '.')
+ }
+ return string(b)
+}
+
// An SRV represents a single DNS SRV record.
type SRV struct {
Target string
diff --git a/libgo/go/net/dnsclient_test.go b/libgo/go/net/dnsclient_test.go
index 3ab2b836ef6..7308fb03fa1 100644
--- a/libgo/go/net/dnsclient_test.go
+++ b/libgo/go/net/dnsclient_test.go
@@ -67,3 +67,51 @@ func testWeighting(t *testing.T, margin float64) {
func TestWeighting(t *testing.T) {
testWeighting(t, 0.05)
}
+
+// Issue 8434: verify that Temporary returns true on an error when rcode
+// is SERVFAIL
+func TestIssue8434(t *testing.T) {
+ msg := &dnsMsg{
+ dnsMsgHdr: dnsMsgHdr{
+ rcode: dnsRcodeServerFailure,
+ },
+ }
+
+ _, _, err := answer("golang.org", "foo:53", msg, uint16(dnsTypeSRV))
+ if err == nil {
+ t.Fatal("expected an error")
+ }
+ if ne, ok := err.(Error); !ok {
+ t.Fatalf("err = %#v; wanted something supporting net.Error", err)
+ } else if !ne.Temporary() {
+ t.Fatalf("Temporary = false for err = %#v; want Temporary == true", err)
+ }
+ if de, ok := err.(*DNSError); !ok {
+ t.Fatalf("err = %#v; wanted a *net.DNSError", err)
+ } else if !de.IsTemporary {
+ t.Fatalf("IsTemporary = false for err = %#v; want IsTemporary == true", err)
+ }
+}
+
+// Issue 12778: verify that NXDOMAIN without RA bit errors as
+// "no such host" and not "server misbehaving"
+func TestIssue12778(t *testing.T) {
+ msg := &dnsMsg{
+ dnsMsgHdr: dnsMsgHdr{
+ rcode: dnsRcodeNameError,
+ recursion_available: false,
+ },
+ }
+
+ _, _, err := answer("golang.org", "foo:53", msg, uint16(dnsTypeSRV))
+ if err == nil {
+ t.Fatal("expected an error")
+ }
+ de, ok := err.(*DNSError)
+ if !ok {
+ t.Fatalf("err = %#v; wanted a *net.DNSError", err)
+ }
+ if de.Err != errNoSuchHost.Error() {
+ t.Fatalf("Err = %#v; wanted %q", de.Err, errNoSuchHost.Error())
+ }
+}
diff --git a/libgo/go/net/dnsclient_unix.go b/libgo/go/net/dnsclient_unix.go
index c03c1b1159f..17188f0024c 100644
--- a/libgo/go/net/dnsclient_unix.go
+++ b/libgo/go/net/dnsclient_unix.go
@@ -20,14 +20,22 @@ import (
"io"
"math/rand"
"os"
- "strconv"
"sync"
"time"
)
+// A dnsDialer provides dialing suitable for DNS queries.
+type dnsDialer interface {
+ dialDNS(string, string) (dnsConn, error)
+}
+
+var testHookDNSDialer = func(d time.Duration) dnsDialer { return &Dialer{Timeout: d} }
+
// A dnsConn represents a DNS transport endpoint.
type dnsConn interface {
- Conn
+ io.Closer
+
+ SetDeadline(time.Time) error
// readDNSResponse reads a DNS response message from the DNS
// transport endpoint and returns the received DNS response
@@ -122,7 +130,7 @@ func (d *Dialer) dialDNS(network, server string) (dnsConn, error) {
// exchange sends a query on the connection and hopes for a response.
func exchange(server, name string, qtype uint16, timeout time.Duration) (*dnsMsg, error) {
- d := Dialer{Timeout: timeout}
+ d := testHookDNSDialer(timeout)
out := dnsMsg{
dnsMsgHdr: dnsMsgHdr{
recursion_desired: true,
@@ -165,9 +173,6 @@ func tryOneName(cfg *dnsConfig, name string, qtype uint16) (string, []dnsRR, err
if len(cfg.servers) == 0 {
return "", nil, &DNSError{Err: "no DNS servers", Name: name}
}
- if len(name) >= 256 {
- return "", nil, &DNSError{Err: "DNS name too long", Name: name}
- }
timeout := time.Duration(cfg.timeout) * time.Second
var lastErr error
for i := 0; i < cfg.attempts; i++ {
@@ -186,7 +191,11 @@ func tryOneName(cfg *dnsConfig, name string, qtype uint16) (string, []dnsRR, err
continue
}
cname, rrs, err := answer(name, server, msg, qtype)
- if err == nil || msg.rcode == dnsRcodeSuccess || msg.rcode == dnsRcodeNameError && msg.recursion_available {
+ // If answer errored for rcodes dnsRcodeSuccess or dnsRcodeNameError,
+ // it means the response in msg was not useful and trying another
+ // server probably won't help. Return now in those cases.
+ // TODO: indicate this in a more obvious way, such as a field on DNSError?
+ if err == nil || msg.rcode == dnsRcodeSuccess || msg.rcode == dnsRcodeNameError {
return cname, rrs, err
}
lastErr = err
@@ -374,7 +383,7 @@ func (o hostLookupOrder) String() string {
if s, ok := lookupOrderName[o]; ok {
return s
}
- return "hostLookupOrder=" + strconv.Itoa(int(o)) + "??"
+ return "hostLookupOrder=" + itoa(int(o)) + "??"
}
// goLookupHost is the native Go implementation of LookupHost.
@@ -440,7 +449,8 @@ func goLookupIPOrder(name string, order hostLookupOrder) (addrs []IPAddr, err er
conf := resolvConf.dnsConfig
resolvConf.mu.RUnlock()
type racer struct {
- rrs []dnsRR
+ fqdn string
+ rrs []dnsRR
error
}
lane := make(chan racer, 1)
@@ -450,13 +460,16 @@ func goLookupIPOrder(name string, order hostLookupOrder) (addrs []IPAddr, err er
for _, qtype := range qtypes {
go func(qtype uint16) {
_, rrs, err := tryOneName(conf, fqdn, qtype)
- lane <- racer{rrs, err}
+ lane <- racer{fqdn, rrs, err}
}(qtype)
}
for range qtypes {
racer := <-lane
if racer.error != nil {
- lastErr = racer.error
+ // Prefer error for original name.
+ if lastErr == nil || racer.fqdn == name+"." {
+ lastErr = racer.error
+ }
continue
}
addrs = append(addrs, addrRecordList(racer.rrs)...)
@@ -473,12 +486,12 @@ func goLookupIPOrder(name string, order hostLookupOrder) (addrs []IPAddr, err er
}
sortByRFC6724(addrs)
if len(addrs) == 0 {
- if lastErr != nil {
- return nil, lastErr
- }
if order == hostLookupDNSFiles {
addrs = goLookupIPFiles(name)
}
+ if len(addrs) == 0 && lastErr != nil {
+ return nil, lastErr
+ }
}
return addrs, nil
}
diff --git a/libgo/go/net/dnsclient_unix_test.go b/libgo/go/net/dnsclient_unix_test.go
index a999f8f0607..934f25b2c94 100644
--- a/libgo/go/net/dnsclient_unix_test.go
+++ b/libgo/go/net/dnsclient_unix_test.go
@@ -80,7 +80,7 @@ func TestSpecialDomainName(t *testing.T) {
server := "8.8.8.8:53"
for _, tt := range specialDomainNameTests {
- msg, err := exchange(server, tt.name, tt.qtype, 0)
+ msg, err := exchange(server, tt.name, tt.qtype, 3*time.Second)
if err != nil {
t.Error(err)
continue
@@ -378,6 +378,103 @@ func TestGoLookupIPWithResolverConfig(t *testing.T) {
}
}
+// Test that goLookupIPOrder falls back to the host file when no DNS servers are available.
+func TestGoLookupIPOrderFallbackToFile(t *testing.T) {
+ if testing.Short() || !*testExternal {
+ t.Skip("avoid external network")
+ }
+
+ // Add a config that simulates no dns servers being available.
+ conf, err := newResolvConfTest()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if err := conf.writeAndUpdate([]string{}); err != nil {
+ t.Fatal(err)
+ }
+ conf.tryUpdate(conf.path)
+ // Redirect host file lookups.
+ defer func(orig string) { testHookHostsPath = orig }(testHookHostsPath)
+ testHookHostsPath = "testdata/hosts"
+
+ for _, order := range []hostLookupOrder{hostLookupFilesDNS, hostLookupDNSFiles} {
+ name := fmt.Sprintf("order %v", order)
+
+ // First ensure that we get an error when contacting a non-existant host.
+ _, err := goLookupIPOrder("notarealhost", order)
+ if err == nil {
+ t.Errorf("%s: expected error while looking up name not in hosts file", name)
+ continue
+ }
+
+ // Now check that we get an address when the name appears in the hosts file.
+ addrs, err := goLookupIPOrder("thor", order) // entry is in "testdata/hosts"
+ if err != nil {
+ t.Errorf("%s: expected to successfully lookup host entry", name)
+ continue
+ }
+ if len(addrs) != 1 {
+ t.Errorf("%s: expected exactly one result, but got %v", name, addrs)
+ continue
+ }
+ if got, want := addrs[0].String(), "127.1.1.1"; got != want {
+ t.Errorf("%s: address doesn't match expectation. got %v, want %v", name, got, want)
+ }
+ }
+ defer conf.teardown()
+}
+
+// Issue 12712.
+// When using search domains, return the error encountered
+// querying the original name instead of an error encountered
+// querying a generated name.
+func TestErrorForOriginalNameWhenSearching(t *testing.T) {
+ const fqdn = "doesnotexist.domain"
+
+ origTestHookDNSDialer := testHookDNSDialer
+ defer func() { testHookDNSDialer = origTestHookDNSDialer }()
+
+ conf, err := newResolvConfTest()
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conf.teardown()
+
+ if err := conf.writeAndUpdate([]string{"search servfail"}); err != nil {
+ t.Fatal(err)
+ }
+
+ d := &fakeDNSConn{}
+ testHookDNSDialer = func(time.Duration) dnsDialer { return d }
+
+ d.rh = func(q *dnsMsg) (*dnsMsg, error) {
+ r := &dnsMsg{
+ dnsMsgHdr: dnsMsgHdr{
+ id: q.id,
+ },
+ }
+
+ switch q.question[0].Name {
+ case fqdn + ".servfail.":
+ r.rcode = dnsRcodeServerFailure
+ default:
+ r.rcode = dnsRcodeNameError
+ }
+
+ return r, nil
+ }
+
+ _, err = goLookupIP(fqdn)
+ if err == nil {
+ t.Fatal("expected an error")
+ }
+
+ want := &DNSError{Name: fqdn, Err: errNoSuchHost.Error()}
+ if err, ok := err.(*DNSError); !ok || err.Name != want.Name || err.Err != want.Err {
+ t.Errorf("got %v; want %v", err, want)
+ }
+}
+
func BenchmarkGoLookupIP(b *testing.B) {
testHookUninstaller.Do(uninstallTestHooks)
@@ -415,3 +512,37 @@ func BenchmarkGoLookupIPWithBrokenNameServer(b *testing.B) {
goLookupIP("www.example.com")
}
}
+
+type fakeDNSConn struct {
+ // last query
+ qmu sync.Mutex // guards q
+ q *dnsMsg
+ // reply handler
+ rh func(*dnsMsg) (*dnsMsg, error)
+}
+
+func (f *fakeDNSConn) dialDNS(n, s string) (dnsConn, error) {
+ return f, nil
+}
+
+func (f *fakeDNSConn) Close() error {
+ return nil
+}
+
+func (f *fakeDNSConn) SetDeadline(time.Time) error {
+ return nil
+}
+
+func (f *fakeDNSConn) writeDNSQuery(q *dnsMsg) error {
+ f.qmu.Lock()
+ defer f.qmu.Unlock()
+ f.q = q
+ return nil
+}
+
+func (f *fakeDNSConn) readDNSResponse() (*dnsMsg, error) {
+ f.qmu.Lock()
+ q := f.q
+ f.qmu.Unlock()
+ return f.rh(q)
+}
diff --git a/libgo/go/net/dnsmsg.go b/libgo/go/net/dnsmsg.go
index 6ecaa948230..93078fe8499 100644
--- a/libgo/go/net/dnsmsg.go
+++ b/libgo/go/net/dnsmsg.go
@@ -691,6 +691,9 @@ func packRR(rr dnsRR, msg []byte, off int) (off2 int, ok bool) {
// off1 is end of header
// off2 is end of rr
off1, ok = packStruct(rr.Header(), msg, off)
+ if !ok {
+ return len(msg), false
+ }
off2, ok = packStruct(rr, msg, off)
if !ok {
return len(msg), false
diff --git a/libgo/go/net/error_test.go b/libgo/go/net/error_test.go
index bf95ff6108c..1aab14c4499 100644
--- a/libgo/go/net/error_test.go
+++ b/libgo/go/net/error_test.go
@@ -93,7 +93,7 @@ second:
goto third
}
switch nestedErr {
- case errClosing, errMissingAddress:
+ case errCanceled, errClosing, errMissingAddress:
return nil
}
return fmt.Errorf("unexpected type on 2nd nested level: %T", nestedErr)
@@ -116,8 +116,10 @@ var dialErrorTests = []struct {
{"tcp", "no-such-name:80"},
{"tcp", "mh/astro/r70:http"},
- {"tcp", "127.0.0.1:0"},
- {"udp", "127.0.0.1:0"},
+ {"tcp", JoinHostPort("127.0.0.1", "-1")},
+ {"tcp", JoinHostPort("127.0.0.1", "123456789")},
+ {"udp", JoinHostPort("127.0.0.1", "-1")},
+ {"udp", JoinHostPort("127.0.0.1", "123456789")},
{"ip:icmp", "127.0.0.1"},
{"unix", "/path/to/somewhere"},
@@ -145,10 +147,23 @@ func TestDialError(t *testing.T) {
for i, tt := range dialErrorTests {
c, err := d.Dial(tt.network, tt.address)
if err == nil {
- t.Errorf("#%d: should fail; %s:%s->%s", i, tt.network, c.LocalAddr(), c.RemoteAddr())
+ t.Errorf("#%d: should fail; %s:%s->%s", i, c.LocalAddr().Network(), c.LocalAddr(), c.RemoteAddr())
c.Close()
continue
}
+ if tt.network == "tcp" || tt.network == "udp" {
+ nerr := err
+ if op, ok := nerr.(*OpError); ok {
+ nerr = op.Err
+ }
+ if sys, ok := nerr.(*os.SyscallError); ok {
+ nerr = sys.Err
+ }
+ if nerr == errOpNotSupported {
+ t.Errorf("#%d: should fail without %v; %s:%s->", i, nerr, tt.network, tt.address)
+ continue
+ }
+ }
if c != nil {
t.Errorf("Dial returned non-nil interface %T(%v) with err != nil", c, c)
}
@@ -198,7 +213,8 @@ var listenErrorTests = []struct {
{"tcp", "no-such-name:80"},
{"tcp", "mh/astro/r70:http"},
- {"tcp", "127.0.0.1:0"},
+ {"tcp", JoinHostPort("127.0.0.1", "-1")},
+ {"tcp", JoinHostPort("127.0.0.1", "123456789")},
{"unix", "/path/to/somewhere"},
{"unixpacket", "/path/to/somewhere"},
@@ -223,10 +239,23 @@ func TestListenError(t *testing.T) {
for i, tt := range listenErrorTests {
ln, err := Listen(tt.network, tt.address)
if err == nil {
- t.Errorf("#%d: should fail; %s:%s->", i, tt.network, ln.Addr())
+ t.Errorf("#%d: should fail; %s:%s->", i, ln.Addr().Network(), ln.Addr())
ln.Close()
continue
}
+ if tt.network == "tcp" {
+ nerr := err
+ if op, ok := nerr.(*OpError); ok {
+ nerr = op.Err
+ }
+ if sys, ok := nerr.(*os.SyscallError); ok {
+ nerr = sys.Err
+ }
+ if nerr == errOpNotSupported {
+ t.Errorf("#%d: should fail without %v; %s:%s->", i, nerr, tt.network, tt.address)
+ continue
+ }
+ }
if ln != nil {
t.Errorf("Listen returned non-nil interface %T(%v) with err != nil", ln, ln)
}
@@ -246,6 +275,9 @@ var listenPacketErrorTests = []struct {
{"udp", "127.0.0.1:☺"},
{"udp", "no-such-name:80"},
{"udp", "mh/astro/r70:http"},
+
+ {"udp", JoinHostPort("127.0.0.1", "-1")},
+ {"udp", JoinHostPort("127.0.0.1", "123456789")},
}
func TestListenPacketError(t *testing.T) {
@@ -263,7 +295,7 @@ func TestListenPacketError(t *testing.T) {
for i, tt := range listenPacketErrorTests {
c, err := ListenPacket(tt.network, tt.address)
if err == nil {
- t.Errorf("#%d: should fail; %s:%s->", i, tt.network, c.LocalAddr())
+ t.Errorf("#%d: should fail; %s:%s->", i, c.LocalAddr().Network(), c.LocalAddr())
c.Close()
continue
}
@@ -381,7 +413,7 @@ second:
goto third
}
switch nestedErr {
- case errClosing, errTimeout, ErrWriteToConnected, io.ErrUnexpectedEOF:
+ case errCanceled, errClosing, errTimeout, ErrWriteToConnected, io.ErrUnexpectedEOF:
return nil
}
return fmt.Errorf("unexpected type on 2nd nested level: %T", nestedErr)
diff --git a/libgo/go/net/fd_plan9.go b/libgo/go/net/fd_plan9.go
index 32766f53b58..cec88609d06 100644
--- a/libgo/go/net/fd_plan9.go
+++ b/libgo/go/net/fd_plan9.go
@@ -171,6 +171,14 @@ func (fd *netFD) Close() error {
if !fd.ok() {
return syscall.EINVAL
}
+ if fd.net == "tcp" {
+ // The following line is required to unblock Reads.
+ // For some reason, WriteString returns an error:
+ // "write /net/tcp/39/listen: inappropriate use of fd"
+ // But without it, Reads on dead conns hang forever.
+ // See Issue 9554.
+ fd.ctl.WriteString("hangup")
+ }
err := fd.ctl.Close()
if fd.data != nil {
if err1 := fd.data.Close(); err1 != nil && err == nil {
diff --git a/libgo/go/net/fd_unix.go b/libgo/go/net/fd_unix.go
index 465023fb23d..ff498c2bff0 100644
--- a/libgo/go/net/fd_unix.go
+++ b/libgo/go/net/fd_unix.go
@@ -68,7 +68,7 @@ func (fd *netFD) name() string {
return fd.net + ":" + ls + "->" + rs
}
-func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time) error {
+func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time, cancel <-chan struct{}) error {
// Do not need to call fd.writeLock here,
// because fd is not yet accessible to user,
// so no concurrent operations are possible.
@@ -102,6 +102,19 @@ func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time) error {
fd.setWriteDeadline(deadline)
defer fd.setWriteDeadline(noDeadline)
}
+ if cancel != nil {
+ done := make(chan bool)
+ defer close(done)
+ go func() {
+ select {
+ case <-cancel:
+ // Force the runtime's poller to immediately give
+ // up waiting for writability.
+ fd.setWriteDeadline(aLongTimeAgo)
+ case <-done:
+ }
+ }()
+ }
for {
// Performing multiple connect system calls on a
// non-blocking socket under Unix variants does not
@@ -112,6 +125,11 @@ func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time) error {
// succeeded or failed. See issue 7474 for further
// details.
if err := fd.pd.WaitWrite(); err != nil {
+ select {
+ case <-cancel:
+ return errCanceled
+ default:
+ }
return err
}
nerr, err := getsockoptIntFunc(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_ERROR)
diff --git a/libgo/go/net/fd_windows.go b/libgo/go/net/fd_windows.go
index 205daff9e46..fd50d772d60 100644
--- a/libgo/go/net/fd_windows.go
+++ b/libgo/go/net/fd_windows.go
@@ -5,6 +5,7 @@
package net
import (
+ "internal/race"
"os"
"runtime"
"sync"
@@ -208,7 +209,7 @@ func (s *ioSrv) ExecIO(o *operation, name string, submit func(o *operation) erro
s.req <- ioSrvReq{o, nil}
<-o.errc
}
- // Wait for cancellation to complete.
+ // Wait for cancelation to complete.
fd.pd.WaitCanceled(int(o.mode))
if o.errno != 0 {
err = syscall.Errno(o.errno)
@@ -217,8 +218,8 @@ func (s *ioSrv) ExecIO(o *operation, name string, submit func(o *operation) erro
}
return 0, err
}
- // We issued cancellation request. But, it seems, IO operation succeeded
- // before cancellation request run. We need to treat IO operation as
+ // We issued a cancelation request. But, it seems, IO operation succeeded
+ // before the cancelation request run. We need to treat the IO operation as
// succeeded (the bytes are actually sent/recv from network).
return int(o.qty), nil
}
@@ -319,7 +320,7 @@ func (fd *netFD) setAddr(laddr, raddr Addr) {
runtime.SetFinalizer(fd, (*netFD).Close)
}
-func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time) error {
+func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time, cancel <-chan struct{}) error {
// Do not need to call fd.writeLock here,
// because fd is not yet accessible to user,
// so no concurrent operations are possible.
@@ -350,14 +351,32 @@ func (fd *netFD) connect(la, ra syscall.Sockaddr, deadline time.Time) error {
// Call ConnectEx API.
o := &fd.wop
o.sa = ra
+ if cancel != nil {
+ done := make(chan struct{})
+ defer close(done)
+ go func() {
+ select {
+ case <-cancel:
+ // Force the runtime's poller to immediately give
+ // up waiting for writability.
+ fd.setWriteDeadline(aLongTimeAgo)
+ case <-done:
+ }
+ }()
+ }
_, err := wsrv.ExecIO(o, "ConnectEx", func(o *operation) error {
return connectExFunc(o.fd.sysfd, o.sa, nil, 0, nil, &o.o)
})
if err != nil {
- if _, ok := err.(syscall.Errno); ok {
- err = os.NewSyscallError("connectex", err)
+ select {
+ case <-cancel:
+ return errCanceled
+ default:
+ if _, ok := err.(syscall.Errno); ok {
+ err = os.NewSyscallError("connectex", err)
+ }
+ return err
}
- return err
}
// Refresh socket properties.
return os.NewSyscallError("setsockopt", syscall.Setsockopt(fd.sysfd, syscall.SOL_SOCKET, syscall.SO_UPDATE_CONNECT_CONTEXT, (*byte)(unsafe.Pointer(&fd.sysfd)), int32(unsafe.Sizeof(fd.sysfd))))
@@ -461,8 +480,8 @@ func (fd *netFD) Read(buf []byte) (int, error) {
n, err := rsrv.ExecIO(o, "WSARecv", func(o *operation) error {
return syscall.WSARecv(o.fd.sysfd, &o.buf, 1, &o.qty, &o.flags, &o.o, nil)
})
- if raceenabled {
- raceAcquire(unsafe.Pointer(&ioSync))
+ if race.Enabled {
+ race.Acquire(unsafe.Pointer(&ioSync))
}
err = fd.eofError(n, err)
if _, ok := err.(syscall.Errno); ok {
@@ -504,8 +523,8 @@ func (fd *netFD) Write(buf []byte) (int, error) {
return 0, err
}
defer fd.writeUnlock()
- if raceenabled {
- raceReleaseMerge(unsafe.Pointer(&ioSync))
+ if race.Enabled {
+ race.ReleaseMerge(unsafe.Pointer(&ioSync))
}
o := &fd.wop
o.InitBuf(buf)
diff --git a/libgo/go/net/file_test.go b/libgo/go/net/file_test.go
index 003dbb2ecb7..6566ce21a1f 100644
--- a/libgo/go/net/file_test.go
+++ b/libgo/go/net/file_test.go
@@ -8,158 +8,222 @@ import (
"os"
"reflect"
"runtime"
+ "sync"
"testing"
)
-type listenerFile interface {
- Listener
- File() (f *os.File, err error)
-}
-
-type packetConnFile interface {
- PacketConn
- File() (f *os.File, err error)
-}
+// The full stack test cases for IPConn have been moved to the
+// following:
+// golang.org/x/net/ipv4
+// golang.org/x/net/ipv6
+// golang.org/x/net/icmp
-type connFile interface {
- Conn
- File() (f *os.File, err error)
+var fileConnTests = []struct {
+ network string
+}{
+ {"tcp"},
+ {"udp"},
+ {"unix"},
+ {"unixpacket"},
}
-func testFileListener(t *testing.T, net, laddr string) {
- l, err := Listen(net, laddr)
- if err != nil {
- t.Fatal(err)
- }
- defer l.Close()
- lf := l.(listenerFile)
- f, err := lf.File()
- if err != nil {
- t.Fatal(err)
- }
- c, err := FileListener(f)
- if err != nil {
- t.Fatal(err)
- }
- if !reflect.DeepEqual(l.Addr(), c.Addr()) {
- t.Fatalf("got %#v; want%#v", l.Addr(), c.Addr())
- }
- if err := c.Close(); err != nil {
- t.Fatal(err)
- }
- if err := f.Close(); err != nil {
- t.Fatal(err)
+func TestFileConn(t *testing.T) {
+ switch runtime.GOOS {
+ case "nacl", "plan9", "windows":
+ t.Skipf("not supported on %s", runtime.GOOS)
}
-}
-var fileListenerTests = []struct {
- net string
- laddr string
-}{
- {net: "tcp", laddr: ":0"},
- {net: "tcp", laddr: "0.0.0.0:0"},
- {net: "tcp", laddr: "[::ffff:0.0.0.0]:0"},
- {net: "tcp", laddr: "[::]:0"},
+ for _, tt := range fileConnTests {
+ if !testableNetwork(tt.network) {
+ t.Logf("skipping %s test", tt.network)
+ continue
+ }
- {net: "tcp", laddr: "127.0.0.1:0"},
- {net: "tcp", laddr: "[::ffff:127.0.0.1]:0"},
- {net: "tcp", laddr: "[::1]:0"},
+ var network, address string
+ switch tt.network {
+ case "udp":
+ c, err := newLocalPacketListener(tt.network)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+ network = c.LocalAddr().Network()
+ address = c.LocalAddr().String()
+ default:
+ handler := func(ls *localServer, ln Listener) {
+ c, err := ln.Accept()
+ if err != nil {
+ return
+ }
+ defer c.Close()
+ var b [1]byte
+ c.Read(b[:])
+ }
+ ls, err := newLocalServer(tt.network)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer ls.teardown()
+ if err := ls.buildup(handler); err != nil {
+ t.Fatal(err)
+ }
+ network = ls.Listener.Addr().Network()
+ address = ls.Listener.Addr().String()
+ }
- {net: "tcp4", laddr: ":0"},
- {net: "tcp4", laddr: "0.0.0.0:0"},
- {net: "tcp4", laddr: "[::ffff:0.0.0.0]:0"},
+ c1, err := Dial(network, address)
+ if err != nil {
+ if perr := parseDialError(err); perr != nil {
+ t.Error(perr)
+ }
+ t.Fatal(err)
+ }
+ addr := c1.LocalAddr()
- {net: "tcp4", laddr: "127.0.0.1:0"},
- {net: "tcp4", laddr: "[::ffff:127.0.0.1]:0"},
+ var f *os.File
+ switch c1 := c1.(type) {
+ case *TCPConn:
+ f, err = c1.File()
+ case *UDPConn:
+ f, err = c1.File()
+ case *UnixConn:
+ f, err = c1.File()
+ }
+ if err := c1.Close(); err != nil {
+ if perr := parseCloseError(err); perr != nil {
+ t.Error(perr)
+ }
+ t.Error(err)
+ }
+ if err != nil {
+ if perr := parseCommonError(err); perr != nil {
+ t.Error(perr)
+ }
+ t.Fatal(err)
+ }
- {net: "tcp6", laddr: ":0"},
- {net: "tcp6", laddr: "[::]:0"},
+ c2, err := FileConn(f)
+ if err := f.Close(); err != nil {
+ t.Error(err)
+ }
+ if err != nil {
+ if perr := parseCommonError(err); perr != nil {
+ t.Error(perr)
+ }
+ t.Fatal(err)
+ }
+ defer c2.Close()
- {net: "tcp6", laddr: "[::1]:0"},
+ if _, err := c2.Write([]byte("FILECONN TEST")); err != nil {
+ if perr := parseWriteError(err); perr != nil {
+ t.Error(perr)
+ }
+ t.Fatal(err)
+ }
+ if !reflect.DeepEqual(c2.LocalAddr(), addr) {
+ t.Fatalf("got %#v; want %#v", c2.LocalAddr(), addr)
+ }
+ }
+}
- {net: "unix", laddr: "@gotest/net"},
- {net: "unixpacket", laddr: "@gotest/net"},
+var fileListenerTests = []struct {
+ network string
+}{
+ {"tcp"},
+ {"unix"},
+ {"unixpacket"},
}
func TestFileListener(t *testing.T) {
switch runtime.GOOS {
- case "nacl", "windows":
+ case "nacl", "plan9", "windows":
t.Skipf("not supported on %s", runtime.GOOS)
}
for _, tt := range fileListenerTests {
- if !testableListenArgs(tt.net, tt.laddr, "") {
- t.Logf("skipping %s test", tt.net+" "+tt.laddr)
+ if !testableNetwork(tt.network) {
+ t.Logf("skipping %s test", tt.network)
continue
}
- testFileListener(t, tt.net, tt.laddr)
- }
-}
-func testFilePacketConn(t *testing.T, pcf packetConnFile, listen bool) {
- f, err := pcf.File()
- if err != nil {
- t.Fatal(err)
- }
- c, err := FilePacketConn(f)
- if err != nil {
- t.Fatal(err)
- }
- if !reflect.DeepEqual(pcf.LocalAddr(), c.LocalAddr()) {
- t.Fatalf("got %#v; want %#v", pcf.LocalAddr(), c.LocalAddr())
- }
- if listen {
- if _, err := c.WriteTo([]byte{}, c.LocalAddr()); err != nil {
+ ln1, err := newLocalListener(tt.network)
+ if err != nil {
t.Fatal(err)
}
- }
- if err := c.Close(); err != nil {
- t.Fatal(err)
- }
- if err := f.Close(); err != nil {
- t.Fatal(err)
- }
-}
+ switch tt.network {
+ case "unix", "unixpacket":
+ defer os.Remove(ln1.Addr().String())
+ }
+ addr := ln1.Addr()
-func testFilePacketConnListen(t *testing.T, net, laddr string) {
- l, err := ListenPacket(net, laddr)
- if err != nil {
- t.Fatal(err)
- }
- testFilePacketConn(t, l.(packetConnFile), true)
- if err := l.Close(); err != nil {
- t.Fatal(err)
- }
-}
+ var f *os.File
+ switch ln1 := ln1.(type) {
+ case *TCPListener:
+ f, err = ln1.File()
+ case *UnixListener:
+ f, err = ln1.File()
+ }
+ switch tt.network {
+ case "unix", "unixpacket":
+ defer ln1.Close() // UnixListener.Close calls syscall.Unlink internally
+ default:
+ if err := ln1.Close(); err != nil {
+ t.Error(err)
+ }
+ }
+ if err != nil {
+ if perr := parseCommonError(err); perr != nil {
+ t.Error(perr)
+ }
+ t.Fatal(err)
+ }
-func testFilePacketConnDial(t *testing.T, net, raddr string) {
- c, err := Dial(net, raddr)
- if err != nil {
- t.Fatal(err)
- }
- testFilePacketConn(t, c.(packetConnFile), false)
- if err := c.Close(); err != nil {
- t.Fatal(err)
+ ln2, err := FileListener(f)
+ if err := f.Close(); err != nil {
+ t.Error(err)
+ }
+ if err != nil {
+ if perr := parseCommonError(err); perr != nil {
+ t.Error(perr)
+ }
+ t.Fatal(err)
+ }
+ defer ln2.Close()
+
+ var wg sync.WaitGroup
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ c, err := Dial(ln2.Addr().Network(), ln2.Addr().String())
+ if err != nil {
+ if perr := parseDialError(err); perr != nil {
+ t.Error(perr)
+ }
+ t.Error(err)
+ return
+ }
+ c.Close()
+ }()
+ c, err := ln2.Accept()
+ if err != nil {
+ if perr := parseAcceptError(err); perr != nil {
+ t.Error(perr)
+ }
+ t.Fatal(err)
+ }
+ c.Close()
+ wg.Wait()
+ if !reflect.DeepEqual(ln2.Addr(), addr) {
+ t.Fatalf("got %#v; want %#v", ln2.Addr(), addr)
+ }
}
}
var filePacketConnTests = []struct {
- net string
- addr string
+ network string
}{
- {net: "udp", addr: "127.0.0.1:0"},
- {net: "udp", addr: "[::ffff:127.0.0.1]:0"},
- {net: "udp", addr: "[::1]:0"},
-
- {net: "udp4", addr: "127.0.0.1:0"},
- {net: "udp4", addr: "[::ffff:127.0.0.1]:0"},
-
- {net: "udp6", addr: "[::1]:0"},
-
- // TODO(mikioh,bradfitz): reenable once 10730 is fixed
- // {net: "ip4:icmp", addr: "127.0.0.1"},
-
- {net: "unixgram", addr: "@gotest3/net"},
+ {"udp"},
+ {"unixgram"},
}
func TestFilePacketConn(t *testing.T) {
@@ -169,25 +233,61 @@ func TestFilePacketConn(t *testing.T) {
}
for _, tt := range filePacketConnTests {
- if !testableListenArgs(tt.net, tt.addr, "") {
- t.Logf("skipping %s test", tt.net+" "+tt.addr)
+ if !testableNetwork(tt.network) {
+ t.Logf("skipping %s test", tt.network)
continue
}
- if os.Getuid() != 0 && tt.net == "ip4:icmp" {
- t.Log("skipping test; must be root")
- continue
+
+ c1, err := newLocalPacketListener(tt.network)
+ if err != nil {
+ t.Fatal(err)
}
- testFilePacketConnListen(t, tt.net, tt.addr)
- switch tt.net {
- case "udp", "udp4", "udp6":
- host, _, err := SplitHostPort(tt.addr)
- if err != nil {
- t.Error(err)
- continue
+ switch tt.network {
+ case "unixgram":
+ defer os.Remove(c1.LocalAddr().String())
+ }
+ addr := c1.LocalAddr()
+
+ var f *os.File
+ switch c1 := c1.(type) {
+ case *UDPConn:
+ f, err = c1.File()
+ case *UnixConn:
+ f, err = c1.File()
+ }
+ if err := c1.Close(); err != nil {
+ if perr := parseCloseError(err); perr != nil {
+ t.Error(perr)
}
- testFilePacketConnDial(t, tt.net, JoinHostPort(host, "12345"))
- case "ip4:icmp":
- testFilePacketConnDial(t, tt.net, tt.addr)
+ t.Error(err)
+ }
+ if err != nil {
+ if perr := parseCommonError(err); perr != nil {
+ t.Error(perr)
+ }
+ t.Fatal(err)
+ }
+
+ c2, err := FilePacketConn(f)
+ if err := f.Close(); err != nil {
+ t.Error(err)
+ }
+ if err != nil {
+ if perr := parseCommonError(err); perr != nil {
+ t.Error(perr)
+ }
+ t.Fatal(err)
+ }
+ defer c2.Close()
+
+ if _, err := c2.WriteTo([]byte("FILEPACKETCONN TEST"), addr); err != nil {
+ if perr := parseWriteError(err); perr != nil {
+ t.Error(perr)
+ }
+ t.Fatal(err)
+ }
+ if !reflect.DeepEqual(c2.LocalAddr(), addr) {
+ t.Fatalf("got %#v; want %#v", c2.LocalAddr(), addr)
}
}
}
diff --git a/libgo/go/net/file_unix.go b/libgo/go/net/file_unix.go
index 5b24c7d09d1..9e581fcb419 100644
--- a/libgo/go/net/file_unix.go
+++ b/libgo/go/net/file_unix.go
@@ -91,7 +91,7 @@ func fileListener(f *os.File) (Listener, error) {
case *TCPAddr:
return &TCPListener{fd}, nil
case *UnixAddr:
- return &UnixListener{fd, laddr.Name}, nil
+ return &UnixListener{fd: fd, path: laddr.Name, unlink: false}, nil
}
fd.Close()
return nil, syscall.EINVAL
diff --git a/libgo/go/net/hosts.go b/libgo/go/net/hosts.go
index 27958c7cc50..c4de1b6a972 100644
--- a/libgo/go/net/hosts.go
+++ b/libgo/go/net/hosts.go
@@ -9,7 +9,7 @@ import (
"time"
)
-const cacheMaxAge = 5 * time.Minute
+const cacheMaxAge = 5 * time.Second
func parseLiteralIP(addr string) string {
var ip IP
@@ -27,51 +27,76 @@ func parseLiteralIP(addr string) string {
return ip.String() + "%" + zone
}
-// Simple cache.
+// hosts contains known host entries.
var hosts struct {
sync.Mutex
+
+ // Key for the list of literal IP addresses must be a host
+ // name. It would be part of DNS labels, a FQDN or an absolute
+ // FQDN.
+ // For now the key is converted to lower case for convenience.
byName map[string][]string
+
+ // Key for the list of host names must be a literal IP address
+ // including IPv6 address with zone identifier.
+ // We don't support old-classful IP address notation.
byAddr map[string][]string
+
expire time.Time
path string
+ mtime time.Time
+ size int64
}
func readHosts() {
now := time.Now()
hp := testHookHostsPath
- if len(hosts.byName) == 0 || now.After(hosts.expire) || hosts.path != hp {
- hs := make(map[string][]string)
- is := make(map[string][]string)
- var file *file
- if file, _ = open(hp); file == nil {
- return
+
+ if now.Before(hosts.expire) && hosts.path == hp && len(hosts.byName) > 0 {
+ return
+ }
+ mtime, size, err := stat(hp)
+ if err == nil && hosts.path == hp && hosts.mtime.Equal(mtime) && hosts.size == size {
+ hosts.expire = now.Add(cacheMaxAge)
+ return
+ }
+
+ hs := make(map[string][]string)
+ is := make(map[string][]string)
+ var file *file
+ if file, _ = open(hp); file == nil {
+ return
+ }
+ for line, ok := file.readLine(); ok; line, ok = file.readLine() {
+ if i := byteIndex(line, '#'); i >= 0 {
+ // Discard comments.
+ line = line[0:i]
}
- for line, ok := file.readLine(); ok; line, ok = file.readLine() {
- if i := byteIndex(line, '#'); i >= 0 {
- // Discard comments.
- line = line[0:i]
- }
- f := getFields(line)
- if len(f) < 2 {
- continue
- }
- addr := parseLiteralIP(f[0])
- if addr == "" {
- continue
- }
- for i := 1; i < len(f); i++ {
- h := f[i]
- hs[h] = append(hs[h], addr)
- is[addr] = append(is[addr], h)
- }
+ f := getFields(line)
+ if len(f) < 2 {
+ continue
+ }
+ addr := parseLiteralIP(f[0])
+ if addr == "" {
+ continue
+ }
+ for i := 1; i < len(f); i++ {
+ name := absDomainName([]byte(f[i]))
+ h := []byte(f[i])
+ lowerASCIIBytes(h)
+ key := absDomainName(h)
+ hs[key] = append(hs[key], addr)
+ is[addr] = append(is[addr], name)
}
- // Update the data cache.
- hosts.expire = now.Add(cacheMaxAge)
- hosts.path = hp
- hosts.byName = hs
- hosts.byAddr = is
- file.close()
}
+ // Update the data cache.
+ hosts.expire = now.Add(cacheMaxAge)
+ hosts.path = hp
+ hosts.byName = hs
+ hosts.byAddr = is
+ hosts.mtime = mtime
+ hosts.size = size
+ file.close()
}
// lookupStaticHost looks up the addresses for the given host from /etc/hosts.
@@ -80,7 +105,11 @@ func lookupStaticHost(host string) []string {
defer hosts.Unlock()
readHosts()
if len(hosts.byName) != 0 {
- if ips, ok := hosts.byName[host]; ok {
+ // TODO(jbd,bradfitz): avoid this alloc if host is already all lowercase?
+ // or linear scan the byName map if it's small enough?
+ lowerHost := []byte(host)
+ lowerASCIIBytes(lowerHost)
+ if ips, ok := hosts.byName[absDomainName(lowerHost)]; ok {
return ips
}
}
diff --git a/libgo/go/net/hosts_test.go b/libgo/go/net/hosts_test.go
index aca64c38b05..4c67bfa9824 100644
--- a/libgo/go/net/hosts_test.go
+++ b/libgo/go/net/hosts_test.go
@@ -6,6 +6,7 @@ package net
import (
"reflect"
+ "strings"
"testing"
)
@@ -48,6 +49,13 @@ var lookupStaticHostTests = []struct {
{"localhost.localdomain", []string{"fe80::3%lo0"}},
},
},
+ {
+ "testdata/case-hosts", // see golang.org/issue/12806
+ []staticHostEntry{
+ {"PreserveMe", []string{"127.0.0.1", "::1"}},
+ {"PreserveMe.local", []string{"127.0.0.1", "::1"}},
+ },
+ },
}
func TestLookupStaticHost(t *testing.T) {
@@ -56,9 +64,12 @@ func TestLookupStaticHost(t *testing.T) {
for _, tt := range lookupStaticHostTests {
testHookHostsPath = tt.name
for _, ent := range tt.ents {
- addrs := lookupStaticHost(ent.in)
- if !reflect.DeepEqual(addrs, ent.out) {
- t.Errorf("%s, lookupStaticHost(%s) = %v; want %v", tt.name, ent.in, addrs, ent.out)
+ ins := []string{ent.in, absDomainName([]byte(ent.in)), strings.ToLower(ent.in), strings.ToUpper(ent.in)}
+ for _, in := range ins {
+ addrs := lookupStaticHost(in)
+ if !reflect.DeepEqual(addrs, ent.out) {
+ t.Errorf("%s, lookupStaticHost(%s) = %v; want %v", tt.name, in, addrs, ent.out)
+ }
}
}
}
@@ -103,6 +114,13 @@ var lookupStaticAddrTests = []struct {
{"fe80::3%lo0", []string{"localhost", "localhost.localdomain"}},
},
},
+ {
+ "testdata/case-hosts", // see golang.org/issue/12806
+ []staticHostEntry{
+ {"127.0.0.1", []string{"PreserveMe", "PreserveMe.local"}},
+ {"::1", []string{"PreserveMe", "PreserveMe.local"}},
+ },
+ },
}
func TestLookupStaticAddr(t *testing.T) {
@@ -112,6 +130,9 @@ func TestLookupStaticAddr(t *testing.T) {
testHookHostsPath = tt.name
for _, ent := range tt.ents {
hosts := lookupStaticAddr(ent.in)
+ for i := range ent.out {
+ ent.out[i] = absDomainName([]byte(ent.out[i]))
+ }
if !reflect.DeepEqual(hosts, ent.out) {
t.Errorf("%s, lookupStaticAddr(%s) = %v; want %v", tt.name, ent.in, hosts, ent.out)
}
diff --git a/libgo/go/net/http/cgi/host.go b/libgo/go/net/http/cgi/host.go
index 4efbe7abeec..9b4d8754183 100644
--- a/libgo/go/net/http/cgi/host.go
+++ b/libgo/go/net/http/cgi/host.go
@@ -77,15 +77,15 @@ type Handler struct {
// Env: []string{"SCRIPT_FILENAME=foo.php"},
// }
func removeLeadingDuplicates(env []string) (ret []string) {
- n := len(env)
- for i := 0; i < n; i++ {
- e := env[i]
- s := strings.SplitN(e, "=", 2)[0]
+ for i, e := range env {
found := false
- for j := i + 1; j < n; j++ {
- if s == strings.SplitN(env[j], "=", 2)[0] {
- found = true
- break
+ if eq := strings.IndexByte(e, '='); eq != -1 {
+ keq := e[:eq+1] // "key="
+ for _, e2 := range env[i+1:] {
+ if strings.HasPrefix(e2, keq) {
+ found = true
+ break
+ }
}
}
if !found {
@@ -159,10 +159,6 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
env = append(env, "CONTENT_TYPE="+ctype)
}
- if h.Env != nil {
- env = append(env, h.Env...)
- }
-
envPath := os.Getenv("PATH")
if envPath == "" {
envPath = "/bin:/usr/bin:/usr/ucb:/usr/bsd:/usr/local/bin"
@@ -181,6 +177,10 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
}
}
+ if h.Env != nil {
+ env = append(env, h.Env...)
+ }
+
env = removeLeadingDuplicates(env)
var cwd, path string
diff --git a/libgo/go/net/http/cgi/host_test.go b/libgo/go/net/http/cgi/host_test.go
index f3411105ca9..fb7d66adb9f 100644
--- a/libgo/go/net/http/cgi/host_test.go
+++ b/libgo/go/net/http/cgi/host_test.go
@@ -16,6 +16,7 @@ import (
"os"
"os/exec"
"path/filepath"
+ "reflect"
"runtime"
"strconv"
"strings"
@@ -488,12 +489,36 @@ func TestEnvOverride(t *testing.T) {
Args: []string{cgifile},
Env: []string{
"SCRIPT_FILENAME=" + cgifile,
- "REQUEST_URI=/foo/bar"},
+ "REQUEST_URI=/foo/bar",
+ "PATH=/wibble"},
}
expectedMap := map[string]string{
"cwd": cwd,
"env-SCRIPT_FILENAME": cgifile,
"env-REQUEST_URI": "/foo/bar",
+ "env-PATH": "/wibble",
}
runCgiTest(t, h, "GET /test.cgi HTTP/1.0\nHost: example.com\n\n", expectedMap)
}
+
+func TestRemoveLeadingDuplicates(t *testing.T) {
+ tests := []struct {
+ env []string
+ want []string
+ }{
+ {
+ env: []string{"a=b", "b=c", "a=b2"},
+ want: []string{"b=c", "a=b2"},
+ },
+ {
+ env: []string{"a=b", "b=c", "d", "e=f"},
+ want: []string{"a=b", "b=c", "d", "e=f"},
+ },
+ }
+ for _, tt := range tests {
+ got := removeLeadingDuplicates(tt.env)
+ if !reflect.DeepEqual(got, tt.want) {
+ t.Errorf("removeLeadingDuplicates(%q) = %q; want %q", tt.env, got, tt.want)
+ }
+ }
+}
diff --git a/libgo/go/net/http/client.go b/libgo/go/net/http/client.go
index 7f2fbb4678e..3106d229da6 100644
--- a/libgo/go/net/http/client.go
+++ b/libgo/go/net/http/client.go
@@ -10,6 +10,7 @@
package http
import (
+ "crypto/tls"
"encoding/base64"
"errors"
"fmt"
@@ -19,7 +20,6 @@ import (
"net/url"
"strings"
"sync"
- "sync/atomic"
"time"
)
@@ -65,10 +65,15 @@ type Client struct {
//
// A Timeout of zero means no timeout.
//
- // The Client's Transport must support the CancelRequest
- // method or Client will return errors when attempting to make
- // a request with Get, Head, Post, or Do. Client's default
- // Transport (DefaultTransport) supports CancelRequest.
+ // The Client cancels requests to the underlying Transport
+ // using the Request.Cancel mechanism. Requests passed
+ // to Client.Do may still set Request.Cancel; both will
+ // cancel the request.
+ //
+ // For compatibility, the Client will also use the deprecated
+ // CancelRequest method on Transport if found. New
+ // RoundTripper implementations should use Request.Cancel
+ // instead of implementing CancelRequest.
Timeout time.Duration
}
@@ -82,19 +87,26 @@ var DefaultClient = &Client{}
// goroutines.
type RoundTripper interface {
// RoundTrip executes a single HTTP transaction, returning
- // the Response for the request req. RoundTrip should not
- // attempt to interpret the response. In particular,
- // RoundTrip must return err == nil if it obtained a response,
- // regardless of the response's HTTP status code. A non-nil
- // err should be reserved for failure to obtain a response.
- // Similarly, RoundTrip should not attempt to handle
- // higher-level protocol details such as redirects,
+ // a Response for the provided Request.
+ //
+ // RoundTrip should not attempt to interpret the response. In
+ // particular, RoundTrip must return err == nil if it obtained
+ // a response, regardless of the response's HTTP status code.
+ // A non-nil err should be reserved for failure to obtain a
+ // response. Similarly, RoundTrip should not attempt to
+ // handle higher-level protocol details such as redirects,
// authentication, or cookies.
//
// RoundTrip should not modify the request, except for
- // consuming and closing the Body, including on errors. The
- // request's URL and Header fields are guaranteed to be
- // initialized.
+ // consuming and closing the Request's Body.
+ //
+ // RoundTrip must always close the body, including on errors,
+ // but depending on the implementation may do so in a separate
+ // goroutine even after RoundTrip returns. This means that
+ // callers wanting to reuse the body for subsequent requests
+ // must arrange to wait for the Close call before doing so.
+ //
+ // The Request's URL and Header fields must be initialized.
RoundTrip(*Request) (*Response, error)
}
@@ -134,13 +146,13 @@ type readClose struct {
io.Closer
}
-func (c *Client) send(req *Request) (*Response, error) {
+func (c *Client) send(req *Request, deadline time.Time) (*Response, error) {
if c.Jar != nil {
for _, cookie := range c.Jar.Cookies(req.URL) {
req.AddCookie(cookie)
}
}
- resp, err := send(req, c.transport())
+ resp, err := send(req, c.transport(), deadline)
if err != nil {
return nil, err
}
@@ -171,13 +183,21 @@ func (c *Client) send(req *Request) (*Response, error) {
//
// Generally Get, Post, or PostForm will be used instead of Do.
func (c *Client) Do(req *Request) (resp *Response, err error) {
- if req.Method == "GET" || req.Method == "HEAD" {
+ method := valueOrDefault(req.Method, "GET")
+ if method == "GET" || method == "HEAD" {
return c.doFollowingRedirects(req, shouldRedirectGet)
}
- if req.Method == "POST" || req.Method == "PUT" {
+ if method == "POST" || method == "PUT" {
return c.doFollowingRedirects(req, shouldRedirectPost)
}
- return c.send(req)
+ return c.send(req, c.deadline())
+}
+
+func (c *Client) deadline() time.Time {
+ if c.Timeout > 0 {
+ return time.Now().Add(c.Timeout)
+ }
+ return time.Time{}
}
func (c *Client) transport() RoundTripper {
@@ -189,8 +209,10 @@ func (c *Client) transport() RoundTripper {
// send issues an HTTP request.
// Caller should close resp.Body when done reading from it.
-func send(req *Request, t RoundTripper) (resp *Response, err error) {
- if t == nil {
+func send(ireq *Request, rt RoundTripper, deadline time.Time) (*Response, error) {
+ req := ireq // req is either the original request, or a modified fork
+
+ if rt == nil {
req.closeBody()
return nil, errors.New("http: no Client.Transport or DefaultTransport")
}
@@ -205,28 +227,122 @@ func send(req *Request, t RoundTripper) (resp *Response, err error) {
return nil, errors.New("http: Request.RequestURI can't be set in client requests.")
}
+ // forkReq forks req into a shallow clone of ireq the first
+ // time it's called.
+ forkReq := func() {
+ if ireq == req {
+ req = new(Request)
+ *req = *ireq // shallow clone
+ }
+ }
+
// Most the callers of send (Get, Post, et al) don't need
// Headers, leaving it uninitialized. We guarantee to the
// Transport that this has been initialized, though.
if req.Header == nil {
+ forkReq()
req.Header = make(Header)
}
if u := req.URL.User; u != nil && req.Header.Get("Authorization") == "" {
username := u.Username()
password, _ := u.Password()
+ forkReq()
+ req.Header = cloneHeader(ireq.Header)
req.Header.Set("Authorization", "Basic "+basicAuth(username, password))
}
- resp, err = t.RoundTrip(req)
+
+ if !deadline.IsZero() {
+ forkReq()
+ }
+ stopTimer, wasCanceled := setRequestCancel(req, rt, deadline)
+
+ resp, err := rt.RoundTrip(req)
if err != nil {
+ stopTimer()
if resp != nil {
log.Printf("RoundTripper returned a response & error; ignoring response")
}
+ if tlsErr, ok := err.(tls.RecordHeaderError); ok {
+ // If we get a bad TLS record header, check to see if the
+ // response looks like HTTP and give a more helpful error.
+ // See golang.org/issue/11111.
+ if string(tlsErr.RecordHeader[:]) == "HTTP/" {
+ err = errors.New("http: server gave HTTP response to HTTPS client")
+ }
+ }
return nil, err
}
+ if !deadline.IsZero() {
+ resp.Body = &cancelTimerBody{
+ stop: stopTimer,
+ rc: resp.Body,
+ reqWasCanceled: wasCanceled,
+ }
+ }
return resp, nil
}
+// setRequestCancel sets the Cancel field of req, if deadline is
+// non-zero. The RoundTripper's type is used to determine whether the legacy
+// CancelRequest behavior should be used.
+func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTimer func(), wasCanceled func() bool) {
+ if deadline.IsZero() {
+ return nop, alwaysFalse
+ }
+
+ initialReqCancel := req.Cancel // the user's original Request.Cancel, if any
+
+ cancel := make(chan struct{})
+ req.Cancel = cancel
+
+ wasCanceled = func() bool {
+ select {
+ case <-cancel:
+ return true
+ default:
+ return false
+ }
+ }
+
+ doCancel := func() {
+ // The new way:
+ close(cancel)
+
+ // The legacy compatibility way, used only
+ // for RoundTripper implementations written
+ // before Go 1.5 or Go 1.6.
+ type canceler interface {
+ CancelRequest(*Request)
+ }
+ switch v := rt.(type) {
+ case *Transport, *http2Transport:
+ // Do nothing. The net/http package's transports
+ // support the new Request.Cancel channel
+ case canceler:
+ v.CancelRequest(req)
+ }
+ }
+
+ stopTimerCh := make(chan struct{})
+ var once sync.Once
+ stopTimer = func() { once.Do(func() { close(stopTimerCh) }) }
+
+ timer := time.NewTimer(deadline.Sub(time.Now()))
+ go func() {
+ select {
+ case <-initialReqCancel:
+ doCancel()
+ case <-timer.C:
+ doCancel()
+ case <-stopTimerCh:
+ timer.Stop()
+ }
+ }()
+
+ return stopTimer, wasCanceled
+}
+
// See 2 (end of page 4) http://www.ietf.org/rfc/rfc2617.txt
// "To receive authorization, the client sends the userid and password,
// separated by a single colon (":") character, within a base64
@@ -321,34 +437,15 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo
return nil, errors.New("http: nil Request.URL")
}
- var reqmu sync.Mutex // guards req
req := ireq
-
- var timer *time.Timer
- var atomicWasCanceled int32 // atomic bool (1 or 0)
- var wasCanceled = alwaysFalse
- if c.Timeout > 0 {
- wasCanceled = func() bool { return atomic.LoadInt32(&atomicWasCanceled) != 0 }
- type canceler interface {
- CancelRequest(*Request)
- }
- tr, ok := c.transport().(canceler)
- if !ok {
- return nil, fmt.Errorf("net/http: Client Transport of type %T doesn't support CancelRequest; Timeout not supported", c.transport())
- }
- timer = time.AfterFunc(c.Timeout, func() {
- atomic.StoreInt32(&atomicWasCanceled, 1)
- reqmu.Lock()
- defer reqmu.Unlock()
- tr.CancelRequest(req)
- })
- }
+ deadline := c.deadline()
urlStr := "" // next relative or absolute URL to fetch (after first request)
redirectFailed := false
for redirect := 0; ; redirect++ {
if redirect != 0 {
nreq := new(Request)
+ nreq.Cancel = ireq.Cancel
nreq.Method = ireq.Method
if ireq.Method == "POST" || ireq.Method == "PUT" {
nreq.Method = "GET"
@@ -371,14 +468,12 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo
break
}
}
- reqmu.Lock()
req = nreq
- reqmu.Unlock()
}
urlStr = req.URL.String()
- if resp, err = c.send(req); err != nil {
- if wasCanceled() {
+ if resp, err = c.send(req, deadline); err != nil {
+ if !deadline.IsZero() && !time.Now().Before(deadline) {
err = &httpError{
err: err.Error() + " (Client.Timeout exceeded while awaiting headers)",
timeout: true,
@@ -403,19 +498,12 @@ func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bo
via = append(via, req)
continue
}
- if timer != nil {
- resp.Body = &cancelTimerBody{
- t: timer,
- rc: resp.Body,
- reqWasCanceled: wasCanceled,
- }
- }
return resp, nil
}
- method := ireq.Method
+ method := valueOrDefault(ireq.Method, "GET")
urlErr := &url.Error{
- Op: method[0:1] + strings.ToLower(method[1:]),
+ Op: method[:1] + strings.ToLower(method[1:]),
URL: urlStr,
Err: err,
}
@@ -528,30 +616,35 @@ func (c *Client) Head(url string) (resp *Response, err error) {
}
// cancelTimerBody is an io.ReadCloser that wraps rc with two features:
-// 1) on Read EOF or Close, the timer t is Stopped,
+// 1) on Read error or close, the stop func is called.
// 2) On Read failure, if reqWasCanceled is true, the error is wrapped and
// marked as net.Error that hit its timeout.
type cancelTimerBody struct {
- t *time.Timer
+ stop func() // stops the time.Timer waiting to cancel the request
rc io.ReadCloser
reqWasCanceled func() bool
}
func (b *cancelTimerBody) Read(p []byte) (n int, err error) {
n, err = b.rc.Read(p)
+ if err == nil {
+ return n, nil
+ }
+ b.stop()
if err == io.EOF {
- b.t.Stop()
- } else if err != nil && b.reqWasCanceled() {
- return n, &httpError{
+ return n, err
+ }
+ if b.reqWasCanceled() {
+ err = &httpError{
err: err.Error() + " (Client.Timeout exceeded while reading body)",
timeout: true,
}
}
- return
+ return n, err
}
func (b *cancelTimerBody) Close() error {
err := b.rc.Close()
- b.t.Stop()
+ b.stop()
return err
}
diff --git a/libgo/go/net/http/client_test.go b/libgo/go/net/http/client_test.go
index 7b524d381bc..8939dc8baf9 100644
--- a/libgo/go/net/http/client_test.go
+++ b/libgo/go/net/http/client_test.go
@@ -20,8 +20,6 @@ import (
. "net/http"
"net/http/httptest"
"net/url"
- "reflect"
- "sort"
"strconv"
"strings"
"sync"
@@ -83,12 +81,15 @@ func TestClient(t *testing.T) {
}
}
-func TestClientHead(t *testing.T) {
+func TestClientHead_h1(t *testing.T) { testClientHead(t, h1Mode) }
+func TestClientHead_h2(t *testing.T) { testClientHead(t, h2Mode) }
+
+func testClientHead(t *testing.T, h2 bool) {
defer afterTest(t)
- ts := httptest.NewServer(robotsTxtHandler)
- defer ts.Close()
+ cst := newClientServerTest(t, h2, robotsTxtHandler)
+ defer cst.close()
- r, err := Head(ts.URL)
+ r, err := cst.c.Head(cst.ts.URL)
if err != nil {
t.Fatal(err)
}
@@ -230,9 +231,18 @@ func TestClientRedirects(t *testing.T) {
t.Errorf("with default client Do, expected error %q, got %q", e, g)
}
+ // Requests with an empty Method should also redirect (Issue 12705)
+ greq.Method = ""
+ _, err = c.Do(greq)
+ if e, g := "Get /?n=10: stopped after 10 redirects", fmt.Sprintf("%v", err); e != g {
+ t.Errorf("with default client Do and empty Method, expected error %q, got %q", e, g)
+ }
+
var checkErr error
var lastVia []*Request
- c = &Client{CheckRedirect: func(_ *Request, via []*Request) error {
+ var lastReq *Request
+ c = &Client{CheckRedirect: func(req *Request, via []*Request) error {
+ lastReq = req
lastVia = via
return checkErr
}}
@@ -252,6 +262,20 @@ func TestClientRedirects(t *testing.T) {
t.Errorf("expected lastVia to have contained %d elements; got %d", e, g)
}
+ // Test that Request.Cancel is propagated between requests (Issue 14053)
+ creq, _ := NewRequest("HEAD", ts.URL, nil)
+ cancel := make(chan struct{})
+ creq.Cancel = cancel
+ if _, err := c.Do(creq); err != nil {
+ t.Fatal(err)
+ }
+ if lastReq == nil {
+ t.Fatal("didn't see redirect")
+ }
+ if lastReq.Cancel != cancel {
+ t.Errorf("expected lastReq to have the cancel channel set on the inital req")
+ }
+
checkErr = errors.New("no redirects allowed")
res, err = c.Get(ts.URL)
if urlError, ok := err.(*url.Error); !ok || urlError.Err != checkErr {
@@ -486,20 +510,23 @@ func (j *RecordingJar) logf(format string, args ...interface{}) {
fmt.Fprintf(&j.log, format, args...)
}
-func TestStreamingGet(t *testing.T) {
+func TestStreamingGet_h1(t *testing.T) { testStreamingGet(t, h1Mode) }
+func TestStreamingGet_h2(t *testing.T) { testStreamingGet(t, h2Mode) }
+
+func testStreamingGet(t *testing.T, h2 bool) {
defer afterTest(t)
say := make(chan string)
- ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
w.(Flusher).Flush()
for str := range say {
w.Write([]byte(str))
w.(Flusher).Flush()
}
}))
- defer ts.Close()
+ defer cst.close()
- c := &Client{}
- res, err := c.Get(ts.URL)
+ c := cst.c
+ res, err := c.Get(cst.ts.URL)
if err != nil {
t.Fatal(err)
}
@@ -642,14 +669,18 @@ func newTLSTransport(t *testing.T, ts *httptest.Server) *Transport {
func TestClientWithCorrectTLSServerName(t *testing.T) {
defer afterTest(t)
+
+ const serverName = "example.com"
ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {
- if r.TLS.ServerName != "127.0.0.1" {
- t.Errorf("expected client to set ServerName 127.0.0.1, got: %q", r.TLS.ServerName)
+ if r.TLS.ServerName != serverName {
+ t.Errorf("expected client to set ServerName %q, got: %q", serverName, r.TLS.ServerName)
}
}))
defer ts.Close()
- c := &Client{Transport: newTLSTransport(t, ts)}
+ trans := newTLSTransport(t, ts)
+ trans.TLSClientConfig.ServerName = serverName
+ c := &Client{Transport: trans}
if _, err := c.Get(ts.URL); err != nil {
t.Fatalf("expected successful TLS connection, got error: %v", err)
}
@@ -739,15 +770,37 @@ func TestResponseSetsTLSConnectionState(t *testing.T) {
}
}
+// Check that an HTTPS client can interpret a particular TLS error
+// to determine that the server is speaking HTTP.
+// See golang.org/issue/11111.
+func TestHTTPSClientDetectsHTTPServer(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
+ defer ts.Close()
+
+ _, err := Get(strings.Replace(ts.URL, "http", "https", 1))
+ if got := err.Error(); !strings.Contains(got, "HTTP response to HTTPS client") {
+ t.Fatalf("error = %q; want error indicating HTTP response to HTTPS request", got)
+ }
+}
+
// Verify Response.ContentLength is populated. https://golang.org/issue/4126
-func TestClientHeadContentLength(t *testing.T) {
+func TestClientHeadContentLength_h1(t *testing.T) {
+ testClientHeadContentLength(t, h1Mode)
+}
+
+func TestClientHeadContentLength_h2(t *testing.T) {
+ testClientHeadContentLength(t, h2Mode)
+}
+
+func testClientHeadContentLength(t *testing.T, h2 bool) {
defer afterTest(t)
- ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
if v := r.FormValue("cl"); v != "" {
w.Header().Set("Content-Length", v)
}
}))
- defer ts.Close()
+ defer cst.close()
tests := []struct {
suffix string
want int64
@@ -757,8 +810,8 @@ func TestClientHeadContentLength(t *testing.T) {
{"", -1},
}
for _, tt := range tests {
- req, _ := NewRequest("HEAD", ts.URL+tt.suffix, nil)
- res, err := DefaultClient.Do(req)
+ req, _ := NewRequest("HEAD", cst.ts.URL+tt.suffix, nil)
+ res, err := cst.c.Do(req)
if err != nil {
t.Fatal(err)
}
@@ -884,14 +937,17 @@ func TestBasicAuthHeadersPreserved(t *testing.T) {
}
-func TestClientTimeout(t *testing.T) {
+func TestClientTimeout_h1(t *testing.T) { testClientTimeout(t, h1Mode) }
+func TestClientTimeout_h2(t *testing.T) { testClientTimeout(t, h2Mode) }
+
+func testClientTimeout(t *testing.T, h2 bool) {
if testing.Short() {
t.Skip("skipping in short mode")
}
defer afterTest(t)
sawRoot := make(chan bool, 1)
sawSlow := make(chan bool, 1)
- ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
if r.URL.Path == "/" {
sawRoot <- true
Redirect(w, r, "/slow", StatusFound)
@@ -905,13 +961,11 @@ func TestClientTimeout(t *testing.T) {
return
}
}))
- defer ts.Close()
+ defer cst.close()
const timeout = 500 * time.Millisecond
- c := &Client{
- Timeout: timeout,
- }
+ cst.c.Timeout = timeout
- res, err := c.Get(ts.URL)
+ res, err := cst.c.Get(cst.ts.URL)
if err != nil {
t.Fatal(err)
}
@@ -957,17 +1011,20 @@ func TestClientTimeout(t *testing.T) {
}
}
+func TestClientTimeout_Headers_h1(t *testing.T) { testClientTimeout_Headers(t, h1Mode) }
+func TestClientTimeout_Headers_h2(t *testing.T) { testClientTimeout_Headers(t, h2Mode) }
+
// Client.Timeout firing before getting to the body
-func TestClientTimeout_Headers(t *testing.T) {
+func testClientTimeout_Headers(t *testing.T, h2 bool) {
if testing.Short() {
t.Skip("skipping in short mode")
}
defer afterTest(t)
donec := make(chan bool)
- ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
<-donec
}))
- defer ts.Close()
+ defer cst.close()
// Note that we use a channel send here and not a close.
// The race detector doesn't know that we're waiting for a timeout
// and thinks that the waitgroup inside httptest.Server is added to concurrently
@@ -977,19 +1034,17 @@ func TestClientTimeout_Headers(t *testing.T) {
// doesn't know this, so synchronize explicitly.
defer func() { donec <- true }()
- c := &Client{Timeout: 500 * time.Millisecond}
-
- _, err := c.Get(ts.URL)
+ cst.c.Timeout = 500 * time.Millisecond
+ _, err := cst.c.Get(cst.ts.URL)
if err == nil {
t.Fatal("got response from Get; expected error")
}
- ue, ok := err.(*url.Error)
- if !ok {
+ if _, ok := err.(*url.Error); !ok {
t.Fatalf("Got error of type %T; want *url.Error", err)
}
- ne, ok := ue.Err.(net.Error)
+ ne, ok := err.(net.Error)
if !ok {
- t.Fatalf("Got url.Error.Err of type %T; want some net.Error", err)
+ t.Fatalf("Got error of type %T; want some net.Error", err)
}
if !ne.Timeout() {
t.Error("net.Error.Timeout = false; want true")
@@ -999,18 +1054,20 @@ func TestClientTimeout_Headers(t *testing.T) {
}
}
-func TestClientRedirectEatsBody(t *testing.T) {
+func TestClientRedirectEatsBody_h1(t *testing.T) { testClientRedirectEatsBody(t, h1Mode) }
+func TestClientRedirectEatsBody_h2(t *testing.T) { testClientRedirectEatsBody(t, h2Mode) }
+func testClientRedirectEatsBody(t *testing.T, h2 bool) {
defer afterTest(t)
saw := make(chan string, 2)
- ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
saw <- r.RemoteAddr
if r.URL.Path == "/" {
Redirect(w, r, "/foo", StatusFound) // which includes a body
}
}))
- defer ts.Close()
+ defer cst.close()
- res, err := Get(ts.URL)
+ res, err := cst.c.Get(cst.ts.URL)
if err != nil {
t.Fatal(err)
}
@@ -1047,76 +1104,6 @@ func (f eofReaderFunc) Read(p []byte) (n int, err error) {
return 0, io.EOF
}
-func TestClientTrailers(t *testing.T) {
- defer afterTest(t)
- ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
- w.Header().Set("Connection", "close")
- w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B")
- w.Header().Add("Trailer", "Server-Trailer-C")
-
- var decl []string
- for k := range r.Trailer {
- decl = append(decl, k)
- }
- sort.Strings(decl)
-
- slurp, err := ioutil.ReadAll(r.Body)
- if err != nil {
- t.Errorf("Server reading request body: %v", err)
- }
- if string(slurp) != "foo" {
- t.Errorf("Server read request body %q; want foo", slurp)
- }
- if r.Trailer == nil {
- io.WriteString(w, "nil Trailer")
- } else {
- fmt.Fprintf(w, "decl: %v, vals: %s, %s",
- decl,
- r.Trailer.Get("Client-Trailer-A"),
- r.Trailer.Get("Client-Trailer-B"))
- }
-
- // How handlers set Trailers: declare it ahead of time
- // with the Trailer header, and then mutate the
- // Header() of those values later, after the response
- // has been written (we wrote to w above).
- w.Header().Set("Server-Trailer-A", "valuea")
- w.Header().Set("Server-Trailer-C", "valuec") // skipping B
- }))
- defer ts.Close()
-
- var req *Request
- req, _ = NewRequest("POST", ts.URL, io.MultiReader(
- eofReaderFunc(func() {
- req.Trailer["Client-Trailer-A"] = []string{"valuea"}
- }),
- strings.NewReader("foo"),
- eofReaderFunc(func() {
- req.Trailer["Client-Trailer-B"] = []string{"valueb"}
- }),
- ))
- req.Trailer = Header{
- "Client-Trailer-A": nil, // to be set later
- "Client-Trailer-B": nil, // to be set later
- }
- req.ContentLength = -1
- res, err := DefaultClient.Do(req)
- if err != nil {
- t.Fatal(err)
- }
- if err := wantBody(res, err, "decl: [Client-Trailer-A Client-Trailer-B], vals: valuea, valueb"); err != nil {
- t.Error(err)
- }
- want := Header{
- "Server-Trailer-A": []string{"valuea"},
- "Server-Trailer-B": nil,
- "Server-Trailer-C": []string{"valuec"},
- }
- if !reflect.DeepEqual(res.Trailer, want) {
- t.Errorf("Response trailers = %#v; want %#v", res.Trailer, want)
- }
-}
-
func TestReferer(t *testing.T) {
tests := []struct {
lastReq, newReq string // from -> to URLs
diff --git a/libgo/go/net/http/clientserver_test.go b/libgo/go/net/http/clientserver_test.go
new file mode 100644
index 00000000000..3c87fd0cf83
--- /dev/null
+++ b/libgo/go/net/http/clientserver_test.go
@@ -0,0 +1,1056 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Tests that use both the client & server, in both HTTP/1 and HTTP/2 mode.
+
+package http_test
+
+import (
+ "bytes"
+ "compress/gzip"
+ "crypto/tls"
+ "fmt"
+ "io"
+ "io/ioutil"
+ "log"
+ "net"
+ . "net/http"
+ "net/http/httptest"
+ "net/url"
+ "os"
+ "reflect"
+ "runtime"
+ "sort"
+ "strings"
+ "sync"
+ "sync/atomic"
+ "testing"
+ "time"
+)
+
+type clientServerTest struct {
+ t *testing.T
+ h2 bool
+ h Handler
+ ts *httptest.Server
+ tr *Transport
+ c *Client
+}
+
+func (t *clientServerTest) close() {
+ t.tr.CloseIdleConnections()
+ t.ts.Close()
+}
+
+const (
+ h1Mode = false
+ h2Mode = true
+)
+
+func newClientServerTest(t *testing.T, h2 bool, h Handler, opts ...interface{}) *clientServerTest {
+ cst := &clientServerTest{
+ t: t,
+ h2: h2,
+ h: h,
+ tr: &Transport{},
+ }
+ cst.c = &Client{Transport: cst.tr}
+
+ for _, opt := range opts {
+ switch opt := opt.(type) {
+ case func(*Transport):
+ opt(cst.tr)
+ default:
+ t.Fatalf("unhandled option type %T", opt)
+ }
+ }
+
+ if !h2 {
+ cst.ts = httptest.NewServer(h)
+ return cst
+ }
+ cst.ts = httptest.NewUnstartedServer(h)
+ ExportHttp2ConfigureServer(cst.ts.Config, nil)
+ cst.ts.TLS = cst.ts.Config.TLSConfig
+ cst.ts.StartTLS()
+
+ cst.tr.TLSClientConfig = &tls.Config{
+ InsecureSkipVerify: true,
+ }
+ if err := ExportHttp2ConfigureTransport(cst.tr); err != nil {
+ t.Fatal(err)
+ }
+ return cst
+}
+
+// Testing the newClientServerTest helper itself.
+func TestNewClientServerTest(t *testing.T) {
+ var got struct {
+ sync.Mutex
+ log []string
+ }
+ h := HandlerFunc(func(w ResponseWriter, r *Request) {
+ got.Lock()
+ defer got.Unlock()
+ got.log = append(got.log, r.Proto)
+ })
+ for _, v := range [2]bool{false, true} {
+ cst := newClientServerTest(t, v, h)
+ if _, err := cst.c.Head(cst.ts.URL); err != nil {
+ t.Fatal(err)
+ }
+ cst.close()
+ }
+ got.Lock() // no need to unlock
+ if want := []string{"HTTP/1.1", "HTTP/2.0"}; !reflect.DeepEqual(got.log, want) {
+ t.Errorf("got %q; want %q", got.log, want)
+ }
+}
+
+func TestChunkedResponseHeaders_h1(t *testing.T) { testChunkedResponseHeaders(t, h1Mode) }
+func TestChunkedResponseHeaders_h2(t *testing.T) { testChunkedResponseHeaders(t, h2Mode) }
+
+func testChunkedResponseHeaders(t *testing.T, h2 bool) {
+ defer afterTest(t)
+ log.SetOutput(ioutil.Discard) // is noisy otherwise
+ defer log.SetOutput(os.Stderr)
+ cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Length", "intentional gibberish") // we check that this is deleted
+ w.(Flusher).Flush()
+ fmt.Fprintf(w, "I am a chunked response.")
+ }))
+ defer cst.close()
+
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatalf("Get error: %v", err)
+ }
+ defer res.Body.Close()
+ if g, e := res.ContentLength, int64(-1); g != e {
+ t.Errorf("expected ContentLength of %d; got %d", e, g)
+ }
+ wantTE := []string{"chunked"}
+ if h2 {
+ wantTE = nil
+ }
+ if !reflect.DeepEqual(res.TransferEncoding, wantTE) {
+ t.Errorf("TransferEncoding = %v; want %v", res.TransferEncoding, wantTE)
+ }
+ if got, haveCL := res.Header["Content-Length"]; haveCL {
+ t.Errorf("Unexpected Content-Length: %q", got)
+ }
+}
+
+type reqFunc func(c *Client, url string) (*Response, error)
+
+// h12Compare is a test that compares HTTP/1 and HTTP/2 behavior
+// against each other.
+type h12Compare struct {
+ Handler func(ResponseWriter, *Request) // required
+ ReqFunc reqFunc // optional
+ CheckResponse func(proto string, res *Response) // optional
+ Opts []interface{}
+}
+
+func (tt h12Compare) reqFunc() reqFunc {
+ if tt.ReqFunc == nil {
+ return (*Client).Get
+ }
+ return tt.ReqFunc
+}
+
+func (tt h12Compare) run(t *testing.T) {
+ cst1 := newClientServerTest(t, false, HandlerFunc(tt.Handler), tt.Opts...)
+ defer cst1.close()
+ cst2 := newClientServerTest(t, true, HandlerFunc(tt.Handler), tt.Opts...)
+ defer cst2.close()
+
+ res1, err := tt.reqFunc()(cst1.c, cst1.ts.URL)
+ if err != nil {
+ t.Errorf("HTTP/1 request: %v", err)
+ return
+ }
+ res2, err := tt.reqFunc()(cst2.c, cst2.ts.URL)
+ if err != nil {
+ t.Errorf("HTTP/2 request: %v", err)
+ return
+ }
+ tt.normalizeRes(t, res1, "HTTP/1.1")
+ tt.normalizeRes(t, res2, "HTTP/2.0")
+ res1body, res2body := res1.Body, res2.Body
+
+ eres1 := mostlyCopy(res1)
+ eres2 := mostlyCopy(res2)
+ if !reflect.DeepEqual(eres1, eres2) {
+ t.Errorf("Response headers to handler differed:\nhttp/1 (%v):\n\t%#v\nhttp/2 (%v):\n\t%#v",
+ cst1.ts.URL, eres1, cst2.ts.URL, eres2)
+ }
+ if !reflect.DeepEqual(res1body, res2body) {
+ t.Errorf("Response bodies to handler differed.\nhttp1: %v\nhttp2: %v\n", res1body, res2body)
+ }
+ if fn := tt.CheckResponse; fn != nil {
+ res1.Body, res2.Body = res1body, res2body
+ fn("HTTP/1.1", res1)
+ fn("HTTP/2.0", res2)
+ }
+}
+
+func mostlyCopy(r *Response) *Response {
+ c := *r
+ c.Body = nil
+ c.TransferEncoding = nil
+ c.TLS = nil
+ c.Request = nil
+ return &c
+}
+
+type slurpResult struct {
+ io.ReadCloser
+ body []byte
+ err error
+}
+
+func (sr slurpResult) String() string { return fmt.Sprintf("body %q; err %v", sr.body, sr.err) }
+
+func (tt h12Compare) normalizeRes(t *testing.T, res *Response, wantProto string) {
+ if res.Proto == wantProto {
+ res.Proto, res.ProtoMajor, res.ProtoMinor = "", 0, 0
+ } else {
+ t.Errorf("got %q response; want %q", res.Proto, wantProto)
+ }
+ slurp, err := ioutil.ReadAll(res.Body)
+ res.Body.Close()
+ res.Body = slurpResult{
+ ReadCloser: ioutil.NopCloser(bytes.NewReader(slurp)),
+ body: slurp,
+ err: err,
+ }
+ for i, v := range res.Header["Date"] {
+ res.Header["Date"][i] = strings.Repeat("x", len(v))
+ }
+ if res.Request == nil {
+ t.Errorf("for %s, no request", wantProto)
+ }
+ if (res.TLS != nil) != (wantProto == "HTTP/2.0") {
+ t.Errorf("TLS set = %v; want %v", res.TLS != nil, res.TLS == nil)
+ }
+}
+
+// Issue 13532
+func TestH12_HeadContentLengthNoBody(t *testing.T) {
+ h12Compare{
+ ReqFunc: (*Client).Head,
+ Handler: func(w ResponseWriter, r *Request) {
+ },
+ }.run(t)
+}
+
+func TestH12_HeadContentLengthSmallBody(t *testing.T) {
+ h12Compare{
+ ReqFunc: (*Client).Head,
+ Handler: func(w ResponseWriter, r *Request) {
+ io.WriteString(w, "small")
+ },
+ }.run(t)
+}
+
+func TestH12_HeadContentLengthLargeBody(t *testing.T) {
+ h12Compare{
+ ReqFunc: (*Client).Head,
+ Handler: func(w ResponseWriter, r *Request) {
+ chunk := strings.Repeat("x", 512<<10)
+ for i := 0; i < 10; i++ {
+ io.WriteString(w, chunk)
+ }
+ },
+ }.run(t)
+}
+
+func TestH12_200NoBody(t *testing.T) {
+ h12Compare{Handler: func(w ResponseWriter, r *Request) {}}.run(t)
+}
+
+func TestH2_204NoBody(t *testing.T) { testH12_noBody(t, 204) }
+func TestH2_304NoBody(t *testing.T) { testH12_noBody(t, 304) }
+func TestH2_404NoBody(t *testing.T) { testH12_noBody(t, 404) }
+
+func testH12_noBody(t *testing.T, status int) {
+ h12Compare{Handler: func(w ResponseWriter, r *Request) {
+ w.WriteHeader(status)
+ }}.run(t)
+}
+
+func TestH12_SmallBody(t *testing.T) {
+ h12Compare{Handler: func(w ResponseWriter, r *Request) {
+ io.WriteString(w, "small body")
+ }}.run(t)
+}
+
+func TestH12_ExplicitContentLength(t *testing.T) {
+ h12Compare{Handler: func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Length", "3")
+ io.WriteString(w, "foo")
+ }}.run(t)
+}
+
+func TestH12_FlushBeforeBody(t *testing.T) {
+ h12Compare{Handler: func(w ResponseWriter, r *Request) {
+ w.(Flusher).Flush()
+ io.WriteString(w, "foo")
+ }}.run(t)
+}
+
+func TestH12_FlushMidBody(t *testing.T) {
+ h12Compare{Handler: func(w ResponseWriter, r *Request) {
+ io.WriteString(w, "foo")
+ w.(Flusher).Flush()
+ io.WriteString(w, "bar")
+ }}.run(t)
+}
+
+func TestH12_Head_ExplicitLen(t *testing.T) {
+ h12Compare{
+ ReqFunc: (*Client).Head,
+ Handler: func(w ResponseWriter, r *Request) {
+ if r.Method != "HEAD" {
+ t.Errorf("unexpected method %q", r.Method)
+ }
+ w.Header().Set("Content-Length", "1235")
+ },
+ }.run(t)
+}
+
+func TestH12_Head_ImplicitLen(t *testing.T) {
+ h12Compare{
+ ReqFunc: (*Client).Head,
+ Handler: func(w ResponseWriter, r *Request) {
+ if r.Method != "HEAD" {
+ t.Errorf("unexpected method %q", r.Method)
+ }
+ io.WriteString(w, "foo")
+ },
+ }.run(t)
+}
+
+func TestH12_HandlerWritesTooLittle(t *testing.T) {
+ h12Compare{
+ Handler: func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Length", "3")
+ io.WriteString(w, "12") // one byte short
+ },
+ CheckResponse: func(proto string, res *Response) {
+ sr, ok := res.Body.(slurpResult)
+ if !ok {
+ t.Errorf("%s body is %T; want slurpResult", proto, res.Body)
+ return
+ }
+ if sr.err != io.ErrUnexpectedEOF {
+ t.Errorf("%s read error = %v; want io.ErrUnexpectedEOF", proto, sr.err)
+ }
+ if string(sr.body) != "12" {
+ t.Errorf("%s body = %q; want %q", proto, sr.body, "12")
+ }
+ },
+ }.run(t)
+}
+
+// Tests that the HTTP/1 and HTTP/2 servers prevent handlers from
+// writing more than they declared. This test does not test whether
+// the transport deals with too much data, though, since the server
+// doesn't make it possible to send bogus data. For those tests, see
+// transport_test.go (for HTTP/1) or x/net/http2/transport_test.go
+// (for HTTP/2).
+func TestH12_HandlerWritesTooMuch(t *testing.T) {
+ h12Compare{
+ Handler: func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Length", "3")
+ w.(Flusher).Flush()
+ io.WriteString(w, "123")
+ w.(Flusher).Flush()
+ n, err := io.WriteString(w, "x") // too many
+ if n > 0 || err == nil {
+ t.Errorf("for proto %q, final write = %v, %v; want 0, some error", r.Proto, n, err)
+ }
+ },
+ }.run(t)
+}
+
+// Verify that both our HTTP/1 and HTTP/2 request and auto-decompress gzip.
+// Some hosts send gzip even if you don't ask for it; see golang.org/issue/13298
+func TestH12_AutoGzip(t *testing.T) {
+ h12Compare{
+ Handler: func(w ResponseWriter, r *Request) {
+ if ae := r.Header.Get("Accept-Encoding"); ae != "gzip" {
+ t.Errorf("%s Accept-Encoding = %q; want gzip", r.Proto, ae)
+ }
+ w.Header().Set("Content-Encoding", "gzip")
+ gz := gzip.NewWriter(w)
+ io.WriteString(gz, "I am some gzipped content. Go go go go go go go go go go go go should compress well.")
+ gz.Close()
+ },
+ }.run(t)
+}
+
+func TestH12_AutoGzip_Disabled(t *testing.T) {
+ h12Compare{
+ Opts: []interface{}{
+ func(tr *Transport) { tr.DisableCompression = true },
+ },
+ Handler: func(w ResponseWriter, r *Request) {
+ fmt.Fprintf(w, "%q", r.Header["Accept-Encoding"])
+ if ae := r.Header.Get("Accept-Encoding"); ae != "" {
+ t.Errorf("%s Accept-Encoding = %q; want empty", r.Proto, ae)
+ }
+ },
+ }.run(t)
+}
+
+// Test304Responses verifies that 304s don't declare that they're
+// chunking in their response headers and aren't allowed to produce
+// output.
+func Test304Responses_h1(t *testing.T) { test304Responses(t, h1Mode) }
+func Test304Responses_h2(t *testing.T) { test304Responses(t, h2Mode) }
+
+func test304Responses(t *testing.T, h2 bool) {
+ defer afterTest(t)
+ cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.WriteHeader(StatusNotModified)
+ _, err := w.Write([]byte("illegal body"))
+ if err != ErrBodyNotAllowed {
+ t.Errorf("on Write, expected ErrBodyNotAllowed, got %v", err)
+ }
+ }))
+ defer cst.close()
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if len(res.TransferEncoding) > 0 {
+ t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding)
+ }
+ body, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Error(err)
+ }
+ if len(body) > 0 {
+ t.Errorf("got unexpected body %q", string(body))
+ }
+}
+
+func TestH12_ServerEmptyContentLength(t *testing.T) {
+ h12Compare{
+ Handler: func(w ResponseWriter, r *Request) {
+ w.Header()["Content-Type"] = []string{""}
+ io.WriteString(w, "<html><body>hi</body></html>")
+ },
+ }.run(t)
+}
+
+func TestH12_RequestContentLength_Known_NonZero(t *testing.T) {
+ h12requestContentLength(t, func() io.Reader { return strings.NewReader("FOUR") }, 4)
+}
+
+func TestH12_RequestContentLength_Known_Zero(t *testing.T) {
+ h12requestContentLength(t, func() io.Reader { return strings.NewReader("") }, 0)
+}
+
+func TestH12_RequestContentLength_Unknown(t *testing.T) {
+ h12requestContentLength(t, func() io.Reader { return struct{ io.Reader }{strings.NewReader("Stuff")} }, -1)
+}
+
+func h12requestContentLength(t *testing.T, bodyfn func() io.Reader, wantLen int64) {
+ h12Compare{
+ Handler: func(w ResponseWriter, r *Request) {
+ w.Header().Set("Got-Length", fmt.Sprint(r.ContentLength))
+ fmt.Fprintf(w, "Req.ContentLength=%v", r.ContentLength)
+ },
+ ReqFunc: func(c *Client, url string) (*Response, error) {
+ return c.Post(url, "text/plain", bodyfn())
+ },
+ CheckResponse: func(proto string, res *Response) {
+ if got, want := res.Header.Get("Got-Length"), fmt.Sprint(wantLen); got != want {
+ t.Errorf("Proto %q got length %q; want %q", proto, got, want)
+ }
+ },
+ }.run(t)
+}
+
+// Tests that closing the Request.Cancel channel also while still
+// reading the response body. Issue 13159.
+func TestCancelRequestMidBody_h1(t *testing.T) { testCancelRequestMidBody(t, h1Mode) }
+func TestCancelRequestMidBody_h2(t *testing.T) { testCancelRequestMidBody(t, h2Mode) }
+func testCancelRequestMidBody(t *testing.T, h2 bool) {
+ defer afterTest(t)
+ unblock := make(chan bool)
+ didFlush := make(chan bool, 1)
+ cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
+ io.WriteString(w, "Hello")
+ w.(Flusher).Flush()
+ didFlush <- true
+ <-unblock
+ io.WriteString(w, ", world.")
+ }))
+ defer cst.close()
+ defer close(unblock)
+
+ req, _ := NewRequest("GET", cst.ts.URL, nil)
+ cancel := make(chan struct{})
+ req.Cancel = cancel
+
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ <-didFlush
+
+ // Read a bit before we cancel. (Issue 13626)
+ // We should have "Hello" at least sitting there.
+ firstRead := make([]byte, 10)
+ n, err := res.Body.Read(firstRead)
+ if err != nil {
+ t.Fatal(err)
+ }
+ firstRead = firstRead[:n]
+
+ close(cancel)
+
+ rest, err := ioutil.ReadAll(res.Body)
+ all := string(firstRead) + string(rest)
+ if all != "Hello" {
+ t.Errorf("Read %q (%q + %q); want Hello", all, firstRead, rest)
+ }
+ if !reflect.DeepEqual(err, ExportErrRequestCanceled) {
+ t.Errorf("ReadAll error = %v; want %v", err, ExportErrRequestCanceled)
+ }
+}
+
+// Tests that clients can send trailers to a server and that the server can read them.
+func TestTrailersClientToServer_h1(t *testing.T) { testTrailersClientToServer(t, h1Mode) }
+func TestTrailersClientToServer_h2(t *testing.T) { testTrailersClientToServer(t, h2Mode) }
+
+func testTrailersClientToServer(t *testing.T, h2 bool) {
+ defer afterTest(t)
+ cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
+ var decl []string
+ for k := range r.Trailer {
+ decl = append(decl, k)
+ }
+ sort.Strings(decl)
+
+ slurp, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ t.Errorf("Server reading request body: %v", err)
+ }
+ if string(slurp) != "foo" {
+ t.Errorf("Server read request body %q; want foo", slurp)
+ }
+ if r.Trailer == nil {
+ io.WriteString(w, "nil Trailer")
+ } else {
+ fmt.Fprintf(w, "decl: %v, vals: %s, %s",
+ decl,
+ r.Trailer.Get("Client-Trailer-A"),
+ r.Trailer.Get("Client-Trailer-B"))
+ }
+ }))
+ defer cst.close()
+
+ var req *Request
+ req, _ = NewRequest("POST", cst.ts.URL, io.MultiReader(
+ eofReaderFunc(func() {
+ req.Trailer["Client-Trailer-A"] = []string{"valuea"}
+ }),
+ strings.NewReader("foo"),
+ eofReaderFunc(func() {
+ req.Trailer["Client-Trailer-B"] = []string{"valueb"}
+ }),
+ ))
+ req.Trailer = Header{
+ "Client-Trailer-A": nil, // to be set later
+ "Client-Trailer-B": nil, // to be set later
+ }
+ req.ContentLength = -1
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if err := wantBody(res, err, "decl: [Client-Trailer-A Client-Trailer-B], vals: valuea, valueb"); err != nil {
+ t.Error(err)
+ }
+}
+
+// Tests that servers send trailers to a client and that the client can read them.
+func TestTrailersServerToClient_h1(t *testing.T) { testTrailersServerToClient(t, h1Mode, false) }
+func TestTrailersServerToClient_h2(t *testing.T) { testTrailersServerToClient(t, h2Mode, false) }
+func TestTrailersServerToClient_Flush_h1(t *testing.T) { testTrailersServerToClient(t, h1Mode, true) }
+func TestTrailersServerToClient_Flush_h2(t *testing.T) { testTrailersServerToClient(t, h2Mode, true) }
+
+func testTrailersServerToClient(t *testing.T, h2, flush bool) {
+ defer afterTest(t)
+ const body = "Some body"
+ cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Trailer", "Server-Trailer-A, Server-Trailer-B")
+ w.Header().Add("Trailer", "Server-Trailer-C")
+
+ io.WriteString(w, body)
+ if flush {
+ w.(Flusher).Flush()
+ }
+
+ // How handlers set Trailers: declare it ahead of time
+ // with the Trailer header, and then mutate the
+ // Header() of those values later, after the response
+ // has been written (we wrote to w above).
+ w.Header().Set("Server-Trailer-A", "valuea")
+ w.Header().Set("Server-Trailer-C", "valuec") // skipping B
+ w.Header().Set("Server-Trailer-NotDeclared", "should be omitted")
+ }))
+ defer cst.close()
+
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ wantHeader := Header{
+ "Content-Type": {"text/plain; charset=utf-8"},
+ }
+ wantLen := -1
+ if h2 && !flush {
+ // In HTTP/1.1, any use of trailers forces HTTP/1.1
+ // chunking and a flush at the first write. That's
+ // unnecessary with HTTP/2's framing, so the server
+ // is able to calculate the length while still sending
+ // trailers afterwards.
+ wantLen = len(body)
+ wantHeader["Content-Length"] = []string{fmt.Sprint(wantLen)}
+ }
+ if res.ContentLength != int64(wantLen) {
+ t.Errorf("ContentLength = %v; want %v", res.ContentLength, wantLen)
+ }
+
+ delete(res.Header, "Date") // irrelevant for test
+ if !reflect.DeepEqual(res.Header, wantHeader) {
+ t.Errorf("Header = %v; want %v", res.Header, wantHeader)
+ }
+
+ if got, want := res.Trailer, (Header{
+ "Server-Trailer-A": nil,
+ "Server-Trailer-B": nil,
+ "Server-Trailer-C": nil,
+ }); !reflect.DeepEqual(got, want) {
+ t.Errorf("Trailer before body read = %v; want %v", got, want)
+ }
+
+ if err := wantBody(res, nil, body); err != nil {
+ t.Fatal(err)
+ }
+
+ if got, want := res.Trailer, (Header{
+ "Server-Trailer-A": {"valuea"},
+ "Server-Trailer-B": nil,
+ "Server-Trailer-C": {"valuec"},
+ }); !reflect.DeepEqual(got, want) {
+ t.Errorf("Trailer after body read = %v; want %v", got, want)
+ }
+}
+
+// Don't allow a Body.Read after Body.Close. Issue 13648.
+func TestResponseBodyReadAfterClose_h1(t *testing.T) { testResponseBodyReadAfterClose(t, h1Mode) }
+func TestResponseBodyReadAfterClose_h2(t *testing.T) { testResponseBodyReadAfterClose(t, h2Mode) }
+
+func testResponseBodyReadAfterClose(t *testing.T, h2 bool) {
+ defer afterTest(t)
+ const body = "Some body"
+ cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
+ io.WriteString(w, body)
+ }))
+ defer cst.close()
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ data, err := ioutil.ReadAll(res.Body)
+ if len(data) != 0 || err == nil {
+ t.Fatalf("ReadAll returned %q, %v; want error", data, err)
+ }
+}
+
+func TestConcurrentReadWriteReqBody_h1(t *testing.T) { testConcurrentReadWriteReqBody(t, h1Mode) }
+func TestConcurrentReadWriteReqBody_h2(t *testing.T) { testConcurrentReadWriteReqBody(t, h2Mode) }
+func testConcurrentReadWriteReqBody(t *testing.T, h2 bool) {
+ defer afterTest(t)
+ const reqBody = "some request body"
+ const resBody = "some response body"
+ cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
+ var wg sync.WaitGroup
+ wg.Add(2)
+ didRead := make(chan bool, 1)
+ // Read in one goroutine.
+ go func() {
+ defer wg.Done()
+ data, err := ioutil.ReadAll(r.Body)
+ if string(data) != reqBody {
+ t.Errorf("Handler read %q; want %q", data, reqBody)
+ }
+ if err != nil {
+ t.Errorf("Handler Read: %v", err)
+ }
+ didRead <- true
+ }()
+ // Write in another goroutine.
+ go func() {
+ defer wg.Done()
+ if !h2 {
+ // our HTTP/1 implementation intentionally
+ // doesn't permit writes during read (mostly
+ // due to it being undefined); if that is ever
+ // relaxed, change this.
+ <-didRead
+ }
+ io.WriteString(w, resBody)
+ }()
+ wg.Wait()
+ }))
+ defer cst.close()
+ req, _ := NewRequest("POST", cst.ts.URL, strings.NewReader(reqBody))
+ req.Header.Add("Expect", "100-continue") // just to complicate things
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ data, err := ioutil.ReadAll(res.Body)
+ defer res.Body.Close()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if string(data) != resBody {
+ t.Errorf("read %q; want %q", data, resBody)
+ }
+}
+
+func TestConnectRequest_h1(t *testing.T) { testConnectRequest(t, h1Mode) }
+func TestConnectRequest_h2(t *testing.T) { testConnectRequest(t, h2Mode) }
+func testConnectRequest(t *testing.T, h2 bool) {
+ defer afterTest(t)
+ gotc := make(chan *Request, 1)
+ cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
+ gotc <- r
+ }))
+ defer cst.close()
+
+ u, err := url.Parse(cst.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ tests := []struct {
+ req *Request
+ want string
+ }{
+ {
+ req: &Request{
+ Method: "CONNECT",
+ Header: Header{},
+ URL: u,
+ },
+ want: u.Host,
+ },
+ {
+ req: &Request{
+ Method: "CONNECT",
+ Header: Header{},
+ URL: u,
+ Host: "example.com:123",
+ },
+ want: "example.com:123",
+ },
+ }
+
+ for i, tt := range tests {
+ res, err := cst.c.Do(tt.req)
+ if err != nil {
+ t.Errorf("%d. RoundTrip = %v", i, err)
+ continue
+ }
+ res.Body.Close()
+ req := <-gotc
+ if req.Method != "CONNECT" {
+ t.Errorf("method = %q; want CONNECT", req.Method)
+ }
+ if req.Host != tt.want {
+ t.Errorf("Host = %q; want %q", req.Host, tt.want)
+ }
+ if req.URL.Host != tt.want {
+ t.Errorf("URL.Host = %q; want %q", req.URL.Host, tt.want)
+ }
+ }
+}
+
+func TestTransportUserAgent_h1(t *testing.T) { testTransportUserAgent(t, h1Mode) }
+func TestTransportUserAgent_h2(t *testing.T) { testTransportUserAgent(t, h2Mode) }
+func testTransportUserAgent(t *testing.T, h2 bool) {
+ defer afterTest(t)
+ cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
+ fmt.Fprintf(w, "%q", r.Header["User-Agent"])
+ }))
+ defer cst.close()
+
+ either := func(a, b string) string {
+ if h2 {
+ return b
+ }
+ return a
+ }
+
+ tests := []struct {
+ setup func(*Request)
+ want string
+ }{
+ {
+ func(r *Request) {},
+ either(`["Go-http-client/1.1"]`, `["Go-http-client/2.0"]`),
+ },
+ {
+ func(r *Request) { r.Header.Set("User-Agent", "foo/1.2.3") },
+ `["foo/1.2.3"]`,
+ },
+ {
+ func(r *Request) { r.Header["User-Agent"] = []string{"single", "or", "multiple"} },
+ `["single"]`,
+ },
+ {
+ func(r *Request) { r.Header.Set("User-Agent", "") },
+ `[]`,
+ },
+ {
+ func(r *Request) { r.Header["User-Agent"] = nil },
+ `[]`,
+ },
+ }
+ for i, tt := range tests {
+ req, _ := NewRequest("GET", cst.ts.URL, nil)
+ tt.setup(req)
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Errorf("%d. RoundTrip = %v", i, err)
+ continue
+ }
+ slurp, err := ioutil.ReadAll(res.Body)
+ res.Body.Close()
+ if err != nil {
+ t.Errorf("%d. read body = %v", i, err)
+ continue
+ }
+ if string(slurp) != tt.want {
+ t.Errorf("%d. body mismatch.\n got: %s\nwant: %s\n", i, slurp, tt.want)
+ }
+ }
+}
+
+func TestStarRequestFoo_h1(t *testing.T) { testStarRequest(t, "FOO", h1Mode) }
+func TestStarRequestFoo_h2(t *testing.T) { testStarRequest(t, "FOO", h2Mode) }
+func TestStarRequestOptions_h1(t *testing.T) { testStarRequest(t, "OPTIONS", h1Mode) }
+func TestStarRequestOptions_h2(t *testing.T) { testStarRequest(t, "OPTIONS", h2Mode) }
+func testStarRequest(t *testing.T, method string, h2 bool) {
+ defer afterTest(t)
+ gotc := make(chan *Request, 1)
+ cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("foo", "bar")
+ gotc <- r
+ w.(Flusher).Flush()
+ }))
+ defer cst.close()
+
+ u, err := url.Parse(cst.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ u.Path = "*"
+
+ req := &Request{
+ Method: method,
+ Header: Header{},
+ URL: u,
+ }
+
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatalf("RoundTrip = %v", err)
+ }
+ res.Body.Close()
+
+ wantFoo := "bar"
+ wantLen := int64(-1)
+ if method == "OPTIONS" {
+ wantFoo = ""
+ wantLen = 0
+ }
+ if res.StatusCode != 200 {
+ t.Errorf("status code = %v; want %d", res.Status, 200)
+ }
+ if res.ContentLength != wantLen {
+ t.Errorf("content length = %v; want %d", res.ContentLength, wantLen)
+ }
+ if got := res.Header.Get("foo"); got != wantFoo {
+ t.Errorf("response \"foo\" header = %q; want %q", got, wantFoo)
+ }
+ select {
+ case req = <-gotc:
+ default:
+ req = nil
+ }
+ if req == nil {
+ if method != "OPTIONS" {
+ t.Fatalf("handler never got request")
+ }
+ return
+ }
+ if req.Method != method {
+ t.Errorf("method = %q; want %q", req.Method, method)
+ }
+ if req.URL.Path != "*" {
+ t.Errorf("URL.Path = %q; want *", req.URL.Path)
+ }
+ if req.RequestURI != "*" {
+ t.Errorf("RequestURI = %q; want *", req.RequestURI)
+ }
+}
+
+// Issue 13957
+func TestTransportDiscardsUnneededConns(t *testing.T) {
+ defer afterTest(t)
+ cst := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ fmt.Fprintf(w, "Hello, %v", r.RemoteAddr)
+ }))
+ defer cst.close()
+
+ var numOpen, numClose int32 // atomic
+
+ tlsConfig := &tls.Config{InsecureSkipVerify: true}
+ tr := &Transport{
+ TLSClientConfig: tlsConfig,
+ DialTLS: func(_, addr string) (net.Conn, error) {
+ time.Sleep(10 * time.Millisecond)
+ rc, err := net.Dial("tcp", addr)
+ if err != nil {
+ return nil, err
+ }
+ atomic.AddInt32(&numOpen, 1)
+ c := noteCloseConn{rc, func() { atomic.AddInt32(&numClose, 1) }}
+ return tls.Client(c, tlsConfig), nil
+ },
+ }
+ if err := ExportHttp2ConfigureTransport(tr); err != nil {
+ t.Fatal(err)
+ }
+ defer tr.CloseIdleConnections()
+
+ c := &Client{Transport: tr}
+
+ const N = 10
+ gotBody := make(chan string, N)
+ var wg sync.WaitGroup
+ for i := 0; i < N; i++ {
+ wg.Add(1)
+ go func() {
+ defer wg.Done()
+ resp, err := c.Get(cst.ts.URL)
+ if err != nil {
+ t.Errorf("Get: %v", err)
+ return
+ }
+ defer resp.Body.Close()
+ slurp, err := ioutil.ReadAll(resp.Body)
+ if err != nil {
+ t.Error(err)
+ }
+ gotBody <- string(slurp)
+ }()
+ }
+ wg.Wait()
+ close(gotBody)
+
+ var last string
+ for got := range gotBody {
+ if last == "" {
+ last = got
+ continue
+ }
+ if got != last {
+ t.Errorf("Response body changed: %q -> %q", last, got)
+ }
+ }
+
+ var open, close int32
+ for i := 0; i < 150; i++ {
+ open, close = atomic.LoadInt32(&numOpen), atomic.LoadInt32(&numClose)
+ if open < 1 {
+ t.Fatalf("open = %d; want at least", open)
+ }
+ if close == open-1 {
+ // Success
+ return
+ }
+ time.Sleep(10 * time.Millisecond)
+ }
+ t.Errorf("%d connections opened, %d closed; want %d to close", open, close, open-1)
+}
+
+// tests that Transport doesn't retain a pointer to the provided request.
+func TestTransportGCRequest_h1(t *testing.T) { testTransportGCRequest(t, h1Mode) }
+func TestTransportGCRequest_h2(t *testing.T) { testTransportGCRequest(t, h2Mode) }
+func testTransportGCRequest(t *testing.T, h2 bool) {
+ if runtime.Compiler == "gccgo" {
+ t.Skip("skipping on gccgo because conservative GC means that finalizer may never run")
+ }
+
+ defer afterTest(t)
+ cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
+ ioutil.ReadAll(r.Body)
+ io.WriteString(w, "Hello.")
+ }))
+ defer cst.close()
+
+ didGC := make(chan struct{})
+ (func() {
+ body := strings.NewReader("some body")
+ req, _ := NewRequest("POST", cst.ts.URL, body)
+ runtime.SetFinalizer(req, func(*Request) { close(didGC) })
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if _, err := ioutil.ReadAll(res.Body); err != nil {
+ t.Fatal(err)
+ }
+ if err := res.Body.Close(); err != nil {
+ t.Fatal(err)
+ }
+ })()
+ timeout := time.NewTimer(5 * time.Second)
+ defer timeout.Stop()
+ for {
+ select {
+ case <-didGC:
+ return
+ case <-time.After(100 * time.Millisecond):
+ runtime.GC()
+ case <-timeout.C:
+ t.Fatal("never saw GC of request")
+ }
+ }
+}
+
+type noteCloseConn struct {
+ net.Conn
+ closeFunc func()
+}
+
+func (x noteCloseConn) Close() error {
+ x.closeFunc()
+ return x.Conn.Close()
+}
diff --git a/libgo/go/net/http/doc.go b/libgo/go/net/http/doc.go
index b1216e8dafa..4ec8272f628 100644
--- a/libgo/go/net/http/doc.go
+++ b/libgo/go/net/http/doc.go
@@ -76,5 +76,20 @@ custom Server:
MaxHeaderBytes: 1 << 20,
}
log.Fatal(s.ListenAndServe())
+
+The http package has transparent support for the HTTP/2 protocol when
+using HTTPS. Programs that must disable HTTP/2 can do so by setting
+Transport.TLSNextProto (for clients) or Server.TLSNextProto (for
+servers) to a non-nil, empty map. Alternatively, the following GODEBUG
+environment variables are currently supported:
+
+ GODEBUG=http2client=0 # disable HTTP/2 client support
+ GODEBUG=http2server=0 # disable HTTP/2 server support
+ GODEBUG=http2debug=1 # enable verbose HTTP/2 debug logs
+ GODEBUG=http2debug=2 # ... even more verbose, with frame dumps
+
+The GODEBUG variables are not covered by Go's API compatibility promise.
+HTTP/2 support was added in Go 1.6. Please report any issues instead of
+disabling HTTP/2 support: https://golang.org/s/http2bug
*/
package http
diff --git a/libgo/go/net/http/export_test.go b/libgo/go/net/http/export_test.go
index 0457be50da6..52bccbdce31 100644
--- a/libgo/go/net/http/export_test.go
+++ b/libgo/go/net/http/export_test.go
@@ -9,11 +9,24 @@ package http
import (
"net"
- "net/url"
"sync"
"time"
)
+var (
+ DefaultUserAgent = defaultUserAgent
+ NewLoggingConn = newLoggingConn
+ ExportAppendTime = appendTime
+ ExportRefererForURL = refererForURL
+ ExportServerNewConn = (*Server).newConn
+ ExportCloseWriteAndWait = (*conn).closeWriteAndWait
+ ExportErrRequestCanceled = errRequestCanceled
+ ExportErrRequestCanceledConn = errRequestCanceledConn
+ ExportServeFile = serveFile
+ ExportHttp2ConfigureTransport = http2ConfigureTransport
+ ExportHttp2ConfigureServer = http2ConfigureServer
+)
+
func init() {
// We only want to pay for this cost during testing.
// When not under test, these values are always nil
@@ -21,11 +34,42 @@ func init() {
testHookMu = new(sync.Mutex)
}
-func NewLoggingConn(baseName string, c net.Conn) net.Conn {
- return newLoggingConn(baseName, c)
+var (
+ SetEnterRoundTripHook = hookSetter(&testHookEnterRoundTrip)
+ SetTestHookWaitResLoop = hookSetter(&testHookWaitResLoop)
+ SetRoundTripRetried = hookSetter(&testHookRoundTripRetried)
+)
+
+func SetReadLoopBeforeNextReadHook(f func()) {
+ testHookMu.Lock()
+ defer testHookMu.Unlock()
+ unnilTestHook(&f)
+ testHookReadLoopBeforeNextRead = f
+}
+
+// SetPendingDialHooks sets the hooks that run before and after handling
+// pending dials.
+func SetPendingDialHooks(before, after func()) {
+ unnilTestHook(&before)
+ unnilTestHook(&after)
+ testHookPrePendingDial, testHookPostPendingDial = before, after
}
-var ExportAppendTime = appendTime
+func SetTestHookServerServe(fn func(*Server, net.Listener)) { testHookServerServe = fn }
+
+func NewTestTimeoutHandler(handler Handler, ch <-chan time.Time) Handler {
+ return &timeoutHandler{
+ handler: handler,
+ timeout: func() <-chan time.Time { return ch },
+ // (no body and nil cancelTimer)
+ }
+}
+
+func ResetCachedEnvironment() {
+ httpProxyEnv.reset()
+ httpsProxyEnv.reset()
+ noProxyEnv.reset()
+}
func (t *Transport) NumPendingRequestsForTesting() int {
t.reqMu.Lock()
@@ -78,55 +122,25 @@ func (t *Transport) RequestIdleConnChForTesting() {
func (t *Transport) PutIdleTestConn() bool {
c, _ := net.Pipe()
- return t.putIdleConn(&persistConn{
+ return t.tryPutIdleConn(&persistConn{
t: t,
conn: c, // dummy
closech: make(chan struct{}), // so it can be closed
cacheKey: connectMethodKey{"", "http", "example.com"},
- })
-}
-
-func SetInstallConnClosedHook(f func()) {
- testHookPersistConnClosedGotRes = f
+ }) == nil
}
-func SetEnterRoundTripHook(f func()) {
- testHookEnterRoundTrip = f
-}
-
-func SetReadLoopBeforeNextReadHook(f func()) {
- testHookMu.Lock()
- defer testHookMu.Unlock()
- testHookReadLoopBeforeNextRead = f
-}
-
-func NewTestTimeoutHandler(handler Handler, ch <-chan time.Time) Handler {
- f := func() <-chan time.Time {
- return ch
+// All test hooks must be non-nil so they can be called directly,
+// but the tests use nil to mean hook disabled.
+func unnilTestHook(f *func()) {
+ if *f == nil {
+ *f = nop
}
- return &timeoutHandler{handler, f, ""}
}
-func ResetCachedEnvironment() {
- httpProxyEnv.reset()
- httpsProxyEnv.reset()
- noProxyEnv.reset()
-}
-
-var DefaultUserAgent = defaultUserAgent
-
-func ExportRefererForURL(lastReq, newReq *url.URL) string {
- return refererForURL(lastReq, newReq)
-}
-
-// SetPendingDialHooks sets the hooks that run before and after handling
-// pending dials.
-func SetPendingDialHooks(before, after func()) {
- prePendingDial, postPendingDial = before, after
+func hookSetter(dst *func()) func(func()) {
+ return func(fn func()) {
+ unnilTestHook(&fn)
+ *dst = fn
+ }
}
-
-var ExportServerNewConn = (*Server).newConn
-
-var ExportCloseWriteAndWait = (*conn).closeWriteAndWait
-
-var ExportErrRequestCanceled = errRequestCanceled
diff --git a/libgo/go/net/http/fcgi/child.go b/libgo/go/net/http/fcgi/child.go
index da824ed717e..88704245db8 100644
--- a/libgo/go/net/http/fcgi/child.go
+++ b/libgo/go/net/http/fcgi/child.go
@@ -56,6 +56,9 @@ func (r *request) parseParams() {
return
}
text = text[n:]
+ if int(keyLen)+int(valLen) > len(text) {
+ return
+ }
key := readString(text, keyLen)
text = text[keyLen:]
val := readString(text, valLen)
diff --git a/libgo/go/net/http/fcgi/fcgi_test.go b/libgo/go/net/http/fcgi/fcgi_test.go
index de0f7f831f6..b6013bfdd51 100644
--- a/libgo/go/net/http/fcgi/fcgi_test.go
+++ b/libgo/go/net/http/fcgi/fcgi_test.go
@@ -254,3 +254,27 @@ func TestChildServeCleansUp(t *testing.T) {
<-done
}
}
+
+type rwNopCloser struct {
+ io.Reader
+ io.Writer
+}
+
+func (rwNopCloser) Close() error {
+ return nil
+}
+
+// Verifies it doesn't crash. Issue 11824.
+func TestMalformedParams(t *testing.T) {
+ input := []byte{
+ // beginRequest, requestId=1, contentLength=8, role=1, keepConn=1
+ 1, 1, 0, 1, 0, 8, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0,
+ // params, requestId=1, contentLength=10, k1Len=50, v1Len=50 (malformed, wrong length)
+ 1, 4, 0, 1, 0, 10, 0, 0, 50, 50, 3, 4, 5, 6, 7, 8, 9, 10,
+ // end of params
+ 1, 4, 0, 1, 0, 0, 0, 0,
+ }
+ rw := rwNopCloser{bytes.NewReader(input), ioutil.Discard}
+ c := newChild(rw, http.DefaultServeMux)
+ c.serve()
+}
diff --git a/libgo/go/net/http/fs.go b/libgo/go/net/http/fs.go
index 75720234c25..f61c138c1d9 100644
--- a/libgo/go/net/http/fs.go
+++ b/libgo/go/net/http/fs.go
@@ -17,6 +17,7 @@ import (
"os"
"path"
"path/filepath"
+ "sort"
"strconv"
"strings"
"time"
@@ -62,30 +63,34 @@ type FileSystem interface {
type File interface {
io.Closer
io.Reader
+ io.Seeker
Readdir(count int) ([]os.FileInfo, error)
- Seek(offset int64, whence int) (int64, error)
Stat() (os.FileInfo, error)
}
func dirList(w ResponseWriter, f File) {
+ dirs, err := f.Readdir(-1)
+ if err != nil {
+ // TODO: log err.Error() to the Server.ErrorLog, once it's possible
+ // for a handler to get at its Server via the ResponseWriter. See
+ // Issue 12438.
+ Error(w, "Error reading directory", StatusInternalServerError)
+ return
+ }
+ sort.Sort(byName(dirs))
+
w.Header().Set("Content-Type", "text/html; charset=utf-8")
fmt.Fprintf(w, "<pre>\n")
- for {
- dirs, err := f.Readdir(100)
- if err != nil || len(dirs) == 0 {
- break
- }
- for _, d := range dirs {
- name := d.Name()
- if d.IsDir() {
- name += "/"
- }
- // name may contain '?' or '#', which must be escaped to remain
- // part of the URL path, and not indicate the start of a query
- // string or fragment.
- url := url.URL{Path: name}
- fmt.Fprintf(w, "<a href=\"%s\">%s</a>\n", url.String(), htmlReplacer.Replace(name))
+ for _, d := range dirs {
+ name := d.Name()
+ if d.IsDir() {
+ name += "/"
}
+ // name may contain '?' or '#', which must be escaped to remain
+ // part of the URL path, and not indicate the start of a query
+ // string or fragment.
+ url := url.URL{Path: name}
+ fmt.Fprintf(w, "<a href=\"%s\">%s</a>\n", url.String(), htmlReplacer.Replace(name))
}
fmt.Fprintf(w, "</pre>\n")
}
@@ -364,8 +369,8 @@ func serveFile(w ResponseWriter, r *Request, fs FileSystem, name string, redirec
}
defer f.Close()
- d, err1 := f.Stat()
- if err1 != nil {
+ d, err := f.Stat()
+ if err != nil {
msg, code := toHTTPError(err)
Error(w, msg, code)
return
@@ -446,15 +451,44 @@ func localRedirect(w ResponseWriter, r *Request, newPath string) {
// ServeFile replies to the request with the contents of the named
// file or directory.
//
+// If the provided file or direcory name is a relative path, it is
+// interpreted relative to the current directory and may ascend to parent
+// directories. If the provided name is constructed from user input, it
+// should be sanitized before calling ServeFile. As a precaution, ServeFile
+// will reject requests where r.URL.Path contains a ".." path element.
+//
// As a special case, ServeFile redirects any request where r.URL.Path
// ends in "/index.html" to the same path, without the final
// "index.html". To avoid such redirects either modify the path or
// use ServeContent.
func ServeFile(w ResponseWriter, r *Request, name string) {
+ if containsDotDot(r.URL.Path) {
+ // Too many programs use r.URL.Path to construct the argument to
+ // serveFile. Reject the request under the assumption that happened
+ // here and ".." may not be wanted.
+ // Note that name might not contain "..", for example if code (still
+ // incorrectly) used filepath.Join(myDir, r.URL.Path).
+ Error(w, "invalid URL path", StatusBadRequest)
+ return
+ }
dir, file := filepath.Split(name)
serveFile(w, r, Dir(dir), file, false)
}
+func containsDotDot(v string) bool {
+ if !strings.Contains(v, "..") {
+ return false
+ }
+ for _, ent := range strings.FieldsFunc(v, isSlashRune) {
+ if ent == ".." {
+ return true
+ }
+ }
+ return false
+}
+
+func isSlashRune(r rune) bool { return r == '/' || r == '\\' }
+
type fileHandler struct {
root FileSystem
}
@@ -585,3 +619,9 @@ func sumRangesSize(ranges []httpRange) (size int64) {
}
return
}
+
+type byName []os.FileInfo
+
+func (s byName) Len() int { return len(s) }
+func (s byName) Less(i, j int) bool { return s[i].Name() < s[j].Name() }
+func (s byName) Swap(i, j int) { s[i], s[j] = s[j], s[i] }
diff --git a/libgo/go/net/http/fs_test.go b/libgo/go/net/http/fs_test.go
index 538f34d7201..cf5b63c9f75 100644
--- a/libgo/go/net/http/fs_test.go
+++ b/libgo/go/net/http/fs_test.go
@@ -5,6 +5,7 @@
package http_test
import (
+ "bufio"
"bytes"
"errors"
"fmt"
@@ -177,6 +178,36 @@ Cases:
}
}
+func TestServeFile_DotDot(t *testing.T) {
+ tests := []struct {
+ req string
+ wantStatus int
+ }{
+ {"/testdata/file", 200},
+ {"/../file", 400},
+ {"/..", 400},
+ {"/../", 400},
+ {"/../foo", 400},
+ {"/..\\foo", 400},
+ {"/file/a", 200},
+ {"/file/a..", 200},
+ {"/file/a/..", 400},
+ {"/file/a\\..", 400},
+ }
+ for _, tt := range tests {
+ req, err := ReadRequest(bufio.NewReader(strings.NewReader("GET " + tt.req + " HTTP/1.1\r\nHost: foo\r\n\r\n")))
+ if err != nil {
+ t.Errorf("bad request %q: %v", tt.req, err)
+ continue
+ }
+ rec := httptest.NewRecorder()
+ ServeFile(rec, req, "testdata/file")
+ if rec.Code != tt.wantStatus {
+ t.Errorf("for request %q, status = %d; want %d", tt.req, rec.Code, tt.wantStatus)
+ }
+ }
+}
+
var fsRedirectTestData = []struct {
original, redirect string
}{
@@ -283,6 +314,49 @@ func TestFileServerEscapesNames(t *testing.T) {
}
}
+func TestFileServerSortsNames(t *testing.T) {
+ defer afterTest(t)
+ const contents = "I am a fake file"
+ dirMod := time.Unix(123, 0).UTC()
+ fileMod := time.Unix(1000000000, 0).UTC()
+ fs := fakeFS{
+ "/": &fakeFileInfo{
+ dir: true,
+ modtime: dirMod,
+ ents: []*fakeFileInfo{
+ {
+ basename: "b",
+ modtime: fileMod,
+ contents: contents,
+ },
+ {
+ basename: "a",
+ modtime: fileMod,
+ contents: contents,
+ },
+ },
+ },
+ }
+
+ ts := httptest.NewServer(FileServer(&fs))
+ defer ts.Close()
+
+ res, err := Get(ts.URL)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ defer res.Body.Close()
+
+ b, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatalf("read Body: %v", err)
+ }
+ s := string(b)
+ if !strings.Contains(s, "<a href=\"a\">a</a>\n<a href=\"b\">b</a>") {
+ t.Errorf("output appears to be unsorted:\n%s", s)
+ }
+}
+
func mustRemoveAll(dir string) {
err := os.RemoveAll(dir)
if err != nil {
@@ -434,14 +508,27 @@ func TestServeFileFromCWD(t *testing.T) {
}
}
-func TestServeFileWithContentEncoding(t *testing.T) {
+// Tests that ServeFile doesn't add a Content-Length if a Content-Encoding is
+// specified.
+func TestServeFileWithContentEncoding_h1(t *testing.T) { testServeFileWithContentEncoding(t, h1Mode) }
+func TestServeFileWithContentEncoding_h2(t *testing.T) { testServeFileWithContentEncoding(t, h2Mode) }
+func testServeFileWithContentEncoding(t *testing.T, h2 bool) {
defer afterTest(t)
- ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
w.Header().Set("Content-Encoding", "foo")
ServeFile(w, r, "testdata/file")
+
+ // Because the testdata is so small, it would fit in
+ // both the h1 and h2 Server's write buffers. For h1,
+ // sendfile is used, though, forcing a header flush at
+ // the io.Copy. http2 doesn't do a header flush so
+ // buffers all 11 bytes and then adds its own
+ // Content-Length. To prevent the Server's
+ // Content-Length and test ServeFile only, flush here.
+ w.(Flusher).Flush()
}))
- defer ts.Close()
- resp, err := Get(ts.URL)
+ defer cst.close()
+ resp, err := cst.c.Get(cst.ts.URL)
if err != nil {
t.Fatal(err)
}
@@ -807,6 +894,28 @@ func TestServeContent(t *testing.T) {
}
}
+// Issue 12991
+func TestServerFileStatError(t *testing.T) {
+ rec := httptest.NewRecorder()
+ r, _ := NewRequest("GET", "http://foo/", nil)
+ redirect := false
+ name := "file.txt"
+ fs := issue12991FS{}
+ ExportServeFile(rec, r, fs, name, redirect)
+ if body := rec.Body.String(); !strings.Contains(body, "403") || !strings.Contains(body, "Forbidden") {
+ t.Errorf("wanted 403 forbidden message; got: %s", body)
+ }
+}
+
+type issue12991FS struct{}
+
+func (issue12991FS) Open(string) (File, error) { return issue12991File{}, nil }
+
+type issue12991File struct{ File }
+
+func (issue12991File) Stat() (os.FileInfo, error) { return nil, os.ErrPermission }
+func (issue12991File) Close() error { return nil }
+
func TestServeContentErrorMessages(t *testing.T) {
defer afterTest(t)
fs := fakeFS{
@@ -852,13 +961,16 @@ func TestLinuxSendfile(t *testing.T) {
}
defer ln.Close()
- trace := "trace=sendfile"
- if runtime.GOARCH != "alpha" {
- trace = trace + ",sendfile64"
+ syscalls := "sendfile,sendfile64"
+ switch runtime.GOARCH {
+ case "mips64", "mips64le", "alpha":
+ // mips64 strace doesn't support sendfile64 and will error out
+ // if we specify that with `-e trace='.
+ syscalls = "sendfile"
}
var buf bytes.Buffer
- child := exec.Command("strace", "-f", "-q", "-e", trace, os.Args[0], "-test.run=TestLinuxSendfileChild")
+ child := exec.Command("strace", "-f", "-q", "-e", "trace="+syscalls, os.Args[0], "-test.run=TestLinuxSendfileChild")
child.ExtraFiles = append(child.ExtraFiles, lnf)
child.Env = append([]string{"GO_WANT_HELPER_PROCESS=1"}, os.Environ()...)
child.Stdout = &buf
@@ -878,7 +990,7 @@ func TestLinuxSendfile(t *testing.T) {
res.Body.Close()
// Force child to exit cleanly.
- Get(fmt.Sprintf("http://%s/quit", ln.Addr()))
+ Post(fmt.Sprintf("http://%s/quit", ln.Addr()), "", nil)
child.Wait()
rx := regexp.MustCompile(`sendfile(64)?\(\d+,\s*\d+,\s*NULL,\s*\d+\)\s*=\s*\d+\s*\n`)
diff --git a/libgo/go/net/http/h2_bundle.go b/libgo/go/net/http/h2_bundle.go
new file mode 100644
index 00000000000..e7236299e22
--- /dev/null
+++ b/libgo/go/net/http/h2_bundle.go
@@ -0,0 +1,6530 @@
+// Code generated by golang.org/x/tools/cmd/bundle command:
+// $ bundle golang.org/x/net/http2 net/http http2
+
+// Package http2 implements the HTTP/2 protocol.
+//
+// This package is low-level and intended to be used directly by very
+// few people. Most users will use it indirectly through the automatic
+// use by the net/http package (from Go 1.6 and later).
+// For use in earlier Go versions see ConfigureServer. (Transport support
+// requires Go 1.6 or later)
+//
+// See https://http2.github.io/ for more information on HTTP/2.
+//
+// See https://http2.golang.org/ for a test server running this code.
+//
+
+package http
+
+import (
+ "bufio"
+ "bytes"
+ "compress/gzip"
+ "crypto/tls"
+ "encoding/binary"
+ "errors"
+ "fmt"
+ "internal/golang.org/x/net/http2/hpack"
+ "io"
+ "io/ioutil"
+ "log"
+ "net"
+ "net/textproto"
+ "net/url"
+ "os"
+ "reflect"
+ "runtime"
+ "sort"
+ "strconv"
+ "strings"
+ "sync"
+ "time"
+)
+
+// ClientConnPool manages a pool of HTTP/2 client connections.
+type http2ClientConnPool interface {
+ GetClientConn(req *Request, addr string) (*http2ClientConn, error)
+ MarkDead(*http2ClientConn)
+}
+
+// TODO: use singleflight for dialing and addConnCalls?
+type http2clientConnPool struct {
+ t *http2Transport
+
+ mu sync.Mutex // TODO: maybe switch to RWMutex
+ // TODO: add support for sharing conns based on cert names
+ // (e.g. share conn for googleapis.com and appspot.com)
+ conns map[string][]*http2ClientConn // key is host:port
+ dialing map[string]*http2dialCall // currently in-flight dials
+ keys map[*http2ClientConn][]string
+ addConnCalls map[string]*http2addConnCall // in-flight addConnIfNeede calls
+}
+
+func (p *http2clientConnPool) GetClientConn(req *Request, addr string) (*http2ClientConn, error) {
+ return p.getClientConn(req, addr, http2dialOnMiss)
+}
+
+const (
+ http2dialOnMiss = true
+ http2noDialOnMiss = false
+)
+
+func (p *http2clientConnPool) getClientConn(_ *Request, addr string, dialOnMiss bool) (*http2ClientConn, error) {
+ p.mu.Lock()
+ for _, cc := range p.conns[addr] {
+ if cc.CanTakeNewRequest() {
+ p.mu.Unlock()
+ return cc, nil
+ }
+ }
+ if !dialOnMiss {
+ p.mu.Unlock()
+ return nil, http2ErrNoCachedConn
+ }
+ call := p.getStartDialLocked(addr)
+ p.mu.Unlock()
+ <-call.done
+ return call.res, call.err
+}
+
+// dialCall is an in-flight Transport dial call to a host.
+type http2dialCall struct {
+ p *http2clientConnPool
+ done chan struct{} // closed when done
+ res *http2ClientConn // valid after done is closed
+ err error // valid after done is closed
+}
+
+// requires p.mu is held.
+func (p *http2clientConnPool) getStartDialLocked(addr string) *http2dialCall {
+ if call, ok := p.dialing[addr]; ok {
+
+ return call
+ }
+ call := &http2dialCall{p: p, done: make(chan struct{})}
+ if p.dialing == nil {
+ p.dialing = make(map[string]*http2dialCall)
+ }
+ p.dialing[addr] = call
+ go call.dial(addr)
+ return call
+}
+
+// run in its own goroutine.
+func (c *http2dialCall) dial(addr string) {
+ c.res, c.err = c.p.t.dialClientConn(addr)
+ close(c.done)
+
+ c.p.mu.Lock()
+ delete(c.p.dialing, addr)
+ if c.err == nil {
+ c.p.addConnLocked(addr, c.res)
+ }
+ c.p.mu.Unlock()
+}
+
+// addConnIfNeeded makes a NewClientConn out of c if a connection for key doesn't
+// already exist. It coalesces concurrent calls with the same key.
+// This is used by the http1 Transport code when it creates a new connection. Because
+// the http1 Transport doesn't de-dup TCP dials to outbound hosts (because it doesn't know
+// the protocol), it can get into a situation where it has multiple TLS connections.
+// This code decides which ones live or die.
+// The return value used is whether c was used.
+// c is never closed.
+func (p *http2clientConnPool) addConnIfNeeded(key string, t *http2Transport, c *tls.Conn) (used bool, err error) {
+ p.mu.Lock()
+ for _, cc := range p.conns[key] {
+ if cc.CanTakeNewRequest() {
+ p.mu.Unlock()
+ return false, nil
+ }
+ }
+ call, dup := p.addConnCalls[key]
+ if !dup {
+ if p.addConnCalls == nil {
+ p.addConnCalls = make(map[string]*http2addConnCall)
+ }
+ call = &http2addConnCall{
+ p: p,
+ done: make(chan struct{}),
+ }
+ p.addConnCalls[key] = call
+ go call.run(t, key, c)
+ }
+ p.mu.Unlock()
+
+ <-call.done
+ if call.err != nil {
+ return false, call.err
+ }
+ return !dup, nil
+}
+
+type http2addConnCall struct {
+ p *http2clientConnPool
+ done chan struct{} // closed when done
+ err error
+}
+
+func (c *http2addConnCall) run(t *http2Transport, key string, tc *tls.Conn) {
+ cc, err := t.NewClientConn(tc)
+
+ p := c.p
+ p.mu.Lock()
+ if err != nil {
+ c.err = err
+ } else {
+ p.addConnLocked(key, cc)
+ }
+ delete(p.addConnCalls, key)
+ p.mu.Unlock()
+ close(c.done)
+}
+
+func (p *http2clientConnPool) addConn(key string, cc *http2ClientConn) {
+ p.mu.Lock()
+ p.addConnLocked(key, cc)
+ p.mu.Unlock()
+}
+
+// p.mu must be held
+func (p *http2clientConnPool) addConnLocked(key string, cc *http2ClientConn) {
+ for _, v := range p.conns[key] {
+ if v == cc {
+ return
+ }
+ }
+ if p.conns == nil {
+ p.conns = make(map[string][]*http2ClientConn)
+ }
+ if p.keys == nil {
+ p.keys = make(map[*http2ClientConn][]string)
+ }
+ p.conns[key] = append(p.conns[key], cc)
+ p.keys[cc] = append(p.keys[cc], key)
+}
+
+func (p *http2clientConnPool) MarkDead(cc *http2ClientConn) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ for _, key := range p.keys[cc] {
+ vv, ok := p.conns[key]
+ if !ok {
+ continue
+ }
+ newList := http2filterOutClientConn(vv, cc)
+ if len(newList) > 0 {
+ p.conns[key] = newList
+ } else {
+ delete(p.conns, key)
+ }
+ }
+ delete(p.keys, cc)
+}
+
+func (p *http2clientConnPool) closeIdleConnections() {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+
+ for _, vv := range p.conns {
+ for _, cc := range vv {
+ cc.closeIfIdle()
+ }
+ }
+}
+
+func http2filterOutClientConn(in []*http2ClientConn, exclude *http2ClientConn) []*http2ClientConn {
+ out := in[:0]
+ for _, v := range in {
+ if v != exclude {
+ out = append(out, v)
+ }
+ }
+
+ if len(in) != len(out) {
+ in[len(in)-1] = nil
+ }
+ return out
+}
+
+func http2configureTransport(t1 *Transport) (*http2Transport, error) {
+ connPool := new(http2clientConnPool)
+ t2 := &http2Transport{
+ ConnPool: http2noDialClientConnPool{connPool},
+ t1: t1,
+ }
+ connPool.t = t2
+ if err := http2registerHTTPSProtocol(t1, http2noDialH2RoundTripper{t2}); err != nil {
+ return nil, err
+ }
+ if t1.TLSClientConfig == nil {
+ t1.TLSClientConfig = new(tls.Config)
+ }
+ if !http2strSliceContains(t1.TLSClientConfig.NextProtos, "h2") {
+ t1.TLSClientConfig.NextProtos = append([]string{"h2"}, t1.TLSClientConfig.NextProtos...)
+ }
+ if !http2strSliceContains(t1.TLSClientConfig.NextProtos, "http/1.1") {
+ t1.TLSClientConfig.NextProtos = append(t1.TLSClientConfig.NextProtos, "http/1.1")
+ }
+ upgradeFn := func(authority string, c *tls.Conn) RoundTripper {
+ addr := http2authorityAddr(authority)
+ if used, err := connPool.addConnIfNeeded(addr, t2, c); err != nil {
+ go c.Close()
+ return http2erringRoundTripper{err}
+ } else if !used {
+
+ go c.Close()
+ }
+ return t2
+ }
+ if m := t1.TLSNextProto; len(m) == 0 {
+ t1.TLSNextProto = map[string]func(string, *tls.Conn) RoundTripper{
+ "h2": upgradeFn,
+ }
+ } else {
+ m["h2"] = upgradeFn
+ }
+ return t2, nil
+}
+
+// registerHTTPSProtocol calls Transport.RegisterProtocol but
+// convering panics into errors.
+func http2registerHTTPSProtocol(t *Transport, rt RoundTripper) (err error) {
+ defer func() {
+ if e := recover(); e != nil {
+ err = fmt.Errorf("%v", e)
+ }
+ }()
+ t.RegisterProtocol("https", rt)
+ return nil
+}
+
+// noDialClientConnPool is an implementation of http2.ClientConnPool
+// which never dials. We let the HTTP/1.1 client dial and use its TLS
+// connection instead.
+type http2noDialClientConnPool struct{ *http2clientConnPool }
+
+func (p http2noDialClientConnPool) GetClientConn(req *Request, addr string) (*http2ClientConn, error) {
+ return p.getClientConn(req, addr, http2noDialOnMiss)
+}
+
+// noDialH2RoundTripper is a RoundTripper which only tries to complete the request
+// if there's already has a cached connection to the host.
+type http2noDialH2RoundTripper struct{ t *http2Transport }
+
+func (rt http2noDialH2RoundTripper) RoundTrip(req *Request) (*Response, error) {
+ res, err := rt.t.RoundTrip(req)
+ if err == http2ErrNoCachedConn {
+ return nil, ErrSkipAltProtocol
+ }
+ return res, err
+}
+
+// An ErrCode is an unsigned 32-bit error code as defined in the HTTP/2 spec.
+type http2ErrCode uint32
+
+const (
+ http2ErrCodeNo http2ErrCode = 0x0
+ http2ErrCodeProtocol http2ErrCode = 0x1
+ http2ErrCodeInternal http2ErrCode = 0x2
+ http2ErrCodeFlowControl http2ErrCode = 0x3
+ http2ErrCodeSettingsTimeout http2ErrCode = 0x4
+ http2ErrCodeStreamClosed http2ErrCode = 0x5
+ http2ErrCodeFrameSize http2ErrCode = 0x6
+ http2ErrCodeRefusedStream http2ErrCode = 0x7
+ http2ErrCodeCancel http2ErrCode = 0x8
+ http2ErrCodeCompression http2ErrCode = 0x9
+ http2ErrCodeConnect http2ErrCode = 0xa
+ http2ErrCodeEnhanceYourCalm http2ErrCode = 0xb
+ http2ErrCodeInadequateSecurity http2ErrCode = 0xc
+ http2ErrCodeHTTP11Required http2ErrCode = 0xd
+)
+
+var http2errCodeName = map[http2ErrCode]string{
+ http2ErrCodeNo: "NO_ERROR",
+ http2ErrCodeProtocol: "PROTOCOL_ERROR",
+ http2ErrCodeInternal: "INTERNAL_ERROR",
+ http2ErrCodeFlowControl: "FLOW_CONTROL_ERROR",
+ http2ErrCodeSettingsTimeout: "SETTINGS_TIMEOUT",
+ http2ErrCodeStreamClosed: "STREAM_CLOSED",
+ http2ErrCodeFrameSize: "FRAME_SIZE_ERROR",
+ http2ErrCodeRefusedStream: "REFUSED_STREAM",
+ http2ErrCodeCancel: "CANCEL",
+ http2ErrCodeCompression: "COMPRESSION_ERROR",
+ http2ErrCodeConnect: "CONNECT_ERROR",
+ http2ErrCodeEnhanceYourCalm: "ENHANCE_YOUR_CALM",
+ http2ErrCodeInadequateSecurity: "INADEQUATE_SECURITY",
+ http2ErrCodeHTTP11Required: "HTTP_1_1_REQUIRED",
+}
+
+func (e http2ErrCode) String() string {
+ if s, ok := http2errCodeName[e]; ok {
+ return s
+ }
+ return fmt.Sprintf("unknown error code 0x%x", uint32(e))
+}
+
+// ConnectionError is an error that results in the termination of the
+// entire connection.
+type http2ConnectionError http2ErrCode
+
+func (e http2ConnectionError) Error() string {
+ return fmt.Sprintf("connection error: %s", http2ErrCode(e))
+}
+
+// StreamError is an error that only affects one stream within an
+// HTTP/2 connection.
+type http2StreamError struct {
+ StreamID uint32
+ Code http2ErrCode
+}
+
+func (e http2StreamError) Error() string {
+ return fmt.Sprintf("stream error: stream ID %d; %v", e.StreamID, e.Code)
+}
+
+// 6.9.1 The Flow Control Window
+// "If a sender receives a WINDOW_UPDATE that causes a flow control
+// window to exceed this maximum it MUST terminate either the stream
+// or the connection, as appropriate. For streams, [...]; for the
+// connection, a GOAWAY frame with a FLOW_CONTROL_ERROR code."
+type http2goAwayFlowError struct{}
+
+func (http2goAwayFlowError) Error() string { return "connection exceeded flow control window size" }
+
+// Errors of this type are only returned by the frame parser functions
+// and converted into ConnectionError(ErrCodeProtocol).
+type http2connError struct {
+ Code http2ErrCode
+ Reason string
+}
+
+func (e http2connError) Error() string {
+ return fmt.Sprintf("http2: connection error: %v: %v", e.Code, e.Reason)
+}
+
+// fixedBuffer is an io.ReadWriter backed by a fixed size buffer.
+// It never allocates, but moves old data as new data is written.
+type http2fixedBuffer struct {
+ buf []byte
+ r, w int
+}
+
+var (
+ http2errReadEmpty = errors.New("read from empty fixedBuffer")
+ http2errWriteFull = errors.New("write on full fixedBuffer")
+)
+
+// Read copies bytes from the buffer into p.
+// It is an error to read when no data is available.
+func (b *http2fixedBuffer) Read(p []byte) (n int, err error) {
+ if b.r == b.w {
+ return 0, http2errReadEmpty
+ }
+ n = copy(p, b.buf[b.r:b.w])
+ b.r += n
+ if b.r == b.w {
+ b.r = 0
+ b.w = 0
+ }
+ return n, nil
+}
+
+// Len returns the number of bytes of the unread portion of the buffer.
+func (b *http2fixedBuffer) Len() int {
+ return b.w - b.r
+}
+
+// Write copies bytes from p into the buffer.
+// It is an error to write more data than the buffer can hold.
+func (b *http2fixedBuffer) Write(p []byte) (n int, err error) {
+
+ if b.r > 0 && len(p) > len(b.buf)-b.w {
+ copy(b.buf, b.buf[b.r:b.w])
+ b.w -= b.r
+ b.r = 0
+ }
+
+ n = copy(b.buf[b.w:], p)
+ b.w += n
+ if n < len(p) {
+ err = http2errWriteFull
+ }
+ return n, err
+}
+
+// flow is the flow control window's size.
+type http2flow struct {
+ // n is the number of DATA bytes we're allowed to send.
+ // A flow is kept both on a conn and a per-stream.
+ n int32
+
+ // conn points to the shared connection-level flow that is
+ // shared by all streams on that conn. It is nil for the flow
+ // that's on the conn directly.
+ conn *http2flow
+}
+
+func (f *http2flow) setConnFlow(cf *http2flow) { f.conn = cf }
+
+func (f *http2flow) available() int32 {
+ n := f.n
+ if f.conn != nil && f.conn.n < n {
+ n = f.conn.n
+ }
+ return n
+}
+
+func (f *http2flow) take(n int32) {
+ if n > f.available() {
+ panic("internal error: took too much")
+ }
+ f.n -= n
+ if f.conn != nil {
+ f.conn.n -= n
+ }
+}
+
+// add adds n bytes (positive or negative) to the flow control window.
+// It returns false if the sum would exceed 2^31-1.
+func (f *http2flow) add(n int32) bool {
+ remain := (1<<31 - 1) - f.n
+ if n > remain {
+ return false
+ }
+ f.n += n
+ return true
+}
+
+const http2frameHeaderLen = 9
+
+var http2padZeros = make([]byte, 255) // zeros for padding
+
+// A FrameType is a registered frame type as defined in
+// http://http2.github.io/http2-spec/#rfc.section.11.2
+type http2FrameType uint8
+
+const (
+ http2FrameData http2FrameType = 0x0
+ http2FrameHeaders http2FrameType = 0x1
+ http2FramePriority http2FrameType = 0x2
+ http2FrameRSTStream http2FrameType = 0x3
+ http2FrameSettings http2FrameType = 0x4
+ http2FramePushPromise http2FrameType = 0x5
+ http2FramePing http2FrameType = 0x6
+ http2FrameGoAway http2FrameType = 0x7
+ http2FrameWindowUpdate http2FrameType = 0x8
+ http2FrameContinuation http2FrameType = 0x9
+)
+
+var http2frameName = map[http2FrameType]string{
+ http2FrameData: "DATA",
+ http2FrameHeaders: "HEADERS",
+ http2FramePriority: "PRIORITY",
+ http2FrameRSTStream: "RST_STREAM",
+ http2FrameSettings: "SETTINGS",
+ http2FramePushPromise: "PUSH_PROMISE",
+ http2FramePing: "PING",
+ http2FrameGoAway: "GOAWAY",
+ http2FrameWindowUpdate: "WINDOW_UPDATE",
+ http2FrameContinuation: "CONTINUATION",
+}
+
+func (t http2FrameType) String() string {
+ if s, ok := http2frameName[t]; ok {
+ return s
+ }
+ return fmt.Sprintf("UNKNOWN_FRAME_TYPE_%d", uint8(t))
+}
+
+// Flags is a bitmask of HTTP/2 flags.
+// The meaning of flags varies depending on the frame type.
+type http2Flags uint8
+
+// Has reports whether f contains all (0 or more) flags in v.
+func (f http2Flags) Has(v http2Flags) bool {
+ return (f & v) == v
+}
+
+// Frame-specific FrameHeader flag bits.
+const (
+ // Data Frame
+ http2FlagDataEndStream http2Flags = 0x1
+ http2FlagDataPadded http2Flags = 0x8
+
+ // Headers Frame
+ http2FlagHeadersEndStream http2Flags = 0x1
+ http2FlagHeadersEndHeaders http2Flags = 0x4
+ http2FlagHeadersPadded http2Flags = 0x8
+ http2FlagHeadersPriority http2Flags = 0x20
+
+ // Settings Frame
+ http2FlagSettingsAck http2Flags = 0x1
+
+ // Ping Frame
+ http2FlagPingAck http2Flags = 0x1
+
+ // Continuation Frame
+ http2FlagContinuationEndHeaders http2Flags = 0x4
+
+ http2FlagPushPromiseEndHeaders http2Flags = 0x4
+ http2FlagPushPromisePadded http2Flags = 0x8
+)
+
+var http2flagName = map[http2FrameType]map[http2Flags]string{
+ http2FrameData: {
+ http2FlagDataEndStream: "END_STREAM",
+ http2FlagDataPadded: "PADDED",
+ },
+ http2FrameHeaders: {
+ http2FlagHeadersEndStream: "END_STREAM",
+ http2FlagHeadersEndHeaders: "END_HEADERS",
+ http2FlagHeadersPadded: "PADDED",
+ http2FlagHeadersPriority: "PRIORITY",
+ },
+ http2FrameSettings: {
+ http2FlagSettingsAck: "ACK",
+ },
+ http2FramePing: {
+ http2FlagPingAck: "ACK",
+ },
+ http2FrameContinuation: {
+ http2FlagContinuationEndHeaders: "END_HEADERS",
+ },
+ http2FramePushPromise: {
+ http2FlagPushPromiseEndHeaders: "END_HEADERS",
+ http2FlagPushPromisePadded: "PADDED",
+ },
+}
+
+// a frameParser parses a frame given its FrameHeader and payload
+// bytes. The length of payload will always equal fh.Length (which
+// might be 0).
+type http2frameParser func(fh http2FrameHeader, payload []byte) (http2Frame, error)
+
+var http2frameParsers = map[http2FrameType]http2frameParser{
+ http2FrameData: http2parseDataFrame,
+ http2FrameHeaders: http2parseHeadersFrame,
+ http2FramePriority: http2parsePriorityFrame,
+ http2FrameRSTStream: http2parseRSTStreamFrame,
+ http2FrameSettings: http2parseSettingsFrame,
+ http2FramePushPromise: http2parsePushPromise,
+ http2FramePing: http2parsePingFrame,
+ http2FrameGoAway: http2parseGoAwayFrame,
+ http2FrameWindowUpdate: http2parseWindowUpdateFrame,
+ http2FrameContinuation: http2parseContinuationFrame,
+}
+
+func http2typeFrameParser(t http2FrameType) http2frameParser {
+ if f := http2frameParsers[t]; f != nil {
+ return f
+ }
+ return http2parseUnknownFrame
+}
+
+// A FrameHeader is the 9 byte header of all HTTP/2 frames.
+//
+// See http://http2.github.io/http2-spec/#FrameHeader
+type http2FrameHeader struct {
+ valid bool // caller can access []byte fields in the Frame
+
+ // Type is the 1 byte frame type. There are ten standard frame
+ // types, but extension frame types may be written by WriteRawFrame
+ // and will be returned by ReadFrame (as UnknownFrame).
+ Type http2FrameType
+
+ // Flags are the 1 byte of 8 potential bit flags per frame.
+ // They are specific to the frame type.
+ Flags http2Flags
+
+ // Length is the length of the frame, not including the 9 byte header.
+ // The maximum size is one byte less than 16MB (uint24), but only
+ // frames up to 16KB are allowed without peer agreement.
+ Length uint32
+
+ // StreamID is which stream this frame is for. Certain frames
+ // are not stream-specific, in which case this field is 0.
+ StreamID uint32
+}
+
+// Header returns h. It exists so FrameHeaders can be embedded in other
+// specific frame types and implement the Frame interface.
+func (h http2FrameHeader) Header() http2FrameHeader { return h }
+
+func (h http2FrameHeader) String() string {
+ var buf bytes.Buffer
+ buf.WriteString("[FrameHeader ")
+ h.writeDebug(&buf)
+ buf.WriteByte(']')
+ return buf.String()
+}
+
+func (h http2FrameHeader) writeDebug(buf *bytes.Buffer) {
+ buf.WriteString(h.Type.String())
+ if h.Flags != 0 {
+ buf.WriteString(" flags=")
+ set := 0
+ for i := uint8(0); i < 8; i++ {
+ if h.Flags&(1<<i) == 0 {
+ continue
+ }
+ set++
+ if set > 1 {
+ buf.WriteByte('|')
+ }
+ name := http2flagName[h.Type][http2Flags(1<<i)]
+ if name != "" {
+ buf.WriteString(name)
+ } else {
+ fmt.Fprintf(buf, "0x%x", 1<<i)
+ }
+ }
+ }
+ if h.StreamID != 0 {
+ fmt.Fprintf(buf, " stream=%d", h.StreamID)
+ }
+ fmt.Fprintf(buf, " len=%d", h.Length)
+}
+
+func (h *http2FrameHeader) checkValid() {
+ if !h.valid {
+ panic("Frame accessor called on non-owned Frame")
+ }
+}
+
+func (h *http2FrameHeader) invalidate() { h.valid = false }
+
+// frame header bytes.
+// Used only by ReadFrameHeader.
+var http2fhBytes = sync.Pool{
+ New: func() interface{} {
+ buf := make([]byte, http2frameHeaderLen)
+ return &buf
+ },
+}
+
+// ReadFrameHeader reads 9 bytes from r and returns a FrameHeader.
+// Most users should use Framer.ReadFrame instead.
+func http2ReadFrameHeader(r io.Reader) (http2FrameHeader, error) {
+ bufp := http2fhBytes.Get().(*[]byte)
+ defer http2fhBytes.Put(bufp)
+ return http2readFrameHeader(*bufp, r)
+}
+
+func http2readFrameHeader(buf []byte, r io.Reader) (http2FrameHeader, error) {
+ _, err := io.ReadFull(r, buf[:http2frameHeaderLen])
+ if err != nil {
+ return http2FrameHeader{}, err
+ }
+ return http2FrameHeader{
+ Length: (uint32(buf[0])<<16 | uint32(buf[1])<<8 | uint32(buf[2])),
+ Type: http2FrameType(buf[3]),
+ Flags: http2Flags(buf[4]),
+ StreamID: binary.BigEndian.Uint32(buf[5:]) & (1<<31 - 1),
+ valid: true,
+ }, nil
+}
+
+// A Frame is the base interface implemented by all frame types.
+// Callers will generally type-assert the specific frame type:
+// *HeadersFrame, *SettingsFrame, *WindowUpdateFrame, etc.
+//
+// Frames are only valid until the next call to Framer.ReadFrame.
+type http2Frame interface {
+ Header() http2FrameHeader
+
+ // invalidate is called by Framer.ReadFrame to make this
+ // frame's buffers as being invalid, since the subsequent
+ // frame will reuse them.
+ invalidate()
+}
+
+// A Framer reads and writes Frames.
+type http2Framer struct {
+ r io.Reader
+ lastFrame http2Frame
+ errReason string
+
+ // lastHeaderStream is non-zero if the last frame was an
+ // unfinished HEADERS/CONTINUATION.
+ lastHeaderStream uint32
+
+ maxReadSize uint32
+ headerBuf [http2frameHeaderLen]byte
+
+ // TODO: let getReadBuf be configurable, and use a less memory-pinning
+ // allocator in server.go to minimize memory pinned for many idle conns.
+ // Will probably also need to make frame invalidation have a hook too.
+ getReadBuf func(size uint32) []byte
+ readBuf []byte // cache for default getReadBuf
+
+ maxWriteSize uint32 // zero means unlimited; TODO: implement
+
+ w io.Writer
+ wbuf []byte
+
+ // AllowIllegalWrites permits the Framer's Write methods to
+ // write frames that do not conform to the HTTP/2 spec. This
+ // permits using the Framer to test other HTTP/2
+ // implementations' conformance to the spec.
+ // If false, the Write methods will prefer to return an error
+ // rather than comply.
+ AllowIllegalWrites bool
+
+ // AllowIllegalReads permits the Framer's ReadFrame method
+ // to return non-compliant frames or frame orders.
+ // This is for testing and permits using the Framer to test
+ // other HTTP/2 implementations' conformance to the spec.
+ AllowIllegalReads bool
+
+ logReads bool
+
+ debugFramer *http2Framer // only use for logging written writes
+ debugFramerBuf *bytes.Buffer
+}
+
+func (f *http2Framer) startWrite(ftype http2FrameType, flags http2Flags, streamID uint32) {
+
+ f.wbuf = append(f.wbuf[:0],
+ 0,
+ 0,
+ 0,
+ byte(ftype),
+ byte(flags),
+ byte(streamID>>24),
+ byte(streamID>>16),
+ byte(streamID>>8),
+ byte(streamID))
+}
+
+func (f *http2Framer) endWrite() error {
+
+ length := len(f.wbuf) - http2frameHeaderLen
+ if length >= (1 << 24) {
+ return http2ErrFrameTooLarge
+ }
+ _ = append(f.wbuf[:0],
+ byte(length>>16),
+ byte(length>>8),
+ byte(length))
+ if http2logFrameWrites {
+ f.logWrite()
+ }
+
+ n, err := f.w.Write(f.wbuf)
+ if err == nil && n != len(f.wbuf) {
+ err = io.ErrShortWrite
+ }
+ return err
+}
+
+func (f *http2Framer) logWrite() {
+ if f.debugFramer == nil {
+ f.debugFramerBuf = new(bytes.Buffer)
+ f.debugFramer = http2NewFramer(nil, f.debugFramerBuf)
+ f.debugFramer.logReads = false
+
+ f.debugFramer.AllowIllegalReads = true
+ }
+ f.debugFramerBuf.Write(f.wbuf)
+ fr, err := f.debugFramer.ReadFrame()
+ if err != nil {
+ log.Printf("http2: Framer %p: failed to decode just-written frame", f)
+ return
+ }
+ log.Printf("http2: Framer %p: wrote %v", f, http2summarizeFrame(fr))
+}
+
+func (f *http2Framer) writeByte(v byte) { f.wbuf = append(f.wbuf, v) }
+
+func (f *http2Framer) writeBytes(v []byte) { f.wbuf = append(f.wbuf, v...) }
+
+func (f *http2Framer) writeUint16(v uint16) { f.wbuf = append(f.wbuf, byte(v>>8), byte(v)) }
+
+func (f *http2Framer) writeUint32(v uint32) {
+ f.wbuf = append(f.wbuf, byte(v>>24), byte(v>>16), byte(v>>8), byte(v))
+}
+
+const (
+ http2minMaxFrameSize = 1 << 14
+ http2maxFrameSize = 1<<24 - 1
+)
+
+// NewFramer returns a Framer that writes frames to w and reads them from r.
+func http2NewFramer(w io.Writer, r io.Reader) *http2Framer {
+ fr := &http2Framer{
+ w: w,
+ r: r,
+ logReads: http2logFrameReads,
+ }
+ fr.getReadBuf = func(size uint32) []byte {
+ if cap(fr.readBuf) >= int(size) {
+ return fr.readBuf[:size]
+ }
+ fr.readBuf = make([]byte, size)
+ return fr.readBuf
+ }
+ fr.SetMaxReadFrameSize(http2maxFrameSize)
+ return fr
+}
+
+// SetMaxReadFrameSize sets the maximum size of a frame
+// that will be read by a subsequent call to ReadFrame.
+// It is the caller's responsibility to advertise this
+// limit with a SETTINGS frame.
+func (fr *http2Framer) SetMaxReadFrameSize(v uint32) {
+ if v > http2maxFrameSize {
+ v = http2maxFrameSize
+ }
+ fr.maxReadSize = v
+}
+
+// ErrFrameTooLarge is returned from Framer.ReadFrame when the peer
+// sends a frame that is larger than declared with SetMaxReadFrameSize.
+var http2ErrFrameTooLarge = errors.New("http2: frame too large")
+
+// terminalReadFrameError reports whether err is an unrecoverable
+// error from ReadFrame and no other frames should be read.
+func http2terminalReadFrameError(err error) bool {
+ if _, ok := err.(http2StreamError); ok {
+ return false
+ }
+ return err != nil
+}
+
+// ReadFrame reads a single frame. The returned Frame is only valid
+// until the next call to ReadFrame.
+//
+// If the frame is larger than previously set with SetMaxReadFrameSize, the
+// returned error is ErrFrameTooLarge. Other errors may be of type
+// ConnectionError, StreamError, or anything else from from the underlying
+// reader.
+func (fr *http2Framer) ReadFrame() (http2Frame, error) {
+ if fr.lastFrame != nil {
+ fr.lastFrame.invalidate()
+ }
+ fh, err := http2readFrameHeader(fr.headerBuf[:], fr.r)
+ if err != nil {
+ return nil, err
+ }
+ if fh.Length > fr.maxReadSize {
+ return nil, http2ErrFrameTooLarge
+ }
+ payload := fr.getReadBuf(fh.Length)
+ if _, err := io.ReadFull(fr.r, payload); err != nil {
+ return nil, err
+ }
+ f, err := http2typeFrameParser(fh.Type)(fh, payload)
+ if err != nil {
+ if ce, ok := err.(http2connError); ok {
+ return nil, fr.connError(ce.Code, ce.Reason)
+ }
+ return nil, err
+ }
+ if err := fr.checkFrameOrder(f); err != nil {
+ return nil, err
+ }
+ if fr.logReads {
+ log.Printf("http2: Framer %p: read %v", fr, http2summarizeFrame(f))
+ }
+ return f, nil
+}
+
+// connError returns ConnectionError(code) but first
+// stashes away a public reason to the caller can optionally relay it
+// to the peer before hanging up on them. This might help others debug
+// their implementations.
+func (fr *http2Framer) connError(code http2ErrCode, reason string) error {
+ fr.errReason = reason
+ return http2ConnectionError(code)
+}
+
+// checkFrameOrder reports an error if f is an invalid frame to return
+// next from ReadFrame. Mostly it checks whether HEADERS and
+// CONTINUATION frames are contiguous.
+func (fr *http2Framer) checkFrameOrder(f http2Frame) error {
+ last := fr.lastFrame
+ fr.lastFrame = f
+ if fr.AllowIllegalReads {
+ return nil
+ }
+
+ fh := f.Header()
+ if fr.lastHeaderStream != 0 {
+ if fh.Type != http2FrameContinuation {
+ return fr.connError(http2ErrCodeProtocol,
+ fmt.Sprintf("got %s for stream %d; expected CONTINUATION following %s for stream %d",
+ fh.Type, fh.StreamID,
+ last.Header().Type, fr.lastHeaderStream))
+ }
+ if fh.StreamID != fr.lastHeaderStream {
+ return fr.connError(http2ErrCodeProtocol,
+ fmt.Sprintf("got CONTINUATION for stream %d; expected stream %d",
+ fh.StreamID, fr.lastHeaderStream))
+ }
+ } else if fh.Type == http2FrameContinuation {
+ return fr.connError(http2ErrCodeProtocol, fmt.Sprintf("unexpected CONTINUATION for stream %d", fh.StreamID))
+ }
+
+ switch fh.Type {
+ case http2FrameHeaders, http2FrameContinuation:
+ if fh.Flags.Has(http2FlagHeadersEndHeaders) {
+ fr.lastHeaderStream = 0
+ } else {
+ fr.lastHeaderStream = fh.StreamID
+ }
+ }
+
+ return nil
+}
+
+// A DataFrame conveys arbitrary, variable-length sequences of octets
+// associated with a stream.
+// See http://http2.github.io/http2-spec/#rfc.section.6.1
+type http2DataFrame struct {
+ http2FrameHeader
+ data []byte
+}
+
+func (f *http2DataFrame) StreamEnded() bool {
+ return f.http2FrameHeader.Flags.Has(http2FlagDataEndStream)
+}
+
+// Data returns the frame's data octets, not including any padding
+// size byte or padding suffix bytes.
+// The caller must not retain the returned memory past the next
+// call to ReadFrame.
+func (f *http2DataFrame) Data() []byte {
+ f.checkValid()
+ return f.data
+}
+
+func http2parseDataFrame(fh http2FrameHeader, payload []byte) (http2Frame, error) {
+ if fh.StreamID == 0 {
+
+ return nil, http2connError{http2ErrCodeProtocol, "DATA frame with stream ID 0"}
+ }
+ f := &http2DataFrame{
+ http2FrameHeader: fh,
+ }
+ var padSize byte
+ if fh.Flags.Has(http2FlagDataPadded) {
+ var err error
+ payload, padSize, err = http2readByte(payload)
+ if err != nil {
+ return nil, err
+ }
+ }
+ if int(padSize) > len(payload) {
+
+ return nil, http2connError{http2ErrCodeProtocol, "pad size larger than data payload"}
+ }
+ f.data = payload[:len(payload)-int(padSize)]
+ return f, nil
+}
+
+var http2errStreamID = errors.New("invalid streamid")
+
+func http2validStreamID(streamID uint32) bool {
+ return streamID != 0 && streamID&(1<<31) == 0
+}
+
+// WriteData writes a DATA frame.
+//
+// It will perform exactly one Write to the underlying Writer.
+// It is the caller's responsibility to not call other Write methods concurrently.
+func (f *http2Framer) WriteData(streamID uint32, endStream bool, data []byte) error {
+
+ if !http2validStreamID(streamID) && !f.AllowIllegalWrites {
+ return http2errStreamID
+ }
+ var flags http2Flags
+ if endStream {
+ flags |= http2FlagDataEndStream
+ }
+ f.startWrite(http2FrameData, flags, streamID)
+ f.wbuf = append(f.wbuf, data...)
+ return f.endWrite()
+}
+
+// A SettingsFrame conveys configuration parameters that affect how
+// endpoints communicate, such as preferences and constraints on peer
+// behavior.
+//
+// See http://http2.github.io/http2-spec/#SETTINGS
+type http2SettingsFrame struct {
+ http2FrameHeader
+ p []byte
+}
+
+func http2parseSettingsFrame(fh http2FrameHeader, p []byte) (http2Frame, error) {
+ if fh.Flags.Has(http2FlagSettingsAck) && fh.Length > 0 {
+
+ return nil, http2ConnectionError(http2ErrCodeFrameSize)
+ }
+ if fh.StreamID != 0 {
+
+ return nil, http2ConnectionError(http2ErrCodeProtocol)
+ }
+ if len(p)%6 != 0 {
+
+ return nil, http2ConnectionError(http2ErrCodeFrameSize)
+ }
+ f := &http2SettingsFrame{http2FrameHeader: fh, p: p}
+ if v, ok := f.Value(http2SettingInitialWindowSize); ok && v > (1<<31)-1 {
+
+ return nil, http2ConnectionError(http2ErrCodeFlowControl)
+ }
+ return f, nil
+}
+
+func (f *http2SettingsFrame) IsAck() bool {
+ return f.http2FrameHeader.Flags.Has(http2FlagSettingsAck)
+}
+
+func (f *http2SettingsFrame) Value(s http2SettingID) (v uint32, ok bool) {
+ f.checkValid()
+ buf := f.p
+ for len(buf) > 0 {
+ settingID := http2SettingID(binary.BigEndian.Uint16(buf[:2]))
+ if settingID == s {
+ return binary.BigEndian.Uint32(buf[2:6]), true
+ }
+ buf = buf[6:]
+ }
+ return 0, false
+}
+
+// ForeachSetting runs fn for each setting.
+// It stops and returns the first error.
+func (f *http2SettingsFrame) ForeachSetting(fn func(http2Setting) error) error {
+ f.checkValid()
+ buf := f.p
+ for len(buf) > 0 {
+ if err := fn(http2Setting{
+ http2SettingID(binary.BigEndian.Uint16(buf[:2])),
+ binary.BigEndian.Uint32(buf[2:6]),
+ }); err != nil {
+ return err
+ }
+ buf = buf[6:]
+ }
+ return nil
+}
+
+// WriteSettings writes a SETTINGS frame with zero or more settings
+// specified and the ACK bit not set.
+//
+// It will perform exactly one Write to the underlying Writer.
+// It is the caller's responsibility to not call other Write methods concurrently.
+func (f *http2Framer) WriteSettings(settings ...http2Setting) error {
+ f.startWrite(http2FrameSettings, 0, 0)
+ for _, s := range settings {
+ f.writeUint16(uint16(s.ID))
+ f.writeUint32(s.Val)
+ }
+ return f.endWrite()
+}
+
+// WriteSettings writes an empty SETTINGS frame with the ACK bit set.
+//
+// It will perform exactly one Write to the underlying Writer.
+// It is the caller's responsibility to not call other Write methods concurrently.
+func (f *http2Framer) WriteSettingsAck() error {
+ f.startWrite(http2FrameSettings, http2FlagSettingsAck, 0)
+ return f.endWrite()
+}
+
+// A PingFrame is a mechanism for measuring a minimal round trip time
+// from the sender, as well as determining whether an idle connection
+// is still functional.
+// See http://http2.github.io/http2-spec/#rfc.section.6.7
+type http2PingFrame struct {
+ http2FrameHeader
+ Data [8]byte
+}
+
+func (f *http2PingFrame) IsAck() bool { return f.Flags.Has(http2FlagPingAck) }
+
+func http2parsePingFrame(fh http2FrameHeader, payload []byte) (http2Frame, error) {
+ if len(payload) != 8 {
+ return nil, http2ConnectionError(http2ErrCodeFrameSize)
+ }
+ if fh.StreamID != 0 {
+ return nil, http2ConnectionError(http2ErrCodeProtocol)
+ }
+ f := &http2PingFrame{http2FrameHeader: fh}
+ copy(f.Data[:], payload)
+ return f, nil
+}
+
+func (f *http2Framer) WritePing(ack bool, data [8]byte) error {
+ var flags http2Flags
+ if ack {
+ flags = http2FlagPingAck
+ }
+ f.startWrite(http2FramePing, flags, 0)
+ f.writeBytes(data[:])
+ return f.endWrite()
+}
+
+// A GoAwayFrame informs the remote peer to stop creating streams on this connection.
+// See http://http2.github.io/http2-spec/#rfc.section.6.8
+type http2GoAwayFrame struct {
+ http2FrameHeader
+ LastStreamID uint32
+ ErrCode http2ErrCode
+ debugData []byte
+}
+
+// DebugData returns any debug data in the GOAWAY frame. Its contents
+// are not defined.
+// The caller must not retain the returned memory past the next
+// call to ReadFrame.
+func (f *http2GoAwayFrame) DebugData() []byte {
+ f.checkValid()
+ return f.debugData
+}
+
+func http2parseGoAwayFrame(fh http2FrameHeader, p []byte) (http2Frame, error) {
+ if fh.StreamID != 0 {
+ return nil, http2ConnectionError(http2ErrCodeProtocol)
+ }
+ if len(p) < 8 {
+ return nil, http2ConnectionError(http2ErrCodeFrameSize)
+ }
+ return &http2GoAwayFrame{
+ http2FrameHeader: fh,
+ LastStreamID: binary.BigEndian.Uint32(p[:4]) & (1<<31 - 1),
+ ErrCode: http2ErrCode(binary.BigEndian.Uint32(p[4:8])),
+ debugData: p[8:],
+ }, nil
+}
+
+func (f *http2Framer) WriteGoAway(maxStreamID uint32, code http2ErrCode, debugData []byte) error {
+ f.startWrite(http2FrameGoAway, 0, 0)
+ f.writeUint32(maxStreamID & (1<<31 - 1))
+ f.writeUint32(uint32(code))
+ f.writeBytes(debugData)
+ return f.endWrite()
+}
+
+// An UnknownFrame is the frame type returned when the frame type is unknown
+// or no specific frame type parser exists.
+type http2UnknownFrame struct {
+ http2FrameHeader
+ p []byte
+}
+
+// Payload returns the frame's payload (after the header). It is not
+// valid to call this method after a subsequent call to
+// Framer.ReadFrame, nor is it valid to retain the returned slice.
+// The memory is owned by the Framer and is invalidated when the next
+// frame is read.
+func (f *http2UnknownFrame) Payload() []byte {
+ f.checkValid()
+ return f.p
+}
+
+func http2parseUnknownFrame(fh http2FrameHeader, p []byte) (http2Frame, error) {
+ return &http2UnknownFrame{fh, p}, nil
+}
+
+// A WindowUpdateFrame is used to implement flow control.
+// See http://http2.github.io/http2-spec/#rfc.section.6.9
+type http2WindowUpdateFrame struct {
+ http2FrameHeader
+ Increment uint32 // never read with high bit set
+}
+
+func http2parseWindowUpdateFrame(fh http2FrameHeader, p []byte) (http2Frame, error) {
+ if len(p) != 4 {
+ return nil, http2ConnectionError(http2ErrCodeFrameSize)
+ }
+ inc := binary.BigEndian.Uint32(p[:4]) & 0x7fffffff
+ if inc == 0 {
+
+ if fh.StreamID == 0 {
+ return nil, http2ConnectionError(http2ErrCodeProtocol)
+ }
+ return nil, http2StreamError{fh.StreamID, http2ErrCodeProtocol}
+ }
+ return &http2WindowUpdateFrame{
+ http2FrameHeader: fh,
+ Increment: inc,
+ }, nil
+}
+
+// WriteWindowUpdate writes a WINDOW_UPDATE frame.
+// The increment value must be between 1 and 2,147,483,647, inclusive.
+// If the Stream ID is zero, the window update applies to the
+// connection as a whole.
+func (f *http2Framer) WriteWindowUpdate(streamID, incr uint32) error {
+
+ if (incr < 1 || incr > 2147483647) && !f.AllowIllegalWrites {
+ return errors.New("illegal window increment value")
+ }
+ f.startWrite(http2FrameWindowUpdate, 0, streamID)
+ f.writeUint32(incr)
+ return f.endWrite()
+}
+
+// A HeadersFrame is used to open a stream and additionally carries a
+// header block fragment.
+type http2HeadersFrame struct {
+ http2FrameHeader
+
+ // Priority is set if FlagHeadersPriority is set in the FrameHeader.
+ Priority http2PriorityParam
+
+ headerFragBuf []byte // not owned
+}
+
+func (f *http2HeadersFrame) HeaderBlockFragment() []byte {
+ f.checkValid()
+ return f.headerFragBuf
+}
+
+func (f *http2HeadersFrame) HeadersEnded() bool {
+ return f.http2FrameHeader.Flags.Has(http2FlagHeadersEndHeaders)
+}
+
+func (f *http2HeadersFrame) StreamEnded() bool {
+ return f.http2FrameHeader.Flags.Has(http2FlagHeadersEndStream)
+}
+
+func (f *http2HeadersFrame) HasPriority() bool {
+ return f.http2FrameHeader.Flags.Has(http2FlagHeadersPriority)
+}
+
+func http2parseHeadersFrame(fh http2FrameHeader, p []byte) (_ http2Frame, err error) {
+ hf := &http2HeadersFrame{
+ http2FrameHeader: fh,
+ }
+ if fh.StreamID == 0 {
+
+ return nil, http2connError{http2ErrCodeProtocol, "HEADERS frame with stream ID 0"}
+ }
+ var padLength uint8
+ if fh.Flags.Has(http2FlagHeadersPadded) {
+ if p, padLength, err = http2readByte(p); err != nil {
+ return
+ }
+ }
+ if fh.Flags.Has(http2FlagHeadersPriority) {
+ var v uint32
+ p, v, err = http2readUint32(p)
+ if err != nil {
+ return nil, err
+ }
+ hf.Priority.StreamDep = v & 0x7fffffff
+ hf.Priority.Exclusive = (v != hf.Priority.StreamDep)
+ p, hf.Priority.Weight, err = http2readByte(p)
+ if err != nil {
+ return nil, err
+ }
+ }
+ if len(p)-int(padLength) <= 0 {
+ return nil, http2StreamError{fh.StreamID, http2ErrCodeProtocol}
+ }
+ hf.headerFragBuf = p[:len(p)-int(padLength)]
+ return hf, nil
+}
+
+// HeadersFrameParam are the parameters for writing a HEADERS frame.
+type http2HeadersFrameParam struct {
+ // StreamID is the required Stream ID to initiate.
+ StreamID uint32
+ // BlockFragment is part (or all) of a Header Block.
+ BlockFragment []byte
+
+ // EndStream indicates that the header block is the last that
+ // the endpoint will send for the identified stream. Setting
+ // this flag causes the stream to enter one of "half closed"
+ // states.
+ EndStream bool
+
+ // EndHeaders indicates that this frame contains an entire
+ // header block and is not followed by any
+ // CONTINUATION frames.
+ EndHeaders bool
+
+ // PadLength is the optional number of bytes of zeros to add
+ // to this frame.
+ PadLength uint8
+
+ // Priority, if non-zero, includes stream priority information
+ // in the HEADER frame.
+ Priority http2PriorityParam
+}
+
+// WriteHeaders writes a single HEADERS frame.
+//
+// This is a low-level header writing method. Encoding headers and
+// splitting them into any necessary CONTINUATION frames is handled
+// elsewhere.
+//
+// It will perform exactly one Write to the underlying Writer.
+// It is the caller's responsibility to not call other Write methods concurrently.
+func (f *http2Framer) WriteHeaders(p http2HeadersFrameParam) error {
+ if !http2validStreamID(p.StreamID) && !f.AllowIllegalWrites {
+ return http2errStreamID
+ }
+ var flags http2Flags
+ if p.PadLength != 0 {
+ flags |= http2FlagHeadersPadded
+ }
+ if p.EndStream {
+ flags |= http2FlagHeadersEndStream
+ }
+ if p.EndHeaders {
+ flags |= http2FlagHeadersEndHeaders
+ }
+ if !p.Priority.IsZero() {
+ flags |= http2FlagHeadersPriority
+ }
+ f.startWrite(http2FrameHeaders, flags, p.StreamID)
+ if p.PadLength != 0 {
+ f.writeByte(p.PadLength)
+ }
+ if !p.Priority.IsZero() {
+ v := p.Priority.StreamDep
+ if !http2validStreamID(v) && !f.AllowIllegalWrites {
+ return errors.New("invalid dependent stream id")
+ }
+ if p.Priority.Exclusive {
+ v |= 1 << 31
+ }
+ f.writeUint32(v)
+ f.writeByte(p.Priority.Weight)
+ }
+ f.wbuf = append(f.wbuf, p.BlockFragment...)
+ f.wbuf = append(f.wbuf, http2padZeros[:p.PadLength]...)
+ return f.endWrite()
+}
+
+// A PriorityFrame specifies the sender-advised priority of a stream.
+// See http://http2.github.io/http2-spec/#rfc.section.6.3
+type http2PriorityFrame struct {
+ http2FrameHeader
+ http2PriorityParam
+}
+
+// PriorityParam are the stream prioritzation parameters.
+type http2PriorityParam struct {
+ // StreamDep is a 31-bit stream identifier for the
+ // stream that this stream depends on. Zero means no
+ // dependency.
+ StreamDep uint32
+
+ // Exclusive is whether the dependency is exclusive.
+ Exclusive bool
+
+ // Weight is the stream's zero-indexed weight. It should be
+ // set together with StreamDep, or neither should be set. Per
+ // the spec, "Add one to the value to obtain a weight between
+ // 1 and 256."
+ Weight uint8
+}
+
+func (p http2PriorityParam) IsZero() bool {
+ return p == http2PriorityParam{}
+}
+
+func http2parsePriorityFrame(fh http2FrameHeader, payload []byte) (http2Frame, error) {
+ if fh.StreamID == 0 {
+ return nil, http2connError{http2ErrCodeProtocol, "PRIORITY frame with stream ID 0"}
+ }
+ if len(payload) != 5 {
+ return nil, http2connError{http2ErrCodeFrameSize, fmt.Sprintf("PRIORITY frame payload size was %d; want 5", len(payload))}
+ }
+ v := binary.BigEndian.Uint32(payload[:4])
+ streamID := v & 0x7fffffff
+ return &http2PriorityFrame{
+ http2FrameHeader: fh,
+ http2PriorityParam: http2PriorityParam{
+ Weight: payload[4],
+ StreamDep: streamID,
+ Exclusive: streamID != v,
+ },
+ }, nil
+}
+
+// WritePriority writes a PRIORITY frame.
+//
+// It will perform exactly one Write to the underlying Writer.
+// It is the caller's responsibility to not call other Write methods concurrently.
+func (f *http2Framer) WritePriority(streamID uint32, p http2PriorityParam) error {
+ if !http2validStreamID(streamID) && !f.AllowIllegalWrites {
+ return http2errStreamID
+ }
+ f.startWrite(http2FramePriority, 0, streamID)
+ v := p.StreamDep
+ if p.Exclusive {
+ v |= 1 << 31
+ }
+ f.writeUint32(v)
+ f.writeByte(p.Weight)
+ return f.endWrite()
+}
+
+// A RSTStreamFrame allows for abnormal termination of a stream.
+// See http://http2.github.io/http2-spec/#rfc.section.6.4
+type http2RSTStreamFrame struct {
+ http2FrameHeader
+ ErrCode http2ErrCode
+}
+
+func http2parseRSTStreamFrame(fh http2FrameHeader, p []byte) (http2Frame, error) {
+ if len(p) != 4 {
+ return nil, http2ConnectionError(http2ErrCodeFrameSize)
+ }
+ if fh.StreamID == 0 {
+ return nil, http2ConnectionError(http2ErrCodeProtocol)
+ }
+ return &http2RSTStreamFrame{fh, http2ErrCode(binary.BigEndian.Uint32(p[:4]))}, nil
+}
+
+// WriteRSTStream writes a RST_STREAM frame.
+//
+// It will perform exactly one Write to the underlying Writer.
+// It is the caller's responsibility to not call other Write methods concurrently.
+func (f *http2Framer) WriteRSTStream(streamID uint32, code http2ErrCode) error {
+ if !http2validStreamID(streamID) && !f.AllowIllegalWrites {
+ return http2errStreamID
+ }
+ f.startWrite(http2FrameRSTStream, 0, streamID)
+ f.writeUint32(uint32(code))
+ return f.endWrite()
+}
+
+// A ContinuationFrame is used to continue a sequence of header block fragments.
+// See http://http2.github.io/http2-spec/#rfc.section.6.10
+type http2ContinuationFrame struct {
+ http2FrameHeader
+ headerFragBuf []byte
+}
+
+func http2parseContinuationFrame(fh http2FrameHeader, p []byte) (http2Frame, error) {
+ if fh.StreamID == 0 {
+ return nil, http2connError{http2ErrCodeProtocol, "CONTINUATION frame with stream ID 0"}
+ }
+ return &http2ContinuationFrame{fh, p}, nil
+}
+
+func (f *http2ContinuationFrame) HeaderBlockFragment() []byte {
+ f.checkValid()
+ return f.headerFragBuf
+}
+
+func (f *http2ContinuationFrame) HeadersEnded() bool {
+ return f.http2FrameHeader.Flags.Has(http2FlagContinuationEndHeaders)
+}
+
+// WriteContinuation writes a CONTINUATION frame.
+//
+// It will perform exactly one Write to the underlying Writer.
+// It is the caller's responsibility to not call other Write methods concurrently.
+func (f *http2Framer) WriteContinuation(streamID uint32, endHeaders bool, headerBlockFragment []byte) error {
+ if !http2validStreamID(streamID) && !f.AllowIllegalWrites {
+ return http2errStreamID
+ }
+ var flags http2Flags
+ if endHeaders {
+ flags |= http2FlagContinuationEndHeaders
+ }
+ f.startWrite(http2FrameContinuation, flags, streamID)
+ f.wbuf = append(f.wbuf, headerBlockFragment...)
+ return f.endWrite()
+}
+
+// A PushPromiseFrame is used to initiate a server stream.
+// See http://http2.github.io/http2-spec/#rfc.section.6.6
+type http2PushPromiseFrame struct {
+ http2FrameHeader
+ PromiseID uint32
+ headerFragBuf []byte // not owned
+}
+
+func (f *http2PushPromiseFrame) HeaderBlockFragment() []byte {
+ f.checkValid()
+ return f.headerFragBuf
+}
+
+func (f *http2PushPromiseFrame) HeadersEnded() bool {
+ return f.http2FrameHeader.Flags.Has(http2FlagPushPromiseEndHeaders)
+}
+
+func http2parsePushPromise(fh http2FrameHeader, p []byte) (_ http2Frame, err error) {
+ pp := &http2PushPromiseFrame{
+ http2FrameHeader: fh,
+ }
+ if pp.StreamID == 0 {
+
+ return nil, http2ConnectionError(http2ErrCodeProtocol)
+ }
+ // The PUSH_PROMISE frame includes optional padding.
+ // Padding fields and flags are identical to those defined for DATA frames
+ var padLength uint8
+ if fh.Flags.Has(http2FlagPushPromisePadded) {
+ if p, padLength, err = http2readByte(p); err != nil {
+ return
+ }
+ }
+
+ p, pp.PromiseID, err = http2readUint32(p)
+ if err != nil {
+ return
+ }
+ pp.PromiseID = pp.PromiseID & (1<<31 - 1)
+
+ if int(padLength) > len(p) {
+
+ return nil, http2ConnectionError(http2ErrCodeProtocol)
+ }
+ pp.headerFragBuf = p[:len(p)-int(padLength)]
+ return pp, nil
+}
+
+// PushPromiseParam are the parameters for writing a PUSH_PROMISE frame.
+type http2PushPromiseParam struct {
+ // StreamID is the required Stream ID to initiate.
+ StreamID uint32
+
+ // PromiseID is the required Stream ID which this
+ // Push Promises
+ PromiseID uint32
+
+ // BlockFragment is part (or all) of a Header Block.
+ BlockFragment []byte
+
+ // EndHeaders indicates that this frame contains an entire
+ // header block and is not followed by any
+ // CONTINUATION frames.
+ EndHeaders bool
+
+ // PadLength is the optional number of bytes of zeros to add
+ // to this frame.
+ PadLength uint8
+}
+
+// WritePushPromise writes a single PushPromise Frame.
+//
+// As with Header Frames, This is the low level call for writing
+// individual frames. Continuation frames are handled elsewhere.
+//
+// It will perform exactly one Write to the underlying Writer.
+// It is the caller's responsibility to not call other Write methods concurrently.
+func (f *http2Framer) WritePushPromise(p http2PushPromiseParam) error {
+ if !http2validStreamID(p.StreamID) && !f.AllowIllegalWrites {
+ return http2errStreamID
+ }
+ var flags http2Flags
+ if p.PadLength != 0 {
+ flags |= http2FlagPushPromisePadded
+ }
+ if p.EndHeaders {
+ flags |= http2FlagPushPromiseEndHeaders
+ }
+ f.startWrite(http2FramePushPromise, flags, p.StreamID)
+ if p.PadLength != 0 {
+ f.writeByte(p.PadLength)
+ }
+ if !http2validStreamID(p.PromiseID) && !f.AllowIllegalWrites {
+ return http2errStreamID
+ }
+ f.writeUint32(p.PromiseID)
+ f.wbuf = append(f.wbuf, p.BlockFragment...)
+ f.wbuf = append(f.wbuf, http2padZeros[:p.PadLength]...)
+ return f.endWrite()
+}
+
+// WriteRawFrame writes a raw frame. This can be used to write
+// extension frames unknown to this package.
+func (f *http2Framer) WriteRawFrame(t http2FrameType, flags http2Flags, streamID uint32, payload []byte) error {
+ f.startWrite(t, flags, streamID)
+ f.writeBytes(payload)
+ return f.endWrite()
+}
+
+func http2readByte(p []byte) (remain []byte, b byte, err error) {
+ if len(p) == 0 {
+ return nil, 0, io.ErrUnexpectedEOF
+ }
+ return p[1:], p[0], nil
+}
+
+func http2readUint32(p []byte) (remain []byte, v uint32, err error) {
+ if len(p) < 4 {
+ return nil, 0, io.ErrUnexpectedEOF
+ }
+ return p[4:], binary.BigEndian.Uint32(p[:4]), nil
+}
+
+type http2streamEnder interface {
+ StreamEnded() bool
+}
+
+type http2headersEnder interface {
+ HeadersEnded() bool
+}
+
+func http2summarizeFrame(f http2Frame) string {
+ var buf bytes.Buffer
+ f.Header().writeDebug(&buf)
+ switch f := f.(type) {
+ case *http2SettingsFrame:
+ n := 0
+ f.ForeachSetting(func(s http2Setting) error {
+ n++
+ if n == 1 {
+ buf.WriteString(", settings:")
+ }
+ fmt.Fprintf(&buf, " %v=%v,", s.ID, s.Val)
+ return nil
+ })
+ if n > 0 {
+ buf.Truncate(buf.Len() - 1)
+ }
+ case *http2DataFrame:
+ data := f.Data()
+ const max = 256
+ if len(data) > max {
+ data = data[:max]
+ }
+ fmt.Fprintf(&buf, " data=%q", data)
+ if len(f.Data()) > max {
+ fmt.Fprintf(&buf, " (%d bytes omitted)", len(f.Data())-max)
+ }
+ case *http2WindowUpdateFrame:
+ if f.StreamID == 0 {
+ buf.WriteString(" (conn)")
+ }
+ fmt.Fprintf(&buf, " incr=%v", f.Increment)
+ case *http2PingFrame:
+ fmt.Fprintf(&buf, " ping=%q", f.Data[:])
+ case *http2GoAwayFrame:
+ fmt.Fprintf(&buf, " LastStreamID=%v ErrCode=%v Debug=%q",
+ f.LastStreamID, f.ErrCode, f.debugData)
+ case *http2RSTStreamFrame:
+ fmt.Fprintf(&buf, " ErrCode=%v", f.ErrCode)
+ }
+ return buf.String()
+}
+
+func http2requestCancel(req *Request) <-chan struct{} { return req.Cancel }
+
+var http2DebugGoroutines = os.Getenv("DEBUG_HTTP2_GOROUTINES") == "1"
+
+type http2goroutineLock uint64
+
+func http2newGoroutineLock() http2goroutineLock {
+ if !http2DebugGoroutines {
+ return 0
+ }
+ return http2goroutineLock(http2curGoroutineID())
+}
+
+func (g http2goroutineLock) check() {
+ if !http2DebugGoroutines {
+ return
+ }
+ if http2curGoroutineID() != uint64(g) {
+ panic("running on the wrong goroutine")
+ }
+}
+
+func (g http2goroutineLock) checkNotOn() {
+ if !http2DebugGoroutines {
+ return
+ }
+ if http2curGoroutineID() == uint64(g) {
+ panic("running on the wrong goroutine")
+ }
+}
+
+var http2goroutineSpace = []byte("goroutine ")
+
+func http2curGoroutineID() uint64 {
+ bp := http2littleBuf.Get().(*[]byte)
+ defer http2littleBuf.Put(bp)
+ b := *bp
+ b = b[:runtime.Stack(b, false)]
+
+ b = bytes.TrimPrefix(b, http2goroutineSpace)
+ i := bytes.IndexByte(b, ' ')
+ if i < 0 {
+ panic(fmt.Sprintf("No space found in %q", b))
+ }
+ b = b[:i]
+ n, err := http2parseUintBytes(b, 10, 64)
+ if err != nil {
+ panic(fmt.Sprintf("Failed to parse goroutine ID out of %q: %v", b, err))
+ }
+ return n
+}
+
+var http2littleBuf = sync.Pool{
+ New: func() interface{} {
+ buf := make([]byte, 64)
+ return &buf
+ },
+}
+
+// parseUintBytes is like strconv.ParseUint, but using a []byte.
+func http2parseUintBytes(s []byte, base int, bitSize int) (n uint64, err error) {
+ var cutoff, maxVal uint64
+
+ if bitSize == 0 {
+ bitSize = int(strconv.IntSize)
+ }
+
+ s0 := s
+ switch {
+ case len(s) < 1:
+ err = strconv.ErrSyntax
+ goto Error
+
+ case 2 <= base && base <= 36:
+
+ case base == 0:
+
+ switch {
+ case s[0] == '0' && len(s) > 1 && (s[1] == 'x' || s[1] == 'X'):
+ base = 16
+ s = s[2:]
+ if len(s) < 1 {
+ err = strconv.ErrSyntax
+ goto Error
+ }
+ case s[0] == '0':
+ base = 8
+ default:
+ base = 10
+ }
+
+ default:
+ err = errors.New("invalid base " + strconv.Itoa(base))
+ goto Error
+ }
+
+ n = 0
+ cutoff = http2cutoff64(base)
+ maxVal = 1<<uint(bitSize) - 1
+
+ for i := 0; i < len(s); i++ {
+ var v byte
+ d := s[i]
+ switch {
+ case '0' <= d && d <= '9':
+ v = d - '0'
+ case 'a' <= d && d <= 'z':
+ v = d - 'a' + 10
+ case 'A' <= d && d <= 'Z':
+ v = d - 'A' + 10
+ default:
+ n = 0
+ err = strconv.ErrSyntax
+ goto Error
+ }
+ if int(v) >= base {
+ n = 0
+ err = strconv.ErrSyntax
+ goto Error
+ }
+
+ if n >= cutoff {
+
+ n = 1<<64 - 1
+ err = strconv.ErrRange
+ goto Error
+ }
+ n *= uint64(base)
+
+ n1 := n + uint64(v)
+ if n1 < n || n1 > maxVal {
+
+ n = 1<<64 - 1
+ err = strconv.ErrRange
+ goto Error
+ }
+ n = n1
+ }
+
+ return n, nil
+
+Error:
+ return n, &strconv.NumError{Func: "ParseUint", Num: string(s0), Err: err}
+}
+
+// Return the first number n such that n*base >= 1<<64.
+func http2cutoff64(base int) uint64 {
+ if base < 2 {
+ return 0
+ }
+ return (1<<64-1)/uint64(base) + 1
+}
+
+var (
+ http2commonLowerHeader = map[string]string{} // Go-Canonical-Case -> lower-case
+ http2commonCanonHeader = map[string]string{} // lower-case -> Go-Canonical-Case
+)
+
+func init() {
+ for _, v := range []string{
+ "accept",
+ "accept-charset",
+ "accept-encoding",
+ "accept-language",
+ "accept-ranges",
+ "age",
+ "access-control-allow-origin",
+ "allow",
+ "authorization",
+ "cache-control",
+ "content-disposition",
+ "content-encoding",
+ "content-language",
+ "content-length",
+ "content-location",
+ "content-range",
+ "content-type",
+ "cookie",
+ "date",
+ "etag",
+ "expect",
+ "expires",
+ "from",
+ "host",
+ "if-match",
+ "if-modified-since",
+ "if-none-match",
+ "if-unmodified-since",
+ "last-modified",
+ "link",
+ "location",
+ "max-forwards",
+ "proxy-authenticate",
+ "proxy-authorization",
+ "range",
+ "referer",
+ "refresh",
+ "retry-after",
+ "server",
+ "set-cookie",
+ "strict-transport-security",
+ "trailer",
+ "transfer-encoding",
+ "user-agent",
+ "vary",
+ "via",
+ "www-authenticate",
+ } {
+ chk := CanonicalHeaderKey(v)
+ http2commonLowerHeader[chk] = v
+ http2commonCanonHeader[v] = chk
+ }
+}
+
+func http2lowerHeader(v string) string {
+ if s, ok := http2commonLowerHeader[v]; ok {
+ return s
+ }
+ return strings.ToLower(v)
+}
+
+var (
+ http2VerboseLogs bool
+ http2logFrameWrites bool
+ http2logFrameReads bool
+)
+
+func init() {
+ e := os.Getenv("GODEBUG")
+ if strings.Contains(e, "http2debug=1") {
+ http2VerboseLogs = true
+ }
+ if strings.Contains(e, "http2debug=2") {
+ http2VerboseLogs = true
+ http2logFrameWrites = true
+ http2logFrameReads = true
+ }
+}
+
+const (
+ // ClientPreface is the string that must be sent by new
+ // connections from clients.
+ http2ClientPreface = "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"
+
+ // SETTINGS_MAX_FRAME_SIZE default
+ // http://http2.github.io/http2-spec/#rfc.section.6.5.2
+ http2initialMaxFrameSize = 16384
+
+ // NextProtoTLS is the NPN/ALPN protocol negotiated during
+ // HTTP/2's TLS setup.
+ http2NextProtoTLS = "h2"
+
+ // http://http2.github.io/http2-spec/#SettingValues
+ http2initialHeaderTableSize = 4096
+
+ http2initialWindowSize = 65535 // 6.9.2 Initial Flow Control Window Size
+
+ http2defaultMaxReadFrameSize = 1 << 20
+)
+
+var (
+ http2clientPreface = []byte(http2ClientPreface)
+)
+
+type http2streamState int
+
+const (
+ http2stateIdle http2streamState = iota
+ http2stateOpen
+ http2stateHalfClosedLocal
+ http2stateHalfClosedRemote
+ http2stateResvLocal
+ http2stateResvRemote
+ http2stateClosed
+)
+
+var http2stateName = [...]string{
+ http2stateIdle: "Idle",
+ http2stateOpen: "Open",
+ http2stateHalfClosedLocal: "HalfClosedLocal",
+ http2stateHalfClosedRemote: "HalfClosedRemote",
+ http2stateResvLocal: "ResvLocal",
+ http2stateResvRemote: "ResvRemote",
+ http2stateClosed: "Closed",
+}
+
+func (st http2streamState) String() string {
+ return http2stateName[st]
+}
+
+// Setting is a setting parameter: which setting it is, and its value.
+type http2Setting struct {
+ // ID is which setting is being set.
+ // See http://http2.github.io/http2-spec/#SettingValues
+ ID http2SettingID
+
+ // Val is the value.
+ Val uint32
+}
+
+func (s http2Setting) String() string {
+ return fmt.Sprintf("[%v = %d]", s.ID, s.Val)
+}
+
+// Valid reports whether the setting is valid.
+func (s http2Setting) Valid() error {
+
+ switch s.ID {
+ case http2SettingEnablePush:
+ if s.Val != 1 && s.Val != 0 {
+ return http2ConnectionError(http2ErrCodeProtocol)
+ }
+ case http2SettingInitialWindowSize:
+ if s.Val > 1<<31-1 {
+ return http2ConnectionError(http2ErrCodeFlowControl)
+ }
+ case http2SettingMaxFrameSize:
+ if s.Val < 16384 || s.Val > 1<<24-1 {
+ return http2ConnectionError(http2ErrCodeProtocol)
+ }
+ }
+ return nil
+}
+
+// A SettingID is an HTTP/2 setting as defined in
+// http://http2.github.io/http2-spec/#iana-settings
+type http2SettingID uint16
+
+const (
+ http2SettingHeaderTableSize http2SettingID = 0x1
+ http2SettingEnablePush http2SettingID = 0x2
+ http2SettingMaxConcurrentStreams http2SettingID = 0x3
+ http2SettingInitialWindowSize http2SettingID = 0x4
+ http2SettingMaxFrameSize http2SettingID = 0x5
+ http2SettingMaxHeaderListSize http2SettingID = 0x6
+)
+
+var http2settingName = map[http2SettingID]string{
+ http2SettingHeaderTableSize: "HEADER_TABLE_SIZE",
+ http2SettingEnablePush: "ENABLE_PUSH",
+ http2SettingMaxConcurrentStreams: "MAX_CONCURRENT_STREAMS",
+ http2SettingInitialWindowSize: "INITIAL_WINDOW_SIZE",
+ http2SettingMaxFrameSize: "MAX_FRAME_SIZE",
+ http2SettingMaxHeaderListSize: "MAX_HEADER_LIST_SIZE",
+}
+
+func (s http2SettingID) String() string {
+ if v, ok := http2settingName[s]; ok {
+ return v
+ }
+ return fmt.Sprintf("UNKNOWN_SETTING_%d", uint16(s))
+}
+
+var (
+ http2errInvalidHeaderFieldName = errors.New("http2: invalid header field name")
+ http2errInvalidHeaderFieldValue = errors.New("http2: invalid header field value")
+)
+
+// validHeaderFieldName reports whether v is a valid header field name (key).
+// RFC 7230 says:
+// header-field = field-name ":" OWS field-value OWS
+// field-name = token
+// tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" / "+" / "-" / "." /
+// "^" / "_" / "
+// Further, http2 says:
+// "Just as in HTTP/1.x, header field names are strings of ASCII
+// characters that are compared in a case-insensitive
+// fashion. However, header field names MUST be converted to
+// lowercase prior to their encoding in HTTP/2. "
+func http2validHeaderFieldName(v string) bool {
+ if len(v) == 0 {
+ return false
+ }
+ for _, r := range v {
+ if int(r) >= len(http2isTokenTable) || ('A' <= r && r <= 'Z') {
+ return false
+ }
+ if !http2isTokenTable[byte(r)] {
+ return false
+ }
+ }
+ return true
+}
+
+// validHeaderFieldValue reports whether v is a valid header field value.
+//
+// RFC 7230 says:
+// field-value = *( field-content / obs-fold )
+// obj-fold = N/A to http2, and deprecated
+// field-content = field-vchar [ 1*( SP / HTAB ) field-vchar ]
+// field-vchar = VCHAR / obs-text
+// obs-text = %x80-FF
+// VCHAR = "any visible [USASCII] character"
+//
+// http2 further says: "Similarly, HTTP/2 allows header field values
+// that are not valid. While most of the values that can be encoded
+// will not alter header field parsing, carriage return (CR, ASCII
+// 0xd), line feed (LF, ASCII 0xa), and the zero character (NUL, ASCII
+// 0x0) might be exploited by an attacker if they are translated
+// verbatim. Any request or response that contains a character not
+// permitted in a header field value MUST be treated as malformed
+// (Section 8.1.2.6). Valid characters are defined by the
+// field-content ABNF rule in Section 3.2 of [RFC7230]."
+//
+// This function does not (yet?) properly handle the rejection of
+// strings that begin or end with SP or HTAB.
+func http2validHeaderFieldValue(v string) bool {
+ for i := 0; i < len(v); i++ {
+ if b := v[i]; b < ' ' && b != '\t' || b == 0x7f {
+ return false
+ }
+ }
+ return true
+}
+
+var http2httpCodeStringCommon = map[int]string{} // n -> strconv.Itoa(n)
+
+func init() {
+ for i := 100; i <= 999; i++ {
+ if v := StatusText(i); v != "" {
+ http2httpCodeStringCommon[i] = strconv.Itoa(i)
+ }
+ }
+}
+
+func http2httpCodeString(code int) string {
+ if s, ok := http2httpCodeStringCommon[code]; ok {
+ return s
+ }
+ return strconv.Itoa(code)
+}
+
+// from pkg io
+type http2stringWriter interface {
+ WriteString(s string) (n int, err error)
+}
+
+// A gate lets two goroutines coordinate their activities.
+type http2gate chan struct{}
+
+func (g http2gate) Done() { g <- struct{}{} }
+
+func (g http2gate) Wait() { <-g }
+
+// A closeWaiter is like a sync.WaitGroup but only goes 1 to 0 (open to closed).
+type http2closeWaiter chan struct{}
+
+// Init makes a closeWaiter usable.
+// It exists because so a closeWaiter value can be placed inside a
+// larger struct and have the Mutex and Cond's memory in the same
+// allocation.
+func (cw *http2closeWaiter) Init() {
+ *cw = make(chan struct{})
+}
+
+// Close marks the closeWaiter as closed and unblocks any waiters.
+func (cw http2closeWaiter) Close() {
+ close(cw)
+}
+
+// Wait waits for the closeWaiter to become closed.
+func (cw http2closeWaiter) Wait() {
+ <-cw
+}
+
+// bufferedWriter is a buffered writer that writes to w.
+// Its buffered writer is lazily allocated as needed, to minimize
+// idle memory usage with many connections.
+type http2bufferedWriter struct {
+ w io.Writer // immutable
+ bw *bufio.Writer // non-nil when data is buffered
+}
+
+func http2newBufferedWriter(w io.Writer) *http2bufferedWriter {
+ return &http2bufferedWriter{w: w}
+}
+
+var http2bufWriterPool = sync.Pool{
+ New: func() interface{} {
+
+ return bufio.NewWriterSize(nil, 4<<10)
+ },
+}
+
+func (w *http2bufferedWriter) Write(p []byte) (n int, err error) {
+ if w.bw == nil {
+ bw := http2bufWriterPool.Get().(*bufio.Writer)
+ bw.Reset(w.w)
+ w.bw = bw
+ }
+ return w.bw.Write(p)
+}
+
+func (w *http2bufferedWriter) Flush() error {
+ bw := w.bw
+ if bw == nil {
+ return nil
+ }
+ err := bw.Flush()
+ bw.Reset(nil)
+ http2bufWriterPool.Put(bw)
+ w.bw = nil
+ return err
+}
+
+func http2mustUint31(v int32) uint32 {
+ if v < 0 || v > 2147483647 {
+ panic("out of range")
+ }
+ return uint32(v)
+}
+
+// bodyAllowedForStatus reports whether a given response status code
+// permits a body. See RFC2616, section 4.4.
+func http2bodyAllowedForStatus(status int) bool {
+ switch {
+ case status >= 100 && status <= 199:
+ return false
+ case status == 204:
+ return false
+ case status == 304:
+ return false
+ }
+ return true
+}
+
+type http2httpError struct {
+ msg string
+ timeout bool
+}
+
+func (e *http2httpError) Error() string { return e.msg }
+
+func (e *http2httpError) Timeout() bool { return e.timeout }
+
+func (e *http2httpError) Temporary() bool { return true }
+
+var http2errTimeout error = &http2httpError{msg: "http2: timeout awaiting response headers", timeout: true}
+
+var http2isTokenTable = [127]bool{
+ '!': true,
+ '#': true,
+ '$': true,
+ '%': true,
+ '&': true,
+ '\'': true,
+ '*': true,
+ '+': true,
+ '-': true,
+ '.': true,
+ '0': true,
+ '1': true,
+ '2': true,
+ '3': true,
+ '4': true,
+ '5': true,
+ '6': true,
+ '7': true,
+ '8': true,
+ '9': true,
+ 'A': true,
+ 'B': true,
+ 'C': true,
+ 'D': true,
+ 'E': true,
+ 'F': true,
+ 'G': true,
+ 'H': true,
+ 'I': true,
+ 'J': true,
+ 'K': true,
+ 'L': true,
+ 'M': true,
+ 'N': true,
+ 'O': true,
+ 'P': true,
+ 'Q': true,
+ 'R': true,
+ 'S': true,
+ 'T': true,
+ 'U': true,
+ 'W': true,
+ 'V': true,
+ 'X': true,
+ 'Y': true,
+ 'Z': true,
+ '^': true,
+ '_': true,
+ '`': true,
+ 'a': true,
+ 'b': true,
+ 'c': true,
+ 'd': true,
+ 'e': true,
+ 'f': true,
+ 'g': true,
+ 'h': true,
+ 'i': true,
+ 'j': true,
+ 'k': true,
+ 'l': true,
+ 'm': true,
+ 'n': true,
+ 'o': true,
+ 'p': true,
+ 'q': true,
+ 'r': true,
+ 's': true,
+ 't': true,
+ 'u': true,
+ 'v': true,
+ 'w': true,
+ 'x': true,
+ 'y': true,
+ 'z': true,
+ '|': true,
+ '~': true,
+}
+
+// pipe is a goroutine-safe io.Reader/io.Writer pair. It's like
+// io.Pipe except there are no PipeReader/PipeWriter halves, and the
+// underlying buffer is an interface. (io.Pipe is always unbuffered)
+type http2pipe struct {
+ mu sync.Mutex
+ c sync.Cond // c.L lazily initialized to &p.mu
+ b http2pipeBuffer
+ err error // read error once empty. non-nil means closed.
+ breakErr error // immediate read error (caller doesn't see rest of b)
+ donec chan struct{} // closed on error
+ readFn func() // optional code to run in Read before error
+}
+
+type http2pipeBuffer interface {
+ Len() int
+ io.Writer
+ io.Reader
+}
+
+// Read waits until data is available and copies bytes
+// from the buffer into p.
+func (p *http2pipe) Read(d []byte) (n int, err error) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ if p.c.L == nil {
+ p.c.L = &p.mu
+ }
+ for {
+ if p.breakErr != nil {
+ return 0, p.breakErr
+ }
+ if p.b.Len() > 0 {
+ return p.b.Read(d)
+ }
+ if p.err != nil {
+ if p.readFn != nil {
+ p.readFn()
+ p.readFn = nil
+ }
+ return 0, p.err
+ }
+ p.c.Wait()
+ }
+}
+
+var http2errClosedPipeWrite = errors.New("write on closed buffer")
+
+// Write copies bytes from p into the buffer and wakes a reader.
+// It is an error to write more data than the buffer can hold.
+func (p *http2pipe) Write(d []byte) (n int, err error) {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ if p.c.L == nil {
+ p.c.L = &p.mu
+ }
+ defer p.c.Signal()
+ if p.err != nil {
+ return 0, http2errClosedPipeWrite
+ }
+ return p.b.Write(d)
+}
+
+// CloseWithError causes the next Read (waking up a current blocked
+// Read if needed) to return the provided err after all data has been
+// read.
+//
+// The error must be non-nil.
+func (p *http2pipe) CloseWithError(err error) { p.closeWithError(&p.err, err, nil) }
+
+// BreakWithError causes the next Read (waking up a current blocked
+// Read if needed) to return the provided err immediately, without
+// waiting for unread data.
+func (p *http2pipe) BreakWithError(err error) { p.closeWithError(&p.breakErr, err, nil) }
+
+// closeWithErrorAndCode is like CloseWithError but also sets some code to run
+// in the caller's goroutine before returning the error.
+func (p *http2pipe) closeWithErrorAndCode(err error, fn func()) { p.closeWithError(&p.err, err, fn) }
+
+func (p *http2pipe) closeWithError(dst *error, err error, fn func()) {
+ if err == nil {
+ panic("err must be non-nil")
+ }
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ if p.c.L == nil {
+ p.c.L = &p.mu
+ }
+ defer p.c.Signal()
+ if *dst != nil {
+
+ return
+ }
+ p.readFn = fn
+ *dst = err
+ p.closeDoneLocked()
+}
+
+// requires p.mu be held.
+func (p *http2pipe) closeDoneLocked() {
+ if p.donec == nil {
+ return
+ }
+
+ select {
+ case <-p.donec:
+ default:
+ close(p.donec)
+ }
+}
+
+// Err returns the error (if any) first set by BreakWithError or CloseWithError.
+func (p *http2pipe) Err() error {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ if p.breakErr != nil {
+ return p.breakErr
+ }
+ return p.err
+}
+
+// Done returns a channel which is closed if and when this pipe is closed
+// with CloseWithError.
+func (p *http2pipe) Done() <-chan struct{} {
+ p.mu.Lock()
+ defer p.mu.Unlock()
+ if p.donec == nil {
+ p.donec = make(chan struct{})
+ if p.err != nil || p.breakErr != nil {
+
+ p.closeDoneLocked()
+ }
+ }
+ return p.donec
+}
+
+const (
+ http2prefaceTimeout = 10 * time.Second
+ http2firstSettingsTimeout = 2 * time.Second // should be in-flight with preface anyway
+ http2handlerChunkWriteSize = 4 << 10
+ http2defaultMaxStreams = 250 // TODO: make this 100 as the GFE seems to?
+)
+
+var (
+ http2errClientDisconnected = errors.New("client disconnected")
+ http2errClosedBody = errors.New("body closed by handler")
+ http2errHandlerComplete = errors.New("http2: request body closed due to handler exiting")
+ http2errStreamClosed = errors.New("http2: stream closed")
+)
+
+var http2responseWriterStatePool = sync.Pool{
+ New: func() interface{} {
+ rws := &http2responseWriterState{}
+ rws.bw = bufio.NewWriterSize(http2chunkWriter{rws}, http2handlerChunkWriteSize)
+ return rws
+ },
+}
+
+// Test hooks.
+var (
+ http2testHookOnConn func()
+ http2testHookGetServerConn func(*http2serverConn)
+ http2testHookOnPanicMu *sync.Mutex // nil except in tests
+ http2testHookOnPanic func(sc *http2serverConn, panicVal interface{}) (rePanic bool)
+)
+
+// Server is an HTTP/2 server.
+type http2Server struct {
+ // MaxHandlers limits the number of http.Handler ServeHTTP goroutines
+ // which may run at a time over all connections.
+ // Negative or zero no limit.
+ // TODO: implement
+ MaxHandlers int
+
+ // MaxConcurrentStreams optionally specifies the number of
+ // concurrent streams that each client may have open at a
+ // time. This is unrelated to the number of http.Handler goroutines
+ // which may be active globally, which is MaxHandlers.
+ // If zero, MaxConcurrentStreams defaults to at least 100, per
+ // the HTTP/2 spec's recommendations.
+ MaxConcurrentStreams uint32
+
+ // MaxReadFrameSize optionally specifies the largest frame
+ // this server is willing to read. A valid value is between
+ // 16k and 16M, inclusive. If zero or otherwise invalid, a
+ // default value is used.
+ MaxReadFrameSize uint32
+
+ // PermitProhibitedCipherSuites, if true, permits the use of
+ // cipher suites prohibited by the HTTP/2 spec.
+ PermitProhibitedCipherSuites bool
+}
+
+func (s *http2Server) maxReadFrameSize() uint32 {
+ if v := s.MaxReadFrameSize; v >= http2minMaxFrameSize && v <= http2maxFrameSize {
+ return v
+ }
+ return http2defaultMaxReadFrameSize
+}
+
+func (s *http2Server) maxConcurrentStreams() uint32 {
+ if v := s.MaxConcurrentStreams; v > 0 {
+ return v
+ }
+ return http2defaultMaxStreams
+}
+
+// ConfigureServer adds HTTP/2 support to a net/http Server.
+//
+// The configuration conf may be nil.
+//
+// ConfigureServer must be called before s begins serving.
+func http2ConfigureServer(s *Server, conf *http2Server) error {
+ if conf == nil {
+ conf = new(http2Server)
+ }
+
+ if s.TLSConfig == nil {
+ s.TLSConfig = new(tls.Config)
+ } else if s.TLSConfig.CipherSuites != nil {
+ // If they already provided a CipherSuite list, return
+ // an error if it has a bad order or is missing
+ // ECDHE_RSA_WITH_AES_128_GCM_SHA256.
+ const requiredCipher = tls.TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256
+ haveRequired := false
+ sawBad := false
+ for i, cs := range s.TLSConfig.CipherSuites {
+ if cs == requiredCipher {
+ haveRequired = true
+ }
+ if http2isBadCipher(cs) {
+ sawBad = true
+ } else if sawBad {
+ return fmt.Errorf("http2: TLSConfig.CipherSuites index %d contains an HTTP/2-approved cipher suite (%#04x), but it comes after unapproved cipher suites. With this configuration, clients that don't support previous, approved cipher suites may be given an unapproved one and reject the connection.", i, cs)
+ }
+ }
+ if !haveRequired {
+ return fmt.Errorf("http2: TLSConfig.CipherSuites is missing HTTP/2-required TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256")
+ }
+ }
+
+ s.TLSConfig.PreferServerCipherSuites = true
+
+ haveNPN := false
+ for _, p := range s.TLSConfig.NextProtos {
+ if p == http2NextProtoTLS {
+ haveNPN = true
+ break
+ }
+ }
+ if !haveNPN {
+ s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, http2NextProtoTLS)
+ }
+
+ s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, "h2-14")
+
+ if s.TLSNextProto == nil {
+ s.TLSNextProto = map[string]func(*Server, *tls.Conn, Handler){}
+ }
+ protoHandler := func(hs *Server, c *tls.Conn, h Handler) {
+ if http2testHookOnConn != nil {
+ http2testHookOnConn()
+ }
+ conf.handleConn(hs, c, h)
+ }
+ s.TLSNextProto[http2NextProtoTLS] = protoHandler
+ s.TLSNextProto["h2-14"] = protoHandler
+ return nil
+}
+
+func (srv *http2Server) handleConn(hs *Server, c net.Conn, h Handler) {
+ sc := &http2serverConn{
+ srv: srv,
+ hs: hs,
+ conn: c,
+ remoteAddrStr: c.RemoteAddr().String(),
+ bw: http2newBufferedWriter(c),
+ handler: h,
+ streams: make(map[uint32]*http2stream),
+ readFrameCh: make(chan http2readFrameResult),
+ wantWriteFrameCh: make(chan http2frameWriteMsg, 8),
+ wroteFrameCh: make(chan http2frameWriteResult, 1),
+ bodyReadCh: make(chan http2bodyReadMsg),
+ doneServing: make(chan struct{}),
+ advMaxStreams: srv.maxConcurrentStreams(),
+ writeSched: http2writeScheduler{
+ maxFrameSize: http2initialMaxFrameSize,
+ },
+ initialWindowSize: http2initialWindowSize,
+ headerTableSize: http2initialHeaderTableSize,
+ serveG: http2newGoroutineLock(),
+ pushEnabled: true,
+ }
+ sc.flow.add(http2initialWindowSize)
+ sc.inflow.add(http2initialWindowSize)
+ sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf)
+ sc.hpackDecoder = hpack.NewDecoder(http2initialHeaderTableSize, nil)
+ sc.hpackDecoder.SetMaxStringLength(sc.maxHeaderStringLen())
+
+ fr := http2NewFramer(sc.bw, c)
+ fr.SetMaxReadFrameSize(srv.maxReadFrameSize())
+ sc.framer = fr
+
+ if tc, ok := c.(*tls.Conn); ok {
+ sc.tlsState = new(tls.ConnectionState)
+ *sc.tlsState = tc.ConnectionState()
+
+ if sc.tlsState.Version < tls.VersionTLS12 {
+ sc.rejectConn(http2ErrCodeInadequateSecurity, "TLS version too low")
+ return
+ }
+
+ if sc.tlsState.ServerName == "" {
+
+ }
+
+ if !srv.PermitProhibitedCipherSuites && http2isBadCipher(sc.tlsState.CipherSuite) {
+
+ sc.rejectConn(http2ErrCodeInadequateSecurity, fmt.Sprintf("Prohibited TLS 1.2 Cipher Suite: %x", sc.tlsState.CipherSuite))
+ return
+ }
+ }
+
+ if hook := http2testHookGetServerConn; hook != nil {
+ hook(sc)
+ }
+ sc.serve()
+}
+
+// isBadCipher reports whether the cipher is blacklisted by the HTTP/2 spec.
+func http2isBadCipher(cipher uint16) bool {
+ switch cipher {
+ case tls.TLS_RSA_WITH_RC4_128_SHA,
+ tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA,
+ tls.TLS_RSA_WITH_AES_128_CBC_SHA,
+ tls.TLS_RSA_WITH_AES_256_CBC_SHA,
+ tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA,
+ tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA,
+ tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA,
+ tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA,
+ tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA,
+ tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
+ tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA:
+
+ return true
+ default:
+ return false
+ }
+}
+
+func (sc *http2serverConn) rejectConn(err http2ErrCode, debug string) {
+ sc.vlogf("http2: server rejecting conn: %v, %s", err, debug)
+
+ sc.framer.WriteGoAway(0, err, []byte(debug))
+ sc.bw.Flush()
+ sc.conn.Close()
+}
+
+type http2serverConn struct {
+ // Immutable:
+ srv *http2Server
+ hs *Server
+ conn net.Conn
+ bw *http2bufferedWriter // writing to conn
+ handler Handler
+ framer *http2Framer
+ hpackDecoder *hpack.Decoder
+ doneServing chan struct{} // closed when serverConn.serve ends
+ readFrameCh chan http2readFrameResult // written by serverConn.readFrames
+ wantWriteFrameCh chan http2frameWriteMsg // from handlers -> serve
+ wroteFrameCh chan http2frameWriteResult // from writeFrameAsync -> serve, tickles more frame writes
+ bodyReadCh chan http2bodyReadMsg // from handlers -> serve
+ testHookCh chan func(int) // code to run on the serve loop
+ flow http2flow // conn-wide (not stream-specific) outbound flow control
+ inflow http2flow // conn-wide inbound flow control
+ tlsState *tls.ConnectionState // shared by all handlers, like net/http
+ remoteAddrStr string
+
+ // Everything following is owned by the serve loop; use serveG.check():
+ serveG http2goroutineLock // used to verify funcs are on serve()
+ pushEnabled bool
+ sawFirstSettings bool // got the initial SETTINGS frame after the preface
+ needToSendSettingsAck bool
+ unackedSettings int // how many SETTINGS have we sent without ACKs?
+ clientMaxStreams uint32 // SETTINGS_MAX_CONCURRENT_STREAMS from client (our PUSH_PROMISE limit)
+ advMaxStreams uint32 // our SETTINGS_MAX_CONCURRENT_STREAMS advertised the client
+ curOpenStreams uint32 // client's number of open streams
+ maxStreamID uint32 // max ever seen
+ streams map[uint32]*http2stream
+ initialWindowSize int32
+ headerTableSize uint32
+ peerMaxHeaderListSize uint32 // zero means unknown (default)
+ canonHeader map[string]string // http2-lower-case -> Go-Canonical-Case
+ req http2requestParam // non-zero while reading request headers
+ writingFrame bool // started write goroutine but haven't heard back on wroteFrameCh
+ needsFrameFlush bool // last frame write wasn't a flush
+ writeSched http2writeScheduler
+ inGoAway bool // we've started to or sent GOAWAY
+ needToSendGoAway bool // we need to schedule a GOAWAY frame write
+ goAwayCode http2ErrCode
+ shutdownTimerCh <-chan time.Time // nil until used
+ shutdownTimer *time.Timer // nil until used
+
+ // Owned by the writeFrameAsync goroutine:
+ headerWriteBuf bytes.Buffer
+ hpackEncoder *hpack.Encoder
+}
+
+func (sc *http2serverConn) maxHeaderStringLen() int {
+ v := sc.maxHeaderListSize()
+ if uint32(int(v)) == v {
+ return int(v)
+ }
+
+ return 0
+}
+
+func (sc *http2serverConn) maxHeaderListSize() uint32 {
+ n := sc.hs.MaxHeaderBytes
+ if n <= 0 {
+ n = DefaultMaxHeaderBytes
+ }
+ // http2's count is in a slightly different unit and includes 32 bytes per pair.
+ // So, take the net/http.Server value and pad it up a bit, assuming 10 headers.
+ const perFieldOverhead = 32 // per http2 spec
+ const typicalHeaders = 10 // conservative
+ return uint32(n + typicalHeaders*perFieldOverhead)
+}
+
+// requestParam is the state of the next request, initialized over
+// potentially several frames HEADERS + zero or more CONTINUATION
+// frames.
+type http2requestParam struct {
+ // stream is non-nil if we're reading (HEADER or CONTINUATION)
+ // frames for a request (but not DATA).
+ stream *http2stream
+ header Header
+ method, path string
+ scheme, authority string
+ sawRegularHeader bool // saw a non-pseudo header already
+ invalidHeader bool // an invalid header was seen
+ headerListSize int64 // actually uint32, but easier math this way
+}
+
+// stream represents a stream. This is the minimal metadata needed by
+// the serve goroutine. Most of the actual stream state is owned by
+// the http.Handler's goroutine in the responseWriter. Because the
+// responseWriter's responseWriterState is recycled at the end of a
+// handler, this struct intentionally has no pointer to the
+// *responseWriter{,State} itself, as the Handler ending nils out the
+// responseWriter's state field.
+type http2stream struct {
+ // immutable:
+ sc *http2serverConn
+ id uint32
+ body *http2pipe // non-nil if expecting DATA frames
+ cw http2closeWaiter // closed wait stream transitions to closed state
+
+ // owned by serverConn's serve loop:
+ bodyBytes int64 // body bytes seen so far
+ declBodyBytes int64 // or -1 if undeclared
+ flow http2flow // limits writing from Handler to client
+ inflow http2flow // what the client is allowed to POST/etc to us
+ parent *http2stream // or nil
+ numTrailerValues int64
+ weight uint8
+ state http2streamState
+ sentReset bool // only true once detached from streams map
+ gotReset bool // only true once detacted from streams map
+ gotTrailerHeader bool // HEADER frame for trailers was seen
+
+ trailer Header // accumulated trailers
+ reqTrailer Header // handler's Request.Trailer
+}
+
+func (sc *http2serverConn) Framer() *http2Framer { return sc.framer }
+
+func (sc *http2serverConn) CloseConn() error { return sc.conn.Close() }
+
+func (sc *http2serverConn) Flush() error { return sc.bw.Flush() }
+
+func (sc *http2serverConn) HeaderEncoder() (*hpack.Encoder, *bytes.Buffer) {
+ return sc.hpackEncoder, &sc.headerWriteBuf
+}
+
+func (sc *http2serverConn) state(streamID uint32) (http2streamState, *http2stream) {
+ sc.serveG.check()
+
+ if st, ok := sc.streams[streamID]; ok {
+ return st.state, st
+ }
+
+ if streamID <= sc.maxStreamID {
+ return http2stateClosed, nil
+ }
+ return http2stateIdle, nil
+}
+
+// setConnState calls the net/http ConnState hook for this connection, if configured.
+// Note that the net/http package does StateNew and StateClosed for us.
+// There is currently no plan for StateHijacked or hijacking HTTP/2 connections.
+func (sc *http2serverConn) setConnState(state ConnState) {
+ if sc.hs.ConnState != nil {
+ sc.hs.ConnState(sc.conn, state)
+ }
+}
+
+func (sc *http2serverConn) vlogf(format string, args ...interface{}) {
+ if http2VerboseLogs {
+ sc.logf(format, args...)
+ }
+}
+
+func (sc *http2serverConn) logf(format string, args ...interface{}) {
+ if lg := sc.hs.ErrorLog; lg != nil {
+ lg.Printf(format, args...)
+ } else {
+ log.Printf(format, args...)
+ }
+}
+
+var http2uintptrType = reflect.TypeOf(uintptr(0))
+
+// errno returns v's underlying uintptr, else 0.
+//
+// TODO: remove this helper function once http2 can use build
+// tags. See comment in isClosedConnError.
+func http2errno(v error) uintptr {
+ if rv := reflect.ValueOf(v); rv.Kind() == reflect.Uintptr {
+ return uintptr(rv.Uint())
+ }
+ return 0
+}
+
+// isClosedConnError reports whether err is an error from use of a closed
+// network connection.
+func http2isClosedConnError(err error) bool {
+ if err == nil {
+ return false
+ }
+
+ str := err.Error()
+ if strings.Contains(str, "use of closed network connection") {
+ return true
+ }
+
+ if runtime.GOOS == "windows" {
+ if oe, ok := err.(*net.OpError); ok && oe.Op == "read" {
+ if se, ok := oe.Err.(*os.SyscallError); ok && se.Syscall == "wsarecv" {
+ const WSAECONNABORTED = 10053
+ const WSAECONNRESET = 10054
+ if n := http2errno(se.Err); n == WSAECONNRESET || n == WSAECONNABORTED {
+ return true
+ }
+ }
+ }
+ }
+ return false
+}
+
+func (sc *http2serverConn) condlogf(err error, format string, args ...interface{}) {
+ if err == nil {
+ return
+ }
+ if err == io.EOF || err == io.ErrUnexpectedEOF || http2isClosedConnError(err) {
+
+ sc.vlogf(format, args...)
+ } else {
+ sc.logf(format, args...)
+ }
+}
+
+func (sc *http2serverConn) onNewHeaderField(f hpack.HeaderField) {
+ sc.serveG.check()
+ if http2VerboseLogs {
+ sc.vlogf("http2: server decoded %v", f)
+ }
+ switch {
+ case !http2validHeaderFieldValue(f.Value):
+ sc.req.invalidHeader = true
+ case strings.HasPrefix(f.Name, ":"):
+ if sc.req.sawRegularHeader {
+ sc.logf("pseudo-header after regular header")
+ sc.req.invalidHeader = true
+ return
+ }
+ var dst *string
+ switch f.Name {
+ case ":method":
+ dst = &sc.req.method
+ case ":path":
+ dst = &sc.req.path
+ case ":scheme":
+ dst = &sc.req.scheme
+ case ":authority":
+ dst = &sc.req.authority
+ default:
+
+ sc.logf("invalid pseudo-header %q", f.Name)
+ sc.req.invalidHeader = true
+ return
+ }
+ if *dst != "" {
+ sc.logf("duplicate pseudo-header %q sent", f.Name)
+ sc.req.invalidHeader = true
+ return
+ }
+ *dst = f.Value
+ case !http2validHeaderFieldName(f.Name):
+ sc.req.invalidHeader = true
+ default:
+ sc.req.sawRegularHeader = true
+ sc.req.header.Add(sc.canonicalHeader(f.Name), f.Value)
+ const headerFieldOverhead = 32 // per spec
+ sc.req.headerListSize += int64(len(f.Name)) + int64(len(f.Value)) + headerFieldOverhead
+ if sc.req.headerListSize > int64(sc.maxHeaderListSize()) {
+ sc.hpackDecoder.SetEmitEnabled(false)
+ }
+ }
+}
+
+func (st *http2stream) onNewTrailerField(f hpack.HeaderField) {
+ sc := st.sc
+ sc.serveG.check()
+ if http2VerboseLogs {
+ sc.vlogf("http2: server decoded trailer %v", f)
+ }
+ switch {
+ case strings.HasPrefix(f.Name, ":"):
+ sc.req.invalidHeader = true
+ return
+ case !http2validHeaderFieldName(f.Name) || !http2validHeaderFieldValue(f.Value):
+ sc.req.invalidHeader = true
+ return
+ default:
+ key := sc.canonicalHeader(f.Name)
+ if st.trailer != nil {
+ vv := append(st.trailer[key], f.Value)
+ st.trailer[key] = vv
+
+ // arbitrary; TODO: read spec about header list size limits wrt trailers
+ const tooBig = 1000
+ if len(vv) >= tooBig {
+ sc.hpackDecoder.SetEmitEnabled(false)
+ }
+ }
+ }
+}
+
+func (sc *http2serverConn) canonicalHeader(v string) string {
+ sc.serveG.check()
+ cv, ok := http2commonCanonHeader[v]
+ if ok {
+ return cv
+ }
+ cv, ok = sc.canonHeader[v]
+ if ok {
+ return cv
+ }
+ if sc.canonHeader == nil {
+ sc.canonHeader = make(map[string]string)
+ }
+ cv = CanonicalHeaderKey(v)
+ sc.canonHeader[v] = cv
+ return cv
+}
+
+type http2readFrameResult struct {
+ f http2Frame // valid until readMore is called
+ err error
+
+ // readMore should be called once the consumer no longer needs or
+ // retains f. After readMore, f is invalid and more frames can be
+ // read.
+ readMore func()
+}
+
+// readFrames is the loop that reads incoming frames.
+// It takes care to only read one frame at a time, blocking until the
+// consumer is done with the frame.
+// It's run on its own goroutine.
+func (sc *http2serverConn) readFrames() {
+ gate := make(http2gate)
+ for {
+ f, err := sc.framer.ReadFrame()
+ select {
+ case sc.readFrameCh <- http2readFrameResult{f, err, gate.Done}:
+ case <-sc.doneServing:
+ return
+ }
+ select {
+ case <-gate:
+ case <-sc.doneServing:
+ return
+ }
+ if http2terminalReadFrameError(err) {
+ return
+ }
+ }
+}
+
+// frameWriteResult is the message passed from writeFrameAsync to the serve goroutine.
+type http2frameWriteResult struct {
+ wm http2frameWriteMsg // what was written (or attempted)
+ err error // result of the writeFrame call
+}
+
+// writeFrameAsync runs in its own goroutine and writes a single frame
+// and then reports when it's done.
+// At most one goroutine can be running writeFrameAsync at a time per
+// serverConn.
+func (sc *http2serverConn) writeFrameAsync(wm http2frameWriteMsg) {
+ err := wm.write.writeFrame(sc)
+ sc.wroteFrameCh <- http2frameWriteResult{wm, err}
+}
+
+func (sc *http2serverConn) closeAllStreamsOnConnClose() {
+ sc.serveG.check()
+ for _, st := range sc.streams {
+ sc.closeStream(st, http2errClientDisconnected)
+ }
+}
+
+func (sc *http2serverConn) stopShutdownTimer() {
+ sc.serveG.check()
+ if t := sc.shutdownTimer; t != nil {
+ t.Stop()
+ }
+}
+
+func (sc *http2serverConn) notePanic() {
+
+ if http2testHookOnPanicMu != nil {
+ http2testHookOnPanicMu.Lock()
+ defer http2testHookOnPanicMu.Unlock()
+ }
+ if http2testHookOnPanic != nil {
+ if e := recover(); e != nil {
+ if http2testHookOnPanic(sc, e) {
+ panic(e)
+ }
+ }
+ }
+}
+
+func (sc *http2serverConn) serve() {
+ sc.serveG.check()
+ defer sc.notePanic()
+ defer sc.conn.Close()
+ defer sc.closeAllStreamsOnConnClose()
+ defer sc.stopShutdownTimer()
+ defer close(sc.doneServing)
+
+ if http2VerboseLogs {
+ sc.vlogf("http2: server connection from %v on %p", sc.conn.RemoteAddr(), sc.hs)
+ }
+
+ sc.writeFrame(http2frameWriteMsg{
+ write: http2writeSettings{
+ {http2SettingMaxFrameSize, sc.srv.maxReadFrameSize()},
+ {http2SettingMaxConcurrentStreams, sc.advMaxStreams},
+ {http2SettingMaxHeaderListSize, sc.maxHeaderListSize()},
+ },
+ })
+ sc.unackedSettings++
+
+ if err := sc.readPreface(); err != nil {
+ sc.condlogf(err, "http2: server: error reading preface from client %v: %v", sc.conn.RemoteAddr(), err)
+ return
+ }
+
+ sc.setConnState(StateActive)
+ sc.setConnState(StateIdle)
+
+ go sc.readFrames()
+
+ settingsTimer := time.NewTimer(http2firstSettingsTimeout)
+ loopNum := 0
+ for {
+ loopNum++
+ select {
+ case wm := <-sc.wantWriteFrameCh:
+ sc.writeFrame(wm)
+ case res := <-sc.wroteFrameCh:
+ sc.wroteFrame(res)
+ case res := <-sc.readFrameCh:
+ if !sc.processFrameFromReader(res) {
+ return
+ }
+ res.readMore()
+ if settingsTimer.C != nil {
+ settingsTimer.Stop()
+ settingsTimer.C = nil
+ }
+ case m := <-sc.bodyReadCh:
+ sc.noteBodyRead(m.st, m.n)
+ case <-settingsTimer.C:
+ sc.logf("timeout waiting for SETTINGS frames from %v", sc.conn.RemoteAddr())
+ return
+ case <-sc.shutdownTimerCh:
+ sc.vlogf("GOAWAY close timer fired; closing conn from %v", sc.conn.RemoteAddr())
+ return
+ case fn := <-sc.testHookCh:
+ fn(loopNum)
+ }
+ }
+}
+
+// readPreface reads the ClientPreface greeting from the peer
+// or returns an error on timeout or an invalid greeting.
+func (sc *http2serverConn) readPreface() error {
+ errc := make(chan error, 1)
+ go func() {
+
+ buf := make([]byte, len(http2ClientPreface))
+ if _, err := io.ReadFull(sc.conn, buf); err != nil {
+ errc <- err
+ } else if !bytes.Equal(buf, http2clientPreface) {
+ errc <- fmt.Errorf("bogus greeting %q", buf)
+ } else {
+ errc <- nil
+ }
+ }()
+ timer := time.NewTimer(http2prefaceTimeout)
+ defer timer.Stop()
+ select {
+ case <-timer.C:
+ return errors.New("timeout waiting for client preface")
+ case err := <-errc:
+ if err == nil {
+ if http2VerboseLogs {
+ sc.vlogf("http2: server: client %v said hello", sc.conn.RemoteAddr())
+ }
+ }
+ return err
+ }
+}
+
+var http2errChanPool = sync.Pool{
+ New: func() interface{} { return make(chan error, 1) },
+}
+
+var http2writeDataPool = sync.Pool{
+ New: func() interface{} { return new(http2writeData) },
+}
+
+// writeDataFromHandler writes DATA response frames from a handler on
+// the given stream.
+func (sc *http2serverConn) writeDataFromHandler(stream *http2stream, data []byte, endStream bool) error {
+ ch := http2errChanPool.Get().(chan error)
+ writeArg := http2writeDataPool.Get().(*http2writeData)
+ *writeArg = http2writeData{stream.id, data, endStream}
+ err := sc.writeFrameFromHandler(http2frameWriteMsg{
+ write: writeArg,
+ stream: stream,
+ done: ch,
+ })
+ if err != nil {
+ return err
+ }
+ var frameWriteDone bool // the frame write is done (successfully or not)
+ select {
+ case err = <-ch:
+ frameWriteDone = true
+ case <-sc.doneServing:
+ return http2errClientDisconnected
+ case <-stream.cw:
+
+ select {
+ case err = <-ch:
+ frameWriteDone = true
+ default:
+ return http2errStreamClosed
+ }
+ }
+ http2errChanPool.Put(ch)
+ if frameWriteDone {
+ http2writeDataPool.Put(writeArg)
+ }
+ return err
+}
+
+// writeFrameFromHandler sends wm to sc.wantWriteFrameCh, but aborts
+// if the connection has gone away.
+//
+// This must not be run from the serve goroutine itself, else it might
+// deadlock writing to sc.wantWriteFrameCh (which is only mildly
+// buffered and is read by serve itself). If you're on the serve
+// goroutine, call writeFrame instead.
+func (sc *http2serverConn) writeFrameFromHandler(wm http2frameWriteMsg) error {
+ sc.serveG.checkNotOn()
+ select {
+ case sc.wantWriteFrameCh <- wm:
+ return nil
+ case <-sc.doneServing:
+
+ return http2errClientDisconnected
+ }
+}
+
+// writeFrame schedules a frame to write and sends it if there's nothing
+// already being written.
+//
+// There is no pushback here (the serve goroutine never blocks). It's
+// the http.Handlers that block, waiting for their previous frames to
+// make it onto the wire
+//
+// If you're not on the serve goroutine, use writeFrameFromHandler instead.
+func (sc *http2serverConn) writeFrame(wm http2frameWriteMsg) {
+ sc.serveG.check()
+ sc.writeSched.add(wm)
+ sc.scheduleFrameWrite()
+}
+
+// startFrameWrite starts a goroutine to write wm (in a separate
+// goroutine since that might block on the network), and updates the
+// serve goroutine's state about the world, updated from info in wm.
+func (sc *http2serverConn) startFrameWrite(wm http2frameWriteMsg) {
+ sc.serveG.check()
+ if sc.writingFrame {
+ panic("internal error: can only be writing one frame at a time")
+ }
+
+ st := wm.stream
+ if st != nil {
+ switch st.state {
+ case http2stateHalfClosedLocal:
+ panic("internal error: attempt to send frame on half-closed-local stream")
+ case http2stateClosed:
+ if st.sentReset || st.gotReset {
+
+ sc.scheduleFrameWrite()
+ return
+ }
+ panic(fmt.Sprintf("internal error: attempt to send a write %v on a closed stream", wm))
+ }
+ }
+
+ sc.writingFrame = true
+ sc.needsFrameFlush = true
+ go sc.writeFrameAsync(wm)
+}
+
+// errHandlerPanicked is the error given to any callers blocked in a read from
+// Request.Body when the main goroutine panics. Since most handlers read in the
+// the main ServeHTTP goroutine, this will show up rarely.
+var http2errHandlerPanicked = errors.New("http2: handler panicked")
+
+// wroteFrame is called on the serve goroutine with the result of
+// whatever happened on writeFrameAsync.
+func (sc *http2serverConn) wroteFrame(res http2frameWriteResult) {
+ sc.serveG.check()
+ if !sc.writingFrame {
+ panic("internal error: expected to be already writing a frame")
+ }
+ sc.writingFrame = false
+
+ wm := res.wm
+ st := wm.stream
+
+ closeStream := http2endsStream(wm.write)
+
+ if _, ok := wm.write.(http2handlerPanicRST); ok {
+ sc.closeStream(st, http2errHandlerPanicked)
+ }
+
+ if ch := wm.done; ch != nil {
+ select {
+ case ch <- res.err:
+ default:
+ panic(fmt.Sprintf("unbuffered done channel passed in for type %T", wm.write))
+ }
+ }
+ wm.write = nil
+
+ if closeStream {
+ if st == nil {
+ panic("internal error: expecting non-nil stream")
+ }
+ switch st.state {
+ case http2stateOpen:
+
+ st.state = http2stateHalfClosedLocal
+ errCancel := http2StreamError{st.id, http2ErrCodeCancel}
+ sc.resetStream(errCancel)
+ case http2stateHalfClosedRemote:
+ sc.closeStream(st, http2errHandlerComplete)
+ }
+ }
+
+ sc.scheduleFrameWrite()
+}
+
+// scheduleFrameWrite tickles the frame writing scheduler.
+//
+// If a frame is already being written, nothing happens. This will be called again
+// when the frame is done being written.
+//
+// If a frame isn't being written we need to send one, the best frame
+// to send is selected, preferring first things that aren't
+// stream-specific (e.g. ACKing settings), and then finding the
+// highest priority stream.
+//
+// If a frame isn't being written and there's nothing else to send, we
+// flush the write buffer.
+func (sc *http2serverConn) scheduleFrameWrite() {
+ sc.serveG.check()
+ if sc.writingFrame {
+ return
+ }
+ if sc.needToSendGoAway {
+ sc.needToSendGoAway = false
+ sc.startFrameWrite(http2frameWriteMsg{
+ write: &http2writeGoAway{
+ maxStreamID: sc.maxStreamID,
+ code: sc.goAwayCode,
+ },
+ })
+ return
+ }
+ if sc.needToSendSettingsAck {
+ sc.needToSendSettingsAck = false
+ sc.startFrameWrite(http2frameWriteMsg{write: http2writeSettingsAck{}})
+ return
+ }
+ if !sc.inGoAway {
+ if wm, ok := sc.writeSched.take(); ok {
+ sc.startFrameWrite(wm)
+ return
+ }
+ }
+ if sc.needsFrameFlush {
+ sc.startFrameWrite(http2frameWriteMsg{write: http2flushFrameWriter{}})
+ sc.needsFrameFlush = false
+ return
+ }
+}
+
+func (sc *http2serverConn) goAway(code http2ErrCode) {
+ sc.serveG.check()
+ if sc.inGoAway {
+ return
+ }
+ if code != http2ErrCodeNo {
+ sc.shutDownIn(250 * time.Millisecond)
+ } else {
+
+ sc.shutDownIn(1 * time.Second)
+ }
+ sc.inGoAway = true
+ sc.needToSendGoAway = true
+ sc.goAwayCode = code
+ sc.scheduleFrameWrite()
+}
+
+func (sc *http2serverConn) shutDownIn(d time.Duration) {
+ sc.serveG.check()
+ sc.shutdownTimer = time.NewTimer(d)
+ sc.shutdownTimerCh = sc.shutdownTimer.C
+}
+
+func (sc *http2serverConn) resetStream(se http2StreamError) {
+ sc.serveG.check()
+ sc.writeFrame(http2frameWriteMsg{write: se})
+ if st, ok := sc.streams[se.StreamID]; ok {
+ st.sentReset = true
+ sc.closeStream(st, se)
+ }
+}
+
+// processFrameFromReader processes the serve loop's read from readFrameCh from the
+// frame-reading goroutine.
+// processFrameFromReader returns whether the connection should be kept open.
+func (sc *http2serverConn) processFrameFromReader(res http2readFrameResult) bool {
+ sc.serveG.check()
+ err := res.err
+ if err != nil {
+ if err == http2ErrFrameTooLarge {
+ sc.goAway(http2ErrCodeFrameSize)
+ return true
+ }
+ clientGone := err == io.EOF || err == io.ErrUnexpectedEOF || http2isClosedConnError(err)
+ if clientGone {
+
+ return false
+ }
+ } else {
+ f := res.f
+ if http2VerboseLogs {
+ sc.vlogf("http2: server read frame %v", http2summarizeFrame(f))
+ }
+ err = sc.processFrame(f)
+ if err == nil {
+ return true
+ }
+ }
+
+ switch ev := err.(type) {
+ case http2StreamError:
+ sc.resetStream(ev)
+ return true
+ case http2goAwayFlowError:
+ sc.goAway(http2ErrCodeFlowControl)
+ return true
+ case http2ConnectionError:
+ sc.logf("http2: server connection error from %v: %v", sc.conn.RemoteAddr(), ev)
+ sc.goAway(http2ErrCode(ev))
+ return true
+ default:
+ if res.err != nil {
+ sc.vlogf("http2: server closing client connection; error reading frame from client %s: %v", sc.conn.RemoteAddr(), err)
+ } else {
+ sc.logf("http2: server closing client connection: %v", err)
+ }
+ return false
+ }
+}
+
+func (sc *http2serverConn) processFrame(f http2Frame) error {
+ sc.serveG.check()
+
+ if !sc.sawFirstSettings {
+ if _, ok := f.(*http2SettingsFrame); !ok {
+ return http2ConnectionError(http2ErrCodeProtocol)
+ }
+ sc.sawFirstSettings = true
+ }
+
+ switch f := f.(type) {
+ case *http2SettingsFrame:
+ return sc.processSettings(f)
+ case *http2HeadersFrame:
+ return sc.processHeaders(f)
+ case *http2ContinuationFrame:
+ return sc.processContinuation(f)
+ case *http2WindowUpdateFrame:
+ return sc.processWindowUpdate(f)
+ case *http2PingFrame:
+ return sc.processPing(f)
+ case *http2DataFrame:
+ return sc.processData(f)
+ case *http2RSTStreamFrame:
+ return sc.processResetStream(f)
+ case *http2PriorityFrame:
+ return sc.processPriority(f)
+ case *http2PushPromiseFrame:
+
+ return http2ConnectionError(http2ErrCodeProtocol)
+ default:
+ sc.vlogf("http2: server ignoring frame: %v", f.Header())
+ return nil
+ }
+}
+
+func (sc *http2serverConn) processPing(f *http2PingFrame) error {
+ sc.serveG.check()
+ if f.IsAck() {
+
+ return nil
+ }
+ if f.StreamID != 0 {
+
+ return http2ConnectionError(http2ErrCodeProtocol)
+ }
+ sc.writeFrame(http2frameWriteMsg{write: http2writePingAck{f}})
+ return nil
+}
+
+func (sc *http2serverConn) processWindowUpdate(f *http2WindowUpdateFrame) error {
+ sc.serveG.check()
+ switch {
+ case f.StreamID != 0:
+ st := sc.streams[f.StreamID]
+ if st == nil {
+
+ return nil
+ }
+ if !st.flow.add(int32(f.Increment)) {
+ return http2StreamError{f.StreamID, http2ErrCodeFlowControl}
+ }
+ default:
+ if !sc.flow.add(int32(f.Increment)) {
+ return http2goAwayFlowError{}
+ }
+ }
+ sc.scheduleFrameWrite()
+ return nil
+}
+
+func (sc *http2serverConn) processResetStream(f *http2RSTStreamFrame) error {
+ sc.serveG.check()
+
+ state, st := sc.state(f.StreamID)
+ if state == http2stateIdle {
+
+ return http2ConnectionError(http2ErrCodeProtocol)
+ }
+ if st != nil {
+ st.gotReset = true
+ sc.closeStream(st, http2StreamError{f.StreamID, f.ErrCode})
+ }
+ return nil
+}
+
+func (sc *http2serverConn) closeStream(st *http2stream, err error) {
+ sc.serveG.check()
+ if st.state == http2stateIdle || st.state == http2stateClosed {
+ panic(fmt.Sprintf("invariant; can't close stream in state %v", st.state))
+ }
+ st.state = http2stateClosed
+ sc.curOpenStreams--
+ if sc.curOpenStreams == 0 {
+ sc.setConnState(StateIdle)
+ }
+ delete(sc.streams, st.id)
+ if p := st.body; p != nil {
+ p.CloseWithError(err)
+ }
+ st.cw.Close()
+ sc.writeSched.forgetStream(st.id)
+}
+
+func (sc *http2serverConn) processSettings(f *http2SettingsFrame) error {
+ sc.serveG.check()
+ if f.IsAck() {
+ sc.unackedSettings--
+ if sc.unackedSettings < 0 {
+
+ return http2ConnectionError(http2ErrCodeProtocol)
+ }
+ return nil
+ }
+ if err := f.ForeachSetting(sc.processSetting); err != nil {
+ return err
+ }
+ sc.needToSendSettingsAck = true
+ sc.scheduleFrameWrite()
+ return nil
+}
+
+func (sc *http2serverConn) processSetting(s http2Setting) error {
+ sc.serveG.check()
+ if err := s.Valid(); err != nil {
+ return err
+ }
+ if http2VerboseLogs {
+ sc.vlogf("http2: server processing setting %v", s)
+ }
+ switch s.ID {
+ case http2SettingHeaderTableSize:
+ sc.headerTableSize = s.Val
+ sc.hpackEncoder.SetMaxDynamicTableSize(s.Val)
+ case http2SettingEnablePush:
+ sc.pushEnabled = s.Val != 0
+ case http2SettingMaxConcurrentStreams:
+ sc.clientMaxStreams = s.Val
+ case http2SettingInitialWindowSize:
+ return sc.processSettingInitialWindowSize(s.Val)
+ case http2SettingMaxFrameSize:
+ sc.writeSched.maxFrameSize = s.Val
+ case http2SettingMaxHeaderListSize:
+ sc.peerMaxHeaderListSize = s.Val
+ default:
+
+ if http2VerboseLogs {
+ sc.vlogf("http2: server ignoring unknown setting %v", s)
+ }
+ }
+ return nil
+}
+
+func (sc *http2serverConn) processSettingInitialWindowSize(val uint32) error {
+ sc.serveG.check()
+
+ old := sc.initialWindowSize
+ sc.initialWindowSize = int32(val)
+ growth := sc.initialWindowSize - old
+ for _, st := range sc.streams {
+ if !st.flow.add(growth) {
+
+ return http2ConnectionError(http2ErrCodeFlowControl)
+ }
+ }
+ return nil
+}
+
+func (sc *http2serverConn) processData(f *http2DataFrame) error {
+ sc.serveG.check()
+
+ id := f.Header().StreamID
+ st, ok := sc.streams[id]
+ if !ok || st.state != http2stateOpen || st.gotTrailerHeader {
+
+ return http2StreamError{id, http2ErrCodeStreamClosed}
+ }
+ if st.body == nil {
+ panic("internal error: should have a body in this state")
+ }
+ data := f.Data()
+
+ if st.declBodyBytes != -1 && st.bodyBytes+int64(len(data)) > st.declBodyBytes {
+ st.body.CloseWithError(fmt.Errorf("sender tried to send more than declared Content-Length of %d bytes", st.declBodyBytes))
+ return http2StreamError{id, http2ErrCodeStreamClosed}
+ }
+ if len(data) > 0 {
+
+ if int(st.inflow.available()) < len(data) {
+ return http2StreamError{id, http2ErrCodeFlowControl}
+ }
+ st.inflow.take(int32(len(data)))
+ wrote, err := st.body.Write(data)
+ if err != nil {
+ return http2StreamError{id, http2ErrCodeStreamClosed}
+ }
+ if wrote != len(data) {
+ panic("internal error: bad Writer")
+ }
+ st.bodyBytes += int64(len(data))
+ }
+ if f.StreamEnded() {
+ st.endStream()
+ }
+ return nil
+}
+
+// endStream closes a Request.Body's pipe. It is called when a DATA
+// frame says a request body is over (or after trailers).
+func (st *http2stream) endStream() {
+ sc := st.sc
+ sc.serveG.check()
+
+ if st.declBodyBytes != -1 && st.declBodyBytes != st.bodyBytes {
+ st.body.CloseWithError(fmt.Errorf("request declared a Content-Length of %d but only wrote %d bytes",
+ st.declBodyBytes, st.bodyBytes))
+ } else {
+ st.body.closeWithErrorAndCode(io.EOF, st.copyTrailersToHandlerRequest)
+ st.body.CloseWithError(io.EOF)
+ }
+ st.state = http2stateHalfClosedRemote
+}
+
+// copyTrailersToHandlerRequest is run in the Handler's goroutine in
+// its Request.Body.Read just before it gets io.EOF.
+func (st *http2stream) copyTrailersToHandlerRequest() {
+ for k, vv := range st.trailer {
+ if _, ok := st.reqTrailer[k]; ok {
+
+ st.reqTrailer[k] = vv
+ }
+ }
+}
+
+func (sc *http2serverConn) processHeaders(f *http2HeadersFrame) error {
+ sc.serveG.check()
+ id := f.Header().StreamID
+ if sc.inGoAway {
+
+ return nil
+ }
+
+ if id%2 != 1 {
+ return http2ConnectionError(http2ErrCodeProtocol)
+ }
+
+ st := sc.streams[f.Header().StreamID]
+ if st != nil {
+ return st.processTrailerHeaders(f)
+ }
+
+ if id <= sc.maxStreamID || sc.req.stream != nil {
+ return http2ConnectionError(http2ErrCodeProtocol)
+ }
+
+ if id > sc.maxStreamID {
+ sc.maxStreamID = id
+ }
+ st = &http2stream{
+ sc: sc,
+ id: id,
+ state: http2stateOpen,
+ }
+ if f.StreamEnded() {
+ st.state = http2stateHalfClosedRemote
+ }
+ st.cw.Init()
+
+ st.flow.conn = &sc.flow
+ st.flow.add(sc.initialWindowSize)
+ st.inflow.conn = &sc.inflow
+ st.inflow.add(http2initialWindowSize)
+
+ sc.streams[id] = st
+ if f.HasPriority() {
+ http2adjustStreamPriority(sc.streams, st.id, f.Priority)
+ }
+ sc.curOpenStreams++
+ if sc.curOpenStreams == 1 {
+ sc.setConnState(StateActive)
+ }
+ sc.req = http2requestParam{
+ stream: st,
+ header: make(Header),
+ }
+ sc.hpackDecoder.SetEmitFunc(sc.onNewHeaderField)
+ sc.hpackDecoder.SetEmitEnabled(true)
+ return sc.processHeaderBlockFragment(st, f.HeaderBlockFragment(), f.HeadersEnded())
+}
+
+func (st *http2stream) processTrailerHeaders(f *http2HeadersFrame) error {
+ sc := st.sc
+ sc.serveG.check()
+ if st.gotTrailerHeader {
+ return http2ConnectionError(http2ErrCodeProtocol)
+ }
+ st.gotTrailerHeader = true
+ if !f.StreamEnded() {
+ return http2StreamError{st.id, http2ErrCodeProtocol}
+ }
+ sc.resetPendingRequest()
+ return st.processTrailerHeaderBlockFragment(f.HeaderBlockFragment(), f.HeadersEnded())
+}
+
+func (sc *http2serverConn) processContinuation(f *http2ContinuationFrame) error {
+ sc.serveG.check()
+ st := sc.streams[f.Header().StreamID]
+ if st.gotTrailerHeader {
+ return st.processTrailerHeaderBlockFragment(f.HeaderBlockFragment(), f.HeadersEnded())
+ }
+ return sc.processHeaderBlockFragment(st, f.HeaderBlockFragment(), f.HeadersEnded())
+}
+
+func (sc *http2serverConn) processHeaderBlockFragment(st *http2stream, frag []byte, end bool) error {
+ sc.serveG.check()
+ if _, err := sc.hpackDecoder.Write(frag); err != nil {
+ return http2ConnectionError(http2ErrCodeCompression)
+ }
+ if !end {
+ return nil
+ }
+ if err := sc.hpackDecoder.Close(); err != nil {
+ return http2ConnectionError(http2ErrCodeCompression)
+ }
+ defer sc.resetPendingRequest()
+ if sc.curOpenStreams > sc.advMaxStreams {
+
+ if sc.unackedSettings == 0 {
+
+ return http2StreamError{st.id, http2ErrCodeProtocol}
+ }
+
+ return http2StreamError{st.id, http2ErrCodeRefusedStream}
+ }
+
+ rw, req, err := sc.newWriterAndRequest()
+ if err != nil {
+ return err
+ }
+ st.reqTrailer = req.Trailer
+ if st.reqTrailer != nil {
+ st.trailer = make(Header)
+ }
+ st.body = req.Body.(*http2requestBody).pipe
+ st.declBodyBytes = req.ContentLength
+
+ handler := sc.handler.ServeHTTP
+ if !sc.hpackDecoder.EmitEnabled() {
+
+ handler = http2handleHeaderListTooLong
+ }
+
+ go sc.runHandler(rw, req, handler)
+ return nil
+}
+
+func (st *http2stream) processTrailerHeaderBlockFragment(frag []byte, end bool) error {
+ sc := st.sc
+ sc.serveG.check()
+ sc.hpackDecoder.SetEmitFunc(st.onNewTrailerField)
+ if _, err := sc.hpackDecoder.Write(frag); err != nil {
+ return http2ConnectionError(http2ErrCodeCompression)
+ }
+ if !end {
+ return nil
+ }
+
+ rp := &sc.req
+ if rp.invalidHeader {
+ return http2StreamError{rp.stream.id, http2ErrCodeProtocol}
+ }
+
+ err := sc.hpackDecoder.Close()
+ st.endStream()
+ if err != nil {
+ return http2ConnectionError(http2ErrCodeCompression)
+ }
+ return nil
+}
+
+func (sc *http2serverConn) processPriority(f *http2PriorityFrame) error {
+ http2adjustStreamPriority(sc.streams, f.StreamID, f.http2PriorityParam)
+ return nil
+}
+
+func http2adjustStreamPriority(streams map[uint32]*http2stream, streamID uint32, priority http2PriorityParam) {
+ st, ok := streams[streamID]
+ if !ok {
+
+ return
+ }
+ st.weight = priority.Weight
+ parent := streams[priority.StreamDep]
+ if parent == st {
+
+ return
+ }
+
+ for piter := parent; piter != nil; piter = piter.parent {
+ if piter == st {
+ parent.parent = st.parent
+ break
+ }
+ }
+ st.parent = parent
+ if priority.Exclusive && (st.parent != nil || priority.StreamDep == 0) {
+ for _, openStream := range streams {
+ if openStream != st && openStream.parent == st.parent {
+ openStream.parent = st
+ }
+ }
+ }
+}
+
+// resetPendingRequest zeros out all state related to a HEADERS frame
+// and its zero or more CONTINUATION frames sent to start a new
+// request.
+func (sc *http2serverConn) resetPendingRequest() {
+ sc.serveG.check()
+ sc.req = http2requestParam{}
+}
+
+func (sc *http2serverConn) newWriterAndRequest() (*http2responseWriter, *Request, error) {
+ sc.serveG.check()
+ rp := &sc.req
+
+ if rp.invalidHeader {
+ return nil, nil, http2StreamError{rp.stream.id, http2ErrCodeProtocol}
+ }
+
+ isConnect := rp.method == "CONNECT"
+ if isConnect {
+ if rp.path != "" || rp.scheme != "" || rp.authority == "" {
+ return nil, nil, http2StreamError{rp.stream.id, http2ErrCodeProtocol}
+ }
+ } else if rp.method == "" || rp.path == "" ||
+ (rp.scheme != "https" && rp.scheme != "http") {
+
+ return nil, nil, http2StreamError{rp.stream.id, http2ErrCodeProtocol}
+ }
+
+ bodyOpen := rp.stream.state == http2stateOpen
+ if rp.method == "HEAD" && bodyOpen {
+
+ return nil, nil, http2StreamError{rp.stream.id, http2ErrCodeProtocol}
+ }
+ var tlsState *tls.ConnectionState // nil if not scheme https
+
+ if rp.scheme == "https" {
+ tlsState = sc.tlsState
+ }
+ authority := rp.authority
+ if authority == "" {
+ authority = rp.header.Get("Host")
+ }
+ needsContinue := rp.header.Get("Expect") == "100-continue"
+ if needsContinue {
+ rp.header.Del("Expect")
+ }
+
+ if cookies := rp.header["Cookie"]; len(cookies) > 1 {
+ rp.header.Set("Cookie", strings.Join(cookies, "; "))
+ }
+
+ // Setup Trailers
+ var trailer Header
+ for _, v := range rp.header["Trailer"] {
+ for _, key := range strings.Split(v, ",") {
+ key = CanonicalHeaderKey(strings.TrimSpace(key))
+ switch key {
+ case "Transfer-Encoding", "Trailer", "Content-Length":
+
+ default:
+ if trailer == nil {
+ trailer = make(Header)
+ }
+ trailer[key] = nil
+ }
+ }
+ }
+ delete(rp.header, "Trailer")
+
+ body := &http2requestBody{
+ conn: sc,
+ stream: rp.stream,
+ needsContinue: needsContinue,
+ }
+ var url_ *url.URL
+ var requestURI string
+ if isConnect {
+ url_ = &url.URL{Host: rp.authority}
+ requestURI = rp.authority
+ } else {
+ var err error
+ url_, err = url.ParseRequestURI(rp.path)
+ if err != nil {
+ return nil, nil, http2StreamError{rp.stream.id, http2ErrCodeProtocol}
+ }
+ requestURI = rp.path
+ }
+ req := &Request{
+ Method: rp.method,
+ URL: url_,
+ RemoteAddr: sc.remoteAddrStr,
+ Header: rp.header,
+ RequestURI: requestURI,
+ Proto: "HTTP/2.0",
+ ProtoMajor: 2,
+ ProtoMinor: 0,
+ TLS: tlsState,
+ Host: authority,
+ Body: body,
+ Trailer: trailer,
+ }
+ if bodyOpen {
+ body.pipe = &http2pipe{
+ b: &http2fixedBuffer{buf: make([]byte, http2initialWindowSize)},
+ }
+
+ if vv, ok := rp.header["Content-Length"]; ok {
+ req.ContentLength, _ = strconv.ParseInt(vv[0], 10, 64)
+ } else {
+ req.ContentLength = -1
+ }
+ }
+
+ rws := http2responseWriterStatePool.Get().(*http2responseWriterState)
+ bwSave := rws.bw
+ *rws = http2responseWriterState{}
+ rws.conn = sc
+ rws.bw = bwSave
+ rws.bw.Reset(http2chunkWriter{rws})
+ rws.stream = rp.stream
+ rws.req = req
+ rws.body = body
+
+ rw := &http2responseWriter{rws: rws}
+ return rw, req, nil
+}
+
+// Run on its own goroutine.
+func (sc *http2serverConn) runHandler(rw *http2responseWriter, req *Request, handler func(ResponseWriter, *Request)) {
+ didPanic := true
+ defer func() {
+ if didPanic {
+ e := recover()
+ // Same as net/http:
+ const size = 64 << 10
+ buf := make([]byte, size)
+ buf = buf[:runtime.Stack(buf, false)]
+ sc.writeFrameFromHandler(http2frameWriteMsg{
+ write: http2handlerPanicRST{rw.rws.stream.id},
+ stream: rw.rws.stream,
+ })
+ sc.logf("http2: panic serving %v: %v\n%s", sc.conn.RemoteAddr(), e, buf)
+ return
+ }
+ rw.handlerDone()
+ }()
+ handler(rw, req)
+ didPanic = false
+}
+
+func http2handleHeaderListTooLong(w ResponseWriter, r *Request) {
+ // 10.5.1 Limits on Header Block Size:
+ // .. "A server that receives a larger header block than it is
+ // willing to handle can send an HTTP 431 (Request Header Fields Too
+ // Large) status code"
+ const statusRequestHeaderFieldsTooLarge = 431 // only in Go 1.6+
+ w.WriteHeader(statusRequestHeaderFieldsTooLarge)
+ io.WriteString(w, "<h1>HTTP Error 431</h1><p>Request Header Field(s) Too Large</p>")
+}
+
+// called from handler goroutines.
+// h may be nil.
+func (sc *http2serverConn) writeHeaders(st *http2stream, headerData *http2writeResHeaders) error {
+ sc.serveG.checkNotOn()
+ var errc chan error
+ if headerData.h != nil {
+
+ errc = http2errChanPool.Get().(chan error)
+ }
+ if err := sc.writeFrameFromHandler(http2frameWriteMsg{
+ write: headerData,
+ stream: st,
+ done: errc,
+ }); err != nil {
+ return err
+ }
+ if errc != nil {
+ select {
+ case err := <-errc:
+ http2errChanPool.Put(errc)
+ return err
+ case <-sc.doneServing:
+ return http2errClientDisconnected
+ case <-st.cw:
+ return http2errStreamClosed
+ }
+ }
+ return nil
+}
+
+// called from handler goroutines.
+func (sc *http2serverConn) write100ContinueHeaders(st *http2stream) {
+ sc.writeFrameFromHandler(http2frameWriteMsg{
+ write: http2write100ContinueHeadersFrame{st.id},
+ stream: st,
+ })
+}
+
+// A bodyReadMsg tells the server loop that the http.Handler read n
+// bytes of the DATA from the client on the given stream.
+type http2bodyReadMsg struct {
+ st *http2stream
+ n int
+}
+
+// called from handler goroutines.
+// Notes that the handler for the given stream ID read n bytes of its body
+// and schedules flow control tokens to be sent.
+func (sc *http2serverConn) noteBodyReadFromHandler(st *http2stream, n int) {
+ sc.serveG.checkNotOn()
+ select {
+ case sc.bodyReadCh <- http2bodyReadMsg{st, n}:
+ case <-sc.doneServing:
+ }
+}
+
+func (sc *http2serverConn) noteBodyRead(st *http2stream, n int) {
+ sc.serveG.check()
+ sc.sendWindowUpdate(nil, n)
+ if st.state != http2stateHalfClosedRemote && st.state != http2stateClosed {
+
+ sc.sendWindowUpdate(st, n)
+ }
+}
+
+// st may be nil for conn-level
+func (sc *http2serverConn) sendWindowUpdate(st *http2stream, n int) {
+ sc.serveG.check()
+ // "The legal range for the increment to the flow control
+ // window is 1 to 2^31-1 (2,147,483,647) octets."
+ // A Go Read call on 64-bit machines could in theory read
+ // a larger Read than this. Very unlikely, but we handle it here
+ // rather than elsewhere for now.
+ const maxUint31 = 1<<31 - 1
+ for n >= maxUint31 {
+ sc.sendWindowUpdate32(st, maxUint31)
+ n -= maxUint31
+ }
+ sc.sendWindowUpdate32(st, int32(n))
+}
+
+// st may be nil for conn-level
+func (sc *http2serverConn) sendWindowUpdate32(st *http2stream, n int32) {
+ sc.serveG.check()
+ if n == 0 {
+ return
+ }
+ if n < 0 {
+ panic("negative update")
+ }
+ var streamID uint32
+ if st != nil {
+ streamID = st.id
+ }
+ sc.writeFrame(http2frameWriteMsg{
+ write: http2writeWindowUpdate{streamID: streamID, n: uint32(n)},
+ stream: st,
+ })
+ var ok bool
+ if st == nil {
+ ok = sc.inflow.add(n)
+ } else {
+ ok = st.inflow.add(n)
+ }
+ if !ok {
+ panic("internal error; sent too many window updates without decrements?")
+ }
+}
+
+type http2requestBody struct {
+ stream *http2stream
+ conn *http2serverConn
+ closed bool
+ pipe *http2pipe // non-nil if we have a HTTP entity message body
+ needsContinue bool // need to send a 100-continue
+}
+
+func (b *http2requestBody) Close() error {
+ if b.pipe != nil {
+ b.pipe.CloseWithError(http2errClosedBody)
+ }
+ b.closed = true
+ return nil
+}
+
+func (b *http2requestBody) Read(p []byte) (n int, err error) {
+ if b.needsContinue {
+ b.needsContinue = false
+ b.conn.write100ContinueHeaders(b.stream)
+ }
+ if b.pipe == nil {
+ return 0, io.EOF
+ }
+ n, err = b.pipe.Read(p)
+ if n > 0 {
+ b.conn.noteBodyReadFromHandler(b.stream, n)
+ }
+ return
+}
+
+// responseWriter is the http.ResponseWriter implementation. It's
+// intentionally small (1 pointer wide) to minimize garbage. The
+// responseWriterState pointer inside is zeroed at the end of a
+// request (in handlerDone) and calls on the responseWriter thereafter
+// simply crash (caller's mistake), but the much larger responseWriterState
+// and buffers are reused between multiple requests.
+type http2responseWriter struct {
+ rws *http2responseWriterState
+}
+
+// Optional http.ResponseWriter interfaces implemented.
+var (
+ _ CloseNotifier = (*http2responseWriter)(nil)
+ _ Flusher = (*http2responseWriter)(nil)
+ _ http2stringWriter = (*http2responseWriter)(nil)
+)
+
+type http2responseWriterState struct {
+ // immutable within a request:
+ stream *http2stream
+ req *Request
+ body *http2requestBody // to close at end of request, if DATA frames didn't
+ conn *http2serverConn
+
+ // TODO: adjust buffer writing sizes based on server config, frame size updates from peer, etc
+ bw *bufio.Writer // writing to a chunkWriter{this *responseWriterState}
+
+ // mutated by http.Handler goroutine:
+ handlerHeader Header // nil until called
+ snapHeader Header // snapshot of handlerHeader at WriteHeader time
+ trailers []string // set in writeChunk
+ status int // status code passed to WriteHeader
+ wroteHeader bool // WriteHeader called (explicitly or implicitly). Not necessarily sent to user yet.
+ sentHeader bool // have we sent the header frame?
+ handlerDone bool // handler has finished
+
+ sentContentLen int64 // non-zero if handler set a Content-Length header
+ wroteBytes int64
+
+ closeNotifierMu sync.Mutex // guards closeNotifierCh
+ closeNotifierCh chan bool // nil until first used
+}
+
+type http2chunkWriter struct{ rws *http2responseWriterState }
+
+func (cw http2chunkWriter) Write(p []byte) (n int, err error) { return cw.rws.writeChunk(p) }
+
+func (rws *http2responseWriterState) hasTrailers() bool { return len(rws.trailers) != 0 }
+
+// declareTrailer is called for each Trailer header when the
+// response header is written. It notes that a header will need to be
+// written in the trailers at the end of the response.
+func (rws *http2responseWriterState) declareTrailer(k string) {
+ k = CanonicalHeaderKey(k)
+ switch k {
+ case "Transfer-Encoding", "Content-Length", "Trailer":
+
+ return
+ }
+ rws.trailers = append(rws.trailers, k)
+}
+
+// writeChunk writes chunks from the bufio.Writer. But because
+// bufio.Writer may bypass its chunking, sometimes p may be
+// arbitrarily large.
+//
+// writeChunk is also responsible (on the first chunk) for sending the
+// HEADER response.
+func (rws *http2responseWriterState) writeChunk(p []byte) (n int, err error) {
+ if !rws.wroteHeader {
+ rws.writeHeader(200)
+ }
+
+ isHeadResp := rws.req.Method == "HEAD"
+ if !rws.sentHeader {
+ rws.sentHeader = true
+ var ctype, clen string
+ if clen = rws.snapHeader.Get("Content-Length"); clen != "" {
+ rws.snapHeader.Del("Content-Length")
+ clen64, err := strconv.ParseInt(clen, 10, 64)
+ if err == nil && clen64 >= 0 {
+ rws.sentContentLen = clen64
+ } else {
+ clen = ""
+ }
+ }
+ if clen == "" && rws.handlerDone && http2bodyAllowedForStatus(rws.status) && (len(p) > 0 || !isHeadResp) {
+ clen = strconv.Itoa(len(p))
+ }
+ _, hasContentType := rws.snapHeader["Content-Type"]
+ if !hasContentType && http2bodyAllowedForStatus(rws.status) {
+ ctype = DetectContentType(p)
+ }
+ var date string
+ if _, ok := rws.snapHeader["Date"]; !ok {
+
+ date = time.Now().UTC().Format(TimeFormat)
+ }
+
+ for _, v := range rws.snapHeader["Trailer"] {
+ http2foreachHeaderElement(v, rws.declareTrailer)
+ }
+
+ endStream := (rws.handlerDone && !rws.hasTrailers() && len(p) == 0) || isHeadResp
+ err = rws.conn.writeHeaders(rws.stream, &http2writeResHeaders{
+ streamID: rws.stream.id,
+ httpResCode: rws.status,
+ h: rws.snapHeader,
+ endStream: endStream,
+ contentType: ctype,
+ contentLength: clen,
+ date: date,
+ })
+ if err != nil {
+ return 0, err
+ }
+ if endStream {
+ return 0, nil
+ }
+ }
+ if isHeadResp {
+ return len(p), nil
+ }
+ if len(p) == 0 && !rws.handlerDone {
+ return 0, nil
+ }
+
+ endStream := rws.handlerDone && !rws.hasTrailers()
+ if len(p) > 0 || endStream {
+
+ if err := rws.conn.writeDataFromHandler(rws.stream, p, endStream); err != nil {
+ return 0, err
+ }
+ }
+
+ if rws.handlerDone && rws.hasTrailers() {
+ err = rws.conn.writeHeaders(rws.stream, &http2writeResHeaders{
+ streamID: rws.stream.id,
+ h: rws.handlerHeader,
+ trailers: rws.trailers,
+ endStream: true,
+ })
+ return len(p), err
+ }
+ return len(p), nil
+}
+
+func (w *http2responseWriter) Flush() {
+ rws := w.rws
+ if rws == nil {
+ panic("Header called after Handler finished")
+ }
+ if rws.bw.Buffered() > 0 {
+ if err := rws.bw.Flush(); err != nil {
+
+ return
+ }
+ } else {
+
+ rws.writeChunk(nil)
+ }
+}
+
+func (w *http2responseWriter) CloseNotify() <-chan bool {
+ rws := w.rws
+ if rws == nil {
+ panic("CloseNotify called after Handler finished")
+ }
+ rws.closeNotifierMu.Lock()
+ ch := rws.closeNotifierCh
+ if ch == nil {
+ ch = make(chan bool, 1)
+ rws.closeNotifierCh = ch
+ go func() {
+ rws.stream.cw.Wait()
+ ch <- true
+ }()
+ }
+ rws.closeNotifierMu.Unlock()
+ return ch
+}
+
+func (w *http2responseWriter) Header() Header {
+ rws := w.rws
+ if rws == nil {
+ panic("Header called after Handler finished")
+ }
+ if rws.handlerHeader == nil {
+ rws.handlerHeader = make(Header)
+ }
+ return rws.handlerHeader
+}
+
+func (w *http2responseWriter) WriteHeader(code int) {
+ rws := w.rws
+ if rws == nil {
+ panic("WriteHeader called after Handler finished")
+ }
+ rws.writeHeader(code)
+}
+
+func (rws *http2responseWriterState) writeHeader(code int) {
+ if !rws.wroteHeader {
+ rws.wroteHeader = true
+ rws.status = code
+ if len(rws.handlerHeader) > 0 {
+ rws.snapHeader = http2cloneHeader(rws.handlerHeader)
+ }
+ }
+}
+
+func http2cloneHeader(h Header) Header {
+ h2 := make(Header, len(h))
+ for k, vv := range h {
+ vv2 := make([]string, len(vv))
+ copy(vv2, vv)
+ h2[k] = vv2
+ }
+ return h2
+}
+
+// The Life Of A Write is like this:
+//
+// * Handler calls w.Write or w.WriteString ->
+// * -> rws.bw (*bufio.Writer) ->
+// * (Handler migth call Flush)
+// * -> chunkWriter{rws}
+// * -> responseWriterState.writeChunk(p []byte)
+// * -> responseWriterState.writeChunk (most of the magic; see comment there)
+func (w *http2responseWriter) Write(p []byte) (n int, err error) {
+ return w.write(len(p), p, "")
+}
+
+func (w *http2responseWriter) WriteString(s string) (n int, err error) {
+ return w.write(len(s), nil, s)
+}
+
+// either dataB or dataS is non-zero.
+func (w *http2responseWriter) write(lenData int, dataB []byte, dataS string) (n int, err error) {
+ rws := w.rws
+ if rws == nil {
+ panic("Write called after Handler finished")
+ }
+ if !rws.wroteHeader {
+ w.WriteHeader(200)
+ }
+ if !http2bodyAllowedForStatus(rws.status) {
+ return 0, ErrBodyNotAllowed
+ }
+ rws.wroteBytes += int64(len(dataB)) + int64(len(dataS))
+ if rws.sentContentLen != 0 && rws.wroteBytes > rws.sentContentLen {
+
+ return 0, errors.New("http2: handler wrote more than declared Content-Length")
+ }
+
+ if dataB != nil {
+ return rws.bw.Write(dataB)
+ } else {
+ return rws.bw.WriteString(dataS)
+ }
+}
+
+func (w *http2responseWriter) handlerDone() {
+ rws := w.rws
+ rws.handlerDone = true
+ w.Flush()
+ w.rws = nil
+ http2responseWriterStatePool.Put(rws)
+}
+
+// foreachHeaderElement splits v according to the "#rule" construction
+// in RFC 2616 section 2.1 and calls fn for each non-empty element.
+func http2foreachHeaderElement(v string, fn func(string)) {
+ v = textproto.TrimString(v)
+ if v == "" {
+ return
+ }
+ if !strings.Contains(v, ",") {
+ fn(v)
+ return
+ }
+ for _, f := range strings.Split(v, ",") {
+ if f = textproto.TrimString(f); f != "" {
+ fn(f)
+ }
+ }
+}
+
+const (
+ // transportDefaultConnFlow is how many connection-level flow control
+ // tokens we give the server at start-up, past the default 64k.
+ http2transportDefaultConnFlow = 1 << 30
+
+ // transportDefaultStreamFlow is how many stream-level flow
+ // control tokens we announce to the peer, and how many bytes
+ // we buffer per stream.
+ http2transportDefaultStreamFlow = 4 << 20
+
+ // transportDefaultStreamMinRefresh is the minimum number of bytes we'll send
+ // a stream-level WINDOW_UPDATE for at a time.
+ http2transportDefaultStreamMinRefresh = 4 << 10
+
+ http2defaultUserAgent = "Go-http-client/2.0"
+)
+
+// Transport is an HTTP/2 Transport.
+//
+// A Transport internally caches connections to servers. It is safe
+// for concurrent use by multiple goroutines.
+type http2Transport struct {
+ // DialTLS specifies an optional dial function for creating
+ // TLS connections for requests.
+ //
+ // If DialTLS is nil, tls.Dial is used.
+ //
+ // If the returned net.Conn has a ConnectionState method like tls.Conn,
+ // it will be used to set http.Response.TLS.
+ DialTLS func(network, addr string, cfg *tls.Config) (net.Conn, error)
+
+ // TLSClientConfig specifies the TLS configuration to use with
+ // tls.Client. If nil, the default configuration is used.
+ TLSClientConfig *tls.Config
+
+ // ConnPool optionally specifies an alternate connection pool to use.
+ // If nil, the default is used.
+ ConnPool http2ClientConnPool
+
+ // DisableCompression, if true, prevents the Transport from
+ // requesting compression with an "Accept-Encoding: gzip"
+ // request header when the Request contains no existing
+ // Accept-Encoding value. If the Transport requests gzip on
+ // its own and gets a gzipped response, it's transparently
+ // decoded in the Response.Body. However, if the user
+ // explicitly requested gzip it is not automatically
+ // uncompressed.
+ DisableCompression bool
+
+ // MaxHeaderListSize is the http2 SETTINGS_MAX_HEADER_LIST_SIZE to
+ // send in the initial settings frame. It is how many bytes
+ // of response headers are allow. Unlike the http2 spec, zero here
+ // means to use a default limit (currently 10MB). If you actually
+ // want to advertise an ulimited value to the peer, Transport
+ // interprets the highest possible value here (0xffffffff or 1<<32-1)
+ // to mean no limit.
+ MaxHeaderListSize uint32
+
+ // t1, if non-nil, is the standard library Transport using
+ // this transport. Its settings are used (but not its
+ // RoundTrip method, etc).
+ t1 *Transport
+
+ connPoolOnce sync.Once
+ connPoolOrDef http2ClientConnPool // non-nil version of ConnPool
+}
+
+func (t *http2Transport) maxHeaderListSize() uint32 {
+ if t.MaxHeaderListSize == 0 {
+ return 10 << 20
+ }
+ if t.MaxHeaderListSize == 0xffffffff {
+ return 0
+ }
+ return t.MaxHeaderListSize
+}
+
+func (t *http2Transport) disableCompression() bool {
+ return t.DisableCompression || (t.t1 != nil && t.t1.DisableCompression)
+}
+
+var http2errTransportVersion = errors.New("http2: ConfigureTransport is only supported starting at Go 1.6")
+
+// ConfigureTransport configures a net/http HTTP/1 Transport to use HTTP/2.
+// It requires Go 1.6 or later and returns an error if the net/http package is too old
+// or if t1 has already been HTTP/2-enabled.
+func http2ConfigureTransport(t1 *Transport) error {
+ _, err := http2configureTransport(t1)
+ return err
+}
+
+func (t *http2Transport) connPool() http2ClientConnPool {
+ t.connPoolOnce.Do(t.initConnPool)
+ return t.connPoolOrDef
+}
+
+func (t *http2Transport) initConnPool() {
+ if t.ConnPool != nil {
+ t.connPoolOrDef = t.ConnPool
+ } else {
+ t.connPoolOrDef = &http2clientConnPool{t: t}
+ }
+}
+
+// ClientConn is the state of a single HTTP/2 client connection to an
+// HTTP/2 server.
+type http2ClientConn struct {
+ t *http2Transport
+ tconn net.Conn // usually *tls.Conn, except specialized impls
+ tlsState *tls.ConnectionState // nil only for specialized impls
+
+ // readLoop goroutine fields:
+ readerDone chan struct{} // closed on error
+ readerErr error // set before readerDone is closed
+
+ mu sync.Mutex // guards following
+ cond *sync.Cond // hold mu; broadcast on flow/closed changes
+ flow http2flow // our conn-level flow control quota (cs.flow is per stream)
+ inflow http2flow // peer's conn-level flow control
+ closed bool
+ goAway *http2GoAwayFrame // if non-nil, the GoAwayFrame we received
+ streams map[uint32]*http2clientStream // client-initiated
+ nextStreamID uint32
+ bw *bufio.Writer
+ br *bufio.Reader
+ fr *http2Framer
+ // Settings from peer:
+ maxFrameSize uint32
+ maxConcurrentStreams uint32
+ initialWindowSize uint32
+ hbuf bytes.Buffer // HPACK encoder writes into this
+ henc *hpack.Encoder
+ freeBuf [][]byte
+
+ wmu sync.Mutex // held while writing; acquire AFTER mu if holding both
+ werr error // first write error that has occurred
+}
+
+// clientStream is the state for a single HTTP/2 stream. One of these
+// is created for each Transport.RoundTrip call.
+type http2clientStream struct {
+ cc *http2ClientConn
+ req *Request
+ ID uint32
+ resc chan http2resAndError
+ bufPipe http2pipe // buffered pipe with the flow-controlled response payload
+ requestedGzip bool
+
+ flow http2flow // guarded by cc.mu
+ inflow http2flow // guarded by cc.mu
+ bytesRemain int64 // -1 means unknown; owned by transportResponseBody.Read
+ readErr error // sticky read error; owned by transportResponseBody.Read
+ stopReqBody error // if non-nil, stop writing req body; guarded by cc.mu
+
+ peerReset chan struct{} // closed on peer reset
+ resetErr error // populated before peerReset is closed
+
+ done chan struct{} // closed when stream remove from cc.streams map; close calls guarded by cc.mu
+
+ // owned by clientConnReadLoop:
+ pastHeaders bool // got HEADERS w/ END_HEADERS
+ pastTrailers bool // got second HEADERS frame w/ END_HEADERS
+
+ trailer Header // accumulated trailers
+ resTrailer *Header // client's Response.Trailer
+}
+
+// awaitRequestCancel runs in its own goroutine and waits for the user
+// to either cancel a RoundTrip request (using the provided
+// Request.Cancel channel), or for the request to be done (any way it
+// might be removed from the cc.streams map: peer reset, successful
+// completion, TCP connection breakage, etc)
+func (cs *http2clientStream) awaitRequestCancel(cancel <-chan struct{}) {
+ if cancel == nil {
+ return
+ }
+ select {
+ case <-cancel:
+ cs.bufPipe.CloseWithError(http2errRequestCanceled)
+ cs.cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil)
+ case <-cs.done:
+ }
+}
+
+// checkReset reports any error sent in a RST_STREAM frame by the
+// server.
+func (cs *http2clientStream) checkReset() error {
+ select {
+ case <-cs.peerReset:
+ return cs.resetErr
+ default:
+ return nil
+ }
+}
+
+func (cs *http2clientStream) abortRequestBodyWrite(err error) {
+ if err == nil {
+ panic("nil error")
+ }
+ cc := cs.cc
+ cc.mu.Lock()
+ cs.stopReqBody = err
+ cc.cond.Broadcast()
+ cc.mu.Unlock()
+}
+
+type http2stickyErrWriter struct {
+ w io.Writer
+ err *error
+}
+
+func (sew http2stickyErrWriter) Write(p []byte) (n int, err error) {
+ if *sew.err != nil {
+ return 0, *sew.err
+ }
+ n, err = sew.w.Write(p)
+ *sew.err = err
+ return
+}
+
+var http2ErrNoCachedConn = errors.New("http2: no cached connection was available")
+
+// RoundTripOpt are options for the Transport.RoundTripOpt method.
+type http2RoundTripOpt struct {
+ // OnlyCachedConn controls whether RoundTripOpt may
+ // create a new TCP connection. If set true and
+ // no cached connection is available, RoundTripOpt
+ // will return ErrNoCachedConn.
+ OnlyCachedConn bool
+}
+
+func (t *http2Transport) RoundTrip(req *Request) (*Response, error) {
+ return t.RoundTripOpt(req, http2RoundTripOpt{})
+}
+
+// authorityAddr returns a given authority (a host/IP, or host:port / ip:port)
+// and returns a host:port. The port 443 is added if needed.
+func http2authorityAddr(authority string) (addr string) {
+ if _, _, err := net.SplitHostPort(authority); err == nil {
+ return authority
+ }
+ return net.JoinHostPort(authority, "443")
+}
+
+// RoundTripOpt is like RoundTrip, but takes options.
+func (t *http2Transport) RoundTripOpt(req *Request, opt http2RoundTripOpt) (*Response, error) {
+ if req.URL.Scheme != "https" {
+ return nil, errors.New("http2: unsupported scheme")
+ }
+
+ addr := http2authorityAddr(req.URL.Host)
+ for {
+ cc, err := t.connPool().GetClientConn(req, addr)
+ if err != nil {
+ t.vlogf("http2: Transport failed to get client conn for %s: %v", addr, err)
+ return nil, err
+ }
+ res, err := cc.RoundTrip(req)
+ if http2shouldRetryRequest(req, err) {
+ continue
+ }
+ if err != nil {
+ t.vlogf("RoundTrip failure: %v", err)
+ return nil, err
+ }
+ return res, nil
+ }
+}
+
+// CloseIdleConnections closes any connections which were previously
+// connected from previous requests but are now sitting idle.
+// It does not interrupt any connections currently in use.
+func (t *http2Transport) CloseIdleConnections() {
+ if cp, ok := t.connPool().(*http2clientConnPool); ok {
+ cp.closeIdleConnections()
+ }
+}
+
+var (
+ http2errClientConnClosed = errors.New("http2: client conn is closed")
+ http2errClientConnUnusable = errors.New("http2: client conn not usable")
+)
+
+func http2shouldRetryRequest(req *Request, err error) bool {
+
+ return err == http2errClientConnUnusable
+}
+
+func (t *http2Transport) dialClientConn(addr string) (*http2ClientConn, error) {
+ host, _, err := net.SplitHostPort(addr)
+ if err != nil {
+ return nil, err
+ }
+ tconn, err := t.dialTLS()("tcp", addr, t.newTLSConfig(host))
+ if err != nil {
+ return nil, err
+ }
+ return t.NewClientConn(tconn)
+}
+
+func (t *http2Transport) newTLSConfig(host string) *tls.Config {
+ cfg := new(tls.Config)
+ if t.TLSClientConfig != nil {
+ *cfg = *t.TLSClientConfig
+ }
+ cfg.NextProtos = []string{http2NextProtoTLS}
+ cfg.ServerName = host
+ return cfg
+}
+
+func (t *http2Transport) dialTLS() func(string, string, *tls.Config) (net.Conn, error) {
+ if t.DialTLS != nil {
+ return t.DialTLS
+ }
+ return t.dialTLSDefault
+}
+
+func (t *http2Transport) dialTLSDefault(network, addr string, cfg *tls.Config) (net.Conn, error) {
+ cn, err := tls.Dial(network, addr, cfg)
+ if err != nil {
+ return nil, err
+ }
+ if err := cn.Handshake(); err != nil {
+ return nil, err
+ }
+ if !cfg.InsecureSkipVerify {
+ if err := cn.VerifyHostname(cfg.ServerName); err != nil {
+ return nil, err
+ }
+ }
+ state := cn.ConnectionState()
+ if p := state.NegotiatedProtocol; p != http2NextProtoTLS {
+ return nil, fmt.Errorf("http2: unexpected ALPN protocol %q; want %q", p, http2NextProtoTLS)
+ }
+ if !state.NegotiatedProtocolIsMutual {
+ return nil, errors.New("http2: could not negotiate protocol mutually")
+ }
+ return cn, nil
+}
+
+// disableKeepAlives reports whether connections should be closed as
+// soon as possible after handling the first request.
+func (t *http2Transport) disableKeepAlives() bool {
+ return t.t1 != nil && t.t1.DisableKeepAlives
+}
+
+func (t *http2Transport) NewClientConn(c net.Conn) (*http2ClientConn, error) {
+ if http2VerboseLogs {
+ t.vlogf("http2: Transport creating client conn to %v", c.RemoteAddr())
+ }
+ if _, err := c.Write(http2clientPreface); err != nil {
+ t.vlogf("client preface write error: %v", err)
+ return nil, err
+ }
+
+ cc := &http2ClientConn{
+ t: t,
+ tconn: c,
+ readerDone: make(chan struct{}),
+ nextStreamID: 1,
+ maxFrameSize: 16 << 10,
+ initialWindowSize: 65535,
+ maxConcurrentStreams: 1000,
+ streams: make(map[uint32]*http2clientStream),
+ }
+ cc.cond = sync.NewCond(&cc.mu)
+ cc.flow.add(int32(http2initialWindowSize))
+
+ cc.bw = bufio.NewWriter(http2stickyErrWriter{c, &cc.werr})
+ cc.br = bufio.NewReader(c)
+ cc.fr = http2NewFramer(cc.bw, cc.br)
+
+ cc.henc = hpack.NewEncoder(&cc.hbuf)
+
+ type connectionStater interface {
+ ConnectionState() tls.ConnectionState
+ }
+ if cs, ok := c.(connectionStater); ok {
+ state := cs.ConnectionState()
+ cc.tlsState = &state
+ }
+
+ initialSettings := []http2Setting{
+ http2Setting{ID: http2SettingEnablePush, Val: 0},
+ http2Setting{ID: http2SettingInitialWindowSize, Val: http2transportDefaultStreamFlow},
+ }
+ if max := t.maxHeaderListSize(); max != 0 {
+ initialSettings = append(initialSettings, http2Setting{ID: http2SettingMaxHeaderListSize, Val: max})
+ }
+ cc.fr.WriteSettings(initialSettings...)
+ cc.fr.WriteWindowUpdate(0, http2transportDefaultConnFlow)
+ cc.inflow.add(http2transportDefaultConnFlow + http2initialWindowSize)
+ cc.bw.Flush()
+ if cc.werr != nil {
+ return nil, cc.werr
+ }
+
+ f, err := cc.fr.ReadFrame()
+ if err != nil {
+ return nil, err
+ }
+ sf, ok := f.(*http2SettingsFrame)
+ if !ok {
+ return nil, fmt.Errorf("expected settings frame, got: %T", f)
+ }
+ cc.fr.WriteSettingsAck()
+ cc.bw.Flush()
+
+ sf.ForeachSetting(func(s http2Setting) error {
+ switch s.ID {
+ case http2SettingMaxFrameSize:
+ cc.maxFrameSize = s.Val
+ case http2SettingMaxConcurrentStreams:
+ cc.maxConcurrentStreams = s.Val
+ case http2SettingInitialWindowSize:
+ cc.initialWindowSize = s.Val
+ default:
+
+ t.vlogf("Unhandled Setting: %v", s)
+ }
+ return nil
+ })
+
+ go cc.readLoop()
+ return cc, nil
+}
+
+func (cc *http2ClientConn) setGoAway(f *http2GoAwayFrame) {
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ cc.goAway = f
+}
+
+func (cc *http2ClientConn) CanTakeNewRequest() bool {
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ return cc.canTakeNewRequestLocked()
+}
+
+func (cc *http2ClientConn) canTakeNewRequestLocked() bool {
+ return cc.goAway == nil && !cc.closed &&
+ int64(len(cc.streams)+1) < int64(cc.maxConcurrentStreams) &&
+ cc.nextStreamID < 2147483647
+}
+
+func (cc *http2ClientConn) closeIfIdle() {
+ cc.mu.Lock()
+ if len(cc.streams) > 0 {
+ cc.mu.Unlock()
+ return
+ }
+ cc.closed = true
+
+ cc.mu.Unlock()
+
+ cc.tconn.Close()
+}
+
+const http2maxAllocFrameSize = 512 << 10
+
+// frameBuffer returns a scratch buffer suitable for writing DATA frames.
+// They're capped at the min of the peer's max frame size or 512KB
+// (kinda arbitrarily), but definitely capped so we don't allocate 4GB
+// bufers.
+func (cc *http2ClientConn) frameScratchBuffer() []byte {
+ cc.mu.Lock()
+ size := cc.maxFrameSize
+ if size > http2maxAllocFrameSize {
+ size = http2maxAllocFrameSize
+ }
+ for i, buf := range cc.freeBuf {
+ if len(buf) >= int(size) {
+ cc.freeBuf[i] = nil
+ cc.mu.Unlock()
+ return buf[:size]
+ }
+ }
+ cc.mu.Unlock()
+ return make([]byte, size)
+}
+
+func (cc *http2ClientConn) putFrameScratchBuffer(buf []byte) {
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ const maxBufs = 4 // arbitrary; 4 concurrent requests per conn? investigate.
+ if len(cc.freeBuf) < maxBufs {
+ cc.freeBuf = append(cc.freeBuf, buf)
+ return
+ }
+ for i, old := range cc.freeBuf {
+ if old == nil {
+ cc.freeBuf[i] = buf
+ return
+ }
+ }
+
+}
+
+// errRequestCanceled is a copy of net/http's errRequestCanceled because it's not
+// exported. At least they'll be DeepEqual for h1-vs-h2 comparisons tests.
+var http2errRequestCanceled = errors.New("net/http: request canceled")
+
+func http2commaSeparatedTrailers(req *Request) (string, error) {
+ keys := make([]string, 0, len(req.Trailer))
+ for k := range req.Trailer {
+ k = CanonicalHeaderKey(k)
+ switch k {
+ case "Transfer-Encoding", "Trailer", "Content-Length":
+ return "", &http2badStringError{"invalid Trailer key", k}
+ }
+ keys = append(keys, k)
+ }
+ if len(keys) > 0 {
+ sort.Strings(keys)
+
+ return strings.Join(keys, ","), nil
+ }
+ return "", nil
+}
+
+func (cc *http2ClientConn) responseHeaderTimeout() time.Duration {
+ if cc.t.t1 != nil {
+ return cc.t.t1.ResponseHeaderTimeout
+ }
+
+ return 0
+}
+
+func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) {
+ trailers, err := http2commaSeparatedTrailers(req)
+ if err != nil {
+ return nil, err
+ }
+ hasTrailers := trailers != ""
+
+ var body io.Reader = req.Body
+ contentLen := req.ContentLength
+ if req.Body != nil && contentLen == 0 {
+ // Test to see if it's actually zero or just unset.
+ var buf [1]byte
+ n, rerr := io.ReadFull(body, buf[:])
+ if rerr != nil && rerr != io.EOF {
+ contentLen = -1
+ body = http2errorReader{rerr}
+ } else if n == 1 {
+
+ contentLen = -1
+ body = io.MultiReader(bytes.NewReader(buf[:]), body)
+ } else {
+
+ body = nil
+ }
+ }
+
+ cc.mu.Lock()
+ if cc.closed || !cc.canTakeNewRequestLocked() {
+ cc.mu.Unlock()
+ return nil, http2errClientConnUnusable
+ }
+
+ cs := cc.newStream()
+ cs.req = req
+ hasBody := body != nil
+
+ if !cc.t.disableCompression() &&
+ req.Header.Get("Accept-Encoding") == "" &&
+ req.Header.Get("Range") == "" &&
+ req.Method != "HEAD" {
+
+ cs.requestedGzip = true
+ }
+
+ hdrs := cc.encodeHeaders(req, cs.requestedGzip, trailers, contentLen)
+ cc.wmu.Lock()
+ endStream := !hasBody && !hasTrailers
+ werr := cc.writeHeaders(cs.ID, endStream, hdrs)
+ cc.wmu.Unlock()
+ cc.mu.Unlock()
+
+ if werr != nil {
+ if hasBody {
+ req.Body.Close()
+ }
+ cc.forgetStreamID(cs.ID)
+
+ return nil, werr
+ }
+
+ var respHeaderTimer <-chan time.Time
+ var bodyCopyErrc chan error // result of body copy
+ if hasBody {
+ bodyCopyErrc = make(chan error, 1)
+ go func() {
+ bodyCopyErrc <- cs.writeRequestBody(body, req.Body)
+ }()
+ } else {
+ if d := cc.responseHeaderTimeout(); d != 0 {
+ timer := time.NewTimer(d)
+ defer timer.Stop()
+ respHeaderTimer = timer.C
+ }
+ }
+
+ readLoopResCh := cs.resc
+ requestCanceledCh := http2requestCancel(req)
+ bodyWritten := false
+
+ for {
+ select {
+ case re := <-readLoopResCh:
+ res := re.res
+ if re.err != nil || res.StatusCode > 299 {
+
+ cs.abortRequestBodyWrite(http2errStopReqBodyWrite)
+ }
+ if re.err != nil {
+ cc.forgetStreamID(cs.ID)
+ return nil, re.err
+ }
+ res.Request = req
+ res.TLS = cc.tlsState
+ return res, nil
+ case <-respHeaderTimer:
+ cc.forgetStreamID(cs.ID)
+ if !hasBody || bodyWritten {
+ cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil)
+ } else {
+ cs.abortRequestBodyWrite(http2errStopReqBodyWriteAndCancel)
+ }
+ return nil, http2errTimeout
+ case <-requestCanceledCh:
+ cc.forgetStreamID(cs.ID)
+ if !hasBody || bodyWritten {
+ cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil)
+ } else {
+ cs.abortRequestBodyWrite(http2errStopReqBodyWriteAndCancel)
+ }
+ return nil, http2errRequestCanceled
+ case <-cs.peerReset:
+
+ return nil, cs.resetErr
+ case err := <-bodyCopyErrc:
+ if err != nil {
+ return nil, err
+ }
+ bodyWritten = true
+ if d := cc.responseHeaderTimeout(); d != 0 {
+ timer := time.NewTimer(d)
+ defer timer.Stop()
+ respHeaderTimer = timer.C
+ }
+ }
+ }
+}
+
+// requires cc.wmu be held
+func (cc *http2ClientConn) writeHeaders(streamID uint32, endStream bool, hdrs []byte) error {
+ first := true
+ frameSize := int(cc.maxFrameSize)
+ for len(hdrs) > 0 && cc.werr == nil {
+ chunk := hdrs
+ if len(chunk) > frameSize {
+ chunk = chunk[:frameSize]
+ }
+ hdrs = hdrs[len(chunk):]
+ endHeaders := len(hdrs) == 0
+ if first {
+ cc.fr.WriteHeaders(http2HeadersFrameParam{
+ StreamID: streamID,
+ BlockFragment: chunk,
+ EndStream: endStream,
+ EndHeaders: endHeaders,
+ })
+ first = false
+ } else {
+ cc.fr.WriteContinuation(streamID, endHeaders, chunk)
+ }
+ }
+
+ cc.bw.Flush()
+ return cc.werr
+}
+
+// internal error values; they don't escape to callers
+var (
+ // abort request body write; don't send cancel
+ http2errStopReqBodyWrite = errors.New("http2: aborting request body write")
+
+ // abort request body write, but send stream reset of cancel.
+ http2errStopReqBodyWriteAndCancel = errors.New("http2: canceling request")
+)
+
+func (cs *http2clientStream) writeRequestBody(body io.Reader, bodyCloser io.Closer) (err error) {
+ cc := cs.cc
+ sentEnd := false
+ buf := cc.frameScratchBuffer()
+ defer cc.putFrameScratchBuffer(buf)
+
+ defer func() {
+
+ cerr := bodyCloser.Close()
+ if err == nil {
+ err = cerr
+ }
+ }()
+
+ req := cs.req
+ hasTrailers := req.Trailer != nil
+
+ var sawEOF bool
+ for !sawEOF {
+ n, err := body.Read(buf)
+ if err == io.EOF {
+ sawEOF = true
+ err = nil
+ } else if err != nil {
+ return err
+ }
+
+ remain := buf[:n]
+ for len(remain) > 0 && err == nil {
+ var allowed int32
+ allowed, err = cs.awaitFlowControl(len(remain))
+ switch {
+ case err == http2errStopReqBodyWrite:
+ return err
+ case err == http2errStopReqBodyWriteAndCancel:
+ cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil)
+ return err
+ case err != nil:
+ return err
+ }
+ cc.wmu.Lock()
+ data := remain[:allowed]
+ remain = remain[allowed:]
+ sentEnd = sawEOF && len(remain) == 0 && !hasTrailers
+ err = cc.fr.WriteData(cs.ID, sentEnd, data)
+ if err == nil {
+
+ err = cc.bw.Flush()
+ }
+ cc.wmu.Unlock()
+ }
+ if err != nil {
+ return err
+ }
+ }
+
+ cc.wmu.Lock()
+ if !sentEnd {
+ var trls []byte
+ if hasTrailers {
+ cc.mu.Lock()
+ trls = cc.encodeTrailers(req)
+ cc.mu.Unlock()
+ }
+
+ if len(trls) > 0 {
+ err = cc.writeHeaders(cs.ID, true, trls)
+ } else {
+ err = cc.fr.WriteData(cs.ID, true, nil)
+ }
+ }
+ if ferr := cc.bw.Flush(); ferr != nil && err == nil {
+ err = ferr
+ }
+ cc.wmu.Unlock()
+
+ return err
+}
+
+// awaitFlowControl waits for [1, min(maxBytes, cc.cs.maxFrameSize)] flow
+// control tokens from the server.
+// It returns either the non-zero number of tokens taken or an error
+// if the stream is dead.
+func (cs *http2clientStream) awaitFlowControl(maxBytes int) (taken int32, err error) {
+ cc := cs.cc
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ for {
+ if cc.closed {
+ return 0, http2errClientConnClosed
+ }
+ if cs.stopReqBody != nil {
+ return 0, cs.stopReqBody
+ }
+ if err := cs.checkReset(); err != nil {
+ return 0, err
+ }
+ if a := cs.flow.available(); a > 0 {
+ take := a
+ if int(take) > maxBytes {
+
+ take = int32(maxBytes)
+ }
+ if take > int32(cc.maxFrameSize) {
+ take = int32(cc.maxFrameSize)
+ }
+ cs.flow.take(take)
+ return take, nil
+ }
+ cc.cond.Wait()
+ }
+}
+
+type http2badStringError struct {
+ what string
+ str string
+}
+
+func (e *http2badStringError) Error() string { return fmt.Sprintf("%s %q", e.what, e.str) }
+
+// requires cc.mu be held.
+func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trailers string, contentLength int64) []byte {
+ cc.hbuf.Reset()
+
+ host := req.Host
+ if host == "" {
+ host = req.URL.Host
+ }
+
+ cc.writeHeader(":authority", host)
+ cc.writeHeader(":method", req.Method)
+ if req.Method != "CONNECT" {
+ cc.writeHeader(":path", req.URL.RequestURI())
+ cc.writeHeader(":scheme", "https")
+ }
+ if trailers != "" {
+ cc.writeHeader("trailer", trailers)
+ }
+
+ var didUA bool
+ for k, vv := range req.Header {
+ lowKey := strings.ToLower(k)
+ if lowKey == "host" || lowKey == "content-length" {
+ continue
+ }
+ if lowKey == "user-agent" {
+
+ didUA = true
+ if len(vv) < 1 {
+ continue
+ }
+ vv = vv[:1]
+ if vv[0] == "" {
+ continue
+ }
+ }
+ for _, v := range vv {
+ cc.writeHeader(lowKey, v)
+ }
+ }
+ if http2shouldSendReqContentLength(req.Method, contentLength) {
+ cc.writeHeader("content-length", strconv.FormatInt(contentLength, 10))
+ }
+ if addGzipHeader {
+ cc.writeHeader("accept-encoding", "gzip")
+ }
+ if !didUA {
+ cc.writeHeader("user-agent", http2defaultUserAgent)
+ }
+ return cc.hbuf.Bytes()
+}
+
+// shouldSendReqContentLength reports whether the http2.Transport should send
+// a "content-length" request header. This logic is basically a copy of the net/http
+// transferWriter.shouldSendContentLength.
+// The contentLength is the corrected contentLength (so 0 means actually 0, not unknown).
+// -1 means unknown.
+func http2shouldSendReqContentLength(method string, contentLength int64) bool {
+ if contentLength > 0 {
+ return true
+ }
+ if contentLength < 0 {
+ return false
+ }
+
+ switch method {
+ case "POST", "PUT", "PATCH":
+ return true
+ default:
+ return false
+ }
+}
+
+// requires cc.mu be held.
+func (cc *http2ClientConn) encodeTrailers(req *Request) []byte {
+ cc.hbuf.Reset()
+ for k, vv := range req.Trailer {
+
+ lowKey := strings.ToLower(k)
+ for _, v := range vv {
+ cc.writeHeader(lowKey, v)
+ }
+ }
+ return cc.hbuf.Bytes()
+}
+
+func (cc *http2ClientConn) writeHeader(name, value string) {
+ if http2VerboseLogs {
+ log.Printf("http2: Transport encoding header %q = %q", name, value)
+ }
+ cc.henc.WriteField(hpack.HeaderField{Name: name, Value: value})
+}
+
+type http2resAndError struct {
+ res *Response
+ err error
+}
+
+// requires cc.mu be held.
+func (cc *http2ClientConn) newStream() *http2clientStream {
+ cs := &http2clientStream{
+ cc: cc,
+ ID: cc.nextStreamID,
+ resc: make(chan http2resAndError, 1),
+ peerReset: make(chan struct{}),
+ done: make(chan struct{}),
+ }
+ cs.flow.add(int32(cc.initialWindowSize))
+ cs.flow.setConnFlow(&cc.flow)
+ cs.inflow.add(http2transportDefaultStreamFlow)
+ cs.inflow.setConnFlow(&cc.inflow)
+ cc.nextStreamID += 2
+ cc.streams[cs.ID] = cs
+ return cs
+}
+
+func (cc *http2ClientConn) forgetStreamID(id uint32) {
+ cc.streamByID(id, true)
+}
+
+func (cc *http2ClientConn) streamByID(id uint32, andRemove bool) *http2clientStream {
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ cs := cc.streams[id]
+ if andRemove && cs != nil && !cc.closed {
+ delete(cc.streams, id)
+ close(cs.done)
+ }
+ return cs
+}
+
+// clientConnReadLoop is the state owned by the clientConn's frame-reading readLoop.
+type http2clientConnReadLoop struct {
+ cc *http2ClientConn
+ activeRes map[uint32]*http2clientStream // keyed by streamID
+
+ hdec *hpack.Decoder
+
+ // Fields reset on each HEADERS:
+ nextRes *Response
+ sawRegHeader bool // saw non-pseudo header
+ reqMalformed error // non-nil once known to be malformed
+ lastHeaderEndsStream bool
+ headerListSize int64 // actually uint32, but easier math this way
+}
+
+// readLoop runs in its own goroutine and reads and dispatches frames.
+func (cc *http2ClientConn) readLoop() {
+ rl := &http2clientConnReadLoop{
+ cc: cc,
+ activeRes: make(map[uint32]*http2clientStream),
+ }
+ rl.hdec = hpack.NewDecoder(http2initialHeaderTableSize, rl.onNewHeaderField)
+
+ defer rl.cleanup()
+ cc.readerErr = rl.run()
+ if ce, ok := cc.readerErr.(http2ConnectionError); ok {
+ cc.wmu.Lock()
+ cc.fr.WriteGoAway(0, http2ErrCode(ce), nil)
+ cc.wmu.Unlock()
+ }
+}
+
+func (rl *http2clientConnReadLoop) cleanup() {
+ cc := rl.cc
+ defer cc.tconn.Close()
+ defer cc.t.connPool().MarkDead(cc)
+ defer close(cc.readerDone)
+
+ err := cc.readerErr
+ if err == io.EOF {
+ err = io.ErrUnexpectedEOF
+ }
+ cc.mu.Lock()
+ for _, cs := range rl.activeRes {
+ cs.bufPipe.CloseWithError(err)
+ }
+ for _, cs := range cc.streams {
+ select {
+ case cs.resc <- http2resAndError{err: err}:
+ default:
+ }
+ close(cs.done)
+ }
+ cc.closed = true
+ cc.cond.Broadcast()
+ cc.mu.Unlock()
+}
+
+func (rl *http2clientConnReadLoop) run() error {
+ cc := rl.cc
+ closeWhenIdle := cc.t.disableKeepAlives()
+ gotReply := false
+ for {
+ f, err := cc.fr.ReadFrame()
+ if err != nil {
+ cc.vlogf("Transport readFrame error: (%T) %v", err, err)
+ }
+ if se, ok := err.(http2StreamError); ok {
+
+ return se
+ } else if err != nil {
+ return err
+ }
+ if http2VerboseLogs {
+ cc.vlogf("http2: Transport received %s", http2summarizeFrame(f))
+ }
+ maybeIdle := false
+
+ switch f := f.(type) {
+ case *http2HeadersFrame:
+ err = rl.processHeaders(f)
+ maybeIdle = true
+ gotReply = true
+ case *http2ContinuationFrame:
+ err = rl.processContinuation(f)
+ maybeIdle = true
+ case *http2DataFrame:
+ err = rl.processData(f)
+ maybeIdle = true
+ case *http2GoAwayFrame:
+ err = rl.processGoAway(f)
+ maybeIdle = true
+ case *http2RSTStreamFrame:
+ err = rl.processResetStream(f)
+ maybeIdle = true
+ case *http2SettingsFrame:
+ err = rl.processSettings(f)
+ case *http2PushPromiseFrame:
+ err = rl.processPushPromise(f)
+ case *http2WindowUpdateFrame:
+ err = rl.processWindowUpdate(f)
+ case *http2PingFrame:
+ err = rl.processPing(f)
+ default:
+ cc.logf("Transport: unhandled response frame type %T", f)
+ }
+ if err != nil {
+ return err
+ }
+ if closeWhenIdle && gotReply && maybeIdle && len(rl.activeRes) == 0 {
+ cc.closeIfIdle()
+ }
+ }
+}
+
+func (rl *http2clientConnReadLoop) processHeaders(f *http2HeadersFrame) error {
+ rl.sawRegHeader = false
+ rl.reqMalformed = nil
+ rl.lastHeaderEndsStream = f.StreamEnded()
+ rl.headerListSize = 0
+ rl.nextRes = &Response{
+ Proto: "HTTP/2.0",
+ ProtoMajor: 2,
+ Header: make(Header),
+ }
+ rl.hdec.SetEmitEnabled(true)
+ return rl.processHeaderBlockFragment(f.HeaderBlockFragment(), f.StreamID, f.HeadersEnded())
+}
+
+func (rl *http2clientConnReadLoop) processContinuation(f *http2ContinuationFrame) error {
+ return rl.processHeaderBlockFragment(f.HeaderBlockFragment(), f.StreamID, f.HeadersEnded())
+}
+
+func (rl *http2clientConnReadLoop) processHeaderBlockFragment(frag []byte, streamID uint32, finalFrag bool) error {
+ cc := rl.cc
+ streamEnded := rl.lastHeaderEndsStream
+ cs := cc.streamByID(streamID, streamEnded && finalFrag)
+ if cs == nil {
+
+ return nil
+ }
+ if cs.pastHeaders {
+ rl.hdec.SetEmitFunc(func(f hpack.HeaderField) { rl.onNewTrailerField(cs, f) })
+ } else {
+ rl.hdec.SetEmitFunc(rl.onNewHeaderField)
+ }
+ _, err := rl.hdec.Write(frag)
+ if err != nil {
+ return http2ConnectionError(http2ErrCodeCompression)
+ }
+ if finalFrag {
+ if err := rl.hdec.Close(); err != nil {
+ return http2ConnectionError(http2ErrCodeCompression)
+ }
+ }
+
+ if !finalFrag {
+ return nil
+ }
+
+ if !cs.pastHeaders {
+ cs.pastHeaders = true
+ } else {
+
+ if cs.pastTrailers {
+
+ return http2ConnectionError(http2ErrCodeProtocol)
+ }
+ cs.pastTrailers = true
+ if !streamEnded {
+
+ return http2ConnectionError(http2ErrCodeProtocol)
+ }
+ rl.endStream(cs)
+ return nil
+ }
+
+ if rl.reqMalformed != nil {
+ cs.resc <- http2resAndError{err: rl.reqMalformed}
+ rl.cc.writeStreamReset(cs.ID, http2ErrCodeProtocol, rl.reqMalformed)
+ return nil
+ }
+
+ res := rl.nextRes
+
+ if res.StatusCode == 100 {
+
+ cs.pastHeaders = false
+ return nil
+ }
+
+ if !streamEnded || cs.req.Method == "HEAD" {
+ res.ContentLength = -1
+ if clens := res.Header["Content-Length"]; len(clens) == 1 {
+ if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil {
+ res.ContentLength = clen64
+ } else {
+
+ }
+ } else if len(clens) > 1 {
+
+ }
+ }
+
+ if streamEnded {
+ res.Body = http2noBody
+ } else {
+ buf := new(bytes.Buffer)
+ cs.bufPipe = http2pipe{b: buf}
+ cs.bytesRemain = res.ContentLength
+ res.Body = http2transportResponseBody{cs}
+ go cs.awaitRequestCancel(http2requestCancel(cs.req))
+
+ if cs.requestedGzip && res.Header.Get("Content-Encoding") == "gzip" {
+ res.Header.Del("Content-Encoding")
+ res.Header.Del("Content-Length")
+ res.ContentLength = -1
+ res.Body = &http2gzipReader{body: res.Body}
+ }
+ }
+
+ cs.resTrailer = &res.Trailer
+ rl.activeRes[cs.ID] = cs
+ cs.resc <- http2resAndError{res: res}
+ rl.nextRes = nil
+ return nil
+}
+
+// transportResponseBody is the concrete type of Transport.RoundTrip's
+// Response.Body. It is an io.ReadCloser. On Read, it reads from cs.body.
+// On Close it sends RST_STREAM if EOF wasn't already seen.
+type http2transportResponseBody struct {
+ cs *http2clientStream
+}
+
+func (b http2transportResponseBody) Read(p []byte) (n int, err error) {
+ cs := b.cs
+ cc := cs.cc
+
+ if cs.readErr != nil {
+ return 0, cs.readErr
+ }
+ n, err = b.cs.bufPipe.Read(p)
+ if cs.bytesRemain != -1 {
+ if int64(n) > cs.bytesRemain {
+ n = int(cs.bytesRemain)
+ if err == nil {
+ err = errors.New("net/http: server replied with more than declared Content-Length; truncated")
+ cc.writeStreamReset(cs.ID, http2ErrCodeProtocol, err)
+ }
+ cs.readErr = err
+ return int(cs.bytesRemain), err
+ }
+ cs.bytesRemain -= int64(n)
+ if err == io.EOF && cs.bytesRemain > 0 {
+ err = io.ErrUnexpectedEOF
+ cs.readErr = err
+ return n, err
+ }
+ }
+ if n == 0 {
+
+ return
+ }
+
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+
+ var connAdd, streamAdd int32
+
+ if v := cc.inflow.available(); v < http2transportDefaultConnFlow/2 {
+ connAdd = http2transportDefaultConnFlow - v
+ cc.inflow.add(connAdd)
+ }
+ if err == nil {
+ if v := cs.inflow.available(); v < http2transportDefaultStreamFlow-http2transportDefaultStreamMinRefresh {
+ streamAdd = http2transportDefaultStreamFlow - v
+ cs.inflow.add(streamAdd)
+ }
+ }
+ if connAdd != 0 || streamAdd != 0 {
+ cc.wmu.Lock()
+ defer cc.wmu.Unlock()
+ if connAdd != 0 {
+ cc.fr.WriteWindowUpdate(0, http2mustUint31(connAdd))
+ }
+ if streamAdd != 0 {
+ cc.fr.WriteWindowUpdate(cs.ID, http2mustUint31(streamAdd))
+ }
+ cc.bw.Flush()
+ }
+ return
+}
+
+var http2errClosedResponseBody = errors.New("http2: response body closed")
+
+func (b http2transportResponseBody) Close() error {
+ cs := b.cs
+ if cs.bufPipe.Err() != io.EOF {
+
+ cs.cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil)
+ }
+ cs.bufPipe.BreakWithError(http2errClosedResponseBody)
+ return nil
+}
+
+func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error {
+ cc := rl.cc
+ cs := cc.streamByID(f.StreamID, f.StreamEnded())
+ if cs == nil {
+ cc.mu.Lock()
+ neverSent := cc.nextStreamID
+ cc.mu.Unlock()
+ if f.StreamID >= neverSent {
+
+ cc.logf("http2: Transport received unsolicited DATA frame; closing connection")
+ return http2ConnectionError(http2ErrCodeProtocol)
+ }
+
+ return nil
+ }
+ if data := f.Data(); len(data) > 0 {
+ if cs.bufPipe.b == nil {
+
+ cc.logf("http2: Transport received DATA frame for closed stream; closing connection")
+ return http2ConnectionError(http2ErrCodeProtocol)
+ }
+
+ cc.mu.Lock()
+ if cs.inflow.available() >= int32(len(data)) {
+ cs.inflow.take(int32(len(data)))
+ } else {
+ cc.mu.Unlock()
+ return http2ConnectionError(http2ErrCodeFlowControl)
+ }
+ cc.mu.Unlock()
+
+ if _, err := cs.bufPipe.Write(data); err != nil {
+ return err
+ }
+ }
+
+ if f.StreamEnded() {
+ rl.endStream(cs)
+ }
+ return nil
+}
+
+var http2errInvalidTrailers = errors.New("http2: invalid trailers")
+
+func (rl *http2clientConnReadLoop) endStream(cs *http2clientStream) {
+
+ err := io.EOF
+ code := cs.copyTrailers
+ if rl.reqMalformed != nil {
+ err = rl.reqMalformed
+ code = nil
+ }
+ cs.bufPipe.closeWithErrorAndCode(err, code)
+ delete(rl.activeRes, cs.ID)
+}
+
+func (cs *http2clientStream) copyTrailers() {
+ for k, vv := range cs.trailer {
+ t := cs.resTrailer
+ if *t == nil {
+ *t = make(Header)
+ }
+ (*t)[k] = vv
+ }
+}
+
+func (rl *http2clientConnReadLoop) processGoAway(f *http2GoAwayFrame) error {
+ cc := rl.cc
+ cc.t.connPool().MarkDead(cc)
+ if f.ErrCode != 0 {
+
+ cc.vlogf("transport got GOAWAY with error code = %v", f.ErrCode)
+ }
+ cc.setGoAway(f)
+ return nil
+}
+
+func (rl *http2clientConnReadLoop) processSettings(f *http2SettingsFrame) error {
+ cc := rl.cc
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+ return f.ForeachSetting(func(s http2Setting) error {
+ switch s.ID {
+ case http2SettingMaxFrameSize:
+ cc.maxFrameSize = s.Val
+ case http2SettingMaxConcurrentStreams:
+ cc.maxConcurrentStreams = s.Val
+ case http2SettingInitialWindowSize:
+
+ cc.initialWindowSize = s.Val
+ default:
+
+ cc.vlogf("Unhandled Setting: %v", s)
+ }
+ return nil
+ })
+}
+
+func (rl *http2clientConnReadLoop) processWindowUpdate(f *http2WindowUpdateFrame) error {
+ cc := rl.cc
+ cs := cc.streamByID(f.StreamID, false)
+ if f.StreamID != 0 && cs == nil {
+ return nil
+ }
+
+ cc.mu.Lock()
+ defer cc.mu.Unlock()
+
+ fl := &cc.flow
+ if cs != nil {
+ fl = &cs.flow
+ }
+ if !fl.add(int32(f.Increment)) {
+ return http2ConnectionError(http2ErrCodeFlowControl)
+ }
+ cc.cond.Broadcast()
+ return nil
+}
+
+func (rl *http2clientConnReadLoop) processResetStream(f *http2RSTStreamFrame) error {
+ cs := rl.cc.streamByID(f.StreamID, true)
+ if cs == nil {
+
+ return nil
+ }
+ select {
+ case <-cs.peerReset:
+
+ default:
+ err := http2StreamError{cs.ID, f.ErrCode}
+ cs.resetErr = err
+ close(cs.peerReset)
+ cs.bufPipe.CloseWithError(err)
+ cs.cc.cond.Broadcast()
+ }
+ delete(rl.activeRes, cs.ID)
+ return nil
+}
+
+func (rl *http2clientConnReadLoop) processPing(f *http2PingFrame) error {
+ if f.IsAck() {
+
+ return nil
+ }
+ cc := rl.cc
+ cc.wmu.Lock()
+ defer cc.wmu.Unlock()
+ if err := cc.fr.WritePing(true, f.Data); err != nil {
+ return err
+ }
+ return cc.bw.Flush()
+}
+
+func (rl *http2clientConnReadLoop) processPushPromise(f *http2PushPromiseFrame) error {
+
+ return http2ConnectionError(http2ErrCodeProtocol)
+}
+
+func (cc *http2ClientConn) writeStreamReset(streamID uint32, code http2ErrCode, err error) {
+
+ cc.wmu.Lock()
+ cc.fr.WriteRSTStream(streamID, code)
+ cc.bw.Flush()
+ cc.wmu.Unlock()
+}
+
+var (
+ http2errResponseHeaderListSize = errors.New("http2: response header list larger than advertised limit")
+ http2errPseudoTrailers = errors.New("http2: invalid pseudo header in trailers")
+)
+
+func (rl *http2clientConnReadLoop) checkHeaderField(f hpack.HeaderField) bool {
+ if rl.reqMalformed != nil {
+ return false
+ }
+
+ const headerFieldOverhead = 32 // per spec
+ rl.headerListSize += int64(len(f.Name)) + int64(len(f.Value)) + headerFieldOverhead
+ if max := rl.cc.t.maxHeaderListSize(); max != 0 && rl.headerListSize > int64(max) {
+ rl.hdec.SetEmitEnabled(false)
+ rl.reqMalformed = http2errResponseHeaderListSize
+ return false
+ }
+
+ if !http2validHeaderFieldValue(f.Value) {
+ rl.reqMalformed = http2errInvalidHeaderFieldValue
+ return false
+ }
+
+ isPseudo := strings.HasPrefix(f.Name, ":")
+ if isPseudo {
+ if rl.sawRegHeader {
+ rl.reqMalformed = errors.New("http2: invalid pseudo header after regular header")
+ return false
+ }
+ } else {
+ if !http2validHeaderFieldName(f.Name) {
+ rl.reqMalformed = http2errInvalidHeaderFieldName
+ return false
+ }
+ rl.sawRegHeader = true
+ }
+
+ return true
+}
+
+// onNewHeaderField runs on the readLoop goroutine whenever a new
+// hpack header field is decoded.
+func (rl *http2clientConnReadLoop) onNewHeaderField(f hpack.HeaderField) {
+ cc := rl.cc
+ if http2VerboseLogs {
+ cc.logf("http2: Transport decoded %v", f)
+ }
+
+ if !rl.checkHeaderField(f) {
+ return
+ }
+
+ isPseudo := strings.HasPrefix(f.Name, ":")
+ if isPseudo {
+ switch f.Name {
+ case ":status":
+ code, err := strconv.Atoi(f.Value)
+ if err != nil {
+ rl.reqMalformed = errors.New("http2: invalid :status")
+ return
+ }
+ rl.nextRes.Status = f.Value + " " + StatusText(code)
+ rl.nextRes.StatusCode = code
+ default:
+
+ rl.reqMalformed = fmt.Errorf("http2: unknown response pseudo header %q", f.Name)
+ }
+ return
+ }
+
+ key := CanonicalHeaderKey(f.Name)
+ if key == "Trailer" {
+ t := rl.nextRes.Trailer
+ if t == nil {
+ t = make(Header)
+ rl.nextRes.Trailer = t
+ }
+ http2foreachHeaderElement(f.Value, func(v string) {
+ t[CanonicalHeaderKey(v)] = nil
+ })
+ } else {
+ rl.nextRes.Header.Add(key, f.Value)
+ }
+}
+
+func (rl *http2clientConnReadLoop) onNewTrailerField(cs *http2clientStream, f hpack.HeaderField) {
+ if http2VerboseLogs {
+ rl.cc.logf("http2: Transport decoded trailer %v", f)
+ }
+ if !rl.checkHeaderField(f) {
+ return
+ }
+ if strings.HasPrefix(f.Name, ":") {
+
+ rl.reqMalformed = http2errPseudoTrailers
+ return
+ }
+
+ key := CanonicalHeaderKey(f.Name)
+
+ // The spec says one must predeclare their trailers but in practice
+ // popular users (which is to say the only user we found) do not so we
+ // violate the spec and accept all of them.
+ const acceptAllTrailers = true
+ if _, ok := (*cs.resTrailer)[key]; ok || acceptAllTrailers {
+ if cs.trailer == nil {
+ cs.trailer = make(Header)
+ }
+ cs.trailer[key] = append(cs.trailer[key], f.Value)
+ }
+}
+
+func (cc *http2ClientConn) logf(format string, args ...interface{}) {
+ cc.t.logf(format, args...)
+}
+
+func (cc *http2ClientConn) vlogf(format string, args ...interface{}) {
+ cc.t.vlogf(format, args...)
+}
+
+func (t *http2Transport) vlogf(format string, args ...interface{}) {
+ if http2VerboseLogs {
+ t.logf(format, args...)
+ }
+}
+
+func (t *http2Transport) logf(format string, args ...interface{}) {
+ log.Printf(format, args...)
+}
+
+var http2noBody io.ReadCloser = ioutil.NopCloser(bytes.NewReader(nil))
+
+func http2strSliceContains(ss []string, s string) bool {
+ for _, v := range ss {
+ if v == s {
+ return true
+ }
+ }
+ return false
+}
+
+type http2erringRoundTripper struct{ err error }
+
+func (rt http2erringRoundTripper) RoundTrip(*Request) (*Response, error) { return nil, rt.err }
+
+// gzipReader wraps a response body so it can lazily
+// call gzip.NewReader on the first call to Read
+type http2gzipReader struct {
+ body io.ReadCloser // underlying Response.Body
+ zr io.Reader // lazily-initialized gzip reader
+}
+
+func (gz *http2gzipReader) Read(p []byte) (n int, err error) {
+ if gz.zr == nil {
+ gz.zr, err = gzip.NewReader(gz.body)
+ if err != nil {
+ return 0, err
+ }
+ }
+ return gz.zr.Read(p)
+}
+
+func (gz *http2gzipReader) Close() error {
+ return gz.body.Close()
+}
+
+type http2errorReader struct{ err error }
+
+func (r http2errorReader) Read(p []byte) (int, error) { return 0, r.err }
+
+// writeFramer is implemented by any type that is used to write frames.
+type http2writeFramer interface {
+ writeFrame(http2writeContext) error
+}
+
+// writeContext is the interface needed by the various frame writer
+// types below. All the writeFrame methods below are scheduled via the
+// frame writing scheduler (see writeScheduler in writesched.go).
+//
+// This interface is implemented by *serverConn.
+//
+// TODO: decide whether to a) use this in the client code (which didn't
+// end up using this yet, because it has a simpler design, not
+// currently implementing priorities), or b) delete this and
+// make the server code a bit more concrete.
+type http2writeContext interface {
+ Framer() *http2Framer
+ Flush() error
+ CloseConn() error
+ // HeaderEncoder returns an HPACK encoder that writes to the
+ // returned buffer.
+ HeaderEncoder() (*hpack.Encoder, *bytes.Buffer)
+}
+
+// endsStream reports whether the given frame writer w will locally
+// close the stream.
+func http2endsStream(w http2writeFramer) bool {
+ switch v := w.(type) {
+ case *http2writeData:
+ return v.endStream
+ case *http2writeResHeaders:
+ return v.endStream
+ case nil:
+
+ panic("endsStream called on nil writeFramer")
+ }
+ return false
+}
+
+type http2flushFrameWriter struct{}
+
+func (http2flushFrameWriter) writeFrame(ctx http2writeContext) error {
+ return ctx.Flush()
+}
+
+type http2writeSettings []http2Setting
+
+func (s http2writeSettings) writeFrame(ctx http2writeContext) error {
+ return ctx.Framer().WriteSettings([]http2Setting(s)...)
+}
+
+type http2writeGoAway struct {
+ maxStreamID uint32
+ code http2ErrCode
+}
+
+func (p *http2writeGoAway) writeFrame(ctx http2writeContext) error {
+ err := ctx.Framer().WriteGoAway(p.maxStreamID, p.code, nil)
+ if p.code != 0 {
+ ctx.Flush()
+ time.Sleep(50 * time.Millisecond)
+ ctx.CloseConn()
+ }
+ return err
+}
+
+type http2writeData struct {
+ streamID uint32
+ p []byte
+ endStream bool
+}
+
+func (w *http2writeData) String() string {
+ return fmt.Sprintf("writeData(stream=%d, p=%d, endStream=%v)", w.streamID, len(w.p), w.endStream)
+}
+
+func (w *http2writeData) writeFrame(ctx http2writeContext) error {
+ return ctx.Framer().WriteData(w.streamID, w.endStream, w.p)
+}
+
+// handlerPanicRST is the message sent from handler goroutines when
+// the handler panics.
+type http2handlerPanicRST struct {
+ StreamID uint32
+}
+
+func (hp http2handlerPanicRST) writeFrame(ctx http2writeContext) error {
+ return ctx.Framer().WriteRSTStream(hp.StreamID, http2ErrCodeInternal)
+}
+
+func (se http2StreamError) writeFrame(ctx http2writeContext) error {
+ return ctx.Framer().WriteRSTStream(se.StreamID, se.Code)
+}
+
+type http2writePingAck struct{ pf *http2PingFrame }
+
+func (w http2writePingAck) writeFrame(ctx http2writeContext) error {
+ return ctx.Framer().WritePing(true, w.pf.Data)
+}
+
+type http2writeSettingsAck struct{}
+
+func (http2writeSettingsAck) writeFrame(ctx http2writeContext) error {
+ return ctx.Framer().WriteSettingsAck()
+}
+
+// writeResHeaders is a request to write a HEADERS and 0+ CONTINUATION frames
+// for HTTP response headers or trailers from a server handler.
+type http2writeResHeaders struct {
+ streamID uint32
+ httpResCode int // 0 means no ":status" line
+ h Header // may be nil
+ trailers []string // if non-nil, which keys of h to write. nil means all.
+ endStream bool
+
+ date string
+ contentType string
+ contentLength string
+}
+
+func http2encKV(enc *hpack.Encoder, k, v string) {
+ if http2VerboseLogs {
+ log.Printf("http2: server encoding header %q = %q", k, v)
+ }
+ enc.WriteField(hpack.HeaderField{Name: k, Value: v})
+}
+
+func (w *http2writeResHeaders) writeFrame(ctx http2writeContext) error {
+ enc, buf := ctx.HeaderEncoder()
+ buf.Reset()
+
+ if w.httpResCode != 0 {
+ http2encKV(enc, ":status", http2httpCodeString(w.httpResCode))
+ }
+
+ http2encodeHeaders(enc, w.h, w.trailers)
+
+ if w.contentType != "" {
+ http2encKV(enc, "content-type", w.contentType)
+ }
+ if w.contentLength != "" {
+ http2encKV(enc, "content-length", w.contentLength)
+ }
+ if w.date != "" {
+ http2encKV(enc, "date", w.date)
+ }
+
+ headerBlock := buf.Bytes()
+ if len(headerBlock) == 0 && w.trailers == nil {
+ panic("unexpected empty hpack")
+ }
+
+ // For now we're lazy and just pick the minimum MAX_FRAME_SIZE
+ // that all peers must support (16KB). Later we could care
+ // more and send larger frames if the peer advertised it, but
+ // there's little point. Most headers are small anyway (so we
+ // generally won't have CONTINUATION frames), and extra frames
+ // only waste 9 bytes anyway.
+ const maxFrameSize = 16384
+
+ first := true
+ for len(headerBlock) > 0 {
+ frag := headerBlock
+ if len(frag) > maxFrameSize {
+ frag = frag[:maxFrameSize]
+ }
+ headerBlock = headerBlock[len(frag):]
+ endHeaders := len(headerBlock) == 0
+ var err error
+ if first {
+ first = false
+ err = ctx.Framer().WriteHeaders(http2HeadersFrameParam{
+ StreamID: w.streamID,
+ BlockFragment: frag,
+ EndStream: w.endStream,
+ EndHeaders: endHeaders,
+ })
+ } else {
+ err = ctx.Framer().WriteContinuation(w.streamID, endHeaders, frag)
+ }
+ if err != nil {
+ return err
+ }
+ }
+ return nil
+}
+
+type http2write100ContinueHeadersFrame struct {
+ streamID uint32
+}
+
+func (w http2write100ContinueHeadersFrame) writeFrame(ctx http2writeContext) error {
+ enc, buf := ctx.HeaderEncoder()
+ buf.Reset()
+ http2encKV(enc, ":status", "100")
+ return ctx.Framer().WriteHeaders(http2HeadersFrameParam{
+ StreamID: w.streamID,
+ BlockFragment: buf.Bytes(),
+ EndStream: false,
+ EndHeaders: true,
+ })
+}
+
+type http2writeWindowUpdate struct {
+ streamID uint32 // or 0 for conn-level
+ n uint32
+}
+
+func (wu http2writeWindowUpdate) writeFrame(ctx http2writeContext) error {
+ return ctx.Framer().WriteWindowUpdate(wu.streamID, wu.n)
+}
+
+func http2encodeHeaders(enc *hpack.Encoder, h Header, keys []string) {
+
+ if keys == nil {
+ keys = make([]string, 0, len(h))
+ for k := range h {
+ keys = append(keys, k)
+ }
+ sort.Strings(keys)
+ }
+ for _, k := range keys {
+ vv := h[k]
+ k = http2lowerHeader(k)
+ isTE := k == "transfer-encoding"
+ for _, v := range vv {
+
+ if isTE && v != "trailers" {
+ continue
+ }
+ http2encKV(enc, k, v)
+ }
+ }
+}
+
+// frameWriteMsg is a request to write a frame.
+type http2frameWriteMsg struct {
+ // write is the interface value that does the writing, once the
+ // writeScheduler (below) has decided to select this frame
+ // to write. The write functions are all defined in write.go.
+ write http2writeFramer
+
+ stream *http2stream // used for prioritization. nil for non-stream frames.
+
+ // done, if non-nil, must be a buffered channel with space for
+ // 1 message and is sent the return value from write (or an
+ // earlier error) when the frame has been written.
+ done chan error
+}
+
+// for debugging only:
+func (wm http2frameWriteMsg) String() string {
+ var streamID uint32
+ if wm.stream != nil {
+ streamID = wm.stream.id
+ }
+ var des string
+ if s, ok := wm.write.(fmt.Stringer); ok {
+ des = s.String()
+ } else {
+ des = fmt.Sprintf("%T", wm.write)
+ }
+ return fmt.Sprintf("[frameWriteMsg stream=%d, ch=%v, type: %v]", streamID, wm.done != nil, des)
+}
+
+// writeScheduler tracks pending frames to write, priorities, and decides
+// the next one to use. It is not thread-safe.
+type http2writeScheduler struct {
+ // zero are frames not associated with a specific stream.
+ // They're sent before any stream-specific freams.
+ zero http2writeQueue
+
+ // maxFrameSize is the maximum size of a DATA frame
+ // we'll write. Must be non-zero and between 16K-16M.
+ maxFrameSize uint32
+
+ // sq contains the stream-specific queues, keyed by stream ID.
+ // when a stream is idle, it's deleted from the map.
+ sq map[uint32]*http2writeQueue
+
+ // canSend is a slice of memory that's reused between frame
+ // scheduling decisions to hold the list of writeQueues (from sq)
+ // which have enough flow control data to send. After canSend is
+ // built, the best is selected.
+ canSend []*http2writeQueue
+
+ // pool of empty queues for reuse.
+ queuePool []*http2writeQueue
+}
+
+func (ws *http2writeScheduler) putEmptyQueue(q *http2writeQueue) {
+ if len(q.s) != 0 {
+ panic("queue must be empty")
+ }
+ ws.queuePool = append(ws.queuePool, q)
+}
+
+func (ws *http2writeScheduler) getEmptyQueue() *http2writeQueue {
+ ln := len(ws.queuePool)
+ if ln == 0 {
+ return new(http2writeQueue)
+ }
+ q := ws.queuePool[ln-1]
+ ws.queuePool = ws.queuePool[:ln-1]
+ return q
+}
+
+func (ws *http2writeScheduler) empty() bool { return ws.zero.empty() && len(ws.sq) == 0 }
+
+func (ws *http2writeScheduler) add(wm http2frameWriteMsg) {
+ st := wm.stream
+ if st == nil {
+ ws.zero.push(wm)
+ } else {
+ ws.streamQueue(st.id).push(wm)
+ }
+}
+
+func (ws *http2writeScheduler) streamQueue(streamID uint32) *http2writeQueue {
+ if q, ok := ws.sq[streamID]; ok {
+ return q
+ }
+ if ws.sq == nil {
+ ws.sq = make(map[uint32]*http2writeQueue)
+ }
+ q := ws.getEmptyQueue()
+ ws.sq[streamID] = q
+ return q
+}
+
+// take returns the most important frame to write and removes it from the scheduler.
+// It is illegal to call this if the scheduler is empty or if there are no connection-level
+// flow control bytes available.
+func (ws *http2writeScheduler) take() (wm http2frameWriteMsg, ok bool) {
+ if ws.maxFrameSize == 0 {
+ panic("internal error: ws.maxFrameSize not initialized or invalid")
+ }
+
+ if !ws.zero.empty() {
+ return ws.zero.shift(), true
+ }
+ if len(ws.sq) == 0 {
+ return
+ }
+
+ for id, q := range ws.sq {
+ if q.firstIsNoCost() {
+ return ws.takeFrom(id, q)
+ }
+ }
+
+ if len(ws.canSend) != 0 {
+ panic("should be empty")
+ }
+ for _, q := range ws.sq {
+ if n := ws.streamWritableBytes(q); n > 0 {
+ ws.canSend = append(ws.canSend, q)
+ }
+ }
+ if len(ws.canSend) == 0 {
+ return
+ }
+ defer ws.zeroCanSend()
+
+ q := ws.canSend[0]
+
+ return ws.takeFrom(q.streamID(), q)
+}
+
+// zeroCanSend is defered from take.
+func (ws *http2writeScheduler) zeroCanSend() {
+ for i := range ws.canSend {
+ ws.canSend[i] = nil
+ }
+ ws.canSend = ws.canSend[:0]
+}
+
+// streamWritableBytes returns the number of DATA bytes we could write
+// from the given queue's stream, if this stream/queue were
+// selected. It is an error to call this if q's head isn't a
+// *writeData.
+func (ws *http2writeScheduler) streamWritableBytes(q *http2writeQueue) int32 {
+ wm := q.head()
+ ret := wm.stream.flow.available()
+ if ret == 0 {
+ return 0
+ }
+ if int32(ws.maxFrameSize) < ret {
+ ret = int32(ws.maxFrameSize)
+ }
+ if ret == 0 {
+ panic("internal error: ws.maxFrameSize not initialized or invalid")
+ }
+ wd := wm.write.(*http2writeData)
+ if len(wd.p) < int(ret) {
+ ret = int32(len(wd.p))
+ }
+ return ret
+}
+
+func (ws *http2writeScheduler) takeFrom(id uint32, q *http2writeQueue) (wm http2frameWriteMsg, ok bool) {
+ wm = q.head()
+
+ if wd, ok := wm.write.(*http2writeData); ok && len(wd.p) > 0 {
+ allowed := wm.stream.flow.available()
+ if allowed == 0 {
+
+ return http2frameWriteMsg{}, false
+ }
+ if int32(ws.maxFrameSize) < allowed {
+ allowed = int32(ws.maxFrameSize)
+ }
+
+ if len(wd.p) > int(allowed) {
+ wm.stream.flow.take(allowed)
+ chunk := wd.p[:allowed]
+ wd.p = wd.p[allowed:]
+
+ return http2frameWriteMsg{
+ stream: wm.stream,
+ write: &http2writeData{
+ streamID: wd.streamID,
+ p: chunk,
+
+ endStream: false,
+ },
+
+ done: nil,
+ }, true
+ }
+ wm.stream.flow.take(int32(len(wd.p)))
+ }
+
+ q.shift()
+ if q.empty() {
+ ws.putEmptyQueue(q)
+ delete(ws.sq, id)
+ }
+ return wm, true
+}
+
+func (ws *http2writeScheduler) forgetStream(id uint32) {
+ q, ok := ws.sq[id]
+ if !ok {
+ return
+ }
+ delete(ws.sq, id)
+
+ for i := range q.s {
+ q.s[i] = http2frameWriteMsg{}
+ }
+ q.s = q.s[:0]
+ ws.putEmptyQueue(q)
+}
+
+type http2writeQueue struct {
+ s []http2frameWriteMsg
+}
+
+// streamID returns the stream ID for a non-empty stream-specific queue.
+func (q *http2writeQueue) streamID() uint32 { return q.s[0].stream.id }
+
+func (q *http2writeQueue) empty() bool { return len(q.s) == 0 }
+
+func (q *http2writeQueue) push(wm http2frameWriteMsg) {
+ q.s = append(q.s, wm)
+}
+
+// head returns the next item that would be removed by shift.
+func (q *http2writeQueue) head() http2frameWriteMsg {
+ if len(q.s) == 0 {
+ panic("invalid use of queue")
+ }
+ return q.s[0]
+}
+
+func (q *http2writeQueue) shift() http2frameWriteMsg {
+ if len(q.s) == 0 {
+ panic("invalid use of queue")
+ }
+ wm := q.s[0]
+
+ copy(q.s, q.s[1:])
+ q.s[len(q.s)-1] = http2frameWriteMsg{}
+ q.s = q.s[:len(q.s)-1]
+ return wm
+}
+
+func (q *http2writeQueue) firstIsNoCost() bool {
+ if df, ok := q.s[0].write.(*http2writeData); ok {
+ return len(df.p) == 0
+ }
+ return true
+}
diff --git a/libgo/go/net/http/header.go b/libgo/go/net/http/header.go
index d847b131184..049f32f27dc 100644
--- a/libgo/go/net/http/header.go
+++ b/libgo/go/net/http/header.go
@@ -211,3 +211,13 @@ func hasToken(v, token string) bool {
func isTokenBoundary(b byte) bool {
return b == ' ' || b == ',' || b == '\t'
}
+
+func cloneHeader(h Header) Header {
+ h2 := make(Header, len(h))
+ for k, vv := range h {
+ vv2 := make([]string, len(vv))
+ copy(vv2, vv)
+ h2[k] = vv2
+ }
+ return h2
+}
diff --git a/libgo/go/net/http/httptest/recorder.go b/libgo/go/net/http/httptest/recorder.go
index 5451f54234c..7c51af1867a 100644
--- a/libgo/go/net/http/httptest/recorder.go
+++ b/libgo/go/net/http/httptest/recorder.go
@@ -44,23 +44,60 @@ func (rw *ResponseRecorder) Header() http.Header {
return m
}
+// writeHeader writes a header if it was not written yet and
+// detects Content-Type if needed.
+//
+// bytes or str are the beginning of the response body.
+// We pass both to avoid unnecessarily generate garbage
+// in rw.WriteString which was created for performance reasons.
+// Non-nil bytes win.
+func (rw *ResponseRecorder) writeHeader(b []byte, str string) {
+ if rw.wroteHeader {
+ return
+ }
+ if len(str) > 512 {
+ str = str[:512]
+ }
+
+ _, hasType := rw.HeaderMap["Content-Type"]
+ hasTE := rw.HeaderMap.Get("Transfer-Encoding") != ""
+ if !hasType && !hasTE {
+ if b == nil {
+ b = []byte(str)
+ }
+ if rw.HeaderMap == nil {
+ rw.HeaderMap = make(http.Header)
+ }
+ rw.HeaderMap.Set("Content-Type", http.DetectContentType(b))
+ }
+
+ rw.WriteHeader(200)
+}
+
// Write always succeeds and writes to rw.Body, if not nil.
func (rw *ResponseRecorder) Write(buf []byte) (int, error) {
- if !rw.wroteHeader {
- rw.WriteHeader(200)
- }
+ rw.writeHeader(buf, "")
if rw.Body != nil {
rw.Body.Write(buf)
}
return len(buf), nil
}
+// WriteString always succeeds and writes to rw.Body, if not nil.
+func (rw *ResponseRecorder) WriteString(str string) (int, error) {
+ rw.writeHeader(nil, str)
+ if rw.Body != nil {
+ rw.Body.WriteString(str)
+ }
+ return len(str), nil
+}
+
// WriteHeader sets rw.Code.
func (rw *ResponseRecorder) WriteHeader(code int) {
if !rw.wroteHeader {
rw.Code = code
+ rw.wroteHeader = true
}
- rw.wroteHeader = true
}
// Flush sets rw.Flushed to true.
diff --git a/libgo/go/net/http/httptest/recorder_test.go b/libgo/go/net/http/httptest/recorder_test.go
index 2b563260c76..c29b6d4cf91 100644
--- a/libgo/go/net/http/httptest/recorder_test.go
+++ b/libgo/go/net/http/httptest/recorder_test.go
@@ -6,6 +6,7 @@ package httptest
import (
"fmt"
+ "io"
"net/http"
"testing"
)
@@ -38,6 +39,14 @@ func TestRecorder(t *testing.T) {
return nil
}
}
+ hasHeader := func(key, want string) checkFunc {
+ return func(rec *ResponseRecorder) error {
+ if got := rec.HeaderMap.Get(key); got != want {
+ return fmt.Errorf("header %s = %q; want %q", key, got, want)
+ }
+ return nil
+ }
+ }
tests := []struct {
name string
@@ -68,6 +77,18 @@ func TestRecorder(t *testing.T) {
check(hasStatus(200), hasContents("hi first"), hasFlush(false)),
},
{
+ "write string",
+ func(w http.ResponseWriter, r *http.Request) {
+ io.WriteString(w, "hi first")
+ },
+ check(
+ hasStatus(200),
+ hasContents("hi first"),
+ hasFlush(false),
+ hasHeader("Content-Type", "text/plain; charset=utf-8"),
+ ),
+ },
+ {
"flush",
func(w http.ResponseWriter, r *http.Request) {
w.(http.Flusher).Flush() // also sends a 200
@@ -75,6 +96,40 @@ func TestRecorder(t *testing.T) {
},
check(hasStatus(200), hasFlush(true)),
},
+ {
+ "Content-Type detection",
+ func(w http.ResponseWriter, r *http.Request) {
+ io.WriteString(w, "<html>")
+ },
+ check(hasHeader("Content-Type", "text/html; charset=utf-8")),
+ },
+ {
+ "no Content-Type detection with Transfer-Encoding",
+ func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Transfer-Encoding", "some encoding")
+ io.WriteString(w, "<html>")
+ },
+ check(hasHeader("Content-Type", "")), // no header
+ },
+ {
+ "no Content-Type detection if set explicitly",
+ func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Content-Type", "some/type")
+ io.WriteString(w, "<html>")
+ },
+ check(hasHeader("Content-Type", "some/type")),
+ },
+ {
+ "Content-Type detection doesn't crash if HeaderMap is nil",
+ func(w http.ResponseWriter, r *http.Request) {
+ // Act as if the user wrote new(httptest.ResponseRecorder)
+ // rather than using NewRecorder (which initializes
+ // HeaderMap)
+ w.(*ResponseRecorder).HeaderMap = nil
+ io.WriteString(w, "<html>")
+ },
+ check(hasHeader("Content-Type", "text/html; charset=utf-8")),
+ },
}
r, _ := http.NewRequest("GET", "http://foo.com/", nil)
for _, tt := range tests {
diff --git a/libgo/go/net/http/httptest/server.go b/libgo/go/net/http/httptest/server.go
index 96eb0ef6d2f..5c19c0ca340 100644
--- a/libgo/go/net/http/httptest/server.go
+++ b/libgo/go/net/http/httptest/server.go
@@ -7,13 +7,18 @@
package httptest
import (
+ "bytes"
"crypto/tls"
"flag"
"fmt"
+ "log"
"net"
"net/http"
+ "net/http/internal"
"os"
+ "runtime"
"sync"
+ "time"
)
// A Server is an HTTP server listening on a system-chosen port on the
@@ -34,24 +39,10 @@ type Server struct {
// wg counts the number of outstanding HTTP requests on this server.
// Close blocks until all requests are finished.
wg sync.WaitGroup
-}
-
-// historyListener keeps track of all connections that it's ever
-// accepted.
-type historyListener struct {
- net.Listener
- sync.Mutex // protects history
- history []net.Conn
-}
-func (hs *historyListener) Accept() (c net.Conn, err error) {
- c, err = hs.Listener.Accept()
- if err == nil {
- hs.Lock()
- hs.history = append(hs.history, c)
- hs.Unlock()
- }
- return
+ mu sync.Mutex // guards closed and conns
+ closed bool
+ conns map[net.Conn]http.ConnState // except terminal states
}
func newLocalListener() net.Listener {
@@ -103,10 +94,9 @@ func (s *Server) Start() {
if s.URL != "" {
panic("Server already started")
}
- s.Listener = &historyListener{Listener: s.Listener}
s.URL = "http://" + s.Listener.Addr().String()
- s.wrapHandler()
- go s.Config.Serve(s.Listener)
+ s.wrap()
+ s.goServe()
if *serve != "" {
fmt.Fprintln(os.Stderr, "httptest: serving on", s.URL)
select {}
@@ -118,7 +108,7 @@ func (s *Server) StartTLS() {
if s.URL != "" {
panic("Server already started")
}
- cert, err := tls.X509KeyPair(localhostCert, localhostKey)
+ cert, err := tls.X509KeyPair(internal.LocalhostCert, internal.LocalhostKey)
if err != nil {
panic(fmt.Sprintf("httptest: NewTLSServer: %v", err))
}
@@ -134,23 +124,10 @@ func (s *Server) StartTLS() {
if len(s.TLS.Certificates) == 0 {
s.TLS.Certificates = []tls.Certificate{cert}
}
- tlsListener := tls.NewListener(s.Listener, s.TLS)
-
- s.Listener = &historyListener{Listener: tlsListener}
+ s.Listener = tls.NewListener(s.Listener, s.TLS)
s.URL = "https://" + s.Listener.Addr().String()
- s.wrapHandler()
- go s.Config.Serve(s.Listener)
-}
-
-func (s *Server) wrapHandler() {
- h := s.Config.Handler
- if h == nil {
- h = http.DefaultServeMux
- }
- s.Config.Handler = &waitGroupHandler{
- s: s,
- h: h,
- }
+ s.wrap()
+ s.goServe()
}
// NewTLSServer starts and returns a new Server using TLS.
@@ -161,78 +138,155 @@ func NewTLSServer(handler http.Handler) *Server {
return ts
}
+type closeIdleTransport interface {
+ CloseIdleConnections()
+}
+
// Close shuts down the server and blocks until all outstanding
// requests on this server have completed.
func (s *Server) Close() {
- s.Listener.Close()
- s.wg.Wait()
- s.CloseClientConnections()
- if t, ok := http.DefaultTransport.(*http.Transport); ok {
+ s.mu.Lock()
+ if !s.closed {
+ s.closed = true
+ s.Listener.Close()
+ s.Config.SetKeepAlivesEnabled(false)
+ for c, st := range s.conns {
+ // Force-close any idle connections (those between
+ // requests) and new connections (those which connected
+ // but never sent a request). StateNew connections are
+ // super rare and have only been seen (in
+ // previously-flaky tests) in the case of
+ // socket-late-binding races from the http Client
+ // dialing this server and then getting an idle
+ // connection before the dial completed. There is thus
+ // a connected connection in StateNew with no
+ // associated Request. We only close StateIdle and
+ // StateNew because they're not doing anything. It's
+ // possible StateNew is about to do something in a few
+ // milliseconds, but a previous CL to check again in a
+ // few milliseconds wasn't liked (early versions of
+ // https://golang.org/cl/15151) so now we just
+ // forcefully close StateNew. The docs for Server.Close say
+ // we wait for "oustanding requests", so we don't close things
+ // in StateActive.
+ if st == http.StateIdle || st == http.StateNew {
+ s.closeConn(c)
+ }
+ }
+ // If this server doesn't shut down in 20 seconds, tell the user why.
+ t := time.AfterFunc(20*time.Second, s.logCloseHangDebugInfo)
+ defer t.Stop()
+ }
+ s.mu.Unlock()
+
+ // Not part of httptest.Server's correctness, but assume most
+ // users of httptest.Server will be using the standard
+ // transport, so help them out and close any idle connections for them.
+ if t, ok := http.DefaultTransport.(closeIdleTransport); ok {
t.CloseIdleConnections()
}
+
+ s.wg.Wait()
}
-// CloseClientConnections closes any currently open HTTP connections
-// to the test Server.
-func (s *Server) CloseClientConnections() {
- hl, ok := s.Listener.(*historyListener)
- if !ok {
- return
+func (s *Server) logCloseHangDebugInfo() {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ var buf bytes.Buffer
+ buf.WriteString("httptest.Server blocked in Close after 5 seconds, waiting for connections:\n")
+ for c, st := range s.conns {
+ fmt.Fprintf(&buf, " %T %p %v in state %v\n", c, c, c.RemoteAddr(), st)
}
- hl.Lock()
- for _, conn := range hl.history {
- conn.Close()
+ log.Print(buf.String())
+}
+
+// CloseClientConnections closes any open HTTP connections to the test Server.
+func (s *Server) CloseClientConnections() {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ for c := range s.conns {
+ s.closeConn(c)
}
- hl.Unlock()
}
-// waitGroupHandler wraps a handler, incrementing and decrementing a
-// sync.WaitGroup on each request, to enable Server.Close to block
-// until outstanding requests are finished.
-type waitGroupHandler struct {
- s *Server
- h http.Handler // non-nil
+func (s *Server) goServe() {
+ s.wg.Add(1)
+ go func() {
+ defer s.wg.Done()
+ s.Config.Serve(s.Listener)
+ }()
+}
+
+// wrap installs the connection state-tracking hook to know which
+// connections are idle.
+func (s *Server) wrap() {
+ oldHook := s.Config.ConnState
+ s.Config.ConnState = func(c net.Conn, cs http.ConnState) {
+ s.mu.Lock()
+ defer s.mu.Unlock()
+ switch cs {
+ case http.StateNew:
+ s.wg.Add(1)
+ if _, exists := s.conns[c]; exists {
+ panic("invalid state transition")
+ }
+ if s.conns == nil {
+ s.conns = make(map[net.Conn]http.ConnState)
+ }
+ s.conns[c] = cs
+ if s.closed {
+ // Probably just a socket-late-binding dial from
+ // the default transport that lost the race (and
+ // thus this connection is now idle and will
+ // never be used).
+ s.closeConn(c)
+ }
+ case http.StateActive:
+ if oldState, ok := s.conns[c]; ok {
+ if oldState != http.StateNew && oldState != http.StateIdle {
+ panic("invalid state transition")
+ }
+ s.conns[c] = cs
+ }
+ case http.StateIdle:
+ if oldState, ok := s.conns[c]; ok {
+ if oldState != http.StateActive {
+ panic("invalid state transition")
+ }
+ s.conns[c] = cs
+ }
+ if s.closed {
+ s.closeConn(c)
+ }
+ case http.StateHijacked, http.StateClosed:
+ s.forgetConn(c)
+ }
+ if oldHook != nil {
+ oldHook(c, cs)
+ }
+ }
}
-func (h *waitGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
- h.s.wg.Add(1)
- defer h.s.wg.Done() // a defer, in case ServeHTTP below panics
- h.h.ServeHTTP(w, r)
+// closeConn closes c. Except on plan9, which is special. See comment below.
+// s.mu must be held.
+func (s *Server) closeConn(c net.Conn) {
+ if runtime.GOOS == "plan9" {
+ // Go's Plan 9 net package isn't great at unblocking reads when
+ // their underlying TCP connections are closed. Don't trust
+ // that that the ConnState state machine will get to
+ // StateClosed. Instead, just go there directly. Plan 9 may leak
+ // resources if the syscall doesn't end up returning. Oh well.
+ s.forgetConn(c)
+ }
+ go c.Close()
}
-// localhostCert is a PEM-encoded TLS cert with SAN IPs
-// "127.0.0.1" and "[::1]", expiring at the last second of 2049 (the end
-// of ASN.1 time).
-// generated from src/crypto/tls:
-// go run generate_cert.go --rsa-bits 1024 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h
-var localhostCert = []byte(`-----BEGIN CERTIFICATE-----
-MIICEzCCAXygAwIBAgIQMIMChMLGrR+QvmQvpwAU6zANBgkqhkiG9w0BAQsFADAS
-MRAwDgYDVQQKEwdBY21lIENvMCAXDTcwMDEwMTAwMDAwMFoYDzIwODQwMTI5MTYw
-MDAwWjASMRAwDgYDVQQKEwdBY21lIENvMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCB
-iQKBgQDuLnQAI3mDgey3VBzWnB2L39JUU4txjeVE6myuDqkM/uGlfjb9SjY1bIw4
-iA5sBBZzHi3z0h1YV8QPuxEbi4nW91IJm2gsvvZhIrCHS3l6afab4pZBl2+XsDul
-rKBxKKtD1rGxlG4LjncdabFn9gvLZad2bSysqz/qTAUStTvqJQIDAQABo2gwZjAO
-BgNVHQ8BAf8EBAMCAqQwEwYDVR0lBAwwCgYIKwYBBQUHAwEwDwYDVR0TAQH/BAUw
-AwEB/zAuBgNVHREEJzAlggtleGFtcGxlLmNvbYcEfwAAAYcQAAAAAAAAAAAAAAAA
-AAAAATANBgkqhkiG9w0BAQsFAAOBgQCEcetwO59EWk7WiJsG4x8SY+UIAA+flUI9
-tyC4lNhbcF2Idq9greZwbYCqTTTr2XiRNSMLCOjKyI7ukPoPjo16ocHj+P3vZGfs
-h1fIw3cSS2OolhloGw/XM6RWPWtPAlGykKLciQrBru5NAPvCMsb/I1DAceTiotQM
-fblo6RBxUQ==
------END CERTIFICATE-----`)
-
-// localhostKey is the private key for localhostCert.
-var localhostKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
-MIICXgIBAAKBgQDuLnQAI3mDgey3VBzWnB2L39JUU4txjeVE6myuDqkM/uGlfjb9
-SjY1bIw4iA5sBBZzHi3z0h1YV8QPuxEbi4nW91IJm2gsvvZhIrCHS3l6afab4pZB
-l2+XsDulrKBxKKtD1rGxlG4LjncdabFn9gvLZad2bSysqz/qTAUStTvqJQIDAQAB
-AoGAGRzwwir7XvBOAy5tM/uV6e+Zf6anZzus1s1Y1ClbjbE6HXbnWWF/wbZGOpet
-3Zm4vD6MXc7jpTLryzTQIvVdfQbRc6+MUVeLKwZatTXtdZrhu+Jk7hx0nTPy8Jcb
-uJqFk541aEw+mMogY/xEcfbWd6IOkp+4xqjlFLBEDytgbIECQQDvH/E6nk+hgN4H
-qzzVtxxr397vWrjrIgPbJpQvBsafG7b0dA4AFjwVbFLmQcj2PprIMmPcQrooz8vp
-jy4SHEg1AkEA/v13/5M47K9vCxmb8QeD/asydfsgS5TeuNi8DoUBEmiSJwma7FXY
-fFUtxuvL7XvjwjN5B30pNEbc6Iuyt7y4MQJBAIt21su4b3sjXNueLKH85Q+phy2U
-fQtuUE9txblTu14q3N7gHRZB4ZMhFYyDy8CKrN2cPg/Fvyt0Xlp/DoCzjA0CQQDU
-y2ptGsuSmgUtWj3NM9xuwYPm+Z/F84K6+ARYiZ6PYj013sovGKUFfYAqVXVlxtIX
-qyUBnu3X9ps8ZfjLZO7BAkEAlT4R5Yl6cGhaJQYZHOde3JEMhNRcVFMO8dJDaFeo
-f9Oeos0UUothgiDktdQHxdNEwLjQf7lJJBzV+5OtwswCWA==
------END RSA PRIVATE KEY-----`)
+// forgetConn removes c from the set of tracked conns and decrements it from the
+// waitgroup, unless it was previously removed.
+// s.mu must be held.
+func (s *Server) forgetConn(c net.Conn) {
+ if _, ok := s.conns[c]; ok {
+ delete(s.conns, c)
+ s.wg.Done()
+ }
+}
diff --git a/libgo/go/net/http/httptest/server_test.go b/libgo/go/net/http/httptest/server_test.go
index 500a9f0b800..6ffc671e575 100644
--- a/libgo/go/net/http/httptest/server_test.go
+++ b/libgo/go/net/http/httptest/server_test.go
@@ -5,7 +5,9 @@
package httptest
import (
+ "bufio"
"io/ioutil"
+ "net"
"net/http"
"testing"
)
@@ -27,3 +29,58 @@ func TestServer(t *testing.T) {
t.Errorf("got %q, want hello", string(got))
}
}
+
+// Issue 12781
+func TestGetAfterClose(t *testing.T) {
+ ts := NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte("hello"))
+ }))
+
+ res, err := http.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ got, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if string(got) != "hello" {
+ t.Fatalf("got %q, want hello", string(got))
+ }
+
+ ts.Close()
+
+ res, err = http.Get(ts.URL)
+ if err == nil {
+ body, _ := ioutil.ReadAll(res.Body)
+ t.Fatalf("Unexected response after close: %v, %v, %s", res.Status, res.Header, body)
+ }
+}
+
+func TestServerCloseBlocking(t *testing.T) {
+ ts := NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte("hello"))
+ }))
+ dial := func() net.Conn {
+ c, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ return c
+ }
+
+ // Keep one connection in StateNew (connected, but not sending anything)
+ cnew := dial()
+ defer cnew.Close()
+
+ // Keep one connection in StateIdle (idle after a request)
+ cidle := dial()
+ defer cidle.Close()
+ cidle.Write([]byte("HEAD / HTTP/1.1\r\nHost: foo\r\n\r\n"))
+ _, err := http.ReadResponse(bufio.NewReader(cidle), nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ ts.Close() // test we don't hang here forever.
+}
diff --git a/libgo/go/net/http/httputil/dump.go b/libgo/go/net/http/httputil/dump.go
index ca2d1cde924..e22cc66dbfc 100644
--- a/libgo/go/net/http/httputil/dump.go
+++ b/libgo/go/net/http/httputil/dump.go
@@ -25,10 +25,10 @@ import (
func drainBody(b io.ReadCloser) (r1, r2 io.ReadCloser, err error) {
var buf bytes.Buffer
if _, err = buf.ReadFrom(b); err != nil {
- return nil, nil, err
+ return nil, b, err
}
if err = b.Close(); err != nil {
- return nil, nil, err
+ return nil, b, err
}
return ioutil.NopCloser(&buf), ioutil.NopCloser(bytes.NewReader(buf.Bytes())), nil
}
@@ -55,9 +55,9 @@ func (b neverEnding) Read(p []byte) (n int, err error) {
return len(p), nil
}
-// DumpRequestOut is like DumpRequest but includes
-// headers that the standard http.Transport adds,
-// such as User-Agent.
+// DumpRequestOut is like DumpRequest but for outgoing client requests. It
+// includes any headers that the standard http.Transport adds, such as
+// User-Agent.
func DumpRequestOut(req *http.Request, body bool) ([]byte, error) {
save := req.Body
dummyBody := false
@@ -175,13 +175,22 @@ func dumpAsReceived(req *http.Request, w io.Writer) error {
return nil
}
-// DumpRequest returns the as-received wire representation of req,
-// optionally including the request body, for debugging.
-// DumpRequest is semantically a no-op, but in order to
-// dump the body, it reads the body data into memory and
-// changes req.Body to refer to the in-memory copy.
+// DumpRequest returns the given request in its HTTP/1.x wire
+// representation. It should only be used by servers to debug client
+// requests. The returned representation is an approximation only;
+// some details of the initial request are lost while parsing it into
+// an http.Request. In particular, the order and case of header field
+// names are lost. The order of values in multi-valued headers is kept
+// intact. HTTP/2 requests are dumped in HTTP/1.x form, not in their
+// original binary representations.
+//
+// If body is true, DumpRequest also returns the body. To do so, it
+// consumes req.Body and then replaces it with a new io.ReadCloser
+// that yields the same bytes. If DumpRequest returns an error,
+// the state of req is undefined.
+//
// The documentation for http.Request.Write details which fields
-// of req are used.
+// of req are included in the dump.
func DumpRequest(req *http.Request, body bool) (dump []byte, err error) {
save := req.Body
if !body || req.Body == nil {
@@ -189,21 +198,35 @@ func DumpRequest(req *http.Request, body bool) (dump []byte, err error) {
} else {
save, req.Body, err = drainBody(req.Body)
if err != nil {
- return
+ return nil, err
}
}
var b bytes.Buffer
+ // By default, print out the unmodified req.RequestURI, which
+ // is always set for incoming server requests. But because we
+ // previously used req.URL.RequestURI and the docs weren't
+ // always so clear about when to use DumpRequest vs
+ // DumpRequestOut, fall back to the old way if the caller
+ // provides a non-server Request.
+ reqURI := req.RequestURI
+ if reqURI == "" {
+ reqURI = req.URL.RequestURI()
+ }
+
fmt.Fprintf(&b, "%s %s HTTP/%d.%d\r\n", valueOrDefault(req.Method, "GET"),
- req.URL.RequestURI(), req.ProtoMajor, req.ProtoMinor)
+ reqURI, req.ProtoMajor, req.ProtoMinor)
- host := req.Host
- if host == "" && req.URL != nil {
- host = req.URL.Host
- }
- if host != "" {
- fmt.Fprintf(&b, "Host: %s\r\n", host)
+ absRequestURI := strings.HasPrefix(req.RequestURI, "http://") || strings.HasPrefix(req.RequestURI, "https://")
+ if !absRequestURI {
+ host := req.Host
+ if host == "" && req.URL != nil {
+ host = req.URL.Host
+ }
+ if host != "" {
+ fmt.Fprintf(&b, "Host: %s\r\n", host)
+ }
}
chunked := len(req.TransferEncoding) > 0 && req.TransferEncoding[0] == "chunked"
@@ -269,7 +292,7 @@ func DumpResponse(resp *http.Response, body bool) (dump []byte, err error) {
} else {
save, resp.Body, err = drainBody(resp.Body)
if err != nil {
- return
+ return nil, err
}
}
err = resp.Write(&b)
diff --git a/libgo/go/net/http/httputil/dump_test.go b/libgo/go/net/http/httputil/dump_test.go
index ae67e983ae9..46bf521723a 100644
--- a/libgo/go/net/http/httputil/dump_test.go
+++ b/libgo/go/net/http/httputil/dump_test.go
@@ -5,6 +5,7 @@
package httputil
import (
+ "bufio"
"bytes"
"fmt"
"io"
@@ -135,6 +136,14 @@ var dumpTests = []dumpTest{
"Accept-Encoding: gzip\r\n\r\n" +
strings.Repeat("a", 8193),
},
+
+ {
+ Req: *mustReadRequest("GET http://foo.com/ HTTP/1.1\r\n" +
+ "User-Agent: blah\r\n\r\n"),
+ NoBody: true,
+ WantDump: "GET http://foo.com/ HTTP/1.1\r\n" +
+ "User-Agent: blah\r\n\r\n",
+ },
}
func TestDumpRequest(t *testing.T) {
@@ -211,6 +220,14 @@ func mustNewRequest(method, url string, body io.Reader) *http.Request {
return req
}
+func mustReadRequest(s string) *http.Request {
+ req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(s)))
+ if err != nil {
+ panic(err)
+ }
+ return req
+}
+
var dumpResTests = []struct {
res *http.Response
body bool
diff --git a/libgo/go/net/http/httputil/example_test.go b/libgo/go/net/http/httputil/example_test.go
new file mode 100644
index 00000000000..8fb1a2d2792
--- /dev/null
+++ b/libgo/go/net/http/httputil/example_test.go
@@ -0,0 +1,125 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// +build ignore
+
+package httputil_test
+
+import (
+ "fmt"
+ "io/ioutil"
+ "log"
+ "net/http"
+ "net/http/httptest"
+ "net/http/httputil"
+ "net/url"
+ "strings"
+)
+
+func ExampleDumpRequest() {
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ dump, err := httputil.DumpRequest(r, true)
+ if err != nil {
+ http.Error(w, fmt.Sprint(err), http.StatusInternalServerError)
+ return
+ }
+
+ fmt.Fprintf(w, "%q", dump)
+ }))
+ defer ts.Close()
+
+ const body = "Go is a general-purpose language designed with systems programming in mind."
+ req, err := http.NewRequest("POST", ts.URL, strings.NewReader(body))
+ if err != nil {
+ log.Fatal(err)
+ }
+ req.Host = "www.example.org"
+ resp, err := http.DefaultClient.Do(req)
+ if err != nil {
+ log.Fatal(err)
+ }
+ defer resp.Body.Close()
+
+ b, err := ioutil.ReadAll(resp.Body)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ fmt.Printf("%s", b)
+
+ // Output:
+ // "POST / HTTP/1.1\r\nHost: www.example.org\r\nAccept-Encoding: gzip\r\nUser-Agent: Go-http-client/1.1\r\n\r\nGo is a general-purpose language designed with systems programming in mind."
+}
+
+func ExampleDumpRequestOut() {
+ const body = "Go is a general-purpose language designed with systems programming in mind."
+ req, err := http.NewRequest("PUT", "http://www.example.org", strings.NewReader(body))
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ dump, err := httputil.DumpRequestOut(req, true)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ fmt.Printf("%q", dump)
+
+ // Output:
+ // "PUT / HTTP/1.1\r\nHost: www.example.org\r\nUser-Agent: Go-http-client/1.1\r\nContent-Length: 75\r\nAccept-Encoding: gzip\r\n\r\nGo is a general-purpose language designed with systems programming in mind."
+}
+
+func ExampleDumpResponse() {
+ const body = "Go is a general-purpose language designed with systems programming in mind."
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Header().Set("Date", "Wed, 19 Jul 1972 19:00:00 GMT")
+ fmt.Fprintln(w, body)
+ }))
+ defer ts.Close()
+
+ resp, err := http.Get(ts.URL)
+ if err != nil {
+ log.Fatal(err)
+ }
+ defer resp.Body.Close()
+
+ dump, err := httputil.DumpResponse(resp, true)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ fmt.Printf("%q", dump)
+
+ // Output:
+ // "HTTP/1.1 200 OK\r\nContent-Length: 76\r\nContent-Type: text/plain; charset=utf-8\r\nDate: Wed, 19 Jul 1972 19:00:00 GMT\r\n\r\nGo is a general-purpose language designed with systems programming in mind.\n"
+}
+
+func ExampleReverseProxy() {
+ backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprintln(w, "this call was relayed by the reverse proxy")
+ }))
+ defer backendServer.Close()
+
+ rpURL, err := url.Parse(backendServer.URL)
+ if err != nil {
+ log.Fatal(err)
+ }
+ frontendProxy := httptest.NewServer(httputil.NewSingleHostReverseProxy(rpURL))
+ defer frontendProxy.Close()
+
+ resp, err := http.Get(frontendProxy.URL)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ b, err := ioutil.ReadAll(resp.Body)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ fmt.Printf("%s", b)
+
+ // Output:
+ // this call was relayed by the reverse proxy
+}
diff --git a/libgo/go/net/http/httputil/reverseproxy.go b/libgo/go/net/http/httputil/reverseproxy.go
index c8e113221c4..4dba352a4fa 100644
--- a/libgo/go/net/http/httputil/reverseproxy.go
+++ b/libgo/go/net/http/httputil/reverseproxy.go
@@ -46,6 +46,18 @@ type ReverseProxy struct {
// If nil, logging goes to os.Stderr via the log package's
// standard logger.
ErrorLog *log.Logger
+
+ // BufferPool optionally specifies a buffer pool to
+ // get byte slices for use by io.CopyBuffer when
+ // copying HTTP response bodies.
+ BufferPool BufferPool
+}
+
+// A BufferPool is an interface for getting and returning temporary
+// byte slices for use by io.CopyBuffer.
+type BufferPool interface {
+ Get() []byte
+ Put([]byte)
}
func singleJoiningSlash(a, b string) string {
@@ -60,10 +72,13 @@ func singleJoiningSlash(a, b string) string {
return a + b
}
-// NewSingleHostReverseProxy returns a new ReverseProxy that rewrites
+// NewSingleHostReverseProxy returns a new ReverseProxy that routes
// URLs to the scheme, host, and base path provided in target. If the
// target's path is "/base" and the incoming request was for "/dir",
// the target request will be for /base/dir.
+// NewSingleHostReverseProxy does not rewrite the Host header.
+// To rewrite Host headers, use ReverseProxy directly with a custom
+// Director policy.
func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy {
targetQuery := target.RawQuery
director := func(req *http.Request) {
@@ -242,7 +257,14 @@ func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) {
}
}
- io.Copy(dst, src)
+ var buf []byte
+ if p.BufferPool != nil {
+ buf = p.BufferPool.Get()
+ }
+ io.CopyBuffer(dst, src, buf)
+ if p.BufferPool != nil {
+ p.BufferPool.Put(buf)
+ }
}
func (p *ReverseProxy) logf(format string, args ...interface{}) {
diff --git a/libgo/go/net/http/httputil/reverseproxy_test.go b/libgo/go/net/http/httputil/reverseproxy_test.go
index 80a26abe414..7f203d878f5 100644
--- a/libgo/go/net/http/httputil/reverseproxy_test.go
+++ b/libgo/go/net/http/httputil/reverseproxy_test.go
@@ -8,14 +8,17 @@ package httputil
import (
"bufio"
+ "bytes"
+ "io"
"io/ioutil"
"log"
"net/http"
"net/http/httptest"
"net/url"
"reflect"
- "runtime"
+ "strconv"
"strings"
+ "sync"
"testing"
"time"
)
@@ -102,7 +105,6 @@ func TestReverseProxy(t *testing.T) {
if g, e := res.Trailer.Get("X-Trailer"), "trailer_value"; g != e {
t.Errorf("Trailer(X-Trailer) = %q ; want %q", g, e)
}
-
}
func TestXForwardedFor(t *testing.T) {
@@ -225,10 +227,7 @@ func TestReverseProxyFlushInterval(t *testing.T) {
}
}
-func TestReverseProxyCancellation(t *testing.T) {
- if runtime.GOOS == "plan9" {
- t.Skip("skipping test; see https://golang.org/issue/9554")
- }
+func TestReverseProxyCancelation(t *testing.T) {
const backendResponse = "I am the backend"
reqInFlight := make(chan struct{})
@@ -320,3 +319,108 @@ func TestNilBody(t *testing.T) {
t.Errorf("Got %q; want %q", slurp, "hi")
}
}
+
+type bufferPool struct {
+ get func() []byte
+ put func([]byte)
+}
+
+func (bp bufferPool) Get() []byte { return bp.get() }
+func (bp bufferPool) Put(v []byte) { bp.put(v) }
+
+func TestReverseProxyGetPutBuffer(t *testing.T) {
+ const msg = "hi"
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ io.WriteString(w, msg)
+ }))
+ defer backend.Close()
+
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ var (
+ mu sync.Mutex
+ log []string
+ )
+ addLog := func(event string) {
+ mu.Lock()
+ defer mu.Unlock()
+ log = append(log, event)
+ }
+ rp := NewSingleHostReverseProxy(backendURL)
+ const size = 1234
+ rp.BufferPool = bufferPool{
+ get: func() []byte {
+ addLog("getBuf")
+ return make([]byte, size)
+ },
+ put: func(p []byte) {
+ addLog("putBuf-" + strconv.Itoa(len(p)))
+ },
+ }
+ frontend := httptest.NewServer(rp)
+ defer frontend.Close()
+
+ req, _ := http.NewRequest("GET", frontend.URL, nil)
+ req.Close = true
+ res, err := http.DefaultClient.Do(req)
+ if err != nil {
+ t.Fatalf("Get: %v", err)
+ }
+ slurp, err := ioutil.ReadAll(res.Body)
+ res.Body.Close()
+ if err != nil {
+ t.Fatalf("reading body: %v", err)
+ }
+ if string(slurp) != msg {
+ t.Errorf("msg = %q; want %q", slurp, msg)
+ }
+ wantLog := []string{"getBuf", "putBuf-" + strconv.Itoa(size)}
+ mu.Lock()
+ defer mu.Unlock()
+ if !reflect.DeepEqual(log, wantLog) {
+ t.Errorf("Log events = %q; want %q", log, wantLog)
+ }
+}
+
+func TestReverseProxy_Post(t *testing.T) {
+ const backendResponse = "I am the backend"
+ const backendStatus = 200
+ var requestBody = bytes.Repeat([]byte("a"), 1<<20)
+ backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ slurp, err := ioutil.ReadAll(r.Body)
+ if err != nil {
+ t.Errorf("Backend body read = %v", err)
+ }
+ if len(slurp) != len(requestBody) {
+ t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody))
+ }
+ if !bytes.Equal(slurp, requestBody) {
+ t.Error("Backend read wrong request body.") // 1MB; omitting details
+ }
+ w.Write([]byte(backendResponse))
+ }))
+ defer backend.Close()
+ backendURL, err := url.Parse(backend.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ proxyHandler := NewSingleHostReverseProxy(backendURL)
+ frontend := httptest.NewServer(proxyHandler)
+ defer frontend.Close()
+
+ postReq, _ := http.NewRequest("POST", frontend.URL, bytes.NewReader(requestBody))
+ res, err := http.DefaultClient.Do(postReq)
+ if err != nil {
+ t.Fatalf("Do: %v", err)
+ }
+ if g, e := res.StatusCode, backendStatus; g != e {
+ t.Errorf("got res.StatusCode %d; expected %d", g, e)
+ }
+ bodyBytes, _ := ioutil.ReadAll(res.Body)
+ if g, e := string(bodyBytes), backendResponse; g != e {
+ t.Errorf("got body %q; expected %q", g, e)
+ }
+}
diff --git a/libgo/go/net/http/internal/chunked.go b/libgo/go/net/http/internal/chunked.go
index 6d7c69874d9..2e62c00d5db 100644
--- a/libgo/go/net/http/internal/chunked.go
+++ b/libgo/go/net/http/internal/chunked.go
@@ -44,7 +44,7 @@ type chunkedReader struct {
func (cr *chunkedReader) beginChunk() {
// chunk-size CRLF
var line []byte
- line, cr.err = readLine(cr.r)
+ line, cr.err = readChunkLine(cr.r)
if cr.err != nil {
return
}
@@ -104,10 +104,11 @@ func (cr *chunkedReader) Read(b []uint8) (n int, err error) {
// Read a line of bytes (up to \n) from b.
// Give up if the line exceeds maxLineLength.
-// The returned bytes are a pointer into storage in
-// the bufio, so they are only valid until the next bufio read.
-func readLine(b *bufio.Reader) (p []byte, err error) {
- if p, err = b.ReadSlice('\n'); err != nil {
+// The returned bytes are owned by the bufio.Reader
+// so they are only valid until the next bufio read.
+func readChunkLine(b *bufio.Reader) ([]byte, error) {
+ p, err := b.ReadSlice('\n')
+ if err != nil {
// We always know when EOF is coming.
// If the caller asked for a line, there should be a line.
if err == io.EOF {
@@ -120,7 +121,12 @@ func readLine(b *bufio.Reader) (p []byte, err error) {
if len(p) >= maxLineLength {
return nil, ErrLineTooLong
}
- return trimTrailingWhitespace(p), nil
+ p = trimTrailingWhitespace(p)
+ p, err = removeChunkExtension(p)
+ if err != nil {
+ return nil, err
+ }
+ return p, nil
}
func trimTrailingWhitespace(b []byte) []byte {
@@ -134,6 +140,23 @@ func isASCIISpace(b byte) bool {
return b == ' ' || b == '\t' || b == '\n' || b == '\r'
}
+// removeChunkExtension removes any chunk-extension from p.
+// For example,
+// "0" => "0"
+// "0;token" => "0"
+// "0;token=val" => "0"
+// `0;token="quoted string"` => "0"
+func removeChunkExtension(p []byte) ([]byte, error) {
+ semi := bytes.IndexByte(p, ';')
+ if semi == -1 {
+ return p, nil
+ }
+ // TODO: care about exact syntax of chunk extensions? We're
+ // ignoring and stripping them anyway. For now just never
+ // return an error.
+ return p[:semi], nil
+}
+
// NewChunkedWriter returns a new chunkedWriter that translates writes into HTTP
// "chunked" format before writing them to w. Closing the returned chunkedWriter
// sends the final 0-length chunk that marks the end of the stream.
@@ -197,8 +220,7 @@ type FlushAfterChunkWriter struct {
}
func parseHexUint(v []byte) (n uint64, err error) {
- for _, b := range v {
- n <<= 4
+ for i, b := range v {
switch {
case '0' <= b && b <= '9':
b = b - '0'
@@ -209,6 +231,10 @@ func parseHexUint(v []byte) (n uint64, err error) {
default:
return 0, errors.New("invalid byte in chunk length")
}
+ if i == 16 {
+ return 0, errors.New("http chunk length too large")
+ }
+ n <<= 4
n |= uint64(b)
}
return
diff --git a/libgo/go/net/http/internal/chunked_test.go b/libgo/go/net/http/internal/chunked_test.go
index ebc626ea9d0..a136dc99a65 100644
--- a/libgo/go/net/http/internal/chunked_test.go
+++ b/libgo/go/net/http/internal/chunked_test.go
@@ -139,18 +139,49 @@ func TestChunkReaderAllocs(t *testing.T) {
}
func TestParseHexUint(t *testing.T) {
+ type testCase struct {
+ in string
+ want uint64
+ wantErr string
+ }
+ tests := []testCase{
+ {"x", 0, "invalid byte in chunk length"},
+ {"0000000000000000", 0, ""},
+ {"0000000000000001", 1, ""},
+ {"ffffffffffffffff", 1<<64 - 1, ""},
+ {"000000000000bogus", 0, "invalid byte in chunk length"},
+ {"00000000000000000", 0, "http chunk length too large"}, // could accept if we wanted
+ {"10000000000000000", 0, "http chunk length too large"},
+ {"00000000000000001", 0, "http chunk length too large"}, // could accept if we wanted
+ }
for i := uint64(0); i <= 1234; i++ {
- line := []byte(fmt.Sprintf("%x", i))
- got, err := parseHexUint(line)
- if err != nil {
- t.Fatalf("on %d: %v", i, err)
- }
- if got != i {
- t.Errorf("for input %q = %d; want %d", line, got, i)
+ tests = append(tests, testCase{in: fmt.Sprintf("%x", i), want: i})
+ }
+ for _, tt := range tests {
+ got, err := parseHexUint([]byte(tt.in))
+ if tt.wantErr != "" {
+ if !strings.Contains(fmt.Sprint(err), tt.wantErr) {
+ t.Errorf("parseHexUint(%q) = %v, %v; want error %q", tt.in, got, err, tt.wantErr)
+ }
+ } else {
+ if err != nil || got != tt.want {
+ t.Errorf("parseHexUint(%q) = %v, %v; want %v", tt.in, got, err, tt.want)
+ }
}
}
- _, err := parseHexUint([]byte("bogus"))
- if err == nil {
- t.Error("expected error on bogus input")
+}
+
+func TestChunkReadingIgnoresExtensions(t *testing.T) {
+ in := "7;ext=\"some quoted string\"\r\n" + // token=quoted string
+ "hello, \r\n" +
+ "17;someext\r\n" + // token without value
+ "world! 0123456789abcdef\r\n" +
+ "0;someextension=sometoken\r\n" // token=token
+ data, err := ioutil.ReadAll(NewChunkedReader(strings.NewReader(in)))
+ if err != nil {
+ t.Fatalf("ReadAll = %q, %v", data, err)
+ }
+ if g, e := string(data), "hello, world! 0123456789abcdef"; g != e {
+ t.Errorf("read %q; want %q", g, e)
}
}
diff --git a/libgo/go/net/http/internal/testcert.go b/libgo/go/net/http/internal/testcert.go
new file mode 100644
index 00000000000..407890920fa
--- /dev/null
+++ b/libgo/go/net/http/internal/testcert.go
@@ -0,0 +1,41 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package internal
+
+// LocalhostCert is a PEM-encoded TLS cert with SAN IPs
+// "127.0.0.1" and "[::1]", expiring at Jan 29 16:00:00 2084 GMT.
+// generated from src/crypto/tls:
+// go run generate_cert.go --rsa-bits 1024 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h
+var LocalhostCert = []byte(`-----BEGIN CERTIFICATE-----
+MIICEzCCAXygAwIBAgIQMIMChMLGrR+QvmQvpwAU6zANBgkqhkiG9w0BAQsFADAS
+MRAwDgYDVQQKEwdBY21lIENvMCAXDTcwMDEwMTAwMDAwMFoYDzIwODQwMTI5MTYw
+MDAwWjASMRAwDgYDVQQKEwdBY21lIENvMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCB
+iQKBgQDuLnQAI3mDgey3VBzWnB2L39JUU4txjeVE6myuDqkM/uGlfjb9SjY1bIw4
+iA5sBBZzHi3z0h1YV8QPuxEbi4nW91IJm2gsvvZhIrCHS3l6afab4pZBl2+XsDul
+rKBxKKtD1rGxlG4LjncdabFn9gvLZad2bSysqz/qTAUStTvqJQIDAQABo2gwZjAO
+BgNVHQ8BAf8EBAMCAqQwEwYDVR0lBAwwCgYIKwYBBQUHAwEwDwYDVR0TAQH/BAUw
+AwEB/zAuBgNVHREEJzAlggtleGFtcGxlLmNvbYcEfwAAAYcQAAAAAAAAAAAAAAAA
+AAAAATANBgkqhkiG9w0BAQsFAAOBgQCEcetwO59EWk7WiJsG4x8SY+UIAA+flUI9
+tyC4lNhbcF2Idq9greZwbYCqTTTr2XiRNSMLCOjKyI7ukPoPjo16ocHj+P3vZGfs
+h1fIw3cSS2OolhloGw/XM6RWPWtPAlGykKLciQrBru5NAPvCMsb/I1DAceTiotQM
+fblo6RBxUQ==
+-----END CERTIFICATE-----`)
+
+// LocalhostKey is the private key for localhostCert.
+var LocalhostKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
+MIICXgIBAAKBgQDuLnQAI3mDgey3VBzWnB2L39JUU4txjeVE6myuDqkM/uGlfjb9
+SjY1bIw4iA5sBBZzHi3z0h1YV8QPuxEbi4nW91IJm2gsvvZhIrCHS3l6afab4pZB
+l2+XsDulrKBxKKtD1rGxlG4LjncdabFn9gvLZad2bSysqz/qTAUStTvqJQIDAQAB
+AoGAGRzwwir7XvBOAy5tM/uV6e+Zf6anZzus1s1Y1ClbjbE6HXbnWWF/wbZGOpet
+3Zm4vD6MXc7jpTLryzTQIvVdfQbRc6+MUVeLKwZatTXtdZrhu+Jk7hx0nTPy8Jcb
+uJqFk541aEw+mMogY/xEcfbWd6IOkp+4xqjlFLBEDytgbIECQQDvH/E6nk+hgN4H
+qzzVtxxr397vWrjrIgPbJpQvBsafG7b0dA4AFjwVbFLmQcj2PprIMmPcQrooz8vp
+jy4SHEg1AkEA/v13/5M47K9vCxmb8QeD/asydfsgS5TeuNi8DoUBEmiSJwma7FXY
+fFUtxuvL7XvjwjN5B30pNEbc6Iuyt7y4MQJBAIt21su4b3sjXNueLKH85Q+phy2U
+fQtuUE9txblTu14q3N7gHRZB4ZMhFYyDy8CKrN2cPg/Fvyt0Xlp/DoCzjA0CQQDU
+y2ptGsuSmgUtWj3NM9xuwYPm+Z/F84K6+ARYiZ6PYj013sovGKUFfYAqVXVlxtIX
+qyUBnu3X9ps8ZfjLZO7BAkEAlT4R5Yl6cGhaJQYZHOde3JEMhNRcVFMO8dJDaFeo
+f9Oeos0UUothgiDktdQHxdNEwLjQf7lJJBzV+5OtwswCWA==
+-----END RSA PRIVATE KEY-----`)
diff --git a/libgo/go/net/http/lex.go b/libgo/go/net/http/lex.go
index 50b14f8b325..52b6481c14e 100644
--- a/libgo/go/net/http/lex.go
+++ b/libgo/go/net/http/lex.go
@@ -167,3 +167,17 @@ func tokenEqual(t1, t2 string) bool {
}
return true
}
+
+// isLWS reports whether b is linear white space, according
+// to http://www.w3.org/Protocols/rfc2616/rfc2616-sec2.html#sec2.2
+// LWS = [CRLF] 1*( SP | HT )
+func isLWS(b byte) bool { return b == ' ' || b == '\t' }
+
+// isCTL reports whether b is a control byte, according
+// to http://www.w3.org/Protocols/rfc2616/rfc2616-sec2.html#sec2.2
+// CTL = <any US-ASCII control character
+// (octets 0 - 31) and DEL (127)>
+func isCTL(b byte) bool {
+ const del = 0x7f // a CTL
+ return b < ' ' || b == del
+}
diff --git a/libgo/go/net/http/main_test.go b/libgo/go/net/http/main_test.go
index 12eea6f0e11..299cd7b2d2f 100644
--- a/libgo/go/net/http/main_test.go
+++ b/libgo/go/net/http/main_test.go
@@ -5,6 +5,7 @@
package http_test
import (
+ "flag"
"fmt"
"net/http"
"os"
@@ -15,6 +16,8 @@ import (
"time"
)
+var flaky = flag.Bool("flaky", false, "run known-flaky tests too")
+
func TestMain(m *testing.M) {
v := m.Run()
if v == 0 && goroutineLeaked() {
@@ -79,6 +82,21 @@ func goroutineLeaked() bool {
return true
}
+// setParallel marks t as a parallel test if we're in short mode
+// (all.bash), but as a serial test otherwise. Using t.Parallel isn't
+// compatible with the afterTest func in non-short mode.
+func setParallel(t *testing.T) {
+ if testing.Short() {
+ t.Parallel()
+ }
+}
+
+func setFlaky(t *testing.T, issue int) {
+ if !*flaky {
+ t.Skipf("skipping known flaky test; see golang.org/issue/%d", issue)
+ }
+}
+
func afterTest(t testing.TB) {
http.DefaultTransport.(*http.Transport).CloseIdleConnections()
if testing.Short() {
diff --git a/libgo/go/net/http/method.go b/libgo/go/net/http/method.go
new file mode 100644
index 00000000000..b74f9604d34
--- /dev/null
+++ b/libgo/go/net/http/method.go
@@ -0,0 +1,20 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package http
+
+// Common HTTP methods.
+//
+// Unless otherwise noted, these are defined in RFC 7231 section 4.3.
+const (
+ MethodGet = "GET"
+ MethodHead = "HEAD"
+ MethodPost = "POST"
+ MethodPut = "PUT"
+ MethodPatch = "PATCH" // RFC 5741
+ MethodDelete = "DELETE"
+ MethodConnect = "CONNECT"
+ MethodOptions = "OPTIONS"
+ MethodTrace = "TRACE"
+)
diff --git a/libgo/go/net/http/pprof/pprof.go b/libgo/go/net/http/pprof/pprof.go
index 8994392b1e4..7262c6c1016 100644
--- a/libgo/go/net/http/pprof/pprof.go
+++ b/libgo/go/net/http/pprof/pprof.go
@@ -79,6 +79,17 @@ func Cmdline(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, strings.Join(os.Args, "\x00"))
}
+func sleep(w http.ResponseWriter, d time.Duration) {
+ var clientGone <-chan bool
+ if cn, ok := w.(http.CloseNotifier); ok {
+ clientGone = cn.CloseNotify()
+ }
+ select {
+ case <-time.After(d):
+ case <-clientGone:
+ }
+}
+
// Profile responds with the pprof-formatted cpu profile.
// The package initialization registers it as /debug/pprof/profile.
func Profile(w http.ResponseWriter, r *http.Request) {
@@ -99,7 +110,7 @@ func Profile(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "Could not enable CPU profiling: %s\n", err)
return
}
- time.Sleep(time.Duration(sec) * time.Second)
+ sleep(w, time.Duration(sec)*time.Second)
pprof.StopCPUProfile()
}
@@ -125,7 +136,7 @@ func Trace(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "Could not enable tracing: %s\n", err)
return
}
- time.Sleep(time.Duration(sec) * time.Second)
+ sleep(w, time.Duration(sec)*time.Second)
trace.Stop()
*/
}
diff --git a/libgo/go/net/http/request.go b/libgo/go/net/http/request.go
index 31fe45a4edb..16c5bb43ac4 100644
--- a/libgo/go/net/http/request.go
+++ b/libgo/go/net/http/request.go
@@ -90,8 +90,11 @@ type Request struct {
// request.
URL *url.URL
- // The protocol version for incoming requests.
- // Client requests always use HTTP/1.1.
+ // The protocol version for incoming server requests.
+ //
+ // For client requests these fields are ignored. The HTTP
+ // client code always uses either HTTP/1.1 or HTTP/2.
+ // See the docs on Transport for details.
Proto string // "HTTP/1.0"
ProtoMajor int // 1
ProtoMinor int // 0
@@ -354,7 +357,7 @@ const defaultUserAgent = "Go-http-client/1.1"
// hasn't been set to "identity", Write adds "Transfer-Encoding:
// chunked" to the header. Body is closed after it is sent.
func (r *Request) Write(w io.Writer) error {
- return r.write(w, false, nil)
+ return r.write(w, false, nil, nil)
}
// WriteProxy is like Write but writes the request in the form
@@ -364,11 +367,16 @@ func (r *Request) Write(w io.Writer) error {
// In either case, WriteProxy also writes a Host header, using
// either r.Host or r.URL.Host.
func (r *Request) WriteProxy(w io.Writer) error {
- return r.write(w, true, nil)
+ return r.write(w, true, nil, nil)
}
+// errMissingHost is returned by Write when there is no Host or URL present in
+// the Request.
+var errMissingHost = errors.New("http: Request.Write on Request with no Host or URL set")
+
// extraHeaders may be nil
-func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) error {
+// waitForContinue may be nil
+func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, waitForContinue func() bool) error {
// Find the target host. Prefer the Host: header, but if that
// is not given, use the host from the request URL.
//
@@ -376,7 +384,7 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) err
host := cleanHost(req.Host)
if host == "" {
if req.URL == nil {
- return errors.New("http: Request.Write on Request with no Host or URL set")
+ return errMissingHost
}
host = cleanHost(req.URL.Host)
}
@@ -419,10 +427,8 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) err
// Use the defaultUserAgent unless the Header contains one, which
// may be blank to not send the header.
userAgent := defaultUserAgent
- if req.Header != nil {
- if ua := req.Header["User-Agent"]; len(ua) > 0 {
- userAgent = ua[0]
- }
+ if _, ok := req.Header["User-Agent"]; ok {
+ userAgent = req.Header.Get("User-Agent")
}
if userAgent != "" {
_, err = fmt.Fprintf(w, "User-Agent: %s\r\n", userAgent)
@@ -458,6 +464,21 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header) err
return err
}
+ // Flush and wait for 100-continue if expected.
+ if waitForContinue != nil {
+ if bw, ok := w.(*bufio.Writer); ok {
+ err = bw.Flush()
+ if err != nil {
+ return err
+ }
+ }
+
+ if !waitForContinue() {
+ req.closeBody()
+ return nil
+ }
+ }
+
// Write body and trailer
err = tw.WriteBody(w)
if err != nil {
@@ -531,6 +552,23 @@ func ParseHTTPVersion(vers string) (major, minor int, ok bool) {
return major, minor, true
}
+func validMethod(method string) bool {
+ /*
+ Method = "OPTIONS" ; Section 9.2
+ | "GET" ; Section 9.3
+ | "HEAD" ; Section 9.4
+ | "POST" ; Section 9.5
+ | "PUT" ; Section 9.6
+ | "DELETE" ; Section 9.7
+ | "TRACE" ; Section 9.8
+ | "CONNECT" ; Section 9.9
+ | extension-method
+ extension-method = token
+ token = 1*<any CHAR except CTLs or separators>
+ */
+ return len(method) > 0 && strings.IndexFunc(method, isNotToken) == -1
+}
+
// NewRequest returns a new Request given a method, URL, and optional body.
//
// If the provided body is also an io.Closer, the returned
@@ -544,6 +582,15 @@ func ParseHTTPVersion(vers string) (major, minor int, ok bool) {
// type's documentation for the difference between inbound and outbound
// request fields.
func NewRequest(method, urlStr string, body io.Reader) (*Request, error) {
+ if method == "" {
+ // We document that "" means "GET" for Request.Method, and people have
+ // relied on that from NewRequest, so keep that working.
+ // We still enforce validMethod for non-empty methods.
+ method = "GET"
+ }
+ if !validMethod(method) {
+ return nil, fmt.Errorf("net/http: invalid method %q", method)
+ }
u, err := url.Parse(urlStr)
if err != nil {
return nil, err
@@ -643,8 +690,15 @@ func putTextprotoReader(r *textproto.Reader) {
}
// ReadRequest reads and parses an incoming request from b.
-func ReadRequest(b *bufio.Reader) (req *Request, err error) {
+func ReadRequest(b *bufio.Reader) (req *Request, err error) { return readRequest(b, deleteHostHeader) }
+
+// Constants for readRequest's deleteHostHeader parameter.
+const (
+ deleteHostHeader = true
+ keepHostHeader = false
+)
+func readRequest(b *bufio.Reader, deleteHostHeader bool) (req *Request, err error) {
tp := newTextprotoReader(b)
req = new(Request)
@@ -711,7 +765,9 @@ func ReadRequest(b *bufio.Reader) (req *Request, err error) {
if req.Host == "" {
req.Host = req.Header.get("Host")
}
- delete(req.Header, "Host")
+ if deleteHostHeader {
+ delete(req.Header, "Host")
+ }
fixPragmaCacheControl(req.Header)
@@ -1006,3 +1062,102 @@ func (r *Request) closeBody() {
r.Body.Close()
}
}
+
+func (r *Request) isReplayable() bool {
+ if r.Body == nil {
+ switch valueOrDefault(r.Method, "GET") {
+ case "GET", "HEAD", "OPTIONS", "TRACE":
+ return true
+ }
+ }
+ return false
+}
+
+func validHostHeader(h string) bool {
+ // The latests spec is actually this:
+ //
+ // http://tools.ietf.org/html/rfc7230#section-5.4
+ // Host = uri-host [ ":" port ]
+ //
+ // Where uri-host is:
+ // http://tools.ietf.org/html/rfc3986#section-3.2.2
+ //
+ // But we're going to be much more lenient for now and just
+ // search for any byte that's not a valid byte in any of those
+ // expressions.
+ for i := 0; i < len(h); i++ {
+ if !validHostByte[h[i]] {
+ return false
+ }
+ }
+ return true
+}
+
+// See the validHostHeader comment.
+var validHostByte = [256]bool{
+ '0': true, '1': true, '2': true, '3': true, '4': true, '5': true, '6': true, '7': true,
+ '8': true, '9': true,
+
+ 'a': true, 'b': true, 'c': true, 'd': true, 'e': true, 'f': true, 'g': true, 'h': true,
+ 'i': true, 'j': true, 'k': true, 'l': true, 'm': true, 'n': true, 'o': true, 'p': true,
+ 'q': true, 'r': true, 's': true, 't': true, 'u': true, 'v': true, 'w': true, 'x': true,
+ 'y': true, 'z': true,
+
+ 'A': true, 'B': true, 'C': true, 'D': true, 'E': true, 'F': true, 'G': true, 'H': true,
+ 'I': true, 'J': true, 'K': true, 'L': true, 'M': true, 'N': true, 'O': true, 'P': true,
+ 'Q': true, 'R': true, 'S': true, 'T': true, 'U': true, 'V': true, 'W': true, 'X': true,
+ 'Y': true, 'Z': true,
+
+ '!': true, // sub-delims
+ '$': true, // sub-delims
+ '%': true, // pct-encoded (and used in IPv6 zones)
+ '&': true, // sub-delims
+ '(': true, // sub-delims
+ ')': true, // sub-delims
+ '*': true, // sub-delims
+ '+': true, // sub-delims
+ ',': true, // sub-delims
+ '-': true, // unreserved
+ '.': true, // unreserved
+ ':': true, // IPv6address + Host expression's optional port
+ ';': true, // sub-delims
+ '=': true, // sub-delims
+ '[': true,
+ '\'': true, // sub-delims
+ ']': true,
+ '_': true, // unreserved
+ '~': true, // unreserved
+}
+
+func validHeaderName(v string) bool {
+ if len(v) == 0 {
+ return false
+ }
+ return strings.IndexFunc(v, isNotToken) == -1
+}
+
+// validHeaderValue reports whether v is a valid "field-value" according to
+// http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2 :
+//
+// message-header = field-name ":" [ field-value ]
+// field-value = *( field-content | LWS )
+// field-content = <the OCTETs making up the field-value
+// and consisting of either *TEXT or combinations
+// of token, separators, and quoted-string>
+//
+// http://www.w3.org/Protocols/rfc2616/rfc2616-sec2.html#sec2.2 :
+//
+// TEXT = <any OCTET except CTLs,
+// but including LWS>
+// LWS = [CRLF] 1*( SP | HT )
+// CTL = <any US-ASCII control character
+// (octets 0 - 31) and DEL (127)>
+func validHeaderValue(v string) bool {
+ for i := 0; i < len(v); i++ {
+ b := v[i]
+ if isCTL(b) && !isLWS(b) {
+ return false
+ }
+ }
+ return true
+}
diff --git a/libgo/go/net/http/request_test.go b/libgo/go/net/http/request_test.go
index 627620c0c41..0ecdf85a563 100644
--- a/libgo/go/net/http/request_test.go
+++ b/libgo/go/net/http/request_test.go
@@ -13,7 +13,6 @@ import (
"io/ioutil"
"mime/multipart"
. "net/http"
- "net/http/httptest"
"net/url"
"os"
"reflect"
@@ -177,9 +176,11 @@ func TestParseMultipartForm(t *testing.T) {
}
}
-func TestRedirect(t *testing.T) {
+func TestRedirect_h1(t *testing.T) { testRedirect(t, h1Mode) }
+func TestRedirect_h2(t *testing.T) { testRedirect(t, h2Mode) }
+func testRedirect(t *testing.T, h2 bool) {
defer afterTest(t)
- ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
switch r.URL.Path {
case "/":
w.Header().Set("Location", "/foo/")
@@ -190,10 +191,10 @@ func TestRedirect(t *testing.T) {
w.WriteHeader(StatusBadRequest)
}
}))
- defer ts.Close()
+ defer cst.close()
var end = regexp.MustCompile("/foo/$")
- r, err := Get(ts.URL)
+ r, err := cst.c.Get(cst.ts.URL)
if err != nil {
t.Fatal(err)
}
@@ -355,6 +356,29 @@ func TestNewRequestHost(t *testing.T) {
}
}
+func TestRequestInvalidMethod(t *testing.T) {
+ _, err := NewRequest("bad method", "http://foo.com/", nil)
+ if err == nil {
+ t.Error("expected error from NewRequest with invalid method")
+ }
+ req, err := NewRequest("GET", "http://foo.example/", nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ req.Method = "bad method"
+ _, err = DefaultClient.Do(req)
+ if err == nil || !strings.Contains(err.Error(), "invalid method") {
+ t.Errorf("Transport error = %v; want invalid method", err)
+ }
+
+ req, err = NewRequest("", "http://foo.com/", nil)
+ if err != nil {
+ t.Errorf("NewRequest(empty method) = %v; want nil", err)
+ } else if req.Method != "GET" {
+ t.Errorf("NewRequest(empty method) has method %q; want GET", req.Method)
+ }
+}
+
func TestNewRequestContentLength(t *testing.T) {
readByte := func(r io.Reader) io.Reader {
var b [1]byte
@@ -515,10 +539,12 @@ func TestRequestWriteBufferedWriter(t *testing.T) {
func TestRequestBadHost(t *testing.T) {
got := []string{}
- req, err := NewRequest("GET", "http://foo.com with spaces/after", nil)
+ req, err := NewRequest("GET", "http://foo/after", nil)
if err != nil {
t.Fatal(err)
}
+ req.Host = "foo.com with spaces"
+ req.URL.Host = "foo.com with spaces"
req.Write(logWrites{t, &got})
want := []string{
"GET /after HTTP/1.1\r\n",
diff --git a/libgo/go/net/http/response.go b/libgo/go/net/http/response.go
index 76b85385244..c424f61cd00 100644
--- a/libgo/go/net/http/response.go
+++ b/libgo/go/net/http/response.go
@@ -72,8 +72,18 @@ type Response struct {
// ReadResponse nor Response.Write ever closes a connection.
Close bool
- // Trailer maps trailer keys to values, in the same
- // format as the header.
+ // Trailer maps trailer keys to values in the same
+ // format as Header.
+ //
+ // The Trailer initially contains only nil values, one for
+ // each key specified in the server's "Trailer" header
+ // value. Those values are not added to Header.
+ //
+ // Trailer must not be accessed concurrently with Read calls
+ // on the Body.
+ //
+ // After Body.Read has returned io.EOF, Trailer will contain
+ // any trailer values sent by the server.
Trailer Header
// The Request that was sent to obtain this Response.
@@ -140,12 +150,14 @@ func ReadResponse(r *bufio.Reader, req *Request) (*Response, error) {
if len(f) > 2 {
reasonPhrase = f[2]
}
- resp.Status = f[1] + " " + reasonPhrase
+ if len(f[1]) != 3 {
+ return nil, &badStringError{"malformed HTTP status code", f[1]}
+ }
resp.StatusCode, err = strconv.Atoi(f[1])
- if err != nil {
+ if err != nil || resp.StatusCode < 0 {
return nil, &badStringError{"malformed HTTP status code", f[1]}
}
-
+ resp.Status = f[1] + " " + reasonPhrase
resp.Proto = f[0]
var ok bool
if resp.ProtoMajor, resp.ProtoMinor, ok = ParseHTTPVersion(resp.Proto); !ok {
diff --git a/libgo/go/net/http/response_test.go b/libgo/go/net/http/response_test.go
index 421cf55f491..d8a53400cf2 100644
--- a/libgo/go/net/http/response_test.go
+++ b/libgo/go/net/http/response_test.go
@@ -456,8 +456,59 @@ some body`,
"",
},
+
+ // Issue 12785: HTTP/1.0 response with bogus (to be ignored) Transfer-Encoding.
+ // Without a Content-Length.
+ {
+ "HTTP/1.0 200 OK\r\n" +
+ "Transfer-Encoding: bogus\r\n" +
+ "\r\n" +
+ "Body here\n",
+
+ Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.0",
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Request: dummyReq("GET"),
+ Header: Header{},
+ Close: true,
+ ContentLength: -1,
+ },
+
+ "Body here\n",
+ },
+
+ // Issue 12785: HTTP/1.0 response with bogus (to be ignored) Transfer-Encoding.
+ // With a Content-Length.
+ {
+ "HTTP/1.0 200 OK\r\n" +
+ "Transfer-Encoding: bogus\r\n" +
+ "Content-Length: 10\r\n" +
+ "\r\n" +
+ "Body here\n",
+
+ Response{
+ Status: "200 OK",
+ StatusCode: 200,
+ Proto: "HTTP/1.0",
+ ProtoMajor: 1,
+ ProtoMinor: 0,
+ Request: dummyReq("GET"),
+ Header: Header{
+ "Content-Length": {"10"},
+ },
+ Close: true,
+ ContentLength: 10,
+ },
+
+ "Body here\n",
+ },
}
+// tests successful calls to ReadResponse, and inspects the returned Response.
+// For error cases, see TestReadResponseErrors below.
func TestReadResponse(t *testing.T) {
for i, tt := range respTests {
resp, err := ReadResponse(bufio.NewReader(strings.NewReader(tt.Raw)), tt.Resp.Request)
@@ -624,6 +675,7 @@ var responseLocationTests = []responseLocationTest{
{"/foo", "http://bar.com/baz", "http://bar.com/foo", nil},
{"http://foo.com/", "http://bar.com/baz", "http://foo.com/", nil},
{"", "http://bar.com/baz", "", ErrNoLocation},
+ {"/bar", "", "/bar", nil},
}
func TestLocationResponse(t *testing.T) {
@@ -702,13 +754,106 @@ func TestResponseContentLengthShortBody(t *testing.T) {
}
}
-func TestReadResponseUnexpectedEOF(t *testing.T) {
- br := bufio.NewReader(strings.NewReader("HTTP/1.1 301 Moved Permanently\r\n" +
- "Location: http://example.com"))
- _, err := ReadResponse(br, nil)
- if err != io.ErrUnexpectedEOF {
- t.Errorf("ReadResponse = %v; want io.ErrUnexpectedEOF", err)
+// Test various ReadResponse error cases. (also tests success cases, but mostly
+// it's about errors). This does not test anything involving the bodies. Only
+// the return value from ReadResponse itself.
+func TestReadResponseErrors(t *testing.T) {
+ type testCase struct {
+ name string // optional, defaults to in
+ in string
+ wantErr interface{} // nil, err value, or string substring
+ }
+
+ status := func(s string, wantErr interface{}) testCase {
+ if wantErr == true {
+ wantErr = "malformed HTTP status code"
+ }
+ return testCase{
+ name: fmt.Sprintf("status %q", s),
+ in: "HTTP/1.1 " + s + "\r\nFoo: bar\r\n\r\n",
+ wantErr: wantErr,
+ }
+ }
+
+ version := func(s string, wantErr interface{}) testCase {
+ if wantErr == true {
+ wantErr = "malformed HTTP version"
+ }
+ return testCase{
+ name: fmt.Sprintf("version %q", s),
+ in: s + " 200 OK\r\n\r\n",
+ wantErr: wantErr,
+ }
+ }
+
+ tests := []testCase{
+ {"", "", io.ErrUnexpectedEOF},
+ {"", "HTTP/1.1 301 Moved Permanently\r\nFoo: bar", io.ErrUnexpectedEOF},
+ {"", "HTTP/1.1", "malformed HTTP response"},
+ {"", "HTTP/2.0", "malformed HTTP response"},
+ status("20X Unknown", true),
+ status("abcd Unknown", true),
+ status("二百/两百 OK", true),
+ status(" Unknown", true),
+ status("c8 OK", true),
+ status("0x12d Moved Permanently", true),
+ status("200 OK", nil),
+ status("000 OK", nil),
+ status("001 OK", nil),
+ status("404 NOTFOUND", nil),
+ status("20 OK", true),
+ status("00 OK", true),
+ status("-10 OK", true),
+ status("1000 OK", true),
+ status("999 Done", nil),
+ status("-1 OK", true),
+ status("-200 OK", true),
+ version("HTTP/1.2", nil),
+ version("HTTP/2.0", nil),
+ version("HTTP/1.100000000002", true),
+ version("HTTP/1.-1", true),
+ version("HTTP/A.B", true),
+ version("HTTP/1", true),
+ version("http/1.1", true),
+ }
+ for i, tt := range tests {
+ br := bufio.NewReader(strings.NewReader(tt.in))
+ _, rerr := ReadResponse(br, nil)
+ if err := matchErr(rerr, tt.wantErr); err != nil {
+ name := tt.name
+ if name == "" {
+ name = fmt.Sprintf("%d. input %q", i, tt.in)
+ }
+ t.Errorf("%s: %v", name, err)
+ }
+ }
+}
+
+// wantErr can be nil, an error value to match exactly, or type string to
+// match a substring.
+func matchErr(err error, wantErr interface{}) error {
+ if err == nil {
+ if wantErr == nil {
+ return nil
+ }
+ if sub, ok := wantErr.(string); ok {
+ return fmt.Errorf("unexpected success; want error with substring %q", sub)
+ }
+ return fmt.Errorf("unexpected success; want error %v", wantErr)
+ }
+ if wantErr == nil {
+ return fmt.Errorf("%v; want success", err)
+ }
+ if sub, ok := wantErr.(string); ok {
+ if strings.Contains(err.Error(), sub) {
+ return nil
+ }
+ return fmt.Errorf("error = %v; want an error with substring %q", err, sub)
+ }
+ if err == wantErr {
+ return nil
}
+ return fmt.Errorf("%v; want %v", err, wantErr)
}
func TestNeedsSniff(t *testing.T) {
diff --git a/libgo/go/net/http/serve_test.go b/libgo/go/net/http/serve_test.go
index d51417eb4a0..f8cad802d49 100644
--- a/libgo/go/net/http/serve_test.go
+++ b/libgo/go/net/http/serve_test.go
@@ -12,6 +12,7 @@ import (
"crypto/tls"
"errors"
"fmt"
+ "internal/testenv"
"io"
"io/ioutil"
"log"
@@ -26,6 +27,8 @@ import (
"os/exec"
"reflect"
"runtime"
+ "runtime/debug"
+ "sort"
"strconv"
"strings"
"sync"
@@ -96,6 +99,7 @@ func (c *rwTestConn) Close() error {
}
type testConn struct {
+ readMu sync.Mutex // for TestHandlerBodyClose
readBuf bytes.Buffer
writeBuf bytes.Buffer
closec chan bool // if non-nil, send value to it on close
@@ -103,6 +107,8 @@ type testConn struct {
}
func (c *testConn) Read(b []byte) (int, error) {
+ c.readMu.Lock()
+ defer c.readMu.Unlock()
return c.readBuf.Read(b)
}
@@ -450,6 +456,7 @@ func TestServerTimeouts(t *testing.T) {
if runtime.GOOS == "plan9" {
t.Skip("skipping test; see https://golang.org/issue/7237")
}
+ setParallel(t)
defer afterTest(t)
reqNum := 0
ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) {
@@ -734,14 +741,17 @@ func TestHandlersCanSetConnectionClose10(t *testing.T) {
}))
}
-func TestSetsRemoteAddr(t *testing.T) {
+func TestSetsRemoteAddr_h1(t *testing.T) { testSetsRemoteAddr(t, h1Mode) }
+func TestSetsRemoteAddr_h2(t *testing.T) { testSetsRemoteAddr(t, h2Mode) }
+
+func testSetsRemoteAddr(t *testing.T, h2 bool) {
defer afterTest(t)
- ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
fmt.Fprintf(w, "%s", r.RemoteAddr)
}))
- defer ts.Close()
+ defer cst.close()
- res, err := Get(ts.URL)
+ res, err := cst.c.Get(cst.ts.URL)
if err != nil {
t.Fatalf("Get error: %v", err)
}
@@ -755,34 +765,106 @@ func TestSetsRemoteAddr(t *testing.T) {
}
}
-func TestChunkedResponseHeaders(t *testing.T) {
- defer afterTest(t)
- log.SetOutput(ioutil.Discard) // is noisy otherwise
- defer log.SetOutput(os.Stderr)
+type blockingRemoteAddrListener struct {
+ net.Listener
+ conns chan<- net.Conn
+}
- ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
- w.Header().Set("Content-Length", "intentional gibberish") // we check that this is deleted
- w.(Flusher).Flush()
- fmt.Fprintf(w, "I am a chunked response.")
+func (l *blockingRemoteAddrListener) Accept() (net.Conn, error) {
+ c, err := l.Listener.Accept()
+ if err != nil {
+ return nil, err
+ }
+ brac := &blockingRemoteAddrConn{
+ Conn: c,
+ addrs: make(chan net.Addr, 1),
+ }
+ l.conns <- brac
+ return brac, nil
+}
+
+type blockingRemoteAddrConn struct {
+ net.Conn
+ addrs chan net.Addr
+}
+
+func (c *blockingRemoteAddrConn) RemoteAddr() net.Addr {
+ return <-c.addrs
+}
+
+// Issue 12943
+func TestServerAllowsBlockingRemoteAddr(t *testing.T) {
+ defer afterTest(t)
+ ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ fmt.Fprintf(w, "RA:%s", r.RemoteAddr)
}))
+ conns := make(chan net.Conn)
+ ts.Listener = &blockingRemoteAddrListener{
+ Listener: ts.Listener,
+ conns: conns,
+ }
+ ts.Start()
defer ts.Close()
- res, err := Get(ts.URL)
- if err != nil {
- t.Fatalf("Get error: %v", err)
+ tr := &Transport{DisableKeepAlives: true}
+ defer tr.CloseIdleConnections()
+ c := &Client{Transport: tr, Timeout: time.Second}
+
+ fetch := func(response chan string) {
+ resp, err := c.Get(ts.URL)
+ if err != nil {
+ t.Error(err)
+ response <- ""
+ return
+ }
+ defer resp.Body.Close()
+ body, err := ioutil.ReadAll(resp.Body)
+ if err != nil {
+ t.Error(err)
+ response <- ""
+ return
+ }
+ response <- string(body)
}
- defer res.Body.Close()
- if g, e := res.ContentLength, int64(-1); g != e {
- t.Errorf("expected ContentLength of %d; got %d", e, g)
+
+ // Start a request. The server will block on getting conn.RemoteAddr.
+ response1c := make(chan string, 1)
+ go fetch(response1c)
+
+ // Wait for the server to accept it; grab the connection.
+ conn1 := <-conns
+
+ // Start another request and grab its connection
+ response2c := make(chan string, 1)
+ go fetch(response2c)
+ var conn2 net.Conn
+
+ select {
+ case conn2 = <-conns:
+ case <-time.After(time.Second):
+ t.Fatal("Second Accept didn't happen")
}
- if g, e := res.TransferEncoding, []string{"chunked"}; !reflect.DeepEqual(g, e) {
- t.Errorf("expected TransferEncoding of %v; got %v", e, g)
+
+ // Send a response on connection 2.
+ conn2.(*blockingRemoteAddrConn).addrs <- &net.TCPAddr{
+ IP: net.ParseIP("12.12.12.12"), Port: 12}
+
+ // ... and see it
+ response2 := <-response2c
+ if g, e := response2, "RA:12.12.12.12:12"; g != e {
+ t.Fatalf("response 2 addr = %q; want %q", g, e)
}
- if _, haveCL := res.Header["Content-Length"]; haveCL {
- t.Errorf("Unexpected Content-Length")
+
+ // Finish the first response.
+ conn1.(*blockingRemoteAddrConn).addrs <- &net.TCPAddr{
+ IP: net.ParseIP("21.21.21.21"), Port: 21}
+
+ // ... and see it
+ response1 := <-response1c
+ if g, e := response1, "RA:21.21.21.21:21"; g != e {
+ t.Fatalf("response 1 addr = %q; want %q", g, e)
}
}
-
func TestIdentityResponseHeaders(t *testing.T) {
defer afterTest(t)
log.SetOutput(ioutil.Discard) // is noisy otherwise
@@ -812,40 +894,14 @@ func TestIdentityResponseHeaders(t *testing.T) {
}
}
-// Test304Responses verifies that 304s don't declare that they're
-// chunking in their response headers and aren't allowed to produce
-// output.
-func Test304Responses(t *testing.T) {
- defer afterTest(t)
- ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
- w.WriteHeader(StatusNotModified)
- _, err := w.Write([]byte("illegal body"))
- if err != ErrBodyNotAllowed {
- t.Errorf("on Write, expected ErrBodyNotAllowed, got %v", err)
- }
- }))
- defer ts.Close()
- res, err := Get(ts.URL)
- if err != nil {
- t.Error(err)
- }
- if len(res.TransferEncoding) > 0 {
- t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding)
- }
- body, err := ioutil.ReadAll(res.Body)
- if err != nil {
- t.Error(err)
- }
- if len(body) > 0 {
- t.Errorf("got unexpected body %q", string(body))
- }
-}
-
// TestHeadResponses verifies that all MIME type sniffing and Content-Length
// counting of GET requests also happens on HEAD requests.
-func TestHeadResponses(t *testing.T) {
+func TestHeadResponses_h1(t *testing.T) { testHeadResponses(t, h1Mode) }
+func TestHeadResponses_h2(t *testing.T) { testHeadResponses(t, h2Mode) }
+
+func testHeadResponses(t *testing.T, h2 bool) {
defer afterTest(t)
- ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
_, err := w.Write([]byte("<html>"))
if err != nil {
t.Errorf("ResponseWriter.Write: %v", err)
@@ -857,8 +913,8 @@ func TestHeadResponses(t *testing.T) {
t.Errorf("Copy(ResponseWriter, ...): %v", err)
}
}))
- defer ts.Close()
- res, err := Head(ts.URL)
+ defer cst.close()
+ res, err := cst.c.Head(cst.ts.URL)
if err != nil {
t.Error(err)
}
@@ -884,6 +940,7 @@ func TestTLSHandshakeTimeout(t *testing.T) {
if runtime.GOOS == "plan9" {
t.Skip("skipping test; see https://golang.org/issue/7237")
}
+ setParallel(t)
defer afterTest(t)
ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {}))
errc := make(chanWriter, 10) // but only expecting 1
@@ -967,6 +1024,79 @@ func TestTLSServer(t *testing.T) {
})
}
+func TestAutomaticHTTP2_Serve(t *testing.T) {
+ defer afterTest(t)
+ ln := newLocalListener(t)
+ ln.Close() // immediately (not a defer!)
+ var s Server
+ if err := s.Serve(ln); err == nil {
+ t.Fatal("expected an error")
+ }
+ on := s.TLSNextProto["h2"] != nil
+ if !on {
+ t.Errorf("http2 wasn't automatically enabled")
+ }
+}
+
+func TestAutomaticHTTP2_ListenAndServe(t *testing.T) {
+ defer afterTest(t)
+ defer SetTestHookServerServe(nil)
+ cert, err := tls.X509KeyPair(internal.LocalhostCert, internal.LocalhostKey)
+ if err != nil {
+ t.Fatal(err)
+ }
+ var ok bool
+ var s *Server
+ const maxTries = 5
+ var ln net.Listener
+Try:
+ for try := 0; try < maxTries; try++ {
+ ln = newLocalListener(t)
+ addr := ln.Addr().String()
+ ln.Close()
+ t.Logf("Got %v", addr)
+ lnc := make(chan net.Listener, 1)
+ SetTestHookServerServe(func(s *Server, ln net.Listener) {
+ lnc <- ln
+ })
+ s = &Server{
+ Addr: addr,
+ TLSConfig: &tls.Config{
+ Certificates: []tls.Certificate{cert},
+ },
+ }
+ errc := make(chan error, 1)
+ go func() { errc <- s.ListenAndServeTLS("", "") }()
+ select {
+ case err := <-errc:
+ t.Logf("On try #%v: %v", try+1, err)
+ continue
+ case ln = <-lnc:
+ ok = true
+ t.Logf("Listening on %v", ln.Addr().String())
+ break Try
+ }
+ }
+ if !ok {
+ t.Fatalf("Failed to start up after %d tries", maxTries)
+ }
+ defer ln.Close()
+ c, err := tls.Dial("tcp", ln.Addr().String(), &tls.Config{
+ InsecureSkipVerify: true,
+ NextProtos: []string{"h2", "http/1.1"},
+ })
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer c.Close()
+ if got, want := c.ConnectionState().NegotiatedProtocol, "h2"; got != want {
+ t.Errorf("NegotiatedProtocol = %q; want %q", got, want)
+ }
+ if got, want := c.ConnectionState().NegotiatedProtocolIsMutual, true; got != want {
+ t.Errorf("NegotiatedProtocolIsMutual = %v; want %v", got, want)
+ }
+}
+
type serverExpectTest struct {
contentLength int // of request body
chunked bool
@@ -1016,6 +1146,7 @@ var serverExpectTests = []serverExpectTest{
// Tests that the server responds to the "Expect" request header
// correctly.
+// http2 test: TestServer_Response_Automatic100Continue
func TestServerExpect(t *testing.T) {
defer afterTest(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
@@ -1122,15 +1253,21 @@ func TestServerUnreadRequestBodyLittle(t *testing.T) {
done := make(chan bool)
+ readBufLen := func() int {
+ conn.readMu.Lock()
+ defer conn.readMu.Unlock()
+ return conn.readBuf.Len()
+ }
+
ls := &oneConnListener{conn}
go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
defer close(done)
- if conn.readBuf.Len() < len(body)/2 {
- t.Errorf("on request, read buffer length is %d; expected about 100 KB", conn.readBuf.Len())
+ if bufLen := readBufLen(); bufLen < len(body)/2 {
+ t.Errorf("on request, read buffer length is %d; expected about 100 KB", bufLen)
}
rw.WriteHeader(200)
rw.(Flusher).Flush()
- if g, e := conn.readBuf.Len(), 0; g != e {
+ if g, e := readBufLen(), 0; g != e {
t.Errorf("after WriteHeader, read buffer length is %d; want %d", g, e)
}
if c := rw.Header().Get("Connection"); c != "" {
@@ -1144,6 +1281,9 @@ func TestServerUnreadRequestBodyLittle(t *testing.T) {
// should ignore client request bodies that a handler didn't read
// and close the connection.
func TestServerUnreadRequestBodyLarge(t *testing.T) {
+ if testing.Short() && testenv.Builder() == "" {
+ t.Log("skipping in short mode")
+ }
conn := new(testConn)
body := strings.Repeat("x", 1<<20)
conn.readBuf.Write([]byte(fmt.Sprintf(
@@ -1274,6 +1414,9 @@ var handlerBodyCloseTests = [...]handlerBodyCloseTest{
}
func TestHandlerBodyClose(t *testing.T) {
+ if testing.Short() && testenv.Builder() == "" {
+ t.Skip("skipping in -short mode")
+ }
for i, tt := range handlerBodyCloseTests {
testHandlerBodyClose(t, i, tt)
}
@@ -1306,15 +1449,21 @@ func testHandlerBodyClose(t *testing.T, i int, tt handlerBodyCloseTest) {
}
conn.closec = make(chan bool, 1)
+ readBufLen := func() int {
+ conn.readMu.Lock()
+ defer conn.readMu.Unlock()
+ return conn.readBuf.Len()
+ }
+
ls := &oneConnListener{conn}
var numReqs int
var size0, size1 int
go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
numReqs++
if numReqs == 1 {
- size0 = conn.readBuf.Len()
+ size0 = readBufLen()
req.Body.Close()
- size1 = conn.readBuf.Len()
+ size1 = readBufLen()
}
}))
<-conn.closec
@@ -1414,7 +1563,9 @@ type slowTestConn struct {
// over multiple calls to Read, time.Durations are slept, strings are read.
script []interface{}
closec chan bool
- rd, wd time.Time // read, write deadline
+
+ mu sync.Mutex // guards rd/wd
+ rd, wd time.Time // read, write deadline
noopConn
}
@@ -1425,16 +1576,22 @@ func (c *slowTestConn) SetDeadline(t time.Time) error {
}
func (c *slowTestConn) SetReadDeadline(t time.Time) error {
+ c.mu.Lock()
+ defer c.mu.Unlock()
c.rd = t
return nil
}
func (c *slowTestConn) SetWriteDeadline(t time.Time) error {
+ c.mu.Lock()
+ defer c.mu.Unlock()
c.wd = t
return nil
}
func (c *slowTestConn) Read(b []byte) (n int, err error) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
restart:
if !c.rd.IsZero() && time.Now().After(c.rd) {
return 0, syscall.ETIMEDOUT
@@ -1531,7 +1688,9 @@ func TestRequestBodyTimeoutClosesConnection(t *testing.T) {
}
}
-func TestTimeoutHandler(t *testing.T) {
+func TestTimeoutHandler_h1(t *testing.T) { testTimeoutHandler(t, h1Mode) }
+func TestTimeoutHandler_h2(t *testing.T) { testTimeoutHandler(t, h2Mode) }
+func testTimeoutHandler(t *testing.T, h2 bool) {
defer afterTest(t)
sendHi := make(chan bool, 1)
writeErrors := make(chan error, 1)
@@ -1541,12 +1700,12 @@ func TestTimeoutHandler(t *testing.T) {
writeErrors <- werr
})
timeout := make(chan time.Time, 1) // write to this to force timeouts
- ts := httptest.NewServer(NewTestTimeoutHandler(sayHi, timeout))
- defer ts.Close()
+ cst := newClientServerTest(t, h2, NewTestTimeoutHandler(sayHi, timeout))
+ defer cst.close()
// Succeed without timing out:
sendHi <- true
- res, err := Get(ts.URL)
+ res, err := cst.c.Get(cst.ts.URL)
if err != nil {
t.Error(err)
}
@@ -1563,7 +1722,7 @@ func TestTimeoutHandler(t *testing.T) {
// Times out:
timeout <- time.Time{}
- res, err = Get(ts.URL)
+ res, err = cst.c.Get(cst.ts.URL)
if err != nil {
t.Error(err)
}
@@ -1659,6 +1818,60 @@ func TestTimeoutHandlerRaceHeader(t *testing.T) {
wg.Wait()
}
+// Issue 9162
+func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) {
+ defer afterTest(t)
+ sendHi := make(chan bool, 1)
+ writeErrors := make(chan error, 1)
+ sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Type", "text/plain")
+ <-sendHi
+ _, werr := w.Write([]byte("hi"))
+ writeErrors <- werr
+ })
+ timeout := make(chan time.Time, 1) // write to this to force timeouts
+ cst := newClientServerTest(t, h1Mode, NewTestTimeoutHandler(sayHi, timeout))
+ defer cst.close()
+
+ // Succeed without timing out:
+ sendHi <- true
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Error(err)
+ }
+ if g, e := res.StatusCode, StatusOK; g != e {
+ t.Errorf("got res.StatusCode %d; expected %d", g, e)
+ }
+ body, _ := ioutil.ReadAll(res.Body)
+ if g, e := string(body), "hi"; g != e {
+ t.Errorf("got body %q; expected %q", g, e)
+ }
+ if g := <-writeErrors; g != nil {
+ t.Errorf("got unexpected Write error on first request: %v", g)
+ }
+
+ // Times out:
+ timeout <- time.Time{}
+ res, err = cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Error(err)
+ }
+ if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
+ t.Errorf("got res.StatusCode %d; expected %d", g, e)
+ }
+ body, _ = ioutil.ReadAll(res.Body)
+ if !strings.Contains(string(body), "<title>Timeout</title>") {
+ t.Errorf("expected timeout body; got %q", string(body))
+ }
+
+ // Now make the previously-timed out handler speak again,
+ // which verifies the panic is handled:
+ sendHi <- true
+ if g, e := <-writeErrors, ErrHandlerTimeout; g != e {
+ t.Errorf("expected Write error of %v; got %v", e, g)
+ }
+}
+
// Verifies we don't path.Clean() on the wrong parts in redirects.
func TestRedirectMunging(t *testing.T) {
req, _ := NewRequest("GET", "http://example.com/", nil)
@@ -1693,15 +1906,57 @@ func TestRedirectBadPath(t *testing.T) {
}
}
+// Test different URL formats and schemes
+func TestRedirectURLFormat(t *testing.T) {
+ req, _ := NewRequest("GET", "http://example.com/qux/", nil)
+
+ var tests = []struct {
+ in string
+ want string
+ }{
+ // normal http
+ {"http://foobar.com/baz", "http://foobar.com/baz"},
+ // normal https
+ {"https://foobar.com/baz", "https://foobar.com/baz"},
+ // custom scheme
+ {"test://foobar.com/baz", "test://foobar.com/baz"},
+ // schemeless
+ {"//foobar.com/baz", "//foobar.com/baz"},
+ // relative to the root
+ {"/foobar.com/baz", "/foobar.com/baz"},
+ // relative to the current path
+ {"foobar.com/baz", "/qux/foobar.com/baz"},
+ // relative to the current path (+ going upwards)
+ {"../quux/foobar.com/baz", "/quux/foobar.com/baz"},
+ // incorrect number of slashes
+ {"///foobar.com/baz", "/foobar.com/baz"},
+ }
+
+ for _, tt := range tests {
+ rec := httptest.NewRecorder()
+ Redirect(rec, req, tt.in, 302)
+ if got := rec.Header().Get("Location"); got != tt.want {
+ t.Errorf("Redirect(%q) generated Location header %q; want %q", tt.in, got, tt.want)
+ }
+ }
+}
+
// TestZeroLengthPostAndResponse exercises an optimization done by the Transport:
// when there is no body (either because the method doesn't permit a body, or an
// explicit Content-Length of zero is present), then the transport can re-use the
// connection immediately. But when it re-uses the connection, it typically closes
// the previous request's body, which is not optimal for zero-lengthed bodies,
// as the client would then see http.ErrBodyReadAfterClose and not 0, io.EOF.
-func TestZeroLengthPostAndResponse(t *testing.T) {
+func TestZeroLengthPostAndResponse_h1(t *testing.T) {
+ testZeroLengthPostAndResponse(t, h1Mode)
+}
+func TestZeroLengthPostAndResponse_h2(t *testing.T) {
+ testZeroLengthPostAndResponse(t, h2Mode)
+}
+
+func testZeroLengthPostAndResponse(t *testing.T, h2 bool) {
defer afterTest(t)
- ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, r *Request) {
+ cst := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, r *Request) {
all, err := ioutil.ReadAll(r.Body)
if err != nil {
t.Fatalf("handler ReadAll: %v", err)
@@ -1711,9 +1966,9 @@ func TestZeroLengthPostAndResponse(t *testing.T) {
}
rw.Header().Set("Content-Length", "0")
}))
- defer ts.Close()
+ defer cst.close()
- req, err := NewRequest("POST", ts.URL, strings.NewReader(""))
+ req, err := NewRequest("POST", cst.ts.URL, strings.NewReader(""))
if err != nil {
t.Fatal(err)
}
@@ -1721,7 +1976,7 @@ func TestZeroLengthPostAndResponse(t *testing.T) {
var resp [5]*Response
for i := range resp {
- resp[i], err = DefaultClient.Do(req)
+ resp[i], err = cst.c.Do(req)
if err != nil {
t.Fatalf("client post #%d: %v", i, err)
}
@@ -1738,19 +1993,22 @@ func TestZeroLengthPostAndResponse(t *testing.T) {
}
}
-func TestHandlerPanicNil(t *testing.T) {
- testHandlerPanic(t, false, nil)
-}
+func TestHandlerPanicNil_h1(t *testing.T) { testHandlerPanic(t, false, h1Mode, nil) }
+func TestHandlerPanicNil_h2(t *testing.T) { testHandlerPanic(t, false, h2Mode, nil) }
-func TestHandlerPanic(t *testing.T) {
- testHandlerPanic(t, false, "intentional death for testing")
+func TestHandlerPanic_h1(t *testing.T) {
+ testHandlerPanic(t, false, h1Mode, "intentional death for testing")
+}
+func TestHandlerPanic_h2(t *testing.T) {
+ testHandlerPanic(t, false, h2Mode, "intentional death for testing")
}
func TestHandlerPanicWithHijack(t *testing.T) {
- testHandlerPanic(t, true, "intentional death for testing")
+ // Only testing HTTP/1, and our http2 server doesn't support hijacking.
+ testHandlerPanic(t, true, h1Mode, "intentional death for testing")
}
-func testHandlerPanic(t *testing.T, withHijack bool, panicValue interface{}) {
+func testHandlerPanic(t *testing.T, withHijack, h2 bool, panicValue interface{}) {
defer afterTest(t)
// Unlike the other tests that set the log output to ioutil.Discard
// to quiet the output, this test uses a pipe. The pipe serves three
@@ -1773,7 +2031,7 @@ func testHandlerPanic(t *testing.T, withHijack bool, panicValue interface{}) {
defer log.SetOutput(os.Stderr)
defer pw.Close()
- ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
if withHijack {
rwc, _, err := w.(Hijacker).Hijack()
if err != nil {
@@ -1783,7 +2041,7 @@ func testHandlerPanic(t *testing.T, withHijack bool, panicValue interface{}) {
}
panic(panicValue)
}))
- defer ts.Close()
+ defer cst.close()
// Do a blocking read on the log output pipe so its logging
// doesn't bleed into the next test. But wait only 5 seconds
@@ -1799,7 +2057,7 @@ func testHandlerPanic(t *testing.T, withHijack bool, panicValue interface{}) {
done <- true
}()
- _, err := Get(ts.URL)
+ _, err := cst.c.Get(cst.ts.URL)
if err == nil {
t.Logf("expected an error")
}
@@ -1816,17 +2074,19 @@ func testHandlerPanic(t *testing.T, withHijack bool, panicValue interface{}) {
}
}
-func TestServerNoDate(t *testing.T) { testServerNoHeader(t, "Date") }
-func TestServerNoContentType(t *testing.T) { testServerNoHeader(t, "Content-Type") }
+func TestServerNoDate_h1(t *testing.T) { testServerNoHeader(t, h1Mode, "Date") }
+func TestServerNoDate_h2(t *testing.T) { testServerNoHeader(t, h2Mode, "Date") }
+func TestServerNoContentType_h1(t *testing.T) { testServerNoHeader(t, h1Mode, "Content-Type") }
+func TestServerNoContentType_h2(t *testing.T) { testServerNoHeader(t, h2Mode, "Content-Type") }
-func testServerNoHeader(t *testing.T, header string) {
+func testServerNoHeader(t *testing.T, h2 bool, header string) {
defer afterTest(t)
- ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
w.Header()[header] = nil
io.WriteString(w, "<html>foo</html>") // non-empty
}))
- defer ts.Close()
- res, err := Get(ts.URL)
+ defer cst.close()
+ res, err := cst.c.Get(cst.ts.URL)
if err != nil {
t.Fatal(err)
}
@@ -1863,18 +2123,20 @@ func TestStripPrefix(t *testing.T) {
res.Body.Close()
}
-func TestRequestLimit(t *testing.T) {
+func TestRequestLimit_h1(t *testing.T) { testRequestLimit(t, h1Mode) }
+func TestRequestLimit_h2(t *testing.T) { testRequestLimit(t, h2Mode) }
+func testRequestLimit(t *testing.T, h2 bool) {
defer afterTest(t)
- ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
t.Fatalf("didn't expect to get request in Handler")
}))
- defer ts.Close()
- req, _ := NewRequest("GET", ts.URL, nil)
+ defer cst.close()
+ req, _ := NewRequest("GET", cst.ts.URL, nil)
var bytesPerHeader = len("header12345: val12345\r\n")
for i := 0; i < ((DefaultMaxHeaderBytes+4096)/bytesPerHeader)+1; i++ {
req.Header.Set(fmt.Sprintf("header%05d", i), fmt.Sprintf("val%05d", i))
}
- res, err := DefaultClient.Do(req)
+ res, err := cst.c.Do(req)
if err != nil {
// Some HTTP clients may fail on this undefined behavior (server replying and
// closing the connection while the request is still being written), but
@@ -1882,8 +2144,8 @@ func TestRequestLimit(t *testing.T) {
t.Fatalf("Do: %v", err)
}
defer res.Body.Close()
- if res.StatusCode != 413 {
- t.Fatalf("expected 413 response status; got: %d %s", res.StatusCode, res.Status)
+ if res.StatusCode != 431 {
+ t.Fatalf("expected 431 response status; got: %d %s", res.StatusCode, res.Status)
}
}
@@ -1907,10 +2169,12 @@ func (cr countReader) Read(p []byte) (n int, err error) {
return
}
-func TestRequestBodyLimit(t *testing.T) {
+func TestRequestBodyLimit_h1(t *testing.T) { testRequestBodyLimit(t, h1Mode) }
+func TestRequestBodyLimit_h2(t *testing.T) { testRequestBodyLimit(t, h2Mode) }
+func testRequestBodyLimit(t *testing.T, h2 bool) {
defer afterTest(t)
const limit = 1 << 20
- ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
r.Body = MaxBytesReader(w, r.Body, limit)
n, err := io.Copy(ioutil.Discard, r.Body)
if err == nil {
@@ -1920,10 +2184,10 @@ func TestRequestBodyLimit(t *testing.T) {
t.Errorf("io.Copy = %d, want %d", n, limit)
}
}))
- defer ts.Close()
+ defer cst.close()
nWritten := new(int64)
- req, _ := NewRequest("POST", ts.URL, io.LimitReader(countReader{neverEnding('a'), nWritten}, limit*200))
+ req, _ := NewRequest("POST", cst.ts.URL, io.LimitReader(countReader{neverEnding('a'), nWritten}, limit*200))
// Send the POST, but don't care it succeeds or not. The
// remote side is going to reply and then close the TCP
@@ -1934,7 +2198,7 @@ func TestRequestBodyLimit(t *testing.T) {
//
// But that's okay, since what we're really testing is that
// the remote side hung up on us before we wrote too much.
- _, _ = DefaultClient.Do(req)
+ _, _ = cst.c.Do(req)
if atomic.LoadInt64(nWritten) > limit*100 {
t.Errorf("handler restricted the request body to %d bytes, but client managed to write %d",
@@ -1982,7 +2246,7 @@ func TestClientWriteShutdown(t *testing.T) {
// buffered before chunk headers are added, not after chunk headers.
func TestServerBufferedChunking(t *testing.T) {
conn := new(testConn)
- conn.readBuf.Write([]byte("GET / HTTP/1.1\r\n\r\n"))
+ conn.readBuf.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"))
conn.closec = make(chan bool, 1)
ls := &oneConnListener{conn}
go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) {
@@ -2045,20 +2309,23 @@ func TestServerGracefulClose(t *testing.T) {
<-writeErr
}
-func TestCaseSensitiveMethod(t *testing.T) {
+func TestCaseSensitiveMethod_h1(t *testing.T) { testCaseSensitiveMethod(t, h1Mode) }
+func TestCaseSensitiveMethod_h2(t *testing.T) { testCaseSensitiveMethod(t, h2Mode) }
+func testCaseSensitiveMethod(t *testing.T, h2 bool) {
defer afterTest(t)
- ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
if r.Method != "get" {
t.Errorf(`Got method %q; want "get"`, r.Method)
}
}))
- defer ts.Close()
- req, _ := NewRequest("get", ts.URL, nil)
- res, err := DefaultClient.Do(req)
+ defer cst.close()
+ req, _ := NewRequest("get", cst.ts.URL, nil)
+ res, err := cst.c.Do(req)
if err != nil {
t.Error(err)
return
}
+
res.Body.Close()
}
@@ -2131,6 +2398,49 @@ For:
ts.Close()
}
+// Tests that a pipelined request causes the first request's Handler's CloseNotify
+// channel to fire. Previously it deadlocked.
+//
+// Issue 13165
+func TestCloseNotifierPipelined(t *testing.T) {
+ defer afterTest(t)
+ gotReq := make(chan bool, 2)
+ sawClose := make(chan bool, 2)
+ ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
+ gotReq <- true
+ cc := rw.(CloseNotifier).CloseNotify()
+ <-cc
+ sawClose <- true
+ }))
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatalf("error dialing: %v", err)
+ }
+ diec := make(chan bool, 2)
+ go func() {
+ const req = "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n"
+ _, err = io.WriteString(conn, req+req) // two requests
+ if err != nil {
+ t.Fatal(err)
+ }
+ <-diec
+ conn.Close()
+ }()
+For:
+ for {
+ select {
+ case <-gotReq:
+ diec <- true
+ case <-sawClose:
+ break For
+ case <-time.After(5 * time.Second):
+ ts.CloseClientConnections()
+ t.Fatal("timeout")
+ }
+ }
+ ts.Close()
+}
+
func TestCloseNotifierChanLeak(t *testing.T) {
defer afterTest(t)
req := reqBytes("GET / HTTP/1.0\nHost: golang.org")
@@ -2153,6 +2463,114 @@ func TestCloseNotifierChanLeak(t *testing.T) {
}
}
+// Tests that we can use CloseNotifier in one request, and later call Hijack
+// on a second request on the same connection.
+//
+// It also tests that the connReader stitches together its background
+// 1-byte read for CloseNotifier when CloseNotifier doesn't fire with
+// the rest of the second HTTP later.
+//
+// Issue 9763.
+// HTTP/1-only test. (http2 doesn't have Hijack)
+func TestHijackAfterCloseNotifier(t *testing.T) {
+ defer afterTest(t)
+ script := make(chan string, 2)
+ script <- "closenotify"
+ script <- "hijack"
+ close(script)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ plan := <-script
+ switch plan {
+ default:
+ panic("bogus plan; too many requests")
+ case "closenotify":
+ w.(CloseNotifier).CloseNotify() // discard result
+ w.Header().Set("X-Addr", r.RemoteAddr)
+ case "hijack":
+ c, _, err := w.(Hijacker).Hijack()
+ if err != nil {
+ t.Errorf("Hijack in Handler: %v", err)
+ return
+ }
+ if _, ok := c.(*net.TCPConn); !ok {
+ // Verify it's not wrapped in some type.
+ // Not strictly a go1 compat issue, but in practice it probably is.
+ t.Errorf("type of hijacked conn is %T; want *net.TCPConn", c)
+ }
+ fmt.Fprintf(c, "HTTP/1.0 200 OK\r\nX-Addr: %v\r\nContent-Length: 0\r\n\r\n", r.RemoteAddr)
+ c.Close()
+ return
+ }
+ }))
+ defer ts.Close()
+ res1, err := Get(ts.URL)
+ if err != nil {
+ log.Fatal(err)
+ }
+ res2, err := Get(ts.URL)
+ if err != nil {
+ log.Fatal(err)
+ }
+ addr1 := res1.Header.Get("X-Addr")
+ addr2 := res2.Header.Get("X-Addr")
+ if addr1 == "" || addr1 != addr2 {
+ t.Errorf("addr1, addr2 = %q, %q; want same", addr1, addr2)
+ }
+}
+
+func TestHijackBeforeRequestBodyRead(t *testing.T) {
+ defer afterTest(t)
+ var requestBody = bytes.Repeat([]byte("a"), 1<<20)
+ bodyOkay := make(chan bool, 1)
+ gotCloseNotify := make(chan bool, 1)
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ defer close(bodyOkay) // caller will read false if nothing else
+
+ reqBody := r.Body
+ r.Body = nil // to test that server.go doesn't use this value.
+
+ gone := w.(CloseNotifier).CloseNotify()
+ slurp, err := ioutil.ReadAll(reqBody)
+ if err != nil {
+ t.Errorf("Body read: %v", err)
+ return
+ }
+ if len(slurp) != len(requestBody) {
+ t.Errorf("Backend read %d request body bytes; want %d", len(slurp), len(requestBody))
+ return
+ }
+ if !bytes.Equal(slurp, requestBody) {
+ t.Error("Backend read wrong request body.") // 1MB; omitting details
+ return
+ }
+ bodyOkay <- true
+ select {
+ case <-gone:
+ gotCloseNotify <- true
+ case <-time.After(5 * time.Second):
+ gotCloseNotify <- false
+ }
+ }))
+ defer ts.Close()
+
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer conn.Close()
+
+ fmt.Fprintf(conn, "POST / HTTP/1.1\r\nHost: foo\r\nContent-Length: %d\r\n\r\n%s",
+ len(requestBody), requestBody)
+ if !<-bodyOkay {
+ // already failed.
+ return
+ }
+ conn.Close()
+ if !<-gotCloseNotify {
+ t.Error("timeout waiting for CloseNotify")
+ }
+}
+
func TestOptions(t *testing.T) {
uric := make(chan string, 2) // only expect 1, but leave space for 2
mux := NewServeMux()
@@ -2230,7 +2648,7 @@ func TestHeaderToWire(t *testing.T) {
return errors.New("no content-length")
}
if !strings.Contains(got, "Content-Type: text/plain") {
- return errors.New("no content-length")
+ return errors.New("no content-type")
}
return nil
},
@@ -2302,7 +2720,7 @@ func TestHeaderToWire(t *testing.T) {
return errors.New("header appeared from after WriteHeader")
}
if !strings.Contains(got, "Content-Type: some/type") {
- return errors.New("wrong content-length")
+ return errors.New("wrong content-type")
}
return nil
},
@@ -2315,7 +2733,7 @@ func TestHeaderToWire(t *testing.T) {
},
check: func(got string) error {
if !strings.Contains(got, "Content-Type: text/html") {
- return errors.New("wrong content-length; want html")
+ return errors.New("wrong content-type; want html")
}
return nil
},
@@ -2328,7 +2746,7 @@ func TestHeaderToWire(t *testing.T) {
},
check: func(got string) error {
if !strings.Contains(got, "Content-Type: some/type") {
- return errors.New("wrong content-length; want html")
+ return errors.New("wrong content-type; want html")
}
return nil
},
@@ -2339,7 +2757,7 @@ func TestHeaderToWire(t *testing.T) {
},
check: func(got string) error {
if !strings.Contains(got, "Content-Type: text/plain") {
- return errors.New("wrong content-length; want text/plain")
+ return errors.New("wrong content-type; want text/plain")
}
if !strings.Contains(got, "Content-Length: 0") {
return errors.New("want 0 content-length")
@@ -2369,7 +2787,7 @@ func TestHeaderToWire(t *testing.T) {
if !strings.Contains(got, "404") {
return errors.New("wrong status")
}
- if strings.Contains(got, "Some-Header") {
+ if strings.Contains(got, "Too-Late") {
return errors.New("shouldn't have seen Too-Late")
}
return nil
@@ -2503,7 +2921,7 @@ func TestHTTP10ConnectionHeader(t *testing.T) {
defer afterTest(t)
mux := NewServeMux()
- mux.Handle("/", HandlerFunc(func(resp ResponseWriter, req *Request) {}))
+ mux.Handle("/", HandlerFunc(func(ResponseWriter, *Request) {}))
ts := httptest.NewServer(mux)
defer ts.Close()
@@ -2552,11 +2970,13 @@ func TestHTTP10ConnectionHeader(t *testing.T) {
}
// See golang.org/issue/5660
-func TestServerReaderFromOrder(t *testing.T) {
+func TestServerReaderFromOrder_h1(t *testing.T) { testServerReaderFromOrder(t, h1Mode) }
+func TestServerReaderFromOrder_h2(t *testing.T) { testServerReaderFromOrder(t, h2Mode) }
+func testServerReaderFromOrder(t *testing.T, h2 bool) {
defer afterTest(t)
pr, pw := io.Pipe()
const size = 3 << 20
- ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
+ cst := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, req *Request) {
rw.Header().Set("Content-Type", "text/plain") // prevent sniffing path
done := make(chan bool)
go func() {
@@ -2576,13 +2996,13 @@ func TestServerReaderFromOrder(t *testing.T) {
pw.Close()
<-done
}))
- defer ts.Close()
+ defer cst.close()
- req, err := NewRequest("POST", ts.URL, io.LimitReader(neverEnding('a'), size))
+ req, err := NewRequest("POST", cst.ts.URL, io.LimitReader(neverEnding('a'), size))
if err != nil {
t.Fatal(err)
}
- res, err := DefaultClient.Do(req)
+ res, err := cst.c.Do(req)
if err != nil {
t.Fatal(err)
}
@@ -2612,9 +3032,9 @@ func TestCodesPreventingContentTypeAndBody(t *testing.T) {
"GET / HTTP/1.0",
"GET /header HTTP/1.0",
"GET /more HTTP/1.0",
- "GET / HTTP/1.1",
- "GET /header HTTP/1.1",
- "GET /more HTTP/1.1",
+ "GET / HTTP/1.1\nHost: foo",
+ "GET /header HTTP/1.1\nHost: foo",
+ "GET /more HTTP/1.1\nHost: foo",
} {
got := ht.rawResponse(req)
wantStatus := fmt.Sprintf("%d %s", code, StatusText(code))
@@ -2635,7 +3055,7 @@ func TestContentTypeOkayOn204(t *testing.T) {
w.Header().Set("Content-Type", "foo/bar")
w.WriteHeader(204)
}))
- got := ht.rawResponse("GET / HTTP/1.1")
+ got := ht.rawResponse("GET / HTTP/1.1\nHost: foo")
if !strings.Contains(got, "Content-Type: foo/bar") {
t.Errorf("Response = %q; want Content-Type: foo/bar", got)
}
@@ -2650,45 +3070,101 @@ func TestContentTypeOkayOn204(t *testing.T) {
// proxy). So then two people own that Request.Body (both the server
// and the http client), and both think they can close it on failure.
// Therefore, all incoming server requests Bodies need to be thread-safe.
-func TestTransportAndServerSharedBodyRace(t *testing.T) {
+func TestTransportAndServerSharedBodyRace_h1(t *testing.T) {
+ testTransportAndServerSharedBodyRace(t, h1Mode)
+}
+func TestTransportAndServerSharedBodyRace_h2(t *testing.T) {
+ testTransportAndServerSharedBodyRace(t, h2Mode)
+}
+func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) {
defer afterTest(t)
const bodySize = 1 << 20
+ // errorf is like t.Errorf, but also writes to println. When
+ // this test fails, it hangs. This helps debugging and I've
+ // added this enough times "temporarily". It now gets added
+ // full time.
+ errorf := func(format string, args ...interface{}) {
+ v := fmt.Sprintf(format, args...)
+ println(v)
+ t.Error(v)
+ }
+
unblockBackend := make(chan bool)
- backend := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
- io.CopyN(rw, req.Body, bodySize)
+ backend := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, req *Request) {
+ gone := rw.(CloseNotifier).CloseNotify()
+ didCopy := make(chan interface{})
+ go func() {
+ n, err := io.CopyN(rw, req.Body, bodySize)
+ didCopy <- []interface{}{n, err}
+ }()
+ isGone := false
+ Loop:
+ for {
+ select {
+ case <-didCopy:
+ break Loop
+ case <-gone:
+ isGone = true
+ case <-time.After(time.Second):
+ println("1 second passes in backend, proxygone=", isGone)
+ }
+ }
<-unblockBackend
}))
- defer backend.Close()
+ var quitTimer *time.Timer
+ defer func() { quitTimer.Stop() }()
+ defer backend.close()
backendRespc := make(chan *Response, 1)
- proxy := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
- req2, _ := NewRequest("POST", backend.URL, req.Body)
+ var proxy *clientServerTest
+ proxy = newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, req *Request) {
+ req2, _ := NewRequest("POST", backend.ts.URL, req.Body)
req2.ContentLength = bodySize
+ cancel := make(chan struct{})
+ req2.Cancel = cancel
- bresp, err := DefaultClient.Do(req2)
+ bresp, err := proxy.c.Do(req2)
if err != nil {
- t.Errorf("Proxy outbound request: %v", err)
+ errorf("Proxy outbound request: %v", err)
return
}
_, err = io.CopyN(ioutil.Discard, bresp.Body, bodySize/2)
if err != nil {
- t.Errorf("Proxy copy error: %v", err)
+ errorf("Proxy copy error: %v", err)
return
}
backendRespc <- bresp // to close later
- // Try to cause a race: Both the DefaultTransport and the proxy handler's Server
+ // Try to cause a race: Both the Transport and the proxy handler's Server
// will try to read/close req.Body (aka req2.Body)
- DefaultTransport.(*Transport).CancelRequest(req2)
+ if h2 {
+ close(cancel)
+ } else {
+ proxy.c.Transport.(*Transport).CancelRequest(req2)
+ }
rw.Write([]byte("OK"))
}))
- defer proxy.Close()
+ defer proxy.close()
+ defer func() {
+ // Before we shut down our two httptest.Servers, start a timer.
+ // We choose 7 seconds because httptest.Server starts logging
+ // warnings to stderr at 5 seconds. If we don't disarm this bomb
+ // in 7 seconds (after the two httptest.Server.Close calls above),
+ // then we explode with stacks.
+ quitTimer = time.AfterFunc(7*time.Second, func() {
+ debug.SetTraceback("ALL")
+ stacks := make([]byte, 1<<20)
+ stacks = stacks[:runtime.Stack(stacks, true)]
+ fmt.Fprintf(os.Stderr, "%s", stacks)
+ log.Fatalf("Timeout.")
+ })
+ }()
defer close(unblockBackend)
- req, _ := NewRequest("POST", proxy.URL, io.LimitReader(neverEnding('a'), bodySize))
- res, err := DefaultClient.Do(req)
+ req, _ := NewRequest("POST", proxy.ts.URL, io.LimitReader(neverEnding('a'), bodySize))
+ res, err := proxy.c.Do(req)
if err != nil {
t.Fatalf("Original request: %v", err)
}
@@ -2699,7 +3175,7 @@ func TestTransportAndServerSharedBodyRace(t *testing.T) {
case res := <-backendRespc:
res.Body.Close()
default:
- // We failed earlier. (e.g. on DefaultClient.Do(req2))
+ // We failed earlier. (e.g. on proxy.c.Do(req2))
}
}
@@ -2863,6 +3339,7 @@ func TestServerConnState(t *testing.T) {
if _, err := io.WriteString(c, "BOGUS REQUEST\r\n\r\n"); err != nil {
t.Fatal(err)
}
+ c.Read(make([]byte, 1)) // block until server hangs up on us
c.Close()
}
@@ -2896,9 +3373,14 @@ func TestServerConnState(t *testing.T) {
}
logString := func(m map[int][]ConnState) string {
var b bytes.Buffer
- for id, l := range m {
+ var keys []int
+ for id := range m {
+ keys = append(keys, id)
+ }
+ sort.Ints(keys)
+ for _, id := range keys {
fmt.Fprintf(&b, "Conn %d: ", id)
- for _, s := range l {
+ for _, s := range m[id] {
fmt.Fprintf(&b, "%s ", s)
}
b.WriteString("\n")
@@ -2959,20 +3441,22 @@ func TestServerKeepAlivesEnabled(t *testing.T) {
}
// golang.org/issue/7856
-func TestServerEmptyBodyRace(t *testing.T) {
+func TestServerEmptyBodyRace_h1(t *testing.T) { testServerEmptyBodyRace(t, h1Mode) }
+func TestServerEmptyBodyRace_h2(t *testing.T) { testServerEmptyBodyRace(t, h2Mode) }
+func testServerEmptyBodyRace(t *testing.T, h2 bool) {
defer afterTest(t)
var n int32
- ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
+ cst := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, req *Request) {
atomic.AddInt32(&n, 1)
}))
- defer ts.Close()
+ defer cst.close()
var wg sync.WaitGroup
const reqs = 20
for i := 0; i < reqs; i++ {
wg.Add(1)
go func() {
defer wg.Done()
- res, err := Get(ts.URL)
+ res, err := cst.c.Get(cst.ts.URL)
if err != nil {
t.Error(err)
return
@@ -3025,10 +3509,7 @@ func (c *closeWriteTestConn) CloseWrite() error {
func TestCloseWrite(t *testing.T) {
var srv Server
var testConn closeWriteTestConn
- c, err := ExportServerNewConn(&srv, &testConn)
- if err != nil {
- t.Fatal(err)
- }
+ c := ExportServerNewConn(&srv, &testConn)
ExportCloseWriteAndWait(c)
if !testConn.didCloseWrite {
t.Error("didn't see CloseWrite call")
@@ -3193,6 +3674,33 @@ func TestTolerateCRLFBeforeRequestLine(t *testing.T) {
}
}
+func TestIssue13893_Expect100(t *testing.T) {
+ // test that the Server doesn't filter out Expect headers.
+ req := reqBytes(`PUT /readbody HTTP/1.1
+User-Agent: PycURL/7.22.0
+Host: 127.0.0.1:9000
+Accept: */*
+Expect: 100-continue
+Content-Length: 10
+
+HelloWorld
+
+`)
+ var buf bytes.Buffer
+ conn := &rwTestConn{
+ Reader: bytes.NewReader(req),
+ Writer: &buf,
+ closec: make(chan bool, 1),
+ }
+ ln := &oneConnListener{conn: conn}
+ go Serve(ln, HandlerFunc(func(w ResponseWriter, r *Request) {
+ if _, ok := r.Header["Expect"]; !ok {
+ t.Error("Expect header should not be filtered out")
+ }
+ }))
+ <-conn.closec
+}
+
func TestIssue11549_Expect100(t *testing.T) {
req := reqBytes(`PUT /readbody HTTP/1.1
User-Agent: PycURL/7.22.0
@@ -3260,6 +3768,122 @@ func TestHandlerFinishSkipBigContentLengthRead(t *testing.T) {
}
}
+func TestHandlerSetsBodyNil_h1(t *testing.T) { testHandlerSetsBodyNil(t, h1Mode) }
+func TestHandlerSetsBodyNil_h2(t *testing.T) { testHandlerSetsBodyNil(t, h2Mode) }
+func testHandlerSetsBodyNil(t *testing.T, h2 bool) {
+ defer afterTest(t)
+ cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
+ r.Body = nil
+ fmt.Fprintf(w, "%v", r.RemoteAddr)
+ }))
+ defer cst.close()
+ get := func() string {
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+ slurp, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ return string(slurp)
+ }
+ a, b := get(), get()
+ if a != b {
+ t.Errorf("Failed to reuse connections between requests: %v vs %v", a, b)
+ }
+}
+
+// Test that we validate the Host header.
+// Issue 11206 (invalid bytes in Host) and 13624 (Host present in HTTP/1.1)
+func TestServerValidatesHostHeader(t *testing.T) {
+ tests := []struct {
+ proto string
+ host string
+ want int
+ }{
+ {"HTTP/1.1", "", 400},
+ {"HTTP/1.1", "Host: \r\n", 200},
+ {"HTTP/1.1", "Host: 1.2.3.4\r\n", 200},
+ {"HTTP/1.1", "Host: foo.com\r\n", 200},
+ {"HTTP/1.1", "Host: foo-bar_baz.com\r\n", 200},
+ {"HTTP/1.1", "Host: foo.com:80\r\n", 200},
+ {"HTTP/1.1", "Host: ::1\r\n", 200},
+ {"HTTP/1.1", "Host: [::1]\r\n", 200}, // questionable without port, but accept it
+ {"HTTP/1.1", "Host: [::1]:80\r\n", 200},
+ {"HTTP/1.1", "Host: [::1%25en0]:80\r\n", 200},
+ {"HTTP/1.1", "Host: 1.2.3.4\r\n", 200},
+ {"HTTP/1.1", "Host: \x06\r\n", 400},
+ {"HTTP/1.1", "Host: \xff\r\n", 400},
+ {"HTTP/1.1", "Host: {\r\n", 400},
+ {"HTTP/1.1", "Host: }\r\n", 400},
+ {"HTTP/1.1", "Host: first\r\nHost: second\r\n", 400},
+
+ // HTTP/1.0 can lack a host header, but if present
+ // must play by the rules too:
+ {"HTTP/1.0", "", 200},
+ {"HTTP/1.0", "Host: first\r\nHost: second\r\n", 400},
+ {"HTTP/1.0", "Host: \xff\r\n", 400},
+ }
+ for _, tt := range tests {
+ conn := &testConn{closec: make(chan bool, 1)}
+ io.WriteString(&conn.readBuf, "GET / "+tt.proto+"\r\n"+tt.host+"\r\n")
+
+ ln := &oneConnListener{conn}
+ go Serve(ln, HandlerFunc(func(ResponseWriter, *Request) {}))
+ <-conn.closec
+ res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
+ if err != nil {
+ t.Errorf("For %s %q, ReadResponse: %v", tt.proto, tt.host, res)
+ continue
+ }
+ if res.StatusCode != tt.want {
+ t.Errorf("For %s %q, Status = %d; want %d", tt.proto, tt.host, res.StatusCode, tt.want)
+ }
+ }
+}
+
+// Test that we validate the valid bytes in HTTP/1 headers.
+// Issue 11207.
+func TestServerValidatesHeaders(t *testing.T) {
+ tests := []struct {
+ header string
+ want int
+ }{
+ {"", 200},
+ {"Foo: bar\r\n", 200},
+ {"X-Foo: bar\r\n", 200},
+ {"Foo: a space\r\n", 200},
+
+ {"A space: foo\r\n", 400}, // space in header
+ {"foo\xffbar: foo\r\n", 400}, // binary in header
+ {"foo\x00bar: foo\r\n", 400}, // binary in header
+
+ {"foo: foo foo\r\n", 200}, // LWS space is okay
+ {"foo: foo\tfoo\r\n", 200}, // LWS tab is okay
+ {"foo: foo\x00foo\r\n", 400}, // CTL 0x00 in value is bad
+ {"foo: foo\x7ffoo\r\n", 400}, // CTL 0x7f in value is bad
+ {"foo: foo\xfffoo\r\n", 200}, // non-ASCII high octets in value are fine
+ }
+ for _, tt := range tests {
+ conn := &testConn{closec: make(chan bool, 1)}
+ io.WriteString(&conn.readBuf, "GET / HTTP/1.1\r\nHost: foo\r\n"+tt.header+"\r\n")
+
+ ln := &oneConnListener{conn}
+ go Serve(ln, HandlerFunc(func(ResponseWriter, *Request) {}))
+ <-conn.closec
+ res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil)
+ if err != nil {
+ t.Errorf("For %q, ReadResponse: %v", tt.header, res)
+ continue
+ }
+ if res.StatusCode != tt.want {
+ t.Errorf("For %q, Status = %d; want %d", tt.header, res.StatusCode, tt.want)
+ }
+ }
+}
+
func BenchmarkClientServer(b *testing.B) {
b.ReportAllocs()
b.StopTimer()
@@ -3685,3 +4309,35 @@ Host: golang.org
<-conn.closec
}
}
+
+func BenchmarkCloseNotifier(b *testing.B) {
+ b.ReportAllocs()
+ b.StopTimer()
+ sawClose := make(chan bool)
+ ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
+ <-rw.(CloseNotifier).CloseNotify()
+ sawClose <- true
+ }))
+ defer ts.Close()
+ tot := time.NewTimer(5 * time.Second)
+ defer tot.Stop()
+ b.StartTimer()
+ for i := 0; i < b.N; i++ {
+ conn, err := net.Dial("tcp", ts.Listener.Addr().String())
+ if err != nil {
+ b.Fatalf("error dialing: %v", err)
+ }
+ _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n")
+ if err != nil {
+ b.Fatal(err)
+ }
+ conn.Close()
+ tot.Reset(5 * time.Second)
+ select {
+ case <-sawClose:
+ case <-tot.C:
+ b.Fatal("timeout")
+ }
+ }
+ b.StopTimer()
+}
diff --git a/libgo/go/net/http/server.go b/libgo/go/net/http/server.go
index a3e43555bb3..004a1f92fc4 100644
--- a/libgo/go/net/http/server.go
+++ b/libgo/go/net/http/server.go
@@ -8,6 +8,7 @@ package http
import (
"bufio"
+ "bytes"
"crypto/tls"
"errors"
"fmt"
@@ -35,26 +36,33 @@ var (
ErrContentLength = errors.New("Conn.Write wrote more than the declared Content-Length")
)
-// Objects implementing the Handler interface can be
-// registered to serve a particular path or subtree
-// in the HTTP server.
+// A Handler responds to an HTTP request.
//
// ServeHTTP should write reply headers and data to the ResponseWriter
-// and then return. Returning signals that the request is finished
-// and that the HTTP server can move on to the next request on
-// the connection.
+// and then return. Returning signals that the request is finished; it
+// is not valid to use the ResponseWriter or read from the
+// Request.Body after or concurrently with the completion of the
+// ServeHTTP call.
+//
+// Depending on the HTTP client software, HTTP protocol version, and
+// any intermediaries between the client and the Go server, it may not
+// be possible to read from the Request.Body after writing to the
+// ResponseWriter. Cautious handlers should read the Request.Body
+// first, and then reply.
//
// If ServeHTTP panics, the server (the caller of ServeHTTP) assumes
// that the effect of the panic was isolated to the active request.
// It recovers the panic, logs a stack trace to the server error log,
// and hangs up the connection.
-//
type Handler interface {
ServeHTTP(ResponseWriter, *Request)
}
// A ResponseWriter interface is used by an HTTP handler to
// construct an HTTP response.
+//
+// A ResponseWriter may not be used after the Handler.ServeHTTP method
+// has returned.
type ResponseWriter interface {
// Header returns the header map that will be sent by
// WriteHeader. Changing the header after a call to
@@ -114,28 +122,76 @@ type Hijacker interface {
// This mechanism can be used to cancel long operations on the server
// if the client has disconnected before the response is ready.
type CloseNotifier interface {
- // CloseNotify returns a channel that receives a single value
- // when the client connection has gone away.
+ // CloseNotify returns a channel that receives at most a
+ // single value (true) when the client connection has gone
+ // away.
+ //
+ // CloseNotify may wait to notify until Request.Body has been
+ // fully read.
+ //
+ // After the Handler has returned, there is no guarantee
+ // that the channel receives a value.
+ //
+ // If the protocol is HTTP/1.1 and CloseNotify is called while
+ // processing an idempotent request (such a GET) while
+ // HTTP/1.1 pipelining is in use, the arrival of a subsequent
+ // pipelined request may cause a value to be sent on the
+ // returned channel. In practice HTTP/1.1 pipelining is not
+ // enabled in browsers and not seen often in the wild. If this
+ // is a problem, use HTTP/2 or only use CloseNotify on methods
+ // such as POST.
CloseNotify() <-chan bool
}
// A conn represents the server side of an HTTP connection.
type conn struct {
- remoteAddr string // network address of remote side
- server *Server // the Server on which the connection arrived
- rwc net.Conn // i/o connection
- w io.Writer // checkConnErrorWriter's copy of wrc, not zeroed on Hijack
- werr error // any errors writing to w
- sr liveSwitchReader // where the LimitReader reads from; usually the rwc
- lr *io.LimitedReader // io.LimitReader(sr)
- buf *bufio.ReadWriter // buffered(lr,rwc), reading from bufio->limitReader->sr->rwc
- tlsState *tls.ConnectionState // or nil when not using TLS
- lastMethod string // method of previous request, or ""
-
- mu sync.Mutex // guards the following
- clientGone bool // if client has disconnected mid-request
- closeNotifyc chan bool // made lazily
- hijackedv bool // connection has been hijacked by handler
+ // server is the server on which the connection arrived.
+ // Immutable; never nil.
+ server *Server
+
+ // rwc is the underlying network connection.
+ // This is never wrapped by other types and is the value given out
+ // to CloseNotifier callers. It is usually of type *net.TCPConn or
+ // *tls.Conn.
+ rwc net.Conn
+
+ // remoteAddr is rwc.RemoteAddr().String(). It is not populated synchronously
+ // inside the Listener's Accept goroutine, as some implementations block.
+ // It is populated immediately inside the (*conn).serve goroutine.
+ // This is the value of a Handler's (*Request).RemoteAddr.
+ remoteAddr string
+
+ // tlsState is the TLS connection state when using TLS.
+ // nil means not TLS.
+ tlsState *tls.ConnectionState
+
+ // werr is set to the first write error to rwc.
+ // It is set via checkConnErrorWriter{w}, where bufw writes.
+ werr error
+
+ // r is bufr's read source. It's a wrapper around rwc that provides
+ // io.LimitedReader-style limiting (while reading request headers)
+ // and functionality to support CloseNotifier. See *connReader docs.
+ r *connReader
+
+ // bufr reads from r.
+ // Users of bufr must hold mu.
+ bufr *bufio.Reader
+
+ // bufw writes to checkConnErrorWriter{c}, which populates werr on error.
+ bufw *bufio.Writer
+
+ // lastMethod is the method of the most recent request
+ // on this connection, if any.
+ lastMethod string
+
+ // mu guards hijackedv, use of bufr, (*response).closeNotifyCh.
+ mu sync.Mutex
+
+ // hijackedv is whether this connection has been hijacked
+ // by a Handler with the Hijacker interface.
+ // It is guarded by mu.
+ hijackedv bool
}
func (c *conn) hijacked() bool {
@@ -144,81 +200,18 @@ func (c *conn) hijacked() bool {
return c.hijackedv
}
-func (c *conn) hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) {
- c.mu.Lock()
- defer c.mu.Unlock()
+// c.mu must be held.
+func (c *conn) hijackLocked() (rwc net.Conn, buf *bufio.ReadWriter, err error) {
if c.hijackedv {
return nil, nil, ErrHijacked
}
- if c.closeNotifyc != nil {
- return nil, nil, errors.New("http: Hijack is incompatible with use of CloseNotifier")
- }
c.hijackedv = true
rwc = c.rwc
- buf = c.buf
- c.rwc = nil
- c.buf = nil
+ buf = bufio.NewReadWriter(c.bufr, bufio.NewWriter(rwc))
c.setState(rwc, StateHijacked)
return
}
-func (c *conn) closeNotify() <-chan bool {
- c.mu.Lock()
- defer c.mu.Unlock()
- if c.closeNotifyc == nil {
- c.closeNotifyc = make(chan bool, 1)
- if c.hijackedv {
- // to obey the function signature, even though
- // it'll never receive a value.
- return c.closeNotifyc
- }
- pr, pw := io.Pipe()
-
- readSource := c.sr.r
- c.sr.Lock()
- c.sr.r = pr
- c.sr.Unlock()
- go func() {
- _, err := io.Copy(pw, readSource)
- if err == nil {
- err = io.EOF
- }
- pw.CloseWithError(err)
- c.noteClientGone()
- }()
- }
- return c.closeNotifyc
-}
-
-func (c *conn) noteClientGone() {
- c.mu.Lock()
- defer c.mu.Unlock()
- if c.closeNotifyc != nil && !c.clientGone {
- c.closeNotifyc <- true
- }
- c.clientGone = true
-}
-
-// A switchWriter can have its Writer changed at runtime.
-// It's not safe for concurrent Writes and switches.
-type switchWriter struct {
- io.Writer
-}
-
-// A liveSwitchReader can have its Reader changed at runtime. It's
-// safe for concurrent reads and switches, if its mutex is held.
-type liveSwitchReader struct {
- sync.Mutex
- r io.Reader
-}
-
-func (sr *liveSwitchReader) Read(p []byte) (n int, err error) {
- sr.Lock()
- r := sr.r
- sr.Unlock()
- return r.Read(p)
-}
-
// This should be >= 512 bytes for DetectContentType,
// but otherwise it's somewhat arbitrary.
const bufferBeforeChunkingSize = 2048
@@ -265,15 +258,15 @@ func (cw *chunkWriter) Write(p []byte) (n int, err error) {
return len(p), nil
}
if cw.chunking {
- _, err = fmt.Fprintf(cw.res.conn.buf, "%x\r\n", len(p))
+ _, err = fmt.Fprintf(cw.res.conn.bufw, "%x\r\n", len(p))
if err != nil {
cw.res.conn.rwc.Close()
return
}
}
- n, err = cw.res.conn.buf.Write(p)
+ n, err = cw.res.conn.bufw.Write(p)
if cw.chunking && err == nil {
- _, err = cw.res.conn.buf.Write(crlf)
+ _, err = cw.res.conn.bufw.Write(crlf)
}
if err != nil {
cw.res.conn.rwc.Close()
@@ -285,7 +278,7 @@ func (cw *chunkWriter) flush() {
if !cw.wroteHeader {
cw.writeHeader(nil)
}
- cw.res.conn.buf.Flush()
+ cw.res.conn.bufw.Flush()
}
func (cw *chunkWriter) close() {
@@ -293,7 +286,7 @@ func (cw *chunkWriter) close() {
cw.writeHeader(nil)
}
if cw.chunking {
- bw := cw.res.conn.buf // conn's bufio writer
+ bw := cw.res.conn.bufw // conn's bufio writer
// zero chunk to mark EOF
bw.WriteString("0\r\n")
if len(cw.res.trailers) > 0 {
@@ -315,12 +308,12 @@ func (cw *chunkWriter) close() {
type response struct {
conn *conn
req *Request // request for this response
- wroteHeader bool // reply header has been (logically) written
- wroteContinue bool // 100 Continue response was written
+ reqBody io.ReadCloser
+ wroteHeader bool // reply header has been (logically) written
+ wroteContinue bool // 100 Continue response was written
w *bufio.Writer // buffers output in chunks to chunkWriter
cw chunkWriter
- sw *switchWriter // of the bufio.Writer, for return to putBufioWriter
// handlerHeader is the Header that Handlers get access to,
// which may be retained and mutated even after WriteHeader.
@@ -354,13 +347,22 @@ type response struct {
// written.
trailers []string
- handlerDone bool // set true when the handler exits
+ handlerDone atomicBool // set true when the handler exits
// Buffers for Date and Content-Length
dateBuf [len(TimeFormat)]byte
clenBuf [10]byte
+
+ // closeNotifyCh is non-nil once CloseNotify is called.
+ // Guarded by conn.mu
+ closeNotifyCh <-chan bool
}
+type atomicBool int32
+
+func (b *atomicBool) isSet() bool { return atomic.LoadInt32((*int32)(b)) != 0 }
+func (b *atomicBool) setTrue() { atomic.StoreInt32((*int32)(b), 1) }
+
// declareTrailer is called for each Trailer header when the
// response header is written. It notes that a header will need to be
// written in the trailers at the end of the response.
@@ -423,7 +425,9 @@ func (w *response) ReadFrom(src io.Reader) (n int64, err error) {
return 0, err
}
if !ok || !regFile {
- return io.Copy(writerOnly{w}, src)
+ bufp := copyBufPool.Get().(*[]byte)
+ defer copyBufPool.Put(bufp)
+ return io.CopyBuffer(writerOnly{w}, src, *bufp)
}
// sendfile path:
@@ -456,29 +460,88 @@ func (w *response) ReadFrom(src io.Reader) (n int64, err error) {
return n, err
}
-// noLimit is an effective infinite upper bound for io.LimitedReader
-const noLimit int64 = (1 << 63) - 1
-
// debugServerConnections controls whether all server connections are wrapped
// with a verbose logging wrapper.
const debugServerConnections = false
// Create new connection from rwc.
-func (srv *Server) newConn(rwc net.Conn) (c *conn, err error) {
- c = new(conn)
- c.remoteAddr = rwc.RemoteAddr().String()
- c.server = srv
- c.rwc = rwc
- c.w = rwc
+func (srv *Server) newConn(rwc net.Conn) *conn {
+ c := &conn{
+ server: srv,
+ rwc: rwc,
+ }
if debugServerConnections {
c.rwc = newLoggingConn("server", c.rwc)
}
- c.sr.r = c.rwc
- c.lr = io.LimitReader(&c.sr, noLimit).(*io.LimitedReader)
- br := newBufioReader(c.lr)
- bw := newBufioWriterSize(checkConnErrorWriter{c}, 4<<10)
- c.buf = bufio.NewReadWriter(br, bw)
- return c, nil
+ return c
+}
+
+type readResult struct {
+ n int
+ err error
+ b byte // byte read, if n == 1
+}
+
+// connReader is the io.Reader wrapper used by *conn. It combines a
+// selectively-activated io.LimitedReader (to bound request header
+// read sizes) with support for selectively keeping an io.Reader.Read
+// call blocked in a background goroutine to wait for activity and
+// trigger a CloseNotifier channel.
+type connReader struct {
+ r io.Reader
+ remain int64 // bytes remaining
+
+ // ch is non-nil if a background read is in progress.
+ // It is guarded by conn.mu.
+ ch chan readResult
+}
+
+func (cr *connReader) setReadLimit(remain int64) { cr.remain = remain }
+func (cr *connReader) setInfiniteReadLimit() { cr.remain = 1<<63 - 1 }
+func (cr *connReader) hitReadLimit() bool { return cr.remain <= 0 }
+
+func (cr *connReader) Read(p []byte) (n int, err error) {
+ if cr.hitReadLimit() {
+ return 0, io.EOF
+ }
+ if len(p) == 0 {
+ return
+ }
+ if int64(len(p)) > cr.remain {
+ p = p[:cr.remain]
+ }
+
+ // Is a background read (started by CloseNotifier) already in
+ // flight? If so, wait for it and use its result.
+ ch := cr.ch
+ if ch != nil {
+ cr.ch = nil
+ res := <-ch
+ if res.n == 1 {
+ p[0] = res.b
+ cr.remain -= 1
+ }
+ return res.n, res.err
+ }
+ n, err = cr.r.Read(p)
+ cr.remain -= int64(n)
+ return
+}
+
+func (cr *connReader) startBackgroundRead(onReadComplete func()) {
+ if cr.ch != nil {
+ // Background read already started.
+ return
+ }
+ cr.ch = make(chan readResult, 1)
+ go cr.closeNotifyAwaitActivityRead(cr.ch, onReadComplete)
+}
+
+func (cr *connReader) closeNotifyAwaitActivityRead(ch chan<- readResult, onReadComplete func()) {
+ var buf [1]byte
+ n, err := cr.r.Read(buf[:1])
+ onReadComplete()
+ ch <- readResult{n, err, buf[0]}
}
var (
@@ -487,6 +550,13 @@ var (
bufioWriter4kPool sync.Pool
)
+var copyBufPool = sync.Pool{
+ New: func() interface{} {
+ b := make([]byte, 32*1024)
+ return &b
+ },
+}
+
func bufioWriterPool(size int) *sync.Pool {
switch size {
case 2 << 10:
@@ -544,7 +614,7 @@ func (srv *Server) maxHeaderBytes() int {
return DefaultMaxHeaderBytes
}
-func (srv *Server) initialLimitedReaderSize() int64 {
+func (srv *Server) initialReadLimitSize() int64 {
return int64(srv.maxHeaderBytes()) + 4096 // bufio slop
}
@@ -563,8 +633,8 @@ func (ecr *expectContinueReader) Read(p []byte) (n int, err error) {
}
if !ecr.resp.wroteContinue && !ecr.resp.conn.hijacked() {
ecr.resp.wroteContinue = true
- ecr.resp.conn.buf.WriteString("HTTP/1.1 100 Continue\r\n\r\n")
- ecr.resp.conn.buf.Flush()
+ ecr.resp.conn.bufw.WriteString("HTTP/1.1 100 Continue\r\n\r\n")
+ ecr.resp.conn.bufw.Flush()
}
n, err = ecr.readCloser.Read(p)
if err == io.EOF {
@@ -578,10 +648,12 @@ func (ecr *expectContinueReader) Close() error {
return ecr.readCloser.Close()
}
-// TimeFormat is the time format to use with
-// time.Parse and time.Time.Format when parsing
-// or generating times in HTTP headers.
-// It is like time.RFC1123 but hard codes GMT as the time zone.
+// TimeFormat is the time format to use when generating times in HTTP
+// headers. It is like time.RFC1123 but hard-codes GMT as the time
+// zone. The time being formatted must be in UTC for Format to
+// generate the correct format.
+//
+// For parsing this time format, see ParseTime.
const TimeFormat = "Mon, 02 Jan 2006 15:04:05 GMT"
// appendTime is a non-allocating version of []byte(t.UTC().Format(TimeFormat))
@@ -623,21 +695,45 @@ func (c *conn) readRequest() (w *response, err error) {
}()
}
- c.lr.N = c.server.initialLimitedReaderSize()
+ c.r.setReadLimit(c.server.initialReadLimitSize())
+ c.mu.Lock() // while using bufr
if c.lastMethod == "POST" {
// RFC 2616 section 4.1 tolerance for old buggy clients.
- peek, _ := c.buf.Reader.Peek(4) // ReadRequest will get err below
- c.buf.Reader.Discard(numLeadingCRorLF(peek))
+ peek, _ := c.bufr.Peek(4) // ReadRequest will get err below
+ c.bufr.Discard(numLeadingCRorLF(peek))
}
- var req *Request
- if req, err = ReadRequest(c.buf.Reader); err != nil {
- if c.lr.N == 0 {
+ req, err := readRequest(c.bufr, keepHostHeader)
+ c.mu.Unlock()
+ if err != nil {
+ if c.r.hitReadLimit() {
return nil, errTooLarge
}
return nil, err
}
- c.lr.N = noLimit
c.lastMethod = req.Method
+ c.r.setInfiniteReadLimit()
+
+ hosts, haveHost := req.Header["Host"]
+ if req.ProtoAtLeast(1, 1) && (!haveHost || len(hosts) == 0) {
+ return nil, badRequestError("missing required Host header")
+ }
+ if len(hosts) > 1 {
+ return nil, badRequestError("too many Host headers")
+ }
+ if len(hosts) == 1 && !validHostHeader(hosts[0]) {
+ return nil, badRequestError("malformed Host header")
+ }
+ for k, vv := range req.Header {
+ if !validHeaderName(k) {
+ return nil, badRequestError("invalid header name")
+ }
+ for _, v := range vv {
+ if !validHeaderValue(v) {
+ return nil, badRequestError("invalid header value")
+ }
+ }
+ }
+ delete(req.Header, "Host")
req.RemoteAddr = c.remoteAddr
req.TLS = c.tlsState
@@ -648,6 +744,7 @@ func (c *conn) readRequest() (w *response, err error) {
w = &response{
conn: c,
req: req,
+ reqBody: req.Body,
handlerHeader: make(Header),
contentLength: -1,
}
@@ -755,7 +852,7 @@ func (h extraHeader) Write(w *bufio.Writer) {
}
// writeHeader finalizes the header sent to the client and writes it
-// to cw.res.conn.buf.
+// to cw.res.conn.bufw.
//
// p is not written by writeHeader, but is the first chunk of the body
// that will be written. It is sniffed for a Content-Type if none is
@@ -821,7 +918,7 @@ func (cw *chunkWriter) writeHeader(p []byte) {
// send a Content-Length header.
// Further, we don't send an automatic Content-Length if they
// set a Transfer-Encoding, because they're generally incompatible.
- if w.handlerDone && !trailers && !hasTE && bodyAllowedForStatus(w.status) && header.get("Content-Length") == "" && (!isHEAD || len(p) > 0) {
+ if w.handlerDone.isSet() && !trailers && !hasTE && bodyAllowedForStatus(w.status) && header.get("Content-Length") == "" && (!isHEAD || len(p) > 0) {
w.contentLength = int64(len(p))
setHeader.contentLength = strconv.AppendInt(cw.res.clenBuf[:0], int64(len(p)), 10)
}
@@ -898,7 +995,7 @@ func (cw *chunkWriter) writeHeader(p []byte) {
}
if discard {
- _, err := io.CopyN(ioutil.Discard, w.req.Body, maxPostHandlerReadBytes+1)
+ _, err := io.CopyN(ioutil.Discard, w.reqBody, maxPostHandlerReadBytes+1)
switch err {
case nil:
// There must be even more data left over.
@@ -907,7 +1004,7 @@ func (cw *chunkWriter) writeHeader(p []byte) {
// Body was already consumed and closed.
case io.EOF:
// The remaining body was just consumed, close it.
- err = w.req.Body.Close()
+ err = w.reqBody.Close()
if err != nil {
w.closeAfterReply = true
}
@@ -996,10 +1093,10 @@ func (cw *chunkWriter) writeHeader(p []byte) {
}
}
- w.conn.buf.WriteString(statusLine(w.req, code))
- cw.header.WriteSubset(w.conn.buf, excludeHeader)
- setHeader.Write(w.conn.buf.Writer)
- w.conn.buf.Write(crlf)
+ w.conn.bufw.WriteString(statusLine(w.req, code))
+ cw.header.WriteSubset(w.conn.bufw, excludeHeader)
+ setHeader.Write(w.conn.bufw)
+ w.conn.bufw.Write(crlf)
}
// foreachHeaderElement splits v according to the "#rule" construction
@@ -1144,7 +1241,7 @@ func (w *response) write(lenData int, dataB []byte, dataS string) (n int, err er
}
func (w *response) finishRequest() {
- w.handlerDone = true
+ w.handlerDone.setTrue()
if !w.wroteHeader {
w.WriteHeader(StatusOK)
@@ -1153,11 +1250,11 @@ func (w *response) finishRequest() {
w.w.Flush()
putBufioWriter(w.w)
w.cw.close()
- w.conn.buf.Flush()
+ w.conn.bufw.Flush()
// Close the body (regardless of w.closeAfterReply) so we can
// re-use its bufio.Reader later safely.
- w.req.Body.Close()
+ w.reqBody.Close()
if w.req.MultipartForm != nil {
w.req.MultipartForm.RemoveAll()
@@ -1206,28 +1303,26 @@ func (w *response) Flush() {
}
func (c *conn) finalFlush() {
- if c.buf != nil {
- c.buf.Flush()
-
+ if c.bufr != nil {
// Steal the bufio.Reader (~4KB worth of memory) and its associated
// reader for a future connection.
- putBufioReader(c.buf.Reader)
+ putBufioReader(c.bufr)
+ c.bufr = nil
+ }
+ if c.bufw != nil {
+ c.bufw.Flush()
// Steal the bufio.Writer (~4KB worth of memory) and its associated
// writer for a future connection.
- putBufioWriter(c.buf.Writer)
-
- c.buf = nil
+ putBufioWriter(c.bufw)
+ c.bufw = nil
}
}
// Close the connection.
func (c *conn) close() {
c.finalFlush()
- if c.rwc != nil {
- c.rwc.Close()
- c.rwc = nil
- }
+ c.rwc.Close()
}
// rstAvoidanceDelay is the amount of time we sleep after closing the
@@ -1277,9 +1372,16 @@ func (c *conn) setState(nc net.Conn, state ConnState) {
}
}
+// badRequestError is a literal string (used by in the server in HTML,
+// unescaped) to tell the user why their request was bad. It should
+// be plain text without user info or other embeddded errors.
+type badRequestError string
+
+func (e badRequestError) Error() string { return "Bad Request: " + string(e) }
+
// Serve a new connection.
func (c *conn) serve() {
- origConn := c.rwc // copy it before it's set nil on Close or Hijack
+ c.remoteAddr = c.rwc.RemoteAddr().String()
defer func() {
if err := recover(); err != nil {
const size = 64 << 10
@@ -1289,7 +1391,7 @@ func (c *conn) serve() {
}
if !c.hijacked() {
c.close()
- c.setState(origConn, StateClosed)
+ c.setState(c.rwc, StateClosed)
}
}()
@@ -1315,9 +1417,13 @@ func (c *conn) serve() {
}
}
+ c.r = &connReader{r: c.rwc}
+ c.bufr = newBufioReader(c.r)
+ c.bufw = newBufioWriterSize(checkConnErrorWriter{c}, 4<<10)
+
for {
w, err := c.readRequest()
- if c.lr.N != c.server.initialLimitedReaderSize() {
+ if c.r.remain != c.server.initialReadLimitSize() {
// If we read any bytes off the wire, we're active.
c.setState(c.rwc, StateActive)
}
@@ -1328,16 +1434,22 @@ func (c *conn) serve() {
// responding to them and hanging up
// while they're still writing their
// request. Undefined behavior.
- io.WriteString(c.rwc, "HTTP/1.1 413 Request Entity Too Large\r\n\r\n")
+ io.WriteString(c.rwc, "HTTP/1.1 431 Request Header Fields Too Large\r\nContent-Type: text/plain\r\nConnection: close\r\n\r\n431 Request Header Fields Too Large")
c.closeWriteAndWait()
- break
- } else if err == io.EOF {
- break // Don't reply
- } else if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
- break // Don't reply
+ return
}
- io.WriteString(c.rwc, "HTTP/1.1 400 Bad Request\r\n\r\n")
- break
+ if err == io.EOF {
+ return // don't reply
+ }
+ if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
+ return // don't reply
+ }
+ var publicErr string
+ if v, ok := err.(badRequestError); ok {
+ publicErr = ": " + string(v)
+ }
+ io.WriteString(c.rwc, "HTTP/1.1 400 Bad Request\r\nContent-Type: text/plain\r\nConnection: close\r\n\r\n400 Bad Request"+publicErr)
+ return
}
// Expect 100 Continue support
@@ -1347,10 +1459,9 @@ func (c *conn) serve() {
// Wrap the Body reader with one that replies on the connection
req.Body = &expectContinueReader{readCloser: req.Body, resp: w}
}
- req.Header.Del("Expect")
} else if req.Header.get("Expect") != "" {
w.sendExpectationFailed()
- break
+ return
}
// HTTP cannot have multiple simultaneous active requests.[*]
@@ -1367,7 +1478,7 @@ func (c *conn) serve() {
if w.requestBodyLimitHit || w.closedRequestBodyEarly() {
c.closeWriteAndWait()
}
- break
+ return
}
c.setState(c.rwc, StateIdle)
}
@@ -1394,12 +1505,24 @@ func (w *response) sendExpectationFailed() {
// Hijack implements the Hijacker.Hijack method. Our response is both a ResponseWriter
// and a Hijacker.
func (w *response) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) {
+ if w.handlerDone.isSet() {
+ panic("net/http: Hijack called after ServeHTTP finished")
+ }
if w.wroteHeader {
w.cw.flush()
}
+
+ c := w.conn
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ if w.closeNotifyCh != nil {
+ return nil, nil, errors.New("http: Hijack is incompatible with use of CloseNotifier in same ServeHTTP call")
+ }
+
// Release the bufioWriter that writes to the chunk writer, it is not
// used after a connection has been hijacked.
- rwc, buf, err = w.conn.hijack()
+ rwc, buf, err = c.hijackLocked()
if err == nil {
putBufioWriter(w.w)
w.w = nil
@@ -1408,13 +1531,86 @@ func (w *response) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) {
}
func (w *response) CloseNotify() <-chan bool {
- return w.conn.closeNotify()
+ if w.handlerDone.isSet() {
+ panic("net/http: CloseNotify called after ServeHTTP finished")
+ }
+ c := w.conn
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ if w.closeNotifyCh != nil {
+ return w.closeNotifyCh
+ }
+ ch := make(chan bool, 1)
+ w.closeNotifyCh = ch
+
+ if w.conn.hijackedv {
+ // CloseNotify is undefined after a hijack, but we have
+ // no place to return an error, so just return a channel,
+ // even though it'll never receive a value.
+ return ch
+ }
+
+ var once sync.Once
+ notify := func() { once.Do(func() { ch <- true }) }
+
+ if requestBodyRemains(w.reqBody) {
+ // They're still consuming the request body, so we
+ // shouldn't notify yet.
+ registerOnHitEOF(w.reqBody, func() {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+ startCloseNotifyBackgroundRead(c, notify)
+ })
+ } else {
+ startCloseNotifyBackgroundRead(c, notify)
+ }
+ return ch
+}
+
+// c.mu must be held.
+func startCloseNotifyBackgroundRead(c *conn, notify func()) {
+ if c.bufr.Buffered() > 0 {
+ // They've consumed the request body, so anything
+ // remaining is a pipelined request, which we
+ // document as firing on.
+ notify()
+ } else {
+ c.r.startBackgroundRead(notify)
+ }
+}
+
+func registerOnHitEOF(rc io.ReadCloser, fn func()) {
+ switch v := rc.(type) {
+ case *expectContinueReader:
+ registerOnHitEOF(v.readCloser, fn)
+ case *body:
+ v.registerOnHitEOF(fn)
+ default:
+ panic("unexpected type " + fmt.Sprintf("%T", rc))
+ }
+}
+
+// requestBodyRemains reports whether future calls to Read
+// on rc might yield more data.
+func requestBodyRemains(rc io.ReadCloser) bool {
+ if rc == eofReader {
+ return false
+ }
+ switch v := rc.(type) {
+ case *expectContinueReader:
+ return requestBodyRemains(v.readCloser)
+ case *body:
+ return v.bodyRemains()
+ default:
+ panic("unexpected type " + fmt.Sprintf("%T", rc))
+ }
}
// The HandlerFunc type is an adapter to allow the use of
// ordinary functions as HTTP handlers. If f is a function
// with the appropriate signature, HandlerFunc(f) is a
-// Handler object that calls f.
+// Handler that calls f.
type HandlerFunc func(ResponseWriter, *Request)
// ServeHTTP calls f(w, r).
@@ -1461,6 +1657,9 @@ func StripPrefix(prefix string, h Handler) Handler {
// Redirect replies to the request with a redirect to url,
// which may be a path relative to the request path.
+//
+// The provided code should be in the 3xx range and is usually
+// StatusMovedPermanently, StatusFound or StatusSeeOther.
func Redirect(w ResponseWriter, r *Request, urlStr string, code int) {
if u, err := url.Parse(urlStr); err == nil {
// If url was relative, make absolute by
@@ -1479,11 +1678,12 @@ func Redirect(w ResponseWriter, r *Request, urlStr string, code int) {
// Because of this problem, no one pays attention
// to the RFC; they all send back just a new path.
// So do we.
- oldpath := r.URL.Path
- if oldpath == "" { // should not happen, but avoid a crash if it does
- oldpath = "/"
- }
- if u.Scheme == "" {
+ if u.Scheme == "" && u.Host == "" {
+ oldpath := r.URL.Path
+ if oldpath == "" { // should not happen, but avoid a crash if it does
+ oldpath = "/"
+ }
+
// no leading http://server
if urlStr == "" || urlStr[0] != '/' {
// make relative path absolute
@@ -1545,6 +1745,9 @@ func (rh *redirectHandler) ServeHTTP(w ResponseWriter, r *Request) {
// RedirectHandler returns a request handler that redirects
// each request it receives to the given url using the given
// status code.
+//
+// The provided code should be in the 3xx range and is usually
+// StatusMovedPermanently, StatusFound or StatusSeeOther.
func RedirectHandler(url string, code int) Handler {
return &redirectHandler{url, code}
}
@@ -1567,6 +1770,14 @@ func RedirectHandler(url string, code int) Handler {
// the pattern "/" matches all paths not matched by other registered
// patterns, not just the URL with Path == "/".
//
+// If a subtree has been registered and a request is received naming the
+// subtree root without its trailing slash, ServeMux redirects that
+// request to the subtree root (adding the trailing slash). This behavior can
+// be overridden with a separate registration for the path without
+// the trailing slash. For example, registering "/images/" causes ServeMux
+// to redirect a request for "/images" to "/images/", unless "/images" has
+// been registered separately.
+//
// Patterns may optionally begin with a host name, restricting matches to
// URLs on that host only. Host-specific patterns take precedence over
// general patterns, so that a handler might register for the two patterns
@@ -1574,8 +1785,8 @@ func RedirectHandler(url string, code int) Handler {
// requests for "http://www.google.com/".
//
// ServeMux also takes care of sanitizing the URL request path,
-// redirecting any request containing . or .. elements to an
-// equivalent .- and ..-free URL.
+// redirecting any request containing . or .. elements or repeated slashes
+// to an equivalent, cleaner URL.
type ServeMux struct {
mu sync.RWMutex
m map[string]muxEntry
@@ -1782,6 +1993,7 @@ type Server struct {
// handle HTTP requests and will initialize the Request's TLS
// and RemoteAddr if not already set. The connection is
// automatically closed when the function returns.
+ // If TLSNextProto is nil, HTTP/2 support is enabled automatically.
TLSNextProto map[string]func(*Server, *tls.Conn, Handler)
// ConnState specifies an optional callback function that is
@@ -1795,7 +2007,9 @@ type Server struct {
// standard logger.
ErrorLog *log.Logger
- disableKeepAlives int32 // accessed atomically.
+ disableKeepAlives int32 // accessed atomically.
+ nextProtoOnce sync.Once // guards initialization of TLSNextProto in Serve
+ nextProtoErr error
}
// A ConnState represents the state of a client connection to a server.
@@ -1815,6 +2029,11 @@ const (
// and doesn't fire again until the request has been
// handled. After the request is handled, the state
// transitions to StateClosed, StateHijacked, or StateIdle.
+ // For HTTP/2, StateActive fires on the transition from zero
+ // to one active request, and only transitions away once all
+ // active requests are complete. That means that ConnState
+ // can not be used to do per-request work; ConnState only notes
+ // the overall state of the connection.
StateActive
// StateIdle represents a connection that has finished
@@ -1863,8 +2082,10 @@ func (sh serverHandler) ServeHTTP(rw ResponseWriter, req *Request) {
}
// ListenAndServe listens on the TCP network address srv.Addr and then
-// calls Serve to handle requests on incoming connections. If
-// srv.Addr is blank, ":http" is used.
+// calls Serve to handle requests on incoming connections.
+// Accepted connections are configured to enable TCP keep-alives.
+// If srv.Addr is blank, ":http" is used.
+// ListenAndServe always returns a non-nil error.
func (srv *Server) ListenAndServe() error {
addr := srv.Addr
if addr == "" {
@@ -1877,12 +2098,21 @@ func (srv *Server) ListenAndServe() error {
return srv.Serve(tcpKeepAliveListener{ln.(*net.TCPListener)})
}
+var testHookServerServe func(*Server, net.Listener) // used if non-nil
+
// Serve accepts incoming connections on the Listener l, creating a
-// new service goroutine for each. The service goroutines read requests and
+// new service goroutine for each. The service goroutines read requests and
// then call srv.Handler to reply to them.
+// Serve always returns a non-nil error.
func (srv *Server) Serve(l net.Listener) error {
defer l.Close()
+ if fn := testHookServerServe; fn != nil {
+ fn(srv, l)
+ }
var tempDelay time.Duration // how long to sleep on accept failure
+ if err := srv.setupHTTP2(); err != nil {
+ return err
+ }
for {
rw, e := l.Accept()
if e != nil {
@@ -1902,10 +2132,7 @@ func (srv *Server) Serve(l net.Listener) error {
return e
}
tempDelay = 0
- c, err := srv.newConn(rw)
- if err != nil {
- continue
- }
+ c := srv.newConn(rw)
c.setState(c.rwc, StateNew) // before Serve can return
go c.serve()
}
@@ -1937,8 +2164,10 @@ func (s *Server) logf(format string, args ...interface{}) {
// ListenAndServe listens on the TCP network address addr
// and then calls Serve with handler to handle requests
-// on incoming connections. Handler is typically nil,
-// in which case the DefaultServeMux is used.
+// on incoming connections.
+// Accepted connections are configured to enable TCP keep-alives.
+// Handler is typically nil, in which case the DefaultServeMux is
+// used.
//
// A trivial example server is:
//
@@ -1957,11 +2186,10 @@ func (s *Server) logf(format string, args ...interface{}) {
//
// func main() {
// http.HandleFunc("/hello", HelloServer)
-// err := http.ListenAndServe(":12345", nil)
-// if err != nil {
-// log.Fatal("ListenAndServe: ", err)
-// }
+// log.Fatal(http.ListenAndServe(":12345", nil))
// }
+//
+// ListenAndServe always returns a non-nil error.
func ListenAndServe(addr string, handler Handler) error {
server := &Server{Addr: addr, Handler: handler}
return server.ListenAndServe()
@@ -1989,19 +2217,20 @@ func ListenAndServe(addr string, handler Handler) error {
// http.HandleFunc("/", handler)
// log.Printf("About to listen on 10443. Go to https://127.0.0.1:10443/")
// err := http.ListenAndServeTLS(":10443", "cert.pem", "key.pem", nil)
-// if err != nil {
-// log.Fatal(err)
-// }
+// log.Fatal(err)
// }
//
// One can use generate_cert.go in crypto/tls to generate cert.pem and key.pem.
-func ListenAndServeTLS(addr string, certFile string, keyFile string, handler Handler) error {
+//
+// ListenAndServeTLS always returns a non-nil error.
+func ListenAndServeTLS(addr, certFile, keyFile string, handler Handler) error {
server := &Server{Addr: addr, Handler: handler}
return server.ListenAndServeTLS(certFile, keyFile)
}
// ListenAndServeTLS listens on the TCP network address srv.Addr and
// then calls Serve to handle requests on incoming TLS connections.
+// Accepted connections are configured to enable TCP keep-alives.
//
// Filenames containing a certificate and matching private key for the
// server must be provided if the Server's TLSConfig.Certificates is
@@ -2010,14 +2239,23 @@ func ListenAndServeTLS(addr string, certFile string, keyFile string, handler Han
// certificate, any intermediates, and the CA's certificate.
//
// If srv.Addr is blank, ":https" is used.
+//
+// ListenAndServeTLS always returns a non-nil error.
func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error {
addr := srv.Addr
if addr == "" {
addr = ":https"
}
+
+ // Setup HTTP/2 before srv.Serve, to initialize srv.TLSConfig
+ // before we clone it and create the TLS Listener.
+ if err := srv.setupHTTP2(); err != nil {
+ return err
+ }
+
config := cloneTLSConfig(srv.TLSConfig)
- if config.NextProtos == nil {
- config.NextProtos = []string{"http/1.1"}
+ if !strSliceContains(config.NextProtos, "http/1.1") {
+ config.NextProtos = append(config.NextProtos, "http/1.1")
}
if len(config.Certificates) == 0 || certFile != "" || keyFile != "" {
@@ -2038,6 +2276,25 @@ func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error {
return srv.Serve(tlsListener)
}
+func (srv *Server) setupHTTP2() error {
+ srv.nextProtoOnce.Do(srv.onceSetNextProtoDefaults)
+ return srv.nextProtoErr
+}
+
+// onceSetNextProtoDefaults configures HTTP/2, if the user hasn't
+// configured otherwise. (by setting srv.TLSNextProto non-nil)
+// It must only be called via srv.nextProtoOnce (use srv.setupHTTP2).
+func (srv *Server) onceSetNextProtoDefaults() {
+ if strings.Contains(os.Getenv("GODEBUG"), "http2server=0") {
+ return
+ }
+ // Enable HTTP/2 by default if the user hasn't otherwise
+ // configured their TLSNextProto map.
+ if srv.TLSNextProto == nil {
+ srv.nextProtoErr = http2ConfigureServer(srv, nil)
+ }
+}
+
// TimeoutHandler returns a Handler that runs h with the given time limit.
//
// The new Handler calls h.ServeHTTP to handle each request, but if a
@@ -2046,11 +2303,20 @@ func (srv *Server) ListenAndServeTLS(certFile, keyFile string) error {
// (If msg is empty, a suitable default message will be sent.)
// After such a timeout, writes by h to its ResponseWriter will return
// ErrHandlerTimeout.
+//
+// TimeoutHandler buffers all Handler writes to memory and does not
+// support the Hijacker or Flusher interfaces.
func TimeoutHandler(h Handler, dt time.Duration, msg string) Handler {
- f := func() <-chan time.Time {
- return time.After(dt)
+ t := time.NewTimer(dt)
+ return &timeoutHandler{
+ handler: h,
+ body: msg,
+
+ // Effectively storing a *time.Timer, but decomposed
+ // for testing:
+ timeout: func() <-chan time.Time { return t.C },
+ cancelTimer: t.Stop,
}
- return &timeoutHandler{h, f, msg}
}
// ErrHandlerTimeout is returned on ResponseWriter Write calls
@@ -2059,8 +2325,13 @@ var ErrHandlerTimeout = errors.New("http: Handler timeout")
type timeoutHandler struct {
handler Handler
- timeout func() <-chan time.Time // returns channel producing a timeout
body string
+
+ // timeout returns the channel of a *time.Timer and
+ // cancelTimer cancels it. They're stored separately for
+ // testing purposes.
+ timeout func() <-chan time.Time // returns channel producing a timeout
+ cancelTimer func() bool // optional
}
func (h *timeoutHandler) errorBody() string {
@@ -2071,46 +2342,61 @@ func (h *timeoutHandler) errorBody() string {
}
func (h *timeoutHandler) ServeHTTP(w ResponseWriter, r *Request) {
- done := make(chan bool, 1)
- tw := &timeoutWriter{w: w}
+ done := make(chan struct{})
+ tw := &timeoutWriter{
+ w: w,
+ h: make(Header),
+ }
go func() {
h.handler.ServeHTTP(tw, r)
- done <- true
+ close(done)
}()
select {
case <-done:
- return
- case <-h.timeout():
tw.mu.Lock()
defer tw.mu.Unlock()
- if !tw.wroteHeader {
- tw.w.WriteHeader(StatusServiceUnavailable)
- tw.w.Write([]byte(h.errorBody()))
+ dst := w.Header()
+ for k, vv := range tw.h {
+ dst[k] = vv
+ }
+ w.WriteHeader(tw.code)
+ w.Write(tw.wbuf.Bytes())
+ if h.cancelTimer != nil {
+ h.cancelTimer()
}
+ case <-h.timeout():
+ tw.mu.Lock()
+ defer tw.mu.Unlock()
+ w.WriteHeader(StatusServiceUnavailable)
+ io.WriteString(w, h.errorBody())
tw.timedOut = true
+ return
}
}
type timeoutWriter struct {
- w ResponseWriter
+ w ResponseWriter
+ h Header
+ wbuf bytes.Buffer
mu sync.Mutex
timedOut bool
wroteHeader bool
+ code int
}
-func (tw *timeoutWriter) Header() Header {
- return tw.w.Header()
-}
+func (tw *timeoutWriter) Header() Header { return tw.h }
func (tw *timeoutWriter) Write(p []byte) (int, error) {
tw.mu.Lock()
defer tw.mu.Unlock()
- tw.wroteHeader = true // implicitly at least
if tw.timedOut {
return 0, ErrHandlerTimeout
}
- return tw.w.Write(p)
+ if !tw.wroteHeader {
+ tw.writeHeader(StatusOK)
+ }
+ return tw.wbuf.Write(p)
}
func (tw *timeoutWriter) WriteHeader(code int) {
@@ -2119,8 +2405,12 @@ func (tw *timeoutWriter) WriteHeader(code int) {
if tw.timedOut || tw.wroteHeader {
return
}
+ tw.writeHeader(code)
+}
+
+func (tw *timeoutWriter) writeHeader(code int) {
tw.wroteHeader = true
- tw.w.WriteHeader(code)
+ tw.code = code
}
// tcpKeepAliveListener sets TCP keep-alive timeouts on accepted
@@ -2247,7 +2537,7 @@ type checkConnErrorWriter struct {
}
func (w checkConnErrorWriter) Write(p []byte) (n int, err error) {
- n, err = w.c.w.Write(p) // c.w == c.rwc, except after a hijack, when rwc is nil.
+ n, err = w.c.rwc.Write(p)
if err != nil && w.c.werr == nil {
w.c.werr = err
}
@@ -2265,3 +2555,12 @@ func numLeadingCRorLF(v []byte) (n int) {
return
}
+
+func strSliceContains(ss []string, s string) bool {
+ for _, v := range ss {
+ if v == s {
+ return true
+ }
+ }
+ return false
+}
diff --git a/libgo/go/net/http/sniff.go b/libgo/go/net/http/sniff.go
index 3be8c865d3b..18810bad068 100644
--- a/libgo/go/net/http/sniff.go
+++ b/libgo/go/net/http/sniff.go
@@ -102,10 +102,9 @@ var sniffSignatures = []sniffSig{
&exactSig{[]byte("\x50\x4B\x03\x04"), "application/zip"},
&exactSig{[]byte("\x1F\x8B\x08"), "application/x-gzip"},
- // TODO(dsymonds): Re-enable this when the spec is sorted w.r.t. MP4.
- //mp4Sig(0),
+ mp4Sig{},
- textSig(0), // should be last
+ textSig{}, // should be last
}
type exactSig struct {
@@ -166,12 +165,14 @@ func (h htmlSig) match(data []byte, firstNonWS int) string {
}
var mp4ftype = []byte("ftyp")
+var mp4 = []byte("mp4")
-type mp4Sig int
+type mp4Sig struct{}
func (mp4Sig) match(data []byte, firstNonWS int) string {
- // c.f. section 6.1.
- if len(data) < 8 {
+ // https://mimesniff.spec.whatwg.org/#signature-for-mp4
+ // c.f. section 6.2.1
+ if len(data) < 12 {
return ""
}
boxSize := int(binary.BigEndian.Uint32(data[:4]))
@@ -186,30 +187,20 @@ func (mp4Sig) match(data []byte, firstNonWS int) string {
// minor version number
continue
}
- seg := string(data[st : st+3])
- switch seg {
- case "mp4", "iso", "M4V", "M4P", "M4B":
+ if bytes.Equal(data[st:st+3], mp4) {
return "video/mp4"
- /* The remainder are not in the spec.
- case "M4A":
- return "audio/mp4"
- case "3gp":
- return "video/3gpp"
- case "jp2":
- return "image/jp2" // JPEG 2000
- */
}
}
return ""
}
-type textSig int
+type textSig struct{}
func (textSig) match(data []byte, firstNonWS int) string {
// c.f. section 5, step 4.
for _, b := range data[firstNonWS:] {
switch {
- case 0x00 <= b && b <= 0x08,
+ case b <= 0x08,
b == 0x0B,
0x0E <= b && b <= 0x1A,
0x1C <= b && b <= 0x1F:
diff --git a/libgo/go/net/http/sniff_test.go b/libgo/go/net/http/sniff_test.go
index 24ca27afc16..e0085516da3 100644
--- a/libgo/go/net/http/sniff_test.go
+++ b/libgo/go/net/http/sniff_test.go
@@ -11,7 +11,6 @@ import (
"io/ioutil"
"log"
. "net/http"
- "net/http/httptest"
"reflect"
"strconv"
"strings"
@@ -40,9 +39,7 @@ var sniffTests = []struct {
{"GIF 87a", []byte(`GIF87a`), "image/gif"},
{"GIF 89a", []byte(`GIF89a...`), "image/gif"},
- // TODO(dsymonds): Re-enable this when the spec is sorted w.r.t. MP4.
- //{"MP4 video", []byte("\x00\x00\x00\x18ftypmp42\x00\x00\x00\x00mp42isom<\x06t\xbfmdat"), "video/mp4"},
- //{"MP4 audio", []byte("\x00\x00\x00\x20ftypM4A \x00\x00\x00\x00M4A mp42isom\x00\x00\x00\x00"), "audio/mp4"},
+ {"MP4 video", []byte("\x00\x00\x00\x18ftypmp42\x00\x00\x00\x00mp42isom<\x06t\xbfmdat"), "video/mp4"},
}
func TestDetectContentType(t *testing.T) {
@@ -54,9 +51,12 @@ func TestDetectContentType(t *testing.T) {
}
}
-func TestServerContentType(t *testing.T) {
+func TestServerContentType_h1(t *testing.T) { testServerContentType(t, h1Mode) }
+func TestServerContentType_h2(t *testing.T) { testServerContentType(t, h2Mode) }
+
+func testServerContentType(t *testing.T, h2 bool) {
defer afterTest(t)
- ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
i, _ := strconv.Atoi(r.FormValue("i"))
tt := sniffTests[i]
n, err := w.Write(tt.data)
@@ -64,10 +64,10 @@ func TestServerContentType(t *testing.T) {
log.Fatalf("%v: Write(%q) = %v, %v want %d, nil", tt.desc, tt.data, n, err, len(tt.data))
}
}))
- defer ts.Close()
+ defer cst.close()
for i, tt := range sniffTests {
- resp, err := Get(ts.URL + "/?i=" + strconv.Itoa(i))
+ resp, err := cst.c.Get(cst.ts.URL + "/?i=" + strconv.Itoa(i))
if err != nil {
t.Errorf("%v: %v", tt.desc, err)
continue
@@ -87,15 +87,17 @@ func TestServerContentType(t *testing.T) {
// Issue 5953: shouldn't sniff if the handler set a Content-Type header,
// even if it's the empty string.
-func TestServerIssue5953(t *testing.T) {
+func TestServerIssue5953_h1(t *testing.T) { testServerIssue5953(t, h1Mode) }
+func TestServerIssue5953_h2(t *testing.T) { testServerIssue5953(t, h2Mode) }
+func testServerIssue5953(t *testing.T, h2 bool) {
defer afterTest(t)
- ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
w.Header()["Content-Type"] = []string{""}
fmt.Fprintf(w, "<html><head></head><body>hi</body></html>")
}))
- defer ts.Close()
+ defer cst.close()
- resp, err := Get(ts.URL)
+ resp, err := cst.c.Get(cst.ts.URL)
if err != nil {
t.Fatal(err)
}
@@ -108,7 +110,9 @@ func TestServerIssue5953(t *testing.T) {
resp.Body.Close()
}
-func TestContentTypeWithCopy(t *testing.T) {
+func TestContentTypeWithCopy_h1(t *testing.T) { testContentTypeWithCopy(t, h1Mode) }
+func TestContentTypeWithCopy_h2(t *testing.T) { testContentTypeWithCopy(t, h2Mode) }
+func testContentTypeWithCopy(t *testing.T, h2 bool) {
defer afterTest(t)
const (
@@ -116,7 +120,7 @@ func TestContentTypeWithCopy(t *testing.T) {
expected = "text/html; charset=utf-8"
)
- ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
// Use io.Copy from a bytes.Buffer to trigger ReadFrom.
buf := bytes.NewBuffer([]byte(input))
n, err := io.Copy(w, buf)
@@ -124,9 +128,9 @@ func TestContentTypeWithCopy(t *testing.T) {
t.Errorf("io.Copy(w, %q) = %v, %v want %d, nil", input, n, err, len(input))
}
}))
- defer ts.Close()
+ defer cst.close()
- resp, err := Get(ts.URL)
+ resp, err := cst.c.Get(cst.ts.URL)
if err != nil {
t.Fatalf("Get: %v", err)
}
@@ -142,9 +146,11 @@ func TestContentTypeWithCopy(t *testing.T) {
resp.Body.Close()
}
-func TestSniffWriteSize(t *testing.T) {
+func TestSniffWriteSize_h1(t *testing.T) { testSniffWriteSize(t, h1Mode) }
+func TestSniffWriteSize_h2(t *testing.T) { testSniffWriteSize(t, h2Mode) }
+func testSniffWriteSize(t *testing.T, h2 bool) {
defer afterTest(t)
- ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
size, _ := strconv.Atoi(r.FormValue("size"))
written, err := io.WriteString(w, strings.Repeat("a", size))
if err != nil {
@@ -155,9 +161,9 @@ func TestSniffWriteSize(t *testing.T) {
t.Errorf("write of %d bytes wrote %d bytes", size, written)
}
}))
- defer ts.Close()
+ defer cst.close()
for _, size := range []int{0, 1, 200, 600, 999, 1000, 1023, 1024, 512 << 10, 1 << 20} {
- res, err := Get(fmt.Sprintf("%s/?size=%d", ts.URL, size))
+ res, err := cst.c.Get(fmt.Sprintf("%s/?size=%d", cst.ts.URL, size))
if err != nil {
t.Fatalf("size %d: %v", size, err)
}
diff --git a/libgo/go/net/http/status.go b/libgo/go/net/http/status.go
index d253bd5cb54..f3dacab6a92 100644
--- a/libgo/go/net/http/status.go
+++ b/libgo/go/net/http/status.go
@@ -44,20 +44,18 @@ const (
StatusRequestedRangeNotSatisfiable = 416
StatusExpectationFailed = 417
StatusTeapot = 418
+ StatusPreconditionRequired = 428
+ StatusTooManyRequests = 429
+ StatusRequestHeaderFieldsTooLarge = 431
+ StatusUnavailableForLegalReasons = 451
- StatusInternalServerError = 500
- StatusNotImplemented = 501
- StatusBadGateway = 502
- StatusServiceUnavailable = 503
- StatusGatewayTimeout = 504
- StatusHTTPVersionNotSupported = 505
-
- // New HTTP status codes from RFC 6585. Not exported yet in Go 1.1.
- // See discussion at https://codereview.appspot.com/7678043/
- statusPreconditionRequired = 428
- statusTooManyRequests = 429
- statusRequestHeaderFieldsTooLarge = 431
- statusNetworkAuthenticationRequired = 511
+ StatusInternalServerError = 500
+ StatusNotImplemented = 501
+ StatusBadGateway = 502
+ StatusServiceUnavailable = 503
+ StatusGatewayTimeout = 504
+ StatusHTTPVersionNotSupported = 505
+ StatusNetworkAuthenticationRequired = 511
)
var statusText = map[int]string{
@@ -99,18 +97,18 @@ var statusText = map[int]string{
StatusRequestedRangeNotSatisfiable: "Requested Range Not Satisfiable",
StatusExpectationFailed: "Expectation Failed",
StatusTeapot: "I'm a teapot",
+ StatusPreconditionRequired: "Precondition Required",
+ StatusTooManyRequests: "Too Many Requests",
+ StatusRequestHeaderFieldsTooLarge: "Request Header Fields Too Large",
+ StatusUnavailableForLegalReasons: "Unavailable For Legal Reasons",
- StatusInternalServerError: "Internal Server Error",
- StatusNotImplemented: "Not Implemented",
- StatusBadGateway: "Bad Gateway",
- StatusServiceUnavailable: "Service Unavailable",
- StatusGatewayTimeout: "Gateway Timeout",
- StatusHTTPVersionNotSupported: "HTTP Version Not Supported",
-
- statusPreconditionRequired: "Precondition Required",
- statusTooManyRequests: "Too Many Requests",
- statusRequestHeaderFieldsTooLarge: "Request Header Fields Too Large",
- statusNetworkAuthenticationRequired: "Network Authentication Required",
+ StatusInternalServerError: "Internal Server Error",
+ StatusNotImplemented: "Not Implemented",
+ StatusBadGateway: "Bad Gateway",
+ StatusServiceUnavailable: "Service Unavailable",
+ StatusGatewayTimeout: "Gateway Timeout",
+ StatusHTTPVersionNotSupported: "HTTP Version Not Supported",
+ StatusNetworkAuthenticationRequired: "Network Authentication Required",
}
// StatusText returns a text for the HTTP status code. It returns the empty
diff --git a/libgo/go/net/http/transfer.go b/libgo/go/net/http/transfer.go
index a8736b28e16..6e59af8f6f4 100644
--- a/libgo/go/net/http/transfer.go
+++ b/libgo/go/net/http/transfer.go
@@ -56,7 +56,7 @@ func newTransferWriter(r interface{}) (t *transferWriter, err error) {
if rr.ContentLength != 0 && rr.Body == nil {
return nil, fmt.Errorf("http: Request.ContentLength=%d with nil Body", rr.ContentLength)
}
- t.Method = rr.Method
+ t.Method = valueOrDefault(rr.Method, "GET")
t.Body = rr.Body
t.BodyCloser = rr.Body
t.ContentLength = rr.ContentLength
@@ -271,6 +271,10 @@ type transferReader struct {
Trailer Header
}
+func (t *transferReader) protoAtLeast(m, n int) bool {
+ return t.ProtoMajor > m || (t.ProtoMajor == m && t.ProtoMinor >= n)
+}
+
// bodyAllowedForStatus reports whether a given response status code
// permits a body. See RFC2616, section 4.4.
func bodyAllowedForStatus(status int) bool {
@@ -337,7 +341,7 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) {
}
// Transfer encoding, content length
- t.TransferEncoding, err = fixTransferEncoding(isResponse, t.RequestMethod, t.Header)
+ err = t.fixTransferEncoding()
if err != nil {
return err
}
@@ -424,13 +428,18 @@ func chunked(te []string) bool { return len(te) > 0 && te[0] == "chunked" }
// Checks whether the encoding is explicitly "identity".
func isIdentity(te []string) bool { return len(te) == 1 && te[0] == "identity" }
-// Sanitize transfer encoding
-func fixTransferEncoding(isResponse bool, requestMethod string, header Header) ([]string, error) {
- raw, present := header["Transfer-Encoding"]
+// fixTransferEncoding sanitizes t.TransferEncoding, if needed.
+func (t *transferReader) fixTransferEncoding() error {
+ raw, present := t.Header["Transfer-Encoding"]
if !present {
- return nil, nil
+ return nil
+ }
+ delete(t.Header, "Transfer-Encoding")
+
+ // Issue 12785; ignore Transfer-Encoding on HTTP/1.0 requests.
+ if !t.protoAtLeast(1, 1) {
+ return nil
}
- delete(header, "Transfer-Encoding")
encodings := strings.Split(raw[0], ",")
te := make([]string, 0, len(encodings))
@@ -445,13 +454,13 @@ func fixTransferEncoding(isResponse bool, requestMethod string, header Header) (
break
}
if encoding != "chunked" {
- return nil, &badStringError{"unsupported transfer encoding", encoding}
+ return &badStringError{"unsupported transfer encoding", encoding}
}
te = te[0 : len(te)+1]
te[len(te)-1] = encoding
}
if len(te) > 1 {
- return nil, &badStringError{"too many transfer encodings", strings.Join(te, ",")}
+ return &badStringError{"too many transfer encodings", strings.Join(te, ",")}
}
if len(te) > 0 {
// RFC 7230 3.3.2 says "A sender MUST NOT send a
@@ -470,11 +479,12 @@ func fixTransferEncoding(isResponse bool, requestMethod string, header Header) (
// such a message downstream."
//
// Reportedly, these appear in the wild.
- delete(header, "Content-Length")
- return te, nil
+ delete(t.Header, "Content-Length")
+ t.TransferEncoding = te
+ return nil
}
- return nil, nil
+ return nil
}
// Determine the expected body length, using RFC 2616 Section 4.4. This
@@ -567,21 +577,29 @@ func shouldClose(major, minor int, header Header, removeCloseHeader bool) bool {
// Parse the trailer header
func fixTrailer(header Header, te []string) (Header, error) {
- raw := header.get("Trailer")
- if raw == "" {
+ vv, ok := header["Trailer"]
+ if !ok {
return nil, nil
}
-
header.Del("Trailer")
+
trailer := make(Header)
- keys := strings.Split(raw, ",")
- for _, key := range keys {
- key = CanonicalHeaderKey(strings.TrimSpace(key))
- switch key {
- case "Transfer-Encoding", "Trailer", "Content-Length":
- return nil, &badStringError{"bad trailer key", key}
- }
- trailer[key] = nil
+ var err error
+ for _, v := range vv {
+ foreachHeaderElement(v, func(key string) {
+ key = CanonicalHeaderKey(key)
+ switch key {
+ case "Transfer-Encoding", "Trailer", "Content-Length":
+ if err == nil {
+ err = &badStringError{"bad trailer key", key}
+ return
+ }
+ }
+ trailer[key] = nil
+ })
+ }
+ if err != nil {
+ return nil, err
}
if len(trailer) == 0 {
return nil, nil
@@ -603,10 +621,11 @@ type body struct {
closing bool // is the connection to be closed after reading body?
doEarlyClose bool // whether Close should stop early
- mu sync.Mutex // guards closed, and calls to Read and Close
+ mu sync.Mutex // guards following, and calls to Read and Close
sawEOF bool
closed bool
- earlyClose bool // Close called and we didn't read to the end of src
+ earlyClose bool // Close called and we didn't read to the end of src
+ onHitEOF func() // if non-nil, func to call when EOF is Read
}
// ErrBodyReadAfterClose is returned when reading a Request or Response
@@ -666,6 +685,10 @@ func (b *body) readLocked(p []byte) (n int, err error) {
}
}
+ if b.sawEOF && b.onHitEOF != nil {
+ b.onHitEOF()
+ }
+
return n, err
}
@@ -800,6 +823,20 @@ func (b *body) didEarlyClose() bool {
return b.earlyClose
}
+// bodyRemains reports whether future Read calls might
+// yield data.
+func (b *body) bodyRemains() bool {
+ b.mu.Lock()
+ defer b.mu.Unlock()
+ return !b.sawEOF
+}
+
+func (b *body) registerOnHitEOF(fn func()) {
+ b.mu.Lock()
+ defer b.mu.Unlock()
+ b.onHitEOF = fn
+}
+
// bodyLocked is a io.Reader reading from a *body when its mutex is
// already held.
type bodyLocked struct {
diff --git a/libgo/go/net/http/transport.go b/libgo/go/net/http/transport.go
index 70d18646059..41df906cf2d 100644
--- a/libgo/go/net/http/transport.go
+++ b/libgo/go/net/http/transport.go
@@ -36,7 +36,8 @@ var DefaultTransport RoundTripper = &Transport{
Timeout: 30 * time.Second,
KeepAlive: 30 * time.Second,
}).Dial,
- TLSHandshakeTimeout: 10 * time.Second,
+ TLSHandshakeTimeout: 10 * time.Second,
+ ExpectContinueTimeout: 1 * time.Second,
}
// DefaultMaxIdleConnsPerHost is the default value of Transport's
@@ -45,7 +46,21 @@ const DefaultMaxIdleConnsPerHost = 2
// Transport is an implementation of RoundTripper that supports HTTP,
// HTTPS, and HTTP proxies (for either HTTP or HTTPS with CONNECT).
-// Transport can also cache connections for future re-use.
+//
+// By default, Transport caches connections for future re-use.
+// This may leave many open connections when accessing many hosts.
+// This behavior can be managed using Transport's CloseIdleConnections method
+// and the MaxIdleConnsPerHost and DisableKeepAlives fields.
+//
+// Transports should be reused instead of created as needed.
+// Transports are safe for concurrent use by multiple goroutines.
+//
+// A Transport is a low-level primitive for making HTTP and HTTPS requests.
+// For high-level functionality, such as cookies and redirects, see Client.
+//
+// Transport uses HTTP/1.1 for HTTP URLs and either HTTP/1.1 or HTTP/2
+// for HTTPS URLs, depending on whether the server supports HTTP/2.
+// See the package docs for more about HTTP/2.
type Transport struct {
idleMu sync.Mutex
wantIdle bool // user has requested to close all idle conns
@@ -113,8 +128,49 @@ type Transport struct {
// time does not include the time to read the response body.
ResponseHeaderTimeout time.Duration
+ // ExpectContinueTimeout, if non-zero, specifies the amount of
+ // time to wait for a server's first response headers after fully
+ // writing the request headers if the request has an
+ // "Expect: 100-continue" header. Zero means no timeout.
+ // This time does not include the time to send the request header.
+ ExpectContinueTimeout time.Duration
+
+ // TLSNextProto specifies how the Transport switches to an
+ // alternate protocol (such as HTTP/2) after a TLS NPN/ALPN
+ // protocol negotiation. If Transport dials an TLS connection
+ // with a non-empty protocol name and TLSNextProto contains a
+ // map entry for that key (such as "h2"), then the func is
+ // called with the request's authority (such as "example.com"
+ // or "example.com:1234") and the TLS connection. The function
+ // must return a RoundTripper that then handles the request.
+ // If TLSNextProto is nil, HTTP/2 support is enabled automatically.
+ TLSNextProto map[string]func(authority string, c *tls.Conn) RoundTripper
+
+ // nextProtoOnce guards initialization of TLSNextProto and
+ // h2transport (via onceSetNextProtoDefaults)
+ nextProtoOnce sync.Once
+ h2transport *http2Transport // non-nil if http2 wired up
+
// TODO: tunable on global max cached connections
// TODO: tunable on timeout on cached connections
+ // TODO: tunable on max per-host TCP dials in flight (Issue 13957)
+}
+
+// onceSetNextProtoDefaults initializes TLSNextProto.
+// It must be called via t.nextProtoOnce.Do.
+func (t *Transport) onceSetNextProtoDefaults() {
+ if strings.Contains(os.Getenv("GODEBUG"), "http2client=0") {
+ return
+ }
+ if t.TLSNextProto != nil {
+ return
+ }
+ t2, err := http2configureTransport(t)
+ if err != nil {
+ log.Printf("Error enabling Transport HTTP/2 support: %v", err)
+ } else {
+ t.h2transport = t2
+ }
}
// ProxyFromEnvironment returns the URL of the proxy to use for a
@@ -188,7 +244,8 @@ func (tr *transportRequest) extraHeaders() Header {
//
// For higher-level HTTP client support (such as handling of cookies
// and redirects), see Get, Post, and the Client type.
-func (t *Transport) RoundTrip(req *Request) (resp *Response, err error) {
+func (t *Transport) RoundTrip(req *Request) (*Response, error) {
+ t.nextProtoOnce.Do(t.onceSetNextProtoDefaults)
if req.URL == nil {
req.closeBody()
return nil, errors.New("http: nil Request.URL")
@@ -197,54 +254,114 @@ func (t *Transport) RoundTrip(req *Request) (resp *Response, err error) {
req.closeBody()
return nil, errors.New("http: nil Request.Header")
}
- if req.URL.Scheme != "http" && req.URL.Scheme != "https" {
- t.altMu.RLock()
- var rt RoundTripper
- if t.altProto != nil {
- rt = t.altProto[req.URL.Scheme]
- }
- t.altMu.RUnlock()
- if rt == nil {
- req.closeBody()
- return nil, &badStringError{"unsupported protocol scheme", req.URL.Scheme}
+ // TODO(bradfitz): switch to atomic.Value for this map instead of RWMutex
+ t.altMu.RLock()
+ altRT := t.altProto[req.URL.Scheme]
+ t.altMu.RUnlock()
+ if altRT != nil {
+ if resp, err := altRT.RoundTrip(req); err != ErrSkipAltProtocol {
+ return resp, err
}
- return rt.RoundTrip(req)
}
- if req.URL.Host == "" {
+ if s := req.URL.Scheme; s != "http" && s != "https" {
req.closeBody()
- return nil, errors.New("http: no Host in request URL")
+ return nil, &badStringError{"unsupported protocol scheme", s}
}
- treq := &transportRequest{Request: req}
- cm, err := t.connectMethodForRequest(treq)
- if err != nil {
+ if req.Method != "" && !validMethod(req.Method) {
+ return nil, fmt.Errorf("net/http: invalid method %q", req.Method)
+ }
+ if req.URL.Host == "" {
req.closeBody()
- return nil, err
+ return nil, errors.New("http: no Host in request URL")
}
- // Get the cached or newly-created connection to either the
- // host (for http or https), the http proxy, or the http proxy
- // pre-CONNECTed to https server. In any case, we'll be ready
- // to send it requests.
- pconn, err := t.getConn(req, cm)
- if err != nil {
- t.setReqCanceler(req, nil)
- req.closeBody()
- return nil, err
+ for {
+ // treq gets modified by roundTrip, so we need to recreate for each retry.
+ treq := &transportRequest{Request: req}
+ cm, err := t.connectMethodForRequest(treq)
+ if err != nil {
+ req.closeBody()
+ return nil, err
+ }
+
+ // Get the cached or newly-created connection to either the
+ // host (for http or https), the http proxy, or the http proxy
+ // pre-CONNECTed to https server. In any case, we'll be ready
+ // to send it requests.
+ pconn, err := t.getConn(req, cm)
+ if err != nil {
+ t.setReqCanceler(req, nil)
+ req.closeBody()
+ return nil, err
+ }
+
+ var resp *Response
+ if pconn.alt != nil {
+ // HTTP/2 path.
+ t.setReqCanceler(req, nil) // not cancelable with CancelRequest
+ resp, err = pconn.alt.RoundTrip(req)
+ } else {
+ resp, err = pconn.roundTrip(treq)
+ }
+ if err == nil {
+ return resp, nil
+ }
+ if err := checkTransportResend(err, req, pconn); err != nil {
+ return nil, err
+ }
+ testHookRoundTripRetried()
}
+}
- return pconn.roundTrip(treq)
+// checkTransportResend checks whether a failed HTTP request can be
+// resent on a new connection. The non-nil input error is the error from
+// roundTrip, which might be wrapped in a beforeRespHeaderError error.
+//
+// The return value is err or the unwrapped error inside a
+// beforeRespHeaderError.
+func checkTransportResend(err error, req *Request, pconn *persistConn) error {
+ brhErr, ok := err.(beforeRespHeaderError)
+ if !ok {
+ return err
+ }
+ err = brhErr.error // unwrap the custom error in case we return it
+ if err != errMissingHost && pconn.isReused() && req.isReplayable() {
+ // If we try to reuse a connection that the server is in the process of
+ // closing, we may end up successfully writing out our request (or a
+ // portion of our request) only to find a connection error when we try to
+ // read from (or finish writing to) the socket.
+
+ // There can be a race between the socket pool checking whether a socket
+ // is still connected, receiving the FIN, and sending/reading data on a
+ // reused socket. If we receive the FIN between the connectedness check
+ // and writing/reading from the socket, we may first learn the socket is
+ // disconnected when we get a ERR_SOCKET_NOT_CONNECTED. This will most
+ // likely happen when trying to retrieve its IP address. See
+ // http://crbug.com/105824 for more details.
+
+ // We resend a request only if we reused a keep-alive connection and did
+ // not yet receive any header data. This automatically prevents an
+ // infinite resend loop because we'll run out of the cached keep-alive
+ // connections eventually.
+ return nil
+ }
+ return err
}
+// ErrSkipAltProtocol is a sentinel error value defined by Transport.RegisterProtocol.
+var ErrSkipAltProtocol = errors.New("net/http: skip alternate protocol")
+
// RegisterProtocol registers a new protocol with scheme.
// The Transport will pass requests using the given scheme to rt.
// It is rt's responsibility to simulate HTTP request semantics.
//
// RegisterProtocol can be used by other packages to provide
// implementations of protocol schemes like "ftp" or "file".
+//
+// If rt.RoundTrip returns ErrSkipAltProtocol, the Transport will
+// handle the RoundTrip itself for that one request, as if the
+// protocol were not registered.
func (t *Transport) RegisterProtocol(scheme string, rt RoundTripper) {
- if scheme == "http" || scheme == "https" {
- panic("protocol " + scheme + " already registered")
- }
t.altMu.Lock()
defer t.altMu.Unlock()
if t.altProto == nil {
@@ -261,6 +378,7 @@ func (t *Transport) RegisterProtocol(scheme string, rt RoundTripper) {
// a "keep-alive" state. It does not interrupt any connections currently
// in use.
func (t *Transport) CloseIdleConnections() {
+ t.nextProtoOnce.Do(t.onceSetNextProtoDefaults)
t.idleMu.Lock()
m := t.idleConn
t.idleConn = nil
@@ -269,13 +387,19 @@ func (t *Transport) CloseIdleConnections() {
t.idleMu.Unlock()
for _, conns := range m {
for _, pconn := range conns {
- pconn.close()
+ pconn.close(errCloseIdleConns)
}
}
+ if t2 := t.h2transport; t2 != nil {
+ t2.CloseIdleConnections()
+ }
}
// CancelRequest cancels an in-flight request by closing its connection.
// CancelRequest should only be called after RoundTrip has returned.
+//
+// Deprecated: Use Request.Cancel instead. CancelRequest can not cancel
+// HTTP/2 requests.
func (t *Transport) CancelRequest(req *Request) {
t.reqMu.Lock()
cancel := t.reqCanceler[req]
@@ -354,23 +478,41 @@ func (cm *connectMethod) proxyAuth() string {
return ""
}
-// putIdleConn adds pconn to the list of idle persistent connections awaiting
+// error values for debugging and testing, not seen by users.
+var (
+ errKeepAlivesDisabled = errors.New("http: putIdleConn: keep alives disabled")
+ errConnBroken = errors.New("http: putIdleConn: connection is in bad state")
+ errWantIdle = errors.New("http: putIdleConn: CloseIdleConnections was called")
+ errTooManyIdle = errors.New("http: putIdleConn: too many idle connections")
+ errCloseIdleConns = errors.New("http: CloseIdleConnections called")
+ errReadLoopExiting = errors.New("http: persistConn.readLoop exiting")
+ errServerClosedIdle = errors.New("http: server closed idle conn")
+)
+
+func (t *Transport) putOrCloseIdleConn(pconn *persistConn) {
+ if err := t.tryPutIdleConn(pconn); err != nil {
+ pconn.close(err)
+ }
+}
+
+// tryPutIdleConn adds pconn to the list of idle persistent connections awaiting
// a new request.
-// If pconn is no longer needed or not in a good state, putIdleConn
-// returns false.
-func (t *Transport) putIdleConn(pconn *persistConn) bool {
+// If pconn is no longer needed or not in a good state, tryPutIdleConn returns
+// an error explaining why it wasn't registered.
+// tryPutIdleConn does not close pconn. Use putOrCloseIdleConn instead for that.
+func (t *Transport) tryPutIdleConn(pconn *persistConn) error {
if t.DisableKeepAlives || t.MaxIdleConnsPerHost < 0 {
- pconn.close()
- return false
+ return errKeepAlivesDisabled
}
if pconn.isBroken() {
- return false
+ return errConnBroken
}
key := pconn.cacheKey
max := t.MaxIdleConnsPerHost
if max == 0 {
max = DefaultMaxIdleConnsPerHost
}
+ pconn.markReused()
t.idleMu.Lock()
waitingDialer := t.idleConnCh[key]
@@ -382,7 +524,7 @@ func (t *Transport) putIdleConn(pconn *persistConn) bool {
// first). Chrome calls this socket late binding. See
// https://insouciant.org/tech/connection-management-in-chromium/
t.idleMu.Unlock()
- return true
+ return nil
default:
if waitingDialer != nil {
// They had populated this, but their dial won
@@ -392,16 +534,14 @@ func (t *Transport) putIdleConn(pconn *persistConn) bool {
}
if t.wantIdle {
t.idleMu.Unlock()
- pconn.close()
- return false
+ return errWantIdle
}
if t.idleConn == nil {
t.idleConn = make(map[connectMethodKey][]*persistConn)
}
if len(t.idleConn[key]) >= max {
t.idleMu.Unlock()
- pconn.close()
- return false
+ return errTooManyIdle
}
for _, exist := range t.idleConn[key] {
if exist == pconn {
@@ -410,7 +550,7 @@ func (t *Transport) putIdleConn(pconn *persistConn) bool {
}
t.idleConn[key] = append(t.idleConn[key], pconn)
t.idleMu.Unlock()
- return true
+ return nil
}
// getIdleConnCh returns a channel to receive and return idle
@@ -494,16 +634,17 @@ func (t *Transport) replaceReqCanceler(r *Request, fn func()) bool {
return true
}
-func (t *Transport) dial(network, addr string) (c net.Conn, err error) {
+func (t *Transport) dial(network, addr string) (net.Conn, error) {
if t.Dial != nil {
- return t.Dial(network, addr)
+ c, err := t.Dial(network, addr)
+ if c == nil && err == nil {
+ err = errors.New("net/http: Transport.Dial hook returned (nil, nil)")
+ }
+ return c, err
}
return net.Dial(network, addr)
}
-// Testing hooks:
-var prePendingDial, postPendingDial func()
-
// getConn dials and creates a new persistConn to the target as
// specified in the connectMethod. This includes doing a proxy CONNECT
// and/or setting up TLS. If this doesn't return an error, the persistConn
@@ -525,20 +666,16 @@ func (t *Transport) getConn(req *Request, cm connectMethod) (*persistConn, error
// Copy these hooks so we don't race on the postPendingDial in
// the goroutine we launch. Issue 11136.
- prePendingDial := prePendingDial
- postPendingDial := postPendingDial
+ testHookPrePendingDial := testHookPrePendingDial
+ testHookPostPendingDial := testHookPostPendingDial
handlePendingDial := func() {
- if prePendingDial != nil {
- prePendingDial()
- }
+ testHookPrePendingDial()
go func() {
if v := <-dialc; v.err == nil {
- t.putIdleConn(v.pc)
- }
- if postPendingDial != nil {
- postPendingDial()
+ t.putOrCloseIdleConn(v.pc)
}
+ testHookPostPendingDial()
}()
}
@@ -565,10 +702,10 @@ func (t *Transport) getConn(req *Request, cm connectMethod) (*persistConn, error
return pc, nil
case <-req.Cancel:
handlePendingDial()
- return nil, errors.New("net/http: request canceled while waiting for connection")
+ return nil, errRequestCanceledConn
case <-cancelc:
handlePendingDial()
- return nil, errors.New("net/http: request canceled while waiting for connection")
+ return nil, errRequestCanceledConn
}
}
@@ -588,7 +725,16 @@ func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) {
if err != nil {
return nil, err
}
+ if pconn.conn == nil {
+ return nil, errors.New("net/http: Transport.DialTLS returned (nil, nil)")
+ }
if tc, ok := pconn.conn.(*tls.Conn); ok {
+ // Handshake here, in case DialTLS didn't. TLSNextProto below
+ // depends on it for knowing the connection state.
+ if err := tc.Handshake(); err != nil {
+ go pconn.conn.Close()
+ return nil, err
+ }
cs := tc.ConnectionState()
pconn.tlsState = &cs
}
@@ -680,6 +826,12 @@ func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) {
pconn.conn = tlsConn
}
+ if s := pconn.tlsState; s != nil && s.NegotiatedProtocolIsMutual && s.NegotiatedProtocol != "" {
+ if next, ok := t.TLSNextProto[s.NegotiatedProtocol]; ok {
+ return &persistConn{alt: next(cm.targetAddr, pconn.conn.(*tls.Conn))}, nil
+ }
+ }
+
pconn.br = bufio.NewReader(noteEOFReader{pconn.conn, &pconn.sawEOF})
pconn.bw = bufio.NewWriter(pconn.conn)
go pconn.readLoop()
@@ -809,6 +961,11 @@ func (k connectMethodKey) String() string {
// persistConn wraps a connection, usually a persistent one
// (but may be used for non-keep-alive requests as well)
type persistConn struct {
+ // alt optionally specifies the TLS NextProto RoundTripper.
+ // This is used for HTTP/2 today and future protocol laters.
+ // If it's non-nil, the rest of the fields are unused.
+ alt RoundTripper
+
t *Transport
cacheKey connectMethodKey
conn net.Conn
@@ -828,9 +985,10 @@ type persistConn struct {
lk sync.Mutex // guards following fields
numExpectedResponses int
- closed bool // whether conn has been closed
- broken bool // an error has happened on this connection; marked broken so it's not reused.
- canceled bool // whether this conn was broken due a CancelRequest
+ closed error // set non-nil when conn is closed, before closech is closed
+ broken bool // an error has happened on this connection; marked broken so it's not reused.
+ canceled bool // whether this conn was broken due a CancelRequest
+ reused bool // whether conn has had successful request/response and is being reused.
// mutateHeaderFunc is an optional func to modify extra
// headers on each outbound request before it's written. (the
// original Request given to RoundTrip is not modified)
@@ -852,15 +1010,34 @@ func (pc *persistConn) isCanceled() bool {
return pc.canceled
}
+// isReused reports whether this connection is in a known broken state.
+func (pc *persistConn) isReused() bool {
+ pc.lk.Lock()
+ r := pc.reused
+ pc.lk.Unlock()
+ return r
+}
+
func (pc *persistConn) cancelRequest() {
pc.lk.Lock()
defer pc.lk.Unlock()
pc.canceled = true
- pc.closeLocked()
+ pc.closeLocked(errRequestCanceled)
}
func (pc *persistConn) readLoop() {
- // eofc is used to block http.Handler goroutines reading from Response.Body
+ closeErr := errReadLoopExiting // default value, if not changed below
+ defer func() { pc.close(closeErr) }()
+
+ tryPutIdleConn := func() bool {
+ if err := pc.t.tryPutIdleConn(pc); err != nil {
+ closeErr = err
+ return false
+ }
+ return true
+ }
+
+ // eofc is used to block caller goroutines reading from Response.Body
// at EOF until this goroutines has (potentially) added the connection
// back to the idle pool.
eofc := make(chan struct{})
@@ -873,17 +1050,14 @@ func (pc *persistConn) readLoop() {
alive := true
for alive {
- pb, err := pc.br.Peek(1)
+ _, err := pc.br.Peek(1)
+ if err != nil {
+ err = beforeRespHeaderError{err}
+ }
pc.lk.Lock()
if pc.numExpectedResponses == 0 {
- if !pc.closed {
- pc.closeLocked()
- if len(pb) > 0 {
- log.Printf("Unsolicited response received on idle HTTP channel starting with %q; err=%v",
- string(pb), err)
- }
- }
+ pc.readLoopPeekFailLocked(err)
pc.lk.Unlock()
return
}
@@ -893,115 +1067,189 @@ func (pc *persistConn) readLoop() {
var resp *Response
if err == nil {
- resp, err = ReadResponse(pc.br, rc.req)
- if err == nil && resp.StatusCode == 100 {
- // Skip any 100-continue for now.
- // TODO(bradfitz): if rc.req had "Expect: 100-continue",
- // actually block the request body write and signal the
- // writeLoop now to begin sending it. (Issue 2184) For now we
- // eat it, since we're never expecting one.
- resp, err = ReadResponse(pc.br, rc.req)
- }
- }
-
- if resp != nil {
- resp.TLS = pc.tlsState
+ resp, err = pc.readResponse(rc)
}
- hasBody := resp != nil && rc.req.Method != "HEAD" && resp.ContentLength != 0
-
if err != nil {
- pc.close()
- } else {
- if rc.addedGzip && hasBody && resp.Header.Get("Content-Encoding") == "gzip" {
- resp.Header.Del("Content-Encoding")
- resp.Header.Del("Content-Length")
- resp.ContentLength = -1
- resp.Body = &gzipReader{body: resp.Body}
+ // If we won't be able to retry this request later (from the
+ // roundTrip goroutine), mark it as done now.
+ // BEFORE the send on rc.ch, as the client might re-use the
+ // same *Request pointer, and we don't want to set call
+ // t.setReqCanceler from this persistConn while the Transport
+ // potentially spins up a different persistConn for the
+ // caller's subsequent request.
+ if checkTransportResend(err, rc.req, pc) != nil {
+ pc.t.setReqCanceler(rc.req, nil)
+ }
+ select {
+ case rc.ch <- responseAndError{err: err}:
+ case <-rc.callerGone:
+ return
}
- resp.Body = &bodyEOFSignal{body: resp.Body}
+ return
}
- if err != nil || resp.Close || rc.req.Close || resp.StatusCode <= 199 {
+ pc.lk.Lock()
+ pc.numExpectedResponses--
+ pc.lk.Unlock()
+
+ hasBody := rc.req.Method != "HEAD" && resp.ContentLength != 0
+
+ if resp.Close || rc.req.Close || resp.StatusCode <= 199 {
// Don't do keep-alive on error if either party requested a close
// or we get an unexpected informational (1xx) response.
// StatusCode 100 is already handled above.
alive = false
}
- var waitForBodyRead chan bool // channel is nil when there's no body
- if hasBody {
- waitForBodyRead = make(chan bool, 2)
- resp.Body.(*bodyEOFSignal).earlyCloseFn = func() error {
- waitForBodyRead <- false
- return nil
- }
- resp.Body.(*bodyEOFSignal).fn = func(err error) error {
- isEOF := err == io.EOF
- waitForBodyRead <- isEOF
- if isEOF {
- <-eofc // see comment at top
- } else if err != nil && pc.isCanceled() {
- return errRequestCanceled
- }
- return err
- }
- } else {
- // Before send on rc.ch, as client might re-use the
- // same *Request pointer, and we don't want to set this
- // on t from this persistConn while the Transport
- // potentially spins up a different persistConn for the
- // caller's subsequent request.
+ if !hasBody {
pc.t.setReqCanceler(rc.req, nil)
- }
- pc.lk.Lock()
- pc.numExpectedResponses--
- pc.lk.Unlock()
+ // Put the idle conn back into the pool before we send the response
+ // so if they process it quickly and make another request, they'll
+ // get this same conn. But we use the unbuffered channel 'rc'
+ // to guarantee that persistConn.roundTrip got out of its select
+ // potentially waiting for this persistConn to close.
+ // but after
+ alive = alive &&
+ !pc.sawEOF &&
+ pc.wroteRequest() &&
+ tryPutIdleConn()
- // The connection might be going away when we put the
- // idleConn below. When that happens, we close the response channel to signal
- // to roundTrip that the connection is gone. roundTrip waits for
- // both closing and a response in a select, so it might choose
- // the close channel, rather than the response.
- // We send the response first so that roundTrip can check
- // if there is a pending one with a non-blocking select
- // on the response channel before erroring out.
- rc.ch <- responseAndError{resp, err}
-
- if hasBody {
- // To avoid a race, wait for the just-returned
- // response body to be fully consumed before peek on
- // the underlying bufio reader.
select {
- case <-rc.req.Cancel:
- alive = false
- pc.t.CancelRequest(rc.req)
- case bodyEOF := <-waitForBodyRead:
- pc.t.setReqCanceler(rc.req, nil) // before pc might return to idle pool
- alive = alive &&
- bodyEOF &&
- !pc.sawEOF &&
- pc.wroteRequest() &&
- pc.t.putIdleConn(pc)
- if bodyEOF {
- eofc <- struct{}{}
- }
- case <-pc.closech:
- alive = false
+ case rc.ch <- responseAndError{res: resp}:
+ case <-rc.callerGone:
+ return
}
- } else {
+
+ // Now that they've read from the unbuffered channel, they're safely
+ // out of the select that also waits on this goroutine to die, so
+ // we're allowed to exit now if needed (if alive is false)
+ testHookReadLoopBeforeNextRead()
+ continue
+ }
+
+ if rc.addedGzip {
+ maybeUngzipResponse(resp)
+ }
+ resp.Body = &bodyEOFSignal{body: resp.Body}
+
+ waitForBodyRead := make(chan bool, 2)
+ resp.Body.(*bodyEOFSignal).earlyCloseFn = func() error {
+ waitForBodyRead <- false
+ return nil
+ }
+ resp.Body.(*bodyEOFSignal).fn = func(err error) error {
+ isEOF := err == io.EOF
+ waitForBodyRead <- isEOF
+ if isEOF {
+ <-eofc // see comment above eofc declaration
+ } else if err != nil && pc.isCanceled() {
+ return errRequestCanceled
+ }
+ return err
+ }
+
+ select {
+ case rc.ch <- responseAndError{res: resp}:
+ case <-rc.callerGone:
+ return
+ }
+
+ // Before looping back to the top of this function and peeking on
+ // the bufio.Reader, wait for the caller goroutine to finish
+ // reading the response body. (or for cancelation or death)
+ select {
+ case bodyEOF := <-waitForBodyRead:
+ pc.t.setReqCanceler(rc.req, nil) // before pc might return to idle pool
alive = alive &&
+ bodyEOF &&
!pc.sawEOF &&
pc.wroteRequest() &&
- pc.t.putIdleConn(pc)
+ tryPutIdleConn()
+ if bodyEOF {
+ eofc <- struct{}{}
+ }
+ case <-rc.req.Cancel:
+ alive = false
+ pc.t.CancelRequest(rc.req)
+ case <-pc.closech:
+ alive = false
+ }
+
+ testHookReadLoopBeforeNextRead()
+ }
+}
+
+func maybeUngzipResponse(resp *Response) {
+ if resp.Header.Get("Content-Encoding") == "gzip" {
+ resp.Header.Del("Content-Encoding")
+ resp.Header.Del("Content-Length")
+ resp.ContentLength = -1
+ resp.Body = &gzipReader{body: resp.Body}
+ }
+}
+
+func (pc *persistConn) readLoopPeekFailLocked(peekErr error) {
+ if pc.closed != nil {
+ return
+ }
+ if n := pc.br.Buffered(); n > 0 {
+ buf, _ := pc.br.Peek(n)
+ log.Printf("Unsolicited response received on idle HTTP channel starting with %q; err=%v", buf, peekErr)
+ }
+ if peekErr == io.EOF {
+ // common case.
+ pc.closeLocked(errServerClosedIdle)
+ } else {
+ pc.closeLocked(fmt.Errorf("readLoopPeekFailLocked: %v", peekErr))
+ }
+}
+
+// readResponse reads an HTTP response (or two, in the case of "Expect:
+// 100-continue") from the server. It returns the final non-100 one.
+func (pc *persistConn) readResponse(rc requestAndChan) (resp *Response, err error) {
+ resp, err = ReadResponse(pc.br, rc.req)
+ if err != nil {
+ return
+ }
+ if rc.continueCh != nil {
+ if resp.StatusCode == 100 {
+ rc.continueCh <- struct{}{}
+ } else {
+ close(rc.continueCh)
+ }
+ }
+ if resp.StatusCode == 100 {
+ resp, err = ReadResponse(pc.br, rc.req)
+ if err != nil {
+ return
}
+ }
+ resp.TLS = pc.tlsState
+ return
+}
+
+// waitForContinue returns the function to block until
+// any response, timeout or connection close. After any of them,
+// the function returns a bool which indicates if the body should be sent.
+func (pc *persistConn) waitForContinue(continueCh <-chan struct{}) func() bool {
+ if continueCh == nil {
+ return nil
+ }
+ return func() bool {
+ timer := time.NewTimer(pc.t.ExpectContinueTimeout)
+ defer timer.Stop()
- if hook := testHookReadLoopBeforeNextRead; hook != nil {
- hook()
+ select {
+ case _, ok := <-continueCh:
+ return ok
+ case <-timer.C:
+ return true
+ case <-pc.closech:
+ return false
}
}
- pc.close()
}
func (pc *persistConn) writeLoop() {
@@ -1012,7 +1260,7 @@ func (pc *persistConn) writeLoop() {
wr.ch <- errors.New("http: can't write HTTP request on broken connection")
continue
}
- err := wr.req.Request.write(pc.bw, pc.isProxy, wr.req.extra)
+ err := wr.req.Request.write(pc.bw, pc.isProxy, wr.req.extra, pc.waitForContinue(wr.continueCh))
if err == nil {
err = pc.bw.Flush()
}
@@ -1056,19 +1304,29 @@ func (pc *persistConn) wroteRequest() bool {
}
}
+// responseAndError is how the goroutine reading from an HTTP/1 server
+// communicates with the goroutine doing the RoundTrip.
type responseAndError struct {
- res *Response
+ res *Response // else use this response (see res method)
err error
}
type requestAndChan struct {
req *Request
- ch chan responseAndError
+ ch chan responseAndError // unbuffered; always send in select on callerGone
// did the Transport (as opposed to the client code) add an
// Accept-Encoding gzip header? only if it we set it do
// we transparently decode the gzip.
addedGzip bool
+
+ // Optional blocking chan for Expect: 100-continue (for send).
+ // If the request has an "Expect: 100-continue" header and
+ // the server responds 100 Continue, readLoop send a value
+ // to writeLoop via this chan.
+ continueCh chan<- struct{}
+
+ callerGone <-chan struct{} // closed when roundTrip caller has returned
}
// A writeRequest is sent by the readLoop's goroutine to the
@@ -1078,6 +1336,11 @@ type requestAndChan struct {
type writeRequest struct {
req *transportRequest
ch chan<- error
+
+ // Optional blocking chan for Expect: 100-continue (for recieve).
+ // If not nil, writeLoop blocks sending request body until
+ // it receives from this chan.
+ continueCh <-chan struct{}
}
type httpError struct {
@@ -1090,23 +1353,34 @@ func (e *httpError) Timeout() bool { return e.timeout }
func (e *httpError) Temporary() bool { return true }
var errTimeout error = &httpError{err: "net/http: timeout awaiting response headers", timeout: true}
-var errClosed error = &httpError{err: "net/http: transport closed before response was received"}
+var errClosed error = &httpError{err: "net/http: server closed connection before response was received"}
var errRequestCanceled = errors.New("net/http: request canceled")
+var errRequestCanceledConn = errors.New("net/http: request canceled while waiting for connection") // TODO: unify?
+
+func nop() {}
-// nil except for tests
+// testHooks. Always non-nil.
var (
- testHookPersistConnClosedGotRes func()
- testHookEnterRoundTrip func()
- testHookMu sync.Locker = fakeLocker{} // guards following
- testHookReadLoopBeforeNextRead func()
+ testHookEnterRoundTrip = nop
+ testHookWaitResLoop = nop
+ testHookRoundTripRetried = nop
+ testHookPrePendingDial = nop
+ testHookPostPendingDial = nop
+
+ testHookMu sync.Locker = fakeLocker{} // guards following
+ testHookReadLoopBeforeNextRead = nop
)
+// beforeRespHeaderError is used to indicate when an IO error has occurred before
+// any header data was received.
+type beforeRespHeaderError struct {
+ error
+}
+
func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) {
- if hook := testHookEnterRoundTrip; hook != nil {
- hook()
- }
+ testHookEnterRoundTrip()
if !pc.t.replaceReqCanceler(req.Request, pc.cancelRequest) {
- pc.t.putIdleConn(pc)
+ pc.t.putOrCloseIdleConn(pc)
return nil, errRequestCanceled
}
pc.lk.Lock()
@@ -1143,42 +1417,47 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err
req.extraHeaders().Set("Accept-Encoding", "gzip")
}
+ var continueCh chan struct{}
+ if req.ProtoAtLeast(1, 1) && req.Body != nil && req.expectsContinue() {
+ continueCh = make(chan struct{}, 1)
+ }
+
if pc.t.DisableKeepAlives {
req.extraHeaders().Set("Connection", "close")
}
+ gone := make(chan struct{})
+ defer close(gone)
+
// Write the request concurrently with waiting for a response,
// in case the server decides to reply before reading our full
// request body.
writeErrCh := make(chan error, 1)
- pc.writech <- writeRequest{req, writeErrCh}
+ pc.writech <- writeRequest{req, writeErrCh, continueCh}
- resc := make(chan responseAndError, 1)
- pc.reqch <- requestAndChan{req.Request, resc, requestedGzip}
+ resc := make(chan responseAndError)
+ pc.reqch <- requestAndChan{
+ req: req.Request,
+ ch: resc,
+ addedGzip: requestedGzip,
+ continueCh: continueCh,
+ callerGone: gone,
+ }
var re responseAndError
var respHeaderTimer <-chan time.Time
cancelChan := req.Request.Cancel
WaitResponse:
for {
+ testHookWaitResLoop()
select {
case err := <-writeErrCh:
- if isNetWriteError(err) {
- // Issue 11745. If we failed to write the request
- // body, it's possible the server just heard enough
- // and already wrote to us. Prioritize the server's
- // response over returning a body write error.
- select {
- case re = <-resc:
- pc.close()
- break WaitResponse
- case <-time.After(50 * time.Millisecond):
- // Fall through.
- }
- }
if err != nil {
- re = responseAndError{nil, err}
- pc.close()
+ if pc.isCanceled() {
+ err = errRequestCanceled
+ }
+ re = responseAndError{err: beforeRespHeaderError{err}}
+ pc.close(fmt.Errorf("write error: %v", err))
break WaitResponse
}
if d := pc.t.ResponseHeaderTimeout; d > 0 {
@@ -1187,33 +1466,22 @@ WaitResponse:
respHeaderTimer = timer.C
}
case <-pc.closech:
- // The persist connection is dead. This shouldn't
- // usually happen (only with Connection: close responses
- // with no response bodies), but if it does happen it
- // means either a) the remote server hung up on us
- // prematurely, or b) the readLoop sent us a response &
- // closed its closech at roughly the same time, and we
- // selected this case first. If we got a response, readLoop makes sure
- // to send it before it puts the conn and closes the channel.
- // That way, we can fetch the response, if there is one,
- // with a non-blocking receive.
- select {
- case re = <-resc:
- if fn := testHookPersistConnClosedGotRes; fn != nil {
- fn()
- }
- default:
- re = responseAndError{err: errClosed}
- if pc.isCanceled() {
- re = responseAndError{err: errRequestCanceled}
- }
+ var err error
+ if pc.isCanceled() {
+ err = errRequestCanceled
+ } else {
+ err = beforeRespHeaderError{fmt.Errorf("net/http: HTTP/1 transport connection broken: %v", pc.closed)}
}
+ re = responseAndError{err: err}
break WaitResponse
case <-respHeaderTimer:
- pc.close()
+ pc.close(errTimeout)
re = responseAndError{err: errTimeout}
break WaitResponse
case re = <-resc:
+ if re.err != nil && pc.isCanceled() {
+ re.err = errRequestCanceled
+ }
break WaitResponse
case <-cancelChan:
pc.t.CancelRequest(req.Request)
@@ -1224,6 +1492,9 @@ WaitResponse:
if re.err != nil {
pc.t.setReqCanceler(req.Request, nil)
}
+ if (re.res == nil) == (re.err == nil) {
+ panic("internal error: exactly one of res or err should be set")
+ }
return re.res, re.err
}
@@ -1236,18 +1507,44 @@ func (pc *persistConn) markBroken() {
pc.broken = true
}
-func (pc *persistConn) close() {
+// markReused marks this connection as having been successfully used for a
+// request and response.
+func (pc *persistConn) markReused() {
+ pc.lk.Lock()
+ pc.reused = true
+ pc.lk.Unlock()
+}
+
+// close closes the underlying TCP connection and closes
+// the pc.closech channel.
+//
+// The provided err is only for testing and debugging; in normal
+// circumstances it should never be seen by users.
+func (pc *persistConn) close(err error) {
pc.lk.Lock()
defer pc.lk.Unlock()
- pc.closeLocked()
+ pc.closeLocked(err)
}
-func (pc *persistConn) closeLocked() {
+func (pc *persistConn) closeLocked(err error) {
+ if err == nil {
+ panic("nil error")
+ }
pc.broken = true
- if !pc.closed {
- pc.conn.Close()
- pc.closed = true
- close(pc.closech)
+ if pc.closed == nil {
+ pc.closed = err
+ if pc.alt != nil {
+ // Do nothing; can only get here via getConn's
+ // handlePendingDial's putOrCloseIdleConn when
+ // it turns out the abandoned connection in
+ // flight ended up negotiating an alternate
+ // protocol. We don't use the connection
+ // freelist for http2. That's done by the
+ // alternate protocol's RoundTripper.
+ } else {
+ pc.conn.Close()
+ close(pc.closech)
+ }
}
pc.mutateHeaderFunc = nil
}
diff --git a/libgo/go/net/http/transport_test.go b/libgo/go/net/http/transport_test.go
index c21d4afa87f..3b2a5f978e2 100644
--- a/libgo/go/net/http/transport_test.go
+++ b/libgo/go/net/http/transport_test.go
@@ -2,7 +2,10 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
-// Tests for transport.go
+// Tests for transport.go.
+//
+// More tests are in clientserver_test.go (for things testing both client & server for both
+// HTTP/1 and HTTP/2). This
package http_test
@@ -20,6 +23,8 @@ import (
"net"
. "net/http"
"net/http/httptest"
+ "net/http/httputil"
+ "net/http/internal"
"net/url"
"os"
"reflect"
@@ -256,6 +261,7 @@ func TestTransportConnectionCloseOnRequest(t *testing.T) {
// if the Transport's DisableKeepAlives is set, all requests should
// send Connection: close.
+// HTTP/1-only (Connection: close doesn't exist in h2)
func TestTransportConnectionCloseOnRequestDisableKeepAlive(t *testing.T) {
defer afterTest(t)
ts := httptest.NewServer(hostPortHandler)
@@ -431,6 +437,7 @@ func TestTransportMaxPerHostIdleConns(t *testing.T) {
}
func TestTransportServerClosingUnexpectedly(t *testing.T) {
+ setParallel(t)
defer afterTest(t)
ts := httptest.NewServer(hostPortHandler)
defer ts.Close()
@@ -597,6 +604,7 @@ func TestTransportHeadChunkedResponse(t *testing.T) {
tr := &Transport{DisableKeepAlives: false}
c := &Client{Transport: tr}
+ defer tr.CloseIdleConnections()
// Ensure that we wait for the readLoop to complete before
// calling Head again
@@ -790,6 +798,94 @@ func TestTransportGzip(t *testing.T) {
}
}
+// If a request has Expect:100-continue header, the request blocks sending body until the first response.
+// Premature consumption of the request body should not be occurred.
+func TestTransportExpect100Continue(t *testing.T) {
+ defer afterTest(t)
+
+ ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) {
+ switch req.URL.Path {
+ case "/100":
+ // This endpoint implicitly responds 100 Continue and reads body.
+ if _, err := io.Copy(ioutil.Discard, req.Body); err != nil {
+ t.Error("Failed to read Body", err)
+ }
+ rw.WriteHeader(StatusOK)
+ case "/200":
+ // Go 1.5 adds Connection: close header if the client expect
+ // continue but not entire request body is consumed.
+ rw.WriteHeader(StatusOK)
+ case "/500":
+ rw.WriteHeader(StatusInternalServerError)
+ case "/keepalive":
+ // This hijacked endpoint responds error without Connection:close.
+ _, bufrw, err := rw.(Hijacker).Hijack()
+ if err != nil {
+ log.Fatal(err)
+ }
+ bufrw.WriteString("HTTP/1.1 500 Internal Server Error\r\n")
+ bufrw.WriteString("Content-Length: 0\r\n\r\n")
+ bufrw.Flush()
+ case "/timeout":
+ // This endpoint tries to read body without 100 (Continue) response.
+ // After ExpectContinueTimeout, the reading will be started.
+ conn, bufrw, err := rw.(Hijacker).Hijack()
+ if err != nil {
+ log.Fatal(err)
+ }
+ if _, err := io.CopyN(ioutil.Discard, bufrw, req.ContentLength); err != nil {
+ t.Error("Failed to read Body", err)
+ }
+ bufrw.WriteString("HTTP/1.1 200 OK\r\n\r\n")
+ bufrw.Flush()
+ conn.Close()
+ }
+
+ }))
+ defer ts.Close()
+
+ tests := []struct {
+ path string
+ body []byte
+ sent int
+ status int
+ }{
+ {path: "/100", body: []byte("hello"), sent: 5, status: 200}, // Got 100 followed by 200, entire body is sent.
+ {path: "/200", body: []byte("hello"), sent: 0, status: 200}, // Got 200 without 100. body isn't sent.
+ {path: "/500", body: []byte("hello"), sent: 0, status: 500}, // Got 500 without 100. body isn't sent.
+ {path: "/keepalive", body: []byte("hello"), sent: 0, status: 500}, // Althogh without Connection:close, body isn't sent.
+ {path: "/timeout", body: []byte("hello"), sent: 5, status: 200}, // Timeout exceeded and entire body is sent.
+ }
+
+ for i, v := range tests {
+ tr := &Transport{ExpectContinueTimeout: 2 * time.Second}
+ defer tr.CloseIdleConnections()
+ c := &Client{Transport: tr}
+
+ body := bytes.NewReader(v.body)
+ req, err := NewRequest("PUT", ts.URL+v.path, body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ req.Header.Set("Expect", "100-continue")
+ req.ContentLength = int64(len(v.body))
+
+ resp, err := c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ resp.Body.Close()
+
+ sent := len(v.body) - body.Len()
+ if v.status != resp.StatusCode {
+ t.Errorf("test %d: status code should be %d but got %d. (%s)", i, v.status, resp.StatusCode, v.path)
+ }
+ if v.sent != sent {
+ t.Errorf("test %d: sent body should be %d but sent %d. (%s)", i, v.sent, sent, v.path)
+ }
+ }
+}
+
func TestTransportProxy(t *testing.T) {
defer afterTest(t)
ch := make(chan string, 1)
@@ -874,9 +970,7 @@ func TestTransportGzipShort(t *testing.T) {
// tests that persistent goroutine connections shut down when no longer desired.
func TestTransportPersistConnLeak(t *testing.T) {
- if runtime.GOOS == "plan9" {
- t.Skip("skipping test; see https://golang.org/issue/7237")
- }
+ setParallel(t)
defer afterTest(t)
gotReqCh := make(chan bool)
unblockCh := make(chan bool)
@@ -943,9 +1037,7 @@ func TestTransportPersistConnLeak(t *testing.T) {
// golang.org/issue/4531: Transport leaks goroutines when
// request.ContentLength is explicitly short
func TestTransportPersistConnLeakShortBody(t *testing.T) {
- if runtime.GOOS == "plan9" {
- t.Skip("skipping test; see https://golang.org/issue/7237")
- }
+ setParallel(t)
defer afterTest(t)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
}))
@@ -1286,6 +1378,7 @@ func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) {
}
func TestTransportResponseHeaderTimeout(t *testing.T) {
+ setParallel(t)
defer afterTest(t)
if testing.Short() {
t.Skip("skipping timeout test in -short mode")
@@ -1357,6 +1450,7 @@ func TestTransportResponseHeaderTimeout(t *testing.T) {
}
func TestTransportCancelRequest(t *testing.T) {
+ setParallel(t)
defer afterTest(t)
if testing.Short() {
t.Skip("skipping test in -short mode")
@@ -1466,6 +1560,7 @@ Get = Get http://something.no-network.tld/: net/http: request canceled while wai
}
func TestCancelRequestWithChannel(t *testing.T) {
+ setParallel(t)
defer afterTest(t)
if testing.Short() {
t.Skip("skipping test in -short mode")
@@ -1523,6 +1618,7 @@ func TestCancelRequestWithChannel(t *testing.T) {
}
func TestCancelRequestWithChannelBeforeDo(t *testing.T) {
+ setParallel(t)
defer afterTest(t)
unblockc := make(chan bool)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
@@ -1554,7 +1650,6 @@ func TestCancelRequestWithChannelBeforeDo(t *testing.T) {
// Issue 11020. The returned error message should be errRequestCanceled
func TestTransportCancelBeforeResponseHeaders(t *testing.T) {
- t.Skip("Skipping flaky test; see Issue 11894")
defer afterTest(t)
serverConnCh := make(chan net.Conn, 1)
@@ -1704,6 +1799,19 @@ func TestTransportNoHost(t *testing.T) {
}
}
+// Issue 13311
+func TestTransportEmptyMethod(t *testing.T) {
+ req, _ := NewRequest("GET", "http://foo.com/", nil)
+ req.Method = "" // docs say "For client requests an empty string means GET"
+ got, err := httputil.DumpRequestOut(req, false) // DumpRequestOut uses Transport
+ if err != nil {
+ t.Fatal(err)
+ }
+ if !strings.Contains(string(got), "GET ") {
+ t.Fatalf("expected substring 'GET '; got: %s", got)
+ }
+}
+
func TestTransportSocketLateBinding(t *testing.T) {
defer afterTest(t)
@@ -2291,15 +2399,103 @@ type errorReader struct {
func (e errorReader) Read(p []byte) (int, error) { return 0, e.err }
+type plan9SleepReader struct{}
+
+func (plan9SleepReader) Read(p []byte) (int, error) {
+ if runtime.GOOS == "plan9" {
+ // After the fix to unblock TCP Reads in
+ // https://golang.org/cl/15941, this sleep is required
+ // on plan9 to make sure TCP Writes before an
+ // immediate TCP close go out on the wire. On Plan 9,
+ // it seems that a hangup of a TCP connection with
+ // queued data doesn't send the queued data first.
+ // https://golang.org/issue/9554
+ time.Sleep(50 * time.Millisecond)
+ }
+ return 0, io.EOF
+}
+
type closerFunc func() error
func (f closerFunc) Close() error { return f() }
+// Issue 4677. If we try to reuse a connection that the server is in the
+// process of closing, we may end up successfully writing out our request (or a
+// portion of our request) only to find a connection error when we try to read
+// from (or finish writing to) the socket.
+//
+// NOTE: we resend a request only if the request is idempotent, we reused a
+// keep-alive connection, and we haven't yet received any header data. This
+// automatically prevents an infinite resend loop because we'll run out of the
+// cached keep-alive connections eventually.
+func TestRetryIdempotentRequestsOnError(t *testing.T) {
+ defer afterTest(t)
+
+ ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
+ }))
+ defer ts.Close()
+
+ tr := &Transport{}
+ c := &Client{Transport: tr}
+
+ const N = 2
+ retryc := make(chan struct{}, N)
+ SetRoundTripRetried(func() {
+ retryc <- struct{}{}
+ })
+ defer SetRoundTripRetried(nil)
+
+ for n := 0; n < 100; n++ {
+ // open 2 conns
+ errc := make(chan error, N)
+ for i := 0; i < N; i++ {
+ // start goroutines, send on errc
+ go func() {
+ res, err := c.Get(ts.URL)
+ if err == nil {
+ res.Body.Close()
+ }
+ errc <- err
+ }()
+ }
+ for i := 0; i < N; i++ {
+ if err := <-errc; err != nil {
+ t.Fatal(err)
+ }
+ }
+
+ ts.CloseClientConnections()
+ for i := 0; i < N; i++ {
+ go func() {
+ res, err := c.Get(ts.URL)
+ if err == nil {
+ res.Body.Close()
+ }
+ errc <- err
+ }()
+ }
+
+ for i := 0; i < N; i++ {
+ if err := <-errc; err != nil {
+ t.Fatal(err)
+ }
+ }
+ for i := 0; i < N; i++ {
+ select {
+ case <-retryc:
+ // we triggered a retry, test was successful
+ t.Logf("finished after %d runs\n", n)
+ return
+ default:
+ }
+ }
+ }
+ t.Fatal("did not trigger any retries")
+}
+
// Issue 6981
func TestTransportClosesBodyOnError(t *testing.T) {
- if runtime.GOOS == "plan9" {
- t.Skip("skipping test; see https://golang.org/issue/7782")
- }
+ setParallel(t)
defer afterTest(t)
readBody := make(chan error, 1)
ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
@@ -2313,7 +2509,7 @@ func TestTransportClosesBodyOnError(t *testing.T) {
io.Reader
io.Closer
}{
- io.MultiReader(io.LimitReader(neverEnding('x'), 1<<20), errorReader{fakeErr}),
+ io.MultiReader(io.LimitReader(neverEnding('x'), 1<<20), plan9SleepReader{}, errorReader{fakeErr}),
closerFunc(func() error {
select {
case didClose <- true:
@@ -2474,52 +2670,6 @@ func TestTransportRangeAndGzip(t *testing.T) {
res.Body.Close()
}
-// Previously, we used to handle a logical race within RoundTrip by waiting for 100ms
-// in the case of an error. Changing the order of the channel operations got rid of this
-// race.
-//
-// In order to test that the channel op reordering works, we install a hook into the
-// roundTrip function which gets called if we saw the connection go away and
-// we subsequently received a response.
-func TestTransportResponseCloseRace(t *testing.T) {
- if testing.Short() {
- t.Skip("skipping in short mode")
- }
- defer afterTest(t)
-
- ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {
- }))
- defer ts.Close()
- sawRace := false
- SetInstallConnClosedHook(func() {
- sawRace = true
- })
- defer SetInstallConnClosedHook(nil)
- tr := &Transport{
- DisableKeepAlives: true,
- }
- req, err := NewRequest("GET", ts.URL, nil)
- if err != nil {
- t.Fatal(err)
- }
- // selects are not deterministic, so do this a bunch
- // and see if we handle the logical race at least once.
- for i := 0; i < 10000; i++ {
- resp, err := tr.RoundTrip(req)
- if err != nil {
- t.Fatalf("unexpected error: %s", err)
- continue
- }
- resp.Body.Close()
- if sawRace {
- break
- }
- }
- if !sawRace {
- t.Errorf("didn't see response/connection going away race")
- }
-}
-
// Test for issue 10474
func TestTransportResponseCancelRace(t *testing.T) {
defer afterTest(t)
@@ -2645,7 +2795,7 @@ func TestTransportFlushesBodyChunks(t *testing.T) {
req.Header.Set("User-Agent", "x") // known value for test
res, err := tr.RoundTrip(req)
if err != nil {
- t.Error("RoundTrip: %v", err)
+ t.Errorf("RoundTrip: %v", err)
close(resc)
return
}
@@ -2735,6 +2885,153 @@ func TestTransportPrefersResponseOverWriteError(t *testing.T) {
}
}
+func TestTransportAutomaticHTTP2(t *testing.T) {
+ tr := &Transport{}
+ _, err := tr.RoundTrip(new(Request))
+ if err == nil {
+ t.Error("expected error from RoundTrip")
+ }
+ if tr.TLSNextProto["h2"] == nil {
+ t.Errorf("HTTP/2 not registered.")
+ }
+
+ // Now with TLSNextProto set:
+ tr = &Transport{TLSNextProto: make(map[string]func(string, *tls.Conn) RoundTripper)}
+ _, err = tr.RoundTrip(new(Request))
+ if err == nil {
+ t.Error("expected error from RoundTrip")
+ }
+ if tr.TLSNextProto["h2"] != nil {
+ t.Errorf("HTTP/2 registered, despite non-nil TLSNextProto field")
+ }
+}
+
+// Issue 13633: there was a race where we returned bodyless responses
+// to callers before recycling the persistent connection, which meant
+// a client doing two subsequent requests could end up on different
+// connections. It's somewhat harmless but enough tests assume it's
+// not true in order to test other things that it's worth fixing.
+// Plus it's nice to be consistent and not have timing-dependent
+// behavior.
+func TestTransportReuseConnEmptyResponseBody(t *testing.T) {
+ defer afterTest(t)
+ cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("X-Addr", r.RemoteAddr)
+ // Empty response body.
+ }))
+ defer cst.close()
+ n := 100
+ if testing.Short() {
+ n = 10
+ }
+ var firstAddr string
+ for i := 0; i < n; i++ {
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ log.Fatal(err)
+ }
+ addr := res.Header.Get("X-Addr")
+ if i == 0 {
+ firstAddr = addr
+ } else if addr != firstAddr {
+ t.Fatalf("On request %d, addr %q != original addr %q", i+1, addr, firstAddr)
+ }
+ res.Body.Close()
+ }
+}
+
+// Issue 13839
+func TestNoCrashReturningTransportAltConn(t *testing.T) {
+ cert, err := tls.X509KeyPair(internal.LocalhostCert, internal.LocalhostKey)
+ if err != nil {
+ t.Fatal(err)
+ }
+ ln := newLocalListener(t)
+ defer ln.Close()
+
+ handledPendingDial := make(chan bool, 1)
+ SetPendingDialHooks(nil, func() { handledPendingDial <- true })
+ defer SetPendingDialHooks(nil, nil)
+
+ testDone := make(chan struct{})
+ defer close(testDone)
+ go func() {
+ tln := tls.NewListener(ln, &tls.Config{
+ NextProtos: []string{"foo"},
+ Certificates: []tls.Certificate{cert},
+ })
+ sc, err := tln.Accept()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ if err := sc.(*tls.Conn).Handshake(); err != nil {
+ t.Error(err)
+ return
+ }
+ <-testDone
+ sc.Close()
+ }()
+
+ addr := ln.Addr().String()
+
+ req, _ := NewRequest("GET", "https://fake.tld/", nil)
+ cancel := make(chan struct{})
+ req.Cancel = cancel
+
+ doReturned := make(chan bool, 1)
+ madeRoundTripper := make(chan bool, 1)
+
+ tr := &Transport{
+ DisableKeepAlives: true,
+ TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{
+ "foo": func(authority string, c *tls.Conn) RoundTripper {
+ madeRoundTripper <- true
+ return funcRoundTripper(func() {
+ t.Error("foo RoundTripper should not be called")
+ })
+ },
+ },
+ Dial: func(_, _ string) (net.Conn, error) {
+ panic("shouldn't be called")
+ },
+ DialTLS: func(_, _ string) (net.Conn, error) {
+ tc, err := tls.Dial("tcp", addr, &tls.Config{
+ InsecureSkipVerify: true,
+ NextProtos: []string{"foo"},
+ })
+ if err != nil {
+ return nil, err
+ }
+ if err := tc.Handshake(); err != nil {
+ return nil, err
+ }
+ close(cancel)
+ <-doReturned
+ return tc, nil
+ },
+ }
+ c := &Client{Transport: tr}
+
+ _, err = c.Do(req)
+ if ue, ok := err.(*url.Error); !ok || ue.Err != ExportErrRequestCanceledConn {
+ t.Fatalf("Do error = %v; want url.Error with errRequestCanceledConn", err)
+ }
+
+ doReturned <- true
+ <-madeRoundTripper
+ <-handledPendingDial
+}
+
+var errFakeRoundTrip = errors.New("fake roundtrip")
+
+type funcRoundTripper func()
+
+func (fn funcRoundTripper) RoundTrip(*Request) (*Response, error) {
+ fn()
+ return nil, errFakeRoundTrip
+}
+
func wantBody(res *Response, err error, want string) error {
if err != nil {
return err
diff --git a/libgo/go/net/http/triv.go b/libgo/go/net/http/triv.go
index 232d6508906..cfbc5778c1c 100644
--- a/libgo/go/net/http/triv.go
+++ b/libgo/go/net/http/triv.go
@@ -134,8 +134,5 @@ func main() {
http.HandleFunc("/args", ArgServer)
http.HandleFunc("/go/hello", HelloServer)
http.HandleFunc("/date", DateServer)
- err := http.ListenAndServe(":12345", nil)
- if err != nil {
- log.Panicln("ListenAndServe:", err)
- }
+ log.Fatal(http.ListenAndServe(":12345", nil))
}
diff --git a/libgo/go/net/interface_test.go b/libgo/go/net/interface_test.go
index 567d18de448..7bdd9241503 100644
--- a/libgo/go/net/interface_test.go
+++ b/libgo/go/net/interface_test.go
@@ -185,21 +185,50 @@ func testAddrs(t *testing.T, ifat []Addr) (naf4, naf6 int) {
t.Errorf("unexpected value: %#v", ifa)
continue
}
+ if len(ifa.IP) != IPv6len {
+ t.Errorf("should be internal representation either IPv6 or IPv6 IPv4-mapped address: %#v", ifa)
+ continue
+ }
prefixLen, maxPrefixLen := ifa.Mask.Size()
if ifa.IP.To4() != nil {
if 0 >= prefixLen || prefixLen > 8*IPv4len || maxPrefixLen != 8*IPv4len {
- t.Errorf("unexpected prefix length: %v/%v", prefixLen, maxPrefixLen)
+ t.Errorf("unexpected prefix length: %d/%d", prefixLen, maxPrefixLen)
+ continue
+ }
+ if ifa.IP.IsLoopback() && (prefixLen != 8 && prefixLen != 8*IPv4len) { // see RFC 1122
+ t.Errorf("unexpected prefix length for IPv4 loopback: %d/%d", prefixLen, maxPrefixLen)
continue
}
naf4++
- } else if ifa.IP.To16() != nil {
+ }
+ if ifa.IP.To16() != nil && ifa.IP.To4() == nil {
if 0 >= prefixLen || prefixLen > 8*IPv6len || maxPrefixLen != 8*IPv6len {
- t.Errorf("unexpected prefix length: %v/%v", prefixLen, maxPrefixLen)
+ t.Errorf("unexpected prefix length: %d/%d", prefixLen, maxPrefixLen)
+ continue
+ }
+ if ifa.IP.IsLoopback() && prefixLen != 8*IPv6len { // see RFC 4291
+ t.Errorf("unexpected prefix length for IPv6 loopback: %d/%d", prefixLen, maxPrefixLen)
continue
}
naf6++
}
t.Logf("interface address %q", ifa.String())
+ case *IPAddr:
+ if ifa == nil || ifa.IP == nil || ifa.IP.IsUnspecified() || ifa.IP.IsMulticast() {
+ t.Errorf("unexpected value: %#v", ifa)
+ continue
+ }
+ if len(ifa.IP) != IPv6len {
+ t.Errorf("should be internal representation either IPv6 or IPv6 IPv4-mapped address: %#v", ifa)
+ continue
+ }
+ if ifa.IP.To4() != nil {
+ naf4++
+ }
+ if ifa.IP.To16() != nil && ifa.IP.To4() == nil {
+ naf6++
+ }
+ t.Logf("interface address %s", ifa.String())
default:
t.Errorf("unexpected type: %T", ifa)
}
@@ -212,12 +241,17 @@ func testMulticastAddrs(t *testing.T, ifmat []Addr) (nmaf4, nmaf6 int) {
switch ifma := ifma.(type) {
case *IPAddr:
if ifma == nil || ifma.IP == nil || ifma.IP.IsUnspecified() || !ifma.IP.IsMulticast() {
- t.Errorf("unexpected value: %#v", ifma)
+ t.Errorf("unexpected value: %+v", ifma)
+ continue
+ }
+ if len(ifma.IP) != IPv6len {
+ t.Errorf("should be internal representation either IPv6 or IPv6 IPv4-mapped address: %#v", ifma)
continue
}
if ifma.IP.To4() != nil {
nmaf4++
- } else if ifma.IP.To16() != nil {
+ }
+ if ifma.IP.To16() != nil && ifma.IP.To4() == nil {
nmaf6++
}
t.Logf("joined group address %q", ifma.String())
diff --git a/libgo/go/net/interface_windows.go b/libgo/go/net/interface_windows.go
index e25c1ed560b..4d6bcdf4c76 100644
--- a/libgo/go/net/interface_windows.go
+++ b/libgo/go/net/interface_windows.go
@@ -11,131 +11,103 @@ import (
"unsafe"
)
-func getAdapters() (*windows.IpAdapterAddresses, error) {
- block := uint32(unsafe.Sizeof(windows.IpAdapterAddresses{}))
+// supportsVistaIP reports whether the platform implements new IP
+// stack and ABIs supported on Windows Vista and above.
+var supportsVistaIP bool
- // pre-allocate a 15KB working buffer pointed to by the AdapterAddresses
- // parameter.
- // https://msdn.microsoft.com/en-us/library/windows/desktop/aa365915(v=vs.85).aspx
- size := uint32(15000)
-
- var addrs []windows.IpAdapterAddresses
- for {
- addrs = make([]windows.IpAdapterAddresses, size/block+1)
- err := windows.GetAdaptersAddresses(syscall.AF_UNSPEC, windows.GAA_FLAG_INCLUDE_PREFIX, 0, &addrs[0], &size)
- if err == nil {
- break
- }
- if err.(syscall.Errno) != syscall.ERROR_BUFFER_OVERFLOW {
- return nil, os.NewSyscallError("getadaptersaddresses", err)
- }
- }
- return &addrs[0], nil
+func init() {
+ supportsVistaIP = probeWindowsIPStack()
}
-func getInterfaceInfos() ([]syscall.InterfaceInfo, error) {
- s, err := sysSocket(syscall.AF_INET, syscall.SOCK_DGRAM, syscall.IPPROTO_UDP)
- if err != nil {
- return nil, err
- }
- defer closeFunc(s)
-
- iia := [20]syscall.InterfaceInfo{}
- ret := uint32(0)
- size := uint32(unsafe.Sizeof(iia))
- err = syscall.WSAIoctl(s, syscall.SIO_GET_INTERFACE_LIST, nil, 0, (*byte)(unsafe.Pointer(&iia[0])), size, &ret, nil, 0)
+func probeWindowsIPStack() (supportsVistaIP bool) {
+ v, err := syscall.GetVersion()
if err != nil {
- return nil, os.NewSyscallError("wsaioctl", err)
+ return true // Windows 10 and above will deprecate this API
}
- iilen := ret / uint32(unsafe.Sizeof(iia[0]))
- return iia[:iilen-1], nil
-}
-
-func bytesEqualIP(a []byte, b []int8) bool {
- for i := 0; i < len(a); i++ {
- if a[i] != byte(b[i]) {
- return false
- }
+ if byte(v) < 6 { // major version of Windows Vista is 6
+ return false
}
return true
}
-func findInterfaceInfo(iis []syscall.InterfaceInfo, paddr *windows.IpAdapterAddresses) *syscall.InterfaceInfo {
- for _, ii := range iis {
- iaddr := (*syscall.RawSockaddr)(unsafe.Pointer(&ii.Address))
- puni := paddr.FirstUnicastAddress
- for ; puni != nil; puni = puni.Next {
- if iaddr.Family == puni.Address.Sockaddr.Addr.Family {
- switch iaddr.Family {
- case syscall.AF_INET:
- a := (*syscall.RawSockaddrInet4)(unsafe.Pointer(&ii.Address)).Addr
- if bytesEqualIP(a[:], puni.Address.Sockaddr.Addr.Data[2:]) {
- return &ii
- }
- case syscall.AF_INET6:
- a := (*syscall.RawSockaddrInet6)(unsafe.Pointer(&ii.Address)).Addr
- if bytesEqualIP(a[:], puni.Address.Sockaddr.Addr.Data[2:]) {
- return &ii
- }
- default:
- continue
- }
+// adapterAddresses returns a list of IP adapter and address
+// structures. The structure contains an IP adapter and flattened
+// multiple IP addresses including unicast, anycast and multicast
+// addresses.
+func adapterAddresses() ([]*windows.IpAdapterAddresses, error) {
+ var b []byte
+ l := uint32(15000) // recommended initial size
+ for {
+ b = make([]byte, l)
+ err := windows.GetAdaptersAddresses(syscall.AF_UNSPEC, windows.GAA_FLAG_INCLUDE_PREFIX, 0, (*windows.IpAdapterAddresses)(unsafe.Pointer(&b[0])), &l)
+ if err == nil {
+ if l == 0 {
+ return nil, nil
}
+ break
}
+ if err.(syscall.Errno) != syscall.ERROR_BUFFER_OVERFLOW {
+ return nil, os.NewSyscallError("getadaptersaddresses", err)
+ }
+ if l <= uint32(len(b)) {
+ return nil, os.NewSyscallError("getadaptersaddresses", err)
+ }
+ }
+ var aas []*windows.IpAdapterAddresses
+ for aa := (*windows.IpAdapterAddresses)(unsafe.Pointer(&b[0])); aa != nil; aa = aa.Next {
+ aas = append(aas, aa)
}
- return nil
+ return aas, nil
}
// If the ifindex is zero, interfaceTable returns mappings of all
// network interfaces. Otherwise it returns a mapping of a specific
// interface.
func interfaceTable(ifindex int) ([]Interface, error) {
- paddr, err := getAdapters()
- if err != nil {
- return nil, err
- }
-
- iis, err := getInterfaceInfos()
+ aas, err := adapterAddresses()
if err != nil {
return nil, err
}
-
var ift []Interface
- for ; paddr != nil; paddr = paddr.Next {
- index := paddr.IfIndex
- if paddr.Ipv6IfIndex != 0 {
- index = paddr.Ipv6IfIndex
+ for _, aa := range aas {
+ index := aa.IfIndex
+ if index == 0 { // ipv6IfIndex is a substitute for ifIndex
+ index = aa.Ipv6IfIndex
}
if ifindex == 0 || ifindex == int(index) {
- ii := findInterfaceInfo(iis, paddr)
- if ii == nil {
- continue
- }
- var flags Flags
- if paddr.Flags&windows.IfOperStatusUp != 0 {
- flags |= FlagUp
- }
- if paddr.IfType&windows.IF_TYPE_SOFTWARE_LOOPBACK != 0 {
- flags |= FlagLoopback
+ ifi := Interface{
+ Index: int(index),
+ Name: syscall.UTF16ToString((*(*[10000]uint16)(unsafe.Pointer(aa.FriendlyName)))[:]),
}
- if ii.Flags&syscall.IFF_BROADCAST != 0 {
- flags |= FlagBroadcast
+ if aa.OperStatus == windows.IfOperStatusUp {
+ ifi.Flags |= FlagUp
}
- if ii.Flags&syscall.IFF_POINTTOPOINT != 0 {
- flags |= FlagPointToPoint
+ // For now we need to infer link-layer service
+ // capabilities from media types.
+ // We will be able to use
+ // MIB_IF_ROW2.AccessType once we drop support
+ // for Windows XP.
+ switch aa.IfType {
+ case windows.IF_TYPE_ETHERNET_CSMACD, windows.IF_TYPE_ISO88025_TOKENRING, windows.IF_TYPE_IEEE80211, windows.IF_TYPE_IEEE1394:
+ ifi.Flags |= FlagBroadcast | FlagMulticast
+ case windows.IF_TYPE_PPP, windows.IF_TYPE_TUNNEL:
+ ifi.Flags |= FlagPointToPoint | FlagMulticast
+ case windows.IF_TYPE_SOFTWARE_LOOPBACK:
+ ifi.Flags |= FlagLoopback | FlagMulticast
+ case windows.IF_TYPE_ATM:
+ ifi.Flags |= FlagBroadcast | FlagPointToPoint | FlagMulticast // assume all services available; LANE, point-to-point and point-to-multipoint
}
- if ii.Flags&syscall.IFF_MULTICAST != 0 {
- flags |= FlagMulticast
+ if aa.Mtu == 0xffffffff {
+ ifi.MTU = -1
+ } else {
+ ifi.MTU = int(aa.Mtu)
}
- ifi := Interface{
- Index: int(index),
- MTU: int(paddr.Mtu),
- Name: syscall.UTF16ToString((*(*[10000]uint16)(unsafe.Pointer(paddr.FriendlyName)))[:]),
- HardwareAddr: HardwareAddr(paddr.PhysicalAddress[:]),
- Flags: flags,
+ if aa.PhysicalAddressLength > 0 {
+ ifi.HardwareAddr = make(HardwareAddr, aa.PhysicalAddressLength)
+ copy(ifi.HardwareAddr, aa.PhysicalAddress[:])
}
ift = append(ift, ifi)
- if ifindex == int(ifi.Index) {
+ if ifindex == ifi.Index {
break
}
}
@@ -147,86 +119,150 @@ func interfaceTable(ifindex int) ([]Interface, error) {
// network interfaces. Otherwise it returns addresses for a specific
// interface.
func interfaceAddrTable(ifi *Interface) ([]Addr, error) {
- paddr, err := getAdapters()
+ aas, err := adapterAddresses()
if err != nil {
return nil, err
}
-
var ifat []Addr
- for ; paddr != nil; paddr = paddr.Next {
- index := paddr.IfIndex
- if paddr.Ipv6IfIndex != 0 {
- index = paddr.Ipv6IfIndex
+ for _, aa := range aas {
+ index := aa.IfIndex
+ if index == 0 { // ipv6IfIndex is a substitute for ifIndex
+ index = aa.Ipv6IfIndex
+ }
+ var pfx4, pfx6 []IPNet
+ if !supportsVistaIP {
+ pfx4, pfx6, err = addrPrefixTable(aa)
+ if err != nil {
+ return nil, err
+ }
}
if ifi == nil || ifi.Index == int(index) {
- puni := paddr.FirstUnicastAddress
- for ; puni != nil; puni = puni.Next {
- if sa, err := puni.Address.Sockaddr.Sockaddr(); err == nil {
- switch sav := sa.(type) {
- case *syscall.SockaddrInet4:
- ifa := &IPNet{IP: make(IP, IPv4len), Mask: CIDRMask(int(puni.Address.SockaddrLength), 8*IPv4len)}
- copy(ifa.IP, sav.Addr[:])
- ifat = append(ifat, ifa)
- case *syscall.SockaddrInet6:
- ifa := &IPNet{IP: make(IP, IPv6len), Mask: CIDRMask(int(puni.Address.SockaddrLength), 8*IPv6len)}
- copy(ifa.IP, sav.Addr[:])
- ifat = append(ifat, ifa)
+ for puni := aa.FirstUnicastAddress; puni != nil; puni = puni.Next {
+ sa, err := puni.Address.Sockaddr.Sockaddr()
+ if err != nil {
+ return nil, os.NewSyscallError("sockaddr", err)
+ }
+ var l int
+ switch sa := sa.(type) {
+ case *syscall.SockaddrInet4:
+ if supportsVistaIP {
+ l = int(puni.OnLinkPrefixLength)
+ } else {
+ l = addrPrefixLen(pfx4, IP(sa.Addr[:]))
}
+ ifat = append(ifat, &IPNet{IP: IPv4(sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3]), Mask: CIDRMask(l, 8*IPv4len)})
+ case *syscall.SockaddrInet6:
+ if supportsVistaIP {
+ l = int(puni.OnLinkPrefixLength)
+ } else {
+ l = addrPrefixLen(pfx6, IP(sa.Addr[:]))
+ }
+ ifa := &IPNet{IP: make(IP, IPv6len), Mask: CIDRMask(l, 8*IPv6len)}
+ copy(ifa.IP, sa.Addr[:])
+ ifat = append(ifat, ifa)
}
}
- pany := paddr.FirstAnycastAddress
- for ; pany != nil; pany = pany.Next {
- if sa, err := pany.Address.Sockaddr.Sockaddr(); err == nil {
- switch sav := sa.(type) {
- case *syscall.SockaddrInet4:
- ifa := &IPNet{IP: make(IP, IPv4len), Mask: CIDRMask(int(pany.Address.SockaddrLength), 8*IPv4len)}
- copy(ifa.IP, sav.Addr[:])
- ifat = append(ifat, ifa)
- case *syscall.SockaddrInet6:
- ifa := &IPNet{IP: make(IP, IPv6len), Mask: CIDRMask(int(pany.Address.SockaddrLength), 8*IPv6len)}
- copy(ifa.IP, sav.Addr[:])
- ifat = append(ifat, ifa)
- }
+ for pany := aa.FirstAnycastAddress; pany != nil; pany = pany.Next {
+ sa, err := pany.Address.Sockaddr.Sockaddr()
+ if err != nil {
+ return nil, os.NewSyscallError("sockaddr", err)
+ }
+ switch sa := sa.(type) {
+ case *syscall.SockaddrInet4:
+ ifat = append(ifat, &IPAddr{IP: IPv4(sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3])})
+ case *syscall.SockaddrInet6:
+ ifa := &IPAddr{IP: make(IP, IPv6len)}
+ copy(ifa.IP, sa.Addr[:])
+ ifat = append(ifat, ifa)
}
}
}
}
-
return ifat, nil
}
+func addrPrefixTable(aa *windows.IpAdapterAddresses) (pfx4, pfx6 []IPNet, err error) {
+ for p := aa.FirstPrefix; p != nil; p = p.Next {
+ sa, err := p.Address.Sockaddr.Sockaddr()
+ if err != nil {
+ return nil, nil, os.NewSyscallError("sockaddr", err)
+ }
+ switch sa := sa.(type) {
+ case *syscall.SockaddrInet4:
+ pfx := IPNet{IP: IP(sa.Addr[:]), Mask: CIDRMask(int(p.PrefixLength), 8*IPv4len)}
+ pfx4 = append(pfx4, pfx)
+ case *syscall.SockaddrInet6:
+ pfx := IPNet{IP: IP(sa.Addr[:]), Mask: CIDRMask(int(p.PrefixLength), 8*IPv6len)}
+ pfx6 = append(pfx6, pfx)
+ }
+ }
+ return
+}
+
+// addrPrefixLen returns an appropriate prefix length in bits for ip
+// from pfxs. It returns 32 or 128 when no appropriate on-link address
+// prefix found.
+//
+// NOTE: This is pretty naive implementation that contains many
+// allocations and non-effective linear search, and should not be used
+// freely.
+func addrPrefixLen(pfxs []IPNet, ip IP) int {
+ var l int
+ var cand *IPNet
+ for i := range pfxs {
+ if !pfxs[i].Contains(ip) {
+ continue
+ }
+ if cand == nil {
+ l, _ = pfxs[i].Mask.Size()
+ cand = &pfxs[i]
+ continue
+ }
+ m, _ := pfxs[i].Mask.Size()
+ if m > l {
+ l = m
+ cand = &pfxs[i]
+ continue
+ }
+ }
+ if l > 0 {
+ return l
+ }
+ if ip.To4() != nil {
+ return 8 * IPv4len
+ }
+ return 8 * IPv6len
+}
+
// interfaceMulticastAddrTable returns addresses for a specific
// interface.
func interfaceMulticastAddrTable(ifi *Interface) ([]Addr, error) {
- paddr, err := getAdapters()
+ aas, err := adapterAddresses()
if err != nil {
return nil, err
}
-
var ifat []Addr
- for ; paddr != nil; paddr = paddr.Next {
- index := paddr.IfIndex
- if paddr.Ipv6IfIndex != 0 {
- index = paddr.Ipv6IfIndex
+ for _, aa := range aas {
+ index := aa.IfIndex
+ if index == 0 { // ipv6IfIndex is a substitute for ifIndex
+ index = aa.Ipv6IfIndex
}
if ifi == nil || ifi.Index == int(index) {
- pmul := paddr.FirstMulticastAddress
- for ; pmul != nil; pmul = pmul.Next {
- if sa, err := pmul.Address.Sockaddr.Sockaddr(); err == nil {
- switch sav := sa.(type) {
- case *syscall.SockaddrInet4:
- ifa := &IPAddr{IP: make(IP, IPv4len)}
- copy(ifa.IP, sav.Addr[:])
- ifat = append(ifat, ifa)
- case *syscall.SockaddrInet6:
- ifa := &IPAddr{IP: make(IP, IPv6len)}
- copy(ifa.IP, sav.Addr[:])
- ifat = append(ifat, ifa)
- }
+ for pmul := aa.FirstMulticastAddress; pmul != nil; pmul = pmul.Next {
+ sa, err := pmul.Address.Sockaddr.Sockaddr()
+ if err != nil {
+ return nil, os.NewSyscallError("sockaddr", err)
+ }
+ switch sa := sa.(type) {
+ case *syscall.SockaddrInet4:
+ ifat = append(ifat, &IPAddr{IP: IPv4(sa.Addr[0], sa.Addr[1], sa.Addr[2], sa.Addr[3])})
+ case *syscall.SockaddrInet6:
+ ifa := &IPAddr{IP: make(IP, IPv6len)}
+ copy(ifa.IP, sa.Addr[:])
+ ifat = append(ifat, ifa)
}
}
}
}
-
return ifat, nil
}
diff --git a/libgo/go/net/interface_windows_test.go b/libgo/go/net/interface_windows_test.go
new file mode 100644
index 00000000000..03f9168b482
--- /dev/null
+++ b/libgo/go/net/interface_windows_test.go
@@ -0,0 +1,132 @@
+// Copyright 2015 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package net
+
+import (
+ "bytes"
+ "internal/syscall/windows"
+ "sort"
+ "testing"
+)
+
+func TestWindowsInterfaces(t *testing.T) {
+ aas, err := adapterAddresses()
+ if err != nil {
+ t.Fatal(err)
+ }
+ ift, err := Interfaces()
+ if err != nil {
+ t.Fatal(err)
+ }
+ for i, ifi := range ift {
+ aa := aas[i]
+ if len(ifi.HardwareAddr) != int(aa.PhysicalAddressLength) {
+ t.Errorf("got %d; want %d", len(ifi.HardwareAddr), aa.PhysicalAddressLength)
+ }
+ if ifi.MTU > 0x7fffffff {
+ t.Errorf("%s: got %d; want less than or equal to 1<<31 - 1", ifi.Name, ifi.MTU)
+ }
+ if ifi.Flags&FlagUp != 0 && aa.OperStatus != windows.IfOperStatusUp {
+ t.Errorf("%s: got %v; should not include FlagUp", ifi.Name, ifi.Flags)
+ }
+ if ifi.Flags&FlagLoopback != 0 && aa.IfType != windows.IF_TYPE_SOFTWARE_LOOPBACK {
+ t.Errorf("%s: got %v; should not include FlagLoopback", ifi.Name, ifi.Flags)
+ }
+ if _, _, err := addrPrefixTable(aa); err != nil {
+ t.Errorf("%s: %v", ifi.Name, err)
+ }
+ }
+}
+
+type byAddrLen []IPNet
+
+func (ps byAddrLen) Len() int { return len(ps) }
+
+func (ps byAddrLen) Less(i, j int) bool {
+ if n := bytes.Compare(ps[i].IP, ps[j].IP); n != 0 {
+ return n < 0
+ }
+ if n := bytes.Compare(ps[i].Mask, ps[j].Mask); n != 0 {
+ return n < 0
+ }
+ return false
+}
+
+func (ps byAddrLen) Swap(i, j int) { ps[i], ps[j] = ps[j], ps[i] }
+
+var windowsAddrPrefixLenTests = []struct {
+ pfxs []IPNet
+ ip IP
+ out int
+}{
+ {
+ []IPNet{
+ {IP: IPv4(172, 16, 0, 0), Mask: IPv4Mask(255, 255, 0, 0)},
+ {IP: IPv4(192, 168, 0, 0), Mask: IPv4Mask(255, 255, 255, 0)},
+ {IP: IPv4(192, 168, 0, 0), Mask: IPv4Mask(255, 255, 255, 128)},
+ {IP: IPv4(192, 168, 0, 0), Mask: IPv4Mask(255, 255, 255, 192)},
+ },
+ IPv4(192, 168, 0, 1),
+ 26,
+ },
+ {
+ []IPNet{
+ {IP: ParseIP("2001:db8::"), Mask: IPMask(ParseIP("ffff:ffff:ffff:ffff:ffff:ffff:ffff:fff0"))},
+ {IP: ParseIP("2001:db8::"), Mask: IPMask(ParseIP("ffff:ffff:ffff:ffff:ffff:ffff:ffff:fff8"))},
+ {IP: ParseIP("2001:db8::"), Mask: IPMask(ParseIP("ffff:ffff:ffff:ffff:ffff:ffff:ffff:fffc"))},
+ },
+ ParseIP("2001:db8::1"),
+ 126,
+ },
+
+ // Fallback cases. It may happen on Windows XP or 2003 server.
+ {
+ []IPNet{
+ {IP: IPv4(127, 0, 0, 0).To4(), Mask: IPv4Mask(255, 0, 0, 0)},
+ {IP: IPv4(10, 0, 0, 0).To4(), Mask: IPv4Mask(255, 0, 0, 0)},
+ {IP: IPv4(172, 16, 0, 0).To4(), Mask: IPv4Mask(255, 255, 0, 0)},
+ {IP: IPv4(192, 168, 255, 0), Mask: IPv4Mask(255, 255, 255, 0)},
+ {IP: IPv4zero, Mask: IPv4Mask(0, 0, 0, 0)},
+ },
+ IPv4(192, 168, 0, 1),
+ 8 * IPv4len,
+ },
+ {
+ nil,
+ IPv4(192, 168, 0, 1),
+ 8 * IPv4len,
+ },
+ {
+ []IPNet{
+ {IP: IPv6loopback, Mask: IPMask(ParseIP("ffff:ffff:ffff:ffff:ffff:ffff:ffff:ffff"))},
+ {IP: ParseIP("2001:db8:1::"), Mask: IPMask(ParseIP("ffff:ffff:ffff:ffff:ffff:ffff:ffff:fff0"))},
+ {IP: ParseIP("2001:db8:2::"), Mask: IPMask(ParseIP("ffff:ffff:ffff:ffff:ffff:ffff:ffff:fff8"))},
+ {IP: ParseIP("2001:db8:3::"), Mask: IPMask(ParseIP("ffff:ffff:ffff:ffff:ffff:ffff:ffff:fffc"))},
+ {IP: IPv6unspecified, Mask: IPMask(ParseIP("::"))},
+ },
+ ParseIP("2001:db8::1"),
+ 8 * IPv6len,
+ },
+ {
+ nil,
+ ParseIP("2001:db8::1"),
+ 8 * IPv6len,
+ },
+}
+
+func TestWindowsAddrPrefixLen(t *testing.T) {
+ for i, tt := range windowsAddrPrefixLenTests {
+ sort.Sort(byAddrLen(tt.pfxs))
+ l := addrPrefixLen(tt.pfxs, tt.ip)
+ if l != tt.out {
+ t.Errorf("#%d: got %d; want %d", i, l, tt.out)
+ }
+ sort.Sort(sort.Reverse(byAddrLen(tt.pfxs)))
+ l = addrPrefixLen(tt.pfxs, tt.ip)
+ if l != tt.out {
+ t.Errorf("#%d: got %d; want %d", i, l, tt.out)
+ }
+ }
+}
diff --git a/libgo/go/net/internal/socktest/switch.go b/libgo/go/net/internal/socktest/switch.go
index 4e38c7a85f3..8bef06b97c8 100644
--- a/libgo/go/net/internal/socktest/switch.go
+++ b/libgo/go/net/internal/socktest/switch.go
@@ -77,7 +77,7 @@ type Status struct {
}
func (so Status) String() string {
- return fmt.Sprintf("(%s, %s, %s): syscallerr=%v, socketerr=%v", familyString(so.Cookie.Family()), typeString(so.Cookie.Type()), protocolString(so.Cookie.Protocol()), so.Err, so.SocketErr)
+ return fmt.Sprintf("(%s, %s, %s): syscallerr=%v socketerr=%v", familyString(so.Cookie.Family()), typeString(so.Cookie.Type()), protocolString(so.Cookie.Protocol()), so.Err, so.SocketErr)
}
// A Stat represents a per-cookie socket statistics.
@@ -100,7 +100,7 @@ type Stat struct {
}
func (st Stat) String() string {
- return fmt.Sprintf("(%s, %s, %s): opened=%d, connected=%d, listened=%d, accepted=%d, closed=%d, openfailed=%d, connectfailed=%d, listenfailed=%d, acceptfailed=%d, closefailed=%d", familyString(st.Family), typeString(st.Type), protocolString(st.Protocol), st.Opened, st.Connected, st.Listened, st.Accepted, st.Closed, st.OpenFailed, st.ConnectFailed, st.ListenFailed, st.AcceptFailed, st.CloseFailed)
+ return fmt.Sprintf("(%s, %s, %s): opened=%d connected=%d listened=%d accepted=%d closed=%d openfailed=%d connectfailed=%d listenfailed=%d acceptfailed=%d closefailed=%d", familyString(st.Family), typeString(st.Type), protocolString(st.Protocol), st.Opened, st.Connected, st.Listened, st.Accepted, st.Closed, st.OpenFailed, st.ConnectFailed, st.ListenFailed, st.AcceptFailed, st.CloseFailed)
}
type stats map[Cookie]*Stat
diff --git a/libgo/go/net/iprawsock_posix.go b/libgo/go/net/iprawsock_posix.go
index 9417606ce94..93fee3e232e 100644
--- a/libgo/go/net/iprawsock_posix.go
+++ b/libgo/go/net/iprawsock_posix.go
@@ -220,7 +220,7 @@ func dialIP(netProto string, laddr, raddr *IPAddr, deadline time.Time) (*IPConn,
if raddr == nil {
return nil, &OpError{Op: "dial", Net: netProto, Source: laddr.opAddr(), Addr: nil, Err: errMissingAddress}
}
- fd, err := internetSocket(net, laddr, raddr, deadline, syscall.SOCK_RAW, proto, "dial")
+ fd, err := internetSocket(net, laddr, raddr, deadline, syscall.SOCK_RAW, proto, "dial", noCancel)
if err != nil {
return nil, &OpError{Op: "dial", Net: netProto, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
}
@@ -241,7 +241,7 @@ func ListenIP(netProto string, laddr *IPAddr) (*IPConn, error) {
default:
return nil, &OpError{Op: "listen", Net: netProto, Source: nil, Addr: laddr.opAddr(), Err: UnknownNetworkError(netProto)}
}
- fd, err := internetSocket(net, laddr, nil, noDeadline, syscall.SOCK_RAW, proto, "listen")
+ fd, err := internetSocket(net, laddr, nil, noDeadline, syscall.SOCK_RAW, proto, "listen", noCancel)
if err != nil {
return nil, &OpError{Op: "listen", Net: netProto, Source: nil, Addr: laddr.opAddr(), Err: err}
}
diff --git a/libgo/go/net/ipsock.go b/libgo/go/net/ipsock.go
index 6e75c33d534..55f697f622e 100644
--- a/libgo/go/net/ipsock.go
+++ b/libgo/go/net/ipsock.go
@@ -213,7 +213,7 @@ func internetAddrList(net, addr string, deadline time.Time) (addrList, error) {
if host, port, err = SplitHostPort(addr); err != nil {
return nil, err
}
- if portnum, err = parsePort(net, port); err != nil {
+ if portnum, err = LookupPort(net, port); err != nil {
return nil, err
}
}
diff --git a/libgo/go/net/ipsock_posix.go b/libgo/go/net/ipsock_posix.go
index 83eaf855b4c..2bddd46a156 100644
--- a/libgo/go/net/ipsock_posix.go
+++ b/libgo/go/net/ipsock_posix.go
@@ -101,10 +101,11 @@ func probeIPv6Stack() (supportsIPv6, supportsIPv4map bool) {
//
// 1. A wild-wild listen, "tcp" + ""
// If the platform supports both IPv6 and IPv6 IPv4-mapping
-// capabilities, we assume that the user want to listen on
-// both IPv4 and IPv6 wildcard address over an AF_INET6
-// socket with IPV6_V6ONLY=0. Otherwise we prefer an IPv4
-// wildcard address listen over an AF_INET socket.
+// capabilities, or does not support IPv4, we assume that
+// the user wants to listen on both IPv4 and IPv6 wildcard
+// addresses over an AF_INET6 socket with IPV6_V6ONLY=0.
+// Otherwise we prefer an IPv4 wildcard address listen over
+// an AF_INET socket.
//
// 2. A wild-ipv4wild listen, "tcp" + "0.0.0.0"
// Same as 1.
@@ -137,7 +138,7 @@ func favoriteAddrFamily(net string, laddr, raddr sockaddr, mode string) (family
}
if mode == "listen" && (laddr == nil || laddr.isWildcard()) {
- if supportsIPv4map {
+ if supportsIPv4map || !supportsIPv4 {
return syscall.AF_INET6, false
}
if laddr == nil {
@@ -155,9 +156,9 @@ func favoriteAddrFamily(net string, laddr, raddr sockaddr, mode string) (family
// Internet sockets (TCP, UDP, IP)
-func internetSocket(net string, laddr, raddr sockaddr, deadline time.Time, sotype, proto int, mode string) (fd *netFD, err error) {
+func internetSocket(net string, laddr, raddr sockaddr, deadline time.Time, sotype, proto int, mode string, cancel <-chan struct{}) (fd *netFD, err error) {
family, ipv6only := favoriteAddrFamily(net, laddr, raddr, mode)
- return socket(net, family, sotype, proto, ipv6only, laddr, raddr, deadline)
+ return socket(net, family, sotype, proto, ipv6only, laddr, raddr, deadline, cancel)
}
func ipToSockaddr(family int, ip IP, port int, zone string) (syscall.Sockaddr, error) {
diff --git a/libgo/go/net/lookup.go b/libgo/go/net/lookup.go
index 9008322dc5a..7aa111ba929 100644
--- a/libgo/go/net/lookup.go
+++ b/libgo/go/net/lookup.go
@@ -123,10 +123,22 @@ func lookupIPDeadline(host string, deadline time.Time) (addrs []IPAddr, err erro
// LookupPort looks up the port for the given network and service.
func LookupPort(network, service string) (port int, err error) {
- if n, i, ok := dtoi(service, 0); ok && i == len(service) {
- return n, nil
+ if service == "" {
+ // Lock in the legacy behavior that an empty string
+ // means port 0. See Issue 13610.
+ return 0, nil
}
- return lookupPort(network, service)
+ port, _, ok := dtoi(service, 0)
+ if !ok && port != big && port != -big {
+ port, err = lookupPort(network, service)
+ if err != nil {
+ return 0, err
+ }
+ }
+ if 0 > port || port > 65535 {
+ return 0, &AddrError{Err: "invalid port", Addr: service}
+ }
+ return port, nil
}
// LookupCNAME returns the canonical DNS host for the given name.
diff --git a/libgo/go/net/lookup_plan9.go b/libgo/go/net/lookup_plan9.go
index c6274640bb7..a33162882b7 100644
--- a/libgo/go/net/lookup_plan9.go
+++ b/libgo/go/net/lookup_plan9.go
@@ -225,8 +225,8 @@ func lookupSRV(service, proto, name string) (cname string, addrs []*SRV, err err
if !(portOk && priorityOk && weightOk) {
continue
}
- addrs = append(addrs, &SRV{f[5], uint16(port), uint16(priority), uint16(weight)})
- cname = f[0]
+ addrs = append(addrs, &SRV{absDomainName([]byte(f[5])), uint16(port), uint16(priority), uint16(weight)})
+ cname = absDomainName([]byte(f[0]))
}
byPriorityWeight(addrs).sort()
return
@@ -243,7 +243,7 @@ func lookupMX(name string) (mx []*MX, err error) {
continue
}
if pref, _, ok := dtoi(f[2], 0); ok {
- mx = append(mx, &MX{f[3], uint16(pref)})
+ mx = append(mx, &MX{absDomainName([]byte(f[3])), uint16(pref)})
}
}
byPref(mx).sort()
@@ -260,7 +260,7 @@ func lookupNS(name string) (ns []*NS, err error) {
if len(f) < 3 {
continue
}
- ns = append(ns, &NS{f[2]})
+ ns = append(ns, &NS{absDomainName([]byte(f[2]))})
}
return
}
@@ -272,7 +272,7 @@ func lookupTXT(name string) (txt []string, err error) {
}
for _, line := range lines {
if i := byteIndex(line, '\t'); i >= 0 {
- txt = append(txt, line[i+1:])
+ txt = append(txt, absDomainName([]byte(line[i+1:])))
}
}
return
@@ -292,7 +292,7 @@ func lookupAddr(addr string) (name []string, err error) {
if len(f) < 3 {
continue
}
- name = append(name, f[2])
+ name = append(name, absDomainName([]byte(f[2])))
}
return
}
diff --git a/libgo/go/net/lookup_test.go b/libgo/go/net/lookup_test.go
index 86957b55756..439496ac81d 100644
--- a/libgo/go/net/lookup_test.go
+++ b/libgo/go/net/lookup_test.go
@@ -7,6 +7,8 @@ package net
import (
"bytes"
"fmt"
+ "internal/testenv"
+ "runtime"
"strings"
"testing"
"time"
@@ -37,26 +39,26 @@ var lookupGoogleSRVTests = []struct {
}{
{
"xmpp-server", "tcp", "google.com",
- "google.com", "google.com",
+ "google.com.", "google.com.",
},
{
"xmpp-server", "tcp", "google.com.",
- "google.com", "google.com",
+ "google.com.", "google.com.",
},
// non-standard back door
{
"", "", "_xmpp-server._tcp.google.com",
- "google.com", "google.com",
+ "google.com.", "google.com.",
},
{
"", "", "_xmpp-server._tcp.google.com.",
- "google.com", "google.com",
+ "google.com.", "google.com.",
},
}
func TestLookupGoogleSRV(t *testing.T) {
- if testing.Short() || !*testExternal {
+ if testing.Short() && testenv.Builder() == "" || !*testExternal {
t.Skip("avoid external network")
}
if !supportsIPv4 || !*testIPv4 {
@@ -71,11 +73,11 @@ func TestLookupGoogleSRV(t *testing.T) {
if len(srvs) == 0 {
t.Error("got no record")
}
- if !strings.HasSuffix(cname, tt.cname) && !strings.HasSuffix(cname, tt.cname+".") {
+ if !strings.HasSuffix(cname, tt.cname) {
t.Errorf("got %s; want %s", cname, tt.cname)
}
for _, srv := range srvs {
- if !strings.HasSuffix(srv.Target, tt.target) && !strings.HasSuffix(srv.Target, tt.target+".") {
+ if !strings.HasSuffix(srv.Target, tt.target) {
t.Errorf("got %v; want a record containing %s", srv, tt.target)
}
}
@@ -85,12 +87,12 @@ func TestLookupGoogleSRV(t *testing.T) {
var lookupGmailMXTests = []struct {
name, host string
}{
- {"gmail.com", "google.com"},
- {"gmail.com.", "google.com"},
+ {"gmail.com", "google.com."},
+ {"gmail.com.", "google.com."},
}
func TestLookupGmailMX(t *testing.T) {
- if testing.Short() || !*testExternal {
+ if testing.Short() && testenv.Builder() == "" || !*testExternal {
t.Skip("avoid external network")
}
if !supportsIPv4 || !*testIPv4 {
@@ -106,7 +108,7 @@ func TestLookupGmailMX(t *testing.T) {
t.Error("got no record")
}
for _, mx := range mxs {
- if !strings.HasSuffix(mx.Host, tt.host) && !strings.HasSuffix(mx.Host, tt.host+".") {
+ if !strings.HasSuffix(mx.Host, tt.host) {
t.Errorf("got %v; want a record containing %s", mx, tt.host)
}
}
@@ -116,12 +118,12 @@ func TestLookupGmailMX(t *testing.T) {
var lookupGmailNSTests = []struct {
name, host string
}{
- {"gmail.com", "google.com"},
- {"gmail.com.", "google.com"},
+ {"gmail.com", "google.com."},
+ {"gmail.com.", "google.com."},
}
func TestLookupGmailNS(t *testing.T) {
- if testing.Short() || !*testExternal {
+ if testing.Short() && testenv.Builder() == "" || !*testExternal {
t.Skip("avoid external network")
}
if !supportsIPv4 || !*testIPv4 {
@@ -137,7 +139,7 @@ func TestLookupGmailNS(t *testing.T) {
t.Error("got no record")
}
for _, ns := range nss {
- if !strings.HasSuffix(ns.Host, tt.host) && !strings.HasSuffix(ns.Host, tt.host+".") {
+ if !strings.HasSuffix(ns.Host, tt.host) {
t.Errorf("got %v; want a record containing %s", ns, tt.host)
}
}
@@ -152,7 +154,7 @@ var lookupGmailTXTTests = []struct {
}
func TestLookupGmailTXT(t *testing.T) {
- if testing.Short() || !*testExternal {
+ if testing.Short() && testenv.Builder() == "" || !*testExternal {
t.Skip("avoid external network")
}
if !supportsIPv4 || !*testIPv4 {
@@ -178,14 +180,15 @@ func TestLookupGmailTXT(t *testing.T) {
var lookupGooglePublicDNSAddrTests = []struct {
addr, name string
}{
- {"8.8.8.8", ".google.com"},
- {"8.8.4.4", ".google.com"},
- {"2001:4860:4860::8888", ".google.com"},
- {"2001:4860:4860::8844", ".google.com"},
+ {"8.8.8.8", ".google.com."},
+ {"8.8.4.4", ".google.com."},
+
+ {"2001:4860:4860::8888", ".google.com."},
+ {"2001:4860:4860::8844", ".google.com."},
}
func TestLookupGooglePublicDNSAddr(t *testing.T) {
- if testing.Short() || !*testExternal {
+ if testing.Short() && testenv.Builder() == "" || !*testExternal {
t.Skip("avoid external network")
}
if !supportsIPv4 || !supportsIPv6 || !*testIPv4 || !*testIPv6 {
@@ -201,22 +204,46 @@ func TestLookupGooglePublicDNSAddr(t *testing.T) {
t.Error("got no record")
}
for _, name := range names {
- if !strings.HasSuffix(name, tt.name) && !strings.HasSuffix(name, tt.name+".") {
+ if !strings.HasSuffix(name, tt.name) {
t.Errorf("got %s; want a record containing %s", name, tt.name)
}
}
}
}
+func TestLookupIPv6LinkLocalAddr(t *testing.T) {
+ if !supportsIPv6 || !*testIPv6 {
+ t.Skip("IPv6 is required")
+ }
+
+ addrs, err := LookupHost("localhost")
+ if err != nil {
+ t.Fatal(err)
+ }
+ found := false
+ for _, addr := range addrs {
+ if addr == "fe80::1%lo0" {
+ found = true
+ break
+ }
+ }
+ if !found {
+ t.Skipf("not supported on %s", runtime.GOOS)
+ }
+ if _, err := LookupAddr("fe80::1%lo0"); err != nil {
+ t.Error(err)
+ }
+}
+
var lookupIANACNAMETests = []struct {
name, cname string
}{
- {"www.iana.org", "icann.org"},
- {"www.iana.org.", "icann.org"},
+ {"www.iana.org", "icann.org."},
+ {"www.iana.org.", "icann.org."},
}
func TestLookupIANACNAME(t *testing.T) {
- if testing.Short() || !*testExternal {
+ if testing.Short() && testenv.Builder() == "" || !*testExternal {
t.Skip("avoid external network")
}
if !supportsIPv4 || !*testIPv4 {
@@ -228,7 +255,7 @@ func TestLookupIANACNAME(t *testing.T) {
if err != nil {
t.Fatal(err)
}
- if !strings.HasSuffix(cname, tt.cname) && !strings.HasSuffix(cname, tt.cname+".") {
+ if !strings.HasSuffix(cname, tt.cname) {
t.Errorf("got %s; want a record containing %s", cname, tt.cname)
}
}
@@ -242,7 +269,7 @@ var lookupGoogleHostTests = []struct {
}
func TestLookupGoogleHost(t *testing.T) {
- if testing.Short() || !*testExternal {
+ if testing.Short() && testenv.Builder() == "" || !*testExternal {
t.Skip("avoid external network")
}
if !supportsIPv4 || !*testIPv4 {
@@ -273,7 +300,7 @@ var lookupGoogleIPTests = []struct {
}
func TestLookupGoogleIP(t *testing.T) {
- if testing.Short() || !*testExternal {
+ if testing.Short() && testenv.Builder() == "" || !*testExternal {
t.Skip("avoid external network")
}
if !supportsIPv4 || !*testIPv4 {
@@ -394,17 +421,62 @@ func TestLookupIPDeadline(t *testing.T) {
t.Logf("%v succeeded, %v failed (%v timeout, %v temporary, %v other, %v unknown)", qstats.succeeded, qstats.failed, qstats.timeout, qstats.temporary, qstats.other, qstats.unknown)
}
-func TestLookupDots(t *testing.T) {
- if testing.Short() || !*testExternal {
- t.Skipf("skipping external network test")
+func TestLookupDotsWithLocalSource(t *testing.T) {
+ if !supportsIPv4 || !*testIPv4 {
+ t.Skip("IPv4 is required")
}
- fixup := forceGoDNS()
- defer fixup()
- testDots(t, "go")
+ for i, fn := range []func() func(){forceGoDNS, forceCgoDNS} {
+ fixup := fn()
+ if fixup == nil {
+ continue
+ }
+ names, err := LookupAddr("127.0.0.1")
+ fixup()
+ if err != nil {
+ t.Logf("#%d: %v", i, err)
+ continue
+ }
+ mode := "netgo"
+ if i == 1 {
+ mode = "netcgo"
+ }
+ loop:
+ for i, name := range names {
+ if strings.Index(name, ".") == len(name)-1 { // "localhost" not "localhost."
+ for j := range names {
+ if j == i {
+ continue
+ }
+ if names[j] == name[:len(name)-1] {
+ // It's OK if we find the name without the dot,
+ // as some systems say 127.0.0.1 localhost localhost.
+ continue loop
+ }
+ }
+ t.Errorf("%s: got %s; want %s", mode, name, name[:len(name)-1])
+ } else if strings.Contains(name, ".") && !strings.HasSuffix(name, ".") { // "localhost.localdomain." not "localhost.localdomain"
+ t.Errorf("%s: got %s; want name ending with trailing dot", mode, name)
+ }
+ }
+ }
+}
+
+func TestLookupDotsWithRemoteSource(t *testing.T) {
+ if testing.Short() && testenv.Builder() == "" || !*testExternal {
+ t.Skip("avoid external network")
+ }
+ if !supportsIPv4 || !*testIPv4 {
+ t.Skip("IPv4 is required")
+ }
- if forceCgoDNS() {
+ if fixup := forceGoDNS(); fixup != nil {
+ testDots(t, "go")
+ fixup()
+ }
+ if fixup := forceCgoDNS(); fixup != nil {
testDots(t, "cgo")
+ fixup()
}
}
@@ -501,3 +573,62 @@ func srvString(srvs []*SRV) string {
fmt.Fprintf(&buf, "]")
return buf.String()
}
+
+var lookupPortTests = []struct {
+ network string
+ name string
+ port int
+ ok bool
+}{
+ {"tcp", "0", 0, true},
+ {"tcp", "echo", 7, true},
+ {"tcp", "discard", 9, true},
+ {"tcp", "systat", 11, true},
+ {"tcp", "daytime", 13, true},
+ {"tcp", "chargen", 19, true},
+ {"tcp", "ftp-data", 20, true},
+ {"tcp", "ftp", 21, true},
+ {"tcp", "telnet", 23, true},
+ {"tcp", "smtp", 25, true},
+ {"tcp", "time", 37, true},
+ {"tcp", "domain", 53, true},
+ {"tcp", "finger", 79, true},
+ {"tcp", "42", 42, true},
+
+ {"udp", "0", 0, true},
+ {"udp", "echo", 7, true},
+ {"udp", "tftp", 69, true},
+ {"udp", "bootpc", 68, true},
+ {"udp", "bootps", 67, true},
+ {"udp", "domain", 53, true},
+ {"udp", "ntp", 123, true},
+ {"udp", "snmp", 161, true},
+ {"udp", "syslog", 514, true},
+ {"udp", "42", 42, true},
+
+ {"--badnet--", "zzz", 0, false},
+ {"tcp", "--badport--", 0, false},
+ {"tcp", "-1", 0, false},
+ {"tcp", "65536", 0, false},
+ {"udp", "-1", 0, false},
+ {"udp", "65536", 0, false},
+
+ // Issue 13610: LookupPort("tcp", "")
+ {"tcp", "", 0, true},
+ {"tcp6", "", 0, true},
+ {"tcp4", "", 0, true},
+ {"udp", "", 0, true},
+}
+
+func TestLookupPort(t *testing.T) {
+ switch runtime.GOOS {
+ case "nacl":
+ t.Skipf("not supported on %s", runtime.GOOS)
+ }
+
+ for _, tt := range lookupPortTests {
+ if port, err := LookupPort(tt.network, tt.name); port != tt.port || (err == nil) != tt.ok {
+ t.Errorf("LookupPort(%q, %q) = %d, %v; want %d", tt.network, tt.name, port, err, tt.port)
+ }
+ }
+}
diff --git a/libgo/go/net/lookup_windows.go b/libgo/go/net/lookup_windows.go
index 1b6d392f660..13edc264e82 100644
--- a/libgo/go/net/lookup_windows.go
+++ b/libgo/go/net/lookup_windows.go
@@ -19,7 +19,7 @@ var (
func getprotobyname(name string) (proto int, err error) {
p, err := syscall.GetProtoByName(name)
if err != nil {
- return 0, os.NewSyscallError("getorotobyname", err)
+ return 0, os.NewSyscallError("getprotobyname", err)
}
return int(p.Proto), nil
}
@@ -221,10 +221,7 @@ func lookupCNAME(name string) (string, error) {
// windows returns DNS_INFO_NO_RECORDS if there are no CNAME-s
if errno, ok := e.(syscall.Errno); ok && errno == syscall.DNS_INFO_NO_RECORDS {
// if there are no aliases, the canonical name is the input name
- if name == "" || name[len(name)-1] != '.' {
- return name + ".", nil
- }
- return name, nil
+ return absDomainName([]byte(name)), nil
}
if e != nil {
return "", &DNSError{Err: os.NewSyscallError("dnsquery", e).Error(), Name: name}
@@ -232,8 +229,8 @@ func lookupCNAME(name string) (string, error) {
defer syscall.DnsRecordListFree(r, 1)
resolved := resolveCNAME(syscall.StringToUTF16Ptr(name), r)
- cname := syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(resolved))[:]) + "."
- return cname, nil
+ cname := syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(resolved))[:])
+ return absDomainName([]byte(cname)), nil
}
func lookupSRV(service, proto, name string) (string, []*SRV, error) {
@@ -255,10 +252,10 @@ func lookupSRV(service, proto, name string) (string, []*SRV, error) {
srvs := make([]*SRV, 0, 10)
for _, p := range validRecs(r, syscall.DNS_TYPE_SRV, target) {
v := (*syscall.DNSSRVData)(unsafe.Pointer(&p.Data[0]))
- srvs = append(srvs, &SRV{syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Target))[:]), v.Port, v.Priority, v.Weight})
+ srvs = append(srvs, &SRV{absDomainName([]byte(syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Target))[:]))), v.Port, v.Priority, v.Weight})
}
byPriorityWeight(srvs).sort()
- return name, srvs, nil
+ return absDomainName([]byte(target)), srvs, nil
}
func lookupMX(name string) ([]*MX, error) {
@@ -274,7 +271,7 @@ func lookupMX(name string) ([]*MX, error) {
mxs := make([]*MX, 0, 10)
for _, p := range validRecs(r, syscall.DNS_TYPE_MX, name) {
v := (*syscall.DNSMXData)(unsafe.Pointer(&p.Data[0]))
- mxs = append(mxs, &MX{syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.NameExchange))[:]) + ".", v.Preference})
+ mxs = append(mxs, &MX{absDomainName([]byte(syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.NameExchange))[:]))), v.Preference})
}
byPref(mxs).sort()
return mxs, nil
@@ -293,7 +290,7 @@ func lookupNS(name string) ([]*NS, error) {
nss := make([]*NS, 0, 10)
for _, p := range validRecs(r, syscall.DNS_TYPE_NS, name) {
v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0]))
- nss = append(nss, &NS{syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:]) + "."})
+ nss = append(nss, &NS{absDomainName([]byte(syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:])))})
}
return nss, nil
}
@@ -336,7 +333,7 @@ func lookupAddr(addr string) ([]string, error) {
ptrs := make([]string, 0, 10)
for _, p := range validRecs(r, syscall.DNS_TYPE_PTR, arpa) {
v := (*syscall.DNSPTRData)(unsafe.Pointer(&p.Data[0]))
- ptrs = append(ptrs, syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:]))
+ ptrs = append(ptrs, absDomainName([]byte(syscall.UTF16ToString((*[256]uint16)(unsafe.Pointer(v.Host))[:]))))
}
return ptrs, nil
}
diff --git a/libgo/go/net/mac.go b/libgo/go/net/mac.go
index 8594a9146ae..93f0b091219 100644
--- a/libgo/go/net/mac.go
+++ b/libgo/go/net/mac.go
@@ -24,14 +24,17 @@ func (a HardwareAddr) String() string {
return string(buf)
}
-// ParseMAC parses s as an IEEE 802 MAC-48, EUI-48, or EUI-64 using one of the
-// following formats:
+// ParseMAC parses s as an IEEE 802 MAC-48, EUI-48, EUI-64, or a 20-octet
+// IP over InfiniBand link-layer address using one of the following formats:
// 01:23:45:67:89:ab
// 01:23:45:67:89:ab:cd:ef
+// 01:23:45:67:89:ab:cd:ef:00:00:01:23:45:67:89:ab:cd:ef:00:00
// 01-23-45-67-89-ab
// 01-23-45-67-89-ab-cd-ef
+// 01-23-45-67-89-ab-cd-ef-00-00-01-23-45-67-89-ab-cd-ef-00-00
// 0123.4567.89ab
// 0123.4567.89ab.cdef
+// 0123.4567.89ab.cdef.0000.0123.4567.89ab.cdef.0000
func ParseMAC(s string) (hw HardwareAddr, err error) {
if len(s) < 14 {
goto error
@@ -42,7 +45,7 @@ func ParseMAC(s string) (hw HardwareAddr, err error) {
goto error
}
n := (len(s) + 1) / 3
- if n != 6 && n != 8 {
+ if n != 6 && n != 8 && n != 20 {
goto error
}
hw = make(HardwareAddr, n)
@@ -58,7 +61,7 @@ func ParseMAC(s string) (hw HardwareAddr, err error) {
goto error
}
n := 2 * (len(s) + 1) / 5
- if n != 6 && n != 8 {
+ if n != 6 && n != 8 && n != 20 {
goto error
}
hw = make(HardwareAddr, n)
diff --git a/libgo/go/net/mac_test.go b/libgo/go/net/mac_test.go
index 0af0c014f50..1ec6b287ac3 100644
--- a/libgo/go/net/mac_test.go
+++ b/libgo/go/net/mac_test.go
@@ -34,6 +34,30 @@ var parseMACTests = []struct {
{"01:23:45:67:89:AB:CD:EF", HardwareAddr{1, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef}, ""},
{"01-23-45-67-89-AB-CD-EF", HardwareAddr{1, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef}, ""},
{"0123.4567.89AB.CDEF", HardwareAddr{1, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef}, ""},
+ {
+ "01:23:45:67:89:ab:cd:ef:00:00:01:23:45:67:89:ab:cd:ef:00:00",
+ HardwareAddr{
+ 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0x00, 0x00,
+ 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0x00, 0x00,
+ },
+ "",
+ },
+ {
+ "01-23-45-67-89-ab-cd-ef-00-00-01-23-45-67-89-ab-cd-ef-00-00",
+ HardwareAddr{
+ 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0x00, 0x00,
+ 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0x00, 0x00,
+ },
+ "",
+ },
+ {
+ "0123.4567.89ab.cdef.0000.0123.4567.89ab.cdef.0000",
+ HardwareAddr{
+ 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0x00, 0x00,
+ 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0x00, 0x00,
+ },
+ "",
+ },
}
func TestParseMAC(t *testing.T) {
diff --git a/libgo/go/net/mail/message.go b/libgo/go/net/mail/message.go
index 266ac50a38d..923630c49ce 100644
--- a/libgo/go/net/mail/message.go
+++ b/libgo/go/net/mail/message.go
@@ -234,6 +234,12 @@ func (a *Address) String() string {
return b.String()
}
+ // Text in an encoded-word in a display-name must not contain certain
+ // characters like quotes or parentheses (see RFC 2047 section 5.3).
+ // When this is the case encode the name using base64 encoding.
+ if strings.ContainsAny(a.Name, "\"#$%&'(),.:;<>@[]^`{|}~") {
+ return mime.BEncoding.Encode("utf-8", a.Name) + " " + s
+ }
return mime.QEncoding.Encode("utf-8", a.Name) + " " + s
}
@@ -386,10 +392,9 @@ func (p *addrParser) consumePhrase() (phrase string, err error) {
// We actually parse dot-atom here to be more permissive
// than what RFC 5322 specifies.
word, err = p.consumeAtom(true, true)
- }
-
- if err == nil {
- word, err = p.decodeRFC2047Word(word)
+ if err == nil {
+ word, err = p.decodeRFC2047Word(word)
+ }
}
if err != nil {
@@ -442,17 +447,25 @@ Loop:
return string(qsb), nil
}
+var errNonASCII = errors.New("mail: unencoded non-ASCII text in address")
+
// consumeAtom parses an RFC 5322 atom at the start of p.
// If dot is true, consumeAtom parses an RFC 5322 dot-atom instead.
// If permissive is true, consumeAtom will not fail on
// leading/trailing/double dots in the atom (see golang.org/issue/4938).
func (p *addrParser) consumeAtom(dot bool, permissive bool) (atom string, err error) {
- if !isAtext(p.peek(), false) {
+ if c := p.peek(); !isAtext(c, false) {
+ if c > 127 {
+ return "", errNonASCII
+ }
return "", errors.New("mail: invalid string")
}
i := 1
for ; i < p.len() && isAtext(p.s[i], dot); i++ {
}
+ if i < p.len() && p.s[i] > 127 {
+ return "", errNonASCII
+ }
atom, p.s = string(p.s[:i]), p.s[i:]
if !permissive {
if strings.HasPrefix(atom, ".") {
diff --git a/libgo/go/net/mail/message_test.go b/libgo/go/net/mail/message_test.go
index 1b422743f95..4e718e26367 100644
--- a/libgo/go/net/mail/message_test.go
+++ b/libgo/go/net/mail/message_test.go
@@ -127,6 +127,14 @@ func TestAddressParsingError(t *testing.T) {
}
}
+func TestAddressParsingErrorUnquotedNonASCII(t *testing.T) {
+ const txt = "µ <micro@example.net>"
+ _, err := ParseAddress(txt)
+ if err == nil || !strings.Contains(err.Error(), "unencoded non-ASCII text in address") {
+ t.Errorf(`mail.ParseAddress(%q) err: %q, want ".*unencoded non-ASCII text in address.*"`, txt, err)
+ }
+}
+
func TestAddressParsing(t *testing.T) {
tests := []struct {
addrsStr string
@@ -449,7 +457,7 @@ func TestAddressParser(t *testing.T) {
}
}
-func TestAddressFormatting(t *testing.T) {
+func TestAddressString(t *testing.T) {
tests := []struct {
addr *Address
exp string
@@ -491,11 +499,40 @@ func TestAddressFormatting(t *testing.T) {
&Address{Name: "Rob", Address: "@"},
`"Rob" <@>`,
},
+ {
+ &Address{Name: "Böb, Jacöb", Address: "bob@example.com"},
+ `=?utf-8?b?QsO2YiwgSmFjw7Zi?= <bob@example.com>`,
+ },
+ {
+ &Address{Name: "=??Q?x?=", Address: "hello@world.com"},
+ `"=??Q?x?=" <hello@world.com>`,
+ },
+ {
+ &Address{Name: "=?hello", Address: "hello@world.com"},
+ `"=?hello" <hello@world.com>`,
+ },
+ {
+ &Address{Name: "world?=", Address: "hello@world.com"},
+ `"world?=" <hello@world.com>`,
+ },
}
for _, test := range tests {
s := test.addr.String()
if s != test.exp {
t.Errorf("Address%+v.String() = %v, want %v", *test.addr, s, test.exp)
+ continue
+ }
+
+ // Check round-trip.
+ if test.addr.Address != "" && test.addr.Address != "@" {
+ a, err := ParseAddress(test.exp)
+ if err != nil {
+ t.Errorf("ParseAddress(%#q): %v", test.exp, err)
+ continue
+ }
+ if a.Name != test.addr.Name || a.Address != test.addr.Address {
+ t.Errorf("ParseAddress(%#q) = %#v, want %#v", test.exp, a, test.addr)
+ }
}
}
}
@@ -586,3 +623,32 @@ func TestAddressParsingAndFormatting(t *testing.T) {
}
}
+
+func TestAddressFormattingAndParsing(t *testing.T) {
+ tests := []*Address{
+ {Name: "@lïce", Address: "alice@example.com"},
+ {Name: "Böb O'Connor", Address: "bob@example.com"},
+ {Name: "???", Address: "bob@example.com"},
+ {Name: "Böb ???", Address: "bob@example.com"},
+ {Name: "Böb (Jacöb)", Address: "bob@example.com"},
+ {Name: "à#$%&'(),.:;<>@[]^`{|}~'", Address: "bob@example.com"},
+ // https://golang.org/issue/11292
+ {Name: "\"\\\x1f,\"", Address: "0@0"},
+ // https://golang.org/issue/12782
+ {Name: "naé, mée", Address: "test.mail@gmail.com"},
+ }
+
+ for i, test := range tests {
+ parsed, err := ParseAddress(test.String())
+ if err != nil {
+ t.Errorf("test #%d: ParseAddr(%q) error: %v", i, test.String(), err)
+ continue
+ }
+ if parsed.Name != test.Name {
+ t.Errorf("test #%d: Parsed name = %q; want %q", i, parsed.Name, test.Name)
+ }
+ if parsed.Address != test.Address {
+ t.Errorf("test #%d: Parsed address = %q; want %q", i, parsed.Address, test.Address)
+ }
+ }
+}
diff --git a/libgo/go/net/net.go b/libgo/go/net/net.go
index 6e84c3a100e..d9d23fae8f6 100644
--- a/libgo/go/net/net.go
+++ b/libgo/go/net/net.go
@@ -345,7 +345,7 @@ var listenerBacklog = maxListenerBacklog()
// Multiple goroutines may invoke methods on a Listener simultaneously.
type Listener interface {
// Accept waits for and returns the next connection to the listener.
- Accept() (c Conn, err error)
+ Accept() (Conn, error)
// Close closes the listener.
// Any blocked Accept operations will be unblocked and return errors.
@@ -426,7 +426,16 @@ func (e *OpError) Error() string {
return s
}
-var noDeadline = time.Time{}
+var (
+ // aLongTimeAgo is a non-zero time, far in the past, used for
+ // immediate cancelation of dials.
+ aLongTimeAgo = time.Unix(233431200, 0)
+
+ // nonDeadline and noCancel are just zero values for
+ // readability with functions taking too many parameters.
+ noDeadline = time.Time{}
+ noCancel = (chan struct{})(nil)
+)
type timeout interface {
Timeout() bool
@@ -520,10 +529,11 @@ var (
// DNSError represents a DNS lookup error.
type DNSError struct {
- Err string // description of the error
- Name string // name looked for
- Server string // server used
- IsTimeout bool // if true, timed out; not all timeouts set this
+ Err string // description of the error
+ Name string // name looked for
+ Server string // server used
+ IsTimeout bool // if true, timed out; not all timeouts set this
+ IsTemporary bool // if true, error is temporary; not all errors set this
}
func (e *DNSError) Error() string {
@@ -546,7 +556,7 @@ func (e *DNSError) Timeout() bool { return e.IsTimeout }
// Temporary reports whether the DNS error is known to be temporary.
// This is not always known; a DNS lookup may fail due to a temporary
// error and return a DNSError for which Temporary returns false.
-func (e *DNSError) Temporary() bool { return e.IsTimeout }
+func (e *DNSError) Temporary() bool { return e.IsTimeout || e.IsTemporary }
type writerOnly struct {
io.Writer
diff --git a/libgo/go/net/net_test.go b/libgo/go/net/net_test.go
index 3907ce4aa56..6dcfc2190e0 100644
--- a/libgo/go/net/net_test.go
+++ b/libgo/go/net/net_test.go
@@ -208,7 +208,6 @@ func TestListenerClose(t *testing.T) {
case "unix", "unixpacket":
defer os.Remove(ln.Addr().String())
}
- defer ln.Close()
if err := ln.Close(); err != nil {
if perr := parseCloseError(err); perr != nil {
@@ -221,6 +220,14 @@ func TestListenerClose(t *testing.T) {
c.Close()
t.Fatal("should fail")
}
+
+ if network == "tcp" {
+ cc, err := Dial("tcp", ln.Addr().String())
+ if err == nil {
+ t.Error("Dial to closed TCP listener succeeeded.")
+ cc.Close()
+ }
+ }
}
}
@@ -254,3 +261,26 @@ func TestPacketConnClose(t *testing.T) {
}
}
}
+
+// nacl was previous failing to reuse an address.
+func TestListenCloseListen(t *testing.T) {
+ const maxTries = 10
+ for tries := 0; tries < maxTries; tries++ {
+ ln, err := newLocalListener("tcp")
+ if err != nil {
+ t.Fatal(err)
+ }
+ addr := ln.Addr().String()
+ if err := ln.Close(); err != nil {
+ t.Fatal(err)
+ }
+ ln, err = Listen("tcp", addr)
+ if err == nil {
+ // Success. nacl couldn't do this before.
+ ln.Close()
+ return
+ }
+ t.Errorf("failed on try %d/%d: %v", tries+1, maxTries, err)
+ }
+ t.Fatalf("failed to listen/close/listen on same address after %d tries", maxTries)
+}
diff --git a/libgo/go/net/non_unix_test.go b/libgo/go/net/non_unix_test.go
index eddca562f98..db3427e7cb4 100644
--- a/libgo/go/net/non_unix_test.go
+++ b/libgo/go/net/non_unix_test.go
@@ -6,6 +6,17 @@
package net
+import "runtime"
+
+// See unix_test.go for what these (don't) do.
+func forceGoDNS() func() {
+ switch runtime.GOOS {
+ case "plan9", "windows":
+ return func() {}
+ default:
+ return nil
+ }
+}
+
// See unix_test.go for what these (don't) do.
-func forceGoDNS() func() { return func() {} }
-func forceCgoDNS() bool { return false }
+func forceCgoDNS() func() { return nil }
diff --git a/libgo/go/net/parse.go b/libgo/go/net/parse.go
index c72e1c2eaf0..80836a108c9 100644
--- a/libgo/go/net/parse.go
+++ b/libgo/go/net/parse.go
@@ -10,6 +10,8 @@ package net
import (
"io"
"os"
+ "time"
+ _ "unsafe" // For go:linkname
)
type file struct {
@@ -70,15 +72,21 @@ func open(name string) (*file, error) {
return &file{fd, make([]byte, 0, os.Getpagesize()), false}, nil
}
-func byteIndex(s string, c byte) int {
- for i := 0; i < len(s); i++ {
- if s[i] == c {
- return i
- }
+func stat(name string) (mtime time.Time, size int64, err error) {
+ st, err := os.Stat(name)
+ if err != nil {
+ return time.Time{}, 0, err
}
- return -1
+ return st.ModTime(), st.Size(), nil
}
+// byteIndex is strings.IndexByte. It returns the index of the
+// first instance of c in s, or -1 if c is not present in s.
+// strings.IndexByte is implemented in runtime/asm_$GOARCH.s
+//go:linkname byteIndex strings.IndexByte
+//extern strings.IndexByte
+func byteIndex(s string, c byte) int
+
// Count occurrences in s of any bytes in t.
func countAnyByte(s string, t string) int {
n := 0
@@ -120,15 +128,27 @@ const big = 0xFFFFFF
// Returns number, new offset, success.
func dtoi(s string, i0 int) (n int, i int, ok bool) {
n = 0
+ neg := false
+ if len(s) > 0 && s[0] == '-' {
+ neg = true
+ s = s[1:]
+ }
for i = i0; i < len(s) && '0' <= s[i] && s[i] <= '9'; i++ {
n = n*10 + int(s[i]-'0')
if n >= big {
- return 0, i, false
+ if neg {
+ return -big, i + 1, false
+ }
+ return big, i, false
}
}
if i == i0 {
return 0, i, false
}
+ if neg {
+ n = -n
+ i++
+ }
return n, i, true
}
@@ -314,14 +334,10 @@ func foreachField(x []byte, fn func(field []byte) error) error {
// bytesIndexByte is bytes.IndexByte. It returns the index of the
// first instance of c in s, or -1 if c is not present in s.
-func bytesIndexByte(s []byte, c byte) int {
- for i, b := range s {
- if b == c {
- return i
- }
- }
- return -1
-}
+// bytes.IndexByte is implemented in runtime/asm_$GOARCH.s
+//go:linkname bytesIndexByte bytes.IndexByte
+//extern bytes.IndexByte
+func bytesIndexByte(s []byte, c byte) int
// stringsHasSuffix is strings.HasSuffix. It reports whether s ends in
// suffix.
diff --git a/libgo/go/net/parse_test.go b/libgo/go/net/parse_test.go
index 0f048fcea0b..fec92009465 100644
--- a/libgo/go/net/parse_test.go
+++ b/libgo/go/net/parse_test.go
@@ -77,3 +77,25 @@ func TestGoDebugString(t *testing.T) {
}
}
}
+
+func TestDtoi(t *testing.T) {
+ for _, tt := range []struct {
+ in string
+ out int
+ off int
+ ok bool
+ }{
+ {"", 0, 0, false},
+
+ {"-123456789", -big, 9, false},
+ {"-1", -1, 2, true},
+ {"0", 0, 1, true},
+ {"65536", 65536, 5, true},
+ {"123456789", big, 8, false},
+ } {
+ n, i, ok := dtoi(tt.in, 0)
+ if n != tt.out || i != tt.off || ok != tt.ok {
+ t.Errorf("got %d, %d, %v; want %d, %d, %v", n, i, ok, tt.out, tt.off, tt.ok)
+ }
+ }
+}
diff --git a/libgo/go/net/platform_test.go b/libgo/go/net/platform_test.go
index d6248520f33..76c53138cdd 100644
--- a/libgo/go/net/platform_test.go
+++ b/libgo/go/net/platform_test.go
@@ -32,7 +32,7 @@ func testableNetwork(network string) bool {
}
case "unix", "unixgram":
switch runtime.GOOS {
- case "nacl", "plan9", "windows":
+ case "android", "nacl", "plan9", "windows":
return false
}
// iOS does not support unix, unixgram.
@@ -134,7 +134,7 @@ func testableListenArgs(network, address, client string) bool {
// Test functionality of IPv4 communication using AF_INET6
// sockets.
- if !supportsIPv4map && (network == "tcp" || network == "udp" || network == "ip") && wildcard {
+ if !supportsIPv4map && supportsIPv4 && (network == "tcp" || network == "udp" || network == "ip") && wildcard {
// At this point, we prefer IPv4 when ip is nil.
// See favoriteAddrFamily for further information.
if ip.To16() != nil && ip.To4() == nil && cip.To4() != nil { // a pair of IPv6 server and IPv4 client
diff --git a/libgo/go/net/port.go b/libgo/go/net/port.go
deleted file mode 100644
index a2a538789e1..00000000000
--- a/libgo/go/net/port.go
+++ /dev/null
@@ -1,24 +0,0 @@
-// Copyright 2012 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// Network service port manipulations
-
-package net
-
-// parsePort parses port as a network service port number for both
-// TCP and UDP.
-func parsePort(net, port string) (int, error) {
- p, i, ok := dtoi(port, 0)
- if !ok || i != len(port) {
- var err error
- p, err = LookupPort(net, port)
- if err != nil {
- return 0, err
- }
- }
- if p < 0 || p > 0xFFFF {
- return 0, &AddrError{Err: "invalid port", Addr: port}
- }
- return p, nil
-}
diff --git a/libgo/go/net/port_test.go b/libgo/go/net/port_test.go
deleted file mode 100644
index 258a5bda48f..00000000000
--- a/libgo/go/net/port_test.go
+++ /dev/null
@@ -1,57 +0,0 @@
-// Copyright 2009 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-package net
-
-import (
- "runtime"
- "testing"
-)
-
-var portTests = []struct {
- network string
- name string
- port int
- ok bool
-}{
- {"tcp", "echo", 7, true},
- {"tcp", "discard", 9, true},
- {"tcp", "systat", 11, true},
- {"tcp", "daytime", 13, true},
- {"tcp", "chargen", 19, true},
- {"tcp", "ftp-data", 20, true},
- {"tcp", "ftp", 21, true},
- {"tcp", "telnet", 23, true},
- {"tcp", "smtp", 25, true},
- {"tcp", "time", 37, true},
- {"tcp", "domain", 53, true},
- {"tcp", "finger", 79, true},
- {"tcp", "42", 42, true},
-
- {"udp", "echo", 7, true},
- {"udp", "tftp", 69, true},
- {"udp", "bootpc", 68, true},
- {"udp", "bootps", 67, true},
- {"udp", "domain", 53, true},
- {"udp", "ntp", 123, true},
- {"udp", "snmp", 161, true},
- {"udp", "syslog", 514, true},
- {"udp", "42", 42, true},
-
- {"--badnet--", "zzz", 0, false},
- {"tcp", "--badport--", 0, false},
-}
-
-func TestLookupPort(t *testing.T) {
- switch runtime.GOOS {
- case "nacl":
- t.Skipf("not supported on %s", runtime.GOOS)
- }
-
- for _, tt := range portTests {
- if port, err := LookupPort(tt.network, tt.name); port != tt.port || (err == nil) != tt.ok {
- t.Errorf("LookupPort(%q, %q) = %v, %v; want %v", tt.network, tt.name, port, err, tt.port)
- }
- }
-}
diff --git a/libgo/go/net/race.go b/libgo/go/net/race.go
deleted file mode 100644
index 2f02a6c226b..00000000000
--- a/libgo/go/net/race.go
+++ /dev/null
@@ -1,31 +0,0 @@
-// Copyright 2013 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// +build race
-// +build windows
-
-package net
-
-import (
- "runtime"
- "unsafe"
-)
-
-const raceenabled = true
-
-func raceAcquire(addr unsafe.Pointer) {
- runtime.RaceAcquire(addr)
-}
-
-func raceReleaseMerge(addr unsafe.Pointer) {
- runtime.RaceReleaseMerge(addr)
-}
-
-func raceReadRange(addr unsafe.Pointer, len int) {
- runtime.RaceReadRange(addr, len)
-}
-
-func raceWriteRange(addr unsafe.Pointer, len int) {
- runtime.RaceWriteRange(addr, len)
-}
diff --git a/libgo/go/net/race0.go b/libgo/go/net/race0.go
deleted file mode 100644
index f5042977931..00000000000
--- a/libgo/go/net/race0.go
+++ /dev/null
@@ -1,26 +0,0 @@
-// Copyright 2013 The Go Authors. All rights reserved.
-// Use of this source code is governed by a BSD-style
-// license that can be found in the LICENSE file.
-
-// +build !race
-// +build windows
-
-package net
-
-import (
- "unsafe"
-)
-
-const raceenabled = false
-
-func raceAcquire(addr unsafe.Pointer) {
-}
-
-func raceReleaseMerge(addr unsafe.Pointer) {
-}
-
-func raceReadRange(addr unsafe.Pointer, len int) {
-}
-
-func raceWriteRange(addr unsafe.Pointer, len int) {
-}
diff --git a/libgo/go/net/rpc/server.go b/libgo/go/net/rpc/server.go
index 6e6e8819174..c4d4479958a 100644
--- a/libgo/go/net/rpc/server.go
+++ b/libgo/go/net/rpc/server.go
@@ -611,13 +611,15 @@ func (server *Server) readRequestHeader(codec ServerCodec) (service *service, mt
}
// Accept accepts connections on the listener and serves requests
-// for each incoming connection. Accept blocks; the caller typically
-// invokes it in a go statement.
+// for each incoming connection. Accept blocks until the listener
+// returns a non-nil error. The caller typically invokes Accept in a
+// go statement.
func (server *Server) Accept(lis net.Listener) {
for {
conn, err := lis.Accept()
if err != nil {
- log.Fatal("rpc.Serve: accept:", err.Error()) // TODO(r): exit?
+ log.Print("rpc.Serve: accept:", err.Error())
+ return
}
go server.ServeConn(conn)
}
diff --git a/libgo/go/net/rpc/server_test.go b/libgo/go/net/rpc/server_test.go
index 0dc4ddc2de0..8871c88133c 100644
--- a/libgo/go/net/rpc/server_test.go
+++ b/libgo/go/net/rpc/server_test.go
@@ -74,6 +74,17 @@ func (t *Arith) Error(args *Args, reply *Reply) error {
panic("ERROR")
}
+type hidden int
+
+func (t *hidden) Exported(args Args, reply *Reply) error {
+ reply.C = args.A + args.B
+ return nil
+}
+
+type Embed struct {
+ hidden
+}
+
func listenTCP() (net.Listener, string) {
l, e := net.Listen("tcp", "127.0.0.1:0") // any available address
if e != nil {
@@ -84,6 +95,7 @@ func listenTCP() (net.Listener, string) {
func startServer() {
Register(new(Arith))
+ Register(new(Embed))
RegisterName("net.rpc.Arith", new(Arith))
var l net.Listener
@@ -98,6 +110,7 @@ func startServer() {
func startNewServer() {
newServer = NewServer()
newServer.Register(new(Arith))
+ newServer.Register(new(Embed))
newServer.RegisterName("net.rpc.Arith", new(Arith))
newServer.RegisterName("newServer.Arith", new(Arith))
@@ -142,6 +155,17 @@ func testRPC(t *testing.T, addr string) {
t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
}
+ // Methods exported from unexported embedded structs
+ args = &Args{7, 0}
+ reply = new(Reply)
+ err = client.Call("Embed.Exported", args, reply)
+ if err != nil {
+ t.Errorf("Add: expected no error but got string %q", err.Error())
+ }
+ if reply.C != args.A+args.B {
+ t.Errorf("Add: expected %d got %d", reply.C, args.A+args.B)
+ }
+
// Nonexistent method
args = &Args{7, 0}
reply = new(Reply)
@@ -593,6 +617,19 @@ func TestErrorAfterClientClose(t *testing.T) {
}
}
+// Tests the fix to issue 11221. Without the fix, this loops forever or crashes.
+func TestAcceptExitAfterListenerClose(t *testing.T) {
+ newServer = NewServer()
+ newServer.Register(new(Arith))
+ newServer.RegisterName("net.rpc.Arith", new(Arith))
+ newServer.RegisterName("newServer.Arith", new(Arith))
+
+ var l net.Listener
+ l, newServerAddr = listenTCP()
+ l.Close()
+ newServer.Accept(l)
+}
+
func benchmarkEndToEnd(dial func() (*Client, error), b *testing.B) {
once.Do(startServer)
client, err := dial()
diff --git a/libgo/go/net/sendfile_solaris.go b/libgo/go/net/sendfile_solaris.go
index 0966575696b..f6833813fd0 100644
--- a/libgo/go/net/sendfile_solaris.go
+++ b/libgo/go/net/sendfile_solaris.go
@@ -26,6 +26,8 @@ const maxSendfileSize int = 4 << 20
//
// if handled == false, sendFile performed no work.
func sendFile(c *netFD, r io.Reader) (written int64, err error, handled bool) {
+ return // Solaris sendfile is disabled until Issue 13892 is understood and fixed
+
// Solaris uses 0 as the "until EOF" value. If you pass in more bytes than the
// file contains, it will loop back to the beginning ad nauseam until it's sent
// exactly the number of bytes told to. As such, we need to know exactly how many
diff --git a/libgo/go/net/server_test.go b/libgo/go/net/server_test.go
index fe0006b11fb..2e998e23a8a 100644
--- a/libgo/go/net/server_test.go
+++ b/libgo/go/net/server_test.go
@@ -55,7 +55,7 @@ func TestTCPServer(t *testing.T) {
for i, tt := range tcpServerTests {
if !testableListenArgs(tt.snet, tt.saddr, tt.taddr) {
- t.Logf("skipping %s test", tt.snet+" "+tt.saddr+"->"+tt.taddr)
+ t.Logf("skipping %s test", tt.snet+" "+tt.saddr+"<-"+tt.taddr)
continue
}
@@ -251,7 +251,7 @@ var udpServerTests = []struct {
func TestUDPServer(t *testing.T) {
for i, tt := range udpServerTests {
if !testableListenArgs(tt.snet, tt.saddr, tt.taddr) {
- t.Logf("skipping %s test", tt.snet+" "+tt.saddr+"->"+tt.taddr)
+ t.Logf("skipping %s test", tt.snet+" "+tt.saddr+"<-"+tt.taddr)
continue
}
@@ -329,7 +329,7 @@ var unixgramServerTests = []struct {
func TestUnixgramServer(t *testing.T) {
for i, tt := range unixgramServerTests {
if !testableListenArgs("unixgram", tt.saddr, "") {
- t.Logf("skipping %s test", "unixgram "+tt.saddr+"->"+tt.caddr)
+ t.Logf("skipping %s test", "unixgram "+tt.saddr+"<-"+tt.caddr)
continue
}
diff --git a/libgo/go/net/sock_posix.go b/libgo/go/net/sock_posix.go
index 4d2cfde3f1c..46767215672 100644
--- a/libgo/go/net/sock_posix.go
+++ b/libgo/go/net/sock_posix.go
@@ -34,7 +34,7 @@ type sockaddr interface {
// socket returns a network file descriptor that is ready for
// asynchronous I/O using the network poller.
-func socket(net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr, deadline time.Time) (fd *netFD, err error) {
+func socket(net string, family, sotype, proto int, ipv6only bool, laddr, raddr sockaddr, deadline time.Time, cancel <-chan struct{}) (fd *netFD, err error) {
s, err := sysSocket(family, sotype, proto)
if err != nil {
return nil, err
@@ -86,7 +86,7 @@ func socket(net string, family, sotype, proto int, ipv6only bool, laddr, raddr s
return fd, nil
}
}
- if err := fd.dial(laddr, raddr, deadline); err != nil {
+ if err := fd.dial(laddr, raddr, deadline, cancel); err != nil {
fd.Close()
return nil, err
}
@@ -117,7 +117,7 @@ func (fd *netFD) addrFunc() func(syscall.Sockaddr) Addr {
return func(syscall.Sockaddr) Addr { return nil }
}
-func (fd *netFD) dial(laddr, raddr sockaddr, deadline time.Time) error {
+func (fd *netFD) dial(laddr, raddr sockaddr, deadline time.Time, cancel <-chan struct{}) error {
var err error
var lsa syscall.Sockaddr
if laddr != nil {
@@ -134,7 +134,7 @@ func (fd *netFD) dial(laddr, raddr sockaddr, deadline time.Time) error {
if rsa, err = raddr.sockaddr(fd.family); err != nil {
return err
}
- if err := fd.connect(lsa, rsa, deadline); err != nil {
+ if err := fd.connect(lsa, rsa, deadline, cancel); err != nil {
return err
}
fd.isConnected = true
diff --git a/libgo/go/net/tcp_test.go b/libgo/go/net/tcp_test.go
index 25ae9b9d771..9da6f6cdd05 100644
--- a/libgo/go/net/tcp_test.go
+++ b/libgo/go/net/tcp_test.go
@@ -540,9 +540,12 @@ func TestTCPStress(t *testing.T) {
if err != nil {
t.Fatal(err)
}
- defer ln.Close()
+ done := make(chan bool)
// Acceptor.
go func() {
+ defer func() {
+ done <- true
+ }()
for {
c, err := ln.Accept()
if err != nil {
@@ -560,7 +563,6 @@ func TestTCPStress(t *testing.T) {
}(c)
}
}()
- done := make(chan bool)
for i := 0; i < conns; i++ {
// Client connection.
go func() {
@@ -584,4 +586,6 @@ func TestTCPStress(t *testing.T) {
for i := 0; i < conns; i++ {
<-done
}
+ ln.Close()
+ <-done
}
diff --git a/libgo/go/net/tcpsock_plan9.go b/libgo/go/net/tcpsock_plan9.go
index 9f23703abb4..afccbfe8a74 100644
--- a/libgo/go/net/tcpsock_plan9.go
+++ b/libgo/go/net/tcpsock_plan9.go
@@ -107,13 +107,14 @@ func (c *TCPConn) SetNoDelay(noDelay bool) error {
// which must be "tcp", "tcp4", or "tcp6". If laddr is not nil, it is
// used as the local address for the connection.
func DialTCP(net string, laddr, raddr *TCPAddr) (*TCPConn, error) {
- return dialTCP(net, laddr, raddr, noDeadline)
+ return dialTCP(net, laddr, raddr, noDeadline, noCancel)
}
-func dialTCP(net string, laddr, raddr *TCPAddr, deadline time.Time) (*TCPConn, error) {
+func dialTCP(net string, laddr, raddr *TCPAddr, deadline time.Time, cancel <-chan struct{}) (*TCPConn, error) {
if !deadline.IsZero() {
panic("net.dialTCP: deadline not implemented on Plan 9")
}
+ // TODO(bradfitz,0intro): also use the cancel channel.
switch net {
case "tcp", "tcp4", "tcp6":
default:
diff --git a/libgo/go/net/tcpsock_posix.go b/libgo/go/net/tcpsock_posix.go
index 7e49b769e1c..0e12d543002 100644
--- a/libgo/go/net/tcpsock_posix.go
+++ b/libgo/go/net/tcpsock_posix.go
@@ -164,11 +164,11 @@ func DialTCP(net string, laddr, raddr *TCPAddr) (*TCPConn, error) {
if raddr == nil {
return nil, &OpError{Op: "dial", Net: net, Source: laddr.opAddr(), Addr: nil, Err: errMissingAddress}
}
- return dialTCP(net, laddr, raddr, noDeadline)
+ return dialTCP(net, laddr, raddr, noDeadline, noCancel)
}
-func dialTCP(net string, laddr, raddr *TCPAddr, deadline time.Time) (*TCPConn, error) {
- fd, err := internetSocket(net, laddr, raddr, deadline, syscall.SOCK_STREAM, 0, "dial")
+func dialTCP(net string, laddr, raddr *TCPAddr, deadline time.Time, cancel <-chan struct{}) (*TCPConn, error) {
+ fd, err := internetSocket(net, laddr, raddr, deadline, syscall.SOCK_STREAM, 0, "dial", cancel)
// TCP has a rarely used mechanism called a 'simultaneous connection' in
// which Dial("tcp", addr1, addr2) run on the machine at addr1 can
@@ -198,7 +198,7 @@ func dialTCP(net string, laddr, raddr *TCPAddr, deadline time.Time) (*TCPConn, e
if err == nil {
fd.Close()
}
- fd, err = internetSocket(net, laddr, raddr, deadline, syscall.SOCK_STREAM, 0, "dial")
+ fd, err = internetSocket(net, laddr, raddr, deadline, syscall.SOCK_STREAM, 0, "dial", cancel)
}
if err != nil {
@@ -326,7 +326,7 @@ func ListenTCP(net string, laddr *TCPAddr) (*TCPListener, error) {
if laddr == nil {
laddr = &TCPAddr{}
}
- fd, err := internetSocket(net, laddr, nil, noDeadline, syscall.SOCK_STREAM, 0, "listen")
+ fd, err := internetSocket(net, laddr, nil, noDeadline, syscall.SOCK_STREAM, 0, "listen", noCancel)
if err != nil {
return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: laddr, Err: err}
}
diff --git a/libgo/go/net/tcpsockopt_plan9.go b/libgo/go/net/tcpsockopt_plan9.go
index 9abe186cecb..157282abd34 100644
--- a/libgo/go/net/tcpsockopt_plan9.go
+++ b/libgo/go/net/tcpsockopt_plan9.go
@@ -7,13 +7,12 @@
package net
import (
- "strconv"
"time"
)
// Set keep alive period.
func setKeepAlivePeriod(fd *netFD, d time.Duration) error {
- cmd := "keepalive " + strconv.Itoa(int(d/time.Millisecond))
+ cmd := "keepalive " + itoa(int(d/time.Millisecond))
_, e := fd.ctl.WriteAt([]byte(cmd), 0)
return e
}
diff --git a/libgo/go/net/testdata/case-hosts b/libgo/go/net/testdata/case-hosts
new file mode 100644
index 00000000000..1f30df1179d
--- /dev/null
+++ b/libgo/go/net/testdata/case-hosts
@@ -0,0 +1,2 @@
+127.0.0.1 PreserveMe PreserveMe.local
+::1 PreserveMe PreserveMe.local
diff --git a/libgo/go/net/testdata/hosts b/libgo/go/net/testdata/hosts
index b601763898b..3ed83ff8a83 100644
--- a/libgo/go/net/testdata/hosts
+++ b/libgo/go/net/testdata/hosts
@@ -5,8 +5,7 @@
127.1.1.1 thor
# aliases
127.1.1.2 ullr ullrhost
+fe80::1%lo0 localhost
# Bogus entries that must be ignored.
123.123.123 loki
321.321.321.321
-# TODO(yvesj): Should we be able to parse this? From a Darwin system.
-fe80::1%lo0 localhost
diff --git a/libgo/go/net/textproto/reader.go b/libgo/go/net/textproto/reader.go
index 91303fec612..91bbb573043 100644
--- a/libgo/go/net/textproto/reader.go
+++ b/libgo/go/net/textproto/reader.go
@@ -150,7 +150,7 @@ func (r *Reader) readContinuedLineSlice() ([]byte, error) {
break
}
r.buf = append(r.buf, ' ')
- r.buf = append(r.buf, line...)
+ r.buf = append(r.buf, trim(line)...)
}
return r.buf, nil
}
@@ -237,7 +237,12 @@ func (r *Reader) ReadCodeLine(expectCode int) (code int, message string, err err
// separated by a newline (\n).
//
// See page 36 of RFC 959 (http://www.ietf.org/rfc/rfc959.txt) for
-// details.
+// details of another form of response accepted:
+//
+// code-message line 1
+// message line 2
+// ...
+// code message line n
//
// If the prefix of the status does not match the digits in expectCode,
// ReadResponse returns with err set to &Error{code, message}.
@@ -248,7 +253,8 @@ func (r *Reader) ReadCodeLine(expectCode int) (code int, message string, err err
//
func (r *Reader) ReadResponse(expectCode int) (code int, message string, err error) {
code, continued, message, err := r.readCodeLine(expectCode)
- for err == nil && continued {
+ multi := continued
+ for continued {
line, err := r.ReadLine()
if err != nil {
return 0, "", err
@@ -256,7 +262,7 @@ func (r *Reader) ReadResponse(expectCode int) (code int, message string, err err
var code2 int
var moreMessage string
- code2, continued, moreMessage, err = parseCodeLine(line, expectCode)
+ code2, continued, moreMessage, err = parseCodeLine(line, 0)
if err != nil || code2 != code {
message += "\n" + strings.TrimRight(line, "\r\n")
continued = true
@@ -264,6 +270,10 @@ func (r *Reader) ReadResponse(expectCode int) (code int, message string, err err
}
message += "\n" + moreMessage
}
+ if err != nil && multi && message != "" {
+ // replace one line error message with all lines (full message)
+ err = &Error{code, message}
+ }
return
}
diff --git a/libgo/go/net/textproto/reader_test.go b/libgo/go/net/textproto/reader_test.go
index 91550f74934..93d7939bb5f 100644
--- a/libgo/go/net/textproto/reader_test.go
+++ b/libgo/go/net/textproto/reader_test.go
@@ -205,6 +205,32 @@ func TestReadMIMEHeaderNonCompliant(t *testing.T) {
}
}
+// Test that continued lines are properly trimmed. Issue 11204.
+func TestReadMIMEHeaderTrimContinued(t *testing.T) {
+ // In this header, \n and \r\n terminated lines are mixed on purpose.
+ // We expect each line to be trimmed (prefix and suffix) before being concatenated.
+ // Keep the spaces as they are.
+ r := reader("" + // for code formatting purpose.
+ "a:\n" +
+ " 0 \r\n" +
+ "b:1 \t\r\n" +
+ "c: 2\r\n" +
+ " 3\t\n" +
+ " \t 4 \r\n\n")
+ m, err := r.ReadMIMEHeader()
+ if err != nil {
+ t.Fatal(err)
+ }
+ want := MIMEHeader{
+ "A": {"0"},
+ "B": {"1"},
+ "C": {"2 3 4"},
+ }
+ if !reflect.DeepEqual(m, want) {
+ t.Fatalf("ReadMIMEHeader mismatch.\n got: %q\nwant: %q", m, want)
+ }
+}
+
type readResponseTest struct {
in string
inCode int
@@ -258,6 +284,35 @@ func TestRFC959Lines(t *testing.T) {
}
}
+// Test that multi-line errors are appropriately and fully read. Issue 10230.
+func TestReadMultiLineError(t *testing.T) {
+ r := reader("550-5.1.1 The email account that you tried to reach does not exist. Please try\n" +
+ "550-5.1.1 double-checking the recipient's email address for typos or\n" +
+ "550-5.1.1 unnecessary spaces. Learn more at\n" +
+ "Unexpected but legal text!\n" +
+ "550 5.1.1 https://support.google.com/mail/answer/6596 h20si25154304pfd.166 - gsmtp\n")
+
+ wantMsg := "5.1.1 The email account that you tried to reach does not exist. Please try\n" +
+ "5.1.1 double-checking the recipient's email address for typos or\n" +
+ "5.1.1 unnecessary spaces. Learn more at\n" +
+ "Unexpected but legal text!\n" +
+ "5.1.1 https://support.google.com/mail/answer/6596 h20si25154304pfd.166 - gsmtp"
+
+ code, msg, err := r.ReadResponse(250)
+ if err == nil {
+ t.Errorf("ReadResponse: no error, want error")
+ }
+ if code != 550 {
+ t.Errorf("ReadResponse: code=%d, want %d", code, 550)
+ }
+ if msg != wantMsg {
+ t.Errorf("ReadResponse: msg=%q, want %q", msg, wantMsg)
+ }
+ if err.Error() != "550 "+wantMsg {
+ t.Errorf("ReadResponse: error=%q, want %q", err.Error(), "550 "+wantMsg)
+ }
+}
+
func TestCommonHeaders(t *testing.T) {
for h := range commonHeader {
if h != CanonicalMIMEHeaderKey(h) {
diff --git a/libgo/go/net/timeout_test.go b/libgo/go/net/timeout_test.go
index ca94e24c816..98e3164fb98 100644
--- a/libgo/go/net/timeout_test.go
+++ b/libgo/go/net/timeout_test.go
@@ -33,6 +33,7 @@ var dialTimeoutTests = []struct {
}
func TestDialTimeout(t *testing.T) {
+ // Cannot use t.Parallel - modifies global hooks.
origTestHookDialChannel := testHookDialChannel
defer func() { testHookDialChannel = origTestHookDialChannel }()
defer sw.Set(socktest.FilterConnect, nil)
@@ -110,6 +111,8 @@ var acceptTimeoutTests = []struct {
}
func TestAcceptTimeout(t *testing.T) {
+ t.Parallel()
+
switch runtime.GOOS {
case "plan9":
t.Skipf("not supported on %s", runtime.GOOS)
@@ -161,6 +164,8 @@ func TestAcceptTimeout(t *testing.T) {
}
func TestAcceptTimeoutMustReturn(t *testing.T) {
+ t.Parallel()
+
switch runtime.GOOS {
case "plan9":
t.Skipf("not supported on %s", runtime.GOOS)
@@ -205,6 +210,8 @@ func TestAcceptTimeoutMustReturn(t *testing.T) {
}
func TestAcceptTimeoutMustNotReturn(t *testing.T) {
+ t.Parallel()
+
switch runtime.GOOS {
case "plan9":
t.Skipf("not supported on %s", runtime.GOOS)
@@ -254,6 +261,8 @@ var readTimeoutTests = []struct {
}
func TestReadTimeout(t *testing.T) {
+ t.Parallel()
+
switch runtime.GOOS {
case "plan9":
t.Skipf("not supported on %s", runtime.GOOS)
@@ -313,6 +322,8 @@ func TestReadTimeout(t *testing.T) {
}
func TestReadTimeoutMustNotReturn(t *testing.T) {
+ t.Parallel()
+
switch runtime.GOOS {
case "plan9":
t.Skipf("not supported on %s", runtime.GOOS)
@@ -454,6 +465,8 @@ var writeTimeoutTests = []struct {
}
func TestWriteTimeout(t *testing.T) {
+ t.Parallel()
+
switch runtime.GOOS {
case "plan9":
t.Skipf("not supported on %s", runtime.GOOS)
@@ -500,6 +513,8 @@ func TestWriteTimeout(t *testing.T) {
}
func TestWriteTimeoutMustNotReturn(t *testing.T) {
+ t.Parallel()
+
switch runtime.GOOS {
case "plan9":
t.Skipf("not supported on %s", runtime.GOOS)
@@ -569,6 +584,8 @@ var writeToTimeoutTests = []struct {
}
func TestWriteToTimeout(t *testing.T) {
+ t.Parallel()
+
switch runtime.GOOS {
case "nacl", "plan9":
t.Skipf("not supported on %s", runtime.GOOS)
@@ -620,6 +637,8 @@ func TestWriteToTimeout(t *testing.T) {
}
func TestReadTimeoutFluctuation(t *testing.T) {
+ t.Parallel()
+
switch runtime.GOOS {
case "plan9":
t.Skipf("not supported on %s", runtime.GOOS)
@@ -656,6 +675,8 @@ func TestReadTimeoutFluctuation(t *testing.T) {
}
func TestReadFromTimeoutFluctuation(t *testing.T) {
+ t.Parallel()
+
switch runtime.GOOS {
case "plan9":
t.Skipf("not supported on %s", runtime.GOOS)
@@ -692,6 +713,8 @@ func TestReadFromTimeoutFluctuation(t *testing.T) {
}
func TestWriteTimeoutFluctuation(t *testing.T) {
+ t.Parallel()
+
switch runtime.GOOS {
case "plan9":
t.Skipf("not supported on %s", runtime.GOOS)
@@ -731,12 +754,27 @@ func TestWriteTimeoutFluctuation(t *testing.T) {
}
}
+func TestVariousDeadlines(t *testing.T) {
+ t.Parallel()
+ testVariousDeadlines(t)
+}
+
func TestVariousDeadlines1Proc(t *testing.T) {
- testVariousDeadlines(t, 1)
+ // Cannot use t.Parallel - modifies global GOMAXPROCS.
+ if testing.Short() {
+ t.Skip("skipping in short mode")
+ }
+ defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(1))
+ testVariousDeadlines(t)
}
func TestVariousDeadlines4Proc(t *testing.T) {
- testVariousDeadlines(t, 4)
+ // Cannot use t.Parallel - modifies global GOMAXPROCS.
+ if testing.Short() {
+ t.Skip("skipping in short mode")
+ }
+ defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(4))
+ testVariousDeadlines(t)
}
type neverEnding byte
@@ -748,14 +786,12 @@ func (b neverEnding) Read(p []byte) (int, error) {
return len(p), nil
}
-func testVariousDeadlines(t *testing.T, maxProcs int) {
+func testVariousDeadlines(t *testing.T) {
switch runtime.GOOS {
case "plan9":
t.Skipf("not supported on %s", runtime.GOOS)
}
- defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(maxProcs))
-
type result struct {
n int64
err error
@@ -869,6 +905,8 @@ func testVariousDeadlines(t *testing.T, maxProcs int) {
// TestReadWriteProlongedTimeout tests concurrent deadline
// modification. Known to cause data races in the past.
func TestReadWriteProlongedTimeout(t *testing.T) {
+ t.Parallel()
+
switch runtime.GOOS {
case "plan9":
t.Skipf("not supported on %s", runtime.GOOS)
@@ -947,6 +985,8 @@ func TestReadWriteProlongedTimeout(t *testing.T) {
}
func TestReadWriteDeadlineRace(t *testing.T) {
+ t.Parallel()
+
switch runtime.GOOS {
case "nacl", "plan9":
t.Skipf("not supported on %s", runtime.GOOS)
@@ -956,7 +996,6 @@ func TestReadWriteDeadlineRace(t *testing.T) {
if testing.Short() {
N = 50
}
- defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(4))
ln, err := newLocalListener("tcp")
if err != nil {
diff --git a/libgo/go/net/udpsock_posix.go b/libgo/go/net/udpsock_posix.go
index 61868c4b0cf..932c6ce713f 100644
--- a/libgo/go/net/udpsock_posix.go
+++ b/libgo/go/net/udpsock_posix.go
@@ -189,7 +189,7 @@ func DialUDP(net string, laddr, raddr *UDPAddr) (*UDPConn, error) {
}
func dialUDP(net string, laddr, raddr *UDPAddr, deadline time.Time) (*UDPConn, error) {
- fd, err := internetSocket(net, laddr, raddr, deadline, syscall.SOCK_DGRAM, 0, "dial")
+ fd, err := internetSocket(net, laddr, raddr, deadline, syscall.SOCK_DGRAM, 0, "dial", noCancel)
if err != nil {
return nil, &OpError{Op: "dial", Net: net, Source: laddr.opAddr(), Addr: raddr.opAddr(), Err: err}
}
@@ -212,7 +212,7 @@ func ListenUDP(net string, laddr *UDPAddr) (*UDPConn, error) {
if laddr == nil {
laddr = &UDPAddr{}
}
- fd, err := internetSocket(net, laddr, nil, noDeadline, syscall.SOCK_DGRAM, 0, "listen")
+ fd, err := internetSocket(net, laddr, nil, noDeadline, syscall.SOCK_DGRAM, 0, "listen", noCancel)
if err != nil {
return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: laddr, Err: err}
}
@@ -239,7 +239,7 @@ func ListenMulticastUDP(network string, ifi *Interface, gaddr *UDPAddr) (*UDPCon
if gaddr == nil || gaddr.IP == nil {
return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: gaddr.opAddr(), Err: errMissingAddress}
}
- fd, err := internetSocket(network, gaddr, nil, noDeadline, syscall.SOCK_DGRAM, 0, "listen")
+ fd, err := internetSocket(network, gaddr, nil, noDeadline, syscall.SOCK_DGRAM, 0, "listen", noCancel)
if err != nil {
return nil, &OpError{Op: "listen", Net: network, Source: nil, Addr: gaddr, Err: err}
}
diff --git a/libgo/go/net/unix_test.go b/libgo/go/net/unix_test.go
index 358ff310725..f0c583068ec 100644
--- a/libgo/go/net/unix_test.go
+++ b/libgo/go/net/unix_test.go
@@ -405,6 +405,42 @@ func TestUnixgramConnLocalAndRemoteNames(t *testing.T) {
}
}
+func TestUnixUnlink(t *testing.T) {
+ if !testableNetwork("unix") {
+ t.Skip("unix test")
+ }
+ name := testUnixAddr()
+ l, err := Listen("unix", name)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if _, err := os.Stat(name); err != nil {
+ t.Fatalf("cannot stat unix socket after ListenUnix: %v", err)
+ }
+ f, _ := l.(*UnixListener).File()
+ l1, err := FileListener(f)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if _, err := os.Stat(name); err != nil {
+ t.Fatalf("cannot stat unix socket after FileListener: %v", err)
+ }
+ if err := l1.Close(); err != nil {
+ t.Fatalf("closing file listener: %v", err)
+ }
+ if _, err := os.Stat(name); err != nil {
+ t.Fatalf("cannot stat unix socket after closing FileListener: %v", err)
+ }
+ f.Close()
+ if _, err := os.Stat(name); err != nil {
+ t.Fatalf("cannot stat unix socket after closing FileListener and fd: %v", err)
+ }
+ l.Close()
+ if _, err := os.Stat(name); err == nil {
+ t.Fatal("closing unix listener did not remove unix socket")
+ }
+}
+
// forceGoDNS forces the resolver configuration to use the pure Go resolver
// and returns a fixup function to restore the old settings.
func forceGoDNS() func() {
@@ -421,11 +457,17 @@ func forceGoDNS() func() {
}
// forceCgoDNS forces the resolver configuration to use the cgo resolver
-// and returns true to indicate that it did so.
-// (On non-Unix systems forceCgoDNS returns false.)
-func forceCgoDNS() bool {
+// and returns a fixup function to restore the old settings.
+// (On non-Unix systems forceCgoDNS returns nil.)
+func forceCgoDNS() func() {
c := systemConf()
+ oldGo := c.netGo
+ oldCgo := c.netCgo
+ fixup := func() {
+ c.netGo = oldGo
+ c.netCgo = oldCgo
+ }
c.netGo = false
c.netCgo = true
- return true
+ return fixup
}
diff --git a/libgo/go/net/unixsock_posix.go b/libgo/go/net/unixsock_posix.go
index 351d9b3a39a..fb2397e26f2 100644
--- a/libgo/go/net/unixsock_posix.go
+++ b/libgo/go/net/unixsock_posix.go
@@ -42,7 +42,7 @@ func unixSocket(net string, laddr, raddr sockaddr, mode string, deadline time.Ti
return nil, errors.New("unknown mode: " + mode)
}
- fd, err := socket(net, syscall.AF_UNIX, sotype, 0, false, laddr, raddr, deadline)
+ fd, err := socket(net, syscall.AF_UNIX, sotype, 0, false, laddr, raddr, deadline, noCancel)
if err != nil {
return nil, err
}
@@ -273,8 +273,9 @@ func dialUnix(net string, laddr, raddr *UnixAddr, deadline time.Time) (*UnixConn
// typically use variables of type Listener instead of assuming Unix
// domain sockets.
type UnixListener struct {
- fd *netFD
- path string
+ fd *netFD
+ path string
+ unlink bool
}
// ListenUnix announces on the Unix domain socket laddr and returns a
@@ -292,7 +293,7 @@ func ListenUnix(net string, laddr *UnixAddr) (*UnixListener, error) {
if err != nil {
return nil, &OpError{Op: "listen", Net: net, Source: nil, Addr: laddr.opAddr(), Err: err}
}
- return &UnixListener{fd: fd, path: fd.laddr.String()}, nil
+ return &UnixListener{fd: fd, path: fd.laddr.String(), unlink: true}, nil
}
// AcceptUnix accepts the next incoming call and returns the new
@@ -335,7 +336,7 @@ func (l *UnixListener) Close() error {
// is at least compatible with the auto-remove
// sequence in ListenUnix. It's only non-Go
// programs that can mess us up.
- if l.path[0] != '@' {
+ if l.path[0] != '@' && l.unlink {
syscall.Unlink(l.path)
}
err := l.fd.Close()
diff --git a/libgo/go/net/url/url.go b/libgo/go/net/url/url.go
index 8ffad663d5c..1a93e3496ed 100644
--- a/libgo/go/net/url/url.go
+++ b/libgo/go/net/url/url.go
@@ -24,6 +24,24 @@ type Error struct {
func (e *Error) Error() string { return e.Op + " " + e.URL + ": " + e.Err.Error() }
+type timeout interface {
+ Timeout() bool
+}
+
+func (e *Error) Timeout() bool {
+ t, ok := e.Err.(timeout)
+ return ok && t.Timeout()
+}
+
+type temporary interface {
+ Temporary() bool
+}
+
+func (e *Error) Temporary() bool {
+ t, ok := e.Err.(temporary)
+ return ok && t.Temporary()
+}
+
func ishex(c byte) bool {
switch {
case '0' <= c && c <= '9':
@@ -53,6 +71,7 @@ type encoding int
const (
encodePath encoding = 1 + iota
encodeHost
+ encodeZone
encodeUserPassword
encodeQueryComponent
encodeFragment
@@ -64,6 +83,12 @@ func (e EscapeError) Error() string {
return "invalid URL escape " + strconv.Quote(string(e))
}
+type InvalidHostError string
+
+func (e InvalidHostError) Error() string {
+ return "invalid character " + strconv.Quote(string(e)) + " in host name"
+}
+
// Return true if the specified character should be escaped when
// appearing in a URL string, according to RFC 3986.
//
@@ -75,14 +100,18 @@ func shouldEscape(c byte, mode encoding) bool {
return false
}
- if mode == encodeHost {
+ if mode == encodeHost || mode == encodeZone {
// §3.2.2 Host allows
// sub-delims = "!" / "$" / "&" / "'" / "(" / ")" / "*" / "+" / "," / ";" / "="
// as part of reg-name.
// We add : because we include :port as part of host.
- // We add [ ] because we include [ipv6]:port as part of host
+ // We add [ ] because we include [ipv6]:port as part of host.
+ // We add < > because they're the only characters left that
+ // we could possibly allow, and Parse will reject them if we
+ // escape them (because hosts can't use %-encoding for
+ // ASCII bytes).
switch c {
- case '!', '$', '&', '\'', '(', ')', '*', '+', ',', ';', '=', ':', '[', ']':
+ case '!', '$', '&', '\'', '(', ')', '*', '+', ',', ';', '=', ':', '[', ']', '<', '>', '"':
return false
}
}
@@ -148,11 +177,36 @@ func unescape(s string, mode encoding) (string, error) {
}
return "", EscapeError(s)
}
+ // Per https://tools.ietf.org/html/rfc3986#page-21
+ // in the host component %-encoding can only be used
+ // for non-ASCII bytes.
+ // But https://tools.ietf.org/html/rfc6874#section-2
+ // introduces %25 being allowed to escape a percent sign
+ // in IPv6 scoped-address literals. Yay.
+ if mode == encodeHost && unhex(s[i+1]) < 8 && s[i:i+3] != "%25" {
+ return "", EscapeError(s[i : i+3])
+ }
+ if mode == encodeZone {
+ // RFC 6874 says basically "anything goes" for zone identifiers
+ // and that even non-ASCII can be redundantly escaped,
+ // but it seems prudent to restrict %-escaped bytes here to those
+ // that are valid host name bytes in their unescaped form.
+ // That is, you can use escaping in the zone identifier but not
+ // to introduce bytes you couldn't just write directly.
+ // But Windows puts spaces here! Yay.
+ v := unhex(s[i+1])<<4 | unhex(s[i+2])
+ if s[i:i+3] != "%25" && v != ' ' && shouldEscape(v, encodeHost) {
+ return "", EscapeError(s[i : i+3])
+ }
+ }
i += 3
case '+':
hasPlus = mode == encodeQueryComponent
i++
default:
+ if (mode == encodeHost || mode == encodeZone) && s[i] < 0x80 && shouldEscape(s[i], mode) {
+ return "", InvalidHostError(s[i : i+1])
+ }
i++
}
}
@@ -246,7 +300,7 @@ func escape(s string, mode encoding) string {
// Go 1.5 introduced the RawPath field to hold the encoded form of Path.
// The Parse function sets both Path and RawPath in the URL it returns,
// and URL's String method uses RawPath if it is a valid encoding of Path,
-// by calling the EncodedPath method.
+// by calling the EscapedPath method.
//
// In earlier versions of Go, the more indirect workarounds were that an
// HTTP server could consult req.RequestURI and an HTTP client could
@@ -431,7 +485,7 @@ func parse(rawurl string, viaRequest bool) (url *URL, err error) {
goto Error
}
// RawPath is a hint as to the encoding of Path to use
- // in url.EncodedPath. If that method already gets the
+ // in url.EscapedPath. If that method already gets the
// right answer without RawPath, leave it empty.
// This will help make sure that people don't rely on it in general.
if url.EscapedPath() != rest && validEncodedPath(rest) {
@@ -478,14 +532,9 @@ func parseAuthority(authority string) (user *Userinfo, host string, err error) {
// parseHost parses host as an authority without user
// information. That is, as host[:port].
func parseHost(host string) (string, error) {
- litOrName := host
if strings.HasPrefix(host, "[") {
// Parse an IP-Literal in RFC 3986 and RFC 6874.
- // E.g., "[fe80::1], "[fe80::1%25en0]"
- //
- // RFC 4007 defines "%" as a delimiter character in
- // the textual representation of IPv6 addresses.
- // Per RFC 6874, in URIs that "%" is encoded as "%25".
+ // E.g., "[fe80::1]", "[fe80::1%25en0]", "[fe80::1]:80".
i := strings.LastIndex(host, "]")
if i < 0 {
return "", errors.New("missing ']' in host")
@@ -494,29 +543,31 @@ func parseHost(host string) (string, error) {
if !validOptionalPort(colonPort) {
return "", fmt.Errorf("invalid port %q after host", colonPort)
}
- // Parse a host subcomponent without a ZoneID in RFC
- // 6874 because the ZoneID is allowed to use the
- // percent encoded form.
- j := strings.Index(host[:i], "%25")
- if j < 0 {
- litOrName = host[1:i]
- } else {
- litOrName = host[1:j]
+
+ // RFC 6874 defines that %25 (%-encoded percent) introduces
+ // the zone identifier, and the zone identifier can use basically
+ // any %-encoding it likes. That's different from the host, which
+ // can only %-encode non-ASCII bytes.
+ // We do impose some restrictions on the zone, to avoid stupidity
+ // like newlines.
+ zone := strings.Index(host[:i], "%25")
+ if zone >= 0 {
+ host1, err := unescape(host[:zone], encodeHost)
+ if err != nil {
+ return "", err
+ }
+ host2, err := unescape(host[zone:i], encodeZone)
+ if err != nil {
+ return "", err
+ }
+ host3, err := unescape(host[i:], encodeHost)
+ if err != nil {
+ return "", err
+ }
+ return host1 + host2 + host3, nil
}
}
- // A URI containing an IP-Literal without a ZoneID or
- // IPv4address in RFC 3986 and RFC 6847 must not be
- // percent-encoded.
- //
- // A URI containing a DNS registered name in RFC 3986 is
- // allowed to be percent-encoded, though we don't use it for
- // now to avoid messing up with the gap between allowed
- // characters in URI and allowed characters in DNS.
- // See golang.org/issue/7991.
- if strings.Contains(litOrName, "%") {
- return "", errors.New("percent-encoded characters in host")
- }
var err error
if host, err = unescape(host, encodeHost); err != nil {
return "", err
@@ -572,12 +623,12 @@ func validEncodedPath(s string) bool {
}
// validOptionalPort reports whether port is either an empty string
-// or matches /^:\d+$/
+// or matches /^:\d*$/
func validOptionalPort(port string) bool {
if port == "" {
return true
}
- if port[0] != ':' || len(port) == 1 {
+ if port[0] != ':' {
return false
}
for _, b := range port[1:] {
@@ -596,7 +647,7 @@ func validOptionalPort(port string) bool {
//
// If u.Opaque is non-empty, String uses the first form;
// otherwise it uses the second form.
-// To obtain the path, String uses u.EncodedPath().
+// To obtain the path, String uses u.EscapedPath().
//
// In the second form, the following rules apply:
// - if u.Scheme is empty, scheme: is omitted.
diff --git a/libgo/go/net/url/url_test.go b/libgo/go/net/url/url_test.go
index ff6e9e4541a..d3f8487bd7c 100644
--- a/libgo/go/net/url/url_test.go
+++ b/libgo/go/net/url/url_test.go
@@ -6,6 +6,8 @@ package url
import (
"fmt"
+ "io"
+ "net"
"reflect"
"strings"
"testing"
@@ -330,7 +332,7 @@ var urltests = []URLTest{
},
"",
},
- // host subcomponent; IPv6 address with zone identifier in RFC 6847
+ // host subcomponent; IPv6 address with zone identifier in RFC 6874
{
"http://[fe80::1%25en0]/", // alphanum zone identifier
&URL{
@@ -340,7 +342,7 @@ var urltests = []URLTest{
},
"",
},
- // host and port subcomponents; IPv6 address with zone identifier in RFC 6847
+ // host and port subcomponents; IPv6 address with zone identifier in RFC 6874
{
"http://[fe80::1%25en0]:8080/", // alphanum zone identifier
&URL{
@@ -350,7 +352,7 @@ var urltests = []URLTest{
},
"",
},
- // host subcomponent; IPv6 address with zone identifier in RFC 6847
+ // host subcomponent; IPv6 address with zone identifier in RFC 6874
{
"http://[fe80::1%25%65%6e%301-._~]/", // percent-encoded+unreserved zone identifier
&URL{
@@ -360,7 +362,7 @@ var urltests = []URLTest{
},
"http://[fe80::1%25en01-._~]/",
},
- // host and port subcomponents; IPv6 address with zone identifier in RFC 6847
+ // host and port subcomponents; IPv6 address with zone identifier in RFC 6874
{
"http://[fe80::1%25%65%6e%301-._~]:8080/", // percent-encoded+unreserved zone identifier
&URL{
@@ -424,6 +426,122 @@ var urltests = []URLTest{
},
"",
},
+ // golang.org/issue/12200 (colon with empty port)
+ {
+ "http://192.168.0.2:8080/foo",
+ &URL{
+ Scheme: "http",
+ Host: "192.168.0.2:8080",
+ Path: "/foo",
+ },
+ "",
+ },
+ {
+ "http://192.168.0.2:/foo",
+ &URL{
+ Scheme: "http",
+ Host: "192.168.0.2:",
+ Path: "/foo",
+ },
+ "",
+ },
+ {
+ // Malformed IPv6 but still accepted.
+ "http://2b01:e34:ef40:7730:8e70:5aff:fefe:edac:8080/foo",
+ &URL{
+ Scheme: "http",
+ Host: "2b01:e34:ef40:7730:8e70:5aff:fefe:edac:8080",
+ Path: "/foo",
+ },
+ "",
+ },
+ {
+ // Malformed IPv6 but still accepted.
+ "http://2b01:e34:ef40:7730:8e70:5aff:fefe:edac:/foo",
+ &URL{
+ Scheme: "http",
+ Host: "2b01:e34:ef40:7730:8e70:5aff:fefe:edac:",
+ Path: "/foo",
+ },
+ "",
+ },
+ {
+ "http://[2b01:e34:ef40:7730:8e70:5aff:fefe:edac]:8080/foo",
+ &URL{
+ Scheme: "http",
+ Host: "[2b01:e34:ef40:7730:8e70:5aff:fefe:edac]:8080",
+ Path: "/foo",
+ },
+ "",
+ },
+ {
+ "http://[2b01:e34:ef40:7730:8e70:5aff:fefe:edac]:/foo",
+ &URL{
+ Scheme: "http",
+ Host: "[2b01:e34:ef40:7730:8e70:5aff:fefe:edac]:",
+ Path: "/foo",
+ },
+ "",
+ },
+ // golang.org/issue/7991 and golang.org/issue/12719 (non-ascii %-encoded in host)
+ {
+ "http://hello.世界.com/foo",
+ &URL{
+ Scheme: "http",
+ Host: "hello.世界.com",
+ Path: "/foo",
+ },
+ "http://hello.%E4%B8%96%E7%95%8C.com/foo",
+ },
+ {
+ "http://hello.%e4%b8%96%e7%95%8c.com/foo",
+ &URL{
+ Scheme: "http",
+ Host: "hello.世界.com",
+ Path: "/foo",
+ },
+ "http://hello.%E4%B8%96%E7%95%8C.com/foo",
+ },
+ {
+ "http://hello.%E4%B8%96%E7%95%8C.com/foo",
+ &URL{
+ Scheme: "http",
+ Host: "hello.世界.com",
+ Path: "/foo",
+ },
+ "",
+ },
+ // golang.org/issue/10433 (path beginning with //)
+ {
+ "http://example.com//foo",
+ &URL{
+ Scheme: "http",
+ Host: "example.com",
+ Path: "//foo",
+ },
+ "",
+ },
+ // test that we can reparse the host names we accept.
+ {
+ "myscheme://authority<\"hi\">/foo",
+ &URL{
+ Scheme: "myscheme",
+ Host: "authority<\"hi\">",
+ Path: "/foo",
+ },
+ "",
+ },
+ // spaces in hosts are disallowed but escaped spaces in IPv6 scope IDs are grudgingly OK.
+ // This happens on Windows.
+ // golang.org/issue/14002
+ {
+ "tcp://[2020::2020:20:2020:2020%25Windows%20Loves%20Spaces]:2020",
+ &URL{
+ Scheme: "tcp",
+ Host: "[2020::2020:20:2020:2020%Windows Loves Spaces]:2020",
+ },
+ "",
+ },
}
// more useful string for debugging than fmt's struct printer
@@ -1091,6 +1209,14 @@ var requritests = []RequestURITest{
},
"opaque?q=go+language",
},
+ {
+ &URL{
+ Scheme: "http",
+ Host: "example.com",
+ Path: "//foo",
+ },
+ "//foo",
+ },
}
func TestRequestURI(t *testing.T) {
@@ -1124,16 +1250,17 @@ func TestParseAuthority(t *testing.T) {
{"http://[::1]a", true},
{"http://[::1]%23", true},
{"http://[::1%25en0]", false}, // valid zone id
- {"http://[::1]:", true}, // colon, but no port
- {"http://[::1]:%38%30", true}, // no hex in port
- {"http://[::1%25%10]", false}, // TODO: reject the %10 after the valid zone %25 separator?
+ {"http://[::1]:", false}, // colon, but no port OK
+ {"http://[::1]:%38%30", true}, // not allowed: % encoding only for non-ASCII
+ {"http://[::1%25%41]", false}, // RFC 6874 allows over-escaping in zone
{"http://[%10::1]", true}, // no %xx escapes in IP address
{"http://[::1]/%48", false}, // %xx in path is fine
- {"http://%41:8080/", true}, // TODO: arguably we should accept reg-name with %xx
+ {"http://%41:8080/", true}, // not allowed: % encoding only for non-ASCII
{"mysql://x@y(z:123)/foo", false}, // golang.org/issue/12023
{"mysql://x@y(1.2.3.4:123)/foo", false},
{"mysql://x@y([2001:db8::1]:123)/foo", false},
{"http://[]%20%48%54%54%50%2f%31%2e%31%0a%4d%79%48%65%61%64%65%72%3a%20%31%32%33%0a%0a/", true}, // golang.org/issue/11208
+ {"http://a b.com/", true}, // no space in host name please
}
for _, tt := range tests {
u, err := Parse(tt.in)
@@ -1229,3 +1356,84 @@ func TestShouldEscape(t *testing.T) {
}
}
}
+
+type timeoutError struct {
+ timeout bool
+}
+
+func (e *timeoutError) Error() string { return "timeout error" }
+func (e *timeoutError) Timeout() bool { return e.timeout }
+
+type temporaryError struct {
+ temporary bool
+}
+
+func (e *temporaryError) Error() string { return "temporary error" }
+func (e *temporaryError) Temporary() bool { return e.temporary }
+
+type timeoutTemporaryError struct {
+ timeoutError
+ temporaryError
+}
+
+func (e *timeoutTemporaryError) Error() string { return "timeout/temporary error" }
+
+var netErrorTests = []struct {
+ err error
+ timeout bool
+ temporary bool
+}{{
+ err: &Error{"Get", "http://google.com/", &timeoutError{timeout: true}},
+ timeout: true,
+ temporary: false,
+}, {
+ err: &Error{"Get", "http://google.com/", &timeoutError{timeout: false}},
+ timeout: false,
+ temporary: false,
+}, {
+ err: &Error{"Get", "http://google.com/", &temporaryError{temporary: true}},
+ timeout: false,
+ temporary: true,
+}, {
+ err: &Error{"Get", "http://google.com/", &temporaryError{temporary: false}},
+ timeout: false,
+ temporary: false,
+}, {
+ err: &Error{"Get", "http://google.com/", &timeoutTemporaryError{timeoutError{timeout: true}, temporaryError{temporary: true}}},
+ timeout: true,
+ temporary: true,
+}, {
+ err: &Error{"Get", "http://google.com/", &timeoutTemporaryError{timeoutError{timeout: false}, temporaryError{temporary: true}}},
+ timeout: false,
+ temporary: true,
+}, {
+ err: &Error{"Get", "http://google.com/", &timeoutTemporaryError{timeoutError{timeout: true}, temporaryError{temporary: false}}},
+ timeout: true,
+ temporary: false,
+}, {
+ err: &Error{"Get", "http://google.com/", &timeoutTemporaryError{timeoutError{timeout: false}, temporaryError{temporary: false}}},
+ timeout: false,
+ temporary: false,
+}, {
+ err: &Error{"Get", "http://google.com/", io.EOF},
+ timeout: false,
+ temporary: false,
+}}
+
+// Test that url.Error implements net.Error and that it forwards
+func TestURLErrorImplementsNetError(t *testing.T) {
+ for i, tt := range netErrorTests {
+ err, ok := tt.err.(net.Error)
+ if !ok {
+ t.Errorf("%d: %T does not implement net.Error", i+1, tt.err)
+ continue
+ }
+ if err.Timeout() != tt.timeout {
+ t.Errorf("%d: err.Timeout(): want %v, have %v", i+1, tt.timeout, err.Timeout())
+ continue
+ }
+ if err.Temporary() != tt.temporary {
+ t.Errorf("%d: err.Temporary(): want %v, have %v", i+1, tt.temporary, err.Temporary())
+ }
+ }
+}