summaryrefslogtreecommitdiff
path: root/src
diff options
context:
space:
mode:
Diffstat (limited to 'src')
-rw-r--r--src/net/http/transport.go5
-rw-r--r--src/net/http/transport_internal_test.go83
2 files changed, 85 insertions, 3 deletions
diff --git a/src/net/http/transport.go b/src/net/http/transport.go
index 7f8fd505bd..e6493036e8 100644
--- a/src/net/http/transport.go
+++ b/src/net/http/transport.go
@@ -478,9 +478,8 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) {
}
testHookRoundTripRetried()
- // Rewind the body if we're able to. (HTTP/2 does this itself so we only
- // need to do it for HTTP/1.1 connections.)
- if req.GetBody != nil && pconn.alt == nil {
+ // Rewind the body if we're able to.
+ if req.GetBody != nil {
newReq := *req
var err error
newReq.Body, err = req.GetBody()
diff --git a/src/net/http/transport_internal_test.go b/src/net/http/transport_internal_test.go
index a5f29c97a9..92729e65b2 100644
--- a/src/net/http/transport_internal_test.go
+++ b/src/net/http/transport_internal_test.go
@@ -7,8 +7,13 @@
package http
import (
+ "bytes"
+ "crypto/tls"
"errors"
+ "io"
+ "io/ioutil"
"net"
+ "net/http/internal"
"strings"
"testing"
)
@@ -178,3 +183,81 @@ func TestTransportShouldRetryRequest(t *testing.T) {
}
}
}
+
+type roundTripFunc func(r *Request) (*Response, error)
+
+func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) {
+ return f(r)
+}
+
+// Issue 25009
+func TestTransportBodyAltRewind(t *testing.T) {
+ cert, err := tls.X509KeyPair(internal.LocalhostCert, internal.LocalhostKey)
+ if err != nil {
+ t.Fatal(err)
+ }
+ ln := newLocalListener(t)
+ defer ln.Close()
+
+ go func() {
+ tln := tls.NewListener(ln, &tls.Config{
+ NextProtos: []string{"foo"},
+ Certificates: []tls.Certificate{cert},
+ })
+ for i := 0; i < 2; i++ {
+ sc, err := tln.Accept()
+ if err != nil {
+ t.Error(err)
+ return
+ }
+ if err := sc.(*tls.Conn).Handshake(); err != nil {
+ t.Error(err)
+ return
+ }
+ sc.Close()
+ }
+ }()
+
+ addr := ln.Addr().String()
+ req, _ := NewRequest("POST", "https://example.org/", bytes.NewBufferString("request"))
+ roundTripped := false
+ tr := &Transport{
+ DisableKeepAlives: true,
+ TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{
+ "foo": func(authority string, c *tls.Conn) RoundTripper {
+ return roundTripFunc(func(r *Request) (*Response, error) {
+ n, _ := io.Copy(ioutil.Discard, r.Body)
+ if n == 0 {
+ t.Error("body length is zero")
+ }
+ if roundTripped {
+ return &Response{
+ Body: NoBody,
+ StatusCode: 200,
+ }, nil
+ }
+ roundTripped = true
+ return nil, http2noCachedConnError{}
+ })
+ },
+ },
+ DialTLS: func(_, _ string) (net.Conn, error) {
+ tc, err := tls.Dial("tcp", addr, &tls.Config{
+ InsecureSkipVerify: true,
+ NextProtos: []string{"foo"},
+ })
+ if err != nil {
+ return nil, err
+ }
+ if err := tc.Handshake(); err != nil {
+ return nil, err
+ }
+ return tc, nil
+ },
+ }
+ c := &Client{Transport: tr}
+ _, err = c.Do(req)
+ if err != nil {
+ t.Error(err)
+ }
+}