diff options
Diffstat (limited to 'src/net/http')
47 files changed, 860 insertions, 435 deletions
diff --git a/src/net/http/alpn_test.go b/src/net/http/alpn_test.go index 618bdbe54a..a51038c355 100644 --- a/src/net/http/alpn_test.go +++ b/src/net/http/alpn_test.go @@ -11,7 +11,6 @@ import ( "crypto/x509" "fmt" "io" - "io/ioutil" . "net/http" "net/http/httptest" "strings" @@ -49,7 +48,7 @@ func TestNextProtoUpgrade(t *testing.T) { if err != nil { t.Fatal(err) } - body, err := ioutil.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } @@ -93,7 +92,7 @@ func TestNextProtoUpgrade(t *testing.T) { t.Fatal(err) } conn.Write([]byte("GET /foo\n")) - body, err := ioutil.ReadAll(conn) + body, err := io.ReadAll(conn) if err != nil { t.Fatal(err) } diff --git a/src/net/http/cgi/child.go b/src/net/http/cgi/child.go index 690986335c..0114da377b 100644 --- a/src/net/http/cgi/child.go +++ b/src/net/http/cgi/child.go @@ -13,7 +13,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "net" "net/http" "net/url" @@ -32,7 +31,7 @@ func Request() (*http.Request, error) { return nil, err } if r.ContentLength > 0 { - r.Body = ioutil.NopCloser(io.LimitReader(os.Stdin, r.ContentLength)) + r.Body = io.NopCloser(io.LimitReader(os.Stdin, r.ContentLength)) } return r, nil } diff --git a/src/net/http/cgi/host.go b/src/net/http/cgi/host.go index 624044aa09..eff67caf4e 100644 --- a/src/net/http/cgi/host.go +++ b/src/net/http/cgi/host.go @@ -39,13 +39,13 @@ var osDefaultInheritEnv = func() []string { switch runtime.GOOS { case "darwin", "ios": return []string{"DYLD_LIBRARY_PATH"} - case "linux", "freebsd", "openbsd": + case "linux", "freebsd", "netbsd", "openbsd": return []string{"LD_LIBRARY_PATH"} case "hpux": return []string{"LD_LIBRARY_PATH", "SHLIB_PATH"} case "irix": return []string{"LD_LIBRARY_PATH", "LD_LIBRARYN32_PATH", "LD_LIBRARY64_PATH"} - case "solaris": + case "illumos", "solaris": return []string{"LD_LIBRARY_PATH", "LD_LIBRARY_PATH_32", "LD_LIBRARY_PATH_64"} case "windows": return []string{"SystemRoot", "COMSPEC", "PATHEXT", "WINDIR"} diff --git a/src/net/http/client.go b/src/net/http/client.go index 6ca0d2e6cf..88e2028bc3 100644 --- a/src/net/http/client.go +++ b/src/net/http/client.go @@ -16,7 +16,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "log" "net/url" "reflect" @@ -282,7 +281,7 @@ func send(ireq *Request, rt RoundTripper, deadline time.Time) (resp *Response, d if resp.ContentLength > 0 && req.Method != "HEAD" { return nil, didTimeout, fmt.Errorf("http: RoundTripper implementation (%T) returned a *Response with content length %d but a nil Body", rt, resp.ContentLength) } - resp.Body = ioutil.NopCloser(strings.NewReader("")) + resp.Body = io.NopCloser(strings.NewReader("")) } if !deadline.IsZero() { resp.Body = &cancelTimerBody{ @@ -697,7 +696,7 @@ func (c *Client) do(req *Request) (retres *Response, reterr error) { // fails, the Transport won't reuse it anyway. const maxBodySlurpSize = 2 << 10 if resp.ContentLength == -1 || resp.ContentLength <= maxBodySlurpSize { - io.CopyN(ioutil.Discard, resp.Body, maxBodySlurpSize) + io.CopyN(io.Discard, resp.Body, maxBodySlurpSize) } resp.Body.Close() diff --git a/src/net/http/client_test.go b/src/net/http/client_test.go index 80807fae7a..d90b4841c6 100644 --- a/src/net/http/client_test.go +++ b/src/net/http/client_test.go @@ -14,7 +14,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "log" "net" . "net/http" @@ -35,7 +34,7 @@ var robotsTxtHandler = HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "User-agent: go\nDisallow: /something/") }) -// pedanticReadAll works like ioutil.ReadAll but additionally +// pedanticReadAll works like io.ReadAll but additionally // verifies that r obeys the documented io.Reader contract. func pedanticReadAll(r io.Reader) (b []byte, err error) { var bufa [64]byte @@ -190,7 +189,7 @@ func TestPostFormRequestFormat(t *testing.T) { if g, e := tr.req.ContentLength, int64(len(expectedBody)); g != e { t.Errorf("got ContentLength %d, want %d", g, e) } - bodyb, err := ioutil.ReadAll(tr.req.Body) + bodyb, err := io.ReadAll(tr.req.Body) if err != nil { t.Fatalf("ReadAll on req.Body: %v", err) } @@ -421,7 +420,7 @@ func testRedirectsByMethod(t *testing.T, method string, table []redirectTest, wa var ts *httptest.Server ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { log.Lock() - slurp, _ := ioutil.ReadAll(r.Body) + slurp, _ := io.ReadAll(r.Body) fmt.Fprintf(&log.Buffer, "%s %s %q", r.Method, r.RequestURI, slurp) if cl := r.Header.Get("Content-Length"); r.Method == "GET" && len(slurp) == 0 && (r.ContentLength != 0 || cl != "") { fmt.Fprintf(&log.Buffer, " (but with body=%T, content-length = %v, %q)", r.Body, r.ContentLength, cl) @@ -452,7 +451,7 @@ func testRedirectsByMethod(t *testing.T, method string, table []redirectTest, wa for _, tt := range table { content := tt.redirectBody req, _ := NewRequest(method, ts.URL+tt.suffix, strings.NewReader(content)) - req.GetBody = func() (io.ReadCloser, error) { return ioutil.NopCloser(strings.NewReader(content)), nil } + req.GetBody = func() (io.ReadCloser, error) { return io.NopCloser(strings.NewReader(content)), nil } res, err := c.Do(req) if err != nil { @@ -522,7 +521,7 @@ func TestClientRedirectUseResponse(t *testing.T) { t.Errorf("status = %d; want %d", res.StatusCode, StatusFound) } defer res.Body.Close() - slurp, err := ioutil.ReadAll(res.Body) + slurp, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } @@ -1042,7 +1041,7 @@ func testClientHeadContentLength(t *testing.T, h2 bool) { if res.ContentLength != tt.want { t.Errorf("Content-Length = %d; want %d", res.ContentLength, tt.want) } - bs, err := ioutil.ReadAll(res.Body) + bs, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } @@ -1257,7 +1256,7 @@ func testClientTimeout(t *testing.T, h2 bool) { errc := make(chan error, 1) go func() { - _, err := ioutil.ReadAll(res.Body) + _, err := io.ReadAll(res.Body) errc <- err res.Body.Close() }() @@ -1348,7 +1347,7 @@ func TestClientTimeoutCancel(t *testing.T) { t.Fatal(err) } cancel() - _, err = io.Copy(ioutil.Discard, res.Body) + _, err = io.Copy(io.Discard, res.Body) if err != ExportErrRequestCanceled { t.Fatalf("error = %v; want errRequestCanceled", err) } @@ -1372,7 +1371,7 @@ func testClientRedirectEatsBody(t *testing.T, h2 bool) { if err != nil { t.Fatal(err) } - _, err = ioutil.ReadAll(res.Body) + _, err = io.ReadAll(res.Body) res.Body.Close() if err != nil { t.Fatal(err) @@ -1450,7 +1449,7 @@ func (issue15577Tripper) RoundTrip(*Request) (*Response, error) { resp := &Response{ StatusCode: 303, Header: map[string][]string{"Location": {"http://www.example.com/"}}, - Body: ioutil.NopCloser(strings.NewReader("")), + Body: io.NopCloser(strings.NewReader("")), } return resp, nil } @@ -1591,7 +1590,7 @@ func TestClientCopyHostOnRedirect(t *testing.T) { if resp.StatusCode != 200 { t.Fatal(resp.Status) } - if got, err := ioutil.ReadAll(resp.Body); err != nil || string(got) != wantBody { + if got, err := io.ReadAll(resp.Body); err != nil || string(got) != wantBody { t.Errorf("body = %q; want %q", got, wantBody) } } @@ -2020,9 +2019,66 @@ func TestClientPopulatesNilResponseBody(t *testing.T) { } }() - if b, err := ioutil.ReadAll(resp.Body); err != nil { + if b, err := io.ReadAll(resp.Body); err != nil { t.Errorf("read error from substitute Response.Body: %v", err) } else if len(b) != 0 { t.Errorf("substitute Response.Body was unexpectedly non-empty: %q", b) } } + +// Issue 40382: Client calls Close multiple times on Request.Body. +func TestClientCallsCloseOnlyOnce(t *testing.T) { + setParallel(t) + defer afterTest(t) + cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + w.WriteHeader(StatusNoContent) + })) + defer cst.close() + + // Issue occurred non-deterministically: needed to occur after a successful + // write (into TCP buffer) but before end of body. + for i := 0; i < 50 && !t.Failed(); i++ { + body := &issue40382Body{t: t, n: 300000} + req, err := NewRequest(MethodPost, cst.ts.URL, body) + if err != nil { + t.Fatal(err) + } + resp, err := cst.tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + resp.Body.Close() + } +} + +// issue40382Body is an io.ReadCloser for TestClientCallsCloseOnlyOnce. +// Its Read reads n bytes before returning io.EOF. +// Its Close returns nil but fails the test if called more than once. +type issue40382Body struct { + t *testing.T + n int + closeCallsAtomic int32 +} + +func (b *issue40382Body) Read(p []byte) (int, error) { + switch { + case b.n == 0: + return 0, io.EOF + case b.n < len(p): + p = p[:b.n] + fallthrough + default: + for i := range p { + p[i] = 'x' + } + b.n -= len(p) + return len(p), nil + } +} + +func (b *issue40382Body) Close() error { + if atomic.AddInt32(&b.closeCallsAtomic, 1) == 2 { + b.t.Error("Body closed more than once") + } + return nil +} diff --git a/src/net/http/clientserver_test.go b/src/net/http/clientserver_test.go index def5c424f0..5e227181ac 100644 --- a/src/net/http/clientserver_test.go +++ b/src/net/http/clientserver_test.go @@ -15,7 +15,6 @@ import ( "fmt" "hash" "io" - "io/ioutil" "log" "net" . "net/http" @@ -53,7 +52,7 @@ func (t *clientServerTest) getURL(u string) string { t.t.Fatal(err) } defer res.Body.Close() - slurp, err := ioutil.ReadAll(res.Body) + slurp, err := io.ReadAll(res.Body) if err != nil { t.t.Fatal(err) } @@ -152,7 +151,7 @@ func TestChunkedResponseHeaders_h2(t *testing.T) { testChunkedResponseHeaders(t, func testChunkedResponseHeaders(t *testing.T, h2 bool) { defer afterTest(t) - log.SetOutput(ioutil.Discard) // is noisy otherwise + log.SetOutput(io.Discard) // is noisy otherwise defer log.SetOutput(os.Stderr) cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Content-Length", "intentional gibberish") // we check that this is deleted @@ -266,11 +265,11 @@ func (tt h12Compare) normalizeRes(t *testing.T, res *Response, wantProto string) } else { t.Errorf("got %q response; want %q", res.Proto, wantProto) } - slurp, err := ioutil.ReadAll(res.Body) + slurp, err := io.ReadAll(res.Body) res.Body.Close() res.Body = slurpResult{ - ReadCloser: ioutil.NopCloser(bytes.NewReader(slurp)), + ReadCloser: io.NopCloser(bytes.NewReader(slurp)), body: slurp, err: err, } @@ -477,7 +476,7 @@ func test304Responses(t *testing.T, h2 bool) { if len(res.TransferEncoding) > 0 { t.Errorf("expected no TransferEncoding; got %v", res.TransferEncoding) } - body, err := ioutil.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) if err != nil { t.Error(err) } @@ -564,7 +563,7 @@ func testCancelRequestMidBody(t *testing.T, h2 bool) { close(cancel) - rest, err := ioutil.ReadAll(res.Body) + rest, err := io.ReadAll(res.Body) all := string(firstRead) + string(rest) if all != "Hello" { t.Errorf("Read %q (%q + %q); want Hello", all, firstRead, rest) @@ -587,7 +586,7 @@ func testTrailersClientToServer(t *testing.T, h2 bool) { } sort.Strings(decl) - slurp, err := ioutil.ReadAll(r.Body) + slurp, err := io.ReadAll(r.Body) if err != nil { t.Errorf("Server reading request body: %v", err) } @@ -721,7 +720,7 @@ func testResponseBodyReadAfterClose(t *testing.T, h2 bool) { t.Fatal(err) } res.Body.Close() - data, err := ioutil.ReadAll(res.Body) + data, err := io.ReadAll(res.Body) if len(data) != 0 || err == nil { t.Fatalf("ReadAll returned %q, %v; want error", data, err) } @@ -740,7 +739,7 @@ func testConcurrentReadWriteReqBody(t *testing.T, h2 bool) { // Read in one goroutine. go func() { defer wg.Done() - data, err := ioutil.ReadAll(r.Body) + data, err := io.ReadAll(r.Body) if string(data) != reqBody { t.Errorf("Handler read %q; want %q", data, reqBody) } @@ -770,7 +769,7 @@ func testConcurrentReadWriteReqBody(t *testing.T, h2 bool) { if err != nil { t.Fatal(err) } - data, err := ioutil.ReadAll(res.Body) + data, err := io.ReadAll(res.Body) defer res.Body.Close() if err != nil { t.Fatal(err) @@ -887,7 +886,7 @@ func testTransportUserAgent(t *testing.T, h2 bool) { t.Errorf("%d. RoundTrip = %v", i, err) continue } - slurp, err := ioutil.ReadAll(res.Body) + slurp, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { t.Errorf("%d. read body = %v", i, err) @@ -1009,11 +1008,17 @@ func TestTransportDiscardsUnneededConns(t *testing.T) { defer wg.Done() resp, err := c.Get(cst.ts.URL) if err != nil { - t.Errorf("Get: %v", err) - return + // Try to work around spurious connection reset on loaded system. + // See golang.org/issue/33585 and golang.org/issue/36797. + time.Sleep(10 * time.Millisecond) + resp, err = c.Get(cst.ts.URL) + if err != nil { + t.Errorf("Get: %v", err) + return + } } defer resp.Body.Close() - slurp, err := ioutil.ReadAll(resp.Body) + slurp, err := io.ReadAll(resp.Body) if err != nil { t.Error(err) } @@ -1058,7 +1063,7 @@ func testTransportGCRequest(t *testing.T, h2, body bool) { setParallel(t) defer afterTest(t) cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { - ioutil.ReadAll(r.Body) + io.ReadAll(r.Body) if body { io.WriteString(w, "Hello.") } @@ -1074,7 +1079,7 @@ func testTransportGCRequest(t *testing.T, h2, body bool) { if err != nil { t.Fatal(err) } - if _, err := ioutil.ReadAll(res.Body); err != nil { + if _, err := io.ReadAll(res.Body); err != nil { t.Fatal(err) } if err := res.Body.Close(); err != nil { @@ -1135,7 +1140,7 @@ func testTransportRejectsInvalidHeaders(t *testing.T, h2 bool) { res, err := cst.c.Do(req) var body []byte if err == nil { - body, _ = ioutil.ReadAll(res.Body) + body, _ = io.ReadAll(res.Body) res.Body.Close() } var dialed bool @@ -1192,7 +1197,7 @@ func testInterruptWithPanic(t *testing.T, h2 bool, panicValue interface{}) { } gotHeaders <- true defer res.Body.Close() - slurp, err := ioutil.ReadAll(res.Body) + slurp, err := io.ReadAll(res.Body) if string(slurp) != msg { t.Errorf("client read %q; want %q", slurp, msg) } @@ -1357,7 +1362,7 @@ func testServerUndeclaredTrailers(t *testing.T, h2 bool) { if err != nil { t.Fatal(err) } - if _, err := io.Copy(ioutil.Discard, res.Body); err != nil { + if _, err := io.Copy(io.Discard, res.Body); err != nil { t.Fatal(err) } res.Body.Close() @@ -1375,7 +1380,7 @@ func testServerUndeclaredTrailers(t *testing.T, h2 bool) { func TestBadResponseAfterReadingBody(t *testing.T) { defer afterTest(t) cst := newClientServerTest(t, false, HandlerFunc(func(w ResponseWriter, r *Request) { - _, err := io.Copy(ioutil.Discard, r.Body) + _, err := io.Copy(io.Discard, r.Body) if err != nil { t.Fatal(err) } @@ -1468,7 +1473,7 @@ func testWriteHeaderAfterWrite(t *testing.T, h2, hijack bool) { t.Fatal(err) } defer res.Body.Close() - body, err := ioutil.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } diff --git a/src/net/http/cookie.go b/src/net/http/cookie.go index d7a8f5e94e..141bc947f6 100644 --- a/src/net/http/cookie.go +++ b/src/net/http/cookie.go @@ -220,7 +220,7 @@ func (c *Cookie) String() string { } switch c.SameSite { case SameSiteDefaultMode: - b.WriteString("; SameSite") + // Skip, default mode is obtained by not emitting the attribute. case SameSiteNoneMode: b.WriteString("; SameSite=None") case SameSiteLaxMode: diff --git a/src/net/http/cookie_test.go b/src/net/http/cookie_test.go index 9e8196ebce..959713a0dc 100644 --- a/src/net/http/cookie_test.go +++ b/src/net/http/cookie_test.go @@ -67,7 +67,7 @@ var writeSetCookiesTests = []struct { }, { &Cookie{Name: "cookie-12", Value: "samesite-default", SameSite: SameSiteDefaultMode}, - "cookie-12=samesite-default; SameSite", + "cookie-12=samesite-default", }, { &Cookie{Name: "cookie-13", Value: "samesite-lax", SameSite: SameSiteLaxMode}, @@ -283,6 +283,15 @@ var readSetCookiesTests = []struct { }}, }, { + Header{"Set-Cookie": {"samesiteinvalidisdefault=foo; SameSite=invalid"}}, + []*Cookie{{ + Name: "samesiteinvalidisdefault", + Value: "foo", + SameSite: SameSiteDefaultMode, + Raw: "samesiteinvalidisdefault=foo; SameSite=invalid", + }}, + }, + { Header{"Set-Cookie": {"samesitelax=foo; SameSite=Lax"}}, []*Cookie{{ Name: "samesitelax", diff --git a/src/net/http/doc.go b/src/net/http/doc.go index 7855feaaa9..ae9b708c69 100644 --- a/src/net/http/doc.go +++ b/src/net/http/doc.go @@ -21,7 +21,7 @@ The client must close the response body when finished with it: // handle error } defer resp.Body.Close() - body, err := ioutil.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) // ... For control over HTTP client headers, redirect policy, and other diff --git a/src/net/http/example_filesystem_test.go b/src/net/http/example_filesystem_test.go index e1fd42d049..0e81458a07 100644 --- a/src/net/http/example_filesystem_test.go +++ b/src/net/http/example_filesystem_test.go @@ -5,9 +5,9 @@ package http_test import ( + "io/fs" "log" "net/http" - "os" "strings" ) @@ -33,7 +33,7 @@ type dotFileHidingFile struct { // Readdir is a wrapper around the Readdir method of the embedded File // that filters out all files that start with a period in their name. -func (f dotFileHidingFile) Readdir(n int) (fis []os.FileInfo, err error) { +func (f dotFileHidingFile) Readdir(n int) (fis []fs.FileInfo, err error) { files, err := f.File.Readdir(n) for _, file := range files { // Filters out the dot files if !strings.HasPrefix(file.Name(), ".") { @@ -52,12 +52,12 @@ type dotFileHidingFileSystem struct { // Open is a wrapper around the Open method of the embedded FileSystem // that serves a 403 permission error when name has a file or directory // with whose name starts with a period in its path. -func (fs dotFileHidingFileSystem) Open(name string) (http.File, error) { +func (fsys dotFileHidingFileSystem) Open(name string) (http.File, error) { if containsDotFile(name) { // If dot file, return 403 response - return nil, os.ErrPermission + return nil, fs.ErrPermission } - file, err := fs.FileSystem.Open(name) + file, err := fsys.FileSystem.Open(name) if err != nil { return nil, err } @@ -65,7 +65,7 @@ func (fs dotFileHidingFileSystem) Open(name string) (http.File, error) { } func ExampleFileServer_dotFileHiding() { - fs := dotFileHidingFileSystem{http.Dir(".")} - http.Handle("/", http.FileServer(fs)) + fsys := dotFileHidingFileSystem{http.Dir(".")} + http.Handle("/", http.FileServer(fsys)) log.Fatal(http.ListenAndServe(":8080", nil)) } diff --git a/src/net/http/example_test.go b/src/net/http/example_test.go index a783b46618..c677d52238 100644 --- a/src/net/http/example_test.go +++ b/src/net/http/example_test.go @@ -8,7 +8,6 @@ import ( "context" "fmt" "io" - "io/ioutil" "log" "net/http" "os" @@ -46,7 +45,7 @@ func ExampleGet() { if err != nil { log.Fatal(err) } - robots, err := ioutil.ReadAll(res.Body) + robots, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { log.Fatal(err) diff --git a/src/net/http/fcgi/child.go b/src/net/http/fcgi/child.go index 34761f32ee..e97b8440e1 100644 --- a/src/net/http/fcgi/child.go +++ b/src/net/http/fcgi/child.go @@ -11,7 +11,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "net" "net/http" "net/http/cgi" @@ -186,7 +185,7 @@ func (c *child) serve() { var errCloseConn = errors.New("fcgi: connection should be closed") -var emptyBody = ioutil.NopCloser(strings.NewReader("")) +var emptyBody = io.NopCloser(strings.NewReader("")) // ErrRequestAborted is returned by Read when a handler attempts to read the // body of a request that has been aborted by the web server. @@ -325,7 +324,7 @@ func (c *child) serveRequest(req *request, body io.ReadCloser) { // some sort of abort request to the host, so the host // can properly cut off the client sending all the data. // For now just bound it a little and - io.CopyN(ioutil.Discard, body, 100<<20) + io.CopyN(io.Discard, body, 100<<20) body.Close() if !req.keepConn { diff --git a/src/net/http/fcgi/fcgi_test.go b/src/net/http/fcgi/fcgi_test.go index 4a27a12c35..d3b704f821 100644 --- a/src/net/http/fcgi/fcgi_test.go +++ b/src/net/http/fcgi/fcgi_test.go @@ -8,7 +8,6 @@ import ( "bytes" "errors" "io" - "io/ioutil" "net/http" "strings" "testing" @@ -243,7 +242,7 @@ func TestChildServeCleansUp(t *testing.T) { r *http.Request, ) { // block on reading body of request - _, err := io.Copy(ioutil.Discard, r.Body) + _, err := io.Copy(io.Discard, r.Body) if err != tt.err { t.Errorf("Expected %#v, got %#v", tt.err, err) } @@ -275,7 +274,7 @@ func TestMalformedParams(t *testing.T) { // end of params 1, 4, 0, 1, 0, 0, 0, 0, } - rw := rwNopCloser{bytes.NewReader(input), ioutil.Discard} + rw := rwNopCloser{bytes.NewReader(input), io.Discard} c := newChild(rw, http.DefaultServeMux) c.serve() } diff --git a/src/net/http/filetransport_test.go b/src/net/http/filetransport_test.go index 2a2f32c769..fdfd44d967 100644 --- a/src/net/http/filetransport_test.go +++ b/src/net/http/filetransport_test.go @@ -5,6 +5,7 @@ package http import ( + "io" "io/ioutil" "os" "path/filepath" @@ -48,7 +49,7 @@ func TestFileTransport(t *testing.T) { if res.Body == nil { t.Fatalf("for %s, nil Body", urlstr) } - slurp, err := ioutil.ReadAll(res.Body) + slurp, err := io.ReadAll(res.Body) res.Body.Close() check("ReadAll "+urlstr, err) if string(slurp) != "Bar" { diff --git a/src/net/http/fs.go b/src/net/http/fs.go index d718fffba0..a28ae85958 100644 --- a/src/net/http/fs.go +++ b/src/net/http/fs.go @@ -10,6 +10,7 @@ import ( "errors" "fmt" "io" + "io/fs" "mime" "mime/multipart" "net/textproto" @@ -43,7 +44,7 @@ type Dir string // mapDirOpenError maps the provided non-nil error from opening name // to a possibly better non-nil error. In particular, it turns OS-specific errors -// about opening files in non-directories into os.ErrNotExist. See Issue 18984. +// about opening files in non-directories into fs.ErrNotExist. See Issue 18984. func mapDirOpenError(originalErr error, name string) error { if os.IsNotExist(originalErr) || os.IsPermission(originalErr) { return originalErr @@ -59,7 +60,7 @@ func mapDirOpenError(originalErr error, name string) error { return originalErr } if !fi.IsDir() { - return os.ErrNotExist + return fs.ErrNotExist } } return originalErr @@ -86,6 +87,10 @@ func (d Dir) Open(name string) (File, error) { // A FileSystem implements access to a collection of named files. // The elements in a file path are separated by slash ('/', U+002F) // characters, regardless of host operating system convention. +// See the FileServer function to convert a FileSystem to a Handler. +// +// This interface predates the fs.FS interface, which can be used instead: +// the FS adapter function converts an fs.FS to a FileSystem. type FileSystem interface { Open(name string) (File, error) } @@ -98,24 +103,56 @@ type File interface { io.Closer io.Reader io.Seeker - Readdir(count int) ([]os.FileInfo, error) - Stat() (os.FileInfo, error) + Readdir(count int) ([]fs.FileInfo, error) + Stat() (fs.FileInfo, error) +} + +type anyDirs interface { + len() int + name(i int) string + isDir(i int) bool } +type fileInfoDirs []fs.FileInfo + +func (d fileInfoDirs) len() int { return len(d) } +func (d fileInfoDirs) isDir(i int) bool { return d[i].IsDir() } +func (d fileInfoDirs) name(i int) string { return d[i].Name() } + +type dirEntryDirs []fs.DirEntry + +func (d dirEntryDirs) len() int { return len(d) } +func (d dirEntryDirs) isDir(i int) bool { return d[i].IsDir() } +func (d dirEntryDirs) name(i int) string { return d[i].Name() } + func dirList(w ResponseWriter, r *Request, f File) { - dirs, err := f.Readdir(-1) + // Prefer to use ReadDir instead of Readdir, + // because the former doesn't require calling + // Stat on every entry of a directory on Unix. + var dirs anyDirs + var err error + if d, ok := f.(fs.ReadDirFile); ok { + var list dirEntryDirs + list, err = d.ReadDir(-1) + dirs = list + } else { + var list fileInfoDirs + list, err = f.Readdir(-1) + dirs = list + } + if err != nil { logf(r, "http: error reading directory: %v", err) Error(w, "Error reading directory", StatusInternalServerError) return } - sort.Slice(dirs, func(i, j int) bool { return dirs[i].Name() < dirs[j].Name() }) + sort.Slice(dirs, func(i, j int) bool { return dirs.name(i) < dirs.name(j) }) w.Header().Set("Content-Type", "text/html; charset=utf-8") fmt.Fprintf(w, "<pre>\n") - for _, d := range dirs { - name := d.Name() - if d.IsDir() { + for i, n := 0, dirs.len(); i < n; i++ { + name := dirs.name(i) + if dirs.isDir(i) { name += "/" } // name may contain '?' or '#', which must be escaped to remain @@ -706,17 +743,98 @@ type fileHandler struct { root FileSystem } +type ioFS struct { + fsys fs.FS +} + +type ioFile struct { + file fs.File +} + +func (f ioFS) Open(name string) (File, error) { + if name == "/" { + name = "." + } else { + name = strings.TrimPrefix(name, "/") + } + file, err := f.fsys.Open(name) + if err != nil { + return nil, err + } + return ioFile{file}, nil +} + +func (f ioFile) Close() error { return f.file.Close() } +func (f ioFile) Read(b []byte) (int, error) { return f.file.Read(b) } +func (f ioFile) Stat() (fs.FileInfo, error) { return f.file.Stat() } + +var errMissingSeek = errors.New("io.File missing Seek method") +var errMissingReadDir = errors.New("io.File directory missing ReadDir method") + +func (f ioFile) Seek(offset int64, whence int) (int64, error) { + s, ok := f.file.(io.Seeker) + if !ok { + return 0, errMissingSeek + } + return s.Seek(offset, whence) +} + +func (f ioFile) ReadDir(count int) ([]fs.DirEntry, error) { + d, ok := f.file.(fs.ReadDirFile) + if !ok { + return nil, errMissingReadDir + } + return d.ReadDir(count) +} + +func (f ioFile) Readdir(count int) ([]fs.FileInfo, error) { + d, ok := f.file.(fs.ReadDirFile) + if !ok { + return nil, errMissingReadDir + } + var list []fs.FileInfo + for { + dirs, err := d.ReadDir(count - len(list)) + for _, dir := range dirs { + info, err := dir.Info() + if err != nil { + // Pretend it doesn't exist, like (*os.File).Readdir does. + continue + } + list = append(list, info) + } + if err != nil { + return list, err + } + if count < 0 || len(list) >= count { + break + } + } + return list, nil +} + +// FS converts fsys to a FileSystem implementation, +// for use with FileServer and NewFileTransport. +func FS(fsys fs.FS) FileSystem { + return ioFS{fsys} +} + // FileServer returns a handler that serves HTTP requests // with the contents of the file system rooted at root. // +// As a special case, the returned file server redirects any request +// ending in "/index.html" to the same path, without the final +// "index.html". +// // To use the operating system's file system implementation, // use http.Dir: // // http.Handle("/", http.FileServer(http.Dir("/tmp"))) // -// As a special case, the returned file server redirects any request -// ending in "/index.html" to the same path, without the final -// "index.html". +// To use an fs.FS implementation, use http.FS to convert it: +// +// http.Handle("/", http.FileServer(http.FS(fsys))) +// func FileServer(root FileSystem) Handler { return &fileHandler{root} } diff --git a/src/net/http/fs_test.go b/src/net/http/fs_test.go index 4ac73b728f..2e4751114d 100644 --- a/src/net/http/fs_test.go +++ b/src/net/http/fs_test.go @@ -10,6 +10,7 @@ import ( "errors" "fmt" "io" + "io/fs" "io/ioutil" "mime" "mime/multipart" @@ -159,7 +160,7 @@ Cases: if g, w := part.Header.Get("Content-Range"), wantContentRange; g != w { t.Errorf("range=%q: part Content-Range = %q; want %q", rt.r, g, w) } - body, err := ioutil.ReadAll(part) + body, err := io.ReadAll(part) if err != nil { t.Errorf("range=%q, reading part index %d body: %v", rt.r, ri, err) continue Cases @@ -311,7 +312,7 @@ func TestFileServerEscapesNames(t *testing.T) { if err != nil { t.Fatalf("test %q: Get: %v", test.name, err) } - b, err := ioutil.ReadAll(res.Body) + b, err := io.ReadAll(res.Body) if err != nil { t.Fatalf("test %q: read Body: %v", test.name, err) } @@ -359,7 +360,7 @@ func TestFileServerSortsNames(t *testing.T) { } defer res.Body.Close() - b, err := ioutil.ReadAll(res.Body) + b, err := io.ReadAll(res.Body) if err != nil { t.Fatalf("read Body: %v", err) } @@ -393,7 +394,7 @@ func TestFileServerImplicitLeadingSlash(t *testing.T) { if err != nil { t.Fatalf("Get %s: %v", suffix, err) } - b, err := ioutil.ReadAll(res.Body) + b, err := io.ReadAll(res.Body) if err != nil { t.Fatalf("ReadAll %s: %v", suffix, err) } @@ -570,6 +571,43 @@ func testServeFileWithContentEncoding(t *testing.T, h2 bool) { func TestServeIndexHtml(t *testing.T) { defer afterTest(t) + + for i := 0; i < 2; i++ { + var h Handler + var name string + switch i { + case 0: + h = FileServer(Dir(".")) + name = "Dir" + case 1: + h = FileServer(FS(os.DirFS("."))) + name = "DirFS" + } + t.Run(name, func(t *testing.T) { + const want = "index.html says hello\n" + ts := httptest.NewServer(h) + defer ts.Close() + + for _, path := range []string{"/testdata/", "/testdata/index.html"} { + res, err := Get(ts.URL + path) + if err != nil { + t.Fatal(err) + } + b, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal("reading Body:", err) + } + if s := string(b); s != want { + t.Errorf("for path %q got %q, want %q", path, s, want) + } + res.Body.Close() + } + }) + } +} + +func TestServeIndexHtmlFS(t *testing.T) { + defer afterTest(t) const want = "index.html says hello\n" ts := httptest.NewServer(FileServer(Dir("."))) defer ts.Close() @@ -579,7 +617,7 @@ func TestServeIndexHtml(t *testing.T) { if err != nil { t.Fatal(err) } - b, err := ioutil.ReadAll(res.Body) + b, err := io.ReadAll(res.Body) if err != nil { t.Fatal("reading Body:", err) } @@ -629,9 +667,9 @@ func (f *fakeFileInfo) Sys() interface{} { return nil } func (f *fakeFileInfo) ModTime() time.Time { return f.modtime } func (f *fakeFileInfo) IsDir() bool { return f.dir } func (f *fakeFileInfo) Size() int64 { return int64(len(f.contents)) } -func (f *fakeFileInfo) Mode() os.FileMode { +func (f *fakeFileInfo) Mode() fs.FileMode { if f.dir { - return 0755 | os.ModeDir + return 0755 | fs.ModeDir } return 0644 } @@ -644,12 +682,12 @@ type fakeFile struct { } func (f *fakeFile) Close() error { return nil } -func (f *fakeFile) Stat() (os.FileInfo, error) { return f.fi, nil } -func (f *fakeFile) Readdir(count int) ([]os.FileInfo, error) { +func (f *fakeFile) Stat() (fs.FileInfo, error) { return f.fi, nil } +func (f *fakeFile) Readdir(count int) ([]fs.FileInfo, error) { if !f.fi.dir { - return nil, os.ErrInvalid + return nil, fs.ErrInvalid } - var fis []os.FileInfo + var fis []fs.FileInfo limit := f.entpos + count if count <= 0 || limit > len(f.fi.ents) { @@ -668,11 +706,11 @@ func (f *fakeFile) Readdir(count int) ([]os.FileInfo, error) { type fakeFS map[string]*fakeFileInfo -func (fs fakeFS) Open(name string) (File, error) { +func (fsys fakeFS) Open(name string) (File, error) { name = path.Clean(name) - f, ok := fs[name] + f, ok := fsys[name] if !ok { - return nil, os.ErrNotExist + return nil, fs.ErrNotExist } if f.err != nil { return nil, f.err @@ -707,7 +745,7 @@ func TestDirectoryIfNotModified(t *testing.T) { if err != nil { t.Fatal(err) } - b, err := ioutil.ReadAll(res.Body) + b, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } @@ -747,7 +785,7 @@ func TestDirectoryIfNotModified(t *testing.T) { res.Body.Close() } -func mustStat(t *testing.T, fileName string) os.FileInfo { +func mustStat(t *testing.T, fileName string) fs.FileInfo { fi, err := os.Stat(fileName) if err != nil { t.Fatal(err) @@ -1044,7 +1082,7 @@ func TestServeContent(t *testing.T) { if err != nil { t.Fatal(err) } - io.Copy(ioutil.Discard, res.Body) + io.Copy(io.Discard, res.Body) res.Body.Close() if res.StatusCode != tt.wantStatus { t.Errorf("test %q using %q: got status = %d; want %d", testName, method, res.StatusCode, tt.wantStatus) @@ -1081,7 +1119,7 @@ func (issue12991FS) Open(string) (File, error) { return issue12991File{}, nil } type issue12991File struct{ File } -func (issue12991File) Stat() (os.FileInfo, error) { return nil, os.ErrPermission } +func (issue12991File) Stat() (fs.FileInfo, error) { return nil, fs.ErrPermission } func (issue12991File) Close() error { return nil } func TestServeContentErrorMessages(t *testing.T) { @@ -1091,7 +1129,7 @@ func TestServeContentErrorMessages(t *testing.T) { err: errors.New("random error"), }, "/403": &fakeFileInfo{ - err: &os.PathError{Err: os.ErrPermission}, + err: &fs.PathError{Err: fs.ErrPermission}, }, } ts := httptest.NewServer(FileServer(fs)) @@ -1158,7 +1196,7 @@ func TestLinuxSendfile(t *testing.T) { if err != nil { t.Fatalf("http client error: %v", err) } - _, err = io.Copy(ioutil.Discard, res.Body) + _, err = io.Copy(io.Discard, res.Body) if err != nil { t.Fatalf("client body read error: %v", err) } @@ -1180,7 +1218,7 @@ func getBody(t *testing.T, testName string, req Request, client *Client) (*Respo if err != nil { t.Fatalf("%s: for URL %q, send error: %v", testName, req.URL.String(), err) } - b, err := ioutil.ReadAll(r.Body) + b, err := io.ReadAll(r.Body) if err != nil { t.Fatalf("%s: for URL %q, reading body: %v", testName, req.URL.String(), err) } @@ -1289,7 +1327,7 @@ func (d fileServerCleanPathDir) Open(path string) (File, error) { // Just return back something that's a directory. return Dir(".").Open(".") } - return nil, os.ErrNotExist + return nil, fs.ErrNotExist } type panicOnSeek struct{ io.ReadSeeker } @@ -1363,7 +1401,7 @@ func testServeFileRejectsInvalidSuffixLengths(t *testing.T, h2 bool) { if g, w := res.StatusCode, tt.wantCode; g != w { t.Errorf("StatusCode mismatch: got %d want %d", g, w) } - slurp, err := ioutil.ReadAll(res.Body) + slurp, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { t.Fatal(err) diff --git a/src/net/http/http_test.go b/src/net/http/http_test.go index 49c2b4196a..3f1d7cee71 100644 --- a/src/net/http/http_test.go +++ b/src/net/http/http_test.go @@ -13,13 +13,8 @@ import ( "os/exec" "reflect" "testing" - "time" ) -func init() { - shutdownPollInterval = 5 * time.Millisecond -} - func TestForeachHeaderElement(t *testing.T) { tests := []struct { in string diff --git a/src/net/http/httptest/example_test.go b/src/net/http/httptest/example_test.go index 54e77dbb84..a6738432eb 100644 --- a/src/net/http/httptest/example_test.go +++ b/src/net/http/httptest/example_test.go @@ -7,7 +7,6 @@ package httptest_test import ( "fmt" "io" - "io/ioutil" "log" "net/http" "net/http/httptest" @@ -23,7 +22,7 @@ func ExampleResponseRecorder() { handler(w, req) resp := w.Result() - body, _ := ioutil.ReadAll(resp.Body) + body, _ := io.ReadAll(resp.Body) fmt.Println(resp.StatusCode) fmt.Println(resp.Header.Get("Content-Type")) @@ -45,7 +44,7 @@ func ExampleServer() { if err != nil { log.Fatal(err) } - greeting, err := ioutil.ReadAll(res.Body) + greeting, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { log.Fatal(err) @@ -67,7 +66,7 @@ func ExampleServer_hTTP2() { if err != nil { log.Fatal(err) } - greeting, err := ioutil.ReadAll(res.Body) + greeting, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { log.Fatal(err) @@ -89,7 +88,7 @@ func ExampleNewTLSServer() { log.Fatal(err) } - greeting, err := ioutil.ReadAll(res.Body) + greeting, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { log.Fatal(err) diff --git a/src/net/http/httptest/httptest.go b/src/net/http/httptest/httptest.go index f7202da92f..9bedefd2bc 100644 --- a/src/net/http/httptest/httptest.go +++ b/src/net/http/httptest/httptest.go @@ -10,7 +10,6 @@ import ( "bytes" "crypto/tls" "io" - "io/ioutil" "net/http" "strings" ) @@ -66,7 +65,7 @@ func NewRequest(method, target string, body io.Reader) *http.Request { if rc, ok := body.(io.ReadCloser); ok { req.Body = rc } else { - req.Body = ioutil.NopCloser(body) + req.Body = io.NopCloser(body) } } diff --git a/src/net/http/httptest/httptest_test.go b/src/net/http/httptest/httptest_test.go index ef7d943837..071add67ea 100644 --- a/src/net/http/httptest/httptest_test.go +++ b/src/net/http/httptest/httptest_test.go @@ -7,7 +7,6 @@ package httptest import ( "crypto/tls" "io" - "io/ioutil" "net/http" "net/url" "reflect" @@ -155,7 +154,7 @@ func TestNewRequest(t *testing.T) { } { t.Run(tt.name, func(t *testing.T) { got := NewRequest(tt.method, tt.uri, tt.body) - slurp, err := ioutil.ReadAll(got.Body) + slurp, err := io.ReadAll(got.Body) if err != nil { t.Errorf("ReadAll: %v", err) } diff --git a/src/net/http/httptest/recorder.go b/src/net/http/httptest/recorder.go index 66e67e78b3..2428482612 100644 --- a/src/net/http/httptest/recorder.go +++ b/src/net/http/httptest/recorder.go @@ -7,7 +7,7 @@ package httptest import ( "bytes" "fmt" - "io/ioutil" + "io" "net/http" "net/textproto" "strconv" @@ -179,7 +179,7 @@ func (rw *ResponseRecorder) Result() *http.Response { } res.Status = fmt.Sprintf("%03d %s", res.StatusCode, http.StatusText(res.StatusCode)) if rw.Body != nil { - res.Body = ioutil.NopCloser(bytes.NewReader(rw.Body.Bytes())) + res.Body = io.NopCloser(bytes.NewReader(rw.Body.Bytes())) } else { res.Body = http.NoBody } diff --git a/src/net/http/httptest/recorder_test.go b/src/net/http/httptest/recorder_test.go index e9534894b6..a865e878b9 100644 --- a/src/net/http/httptest/recorder_test.go +++ b/src/net/http/httptest/recorder_test.go @@ -7,7 +7,6 @@ package httptest import ( "fmt" "io" - "io/ioutil" "net/http" "testing" ) @@ -42,7 +41,7 @@ func TestRecorder(t *testing.T) { } hasResultContents := func(want string) checkFunc { return func(rec *ResponseRecorder) error { - contentBytes, err := ioutil.ReadAll(rec.Result().Body) + contentBytes, err := io.ReadAll(rec.Result().Body) if err != nil { return err } diff --git a/src/net/http/httptest/server_test.go b/src/net/http/httptest/server_test.go index 0aad15c5ed..39568b358c 100644 --- a/src/net/http/httptest/server_test.go +++ b/src/net/http/httptest/server_test.go @@ -6,7 +6,7 @@ package httptest import ( "bufio" - "io/ioutil" + "io" "net" "net/http" "testing" @@ -61,7 +61,7 @@ func testServer(t *testing.T, newServer newServerFunc) { if err != nil { t.Fatal(err) } - got, err := ioutil.ReadAll(res.Body) + got, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { t.Fatal(err) @@ -81,7 +81,7 @@ func testGetAfterClose(t *testing.T, newServer newServerFunc) { if err != nil { t.Fatal(err) } - got, err := ioutil.ReadAll(res.Body) + got, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } @@ -93,7 +93,7 @@ func testGetAfterClose(t *testing.T, newServer newServerFunc) { res, err = http.Get(ts.URL) if err == nil { - body, _ := ioutil.ReadAll(res.Body) + body, _ := io.ReadAll(res.Body) t.Fatalf("Unexpected response after close: %v, %v, %s", res.Status, res.Header, body) } } @@ -152,7 +152,7 @@ func testServerClient(t *testing.T, newTLSServer newServerFunc) { if err != nil { t.Fatal(err) } - got, err := ioutil.ReadAll(res.Body) + got, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { t.Fatal(err) diff --git a/src/net/http/httputil/dump.go b/src/net/http/httputil/dump.go index c97be066d7..4c9d28bed8 100644 --- a/src/net/http/httputil/dump.go +++ b/src/net/http/httputil/dump.go @@ -10,7 +10,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "net" "net/http" "net/url" @@ -35,7 +34,7 @@ func drainBody(b io.ReadCloser) (r1, r2 io.ReadCloser, err error) { if err = b.Close(); err != nil { return nil, b, err } - return ioutil.NopCloser(&buf), ioutil.NopCloser(bytes.NewReader(buf.Bytes())), nil + return io.NopCloser(&buf), io.NopCloser(bytes.NewReader(buf.Bytes())), nil } // dumpConn is a net.Conn which writes to Writer and reads from Reader @@ -81,7 +80,7 @@ func DumpRequestOut(req *http.Request, body bool) ([]byte, error) { if !body { contentLength := outgoingLength(req) if contentLength != 0 { - req.Body = ioutil.NopCloser(io.LimitReader(neverEnding('x'), contentLength)) + req.Body = io.NopCloser(io.LimitReader(neverEnding('x'), contentLength)) dummyBody = true } } else { @@ -133,7 +132,7 @@ func DumpRequestOut(req *http.Request, body bool) ([]byte, error) { if err == nil { // Ensure all the body is read; otherwise // we'll get a partial dump. - io.Copy(ioutil.Discard, req.Body) + io.Copy(io.Discard, req.Body) req.Body.Close() } select { @@ -296,7 +295,7 @@ func (failureToReadBody) Read([]byte) (int, error) { return 0, errNoBody } func (failureToReadBody) Close() error { return nil } // emptyBody is an instance of empty reader. -var emptyBody = ioutil.NopCloser(strings.NewReader("")) +var emptyBody = io.NopCloser(strings.NewReader("")) // DumpResponse is like DumpRequest but dumps a response. func DumpResponse(resp *http.Response, body bool) ([]byte, error) { diff --git a/src/net/http/httputil/dump_test.go b/src/net/http/httputil/dump_test.go index ead56bc172..7571eb0820 100644 --- a/src/net/http/httputil/dump_test.go +++ b/src/net/http/httputil/dump_test.go @@ -9,7 +9,6 @@ import ( "bytes" "fmt" "io" - "io/ioutil" "net/http" "net/url" "runtime" @@ -268,7 +267,7 @@ func TestDumpRequest(t *testing.T) { } switch b := ti.Body.(type) { case []byte: - req.Body = ioutil.NopCloser(bytes.NewReader(b)) + req.Body = io.NopCloser(bytes.NewReader(b)) case func() io.ReadCloser: req.Body = b() default: @@ -363,7 +362,7 @@ var dumpResTests = []struct { Header: http.Header{ "Foo": []string{"Bar"}, }, - Body: ioutil.NopCloser(strings.NewReader("foo")), // shouldn't be used + Body: io.NopCloser(strings.NewReader("foo")), // shouldn't be used }, body: false, // to verify we see 50, not empty or 3. want: `HTTP/1.1 200 OK @@ -379,7 +378,7 @@ Foo: Bar`, ProtoMajor: 1, ProtoMinor: 1, ContentLength: 3, - Body: ioutil.NopCloser(strings.NewReader("foo")), + Body: io.NopCloser(strings.NewReader("foo")), }, body: true, want: `HTTP/1.1 200 OK @@ -396,7 +395,7 @@ foo`, ProtoMajor: 1, ProtoMinor: 1, ContentLength: -1, - Body: ioutil.NopCloser(strings.NewReader("foo")), + Body: io.NopCloser(strings.NewReader("foo")), TransferEncoding: []string{"chunked"}, }, body: true, diff --git a/src/net/http/httputil/example_test.go b/src/net/http/httputil/example_test.go index 6191603674..b77a243ca3 100644 --- a/src/net/http/httputil/example_test.go +++ b/src/net/http/httputil/example_test.go @@ -6,7 +6,7 @@ package httputil_test import ( "fmt" - "io/ioutil" + "io" "log" "net/http" "net/http/httptest" @@ -39,7 +39,7 @@ func ExampleDumpRequest() { } defer resp.Body.Close() - b, err := ioutil.ReadAll(resp.Body) + b, err := io.ReadAll(resp.Body) if err != nil { log.Fatal(err) } @@ -111,7 +111,7 @@ func ExampleReverseProxy() { log.Fatal(err) } - b, err := ioutil.ReadAll(resp.Body) + b, err := io.ReadAll(resp.Body) if err != nil { log.Fatal(err) } diff --git a/src/net/http/httputil/reverseproxy.go b/src/net/http/httputil/reverseproxy.go index 3f48fab544..4e369580ea 100644 --- a/src/net/http/httputil/reverseproxy.go +++ b/src/net/http/httputil/reverseproxy.go @@ -58,9 +58,9 @@ type ReverseProxy struct { // A negative value means to flush immediately // after each write to the client. // The FlushInterval is ignored when ReverseProxy - // recognizes a response as a streaming response; - // for such responses, writes are flushed to the client - // immediately. + // recognizes a response as a streaming response, or + // if its ContentLength is -1; for such responses, writes + // are flushed to the client immediately. FlushInterval time.Duration // ErrorLog specifies an optional logger for errors @@ -325,7 +325,7 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { rw.WriteHeader(res.StatusCode) - err = p.copyResponse(rw, res.Body, p.flushInterval(req, res)) + err = p.copyResponse(rw, res.Body, p.flushInterval(res)) if err != nil { defer res.Body.Close() // Since we're streaming the response, if we run into an error all we can do @@ -397,7 +397,7 @@ func removeConnectionHeaders(h http.Header) { // flushInterval returns the p.FlushInterval value, conditionally // overriding its value for a specific request/response. -func (p *ReverseProxy) flushInterval(req *http.Request, res *http.Response) time.Duration { +func (p *ReverseProxy) flushInterval(res *http.Response) time.Duration { resCT := res.Header.Get("Content-Type") // For Server-Sent Events responses, flush immediately. @@ -406,7 +406,11 @@ func (p *ReverseProxy) flushInterval(req *http.Request, res *http.Response) time return -1 // negative means immediately } - // TODO: more specific cases? e.g. res.ContentLength == -1? + // We might have the case of streaming for which Content-Length might be unset. + if res.ContentLength == -1 { + return -1 + } + return p.FlushInterval } @@ -545,8 +549,6 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R return } - copyHeader(res.Header, rw.Header()) - hj, ok := rw.(http.Hijacker) if !ok { p.getErrorHandler()(rw, req, fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw)) @@ -577,6 +579,10 @@ func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.R return } defer conn.Close() + + copyHeader(rw.Header(), res.Header) + + res.Header = rw.Header() res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above if err := res.Write(brw); err != nil { p.getErrorHandler()(rw, req, fmt.Errorf("response write: %v", err)) diff --git a/src/net/http/httputil/reverseproxy_test.go b/src/net/http/httputil/reverseproxy_test.go index 764939fb0f..3acbd940e4 100644 --- a/src/net/http/httputil/reverseproxy_test.go +++ b/src/net/http/httputil/reverseproxy_test.go @@ -13,7 +13,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "log" "net/http" "net/http/httptest" @@ -84,7 +83,7 @@ func TestReverseProxy(t *testing.T) { t.Fatal(err) } proxyHandler := NewSingleHostReverseProxy(backendURL) - proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests frontend := httptest.NewServer(proxyHandler) defer frontend.Close() frontendClient := frontend.Client() @@ -124,7 +123,7 @@ func TestReverseProxy(t *testing.T) { if cookie := res.Cookies()[0]; cookie.Name != "flavor" { t.Errorf("unexpected cookie %q", cookie.Name) } - bodyBytes, _ := ioutil.ReadAll(res.Body) + bodyBytes, _ := io.ReadAll(res.Body) if g, e := string(bodyBytes), backendResponse; g != e { t.Errorf("got body %q; expected %q", g, e) } @@ -218,7 +217,7 @@ func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) { t.Fatalf("Get: %v", err) } defer res.Body.Close() - bodyBytes, err := ioutil.ReadAll(res.Body) + bodyBytes, err := io.ReadAll(res.Body) if err != nil { t.Fatalf("reading body: %v", err) } @@ -271,7 +270,7 @@ func TestXForwardedFor(t *testing.T) { if g, e := res.StatusCode, backendStatus; g != e { t.Errorf("got res.StatusCode %d; expected %d", g, e) } - bodyBytes, _ := ioutil.ReadAll(res.Body) + bodyBytes, _ := io.ReadAll(res.Body) if g, e := string(bodyBytes), backendResponse; g != e { t.Errorf("got body %q; expected %q", g, e) } @@ -373,7 +372,7 @@ func TestReverseProxyFlushInterval(t *testing.T) { t.Fatalf("Get: %v", err) } defer res.Body.Close() - if bodyBytes, _ := ioutil.ReadAll(res.Body); string(bodyBytes) != expected { + if bodyBytes, _ := io.ReadAll(res.Body); string(bodyBytes) != expected { t.Errorf("got body %q; expected %q", bodyBytes, expected) } } @@ -441,7 +440,7 @@ func TestReverseProxyCancellation(t *testing.T) { defer backend.Close() - backend.Config.ErrorLog = log.New(ioutil.Discard, "", 0) + backend.Config.ErrorLog = log.New(io.Discard, "", 0) backendURL, err := url.Parse(backend.URL) if err != nil { @@ -452,7 +451,7 @@ func TestReverseProxyCancellation(t *testing.T) { // Discards errors of the form: // http: proxy error: read tcp 127.0.0.1:44643: use of closed network connection - proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) frontend := httptest.NewServer(proxyHandler) defer frontend.Close() @@ -504,7 +503,7 @@ func TestNilBody(t *testing.T) { t.Fatal(err) } defer res.Body.Close() - slurp, err := ioutil.ReadAll(res.Body) + slurp, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } @@ -533,7 +532,7 @@ func TestUserAgentHeader(t *testing.T) { t.Fatal(err) } proxyHandler := NewSingleHostReverseProxy(backendURL) - proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests frontend := httptest.NewServer(proxyHandler) defer frontend.Close() frontendClient := frontend.Client() @@ -606,7 +605,7 @@ func TestReverseProxyGetPutBuffer(t *testing.T) { if err != nil { t.Fatalf("Get: %v", err) } - slurp, err := ioutil.ReadAll(res.Body) + slurp, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { t.Fatalf("reading body: %v", err) @@ -627,7 +626,7 @@ func TestReverseProxy_Post(t *testing.T) { const backendStatus = 200 var requestBody = bytes.Repeat([]byte("a"), 1<<20) backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - slurp, err := ioutil.ReadAll(r.Body) + slurp, err := io.ReadAll(r.Body) if err != nil { t.Errorf("Backend body read = %v", err) } @@ -656,7 +655,7 @@ func TestReverseProxy_Post(t *testing.T) { if g, e := res.StatusCode, backendStatus; g != e { t.Errorf("got res.StatusCode %d; expected %d", g, e) } - bodyBytes, _ := ioutil.ReadAll(res.Body) + bodyBytes, _ := io.ReadAll(res.Body) if g, e := string(bodyBytes), backendResponse; g != e { t.Errorf("got body %q; expected %q", g, e) } @@ -672,7 +671,7 @@ func (fn RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) func TestReverseProxy_NilBody(t *testing.T) { backendURL, _ := url.Parse("http://fake.tld/") proxyHandler := NewSingleHostReverseProxy(backendURL) - proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) { if req.Body != nil { t.Error("Body != nil; want a nil Body") @@ -695,8 +694,8 @@ func TestReverseProxy_NilBody(t *testing.T) { // Issue 33142: always allocate the request headers func TestReverseProxy_AllocatedHeader(t *testing.T) { proxyHandler := new(ReverseProxy) - proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests - proxyHandler.Director = func(*http.Request) {} // noop + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests + proxyHandler.Director = func(*http.Request) {} // noop proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) { if req.Header == nil { t.Error("Header == nil; want a non-nil Header") @@ -722,7 +721,7 @@ func TestReverseProxyModifyResponse(t *testing.T) { rpURL, _ := url.Parse(backendServer.URL) rproxy := NewSingleHostReverseProxy(rpURL) - rproxy.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests + rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests rproxy.ModifyResponse = func(resp *http.Response) error { if resp.Header.Get("X-Hit-Mod") != "true" { return fmt.Errorf("tried to by-pass proxy") @@ -821,7 +820,7 @@ func TestReverseProxyErrorHandler(t *testing.T) { if rproxy.Transport == nil { rproxy.Transport = failingRoundTripper{} } - rproxy.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests + rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests if tt.errorHandler != nil { rproxy.ErrorHandler = tt.errorHandler } @@ -896,7 +895,7 @@ func (t *staticTransport) RoundTrip(r *http.Request) (*http.Response, error) { func BenchmarkServeHTTP(b *testing.B) { res := &http.Response{ StatusCode: 200, - Body: ioutil.NopCloser(strings.NewReader("")), + Body: io.NopCloser(strings.NewReader("")), } proxy := &ReverseProxy{ Director: func(*http.Request) {}, @@ -953,7 +952,7 @@ func TestServeHTTPDeepCopy(t *testing.T) { // Issue 18327: verify we always do a deep copy of the Request.Header map // before any mutations. func TestClonesRequestHeaders(t *testing.T) { - log.SetOutput(ioutil.Discard) + log.SetOutput(io.Discard) defer log.SetOutput(os.Stderr) req, _ := http.NewRequest("GET", "http://foo.tld/", nil) req.RemoteAddr = "1.2.3.4:56789" @@ -1031,7 +1030,7 @@ func (cc *checkCloser) Read(b []byte) (int, error) { // Issue 23643: panic on body copy error func TestReverseProxy_PanicBodyError(t *testing.T) { - log.SetOutput(ioutil.Discard) + log.SetOutput(io.Discard) defer log.SetOutput(os.Stderr) backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { out := "this call was relayed by the reverse proxy" @@ -1067,7 +1066,6 @@ func TestSelectFlushInterval(t *testing.T) { tests := []struct { name string p *ReverseProxy - req *http.Request res *http.Response want time.Duration }{ @@ -1097,10 +1095,26 @@ func TestSelectFlushInterval(t *testing.T) { p: &ReverseProxy{FlushInterval: 0}, want: -1, }, + { + name: "Content-Length: -1, overrides non-zero", + res: &http.Response{ + ContentLength: -1, + }, + p: &ReverseProxy{FlushInterval: 123}, + want: -1, + }, + { + name: "Content-Length: -1, overrides zero", + res: &http.Response{ + ContentLength: -1, + }, + p: &ReverseProxy{FlushInterval: 0}, + want: -1, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - got := tt.p.flushInterval(tt.req, tt.res) + got := tt.p.flushInterval(tt.res) if got != tt.want { t.Errorf("flushLatency = %v; want %v", got, tt.want) } @@ -1133,7 +1147,7 @@ func TestReverseProxyWebSocket(t *testing.T) { backURL, _ := url.Parse(backendServer.URL) rproxy := NewSingleHostReverseProxy(backURL) - rproxy.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests + rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests rproxy.ModifyResponse = func(res *http.Response) error { res.Header.Add("X-Modified", "true") return nil @@ -1142,6 +1156,9 @@ func TestReverseProxyWebSocket(t *testing.T) { handler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { rw.Header().Set("X-Header", "X-Value") rproxy.ServeHTTP(rw, req) + if got, want := rw.Header().Get("X-Modified"), "true"; got != want { + t.Errorf("response writer X-Modified header = %q; want %q", got, want) + } }) frontendProxy := httptest.NewServer(handler) @@ -1247,7 +1264,7 @@ func TestReverseProxyWebSocketCancelation(t *testing.T) { backendURL, _ := url.Parse(cst.URL) rproxy := NewSingleHostReverseProxy(backendURL) - rproxy.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests + rproxy.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests rproxy.ModifyResponse = func(res *http.Response) error { res.Header.Add("X-Modified", "true") return nil @@ -1334,7 +1351,7 @@ func TestUnannouncedTrailer(t *testing.T) { t.Fatal(err) } proxyHandler := NewSingleHostReverseProxy(backendURL) - proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests + proxyHandler.ErrorLog = log.New(io.Discard, "", 0) // quiet for tests frontend := httptest.NewServer(proxyHandler) defer frontend.Close() frontendClient := frontend.Client() @@ -1344,7 +1361,7 @@ func TestUnannouncedTrailer(t *testing.T) { t.Fatalf("Get: %v", err) } - ioutil.ReadAll(res.Body) + io.ReadAll(res.Body) if g, w := res.Trailer.Get("X-Unannounced-Trailer"), "unannounced_trailer_value"; g != w { t.Errorf("Trailer(X-Unannounced-Trailer) = %q; want %q", g, w) diff --git a/src/net/http/internal/chunked_test.go b/src/net/http/internal/chunked_test.go index d06716591a..08152ed1e2 100644 --- a/src/net/http/internal/chunked_test.go +++ b/src/net/http/internal/chunked_test.go @@ -9,7 +9,6 @@ import ( "bytes" "fmt" "io" - "io/ioutil" "strings" "testing" ) @@ -29,7 +28,7 @@ func TestChunk(t *testing.T) { } r := NewChunkedReader(&b) - data, err := ioutil.ReadAll(r) + data, err := io.ReadAll(r) if err != nil { t.Logf(`data: "%s"`, data) t.Fatalf("ReadAll from reader: %v", err) @@ -177,7 +176,7 @@ func TestChunkReadingIgnoresExtensions(t *testing.T) { "17;someext\r\n" + // token without value "world! 0123456789abcdef\r\n" + "0;someextension=sometoken\r\n" // token=token - data, err := ioutil.ReadAll(NewChunkedReader(strings.NewReader(in))) + data, err := io.ReadAll(NewChunkedReader(strings.NewReader(in))) if err != nil { t.Fatalf("ReadAll = %q, %v", data, err) } diff --git a/src/net/http/main_test.go b/src/net/http/main_test.go index 35cc80977c..6564627998 100644 --- a/src/net/http/main_test.go +++ b/src/net/http/main_test.go @@ -6,7 +6,7 @@ package http_test import ( "fmt" - "io/ioutil" + "io" "log" "net/http" "os" @@ -17,7 +17,7 @@ import ( "time" ) -var quietLog = log.New(ioutil.Discard, "", 0) +var quietLog = log.New(io.Discard, "", 0) func TestMain(m *testing.M) { v := m.Run() diff --git a/src/net/http/pprof/pprof.go b/src/net/http/pprof/pprof.go index 81df0448e9..2bfcfb9545 100644 --- a/src/net/http/pprof/pprof.go +++ b/src/net/http/pprof/pprof.go @@ -61,11 +61,12 @@ import ( "bytes" "context" "fmt" - "html/template" + "html" "internal/profile" "io" "log" "net/http" + "net/url" "os" "runtime" "runtime/pprof" @@ -93,14 +94,10 @@ func Cmdline(w http.ResponseWriter, r *http.Request) { fmt.Fprintf(w, strings.Join(os.Args, "\x00")) } -func sleep(w http.ResponseWriter, d time.Duration) { - var clientGone <-chan bool - if cn, ok := w.(http.CloseNotifier); ok { - clientGone = cn.CloseNotify() - } +func sleep(r *http.Request, d time.Duration) { select { case <-time.After(d): - case <-clientGone: + case <-r.Context().Done(): } } @@ -142,7 +139,7 @@ func Profile(w http.ResponseWriter, r *http.Request) { fmt.Sprintf("Could not enable CPU profiling: %s", err)) return } - sleep(w, time.Duration(sec)*time.Second) + sleep(r, time.Duration(sec)*time.Second) pprof.StopCPUProfile() } @@ -171,7 +168,7 @@ func Trace(w http.ResponseWriter, r *http.Request) { fmt.Sprintf("Could not enable tracing: %s", err)) return } - sleep(w, time.Duration(sec*float64(time.Second))) + sleep(r, time.Duration(sec*float64(time.Second))) trace.Stop() } @@ -356,6 +353,13 @@ var profileDescriptions = map[string]string{ "trace": "A trace of execution of the current program. You can specify the duration in the seconds GET parameter. After you get the trace file, use the go tool trace command to investigate the trace.", } +type profileEntry struct { + Name string + Href string + Desc string + Count int +} + // Index responds with the pprof-formatted profile named by the request. // For example, "/debug/pprof/heap" serves the "heap" profile. // Index responds to a request for "/debug/pprof/" with an HTML page @@ -372,17 +376,11 @@ func Index(w http.ResponseWriter, r *http.Request) { w.Header().Set("X-Content-Type-Options", "nosniff") w.Header().Set("Content-Type", "text/html; charset=utf-8") - type profile struct { - Name string - Href string - Desc string - Count int - } - var profiles []profile + var profiles []profileEntry for _, p := range pprof.Profiles() { - profiles = append(profiles, profile{ + profiles = append(profiles, profileEntry{ Name: p.Name(), - Href: p.Name() + "?debug=1", + Href: p.Name(), Desc: profileDescriptions[p.Name()], Count: p.Count(), }) @@ -390,7 +388,7 @@ func Index(w http.ResponseWriter, r *http.Request) { // Adding other profiles exposed from within this package for _, p := range []string{"cmdline", "profile", "trace"} { - profiles = append(profiles, profile{ + profiles = append(profiles, profileEntry{ Name: p, Href: p, Desc: profileDescriptions[p], @@ -401,12 +399,14 @@ func Index(w http.ResponseWriter, r *http.Request) { return profiles[i].Name < profiles[j].Name }) - if err := indexTmpl.Execute(w, profiles); err != nil { + if err := indexTmplExecute(w, profiles); err != nil { log.Print(err) } } -var indexTmpl = template.Must(template.New("index").Parse(`<html> +func indexTmplExecute(w io.Writer, profiles []profileEntry) error { + var b bytes.Buffer + b.WriteString(`<html> <head> <title>/debug/pprof/</title> <style> @@ -422,22 +422,28 @@ var indexTmpl = template.Must(template.New("index").Parse(`<html> Types of profiles available: <table> <thead><td>Count</td><td>Profile</td></thead> -{{range .}} - <tr> - <td>{{.Count}}</td><td><a href={{.Href}}>{{.Name}}</a></td> - </tr> -{{end}} -</table> +`) + + for _, profile := range profiles { + link := &url.URL{Path: profile.Href, RawQuery: "debug=1"} + fmt.Fprintf(&b, "<tr><td>%d</td><td><a href='%s'>%s</a></td></tr>\n", profile.Count, link, html.EscapeString(profile.Name)) + } + + b.WriteString(`</table> <a href="goroutine?debug=2">full goroutine stack dump</a> <br/> <p> Profile Descriptions: <ul> -{{range .}} -<li><div class=profile-name>{{.Name}}:</div> {{.Desc}}</li> -{{end}} -</ul> +`) + for _, profile := range profiles { + fmt.Fprintf(&b, "<li><div class=profile-name>%s: </div> %s</li>\n", html.EscapeString(profile.Name), html.EscapeString(profile.Desc)) + } + b.WriteString(`</ul> </p> </body> -</html> -`)) +</html>`) + + _, err := w.Write(b.Bytes()) + return err +} diff --git a/src/net/http/pprof/pprof_test.go b/src/net/http/pprof/pprof_test.go index f6f9ef5b04..84757e401a 100644 --- a/src/net/http/pprof/pprof_test.go +++ b/src/net/http/pprof/pprof_test.go @@ -8,7 +8,7 @@ import ( "bytes" "fmt" "internal/profile" - "io/ioutil" + "io" "net/http" "net/http/httptest" "runtime" @@ -63,7 +63,7 @@ func TestHandlers(t *testing.T) { t.Errorf("status code: got %d; want %d", got, want) } - body, err := ioutil.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) if err != nil { t.Errorf("when reading response body, expected non-nil err; got %v", err) } @@ -227,7 +227,7 @@ func query(endpoint string) (*profile.Profile, error) { return nil, fmt.Errorf("failed to fetch %q: %v", url, r.Status) } - b, err := ioutil.ReadAll(r.Body) + b, err := io.ReadAll(r.Body) r.Body.Close() if err != nil { return nil, fmt.Errorf("failed to read and parse the result from %q: %v", url, err) diff --git a/src/net/http/readrequest_test.go b/src/net/http/readrequest_test.go index b227bb6d38..1950f4907a 100644 --- a/src/net/http/readrequest_test.go +++ b/src/net/http/readrequest_test.go @@ -9,7 +9,6 @@ import ( "bytes" "fmt" "io" - "io/ioutil" "net/url" "reflect" "strings" @@ -468,7 +467,7 @@ func TestReadRequest_Bad(t *testing.T) { for _, tt := range badRequestTests { got, err := ReadRequest(bufio.NewReader(bytes.NewReader(tt.req))) if err == nil { - all, err := ioutil.ReadAll(got.Body) + all, err := io.ReadAll(got.Body) t.Errorf("%s: got unexpected request = %#v\n Body = %q, %v", tt.name, got, all, err) } } diff --git a/src/net/http/request.go b/src/net/http/request.go index fe6b60982c..adba5406e9 100644 --- a/src/net/http/request.go +++ b/src/net/http/request.go @@ -15,7 +15,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "mime" "mime/multipart" "net" @@ -175,6 +174,10 @@ type Request struct { // but will return EOF immediately when no body is present. // The Server will close the request body. The ServeHTTP // Handler does not need to. + // + // Body must allow Read to be called concurrently with Close. + // In particular, calling Close should unblock a Read waiting + // for input. Body io.ReadCloser // GetBody defines an optional func to return a new copy of @@ -382,7 +385,7 @@ func (r *Request) Clone(ctx context.Context) *Request { if s := r.TransferEncoding; s != nil { s2 := make([]string, len(s)) copy(s2, s) - r2.TransferEncoding = s + r2.TransferEncoding = s2 } r2.Form = cloneURLValues(r.Form) r2.PostForm = cloneURLValues(r.PostForm) @@ -540,6 +543,7 @@ var errMissingHost = errors.New("http: Request.Write on Request with no Host or // extraHeaders may be nil // waitForContinue may be nil +// always closes body func (r *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, waitForContinue func() bool) (err error) { trace := httptrace.ContextClientTrace(r.Context()) if trace != nil && trace.WroteRequest != nil { @@ -549,6 +553,15 @@ func (r *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, waitF }) }() } + closed := false + defer func() { + if closed { + return + } + if closeErr := r.closeBody(); closeErr != nil && err == nil { + err = closeErr + } + }() // Find the target host. Prefer the Host: header, but if that // is not given, use the host from the request URL. @@ -667,6 +680,7 @@ func (r *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, waitF trace.Wait100Continue() } if !waitForContinue() { + closed = true r.closeBody() return nil } @@ -679,6 +693,7 @@ func (r *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, waitF } // Write body and trailer + closed = true err = tw.writeBody(w) if err != nil { if tw.bodyReadError == err { @@ -854,7 +869,7 @@ func NewRequestWithContext(ctx context.Context, method, url string, body io.Read } rc, ok := body.(io.ReadCloser) if !ok && body != nil { - rc = ioutil.NopCloser(body) + rc = io.NopCloser(body) } // The host's colon:port should be normalized. See Issue 14836. u.Host = removeEmptyPort(u.Host) @@ -876,21 +891,21 @@ func NewRequestWithContext(ctx context.Context, method, url string, body io.Read buf := v.Bytes() req.GetBody = func() (io.ReadCloser, error) { r := bytes.NewReader(buf) - return ioutil.NopCloser(r), nil + return io.NopCloser(r), nil } case *bytes.Reader: req.ContentLength = int64(v.Len()) snapshot := *v req.GetBody = func() (io.ReadCloser, error) { r := snapshot - return ioutil.NopCloser(&r), nil + return io.NopCloser(&r), nil } case *strings.Reader: req.ContentLength = int64(v.Len()) snapshot := *v req.GetBody = func() (io.ReadCloser, error) { r := snapshot - return ioutil.NopCloser(&r), nil + return io.NopCloser(&r), nil } default: // This is where we'd set it to -1 (at least @@ -1189,7 +1204,7 @@ func parsePostForm(r *Request) (vs url.Values, err error) { maxFormSize = int64(10 << 20) // 10 MB is a lot of text. reader = io.LimitReader(r.Body, maxFormSize+1) } - b, e := ioutil.ReadAll(reader) + b, e := io.ReadAll(reader) if e != nil { if err == nil { err = e @@ -1383,10 +1398,11 @@ func (r *Request) wantsClose() bool { return hasToken(r.Header.get("Connection"), "close") } -func (r *Request) closeBody() { - if r.Body != nil { - r.Body.Close() +func (r *Request) closeBody() error { + if r.Body == nil { + return nil } + return r.Body.Close() } func (r *Request) isReplayable() bool { diff --git a/src/net/http/request_test.go b/src/net/http/request_test.go index 42c16d00ea..19526b9ad7 100644 --- a/src/net/http/request_test.go +++ b/src/net/http/request_test.go @@ -13,6 +13,7 @@ import ( "fmt" "io" "io/ioutil" + "math" "mime/multipart" . "net/http" "net/http/httptest" @@ -103,7 +104,7 @@ func TestParseFormUnknownContentType(t *testing.T) { req := &Request{ Method: "POST", Header: test.contentType, - Body: ioutil.NopCloser(strings.NewReader("body")), + Body: io.NopCloser(strings.NewReader("body")), } err := req.ParseForm() switch { @@ -150,7 +151,7 @@ func TestMultipartReader(t *testing.T) { req := &Request{ Method: "POST", Header: Header{"Content-Type": {test.contentType}}, - Body: ioutil.NopCloser(new(bytes.Buffer)), + Body: io.NopCloser(new(bytes.Buffer)), } multipart, err := req.MultipartReader() if test.shouldError { @@ -187,7 +188,7 @@ binary data req := &Request{ Method: "POST", Header: Header{"Content-Type": {`multipart/form-data; boundary=xxx`}}, - Body: ioutil.NopCloser(strings.NewReader(postData)), + Body: io.NopCloser(strings.NewReader(postData)), } initialFormItems := map[string]string{ @@ -231,7 +232,7 @@ func TestParseMultipartForm(t *testing.T) { req := &Request{ Method: "POST", Header: Header{"Content-Type": {`multipart/form-data; boundary="foo123"`}}, - Body: ioutil.NopCloser(new(bytes.Buffer)), + Body: io.NopCloser(new(bytes.Buffer)), } err := req.ParseMultipartForm(25) if err == nil { @@ -245,6 +246,50 @@ func TestParseMultipartForm(t *testing.T) { } } +// Issue #40430: Test that if maxMemory for ParseMultipartForm when combined with +// the payload size and the internal leeway buffer size of 10MiB overflows, that we +// correctly return an error. +func TestMaxInt64ForMultipartFormMaxMemoryOverflow(t *testing.T) { + defer afterTest(t) + + payloadSize := 1 << 10 + cst := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + // The combination of: + // MaxInt64 + payloadSize + (internal spare of 10MiB) + // triggers the overflow. See issue https://golang.org/issue/40430/ + if err := req.ParseMultipartForm(math.MaxInt64); err != nil { + Error(rw, err.Error(), StatusBadRequest) + return + } + })) + defer cst.Close() + fBuf := new(bytes.Buffer) + mw := multipart.NewWriter(fBuf) + mf, err := mw.CreateFormFile("file", "myfile.txt") + if err != nil { + t.Fatal(err) + } + if _, err := mf.Write(bytes.Repeat([]byte("abc"), payloadSize)); err != nil { + t.Fatal(err) + } + if err := mw.Close(); err != nil { + t.Fatal(err) + } + req, err := NewRequest("POST", cst.URL, fBuf) + if err != nil { + t.Fatal(err) + } + req.Header.Set("Content-Type", mw.FormDataContentType()) + res, err := cst.Client().Do(req) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if g, w := res.StatusCode, StatusBadRequest; g != w { + t.Fatalf("Status code mismatch: got %d, want %d", g, w) + } +} + func TestRedirect_h1(t *testing.T) { testRedirect(t, h1Mode) } func TestRedirect_h2(t *testing.T) { testRedirect(t, h2Mode) } func testRedirect(t *testing.T, h2 bool) { @@ -756,10 +801,10 @@ func (dr delayedEOFReader) Read(p []byte) (n int, err error) { } func TestIssue10884_MaxBytesEOF(t *testing.T) { - dst := ioutil.Discard + dst := io.Discard _, err := io.Copy(dst, MaxBytesReader( responseWriterJustWriter{dst}, - ioutil.NopCloser(delayedEOFReader{strings.NewReader("12345")}), + io.NopCloser(delayedEOFReader{strings.NewReader("12345")}), 5)) if err != nil { t.Fatal(err) @@ -799,7 +844,7 @@ func TestMaxBytesReaderStickyError(t *testing.T) { 2: {101, 100}, } for i, tt := range tests { - rc := MaxBytesReader(nil, ioutil.NopCloser(bytes.NewReader(make([]byte, tt.readable))), tt.limit) + rc := MaxBytesReader(nil, io.NopCloser(bytes.NewReader(make([]byte, tt.readable))), tt.limit) if err := isSticky(rc); err != nil { t.Errorf("%d. error: %v", i, err) } @@ -828,6 +873,27 @@ func TestWithContextDeepCopiesURL(t *testing.T) { } } +// Ensure that Request.Clone creates a deep copy of TransferEncoding. +// See issue 41907. +func TestRequestCloneTransferEncoding(t *testing.T) { + body := strings.NewReader("body") + req, _ := NewRequest("POST", "https://example.org/", body) + req.TransferEncoding = []string{ + "encoding1", + } + + clonedReq := req.Clone(context.Background()) + // modify original after deep copy + req.TransferEncoding[0] = "encoding2" + + if req.TransferEncoding[0] != "encoding2" { + t.Error("expected req.TransferEncoding to be changed") + } + if clonedReq.TransferEncoding[0] != "encoding1" { + t.Error("expected clonedReq.TransferEncoding to be unchanged") + } +} + func TestNoPanicOnRoundTripWithBasicAuth_h1(t *testing.T) { testNoPanicWithBasicAuth(t, h1Mode) } @@ -879,7 +945,7 @@ func TestNewRequestGetBody(t *testing.T) { t.Errorf("test[%d]: GetBody = nil", i) continue } - slurp1, err := ioutil.ReadAll(req.Body) + slurp1, err := io.ReadAll(req.Body) if err != nil { t.Errorf("test[%d]: ReadAll(Body) = %v", i, err) } @@ -887,7 +953,7 @@ func TestNewRequestGetBody(t *testing.T) { if err != nil { t.Errorf("test[%d]: GetBody = %v", i, err) } - slurp2, err := ioutil.ReadAll(newBody) + slurp2, err := io.ReadAll(newBody) if err != nil { t.Errorf("test[%d]: ReadAll(GetBody()) = %v", i, err) } @@ -1124,7 +1190,7 @@ func benchmarkFileAndServer(b *testing.B, n int64) { func runFileAndServerBenchmarks(b *testing.B, tlsOption bool, f *os.File, n int64) { handler := HandlerFunc(func(rw ResponseWriter, req *Request) { defer req.Body.Close() - nc, err := io.Copy(ioutil.Discard, req.Body) + nc, err := io.Copy(io.Discard, req.Body) if err != nil { panic(err) } @@ -1151,7 +1217,7 @@ func runFileAndServerBenchmarks(b *testing.B, tlsOption bool, f *os.File, n int6 } b.StartTimer() - req, err := NewRequest("PUT", cst.URL, ioutil.NopCloser(f)) + req, err := NewRequest("PUT", cst.URL, io.NopCloser(f)) if err != nil { b.Fatal(err) } diff --git a/src/net/http/requestwrite_test.go b/src/net/http/requestwrite_test.go index 9ac6701cfd..1157bdfff9 100644 --- a/src/net/http/requestwrite_test.go +++ b/src/net/http/requestwrite_test.go @@ -10,7 +10,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "net" "net/url" "strings" @@ -229,7 +228,7 @@ var reqWriteTests = []reqWriteTest{ ContentLength: 0, // as if unset by user }, - Body: func() io.ReadCloser { return ioutil.NopCloser(io.LimitReader(strings.NewReader("xx"), 0)) }, + Body: func() io.ReadCloser { return io.NopCloser(io.LimitReader(strings.NewReader("xx"), 0)) }, WantWrite: "POST / HTTP/1.1\r\n" + "Host: example.com\r\n" + @@ -281,7 +280,7 @@ var reqWriteTests = []reqWriteTest{ ContentLength: 0, // as if unset by user }, - Body: func() io.ReadCloser { return ioutil.NopCloser(io.LimitReader(strings.NewReader("xx"), 1)) }, + Body: func() io.ReadCloser { return io.NopCloser(io.LimitReader(strings.NewReader("xx"), 1)) }, WantWrite: "POST / HTTP/1.1\r\n" + "Host: example.com\r\n" + @@ -351,7 +350,7 @@ var reqWriteTests = []reqWriteTest{ Body: func() io.ReadCloser { err := errors.New("Custom reader error") errReader := iotest.ErrReader(err) - return ioutil.NopCloser(io.MultiReader(strings.NewReader("x"), errReader)) + return io.NopCloser(io.MultiReader(strings.NewReader("x"), errReader)) }, WantError: errors.New("Custom reader error"), @@ -371,7 +370,7 @@ var reqWriteTests = []reqWriteTest{ Body: func() io.ReadCloser { err := errors.New("Custom reader error") errReader := iotest.ErrReader(err) - return ioutil.NopCloser(errReader) + return io.NopCloser(errReader) }, WantError: errors.New("Custom reader error"), @@ -620,7 +619,7 @@ func TestRequestWrite(t *testing.T) { } switch b := tt.Body.(type) { case []byte: - tt.Req.Body = ioutil.NopCloser(bytes.NewReader(b)) + tt.Req.Body = io.NopCloser(bytes.NewReader(b)) case func() io.ReadCloser: tt.Req.Body = b() } @@ -716,20 +715,20 @@ func TestRequestWriteTransport(t *testing.T) { }, { method: "GET", - body: ioutil.NopCloser(strings.NewReader("")), + body: io.NopCloser(strings.NewReader("")), want: noContentLengthOrTransferEncoding, }, { method: "GET", clen: -1, - body: ioutil.NopCloser(strings.NewReader("")), + body: io.NopCloser(strings.NewReader("")), want: noContentLengthOrTransferEncoding, }, // A GET with a body, with explicit content length: { method: "GET", clen: 7, - body: ioutil.NopCloser(strings.NewReader("foobody")), + body: io.NopCloser(strings.NewReader("foobody")), want: all(matchSubstr("Content-Length: 7"), matchSubstr("foobody")), }, @@ -737,7 +736,7 @@ func TestRequestWriteTransport(t *testing.T) { { method: "GET", clen: -1, - body: ioutil.NopCloser(strings.NewReader("foobody")), + body: io.NopCloser(strings.NewReader("foobody")), want: all(matchSubstr("Transfer-Encoding: chunked"), matchSubstr("\r\n1\r\nf\r\n"), matchSubstr("oobody")), @@ -747,14 +746,14 @@ func TestRequestWriteTransport(t *testing.T) { { method: "POST", clen: -1, - body: ioutil.NopCloser(strings.NewReader("foobody")), + body: io.NopCloser(strings.NewReader("foobody")), want: all(matchSubstr("Transfer-Encoding: chunked"), matchSubstr("foobody")), }, { method: "POST", clen: -1, - body: ioutil.NopCloser(strings.NewReader("")), + body: io.NopCloser(strings.NewReader("")), want: all(matchSubstr("Transfer-Encoding: chunked")), }, // Verify that a blocking Request.Body doesn't block forever. @@ -766,7 +765,7 @@ func TestRequestWriteTransport(t *testing.T) { tt.afterReqRead = func() { pw.Close() } - tt.body = ioutil.NopCloser(pr) + tt.body = io.NopCloser(pr) }, want: matchSubstr("Transfer-Encoding: chunked"), }, @@ -937,7 +936,7 @@ func dumpRequestOut(req *Request, onReadHeaders func()) ([]byte, error) { } // Ensure all the body is read; otherwise // we'll get a partial dump. - io.Copy(ioutil.Discard, req.Body) + io.Copy(io.Discard, req.Body) req.Body.Close() } dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n") diff --git a/src/net/http/response_test.go b/src/net/http/response_test.go index ce872606b1..8eef65474e 100644 --- a/src/net/http/response_test.go +++ b/src/net/http/response_test.go @@ -12,7 +12,6 @@ import ( "fmt" "go/token" "io" - "io/ioutil" "net/http/internal" "net/url" "reflect" @@ -620,7 +619,7 @@ func TestWriteResponse(t *testing.T) { t.Errorf("#%d: %v", i, err) continue } - err = resp.Write(ioutil.Discard) + err = resp.Write(io.Discard) if err != nil { t.Errorf("#%d: %v", i, err) continue @@ -722,7 +721,7 @@ func TestReadResponseCloseInMiddle(t *testing.T) { } resp.Body.Close() - rest, err := ioutil.ReadAll(bufr) + rest, err := io.ReadAll(bufr) checkErr(err, "ReadAll on remainder") if e, g := "Next Request Here", string(rest); e != g { g = regexp.MustCompile(`(xx+)`).ReplaceAllStringFunc(g, func(match string) string { diff --git a/src/net/http/responsewrite_test.go b/src/net/http/responsewrite_test.go index d41d89896e..1cc87b942e 100644 --- a/src/net/http/responsewrite_test.go +++ b/src/net/http/responsewrite_test.go @@ -6,7 +6,7 @@ package http import ( "bytes" - "io/ioutil" + "io" "strings" "testing" ) @@ -26,7 +26,7 @@ func TestResponseWrite(t *testing.T) { ProtoMinor: 0, Request: dummyReq("GET"), Header: Header{}, - Body: ioutil.NopCloser(strings.NewReader("abcdef")), + Body: io.NopCloser(strings.NewReader("abcdef")), ContentLength: 6, }, @@ -42,7 +42,7 @@ func TestResponseWrite(t *testing.T) { ProtoMinor: 0, Request: dummyReq("GET"), Header: Header{}, - Body: ioutil.NopCloser(strings.NewReader("abcdef")), + Body: io.NopCloser(strings.NewReader("abcdef")), ContentLength: -1, }, "HTTP/1.0 200 OK\r\n" + @@ -57,7 +57,7 @@ func TestResponseWrite(t *testing.T) { ProtoMinor: 1, Request: dummyReq("GET"), Header: Header{}, - Body: ioutil.NopCloser(strings.NewReader("abcdef")), + Body: io.NopCloser(strings.NewReader("abcdef")), ContentLength: -1, Close: true, }, @@ -74,7 +74,7 @@ func TestResponseWrite(t *testing.T) { ProtoMinor: 1, Request: dummyReq11("GET"), Header: Header{}, - Body: ioutil.NopCloser(strings.NewReader("abcdef")), + Body: io.NopCloser(strings.NewReader("abcdef")), ContentLength: -1, Close: false, }, @@ -92,7 +92,7 @@ func TestResponseWrite(t *testing.T) { ProtoMinor: 1, Request: dummyReq11("GET"), Header: Header{}, - Body: ioutil.NopCloser(strings.NewReader("abcdef")), + Body: io.NopCloser(strings.NewReader("abcdef")), ContentLength: -1, TransferEncoding: []string{"chunked"}, Close: false, @@ -125,7 +125,7 @@ func TestResponseWrite(t *testing.T) { ProtoMinor: 1, Request: dummyReq11("GET"), Header: Header{}, - Body: ioutil.NopCloser(strings.NewReader("")), + Body: io.NopCloser(strings.NewReader("")), ContentLength: 0, Close: false, }, @@ -141,7 +141,7 @@ func TestResponseWrite(t *testing.T) { ProtoMinor: 1, Request: dummyReq11("GET"), Header: Header{}, - Body: ioutil.NopCloser(strings.NewReader("foo")), + Body: io.NopCloser(strings.NewReader("foo")), ContentLength: 0, Close: false, }, @@ -157,7 +157,7 @@ func TestResponseWrite(t *testing.T) { ProtoMinor: 1, Request: dummyReq("GET"), Header: Header{}, - Body: ioutil.NopCloser(strings.NewReader("abcdef")), + Body: io.NopCloser(strings.NewReader("abcdef")), ContentLength: 6, TransferEncoding: []string{"chunked"}, Close: true, @@ -218,7 +218,7 @@ func TestResponseWrite(t *testing.T) { Request: &Request{Method: "POST"}, Header: Header{}, ContentLength: -1, - Body: ioutil.NopCloser(strings.NewReader("abcdef")), + Body: io.NopCloser(strings.NewReader("abcdef")), }, "HTTP/1.1 200 OK\r\nConnection: close\r\n\r\nabcdef", }, diff --git a/src/net/http/roundtrip_js.go b/src/net/http/roundtrip_js.go index b09923c386..c6a221ac62 100644 --- a/src/net/http/roundtrip_js.go +++ b/src/net/http/roundtrip_js.go @@ -10,7 +10,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "strconv" "syscall/js" ) @@ -92,7 +91,7 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { // See https://github.com/web-platform-tests/wpt/issues/7693 for WHATWG tests issue. // See https://developer.mozilla.org/en-US/docs/Web/API/Streams_API for more details on the Streams API // and browser support. - body, err := ioutil.ReadAll(req.Body) + body, err := io.ReadAll(req.Body) if err != nil { req.Body.Close() // RoundTrip must always close the body, including on errors. return nil, err diff --git a/src/net/http/serve_test.go b/src/net/http/serve_test.go index 6d3317fb0c..ba54b31a29 100644 --- a/src/net/http/serve_test.go +++ b/src/net/http/serve_test.go @@ -18,7 +18,6 @@ import ( "fmt" "internal/testenv" "io" - "io/ioutil" "log" "math/rand" "net" @@ -529,7 +528,7 @@ func TestServeWithSlashRedirectKeepsQueryString(t *testing.T) { if err != nil { continue } - slurp, _ := ioutil.ReadAll(res.Body) + slurp, _ := io.ReadAll(res.Body) res.Body.Close() if !tt.statusOk { if got, want := res.StatusCode, 404; got != want { @@ -689,7 +688,7 @@ func testServerTimeouts(timeout time.Duration) error { if err != nil { return fmt.Errorf("http Get #1: %v", err) } - got, err := ioutil.ReadAll(r.Body) + got, err := io.ReadAll(r.Body) expected := "req=1" if string(got) != expected || err != nil { return fmt.Errorf("Unexpected response for request #1; got %q ,%v; expected %q, nil", @@ -721,7 +720,7 @@ func testServerTimeouts(timeout time.Duration) error { if err != nil { return fmt.Errorf("http Get #2: %v", err) } - got, err = ioutil.ReadAll(r.Body) + got, err = io.ReadAll(r.Body) r.Body.Close() expected = "req=2" if string(got) != expected || err != nil { @@ -734,7 +733,7 @@ func testServerTimeouts(timeout time.Duration) error { return fmt.Errorf("long Dial: %v", err) } defer conn.Close() - go io.Copy(ioutil.Discard, conn) + go io.Copy(io.Discard, conn) for i := 0; i < 5; i++ { _, err := conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo\r\n\r\n")) if err != nil { @@ -954,7 +953,7 @@ func TestOnlyWriteTimeout(t *testing.T) { errc <- err return } - _, err = io.Copy(ioutil.Discard, res.Body) + _, err = io.Copy(io.Discard, res.Body) res.Body.Close() errc <- err }() @@ -1058,7 +1057,7 @@ func TestIdentityResponse(t *testing.T) { } // The ReadAll will hang for a failing test. - got, _ := ioutil.ReadAll(conn) + got, _ := io.ReadAll(conn) expectedSuffix := "\r\n\r\ntoo short" if !strings.HasSuffix(string(got), expectedSuffix) { t.Errorf("Expected output to end with %q; got response body %q", @@ -1099,7 +1098,7 @@ func testTCPConnectionCloses(t *testing.T, req string, h Handler) { } }() - _, err = ioutil.ReadAll(r) + _, err = io.ReadAll(r) if err != nil { t.Fatal("read error:", err) } @@ -1129,7 +1128,7 @@ func testTCPConnectionStaysOpen(t *testing.T, req string, handler Handler) { if err != nil { t.Fatalf("res %d: %v", i+1, err) } - if _, err := io.Copy(ioutil.Discard, res.Body); err != nil { + if _, err := io.Copy(io.Discard, res.Body); err != nil { t.Fatalf("res %d body copy: %v", i+1, err) } res.Body.Close() @@ -1235,7 +1234,7 @@ func testSetsRemoteAddr(t *testing.T, h2 bool) { if err != nil { t.Fatalf("Get error: %v", err) } - body, err := ioutil.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) if err != nil { t.Fatalf("ReadAll error: %v", err) } @@ -1299,7 +1298,7 @@ func TestServerAllowsBlockingRemoteAddr(t *testing.T) { return } defer resp.Body.Close() - body, err := ioutil.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) if err != nil { t.Errorf("Request %d: %v", num, err) response <- "" @@ -1381,7 +1380,7 @@ func testHeadResponses(t *testing.T, h2 bool) { if v := res.ContentLength; v != 10 { t.Errorf("Content-Length: %d; want 10", v) } - body, err := ioutil.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) if err != nil { t.Error(err) } @@ -1432,7 +1431,7 @@ func TestTLSServer(t *testing.T) { } } })) - ts.Config.ErrorLog = log.New(ioutil.Discard, "", 0) + ts.Config.ErrorLog = log.New(io.Discard, "", 0) defer ts.Close() // Connect an idle TCP connection to this server before we run @@ -1540,7 +1539,7 @@ func TestTLSServerRejectHTTPRequests(t *testing.T) { } defer conn.Close() io.WriteString(conn, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n") - slurp, err := ioutil.ReadAll(conn) + slurp, err := io.ReadAll(conn) if err != nil { t.Fatal(err) } @@ -1734,7 +1733,7 @@ func TestServerExpect(t *testing.T) { // requests that would read from r.Body, which we only // conditionally want to do. if strings.Contains(r.URL.RawQuery, "readbody=true") { - ioutil.ReadAll(r.Body) + io.ReadAll(r.Body) w.Write([]byte("Hi")) } else { w.WriteHeader(StatusUnauthorized) @@ -1773,7 +1772,7 @@ func TestServerExpect(t *testing.T) { io.Closer }{ conn, - ioutil.NopCloser(nil), + io.NopCloser(nil), } if test.chunked { targ = httputil.NewChunkedWriter(conn) @@ -2072,7 +2071,7 @@ type testHandlerBodyConsumer struct { var testHandlerBodyConsumers = []testHandlerBodyConsumer{ {"nil", func(io.ReadCloser) {}}, {"close", func(r io.ReadCloser) { r.Close() }}, - {"discard", func(r io.ReadCloser) { io.Copy(ioutil.Discard, r) }}, + {"discard", func(r io.ReadCloser) { io.Copy(io.Discard, r) }}, } func TestRequestBodyReadErrorClosesConnection(t *testing.T) { @@ -2298,7 +2297,7 @@ func testTimeoutHandler(t *testing.T, h2 bool) { if g, e := res.StatusCode, StatusOK; g != e { t.Errorf("got res.StatusCode %d; expected %d", g, e) } - body, _ := ioutil.ReadAll(res.Body) + body, _ := io.ReadAll(res.Body) if g, e := string(body), "hi"; g != e { t.Errorf("got body %q; expected %q", g, e) } @@ -2315,7 +2314,7 @@ func testTimeoutHandler(t *testing.T, h2 bool) { if g, e := res.StatusCode, StatusServiceUnavailable; g != e { t.Errorf("got res.StatusCode %d; expected %d", g, e) } - body, _ = ioutil.ReadAll(res.Body) + body, _ = io.ReadAll(res.Body) if !strings.Contains(string(body), "<title>Timeout</title>") { t.Errorf("expected timeout body; got %q", string(body)) } @@ -2367,7 +2366,7 @@ func TestTimeoutHandlerRace(t *testing.T) { defer func() { <-gate }() res, err := c.Get(fmt.Sprintf("%s/%d", ts.URL, rand.Intn(50))) if err == nil { - io.Copy(ioutil.Discard, res.Body) + io.Copy(io.Discard, res.Body) res.Body.Close() } }() @@ -2410,7 +2409,7 @@ func TestTimeoutHandlerRaceHeader(t *testing.T) { return } defer res.Body.Close() - io.Copy(ioutil.Discard, res.Body) + io.Copy(io.Discard, res.Body) }() } wg.Wait() @@ -2441,7 +2440,7 @@ func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { if g, e := res.StatusCode, StatusOK; g != e { t.Errorf("got res.StatusCode %d; expected %d", g, e) } - body, _ := ioutil.ReadAll(res.Body) + body, _ := io.ReadAll(res.Body) if g, e := string(body), "hi"; g != e { t.Errorf("got body %q; expected %q", g, e) } @@ -2458,7 +2457,7 @@ func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { if g, e := res.StatusCode, StatusServiceUnavailable; g != e { t.Errorf("got res.StatusCode %d; expected %d", g, e) } - body, _ = ioutil.ReadAll(res.Body) + body, _ = io.ReadAll(res.Body) if !strings.Contains(string(body), "<title>Timeout</title>") { t.Errorf("expected timeout body; got %q", string(body)) } @@ -2630,7 +2629,7 @@ func TestRedirectContentTypeAndBody(t *testing.T) { t.Errorf("Redirect(%q, %#v) generated Content-Type header %q; want %q", tt.method, tt.ct, got, want) } resp := rec.Result() - body, err := ioutil.ReadAll(resp.Body) + body, err := io.ReadAll(resp.Body) if err != nil { t.Fatal(err) } @@ -2657,7 +2656,7 @@ func testZeroLengthPostAndResponse(t *testing.T, h2 bool) { setParallel(t) defer afterTest(t) cst := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, r *Request) { - all, err := ioutil.ReadAll(r.Body) + all, err := io.ReadAll(r.Body) if err != nil { t.Fatalf("handler ReadAll: %v", err) } @@ -2683,7 +2682,7 @@ func testZeroLengthPostAndResponse(t *testing.T, h2 bool) { } for i := range resp { - all, err := ioutil.ReadAll(resp[i].Body) + all, err := io.ReadAll(resp[i].Body) if err != nil { t.Fatalf("req #%d: client ReadAll: %v", i, err) } @@ -2710,7 +2709,7 @@ func TestHandlerPanicWithHijack(t *testing.T) { func testHandlerPanic(t *testing.T, withHijack, h2 bool, wrapper func(Handler) Handler, panicValue interface{}) { defer afterTest(t) - // Unlike the other tests that set the log output to ioutil.Discard + // Unlike the other tests that set the log output to io.Discard // to quiet the output, this test uses a pipe. The pipe serves three // purposes: // @@ -2970,7 +2969,7 @@ func testRequestBodyLimit(t *testing.T, h2 bool) { const limit = 1 << 20 cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { r.Body = MaxBytesReader(w, r.Body, limit) - n, err := io.Copy(ioutil.Discard, r.Body) + n, err := io.Copy(io.Discard, r.Body) if err == nil { t.Errorf("expected error from io.Copy") } @@ -3020,7 +3019,7 @@ func TestClientWriteShutdown(t *testing.T) { donec := make(chan bool) go func() { defer close(donec) - bs, err := ioutil.ReadAll(conn) + bs, err := io.ReadAll(conn) if err != nil { t.Errorf("ReadAll: %v", err) } @@ -3341,7 +3340,7 @@ func TestHijackBeforeRequestBodyRead(t *testing.T) { r.Body = nil // to test that server.go doesn't use this value. gone := w.(CloseNotifier).CloseNotify() - slurp, err := ioutil.ReadAll(reqBody) + slurp, err := io.ReadAll(reqBody) if err != nil { t.Errorf("Body read: %v", err) return @@ -3643,7 +3642,7 @@ func TestAcceptMaxFds(t *testing.T) { }}} server := &Server{ Handler: HandlerFunc(HandlerFunc(func(ResponseWriter, *Request) {})), - ErrorLog: log.New(ioutil.Discard, "", 0), // noisy otherwise + ErrorLog: log.New(io.Discard, "", 0), // noisy otherwise } err := server.Serve(ln) if err != io.EOF { @@ -3782,7 +3781,7 @@ func testServerReaderFromOrder(t *testing.T, h2 bool) { close(done) }() time.Sleep(25 * time.Millisecond) // give Copy a chance to break things - n, err := io.Copy(ioutil.Discard, req.Body) + n, err := io.Copy(io.Discard, req.Body) if err != nil { t.Errorf("handler Copy: %v", err) return @@ -3804,7 +3803,7 @@ func testServerReaderFromOrder(t *testing.T, h2 bool) { if err != nil { t.Fatal(err) } - all, err := ioutil.ReadAll(res.Body) + all, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } @@ -3929,7 +3928,7 @@ func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) { errorf("Proxy outbound request: %v", err) return } - _, err = io.CopyN(ioutil.Discard, bresp.Body, bodySize/2) + _, err = io.CopyN(io.Discard, bresp.Body, bodySize/2) if err != nil { errorf("Proxy copy error: %v", err) return @@ -4136,7 +4135,7 @@ func TestServerConnState(t *testing.T) { ts.Close() }() - ts.Config.ErrorLog = log.New(ioutil.Discard, "", 0) + ts.Config.ErrorLog = log.New(io.Discard, "", 0) ts.Config.ConnState = func(c net.Conn, state ConnState) { if c == nil { t.Errorf("nil conn seen in state %s", state) @@ -4176,7 +4175,7 @@ func TestServerConnState(t *testing.T) { t.Errorf("Error fetching %s: %v", url, err) return } - _, err = ioutil.ReadAll(res.Body) + _, err = io.ReadAll(res.Body) defer res.Body.Close() if err != nil { t.Errorf("Error reading %s: %v", url, err) @@ -4233,7 +4232,7 @@ func TestServerConnState(t *testing.T) { if err != nil { t.Fatal(err) } - if _, err := io.Copy(ioutil.Discard, res.Body); err != nil { + if _, err := io.Copy(io.Discard, res.Body); err != nil { t.Fatal(err) } c.Close() @@ -4275,11 +4274,17 @@ func testServerEmptyBodyRace(t *testing.T, h2 bool) { defer wg.Done() res, err := cst.c.Get(cst.ts.URL) if err != nil { - t.Error(err) - return + // Try to deflake spurious "connection reset by peer" under load. + // See golang.org/issue/22540. + time.Sleep(10 * time.Millisecond) + res, err = cst.c.Get(cst.ts.URL) + if err != nil { + t.Error(err) + return + } } defer res.Body.Close() - _, err = io.Copy(ioutil.Discard, res.Body) + _, err = io.Copy(io.Discard, res.Body) if err != nil { t.Error(err) return @@ -4305,7 +4310,7 @@ func TestServerConnStateNew(t *testing.T) { srv.Serve(&oneConnListener{ conn: &rwTestConn{ Reader: strings.NewReader("GET / HTTP/1.1\r\nHost: foo\r\n\r\n"), - Writer: ioutil.Discard, + Writer: io.Discard, }, }) if !sawNew { // testing that this read isn't racy @@ -4361,7 +4366,7 @@ func TestServerFlushAndHijack(t *testing.T) { t.Fatal(err) } defer res.Body.Close() - all, err := ioutil.ReadAll(res.Body) + all, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } @@ -4549,7 +4554,7 @@ Host: foo go Serve(ln, HandlerFunc(func(w ResponseWriter, r *Request) { numReq++ if r.URL.Path == "/readbody" { - ioutil.ReadAll(r.Body) + io.ReadAll(r.Body) } io.WriteString(w, "Hello world!") })) @@ -4602,7 +4607,7 @@ func testHandlerSetsBodyNil(t *testing.T, h2 bool) { t.Fatal(err) } defer res.Body.Close() - slurp, err := ioutil.ReadAll(res.Body) + slurp, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } @@ -4622,7 +4627,7 @@ func TestServerValidatesHostHeader(t *testing.T) { host string want int }{ - {"HTTP/0.9", "", 400}, + {"HTTP/0.9", "", 505}, {"HTTP/1.1", "", 400}, {"HTTP/1.1", "Host: \r\n", 200}, @@ -4654,9 +4659,9 @@ func TestServerValidatesHostHeader(t *testing.T) { {"CONNECT golang.org:443 HTTP/1.1", "", 200}, // But not other HTTP/2 stuff: - {"PRI / HTTP/2.0", "", 400}, - {"GET / HTTP/2.0", "", 400}, - {"GET / HTTP/3.0", "", 400}, + {"PRI / HTTP/2.0", "", 505}, + {"GET / HTTP/2.0", "", 505}, + {"GET / HTTP/3.0", "", 505}, } for _, tt := range tests { conn := &testConn{closec: make(chan bool, 1)} @@ -4718,7 +4723,7 @@ func TestServerHandlersCanHandleH2PRI(t *testing.T) { } defer c.Close() io.WriteString(c, "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n") - slurp, err := ioutil.ReadAll(c) + slurp, err := io.ReadAll(c) if err != nil { t.Fatal(err) } @@ -4952,7 +4957,7 @@ func BenchmarkClientServer(b *testing.B) { if err != nil { b.Fatal("Get:", err) } - all, err := ioutil.ReadAll(res.Body) + all, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { b.Fatal("ReadAll:", err) @@ -5003,7 +5008,7 @@ func benchmarkClientServerParallel(b *testing.B, parallelism int, useTLS bool) { b.Logf("Get: %v", err) continue } - all, err := ioutil.ReadAll(res.Body) + all, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { b.Logf("ReadAll: %v", err) @@ -5038,7 +5043,7 @@ func BenchmarkServer(b *testing.B) { if err != nil { log.Panicf("Get: %v", err) } - all, err := ioutil.ReadAll(res.Body) + all, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { log.Panicf("ReadAll: %v", err) @@ -5161,7 +5166,7 @@ func BenchmarkClient(b *testing.B) { if err != nil { b.Fatalf("Get: %v", err) } - body, err := ioutil.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { b.Fatalf("ReadAll: %v", err) @@ -5251,7 +5256,7 @@ Accept-Charset: ISO-8859-1,utf-8;q=0.7,*;q=0.3 conn := &rwTestConn{ Reader: &repeatReader{content: req, count: b.N}, - Writer: ioutil.Discard, + Writer: io.Discard, closec: make(chan bool, 1), } handled := 0 @@ -5280,7 +5285,7 @@ Host: golang.org conn := &rwTestConn{ Reader: &repeatReader{content: req, count: b.N}, - Writer: ioutil.Discard, + Writer: io.Discard, closec: make(chan bool, 1), } handled := 0 @@ -5340,7 +5345,7 @@ Host: golang.org `) conn := &rwTestConn{ Reader: &repeatReader{content: req, count: b.N}, - Writer: ioutil.Discard, + Writer: io.Discard, closec: make(chan bool, 1), } handled := 0 @@ -5369,7 +5374,7 @@ Host: golang.org conn.Close() }) conn := &rwTestConn{ - Writer: ioutil.Discard, + Writer: io.Discard, closec: make(chan bool, 1), } ln := &oneConnListener{conn: conn} @@ -5432,7 +5437,7 @@ func TestServerIdleTimeout(t *testing.T) { setParallel(t) defer afterTest(t) ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { - io.Copy(ioutil.Discard, r.Body) + io.Copy(io.Discard, r.Body) io.WriteString(w, r.RemoteAddr) })) ts.Config.ReadHeaderTimeout = 1 * time.Second @@ -5447,7 +5452,7 @@ func TestServerIdleTimeout(t *testing.T) { t.Fatal(err) } defer res.Body.Close() - slurp, err := ioutil.ReadAll(res.Body) + slurp, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } @@ -5472,7 +5477,7 @@ func TestServerIdleTimeout(t *testing.T) { defer conn.Close() conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo.com\r\n")) time.Sleep(2 * time.Second) - if _, err := io.CopyN(ioutil.Discard, conn, 1); err == nil { + if _, err := io.CopyN(io.Discard, conn, 1); err == nil { t.Fatal("copy byte succeeded; want err") } } @@ -5483,7 +5488,7 @@ func get(t *testing.T, c *Client, url string) string { t.Fatal(err) } defer res.Body.Close() - slurp, err := ioutil.ReadAll(res.Body) + slurp, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } @@ -5733,7 +5738,7 @@ func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) { if err != nil { return fmt.Errorf("Get: %v", err) } - slurp, err := ioutil.ReadAll(res.Body) + slurp, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { return fmt.Errorf("Body ReadAll: %v", err) @@ -5796,7 +5801,7 @@ func TestServerDuplicateBackgroundRead(t *testing.T) { wg.Add(1) go func() { defer wg.Done() - io.Copy(ioutil.Discard, cn) + io.Copy(io.Discard, cn) }() for j := 0; j < requests; j++ { @@ -5896,7 +5901,7 @@ func TestServerHijackGetsBackgroundByte_big(t *testing.T) { return } defer conn.Close() - slurp, err := ioutil.ReadAll(buf.Reader) + slurp, err := io.ReadAll(buf.Reader) if err != nil { t.Errorf("Copy: %v", err) } @@ -6430,13 +6435,13 @@ func fetchWireResponse(host string, http1ReqBody []byte) ([]byte, error) { if _, err := conn.Write(http1ReqBody); err != nil { return nil, err } - return ioutil.ReadAll(conn) + return io.ReadAll(conn) } func BenchmarkResponseStatusLine(b *testing.B) { b.ReportAllocs() b.RunParallel(func(pb *testing.PB) { - bw := bufio.NewWriter(ioutil.Discard) + bw := bufio.NewWriter(io.Discard) var buf3 [3]byte for pb.Next() { Export_writeStatusLine(bw, true, 200, buf3[:]) diff --git a/src/net/http/server.go b/src/net/http/server.go index 25fab288f2..4776d960e5 100644 --- a/src/net/http/server.go +++ b/src/net/http/server.go @@ -14,8 +14,8 @@ import ( "errors" "fmt" "io" - "io/ioutil" "log" + "math/rand" "net" "net/textproto" "net/url" @@ -890,12 +890,12 @@ func (srv *Server) initialReadLimitSize() int64 { type expectContinueReader struct { resp *response readCloser io.ReadCloser - closed bool + closed atomicBool sawEOF atomicBool } func (ecr *expectContinueReader) Read(p []byte) (n int, err error) { - if ecr.closed { + if ecr.closed.isSet() { return 0, ErrBodyReadAfterClose } w := ecr.resp @@ -917,7 +917,7 @@ func (ecr *expectContinueReader) Read(p []byte) (n int, err error) { } func (ecr *expectContinueReader) Close() error { - ecr.closed = true + ecr.closed.setTrue() return ecr.readCloser.Close() } @@ -992,7 +992,7 @@ func (c *conn) readRequest(ctx context.Context) (w *response, err error) { } if !http1ServerSupportsRequest(req) { - return nil, badRequestError("unsupported protocol version") + return nil, statusError{StatusHTTPVersionNotSupported, "unsupported protocol version"} } c.lastMethod = req.Method @@ -1368,7 +1368,7 @@ func (cw *chunkWriter) writeHeader(p []byte) { } if discard { - _, err := io.CopyN(ioutil.Discard, w.reqBody, maxPostHandlerReadBytes+1) + _, err := io.CopyN(io.Discard, w.reqBody, maxPostHandlerReadBytes+1) switch err { case nil: // There must be even more data left over. @@ -1773,9 +1773,16 @@ func (c *conn) getState() (state ConnState, unixSec int64) { // badRequestError is a literal string (used by in the server in HTML, // unescaped) to tell the user why their request was bad. It should // be plain text without user info or other embedded errors. -type badRequestError string +func badRequestError(e string) error { return statusError{StatusBadRequest, e} } -func (e badRequestError) Error() string { return "Bad Request: " + string(e) } +// statusError is an error used to respond to a request with an HTTP status. +// The text should be plain text without user info or other embedded errors. +type statusError struct { + code int + text string +} + +func (e statusError) Error() string { return StatusText(e.code) + ": " + e.text } // ErrAbortHandler is a sentinel panic value to abort a handler. // While any panic from ServeHTTP aborts the response to the client, @@ -1898,11 +1905,11 @@ func (c *conn) serve(ctx context.Context) { return // don't reply default: - publicErr := "400 Bad Request" - if v, ok := err.(badRequestError); ok { - publicErr = publicErr + ": " + string(v) + if v, ok := err.(statusError); ok { + fmt.Fprintf(c.rwc, "HTTP/1.1 %d %s: %s%s%d %s: %s", v.code, StatusText(v.code), v.text, errorHeaders, v.code, StatusText(v.code), v.text) + return } - + publicErr := "400 Bad Request" fmt.Fprintf(c.rwc, "HTTP/1.1 "+publicErr+errorHeaders+publicErr) return } @@ -2685,14 +2692,14 @@ func (srv *Server) Close() error { return err } -// shutdownPollInterval is how often we poll for quiescence -// during Server.Shutdown. This is lower during tests, to -// speed up tests. +// shutdownPollIntervalMax is the max polling interval when checking +// quiescence during Server.Shutdown. Polling starts with a small +// interval and backs off to the max. // Ideally we could find a solution that doesn't involve polling, // but which also doesn't have a high runtime cost (and doesn't // involve any contentious mutexes), but that is left as an // exercise for the reader. -var shutdownPollInterval = 500 * time.Millisecond +const shutdownPollIntervalMax = 500 * time.Millisecond // Shutdown gracefully shuts down the server without interrupting any // active connections. Shutdown works by first closing all open @@ -2725,8 +2732,20 @@ func (srv *Server) Shutdown(ctx context.Context) error { } srv.mu.Unlock() - ticker := time.NewTicker(shutdownPollInterval) - defer ticker.Stop() + pollIntervalBase := time.Millisecond + nextPollInterval := func() time.Duration { + // Add 10% jitter. + interval := pollIntervalBase + time.Duration(rand.Intn(int(pollIntervalBase/10))) + // Double and clamp for next time. + pollIntervalBase *= 2 + if pollIntervalBase > shutdownPollIntervalMax { + pollIntervalBase = shutdownPollIntervalMax + } + return interval + } + + timer := time.NewTimer(nextPollInterval()) + defer timer.Stop() for { if srv.closeIdleConns() && srv.numListeners() == 0 { return lnerr @@ -2734,7 +2753,8 @@ func (srv *Server) Shutdown(ctx context.Context) error { select { case <-ctx.Done(): return ctx.Err() - case <-ticker.C: + case <-timer.C: + timer.Reset(nextPollInterval()) } } } @@ -3400,7 +3420,7 @@ func (globalOptionsHandler) ServeHTTP(w ResponseWriter, r *Request) { // (or an attack) and we abort and close the connection, // courtesy of MaxBytesReader's EOF behavior. mb := MaxBytesReader(w, r.Body, 4<<10) - io.Copy(ioutil.Discard, mb) + io.Copy(io.Discard, mb) } } diff --git a/src/net/http/sniff_test.go b/src/net/http/sniff_test.go index a1157a0823..8d5350374d 100644 --- a/src/net/http/sniff_test.go +++ b/src/net/http/sniff_test.go @@ -8,7 +8,6 @@ import ( "bytes" "fmt" "io" - "io/ioutil" "log" . "net/http" "reflect" @@ -123,7 +122,7 @@ func testServerContentType(t *testing.T, h2 bool) { if ct := resp.Header.Get("Content-Type"); ct != wantContentType { t.Errorf("%v: Content-Type = %q, want %q", tt.desc, ct, wantContentType) } - data, err := ioutil.ReadAll(resp.Body) + data, err := io.ReadAll(resp.Body) if err != nil { t.Errorf("%v: reading body: %v", tt.desc, err) } else if !bytes.Equal(data, tt.data) { @@ -185,7 +184,7 @@ func testContentTypeWithCopy(t *testing.T, h2 bool) { if ct := resp.Header.Get("Content-Type"); ct != expected { t.Errorf("Content-Type = %q, want %q", ct, expected) } - data, err := ioutil.ReadAll(resp.Body) + data, err := io.ReadAll(resp.Body) if err != nil { t.Errorf("reading body: %v", err) } else if !bytes.Equal(data, []byte(input)) { @@ -216,7 +215,7 @@ func testSniffWriteSize(t *testing.T, h2 bool) { if err != nil { t.Fatalf("size %d: %v", size, err) } - if _, err := io.Copy(ioutil.Discard, res.Body); err != nil { + if _, err := io.Copy(io.Discard, res.Body); err != nil { t.Fatalf("size %d: io.Copy of body = %v", size, err) } if err := res.Body.Close(); err != nil { diff --git a/src/net/http/transfer.go b/src/net/http/transfer.go index ab009177bc..fbb0c39829 100644 --- a/src/net/http/transfer.go +++ b/src/net/http/transfer.go @@ -10,7 +10,6 @@ import ( "errors" "fmt" "io" - "io/ioutil" "net/http/httptrace" "net/http/internal" "net/textproto" @@ -156,7 +155,7 @@ func newTransferWriter(r interface{}) (t *transferWriter, err error) { // servers. See Issue 18257, as one example. // // The only reason we'd send such a request is if the user set the Body to a -// non-nil value (say, ioutil.NopCloser(bytes.NewReader(nil))) and didn't +// non-nil value (say, io.NopCloser(bytes.NewReader(nil))) and didn't // set ContentLength, or NewRequest set it to -1 (unknown), so then we assume // there's bytes to send. // @@ -330,9 +329,18 @@ func (t *transferWriter) writeHeader(w io.Writer, trace *httptrace.ClientTrace) return nil } -func (t *transferWriter) writeBody(w io.Writer) error { - var err error +// always closes t.BodyCloser +func (t *transferWriter) writeBody(w io.Writer) (err error) { var ncopy int64 + closed := false + defer func() { + if closed || t.BodyCloser == nil { + return + } + if closeErr := t.BodyCloser.Close(); closeErr != nil && err == nil { + err = closeErr + } + }() // Write body. We "unwrap" the body first if it was wrapped in a // nopCloser or readTrackingBody. This is to ensure that we can take advantage of @@ -361,7 +369,7 @@ func (t *transferWriter) writeBody(w io.Writer) error { return err } var nextra int64 - nextra, err = t.doBodyCopy(ioutil.Discard, body) + nextra, err = t.doBodyCopy(io.Discard, body) ncopy += nextra } if err != nil { @@ -369,6 +377,7 @@ func (t *transferWriter) writeBody(w io.Writer) error { } } if t.BodyCloser != nil { + closed = true if err := t.BodyCloser.Close(); err != nil { return err } @@ -982,7 +991,7 @@ func (b *body) Close() error { var n int64 // Consume the body, or, which will also lead to us reading // the trailer headers after the body, if present. - n, err = io.CopyN(ioutil.Discard, bodyLocked{b}, maxPostHandlerReadBytes) + n, err = io.CopyN(io.Discard, bodyLocked{b}, maxPostHandlerReadBytes) if err == io.EOF { err = nil } @@ -993,7 +1002,7 @@ func (b *body) Close() error { default: // Fully consume the body, which will also lead to us reading // the trailer headers after the body, if present. - _, err = io.Copy(ioutil.Discard, bodyLocked{b}) + _, err = io.Copy(io.Discard, bodyLocked{b}) } b.closed = true return err @@ -1065,7 +1074,7 @@ func (fr finishAsyncByteRead) Read(p []byte) (n int, err error) { return } -var nopCloserType = reflect.TypeOf(ioutil.NopCloser(nil)) +var nopCloserType = reflect.TypeOf(io.NopCloser(nil)) // isKnownInMemoryReader reports whether r is a type known to not // block on Read. Its caller uses this as an optional optimization to diff --git a/src/net/http/transfer_test.go b/src/net/http/transfer_test.go index 185225fa93..1f3d32526d 100644 --- a/src/net/http/transfer_test.go +++ b/src/net/http/transfer_test.go @@ -81,11 +81,11 @@ func TestDetectInMemoryReaders(t *testing.T) { {bytes.NewBuffer(nil), true}, {strings.NewReader(""), true}, - {ioutil.NopCloser(pr), false}, + {io.NopCloser(pr), false}, - {ioutil.NopCloser(bytes.NewReader(nil)), true}, - {ioutil.NopCloser(bytes.NewBuffer(nil)), true}, - {ioutil.NopCloser(strings.NewReader("")), true}, + {io.NopCloser(bytes.NewReader(nil)), true}, + {io.NopCloser(bytes.NewBuffer(nil)), true}, + {io.NopCloser(strings.NewReader("")), true}, } for i, tt := range tests { got := isKnownInMemoryReader(tt.r) @@ -104,12 +104,12 @@ var _ io.ReaderFrom = (*mockTransferWriter)(nil) func (w *mockTransferWriter) ReadFrom(r io.Reader) (int64, error) { w.CalledReader = r - return io.Copy(ioutil.Discard, r) + return io.Copy(io.Discard, r) } func (w *mockTransferWriter) Write(p []byte) (int, error) { w.WriteCalled = true - return ioutil.Discard.Write(p) + return io.Discard.Write(p) } func TestTransferWriterWriteBodyReaderTypes(t *testing.T) { @@ -166,7 +166,7 @@ func TestTransferWriterWriteBodyReaderTypes(t *testing.T) { method: "PUT", bodyFunc: func() (io.Reader, func(), error) { r, cleanup, err := newFileFunc() - return ioutil.NopCloser(r), cleanup, err + return io.NopCloser(r), cleanup, err }, contentLength: nBytes, limitedReader: true, @@ -206,7 +206,7 @@ func TestTransferWriterWriteBodyReaderTypes(t *testing.T) { method: "PUT", bodyFunc: func() (io.Reader, func(), error) { r, cleanup, err := newBufferFunc() - return ioutil.NopCloser(r), cleanup, err + return io.NopCloser(r), cleanup, err }, contentLength: nBytes, limitedReader: true, diff --git a/src/net/http/transport.go b/src/net/http/transport.go index b97c4268b5..29d7434f2a 100644 --- a/src/net/http/transport.go +++ b/src/net/http/transport.go @@ -44,7 +44,6 @@ var DefaultTransport RoundTripper = &Transport{ DialContext: (&net.Dialer{ Timeout: 30 * time.Second, KeepAlive: 30 * time.Second, - DualStack: true, }).DialContext, ForceAttemptHTTP2: true, MaxIdleConns: 100, @@ -240,8 +239,18 @@ type Transport struct { // ProxyConnectHeader optionally specifies headers to send to // proxies during CONNECT requests. + // To set the header dynamically, see GetProxyConnectHeader. ProxyConnectHeader Header + // GetProxyConnectHeader optionally specifies a func to return + // headers to send to proxyURL during a CONNECT request to the + // ip:port target. + // If it returns an error, the Transport's RoundTrip fails with + // that error. It can return (nil, nil) to not add headers. + // If GetProxyConnectHeader is non-nil, ProxyConnectHeader is + // ignored. + GetProxyConnectHeader func(ctx context.Context, proxyURL *url.URL, target string) (Header, error) + // MaxResponseHeaderBytes specifies a limit on how many // response bytes are allowed in the server's response // header. @@ -313,6 +322,7 @@ func (t *Transport) Clone() *Transport { ResponseHeaderTimeout: t.ResponseHeaderTimeout, ExpectContinueTimeout: t.ExpectContinueTimeout, ProxyConnectHeader: t.ProxyConnectHeader.Clone(), + GetProxyConnectHeader: t.GetProxyConnectHeader, MaxResponseHeaderBytes: t.MaxResponseHeaderBytes, ForceAttemptHTTP2: t.ForceAttemptHTTP2, WriteBufferSize: t.WriteBufferSize, @@ -613,7 +623,8 @@ var errCannotRewind = errors.New("net/http: cannot rewind body after connection type readTrackingBody struct { io.ReadCloser - didRead bool + didRead bool + didClose bool } func (r *readTrackingBody) Read(data []byte) (int, error) { @@ -621,6 +632,11 @@ func (r *readTrackingBody) Read(data []byte) (int, error) { return r.ReadCloser.Read(data) } +func (r *readTrackingBody) Close() error { + r.didClose = true + return r.ReadCloser.Close() +} + // setupRewindBody returns a new request with a custom body wrapper // that can report whether the body needs rewinding. // This lets rewindBody avoid an error result when the request @@ -639,10 +655,12 @@ func setupRewindBody(req *Request) *Request { // rewindBody takes care of closing req.Body when appropriate // (in all cases except when rewindBody returns req unmodified). func rewindBody(req *Request) (rewound *Request, err error) { - if req.Body == nil || req.Body == NoBody || !req.Body.(*readTrackingBody).didRead { + if req.Body == nil || req.Body == NoBody || (!req.Body.(*readTrackingBody).didRead && !req.Body.(*readTrackingBody).didClose) { return req, nil // nothing to rewind } - req.closeBody() + if !req.Body.(*readTrackingBody).didClose { + req.closeBody() + } if req.GetBody == nil { return nil, errCannotRewind } @@ -1623,7 +1641,17 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (pconn *pers } case cm.targetScheme == "https": conn := pconn.conn - hdr := t.ProxyConnectHeader + var hdr Header + if t.GetProxyConnectHeader != nil { + var err error + hdr, err = t.GetProxyConnectHeader(ctx, cm.proxyURL, cm.targetAddr) + if err != nil { + conn.Close() + return nil, err + } + } else { + hdr = t.ProxyConnectHeader + } if hdr == nil { hdr = make(Header) } @@ -2359,7 +2387,7 @@ func (pc *persistConn) writeLoop() { // Request.Body are high priority. // Set it here before sending on the // channels below or calling - // pc.close() which tears town + // pc.close() which tears down // connections and causes other // errors. wr.req.setError(err) @@ -2368,7 +2396,6 @@ func (pc *persistConn) writeLoop() { err = pc.bw.Flush() } if err != nil { - wr.req.Request.closeBody() if pc.nwrite == startBytesWritten { err = nothingWrittenError{err} } diff --git a/src/net/http/transport_internal_test.go b/src/net/http/transport_internal_test.go index 92729e65b2..1097ffd173 100644 --- a/src/net/http/transport_internal_test.go +++ b/src/net/http/transport_internal_test.go @@ -11,7 +11,6 @@ import ( "crypto/tls" "errors" "io" - "io/ioutil" "net" "net/http/internal" "strings" @@ -226,7 +225,7 @@ func TestTransportBodyAltRewind(t *testing.T) { TLSNextProto: map[string]func(string, *tls.Conn) RoundTripper{ "foo": func(authority string, c *tls.Conn) RoundTripper { return roundTripFunc(func(r *Request) (*Response, error) { - n, _ := io.Copy(ioutil.Discard, r.Body) + n, _ := io.Copy(io.Discard, r.Body) if n == 0 { t.Error("body length is zero") } diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go index f4b7623630..e69133e786 100644 --- a/src/net/http/transport_test.go +++ b/src/net/http/transport_test.go @@ -173,7 +173,7 @@ func TestTransportKeepAlives(t *testing.T) { if err != nil { t.Fatalf("error in disableKeepAlive=%v, req #%d, GET: %v", disableKeepAlive, n, err) } - body, err := ioutil.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) if err != nil { t.Fatalf("error in disableKeepAlive=%v, req #%d, ReadAll: %v", disableKeepAlive, n, err) } @@ -220,7 +220,7 @@ func TestTransportConnectionCloseOnResponse(t *testing.T) { t.Fatalf("error in connectionClose=%v, req #%d, Do: %v", connectionClose, n, err) } defer res.Body.Close() - body, err := ioutil.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) if err != nil { t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err) } @@ -273,7 +273,7 @@ func TestTransportConnectionCloseOnRequest(t *testing.T) { t.Errorf("For connectionClose = %v; handler's X-Saw-Close was %v; want %v", connectionClose, got, !connectionClose) } - body, err := ioutil.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) if err != nil { t.Fatalf("error in connectionClose=%v, req #%d, ReadAll: %v", connectionClose, n, err) } @@ -382,7 +382,7 @@ func TestTransportIdleCacheKeys(t *testing.T) { if err != nil { t.Error(err) } - ioutil.ReadAll(resp.Body) + io.ReadAll(resp.Body) keys := tr.IdleConnKeysForTesting() if e, g := 1, len(keys); e != g { @@ -412,7 +412,7 @@ func TestTransportReadToEndReusesConn(t *testing.T) { w.WriteHeader(200) w.(Flusher).Flush() } else { - w.Header().Set("Content-Type", strconv.Itoa(len(msg))) + w.Header().Set("Content-Length", strconv.Itoa(len(msg))) w.WriteHeader(200) } w.Write([]byte(msg)) @@ -495,7 +495,7 @@ func TestTransportMaxPerHostIdleConns(t *testing.T) { t.Error(err) return } - if _, err := ioutil.ReadAll(resp.Body); err != nil { + if _, err := io.ReadAll(resp.Body); err != nil { t.Errorf("ReadAll: %v", err) return } @@ -575,7 +575,7 @@ func TestTransportMaxConnsPerHostIncludeDialInProgress(t *testing.T) { if err != nil { t.Errorf("unexpected error for request %s: %v", reqId, err) } - _, err = ioutil.ReadAll(resp.Body) + _, err = io.ReadAll(resp.Body) if err != nil { t.Errorf("unexpected error for request %s: %v", reqId, err) } @@ -655,7 +655,7 @@ func TestTransportMaxConnsPerHost(t *testing.T) { t.Fatalf("request failed: %v", err) } defer resp.Body.Close() - _, err = ioutil.ReadAll(resp.Body) + _, err = io.ReadAll(resp.Body) if err != nil { t.Fatalf("read body failed: %v", err) } @@ -733,7 +733,7 @@ func TestTransportRemovesDeadIdleConnections(t *testing.T) { t.Fatalf("%s: %v", name, res.Status) } defer res.Body.Close() - slurp, err := ioutil.ReadAll(res.Body) + slurp, err := io.ReadAll(res.Body) if err != nil { t.Fatalf("%s: %v", name, err) } @@ -783,7 +783,7 @@ func TestTransportServerClosingUnexpectedly(t *testing.T) { condFatalf("error in req #%d, GET: %v", n, err) continue } - body, err := ioutil.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) if err != nil { condFatalf("error in req #%d, ReadAll: %v", n, err) continue @@ -903,7 +903,7 @@ func TestTransportHeadResponses(t *testing.T) { if e, g := int64(123), res.ContentLength; e != g { t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g) } - if all, err := ioutil.ReadAll(res.Body); err != nil { + if all, err := io.ReadAll(res.Body); err != nil { t.Errorf("loop %d: Body ReadAll: %v", i, err) } else if len(all) != 0 { t.Errorf("Bogus body %q", all) @@ -1006,10 +1006,10 @@ func TestRoundTripGzip(t *testing.T) { t.Errorf("%d. gzip NewReader: %v", i, err) continue } - body, err = ioutil.ReadAll(r) + body, err = io.ReadAll(r) res.Body.Close() } else { - body, err = ioutil.ReadAll(res.Body) + body, err = io.ReadAll(res.Body) } if err != nil { t.Errorf("%d. Error: %q", i, err) @@ -1090,7 +1090,7 @@ func TestTransportGzip(t *testing.T) { if err != nil { t.Fatal(err) } - body, err := ioutil.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } @@ -1133,7 +1133,7 @@ func TestTransportExpect100Continue(t *testing.T) { switch req.URL.Path { case "/100": // This endpoint implicitly responds 100 Continue and reads body. - if _, err := io.Copy(ioutil.Discard, req.Body); err != nil { + if _, err := io.Copy(io.Discard, req.Body); err != nil { t.Error("Failed to read Body", err) } rw.WriteHeader(StatusOK) @@ -1159,7 +1159,7 @@ func TestTransportExpect100Continue(t *testing.T) { if err != nil { log.Fatal(err) } - if _, err := io.CopyN(ioutil.Discard, bufrw, req.ContentLength); err != nil { + if _, err := io.CopyN(io.Discard, bufrw, req.ContentLength); err != nil { t.Error("Failed to read Body", err) } bufrw.WriteString("HTTP/1.1 200 OK\r\n\r\n") @@ -1625,7 +1625,7 @@ func TestTransportGzipRecursive(t *testing.T) { if err != nil { t.Fatal(err) } - body, err := ioutil.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } @@ -1654,7 +1654,7 @@ func TestTransportGzipShort(t *testing.T) { t.Fatal(err) } defer res.Body.Close() - _, err = ioutil.ReadAll(res.Body) + _, err = io.ReadAll(res.Body) if err == nil { t.Fatal("Expect an error from reading a body.") } @@ -1701,7 +1701,7 @@ func TestTransportPersistConnLeak(t *testing.T) { res, err := c.Get(ts.URL) didReqCh <- true if err != nil { - t.Errorf("client fetch error: %v", err) + t.Logf("client fetch error: %v", err) failed <- true return } @@ -1715,17 +1715,15 @@ func TestTransportPersistConnLeak(t *testing.T) { case <-gotReqCh: // ok case <-failed: - close(unblockCh) - return + // Not great but not what we are testing: + // sometimes an overloaded system will fail to make all the connections. } } nhigh := runtime.NumGoroutine() // Tell all handlers to unblock and reply. - for i := 0; i < numReq; i++ { - unblockCh <- true - } + close(unblockCh) // Wait for all HTTP clients to be done. for i := 0; i < numReq; i++ { @@ -2001,7 +1999,7 @@ func TestIssue3644(t *testing.T) { t.Fatal(err) } defer res.Body.Close() - bs, err := ioutil.ReadAll(res.Body) + bs, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } @@ -2026,7 +2024,7 @@ func TestIssue3595(t *testing.T) { t.Errorf("Post: %v", err) return } - got, err := ioutil.ReadAll(res.Body) + got, err := io.ReadAll(res.Body) if err != nil { t.Fatalf("Body ReadAll: %v", err) } @@ -2098,7 +2096,7 @@ func TestTransportConcurrency(t *testing.T) { wg.Done() continue } - all, err := ioutil.ReadAll(res.Body) + all, err := io.ReadAll(res.Body) if err != nil { t.Errorf("read error on req %s: %v", req, err) wg.Done() @@ -2165,7 +2163,7 @@ func TestIssue4191_InfiniteGetTimeout(t *testing.T) { t.Errorf("Error issuing GET: %v", err) break } - _, err = io.Copy(ioutil.Discard, sres.Body) + _, err = io.Copy(io.Discard, sres.Body) if err == nil { t.Errorf("Unexpected successful copy") break @@ -2186,7 +2184,7 @@ func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { }) mux.HandleFunc("/put", func(w ResponseWriter, r *Request) { defer r.Body.Close() - io.Copy(ioutil.Discard, r.Body) + io.Copy(io.Discard, r.Body) }) ts := httptest.NewServer(mux) timeout := 100 * time.Millisecond @@ -2340,7 +2338,7 @@ func TestTransportCancelRequest(t *testing.T) { tr.CancelRequest(req) }() t0 := time.Now() - body, err := ioutil.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) d := time.Since(t0) if err != ExportErrRequestCanceled { @@ -2499,7 +2497,7 @@ func TestCancelRequestWithChannel(t *testing.T) { close(ch) }() t0 := time.Now() - body, err := ioutil.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) d := time.Since(t0) if err != ExportErrRequestCanceled { @@ -2680,7 +2678,7 @@ func (fooProto) RoundTrip(req *Request) (*Response, error) { Status: "200 OK", StatusCode: 200, Header: make(Header), - Body: ioutil.NopCloser(strings.NewReader("You wanted " + req.URL.String())), + Body: io.NopCloser(strings.NewReader("You wanted " + req.URL.String())), } return res, nil } @@ -2694,7 +2692,7 @@ func TestTransportAltProto(t *testing.T) { if err != nil { t.Fatal(err) } - bodyb, err := ioutil.ReadAll(res.Body) + bodyb, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } @@ -2771,7 +2769,7 @@ func TestTransportSocketLateBinding(t *testing.T) { // let the foo response finish so we can use its // connection for /bar fooGate <- true - io.Copy(ioutil.Discard, fooRes.Body) + io.Copy(io.Discard, fooRes.Body) fooRes.Body.Close() }) @@ -2810,7 +2808,7 @@ func TestTransportReading100Continue(t *testing.T) { t.Error(err) return } - slurp, err := ioutil.ReadAll(req.Body) + slurp, err := io.ReadAll(req.Body) if err != nil { t.Errorf("Server request body slurp: %v", err) return @@ -2874,7 +2872,7 @@ Content-Length: %d if id, idBack := req.Header.Get("Request-Id"), res.Header.Get("Echo-Request-Id"); id != "" && id != idBack { t.Errorf("%s: response id %q != request id %q", name, idBack, id) } - _, err = ioutil.ReadAll(res.Body) + _, err = io.ReadAll(res.Body) if err != nil { t.Fatalf("%s: Slurp error: %v", name, err) } @@ -3153,7 +3151,7 @@ func TestIdleConnChannelLeak(t *testing.T) { func TestTransportClosesRequestBody(t *testing.T) { defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { - io.Copy(ioutil.Discard, r.Body) + io.Copy(io.Discard, r.Body) })) defer ts.Close() @@ -3260,7 +3258,7 @@ func TestTLSServerClosesConnection(t *testing.T) { t.Fatal(err) } <-closedc - slurp, err := ioutil.ReadAll(res.Body) + slurp, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } @@ -3275,7 +3273,7 @@ func TestTLSServerClosesConnection(t *testing.T) { errs = append(errs, err) continue } - slurp, err = ioutil.ReadAll(res.Body) + slurp, err = io.ReadAll(res.Body) if err != nil { errs = append(errs, err) continue @@ -3346,7 +3344,7 @@ func TestTransportNoReuseAfterEarlyResponse(t *testing.T) { sconn.c = conn sconn.Unlock() conn.Write([]byte("HTTP/1.1 200 OK\r\nContent-Length: 3\r\n\r\nfoo")) // keep-alive - go io.Copy(ioutil.Discard, conn) + go io.Copy(io.Discard, conn) })) defer ts.Close() c := ts.Client() @@ -3595,7 +3593,7 @@ func TestTransportClosesBodyOnError(t *testing.T) { defer afterTest(t) readBody := make(chan error, 1) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { - _, err := ioutil.ReadAll(r.Body) + _, err := io.ReadAll(r.Body) readBody <- err })) defer ts.Close() @@ -3943,7 +3941,7 @@ func TestTransportResponseCancelRace(t *testing.T) { // If we do an early close, Transport just throws the connection away and // doesn't reuse it. In order to trigger the bug, it has to reuse the connection // so read the body - if _, err := io.Copy(ioutil.Discard, res.Body); err != nil { + if _, err := io.Copy(io.Discard, res.Body); err != nil { t.Fatal(err) } @@ -3980,7 +3978,7 @@ func TestTransportContentEncodingCaseInsensitive(t *testing.T) { t.Fatal(err) } - body, err := ioutil.ReadAll(res.Body) + body, err := io.ReadAll(res.Body) res.Body.Close() if err != nil { t.Fatal(err) @@ -4087,7 +4085,7 @@ func TestTransportFlushesBodyChunks(t *testing.T) { if err != nil { t.Fatal(err) } - io.Copy(ioutil.Discard, req.Body) + io.Copy(io.Discard, req.Body) // Unblock the transport's roundTrip goroutine. resBody <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n") @@ -4468,7 +4466,7 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) { // Do nothing for the second request. return } - if _, err := ioutil.ReadAll(r.Body); err != nil { + if _, err := io.ReadAll(r.Body); err != nil { t.Error(err) } if !noHooks { @@ -4556,7 +4554,7 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) { t.Fatal(err) } logf("got roundtrip.response") - slurp, err := ioutil.ReadAll(res.Body) + slurp, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } @@ -5174,6 +5172,57 @@ func TestTransportProxyConnectHeader(t *testing.T) { } } +func TestTransportProxyGetConnectHeader(t *testing.T) { + defer afterTest(t) + reqc := make(chan *Request, 1) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if r.Method != "CONNECT" { + t.Errorf("method = %q; want CONNECT", r.Method) + } + reqc <- r + c, _, err := w.(Hijacker).Hijack() + if err != nil { + t.Errorf("Hijack: %v", err) + return + } + c.Close() + })) + defer ts.Close() + + c := ts.Client() + c.Transport.(*Transport).Proxy = func(r *Request) (*url.URL, error) { + return url.Parse(ts.URL) + } + // These should be ignored: + c.Transport.(*Transport).ProxyConnectHeader = Header{ + "User-Agent": {"foo"}, + "Other": {"bar"}, + } + c.Transport.(*Transport).GetProxyConnectHeader = func(ctx context.Context, proxyURL *url.URL, target string) (Header, error) { + return Header{ + "User-Agent": {"foo2"}, + "Other": {"bar2"}, + }, nil + } + + res, err := c.Get("https://dummy.tld/") // https to force a CONNECT + if err == nil { + res.Body.Close() + t.Errorf("unexpected success") + } + select { + case <-time.After(3 * time.Second): + t.Fatal("timeout") + case r := <-reqc: + if got, want := r.Header.Get("User-Agent"), "foo2"; got != want { + t.Errorf("CONNECT request User-Agent = %q; want %q", got, want) + } + if got, want := r.Header.Get("Other"), "bar2"; got != want { + t.Errorf("CONNECT request Other = %q; want %q", got, want) + } + } +} + var errFakeRoundTrip = errors.New("fake roundtrip") type funcRoundTripper func() @@ -5187,7 +5236,7 @@ func wantBody(res *Response, err error, want string) error { if err != nil { return err } - slurp, err := ioutil.ReadAll(res.Body) + slurp, err := io.ReadAll(res.Body) if err != nil { return fmt.Errorf("error reading body: %v", err) } @@ -5286,7 +5335,7 @@ func TestMissingStatusNoPanic(t *testing.T) { conn, _ := ln.Accept() if conn != nil { io.WriteString(conn, raw) - ioutil.ReadAll(conn) + io.ReadAll(conn) conn.Close() } }() @@ -5304,7 +5353,7 @@ func TestMissingStatusNoPanic(t *testing.T) { t.Error("panicked, expecting an error") } if res != nil && res.Body != nil { - io.Copy(ioutil.Discard, res.Body) + io.Copy(io.Discard, res.Body) res.Body.Close() } @@ -5490,7 +5539,7 @@ func TestClientTimeoutKillsConn_AfterHeaders(t *testing.T) { } close(cancel) - got, err := ioutil.ReadAll(res.Body) + got, err := io.ReadAll(res.Body) if err == nil { t.Fatalf("unexpected success; read %q, nil", got) } @@ -5629,7 +5678,7 @@ func TestTransportCONNECTBidi(t *testing.T) { } func TestTransportRequestReplayable(t *testing.T) { - someBody := ioutil.NopCloser(strings.NewReader("")) + someBody := io.NopCloser(strings.NewReader("")) tests := []struct { name string req *Request @@ -5790,7 +5839,7 @@ func TestTransportRequestWriteRoundTrip(t *testing.T) { t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { - io.Copy(ioutil.Discard, r.Body) + io.Copy(io.Discard, r.Body) r.Body.Close() w.WriteHeader(200) }), @@ -5842,6 +5891,7 @@ func TestTransportClone(t *testing.T) { ResponseHeaderTimeout: time.Second, ExpectContinueTimeout: time.Second, ProxyConnectHeader: Header{}, + GetProxyConnectHeader: func(context.Context, *url.URL, string) (Header, error) { return nil, nil }, MaxResponseHeaderBytes: 1, ForceAttemptHTTP2: true, TLSNextProto: map[string]func(authority string, c *tls.Conn) RoundTripper{ @@ -5925,7 +5975,7 @@ func TestTransportIgnores408(t *testing.T) { if err != nil { t.Fatal(err) } - slurp, err := ioutil.ReadAll(res.Body) + slurp, err := io.ReadAll(res.Body) if err != nil { t.Fatal(err) } @@ -6187,7 +6237,7 @@ func TestTransportDecrementConnWhenIdleConnRemoved(t *testing.T) { return } defer resp.Body.Close() - _, err = ioutil.ReadAll(resp.Body) + _, err = io.ReadAll(resp.Body) if err != nil { errCh <- fmt.Errorf("read body failed: %v", err) } @@ -6249,7 +6299,7 @@ func (f roundTripFunc) RoundTrip(r *Request) (*Response, error) { return f(r) } func TestIssue32441(t *testing.T) { defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { - if n, _ := io.Copy(ioutil.Discard, r.Body); n == 0 { + if n, _ := io.Copy(io.Discard, r.Body); n == 0 { t.Error("body length is zero") } })) @@ -6257,7 +6307,7 @@ func TestIssue32441(t *testing.T) { c := ts.Client() c.Transport.(*Transport).RegisterProtocol("http", roundTripFunc(func(r *Request) (*Response, error) { // Draining body to trigger failure condition on actual request to server. - if n, _ := io.Copy(ioutil.Discard, r.Body); n == 0 { + if n, _ := io.Copy(io.Discard, r.Body); n == 0 { t.Error("body length is zero during round trip") } return nil, ErrSkipAltProtocol @@ -6339,7 +6389,7 @@ func testTransportRace(req *Request) { if err == nil { // Ensure all the body is read; otherwise // we'll get a partial dump. - io.Copy(ioutil.Discard, req.Body) + io.Copy(io.Discard, req.Body) req.Body.Close() } select { |