diff options
Diffstat (limited to 'libgo/go/net/http')
55 files changed, 5048 insertions, 1942 deletions
diff --git a/libgo/go/net/http/cgi/host.go b/libgo/go/net/http/cgi/host.go index 9b4d8754183..58e9f7132a8 100644 --- a/libgo/go/net/http/cgi/host.go +++ b/libgo/go/net/http/cgi/host.go @@ -10,7 +10,7 @@ // // Note that using CGI means starting a new process to handle each // request, which is typically less efficient than using a -// long-running server. This package is intended primarily for +// long-running server. This package is intended primarily for // compatibility with existing systems. package cgi @@ -58,6 +58,7 @@ type Handler struct { InheritEnv []string // environment variables to inherit from host, as "key" Logger *log.Logger // optional log for errors or nil to use log.Print Args []string // optional arguments to pass to child process + Stderr io.Writer // optional stderr for the child process; nil means os.Stderr // PathLocationHandler specifies the root http Handler that // should handle internal redirects when the CGI process @@ -70,6 +71,13 @@ type Handler struct { PathLocationHandler http.Handler } +func (h *Handler) stderr() io.Writer { + if h.Stderr != nil { + return h.Stderr + } + return os.Stderr +} + // removeLeadingDuplicates remove leading duplicate in environments. // It's possible to override environment like following. // cgi.Handler{ @@ -145,6 +153,10 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { for k, v := range req.Header { k = strings.Map(upperCaseAndUnderscore, k) + if k == "PROXY" { + // See Issue 16405 + continue + } joinStr := ", " if k == "COOKIE" { joinStr = "; " @@ -204,7 +216,7 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { Args: append([]string{h.Path}, h.Args...), Dir: cwd, Env: env, - Stderr: os.Stderr, // for now + Stderr: h.stderr(), } if req.ContentLength != 0 { cmd.Stdin = req.Body diff --git a/libgo/go/net/http/cgi/host_test.go b/libgo/go/net/http/cgi/host_test.go index fb7d66adb9f..f0583729eba 100644 --- a/libgo/go/net/http/cgi/host_test.go +++ b/libgo/go/net/http/cgi/host_test.go @@ -8,6 +8,7 @@ package cgi import ( "bufio" + "bytes" "fmt" "io" "net" @@ -34,15 +35,18 @@ func newRequest(httpreq string) *http.Request { return req } -func runCgiTest(t *testing.T, h *Handler, httpreq string, expectedMap map[string]string) *httptest.ResponseRecorder { +func runCgiTest(t *testing.T, h *Handler, + httpreq string, + expectedMap map[string]string, checks ...func(reqInfo map[string]string)) *httptest.ResponseRecorder { rw := httptest.NewRecorder() req := newRequest(httpreq) h.ServeHTTP(rw, req) - runResponseChecks(t, rw, expectedMap) + runResponseChecks(t, rw, expectedMap, checks...) return rw } -func runResponseChecks(t *testing.T, rw *httptest.ResponseRecorder, expectedMap map[string]string) { +func runResponseChecks(t *testing.T, rw *httptest.ResponseRecorder, + expectedMap map[string]string, checks ...func(reqInfo map[string]string)) { // Make a map to hold the test map that the CGI returns. m := make(map[string]string) m["_body"] = rw.Body.String() @@ -80,6 +84,9 @@ readlines: t.Errorf("for key %q got %q; expected %q", key, got, expected) } } + for _, check := range checks { + check(m) + } } var cgiTested, cgiWorks bool @@ -235,6 +242,31 @@ func TestDupHeaders(t *testing.T) { expectedMap) } +// Issue 16405: CGI+http.Transport differing uses of HTTP_PROXY. +// Verify we don't set the HTTP_PROXY environment variable. +// Hope nobody was depending on it. It's not a known header, though. +func TestDropProxyHeader(t *testing.T) { + check(t) + h := &Handler{ + Path: "testdata/test.cgi", + } + expectedMap := map[string]string{ + "env-REQUEST_URI": "/myscript/bar?a=b", + "env-SCRIPT_FILENAME": "testdata/test.cgi", + "env-HTTP_X_FOO": "a", + } + runCgiTest(t, h, "GET /myscript/bar?a=b HTTP/1.0\n"+ + "X-Foo: a\n"+ + "Proxy: should_be_stripped\n"+ + "Host: example.com\n\n", + expectedMap, + func(reqInfo map[string]string) { + if v, ok := reqInfo["env-HTTP_PROXY"]; ok { + t.Errorf("HTTP_PROXY = %q; should be absent", v) + } + }) +} + func TestPathInfoNoRoot(t *testing.T) { check(t) h := &Handler{ @@ -501,6 +533,23 @@ func TestEnvOverride(t *testing.T) { runCgiTest(t, h, "GET /test.cgi HTTP/1.0\nHost: example.com\n\n", expectedMap) } +func TestHandlerStderr(t *testing.T) { + check(t) + var stderr bytes.Buffer + h := &Handler{ + Path: "testdata/test.cgi", + Root: "/test.cgi", + Stderr: &stderr, + } + + rw := httptest.NewRecorder() + req := newRequest("GET /test.cgi?writestderr=1 HTTP/1.0\nHost: example.com\n\n") + h.ServeHTTP(rw, req) + if got, want := stderr.String(), "Hello, stderr!\n"; got != want { + t.Errorf("Stderr = %q; want %q", got, want) + } +} + func TestRemoveLeadingDuplicates(t *testing.T) { tests := []struct { env []string diff --git a/libgo/go/net/http/cgi/testdata/test.cgi b/libgo/go/net/http/cgi/testdata/test.cgi index ec7ee6f3864..667fce217ea 100644 --- a/libgo/go/net/http/cgi/testdata/test.cgi +++ b/libgo/go/net/http/cgi/testdata/test.cgi @@ -23,6 +23,10 @@ print "X-CGI-Pid: $$\r\n"; print "X-Test-Header: X-Test-Value\r\n"; print "\r\n"; +if ($params->{"writestderr"}) { + print STDERR "Hello, stderr!\n"; +} + if ($params->{"bigresponse"}) { # 17 MB, for OS X: golang.org/issue/4958 for (1..(17 * 1024)) { diff --git a/libgo/go/net/http/client.go b/libgo/go/net/http/client.go index 3106d229da6..993c247eef5 100644 --- a/libgo/go/net/http/client.go +++ b/libgo/go/net/http/client.go @@ -44,9 +44,12 @@ type Client struct { // following an HTTP redirect. The arguments req and via are // the upcoming request and the requests made already, oldest // first. If CheckRedirect returns an error, the Client's Get - // method returns both the previous Response and - // CheckRedirect's error (wrapped in a url.Error) instead of - // issuing the Request req. + // method returns both the previous Response (with its Body + // closed) and CheckRedirect's error (wrapped in a url.Error) + // instead of issuing the Request req. + // As a special case, if CheckRedirect returns ErrUseLastResponse, + // then the most recent response is returned with its body + // unclosed, along with a nil error. // // If CheckRedirect is nil, the Client uses its default policy, // which is to stop after 10 consecutive requests. @@ -110,10 +113,6 @@ type RoundTripper interface { RoundTrip(*Request) (*Response, error) } -// Given a string of the form "host", "host:port", or "[ipv6::address]:port", -// return true if the string includes a port. -func hasPort(s string) bool { return strings.LastIndex(s, ":") > strings.LastIndex(s, "]") } - // refererForURL returns a referer without any authentication info or // an empty string if lastReq scheme is https and newReq scheme is http. func refererForURL(lastReq, newReq *url.URL) string { @@ -138,14 +137,6 @@ func refererForURL(lastReq, newReq *url.URL) string { return referer } -// Used in Send to implement io.ReadCloser by bundling together the -// bufio.Reader through which we read the response, and the underlying -// network connection. -type readClose struct { - io.Reader - io.Closer -} - func (c *Client) send(req *Request, deadline time.Time) (*Response, error) { if c.Jar != nil { for _, cookie := range c.Jar.Cookies(req.URL) { @@ -161,28 +152,33 @@ func (c *Client) send(req *Request, deadline time.Time) (*Response, error) { c.Jar.SetCookies(req.URL, rc) } } - return resp, err + return resp, nil } // Do sends an HTTP request and returns an HTTP response, following -// policy (e.g. redirects, cookies, auth) as configured on the client. +// policy (such as redirects, cookies, auth) as configured on the +// client. // // An error is returned if caused by client policy (such as -// CheckRedirect), or if there was an HTTP protocol error. -// A non-2xx response doesn't cause an error. -// -// When err is nil, resp always contains a non-nil resp.Body. +// CheckRedirect), or failure to speak HTTP (such as a network +// connectivity problem). A non-2xx status code doesn't cause an +// error. // -// Callers should close resp.Body when done reading from it. If -// resp.Body is not closed, the Client's underlying RoundTripper -// (typically Transport) may not be able to re-use a persistent TCP -// connection to the server for a subsequent "keep-alive" request. +// If the returned error is nil, the Response will contain a non-nil +// Body which the user is expected to close. If the Body is not +// closed, the Client's underlying RoundTripper (typically Transport) +// may not be able to re-use a persistent TCP connection to the server +// for a subsequent "keep-alive" request. // // The request Body, if non-nil, will be closed by the underlying // Transport, even on errors. // +// On error, any Response can be ignored. A non-nil Response with a +// non-nil error only occurs when CheckRedirect fails, and even then +// the returned Response.Body is already closed. +// // Generally Get, Post, or PostForm will be used instead of Do. -func (c *Client) Do(req *Request) (resp *Response, err error) { +func (c *Client) Do(req *Request) (*Response, error) { method := valueOrDefault(req.Method, "GET") if method == "GET" || method == "HEAD" { return c.doFollowingRedirects(req, shouldRedirectGet) @@ -237,7 +233,7 @@ func send(ireq *Request, rt RoundTripper, deadline time.Time) (*Response, error) } // Most the callers of send (Get, Post, et al) don't need - // Headers, leaving it uninitialized. We guarantee to the + // Headers, leaving it uninitialized. We guarantee to the // Transport that this has been initialized, though. if req.Header == nil { forkReq() @@ -424,101 +420,125 @@ func (c *Client) Get(url string) (resp *Response, err error) { func alwaysFalse() bool { return false } -func (c *Client) doFollowingRedirects(ireq *Request, shouldRedirect func(int) bool) (resp *Response, err error) { - var base *url.URL - redirectChecker := c.CheckRedirect - if redirectChecker == nil { - redirectChecker = defaultCheckRedirect +// ErrUseLastResponse can be returned by Client.CheckRedirect hooks to +// control how redirects are processed. If returned, the next request +// is not sent and the most recent response is returned with its body +// unclosed. +var ErrUseLastResponse = errors.New("net/http: use last response") + +// checkRedirect calls either the user's configured CheckRedirect +// function, or the default. +func (c *Client) checkRedirect(req *Request, via []*Request) error { + fn := c.CheckRedirect + if fn == nil { + fn = defaultCheckRedirect } - var via []*Request + return fn(req, via) +} - if ireq.URL == nil { - ireq.closeBody() +func (c *Client) doFollowingRedirects(req *Request, shouldRedirect func(int) bool) (*Response, error) { + if req.URL == nil { + req.closeBody() return nil, errors.New("http: nil Request.URL") } - req := ireq - deadline := c.deadline() - - urlStr := "" // next relative or absolute URL to fetch (after first request) - redirectFailed := false - for redirect := 0; ; redirect++ { - if redirect != 0 { - nreq := new(Request) - nreq.Cancel = ireq.Cancel - nreq.Method = ireq.Method - if ireq.Method == "POST" || ireq.Method == "PUT" { - nreq.Method = "GET" + var ( + deadline = c.deadline() + reqs []*Request + resp *Response + ) + uerr := func(err error) error { + req.closeBody() + method := valueOrDefault(reqs[0].Method, "GET") + var urlStr string + if resp != nil && resp.Request != nil { + urlStr = resp.Request.URL.String() + } else { + urlStr = req.URL.String() + } + return &url.Error{ + Op: method[:1] + strings.ToLower(method[1:]), + URL: urlStr, + Err: err, + } + } + for { + // For all but the first request, create the next + // request hop and replace req. + if len(reqs) > 0 { + loc := resp.Header.Get("Location") + if loc == "" { + return nil, uerr(fmt.Errorf("%d response missing Location header", resp.StatusCode)) } - nreq.Header = make(Header) - nreq.URL, err = base.Parse(urlStr) + u, err := req.URL.Parse(loc) if err != nil { - break + return nil, uerr(fmt.Errorf("failed to parse Location header %q: %v", loc, err)) } - if len(via) > 0 { - // Add the Referer header. - lastReq := via[len(via)-1] - if ref := refererForURL(lastReq.URL, nreq.URL); ref != "" { - nreq.Header.Set("Referer", ref) - } - - err = redirectChecker(nreq, via) - if err != nil { - redirectFailed = true - break - } + ireq := reqs[0] + req = &Request{ + Method: ireq.Method, + Response: resp, + URL: u, + Header: make(Header), + Cancel: ireq.Cancel, + ctx: ireq.ctx, } - req = nreq - } + if ireq.Method == "POST" || ireq.Method == "PUT" { + req.Method = "GET" + } + // Add the Referer header from the most recent + // request URL to the new one, if it's not https->http: + if ref := refererForURL(reqs[len(reqs)-1].URL, req.URL); ref != "" { + req.Header.Set("Referer", ref) + } + err = c.checkRedirect(req, reqs) - urlStr = req.URL.String() - if resp, err = c.send(req, deadline); err != nil { - if !deadline.IsZero() && !time.Now().Before(deadline) { - err = &httpError{ - err: err.Error() + " (Client.Timeout exceeded while awaiting headers)", - timeout: true, - } + // Sentinel error to let users select the + // previous response, without closing its + // body. See Issue 10069. + if err == ErrUseLastResponse { + return resp, nil } - break - } - if shouldRedirect(resp.StatusCode) { - // Read the body if small so underlying TCP connection will be re-used. - // No need to check for errors: if it fails, Transport won't reuse it anyway. + // Close the previous response's body. But + // read at least some of the body so if it's + // small the underlying TCP connection will be + // re-used. No need to check for errors: if it + // fails, the Transport won't reuse it anyway. const maxBodySlurpSize = 2 << 10 if resp.ContentLength == -1 || resp.ContentLength <= maxBodySlurpSize { io.CopyN(ioutil.Discard, resp.Body, maxBodySlurpSize) } resp.Body.Close() - if urlStr = resp.Header.Get("Location"); urlStr == "" { - err = fmt.Errorf("%d response missing Location header", resp.StatusCode) - break + + if err != nil { + // Special case for Go 1 compatibility: return both the response + // and an error if the CheckRedirect function failed. + // See https://golang.org/issue/3795 + // The resp.Body has already been closed. + ue := uerr(err) + ue.(*url.Error).URL = loc + return resp, ue } - base = req.URL - via = append(via, req) - continue } - return resp, nil - } - method := valueOrDefault(ireq.Method, "GET") - urlErr := &url.Error{ - Op: method[:1] + strings.ToLower(method[1:]), - URL: urlStr, - Err: err, - } + reqs = append(reqs, req) - if redirectFailed { - // Special case for Go 1 compatibility: return both the response - // and an error if the CheckRedirect function failed. - // See https://golang.org/issue/3795 - return resp, urlErr - } + var err error + if resp, err = c.send(req, deadline); err != nil { + if !deadline.IsZero() && !time.Now().Before(deadline) { + err = &httpError{ + err: err.Error() + " (Client.Timeout exceeded while awaiting headers)", + timeout: true, + } + } + return nil, uerr(err) + } - if resp != nil { - resp.Body.Close() + if !shouldRedirect(resp.StatusCode) { + return resp, nil + } } - return nil, urlErr } func defaultCheckRedirect(req *Request, via []*Request) error { diff --git a/libgo/go/net/http/client_test.go b/libgo/go/net/http/client_test.go index 8939dc8baf9..a9b1948005c 100644 --- a/libgo/go/net/http/client_test.go +++ b/libgo/go/net/http/client_test.go @@ -8,6 +8,7 @@ package http_test import ( "bytes" + "context" "crypto/tls" "crypto/x509" "encoding/base64" @@ -273,7 +274,7 @@ func TestClientRedirects(t *testing.T) { t.Fatal("didn't see redirect") } if lastReq.Cancel != cancel { - t.Errorf("expected lastReq to have the cancel channel set on the inital req") + t.Errorf("expected lastReq to have the cancel channel set on the initial req") } checkErr = errors.New("no redirects allowed") @@ -290,6 +291,33 @@ func TestClientRedirects(t *testing.T) { } } +func TestClientRedirectContext(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + Redirect(w, r, "/", StatusFound) + })) + defer ts.Close() + + ctx, cancel := context.WithCancel(context.Background()) + c := &Client{CheckRedirect: func(req *Request, via []*Request) error { + cancel() + if len(via) > 2 { + return errors.New("too many redirects") + } + return nil + }} + req, _ := NewRequest("GET", ts.URL, nil) + req = req.WithContext(ctx) + _, err := c.Do(req) + ue, ok := err.(*url.Error) + if !ok { + t.Fatalf("got error %T; want *url.Error", err) + } + if ue.Err != ExportErrRequestCanceled && ue.Err != ExportErrRequestCanceledConn { + t.Errorf("url.Error.Err = %v; want errRequestCanceled or errRequestCanceledConn", ue.Err) + } +} + func TestPostRedirects(t *testing.T) { defer afterTest(t) var log struct { @@ -338,6 +366,44 @@ func TestPostRedirects(t *testing.T) { } } +func TestClientRedirectUseResponse(t *testing.T) { + defer afterTest(t) + const body = "Hello, world." + var ts *httptest.Server + ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if strings.Contains(r.URL.Path, "/other") { + io.WriteString(w, "wrong body") + } else { + w.Header().Set("Location", ts.URL+"/other") + w.WriteHeader(StatusFound) + io.WriteString(w, body) + } + })) + defer ts.Close() + + c := &Client{CheckRedirect: func(req *Request, via []*Request) error { + if req.Response == nil { + t.Error("expected non-nil Request.Response") + } + return ErrUseLastResponse + }} + res, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + if res.StatusCode != StatusFound { + t.Errorf("status = %d; want %d", res.StatusCode, StatusFound) + } + defer res.Body.Close() + slurp, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if string(slurp) != body { + t.Errorf("body = %q; want %q", slurp, body) + } +} + var expectedCookies = []*Cookie{ {Name: "ChocolateChip", Value: "tasty"}, {Name: "First", Value: "Hit"}, @@ -1140,3 +1206,26 @@ func TestReferer(t *testing.T) { } } } + +// issue15577Tripper returns a Response with a redirect response +// header and doesn't populate its Response.Request field. +type issue15577Tripper struct{} + +func (issue15577Tripper) RoundTrip(*Request) (*Response, error) { + resp := &Response{ + StatusCode: 303, + Header: map[string][]string{"Location": {"http://www.example.com/"}}, + Body: ioutil.NopCloser(strings.NewReader("")), + } + return resp, nil +} + +// Issue 15577: don't assume the roundtripper's response populates its Request field. +func TestClientRedirectResponseWithoutRequest(t *testing.T) { + c := &Client{ + CheckRedirect: func(*Request, []*Request) error { return fmt.Errorf("no redirects!") }, + Transport: issue15577Tripper{}, + } + // Check that this doesn't crash: + c.Get("http://dummy.tld") +} diff --git a/libgo/go/net/http/clientserver_test.go b/libgo/go/net/http/clientserver_test.go index aa2473a2773..3d1f09cae83 100644 --- a/libgo/go/net/http/clientserver_test.go +++ b/libgo/go/net/http/clientserver_test.go @@ -17,6 +17,7 @@ import ( "net" . "net/http" "net/http/httptest" + "net/http/httputil" "net/url" "os" "reflect" @@ -43,6 +44,13 @@ func (t *clientServerTest) close() { t.ts.Close() } +func (t *clientServerTest) scheme() string { + if t.h2 { + return "https" + } + return "http" +} + const ( h1Mode = false h2Mode = true @@ -147,10 +155,11 @@ type reqFunc func(c *Client, url string) (*Response, error) // h12Compare is a test that compares HTTP/1 and HTTP/2 behavior // against each other. type h12Compare struct { - Handler func(ResponseWriter, *Request) // required - ReqFunc reqFunc // optional - CheckResponse func(proto string, res *Response) // optional - Opts []interface{} + Handler func(ResponseWriter, *Request) // required + ReqFunc reqFunc // optional + CheckResponse func(proto string, res *Response) // optional + EarlyCheckResponse func(proto string, res *Response) // optional; pre-normalize + Opts []interface{} } func (tt h12Compare) reqFunc() reqFunc { @@ -176,6 +185,12 @@ func (tt h12Compare) run(t *testing.T) { t.Errorf("HTTP/2 request: %v", err) return } + + if fn := tt.EarlyCheckResponse; fn != nil { + fn("HTTP/1.1", res1) + fn("HTTP/2.0", res2) + } + tt.normalizeRes(t, res1, "HTTP/1.1") tt.normalizeRes(t, res2, "HTTP/2.0") res1body, res2body := res1.Body, res2.Body @@ -220,6 +235,7 @@ func (tt h12Compare) normalizeRes(t *testing.T, res *Response, wantProto string) t.Errorf("got %q response; want %q", res.Proto, wantProto) } slurp, err := ioutil.ReadAll(res.Body) + res.Body.Close() res.Body = slurpResult{ ReadCloser: ioutil.NopCloser(bytes.NewReader(slurp)), @@ -356,7 +372,7 @@ func TestH12_HandlerWritesTooLittle(t *testing.T) { } // Tests that the HTTP/1 and HTTP/2 servers prevent handlers from -// writing more than they declared. This test does not test whether +// writing more than they declared. This test does not test whether // the transport deals with too much data, though, since the server // doesn't make it possible to send bogus data. For those tests, see // transport_test.go (for HTTP/1) or x/net/http2/transport_test.go @@ -1049,6 +1065,170 @@ func testTransportGCRequest(t *testing.T, h2, body bool) { } } +func TestTransportRejectsInvalidHeaders_h1(t *testing.T) { + testTransportRejectsInvalidHeaders(t, h1Mode) +} +func TestTransportRejectsInvalidHeaders_h2(t *testing.T) { + testTransportRejectsInvalidHeaders(t, h2Mode) +} +func testTransportRejectsInvalidHeaders(t *testing.T, h2 bool) { + defer afterTest(t) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + fmt.Fprintf(w, "Handler saw headers: %q", r.Header) + })) + defer cst.close() + cst.tr.DisableKeepAlives = true + + tests := []struct { + key, val string + ok bool + }{ + {"Foo", "capital-key", true}, // verify h2 allows capital keys + {"Foo", "foo\x00bar", false}, // \x00 byte in value not allowed + {"Foo", "two\nlines", false}, // \n byte in value not allowed + {"bogus\nkey", "v", false}, // \n byte also not allowed in key + {"A space", "v", false}, // spaces in keys not allowed + {"имя", "v", false}, // key must be ascii + {"name", "валю", true}, // value may be non-ascii + {"", "v", false}, // key must be non-empty + {"k", "", true}, // value may be empty + } + for _, tt := range tests { + dialedc := make(chan bool, 1) + cst.tr.Dial = func(netw, addr string) (net.Conn, error) { + dialedc <- true + return net.Dial(netw, addr) + } + req, _ := NewRequest("GET", cst.ts.URL, nil) + req.Header[tt.key] = []string{tt.val} + res, err := cst.c.Do(req) + var body []byte + if err == nil { + body, _ = ioutil.ReadAll(res.Body) + res.Body.Close() + } + var dialed bool + select { + case <-dialedc: + dialed = true + default: + } + + if !tt.ok && dialed { + t.Errorf("For key %q, value %q, transport dialed. Expected local failure. Response was: (%v, %v)\nServer replied with: %s", tt.key, tt.val, res, err, body) + } else if (err == nil) != tt.ok { + t.Errorf("For key %q, value %q; got err = %v; want ok=%v", tt.key, tt.val, err, tt.ok) + } + } +} + +// Tests that we support bogus under-100 HTTP statuses, because we historically +// have. This might change at some point, but not yet in Go 1.6. +func TestBogusStatusWorks_h1(t *testing.T) { testBogusStatusWorks(t, h1Mode) } +func TestBogusStatusWorks_h2(t *testing.T) { testBogusStatusWorks(t, h2Mode) } +func testBogusStatusWorks(t *testing.T, h2 bool) { + defer afterTest(t) + const code = 7 + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + w.WriteHeader(code) + })) + defer cst.close() + + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + t.Fatal(err) + } + if res.StatusCode != code { + t.Errorf("StatusCode = %d; want %d", res.StatusCode, code) + } +} + +func TestInterruptWithPanic_h1(t *testing.T) { testInterruptWithPanic(t, h1Mode) } +func TestInterruptWithPanic_h2(t *testing.T) { testInterruptWithPanic(t, h2Mode) } +func testInterruptWithPanic(t *testing.T, h2 bool) { + log.SetOutput(ioutil.Discard) // is noisy otherwise + defer log.SetOutput(os.Stderr) + + const msg = "hello" + defer afterTest(t) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + io.WriteString(w, msg) + w.(Flusher).Flush() + panic("no more") + })) + defer cst.close() + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + slurp, err := ioutil.ReadAll(res.Body) + if string(slurp) != msg { + t.Errorf("client read %q; want %q", slurp, msg) + } + if err == nil { + t.Errorf("client read all successfully; want some error") + } +} + +// Issue 15366 +func TestH12_AutoGzipWithDumpResponse(t *testing.T) { + h12Compare{ + Handler: func(w ResponseWriter, r *Request) { + h := w.Header() + h.Set("Content-Encoding", "gzip") + h.Set("Content-Length", "23") + h.Set("Connection", "keep-alive") + io.WriteString(w, "\x1f\x8b\b\x00\x00\x00\x00\x00\x00\x00s\xf3\xf7\a\x00\xab'\xd4\x1a\x03\x00\x00\x00") + }, + EarlyCheckResponse: func(proto string, res *Response) { + if !res.Uncompressed { + t.Errorf("%s: expected Uncompressed to be set", proto) + } + dump, err := httputil.DumpResponse(res, true) + if err != nil { + t.Errorf("%s: DumpResponse: %v", proto, err) + return + } + if strings.Contains(string(dump), "Connection: close") { + t.Errorf("%s: should not see \"Connection: close\" in dump; got:\n%s", proto, dump) + } + if !strings.Contains(string(dump), "FOO") { + t.Errorf("%s: should see \"FOO\" in response; got:\n%s", proto, dump) + } + }, + }.run(t) +} + +// Issue 14607 +func TestCloseIdleConnections_h1(t *testing.T) { testCloseIdleConnections(t, h1Mode) } +func TestCloseIdleConnections_h2(t *testing.T) { testCloseIdleConnections(t, h2Mode) } +func testCloseIdleConnections(t *testing.T, h2 bool) { + defer afterTest(t) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("X-Addr", r.RemoteAddr) + })) + defer cst.close() + get := func() string { + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + v := res.Header.Get("X-Addr") + if v == "" { + t.Fatal("didn't get X-Addr") + } + return v + } + a1 := get() + cst.tr.CloseIdleConnections() + a2 := get() + if a1 == a2 { + t.Errorf("didn't close connection") + } +} + type noteCloseConn struct { net.Conn closeFunc func() diff --git a/libgo/go/net/http/cookie.go b/libgo/go/net/http/cookie.go index 648709dd997..1ea0e9397a3 100644 --- a/libgo/go/net/http/cookie.go +++ b/libgo/go/net/http/cookie.go @@ -223,7 +223,7 @@ func readCookies(h Header, filter string) []*Cookie { return cookies } -// validCookieDomain returns wheter v is a valid cookie domain-value. +// validCookieDomain returns whether v is a valid cookie domain-value. func validCookieDomain(v string) bool { if isCookieDomainName(v) { return true diff --git a/libgo/go/net/http/cookie_test.go b/libgo/go/net/http/cookie_test.go index d474f313476..95e61479a15 100644 --- a/libgo/go/net/http/cookie_test.go +++ b/libgo/go/net/http/cookie_test.go @@ -1,4 +1,4 @@ -// Copyright 2010 The Go Authors. All rights reserved. +// Copyright 2010 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. diff --git a/libgo/go/net/http/cookiejar/punycode.go b/libgo/go/net/http/cookiejar/punycode.go index ea7ceb5ef3f..a9cc666e8c9 100644 --- a/libgo/go/net/http/cookiejar/punycode.go +++ b/libgo/go/net/http/cookiejar/punycode.go @@ -37,7 +37,7 @@ func encode(prefix, s string) (string, error) { delta, n, bias := int32(0), initialN, initialBias b, remaining := int32(0), int32(0) for _, r := range s { - if r < 0x80 { + if r < utf8.RuneSelf { b++ output = append(output, byte(r)) } else { diff --git a/libgo/go/net/http/export_test.go b/libgo/go/net/http/export_test.go index 52bccbdce31..9c5ba0809ad 100644 --- a/libgo/go/net/http/export_test.go +++ b/libgo/go/net/http/export_test.go @@ -1,4 +1,4 @@ -// Copyright 2011 The Go Authors. All rights reserved. +// Copyright 2011 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. @@ -9,22 +9,22 @@ package http import ( "net" + "sort" "sync" "time" ) var ( - DefaultUserAgent = defaultUserAgent - NewLoggingConn = newLoggingConn - ExportAppendTime = appendTime - ExportRefererForURL = refererForURL - ExportServerNewConn = (*Server).newConn - ExportCloseWriteAndWait = (*conn).closeWriteAndWait - ExportErrRequestCanceled = errRequestCanceled - ExportErrRequestCanceledConn = errRequestCanceledConn - ExportServeFile = serveFile - ExportHttp2ConfigureTransport = http2ConfigureTransport - ExportHttp2ConfigureServer = http2ConfigureServer + DefaultUserAgent = defaultUserAgent + NewLoggingConn = newLoggingConn + ExportAppendTime = appendTime + ExportRefererForURL = refererForURL + ExportServerNewConn = (*Server).newConn + ExportCloseWriteAndWait = (*conn).closeWriteAndWait + ExportErrRequestCanceled = errRequestCanceled + ExportErrRequestCanceledConn = errRequestCanceledConn + ExportServeFile = serveFile + ExportHttp2ConfigureServer = http2ConfigureServer ) func init() { @@ -35,9 +35,8 @@ func init() { } var ( - SetEnterRoundTripHook = hookSetter(&testHookEnterRoundTrip) - SetTestHookWaitResLoop = hookSetter(&testHookWaitResLoop) - SetRoundTripRetried = hookSetter(&testHookRoundTripRetried) + SetEnterRoundTripHook = hookSetter(&testHookEnterRoundTrip) + SetRoundTripRetried = hookSetter(&testHookRoundTripRetried) ) func SetReadLoopBeforeNextReadHook(f func()) { @@ -59,9 +58,9 @@ func SetTestHookServerServe(fn func(*Server, net.Listener)) { testHookServerServ func NewTestTimeoutHandler(handler Handler, ch <-chan time.Time) Handler { return &timeoutHandler{ - handler: handler, - timeout: func() <-chan time.Time { return ch }, - // (no body and nil cancelTimer) + handler: handler, + testTimeout: ch, + // (no body) } } @@ -81,21 +80,29 @@ func (t *Transport) IdleConnKeysForTesting() (keys []string) { keys = make([]string, 0) t.idleMu.Lock() defer t.idleMu.Unlock() - if t.idleConn == nil { - return - } for key := range t.idleConn { keys = append(keys, key.String()) } + sort.Strings(keys) return } -func (t *Transport) IdleConnCountForTesting(cacheKey string) int { +func (t *Transport) IdleConnStrsForTesting() []string { + var ret []string t.idleMu.Lock() defer t.idleMu.Unlock() - if t.idleConn == nil { - return 0 + for _, conns := range t.idleConn { + for _, pc := range conns { + ret = append(ret, pc.conn.LocalAddr().String()+"/"+pc.conn.RemoteAddr().String()) + } } + sort.Strings(ret) + return ret +} + +func (t *Transport) IdleConnCountForTesting(cacheKey string) int { + t.idleMu.Lock() + defer t.idleMu.Unlock() for k, conns := range t.idleConn { if k.String() == cacheKey { return len(conns) @@ -144,3 +151,12 @@ func hookSetter(dst *func()) func(func()) { *dst = fn } } + +func ExportHttp2ConfigureTransport(t *Transport) error { + t2, err := http2configureTransport(t) + if err != nil { + return err + } + t.h2transport = t2 + return nil +} diff --git a/libgo/go/net/http/fcgi/fcgi.go b/libgo/go/net/http/fcgi/fcgi.go index 06bba0488a2..337484139d3 100644 --- a/libgo/go/net/http/fcgi/fcgi.go +++ b/libgo/go/net/http/fcgi/fcgi.go @@ -58,8 +58,6 @@ const ( statusUnknownRole ) -const headerLen = 8 - type header struct { Version uint8 Type recType @@ -158,11 +156,6 @@ func (c *conn) writeRecord(recType recType, reqId uint16, b []byte) error { return err } -func (c *conn) writeBeginRequest(reqId uint16, role uint16, flags uint8) error { - b := [8]byte{byte(role >> 8), byte(role), flags} - return c.writeRecord(typeBeginRequest, reqId, b[:]) -} - func (c *conn) writeEndRequest(reqId uint16, appStatus int, protocolStatus uint8) error { b := make([]byte, 8) binary.BigEndian.PutUint32(b, uint32(appStatus)) diff --git a/libgo/go/net/http/filetransport.go b/libgo/go/net/http/filetransport.go index 821787e0c4b..32126d7ec0f 100644 --- a/libgo/go/net/http/filetransport.go +++ b/libgo/go/net/http/filetransport.go @@ -33,7 +33,7 @@ func NewFileTransport(fs FileSystem) RoundTripper { func (t fileTransport) RoundTrip(req *Request) (resp *Response, err error) { // We start ServeHTTP in a goroutine, which may take a long - // time if the file is large. The newPopulateResponseWriter + // time if the file is large. The newPopulateResponseWriter // call returns a channel which either ServeHTTP or finish() // sends our *Response on, once the *Response itself has been // populated (even if the body itself is still being diff --git a/libgo/go/net/http/fs.go b/libgo/go/net/http/fs.go index f61c138c1d9..c7a58a61dff 100644 --- a/libgo/go/net/http/fs.go +++ b/libgo/go/net/http/fs.go @@ -34,7 +34,7 @@ import ( type Dir string func (d Dir) Open(name string) (File, error) { - if filepath.Separator != '/' && strings.IndexRune(name, filepath.Separator) >= 0 || + if filepath.Separator != '/' && strings.ContainsRune(name, filepath.Separator) || strings.Contains(name, "\x00") { return nil, errors.New("http: invalid character in file path") } @@ -96,7 +96,7 @@ func dirList(w ResponseWriter, f File) { } // ServeContent replies to the request using the content in the -// provided ReadSeeker. The main benefit of ServeContent over io.Copy +// provided ReadSeeker. The main benefit of ServeContent over io.Copy // is that it handles Range requests properly, sets the MIME type, and // handles If-Modified-Since requests. // @@ -108,7 +108,7 @@ func dirList(w ResponseWriter, f File) { // never sent in the response. // // If modtime is not the zero time or Unix epoch, ServeContent -// includes it in a Last-Modified header in the response. If the +// includes it in a Last-Modified header in the response. If the // request includes an If-Modified-Since header, ServeContent uses // modtime to decide whether the content needs to be sent at all. // @@ -121,11 +121,11 @@ func dirList(w ResponseWriter, f File) { // Note that *os.File implements the io.ReadSeeker interface. func ServeContent(w ResponseWriter, req *Request, name string, modtime time.Time, content io.ReadSeeker) { sizeFunc := func() (int64, error) { - size, err := content.Seek(0, os.SEEK_END) + size, err := content.Seek(0, io.SeekEnd) if err != nil { return 0, errSeeker } - _, err = content.Seek(0, os.SEEK_SET) + _, err = content.Seek(0, io.SeekStart) if err != nil { return 0, errSeeker } @@ -166,7 +166,7 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time, var buf [sniffLen]byte n, _ := io.ReadFull(content, buf[:]) ctype = DetectContentType(buf[:n]) - _, err := content.Seek(0, os.SEEK_SET) // rewind to output whole file + _, err := content.Seek(0, io.SeekStart) // rewind to output whole file if err != nil { Error(w, "seeker can't seek", StatusInternalServerError) return @@ -196,7 +196,7 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time, // The total number of bytes in all the ranges // is larger than the size of the file by // itself, so this is probably an attack, or a - // dumb client. Ignore the range request. + // dumb client. Ignore the range request. ranges = nil } switch { @@ -213,7 +213,7 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time, // A response to a request for a single range MUST NOT // be sent using the multipart/byteranges media type." ra := ranges[0] - if _, err := content.Seek(ra.start, os.SEEK_SET); err != nil { + if _, err := content.Seek(ra.start, io.SeekStart); err != nil { Error(w, err.Error(), StatusRequestedRangeNotSatisfiable) return } @@ -236,7 +236,7 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time, pw.CloseWithError(err) return } - if _, err := content.Seek(ra.start, os.SEEK_SET); err != nil { + if _, err := content.Seek(ra.start, io.SeekStart); err != nil { pw.CloseWithError(err) return } @@ -291,7 +291,7 @@ func checkLastModified(w ResponseWriter, r *Request, modtime time.Time) bool { // checkETag implements If-None-Match and If-Range checks. // // The ETag or modtime must have been previously set in the -// ResponseWriter's headers. The modtime is only compared at second +// ResponseWriter's headers. The modtime is only compared at second // granularity and may be the zero value to mean unknown. // // The return value is the effective request "Range" header to use and @@ -336,7 +336,7 @@ func checkETag(w ResponseWriter, r *Request, modtime time.Time) (rangeReq string } // TODO(bradfitz): deal with comma-separated or multiple-valued - // list of If-None-match values. For now just handle the common + // list of If-None-match values. For now just handle the common // case of a single item. if inm == etag || inm == "*" { h := w.Header() @@ -393,6 +393,15 @@ func serveFile(w ResponseWriter, r *Request, fs FileSystem, name string, redirec } } + // redirect if the directory name doesn't end in a slash + if d.IsDir() { + url := r.URL.Path + if url[len(url)-1] != '/' { + localRedirect(w, r, path.Base(url)+"/") + return + } + } + // use contents of index.html for directory, if present if d.IsDir() { index := strings.TrimSuffix(name, "/") + indexPage @@ -451,7 +460,7 @@ func localRedirect(w ResponseWriter, r *Request, newPath string) { // ServeFile replies to the request with the contents of the named // file or directory. // -// If the provided file or direcory name is a relative path, it is +// If the provided file or directory name is a relative path, it is // interpreted relative to the current directory and may ascend to parent // directories. If the provided name is constructed from user input, it // should be sanitized before calling ServeFile. As a precaution, ServeFile diff --git a/libgo/go/net/http/fs_test.go b/libgo/go/net/http/fs_test.go index cf5b63c9f75..22be3899223 100644 --- a/libgo/go/net/http/fs_test.go +++ b/libgo/go/net/http/fs_test.go @@ -24,7 +24,6 @@ import ( "reflect" "regexp" "runtime" - "strconv" "strings" "testing" "time" @@ -39,8 +38,6 @@ type wantRange struct { start, end int64 // range [start,end) } -var itoa = strconv.Itoa - var ServeFileRangeTests = []struct { r string code int @@ -508,6 +505,24 @@ func TestServeFileFromCWD(t *testing.T) { } } +// Issue 13996 +func TestServeDirWithoutTrailingSlash(t *testing.T) { + e := "/testdata/" + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + ServeFile(w, r, ".") + })) + defer ts.Close() + r, err := Get(ts.URL + "/testdata") + if err != nil { + t.Fatal(err) + } + r.Body.Close() + if g := r.Request.URL.Path; g != e { + t.Errorf("got %s, want %s", g, e) + } +} + // Tests that ServeFile doesn't add a Content-Length if a Content-Encoding is // specified. func TestServeFileWithContentEncoding_h1(t *testing.T) { testServeFileWithContentEncoding(t, h1Mode) } @@ -963,9 +978,9 @@ func TestLinuxSendfile(t *testing.T) { syscalls := "sendfile,sendfile64" switch runtime.GOARCH { - case "mips64", "mips64le", "alpha": - // mips64 strace doesn't support sendfile64 and will error out - // if we specify that with `-e trace='. + case "mips64", "mips64le", "s390x", "alpha": + // strace on the above platforms doesn't support sendfile64 + // and will error out if we specify that with `-e trace='. syscalls = "sendfile" } diff --git a/libgo/go/net/http/h2_bundle.go b/libgo/go/net/http/h2_bundle.go index 4e19b3e71f7..db774554b2c 100644 --- a/libgo/go/net/http/h2_bundle.go +++ b/libgo/go/net/http/h2_bundle.go @@ -1,5 +1,5 @@ // Code generated by golang.org/x/tools/cmd/bundle. -//go:generate bundle -o h2_bundle.go -prefix http2 -import golang.org/x/net/http2/hpack=internal/golang.org/x/net/http2/hpack golang.org/x/net/http2 +//go:generate bundle -o h2_bundle.go -prefix http2 golang.org/x/net/http2 // Package http2 implements the HTTP/2 protocol. // @@ -20,15 +20,16 @@ import ( "bufio" "bytes" "compress/gzip" + "context" "crypto/tls" "encoding/binary" "errors" "fmt" - "internal/golang.org/x/net/http2/hpack" "io" "io/ioutil" "log" "net" + "net/http/httptrace" "net/textproto" "net/url" "os" @@ -39,6 +40,9 @@ import ( "strings" "sync" "time" + + "golang_org/x/net/http2/hpack" + "golang_org/x/net/lex/httplex" ) // ClientConnPool manages a pool of HTTP/2 client connections. @@ -47,6 +51,18 @@ type http2ClientConnPool interface { MarkDead(*http2ClientConn) } +// clientConnPoolIdleCloser is the interface implemented by ClientConnPool +// implementations which can close their idle connections. +type http2clientConnPoolIdleCloser interface { + http2ClientConnPool + closeIdleConnections() +} + +var ( + _ http2clientConnPoolIdleCloser = (*http2clientConnPool)(nil) + _ http2clientConnPoolIdleCloser = http2noDialClientConnPool{} +) + // TODO: use singleflight for dialing and addConnCalls? type http2clientConnPool struct { t *http2Transport @@ -247,6 +263,15 @@ func http2filterOutClientConn(in []*http2ClientConn, exclude *http2ClientConn) [ return out } +// noDialClientConnPool is an implementation of http2.ClientConnPool +// which never dials. We let the HTTP/1.1 client dial and use its TLS +// connection instead. +type http2noDialClientConnPool struct{ *http2clientConnPool } + +func (p http2noDialClientConnPool) GetClientConn(req *Request, addr string) (*http2ClientConn, error) { + return p.getClientConn(req, addr, http2noDialOnMiss) +} + func http2configureTransport(t1 *Transport) (*http2Transport, error) { connPool := new(http2clientConnPool) t2 := &http2Transport{ @@ -267,7 +292,7 @@ func http2configureTransport(t1 *Transport) (*http2Transport, error) { t1.TLSClientConfig.NextProtos = append(t1.TLSClientConfig.NextProtos, "http/1.1") } upgradeFn := func(authority string, c *tls.Conn) RoundTripper { - addr := http2authorityAddr(authority) + addr := http2authorityAddr("https", authority) if used, err := connPool.addConnIfNeeded(addr, t2, c); err != nil { go c.Close() return http2erringRoundTripper{err} @@ -299,15 +324,6 @@ func http2registerHTTPSProtocol(t *Transport, rt RoundTripper) (err error) { return nil } -// noDialClientConnPool is an implementation of http2.ClientConnPool -// which never dials. We let the HTTP/1.1 client dial and use its TLS -// connection instead. -type http2noDialClientConnPool struct{ *http2clientConnPool } - -func (p http2noDialClientConnPool) GetClientConn(req *Request, addr string) (*http2ClientConn, error) { - return p.getClientConn(req, addr, http2noDialOnMiss) -} - // noDialH2RoundTripper is a RoundTripper which only tries to complete the request // if there's already has a cached connection to the host. type http2noDialH2RoundTripper struct{ t *http2Transport } @@ -403,6 +419,35 @@ func (e http2connError) Error() string { return fmt.Sprintf("http2: connection error: %v: %v", e.Code, e.Reason) } +type http2pseudoHeaderError string + +func (e http2pseudoHeaderError) Error() string { + return fmt.Sprintf("invalid pseudo-header %q", string(e)) +} + +type http2duplicatePseudoHeaderError string + +func (e http2duplicatePseudoHeaderError) Error() string { + return fmt.Sprintf("duplicate pseudo-header %q", string(e)) +} + +type http2headerFieldNameError string + +func (e http2headerFieldNameError) Error() string { + return fmt.Sprintf("invalid header field name %q", string(e)) +} + +type http2headerFieldValueError string + +func (e http2headerFieldValueError) Error() string { + return fmt.Sprintf("invalid header field value %q", string(e)) +} + +var ( + http2errMixPseudoHeaderTypes = errors.New("mix of request and response pseudo headers") + http2errPseudoAfterRegular = errors.New("pseudo header field after regular") +) + // fixedBuffer is an io.ReadWriter backed by a fixed size buffer. // It never allocates, but moves old data as new data is written. type http2fixedBuffer struct { @@ -743,7 +788,7 @@ type http2Frame interface { type http2Framer struct { r io.Reader lastFrame http2Frame - errReason string + errDetail error // lastHeaderStream is non-zero if the last frame was an // unfinished HEADERS/CONTINUATION. @@ -775,14 +820,33 @@ type http2Framer struct { // to return non-compliant frames or frame orders. // This is for testing and permits using the Framer to test // other HTTP/2 implementations' conformance to the spec. + // It is not compatible with ReadMetaHeaders. AllowIllegalReads bool + // ReadMetaHeaders if non-nil causes ReadFrame to merge + // HEADERS and CONTINUATION frames together and return + // MetaHeadersFrame instead. + ReadMetaHeaders *hpack.Decoder + + // MaxHeaderListSize is the http2 MAX_HEADER_LIST_SIZE. + // It's used only if ReadMetaHeaders is set; 0 means a sane default + // (currently 16MB) + // If the limit is hit, MetaHeadersFrame.Truncated is set true. + MaxHeaderListSize uint32 + logReads bool debugFramer *http2Framer // only use for logging written writes debugFramerBuf *bytes.Buffer } +func (fr *http2Framer) maxHeaderListSize() uint32 { + if fr.MaxHeaderListSize == 0 { + return 16 << 20 + } + return fr.MaxHeaderListSize +} + func (f *http2Framer) startWrite(ftype http2FrameType, flags http2Flags, streamID uint32) { f.wbuf = append(f.wbuf[:0], @@ -879,6 +943,17 @@ func (fr *http2Framer) SetMaxReadFrameSize(v uint32) { fr.maxReadSize = v } +// ErrorDetail returns a more detailed error of the last error +// returned by Framer.ReadFrame. For instance, if ReadFrame +// returns a StreamError with code PROTOCOL_ERROR, ErrorDetail +// will say exactly what was invalid. ErrorDetail is not guaranteed +// to return a non-nil value and like the rest of the http2 package, +// its return value is not protected by an API compatibility promise. +// ErrorDetail is reset after the next call to ReadFrame. +func (fr *http2Framer) ErrorDetail() error { + return fr.errDetail +} + // ErrFrameTooLarge is returned from Framer.ReadFrame when the peer // sends a frame that is larger than declared with SetMaxReadFrameSize. var http2ErrFrameTooLarge = errors.New("http2: frame too large") @@ -897,9 +972,10 @@ func http2terminalReadFrameError(err error) bool { // // If the frame is larger than previously set with SetMaxReadFrameSize, the // returned error is ErrFrameTooLarge. Other errors may be of type -// ConnectionError, StreamError, or anything else from from the underlying +// ConnectionError, StreamError, or anything else from the underlying // reader. func (fr *http2Framer) ReadFrame() (http2Frame, error) { + fr.errDetail = nil if fr.lastFrame != nil { fr.lastFrame.invalidate() } @@ -927,6 +1003,9 @@ func (fr *http2Framer) ReadFrame() (http2Frame, error) { if fr.logReads { log.Printf("http2: Framer %p: read %v", fr, http2summarizeFrame(f)) } + if fh.Type == http2FrameHeaders && fr.ReadMetaHeaders != nil { + return fr.readMetaFrame(f.(*http2HeadersFrame)) + } return f, nil } @@ -935,7 +1014,7 @@ func (fr *http2Framer) ReadFrame() (http2Frame, error) { // to the peer before hanging up on them. This might help others debug // their implementations. func (fr *http2Framer) connError(code http2ErrCode, reason string) error { - fr.errReason = reason + fr.errDetail = errors.New(reason) return http2ConnectionError(code) } @@ -1023,7 +1102,14 @@ func http2parseDataFrame(fh http2FrameHeader, payload []byte) (http2Frame, error return f, nil } -var http2errStreamID = errors.New("invalid streamid") +var ( + http2errStreamID = errors.New("invalid stream ID") + http2errDepStreamID = errors.New("invalid dependent stream ID") +) + +func http2validStreamIDOrZero(streamID uint32) bool { + return streamID&(1<<31) == 0 +} func http2validStreamID(streamID uint32) bool { return streamID != 0 && streamID&(1<<31) == 0 @@ -1389,8 +1475,8 @@ func (f *http2Framer) WriteHeaders(p http2HeadersFrameParam) error { } if !p.Priority.IsZero() { v := p.Priority.StreamDep - if !http2validStreamID(v) && !f.AllowIllegalWrites { - return errors.New("invalid dependent stream id") + if !http2validStreamIDOrZero(v) && !f.AllowIllegalWrites { + return http2errDepStreamID } if p.Priority.Exclusive { v |= 1 << 31 @@ -1458,6 +1544,9 @@ func (f *http2Framer) WritePriority(streamID uint32, p http2PriorityParam) error if !http2validStreamID(streamID) && !f.AllowIllegalWrites { return http2errStreamID } + if !http2validStreamIDOrZero(p.StreamDep) { + return http2errDepStreamID + } f.startWrite(http2FramePriority, 0, streamID) v := p.StreamDep if p.Exclusive { @@ -1669,6 +1758,193 @@ type http2headersEnder interface { HeadersEnded() bool } +type http2headersOrContinuation interface { + http2headersEnder + HeaderBlockFragment() []byte +} + +// A MetaHeadersFrame is the representation of one HEADERS frame and +// zero or more contiguous CONTINUATION frames and the decoding of +// their HPACK-encoded contents. +// +// This type of frame does not appear on the wire and is only returned +// by the Framer when Framer.ReadMetaHeaders is set. +type http2MetaHeadersFrame struct { + *http2HeadersFrame + + // Fields are the fields contained in the HEADERS and + // CONTINUATION frames. The underlying slice is owned by the + // Framer and must not be retained after the next call to + // ReadFrame. + // + // Fields are guaranteed to be in the correct http2 order and + // not have unknown pseudo header fields or invalid header + // field names or values. Required pseudo header fields may be + // missing, however. Use the MetaHeadersFrame.Pseudo accessor + // method access pseudo headers. + Fields []hpack.HeaderField + + // Truncated is whether the max header list size limit was hit + // and Fields is incomplete. The hpack decoder state is still + // valid, however. + Truncated bool +} + +// PseudoValue returns the given pseudo header field's value. +// The provided pseudo field should not contain the leading colon. +func (mh *http2MetaHeadersFrame) PseudoValue(pseudo string) string { + for _, hf := range mh.Fields { + if !hf.IsPseudo() { + return "" + } + if hf.Name[1:] == pseudo { + return hf.Value + } + } + return "" +} + +// RegularFields returns the regular (non-pseudo) header fields of mh. +// The caller does not own the returned slice. +func (mh *http2MetaHeadersFrame) RegularFields() []hpack.HeaderField { + for i, hf := range mh.Fields { + if !hf.IsPseudo() { + return mh.Fields[i:] + } + } + return nil +} + +// PseudoFields returns the pseudo header fields of mh. +// The caller does not own the returned slice. +func (mh *http2MetaHeadersFrame) PseudoFields() []hpack.HeaderField { + for i, hf := range mh.Fields { + if !hf.IsPseudo() { + return mh.Fields[:i] + } + } + return mh.Fields +} + +func (mh *http2MetaHeadersFrame) checkPseudos() error { + var isRequest, isResponse bool + pf := mh.PseudoFields() + for i, hf := range pf { + switch hf.Name { + case ":method", ":path", ":scheme", ":authority": + isRequest = true + case ":status": + isResponse = true + default: + return http2pseudoHeaderError(hf.Name) + } + + for _, hf2 := range pf[:i] { + if hf.Name == hf2.Name { + return http2duplicatePseudoHeaderError(hf.Name) + } + } + } + if isRequest && isResponse { + return http2errMixPseudoHeaderTypes + } + return nil +} + +func (fr *http2Framer) maxHeaderStringLen() int { + v := fr.maxHeaderListSize() + if uint32(int(v)) == v { + return int(v) + } + + return 0 +} + +// readMetaFrame returns 0 or more CONTINUATION frames from fr and +// merge them into into the provided hf and returns a MetaHeadersFrame +// with the decoded hpack values. +func (fr *http2Framer) readMetaFrame(hf *http2HeadersFrame) (*http2MetaHeadersFrame, error) { + if fr.AllowIllegalReads { + return nil, errors.New("illegal use of AllowIllegalReads with ReadMetaHeaders") + } + mh := &http2MetaHeadersFrame{ + http2HeadersFrame: hf, + } + var remainSize = fr.maxHeaderListSize() + var sawRegular bool + + var invalid error // pseudo header field errors + hdec := fr.ReadMetaHeaders + hdec.SetEmitEnabled(true) + hdec.SetMaxStringLength(fr.maxHeaderStringLen()) + hdec.SetEmitFunc(func(hf hpack.HeaderField) { + if !httplex.ValidHeaderFieldValue(hf.Value) { + invalid = http2headerFieldValueError(hf.Value) + } + isPseudo := strings.HasPrefix(hf.Name, ":") + if isPseudo { + if sawRegular { + invalid = http2errPseudoAfterRegular + } + } else { + sawRegular = true + if !http2validWireHeaderFieldName(hf.Name) { + invalid = http2headerFieldNameError(hf.Name) + } + } + + if invalid != nil { + hdec.SetEmitEnabled(false) + return + } + + size := hf.Size() + if size > remainSize { + hdec.SetEmitEnabled(false) + mh.Truncated = true + return + } + remainSize -= size + + mh.Fields = append(mh.Fields, hf) + }) + + defer hdec.SetEmitFunc(func(hf hpack.HeaderField) {}) + + var hc http2headersOrContinuation = hf + for { + frag := hc.HeaderBlockFragment() + if _, err := hdec.Write(frag); err != nil { + return nil, http2ConnectionError(http2ErrCodeCompression) + } + + if hc.HeadersEnded() { + break + } + if f, err := fr.ReadFrame(); err != nil { + return nil, err + } else { + hc = f.(*http2ContinuationFrame) + } + } + + mh.http2HeadersFrame.headerFragBuf = nil + mh.http2HeadersFrame.invalidate() + + if err := hdec.Close(); err != nil { + return nil, http2ConnectionError(http2ErrCodeCompression) + } + if invalid != nil { + fr.errDetail = invalid + return nil, http2StreamError{mh.StreamID, http2ErrCodeProtocol} + } + if err := mh.checkPseudos(); err != nil { + fr.errDetail = err + return nil, http2StreamError{mh.StreamID, http2ErrCodeProtocol} + } + return mh, nil +} + func http2summarizeFrame(f http2Frame) string { var buf bytes.Buffer f.Header().writeDebug(&buf) @@ -1712,7 +1988,111 @@ func http2summarizeFrame(f http2Frame) string { return buf.String() } -func http2requestCancel(req *Request) <-chan struct{} { return req.Cancel } +func http2transportExpectContinueTimeout(t1 *Transport) time.Duration { + return t1.ExpectContinueTimeout +} + +// isBadCipher reports whether the cipher is blacklisted by the HTTP/2 spec. +func http2isBadCipher(cipher uint16) bool { + switch cipher { + case tls.TLS_RSA_WITH_RC4_128_SHA, + tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA, + tls.TLS_RSA_WITH_AES_128_CBC_SHA, + tls.TLS_RSA_WITH_AES_256_CBC_SHA, + tls.TLS_RSA_WITH_AES_128_GCM_SHA256, + tls.TLS_RSA_WITH_AES_256_GCM_SHA384, + tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, + tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, + tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, + tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA, + tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, + tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, + tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA: + + return true + default: + return false + } +} + +type http2contextContext interface { + context.Context +} + +func http2serverConnBaseContext(c net.Conn, opts *http2ServeConnOpts) (ctx http2contextContext, cancel func()) { + ctx, cancel = context.WithCancel(context.Background()) + ctx = context.WithValue(ctx, LocalAddrContextKey, c.LocalAddr()) + if hs := opts.baseConfig(); hs != nil { + ctx = context.WithValue(ctx, ServerContextKey, hs) + } + return +} + +func http2contextWithCancel(ctx http2contextContext) (_ http2contextContext, cancel func()) { + return context.WithCancel(ctx) +} + +func http2requestWithContext(req *Request, ctx http2contextContext) *Request { + return req.WithContext(ctx) +} + +type http2clientTrace httptrace.ClientTrace + +func http2reqContext(r *Request) context.Context { return r.Context() } + +func http2setResponseUncompressed(res *Response) { res.Uncompressed = true } + +func http2traceGotConn(req *Request, cc *http2ClientConn) { + trace := httptrace.ContextClientTrace(req.Context()) + if trace == nil || trace.GotConn == nil { + return + } + ci := httptrace.GotConnInfo{Conn: cc.tconn} + cc.mu.Lock() + ci.Reused = cc.nextStreamID > 1 + ci.WasIdle = len(cc.streams) == 0 && ci.Reused + if ci.WasIdle && !cc.lastActive.IsZero() { + ci.IdleTime = time.Now().Sub(cc.lastActive) + } + cc.mu.Unlock() + + trace.GotConn(ci) +} + +func http2traceWroteHeaders(trace *http2clientTrace) { + if trace != nil && trace.WroteHeaders != nil { + trace.WroteHeaders() + } +} + +func http2traceGot100Continue(trace *http2clientTrace) { + if trace != nil && trace.Got100Continue != nil { + trace.Got100Continue() + } +} + +func http2traceWait100Continue(trace *http2clientTrace) { + if trace != nil && trace.Wait100Continue != nil { + trace.Wait100Continue() + } +} + +func http2traceWroteRequest(trace *http2clientTrace, err error) { + if trace != nil && trace.WroteRequest != nil { + trace.WroteRequest(httptrace.WroteRequestInfo{Err: err}) + } +} + +func http2traceFirstResponseByte(trace *http2clientTrace) { + if trace != nil && trace.GotFirstResponseByte != nil { + trace.GotFirstResponseByte() + } +} + +func http2requestTrace(req *Request) *http2clientTrace { + trace := httptrace.ContextClientTrace(req.Context()) + return (*http2clientTrace)(trace) +} var http2DebugGoroutines = os.Getenv("DEBUG_HTTP2_GOROUTINES") == "1" @@ -2070,57 +2450,23 @@ var ( http2errInvalidHeaderFieldValue = errors.New("http2: invalid header field value") ) -// validHeaderFieldName reports whether v is a valid header field name (key). -// RFC 7230 says: -// header-field = field-name ":" OWS field-value OWS -// field-name = token -// tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" / "+" / "-" / "." / -// "^" / "_" / " +// validWireHeaderFieldName reports whether v is a valid header field +// name (key). See httplex.ValidHeaderName for the base rules. +// // Further, http2 says: // "Just as in HTTP/1.x, header field names are strings of ASCII // characters that are compared in a case-insensitive // fashion. However, header field names MUST be converted to // lowercase prior to their encoding in HTTP/2. " -func http2validHeaderFieldName(v string) bool { +func http2validWireHeaderFieldName(v string) bool { if len(v) == 0 { return false } for _, r := range v { - if int(r) >= len(http2isTokenTable) || ('A' <= r && r <= 'Z') { - return false - } - if !http2isTokenTable[byte(r)] { + if !httplex.IsTokenRune(r) { return false } - } - return true -} - -// validHeaderFieldValue reports whether v is a valid header field value. -// -// RFC 7230 says: -// field-value = *( field-content / obs-fold ) -// obj-fold = N/A to http2, and deprecated -// field-content = field-vchar [ 1*( SP / HTAB ) field-vchar ] -// field-vchar = VCHAR / obs-text -// obs-text = %x80-FF -// VCHAR = "any visible [USASCII] character" -// -// http2 further says: "Similarly, HTTP/2 allows header field values -// that are not valid. While most of the values that can be encoded -// will not alter header field parsing, carriage return (CR, ASCII -// 0xd), line feed (LF, ASCII 0xa), and the zero character (NUL, ASCII -// 0x0) might be exploited by an attacker if they are translated -// verbatim. Any request or response that contains a character not -// permitted in a header field value MUST be treated as malformed -// (Section 8.1.2.6). Valid characters are defined by the -// field-content ABNF rule in Section 3.2 of [RFC7230]." -// -// This function does not (yet?) properly handle the rejection of -// strings that begin or end with SP or HTAB. -func http2validHeaderFieldValue(v string) bool { - for i := 0; i < len(v); i++ { - if b := v[i]; b < ' ' && b != '\t' || b == 0x7f { + if 'A' <= r && r <= 'Z' { return false } } @@ -2225,7 +2571,7 @@ func http2mustUint31(v int32) uint32 { } // bodyAllowedForStatus reports whether a given response status code -// permits a body. See RFC2616, section 4.4. +// permits a body. See RFC 2616, section 4.4. func http2bodyAllowedForStatus(status int) bool { switch { case status >= 100 && status <= 199: @@ -2251,90 +2597,44 @@ func (e *http2httpError) Temporary() bool { return true } var http2errTimeout error = &http2httpError{msg: "http2: timeout awaiting response headers", timeout: true} -var http2isTokenTable = [127]bool{ - '!': true, - '#': true, - '$': true, - '%': true, - '&': true, - '\'': true, - '*': true, - '+': true, - '-': true, - '.': true, - '0': true, - '1': true, - '2': true, - '3': true, - '4': true, - '5': true, - '6': true, - '7': true, - '8': true, - '9': true, - 'A': true, - 'B': true, - 'C': true, - 'D': true, - 'E': true, - 'F': true, - 'G': true, - 'H': true, - 'I': true, - 'J': true, - 'K': true, - 'L': true, - 'M': true, - 'N': true, - 'O': true, - 'P': true, - 'Q': true, - 'R': true, - 'S': true, - 'T': true, - 'U': true, - 'W': true, - 'V': true, - 'X': true, - 'Y': true, - 'Z': true, - '^': true, - '_': true, - '`': true, - 'a': true, - 'b': true, - 'c': true, - 'd': true, - 'e': true, - 'f': true, - 'g': true, - 'h': true, - 'i': true, - 'j': true, - 'k': true, - 'l': true, - 'm': true, - 'n': true, - 'o': true, - 'p': true, - 'q': true, - 'r': true, - 's': true, - 't': true, - 'u': true, - 'v': true, - 'w': true, - 'x': true, - 'y': true, - 'z': true, - '|': true, - '~': true, -} - type http2connectionStater interface { ConnectionState() tls.ConnectionState } +var http2sorterPool = sync.Pool{New: func() interface{} { return new(http2sorter) }} + +type http2sorter struct { + v []string // owned by sorter +} + +func (s *http2sorter) Len() int { return len(s.v) } + +func (s *http2sorter) Swap(i, j int) { s.v[i], s.v[j] = s.v[j], s.v[i] } + +func (s *http2sorter) Less(i, j int) bool { return s.v[i] < s.v[j] } + +// Keys returns the sorted keys of h. +// +// The returned slice is only valid until s used again or returned to +// its pool. +func (s *http2sorter) Keys(h Header) []string { + keys := s.v[:0] + for k := range h { + keys = append(keys, k) + } + s.v = keys + sort.Sort(s) + return keys +} + +func (s *http2sorter) SortStrings(ss []string) { + + save := s.v + s.v = ss + sort.Sort(s) + s.v = save +} + // pipe is a goroutine-safe io.Reader/io.Writer pair. It's like // io.Pipe except there are no PipeReader/PipeWriter halves, and the // underlying buffer is an interface. (io.Pipe is always unbuffered) @@ -2354,6 +2654,12 @@ type http2pipeBuffer interface { io.Reader } +func (p *http2pipe) Len() int { + p.mu.Lock() + defer p.mu.Unlock() + return p.b.Len() +} + // Read waits until data is available and copies bytes // from the buffer into p. func (p *http2pipe) Read(d []byte) (n int, err error) { @@ -2653,10 +2959,14 @@ func (o *http2ServeConnOpts) handler() Handler { // // The opts parameter is optional. If nil, default values are used. func (s *http2Server) ServeConn(c net.Conn, opts *http2ServeConnOpts) { + baseCtx, cancel := http2serverConnBaseContext(c, opts) + defer cancel() + sc := &http2serverConn{ srv: s, hs: opts.baseConfig(), conn: c, + baseCtx: baseCtx, remoteAddrStr: c.RemoteAddr().String(), bw: http2newBufferedWriter(c), handler: opts.handler(), @@ -2675,13 +2985,14 @@ func (s *http2Server) ServeConn(c net.Conn, opts *http2ServeConnOpts) { serveG: http2newGoroutineLock(), pushEnabled: true, } + sc.flow.add(http2initialWindowSize) sc.inflow.add(http2initialWindowSize) sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf) - sc.hpackDecoder = hpack.NewDecoder(http2initialHeaderTableSize, nil) - sc.hpackDecoder.SetMaxStringLength(sc.maxHeaderStringLen()) fr := http2NewFramer(sc.bw, c) + fr.ReadMetaHeaders = hpack.NewDecoder(http2initialHeaderTableSize, nil) + fr.MaxHeaderListSize = sc.maxHeaderListSize() fr.SetMaxReadFrameSize(s.maxReadFrameSize()) sc.framer = fr @@ -2711,27 +3022,6 @@ func (s *http2Server) ServeConn(c net.Conn, opts *http2ServeConnOpts) { sc.serve() } -// isBadCipher reports whether the cipher is blacklisted by the HTTP/2 spec. -func http2isBadCipher(cipher uint16) bool { - switch cipher { - case tls.TLS_RSA_WITH_RC4_128_SHA, - tls.TLS_RSA_WITH_3DES_EDE_CBC_SHA, - tls.TLS_RSA_WITH_AES_128_CBC_SHA, - tls.TLS_RSA_WITH_AES_256_CBC_SHA, - tls.TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, - tls.TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, - tls.TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, - tls.TLS_ECDHE_RSA_WITH_RC4_128_SHA, - tls.TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, - tls.TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, - tls.TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA: - - return true - default: - return false - } -} - func (sc *http2serverConn) rejectConn(err http2ErrCode, debug string) { sc.vlogf("http2: server rejecting conn: %v, %s", err, debug) @@ -2747,8 +3037,8 @@ type http2serverConn struct { conn net.Conn bw *http2bufferedWriter // writing to conn handler Handler + baseCtx http2contextContext framer *http2Framer - hpackDecoder *hpack.Decoder doneServing chan struct{} // closed when serverConn.serve ends readFrameCh chan http2readFrameResult // written by serverConn.readFrames wantWriteFrameCh chan http2frameWriteMsg // from handlers -> serve @@ -2775,7 +3065,6 @@ type http2serverConn struct { headerTableSize uint32 peerMaxHeaderListSize uint32 // zero means unknown (default) canonHeader map[string]string // http2-lower-case -> Go-Canonical-Case - req http2requestParam // non-zero while reading request headers writingFrame bool // started write goroutine but haven't heard back on wroteFrameCh needsFrameFlush bool // last frame write wasn't a flush writeSched http2writeScheduler @@ -2784,21 +3073,13 @@ type http2serverConn struct { goAwayCode http2ErrCode shutdownTimerCh <-chan time.Time // nil until used shutdownTimer *time.Timer // nil until used + freeRequestBodyBuf []byte // if non-nil, a free initialWindowSize buffer for getRequestBodyBuf // Owned by the writeFrameAsync goroutine: headerWriteBuf bytes.Buffer hpackEncoder *hpack.Encoder } -func (sc *http2serverConn) maxHeaderStringLen() int { - v := sc.maxHeaderListSize() - if uint32(int(v)) == v { - return int(v) - } - - return 0 -} - func (sc *http2serverConn) maxHeaderListSize() uint32 { n := sc.hs.MaxHeaderBytes if n <= 0 { @@ -2811,21 +3092,6 @@ func (sc *http2serverConn) maxHeaderListSize() uint32 { return uint32(n + typicalHeaders*perFieldOverhead) } -// requestParam is the state of the next request, initialized over -// potentially several frames HEADERS + zero or more CONTINUATION -// frames. -type http2requestParam struct { - // stream is non-nil if we're reading (HEADER or CONTINUATION) - // frames for a request (but not DATA). - stream *http2stream - header Header - method, path string - scheme, authority string - sawRegularHeader bool // saw a non-pseudo header already - invalidHeader bool // an invalid header was seen - headerListSize int64 // actually uint32, but easier math this way -} - // stream represents a stream. This is the minimal metadata needed by // the serve goroutine. Most of the actual stream state is owned by // the http.Handler's goroutine in the responseWriter. Because the @@ -2835,10 +3101,12 @@ type http2requestParam struct { // responseWriter's state field. type http2stream struct { // immutable: - sc *http2serverConn - id uint32 - body *http2pipe // non-nil if expecting DATA frames - cw http2closeWaiter // closed wait stream transitions to closed state + sc *http2serverConn + id uint32 + body *http2pipe // non-nil if expecting DATA frames + cw http2closeWaiter // closed wait stream transitions to closed state + ctx http2contextContext + cancelCtx func() // owned by serverConn's serve loop: bodyBytes int64 // body bytes seen so far @@ -2852,6 +3120,8 @@ type http2stream struct { sentReset bool // only true once detached from streams map gotReset bool // only true once detacted from streams map gotTrailerHeader bool // HEADER frame for trailers was seen + wroteHeaders bool // whether we wrote headers (not status 100) + reqBuf []byte trailer Header // accumulated trailers reqTrailer Header // handler's Request.Trailer @@ -2952,83 +3222,6 @@ func (sc *http2serverConn) condlogf(err error, format string, args ...interface{ } } -func (sc *http2serverConn) onNewHeaderField(f hpack.HeaderField) { - sc.serveG.check() - if http2VerboseLogs { - sc.vlogf("http2: server decoded %v", f) - } - switch { - case !http2validHeaderFieldValue(f.Value): - sc.req.invalidHeader = true - case strings.HasPrefix(f.Name, ":"): - if sc.req.sawRegularHeader { - sc.logf("pseudo-header after regular header") - sc.req.invalidHeader = true - return - } - var dst *string - switch f.Name { - case ":method": - dst = &sc.req.method - case ":path": - dst = &sc.req.path - case ":scheme": - dst = &sc.req.scheme - case ":authority": - dst = &sc.req.authority - default: - - sc.logf("invalid pseudo-header %q", f.Name) - sc.req.invalidHeader = true - return - } - if *dst != "" { - sc.logf("duplicate pseudo-header %q sent", f.Name) - sc.req.invalidHeader = true - return - } - *dst = f.Value - case !http2validHeaderFieldName(f.Name): - sc.req.invalidHeader = true - default: - sc.req.sawRegularHeader = true - sc.req.header.Add(sc.canonicalHeader(f.Name), f.Value) - const headerFieldOverhead = 32 // per spec - sc.req.headerListSize += int64(len(f.Name)) + int64(len(f.Value)) + headerFieldOverhead - if sc.req.headerListSize > int64(sc.maxHeaderListSize()) { - sc.hpackDecoder.SetEmitEnabled(false) - } - } -} - -func (st *http2stream) onNewTrailerField(f hpack.HeaderField) { - sc := st.sc - sc.serveG.check() - if http2VerboseLogs { - sc.vlogf("http2: server decoded trailer %v", f) - } - switch { - case strings.HasPrefix(f.Name, ":"): - sc.req.invalidHeader = true - return - case !http2validHeaderFieldName(f.Name) || !http2validHeaderFieldValue(f.Value): - sc.req.invalidHeader = true - return - default: - key := sc.canonicalHeader(f.Name) - if st.trailer != nil { - vv := append(st.trailer[key], f.Value) - st.trailer[key] = vv - - // arbitrary; TODO: read spec about header list size limits wrt trailers - const tooBig = 1000 - if len(vv) >= tooBig { - sc.hpackDecoder.SetEmitEnabled(false) - } - } - } -} - func (sc *http2serverConn) canonicalHeader(v string) string { sc.serveG.check() cv, ok := http2commonCanonHeader[v] @@ -3063,10 +3256,11 @@ type http2readFrameResult struct { // It's run on its own goroutine. func (sc *http2serverConn) readFrames() { gate := make(http2gate) + gateDone := gate.Done for { f, err := sc.framer.ReadFrame() select { - case sc.readFrameCh <- http2readFrameResult{f, err, gate.Done}: + case sc.readFrameCh <- http2readFrameResult{f, err, gateDone}: case <-sc.doneServing: return } @@ -3290,7 +3484,21 @@ func (sc *http2serverConn) writeFrameFromHandler(wm http2frameWriteMsg) error { // If you're not on the serve goroutine, use writeFrameFromHandler instead. func (sc *http2serverConn) writeFrame(wm http2frameWriteMsg) { sc.serveG.check() - sc.writeSched.add(wm) + + var ignoreWrite bool + + switch wm.write.(type) { + case *http2writeResHeaders: + wm.stream.wroteHeaders = true + case http2write100ContinueHeadersFrame: + if wm.stream.wroteHeaders { + ignoreWrite = true + } + } + + if !ignoreWrite { + sc.writeSched.add(wm) + } sc.scheduleFrameWrite() } @@ -3511,10 +3719,8 @@ func (sc *http2serverConn) processFrame(f http2Frame) error { switch f := f.(type) { case *http2SettingsFrame: return sc.processSettings(f) - case *http2HeadersFrame: + case *http2MetaHeadersFrame: return sc.processHeaders(f) - case *http2ContinuationFrame: - return sc.processContinuation(f) case *http2WindowUpdateFrame: return sc.processWindowUpdate(f) case *http2PingFrame: @@ -3579,6 +3785,7 @@ func (sc *http2serverConn) processResetStream(f *http2RSTStreamFrame) error { } if st != nil { st.gotReset = true + st.cancelCtx() sc.closeStream(st, http2StreamError{f.StreamID, f.ErrCode}) } return nil @@ -3600,6 +3807,10 @@ func (sc *http2serverConn) closeStream(st *http2stream, err error) { } st.cw.Close() sc.writeSched.forgetStream(st.id) + if st.reqBuf != nil { + + sc.freeRequestBodyBuf = st.reqBuf + } } func (sc *http2serverConn) processSettings(f *http2SettingsFrame) error { @@ -3732,7 +3943,7 @@ func (st *http2stream) copyTrailersToHandlerRequest() { } } -func (sc *http2serverConn) processHeaders(f *http2HeadersFrame) error { +func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error { sc.serveG.check() id := f.Header().StreamID if sc.inGoAway { @@ -3749,17 +3960,18 @@ func (sc *http2serverConn) processHeaders(f *http2HeadersFrame) error { return st.processTrailerHeaders(f) } - if id <= sc.maxStreamID || sc.req.stream != nil { + if id <= sc.maxStreamID { return http2ConnectionError(http2ErrCodeProtocol) } + sc.maxStreamID = id - if id > sc.maxStreamID { - sc.maxStreamID = id - } + ctx, cancelCtx := http2contextWithCancel(sc.baseCtx) st = &http2stream{ - sc: sc, - id: id, - state: http2stateOpen, + sc: sc, + id: id, + state: http2stateOpen, + ctx: ctx, + cancelCtx: cancelCtx, } if f.StreamEnded() { st.state = http2stateHalfClosedRemote @@ -3779,50 +3991,6 @@ func (sc *http2serverConn) processHeaders(f *http2HeadersFrame) error { if sc.curOpenStreams == 1 { sc.setConnState(StateActive) } - sc.req = http2requestParam{ - stream: st, - header: make(Header), - } - sc.hpackDecoder.SetEmitFunc(sc.onNewHeaderField) - sc.hpackDecoder.SetEmitEnabled(true) - return sc.processHeaderBlockFragment(st, f.HeaderBlockFragment(), f.HeadersEnded()) -} - -func (st *http2stream) processTrailerHeaders(f *http2HeadersFrame) error { - sc := st.sc - sc.serveG.check() - if st.gotTrailerHeader { - return http2ConnectionError(http2ErrCodeProtocol) - } - st.gotTrailerHeader = true - if !f.StreamEnded() { - return http2StreamError{st.id, http2ErrCodeProtocol} - } - sc.resetPendingRequest() - return st.processTrailerHeaderBlockFragment(f.HeaderBlockFragment(), f.HeadersEnded()) -} - -func (sc *http2serverConn) processContinuation(f *http2ContinuationFrame) error { - sc.serveG.check() - st := sc.streams[f.Header().StreamID] - if st.gotTrailerHeader { - return st.processTrailerHeaderBlockFragment(f.HeaderBlockFragment(), f.HeadersEnded()) - } - return sc.processHeaderBlockFragment(st, f.HeaderBlockFragment(), f.HeadersEnded()) -} - -func (sc *http2serverConn) processHeaderBlockFragment(st *http2stream, frag []byte, end bool) error { - sc.serveG.check() - if _, err := sc.hpackDecoder.Write(frag); err != nil { - return http2ConnectionError(http2ErrCodeCompression) - } - if !end { - return nil - } - if err := sc.hpackDecoder.Close(); err != nil { - return http2ConnectionError(http2ErrCodeCompression) - } - defer sc.resetPendingRequest() if sc.curOpenStreams > sc.advMaxStreams { if sc.unackedSettings == 0 { @@ -3833,7 +4001,7 @@ func (sc *http2serverConn) processHeaderBlockFragment(st *http2stream, frag []by return http2StreamError{st.id, http2ErrCodeRefusedStream} } - rw, req, err := sc.newWriterAndRequest() + rw, req, err := sc.newWriterAndRequest(st, f) if err != nil { return err } @@ -3845,36 +4013,42 @@ func (sc *http2serverConn) processHeaderBlockFragment(st *http2stream, frag []by st.declBodyBytes = req.ContentLength handler := sc.handler.ServeHTTP - if !sc.hpackDecoder.EmitEnabled() { + if f.Truncated { handler = http2handleHeaderListTooLong + } else if err := http2checkValidHTTP2Request(req); err != nil { + handler = http2new400Handler(err) } go sc.runHandler(rw, req, handler) return nil } -func (st *http2stream) processTrailerHeaderBlockFragment(frag []byte, end bool) error { +func (st *http2stream) processTrailerHeaders(f *http2MetaHeadersFrame) error { sc := st.sc sc.serveG.check() - sc.hpackDecoder.SetEmitFunc(st.onNewTrailerField) - if _, err := sc.hpackDecoder.Write(frag); err != nil { - return http2ConnectionError(http2ErrCodeCompression) + if st.gotTrailerHeader { + return http2ConnectionError(http2ErrCodeProtocol) } - if !end { - return nil + st.gotTrailerHeader = true + if !f.StreamEnded() { + return http2StreamError{st.id, http2ErrCodeProtocol} } - rp := &sc.req - if rp.invalidHeader { - return http2StreamError{rp.stream.id, http2ErrCodeProtocol} + if len(f.PseudoFields()) > 0 { + return http2StreamError{st.id, http2ErrCodeProtocol} } + if st.trailer != nil { + for _, hf := range f.RegularFields() { + key := sc.canonicalHeader(hf.Name) + if !http2ValidTrailerHeader(key) { - err := sc.hpackDecoder.Close() - st.endStream() - if err != nil { - return http2ConnectionError(http2ErrCodeCompression) + return http2StreamError{st.id, http2ErrCodeProtocol} + } + st.trailer[key] = append(st.trailer[key], hf.Value) + } } + st.endStream() return nil } @@ -3912,59 +4086,56 @@ func http2adjustStreamPriority(streams map[uint32]*http2stream, streamID uint32, } } -// resetPendingRequest zeros out all state related to a HEADERS frame -// and its zero or more CONTINUATION frames sent to start a new -// request. -func (sc *http2serverConn) resetPendingRequest() { +func (sc *http2serverConn) newWriterAndRequest(st *http2stream, f *http2MetaHeadersFrame) (*http2responseWriter, *Request, error) { sc.serveG.check() - sc.req = http2requestParam{} -} -func (sc *http2serverConn) newWriterAndRequest() (*http2responseWriter, *Request, error) { - sc.serveG.check() - rp := &sc.req - - if rp.invalidHeader { - return nil, nil, http2StreamError{rp.stream.id, http2ErrCodeProtocol} - } + method := f.PseudoValue("method") + path := f.PseudoValue("path") + scheme := f.PseudoValue("scheme") + authority := f.PseudoValue("authority") - isConnect := rp.method == "CONNECT" + isConnect := method == "CONNECT" if isConnect { - if rp.path != "" || rp.scheme != "" || rp.authority == "" { - return nil, nil, http2StreamError{rp.stream.id, http2ErrCodeProtocol} + if path != "" || scheme != "" || authority == "" { + return nil, nil, http2StreamError{f.StreamID, http2ErrCodeProtocol} } - } else if rp.method == "" || rp.path == "" || - (rp.scheme != "https" && rp.scheme != "http") { + } else if method == "" || path == "" || + (scheme != "https" && scheme != "http") { - return nil, nil, http2StreamError{rp.stream.id, http2ErrCodeProtocol} + return nil, nil, http2StreamError{f.StreamID, http2ErrCodeProtocol} } - bodyOpen := rp.stream.state == http2stateOpen - if rp.method == "HEAD" && bodyOpen { + bodyOpen := !f.StreamEnded() + if method == "HEAD" && bodyOpen { - return nil, nil, http2StreamError{rp.stream.id, http2ErrCodeProtocol} + return nil, nil, http2StreamError{f.StreamID, http2ErrCodeProtocol} } var tlsState *tls.ConnectionState // nil if not scheme https - if rp.scheme == "https" { + if scheme == "https" { tlsState = sc.tlsState } - authority := rp.authority + + header := make(Header) + for _, hf := range f.RegularFields() { + header.Add(sc.canonicalHeader(hf.Name), hf.Value) + } + if authority == "" { - authority = rp.header.Get("Host") + authority = header.Get("Host") } - needsContinue := rp.header.Get("Expect") == "100-continue" + needsContinue := header.Get("Expect") == "100-continue" if needsContinue { - rp.header.Del("Expect") + header.Del("Expect") } - if cookies := rp.header["Cookie"]; len(cookies) > 1 { - rp.header.Set("Cookie", strings.Join(cookies, "; ")) + if cookies := header["Cookie"]; len(cookies) > 1 { + header.Set("Cookie", strings.Join(cookies, "; ")) } // Setup Trailers var trailer Header - for _, v := range rp.header["Trailer"] { + for _, v := range header["Trailer"] { for _, key := range strings.Split(v, ",") { key = CanonicalHeaderKey(strings.TrimSpace(key)) switch key { @@ -3978,31 +4149,31 @@ func (sc *http2serverConn) newWriterAndRequest() (*http2responseWriter, *Request } } } - delete(rp.header, "Trailer") + delete(header, "Trailer") body := &http2requestBody{ conn: sc, - stream: rp.stream, + stream: st, needsContinue: needsContinue, } var url_ *url.URL var requestURI string if isConnect { - url_ = &url.URL{Host: rp.authority} - requestURI = rp.authority + url_ = &url.URL{Host: authority} + requestURI = authority } else { var err error - url_, err = url.ParseRequestURI(rp.path) + url_, err = url.ParseRequestURI(path) if err != nil { - return nil, nil, http2StreamError{rp.stream.id, http2ErrCodeProtocol} + return nil, nil, http2StreamError{f.StreamID, http2ErrCodeProtocol} } - requestURI = rp.path + requestURI = path } req := &Request{ - Method: rp.method, + Method: method, URL: url_, RemoteAddr: sc.remoteAddrStr, - Header: rp.header, + Header: header, RequestURI: requestURI, Proto: "HTTP/2.0", ProtoMajor: 2, @@ -4012,12 +4183,16 @@ func (sc *http2serverConn) newWriterAndRequest() (*http2responseWriter, *Request Body: body, Trailer: trailer, } + req = http2requestWithContext(req, st.ctx) if bodyOpen { + + buf := make([]byte, http2initialWindowSize) + body.pipe = &http2pipe{ - b: &http2fixedBuffer{buf: make([]byte, http2initialWindowSize)}, + b: &http2fixedBuffer{buf: buf}, } - if vv, ok := rp.header["Content-Length"]; ok { + if vv, ok := header["Content-Length"]; ok { req.ContentLength, _ = strconv.ParseInt(vv[0], 10, 64) } else { req.ContentLength = -1 @@ -4030,7 +4205,7 @@ func (sc *http2serverConn) newWriterAndRequest() (*http2responseWriter, *Request rws.conn = sc rws.bw = bwSave rws.bw.Reset(http2chunkWriter{rws}) - rws.stream = rp.stream + rws.stream = st rws.req = req rws.body = body @@ -4038,10 +4213,20 @@ func (sc *http2serverConn) newWriterAndRequest() (*http2responseWriter, *Request return rw, req, nil } +func (sc *http2serverConn) getRequestBodyBuf() []byte { + sc.serveG.check() + if buf := sc.freeRequestBodyBuf; buf != nil { + sc.freeRequestBodyBuf = nil + return buf + } + return make([]byte, http2initialWindowSize) +} + // Run on its own goroutine. func (sc *http2serverConn) runHandler(rw *http2responseWriter, req *Request, handler func(ResponseWriter, *Request)) { didPanic := true defer func() { + rw.rws.stream.cancelCtx() if didPanic { e := recover() // Same as net/http: @@ -4190,7 +4375,7 @@ type http2requestBody struct { func (b *http2requestBody) Close() error { if b.pipe != nil { - b.pipe.CloseWithError(http2errClosedBody) + b.pipe.BreakWithError(http2errClosedBody) } b.closed = true return nil @@ -4265,9 +4450,9 @@ func (rws *http2responseWriterState) hasTrailers() bool { return len(rws.trailer // written in the trailers at the end of the response. func (rws *http2responseWriterState) declareTrailer(k string) { k = CanonicalHeaderKey(k) - switch k { - case "Transfer-Encoding", "Content-Length", "Trailer": + if !http2ValidTrailerHeader(k) { + rws.conn.logf("ignoring invalid trailer %q", k) return } if !http2strSliceContains(rws.trailers, k) { @@ -4408,7 +4593,12 @@ func (rws *http2responseWriterState) promoteUndeclaredTrailers() { rws.declareTrailer(trailerKey) rws.handlerHeader[CanonicalHeaderKey(trailerKey)] = vv } - sort.Strings(rws.trailers) + + if len(rws.trailers) > 1 { + sorter := http2sorterPool.Get().(*http2sorter) + sorter.SortStrings(rws.trailers) + http2sorterPool.Put(sorter) + } } func (w *http2responseWriter) Flush() { @@ -4552,6 +4742,72 @@ func http2foreachHeaderElement(v string, fn func(string)) { } } +// From http://httpwg.org/specs/rfc7540.html#rfc.section.8.1.2.2 +var http2connHeaders = []string{ + "Connection", + "Keep-Alive", + "Proxy-Connection", + "Transfer-Encoding", + "Upgrade", +} + +// checkValidHTTP2Request checks whether req is a valid HTTP/2 request, +// per RFC 7540 Section 8.1.2.2. +// The returned error is reported to users. +func http2checkValidHTTP2Request(req *Request) error { + for _, h := range http2connHeaders { + if _, ok := req.Header[h]; ok { + return fmt.Errorf("request header %q is not valid in HTTP/2", h) + } + } + te := req.Header["Te"] + if len(te) > 0 && (len(te) > 1 || (te[0] != "trailers" && te[0] != "")) { + return errors.New(`request header "TE" may only be "trailers" in HTTP/2`) + } + return nil +} + +func http2new400Handler(err error) HandlerFunc { + return func(w ResponseWriter, r *Request) { + Error(w, err.Error(), StatusBadRequest) + } +} + +// ValidTrailerHeader reports whether name is a valid header field name to appear +// in trailers. +// See: http://tools.ietf.org/html/rfc7230#section-4.1.2 +func http2ValidTrailerHeader(name string) bool { + name = CanonicalHeaderKey(name) + if strings.HasPrefix(name, "If-") || http2badTrailer[name] { + return false + } + return true +} + +var http2badTrailer = map[string]bool{ + "Authorization": true, + "Cache-Control": true, + "Connection": true, + "Content-Encoding": true, + "Content-Length": true, + "Content-Range": true, + "Content-Type": true, + "Expect": true, + "Host": true, + "Keep-Alive": true, + "Max-Forwards": true, + "Pragma": true, + "Proxy-Authenticate": true, + "Proxy-Authorization": true, + "Proxy-Connection": true, + "Range": true, + "Realm": true, + "Te": true, + "Trailer": true, + "Transfer-Encoding": true, + "Www-Authenticate": true, +} + const ( // transportDefaultConnFlow is how many connection-level flow control // tokens we give the server at start-up, past the default 64k. @@ -4601,6 +4857,10 @@ type http2Transport struct { // uncompressed. DisableCompression bool + // AllowHTTP, if true, permits HTTP/2 requests using the insecure, + // plain-text "http" scheme. Note that this does not enable h2c support. + AllowHTTP bool + // MaxHeaderListSize is the http2 SETTINGS_MAX_HEADER_LIST_SIZE to // send in the initial settings frame. It is how many bytes // of response headers are allow. Unlike the http2 spec, zero here @@ -4673,11 +4933,14 @@ type http2ClientConn struct { inflow http2flow // peer's conn-level flow control closed bool goAway *http2GoAwayFrame // if non-nil, the GoAwayFrame we received + goAwayDebug string // goAway frame's debug data, retained as a string streams map[uint32]*http2clientStream // client-initiated nextStreamID uint32 bw *bufio.Writer br *bufio.Reader fr *http2Framer + lastActive time.Time + // Settings from peer: maxFrameSize uint32 maxConcurrentStreams uint32 @@ -4695,10 +4958,12 @@ type http2ClientConn struct { type http2clientStream struct { cc *http2ClientConn req *Request + trace *http2clientTrace // or nil ID uint32 resc chan http2resAndError bufPipe http2pipe // buffered pipe with the flow-controlled response payload requestedGzip bool + on100 func() // optional code to run if get a 100 continue response flow http2flow // guarded by cc.mu inflow http2flow // guarded by cc.mu @@ -4712,36 +4977,43 @@ type http2clientStream struct { done chan struct{} // closed when stream remove from cc.streams map; close calls guarded by cc.mu // owned by clientConnReadLoop: - pastHeaders bool // got HEADERS w/ END_HEADERS - pastTrailers bool // got second HEADERS frame w/ END_HEADERS + firstByte bool // got the first response byte + pastHeaders bool // got first MetaHeadersFrame (actual headers) + pastTrailers bool // got optional second MetaHeadersFrame (trailers) trailer Header // accumulated trailers resTrailer *Header // client's Response.Trailer } // awaitRequestCancel runs in its own goroutine and waits for the user -// to either cancel a RoundTrip request (using the provided -// Request.Cancel channel), or for the request to be done (any way it -// might be removed from the cc.streams map: peer reset, successful -// completion, TCP connection breakage, etc) -func (cs *http2clientStream) awaitRequestCancel(cancel <-chan struct{}) { - if cancel == nil { +// to cancel a RoundTrip request, its context to expire, or for the +// request to be done (any way it might be removed from the cc.streams +// map: peer reset, successful completion, TCP connection breakage, +// etc) +func (cs *http2clientStream) awaitRequestCancel(req *Request) { + ctx := http2reqContext(req) + if req.Cancel == nil && ctx.Done() == nil { return } select { - case <-cancel: + case <-req.Cancel: cs.bufPipe.CloseWithError(http2errRequestCanceled) cs.cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) + case <-ctx.Done(): + cs.bufPipe.CloseWithError(ctx.Err()) + cs.cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) case <-cs.done: } } -// checkReset reports any error sent in a RST_STREAM frame by the -// server. -func (cs *http2clientStream) checkReset() error { +// checkResetOrDone reports any error sent in a RST_STREAM frame by the +// server, or errStreamClosed if the stream is complete. +func (cs *http2clientStream) checkResetOrDone() error { select { case <-cs.peerReset: return cs.resetErr + case <-cs.done: + return http2errStreamClosed default: return nil } @@ -4789,26 +5061,31 @@ func (t *http2Transport) RoundTrip(req *Request) (*Response, error) { // authorityAddr returns a given authority (a host/IP, or host:port / ip:port) // and returns a host:port. The port 443 is added if needed. -func http2authorityAddr(authority string) (addr string) { +func http2authorityAddr(scheme string, authority string) (addr string) { if _, _, err := net.SplitHostPort(authority); err == nil { return authority } - return net.JoinHostPort(authority, "443") + port := "443" + if scheme == "http" { + port = "80" + } + return net.JoinHostPort(authority, port) } // RoundTripOpt is like RoundTrip, but takes options. func (t *http2Transport) RoundTripOpt(req *Request, opt http2RoundTripOpt) (*Response, error) { - if req.URL.Scheme != "https" { + if !(req.URL.Scheme == "https" || (req.URL.Scheme == "http" && t.AllowHTTP)) { return nil, errors.New("http2: unsupported scheme") } - addr := http2authorityAddr(req.URL.Host) + addr := http2authorityAddr(req.URL.Scheme, req.URL.Host) for { cc, err := t.connPool().GetClientConn(req, addr) if err != nil { t.vlogf("http2: Transport failed to get client conn for %s: %v", addr, err) return nil, err } + http2traceGotConn(req, cc) res, err := cc.RoundTrip(req) if http2shouldRetryRequest(req, err) { continue @@ -4825,7 +5102,7 @@ func (t *http2Transport) RoundTripOpt(req *Request, opt http2RoundTripOpt) (*Res // connected from previous requests but are now sitting idle. // It does not interrupt any connections currently in use. func (t *http2Transport) CloseIdleConnections() { - if cp, ok := t.connPool().(*http2clientConnPool); ok { + if cp, ok := t.connPool().(http2clientConnPoolIdleCloser); ok { cp.closeIdleConnections() } } @@ -4857,8 +5134,12 @@ func (t *http2Transport) newTLSConfig(host string) *tls.Config { if t.TLSClientConfig != nil { *cfg = *t.TLSClientConfig } - cfg.NextProtos = []string{http2NextProtoTLS} - cfg.ServerName = host + if !http2strSliceContains(cfg.NextProtos, http2NextProtoTLS) { + cfg.NextProtos = append([]string{http2NextProtoTLS}, cfg.NextProtos...) + } + if cfg.ServerName == "" { + cfg.ServerName = host + } return cfg } @@ -4898,6 +5179,13 @@ func (t *http2Transport) disableKeepAlives() bool { return t.t1 != nil && t.t1.DisableKeepAlives } +func (t *http2Transport) expectContinueTimeout() time.Duration { + if t.t1 == nil { + return 0 + } + return http2transportExpectContinueTimeout(t.t1) +} + func (t *http2Transport) NewClientConn(c net.Conn) (*http2ClientConn, error) { if http2VerboseLogs { t.vlogf("http2: Transport creating client conn to %v", c.RemoteAddr()) @@ -4923,6 +5211,8 @@ func (t *http2Transport) NewClientConn(c net.Conn) (*http2ClientConn, error) { cc.bw = bufio.NewWriter(http2stickyErrWriter{c, &cc.werr}) cc.br = bufio.NewReader(c) cc.fr = http2NewFramer(cc.bw, cc.br) + cc.fr.ReadMetaHeaders = hpack.NewDecoder(http2initialHeaderTableSize, nil) + cc.fr.MaxHeaderListSize = t.maxHeaderListSize() cc.henc = hpack.NewEncoder(&cc.hbuf) @@ -4932,8 +5222,8 @@ func (t *http2Transport) NewClientConn(c net.Conn) (*http2ClientConn, error) { } initialSettings := []http2Setting{ - http2Setting{ID: http2SettingEnablePush, Val: 0}, - http2Setting{ID: http2SettingInitialWindowSize, Val: http2transportDefaultStreamFlow}, + {ID: http2SettingEnablePush, Val: 0}, + {ID: http2SettingInitialWindowSize, Val: http2transportDefaultStreamFlow}, } if max := t.maxHeaderListSize(); max != 0 { initialSettings = append(initialSettings, http2Setting{ID: http2SettingMaxHeaderListSize, Val: max}) @@ -4979,7 +5269,16 @@ func (t *http2Transport) NewClientConn(c net.Conn) (*http2ClientConn, error) { func (cc *http2ClientConn) setGoAway(f *http2GoAwayFrame) { cc.mu.Lock() defer cc.mu.Unlock() + + old := cc.goAway cc.goAway = f + + if cc.goAwayDebug == "" { + cc.goAwayDebug = string(f.DebugData()) + } + if old != nil && old.ErrCode != http2ErrCodeNo { + cc.goAway.ErrCode = old.ErrCode + } } func (cc *http2ClientConn) CanTakeNewRequest() bool { @@ -5093,6 +5392,30 @@ func http2checkConnHeaders(req *Request) error { return nil } +func http2bodyAndLength(req *Request) (body io.Reader, contentLen int64) { + body = req.Body + if body == nil { + return nil, 0 + } + if req.ContentLength != 0 { + return req.Body, req.ContentLength + } + + // We have a body but a zero content length. Test to see if + // it's actually zero or just unset. + var buf [1]byte + n, rerr := io.ReadFull(body, buf[:]) + if rerr != nil && rerr != io.EOF { + return http2errorReader{rerr}, -1 + } + if n == 1 { + + return io.MultiReader(bytes.NewReader(buf[:]), body), -1 + } + + return nil, 0 +} + func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { if err := http2checkConnHeaders(req); err != nil { return nil, err @@ -5104,67 +5427,62 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { } hasTrailers := trailers != "" - var body io.Reader = req.Body - contentLen := req.ContentLength - if req.Body != nil && contentLen == 0 { - // Test to see if it's actually zero or just unset. - var buf [1]byte - n, rerr := io.ReadFull(body, buf[:]) - if rerr != nil && rerr != io.EOF { - contentLen = -1 - body = http2errorReader{rerr} - } else if n == 1 { - - contentLen = -1 - body = io.MultiReader(bytes.NewReader(buf[:]), body) - } else { - - body = nil - } - } + body, contentLen := http2bodyAndLength(req) + hasBody := body != nil cc.mu.Lock() + cc.lastActive = time.Now() if cc.closed || !cc.canTakeNewRequestLocked() { cc.mu.Unlock() return nil, http2errClientConnUnusable } - cs := cc.newStream() - cs.req = req - hasBody := body != nil - + // TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere? + var requestedGzip bool if !cc.t.disableCompression() && req.Header.Get("Accept-Encoding") == "" && req.Header.Get("Range") == "" && req.Method != "HEAD" { - cs.requestedGzip = true + requestedGzip = true } - hdrs := cc.encodeHeaders(req, cs.requestedGzip, trailers, contentLen) + hdrs, err := cc.encodeHeaders(req, requestedGzip, trailers, contentLen) + if err != nil { + cc.mu.Unlock() + return nil, err + } + + cs := cc.newStream() + cs.req = req + cs.trace = http2requestTrace(req) + cs.requestedGzip = requestedGzip + bodyWriter := cc.t.getBodyWriterState(cs, body) + cs.on100 = bodyWriter.on100 + cc.wmu.Lock() endStream := !hasBody && !hasTrailers werr := cc.writeHeaders(cs.ID, endStream, hdrs) cc.wmu.Unlock() + http2traceWroteHeaders(cs.trace) cc.mu.Unlock() if werr != nil { if hasBody { req.Body.Close() + bodyWriter.cancel() } cc.forgetStreamID(cs.ID) + http2traceWroteRequest(cs.trace, werr) return nil, werr } var respHeaderTimer <-chan time.Time - var bodyCopyErrc chan error // result of body copy if hasBody { - bodyCopyErrc = make(chan error, 1) - go func() { - bodyCopyErrc <- cs.writeRequestBody(body, req.Body) - }() + bodyWriter.scheduleBodyWrite() } else { + http2traceWroteRequest(cs.trace, nil) if d := cc.responseHeaderTimeout(); d != 0 { timer := time.NewTimer(d) defer timer.Stop() @@ -5173,44 +5491,78 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { } readLoopResCh := cs.resc - requestCanceledCh := http2requestCancel(req) bodyWritten := false + ctx := http2reqContext(req) + + reFunc := func(re http2resAndError) (*Response, error) { + res := re.res + if re.err != nil || res.StatusCode > 299 { + bodyWriter.cancel() + cs.abortRequestBodyWrite(http2errStopReqBodyWrite) + } + if re.err != nil { + cc.forgetStreamID(cs.ID) + return nil, re.err + } + res.Request = req + res.TLS = cc.tlsState + return res, nil + } for { select { case re := <-readLoopResCh: - res := re.res - if re.err != nil || res.StatusCode > 299 { - - cs.abortRequestBodyWrite(http2errStopReqBodyWrite) - } - if re.err != nil { - cc.forgetStreamID(cs.ID) - return nil, re.err - } - res.Request = req - res.TLS = cc.tlsState - return res, nil + return reFunc(re) case <-respHeaderTimer: cc.forgetStreamID(cs.ID) if !hasBody || bodyWritten { cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) } else { + bodyWriter.cancel() cs.abortRequestBodyWrite(http2errStopReqBodyWriteAndCancel) } return nil, http2errTimeout - case <-requestCanceledCh: + case <-ctx.Done(): + select { + case re := <-readLoopResCh: + return reFunc(re) + default: + } + cc.forgetStreamID(cs.ID) + if !hasBody || bodyWritten { + cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) + } else { + bodyWriter.cancel() + cs.abortRequestBodyWrite(http2errStopReqBodyWriteAndCancel) + } + return nil, ctx.Err() + case <-req.Cancel: + select { + case re := <-readLoopResCh: + return reFunc(re) + default: + } cc.forgetStreamID(cs.ID) if !hasBody || bodyWritten { cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) } else { + bodyWriter.cancel() cs.abortRequestBodyWrite(http2errStopReqBodyWriteAndCancel) } return nil, http2errRequestCanceled case <-cs.peerReset: - + select { + case re := <-readLoopResCh: + return reFunc(re) + default: + } return nil, cs.resetErr - case err := <-bodyCopyErrc: + case err := <-bodyWriter.resc: + select { + case re := <-readLoopResCh: + return reFunc(re) + default: + } if err != nil { return nil, err } @@ -5268,6 +5620,7 @@ func (cs *http2clientStream) writeRequestBody(body io.Reader, bodyCloser io.Clos defer cc.putFrameScratchBuffer(buf) defer func() { + http2traceWroteRequest(cs.trace, err) cerr := bodyCloser.Close() if err == nil { @@ -5355,7 +5708,7 @@ func (cs *http2clientStream) awaitFlowControl(maxBytes int) (taken int32, err er if cs.stopReqBody != nil { return 0, cs.stopReqBody } - if err := cs.checkReset(); err != nil { + if err := cs.checkResetOrDone(); err != nil { return 0, err } if a := cs.flow.available(); a > 0 { @@ -5382,7 +5735,7 @@ type http2badStringError struct { func (e *http2badStringError) Error() string { return fmt.Sprintf("%s %q", e.what, e.str) } // requires cc.mu be held. -func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trailers string, contentLength int64) []byte { +func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trailers string, contentLength int64) ([]byte, error) { cc.hbuf.Reset() host := req.Host @@ -5390,6 +5743,17 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail host = req.URL.Host } + for k, vv := range req.Header { + if !httplex.ValidHeaderFieldName(k) { + return nil, fmt.Errorf("invalid HTTP header name %q", k) + } + for _, v := range vv { + if !httplex.ValidHeaderFieldValue(v) { + return nil, fmt.Errorf("invalid HTTP header value %q for header %q", v, k) + } + } + } + cc.writeHeader(":authority", host) cc.writeHeader(":method", req.Method) if req.Method != "CONNECT" { @@ -5407,7 +5771,7 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail case "host", "content-length": continue - case "connection", "proxy-connection", "transfer-encoding", "upgrade": + case "connection", "proxy-connection", "transfer-encoding", "upgrade", "keep-alive": continue case "user-agent": @@ -5434,7 +5798,7 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail if !didUA { cc.writeHeader("user-agent", http2defaultUserAgent) } - return cc.hbuf.Bytes() + return cc.hbuf.Bytes(), nil } // shouldSendReqContentLength reports whether the http2.Transport should send @@ -5510,8 +5874,10 @@ func (cc *http2ClientConn) streamByID(id uint32, andRemove bool) *http2clientStr defer cc.mu.Unlock() cs := cc.streams[id] if andRemove && cs != nil && !cc.closed { + cc.lastActive = time.Now() delete(cc.streams, id) close(cs.done) + cc.cond.Broadcast() } return cs } @@ -5521,15 +5887,6 @@ type http2clientConnReadLoop struct { cc *http2ClientConn activeRes map[uint32]*http2clientStream // keyed by streamID closeWhenIdle bool - - hdec *hpack.Decoder - - // Fields reset on each HEADERS: - nextRes *Response - sawRegHeader bool // saw non-pseudo header - reqMalformed error // non-nil once known to be malformed - lastHeaderEndsStream bool - headerListSize int64 // actually uint32, but easier math this way } // readLoop runs in its own goroutine and reads and dispatches frames. @@ -5538,7 +5895,6 @@ func (cc *http2ClientConn) readLoop() { cc: cc, activeRes: make(map[uint32]*http2clientStream), } - rl.hdec = hpack.NewDecoder(http2initialHeaderTableSize, rl.onNewHeaderField) defer rl.cleanup() cc.readerErr = rl.run() @@ -5549,6 +5905,19 @@ func (cc *http2ClientConn) readLoop() { } } +// GoAwayError is returned by the Transport when the server closes the +// TCP connection after sending a GOAWAY frame. +type http2GoAwayError struct { + LastStreamID uint32 + ErrCode http2ErrCode + DebugData string +} + +func (e http2GoAwayError) Error() string { + return fmt.Sprintf("http2: server sent GOAWAY and closed the connection; LastStreamID=%v, ErrCode=%v, debug=%q", + e.LastStreamID, e.ErrCode, e.DebugData) +} + func (rl *http2clientConnReadLoop) cleanup() { cc := rl.cc defer cc.tconn.Close() @@ -5556,10 +5925,18 @@ func (rl *http2clientConnReadLoop) cleanup() { defer close(cc.readerDone) err := cc.readerErr + cc.mu.Lock() if err == io.EOF { - err = io.ErrUnexpectedEOF + if cc.goAway != nil { + err = http2GoAwayError{ + LastStreamID: cc.goAway.LastStreamID, + ErrCode: cc.goAway.ErrCode, + DebugData: cc.goAwayDebug, + } + } else { + err = io.ErrUnexpectedEOF + } } - cc.mu.Lock() for _, cs := range rl.activeRes { cs.bufPipe.CloseWithError(err) } @@ -5585,8 +5962,10 @@ func (rl *http2clientConnReadLoop) run() error { cc.vlogf("Transport readFrame error: (%T) %v", err, err) } if se, ok := err.(http2StreamError); ok { - - return se + if cs := cc.streamByID(se.StreamID, true); cs != nil { + rl.endStreamError(cs, cc.fr.errDetail) + } + continue } else if err != nil { return err } @@ -5596,13 +5975,10 @@ func (rl *http2clientConnReadLoop) run() error { maybeIdle := false switch f := f.(type) { - case *http2HeadersFrame: + case *http2MetaHeadersFrame: err = rl.processHeaders(f) maybeIdle = true gotReply = true - case *http2ContinuationFrame: - err = rl.processContinuation(f) - maybeIdle = true case *http2DataFrame: err = rl.processData(f) maybeIdle = true @@ -5632,83 +6008,105 @@ func (rl *http2clientConnReadLoop) run() error { } } -func (rl *http2clientConnReadLoop) processHeaders(f *http2HeadersFrame) error { - rl.sawRegHeader = false - rl.reqMalformed = nil - rl.lastHeaderEndsStream = f.StreamEnded() - rl.headerListSize = 0 - rl.nextRes = &Response{ - Proto: "HTTP/2.0", - ProtoMajor: 2, - Header: make(Header), - } - rl.hdec.SetEmitEnabled(true) - return rl.processHeaderBlockFragment(f.HeaderBlockFragment(), f.StreamID, f.HeadersEnded()) -} - -func (rl *http2clientConnReadLoop) processContinuation(f *http2ContinuationFrame) error { - return rl.processHeaderBlockFragment(f.HeaderBlockFragment(), f.StreamID, f.HeadersEnded()) -} - -func (rl *http2clientConnReadLoop) processHeaderBlockFragment(frag []byte, streamID uint32, finalFrag bool) error { +func (rl *http2clientConnReadLoop) processHeaders(f *http2MetaHeadersFrame) error { cc := rl.cc - streamEnded := rl.lastHeaderEndsStream - cs := cc.streamByID(streamID, streamEnded && finalFrag) + cs := cc.streamByID(f.StreamID, f.StreamEnded()) if cs == nil { return nil } - if cs.pastHeaders { - rl.hdec.SetEmitFunc(func(f hpack.HeaderField) { rl.onNewTrailerField(cs, f) }) - } else { - rl.hdec.SetEmitFunc(rl.onNewHeaderField) - } - _, err := rl.hdec.Write(frag) - if err != nil { - return http2ConnectionError(http2ErrCodeCompression) - } - if finalFrag { - if err := rl.hdec.Close(); err != nil { - return http2ConnectionError(http2ErrCodeCompression) - } - } + if !cs.firstByte { + if cs.trace != nil { - if !finalFrag { - return nil + http2traceFirstResponseByte(cs.trace) + } + cs.firstByte = true } - if !cs.pastHeaders { cs.pastHeaders = true } else { + return rl.processTrailers(cs, f) + } - if cs.pastTrailers { - - return http2ConnectionError(http2ErrCodeProtocol) + res, err := rl.handleResponse(cs, f) + if err != nil { + if _, ok := err.(http2ConnectionError); ok { + return err } - cs.pastTrailers = true - if !streamEnded { - return http2ConnectionError(http2ErrCodeProtocol) - } - rl.endStream(cs) + cs.cc.writeStreamReset(f.StreamID, http2ErrCodeProtocol, err) + cs.resc <- http2resAndError{err: err} return nil } + if res == nil { - if rl.reqMalformed != nil { - cs.resc <- http2resAndError{err: rl.reqMalformed} - rl.cc.writeStreamReset(cs.ID, http2ErrCodeProtocol, rl.reqMalformed) return nil } + if res.Body != http2noBody { + rl.activeRes[cs.ID] = cs + } + cs.resTrailer = &res.Trailer + cs.resc <- http2resAndError{res: res} + return nil +} - res := rl.nextRes +// may return error types nil, or ConnectionError. Any other error value +// is a StreamError of type ErrCodeProtocol. The returned error in that case +// is the detail. +// +// As a special case, handleResponse may return (nil, nil) to skip the +// frame (currently only used for 100 expect continue). This special +// case is going away after Issue 13851 is fixed. +func (rl *http2clientConnReadLoop) handleResponse(cs *http2clientStream, f *http2MetaHeadersFrame) (*Response, error) { + if f.Truncated { + return nil, http2errResponseHeaderListSize + } - if res.StatusCode == 100 { + status := f.PseudoValue("status") + if status == "" { + return nil, errors.New("missing status pseudo header") + } + statusCode, err := strconv.Atoi(status) + if err != nil { + return nil, errors.New("malformed non-numeric status pseudo header") + } + if statusCode == 100 { + http2traceGot100Continue(cs.trace) + if cs.on100 != nil { + cs.on100() + } cs.pastHeaders = false - return nil + return nil, nil + } + + header := make(Header) + res := &Response{ + Proto: "HTTP/2.0", + ProtoMajor: 2, + Header: header, + StatusCode: statusCode, + Status: status + " " + StatusText(statusCode), + } + for _, hf := range f.RegularFields() { + key := CanonicalHeaderKey(hf.Name) + if key == "Trailer" { + t := res.Trailer + if t == nil { + t = make(Header) + res.Trailer = t + } + http2foreachHeaderElement(hf.Value, func(v string) { + t[CanonicalHeaderKey(v)] = nil + }) + } else { + header[key] = append(header[key], hf.Value) + } } - if !streamEnded || cs.req.Method == "HEAD" { + streamEnded := f.StreamEnded() + isHead := cs.req.Method == "HEAD" + if !streamEnded || isHead { res.ContentLength = -1 if clens := res.Header["Content-Length"]; len(clens) == 1 { if clen64, err := strconv.ParseInt(clens[0], 10, 64); err == nil { @@ -5721,27 +6119,50 @@ func (rl *http2clientConnReadLoop) processHeaderBlockFragment(frag []byte, strea } } - if streamEnded { + if streamEnded || isHead { res.Body = http2noBody - } else { - buf := new(bytes.Buffer) - cs.bufPipe = http2pipe{b: buf} - cs.bytesRemain = res.ContentLength - res.Body = http2transportResponseBody{cs} - go cs.awaitRequestCancel(http2requestCancel(cs.req)) - - if cs.requestedGzip && res.Header.Get("Content-Encoding") == "gzip" { - res.Header.Del("Content-Encoding") - res.Header.Del("Content-Length") - res.ContentLength = -1 - res.Body = &http2gzipReader{body: res.Body} - } - rl.activeRes[cs.ID] = cs + return res, nil } - cs.resTrailer = &res.Trailer - cs.resc <- http2resAndError{res: res} - rl.nextRes = nil + buf := new(bytes.Buffer) + cs.bufPipe = http2pipe{b: buf} + cs.bytesRemain = res.ContentLength + res.Body = http2transportResponseBody{cs} + go cs.awaitRequestCancel(cs.req) + + if cs.requestedGzip && res.Header.Get("Content-Encoding") == "gzip" { + res.Header.Del("Content-Encoding") + res.Header.Del("Content-Length") + res.ContentLength = -1 + res.Body = &http2gzipReader{body: res.Body} + http2setResponseUncompressed(res) + } + return res, nil +} + +func (rl *http2clientConnReadLoop) processTrailers(cs *http2clientStream, f *http2MetaHeadersFrame) error { + if cs.pastTrailers { + + return http2ConnectionError(http2ErrCodeProtocol) + } + cs.pastTrailers = true + if !f.StreamEnded() { + + return http2ConnectionError(http2ErrCodeProtocol) + } + if len(f.PseudoFields()) > 0 { + + return http2ConnectionError(http2ErrCodeProtocol) + } + + trailer := make(Header) + for _, hf := range f.RegularFields() { + key := CanonicalHeaderKey(hf.Name) + trailer[key] = append(trailer[key], hf.Value) + } + cs.trailer = trailer + + rl.endStream(cs) return nil } @@ -5792,8 +6213,10 @@ func (b http2transportResponseBody) Read(p []byte) (n int, err error) { cc.inflow.add(connAdd) } if err == nil { - if v := cs.inflow.available(); v < http2transportDefaultStreamFlow-http2transportDefaultStreamMinRefresh { - streamAdd = http2transportDefaultStreamFlow - v + + v := int(cs.inflow.available()) + cs.bufPipe.Len() + if v < http2transportDefaultStreamFlow-http2transportDefaultStreamMinRefresh { + streamAdd = int32(http2transportDefaultStreamFlow - v) cs.inflow.add(streamAdd) } } @@ -5855,6 +6278,7 @@ func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error { cc.mu.Unlock() if _, err := cs.bufPipe.Write(data); err != nil { + rl.endStreamError(cs, err) return err } } @@ -5869,11 +6293,14 @@ var http2errInvalidTrailers = errors.New("http2: invalid trailers") func (rl *http2clientConnReadLoop) endStream(cs *http2clientStream) { - err := io.EOF - code := cs.copyTrailers - if rl.reqMalformed != nil { - err = rl.reqMalformed - code = nil + rl.endStreamError(cs, nil) +} + +func (rl *http2clientConnReadLoop) endStreamError(cs *http2clientStream, err error) { + var code func() + if err == nil { + err = io.EOF + code = cs.copyTrailers } cs.bufPipe.closeWithErrorAndCode(err, code) delete(rl.activeRes, cs.ID) @@ -5997,113 +6424,6 @@ var ( http2errPseudoTrailers = errors.New("http2: invalid pseudo header in trailers") ) -func (rl *http2clientConnReadLoop) checkHeaderField(f hpack.HeaderField) bool { - if rl.reqMalformed != nil { - return false - } - - const headerFieldOverhead = 32 // per spec - rl.headerListSize += int64(len(f.Name)) + int64(len(f.Value)) + headerFieldOverhead - if max := rl.cc.t.maxHeaderListSize(); max != 0 && rl.headerListSize > int64(max) { - rl.hdec.SetEmitEnabled(false) - rl.reqMalformed = http2errResponseHeaderListSize - return false - } - - if !http2validHeaderFieldValue(f.Value) { - rl.reqMalformed = http2errInvalidHeaderFieldValue - return false - } - - isPseudo := strings.HasPrefix(f.Name, ":") - if isPseudo { - if rl.sawRegHeader { - rl.reqMalformed = errors.New("http2: invalid pseudo header after regular header") - return false - } - } else { - if !http2validHeaderFieldName(f.Name) { - rl.reqMalformed = http2errInvalidHeaderFieldName - return false - } - rl.sawRegHeader = true - } - - return true -} - -// onNewHeaderField runs on the readLoop goroutine whenever a new -// hpack header field is decoded. -func (rl *http2clientConnReadLoop) onNewHeaderField(f hpack.HeaderField) { - cc := rl.cc - if http2VerboseLogs { - cc.logf("http2: Transport decoded %v", f) - } - - if !rl.checkHeaderField(f) { - return - } - - isPseudo := strings.HasPrefix(f.Name, ":") - if isPseudo { - switch f.Name { - case ":status": - code, err := strconv.Atoi(f.Value) - if err != nil { - rl.reqMalformed = errors.New("http2: invalid :status") - return - } - rl.nextRes.Status = f.Value + " " + StatusText(code) - rl.nextRes.StatusCode = code - default: - - rl.reqMalformed = fmt.Errorf("http2: unknown response pseudo header %q", f.Name) - } - return - } - - key := CanonicalHeaderKey(f.Name) - if key == "Trailer" { - t := rl.nextRes.Trailer - if t == nil { - t = make(Header) - rl.nextRes.Trailer = t - } - http2foreachHeaderElement(f.Value, func(v string) { - t[CanonicalHeaderKey(v)] = nil - }) - } else { - rl.nextRes.Header.Add(key, f.Value) - } -} - -func (rl *http2clientConnReadLoop) onNewTrailerField(cs *http2clientStream, f hpack.HeaderField) { - if http2VerboseLogs { - rl.cc.logf("http2: Transport decoded trailer %v", f) - } - if !rl.checkHeaderField(f) { - return - } - if strings.HasPrefix(f.Name, ":") { - - rl.reqMalformed = http2errPseudoTrailers - return - } - - key := CanonicalHeaderKey(f.Name) - - // The spec says one must predeclare their trailers but in practice - // popular users (which is to say the only user we found) do not so we - // violate the spec and accept all of them. - const acceptAllTrailers = true - if _, ok := (*cs.resTrailer)[key]; ok || acceptAllTrailers { - if cs.trailer == nil { - cs.trailer = make(Header) - } - cs.trailer[key] = append(cs.trailer[key], f.Value) - } -} - func (cc *http2ClientConn) logf(format string, args ...interface{}) { cc.t.logf(format, args...) } @@ -6167,6 +6487,79 @@ type http2errorReader struct{ err error } func (r http2errorReader) Read(p []byte) (int, error) { return 0, r.err } +// bodyWriterState encapsulates various state around the Transport's writing +// of the request body, particularly regarding doing delayed writes of the body +// when the request contains "Expect: 100-continue". +type http2bodyWriterState struct { + cs *http2clientStream + timer *time.Timer // if non-nil, we're doing a delayed write + fnonce *sync.Once // to call fn with + fn func() // the code to run in the goroutine, writing the body + resc chan error // result of fn's execution + delay time.Duration // how long we should delay a delayed write for +} + +func (t *http2Transport) getBodyWriterState(cs *http2clientStream, body io.Reader) (s http2bodyWriterState) { + s.cs = cs + if body == nil { + return + } + resc := make(chan error, 1) + s.resc = resc + s.fn = func() { + resc <- cs.writeRequestBody(body, cs.req.Body) + } + s.delay = t.expectContinueTimeout() + if s.delay == 0 || + !httplex.HeaderValuesContainsToken( + cs.req.Header["Expect"], + "100-continue") { + return + } + s.fnonce = new(sync.Once) + + // Arm the timer with a very large duration, which we'll + // intentionally lower later. It has to be large now because + // we need a handle to it before writing the headers, but the + // s.delay value is defined to not start until after the + // request headers were written. + const hugeDuration = 365 * 24 * time.Hour + s.timer = time.AfterFunc(hugeDuration, func() { + s.fnonce.Do(s.fn) + }) + return +} + +func (s http2bodyWriterState) cancel() { + if s.timer != nil { + s.timer.Stop() + } +} + +func (s http2bodyWriterState) on100() { + if s.timer == nil { + + return + } + s.timer.Stop() + go func() { s.fnonce.Do(s.fn) }() +} + +// scheduleBodyWrite starts writing the body, either immediately (in +// the common case) or after the delay timeout. It should not be +// called until after the headers have been written. +func (s http2bodyWriterState) scheduleBodyWrite() { + if s.timer == nil { + + go s.fn() + return + } + http2traceWait100Continue(s.cs.trace) + if s.timer.Stop() { + s.timer.Reset(s.delay) + } +} + // writeFramer is implemented by any type that is used to write frames. type http2writeFramer interface { writeFrame(http2writeContext) error @@ -6380,24 +6773,22 @@ func (wu http2writeWindowUpdate) writeFrame(ctx http2writeContext) error { } func http2encodeHeaders(enc *hpack.Encoder, h Header, keys []string) { - if keys == nil { - keys = make([]string, 0, len(h)) - for k := range h { - keys = append(keys, k) - } - sort.Strings(keys) + sorter := http2sorterPool.Get().(*http2sorter) + + defer http2sorterPool.Put(sorter) + keys = sorter.Keys(h) } for _, k := range keys { vv := h[k] k = http2lowerHeader(k) - if !http2validHeaderFieldName(k) { + if !http2validWireHeaderFieldName(k) { continue } isTE := k == "transfer-encoding" for _, v := range vv { - if !http2validHeaderFieldValue(v) { + if !httplex.ValidHeaderFieldValue(v) { continue } diff --git a/libgo/go/net/http/header.go b/libgo/go/net/http/header.go index 049f32f27dc..6343165a840 100644 --- a/libgo/go/net/http/header.go +++ b/libgo/go/net/http/header.go @@ -1,4 +1,4 @@ -// Copyright 2010 The Go Authors. All rights reserved. +// Copyright 2010 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. @@ -25,7 +25,7 @@ func (h Header) Add(key, value string) { } // Set sets the header entries associated with key to -// the single element value. It replaces any existing +// the single element value. It replaces any existing // values associated with key. func (h Header) Set(key, value string) { textproto.MIMEHeader(h).Set(key, value) @@ -164,9 +164,9 @@ func (h Header) WriteSubset(w io.Writer, exclude map[string]bool) error { } // CanonicalHeaderKey returns the canonical format of the -// header key s. The canonicalization converts the first +// header key s. The canonicalization converts the first // letter and any letter following a hyphen to upper case; -// the rest are converted to lowercase. For example, the +// the rest are converted to lowercase. For example, the // canonical key for "accept-encoding" is "Accept-Encoding". // If s contains a space or invalid header field bytes, it is // returned without modifications. @@ -186,7 +186,7 @@ func hasToken(v, token string) bool { for sp := 0; sp <= len(v)-len(token); sp++ { // Check that first character is good. // The token is ASCII, so checking only a single byte - // is sufficient. We skip this potential starting + // is sufficient. We skip this potential starting // position if both the first byte and its potential // ASCII uppercase equivalent (b|0x20) don't match. // False positives ('^' => '~') are caught by EqualFold. diff --git a/libgo/go/net/http/header_test.go b/libgo/go/net/http/header_test.go index 299576ba8cf..5c0de15b731 100644 --- a/libgo/go/net/http/header_test.go +++ b/libgo/go/net/http/header_test.go @@ -1,4 +1,4 @@ -// Copyright 2011 The Go Authors. All rights reserved. +// Copyright 2011 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. diff --git a/libgo/go/net/http/http.go b/libgo/go/net/http/http.go new file mode 100644 index 00000000000..b34ae41ec51 --- /dev/null +++ b/libgo/go/net/http/http.go @@ -0,0 +1,43 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "strings" + + "golang_org/x/net/lex/httplex" +) + +// maxInt64 is the effective "infinite" value for the Server and +// Transport's byte-limiting readers. +const maxInt64 = 1<<63 - 1 + +// TODO(bradfitz): move common stuff here. The other files have accumulated +// generic http stuff in random places. + +// contextKey is a value for use with context.WithValue. It's used as +// a pointer so it fits in an interface{} without allocation. +type contextKey struct { + name string +} + +func (k *contextKey) String() string { return "net/http context value " + k.name } + +// Given a string of the form "host", "host:port", or "[ipv6::address]:port", +// return true if the string includes a port. +func hasPort(s string) bool { return strings.LastIndex(s, ":") > strings.LastIndex(s, "]") } + +// removeEmptyPort strips the empty port in ":port" to "" +// as mandated by RFC 3986 Section 6.2.3. +func removeEmptyPort(host string) string { + if hasPort(host) { + return strings.TrimSuffix(host, ":") + } + return host +} + +func isNotToken(r rune) bool { + return !httplex.IsTokenRune(r) +} diff --git a/libgo/go/net/http/http_test.go b/libgo/go/net/http/http_test.go index dead3b04542..34da4bbb59e 100644 --- a/libgo/go/net/http/http_test.go +++ b/libgo/go/net/http/http_test.go @@ -2,11 +2,14 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Tests of internal functions with no better homes. +// Tests of internal functions and things with no better homes. package http import ( + "bytes" + "internal/testenv" + "os/exec" "reflect" "testing" ) @@ -56,3 +59,35 @@ func TestCleanHost(t *testing.T) { } } } + +// Test that cmd/go doesn't link in the HTTP server. +// +// This catches accidental dependencies between the HTTP transport and +// server code. +func TestCmdGoNoHTTPServer(t *testing.T) { + goBin := testenv.GoToolPath(t) + out, err := exec.Command("go", "tool", "nm", goBin).CombinedOutput() + if err != nil { + t.Fatalf("go tool nm: %v: %s", err, out) + } + wantSym := map[string]bool{ + // Verify these exist: (sanity checking this test) + "net/http.(*Client).Get": true, + "net/http.(*Transport).RoundTrip": true, + + // Verify these don't exist: + "net/http.http2Server": false, + "net/http.(*Server).Serve": false, + "net/http.(*ServeMux).ServeHTTP": false, + "net/http.DefaultServeMux": false, + } + for sym, want := range wantSym { + got := bytes.Contains(out, []byte(sym)) + if !want && got { + t.Errorf("cmd/go unexpectedly links in HTTP server code; found symbol %q in cmd/go", sym) + } + if want && !got { + t.Errorf("expected to find symbol %q in cmd/go; not found", sym) + } + } +} diff --git a/libgo/go/net/http/httptest/httptest.go b/libgo/go/net/http/httptest/httptest.go new file mode 100644 index 00000000000..e2148a659c1 --- /dev/null +++ b/libgo/go/net/http/httptest/httptest.go @@ -0,0 +1,88 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package httptest provides utilities for HTTP testing. +package httptest + +import ( + "bufio" + "bytes" + "crypto/tls" + "io" + "io/ioutil" + "net/http" + "strings" +) + +// NewRequest returns a new incoming server Request, suitable +// for passing to an http.Handler for testing. +// +// The target is the RFC 7230 "request-target": it may be either a +// path or an absolute URL. If target is an absolute URL, the host name +// from the URL is used. Otherwise, "example.com" is used. +// +// The TLS field is set to a non-nil dummy value if target has scheme +// "https". +// +// The Request.Proto is always HTTP/1.1. +// +// An empty method means "GET". +// +// The provided body may be nil. If the body is of type *bytes.Reader, +// *strings.Reader, or *bytes.Buffer, the Request.ContentLength is +// set. +// +// NewRequest panics on error for ease of use in testing, where a +// panic is acceptable. +func NewRequest(method, target string, body io.Reader) *http.Request { + if method == "" { + method = "GET" + } + req, err := http.ReadRequest(bufio.NewReader(strings.NewReader(method + " " + target + " HTTP/1.0\r\n\r\n"))) + if err != nil { + panic("invalid NewRequest arguments; " + err.Error()) + } + + // HTTP/1.0 was used above to avoid needing a Host field. Change it to 1.1 here. + req.Proto = "HTTP/1.1" + req.ProtoMinor = 1 + req.Close = false + + if body != nil { + switch v := body.(type) { + case *bytes.Buffer: + req.ContentLength = int64(v.Len()) + case *bytes.Reader: + req.ContentLength = int64(v.Len()) + case *strings.Reader: + req.ContentLength = int64(v.Len()) + default: + req.ContentLength = -1 + } + if rc, ok := body.(io.ReadCloser); ok { + req.Body = rc + } else { + req.Body = ioutil.NopCloser(body) + } + } + + // 192.0.2.0/24 is "TEST-NET" in RFC 5737 for use solely in + // documentation and example source code and should not be + // used publicly. + req.RemoteAddr = "192.0.2.1:1234" + + if req.Host == "" { + req.Host = "example.com" + } + + if strings.HasPrefix(target, "https://") { + req.TLS = &tls.ConnectionState{ + Version: tls.VersionTLS12, + HandshakeComplete: true, + ServerName: req.Host, + } + } + + return req +} diff --git a/libgo/go/net/http/httptest/httptest_test.go b/libgo/go/net/http/httptest/httptest_test.go new file mode 100644 index 00000000000..4f9ecbd8bbc --- /dev/null +++ b/libgo/go/net/http/httptest/httptest_test.go @@ -0,0 +1,177 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package httptest + +import ( + "crypto/tls" + "io" + "io/ioutil" + "net/http" + "net/url" + "reflect" + "strings" + "testing" +) + +func TestNewRequest(t *testing.T) { + tests := [...]struct { + method, uri string + body io.Reader + + want *http.Request + wantBody string + }{ + // Empty method means GET: + 0: { + method: "", + uri: "/", + body: nil, + want: &http.Request{ + Method: "GET", + Host: "example.com", + URL: &url.URL{Path: "/"}, + Header: http.Header{}, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + RemoteAddr: "192.0.2.1:1234", + RequestURI: "/", + }, + wantBody: "", + }, + + // GET with full URL: + 1: { + method: "GET", + uri: "http://foo.com/path/%2f/bar/", + body: nil, + want: &http.Request{ + Method: "GET", + Host: "foo.com", + URL: &url.URL{ + Scheme: "http", + Path: "/path///bar/", + RawPath: "/path/%2f/bar/", + Host: "foo.com", + }, + Header: http.Header{}, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + RemoteAddr: "192.0.2.1:1234", + RequestURI: "http://foo.com/path/%2f/bar/", + }, + wantBody: "", + }, + + // GET with full https URL: + 2: { + method: "GET", + uri: "https://foo.com/path/", + body: nil, + want: &http.Request{ + Method: "GET", + Host: "foo.com", + URL: &url.URL{ + Scheme: "https", + Path: "/path/", + Host: "foo.com", + }, + Header: http.Header{}, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + RemoteAddr: "192.0.2.1:1234", + RequestURI: "https://foo.com/path/", + TLS: &tls.ConnectionState{ + Version: tls.VersionTLS12, + HandshakeComplete: true, + ServerName: "foo.com", + }, + }, + wantBody: "", + }, + + // Post with known length + 3: { + method: "POST", + uri: "/", + body: strings.NewReader("foo"), + want: &http.Request{ + Method: "POST", + Host: "example.com", + URL: &url.URL{Path: "/"}, + Header: http.Header{}, + Proto: "HTTP/1.1", + ContentLength: 3, + ProtoMajor: 1, + ProtoMinor: 1, + RemoteAddr: "192.0.2.1:1234", + RequestURI: "/", + }, + wantBody: "foo", + }, + + // Post with unknown length + 4: { + method: "POST", + uri: "/", + body: struct{ io.Reader }{strings.NewReader("foo")}, + want: &http.Request{ + Method: "POST", + Host: "example.com", + URL: &url.URL{Path: "/"}, + Header: http.Header{}, + Proto: "HTTP/1.1", + ContentLength: -1, + ProtoMajor: 1, + ProtoMinor: 1, + RemoteAddr: "192.0.2.1:1234", + RequestURI: "/", + }, + wantBody: "foo", + }, + + // OPTIONS * + 5: { + method: "OPTIONS", + uri: "*", + want: &http.Request{ + Method: "OPTIONS", + Host: "example.com", + URL: &url.URL{Path: "*"}, + Header: http.Header{}, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + RemoteAddr: "192.0.2.1:1234", + RequestURI: "*", + }, + }, + } + for i, tt := range tests { + got := NewRequest(tt.method, tt.uri, tt.body) + slurp, err := ioutil.ReadAll(got.Body) + if err != nil { + t.Errorf("%d. ReadAll: %v", i, err) + } + if string(slurp) != tt.wantBody { + t.Errorf("%d. Body = %q; want %q", i, slurp, tt.wantBody) + } + got.Body = nil // before DeepEqual + if !reflect.DeepEqual(got.URL, tt.want.URL) { + t.Errorf("%d. Request.URL mismatch:\n got: %#v\nwant: %#v", i, got.URL, tt.want.URL) + } + if !reflect.DeepEqual(got.Header, tt.want.Header) { + t.Errorf("%d. Request.Header mismatch:\n got: %#v\nwant: %#v", i, got.Header, tt.want.Header) + } + if !reflect.DeepEqual(got.TLS, tt.want.TLS) { + t.Errorf("%d. Request.TLS mismatch:\n got: %#v\nwant: %#v", i, got.TLS, tt.want.TLS) + } + if !reflect.DeepEqual(got, tt.want) { + t.Errorf("%d. Request mismatch:\n got: %#v\nwant: %#v", i, got, tt.want) + } + } +} diff --git a/libgo/go/net/http/httptest/recorder.go b/libgo/go/net/http/httptest/recorder.go index 7c51af1867a..0ad26a3d418 100644 --- a/libgo/go/net/http/httptest/recorder.go +++ b/libgo/go/net/http/httptest/recorder.go @@ -2,11 +2,11 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// Package httptest provides utilities for HTTP testing. package httptest import ( "bytes" + "io/ioutil" "net/http" ) @@ -18,6 +18,8 @@ type ResponseRecorder struct { Body *bytes.Buffer // if non-nil, the bytes.Buffer to append written data to Flushed bool + result *http.Response // cache of Result's return value + snapHeader http.Header // snapshot of HeaderMap at first Write wroteHeader bool } @@ -59,16 +61,15 @@ func (rw *ResponseRecorder) writeHeader(b []byte, str string) { str = str[:512] } - _, hasType := rw.HeaderMap["Content-Type"] - hasTE := rw.HeaderMap.Get("Transfer-Encoding") != "" + m := rw.Header() + + _, hasType := m["Content-Type"] + hasTE := m.Get("Transfer-Encoding") != "" if !hasType && !hasTE { if b == nil { b = []byte(str) } - if rw.HeaderMap == nil { - rw.HeaderMap = make(http.Header) - } - rw.HeaderMap.Set("Content-Type", http.DetectContentType(b)) + m.Set("Content-Type", http.DetectContentType(b)) } rw.WriteHeader(200) @@ -92,12 +93,28 @@ func (rw *ResponseRecorder) WriteString(str string) (int, error) { return len(str), nil } -// WriteHeader sets rw.Code. +// WriteHeader sets rw.Code. After it is called, changing rw.Header +// will not affect rw.HeaderMap. func (rw *ResponseRecorder) WriteHeader(code int) { - if !rw.wroteHeader { - rw.Code = code - rw.wroteHeader = true + if rw.wroteHeader { + return + } + rw.Code = code + rw.wroteHeader = true + if rw.HeaderMap == nil { + rw.HeaderMap = make(http.Header) } + rw.snapHeader = cloneHeader(rw.HeaderMap) +} + +func cloneHeader(h http.Header) http.Header { + h2 := make(http.Header, len(h)) + for k, vv := range h { + vv2 := make([]string, len(vv)) + copy(vv2, vv) + h2[k] = vv2 + } + return h2 } // Flush sets rw.Flushed to true. @@ -107,3 +124,62 @@ func (rw *ResponseRecorder) Flush() { } rw.Flushed = true } + +// Result returns the response generated by the handler. +// +// The returned Response will have at least its StatusCode, +// Header, Body, and optionally Trailer populated. +// More fields may be populated in the future, so callers should +// not DeepEqual the result in tests. +// +// The Response.Header is a snapshot of the headers at the time of the +// first write call, or at the time of this call, if the handler never +// did a write. +// +// Result must only be called after the handler has finished running. +func (rw *ResponseRecorder) Result() *http.Response { + if rw.result != nil { + return rw.result + } + if rw.snapHeader == nil { + rw.snapHeader = cloneHeader(rw.HeaderMap) + } + res := &http.Response{ + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + StatusCode: rw.Code, + Header: rw.snapHeader, + } + rw.result = res + if res.StatusCode == 0 { + res.StatusCode = 200 + } + res.Status = http.StatusText(res.StatusCode) + if rw.Body != nil { + res.Body = ioutil.NopCloser(bytes.NewReader(rw.Body.Bytes())) + } + + if trailers, ok := rw.snapHeader["Trailer"]; ok { + res.Trailer = make(http.Header, len(trailers)) + for _, k := range trailers { + // TODO: use http2.ValidTrailerHeader, but we can't + // get at it easily because it's bundled into net/http + // unexported. This is good enough for now: + switch k { + case "Transfer-Encoding", "Content-Length", "Trailer": + // Ignore since forbidden by RFC 2616 14.40. + continue + } + k = http.CanonicalHeaderKey(k) + vv, ok := rw.HeaderMap[k] + if !ok { + continue + } + vv2 := make([]string, len(vv)) + copy(vv2, vv) + res.Trailer[k] = vv2 + } + } + return res +} diff --git a/libgo/go/net/http/httptest/recorder_test.go b/libgo/go/net/http/httptest/recorder_test.go index c29b6d4cf91..d4e7137913e 100644 --- a/libgo/go/net/http/httptest/recorder_test.go +++ b/libgo/go/net/http/httptest/recorder_test.go @@ -23,6 +23,14 @@ func TestRecorder(t *testing.T) { return nil } } + hasResultStatus := func(wantCode int) checkFunc { + return func(rec *ResponseRecorder) error { + if rec.Result().StatusCode != wantCode { + return fmt.Errorf("Result().StatusCode = %d; want %d", rec.Result().StatusCode, wantCode) + } + return nil + } + } hasContents := func(want string) checkFunc { return func(rec *ResponseRecorder) error { if rec.Body.String() != want { @@ -39,10 +47,49 @@ func TestRecorder(t *testing.T) { return nil } } - hasHeader := func(key, want string) checkFunc { + hasOldHeader := func(key, want string) checkFunc { return func(rec *ResponseRecorder) error { if got := rec.HeaderMap.Get(key); got != want { - return fmt.Errorf("header %s = %q; want %q", key, got, want) + return fmt.Errorf("HeaderMap header %s = %q; want %q", key, got, want) + } + return nil + } + } + hasHeader := func(key, want string) checkFunc { + return func(rec *ResponseRecorder) error { + if got := rec.Result().Header.Get(key); got != want { + return fmt.Errorf("final header %s = %q; want %q", key, got, want) + } + return nil + } + } + hasNotHeaders := func(keys ...string) checkFunc { + return func(rec *ResponseRecorder) error { + for _, k := range keys { + v, ok := rec.Result().Header[http.CanonicalHeaderKey(k)] + if ok { + return fmt.Errorf("unexpected header %s with value %q", k, v) + } + } + return nil + } + } + hasTrailer := func(key, want string) checkFunc { + return func(rec *ResponseRecorder) error { + if got := rec.Result().Trailer.Get(key); got != want { + return fmt.Errorf("trailer %s = %q; want %q", key, got, want) + } + return nil + } + } + hasNotTrailers := func(keys ...string) checkFunc { + return func(rec *ResponseRecorder) error { + trailers := rec.Result().Trailer + for _, k := range keys { + _, ok := trailers[http.CanonicalHeaderKey(k)] + if ok { + return fmt.Errorf("unexpected trailer %s", k) + } } return nil } @@ -130,6 +177,73 @@ func TestRecorder(t *testing.T) { }, check(hasHeader("Content-Type", "text/html; charset=utf-8")), }, + { + "Header is not changed after write", + func(w http.ResponseWriter, r *http.Request) { + hdr := w.Header() + hdr.Set("Key", "correct") + w.WriteHeader(200) + hdr.Set("Key", "incorrect") + }, + check(hasHeader("Key", "correct")), + }, + { + "Trailer headers are correctly recorded", + func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Non-Trailer", "correct") + w.Header().Set("Trailer", "Trailer-A") + w.Header().Add("Trailer", "Trailer-B") + w.Header().Add("Trailer", "Trailer-C") + io.WriteString(w, "<html>") + w.Header().Set("Non-Trailer", "incorrect") + w.Header().Set("Trailer-A", "valuea") + w.Header().Set("Trailer-C", "valuec") + w.Header().Set("Trailer-NotDeclared", "should be omitted") + }, + check( + hasStatus(200), + hasHeader("Content-Type", "text/html; charset=utf-8"), + hasHeader("Non-Trailer", "correct"), + hasNotHeaders("Trailer-A", "Trailer-B", "Trailer-C", "Trailer-NotDeclared"), + hasTrailer("Trailer-A", "valuea"), + hasTrailer("Trailer-C", "valuec"), + hasNotTrailers("Non-Trailer", "Trailer-B", "Trailer-NotDeclared"), + ), + }, + { + "Header set without any write", // Issue 15560 + func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("X-Foo", "1") + + // Simulate somebody using + // new(ResponseRecorder) instead of + // using the constructor which sets + // this to 200 + w.(*ResponseRecorder).Code = 0 + }, + check( + hasOldHeader("X-Foo", "1"), + hasStatus(0), + hasHeader("X-Foo", "1"), + hasResultStatus(200), + ), + }, + { + "HeaderMap vs FinalHeaders", // more for Issue 15560 + func(w http.ResponseWriter, r *http.Request) { + h := w.Header() + h.Set("X-Foo", "1") + w.Write([]byte("hi")) + h.Set("X-Foo", "2") + h.Set("X-Bar", "2") + }, + check( + hasOldHeader("X-Foo", "2"), + hasOldHeader("X-Bar", "2"), + hasHeader("X-Foo", "1"), + hasNotHeaders("X-Bar"), + ), + }, } r, _ := http.NewRequest("GET", "http://foo.com/", nil) for _, tt := range tests { diff --git a/libgo/go/net/http/httptest/server.go b/libgo/go/net/http/httptest/server.go index bbe323396f5..e27526a937a 100644 --- a/libgo/go/net/http/httptest/server.go +++ b/libgo/go/net/http/httptest/server.go @@ -158,7 +158,7 @@ func (s *Server) Close() { // previously-flaky tests) in the case of // socket-late-binding races from the http Client // dialing this server and then getting an idle - // connection before the dial completed. There is thus + // connection before the dial completed. There is thus // a connected connection in StateNew with no // associated Request. We only close StateIdle and // StateNew because they're not doing anything. It's @@ -167,7 +167,7 @@ func (s *Server) Close() { // few milliseconds wasn't liked (early versions of // https://golang.org/cl/15151) so now we just // forcefully close StateNew. The docs for Server.Close say - // we wait for "oustanding requests", so we don't close things + // we wait for "outstanding requests", so we don't close things // in StateActive. if st == http.StateIdle || st == http.StateNew { s.closeConn(c) @@ -202,12 +202,10 @@ func (s *Server) logCloseHangDebugInfo() { // CloseClientConnections closes any open HTTP connections to the test Server. func (s *Server) CloseClientConnections() { - var conns int - ch := make(chan bool) - s.mu.Lock() + nconn := len(s.conns) + ch := make(chan struct{}, nconn) for c := range s.conns { - conns++ s.closeConnChan(c, ch) } s.mu.Unlock() @@ -220,7 +218,7 @@ func (s *Server) CloseClientConnections() { // in tests. timer := time.NewTimer(5 * time.Second) defer timer.Stop() - for i := 0; i < conns; i++ { + for i := 0; i < nconn; i++ { select { case <-ch: case <-timer.C: @@ -294,30 +292,20 @@ func (s *Server) closeConn(c net.Conn) { s.closeConnChan(c, nil) } // closeConnChan is like closeConn, but takes an optional channel to receive a value // when the goroutine closing c is done. -func (s *Server) closeConnChan(c net.Conn, done chan<- bool) { +func (s *Server) closeConnChan(c net.Conn, done chan<- struct{}) { if runtime.GOOS == "plan9" { // Go's Plan 9 net package isn't great at unblocking reads when - // their underlying TCP connections are closed. Don't trust + // their underlying TCP connections are closed. Don't trust // that that the ConnState state machine will get to // StateClosed. Instead, just go there directly. Plan 9 may leak // resources if the syscall doesn't end up returning. Oh well. s.forgetConn(c) } - // Somewhere in the chaos of https://golang.org/cl/15151 we found that - // some types of conns were blocking in Close too long (or deadlocking?) - // and we had to call Close in a goroutine. I (bradfitz) forget what - // that was at this point, but I suspect it was *tls.Conns, which - // were later fixed in https://golang.org/cl/18572, so this goroutine - // is _probably_ unnecessary now. But it's too late in Go 1.6 too remove - // it with confidence. - // TODO(bradfitz): try to remove it for Go 1.7. (golang.org/issue/14291) - go func() { - c.Close() - if done != nil { - done <- true - } - }() + c.Close() + if done != nil { + done <- struct{}{} + } } // forgetConn removes c from the set of tracked conns and decrements it from the diff --git a/libgo/go/net/http/httptest/server_test.go b/libgo/go/net/http/httptest/server_test.go index c9606f24198..d032c5983b7 100644 --- a/libgo/go/net/http/httptest/server_test.go +++ b/libgo/go/net/http/httptest/server_test.go @@ -53,7 +53,7 @@ func TestGetAfterClose(t *testing.T) { res, err = http.Get(ts.URL) if err == nil { body, _ := ioutil.ReadAll(res.Body) - t.Fatalf("Unexected response after close: %v, %v, %s", res.Status, res.Header, body) + t.Fatalf("Unexpected response after close: %v, %v, %s", res.Status, res.Header, body) } } @@ -95,6 +95,6 @@ func TestServerCloseClientConnections(t *testing.T) { res, err := http.Get(s.URL) if err == nil { res.Body.Close() - t.Fatal("Unexpected response: %#v", res) + t.Fatalf("Unexpected response: %#v", res) } } diff --git a/libgo/go/net/http/httptrace/trace.go b/libgo/go/net/http/httptrace/trace.go new file mode 100644 index 00000000000..6f187a7b694 --- /dev/null +++ b/libgo/go/net/http/httptrace/trace.go @@ -0,0 +1,226 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file.h + +// Package httptrace provides mechanisms to trace the events within +// HTTP client requests. +package httptrace + +import ( + "context" + "internal/nettrace" + "net" + "reflect" + "time" +) + +// unique type to prevent assignment. +type clientEventContextKey struct{} + +// ContextClientTrace returns the ClientTrace associated with the +// provided context. If none, it returns nil. +func ContextClientTrace(ctx context.Context) *ClientTrace { + trace, _ := ctx.Value(clientEventContextKey{}).(*ClientTrace) + return trace +} + +// WithClientTrace returns a new context based on the provided parent +// ctx. HTTP client requests made with the returned context will use +// the provided trace hooks, in addition to any previous hooks +// registered with ctx. Any hooks defined in the provided trace will +// be called first. +func WithClientTrace(ctx context.Context, trace *ClientTrace) context.Context { + if trace == nil { + panic("nil trace") + } + old := ContextClientTrace(ctx) + trace.compose(old) + + ctx = context.WithValue(ctx, clientEventContextKey{}, trace) + if trace.hasNetHooks() { + nt := &nettrace.Trace{ + ConnectStart: trace.ConnectStart, + ConnectDone: trace.ConnectDone, + } + if trace.DNSStart != nil { + nt.DNSStart = func(name string) { + trace.DNSStart(DNSStartInfo{Host: name}) + } + } + if trace.DNSDone != nil { + nt.DNSDone = func(netIPs []interface{}, coalesced bool, err error) { + addrs := make([]net.IPAddr, len(netIPs)) + for i, ip := range netIPs { + addrs[i] = ip.(net.IPAddr) + } + trace.DNSDone(DNSDoneInfo{ + Addrs: addrs, + Coalesced: coalesced, + Err: err, + }) + } + } + ctx = context.WithValue(ctx, nettrace.TraceKey{}, nt) + } + return ctx +} + +// ClientTrace is a set of hooks to run at various stages of an HTTP +// client request. Any particular hook may be nil. Functions may be +// called concurrently from different goroutines, starting after the +// call to Transport.RoundTrip and ending either when RoundTrip +// returns an error, or when the Response.Body is closed. +type ClientTrace struct { + // GetConn is called before a connection is created or + // retrieved from an idle pool. The hostPort is the + // "host:port" of the target or proxy. GetConn is called even + // if there's already an idle cached connection available. + GetConn func(hostPort string) + + // GotConn is called after a successful connection is + // obtained. There is no hook for failure to obtain a + // connection; instead, use the error from + // Transport.RoundTrip. + GotConn func(GotConnInfo) + + // PutIdleConn is called when the connection is returned to + // the idle pool. If err is nil, the connection was + // successfully returned to the idle pool. If err is non-nil, + // it describes why not. PutIdleConn is not called if + // connection reuse is disabled via Transport.DisableKeepAlives. + // PutIdleConn is called before the caller's Response.Body.Close + // call returns. + // For HTTP/2, this hook is not currently used. + PutIdleConn func(err error) + + // GotFirstResponseByte is called when the first byte of the response + // headers is available. + GotFirstResponseByte func() + + // Got100Continue is called if the server replies with a "100 + // Continue" response. + Got100Continue func() + + // DNSStart is called when a DNS lookup begins. + DNSStart func(DNSStartInfo) + + // DNSDone is called when a DNS lookup ends. + DNSDone func(DNSDoneInfo) + + // ConnectStart is called when a new connection's Dial begins. + // If net.Dialer.DualStack (IPv6 "Happy Eyeballs") support is + // enabled, this may be called multiple times. + ConnectStart func(network, addr string) + + // ConnectDone is called when a new connection's Dial + // completes. The provided err indicates whether the + // connection completedly successfully. + // If net.Dialer.DualStack ("Happy Eyeballs") support is + // enabled, this may be called multiple times. + ConnectDone func(network, addr string, err error) + + // WroteHeaders is called after the Transport has written + // the request headers. + WroteHeaders func() + + // Wait100Continue is called if the Request specified + // "Expected: 100-continue" and the Transport has written the + // request headers but is waiting for "100 Continue" from the + // server before writing the request body. + Wait100Continue func() + + // WroteRequest is called with the result of writing the + // request and any body. + WroteRequest func(WroteRequestInfo) +} + +// WroteRequestInfo contains information provided to the WroteRequest +// hook. +type WroteRequestInfo struct { + // Err is any error encountered while writing the Request. + Err error +} + +// compose modifies t such that it respects the previously-registered hooks in old, +// subject to the composition policy requested in t.Compose. +func (t *ClientTrace) compose(old *ClientTrace) { + if old == nil { + return + } + tv := reflect.ValueOf(t).Elem() + ov := reflect.ValueOf(old).Elem() + structType := tv.Type() + for i := 0; i < structType.NumField(); i++ { + tf := tv.Field(i) + hookType := tf.Type() + if hookType.Kind() != reflect.Func { + continue + } + of := ov.Field(i) + if of.IsNil() { + continue + } + if tf.IsNil() { + tf.Set(of) + continue + } + + // Make a copy of tf for tf to call. (Otherwise it + // creates a recursive call cycle and stack overflows) + tfCopy := reflect.ValueOf(tf.Interface()) + + // We need to call both tf and of in some order. + newFunc := reflect.MakeFunc(hookType, func(args []reflect.Value) []reflect.Value { + tfCopy.Call(args) + return of.Call(args) + }) + tv.Field(i).Set(newFunc) + } +} + +// DNSStartInfo contains information about a DNS request. +type DNSStartInfo struct { + Host string +} + +// DNSDoneInfo contains information about the results of a DNS lookup. +type DNSDoneInfo struct { + // Addrs are the IPv4 and/or IPv6 addresses found in the DNS + // lookup. The contents of the slice should not be mutated. + Addrs []net.IPAddr + + // Err is any error that occurred during the DNS lookup. + Err error + + // Coalesced is whether the Addrs were shared with another + // caller who was doing the same DNS lookup concurrently. + Coalesced bool +} + +func (t *ClientTrace) hasNetHooks() bool { + if t == nil { + return false + } + return t.DNSStart != nil || t.DNSDone != nil || t.ConnectStart != nil || t.ConnectDone != nil +} + +// GotConnInfo is the argument to the ClientTrace.GotConn function and +// contains information about the obtained connection. +type GotConnInfo struct { + // Conn is the connection that was obtained. It is owned by + // the http.Transport and should not be read, written or + // closed by users of ClientTrace. + Conn net.Conn + + // Reused is whether this connection has been previously + // used for another HTTP request. + Reused bool + + // WasIdle is whether this connection was obtained from an + // idle pool. + WasIdle bool + + // IdleTime reports how long the connection was previously + // idle, if WasIdle is true. + IdleTime time.Duration +} diff --git a/libgo/go/net/http/httptrace/trace_test.go b/libgo/go/net/http/httptrace/trace_test.go new file mode 100644 index 00000000000..c7eaed83d47 --- /dev/null +++ b/libgo/go/net/http/httptrace/trace_test.go @@ -0,0 +1,62 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file.h + +package httptrace + +import ( + "bytes" + "testing" +) + +func TestCompose(t *testing.T) { + var buf bytes.Buffer + var testNum int + + connectStart := func(b byte) func(network, addr string) { + return func(network, addr string) { + if addr != "addr" { + t.Errorf(`%d. args for %q case = %q, %q; want addr of "addr"`, testNum, b, network, addr) + } + buf.WriteByte(b) + } + } + + tests := [...]struct { + trace, old *ClientTrace + want string + }{ + 0: { + want: "T", + trace: &ClientTrace{ + ConnectStart: connectStart('T'), + }, + }, + 1: { + want: "TO", + trace: &ClientTrace{ + ConnectStart: connectStart('T'), + }, + old: &ClientTrace{ConnectStart: connectStart('O')}, + }, + 2: { + want: "O", + trace: &ClientTrace{}, + old: &ClientTrace{ConnectStart: connectStart('O')}, + }, + } + for i, tt := range tests { + testNum = i + buf.Reset() + + tr := *tt.trace + tr.compose(tt.old) + if tr.ConnectStart != nil { + tr.ConnectStart("net", "addr") + } + if got := buf.String(); got != tt.want { + t.Errorf("%d. got = %q; want %q", i, got, tt.want) + } + } + +} diff --git a/libgo/go/net/http/httputil/dump.go b/libgo/go/net/http/httputil/dump.go index e22cc66dbfc..15116816328 100644 --- a/libgo/go/net/http/httputil/dump.go +++ b/libgo/go/net/http/httputil/dump.go @@ -128,7 +128,7 @@ func DumpRequestOut(req *http.Request, body bool) ([]byte, error) { // If we used a dummy body above, remove it now. // TODO: if the req.ContentLength is large, we allocate memory - // unnecessarily just to slice it off here. But this is just + // unnecessarily just to slice it off here. But this is just // a debug function, so this is acceptable for now. We could // discard the body earlier if this matters. if dummyBody { @@ -163,18 +163,10 @@ func valueOrDefault(value, def string) string { var reqWriteExcludeHeaderDump = map[string]bool{ "Host": true, // not in Header map anyway - "Content-Length": true, "Transfer-Encoding": true, "Trailer": true, } -// dumpAsReceived writes req to w in the form as it was received, or -// at least as accurately as possible from the information retained in -// the request. -func dumpAsReceived(req *http.Request, w io.Writer) error { - return nil -} - // DumpRequest returns the given request in its HTTP/1.x wire // representation. It should only be used by servers to debug client // requests. The returned representation is an approximation only; @@ -191,7 +183,8 @@ func dumpAsReceived(req *http.Request, w io.Writer) error { // // The documentation for http.Request.Write details which fields // of req are included in the dump. -func DumpRequest(req *http.Request, body bool) (dump []byte, err error) { +func DumpRequest(req *http.Request, body bool) ([]byte, error) { + var err error save := req.Body if !body || req.Body == nil { req.Body = nil @@ -239,7 +232,7 @@ func DumpRequest(req *http.Request, body bool) (dump []byte, err error) { err = req.Header.WriteSubset(&b, reqWriteExcludeHeaderDump) if err != nil { - return + return nil, err } io.WriteString(&b, "\r\n") @@ -258,35 +251,42 @@ func DumpRequest(req *http.Request, body bool) (dump []byte, err error) { req.Body = save if err != nil { - return + return nil, err } - dump = b.Bytes() - return + return b.Bytes(), nil } -// errNoBody is a sentinel error value used by failureToReadBody so we can detect -// that the lack of body was intentional. +// errNoBody is a sentinel error value used by failureToReadBody so we +// can detect that the lack of body was intentional. var errNoBody = errors.New("sentinel error value") // failureToReadBody is a io.ReadCloser that just returns errNoBody on -// Read. It's swapped in when we don't actually want to consume the -// body, but need a non-nil one, and want to distinguish the error -// from reading the dummy body. +// Read. It's swapped in when we don't actually want to consume +// the body, but need a non-nil one, and want to distinguish the +// error from reading the dummy body. type failureToReadBody struct{} func (failureToReadBody) Read([]byte) (int, error) { return 0, errNoBody } func (failureToReadBody) Close() error { return nil } +// emptyBody is an instance of empty reader. var emptyBody = ioutil.NopCloser(strings.NewReader("")) // DumpResponse is like DumpRequest but dumps a response. -func DumpResponse(resp *http.Response, body bool) (dump []byte, err error) { +func DumpResponse(resp *http.Response, body bool) ([]byte, error) { var b bytes.Buffer + var err error save := resp.Body savecl := resp.ContentLength if !body { - resp.Body = failureToReadBody{} + // For content length of zero. Make sure the body is an empty + // reader, instead of returning error through failureToReadBody{}. + if resp.ContentLength == 0 { + resp.Body = emptyBody + } else { + resp.Body = failureToReadBody{} + } } else if resp.Body == nil { resp.Body = emptyBody } else { diff --git a/libgo/go/net/http/httputil/dump_test.go b/libgo/go/net/http/httputil/dump_test.go index 46bf521723a..2e980d39f8a 100644 --- a/libgo/go/net/http/httputil/dump_test.go +++ b/libgo/go/net/http/httputil/dump_test.go @@ -122,6 +122,10 @@ var dumpTests = []dumpTest{ Host: "post.tld", Path: "/", }, + Header: http.Header{ + "Content-Length": []string{"8193"}, + }, + ContentLength: 8193, ProtoMajor: 1, ProtoMinor: 1, @@ -135,6 +139,10 @@ var dumpTests = []dumpTest{ "Content-Length: 8193\r\n" + "Accept-Encoding: gzip\r\n\r\n" + strings.Repeat("a", 8193), + WantDump: "POST / HTTP/1.1\r\n" + + "Host: post.tld\r\n" + + "Content-Length: 8193\r\n\r\n" + + strings.Repeat("a", 8193), }, { @@ -144,6 +152,38 @@ var dumpTests = []dumpTest{ WantDump: "GET http://foo.com/ HTTP/1.1\r\n" + "User-Agent: blah\r\n\r\n", }, + + // Issue #7215. DumpRequest should return the "Content-Length" when set + { + Req: *mustReadRequest("POST /v2/api/?login HTTP/1.1\r\n" + + "Host: passport.myhost.com\r\n" + + "Content-Length: 3\r\n" + + "\r\nkey1=name1&key2=name2"), + WantDump: "POST /v2/api/?login HTTP/1.1\r\n" + + "Host: passport.myhost.com\r\n" + + "Content-Length: 3\r\n" + + "\r\nkey", + }, + + // Issue #7215. DumpRequest should return the "Content-Length" in ReadRequest + { + Req: *mustReadRequest("POST /v2/api/?login HTTP/1.1\r\n" + + "Host: passport.myhost.com\r\n" + + "Content-Length: 0\r\n" + + "\r\nkey1=name1&key2=name2"), + WantDump: "POST /v2/api/?login HTTP/1.1\r\n" + + "Host: passport.myhost.com\r\n" + + "Content-Length: 0\r\n\r\n", + }, + + // Issue #7215. DumpRequest should not return the "Content-Length" if unset + { + Req: *mustReadRequest("POST /v2/api/?login HTTP/1.1\r\n" + + "Host: passport.myhost.com\r\n" + + "\r\nkey1=name1&key2=name2"), + WantDump: "POST /v2/api/?login HTTP/1.1\r\n" + + "Host: passport.myhost.com\r\n\r\n", + }, } func TestDumpRequest(t *testing.T) { @@ -288,6 +328,27 @@ Transfer-Encoding: chunked foo 0`, }, + { + res: &http.Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: 0, + Header: http.Header{ + // To verify if headers are not filtered out. + "Foo1": []string{"Bar1"}, + "Foo2": []string{"Bar2"}, + }, + Body: nil, + }, + body: false, // to verify we see 0, not empty. + want: `HTTP/1.1 200 OK +Foo1: Bar1 +Foo2: Bar2 +Content-Length: 0`, + }, } func TestDumpResponse(t *testing.T) { diff --git a/libgo/go/net/http/httputil/example_test.go b/libgo/go/net/http/httputil/example_test.go index 8fb1a2d2792..e8dc962d3e3 100644 --- a/libgo/go/net/http/httputil/example_test.go +++ b/libgo/go/net/http/httputil/example_test.go @@ -49,7 +49,7 @@ func ExampleDumpRequest() { fmt.Printf("%s", b) // Output: - // "POST / HTTP/1.1\r\nHost: www.example.org\r\nAccept-Encoding: gzip\r\nUser-Agent: Go-http-client/1.1\r\n\r\nGo is a general-purpose language designed with systems programming in mind." + // "POST / HTTP/1.1\r\nHost: www.example.org\r\nAccept-Encoding: gzip\r\nContent-Length: 75\r\nUser-Agent: Go-http-client/1.1\r\n\r\nGo is a general-purpose language designed with systems programming in mind." } func ExampleDumpRequestOut() { diff --git a/libgo/go/net/http/httputil/persist.go b/libgo/go/net/http/httputil/persist.go index 987bcc96ba1..87ddd52cd96 100644 --- a/libgo/go/net/http/httputil/persist.go +++ b/libgo/go/net/http/httputil/persist.go @@ -24,17 +24,13 @@ var ( // ErrPersistEOF (above) reports that the remote side is closed. var errClosed = errors.New("i/o operation on closed connection") -// A ServerConn reads requests and sends responses over an underlying -// connection, until the HTTP keepalive logic commands an end. ServerConn -// also allows hijacking the underlying connection by calling Hijack -// to regain control over the connection. ServerConn supports pipe-lining, -// i.e. requests can be read out of sync (but in the same order) while the -// respective responses are sent. +// ServerConn is an artifact of Go's early HTTP implementation. +// It is low-level, old, and unused by Go's current HTTP stack. +// We should have deleted it before Go 1. // -// ServerConn is low-level and old. Applications should instead use Server -// in the net/http package. +// Deprecated: Use the Server in package net/http instead. type ServerConn struct { - lk sync.Mutex // read-write protects the following fields + mu sync.Mutex // read-write protects the following fields c net.Conn r *bufio.Reader re, we error // read/write errors @@ -45,11 +41,11 @@ type ServerConn struct { pipe textproto.Pipeline } -// NewServerConn returns a new ServerConn reading and writing c. If r is not -// nil, it is the buffer to use when reading c. +// NewServerConn is an artifact of Go's early HTTP implementation. +// It is low-level, old, and unused by Go's current HTTP stack. +// We should have deleted it before Go 1. // -// ServerConn is low-level and old. Applications should instead use Server -// in the net/http package. +// Deprecated: Use the Server in package net/http instead. func NewServerConn(c net.Conn, r *bufio.Reader) *ServerConn { if r == nil { r = bufio.NewReader(c) @@ -61,17 +57,17 @@ func NewServerConn(c net.Conn, r *bufio.Reader) *ServerConn { // as the read-side bufio which may have some left over data. Hijack may be // called before Read has signaled the end of the keep-alive logic. The user // should not call Hijack while Read or Write is in progress. -func (sc *ServerConn) Hijack() (c net.Conn, r *bufio.Reader) { - sc.lk.Lock() - defer sc.lk.Unlock() - c = sc.c - r = sc.r +func (sc *ServerConn) Hijack() (net.Conn, *bufio.Reader) { + sc.mu.Lock() + defer sc.mu.Unlock() + c := sc.c + r := sc.r sc.c = nil sc.r = nil - return + return c, r } -// Close calls Hijack and then also closes the underlying connection +// Close calls Hijack and then also closes the underlying connection. func (sc *ServerConn) Close() error { c, _ := sc.Hijack() if c != nil { @@ -84,7 +80,9 @@ func (sc *ServerConn) Close() error { // it is gracefully determined that there are no more requests (e.g. after the // first request on an HTTP/1.0 connection, or after a Connection:close on a // HTTP/1.1 connection). -func (sc *ServerConn) Read() (req *http.Request, err error) { +func (sc *ServerConn) Read() (*http.Request, error) { + var req *http.Request + var err error // Ensure ordered execution of Reads and Writes id := sc.pipe.Next() @@ -96,29 +94,29 @@ func (sc *ServerConn) Read() (req *http.Request, err error) { sc.pipe.EndResponse(id) } else { // Remember the pipeline id of this request - sc.lk.Lock() + sc.mu.Lock() sc.pipereq[req] = id - sc.lk.Unlock() + sc.mu.Unlock() } }() - sc.lk.Lock() + sc.mu.Lock() if sc.we != nil { // no point receiving if write-side broken or closed - defer sc.lk.Unlock() + defer sc.mu.Unlock() return nil, sc.we } if sc.re != nil { - defer sc.lk.Unlock() + defer sc.mu.Unlock() return nil, sc.re } if sc.r == nil { // connection closed by user in the meantime - defer sc.lk.Unlock() + defer sc.mu.Unlock() return nil, errClosed } r := sc.r lastbody := sc.lastbody sc.lastbody = nil - sc.lk.Unlock() + sc.mu.Unlock() // Make sure body is fully consumed, even if user does not call body.Close if lastbody != nil { @@ -127,16 +125,16 @@ func (sc *ServerConn) Read() (req *http.Request, err error) { // returned. err = lastbody.Close() if err != nil { - sc.lk.Lock() - defer sc.lk.Unlock() + sc.mu.Lock() + defer sc.mu.Unlock() sc.re = err return nil, err } } req, err = http.ReadRequest(r) - sc.lk.Lock() - defer sc.lk.Unlock() + sc.mu.Lock() + defer sc.mu.Unlock() if err != nil { if err == io.ErrUnexpectedEOF { // A close from the opposing client is treated as a @@ -161,8 +159,8 @@ func (sc *ServerConn) Read() (req *http.Request, err error) { // Pending returns the number of unanswered requests // that have been received on the connection. func (sc *ServerConn) Pending() int { - sc.lk.Lock() - defer sc.lk.Unlock() + sc.mu.Lock() + defer sc.mu.Unlock() return sc.nread - sc.nwritten } @@ -172,31 +170,31 @@ func (sc *ServerConn) Pending() int { func (sc *ServerConn) Write(req *http.Request, resp *http.Response) error { // Retrieve the pipeline ID of this request/response pair - sc.lk.Lock() + sc.mu.Lock() id, ok := sc.pipereq[req] delete(sc.pipereq, req) if !ok { - sc.lk.Unlock() + sc.mu.Unlock() return ErrPipeline } - sc.lk.Unlock() + sc.mu.Unlock() // Ensure pipeline order sc.pipe.StartResponse(id) defer sc.pipe.EndResponse(id) - sc.lk.Lock() + sc.mu.Lock() if sc.we != nil { - defer sc.lk.Unlock() + defer sc.mu.Unlock() return sc.we } if sc.c == nil { // connection closed by user in the meantime - defer sc.lk.Unlock() + defer sc.mu.Unlock() return ErrClosed } c := sc.c if sc.nread <= sc.nwritten { - defer sc.lk.Unlock() + defer sc.mu.Unlock() return errors.New("persist server pipe count") } if resp.Close { @@ -205,11 +203,11 @@ func (sc *ServerConn) Write(req *http.Request, resp *http.Response) error { // before signaling. sc.re = ErrPersistEOF } - sc.lk.Unlock() + sc.mu.Unlock() err := resp.Write(c) - sc.lk.Lock() - defer sc.lk.Unlock() + sc.mu.Lock() + defer sc.mu.Unlock() if err != nil { sc.we = err return err @@ -219,15 +217,13 @@ func (sc *ServerConn) Write(req *http.Request, resp *http.Response) error { return nil } -// A ClientConn sends request and receives headers over an underlying -// connection, while respecting the HTTP keepalive logic. ClientConn -// supports hijacking the connection calling Hijack to -// regain control of the underlying net.Conn and deal with it as desired. +// ClientConn is an artifact of Go's early HTTP implementation. +// It is low-level, old, and unused by Go's current HTTP stack. +// We should have deleted it before Go 1. // -// ClientConn is low-level and old. Applications should instead use -// Client or Transport in the net/http package. +// Deprecated: Use Client or Transport in package net/http instead. type ClientConn struct { - lk sync.Mutex // read-write protects the following fields + mu sync.Mutex // read-write protects the following fields c net.Conn r *bufio.Reader re, we error // read/write errors @@ -239,11 +235,11 @@ type ClientConn struct { writeReq func(*http.Request, io.Writer) error } -// NewClientConn returns a new ClientConn reading and writing c. If r is not -// nil, it is the buffer to use when reading c. +// NewClientConn is an artifact of Go's early HTTP implementation. +// It is low-level, old, and unused by Go's current HTTP stack. +// We should have deleted it before Go 1. // -// ClientConn is low-level and old. Applications should use Client or -// Transport in the net/http package. +// Deprecated: Use the Client or Transport in package net/http instead. func NewClientConn(c net.Conn, r *bufio.Reader) *ClientConn { if r == nil { r = bufio.NewReader(c) @@ -256,11 +252,11 @@ func NewClientConn(c net.Conn, r *bufio.Reader) *ClientConn { } } -// NewProxyClientConn works like NewClientConn but writes Requests -// using Request's WriteProxy method. +// NewProxyClientConn is an artifact of Go's early HTTP implementation. +// It is low-level, old, and unused by Go's current HTTP stack. +// We should have deleted it before Go 1. // -// New code should not use NewProxyClientConn. See Client or -// Transport in the net/http package instead. +// Deprecated: Use the Client or Transport in package net/http instead. func NewProxyClientConn(c net.Conn, r *bufio.Reader) *ClientConn { cc := NewClientConn(c, r) cc.writeReq = (*http.Request).WriteProxy @@ -272,8 +268,8 @@ func NewProxyClientConn(c net.Conn, r *bufio.Reader) *ClientConn { // called before the user or Read have signaled the end of the keep-alive // logic. The user should not call Hijack while Read or Write is in progress. func (cc *ClientConn) Hijack() (c net.Conn, r *bufio.Reader) { - cc.lk.Lock() - defer cc.lk.Unlock() + cc.mu.Lock() + defer cc.mu.Unlock() c = cc.c r = cc.r cc.c = nil @@ -281,7 +277,7 @@ func (cc *ClientConn) Hijack() (c net.Conn, r *bufio.Reader) { return } -// Close calls Hijack and then also closes the underlying connection +// Close calls Hijack and then also closes the underlying connection. func (cc *ClientConn) Close() error { c, _ := cc.Hijack() if c != nil { @@ -295,7 +291,8 @@ func (cc *ClientConn) Close() error { // keepalive connection is logically closed after this request and the opposing // server is informed. An ErrUnexpectedEOF indicates the remote closed the // underlying TCP connection, which is usually considered as graceful close. -func (cc *ClientConn) Write(req *http.Request) (err error) { +func (cc *ClientConn) Write(req *http.Request) error { + var err error // Ensure ordered execution of Writes id := cc.pipe.Next() @@ -307,23 +304,23 @@ func (cc *ClientConn) Write(req *http.Request) (err error) { cc.pipe.EndResponse(id) } else { // Remember the pipeline id of this request - cc.lk.Lock() + cc.mu.Lock() cc.pipereq[req] = id - cc.lk.Unlock() + cc.mu.Unlock() } }() - cc.lk.Lock() + cc.mu.Lock() if cc.re != nil { // no point sending if read-side closed or broken - defer cc.lk.Unlock() + defer cc.mu.Unlock() return cc.re } if cc.we != nil { - defer cc.lk.Unlock() + defer cc.mu.Unlock() return cc.we } if cc.c == nil { // connection closed by user in the meantime - defer cc.lk.Unlock() + defer cc.mu.Unlock() return errClosed } c := cc.c @@ -332,11 +329,11 @@ func (cc *ClientConn) Write(req *http.Request) (err error) { // still might be some pipelined reads cc.we = ErrPersistEOF } - cc.lk.Unlock() + cc.mu.Unlock() err = cc.writeReq(req, c) - cc.lk.Lock() - defer cc.lk.Unlock() + cc.mu.Lock() + defer cc.mu.Unlock() if err != nil { cc.we = err return err @@ -349,8 +346,8 @@ func (cc *ClientConn) Write(req *http.Request) (err error) { // Pending returns the number of unanswered requests // that have been sent on the connection. func (cc *ClientConn) Pending() int { - cc.lk.Lock() - defer cc.lk.Unlock() + cc.mu.Lock() + defer cc.mu.Unlock() return cc.nwritten - cc.nread } @@ -360,32 +357,32 @@ func (cc *ClientConn) Pending() int { // concurrently with Write, but not with another Read. func (cc *ClientConn) Read(req *http.Request) (resp *http.Response, err error) { // Retrieve the pipeline ID of this request/response pair - cc.lk.Lock() + cc.mu.Lock() id, ok := cc.pipereq[req] delete(cc.pipereq, req) if !ok { - cc.lk.Unlock() + cc.mu.Unlock() return nil, ErrPipeline } - cc.lk.Unlock() + cc.mu.Unlock() // Ensure pipeline order cc.pipe.StartResponse(id) defer cc.pipe.EndResponse(id) - cc.lk.Lock() + cc.mu.Lock() if cc.re != nil { - defer cc.lk.Unlock() + defer cc.mu.Unlock() return nil, cc.re } if cc.r == nil { // connection closed by user in the meantime - defer cc.lk.Unlock() + defer cc.mu.Unlock() return nil, errClosed } r := cc.r lastbody := cc.lastbody cc.lastbody = nil - cc.lk.Unlock() + cc.mu.Unlock() // Make sure body is fully consumed, even if user does not call body.Close if lastbody != nil { @@ -394,16 +391,16 @@ func (cc *ClientConn) Read(req *http.Request) (resp *http.Response, err error) { // returned. err = lastbody.Close() if err != nil { - cc.lk.Lock() - defer cc.lk.Unlock() + cc.mu.Lock() + defer cc.mu.Unlock() cc.re = err return nil, err } } resp, err = http.ReadResponse(r, req) - cc.lk.Lock() - defer cc.lk.Unlock() + cc.mu.Lock() + defer cc.mu.Unlock() if err != nil { cc.re = err return resp, err @@ -420,10 +417,10 @@ func (cc *ClientConn) Read(req *http.Request) (resp *http.Response, err error) { } // Do is convenience method that writes a request and reads a response. -func (cc *ClientConn) Do(req *http.Request) (resp *http.Response, err error) { - err = cc.Write(req) +func (cc *ClientConn) Do(req *http.Request) (*http.Response, error) { + err := cc.Write(req) if err != nil { - return + return nil, err } return cc.Read(req) } diff --git a/libgo/go/net/http/httputil/reverseproxy.go b/libgo/go/net/http/httputil/reverseproxy.go index 54411caeca8..49c120afde1 100644 --- a/libgo/go/net/http/httputil/reverseproxy.go +++ b/libgo/go/net/http/httputil/reverseproxy.go @@ -90,6 +90,10 @@ func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy { } else { req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery } + if _, ok := req.Header["User-Agent"]; !ok { + // explicitly disable User-Agent so it's not set to default value + req.Header.Set("User-Agent", "") + } } return &ReverseProxy{Director: director} } @@ -180,9 +184,9 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { outreq.ProtoMinor = 1 outreq.Close = false - // Remove hop-by-hop headers to the backend. Especially + // Remove hop-by-hop headers to the backend. Especially // important is "Connection" because we want a persistent - // connection, regardless of what the client sent to us. This + // connection, regardless of what the client sent to us. This // is modifying the same underlying map from req (shallow // copied above) so we only copy it if necessary. copiedHeaders := false @@ -210,7 +214,7 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { res, err := transport.RoundTrip(outreq) if err != nil { p.logf("http: proxy error: %v", err) - rw.WriteHeader(http.StatusInternalServerError) + rw.WriteHeader(http.StatusBadGateway) return } @@ -285,13 +289,13 @@ type maxLatencyWriter struct { dst writeFlusher latency time.Duration - lk sync.Mutex // protects Write + Flush + mu sync.Mutex // protects Write + Flush done chan bool } func (m *maxLatencyWriter) Write(p []byte) (int, error) { - m.lk.Lock() - defer m.lk.Unlock() + m.mu.Lock() + defer m.mu.Unlock() return m.dst.Write(p) } @@ -306,9 +310,9 @@ func (m *maxLatencyWriter) flushLoop() { } return case <-t.C: - m.lk.Lock() + m.mu.Lock() m.dst.Flush() - m.lk.Unlock() + m.mu.Unlock() } } } diff --git a/libgo/go/net/http/httputil/reverseproxy_test.go b/libgo/go/net/http/httputil/reverseproxy_test.go index 0849427b85c..fe7cdb888f5 100644 --- a/libgo/go/net/http/httputil/reverseproxy_test.go +++ b/libgo/go/net/http/httputil/reverseproxy_test.go @@ -33,6 +33,11 @@ func TestReverseProxy(t *testing.T) { const backendResponse = "I am the backend" const backendStatus = 404 backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method == "GET" && r.FormValue("mode") == "hangup" { + c, _, _ := w.(http.Hijacker).Hijack() + c.Close() + return + } if len(r.TransferEncoding) > 0 { t.Errorf("backend got unexpected TransferEncoding: %v", r.TransferEncoding) } @@ -69,6 +74,7 @@ func TestReverseProxy(t *testing.T) { t.Fatal(err) } proxyHandler := NewSingleHostReverseProxy(backendURL) + proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests frontend := httptest.NewServer(proxyHandler) defer frontend.Close() @@ -113,6 +119,20 @@ func TestReverseProxy(t *testing.T) { if g, e := res.Trailer.Get("X-Trailer"), "trailer_value"; g != e { t.Errorf("Trailer(X-Trailer) = %q ; want %q", g, e) } + + // Test that a backend failing to be reached or one which doesn't return + // a response results in a StatusBadGateway. + getReq, _ = http.NewRequest("GET", frontend.URL+"/?mode=hangup", nil) + getReq.Close = true + res, err = http.DefaultClient.Do(getReq) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if res.StatusCode != http.StatusBadGateway { + t.Errorf("request to bad proxy = %v; want 502 StatusBadGateway", res.Status) + } + } func TestXForwardedFor(t *testing.T) { @@ -328,6 +348,49 @@ func TestNilBody(t *testing.T) { } } +// Issue 15524 +func TestUserAgentHeader(t *testing.T) { + const explicitUA = "explicit UA" + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path == "/noua" { + if c := r.Header.Get("User-Agent"); c != "" { + t.Errorf("handler got non-empty User-Agent header %q", c) + } + return + } + if c := r.Header.Get("User-Agent"); c != explicitUA { + t.Errorf("handler got unexpected User-Agent header %q", c) + } + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + getReq, _ := http.NewRequest("GET", frontend.URL, nil) + getReq.Header.Set("User-Agent", explicitUA) + getReq.Close = true + res, err := http.DefaultClient.Do(getReq) + if err != nil { + t.Fatalf("Get: %v", err) + } + res.Body.Close() + + getReq, _ = http.NewRequest("GET", frontend.URL+"/noua", nil) + getReq.Header.Set("User-Agent", "") + getReq.Close = true + res, err = http.DefaultClient.Do(getReq) + if err != nil { + t.Fatalf("Get: %v", err) + } + res.Body.Close() +} + type bufferPool struct { get func() []byte put func([]byte) diff --git a/libgo/go/net/http/internal/chunked_test.go b/libgo/go/net/http/internal/chunked_test.go index a136dc99a65..9abe1ab6d9d 100644 --- a/libgo/go/net/http/internal/chunked_test.go +++ b/libgo/go/net/http/internal/chunked_test.go @@ -122,7 +122,7 @@ func TestChunkReaderAllocs(t *testing.T) { byter := bytes.NewReader(buf.Bytes()) bufr := bufio.NewReader(byter) mallocs := testing.AllocsPerRun(100, func() { - byter.Seek(0, 0) + byter.Seek(0, io.SeekStart) bufr.Reset(byter) r := NewChunkedReader(bufr) n, err := io.ReadFull(r, readBuf) diff --git a/libgo/go/net/http/lex.go b/libgo/go/net/http/lex.go deleted file mode 100644 index 52b6481c14e..00000000000 --- a/libgo/go/net/http/lex.go +++ /dev/null @@ -1,183 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package http - -import ( - "strings" - "unicode/utf8" -) - -// This file deals with lexical matters of HTTP - -var isTokenTable = [127]bool{ - '!': true, - '#': true, - '$': true, - '%': true, - '&': true, - '\'': true, - '*': true, - '+': true, - '-': true, - '.': true, - '0': true, - '1': true, - '2': true, - '3': true, - '4': true, - '5': true, - '6': true, - '7': true, - '8': true, - '9': true, - 'A': true, - 'B': true, - 'C': true, - 'D': true, - 'E': true, - 'F': true, - 'G': true, - 'H': true, - 'I': true, - 'J': true, - 'K': true, - 'L': true, - 'M': true, - 'N': true, - 'O': true, - 'P': true, - 'Q': true, - 'R': true, - 'S': true, - 'T': true, - 'U': true, - 'W': true, - 'V': true, - 'X': true, - 'Y': true, - 'Z': true, - '^': true, - '_': true, - '`': true, - 'a': true, - 'b': true, - 'c': true, - 'd': true, - 'e': true, - 'f': true, - 'g': true, - 'h': true, - 'i': true, - 'j': true, - 'k': true, - 'l': true, - 'm': true, - 'n': true, - 'o': true, - 'p': true, - 'q': true, - 'r': true, - 's': true, - 't': true, - 'u': true, - 'v': true, - 'w': true, - 'x': true, - 'y': true, - 'z': true, - '|': true, - '~': true, -} - -func isToken(r rune) bool { - i := int(r) - return i < len(isTokenTable) && isTokenTable[i] -} - -func isNotToken(r rune) bool { - return !isToken(r) -} - -// headerValuesContainsToken reports whether any string in values -// contains the provided token, ASCII case-insensitively. -func headerValuesContainsToken(values []string, token string) bool { - for _, v := range values { - if headerValueContainsToken(v, token) { - return true - } - } - return false -} - -// isOWS reports whether b is an optional whitespace byte, as defined -// by RFC 7230 section 3.2.3. -func isOWS(b byte) bool { return b == ' ' || b == '\t' } - -// trimOWS returns x with all optional whitespace removes from the -// beginning and end. -func trimOWS(x string) string { - // TODO: consider using strings.Trim(x, " \t") instead, - // if and when it's fast enough. See issue 10292. - // But this ASCII-only code will probably always beat UTF-8 - // aware code. - for len(x) > 0 && isOWS(x[0]) { - x = x[1:] - } - for len(x) > 0 && isOWS(x[len(x)-1]) { - x = x[:len(x)-1] - } - return x -} - -// headerValueContainsToken reports whether v (assumed to be a -// 0#element, in the ABNF extension described in RFC 7230 section 7) -// contains token amongst its comma-separated tokens, ASCII -// case-insensitively. -func headerValueContainsToken(v string, token string) bool { - v = trimOWS(v) - if comma := strings.IndexByte(v, ','); comma != -1 { - return tokenEqual(trimOWS(v[:comma]), token) || headerValueContainsToken(v[comma+1:], token) - } - return tokenEqual(v, token) -} - -// lowerASCII returns the ASCII lowercase version of b. -func lowerASCII(b byte) byte { - if 'A' <= b && b <= 'Z' { - return b + ('a' - 'A') - } - return b -} - -// tokenEqual reports whether t1 and t2 are equal, ASCII case-insensitively. -func tokenEqual(t1, t2 string) bool { - if len(t1) != len(t2) { - return false - } - for i, b := range t1 { - if b >= utf8.RuneSelf { - // No UTF-8 or non-ASCII allowed in tokens. - return false - } - if lowerASCII(byte(b)) != lowerASCII(t2[i]) { - return false - } - } - return true -} - -// isLWS reports whether b is linear white space, according -// to http://www.w3.org/Protocols/rfc2616/rfc2616-sec2.html#sec2.2 -// LWS = [CRLF] 1*( SP | HT ) -func isLWS(b byte) bool { return b == ' ' || b == '\t' } - -// isCTL reports whether b is a control byte, according -// to http://www.w3.org/Protocols/rfc2616/rfc2616-sec2.html#sec2.2 -// CTL = <any US-ASCII control character -// (octets 0 - 31) and DEL (127)> -func isCTL(b byte) bool { - const del = 0x7f // a CTL - return b < ' ' || b == del -} diff --git a/libgo/go/net/http/lex_test.go b/libgo/go/net/http/lex_test.go deleted file mode 100644 index 986fda17dcd..00000000000 --- a/libgo/go/net/http/lex_test.go +++ /dev/null @@ -1,101 +0,0 @@ -// Copyright 2009 The Go Authors. All rights reserved. -// Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file. - -package http - -import ( - "testing" -) - -func isChar(c rune) bool { return c <= 127 } - -func isCtl(c rune) bool { return c <= 31 || c == 127 } - -func isSeparator(c rune) bool { - switch c { - case '(', ')', '<', '>', '@', ',', ';', ':', '\\', '"', '/', '[', ']', '?', '=', '{', '}', ' ', '\t': - return true - } - return false -} - -func TestIsToken(t *testing.T) { - for i := 0; i <= 130; i++ { - r := rune(i) - expected := isChar(r) && !isCtl(r) && !isSeparator(r) - if isToken(r) != expected { - t.Errorf("isToken(0x%x) = %v", r, !expected) - } - } -} - -func TestHeaderValuesContainsToken(t *testing.T) { - tests := []struct { - vals []string - token string - want bool - }{ - { - vals: []string{"foo"}, - token: "foo", - want: true, - }, - { - vals: []string{"bar", "foo"}, - token: "foo", - want: true, - }, - { - vals: []string{"foo"}, - token: "FOO", - want: true, - }, - { - vals: []string{"foo"}, - token: "bar", - want: false, - }, - { - vals: []string{" foo "}, - token: "FOO", - want: true, - }, - { - vals: []string{"foo,bar"}, - token: "FOO", - want: true, - }, - { - vals: []string{"bar,foo,bar"}, - token: "FOO", - want: true, - }, - { - vals: []string{"bar , foo"}, - token: "FOO", - want: true, - }, - { - vals: []string{"foo ,bar "}, - token: "FOO", - want: true, - }, - { - vals: []string{"bar, foo ,bar"}, - token: "FOO", - want: true, - }, - { - vals: []string{"bar , foo"}, - token: "FOO", - want: true, - }, - } - for _, tt := range tests { - got := headerValuesContainsToken(tt.vals, tt.token) - if got != tt.want { - t.Errorf("headerValuesContainsToken(%q, %q) = %v; want %v", tt.vals, tt.token, got, tt.want) - } - } -} diff --git a/libgo/go/net/http/main_test.go b/libgo/go/net/http/main_test.go index 299cd7b2d2f..aea6e12744b 100644 --- a/libgo/go/net/http/main_test.go +++ b/libgo/go/net/http/main_test.go @@ -5,7 +5,6 @@ package http_test import ( - "flag" "fmt" "net/http" "os" @@ -16,8 +15,6 @@ import ( "time" ) -var flaky = flag.Bool("flaky", false, "run known-flaky tests too") - func TestMain(m *testing.M) { v := m.Run() if v == 0 && goroutineLeaked() { @@ -91,12 +88,6 @@ func setParallel(t *testing.T) { } } -func setFlaky(t *testing.T, issue int) { - if !*flaky { - t.Skipf("skipping known flaky test; see golang.org/issue/%d", issue) - } -} - func afterTest(t testing.TB) { http.DefaultTransport.(*http.Transport).CloseIdleConnections() if testing.Short() { @@ -129,3 +120,17 @@ func afterTest(t testing.TB) { } t.Errorf("Test appears to have leaked %s:\n%s", bad, stacks) } + +// waitCondition reports whether fn eventually returned true, +// checking immediately and then every checkEvery amount, +// until waitFor has elapsed, at which point it returns false. +func waitCondition(waitFor, checkEvery time.Duration, fn func() bool) bool { + deadline := time.Now().Add(waitFor) + for time.Now().Before(deadline) { + if fn() { + return true + } + time.Sleep(checkEvery) + } + return false +} diff --git a/libgo/go/net/http/method.go b/libgo/go/net/http/method.go index b74f9604d34..6f46155069f 100644 --- a/libgo/go/net/http/method.go +++ b/libgo/go/net/http/method.go @@ -12,7 +12,7 @@ const ( MethodHead = "HEAD" MethodPost = "POST" MethodPut = "PUT" - MethodPatch = "PATCH" // RFC 5741 + MethodPatch = "PATCH" // RFC 5789 MethodDelete = "DELETE" MethodConnect = "CONNECT" MethodOptions = "OPTIONS" diff --git a/libgo/go/net/http/pprof/pprof.go b/libgo/go/net/http/pprof/pprof.go index 7262c6c1016..05d0890fdf3 100644 --- a/libgo/go/net/http/pprof/pprof.go +++ b/libgo/go/net/http/pprof/pprof.go @@ -1,11 +1,9 @@ -// Copyright 2010 The Go Authors. All rights reserved. +// Copyright 2010 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // Package pprof serves via its HTTP server runtime profiling data // in the format expected by the pprof visualization tool. -// For more information about pprof, see -// http://code.google.com/p/google-perftools/. // // The package is typically only imported for the side effect of // registering its HTTP handlers. @@ -15,7 +13,7 @@ // import _ "net/http/pprof" // // If your application is not already running an http server, you -// need to start one. Add "net/http" and "log" to your imports and +// need to start one. Add "net/http" and "log" to your imports and // the following code to your main function: // // go func() { @@ -30,7 +28,8 @@ // // go tool pprof http://localhost:6060/debug/pprof/profile // -// Or to look at the goroutine blocking profile: +// Or to look at the goroutine blocking profile, after calling +// runtime.SetBlockProfileRate in your program: // // go tool pprof http://localhost:6060/debug/pprof/block // @@ -118,8 +117,8 @@ func Profile(w http.ResponseWriter, r *http.Request) { // Tracing lasts for duration specified in seconds GET parameter, or for 1 second if not specified. // The package initialization registers it as /debug/pprof/trace. func Trace(w http.ResponseWriter, r *http.Request) { - sec, _ := strconv.ParseInt(r.FormValue("seconds"), 10, 64) - if sec == 0 { + sec, err := strconv.ParseFloat(r.FormValue("seconds"), 64) + if sec <= 0 || err != nil { sec = 1 } @@ -127,18 +126,16 @@ func Trace(w http.ResponseWriter, r *http.Request) { // because if it does it starts writing. w.Header().Set("Content-Type", "application/octet-stream") w.Write([]byte("tracing not yet supported with gccgo")) - /* - if err := trace.Start(w); err != nil { - // trace.Start failed, so no writes yet. - // Can change header back to text content and send error code. - w.Header().Set("Content-Type", "text/plain; charset=utf-8") - w.WriteHeader(http.StatusInternalServerError) - fmt.Fprintf(w, "Could not enable tracing: %s\n", err) - return - } - sleep(w, time.Duration(sec)*time.Second) - trace.Stop() - */ + // if err := trace.Start(w); err != nil { + // // trace.Start failed, so no writes yet. + // // Can change header back to text content and send error code. + // w.Header().Set("Content-Type", "text/plain; charset=utf-8") + // w.WriteHeader(http.StatusInternalServerError) + // fmt.Fprintf(w, "Could not enable tracing: %s\n", err) + // return + // } + // sleep(w, time.Duration(sec*float64(time.Second))) + // trace.Stop() } // Symbol looks up the program counters listed in the request, @@ -148,11 +145,11 @@ func Symbol(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/plain; charset=utf-8") // We have to read the whole POST body before - // writing any output. Buffer the output here. + // writing any output. Buffer the output here. var buf bytes.Buffer // We don't know how many symbols we have, but we - // do have symbol information. Pprof only cares whether + // do have symbol information. Pprof only cares whether // this number is 0 (no symbols available) or > 0. fmt.Fprintf(&buf, "num_symbols: 1\n") diff --git a/libgo/go/net/http/readrequest_test.go b/libgo/go/net/http/readrequest_test.go index 60e2be41d17..4bf646b0a63 100644 --- a/libgo/go/net/http/readrequest_test.go +++ b/libgo/go/net/http/readrequest_test.go @@ -1,4 +1,4 @@ -// Copyright 2010 The Go Authors. All rights reserved. +// Copyright 2010 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. @@ -380,6 +380,27 @@ var reqTests = []reqTest{ noTrailer, noError, }, + + // http2 client preface: + { + "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n", + &Request{ + Method: "PRI", + URL: &url.URL{ + Path: "*", + }, + Header: Header{}, + Proto: "HTTP/2.0", + ProtoMajor: 2, + ProtoMinor: 0, + RequestURI: "*", + ContentLength: -1, + Close: true, + }, + noBody, + noTrailer, + noError, + }, } func TestReadRequest(t *testing.T) { diff --git a/libgo/go/net/http/request.go b/libgo/go/net/http/request.go index 8cdab02af5a..dc5559282d0 100644 --- a/libgo/go/net/http/request.go +++ b/libgo/go/net/http/request.go @@ -9,6 +9,7 @@ package http import ( "bufio" "bytes" + "context" "crypto/tls" "encoding/base64" "errors" @@ -17,6 +18,7 @@ import ( "io/ioutil" "mime" "mime/multipart" + "net/http/httptrace" "net/textproto" "net/url" "strconv" @@ -247,7 +249,52 @@ type Request struct { // RoundTripper may support Cancel. // // For server requests, this field is not applicable. + // + // Deprecated: Use the Context and WithContext methods + // instead. If a Request's Cancel field and context are both + // set, it is undefined whether Cancel is respected. Cancel <-chan struct{} + + // Response is the redirect response which caused this request + // to be created. This field is only populated during client + // redirects. + Response *Response + + // ctx is either the client or server context. It should only + // be modified via copying the whole Request using WithContext. + // It is unexported to prevent people from using Context wrong + // and mutating the contexts held by callers of the same request. + ctx context.Context +} + +// Context returns the request's context. To change the context, use +// WithContext. +// +// The returned context is always non-nil; it defaults to the +// background context. +// +// For outgoing client requests, the context controls cancelation. +// +// For incoming server requests, the context is canceled when the +// ServeHTTP method returns. For its associated values, see +// ServerContextKey and LocalAddrContextKey. +func (r *Request) Context() context.Context { + if r.ctx != nil { + return r.ctx + } + return context.Background() +} + +// WithContext returns a shallow copy of r with its context changed +// to ctx. The provided ctx must be non-nil. +func (r *Request) WithContext(ctx context.Context) *Request { + if ctx == nil { + panic("nil context") + } + r2 := new(Request) + *r2 = *r + r2.ctx = ctx + return r2 } // ProtoAtLeast reports whether the HTTP protocol used @@ -279,8 +326,8 @@ func (r *Request) Cookie(name string) (*Cookie, error) { return nil, ErrNoCookie } -// AddCookie adds a cookie to the request. Per RFC 6265 section 5.4, -// AddCookie does not attach more than one Cookie header field. That +// AddCookie adds a cookie to the request. Per RFC 6265 section 5.4, +// AddCookie does not attach more than one Cookie header field. That // means all cookies, if any, are written into the same line, // separated by semicolon. func (r *Request) AddCookie(c *Cookie) { @@ -343,6 +390,12 @@ func (r *Request) multipartReader() (*multipart.Reader, error) { return multipart.NewReader(r.Body, boundary), nil } +// isH2Upgrade reports whether r represents the http2 "client preface" +// magic string. +func (r *Request) isH2Upgrade() bool { + return r.Method == "PRI" && len(r.Header) == 0 && r.URL.Path == "*" && r.Proto == "HTTP/2.0" +} + // Return value if nonempty, def otherwise. func valueOrDefault(value, def string) string { if value != "" { @@ -375,7 +428,7 @@ func (r *Request) Write(w io.Writer) error { } // WriteProxy is like Write but writes the request in the form -// expected by an HTTP proxy. In particular, WriteProxy writes the +// expected by an HTTP proxy. In particular, WriteProxy writes the // initial Request-URI line of the request with an absolute URI, per // section 5.1.2 of RFC 2616, including the scheme and host. // In either case, WriteProxy also writes a Host header, using @@ -390,7 +443,16 @@ var errMissingHost = errors.New("http: Request.Write on Request with no Host or // extraHeaders may be nil // waitForContinue may be nil -func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, waitForContinue func() bool) error { +func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, waitForContinue func() bool) (err error) { + trace := httptrace.ContextClientTrace(req.Context()) + if trace != nil && trace.WroteRequest != nil { + defer func() { + trace.WroteRequest(httptrace.WroteRequestInfo{ + Err: err, + }) + }() + } + // Find the target host. Prefer the Host: header, but if that // is not given, use the host from the request URL. // @@ -427,7 +489,7 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, wai w = bw } - _, err := fmt.Fprintf(w, "%s %s HTTP/1.1\r\n", valueOrDefault(req.Method, "GET"), ruri) + _, err = fmt.Fprintf(w, "%s %s HTTP/1.1\r\n", valueOrDefault(req.Method, "GET"), ruri) if err != nil { return err } @@ -478,6 +540,10 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, wai return err } + if trace != nil && trace.WroteHeaders != nil { + trace.WroteHeaders() + } + // Flush and wait for 100-continue if expected. if waitForContinue != nil { if bw, ok := w.(*bufio.Writer); ok { @@ -486,7 +552,9 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, wai return err } } - + if trace != nil && trace.Wait100Continue != nil { + trace.Wait100Continue() + } if !waitForContinue() { req.closeBody() return nil @@ -521,7 +589,7 @@ func cleanHost(in string) string { return in } -// removeZone removes IPv6 zone identifer from host. +// removeZone removes IPv6 zone identifier from host. // E.g., "[fe80::1%en0]:8080" to "[fe80::1]:8080" func removeZone(host string) string { if !strings.HasPrefix(host, "[") { @@ -613,6 +681,8 @@ func NewRequest(method, urlStr string, body io.Reader) (*Request, error) { if !ok && body != nil { rc = ioutil.NopCloser(body) } + // The host's colon:port should be normalized. See Issue 14836. + u.Host = removeEmptyPort(u.Host) req := &Request{ Method: method, URL: u, @@ -704,7 +774,9 @@ func putTextprotoReader(r *textproto.Reader) { } // ReadRequest reads and parses an incoming request from b. -func ReadRequest(b *bufio.Reader) (req *Request, err error) { return readRequest(b, deleteHostHeader) } +func ReadRequest(b *bufio.Reader) (*Request, error) { + return readRequest(b, deleteHostHeader) +} // Constants for readRequest's deleteHostHeader parameter. const ( @@ -768,13 +840,13 @@ func readRequest(b *bufio.Reader, deleteHostHeader bool) (req *Request, err erro } req.Header = Header(mimeHeader) - // RFC2616: Must treat + // RFC 2616: Must treat // GET /index.html HTTP/1.1 // Host: www.google.com // and // GET http://www.google.com/index.html HTTP/1.1 // Host: doesntmatter - // the same. In the second case, any Host line is ignored. + // the same. In the second case, any Host line is ignored. req.Host = req.URL.Host if req.Host == "" { req.Host = req.Header.get("Host") @@ -792,6 +864,16 @@ func readRequest(b *bufio.Reader, deleteHostHeader bool) (req *Request, err erro return nil, err } + if req.isH2Upgrade() { + // Because it's neither chunked, nor declared: + req.ContentLength = -1 + + // We want to give handlers a chance to hijack the + // connection, but we need to prevent the Server from + // dealing with the connection further if it's not + // hijacked. Set Close to ensure that: + req.Close = true + } return req, nil } @@ -808,57 +890,56 @@ func MaxBytesReader(w ResponseWriter, r io.ReadCloser, n int64) io.ReadCloser { } type maxBytesReader struct { - w ResponseWriter - r io.ReadCloser // underlying reader - n int64 // max bytes remaining - stopped bool - sawEOF bool + w ResponseWriter + r io.ReadCloser // underlying reader + n int64 // max bytes remaining + err error // sticky error } func (l *maxBytesReader) tooLarge() (n int, err error) { - if !l.stopped { - l.stopped = true - if res, ok := l.w.(*response); ok { - res.requestTooLarge() - } - } - return 0, errors.New("http: request body too large") + l.err = errors.New("http: request body too large") + return 0, l.err } func (l *maxBytesReader) Read(p []byte) (n int, err error) { - toRead := l.n - if l.n == 0 { - if l.sawEOF { - return l.tooLarge() - } - // The underlying io.Reader may not return (0, io.EOF) - // at EOF if the requested size is 0, so read 1 byte - // instead. The io.Reader docs are a bit ambiguous - // about the return value of Read when 0 bytes are - // requested, and {bytes,strings}.Reader gets it wrong - // too (it returns (0, nil) even at EOF). - toRead = 1 + if l.err != nil { + return 0, l.err } - if int64(len(p)) > toRead { - p = p[:toRead] + if len(p) == 0 { + return 0, nil + } + // If they asked for a 32KB byte read but only 5 bytes are + // remaining, no need to read 32KB. 6 bytes will answer the + // question of the whether we hit the limit or go past it. + if int64(len(p)) > l.n+1 { + p = p[:l.n+1] } n, err = l.r.Read(p) - if err == io.EOF { - l.sawEOF = true - } - if l.n == 0 { - // If we had zero bytes to read remaining (but hadn't seen EOF) - // and we get a byte here, that means we went over our limit. - if n > 0 { - return l.tooLarge() - } - return 0, err + + if int64(n) <= l.n { + l.n -= int64(n) + l.err = err + return n, err } - l.n -= int64(n) - if l.n < 0 { - l.n = 0 + + n = int(l.n) + l.n = 0 + + // The server code and client code both use + // maxBytesReader. This "requestTooLarge" check is + // only used by the server code. To prevent binaries + // which only using the HTTP Client code (such as + // cmd/go) from also linking in the HTTP server, don't + // use a static type assertion to the server + // "*response" type. Check this interface instead: + type requestTooLarger interface { + requestTooLarge() } - return + if res, ok := l.w.(requestTooLarger); ok { + res.requestTooLarge() + } + l.err = errors.New("http: request body too large") + return n, l.err } func (l *maxBytesReader) Close() error { @@ -995,9 +1076,16 @@ func (r *Request) ParseMultipartForm(maxMemory int64) error { if err != nil { return err } + + if r.PostForm == nil { + r.PostForm = make(url.Values) + } for k, v := range f.Value { r.Form[k] = append(r.Form[k], v...) + // r.PostForm should also be populated. See Issue 9305. + r.PostForm[k] = append(r.PostForm[k], v...) } + r.MultipartForm = f return nil @@ -1086,92 +1174,3 @@ func (r *Request) isReplayable() bool { } return false } - -func validHostHeader(h string) bool { - // The latests spec is actually this: - // - // http://tools.ietf.org/html/rfc7230#section-5.4 - // Host = uri-host [ ":" port ] - // - // Where uri-host is: - // http://tools.ietf.org/html/rfc3986#section-3.2.2 - // - // But we're going to be much more lenient for now and just - // search for any byte that's not a valid byte in any of those - // expressions. - for i := 0; i < len(h); i++ { - if !validHostByte[h[i]] { - return false - } - } - return true -} - -// See the validHostHeader comment. -var validHostByte = [256]bool{ - '0': true, '1': true, '2': true, '3': true, '4': true, '5': true, '6': true, '7': true, - '8': true, '9': true, - - 'a': true, 'b': true, 'c': true, 'd': true, 'e': true, 'f': true, 'g': true, 'h': true, - 'i': true, 'j': true, 'k': true, 'l': true, 'm': true, 'n': true, 'o': true, 'p': true, - 'q': true, 'r': true, 's': true, 't': true, 'u': true, 'v': true, 'w': true, 'x': true, - 'y': true, 'z': true, - - 'A': true, 'B': true, 'C': true, 'D': true, 'E': true, 'F': true, 'G': true, 'H': true, - 'I': true, 'J': true, 'K': true, 'L': true, 'M': true, 'N': true, 'O': true, 'P': true, - 'Q': true, 'R': true, 'S': true, 'T': true, 'U': true, 'V': true, 'W': true, 'X': true, - 'Y': true, 'Z': true, - - '!': true, // sub-delims - '$': true, // sub-delims - '%': true, // pct-encoded (and used in IPv6 zones) - '&': true, // sub-delims - '(': true, // sub-delims - ')': true, // sub-delims - '*': true, // sub-delims - '+': true, // sub-delims - ',': true, // sub-delims - '-': true, // unreserved - '.': true, // unreserved - ':': true, // IPv6address + Host expression's optional port - ';': true, // sub-delims - '=': true, // sub-delims - '[': true, - '\'': true, // sub-delims - ']': true, - '_': true, // unreserved - '~': true, // unreserved -} - -func validHeaderName(v string) bool { - if len(v) == 0 { - return false - } - return strings.IndexFunc(v, isNotToken) == -1 -} - -// validHeaderValue reports whether v is a valid "field-value" according to -// http://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2 : -// -// message-header = field-name ":" [ field-value ] -// field-value = *( field-content | LWS ) -// field-content = <the OCTETs making up the field-value -// and consisting of either *TEXT or combinations -// of token, separators, and quoted-string> -// -// http://www.w3.org/Protocols/rfc2616/rfc2616-sec2.html#sec2.2 : -// -// TEXT = <any OCTET except CTLs, -// but including LWS> -// LWS = [CRLF] 1*( SP | HT ) -// CTL = <any US-ASCII control character -// (octets 0 - 31) and DEL (127)> -func validHeaderValue(v string) bool { - for i := 0; i < len(v); i++ { - b := v[i] - if isCTL(b) && !isLWS(b) { - return false - } - } - return true -} diff --git a/libgo/go/net/http/request_test.go b/libgo/go/net/http/request_test.go index 0ecdf85a563..a4c88c02915 100644 --- a/libgo/go/net/http/request_test.go +++ b/libgo/go/net/http/request_test.go @@ -99,7 +99,7 @@ type parseContentTypeTest struct { var parseContentTypeTests = []parseContentTypeTest{ {false, stringMap{"Content-Type": {"text/plain"}}}, - // Empty content type is legal - shoult be treated as + // Empty content type is legal - should be treated as // application/octet-stream (RFC 2616, section 7.2.1) {false, stringMap{}}, {true, stringMap{"Content-Type": {"text/plain; boundary="}}}, @@ -158,6 +158,68 @@ func TestMultipartReader(t *testing.T) { } } +// Issue 9305: ParseMultipartForm should populate PostForm too +func TestParseMultipartFormPopulatesPostForm(t *testing.T) { + postData := + `--xxx +Content-Disposition: form-data; name="field1" + +value1 +--xxx +Content-Disposition: form-data; name="field2" + +value2 +--xxx +Content-Disposition: form-data; name="file"; filename="file" +Content-Type: application/octet-stream +Content-Transfer-Encoding: binary + +binary data +--xxx-- +` + req := &Request{ + Method: "POST", + Header: Header{"Content-Type": {`multipart/form-data; boundary=xxx`}}, + Body: ioutil.NopCloser(strings.NewReader(postData)), + } + + initialFormItems := map[string]string{ + "language": "Go", + "name": "gopher", + "skill": "go-ing", + "field2": "initial-value2", + } + + req.Form = make(url.Values) + for k, v := range initialFormItems { + req.Form.Add(k, v) + } + + err := req.ParseMultipartForm(10000) + if err != nil { + t.Fatalf("unexpected multipart error %v", err) + } + + wantForm := url.Values{ + "language": []string{"Go"}, + "name": []string{"gopher"}, + "skill": []string{"go-ing"}, + "field1": []string{"value1"}, + "field2": []string{"initial-value2", "value2"}, + } + if !reflect.DeepEqual(req.Form, wantForm) { + t.Fatalf("req.Form = %v, want %v", req.Form, wantForm) + } + + wantPostForm := url.Values{ + "field1": []string{"value1"}, + "field2": []string{"value2"}, + } + if !reflect.DeepEqual(req.PostForm, wantPostForm) { + t.Fatalf("req.PostForm = %v, want %v", req.PostForm, wantPostForm) + } +} + func TestParseMultipartForm(t *testing.T) { req := &Request{ Method: "POST", @@ -336,11 +398,13 @@ var newRequestHostTests = []struct { {"http://192.168.0.1/", "192.168.0.1"}, {"http://192.168.0.1:8080/", "192.168.0.1:8080"}, + {"http://192.168.0.1:/", "192.168.0.1"}, {"http://[fe80::1]/", "[fe80::1]"}, {"http://[fe80::1]:8080/", "[fe80::1]:8080"}, {"http://[fe80::1%25en0]/", "[fe80::1%en0]"}, {"http://[fe80::1%25en0]:8080/", "[fe80::1%en0]:8080"}, + {"http://[fe80::1%25en0]:/", "[fe80::1%en0]"}, } func TestNewRequestHost(t *testing.T) { @@ -615,6 +679,46 @@ func TestIssue10884_MaxBytesEOF(t *testing.T) { } } +// Issue 14981: MaxBytesReader's return error wasn't sticky. It +// doesn't technically need to be, but people expected it to be. +func TestMaxBytesReaderStickyError(t *testing.T) { + isSticky := func(r io.Reader) error { + var log bytes.Buffer + buf := make([]byte, 1000) + var firstErr error + for { + n, err := r.Read(buf) + fmt.Fprintf(&log, "Read(%d) = %d, %v\n", len(buf), n, err) + if err == nil { + continue + } + if firstErr == nil { + firstErr = err + continue + } + if !reflect.DeepEqual(err, firstErr) { + return fmt.Errorf("non-sticky error. got log:\n%s", log.Bytes()) + } + t.Logf("Got log: %s", log.Bytes()) + return nil + } + } + tests := [...]struct { + readable int + limit int64 + }{ + 0: {99, 100}, + 1: {100, 100}, + 2: {101, 100}, + } + for i, tt := range tests { + rc := MaxBytesReader(nil, ioutil.NopCloser(bytes.NewReader(make([]byte, tt.readable))), tt.limit) + if err := isSticky(rc); err != nil { + t.Errorf("%d. error: %v", i, err) + } + } +} + func testMissingFile(t *testing.T, req *Request) { f, fh, err := req.FormFile("missing") if f != nil { diff --git a/libgo/go/net/http/requestwrite_test.go b/libgo/go/net/http/requestwrite_test.go index cfb95b0a800..2545f6f4c22 100644 --- a/libgo/go/net/http/requestwrite_test.go +++ b/libgo/go/net/http/requestwrite_test.go @@ -1,4 +1,4 @@ -// Copyright 2010 The Go Authors. All rights reserved. +// Copyright 2010 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. @@ -573,7 +573,7 @@ func TestRequestWriteClosesBody(t *testing.T) { "Transfer-Encoding: chunked\r\n\r\n" + // TODO: currently we don't buffer before chunking, so we get a // single "m" chunk before the other chunks, as this was the 1-byte - // read from our MultiReader where we stiched the Body back together + // read from our MultiReader where we stitched the Body back together // after sniffing whether the Body was 0 bytes or not. chunk("m") + chunk("y body") + @@ -604,7 +604,7 @@ func TestRequestWriteError(t *testing.T) { failAfter, writeCount := 0, 0 errFail := errors.New("fake write failure") - // w is the buffered io.Writer to write the request to. It + // w is the buffered io.Writer to write the request to. It // fails exactly once on its Nth Write call, as controlled by // failAfter. It also tracks the number of calls in // writeCount. diff --git a/libgo/go/net/http/response.go b/libgo/go/net/http/response.go index c424f61cd00..5450d50c3ce 100644 --- a/libgo/go/net/http/response.go +++ b/libgo/go/net/http/response.go @@ -11,6 +11,7 @@ import ( "bytes" "crypto/tls" "errors" + "fmt" "io" "net/textproto" "net/url" @@ -33,7 +34,7 @@ type Response struct { ProtoMajor int // e.g. 1 ProtoMinor int // e.g. 0 - // Header maps header keys to values. If the response had multiple + // Header maps header keys to values. If the response had multiple // headers with the same key, they may be concatenated, with comma // delimiters. (Section 4.2 of RFC 2616 requires that multiple headers // be semantically equivalent to a comma-delimited sequence.) Values @@ -57,8 +58,8 @@ type Response struct { // with a "chunked" Transfer-Encoding. Body io.ReadCloser - // ContentLength records the length of the associated content. The - // value -1 indicates that the length is unknown. Unless Request.Method + // ContentLength records the length of the associated content. The + // value -1 indicates that the length is unknown. Unless Request.Method // is "HEAD", values >= 0 indicate that the given number of bytes may // be read from Body. ContentLength int64 @@ -68,10 +69,19 @@ type Response struct { TransferEncoding []string // Close records whether the header directed that the connection be - // closed after reading Body. The value is advice for clients: neither + // closed after reading Body. The value is advice for clients: neither // ReadResponse nor Response.Write ever closes a connection. Close bool + // Uncompressed reports whether the response was sent compressed but + // was decompressed by the http package. When true, reading from + // Body yields the uncompressed content instead of the compressed + // content actually set from the server, ContentLength is set to -1, + // and the "Content-Length" and "Content-Encoding" fields are deleted + // from the responseHeader. To get the original response from + // the server, set Transport.DisableCompression to true. + Uncompressed bool + // Trailer maps trailer keys to values in the same // format as Header. // @@ -86,7 +96,7 @@ type Response struct { // any trailer values sent by the server. Trailer Header - // The Request that was sent to obtain this Response. + // Request is the request that was sent to obtain this Response. // Request's Body is nil (having already been consumed). // This is only populated for Client requests. Request *Request @@ -108,8 +118,8 @@ func (r *Response) Cookies() []*Cookie { var ErrNoLocation = errors.New("http: no Location header in response") // Location returns the URL of the response's "Location" header, -// if present. Relative redirects are resolved relative to -// the Response's Request. ErrNoLocation is returned if no +// if present. Relative redirects are resolved relative to +// the Response's Request. ErrNoLocation is returned if no // Location header is present. func (r *Response) Location() (*url.URL, error) { lv := r.Header.Get("Location") @@ -184,7 +194,7 @@ func ReadResponse(r *bufio.Reader, req *Request) (*Response, error) { return resp, nil } -// RFC2616: Should treat +// RFC 2616: Should treat // Pragma: no-cache // like // Cache-Control: no-cache @@ -203,7 +213,7 @@ func (r *Response) ProtoAtLeast(major, minor int) bool { r.ProtoMajor == major && r.ProtoMinor >= minor } -// Write writes r to w in the HTTP/1.n server response format, +// Write writes r to w in the HTTP/1.x server response format, // including the status line, headers, body, and optional trailer. // // This method consults the following fields of the response r: @@ -228,11 +238,13 @@ func (r *Response) Write(w io.Writer) error { if !ok { text = "status code " + strconv.Itoa(r.StatusCode) } + } else { + // Just to reduce stutter, if user set r.Status to "200 OK" and StatusCode to 200. + // Not important. + text = strings.TrimPrefix(text, strconv.Itoa(r.StatusCode)+" ") } - protoMajor, protoMinor := strconv.Itoa(r.ProtoMajor), strconv.Itoa(r.ProtoMinor) - statusCode := strconv.Itoa(r.StatusCode) + " " - text = strings.TrimPrefix(text, statusCode) - if _, err := io.WriteString(w, "HTTP/"+protoMajor+"."+protoMinor+" "+statusCode+text+"\r\n"); err != nil { + + if _, err := fmt.Fprintf(w, "HTTP/%d.%d %03d %s\r\n", r.ProtoMajor, r.ProtoMinor, r.StatusCode, text); err != nil { return err } @@ -265,7 +277,7 @@ func (r *Response) Write(w io.Writer) error { // content-length, the only way to do that is the old HTTP/1.0 // way, by noting the EOF with a connection close, so we need // to set Close. - if r1.ContentLength == -1 && !r1.Close && r1.ProtoAtLeast(1, 1) && !chunked(r1.TransferEncoding) { + if r1.ContentLength == -1 && !r1.Close && r1.ProtoAtLeast(1, 1) && !chunked(r1.TransferEncoding) && !r1.Uncompressed { r1.Close = true } diff --git a/libgo/go/net/http/response_test.go b/libgo/go/net/http/response_test.go index d8a53400cf2..126da927355 100644 --- a/libgo/go/net/http/response_test.go +++ b/libgo/go/net/http/response_test.go @@ -1,4 +1,4 @@ -// Copyright 2010 The Go Authors. All rights reserved. +// Copyright 2010 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. @@ -10,6 +10,7 @@ import ( "compress/gzip" "crypto/rand" "fmt" + "go/ast" "io" "io/ioutil" "net/http/internal" @@ -505,6 +506,32 @@ some body`, "Body here\n", }, + + { + "HTTP/1.1 200 OK\r\n" + + "Content-Encoding: gzip\r\n" + + "Content-Length: 23\r\n" + + "Connection: keep-alive\r\n" + + "Keep-Alive: timeout=7200\r\n\r\n" + + "\x1f\x8b\b\x00\x00\x00\x00\x00\x00\x00s\xf3\xf7\a\x00\xab'\xd4\x1a\x03\x00\x00\x00", + Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Request: dummyReq("GET"), + Header: Header{ + "Content-Length": {"23"}, + "Content-Encoding": {"gzip"}, + "Connection": {"keep-alive"}, + "Keep-Alive": {"timeout=7200"}, + }, + Close: false, + ContentLength: 23, + }, + "\x1f\x8b\b\x00\x00\x00\x00\x00\x00\x00s\xf3\xf7\a\x00\xab'\xd4\x1a\x03\x00\x00\x00", + }, } // tests successful calls to ReadResponse, and inspects the returned Response. @@ -656,10 +683,14 @@ func diff(t *testing.T, prefix string, have, want interface{}) { t.Errorf("%s: type mismatch %v want %v", prefix, hv.Type(), wv.Type()) } for i := 0; i < hv.NumField(); i++ { + name := hv.Type().Field(i).Name + if !ast.IsExported(name) { + continue + } hf := hv.Field(i).Interface() wf := wv.Field(i).Interface() if !reflect.DeepEqual(hf, wf) { - t.Errorf("%s: %s = %v want %v", prefix, hv.Type().Field(i).Name, hf, wf) + t.Errorf("%s: %s = %v want %v", prefix, name, hf, wf) } } } diff --git a/libgo/go/net/http/responsewrite_test.go b/libgo/go/net/http/responsewrite_test.go index 5b8d47ab581..90f6767d96b 100644 --- a/libgo/go/net/http/responsewrite_test.go +++ b/libgo/go/net/http/responsewrite_test.go @@ -1,4 +1,4 @@ -// Copyright 2010 The Go Authors. All rights reserved. +// Copyright 2010 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. @@ -222,6 +222,39 @@ func TestResponseWrite(t *testing.T) { }, "HTTP/1.1 200 OK\r\nConnection: close\r\n\r\nabcdef", }, + + // Status code under 100 should be zero-padded to + // three digits. Still bogus, but less bogus. (be + // consistent with generating three digits, since the + // Transport requires it) + { + Response{ + StatusCode: 7, + Status: "license to violate specs", + ProtoMajor: 1, + ProtoMinor: 0, + Request: dummyReq("GET"), + Header: Header{}, + Body: nil, + }, + + "HTTP/1.0 007 license to violate specs\r\nContent-Length: 0\r\n\r\n", + }, + + // No stutter. + { + Response{ + StatusCode: 123, + Status: "123 Sesame Street", + ProtoMajor: 1, + ProtoMinor: 0, + Request: dummyReq("GET"), + Header: Header{}, + Body: nil, + }, + + "HTTP/1.0 123 Sesame Street\r\nContent-Length: 0\r\n\r\n", + }, } for i := range respWriteTests { diff --git a/libgo/go/net/http/serve_test.go b/libgo/go/net/http/serve_test.go index 384b453ce0a..139ce3eafc7 100644 --- a/libgo/go/net/http/serve_test.go +++ b/libgo/go/net/http/serve_test.go @@ -9,7 +9,10 @@ package http_test import ( "bufio" "bytes" + "compress/gzip" + "context" "crypto/tls" + "encoding/json" "errors" "fmt" "internal/testenv" @@ -617,7 +620,7 @@ func TestIdentityResponse(t *testing.T) { defer ts.Close() // Note: this relies on the assumption (which is true) that - // Get sends HTTP/1.1 or greater requests. Otherwise the + // Get sends HTTP/1.1 or greater requests. Otherwise the // server wouldn't have the choice to send back chunked // responses. for _, te := range []string{"", "identity"} { @@ -713,6 +716,31 @@ func testTCPConnectionCloses(t *testing.T, req string, h Handler) { } } +func testTCPConnectionStaysOpen(t *testing.T, req string, handler Handler) { + defer afterTest(t) + ts := httptest.NewServer(handler) + defer ts.Close() + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + br := bufio.NewReader(conn) + for i := 0; i < 2; i++ { + if _, err := io.WriteString(conn, req); err != nil { + t.Fatal(err) + } + res, err := ReadResponse(br, nil) + if err != nil { + t.Fatalf("res %d: %v", i+1, err) + } + if _, err := io.Copy(ioutil.Discard, res.Body); err != nil { + t.Fatalf("res %d body copy: %v", i+1, err) + } + res.Body.Close() + } +} + // TestServeHTTP10Close verifies that HTTP/1.0 requests won't be kept alive. func TestServeHTTP10Close(t *testing.T) { testTCPConnectionCloses(t, "GET / HTTP/1.0\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) { @@ -741,6 +769,61 @@ func TestHandlersCanSetConnectionClose10(t *testing.T) { })) } +func TestHTTP2UpgradeClosesConnection(t *testing.T) { + testTCPConnectionCloses(t, "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) { + // Nothing. (if not hijacked, the server should close the connection + // afterwards) + })) +} + +func send204(w ResponseWriter, r *Request) { w.WriteHeader(204) } +func send304(w ResponseWriter, r *Request) { w.WriteHeader(304) } + +// Issue 15647: 204 responses can't have bodies, so HTTP/1.0 keep-alive conns should stay open. +func TestHTTP10KeepAlive204Response(t *testing.T) { + testTCPConnectionStaysOpen(t, "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", HandlerFunc(send204)) +} + +func TestHTTP11KeepAlive204Response(t *testing.T) { + testTCPConnectionStaysOpen(t, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n", HandlerFunc(send204)) +} + +func TestHTTP10KeepAlive304Response(t *testing.T) { + testTCPConnectionStaysOpen(t, + "GET / HTTP/1.0\r\nConnection: keep-alive\r\nIf-Modified-Since: Mon, 02 Jan 2006 15:04:05 GMT\r\n\r\n", + HandlerFunc(send304)) +} + +// Issue 15703 +func TestKeepAliveFinalChunkWithEOF(t *testing.T) { + defer afterTest(t) + cst := newClientServerTest(t, false /* h1 */, HandlerFunc(func(w ResponseWriter, r *Request) { + w.(Flusher).Flush() // force chunked encoding + w.Write([]byte("{\"Addr\": \"" + r.RemoteAddr + "\"}")) + })) + defer cst.close() + type data struct { + Addr string + } + var addrs [2]data + for i := range addrs { + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + t.Fatal(err) + } + if err := json.NewDecoder(res.Body).Decode(&addrs[i]); err != nil { + t.Fatal(err) + } + if addrs[i].Addr == "" { + t.Fatal("no address") + } + res.Body.Close() + } + if addrs[0] != addrs[1] { + t.Fatalf("connection not reused") + } +} + func TestSetsRemoteAddr_h1(t *testing.T) { testSetsRemoteAddr(t, h1Mode) } func TestSetsRemoteAddr_h2(t *testing.T) { testSetsRemoteAddr(t, h2Mode) } @@ -984,7 +1067,7 @@ func TestTLSServer(t *testing.T) { defer ts.Close() // Connect an idle TCP connection to this server before we run - // our real tests. This idle connection used to block forever + // our real tests. This idle connection used to block forever // in the TLS handshake, preventing future connections from // being accepted. It may prevent future accidental blocking // in newConn. @@ -1024,11 +1107,44 @@ func TestTLSServer(t *testing.T) { }) } -func TestAutomaticHTTP2_Serve(t *testing.T) { +// Issue 15908 +func TestAutomaticHTTP2_Serve_NoTLSConfig(t *testing.T) { + testAutomaticHTTP2_Serve(t, nil, true) +} + +func TestAutomaticHTTP2_Serve_NonH2TLSConfig(t *testing.T) { + testAutomaticHTTP2_Serve(t, &tls.Config{}, false) +} + +func TestAutomaticHTTP2_Serve_H2TLSConfig(t *testing.T) { + testAutomaticHTTP2_Serve(t, &tls.Config{NextProtos: []string{"h2"}}, true) +} + +func testAutomaticHTTP2_Serve(t *testing.T, tlsConf *tls.Config, wantH2 bool) { defer afterTest(t) ln := newLocalListener(t) ln.Close() // immediately (not a defer!) var s Server + s.TLSConfig = tlsConf + if err := s.Serve(ln); err == nil { + t.Fatal("expected an error") + } + gotH2 := s.TLSNextProto["h2"] != nil + if gotH2 != wantH2 { + t.Errorf("http2 configured = %v; want %v", gotH2, wantH2) + } +} + +func TestAutomaticHTTP2_Serve_WithTLSConfig(t *testing.T) { + defer afterTest(t) + ln := newLocalListener(t) + ln.Close() // immediately (not a defer!) + var s Server + // Set the TLSConfig. In reality, this would be the + // *tls.Config given to tls.NewListener. + s.TLSConfig = &tls.Config{ + NextProtos: []string{"h2"}, + } if err := s.Serve(ln); err == nil { t.Fatal("expected an error") } @@ -1888,6 +2004,52 @@ func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { } } +// Issue 14568. +func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) { + if testing.Short() { + t.Skip("skipping sleeping test in -short mode") + } + defer afterTest(t) + var handler HandlerFunc = func(w ResponseWriter, _ *Request) { + w.WriteHeader(StatusNoContent) + } + timeout := 300 * time.Millisecond + ts := httptest.NewServer(TimeoutHandler(handler, timeout, "")) + defer ts.Close() + // Issue was caused by the timeout handler starting the timer when + // was created, not when the request. So wait for more than the timeout + // to ensure that's not the case. + time.Sleep(2 * timeout) + res, err := Get(ts.URL) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + if res.StatusCode != StatusNoContent { + t.Errorf("got res.StatusCode %d, want %v", res.StatusCode, StatusNoContent) + } +} + +// https://golang.org/issue/15948 +func TestTimeoutHandlerEmptyResponse(t *testing.T) { + defer afterTest(t) + var handler HandlerFunc = func(w ResponseWriter, _ *Request) { + // No response. + } + timeout := 300 * time.Millisecond + ts := httptest.NewServer(TimeoutHandler(handler, timeout, "")) + defer ts.Close() + + res, err := Get(ts.URL) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + if res.StatusCode != StatusOK { + t.Errorf("got res.StatusCode %d, want %v", res.StatusCode, StatusOK) + } +} + // Verifies we don't path.Clean() on the wrong parts in redirects. func TestRedirectMunging(t *testing.T) { req, _ := NewRequest("GET", "http://example.com/", nil) @@ -2027,7 +2189,7 @@ func TestHandlerPanicWithHijack(t *testing.T) { func testHandlerPanic(t *testing.T, withHijack, h2 bool, panicValue interface{}) { defer afterTest(t) // Unlike the other tests that set the log output to ioutil.Discard - // to quiet the output, this test uses a pipe. The pipe serves three + // to quiet the output, this test uses a pipe. The pipe serves three // purposes: // // 1) The log.Print from the http server (generated by the caught @@ -2060,7 +2222,7 @@ func testHandlerPanic(t *testing.T, withHijack, h2 bool, panicValue interface{}) defer cst.close() // Do a blocking read on the log output pipe so its logging - // doesn't bleed into the next test. But wait only 5 seconds + // doesn't bleed into the next test. But wait only 5 seconds // for it. done := make(chan bool, 1) go func() { @@ -2205,10 +2367,10 @@ func testRequestBodyLimit(t *testing.T, h2 bool) { nWritten := new(int64) req, _ := NewRequest("POST", cst.ts.URL, io.LimitReader(countReader{neverEnding('a'), nWritten}, limit*200)) - // Send the POST, but don't care it succeeds or not. The + // Send the POST, but don't care it succeeds or not. The // remote side is going to reply and then close the TCP // connection, and HTTP doesn't really define if that's - // allowed or not. Some HTTP clients will get the response + // allowed or not. Some HTTP clients will get the response // and some (like ours, currently) will complain that the // request write failed, without reading the response. // @@ -2650,7 +2812,7 @@ func TestOptions(t *testing.T) { } // Tests regarding the ordering of Write, WriteHeader, Header, and -// Flush calls. In Go 1.0, rw.WriteHeader immediately flushed the +// Flush calls. In Go 1.0, rw.WriteHeader immediately flushed the // (*response).header to the wire. In Go 1.1, the actual wire flush is // delayed, so we could maybe tack on a Content-Length and better // Content-Type after we see more (or all) of the output. To preserve @@ -3107,7 +3269,7 @@ func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) { const bodySize = 1 << 20 - // errorf is like t.Errorf, but also writes to println. When + // errorf is like t.Errorf, but also writes to println. When // this test fails, it hangs. This helps debugging and I've // added this enough times "temporarily". It now gets added // full time. @@ -3829,6 +3991,8 @@ func TestServerValidatesHostHeader(t *testing.T) { host string want int }{ + {"HTTP/0.9", "", 400}, + {"HTTP/1.1", "", 400}, {"HTTP/1.1", "Host: \r\n", 200}, {"HTTP/1.1", "Host: 1.2.3.4\r\n", 200}, @@ -3851,10 +4015,22 @@ func TestServerValidatesHostHeader(t *testing.T) { {"HTTP/1.0", "", 200}, {"HTTP/1.0", "Host: first\r\nHost: second\r\n", 400}, {"HTTP/1.0", "Host: \xff\r\n", 400}, + + // Make an exception for HTTP upgrade requests: + {"PRI * HTTP/2.0", "", 200}, + + // But not other HTTP/2 stuff: + {"PRI / HTTP/2.0", "", 400}, + {"GET / HTTP/2.0", "", 400}, + {"GET / HTTP/3.0", "", 400}, } for _, tt := range tests { conn := &testConn{closec: make(chan bool, 1)} - io.WriteString(&conn.readBuf, "GET / "+tt.proto+"\r\n"+tt.host+"\r\n") + methodTarget := "GET / " + if !strings.HasPrefix(tt.proto, "HTTP/") { + methodTarget = "" + } + io.WriteString(&conn.readBuf, methodTarget+tt.proto+"\r\n"+tt.host+"\r\n") ln := &oneConnListener{conn} go Serve(ln, HandlerFunc(func(ResponseWriter, *Request) {})) @@ -3870,6 +4046,45 @@ func TestServerValidatesHostHeader(t *testing.T) { } } +func TestServerHandlersCanHandleH2PRI(t *testing.T) { + const upgradeResponse = "upgrade here" + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + conn, br, err := w.(Hijacker).Hijack() + defer conn.Close() + if r.Method != "PRI" || r.RequestURI != "*" { + t.Errorf("Got method/target %q %q; want PRI *", r.Method, r.RequestURI) + return + } + if !r.Close { + t.Errorf("Request.Close = true; want false") + } + const want = "SM\r\n\r\n" + buf := make([]byte, len(want)) + n, err := io.ReadFull(br, buf) + if err != nil || string(buf[:n]) != want { + t.Errorf("Read = %v, %v (%q), want %q", n, err, buf[:n], want) + return + } + io.WriteString(conn, upgradeResponse) + })) + defer ts.Close() + + c, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer c.Close() + io.WriteString(c, "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n") + slurp, err := ioutil.ReadAll(c) + if err != nil { + t.Fatal(err) + } + if string(slurp) != upgradeResponse { + t.Errorf("Handler response = %q; want %q", slurp, upgradeResponse) + } +} + // Test that we validate the valid bytes in HTTP/1 headers. // Issue 11207. func TestServerValidatesHeaders(t *testing.T) { @@ -3910,6 +4125,140 @@ func TestServerValidatesHeaders(t *testing.T) { } } +func TestServerRequestContextCancel_ServeHTTPDone_h1(t *testing.T) { + testServerRequestContextCancel_ServeHTTPDone(t, h1Mode) +} +func TestServerRequestContextCancel_ServeHTTPDone_h2(t *testing.T) { + testServerRequestContextCancel_ServeHTTPDone(t, h2Mode) +} +func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, h2 bool) { + defer afterTest(t) + ctxc := make(chan context.Context, 1) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + ctx := r.Context() + select { + case <-ctx.Done(): + t.Error("should not be Done in ServeHTTP") + default: + } + ctxc <- ctx + })) + defer cst.close() + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + ctx := <-ctxc + select { + case <-ctx.Done(): + default: + t.Error("context should be done after ServeHTTP completes") + } +} + +func TestServerRequestContextCancel_ConnClose(t *testing.T) { + // Currently the context is not canceled when the connection + // is closed because we're not reading from the connection + // until after ServeHTTP for the previous handler is done. + // Until the server code is modified to always be in a read + // (Issue 15224), this test doesn't work yet. + t.Skip("TODO(bradfitz): this test doesn't yet work; golang.org/issue/15224") + defer afterTest(t) + inHandler := make(chan struct{}) + handlerDone := make(chan struct{}) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + close(inHandler) + select { + case <-r.Context().Done(): + case <-time.After(3 * time.Second): + t.Errorf("timeout waiting for context to be done") + } + close(handlerDone) + })) + defer ts.Close() + c, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer c.Close() + io.WriteString(c, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n") + select { + case <-inHandler: + case <-time.After(3 * time.Second): + t.Fatalf("timeout waiting to see ServeHTTP get called") + } + c.Close() // this should trigger the context being done + + select { + case <-handlerDone: + case <-time.After(3 * time.Second): + t.Fatalf("timeout waiting to see ServeHTTP exit") + } +} + +func TestServerContext_ServerContextKey_h1(t *testing.T) { + testServerContext_ServerContextKey(t, h1Mode) +} +func TestServerContext_ServerContextKey_h2(t *testing.T) { + testServerContext_ServerContextKey(t, h2Mode) +} +func testServerContext_ServerContextKey(t *testing.T, h2 bool) { + defer afterTest(t) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + ctx := r.Context() + got := ctx.Value(ServerContextKey) + if _, ok := got.(*Server); !ok { + t.Errorf("context value = %T; want *http.Server", got) + } + + got = ctx.Value(LocalAddrContextKey) + if addr, ok := got.(net.Addr); !ok { + t.Errorf("local addr value = %T; want net.Addr", got) + } else if fmt.Sprint(addr) != r.Host { + t.Errorf("local addr = %v; want %v", addr, r.Host) + } + })) + defer cst.close() + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + t.Fatal(err) + } + res.Body.Close() +} + +// https://golang.org/issue/15960 +func TestHandlerSetTransferEncodingChunked(t *testing.T) { + defer afterTest(t) + ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Transfer-Encoding", "chunked") + w.Write([]byte("hello")) + })) + resp := ht.rawResponse("GET / HTTP/1.1\nHost: foo") + const hdr = "Transfer-Encoding: chunked" + if n := strings.Count(resp, hdr); n != 1 { + t.Errorf("want 1 occurrence of %q in response, got %v\nresponse: %v", hdr, n, resp) + } +} + +// https://golang.org/issue/16063 +func TestHandlerSetTransferEncodingGzip(t *testing.T) { + defer afterTest(t) + ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Transfer-Encoding", "gzip") + gz := gzip.NewWriter(w) + gz.Write([]byte("hello")) + gz.Close() + })) + resp := ht.rawResponse("GET / HTTP/1.1\nHost: foo") + for _, v := range []string{"gzip", "chunked"} { + hdr := "Transfer-Encoding: " + v + if n := strings.Count(resp, hdr); n != 1 { + t.Errorf("want 1 occurrence of %q in response, got %v\nresponse: %v", hdr, n, resp) + } + } +} + func BenchmarkClientServer(b *testing.B) { b.ReportAllocs() b.StopTimer() @@ -4100,7 +4449,7 @@ func BenchmarkClient(b *testing.B) { // Wait for the server process to respond. url := "http://localhost:" + port + "/" for i := 0; i < 100; i++ { - time.Sleep(50 * time.Millisecond) + time.Sleep(100 * time.Millisecond) if _, err := getNoBody(url); err == nil { break } @@ -4121,7 +4470,7 @@ func BenchmarkClient(b *testing.B) { if err != nil { b.Fatalf("ReadAll: %v", err) } - if bytes.Compare(body, data) != 0 { + if !bytes.Equal(body, data) { b.Fatalf("Got body: %q", body) } } diff --git a/libgo/go/net/http/server.go b/libgo/go/net/http/server.go index 5e3b6084ae3..7b2b4b2f423 100644 --- a/libgo/go/net/http/server.go +++ b/libgo/go/net/http/server.go @@ -2,13 +2,14 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// HTTP server. See RFC 2616. +// HTTP server. See RFC 2616. package http import ( "bufio" "bytes" + "context" "crypto/tls" "errors" "fmt" @@ -26,14 +27,30 @@ import ( "sync" "sync/atomic" "time" + + "golang_org/x/net/lex/httplex" ) -// Errors introduced by the HTTP server. +// Errors used by the HTTP server. var ( - ErrWriteAfterFlush = errors.New("Conn.Write called after Flush") - ErrBodyNotAllowed = errors.New("http: request method or response status code does not allow body") - ErrHijacked = errors.New("Conn has been hijacked") - ErrContentLength = errors.New("Conn.Write wrote more than the declared Content-Length") + // ErrBodyNotAllowed is returned by ResponseWriter.Write calls + // when the HTTP method or response code does not permit a + // body. + ErrBodyNotAllowed = errors.New("http: request method or response status code does not allow body") + + // ErrHijacked is returned by ResponseWriter.Write calls when + // the underlying connection has been hijacked using the + // Hijacker interfaced. + ErrHijacked = errors.New("http: connection has been hijacked") + + // ErrContentLength is returned by ResponseWriter.Write calls + // when a Handler set a Content-Length response header with a + // declared size and then attempted to write more bytes than + // declared. + ErrContentLength = errors.New("http: wrote more than the declared Content-Length") + + // Deprecated: ErrWriteAfterFlush is no longer used. + ErrWriteAfterFlush = errors.New("unused") ) // A Handler responds to an HTTP request. @@ -50,6 +67,9 @@ var ( // ResponseWriter. Cautious handlers should read the Request.Body // first, and then reply. // +// Except for reading the body, handlers should not modify the +// provided Request. +// // If ServeHTTP panics, the server (the caller of ServeHTTP) assumes // that the effect of the panic was isolated to the active request. // It recovers the panic, logs a stack trace to the server error log, @@ -73,10 +93,24 @@ type ResponseWriter interface { Header() Header // Write writes the data to the connection as part of an HTTP reply. - // If WriteHeader has not yet been called, Write calls WriteHeader(http.StatusOK) - // before writing the data. If the Header does not contain a - // Content-Type line, Write adds a Content-Type set to the result of passing - // the initial 512 bytes of written data to DetectContentType. + // + // If WriteHeader has not yet been called, Write calls + // WriteHeader(http.StatusOK) before writing the data. If the Header + // does not contain a Content-Type line, Write adds a Content-Type set + // to the result of passing the initial 512 bytes of written data to + // DetectContentType. + // + // Depending on the HTTP protocol version and the client, calling + // Write or WriteHeader may prevent future reads on the + // Request.Body. For HTTP/1.x requests, handlers should read any + // needed request body data before writing the response. Once the + // headers have been flushed (due to either an explicit Flusher.Flush + // call or writing enough data to trigger a flush), the request body + // may be unavailable. For HTTP/2 requests, the Go HTTP server permits + // handlers to continue to read the request body while concurrently + // writing the response. However, such behavior may not be supported + // by all HTTP/2 clients. Handlers should read before writing if + // possible to maximize compatibility. Write([]byte) (int, error) // WriteHeader sends an HTTP response header with status code. @@ -90,6 +124,10 @@ type ResponseWriter interface { // The Flusher interface is implemented by ResponseWriters that allow // an HTTP handler to flush buffered data to the client. // +// The default HTTP/1.x and HTTP/2 ResponseWriter implementations +// support Flusher, but ResponseWriter wrappers may not. Handlers +// should always test for this ability at runtime. +// // Note that even for ResponseWriters that support Flush, // if the client is connected through an HTTP proxy, // the buffered data may not reach the client until the response @@ -101,6 +139,11 @@ type Flusher interface { // The Hijacker interface is implemented by ResponseWriters that allow // an HTTP handler to take over the connection. +// +// The default ResponseWriter for HTTP/1.x connections supports +// Hijacker, but HTTP/2 connections intentionally do not. +// ResponseWriter wrappers may also not support Hijacker. Handlers +// should always test for this ability at runtime. type Hijacker interface { // Hijack lets the caller take over the connection. // After a call to Hijack(), the HTTP server library @@ -143,6 +186,20 @@ type CloseNotifier interface { CloseNotify() <-chan bool } +var ( + // ServerContextKey is a context key. It can be used in HTTP + // handlers with context.WithValue to access the server that + // started the handler. The associated value will be of + // type *Server. + ServerContextKey = &contextKey{"http-server"} + + // LocalAddrContextKey is a context key. It can be used in + // HTTP handlers with context.WithValue to access the address + // the local address the connection arrived on. + // The associated value will be of type net.Addr. + LocalAddrContextKey = &contextKey{"local-addr"} +) + // A conn represents the server side of an HTTP connection. type conn struct { // server is the server on which the connection arrived. @@ -306,11 +363,14 @@ func (cw *chunkWriter) close() { // A response represents the server side of an HTTP response. type response struct { - conn *conn - req *Request // request for this response - reqBody io.ReadCloser - wroteHeader bool // reply header has been (logically) written - wroteContinue bool // 100 Continue response was written + conn *conn + req *Request // request for this response + reqBody io.ReadCloser + cancelCtx context.CancelFunc // when ServeHTTP exits + wroteHeader bool // reply header has been (logically) written + wroteContinue bool // 100 Continue response was written + wants10KeepAlive bool // HTTP/1.0 w/ Connection "keep-alive" + wantsClose bool // HTTP request has Connection "close" w *bufio.Writer // buffers output in chunks to chunkWriter cw chunkWriter @@ -342,7 +402,7 @@ type response struct { requestBodyLimitHit bool // trailers are the headers to be sent after the handler - // finishes writing the body. This field is initialized from + // finishes writing the body. This field is initialized from // the Trailer response header when the response header is // written. trailers []string @@ -497,7 +557,7 @@ type connReader struct { } func (cr *connReader) setReadLimit(remain int64) { cr.remain = remain } -func (cr *connReader) setInfiniteReadLimit() { cr.remain = 1<<63 - 1 } +func (cr *connReader) setInfiniteReadLimit() { cr.remain = maxInt64 } func (cr *connReader) hitReadLimit() bool { return cr.remain <= 0 } func (cr *connReader) Read(p []byte) (n int, err error) { @@ -681,7 +741,7 @@ func appendTime(b []byte, t time.Time) []byte { var errTooLarge = errors.New("http: request too large") // Read next request from connection. -func (c *conn) readRequest() (w *response, err error) { +func (c *conn) readRequest(ctx context.Context) (w *response, err error) { if c.hijacked() { return nil, ErrHijacked } @@ -710,31 +770,39 @@ func (c *conn) readRequest() (w *response, err error) { } return nil, err } + + if !http1ServerSupportsRequest(req) { + return nil, badRequestError("unsupported protocol version") + } + c.lastMethod = req.Method c.r.setInfiniteReadLimit() hosts, haveHost := req.Header["Host"] - if req.ProtoAtLeast(1, 1) && (!haveHost || len(hosts) == 0) { + isH2Upgrade := req.isH2Upgrade() + if req.ProtoAtLeast(1, 1) && (!haveHost || len(hosts) == 0) && !isH2Upgrade { return nil, badRequestError("missing required Host header") } if len(hosts) > 1 { return nil, badRequestError("too many Host headers") } - if len(hosts) == 1 && !validHostHeader(hosts[0]) { + if len(hosts) == 1 && !httplex.ValidHostHeader(hosts[0]) { return nil, badRequestError("malformed Host header") } for k, vv := range req.Header { - if !validHeaderName(k) { + if !httplex.ValidHeaderFieldName(k) { return nil, badRequestError("invalid header name") } for _, v := range vv { - if !validHeaderValue(v) { + if !httplex.ValidHeaderFieldValue(v) { return nil, badRequestError("invalid header value") } } } delete(req.Header, "Host") + ctx, cancelCtx := context.WithCancel(ctx) + req.ctx = ctx req.RemoteAddr = c.remoteAddr req.TLS = c.tlsState if body, ok := req.Body.(*body); ok { @@ -743,16 +811,43 @@ func (c *conn) readRequest() (w *response, err error) { w = &response{ conn: c, + cancelCtx: cancelCtx, req: req, reqBody: req.Body, handlerHeader: make(Header), contentLength: -1, + + // We populate these ahead of time so we're not + // reading from req.Header after their Handler starts + // and maybe mutates it (Issue 14940) + wants10KeepAlive: req.wantsHttp10KeepAlive(), + wantsClose: req.wantsClose(), + } + if isH2Upgrade { + w.closeAfterReply = true } w.cw.res = w w.w = newBufioWriterSize(&w.cw, bufferBeforeChunkingSize) return w, nil } +// http1ServerSupportsRequest reports whether Go's HTTP/1.x server +// supports the given request. +func http1ServerSupportsRequest(req *Request) bool { + if req.ProtoMajor == 1 { + return true + } + // Accept "PRI * HTTP/2.0" upgrade requests, so Handlers can + // wire up their own HTTP/2 upgrades. + if req.ProtoMajor == 2 && req.ProtoMinor == 0 && + req.Method == "PRI" && req.RequestURI == "*" { + return true + } + // Reject HTTP/0.x, and all other HTTP/2+ requests (which + // aren't encoded in ASCII anyway). + return false +} + func (w *response) Header() Header { if w.cw.header == nil && w.wroteHeader && !w.cw.wroteHeader { // Accessing the header between logically writing it @@ -766,7 +861,7 @@ func (w *response) Header() Header { // maxPostHandlerReadBytes is the max number of Request.Body bytes not // consumed by a handler that the server will read from the client -// in order to keep a connection alive. If there are more bytes than +// in order to keep a connection alive. If there are more bytes than // this then the server to be paranoid instead sends a "Connection: // close" response. // @@ -855,8 +950,8 @@ func (h extraHeader) Write(w *bufio.Writer) { // to cw.res.conn.bufw. // // p is not written by writeHeader, but is the first chunk of the body -// that will be written. It is sniffed for a Content-Type if none is -// set explicitly. It's also used to set the Content-Length, if the +// that will be written. It is sniffed for a Content-Type if none is +// set explicitly. It's also used to set the Content-Length, if the // total body size was small and the handler has already finished // running. func (cw *chunkWriter) writeHeader(p []byte) { @@ -911,9 +1006,9 @@ func (cw *chunkWriter) writeHeader(p []byte) { // Exceptions: 304/204/1xx responses never get Content-Length, and if // it was a HEAD request, we don't know the difference between // 0 actual bytes and 0 bytes because the handler noticed it - // was a HEAD request and chose not to write anything. So for + // was a HEAD request and chose not to write anything. So for // HEAD, the handler should either write the Content-Length or - // write non-zero bytes. If it's actually 0 bytes and the + // write non-zero bytes. If it's actually 0 bytes and the // handler never looked at the Request.Method, we just don't // send a Content-Length header. // Further, we don't send an automatic Content-Length if they @@ -925,7 +1020,7 @@ func (cw *chunkWriter) writeHeader(p []byte) { // If this was an HTTP/1.0 request with keep-alive and we sent a // Content-Length back, we can make this a keep-alive response ... - if w.req.wantsHttp10KeepAlive() && keepAlivesEnabled { + if w.wants10KeepAlive && keepAlivesEnabled { sentLength := header.get("Content-Length") != "" if sentLength && header.get("Connection") == "keep-alive" { w.closeAfterReply = false @@ -935,12 +1030,12 @@ func (cw *chunkWriter) writeHeader(p []byte) { // Check for a explicit (and valid) Content-Length header. hasCL := w.contentLength != -1 - if w.req.wantsHttp10KeepAlive() && (isHEAD || hasCL) { + if w.wants10KeepAlive && (isHEAD || hasCL || !bodyAllowedForStatus(w.status)) { _, connectionHeaderSet := header["Connection"] if !connectionHeaderSet { setHeader.connection = "keep-alive" } - } else if !w.req.ProtoAtLeast(1, 1) || w.req.wantsClose() { + } else if !w.req.ProtoAtLeast(1, 1) || w.wantsClose { w.closeAfterReply = true } @@ -965,9 +1060,12 @@ func (cw *chunkWriter) writeHeader(p []byte) { } // Per RFC 2616, we should consume the request body before - // replying, if the handler hasn't already done so. But we + // replying, if the handler hasn't already done so. But we // don't want to do an unbounded amount of reading here for // DoS reasons, so we only try up to a threshold. + // TODO(bradfitz): where does RFC 2616 say that? See Issue 15527 + // about HTTP/1.x Handlers concurrently reading and writing, like + // HTTP/2 handlers can do. Maybe this code should be relaxed? if w.req.ContentLength != 0 && !w.closeAfterReply { var discard, tooBig bool @@ -1009,7 +1107,7 @@ func (cw *chunkWriter) writeHeader(p []byte) { w.closeAfterReply = true } default: - // Some other kind of error occured, like a read timeout, or + // Some other kind of error occurred, like a read timeout, or // corrupt chunked encoding. In any case, whatever remains // on the wire must not be parsed as another HTTP request. w.closeAfterReply = true @@ -1069,6 +1167,10 @@ func (cw *chunkWriter) writeHeader(p []byte) { // to avoid closing the connection at EOF. cw.chunking = true setHeader.transferEncoding = "chunked" + if hasTE && te == "chunked" { + // We will send the chunked Transfer-Encoding header later. + delHeader("Transfer-Encoding") + } } } else { // HTTP version < 1.1: cannot do chunked transfer @@ -1148,7 +1250,7 @@ func statusLine(req *Request, code int) string { if proto11 { proto = "HTTP/1.1" } - codestring := strconv.Itoa(code) + codestring := fmt.Sprintf("%03d", code) text, ok := statusText[code] if !ok { text = "status code " + codestring @@ -1174,7 +1276,7 @@ func (w *response) bodyAllowed() bool { // The Life Of A Write is like this: // // Handler starts. No header has been sent. The handler can either -// write a header, or just start writing. Writing before sending a header +// write a header, or just start writing. Writing before sending a header // sends an implicitly empty 200 OK header. // // If the handler didn't declare a Content-Length up front, we either @@ -1200,7 +1302,7 @@ func (w *response) bodyAllowed() bool { // initial header contains both a Content-Type and Content-Length. // Also short-circuit in (1) when the header's been sent and not in // chunking mode, writing directly to (4) instead, if (2) has no -// buffered data. More generally, we could short-circuit from (1) to +// buffered data. More generally, we could short-circuit from (1) to // (3) even in chunking mode if the write size from (1) is over some // threshold and nothing is in (2). The answer might be mostly making // bufferBeforeChunkingSize smaller and having bufio's fast-paths deal @@ -1341,7 +1443,7 @@ type closeWriter interface { var _ closeWriter = (*net.TCPConn)(nil) // closeWrite flushes any outstanding data and sends a FIN packet (if -// client is connected via TCP), signalling that we're done. We then +// client is connected via TCP), signalling that we're done. We then // pause for a bit, hoping the client processes it before any // subsequent RST. // @@ -1355,7 +1457,7 @@ func (c *conn) closeWriteAndWait() { } // validNPN reports whether the proto is not a blacklisted Next -// Protocol Negotiation protocol. Empty and built-in protocol types +// Protocol Negotiation protocol. Empty and built-in protocol types // are blacklisted and can't be overridden with alternate // implementations. func validNPN(proto string) bool { @@ -1374,13 +1476,13 @@ func (c *conn) setState(nc net.Conn, state ConnState) { // badRequestError is a literal string (used by in the server in HTML, // unescaped) to tell the user why their request was bad. It should -// be plain text without user info or other embeddded errors. +// be plain text without user info or other embedded errors. type badRequestError string func (e badRequestError) Error() string { return "Bad Request: " + string(e) } // Serve a new connection. -func (c *conn) serve() { +func (c *conn) serve(ctx context.Context) { c.remoteAddr = c.rwc.RemoteAddr().String() defer func() { if err := recover(); err != nil { @@ -1417,12 +1519,17 @@ func (c *conn) serve() { } } + // HTTP/1.x from here on. + c.r = &connReader{r: c.rwc} c.bufr = newBufioReader(c.r) c.bufw = newBufioWriterSize(checkConnErrorWriter{c}, 4<<10) + ctx, cancelCtx := context.WithCancel(ctx) + defer cancelCtx() + for { - w, err := c.readRequest() + w, err := c.readRequest(ctx) if c.r.remain != c.server.initialReadLimitSize() { // If we read any bytes off the wire, we're active. c.setState(c.rwc, StateActive) @@ -1433,7 +1540,7 @@ func (c *conn) serve() { // able to read this if we're // responding to them and hanging up // while they're still writing their - // request. Undefined behavior. + // request. Undefined behavior. io.WriteString(c.rwc, "HTTP/1.1 431 Request Header Fields Too Large\r\nContent-Type: text/plain\r\nConnection: close\r\n\r\n431 Request Header Fields Too Large") c.closeWriteAndWait() return @@ -1467,9 +1574,10 @@ func (c *conn) serve() { // HTTP cannot have multiple simultaneous active requests.[*] // Until the server replies to this request, it can't read another, // so we might as well run the handler in this goroutine. - // [*] Not strictly true: HTTP pipelining. We could let them all process + // [*] Not strictly true: HTTP pipelining. We could let them all process // in parallel even if their responses need to be serialized. serverHandler{c.server}.ServeHTTP(w, w.req) + w.cancelCtx() if c.hijacked() { return } @@ -1488,7 +1596,7 @@ func (w *response) sendExpectationFailed() { // TODO(bradfitz): let ServeHTTP handlers handle // requests with non-standard expectation[s]? Seems // theoretical at best, and doesn't fit into the - // current ServeHTTP model anyway. We'd need to + // current ServeHTTP model anyway. We'd need to // make the ResponseWriter an optional // "ExpectReplier" interface or something. // @@ -1608,7 +1716,7 @@ func requestBodyRemains(rc io.ReadCloser) bool { } // The HandlerFunc type is an adapter to allow the use of -// ordinary functions as HTTP handlers. If f is a function +// ordinary functions as HTTP handlers. If f is a function // with the appropriate signature, HandlerFunc(f) is a // Handler that calls f. type HandlerFunc func(ResponseWriter, *Request) @@ -1621,6 +1729,8 @@ func (f HandlerFunc) ServeHTTP(w ResponseWriter, r *Request) { // Helper handlers // Error replies to the request with the specified error message and HTTP code. +// It does not otherwise end the request; the caller should ensure no further +// writes are done to w. // The error message should be plain text. func Error(w ResponseWriter, error string, code int) { w.Header().Set("Content-Type", "text/plain; charset=utf-8") @@ -1709,7 +1819,7 @@ func Redirect(w ResponseWriter, r *Request, urlStr string, code int) { w.Header().Set("Location", urlStr) w.WriteHeader(code) - // RFC2616 recommends that a short note "SHOULD" be included in the + // RFC 2616 recommends that a short note "SHOULD" be included in the // response because older user agents may not understand 301/307. // Shouldn't send the response for POST or HEAD; that leaves GET. if r.Method == "GET" { @@ -1779,7 +1889,7 @@ func RedirectHandler(url string, code int) Handler { // been registered separately. // // Patterns may optionally begin with a host name, restricting matches to -// URLs on that host only. Host-specific patterns take precedence over +// URLs on that host only. Host-specific patterns take precedence over // general patterns, so that a handler might register for the two patterns // "/codesearch" and "codesearch.google.com/" without also taking over // requests for "http://www.google.com/". @@ -1800,10 +1910,12 @@ type muxEntry struct { } // NewServeMux allocates and returns a new ServeMux. -func NewServeMux() *ServeMux { return &ServeMux{m: make(map[string]muxEntry)} } +func NewServeMux() *ServeMux { return new(ServeMux) } // DefaultServeMux is the default ServeMux used by Serve. -var DefaultServeMux = NewServeMux() +var DefaultServeMux = &defaultServeMux + +var defaultServeMux ServeMux // Does path match pattern? func pathMatch(pattern, path string) bool { @@ -1926,6 +2038,9 @@ func (mux *ServeMux) Handle(pattern string, handler Handler) { panic("http: multiple registrations for " + pattern) } + if mux.m == nil { + mux.m = make(map[string]muxEntry) + } mux.m[pattern] = muxEntry{explicit: true, h: handler, pattern: pattern} if pattern[0] != '/' { @@ -1968,7 +2083,7 @@ func HandleFunc(pattern string, handler func(ResponseWriter, *Request)) { } // Serve accepts incoming HTTP connections on the listener l, -// creating a new service goroutine for each. The service goroutines +// creating a new service goroutine for each. The service goroutines // read requests and then call handler to reply to them. // Handler is typically nil, in which case the DefaultServeMux is used. func Serve(l net.Listener, handler Handler) error { @@ -1979,19 +2094,25 @@ func Serve(l net.Listener, handler Handler) error { // A Server defines parameters for running an HTTP server. // The zero value for Server is a valid configuration. type Server struct { - Addr string // TCP address to listen on, ":http" if empty - Handler Handler // handler to invoke, http.DefaultServeMux if nil - ReadTimeout time.Duration // maximum duration before timing out read of the request - WriteTimeout time.Duration // maximum duration before timing out write of the response - MaxHeaderBytes int // maximum size of request headers, DefaultMaxHeaderBytes if 0 - TLSConfig *tls.Config // optional TLS config, used by ListenAndServeTLS + Addr string // TCP address to listen on, ":http" if empty + Handler Handler // handler to invoke, http.DefaultServeMux if nil + ReadTimeout time.Duration // maximum duration before timing out read of the request + WriteTimeout time.Duration // maximum duration before timing out write of the response + TLSConfig *tls.Config // optional TLS config, used by ListenAndServeTLS + + // MaxHeaderBytes controls the maximum number of bytes the + // server will read parsing the request header's keys and + // values, including the request line. It does not limit the + // size of the request body. + // If zero, DefaultMaxHeaderBytes is used. + MaxHeaderBytes int // TLSNextProto optionally specifies a function to take over - // ownership of the provided TLS connection when an NPN - // protocol upgrade has occurred. The map key is the protocol + // ownership of the provided TLS connection when an NPN/ALPN + // protocol upgrade has occurred. The map key is the protocol // name negotiated. The Handler argument should be used to // handle HTTP requests and will initialize the Request's TLS - // and RemoteAddr if not already set. The connection is + // and RemoteAddr if not already set. The connection is // automatically closed when the function returns. // If TLSNextProto is nil, HTTP/2 support is enabled automatically. TLSNextProto map[string]func(*Server, *tls.Conn, Handler) @@ -2032,7 +2153,7 @@ const ( // For HTTP/2, StateActive fires on the transition from zero // to one active request, and only transitions away once all // active requests are complete. That means that ConnState - // can not be used to do per-request work; ConnState only notes + // cannot be used to do per-request work; ConnState only notes // the overall state of the connection. StateActive @@ -2100,9 +2221,37 @@ func (srv *Server) ListenAndServe() error { var testHookServerServe func(*Server, net.Listener) // used if non-nil +// shouldDoServeHTTP2 reports whether Server.Serve should configure +// automatic HTTP/2. (which sets up the srv.TLSNextProto map) +func (srv *Server) shouldConfigureHTTP2ForServe() bool { + if srv.TLSConfig == nil { + // Compatibility with Go 1.6: + // If there's no TLSConfig, it's possible that the user just + // didn't set it on the http.Server, but did pass it to + // tls.NewListener and passed that listener to Serve. + // So we should configure HTTP/2 (to set up srv.TLSNextProto) + // in case the listener returns an "h2" *tls.Conn. + return true + } + // The user specified a TLSConfig on their http.Server. + // In this, case, only configure HTTP/2 if their tls.Config + // explicitly mentions "h2". Otherwise http2.ConfigureServer + // would modify the tls.Config to add it, but they probably already + // passed this tls.Config to tls.NewListener. And if they did, + // it's too late anyway to fix it. It would only be potentially racy. + // See Issue 15908. + return strSliceContains(srv.TLSConfig.NextProtos, http2NextProtoTLS) +} + // Serve accepts incoming connections on the Listener l, creating a // new service goroutine for each. The service goroutines read requests and // then call srv.Handler to reply to them. +// +// For HTTP/2 support, srv.TLSConfig should be initialized to the +// provided listener's TLS Config before calling Serve. If +// srv.TLSConfig is non-nil and doesn't include the string "h2" in +// Config.NextProtos, HTTP/2 support is not enabled. +// // Serve always returns a non-nil error. func (srv *Server) Serve(l net.Listener) error { defer l.Close() @@ -2110,9 +2259,18 @@ func (srv *Server) Serve(l net.Listener) error { fn(srv, l) } var tempDelay time.Duration // how long to sleep on accept failure - if err := srv.setupHTTP2(); err != nil { - return err + + if srv.shouldConfigureHTTP2ForServe() { + if err := srv.setupHTTP2(); err != nil { + return err + } } + + // TODO: allow changing base context? can't imagine concrete + // use cases yet. + baseCtx := context.Background() + ctx := context.WithValue(baseCtx, ServerContextKey, srv) + ctx = context.WithValue(ctx, LocalAddrContextKey, l.Addr()) for { rw, e := l.Accept() if e != nil { @@ -2134,7 +2292,7 @@ func (srv *Server) Serve(l net.Listener) error { tempDelay = 0 c := srv.newConn(rw) c.setState(c.rwc, StateNew) // before Serve can return - go c.serve() + go c.serve(ctx) } } @@ -2309,15 +2467,10 @@ func (srv *Server) onceSetNextProtoDefaults() { // TimeoutHandler buffers all Handler writes to memory and does not // support the Hijacker or Flusher interfaces. func TimeoutHandler(h Handler, dt time.Duration, msg string) Handler { - t := time.NewTimer(dt) return &timeoutHandler{ handler: h, body: msg, - - // Effectively storing a *time.Timer, but decomposed - // for testing: - timeout: func() <-chan time.Time { return t.C }, - cancelTimer: t.Stop, + dt: dt, } } @@ -2328,12 +2481,11 @@ var ErrHandlerTimeout = errors.New("http: Handler timeout") type timeoutHandler struct { handler Handler body string + dt time.Duration - // timeout returns the channel of a *time.Timer and - // cancelTimer cancels it. They're stored separately for - // testing purposes. - timeout func() <-chan time.Time // returns channel producing a timeout - cancelTimer func() bool // optional + // When set, no timer will be created and this channel will + // be used instead. + testTimeout <-chan time.Time } func (h *timeoutHandler) errorBody() string { @@ -2344,6 +2496,12 @@ func (h *timeoutHandler) errorBody() string { } func (h *timeoutHandler) ServeHTTP(w ResponseWriter, r *Request) { + var t *time.Timer + timeout := h.testTimeout + if timeout == nil { + t = time.NewTimer(h.dt) + timeout = t.C + } done := make(chan struct{}) tw := &timeoutWriter{ w: w, @@ -2361,12 +2519,15 @@ func (h *timeoutHandler) ServeHTTP(w ResponseWriter, r *Request) { for k, vv := range tw.h { dst[k] = vv } + if !tw.wroteHeader { + tw.code = StatusOK + } w.WriteHeader(tw.code) w.Write(tw.wbuf.Bytes()) - if h.cancelTimer != nil { - h.cancelTimer() + if t != nil { + t.Stop() } - case <-h.timeout(): + case <-timeout: tw.mu.Lock() defer tw.mu.Unlock() w.WriteHeader(StatusServiceUnavailable) diff --git a/libgo/go/net/http/sniff.go b/libgo/go/net/http/sniff.go index 18810bad068..0d21b44a560 100644 --- a/libgo/go/net/http/sniff.go +++ b/libgo/go/net/http/sniff.go @@ -1,4 +1,4 @@ -// Copyright 2011 The Go Authors. All rights reserved. +// Copyright 2011 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. @@ -14,8 +14,8 @@ const sniffLen = 512 // DetectContentType implements the algorithm described // at http://mimesniff.spec.whatwg.org/ to determine the -// Content-Type of the given data. It considers at most the -// first 512 bytes of data. DetectContentType always returns +// Content-Type of the given data. It considers at most the +// first 512 bytes of data. DetectContentType always returns // a valid MIME type: if it cannot determine a more specific one, it // returns "application/octet-stream". func DetectContentType(data []byte) string { @@ -91,12 +91,41 @@ var sniffSignatures = []sniffSig{ ct: "image/webp", }, &exactSig{[]byte("\x00\x00\x01\x00"), "image/vnd.microsoft.icon"}, - &exactSig{[]byte("\x4F\x67\x67\x53\x00"), "application/ogg"}, &maskedSig{ mask: []byte("\xFF\xFF\xFF\xFF\x00\x00\x00\x00\xFF\xFF\xFF\xFF"), pat: []byte("RIFF\x00\x00\x00\x00WAVE"), ct: "audio/wave", }, + &maskedSig{ + mask: []byte("\xFF\xFF\xFF\xFF\x00\x00\x00\x00\xFF\xFF\xFF\xFF"), + pat: []byte("FORM\x00\x00\x00\x00AIFF"), + ct: "audio/aiff", + }, + &maskedSig{ + mask: []byte("\xFF\xFF\xFF\xFF"), + pat: []byte(".snd"), + ct: "audio/basic", + }, + &maskedSig{ + mask: []byte("OggS\x00"), + pat: []byte("\x4F\x67\x67\x53\x00"), + ct: "application/ogg", + }, + &maskedSig{ + mask: []byte("\xFF\xFF\xFF\xFF\xFF\xFF\xFF\xFF"), + pat: []byte("MThd\x00\x00\x00\x06"), + ct: "audio/midi", + }, + &maskedSig{ + mask: []byte("\xFF\xFF\xFF"), + pat: []byte("ID3"), + ct: "audio/mpeg", + }, + &maskedSig{ + mask: []byte("\xFF\xFF\xFF\xFF\x00\x00\x00\x00\xFF\xFF\xFF\xFF"), + pat: []byte("RIFF\x00\x00\x00\x00AVI "), + ct: "video/avi", + }, &exactSig{[]byte("\x1A\x45\xDF\xA3"), "video/webm"}, &exactSig{[]byte("\x52\x61\x72\x20\x1A\x07\x00"), "application/x-rar-compressed"}, &exactSig{[]byte("\x50\x4B\x03\x04"), "application/zip"}, @@ -126,9 +155,15 @@ type maskedSig struct { } func (m *maskedSig) match(data []byte, firstNonWS int) string { + // pattern matching algorithm section 6 + // https://mimesniff.spec.whatwg.org/#pattern-matching-algorithm + if m.skipWS { data = data[firstNonWS:] } + if len(m.pat) != len(m.mask) { + return "" + } if len(data) < len(m.mask) { return "" } diff --git a/libgo/go/net/http/sniff_test.go b/libgo/go/net/http/sniff_test.go index e0085516da3..ac404bfa723 100644 --- a/libgo/go/net/http/sniff_test.go +++ b/libgo/go/net/http/sniff_test.go @@ -39,7 +39,18 @@ var sniffTests = []struct { {"GIF 87a", []byte(`GIF87a`), "image/gif"}, {"GIF 89a", []byte(`GIF89a...`), "image/gif"}, + // Audio types. + {"MIDI audio", []byte("MThd\x00\x00\x00\x06\x00\x01"), "audio/midi"}, + {"MP3 audio/MPEG audio", []byte("ID3\x03\x00\x00\x00\x00\x0f"), "audio/mpeg"}, + {"WAV audio #1", []byte("RIFFb\xb8\x00\x00WAVEfmt \x12\x00\x00\x00\x06"), "audio/wave"}, + {"WAV audio #2", []byte("RIFF,\x00\x00\x00WAVEfmt \x12\x00\x00\x00\x06"), "audio/wave"}, + {"AIFF audio #1", []byte("FORM\x00\x00\x00\x00AIFFCOMM\x00\x00\x00\x12\x00\x01\x00\x00\x57\x55\x00\x10\x40\x0d\xf3\x34"), "audio/aiff"}, + {"OGG audio", []byte("OggS\x00\x02\x00\x00\x00\x00\x00\x00\x00\x00\x7e\x46\x00\x00\x00\x00\x00\x00\x1f\xf6\xb4\xfc\x01\x1e\x01\x76\x6f\x72"), "application/ogg"}, + + // Video types. {"MP4 video", []byte("\x00\x00\x00\x18ftypmp42\x00\x00\x00\x00mp42isom<\x06t\xbfmdat"), "video/mp4"}, + {"AVI video #1", []byte("RIFF,O\n\x00AVI LISTÀ"), "video/avi"}, + {"AVI video #2", []byte("RIFF,\n\x00\x00AVI LISTÀ"), "video/avi"}, } func TestDetectContentType(t *testing.T) { diff --git a/libgo/go/net/http/status.go b/libgo/go/net/http/status.go index f3dacab6a92..98645b7d746 100644 --- a/libgo/go/net/http/status.go +++ b/libgo/go/net/http/status.go @@ -4,63 +4,79 @@ package http -// HTTP status codes, defined in RFC 2616. +// HTTP status codes as registered with IANA. +// See: http://www.iana.org/assignments/http-status-codes/http-status-codes.xhtml const ( - StatusContinue = 100 - StatusSwitchingProtocols = 101 + StatusContinue = 100 // RFC 7231, 6.2.1 + StatusSwitchingProtocols = 101 // RFC 7231, 6.2.2 + StatusProcessing = 102 // RFC 2518, 10.1 - StatusOK = 200 - StatusCreated = 201 - StatusAccepted = 202 - StatusNonAuthoritativeInfo = 203 - StatusNoContent = 204 - StatusResetContent = 205 - StatusPartialContent = 206 + StatusOK = 200 // RFC 7231, 6.3.1 + StatusCreated = 201 // RFC 7231, 6.3.2 + StatusAccepted = 202 // RFC 7231, 6.3.3 + StatusNonAuthoritativeInfo = 203 // RFC 7231, 6.3.4 + StatusNoContent = 204 // RFC 7231, 6.3.5 + StatusResetContent = 205 // RFC 7231, 6.3.6 + StatusPartialContent = 206 // RFC 7233, 4.1 + StatusMultiStatus = 207 // RFC 4918, 11.1 + StatusAlreadyReported = 208 // RFC 5842, 7.1 + StatusIMUsed = 226 // RFC 3229, 10.4.1 - StatusMultipleChoices = 300 - StatusMovedPermanently = 301 - StatusFound = 302 - StatusSeeOther = 303 - StatusNotModified = 304 - StatusUseProxy = 305 - StatusTemporaryRedirect = 307 + StatusMultipleChoices = 300 // RFC 7231, 6.4.1 + StatusMovedPermanently = 301 // RFC 7231, 6.4.2 + StatusFound = 302 // RFC 7231, 6.4.3 + StatusSeeOther = 303 // RFC 7231, 6.4.4 + StatusNotModified = 304 // RFC 7232, 4.1 + StatusUseProxy = 305 // RFC 7231, 6.4.5 + _ = 306 // RFC 7231, 6.4.6 (Unused) + StatusTemporaryRedirect = 307 // RFC 7231, 6.4.7 + StatusPermanentRedirect = 308 // RFC 7538, 3 - StatusBadRequest = 400 - StatusUnauthorized = 401 - StatusPaymentRequired = 402 - StatusForbidden = 403 - StatusNotFound = 404 - StatusMethodNotAllowed = 405 - StatusNotAcceptable = 406 - StatusProxyAuthRequired = 407 - StatusRequestTimeout = 408 - StatusConflict = 409 - StatusGone = 410 - StatusLengthRequired = 411 - StatusPreconditionFailed = 412 - StatusRequestEntityTooLarge = 413 - StatusRequestURITooLong = 414 - StatusUnsupportedMediaType = 415 - StatusRequestedRangeNotSatisfiable = 416 - StatusExpectationFailed = 417 - StatusTeapot = 418 - StatusPreconditionRequired = 428 - StatusTooManyRequests = 429 - StatusRequestHeaderFieldsTooLarge = 431 - StatusUnavailableForLegalReasons = 451 + StatusBadRequest = 400 // RFC 7231, 6.5.1 + StatusUnauthorized = 401 // RFC 7235, 3.1 + StatusPaymentRequired = 402 // RFC 7231, 6.5.2 + StatusForbidden = 403 // RFC 7231, 6.5.3 + StatusNotFound = 404 // RFC 7231, 6.5.4 + StatusMethodNotAllowed = 405 // RFC 7231, 6.5.5 + StatusNotAcceptable = 406 // RFC 7231, 6.5.6 + StatusProxyAuthRequired = 407 // RFC 7235, 3.2 + StatusRequestTimeout = 408 // RFC 7231, 6.5.7 + StatusConflict = 409 // RFC 7231, 6.5.8 + StatusGone = 410 // RFC 7231, 6.5.9 + StatusLengthRequired = 411 // RFC 7231, 6.5.10 + StatusPreconditionFailed = 412 // RFC 7232, 4.2 + StatusRequestEntityTooLarge = 413 // RFC 7231, 6.5.11 + StatusRequestURITooLong = 414 // RFC 7231, 6.5.12 + StatusUnsupportedMediaType = 415 // RFC 7231, 6.5.13 + StatusRequestedRangeNotSatisfiable = 416 // RFC 7233, 4.4 + StatusExpectationFailed = 417 // RFC 7231, 6.5.14 + StatusTeapot = 418 // RFC 7168, 2.3.3 + StatusUnprocessableEntity = 422 // RFC 4918, 11.2 + StatusLocked = 423 // RFC 4918, 11.3 + StatusFailedDependency = 424 // RFC 4918, 11.4 + StatusUpgradeRequired = 426 // RFC 7231, 6.5.15 + StatusPreconditionRequired = 428 // RFC 6585, 3 + StatusTooManyRequests = 429 // RFC 6585, 4 + StatusRequestHeaderFieldsTooLarge = 431 // RFC 6585, 5 + StatusUnavailableForLegalReasons = 451 // RFC 7725, 3 - StatusInternalServerError = 500 - StatusNotImplemented = 501 - StatusBadGateway = 502 - StatusServiceUnavailable = 503 - StatusGatewayTimeout = 504 - StatusHTTPVersionNotSupported = 505 - StatusNetworkAuthenticationRequired = 511 + StatusInternalServerError = 500 // RFC 7231, 6.6.1 + StatusNotImplemented = 501 // RFC 7231, 6.6.2 + StatusBadGateway = 502 // RFC 7231, 6.6.3 + StatusServiceUnavailable = 503 // RFC 7231, 6.6.4 + StatusGatewayTimeout = 504 // RFC 7231, 6.6.5 + StatusHTTPVersionNotSupported = 505 // RFC 7231, 6.6.6 + StatusVariantAlsoNegotiates = 506 // RFC 2295, 8.1 + StatusInsufficientStorage = 507 // RFC 4918, 11.5 + StatusLoopDetected = 508 // RFC 5842, 7.2 + StatusNotExtended = 510 // RFC 2774, 7 + StatusNetworkAuthenticationRequired = 511 // RFC 6585, 6 ) var statusText = map[int]string{ StatusContinue: "Continue", StatusSwitchingProtocols: "Switching Protocols", + StatusProcessing: "Processing", StatusOK: "OK", StatusCreated: "Created", @@ -69,6 +85,9 @@ var statusText = map[int]string{ StatusNoContent: "No Content", StatusResetContent: "Reset Content", StatusPartialContent: "Partial Content", + StatusMultiStatus: "Multi-Status", + StatusAlreadyReported: "Already Reported", + StatusIMUsed: "IM Used", StatusMultipleChoices: "Multiple Choices", StatusMovedPermanently: "Moved Permanently", @@ -77,6 +96,7 @@ var statusText = map[int]string{ StatusNotModified: "Not Modified", StatusUseProxy: "Use Proxy", StatusTemporaryRedirect: "Temporary Redirect", + StatusPermanentRedirect: "Permanent Redirect", StatusBadRequest: "Bad Request", StatusUnauthorized: "Unauthorized", @@ -97,6 +117,10 @@ var statusText = map[int]string{ StatusRequestedRangeNotSatisfiable: "Requested Range Not Satisfiable", StatusExpectationFailed: "Expectation Failed", StatusTeapot: "I'm a teapot", + StatusUnprocessableEntity: "Unprocessable Entity", + StatusLocked: "Locked", + StatusFailedDependency: "Failed Dependency", + StatusUpgradeRequired: "Upgrade Required", StatusPreconditionRequired: "Precondition Required", StatusTooManyRequests: "Too Many Requests", StatusRequestHeaderFieldsTooLarge: "Request Header Fields Too Large", @@ -108,6 +132,10 @@ var statusText = map[int]string{ StatusServiceUnavailable: "Service Unavailable", StatusGatewayTimeout: "Gateway Timeout", StatusHTTPVersionNotSupported: "HTTP Version Not Supported", + StatusVariantAlsoNegotiates: "Variant Also Negotiates", + StatusInsufficientStorage: "Insufficient Storage", + StatusLoopDetected: "Loop Detected", + StatusNotExtended: "Not Extended", StatusNetworkAuthenticationRequired: "Network Authentication Required", } diff --git a/libgo/go/net/http/transfer.go b/libgo/go/net/http/transfer.go index 6e59af8f6f4..c653467098c 100644 --- a/libgo/go/net/http/transfer.go +++ b/libgo/go/net/http/transfer.go @@ -17,6 +17,8 @@ import ( "strconv" "strings" "sync" + + "golang_org/x/net/lex/httplex" ) // ErrLineTooLong is returned when reading request or response bodies @@ -276,7 +278,7 @@ func (t *transferReader) protoAtLeast(m, n int) bool { } // bodyAllowedForStatus reports whether a given response status code -// permits a body. See RFC2616, section 4.4. +// permits a body. See RFC 2616, section 4.4. func bodyAllowedForStatus(status int) bool { switch { case status >= 100 && status <= 199: @@ -368,7 +370,7 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) { // If there is no Content-Length or chunked Transfer-Encoding on a *Response // and the status is not 1xx, 204 or 304, then the body is unbounded. - // See RFC2616, section 4.4. + // See RFC 2616, section 4.4. switch msg.(type) { case *Response: if realLength == -1 && @@ -379,7 +381,7 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) { } } - // Prepare body reader. ContentLength < 0 means chunked encoding + // Prepare body reader. ContentLength < 0 means chunked encoding // or close connection when finished, since multipart is not supported yet switch { case chunked(t.TransferEncoding): @@ -558,21 +560,19 @@ func fixLength(isResponse bool, status int, requestMethod string, header Header, func shouldClose(major, minor int, header Header, removeCloseHeader bool) bool { if major < 1 { return true - } else if major == 1 && minor == 0 { - vv := header["Connection"] - if headerValuesContainsToken(vv, "close") || !headerValuesContainsToken(vv, "keep-alive") { - return true - } - return false - } else { - if headerValuesContainsToken(header["Connection"], "close") { - if removeCloseHeader { - header.Del("Connection") - } - return true - } } - return false + + conv := header["Connection"] + hasClose := httplex.HeaderValuesContainsToken(conv, "close") + if major == 1 && minor == 0 { + return hasClose || !httplex.HeaderValuesContainsToken(conv, "keep-alive") + } + + if hasClose && removeCloseHeader { + header.Del("Connection") + } + + return hasClose } // Parse the trailer header @@ -729,11 +729,11 @@ func (b *body) readTrailer() error { } // Make sure there's a header terminator coming up, to prevent - // a DoS with an unbounded size Trailer. It's not easy to + // a DoS with an unbounded size Trailer. It's not easy to // slip in a LimitReader here, as textproto.NewReader requires - // a concrete *bufio.Reader. Also, we can't get all the way + // a concrete *bufio.Reader. Also, we can't get all the way // back up to our conn's LimitedReader that *might* be backing - // this bufio.Reader. Instead, a hack: we iteratively Peek up + // this bufio.Reader. Instead, a hack: we iteratively Peek up // to the bufio.Reader's max size, looking for a double CRLF. // This limits the trailer to the underlying buffer size, typically 4kB. if !seeUpcomingDoubleCRLF(b.r) { diff --git a/libgo/go/net/http/transport.go b/libgo/go/net/http/transport.go index baf71d5e85e..9164d0d827c 100644 --- a/libgo/go/net/http/transport.go +++ b/libgo/go/net/http/transport.go @@ -12,17 +12,22 @@ package http import ( "bufio" "compress/gzip" + "container/list" + "context" "crypto/tls" "errors" "fmt" "io" "log" "net" + "net/http/httptrace" "net/url" "os" "strings" "sync" "time" + + "golang_org/x/net/lex/httplex" ) // DefaultTransport is the default implementation of Transport and is @@ -32,10 +37,12 @@ import ( // $no_proxy) environment variables. var DefaultTransport RoundTripper = &Transport{ Proxy: ProxyFromEnvironment, - Dial: (&net.Dialer{ + DialContext: (&net.Dialer{ Timeout: 30 * time.Second, KeepAlive: 30 * time.Second, - }).Dial, + }).DialContext, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, } @@ -63,9 +70,10 @@ const DefaultMaxIdleConnsPerHost = 2 // See the package docs for more about HTTP/2. type Transport struct { idleMu sync.Mutex - wantIdle bool // user has requested to close all idle conns - idleConn map[connectMethodKey][]*persistConn + wantIdle bool // user has requested to close all idle conns + idleConn map[connectMethodKey][]*persistConn // most recently used at end idleConnCh map[connectMethodKey]chan *persistConn + idleLRU connLRU reqMu sync.Mutex reqCanceler map[*Request]func() @@ -79,9 +87,16 @@ type Transport struct { // If Proxy is nil or returns a nil *URL, no proxy is used. Proxy func(*Request) (*url.URL, error) - // Dial specifies the dial function for creating unencrypted - // TCP connections. - // If Dial is nil, net.Dial is used. + // DialContext specifies the dial function for creating unencrypted TCP connections. + // If DialContext is nil (and the deprecated Dial below is also nil), + // then the transport dials using package net. + DialContext func(ctx context.Context, network, addr string) (net.Conn, error) + + // Dial specifies the dial function for creating unencrypted TCP connections. + // + // Deprecated: Use DialContext instead, which allows the transport + // to cancel dials as soon as they are no longer needed. + // If both are set, DialContext takes priority. Dial func(network, addr string) (net.Conn, error) // DialTLS specifies an optional dial function for creating @@ -117,11 +132,21 @@ type Transport struct { // uncompressed. DisableCompression bool + // MaxIdleConns controls the maximum number of idle (keep-alive) + // connections across all hosts. Zero means no limit. + MaxIdleConns int + // MaxIdleConnsPerHost, if non-zero, controls the maximum idle - // (keep-alive) to keep per-host. If zero, + // (keep-alive) connections to keep per-host. If zero, // DefaultMaxIdleConnsPerHost is used. MaxIdleConnsPerHost int + // IdleConnTimeout is the maximum amount of time an idle + // (keep-alive) connection will remain idle before closing + // itself. + // Zero means no limit. + IdleConnTimeout time.Duration + // ResponseHeaderTimeout, if non-zero, specifies the amount of // time to wait for a server's response headers after fully // writing the request (including its body, if any). This @@ -137,7 +162,7 @@ type Transport struct { // TLSNextProto specifies how the Transport switches to an // alternate protocol (such as HTTP/2) after a TLS NPN/ALPN - // protocol negotiation. If Transport dials an TLS connection + // protocol negotiation. If Transport dials an TLS connection // with a non-empty protocol name and TLSNextProto contains a // map entry for that key (such as "h2"), then the func is // called with the request's authority (such as "example.com" @@ -146,13 +171,18 @@ type Transport struct { // If TLSNextProto is nil, HTTP/2 support is enabled automatically. TLSNextProto map[string]func(authority string, c *tls.Conn) RoundTripper + // MaxResponseHeaderBytes specifies a limit on how many + // response bytes are allowed in the server's response + // header. + // + // Zero means to use a default limit. + MaxResponseHeaderBytes int64 + // nextProtoOnce guards initialization of TLSNextProto and // h2transport (via onceSetNextProtoDefaults) nextProtoOnce sync.Once h2transport *http2Transport // non-nil if http2 wired up - // TODO: tunable on global max cached connections - // TODO: tunable on timeout on cached connections // TODO: tunable on max per-host TCP dials in flight (Issue 13957) } @@ -167,25 +197,34 @@ func (t *Transport) onceSetNextProtoDefaults() { // Transport. return } - if t.TLSClientConfig != nil { - // Be conservative for now (for Go 1.6) at least and - // don't automatically enable http2 if they've - // specified a custom TLS config. Let them opt-in - // themselves via http2.ConfigureTransport so we don't - // surprise them by modifying their tls.Config. - // Issue 14275. - return - } - if t.ExpectContinueTimeout != 0 { - // Unsupported in http2, so disable http2 for now. - // Issue 13851. + if t.TLSClientConfig != nil || t.Dial != nil || t.DialTLS != nil { + // Be conservative and don't automatically enable + // http2 if they've specified a custom TLS config or + // custom dialers. Let them opt-in themselves via + // http2.ConfigureTransport so we don't surprise them + // by modifying their tls.Config. Issue 14275. return } t2, err := http2configureTransport(t) if err != nil { log.Printf("Error enabling Transport HTTP/2 support: %v", err) - } else { - t.h2transport = t2 + return + } + t.h2transport = t2 + + // Auto-configure the http2.Transport's MaxHeaderListSize from + // the http.Transport's MaxResponseHeaderBytes. They don't + // exactly mean the same thing, but they're close. + // + // TODO: also add this to x/net/http2.Configure Transport, behind + // a +build go1.7 build tag: + if limit1 := t.MaxResponseHeaderBytes; limit1 != 0 && t2.MaxHeaderListSize == 0 { + const h2max = 1<<32 - 1 + if limit1 >= h2max { + t2.MaxHeaderListSize = h2max + } else { + t2.MaxHeaderListSize = uint32(limit1) + } } } @@ -212,6 +251,9 @@ func ProxyFromEnvironment(req *Request) (*url.URL, error) { } if proxy == "" { proxy = httpProxyEnv.Get() + if proxy != "" && os.Getenv("REQUEST_METHOD") != "" { + return nil, errors.New("net/http: refusing to use HTTP_PROXY value in CGI environment; see golang.org/s/cgihttpproxy") + } } if proxy == "" { return nil, nil @@ -245,8 +287,9 @@ func ProxyURL(fixedURL *url.URL) func(*Request) (*url.URL, error) { // transportRequest is a wrapper around a *Request that adds // optional extra headers to write. type transportRequest struct { - *Request // original request, not to be mutated - extra Header // extra headers to write, or nil + *Request // original request, not to be mutated + extra Header // extra headers to write, or nil + trace *httptrace.ClientTrace // optional } func (tr *transportRequest) extraHeaders() Header { @@ -262,6 +305,9 @@ func (tr *transportRequest) extraHeaders() Header { // and redirects), see Get, Post, and the Client type. func (t *Transport) RoundTrip(req *Request) (*Response, error) { t.nextProtoOnce.Do(t.onceSetNextProtoDefaults) + ctx := req.Context() + trace := httptrace.ContextClientTrace(ctx) + if req.URL == nil { req.closeBody() return nil, errors.New("http: nil Request.URL") @@ -270,18 +316,32 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { req.closeBody() return nil, errors.New("http: nil Request.Header") } + scheme := req.URL.Scheme + isHTTP := scheme == "http" || scheme == "https" + if isHTTP { + for k, vv := range req.Header { + if !httplex.ValidHeaderFieldName(k) { + return nil, fmt.Errorf("net/http: invalid header field name %q", k) + } + for _, v := range vv { + if !httplex.ValidHeaderFieldValue(v) { + return nil, fmt.Errorf("net/http: invalid header field value %q for key %v", v, k) + } + } + } + } // TODO(bradfitz): switch to atomic.Value for this map instead of RWMutex t.altMu.RLock() - altRT := t.altProto[req.URL.Scheme] + altRT := t.altProto[scheme] t.altMu.RUnlock() if altRT != nil { if resp, err := altRT.RoundTrip(req); err != ErrSkipAltProtocol { return resp, err } } - if s := req.URL.Scheme; s != "http" && s != "https" { + if !isHTTP { req.closeBody() - return nil, &badStringError{"unsupported protocol scheme", s} + return nil, &badStringError{"unsupported protocol scheme", scheme} } if req.Method != "" && !validMethod(req.Method) { return nil, fmt.Errorf("net/http: invalid method %q", req.Method) @@ -293,7 +353,7 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { for { // treq gets modified by roundTrip, so we need to recreate for each retry. - treq := &transportRequest{Request: req} + treq := &transportRequest{Request: req, trace: trace} cm, err := t.connectMethodForRequest(treq) if err != nil { req.closeBody() @@ -302,9 +362,9 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { // Get the cached or newly-created connection to either the // host (for http or https), the http proxy, or the http proxy - // pre-CONNECTed to https server. In any case, we'll be ready + // pre-CONNECTed to https server. In any case, we'll be ready // to send it requests. - pconn, err := t.getConn(req, cm) + pconn, err := t.getConn(treq, cm) if err != nil { t.setReqCanceler(req, nil) req.closeBody() @@ -322,46 +382,47 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { if err == nil { return resp, nil } - if err := checkTransportResend(err, req, pconn); err != nil { + if !pconn.shouldRetryRequest(req, err) { return nil, err } testHookRoundTripRetried() } } -// checkTransportResend checks whether a failed HTTP request can be -// resent on a new connection. The non-nil input error is the error from -// roundTrip, which might be wrapped in a beforeRespHeaderError error. -// -// The return value is err or the unwrapped error inside a -// beforeRespHeaderError. -func checkTransportResend(err error, req *Request, pconn *persistConn) error { - brhErr, ok := err.(beforeRespHeaderError) - if !ok { - return err +// shouldRetryRequest reports whether we should retry sending a failed +// HTTP request on a new connection. The non-nil input error is the +// error from roundTrip. +func (pc *persistConn) shouldRetryRequest(req *Request, err error) bool { + if err == errMissingHost { + // User error. + return false } - err = brhErr.error // unwrap the custom error in case we return it - if err != errMissingHost && pconn.isReused() && req.isReplayable() { - // If we try to reuse a connection that the server is in the process of - // closing, we may end up successfully writing out our request (or a - // portion of our request) only to find a connection error when we try to - // read from (or finish writing to) the socket. - - // There can be a race between the socket pool checking whether a socket - // is still connected, receiving the FIN, and sending/reading data on a - // reused socket. If we receive the FIN between the connectedness check - // and writing/reading from the socket, we may first learn the socket is - // disconnected when we get a ERR_SOCKET_NOT_CONNECTED. This will most - // likely happen when trying to retrieve its IP address. See - // http://crbug.com/105824 for more details. - - // We resend a request only if we reused a keep-alive connection and did - // not yet receive any header data. This automatically prevents an - // infinite resend loop because we'll run out of the cached keep-alive - // connections eventually. - return nil + if !pc.isReused() { + // This was a fresh connection. There's no reason the server + // should've hung up on us. + // + // Also, if we retried now, we could loop forever + // creating new connections and retrying if the server + // is just hanging up on us because it doesn't like + // our request (as opposed to sending an error). + return false } - return err + if !req.isReplayable() { + // Don't retry non-idempotent requests. + + // TODO: swap the nothingWrittenError and isReplayable checks, + // putting the "if nothingWrittenError => return true" case + // first, per golang.org/issue/15723 + return false + } + if _, ok := err.(nothingWrittenError); ok { + // We never wrote anything, so it's safe to retry. + return true + } + if err == errServerClosedIdle || err == errServerClosedConn { + return true + } + return false // conservatively } // ErrSkipAltProtocol is a sentinel error value defined by Transport.RegisterProtocol. @@ -400,6 +461,7 @@ func (t *Transport) CloseIdleConnections() { t.idleConn = nil t.idleConnCh = nil t.wantIdle = true + t.idleLRU = connLRU{} t.idleMu.Unlock() for _, conns := range m { for _, pconn := range conns { @@ -414,7 +476,7 @@ func (t *Transport) CloseIdleConnections() { // CancelRequest cancels an in-flight request by closing its connection. // CancelRequest should only be called after RoundTrip has returned. // -// Deprecated: Use Request.Cancel instead. CancelRequest can not cancel +// Deprecated: Use Request.Cancel instead. CancelRequest cannot cancel // HTTP/2 requests. func (t *Transport) CancelRequest(req *Request) { t.reqMu.Lock() @@ -500,9 +562,12 @@ var ( errConnBroken = errors.New("http: putIdleConn: connection is in bad state") errWantIdle = errors.New("http: putIdleConn: CloseIdleConnections was called") errTooManyIdle = errors.New("http: putIdleConn: too many idle connections") + errTooManyIdleHost = errors.New("http: putIdleConn: too many idle connections for host") errCloseIdleConns = errors.New("http: CloseIdleConnections called") errReadLoopExiting = errors.New("http: persistConn.readLoop exiting") - errServerClosedIdle = errors.New("http: server closed idle conn") + errServerClosedIdle = errors.New("http: server closed idle connection") + errServerClosedConn = errors.New("http: server closed connection") + errIdleConnTimeout = errors.New("http: idle connection timeout") ) func (t *Transport) putOrCloseIdleConn(pconn *persistConn) { @@ -511,6 +576,13 @@ func (t *Transport) putOrCloseIdleConn(pconn *persistConn) { } } +func (t *Transport) maxIdleConnsPerHost() int { + if v := t.MaxIdleConnsPerHost; v != 0 { + return v + } + return DefaultMaxIdleConnsPerHost +} + // tryPutIdleConn adds pconn to the list of idle persistent connections awaiting // a new request. // If pconn is no longer needed or not in a good state, tryPutIdleConn returns @@ -523,13 +595,11 @@ func (t *Transport) tryPutIdleConn(pconn *persistConn) error { if pconn.isBroken() { return errConnBroken } - key := pconn.cacheKey - max := t.MaxIdleConnsPerHost - if max == 0 { - max = DefaultMaxIdleConnsPerHost - } pconn.markReused() + key := pconn.cacheKey + t.idleMu.Lock() + defer t.idleMu.Unlock() waitingDialer := t.idleConnCh[key] select { @@ -537,9 +607,8 @@ func (t *Transport) tryPutIdleConn(pconn *persistConn) error { // We're done with this pconn and somebody else is // currently waiting for a conn of this type (they're // actively dialing, but this conn is ready - // first). Chrome calls this socket late binding. See + // first). Chrome calls this socket late binding. See // https://insouciant.org/tech/connection-management-in-chromium/ - t.idleMu.Unlock() return nil default: if waitingDialer != nil { @@ -549,23 +618,35 @@ func (t *Transport) tryPutIdleConn(pconn *persistConn) error { } } if t.wantIdle { - t.idleMu.Unlock() return errWantIdle } if t.idleConn == nil { t.idleConn = make(map[connectMethodKey][]*persistConn) } - if len(t.idleConn[key]) >= max { - t.idleMu.Unlock() - return errTooManyIdle + idles := t.idleConn[key] + if len(idles) >= t.maxIdleConnsPerHost() { + return errTooManyIdleHost } - for _, exist := range t.idleConn[key] { + for _, exist := range idles { if exist == pconn { log.Fatalf("dup idle pconn %p in freelist", pconn) } } - t.idleConn[key] = append(t.idleConn[key], pconn) - t.idleMu.Unlock() + t.idleConn[key] = append(idles, pconn) + t.idleLRU.add(pconn) + if t.MaxIdleConns != 0 && t.idleLRU.len() > t.MaxIdleConns { + oldest := t.idleLRU.removeOldest() + oldest.close(errTooManyIdle) + t.removeIdleConnLocked(oldest) + } + if t.IdleConnTimeout > 0 { + if pconn.idleTimer != nil { + pconn.idleTimer.Reset(t.IdleConnTimeout) + } else { + pconn.idleTimer = time.AfterFunc(t.IdleConnTimeout, pconn.closeConnIfStillIdle) + } + } + pconn.idleAt = time.Now() return nil } @@ -591,29 +672,75 @@ func (t *Transport) getIdleConnCh(cm connectMethod) chan *persistConn { return ch } -func (t *Transport) getIdleConn(cm connectMethod) (pconn *persistConn) { +func (t *Transport) getIdleConn(cm connectMethod) (pconn *persistConn, idleSince time.Time) { key := cm.key() t.idleMu.Lock() defer t.idleMu.Unlock() - if t.idleConn == nil { - return nil - } for { pconns, ok := t.idleConn[key] if !ok { - return nil + return nil, time.Time{} } if len(pconns) == 1 { pconn = pconns[0] delete(t.idleConn, key) } else { - // 2 or more cached connections; pop last - // TODO: queue? + // 2 or more cached connections; use the most + // recently used one at the end. pconn = pconns[len(pconns)-1] t.idleConn[key] = pconns[:len(pconns)-1] } - if !pconn.isBroken() { - return + t.idleLRU.remove(pconn) + if pconn.isBroken() { + // There is a tiny window where this is + // possible, between the connecting dying and + // the persistConn readLoop calling + // Transport.removeIdleConn. Just skip it and + // carry on. + continue + } + if pconn.idleTimer != nil && !pconn.idleTimer.Stop() { + // We picked this conn at the ~same time it + // was expiring and it's trying to close + // itself in another goroutine. Don't use it. + continue + } + return pconn, pconn.idleAt + } +} + +// removeIdleConn marks pconn as dead. +func (t *Transport) removeIdleConn(pconn *persistConn) { + t.idleMu.Lock() + defer t.idleMu.Unlock() + t.removeIdleConnLocked(pconn) +} + +// t.idleMu must be held. +func (t *Transport) removeIdleConnLocked(pconn *persistConn) { + if pconn.idleTimer != nil { + pconn.idleTimer.Stop() + } + t.idleLRU.remove(pconn) + key := pconn.cacheKey + pconns, _ := t.idleConn[key] + switch len(pconns) { + case 0: + // Nothing + case 1: + if pconns[0] == pconn { + delete(t.idleConn, key) + } + default: + for i, v := range pconns { + if v != pconn { + continue + } + // Slide down, keeping most recently-used + // conns at the end. + copy(pconns[i:], pconns[i+1:]) + t.idleConn[key] = pconns[:len(pconns)-1] + break } } } @@ -650,7 +777,12 @@ func (t *Transport) replaceReqCanceler(r *Request, fn func()) bool { return true } -func (t *Transport) dial(network, addr string) (net.Conn, error) { +var zeroDialer net.Dialer + +func (t *Transport) dial(ctx context.Context, network, addr string) (net.Conn, error) { + if t.DialContext != nil { + return t.DialContext(ctx, network, addr) + } if t.Dial != nil { c, err := t.Dial(network, addr) if c == nil && err == nil { @@ -658,15 +790,24 @@ func (t *Transport) dial(network, addr string) (net.Conn, error) { } return c, err } - return net.Dial(network, addr) + return zeroDialer.DialContext(ctx, network, addr) } // getConn dials and creates a new persistConn to the target as -// specified in the connectMethod. This includes doing a proxy CONNECT +// specified in the connectMethod. This includes doing a proxy CONNECT // and/or setting up TLS. If this doesn't return an error, the persistConn // is ready to write requests to. -func (t *Transport) getConn(req *Request, cm connectMethod) (*persistConn, error) { - if pc := t.getIdleConn(cm); pc != nil { +func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistConn, error) { + req := treq.Request + trace := treq.trace + ctx := req.Context() + if trace != nil && trace.GetConn != nil { + trace.GetConn(cm.addr()) + } + if pc, idleSince := t.getIdleConn(cm); pc != nil { + if trace != nil && trace.GotConn != nil { + trace.GotConn(pc.gotIdleConnTrace(idleSince)) + } // set request canceler to some non-nil function so we // can detect whether it was cleared between now and when // we enter roundTrip @@ -699,7 +840,7 @@ func (t *Transport) getConn(req *Request, cm connectMethod) (*persistConn, error t.setReqCanceler(req, func() { close(cancelc) }) go func() { - pc, err := t.dialConn(cm) + pc, err := t.dialConn(ctx, cm) dialc <- dialRes{pc, err} }() @@ -707,7 +848,26 @@ func (t *Transport) getConn(req *Request, cm connectMethod) (*persistConn, error select { case v := <-dialc: // Our dial finished. - return v.pc, v.err + if v.pc != nil { + if trace != nil && trace.GotConn != nil && v.pc.alt == nil { + trace.GotConn(httptrace.GotConnInfo{Conn: v.pc.conn}) + } + return v.pc, nil + } + // Our dial failed. See why to return a nicer error + // value. + select { + case <-req.Cancel: + case <-req.Context().Done(): + case <-cancelc: + default: + // It wasn't an error due to cancelation, so + // return the original error message: + return nil, v.err + } + // It was an error due to cancelation, so prioritize that + // error value. (Issue 16049) + return nil, errRequestCanceledConn case pc := <-idleConnCh: // Another request finished first and its net.Conn // became available before our dial. Or somebody @@ -715,24 +875,31 @@ func (t *Transport) getConn(req *Request, cm connectMethod) (*persistConn, error // But our dial is still going, so give it away // when it finishes: handlePendingDial() + if trace != nil && trace.GotConn != nil { + trace.GotConn(httptrace.GotConnInfo{Conn: pc.conn, Reused: pc.isReused()}) + } return pc, nil case <-req.Cancel: handlePendingDial() return nil, errRequestCanceledConn + case <-req.Context().Done(): + handlePendingDial() + return nil, errRequestCanceledConn case <-cancelc: handlePendingDial() return nil, errRequestCanceledConn } } -func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) { +func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistConn, error) { pconn := &persistConn{ - t: t, - cacheKey: cm.key(), - reqch: make(chan requestAndChan, 1), - writech: make(chan writeRequest, 1), - closech: make(chan struct{}), - writeErrCh: make(chan error, 1), + t: t, + cacheKey: cm.key(), + reqch: make(chan requestAndChan, 1), + writech: make(chan writeRequest, 1), + closech: make(chan struct{}), + writeErrCh: make(chan error, 1), + writeLoopDone: make(chan struct{}), } tlsDial := t.DialTLS != nil && cm.targetScheme == "https" && cm.proxyURL == nil if tlsDial { @@ -755,7 +922,7 @@ func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) { pconn.tlsState = &cs } } else { - conn, err := t.dial("tcp", cm.addr()) + conn, err := t.dial(ctx, "tcp", cm.addr()) if err != nil { if cm.proxyURL != nil { err = fmt.Errorf("http: error connecting to proxy %s: %v", cm.proxyURL, err) @@ -848,13 +1015,29 @@ func (t *Transport) dialConn(cm connectMethod) (*persistConn, error) { } } - pconn.br = bufio.NewReader(noteEOFReader{pconn.conn, &pconn.sawEOF}) - pconn.bw = bufio.NewWriter(pconn.conn) + pconn.br = bufio.NewReader(pconn) + pconn.bw = bufio.NewWriter(persistConnWriter{pconn}) go pconn.readLoop() go pconn.writeLoop() return pconn, nil } +// persistConnWriter is the io.Writer written to by pc.bw. +// It accumulates the number of bytes written to the underlying conn, +// so the retry logic can determine whether any bytes made it across +// the wire. +// This is exactly 1 pointer field wide so it can go into an interface +// without allocation. +type persistConnWriter struct { + pc *persistConn +} + +func (w persistConnWriter) Write(p []byte) (n int, err error) { + n, err = w.pc.conn.Write(p) + w.pc.nwrite += int64(n) + return +} + // useProxy reports whether requests to addr should use a proxy, // according to the NO_PROXY or no_proxy environment variable. // addr is always a canonicalAddr with a host and port. @@ -978,28 +1161,36 @@ func (k connectMethodKey) String() string { // (but may be used for non-keep-alive requests as well) type persistConn struct { // alt optionally specifies the TLS NextProto RoundTripper. - // This is used for HTTP/2 today and future protocol laters. + // This is used for HTTP/2 today and future protocols later. // If it's non-nil, the rest of the fields are unused. alt RoundTripper - t *Transport - cacheKey connectMethodKey - conn net.Conn - tlsState *tls.ConnectionState - br *bufio.Reader // from conn - sawEOF bool // whether we've seen EOF from conn; owned by readLoop - bw *bufio.Writer // to conn - reqch chan requestAndChan // written by roundTrip; read by readLoop - writech chan writeRequest // written by roundTrip; read by writeLoop - closech chan struct{} // closed when conn closed - isProxy bool + t *Transport + cacheKey connectMethodKey + conn net.Conn + tlsState *tls.ConnectionState + br *bufio.Reader // from conn + bw *bufio.Writer // to conn + nwrite int64 // bytes written + reqch chan requestAndChan // written by roundTrip; read by readLoop + writech chan writeRequest // written by roundTrip; read by writeLoop + closech chan struct{} // closed when conn closed + isProxy bool + sawEOF bool // whether we've seen EOF from conn; owned by readLoop + readLimit int64 // bytes allowed to be read; owned by readLoop // writeErrCh passes the request write error (usually nil) // from the writeLoop goroutine to the readLoop which passes // it off to the res.Body reader, which then uses it to decide // whether or not a connection can be reused. Issue 7569. writeErrCh chan error - lk sync.Mutex // guards following fields + writeLoopDone chan struct{} // closed when write loop ends + + // Both guarded by Transport.idleMu: + idleAt time.Time // time it last become idle + idleTimer *time.Timer // holding an AfterFunc to close it + + mu sync.Mutex // guards following fields numExpectedResponses int closed error // set non-nil when conn is closed, before closech is closed broken bool // an error has happened on this connection; marked broken so it's not reused. @@ -1011,45 +1202,153 @@ type persistConn struct { mutateHeaderFunc func(Header) } +func (pc *persistConn) maxHeaderResponseSize() int64 { + if v := pc.t.MaxResponseHeaderBytes; v != 0 { + return v + } + return 10 << 20 // conservative default; same as http2 +} + +func (pc *persistConn) Read(p []byte) (n int, err error) { + if pc.readLimit <= 0 { + return 0, fmt.Errorf("read limit of %d bytes exhausted", pc.maxHeaderResponseSize()) + } + if int64(len(p)) > pc.readLimit { + p = p[:pc.readLimit] + } + n, err = pc.conn.Read(p) + if err == io.EOF { + pc.sawEOF = true + } + pc.readLimit -= int64(n) + return +} + // isBroken reports whether this connection is in a known broken state. func (pc *persistConn) isBroken() bool { - pc.lk.Lock() - b := pc.broken - pc.lk.Unlock() + pc.mu.Lock() + b := pc.closed != nil + pc.mu.Unlock() return b } // isCanceled reports whether this connection was closed due to CancelRequest. func (pc *persistConn) isCanceled() bool { - pc.lk.Lock() - defer pc.lk.Unlock() + pc.mu.Lock() + defer pc.mu.Unlock() return pc.canceled } // isReused reports whether this connection is in a known broken state. func (pc *persistConn) isReused() bool { - pc.lk.Lock() + pc.mu.Lock() r := pc.reused - pc.lk.Unlock() + pc.mu.Unlock() return r } +func (pc *persistConn) gotIdleConnTrace(idleAt time.Time) (t httptrace.GotConnInfo) { + pc.mu.Lock() + defer pc.mu.Unlock() + t.Reused = pc.reused + t.Conn = pc.conn + t.WasIdle = true + if !idleAt.IsZero() { + t.IdleTime = time.Since(idleAt) + } + return +} + func (pc *persistConn) cancelRequest() { - pc.lk.Lock() - defer pc.lk.Unlock() + pc.mu.Lock() + defer pc.mu.Unlock() pc.canceled = true pc.closeLocked(errRequestCanceled) } +// closeConnIfStillIdle closes the connection if it's still sitting idle. +// This is what's called by the persistConn's idleTimer, and is run in its +// own goroutine. +func (pc *persistConn) closeConnIfStillIdle() { + t := pc.t + t.idleMu.Lock() + defer t.idleMu.Unlock() + if _, ok := t.idleLRU.m[pc]; !ok { + // Not idle. + return + } + t.removeIdleConnLocked(pc) + pc.close(errIdleConnTimeout) +} + +// mapRoundTripErrorFromReadLoop maps the provided readLoop error into +// the error value that should be returned from persistConn.roundTrip. +// +// The startBytesWritten value should be the value of pc.nwrite before the roundTrip +// started writing the request. +func (pc *persistConn) mapRoundTripErrorFromReadLoop(startBytesWritten int64, err error) (out error) { + if err == nil { + return nil + } + if pc.isCanceled() { + return errRequestCanceled + } + if err == errServerClosedIdle || err == errServerClosedConn { + return err + } + if pc.isBroken() { + <-pc.writeLoopDone + if pc.nwrite == startBytesWritten { + return nothingWrittenError{err} + } + } + return err +} + +// mapRoundTripErrorAfterClosed returns the error value to be propagated +// up to Transport.RoundTrip method when persistConn.roundTrip sees +// its pc.closech channel close, indicating the persistConn is dead. +// (after closech is closed, pc.closed is valid). +func (pc *persistConn) mapRoundTripErrorAfterClosed(startBytesWritten int64) error { + if pc.isCanceled() { + return errRequestCanceled + } + err := pc.closed + if err == errServerClosedIdle || err == errServerClosedConn { + // Don't decorate + return err + } + + // Wait for the writeLoop goroutine to terminated, and then + // see if we actually managed to write anything. If not, we + // can retry the request. + <-pc.writeLoopDone + if pc.nwrite == startBytesWritten { + return nothingWrittenError{err} + } + + return fmt.Errorf("net/http: HTTP/1.x transport connection broken: %v", err) + +} + func (pc *persistConn) readLoop() { closeErr := errReadLoopExiting // default value, if not changed below - defer func() { pc.close(closeErr) }() + defer func() { + pc.close(closeErr) + pc.t.removeIdleConn(pc) + }() - tryPutIdleConn := func() bool { + tryPutIdleConn := func(trace *httptrace.ClientTrace) bool { if err := pc.t.tryPutIdleConn(pc); err != nil { closeErr = err + if trace != nil && trace.PutIdleConn != nil && err != errKeepAlivesDisabled { + trace.PutIdleConn(err) + } return false } + if trace != nil && trace.PutIdleConn != nil { + trace.PutIdleConn(nil) + } return true } @@ -1066,27 +1365,33 @@ func (pc *persistConn) readLoop() { alive := true for alive { + pc.readLimit = pc.maxHeaderResponseSize() _, err := pc.br.Peek(1) - if err != nil { - err = beforeRespHeaderError{err} - } - pc.lk.Lock() + pc.mu.Lock() if pc.numExpectedResponses == 0 { pc.readLoopPeekFailLocked(err) - pc.lk.Unlock() + pc.mu.Unlock() return } - pc.lk.Unlock() + pc.mu.Unlock() rc := <-pc.reqch + trace := httptrace.ContextClientTrace(rc.req.Context()) var resp *Response if err == nil { - resp, err = pc.readResponse(rc) + resp, err = pc.readResponse(rc, trace) + } else { + err = errServerClosedConn + closeErr = err } if err != nil { + if pc.readLimit <= 0 { + err = fmt.Errorf("net/http: server response headers exceeded %d bytes; aborted", pc.maxHeaderResponseSize()) + } + // If we won't be able to retry this request later (from the // roundTrip goroutine), mark it as done now. // BEFORE the send on rc.ch, as the client might re-use the @@ -1094,7 +1399,7 @@ func (pc *persistConn) readLoop() { // t.setReqCanceler from this persistConn while the Transport // potentially spins up a different persistConn for the // caller's subsequent request. - if checkTransportResend(err, rc.req, pc) != nil { + if !pc.shouldRetryRequest(rc.req, err) { pc.t.setReqCanceler(rc.req, nil) } select { @@ -1104,10 +1409,11 @@ func (pc *persistConn) readLoop() { } return } + pc.readLimit = maxInt64 // effictively no limit for response bodies - pc.lk.Lock() + pc.mu.Lock() pc.numExpectedResponses-- - pc.lk.Unlock() + pc.mu.Unlock() hasBody := rc.req.Method != "HEAD" && resp.ContentLength != 0 @@ -1130,7 +1436,7 @@ func (pc *persistConn) readLoop() { alive = alive && !pc.sawEOF && pc.wroteRequest() && - tryPutIdleConn() + tryPutIdleConn(trace) select { case rc.ch <- responseAndError{res: resp}: @@ -1145,25 +1451,33 @@ func (pc *persistConn) readLoop() { continue } - if rc.addedGzip { - maybeUngzipResponse(resp) + waitForBodyRead := make(chan bool, 2) + body := &bodyEOFSignal{ + body: resp.Body, + earlyCloseFn: func() error { + waitForBodyRead <- false + return nil + + }, + fn: func(err error) error { + isEOF := err == io.EOF + waitForBodyRead <- isEOF + if isEOF { + <-eofc // see comment above eofc declaration + } else if err != nil && pc.isCanceled() { + return errRequestCanceled + } + return err + }, } - resp.Body = &bodyEOFSignal{body: resp.Body} - waitForBodyRead := make(chan bool, 2) - resp.Body.(*bodyEOFSignal).earlyCloseFn = func() error { - waitForBodyRead <- false - return nil - } - resp.Body.(*bodyEOFSignal).fn = func(err error) error { - isEOF := err == io.EOF - waitForBodyRead <- isEOF - if isEOF { - <-eofc // see comment above eofc declaration - } else if err != nil && pc.isCanceled() { - return errRequestCanceled - } - return err + resp.Body = body + if rc.addedGzip && resp.Header.Get("Content-Encoding") == "gzip" { + resp.Body = &gzipReader{body: body} + resp.Header.Del("Content-Encoding") + resp.Header.Del("Content-Length") + resp.ContentLength = -1 + resp.Uncompressed = true } select { @@ -1182,13 +1496,16 @@ func (pc *persistConn) readLoop() { bodyEOF && !pc.sawEOF && pc.wroteRequest() && - tryPutIdleConn() + tryPutIdleConn(trace) if bodyEOF { eofc <- struct{}{} } case <-rc.req.Cancel: alive = false pc.t.CancelRequest(rc.req) + case <-rc.req.Context().Done(): + alive = false + pc.t.CancelRequest(rc.req) case <-pc.closech: alive = false } @@ -1197,15 +1514,6 @@ func (pc *persistConn) readLoop() { } } -func maybeUngzipResponse(resp *Response) { - if resp.Header.Get("Content-Encoding") == "gzip" { - resp.Header.Del("Content-Encoding") - resp.Header.Del("Content-Length") - resp.ContentLength = -1 - resp.Body = &gzipReader{body: resp.Body} - } -} - func (pc *persistConn) readLoopPeekFailLocked(peekErr error) { if pc.closed != nil { return @@ -1224,19 +1532,29 @@ func (pc *persistConn) readLoopPeekFailLocked(peekErr error) { // readResponse reads an HTTP response (or two, in the case of "Expect: // 100-continue") from the server. It returns the final non-100 one. -func (pc *persistConn) readResponse(rc requestAndChan) (resp *Response, err error) { +// trace is optional. +func (pc *persistConn) readResponse(rc requestAndChan, trace *httptrace.ClientTrace) (resp *Response, err error) { + if trace != nil && trace.GotFirstResponseByte != nil { + if peek, err := pc.br.Peek(1); err == nil && len(peek) == 1 { + trace.GotFirstResponseByte() + } + } resp, err = ReadResponse(pc.br, rc.req) if err != nil { return } if rc.continueCh != nil { if resp.StatusCode == 100 { + if trace != nil && trace.Got100Continue != nil { + trace.Got100Continue() + } rc.continueCh <- struct{}{} } else { close(rc.continueCh) } } if resp.StatusCode == 100 { + pc.readLimit = pc.maxHeaderResponseSize() // reset the limit resp, err = ReadResponse(pc.br, rc.req) if err != nil { return @@ -1268,24 +1586,33 @@ func (pc *persistConn) waitForContinue(continueCh <-chan struct{}) func() bool { } } +// nothingWrittenError wraps a write errors which ended up writing zero bytes. +type nothingWrittenError struct { + error +} + func (pc *persistConn) writeLoop() { + defer close(pc.writeLoopDone) for { select { case wr := <-pc.writech: - if pc.isBroken() { - wr.ch <- errors.New("http: can't write HTTP request on broken connection") - continue - } + startBytesWritten := pc.nwrite err := wr.req.Request.write(pc.bw, pc.isProxy, wr.req.extra, pc.waitForContinue(wr.continueCh)) if err == nil { err = pc.bw.Flush() } if err != nil { - pc.markBroken() wr.req.Request.closeBody() + if pc.nwrite == startBytesWritten { + err = nothingWrittenError{err} + } } pc.writeErrCh <- err // to the body reader, which might recycle us wr.ch <- err // to the roundTrip function + if err != nil { + pc.close(err) + return + } case <-pc.closech: return } @@ -1331,9 +1658,9 @@ type requestAndChan struct { req *Request ch chan responseAndError // unbuffered; always send in select on callerGone - // did the Transport (as opposed to the client code) add an - // Accept-Encoding gzip header? only if it we set it do - // we transparently decode the gzip. + // whether the Transport (as opposed to the user client code) + // added the Accept-Encoding gzip header. If the Transport + // set it, only then do we transparently decode the gzip. addedGzip bool // Optional blocking chan for Expect: 100-continue (for send). @@ -1353,7 +1680,7 @@ type writeRequest struct { req *transportRequest ch chan<- error - // Optional blocking chan for Expect: 100-continue (for recieve). + // Optional blocking chan for Expect: 100-continue (for receive). // If not nil, writeLoop blocks sending request body until // it receives from this chan. continueCh <-chan struct{} @@ -1369,7 +1696,6 @@ func (e *httpError) Timeout() bool { return e.timeout } func (e *httpError) Temporary() bool { return true } var errTimeout error = &httpError{err: "net/http: timeout awaiting response headers", timeout: true} -var errClosed error = &httpError{err: "net/http: server closed connection before response was received"} var errRequestCanceled = errors.New("net/http: request canceled") var errRequestCanceledConn = errors.New("net/http: request canceled while waiting for connection") // TODO: unify? @@ -1387,22 +1713,16 @@ var ( testHookReadLoopBeforeNextRead = nop ) -// beforeRespHeaderError is used to indicate when an IO error has occurred before -// any header data was received. -type beforeRespHeaderError struct { - error -} - func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err error) { testHookEnterRoundTrip() if !pc.t.replaceReqCanceler(req.Request, pc.cancelRequest) { pc.t.putOrCloseIdleConn(pc) return nil, errRequestCanceled } - pc.lk.Lock() + pc.mu.Lock() pc.numExpectedResponses++ headerFn := pc.mutateHeaderFunc - pc.lk.Unlock() + pc.mu.Unlock() if headerFn != nil { headerFn(req.extraHeaders()) @@ -1448,6 +1768,7 @@ func (pc *persistConn) roundTrip(req *transportRequest) (resp *Response, err err // Write the request concurrently with waiting for a response, // in case the server decides to reply before reading our full // request body. + startBytesWritten := pc.nwrite writeErrCh := make(chan error, 1) pc.writech <- writeRequest{req, writeErrCh, continueCh} @@ -1472,7 +1793,7 @@ WaitResponse: if pc.isCanceled() { err = errRequestCanceled } - re = responseAndError{err: beforeRespHeaderError{err}} + re = responseAndError{err: err} pc.close(fmt.Errorf("write error: %v", err)) break WaitResponse } @@ -1482,26 +1803,21 @@ WaitResponse: respHeaderTimer = timer.C } case <-pc.closech: - var err error - if pc.isCanceled() { - err = errRequestCanceled - } else { - err = beforeRespHeaderError{fmt.Errorf("net/http: HTTP/1 transport connection broken: %v", pc.closed)} - } - re = responseAndError{err: err} + re = responseAndError{err: pc.mapRoundTripErrorAfterClosed(startBytesWritten)} break WaitResponse case <-respHeaderTimer: pc.close(errTimeout) re = responseAndError{err: errTimeout} break WaitResponse case re = <-resc: - if re.err != nil && pc.isCanceled() { - re.err = errRequestCanceled - } + re.err = pc.mapRoundTripErrorFromReadLoop(startBytesWritten, re.err) break WaitResponse case <-cancelChan: pc.t.CancelRequest(req.Request) cancelChan = nil + case <-req.Context().Done(): + pc.t.CancelRequest(req.Request) + cancelChan = nil } } @@ -1514,21 +1830,12 @@ WaitResponse: return re.res, re.err } -// markBroken marks a connection as broken (so it's not reused). -// It differs from close in that it doesn't close the underlying -// connection for use when it's still being read. -func (pc *persistConn) markBroken() { - pc.lk.Lock() - defer pc.lk.Unlock() - pc.broken = true -} - // markReused marks this connection as having been successfully used for a // request and response. func (pc *persistConn) markReused() { - pc.lk.Lock() + pc.mu.Lock() pc.reused = true - pc.lk.Unlock() + pc.mu.Unlock() } // close closes the underlying TCP connection and closes @@ -1537,8 +1844,8 @@ func (pc *persistConn) markReused() { // The provided err is only for testing and debugging; in normal // circumstances it should never be seen by users. func (pc *persistConn) close(err error) { - pc.lk.Lock() - defer pc.lk.Unlock() + pc.mu.Lock() + defer pc.mu.Unlock() pc.closeLocked(err) } @@ -1554,7 +1861,7 @@ func (pc *persistConn) closeLocked(err error) { // handlePendingDial's putOrCloseIdleConn when // it turns out the abandoned connection in // flight ended up negotiating an alternate - // protocol. We don't use the connection + // protocol. We don't use the connection // freelist for http2. That's done by the // alternate protocol's RoundTripper. } else { @@ -1579,7 +1886,11 @@ func canonicalAddr(url *url.URL) string { return addr } -// bodyEOFSignal wraps a ReadCloser but runs fn (if non-nil) at most +// bodyEOFSignal is used by the HTTP/1 transport when reading response +// bodies to make sure we see the end of a response body before +// proceeding and reading on the connection again. +// +// It wraps a ReadCloser but runs fn (if non-nil) at most // once, right before its final (error-producing) Read or Close call // returns. fn should return the new error to return from Read or Close. // @@ -1595,12 +1906,14 @@ type bodyEOFSignal struct { earlyCloseFn func() error // optional alt Close func used if io.EOF not seen } +var errReadOnClosedResBody = errors.New("http: read on closed response body") + func (es *bodyEOFSignal) Read(p []byte) (n int, err error) { es.mu.Lock() closed, rerr := es.closed, es.rerr es.mu.Unlock() if closed { - return 0, errors.New("http: read on closed response body") + return 0, errReadOnClosedResBody } if rerr != nil { return 0, rerr @@ -1645,16 +1958,29 @@ func (es *bodyEOFSignal) condfn(err error) error { // gzipReader wraps a response body so it can lazily // call gzip.NewReader on the first call to Read type gzipReader struct { - body io.ReadCloser // underlying Response.Body - zr io.Reader // lazily-initialized gzip reader + body *bodyEOFSignal // underlying HTTP/1 response body framing + zr *gzip.Reader // lazily-initialized gzip reader + zerr error // any error from gzip.NewReader; sticky } func (gz *gzipReader) Read(p []byte) (n int, err error) { if gz.zr == nil { - gz.zr, err = gzip.NewReader(gz.body) - if err != nil { - return 0, err + if gz.zerr == nil { + gz.zr, gz.zerr = gzip.NewReader(gz.body) } + if gz.zerr != nil { + return 0, gz.zerr + } + } + + gz.body.mu.Lock() + if gz.body.closed { + err = errReadOnClosedResBody + } + gz.body.mu.Unlock() + + if err != nil { + return 0, err } return gz.zr.Read(p) } @@ -1674,19 +2000,6 @@ func (tlsHandshakeTimeoutError) Timeout() bool { return true } func (tlsHandshakeTimeoutError) Temporary() bool { return true } func (tlsHandshakeTimeoutError) Error() string { return "net/http: TLS handshake timeout" } -type noteEOFReader struct { - r io.Reader - sawEOF *bool -} - -func (nr noteEOFReader) Read(p []byte) (n int, err error) { - n, err = nr.r.Read(p) - if err == io.EOF { - *nr.sawEOF = true - } - return -} - // fakeLocker is a sync.Locker which does nothing. It's used to guard // test-only fields when not under test, to avoid runtime atomic // overhead. @@ -1695,17 +2008,6 @@ type fakeLocker struct{} func (fakeLocker) Lock() {} func (fakeLocker) Unlock() {} -func isNetWriteError(err error) bool { - switch e := err.(type) { - case *url.Error: - return isNetWriteError(e.Err) - case *net.OpError: - return e.Op == "write" - default: - return false - } -} - // cloneTLSConfig returns a shallow clone of the exported // fields of cfg, ignoring the unexported sync.Once, which // contains a mutex and must not be copied. @@ -1722,25 +2024,27 @@ func cloneTLSConfig(cfg *tls.Config) *tls.Config { return &tls.Config{} } return &tls.Config{ - Rand: cfg.Rand, - Time: cfg.Time, - Certificates: cfg.Certificates, - NameToCertificate: cfg.NameToCertificate, - GetCertificate: cfg.GetCertificate, - RootCAs: cfg.RootCAs, - NextProtos: cfg.NextProtos, - ServerName: cfg.ServerName, - ClientAuth: cfg.ClientAuth, - ClientCAs: cfg.ClientCAs, - InsecureSkipVerify: cfg.InsecureSkipVerify, - CipherSuites: cfg.CipherSuites, - PreferServerCipherSuites: cfg.PreferServerCipherSuites, - SessionTicketsDisabled: cfg.SessionTicketsDisabled, - SessionTicketKey: cfg.SessionTicketKey, - ClientSessionCache: cfg.ClientSessionCache, - MinVersion: cfg.MinVersion, - MaxVersion: cfg.MaxVersion, - CurvePreferences: cfg.CurvePreferences, + Rand: cfg.Rand, + Time: cfg.Time, + Certificates: cfg.Certificates, + NameToCertificate: cfg.NameToCertificate, + GetCertificate: cfg.GetCertificate, + RootCAs: cfg.RootCAs, + NextProtos: cfg.NextProtos, + ServerName: cfg.ServerName, + ClientAuth: cfg.ClientAuth, + ClientCAs: cfg.ClientCAs, + InsecureSkipVerify: cfg.InsecureSkipVerify, + CipherSuites: cfg.CipherSuites, + PreferServerCipherSuites: cfg.PreferServerCipherSuites, + SessionTicketsDisabled: cfg.SessionTicketsDisabled, + SessionTicketKey: cfg.SessionTicketKey, + ClientSessionCache: cfg.ClientSessionCache, + MinVersion: cfg.MinVersion, + MaxVersion: cfg.MaxVersion, + CurvePreferences: cfg.CurvePreferences, + DynamicRecordSizingDisabled: cfg.DynamicRecordSizingDisabled, + Renegotiation: cfg.Renegotiation, } } @@ -1753,22 +2057,63 @@ func cloneTLSClientConfig(cfg *tls.Config) *tls.Config { return &tls.Config{} } return &tls.Config{ - Rand: cfg.Rand, - Time: cfg.Time, - Certificates: cfg.Certificates, - NameToCertificate: cfg.NameToCertificate, - GetCertificate: cfg.GetCertificate, - RootCAs: cfg.RootCAs, - NextProtos: cfg.NextProtos, - ServerName: cfg.ServerName, - ClientAuth: cfg.ClientAuth, - ClientCAs: cfg.ClientCAs, - InsecureSkipVerify: cfg.InsecureSkipVerify, - CipherSuites: cfg.CipherSuites, - PreferServerCipherSuites: cfg.PreferServerCipherSuites, - ClientSessionCache: cfg.ClientSessionCache, - MinVersion: cfg.MinVersion, - MaxVersion: cfg.MaxVersion, - CurvePreferences: cfg.CurvePreferences, + Rand: cfg.Rand, + Time: cfg.Time, + Certificates: cfg.Certificates, + NameToCertificate: cfg.NameToCertificate, + GetCertificate: cfg.GetCertificate, + RootCAs: cfg.RootCAs, + NextProtos: cfg.NextProtos, + ServerName: cfg.ServerName, + ClientAuth: cfg.ClientAuth, + ClientCAs: cfg.ClientCAs, + InsecureSkipVerify: cfg.InsecureSkipVerify, + CipherSuites: cfg.CipherSuites, + PreferServerCipherSuites: cfg.PreferServerCipherSuites, + ClientSessionCache: cfg.ClientSessionCache, + MinVersion: cfg.MinVersion, + MaxVersion: cfg.MaxVersion, + CurvePreferences: cfg.CurvePreferences, + DynamicRecordSizingDisabled: cfg.DynamicRecordSizingDisabled, + Renegotiation: cfg.Renegotiation, } } + +type connLRU struct { + ll *list.List // list.Element.Value type is of *persistConn + m map[*persistConn]*list.Element +} + +// add adds pc to the head of the linked list. +func (cl *connLRU) add(pc *persistConn) { + if cl.ll == nil { + cl.ll = list.New() + cl.m = make(map[*persistConn]*list.Element) + } + ele := cl.ll.PushFront(pc) + if _, ok := cl.m[pc]; ok { + panic("persistConn was already in LRU") + } + cl.m[pc] = ele +} + +func (cl *connLRU) removeOldest() *persistConn { + ele := cl.ll.Back() + pc := ele.Value.(*persistConn) + cl.ll.Remove(ele) + delete(cl.m, pc) + return pc +} + +// remove removes pc from cl. +func (cl *connLRU) remove(pc *persistConn) { + if ele, ok := cl.m[pc]; ok { + cl.ll.Remove(ele) + delete(cl.m, pc) + } +} + +// len returns the number of items in the cache. +func (cl *connLRU) len() int { + return len(cl.m) +} diff --git a/libgo/go/net/http/transport_internal_test.go b/libgo/go/net/http/transport_internal_test.go new file mode 100644 index 00000000000..a157d906300 --- /dev/null +++ b/libgo/go/net/http/transport_internal_test.go @@ -0,0 +1,69 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// White-box tests for transport.go (in package http instead of http_test). + +package http + +import ( + "errors" + "net" + "testing" +) + +// Issue 15446: incorrect wrapping of errors when server closes an idle connection. +func TestTransportPersistConnReadLoopEOF(t *testing.T) { + ln := newLocalListener(t) + defer ln.Close() + + connc := make(chan net.Conn, 1) + go func() { + defer close(connc) + c, err := ln.Accept() + if err != nil { + t.Error(err) + return + } + connc <- c + }() + + tr := new(Transport) + req, _ := NewRequest("GET", "http://"+ln.Addr().String(), nil) + treq := &transportRequest{Request: req} + cm := connectMethod{targetScheme: "http", targetAddr: ln.Addr().String()} + pc, err := tr.getConn(treq, cm) + if err != nil { + t.Fatal(err) + } + defer pc.close(errors.New("test over")) + + conn := <-connc + if conn == nil { + // Already called t.Error in the accept goroutine. + return + } + conn.Close() // simulate the server hanging up on the client + + _, err = pc.roundTrip(treq) + if err != errServerClosedConn && err != errServerClosedIdle { + t.Fatalf("roundTrip = %#v, %v; want errServerClosedConn or errServerClosedIdle", err, err) + } + + <-pc.closech + err = pc.closed + if err != errServerClosedConn && err != errServerClosedIdle { + t.Fatalf("pc.closed = %#v, %v; want errServerClosedConn or errServerClosedIdle", err, err) + } +} + +func newLocalListener(t *testing.T) net.Listener { + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + ln, err = net.Listen("tcp6", "[::1]:0") + } + if err != nil { + t.Fatal(err) + } + return ln +} diff --git a/libgo/go/net/http/transport_test.go b/libgo/go/net/http/transport_test.go index 0c901b30a44..72b98f16d7e 100644 --- a/libgo/go/net/http/transport_test.go +++ b/libgo/go/net/http/transport_test.go @@ -13,16 +13,20 @@ import ( "bufio" "bytes" "compress/gzip" + "context" "crypto/rand" "crypto/tls" "errors" "fmt" + "internal/nettrace" + "internal/testenv" "io" "io/ioutil" "log" "net" . "net/http" "net/http/httptest" + "net/http/httptrace" "net/http/httputil" "net/http/internal" "net/url" @@ -379,8 +383,8 @@ func TestTransportMaxPerHostIdleConns(t *testing.T) { } })) defer ts.Close() - maxIdleConns := 2 - tr := &Transport{DisableKeepAlives: false, MaxIdleConnsPerHost: maxIdleConns} + maxIdleConnsPerHost := 2 + tr := &Transport{DisableKeepAlives: false, MaxIdleConnsPerHost: maxIdleConnsPerHost} c := &Client{Transport: tr} // Start 3 outstanding requests and wait for the server to get them. @@ -425,14 +429,65 @@ func TestTransportMaxPerHostIdleConns(t *testing.T) { resch <- "res2" <-donech - if e, g := 2, tr.IdleConnCountForTesting(cacheKey); e != g { - t.Errorf("after second response, expected %d idle conns; got %d", e, g) + if g, w := tr.IdleConnCountForTesting(cacheKey), 2; g != w { + t.Errorf("after second response, idle conns = %d; want %d", g, w) } resch <- "res3" <-donech - if e, g := maxIdleConns, tr.IdleConnCountForTesting(cacheKey); e != g { - t.Errorf("after third response, still expected %d idle conns; got %d", e, g) + if g, w := tr.IdleConnCountForTesting(cacheKey), maxIdleConnsPerHost; g != w { + t.Errorf("after third response, idle conns = %d; want %d", g, w) + } +} + +func TestTransportRemovesDeadIdleConnections(t *testing.T) { + if runtime.GOOS == "plan9" { + t.Skip("skipping test; see https://golang.org/issue/15464") + } + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + io.WriteString(w, r.RemoteAddr) + })) + defer ts.Close() + + tr := &Transport{} + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + + doReq := func(name string) string { + // Do a POST instead of a GET to prevent the Transport's + // idempotent request retry logic from kicking in... + res, err := c.Post(ts.URL, "", nil) + if err != nil { + t.Fatalf("%s: %v", name, err) + } + if res.StatusCode != 200 { + t.Fatalf("%s: %v", name, res.Status) + } + defer res.Body.Close() + slurp, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("%s: %v", name, err) + } + return string(slurp) + } + + first := doReq("first") + keys1 := tr.IdleConnKeysForTesting() + + ts.CloseClientConnections() + + var keys2 []string + if !waitCondition(3*time.Second, 50*time.Millisecond, func() bool { + keys2 = tr.IdleConnKeysForTesting() + return len(keys2) == 0 + }) { + t.Fatalf("Transport didn't notice idle connection's death.\nbefore: %q\n after: %q\n", keys1, keys2) + } + + second := doReq("second") + if first == second { + t.Errorf("expected a different connection between requests. got %q both times", first) } } @@ -478,7 +533,7 @@ func TestTransportServerClosingUnexpectedly(t *testing.T) { // This test has an expected race. Sleeping for 25 ms prevents // it on most fast machines, causing the next fetch() call to - // succeed quickly. But if we do get errors, fetch() will retry 5 + // succeed quickly. But if we do get errors, fetch() will retry 5 // times with some delays between. time.Sleep(25 * time.Millisecond) @@ -518,7 +573,7 @@ func TestStressSurpriseServerCloses(t *testing.T) { // after each request completes, regardless of whether it failed. // If these are too high, OS X exhausts its ephemeral ports // and hangs waiting for them to transition TCP states. That's - // not what we want to test. TODO(bradfitz): use an io.Pipe + // not what we want to test. TODO(bradfitz): use an io.Pipe // dialer for this test instead? const ( numClients = 20 @@ -853,7 +908,7 @@ func TestTransportExpect100Continue(t *testing.T) { {path: "/100", body: []byte("hello"), sent: 5, status: 200}, // Got 100 followed by 200, entire body is sent. {path: "/200", body: []byte("hello"), sent: 0, status: 200}, // Got 200 without 100. body isn't sent. {path: "/500", body: []byte("hello"), sent: 0, status: 500}, // Got 500 without 100. body isn't sent. - {path: "/keepalive", body: []byte("hello"), sent: 0, status: 500}, // Althogh without Connection:close, body isn't sent. + {path: "/keepalive", body: []byte("hello"), sent: 0, status: 500}, // Although without Connection:close, body isn't sent. {path: "/timeout", body: []byte("hello"), sent: 5, status: 200}, // Timeout exceeded and entire body is sent. } @@ -923,7 +978,9 @@ func TestTransportGzipRecursive(t *testing.T) { })) defer ts.Close() - c := &Client{Transport: &Transport{}} + tr := &Transport{} + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} res, err := c.Get(ts.URL) if err != nil { t.Fatal(err) @@ -968,6 +1025,17 @@ func TestTransportGzipShort(t *testing.T) { } } +// Wait until number of goroutines is no greater than nmax, or time out. +func waitNumGoroutine(nmax int) int { + nfinal := runtime.NumGoroutine() + for ntries := 10; ntries > 0 && nfinal > nmax; ntries-- { + time.Sleep(50 * time.Millisecond) + runtime.GC() + nfinal = runtime.NumGoroutine() + } + return nfinal +} + // tests that persistent goroutine connections shut down when no longer desired. func TestTransportPersistConnLeak(t *testing.T) { setParallel(t) @@ -1019,14 +1087,11 @@ func TestTransportPersistConnLeak(t *testing.T) { } tr.CloseIdleConnections() - time.Sleep(100 * time.Millisecond) - runtime.GC() - runtime.GC() // even more. - nfinal := runtime.NumGoroutine() + nfinal := waitNumGoroutine(n0 + 5) growth := nfinal - n0 - // We expect 0 or 1 extra goroutine, empirically. Allow up to 5. + // We expect 0 or 1 extra goroutine, empirically. Allow up to 5. // Previously we were leaking one per numReq. if int(growth) > 5 { t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth) @@ -1061,13 +1126,11 @@ func TestTransportPersistConnLeakShortBody(t *testing.T) { } nhigh := runtime.NumGoroutine() tr.CloseIdleConnections() - time.Sleep(400 * time.Millisecond) - runtime.GC() - nfinal := runtime.NumGoroutine() + nfinal := waitNumGoroutine(n0 + 5) growth := nfinal - n0 - // We expect 0 or 1 extra goroutine, empirically. Allow up to 5. + // We expect 0 or 1 extra goroutine, empirically. Allow up to 5. // Previously we were leaking one per numReq. t.Logf("goroutine growth: %d -> %d -> %d (delta: %d)", n0, nhigh, nfinal, growth) if int(growth) > 5 { @@ -1103,8 +1166,8 @@ func TestTransportIdleConnCrash(t *testing.T) { } // Test that the transport doesn't close the TCP connection early, -// before the response body has been read. This was a regression -// which sadly lacked a triggering test. The large response body made +// before the response body has been read. This was a regression +// which sadly lacked a triggering test. The large response body made // the old race easier to trigger. func TestIssue3644(t *testing.T) { defer afterTest(t) @@ -1199,7 +1262,7 @@ func TestTransportConcurrency(t *testing.T) { // Due to the Transport's "socket late binding" (see // idleConnCh in transport.go), the numReqs HTTP requests - // below can finish with a dial still outstanding. To keep + // below can finish with a dial still outstanding. To keep // the leak checker happy, keep track of pending dials and // wait for them to finish (and be closed or returned to the // idle pool) before we close idle connections. @@ -1617,7 +1680,13 @@ func TestCancelRequestWithChannel(t *testing.T) { } } -func TestCancelRequestWithChannelBeforeDo(t *testing.T) { +func TestCancelRequestWithChannelBeforeDo_Cancel(t *testing.T) { + testCancelRequestWithChannelBeforeDo(t, false) +} +func TestCancelRequestWithChannelBeforeDo_Context(t *testing.T) { + testCancelRequestWithChannelBeforeDo(t, true) +} +func testCancelRequestWithChannelBeforeDo(t *testing.T, withCtx bool) { setParallel(t) defer afterTest(t) unblockc := make(chan bool) @@ -1638,9 +1707,15 @@ func TestCancelRequestWithChannelBeforeDo(t *testing.T) { c := &Client{Transport: tr} req, _ := NewRequest("GET", ts.URL, nil) - ch := make(chan struct{}) - req.Cancel = ch - close(ch) + if withCtx { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + req = req.WithContext(ctx) + } else { + ch := make(chan struct{}) + req.Cancel = ch + close(ch) + } _, err := c.Do(req) if err == nil || !strings.Contains(err.Error(), "canceled") { @@ -1985,7 +2060,8 @@ type proxyFromEnvTest struct { env string // HTTP_PROXY httpsenv string // HTTPS_PROXY - noenv string // NO_RPXY + noenv string // NO_PROXY + reqmeth string // REQUEST_METHOD want string wanterr error @@ -2009,6 +2085,10 @@ func (t proxyFromEnvTest) String() string { space() fmt.Fprintf(&buf, "no_proxy=%q", t.noenv) } + if t.reqmeth != "" { + space() + fmt.Fprintf(&buf, "request_method=%q", t.reqmeth) + } req := "http://example.com" if t.req != "" { req = t.req @@ -2032,6 +2112,12 @@ var proxyFromEnvTests = []proxyFromEnvTest{ {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "secure.proxy.tld", want: "http://secure.proxy.tld"}, {req: "https://secure.tld/", env: "http.proxy.tld", httpsenv: "https://secure.proxy.tld", want: "https://secure.proxy.tld"}, + // Issue 16405: don't use HTTP_PROXY in a CGI environment, + // where HTTP_PROXY can be attacker-controlled. + {env: "http://10.1.2.3:8080", reqmeth: "POST", + want: "<nil>", + wanterr: errors.New("net/http: refusing to use HTTP_PROXY value in CGI environment; see golang.org/s/cgihttpproxy")}, + {want: "<nil>"}, {noenv: "example.com", req: "http://example.com/", env: "proxy", want: "<nil>"}, @@ -2047,6 +2133,7 @@ func TestProxyFromEnvironment(t *testing.T) { os.Setenv("HTTP_PROXY", tt.env) os.Setenv("HTTPS_PROXY", tt.httpsenv) os.Setenv("NO_PROXY", tt.noenv) + os.Setenv("REQUEST_METHOD", tt.reqmeth) ResetCachedEnvironment() reqURL := tt.req if reqURL == "" { @@ -2208,7 +2295,7 @@ func TestTransportTLSHandshakeTimeout(t *testing.T) { // Trying to repro golang.org/issue/3514 func TestTLSServerClosesConnection(t *testing.T) { defer afterTest(t) - setFlaky(t, 7634) + testenv.SkipFlaky(t, 7634) closedc := make(chan bool, 1) ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { @@ -2273,7 +2360,7 @@ func TestTLSServerClosesConnection(t *testing.T) { } // byteFromChanReader is an io.Reader that reads a single byte at a -// time from the channel. When the channel is closed, the reader +// time from the channel. When the channel is closed, the reader // returns io.EOF. type byteFromChanReader chan byte @@ -2405,7 +2492,7 @@ func (plan9SleepReader) Read(p []byte) (int, error) { // After the fix to unblock TCP Reads in // https://golang.org/cl/15941, this sleep is required // on plan9 to make sure TCP Writes before an - // immediate TCP close go out on the wire. On Plan 9, + // immediate TCP close go out on the wire. On Plan 9, // it seems that a hangup of a TCP connection with // queued data doesn't send the queued data first. // https://golang.org/issue/9554 @@ -2424,7 +2511,7 @@ func (f closerFunc) Close() error { return f() } // from (or finish writing to) the socket. // // NOTE: we resend a request only if the request is idempotent, we reused a -// keep-alive connection, and we haven't yet received any header data. This +// keep-alive connection, and we haven't yet received any header data. This // automatically prevents an infinite resend loop because we'll run out of the // cached keep-alive connections eventually. func TestRetryIdempotentRequestsOnError(t *testing.T) { @@ -2888,6 +2975,11 @@ func TestTransportAutomaticHTTP2(t *testing.T) { testTransportAutoHTTP(t, &Transport{}, true) } +// golang.org/issue/14391: also check DefaultTransport +func TestTransportAutomaticHTTP2_DefaultTransport(t *testing.T) { + testTransportAutoHTTP(t, DefaultTransport.(*Transport), true) +} + func TestTransportAutomaticHTTP2_TLSNextProto(t *testing.T) { testTransportAutoHTTP(t, &Transport{ TLSNextProto: make(map[string]func(string, *tls.Conn) RoundTripper), @@ -2903,6 +2995,21 @@ func TestTransportAutomaticHTTP2_TLSConfig(t *testing.T) { func TestTransportAutomaticHTTP2_ExpectContinueTimeout(t *testing.T) { testTransportAutoHTTP(t, &Transport{ ExpectContinueTimeout: 1 * time.Second, + }, true) +} + +func TestTransportAutomaticHTTP2_Dial(t *testing.T) { + var d net.Dialer + testTransportAutoHTTP(t, &Transport{ + Dial: d.Dial, + }, false) +} + +func TestTransportAutomaticHTTP2_DialTLS(t *testing.T) { + testTransportAutoHTTP(t, &Transport{ + DialTLS: func(network, addr string) (net.Conn, error) { + panic("unused") + }, }, false) } @@ -3033,6 +3140,377 @@ func TestNoCrashReturningTransportAltConn(t *testing.T) { <-handledPendingDial } +func TestTransportReuseConnection_Gzip_Chunked(t *testing.T) { + testTransportReuseConnection_Gzip(t, true) +} + +func TestTransportReuseConnection_Gzip_ContentLength(t *testing.T) { + testTransportReuseConnection_Gzip(t, false) +} + +// Make sure we re-use underlying TCP connection for gzipped responses too. +func testTransportReuseConnection_Gzip(t *testing.T, chunked bool) { + defer afterTest(t) + addr := make(chan string, 2) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + addr <- r.RemoteAddr + w.Header().Set("Content-Encoding", "gzip") + if chunked { + w.(Flusher).Flush() + } + w.Write(rgz) // arbitrary gzip response + })) + defer ts.Close() + + tr := &Transport{} + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + for i := 0; i < 2; i++ { + res, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + buf := make([]byte, len(rgz)) + if n, err := io.ReadFull(res.Body, buf); err != nil { + t.Errorf("%d. ReadFull = %v, %v", i, n, err) + } + // Note: no res.Body.Close call. It should work without it, + // since the flate.Reader's internal buffering will hit EOF + // and that should be sufficient. + } + a1, a2 := <-addr, <-addr + if a1 != a2 { + t.Fatalf("didn't reuse connection") + } +} + +func TestTransportResponseHeaderLength(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if r.URL.Path == "/long" { + w.Header().Set("Long", strings.Repeat("a", 1<<20)) + } + })) + defer ts.Close() + + tr := &Transport{ + MaxResponseHeaderBytes: 512 << 10, + } + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + if res, err := c.Get(ts.URL); err != nil { + t.Fatal(err) + } else { + res.Body.Close() + } + + res, err := c.Get(ts.URL + "/long") + if err == nil { + defer res.Body.Close() + var n int64 + for k, vv := range res.Header { + for _, v := range vv { + n += int64(len(k)) + int64(len(v)) + } + } + t.Fatalf("Unexpected success. Got %v and %d bytes of response headers", res.Status, n) + } + if want := "server response headers exceeded 524288 bytes"; !strings.Contains(err.Error(), want) { + t.Errorf("got error: %v; want %q", err, want) + } +} + +func TestTransportEventTrace(t *testing.T) { testTransportEventTrace(t, h1Mode, false) } +func TestTransportEventTrace_h2(t *testing.T) { testTransportEventTrace(t, h2Mode, false) } + +// test a non-nil httptrace.ClientTrace but with all hooks set to zero. +func TestTransportEventTrace_NoHooks(t *testing.T) { testTransportEventTrace(t, h1Mode, true) } +func TestTransportEventTrace_NoHooks_h2(t *testing.T) { testTransportEventTrace(t, h2Mode, true) } + +func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) { + defer afterTest(t) + const resBody = "some body" + gotWroteReqEvent := make(chan struct{}) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + if _, err := ioutil.ReadAll(r.Body); err != nil { + t.Error(err) + } + if !noHooks { + select { + case <-gotWroteReqEvent: + case <-time.After(5 * time.Second): + t.Error("timeout waiting for WroteRequest event") + } + } + io.WriteString(w, resBody) + })) + defer cst.close() + + cst.tr.ExpectContinueTimeout = 1 * time.Second + + var mu sync.Mutex + var buf bytes.Buffer + logf := func(format string, args ...interface{}) { + mu.Lock() + defer mu.Unlock() + fmt.Fprintf(&buf, format, args...) + buf.WriteByte('\n') + } + + addrStr := cst.ts.Listener.Addr().String() + ip, port, err := net.SplitHostPort(addrStr) + if err != nil { + t.Fatal(err) + } + + // Install a fake DNS server. + ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, host string) ([]net.IPAddr, error) { + if host != "dns-is-faked.golang" { + t.Errorf("unexpected DNS host lookup for %q", host) + return nil, nil + } + return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil + }) + + req, _ := NewRequest("POST", cst.scheme()+"://dns-is-faked.golang:"+port, strings.NewReader("some body")) + trace := &httptrace.ClientTrace{ + GetConn: func(hostPort string) { logf("Getting conn for %v ...", hostPort) }, + GotConn: func(ci httptrace.GotConnInfo) { logf("got conn: %+v", ci) }, + GotFirstResponseByte: func() { logf("first response byte") }, + PutIdleConn: func(err error) { logf("PutIdleConn = %v", err) }, + DNSStart: func(e httptrace.DNSStartInfo) { logf("DNS start: %+v", e) }, + DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNS done: %+v", e) }, + ConnectStart: func(network, addr string) { logf("ConnectStart: Connecting to %s %s ...", network, addr) }, + ConnectDone: func(network, addr string, err error) { + if err != nil { + t.Errorf("ConnectDone: %v", err) + } + logf("ConnectDone: connected to %s %s = %v", network, addr, err) + }, + Wait100Continue: func() { logf("Wait100Continue") }, + Got100Continue: func() { logf("Got100Continue") }, + WroteRequest: func(e httptrace.WroteRequestInfo) { + close(gotWroteReqEvent) + logf("WroteRequest: %+v", e) + }, + } + if noHooks { + // zero out all func pointers, trying to get some path to crash + *trace = httptrace.ClientTrace{} + } + req = req.WithContext(httptrace.WithClientTrace(ctx, trace)) + + req.Header.Set("Expect", "100-continue") + res, err := cst.c.Do(req) + if err != nil { + t.Fatal(err) + } + logf("got roundtrip.response") + slurp, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + logf("consumed body") + if string(slurp) != resBody || res.StatusCode != 200 { + t.Fatalf("Got %q, %v; want %q, 200 OK", slurp, res.Status, resBody) + } + res.Body.Close() + + if noHooks { + // Done at this point. Just testing a full HTTP + // requests can happen with a trace pointing to a zero + // ClientTrace, full of nil func pointers. + return + } + + got := buf.String() + wantOnce := func(sub string) { + if strings.Count(got, sub) != 1 { + t.Errorf("expected substring %q exactly once in output.", sub) + } + } + wantOnceOrMore := func(sub string) { + if strings.Count(got, sub) == 0 { + t.Errorf("expected substring %q at least once in output.", sub) + } + } + wantOnce("Getting conn for dns-is-faked.golang:" + port) + wantOnce("DNS start: {Host:dns-is-faked.golang}") + wantOnce("DNS done: {Addrs:[{IP:" + ip + " Zone:}] Err:<nil> Coalesced:false}") + wantOnce("got conn: {") + wantOnceOrMore("Connecting to tcp " + addrStr) + wantOnceOrMore("connected to tcp " + addrStr + " = <nil>") + wantOnce("Reused:false WasIdle:false IdleTime:0s") + wantOnce("first response byte") + if !h2 { + wantOnce("PutIdleConn = <nil>") + } + wantOnce("Wait100Continue") + wantOnce("Got100Continue") + wantOnce("WroteRequest: {Err:<nil>}") + if strings.Contains(got, " to udp ") { + t.Errorf("should not see UDP (DNS) connections") + } + if t.Failed() { + t.Errorf("Output:\n%s", got) + } +} + +func TestTransportEventTraceRealDNS(t *testing.T) { + defer afterTest(t) + tr := &Transport{} + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + + var mu sync.Mutex + var buf bytes.Buffer + logf := func(format string, args ...interface{}) { + mu.Lock() + defer mu.Unlock() + fmt.Fprintf(&buf, format, args...) + buf.WriteByte('\n') + } + + req, _ := NewRequest("GET", "http://dns-should-not-resolve.golang:80", nil) + trace := &httptrace.ClientTrace{ + DNSStart: func(e httptrace.DNSStartInfo) { logf("DNSStart: %+v", e) }, + DNSDone: func(e httptrace.DNSDoneInfo) { logf("DNSDone: %+v", e) }, + ConnectStart: func(network, addr string) { logf("ConnectStart: %s %s", network, addr) }, + ConnectDone: func(network, addr string, err error) { logf("ConnectDone: %s %s %v", network, addr, err) }, + } + req = req.WithContext(httptrace.WithClientTrace(context.Background(), trace)) + + resp, err := c.Do(req) + if err == nil { + resp.Body.Close() + t.Fatal("expected error during DNS lookup") + } + + got := buf.String() + wantSub := func(sub string) { + if !strings.Contains(got, sub) { + t.Errorf("expected substring %q in output.", sub) + } + } + wantSub("DNSStart: {Host:dns-should-not-resolve.golang}") + wantSub("DNSDone: {Addrs:[] Err:") + if strings.Contains(got, "ConnectStart") || strings.Contains(got, "ConnectDone") { + t.Errorf("should not see Connect events") + } + if t.Failed() { + t.Errorf("Output:\n%s", got) + } +} + +func TestTransportMaxIdleConns(t *testing.T) { + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + // No body for convenience. + })) + defer ts.Close() + tr := &Transport{ + MaxIdleConns: 4, + } + defer tr.CloseIdleConnections() + + ip, port, err := net.SplitHostPort(ts.Listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + c := &Client{Transport: tr} + ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, host string) ([]net.IPAddr, error) { + return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil + }) + + hitHost := func(n int) { + req, _ := NewRequest("GET", fmt.Sprintf("http://host-%d.dns-is-faked.golang:"+port, n), nil) + req = req.WithContext(ctx) + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + } + for i := 0; i < 4; i++ { + hitHost(i) + } + want := []string{ + "|http|host-0.dns-is-faked.golang:" + port, + "|http|host-1.dns-is-faked.golang:" + port, + "|http|host-2.dns-is-faked.golang:" + port, + "|http|host-3.dns-is-faked.golang:" + port, + } + if got := tr.IdleConnKeysForTesting(); !reflect.DeepEqual(got, want) { + t.Fatalf("idle conn keys mismatch.\n got: %q\nwant: %q\n", got, want) + } + + // Now hitting the 5th host should kick out the first host: + hitHost(4) + want = []string{ + "|http|host-1.dns-is-faked.golang:" + port, + "|http|host-2.dns-is-faked.golang:" + port, + "|http|host-3.dns-is-faked.golang:" + port, + "|http|host-4.dns-is-faked.golang:" + port, + } + if got := tr.IdleConnKeysForTesting(); !reflect.DeepEqual(got, want) { + t.Fatalf("idle conn keys mismatch after 5th host.\n got: %q\nwant: %q\n", got, want) + } +} + +func TestTransportIdleConnTimeout(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + defer afterTest(t) + + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + // No body for convenience. + })) + defer ts.Close() + + const timeout = 1 * time.Second + tr := &Transport{ + IdleConnTimeout: timeout, + } + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + + var conn string + doReq := func(n int) { + req, _ := NewRequest("GET", ts.URL, nil) + req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ + PutIdleConn: func(err error) { + if err != nil { + t.Errorf("failed to keep idle conn: %v", err) + } + }, + })) + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + conns := tr.IdleConnStrsForTesting() + if len(conns) != 1 { + t.Fatalf("req %v: unexpected number of idle conns: %q", n, conns) + } + if conn == "" { + conn = conns[0] + } + if conn != conns[0] { + t.Fatalf("req %v: cached connection changed; expected the same one throughout the test", n) + } + } + for i := 0; i < 3; i++ { + doReq(i) + time.Sleep(timeout / 2) + } + time.Sleep(timeout * 3 / 2) + if got := tr.IdleConnStrsForTesting(); len(got) != 0 { + t.Errorf("idle conns = %q; want none", got) + } +} + var errFakeRoundTrip = errors.New("fake roundtrip") type funcRoundTripper func() |