diff options
-rw-r--r-- | src/net/http/transport.go | 5 | ||||
-rw-r--r-- | src/net/http/transport_internal_test.go | 83 |
2 files changed, 85 insertions, 3 deletions
diff --git a/src/net/http/transport.go b/src/net/http/transport.go index 7f8fd505bd..e6493036e8 100644 --- a/src/net/http/transport.go +++ b/src/net/http/transport.go @@ -478,9 +478,8 @@ func (t *Transport) roundTrip(req *Request) (*Response, error) { } testHookRoundTripRetried() - // Rewind the body if we're able to. (HTTP/2 does this itself so we only - // need to do it for HTTP/1.1 connections.) - if req.GetBody != nil && pconn.alt == nil { + // Rewind the body if we're able to. + if req.GetBody != nil { newReq := *req var err error newReq.Body, err = req.GetBody() diff --git a/src/net/http/transport_internal_test.go b/src/net/http/transport_internal_test.go index a5f29c97a9..92729e65b2 100644 --- a/src/net/http/transport_internal_test.go +++ b/src/net/http/transport_internal_test.go @@ -7,8 +7,13 @@ package http import ( + "bytes" + "crypto/tls" "errors" + "io" + "io/ioutil" "net" + "net/http/internal" "strings" "testing" ) @@ -178,3 +183,81 @@ func TestTransportShouldRetryRequest(t *testing.T) { } } } + +type roundTripFunc func(r *Request) (*Response, error) + +func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) { + return f(r) +} + +// Issue 25009 +func TestTransportBodyAltRewind(t *testing.T) { + cert, err := tls.X509KeyPair(internal.LocalhostCert, internal.LocalhostKey) + if err != nil { + t.Fatal(err) + } + ln := newLocalListener(t) + defer ln.Close() + + go func() { + tln := tls.NewListener(ln, &tls.Config{ + NextProtos: []string{"foo"}, + Certificates: []tls.Certificate{cert}, + }) + for i := 0; i < 2; i++ { + sc, err := tln.Accept() + if err != nil { + t.Error(err) + return + } + if err := sc.(*tls.Conn).Handshake(); err != nil { + t.Error(err) + return + } + sc.Close() + } + }() + + addr := ln.Addr().String() + req, _ := NewRequest("POST", "https://example.org/", bytes.NewBufferString("request")) + roundTripped := false + tr := &Transport{ + DisableKeepAlives: true, + TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{ + "foo": func(authority string, c *tls.Conn) RoundTripper { + return roundTripFunc(func(r *Request) (*Response, error) { + n, _ := io.Copy(ioutil.Discard, r.Body) + if n == 0 { + t.Error("body length is zero") + } + if roundTripped { + return &Response{ + Body: NoBody, + StatusCode: 200, + }, nil + } + roundTripped = true + return nil, http2noCachedConnError{} + }) + }, + }, + DialTLS: func(_, _ string) (net.Conn, error) { + tc, err := tls.Dial("tcp", addr, &tls.Config{ + InsecureSkipVerify: true, + NextProtos: []string{"foo"}, + }) + if err != nil { + return nil, err + } + if err := tc.Handshake(); err != nil { + return nil, err + } + return tc, nil + }, + } + c := &Client{Transport: tr} + _, err = c.Do(req) + if err != nil { + t.Error(err) + } +} |