diff options
Diffstat (limited to 'libgo/go/net/http')
48 files changed, 6704 insertions, 1614 deletions
diff --git a/libgo/go/net/http/client.go b/libgo/go/net/http/client.go index 993c247eef5..d368bae861e 100644 --- a/libgo/go/net/http/client.go +++ b/libgo/go/net/http/client.go @@ -18,6 +18,7 @@ import ( "io/ioutil" "log" "net/url" + "sort" "strings" "sync" "time" @@ -33,6 +34,25 @@ import ( // A Client is higher-level than a RoundTripper (such as Transport) // and additionally handles HTTP details such as cookies and // redirects. +// +// When following redirects, the Client will forward all headers set on the +// initial Request except: +// +// * when forwarding sensitive headers like "Authorization", +// "WWW-Authenticate", and "Cookie" to untrusted targets. +// These headers will be ignored when following a redirect to a domain +// that is not a subdomain match or exact match of the initial domain. +// For example, a redirect from "foo.com" to either "foo.com" or "sub.foo.com" +// will forward the sensitive headers, but a redirect to "bar.com" will not. +// +// * when forwarding the "Cookie" header with a non-nil cookie Jar. +// Since each redirect may mutate the state of the cookie jar, +// a redirect may possibly alter a cookie set in the initial request. +// When forwarding the "Cookie" header, any mutated cookies will be omitted, +// with the expectation that the Jar will insert those mutated cookies +// with the updated values (assuming the origin matches). +// If Jar is nil, the initial cookies are forwarded without change. +// type Client struct { // Transport specifies the mechanism by which individual // HTTP requests are made. @@ -56,8 +76,14 @@ type Client struct { CheckRedirect func(req *Request, via []*Request) error // Jar specifies the cookie jar. - // If Jar is nil, cookies are not sent in requests and ignored - // in responses. + // + // The Jar is used to insert relevant cookies into every + // outbound Request and is updated with the cookie values + // of every inbound Response. The Jar is consulted for every + // redirect that the Client follows. + // + // If Jar is nil, cookies are only sent if they are explicitly + // set on the Request. Jar CookieJar // Timeout specifies a time limit for requests made by this @@ -137,56 +163,23 @@ func refererForURL(lastReq, newReq *url.URL) string { return referer } -func (c *Client) send(req *Request, deadline time.Time) (*Response, error) { +// didTimeout is non-nil only if err != nil. +func (c *Client) send(req *Request, deadline time.Time) (resp *Response, didTimeout func() bool, err error) { if c.Jar != nil { for _, cookie := range c.Jar.Cookies(req.URL) { req.AddCookie(cookie) } } - resp, err := send(req, c.transport(), deadline) + resp, didTimeout, err = send(req, c.transport(), deadline) if err != nil { - return nil, err + return nil, didTimeout, err } if c.Jar != nil { if rc := resp.Cookies(); len(rc) > 0 { c.Jar.SetCookies(req.URL, rc) } } - return resp, nil -} - -// Do sends an HTTP request and returns an HTTP response, following -// policy (such as redirects, cookies, auth) as configured on the -// client. -// -// An error is returned if caused by client policy (such as -// CheckRedirect), or failure to speak HTTP (such as a network -// connectivity problem). A non-2xx status code doesn't cause an -// error. -// -// If the returned error is nil, the Response will contain a non-nil -// Body which the user is expected to close. If the Body is not -// closed, the Client's underlying RoundTripper (typically Transport) -// may not be able to re-use a persistent TCP connection to the server -// for a subsequent "keep-alive" request. -// -// The request Body, if non-nil, will be closed by the underlying -// Transport, even on errors. -// -// On error, any Response can be ignored. A non-nil Response with a -// non-nil error only occurs when CheckRedirect fails, and even then -// the returned Response.Body is already closed. -// -// Generally Get, Post, or PostForm will be used instead of Do. -func (c *Client) Do(req *Request) (*Response, error) { - method := valueOrDefault(req.Method, "GET") - if method == "GET" || method == "HEAD" { - return c.doFollowingRedirects(req, shouldRedirectGet) - } - if method == "POST" || method == "PUT" { - return c.doFollowingRedirects(req, shouldRedirectPost) - } - return c.send(req, c.deadline()) + return resp, nil, nil } func (c *Client) deadline() time.Time { @@ -205,22 +198,22 @@ func (c *Client) transport() RoundTripper { // send issues an HTTP request. // Caller should close resp.Body when done reading from it. -func send(ireq *Request, rt RoundTripper, deadline time.Time) (*Response, error) { +func send(ireq *Request, rt RoundTripper, deadline time.Time) (resp *Response, didTimeout func() bool, err error) { req := ireq // req is either the original request, or a modified fork if rt == nil { req.closeBody() - return nil, errors.New("http: no Client.Transport or DefaultTransport") + return nil, alwaysFalse, errors.New("http: no Client.Transport or DefaultTransport") } if req.URL == nil { req.closeBody() - return nil, errors.New("http: nil Request.URL") + return nil, alwaysFalse, errors.New("http: nil Request.URL") } if req.RequestURI != "" { req.closeBody() - return nil, errors.New("http: Request.RequestURI can't be set in client requests.") + return nil, alwaysFalse, errors.New("http: Request.RequestURI can't be set in client requests.") } // forkReq forks req into a shallow clone of ireq the first @@ -251,9 +244,9 @@ func send(ireq *Request, rt RoundTripper, deadline time.Time) (*Response, error) if !deadline.IsZero() { forkReq() } - stopTimer, wasCanceled := setRequestCancel(req, rt, deadline) + stopTimer, didTimeout := setRequestCancel(req, rt, deadline) - resp, err := rt.RoundTrip(req) + resp, err = rt.RoundTrip(req) if err != nil { stopTimer() if resp != nil { @@ -267,22 +260,27 @@ func send(ireq *Request, rt RoundTripper, deadline time.Time) (*Response, error) err = errors.New("http: server gave HTTP response to HTTPS client") } } - return nil, err + return nil, didTimeout, err } if !deadline.IsZero() { resp.Body = &cancelTimerBody{ - stop: stopTimer, - rc: resp.Body, - reqWasCanceled: wasCanceled, + stop: stopTimer, + rc: resp.Body, + reqDidTimeout: didTimeout, } } - return resp, nil + return resp, nil, nil } // setRequestCancel sets the Cancel field of req, if deadline is // non-zero. The RoundTripper's type is used to determine whether the legacy // CancelRequest behavior should be used. -func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTimer func(), wasCanceled func() bool) { +// +// As background, there are three ways to cancel a request: +// First was Transport.CancelRequest. (deprecated) +// Second was Request.Cancel (this mechanism). +// Third was Request.Context. +func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTimer func(), didTimeout func() bool) { if deadline.IsZero() { return nop, alwaysFalse } @@ -292,17 +290,8 @@ func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTi cancel := make(chan struct{}) req.Cancel = cancel - wasCanceled = func() bool { - select { - case <-cancel: - return true - default: - return false - } - } - doCancel := func() { - // The new way: + // The newer way (the second way in the func comment): close(cancel) // The legacy compatibility way, used only @@ -324,19 +313,23 @@ func setRequestCancel(req *Request, rt RoundTripper, deadline time.Time) (stopTi var once sync.Once stopTimer = func() { once.Do(func() { close(stopTimerCh) }) } - timer := time.NewTimer(deadline.Sub(time.Now())) + timer := time.NewTimer(time.Until(deadline)) + var timedOut atomicBool + go func() { select { case <-initialReqCancel: doCancel() + timer.Stop() case <-timer.C: + timedOut.setTrue() doCancel() case <-stopTimerCh: timer.Stop() } }() - return stopTimer, wasCanceled + return stopTimer, timedOut.isSet } // See 2 (end of page 4) http://www.ietf.org/rfc/rfc2617.txt @@ -349,26 +342,6 @@ func basicAuth(username, password string) string { return base64.StdEncoding.EncodeToString([]byte(auth)) } -// True if the specified HTTP status code is one for which the Get utility should -// automatically redirect. -func shouldRedirectGet(statusCode int) bool { - switch statusCode { - case StatusMovedPermanently, StatusFound, StatusSeeOther, StatusTemporaryRedirect: - return true - } - return false -} - -// True if the specified HTTP status code is one for which the Post utility should -// automatically redirect. -func shouldRedirectPost(statusCode int) bool { - switch statusCode { - case StatusFound, StatusSeeOther: - return true - } - return false -} - // Get issues a GET to the specified URL. If the response is one of // the following redirect codes, Get follows the redirect, up to a // maximum of 10 redirects: @@ -377,6 +350,7 @@ func shouldRedirectPost(statusCode int) bool { // 302 (Found) // 303 (See Other) // 307 (Temporary Redirect) +// 308 (Permanent Redirect) // // An error is returned if there were too many redirects or if there // was an HTTP protocol error. A non-2xx response doesn't cause an @@ -401,6 +375,7 @@ func Get(url string) (resp *Response, err error) { // 302 (Found) // 303 (See Other) // 307 (Temporary Redirect) +// 308 (Permanent Redirect) // // An error is returned if the Client's CheckRedirect function fails // or if there was an HTTP protocol error. A non-2xx response doesn't @@ -415,7 +390,7 @@ func (c *Client) Get(url string) (resp *Response, err error) { if err != nil { return nil, err } - return c.doFollowingRedirects(req, shouldRedirectGet) + return c.Do(req) } func alwaysFalse() bool { return false } @@ -436,16 +411,92 @@ func (c *Client) checkRedirect(req *Request, via []*Request) error { return fn(req, via) } -func (c *Client) doFollowingRedirects(req *Request, shouldRedirect func(int) bool) (*Response, error) { +// redirectBehavior describes what should happen when the +// client encounters a 3xx status code from the server +func redirectBehavior(reqMethod string, resp *Response, ireq *Request) (redirectMethod string, shouldRedirect bool) { + switch resp.StatusCode { + case 301, 302, 303: + redirectMethod = reqMethod + shouldRedirect = true + + // RFC 2616 allowed automatic redirection only with GET and + // HEAD requests. RFC 7231 lifts this restriction, but we still + // restrict other methods to GET to maintain compatibility. + // See Issue 18570. + if reqMethod != "GET" && reqMethod != "HEAD" { + redirectMethod = "GET" + } + case 307, 308: + redirectMethod = reqMethod + shouldRedirect = true + + // Treat 307 and 308 specially, since they're new in + // Go 1.8, and they also require re-sending the request body. + if resp.Header.Get("Location") == "" { + // 308s have been observed in the wild being served + // without Location headers. Since Go 1.7 and earlier + // didn't follow these codes, just stop here instead + // of returning an error. + // See Issue 17773. + shouldRedirect = false + break + } + if ireq.GetBody == nil && ireq.outgoingLength() != 0 { + // We had a request body, and 307/308 require + // re-sending it, but GetBody is not defined. So just + // return this response to the user instead of an + // error, like we did in Go 1.7 and earlier. + shouldRedirect = false + } + } + return redirectMethod, shouldRedirect +} + +// Do sends an HTTP request and returns an HTTP response, following +// policy (such as redirects, cookies, auth) as configured on the +// client. +// +// An error is returned if caused by client policy (such as +// CheckRedirect), or failure to speak HTTP (such as a network +// connectivity problem). A non-2xx status code doesn't cause an +// error. +// +// If the returned error is nil, the Response will contain a non-nil +// Body which the user is expected to close. If the Body is not +// closed, the Client's underlying RoundTripper (typically Transport) +// may not be able to re-use a persistent TCP connection to the server +// for a subsequent "keep-alive" request. +// +// The request Body, if non-nil, will be closed by the underlying +// Transport, even on errors. +// +// On error, any Response can be ignored. A non-nil Response with a +// non-nil error only occurs when CheckRedirect fails, and even then +// the returned Response.Body is already closed. +// +// Generally Get, Post, or PostForm will be used instead of Do. +// +// If the server replies with a redirect, the Client first uses the +// CheckRedirect function to determine whether the redirect should be +// followed. If permitted, a 301, 302, or 303 redirect causes +// subsequent requests to use HTTP method GET +// (or HEAD if the original request was HEAD), with no body. +// A 307 or 308 redirect preserves the original HTTP method and body, +// provided that the Request.GetBody function is defined. +// The NewRequest function automatically sets GetBody for common +// standard library body types. +func (c *Client) Do(req *Request) (*Response, error) { if req.URL == nil { req.closeBody() return nil, errors.New("http: nil Request.URL") } var ( - deadline = c.deadline() - reqs []*Request - resp *Response + deadline = c.deadline() + reqs []*Request + resp *Response + copyHeaders = c.makeHeadersCopier(req) + redirectMethod string ) uerr := func(err error) error { req.closeBody() @@ -476,16 +527,27 @@ func (c *Client) doFollowingRedirects(req *Request, shouldRedirect func(int) boo } ireq := reqs[0] req = &Request{ - Method: ireq.Method, + Method: redirectMethod, Response: resp, URL: u, Header: make(Header), Cancel: ireq.Cancel, ctx: ireq.ctx, } - if ireq.Method == "POST" || ireq.Method == "PUT" { - req.Method = "GET" + if ireq.GetBody != nil { + req.Body, err = ireq.GetBody() + if err != nil { + return nil, uerr(err) + } + req.ContentLength = ireq.ContentLength } + + // Copy original headers before setting the Referer, + // in case the user set Referer on their first request. + // If they really want to override, they can do it in + // their CheckRedirect func. + copyHeaders(req) + // Add the Referer header from the most recent // request URL to the new one, if it's not https->http: if ref := refererForURL(reqs[len(reqs)-1].URL, req.URL); ref != "" { @@ -523,10 +585,10 @@ func (c *Client) doFollowingRedirects(req *Request, shouldRedirect func(int) boo } reqs = append(reqs, req) - var err error - if resp, err = c.send(req, deadline); err != nil { - if !deadline.IsZero() && !time.Now().Before(deadline) { + var didTimeout func() bool + if resp, didTimeout, err = c.send(req, deadline); err != nil { + if !deadline.IsZero() && didTimeout() { err = &httpError{ err: err.Error() + " (Client.Timeout exceeded while awaiting headers)", timeout: true, @@ -535,9 +597,77 @@ func (c *Client) doFollowingRedirects(req *Request, shouldRedirect func(int) boo return nil, uerr(err) } - if !shouldRedirect(resp.StatusCode) { + var shouldRedirect bool + redirectMethod, shouldRedirect = redirectBehavior(req.Method, resp, reqs[0]) + if !shouldRedirect { return resp, nil } + + req.closeBody() + } +} + +// makeHeadersCopier makes a function that copies headers from the +// initial Request, ireq. For every redirect, this function must be called +// so that it can copy headers into the upcoming Request. +func (c *Client) makeHeadersCopier(ireq *Request) func(*Request) { + // The headers to copy are from the very initial request. + // We use a closured callback to keep a reference to these original headers. + var ( + ireqhdr = ireq.Header.clone() + icookies map[string][]*Cookie + ) + if c.Jar != nil && ireq.Header.Get("Cookie") != "" { + icookies = make(map[string][]*Cookie) + for _, c := range ireq.Cookies() { + icookies[c.Name] = append(icookies[c.Name], c) + } + } + + preq := ireq // The previous request + return func(req *Request) { + // If Jar is present and there was some initial cookies provided + // via the request header, then we may need to alter the initial + // cookies as we follow redirects since each redirect may end up + // modifying a pre-existing cookie. + // + // Since cookies already set in the request header do not contain + // information about the original domain and path, the logic below + // assumes any new set cookies override the original cookie + // regardless of domain or path. + // + // See https://golang.org/issue/17494 + if c.Jar != nil && icookies != nil { + var changed bool + resp := req.Response // The response that caused the upcoming redirect + for _, c := range resp.Cookies() { + if _, ok := icookies[c.Name]; ok { + delete(icookies, c.Name) + changed = true + } + } + if changed { + ireqhdr.Del("Cookie") + var ss []string + for _, cs := range icookies { + for _, c := range cs { + ss = append(ss, c.Name+"="+c.Value) + } + } + sort.Strings(ss) // Ensure deterministic headers + ireqhdr.Set("Cookie", strings.Join(ss, "; ")) + } + } + + // Copy the initial request's Header values + // (at least the safe ones). + for k, vv := range ireqhdr { + if shouldCopyHeaderOnRedirect(k, preq.URL, req.URL) { + req.Header[k] = vv + } + } + + preq = req // Update previous Request with the current request } } @@ -558,8 +688,11 @@ func defaultCheckRedirect(req *Request, via []*Request) error { // Post is a wrapper around DefaultClient.Post. // // To set custom headers, use NewRequest and DefaultClient.Do. -func Post(url string, bodyType string, body io.Reader) (resp *Response, err error) { - return DefaultClient.Post(url, bodyType, body) +// +// See the Client.Do method documentation for details on how redirects +// are handled. +func Post(url string, contentType string, body io.Reader) (resp *Response, err error) { + return DefaultClient.Post(url, contentType, body) } // Post issues a POST to the specified URL. @@ -570,13 +703,16 @@ func Post(url string, bodyType string, body io.Reader) (resp *Response, err erro // request. // // To set custom headers, use NewRequest and Client.Do. -func (c *Client) Post(url string, bodyType string, body io.Reader) (resp *Response, err error) { +// +// See the Client.Do method documentation for details on how redirects +// are handled. +func (c *Client) Post(url string, contentType string, body io.Reader) (resp *Response, err error) { req, err := NewRequest("POST", url, body) if err != nil { return nil, err } - req.Header.Set("Content-Type", bodyType) - return c.doFollowingRedirects(req, shouldRedirectPost) + req.Header.Set("Content-Type", contentType) + return c.Do(req) } // PostForm issues a POST to the specified URL, with data's keys and @@ -589,6 +725,9 @@ func (c *Client) Post(url string, bodyType string, body io.Reader) (resp *Respon // Caller should close resp.Body when done reading from it. // // PostForm is a wrapper around DefaultClient.PostForm. +// +// See the Client.Do method documentation for details on how redirects +// are handled. func PostForm(url string, data url.Values) (resp *Response, err error) { return DefaultClient.PostForm(url, data) } @@ -601,11 +740,14 @@ func PostForm(url string, data url.Values) (resp *Response, err error) { // // When err is nil, resp always contains a non-nil resp.Body. // Caller should close resp.Body when done reading from it. +// +// See the Client.Do method documentation for details on how redirects +// are handled. func (c *Client) PostForm(url string, data url.Values) (resp *Response, err error) { return c.Post(url, "application/x-www-form-urlencoded", strings.NewReader(data.Encode())) } -// Head issues a HEAD to the specified URL. If the response is one of +// Head issues a HEAD to the specified URL. If the response is one of // the following redirect codes, Head follows the redirect, up to a // maximum of 10 redirects: // @@ -613,13 +755,14 @@ func (c *Client) PostForm(url string, data url.Values) (resp *Response, err erro // 302 (Found) // 303 (See Other) // 307 (Temporary Redirect) +// 308 (Permanent Redirect) // // Head is a wrapper around DefaultClient.Head func Head(url string) (resp *Response, err error) { return DefaultClient.Head(url) } -// Head issues a HEAD to the specified URL. If the response is one of the +// Head issues a HEAD to the specified URL. If the response is one of the // following redirect codes, Head follows the redirect after calling the // Client's CheckRedirect function: // @@ -627,22 +770,23 @@ func Head(url string) (resp *Response, err error) { // 302 (Found) // 303 (See Other) // 307 (Temporary Redirect) +// 308 (Permanent Redirect) func (c *Client) Head(url string) (resp *Response, err error) { req, err := NewRequest("HEAD", url, nil) if err != nil { return nil, err } - return c.doFollowingRedirects(req, shouldRedirectGet) + return c.Do(req) } // cancelTimerBody is an io.ReadCloser that wraps rc with two features: // 1) on Read error or close, the stop func is called. -// 2) On Read failure, if reqWasCanceled is true, the error is wrapped and +// 2) On Read failure, if reqDidTimeout is true, the error is wrapped and // marked as net.Error that hit its timeout. type cancelTimerBody struct { - stop func() // stops the time.Timer waiting to cancel the request - rc io.ReadCloser - reqWasCanceled func() bool + stop func() // stops the time.Timer waiting to cancel the request + rc io.ReadCloser + reqDidTimeout func() bool } func (b *cancelTimerBody) Read(p []byte) (n int, err error) { @@ -654,7 +798,7 @@ func (b *cancelTimerBody) Read(p []byte) (n int, err error) { if err == io.EOF { return n, err } - if b.reqWasCanceled() { + if b.reqDidTimeout() { err = &httpError{ err: err.Error() + " (Client.Timeout exceeded while reading body)", timeout: true, @@ -668,3 +812,52 @@ func (b *cancelTimerBody) Close() error { b.stop() return err } + +func shouldCopyHeaderOnRedirect(headerKey string, initial, dest *url.URL) bool { + switch CanonicalHeaderKey(headerKey) { + case "Authorization", "Www-Authenticate", "Cookie", "Cookie2": + // Permit sending auth/cookie headers from "foo.com" + // to "sub.foo.com". + + // Note that we don't send all cookies to subdomains + // automatically. This function is only used for + // Cookies set explicitly on the initial outgoing + // client request. Cookies automatically added via the + // CookieJar mechanism continue to follow each + // cookie's scope as set by Set-Cookie. But for + // outgoing requests with the Cookie header set + // directly, we don't know their scope, so we assume + // it's for *.domain.com. + + // TODO(bradfitz): once issue 16142 is fixed, make + // this code use those URL accessors, and consider + // "http://foo.com" and "http://foo.com:80" as + // equivalent? + + // TODO(bradfitz): better hostname canonicalization, + // at least once we figure out IDNA/Punycode (issue + // 13835). + ihost := strings.ToLower(initial.Host) + dhost := strings.ToLower(dest.Host) + return isDomainOrSubdomain(dhost, ihost) + } + // All other headers are copied: + return true +} + +// isDomainOrSubdomain reports whether sub is a subdomain (or exact +// match) of the parent domain. +// +// Both domains must already be in canonical form. +func isDomainOrSubdomain(sub, parent string) bool { + if sub == parent { + return true + } + // If sub is "foo.example.com" and parent is "example.com", + // that means sub must end in "."+parent. + // Do it without allocating. + if !strings.HasSuffix(sub, parent) { + return false + } + return sub[len(sub)-len(parent)-1] == '.' +} diff --git a/libgo/go/net/http/client_test.go b/libgo/go/net/http/client_test.go index a9b1948005c..eaf2cdca8ee 100644 --- a/libgo/go/net/http/client_test.go +++ b/libgo/go/net/http/client_test.go @@ -19,11 +19,14 @@ import ( "log" "net" . "net/http" + "net/http/cookiejar" "net/http/httptest" "net/url" + "reflect" "strconv" "strings" "sync" + "sync/atomic" "testing" "time" ) @@ -65,11 +68,13 @@ func (w chanWriter) Write(p []byte) (n int, err error) { } func TestClient(t *testing.T) { + setParallel(t) defer afterTest(t) ts := httptest.NewServer(robotsTxtHandler) defer ts.Close() - r, err := Get(ts.URL) + c := &Client{Transport: &Transport{DisableKeepAlives: true}} + r, err := c.Get(ts.URL) var b []byte if err == nil { b, err = pedanticReadAll(r.Body) @@ -109,6 +114,7 @@ func (t *recordingTransport) RoundTrip(req *Request) (resp *Response, err error) } func TestGetRequestFormat(t *testing.T) { + setParallel(t) defer afterTest(t) tr := &recordingTransport{} client := &Client{Transport: tr} @@ -195,6 +201,7 @@ func TestPostFormRequestFormat(t *testing.T) { } func TestClientRedirects(t *testing.T) { + setParallel(t) defer afterTest(t) var ts *httptest.Server ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { @@ -206,14 +213,17 @@ func TestClientRedirects(t *testing.T) { } } if n < 15 { - Redirect(w, r, fmt.Sprintf("/?n=%d", n+1), StatusFound) + Redirect(w, r, fmt.Sprintf("/?n=%d", n+1), StatusTemporaryRedirect) return } fmt.Fprintf(w, "n=%d", n) })) defer ts.Close() - c := &Client{} + tr := &Transport{} + defer tr.CloseIdleConnections() + + c := &Client{Transport: tr} _, err := c.Get(ts.URL) if e, g := "Get /?n=10: stopped after 10 redirects", fmt.Sprintf("%v", err); e != g { t.Errorf("with default client Get, expected error %q, got %q", e, g) @@ -242,11 +252,14 @@ func TestClientRedirects(t *testing.T) { var checkErr error var lastVia []*Request var lastReq *Request - c = &Client{CheckRedirect: func(req *Request, via []*Request) error { - lastReq = req - lastVia = via - return checkErr - }} + c = &Client{ + Transport: tr, + CheckRedirect: func(req *Request, via []*Request) error { + lastReq = req + lastVia = via + return checkErr + }, + } res, err := c.Get(ts.URL) if err != nil { t.Fatalf("Get error: %v", err) @@ -292,20 +305,27 @@ func TestClientRedirects(t *testing.T) { } func TestClientRedirectContext(t *testing.T) { + setParallel(t) defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { - Redirect(w, r, "/", StatusFound) + Redirect(w, r, "/", StatusTemporaryRedirect) })) defer ts.Close() + tr := &Transport{} + defer tr.CloseIdleConnections() + ctx, cancel := context.WithCancel(context.Background()) - c := &Client{CheckRedirect: func(req *Request, via []*Request) error { - cancel() - if len(via) > 2 { - return errors.New("too many redirects") - } - return nil - }} + c := &Client{ + Transport: tr, + CheckRedirect: func(req *Request, via []*Request) error { + cancel() + if len(via) > 2 { + return errors.New("too many redirects") + } + return nil + }, + } req, _ := NewRequest("GET", ts.URL, nil) req = req.WithContext(ctx) _, err := c.Do(req) @@ -313,12 +333,96 @@ func TestClientRedirectContext(t *testing.T) { if !ok { t.Fatalf("got error %T; want *url.Error", err) } - if ue.Err != ExportErrRequestCanceled && ue.Err != ExportErrRequestCanceledConn { - t.Errorf("url.Error.Err = %v; want errRequestCanceled or errRequestCanceledConn", ue.Err) + if ue.Err != context.Canceled { + t.Errorf("url.Error.Err = %v; want %v", ue.Err, context.Canceled) } } +type redirectTest struct { + suffix string + want int // response code + redirectBody string +} + func TestPostRedirects(t *testing.T) { + postRedirectTests := []redirectTest{ + {"/", 200, "first"}, + {"/?code=301&next=302", 200, "c301"}, + {"/?code=302&next=302", 200, "c302"}, + {"/?code=303&next=301", 200, "c303wc301"}, // Issue 9348 + {"/?code=304", 304, "c304"}, + {"/?code=305", 305, "c305"}, + {"/?code=307&next=303,308,302", 200, "c307"}, + {"/?code=308&next=302,301", 200, "c308"}, + {"/?code=404", 404, "c404"}, + } + + wantSegments := []string{ + `POST / "first"`, + `POST /?code=301&next=302 "c301"`, + `GET /?code=302 "c301"`, + `GET / "c301"`, + `POST /?code=302&next=302 "c302"`, + `GET /?code=302 "c302"`, + `GET / "c302"`, + `POST /?code=303&next=301 "c303wc301"`, + `GET /?code=301 "c303wc301"`, + `GET / "c303wc301"`, + `POST /?code=304 "c304"`, + `POST /?code=305 "c305"`, + `POST /?code=307&next=303,308,302 "c307"`, + `POST /?code=303&next=308,302 "c307"`, + `GET /?code=308&next=302 "c307"`, + `GET /?code=302 "c307"`, + `GET / "c307"`, + `POST /?code=308&next=302,301 "c308"`, + `POST /?code=302&next=301 "c308"`, + `GET /?code=301 "c308"`, + `GET / "c308"`, + `POST /?code=404 "c404"`, + } + want := strings.Join(wantSegments, "\n") + testRedirectsByMethod(t, "POST", postRedirectTests, want) +} + +func TestDeleteRedirects(t *testing.T) { + deleteRedirectTests := []redirectTest{ + {"/", 200, "first"}, + {"/?code=301&next=302,308", 200, "c301"}, + {"/?code=302&next=302", 200, "c302"}, + {"/?code=303", 200, "c303"}, + {"/?code=307&next=301,308,303,302,304", 304, "c307"}, + {"/?code=308&next=307", 200, "c308"}, + {"/?code=404", 404, "c404"}, + } + + wantSegments := []string{ + `DELETE / "first"`, + `DELETE /?code=301&next=302,308 "c301"`, + `GET /?code=302&next=308 "c301"`, + `GET /?code=308 "c301"`, + `GET / "c301"`, + `DELETE /?code=302&next=302 "c302"`, + `GET /?code=302 "c302"`, + `GET / "c302"`, + `DELETE /?code=303 "c303"`, + `GET / "c303"`, + `DELETE /?code=307&next=301,308,303,302,304 "c307"`, + `DELETE /?code=301&next=308,303,302,304 "c307"`, + `GET /?code=308&next=303,302,304 "c307"`, + `GET /?code=303&next=302,304 "c307"`, + `GET /?code=302&next=304 "c307"`, + `GET /?code=304 "c307"`, + `DELETE /?code=308&next=307 "c308"`, + `DELETE /?code=307 "c308"`, + `DELETE / "c308"`, + `DELETE /?code=404 "c404"`, + } + want := strings.Join(wantSegments, "\n") + testRedirectsByMethod(t, "DELETE", deleteRedirectTests, want) +} + +func testRedirectsByMethod(t *testing.T, method string, table []redirectTest, want string) { defer afterTest(t) var log struct { sync.Mutex @@ -327,29 +431,35 @@ func TestPostRedirects(t *testing.T) { var ts *httptest.Server ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { log.Lock() - fmt.Fprintf(&log.Buffer, "%s %s ", r.Method, r.RequestURI) + slurp, _ := ioutil.ReadAll(r.Body) + fmt.Fprintf(&log.Buffer, "%s %s %q\n", r.Method, r.RequestURI, slurp) log.Unlock() - if v := r.URL.Query().Get("code"); v != "" { + urlQuery := r.URL.Query() + if v := urlQuery.Get("code"); v != "" { + location := ts.URL + if final := urlQuery.Get("next"); final != "" { + splits := strings.Split(final, ",") + first, rest := splits[0], splits[1:] + location = fmt.Sprintf("%s?code=%s", location, first) + if len(rest) > 0 { + location = fmt.Sprintf("%s&next=%s", location, strings.Join(rest, ",")) + } + } code, _ := strconv.Atoi(v) if code/100 == 3 { - w.Header().Set("Location", ts.URL) + w.Header().Set("Location", location) } w.WriteHeader(code) } })) defer ts.Close() - tests := []struct { - suffix string - want int // response code - }{ - {"/", 200}, - {"/?code=301", 301}, - {"/?code=302", 200}, - {"/?code=303", 200}, - {"/?code=404", 404}, - } - for _, tt := range tests { - res, err := Post(ts.URL+tt.suffix, "text/plain", strings.NewReader("Some content")) + + for _, tt := range table { + content := tt.redirectBody + req, _ := NewRequest(method, ts.URL+tt.suffix, strings.NewReader(content)) + req.GetBody = func() (io.ReadCloser, error) { return ioutil.NopCloser(strings.NewReader(content)), nil } + res, err := DefaultClient.Do(req) + if err != nil { t.Fatal(err) } @@ -360,13 +470,17 @@ func TestPostRedirects(t *testing.T) { log.Lock() got := log.String() log.Unlock() - want := "POST / POST /?code=301 POST /?code=302 GET / POST /?code=303 GET / POST /?code=404 " + + got = strings.TrimSpace(got) + want = strings.TrimSpace(want) + if got != want { - t.Errorf("Log differs.\n Got: %q\nWant: %q", got, want) + t.Errorf("Log differs.\n Got:\n%s\nWant:\n%s\n", got, want) } } func TestClientRedirectUseResponse(t *testing.T) { + setParallel(t) defer afterTest(t) const body = "Hello, world." var ts *httptest.Server @@ -381,12 +495,18 @@ func TestClientRedirectUseResponse(t *testing.T) { })) defer ts.Close() - c := &Client{CheckRedirect: func(req *Request, via []*Request) error { - if req.Response == nil { - t.Error("expected non-nil Request.Response") - } - return ErrUseLastResponse - }} + tr := &Transport{} + defer tr.CloseIdleConnections() + + c := &Client{ + Transport: tr, + CheckRedirect: func(req *Request, via []*Request) error { + if req.Response == nil { + t.Error("expected non-nil Request.Response") + } + return ErrUseLastResponse + }, + } res, err := c.Get(ts.URL) if err != nil { t.Fatal(err) @@ -404,6 +524,57 @@ func TestClientRedirectUseResponse(t *testing.T) { } } +// Issue 17773: don't follow a 308 (or 307) if the response doesn't +// have a Location header. +func TestClientRedirect308NoLocation(t *testing.T) { + setParallel(t) + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Foo", "Bar") + w.WriteHeader(308) + })) + defer ts.Close() + res, err := Get(ts.URL) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if res.StatusCode != 308 { + t.Errorf("status = %d; want %d", res.StatusCode, 308) + } + if got := res.Header.Get("Foo"); got != "Bar" { + t.Errorf("Foo header = %q; want Bar", got) + } +} + +// Don't follow a 307/308 if we can't resent the request body. +func TestClientRedirect308NoGetBody(t *testing.T) { + setParallel(t) + defer afterTest(t) + const fakeURL = "https://localhost:1234/" // won't be hit + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Location", fakeURL) + w.WriteHeader(308) + })) + defer ts.Close() + req, err := NewRequest("POST", ts.URL, strings.NewReader("some body")) + if err != nil { + t.Fatal(err) + } + req.GetBody = nil // so it can't rewind. + res, err := DefaultClient.Do(req) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + if res.StatusCode != 308 { + t.Errorf("status = %d; want %d", res.StatusCode, 308) + } + if got := res.Header.Get("Location"); got != fakeURL { + t.Errorf("Location header = %q; want %q", got, fakeURL) + } +} + var expectedCookies = []*Cookie{ {Name: "ChocolateChip", Value: "tasty"}, {Name: "First", Value: "Hit"}, @@ -476,12 +647,16 @@ func (j *TestJar) Cookies(u *url.URL) []*Cookie { } func TestRedirectCookiesJar(t *testing.T) { + setParallel(t) defer afterTest(t) var ts *httptest.Server ts = httptest.NewServer(echoCookiesRedirectHandler) defer ts.Close() + tr := &Transport{} + defer tr.CloseIdleConnections() c := &Client{ - Jar: new(TestJar), + Transport: tr, + Jar: new(TestJar), } u, _ := url.Parse(ts.URL) c.Jar.SetCookies(u, []*Cookie{expectedCookies[0]}) @@ -665,6 +840,7 @@ func TestClientWrites(t *testing.T) { } func TestClientInsecureTransport(t *testing.T) { + setParallel(t) defer afterTest(t) ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { w.Write([]byte("Hello")) @@ -842,6 +1018,7 @@ func TestResponseSetsTLSConnectionState(t *testing.T) { func TestHTTPSClientDetectsHTTPServer(t *testing.T) { defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) + ts.Config.ErrorLog = quietLog defer ts.Close() _, err := Get(strings.Replace(ts.URL, "http", "https", 1)) @@ -895,6 +1072,7 @@ func testClientHeadContentLength(t *testing.T, h2 bool) { } func TestEmptyPasswordAuth(t *testing.T) { + setParallel(t) defer afterTest(t) gopher := "gopher" ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { @@ -915,7 +1093,9 @@ func TestEmptyPasswordAuth(t *testing.T) { } })) defer ts.Close() - c := &Client{} + tr := &Transport{} + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} req, err := NewRequest("GET", ts.URL, nil) if err != nil { t.Fatal(err) @@ -1007,10 +1187,10 @@ func TestClientTimeout_h1(t *testing.T) { testClientTimeout(t, h1Mode) } func TestClientTimeout_h2(t *testing.T) { testClientTimeout(t, h2Mode) } func testClientTimeout(t *testing.T, h2 bool) { - if testing.Short() { - t.Skip("skipping in short mode") - } + setParallel(t) defer afterTest(t) + testDone := make(chan struct{}) // closed in defer below + sawRoot := make(chan bool, 1) sawSlow := make(chan bool, 1) cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { @@ -1020,19 +1200,26 @@ func testClientTimeout(t *testing.T, h2 bool) { return } if r.URL.Path == "/slow" { + sawSlow <- true w.Write([]byte("Hello")) w.(Flusher).Flush() - sawSlow <- true - time.Sleep(2 * time.Second) + <-testDone return } })) defer cst.close() - const timeout = 500 * time.Millisecond + defer close(testDone) // before cst.close, to unblock /slow handler + + // 200ms should be long enough to get a normal request (the / + // handler), but not so long that it makes the test slow. + const timeout = 200 * time.Millisecond cst.c.Timeout = timeout res, err := cst.c.Get(cst.ts.URL) if err != nil { + if strings.Contains(err.Error(), "Client.Timeout") { + t.Skipf("host too slow to get fast resource in %v", timeout) + } t.Fatal(err) } @@ -1057,7 +1244,7 @@ func testClientTimeout(t *testing.T, h2 bool) { res.Body.Close() }() - const failTime = timeout * 2 + const failTime = 5 * time.Second select { case err := <-errc: if err == nil { @@ -1082,11 +1269,9 @@ func TestClientTimeout_Headers_h2(t *testing.T) { testClientTimeout_Headers(t, h // Client.Timeout firing before getting to the body func testClientTimeout_Headers(t *testing.T, h2 bool) { - if testing.Short() { - t.Skip("skipping in short mode") - } + setParallel(t) defer afterTest(t) - donec := make(chan bool) + donec := make(chan bool, 1) cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { <-donec })) @@ -1100,9 +1285,10 @@ func testClientTimeout_Headers(t *testing.T, h2 bool) { // doesn't know this, so synchronize explicitly. defer func() { donec <- true }() - cst.c.Timeout = 500 * time.Millisecond - _, err := cst.c.Get(cst.ts.URL) + cst.c.Timeout = 5 * time.Millisecond + res, err := cst.c.Get(cst.ts.URL) if err == nil { + res.Body.Close() t.Fatal("got response from Get; expected error") } if _, ok := err.(*url.Error); !ok { @@ -1120,9 +1306,40 @@ func testClientTimeout_Headers(t *testing.T, h2 bool) { } } +// Issue 16094: if Client.Timeout is set but not hit, a Timeout error shouldn't be +// returned. +func TestClientTimeoutCancel(t *testing.T) { + setParallel(t) + defer afterTest(t) + + testDone := make(chan struct{}) + ctx, cancel := context.WithCancel(context.Background()) + + cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + w.(Flusher).Flush() + <-testDone + })) + defer cst.close() + defer close(testDone) + + cst.c.Timeout = 1 * time.Hour + req, _ := NewRequest("GET", cst.ts.URL, nil) + req.Cancel = ctx.Done() + res, err := cst.c.Do(req) + if err != nil { + t.Fatal(err) + } + cancel() + _, err = io.Copy(ioutil.Discard, res.Body) + if err != ExportErrRequestCanceled { + t.Fatalf("error = %v; want errRequestCanceled", err) + } +} + func TestClientRedirectEatsBody_h1(t *testing.T) { testClientRedirectEatsBody(t, h1Mode) } func TestClientRedirectEatsBody_h2(t *testing.T) { testClientRedirectEatsBody(t, h2Mode) } func testClientRedirectEatsBody(t *testing.T, h2 bool) { + setParallel(t) defer afterTest(t) saw := make(chan string, 2) cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { @@ -1138,10 +1355,10 @@ func testClientRedirectEatsBody(t *testing.T, h2 bool) { t.Fatal(err) } _, err = ioutil.ReadAll(res.Body) + res.Body.Close() if err != nil { t.Fatal(err) } - res.Body.Close() var first string select { @@ -1229,3 +1446,369 @@ func TestClientRedirectResponseWithoutRequest(t *testing.T) { // Check that this doesn't crash: c.Get("http://dummy.tld") } + +// Issue 4800: copy (some) headers when Client follows a redirect +func TestClientCopyHeadersOnRedirect(t *testing.T) { + const ( + ua = "some-agent/1.2" + xfoo = "foo-val" + ) + var ts2URL string + ts1 := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + want := Header{ + "User-Agent": []string{ua}, + "X-Foo": []string{xfoo}, + "Referer": []string{ts2URL}, + "Accept-Encoding": []string{"gzip"}, + } + if !reflect.DeepEqual(r.Header, want) { + t.Errorf("Request.Header = %#v; want %#v", r.Header, want) + } + if t.Failed() { + w.Header().Set("Result", "got errors") + } else { + w.Header().Set("Result", "ok") + } + })) + defer ts1.Close() + ts2 := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + Redirect(w, r, ts1.URL, StatusFound) + })) + defer ts2.Close() + ts2URL = ts2.URL + + tr := &Transport{} + defer tr.CloseIdleConnections() + c := &Client{ + Transport: tr, + CheckRedirect: func(r *Request, via []*Request) error { + want := Header{ + "User-Agent": []string{ua}, + "X-Foo": []string{xfoo}, + "Referer": []string{ts2URL}, + } + if !reflect.DeepEqual(r.Header, want) { + t.Errorf("CheckRedirect Request.Header = %#v; want %#v", r.Header, want) + } + return nil + }, + } + + req, _ := NewRequest("GET", ts2.URL, nil) + req.Header.Add("User-Agent", ua) + req.Header.Add("X-Foo", xfoo) + req.Header.Add("Cookie", "foo=bar") + req.Header.Add("Authorization", "secretpassword") + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + if res.StatusCode != 200 { + t.Fatal(res.Status) + } + if got := res.Header.Get("Result"); got != "ok" { + t.Errorf("result = %q; want ok", got) + } +} + +// Issue 17494: cookies should be altered when Client follows redirects. +func TestClientAltersCookiesOnRedirect(t *testing.T) { + cookieMap := func(cs []*Cookie) map[string][]string { + m := make(map[string][]string) + for _, c := range cs { + m[c.Name] = append(m[c.Name], c.Value) + } + return m + } + + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + var want map[string][]string + got := cookieMap(r.Cookies()) + + c, _ := r.Cookie("Cycle") + switch c.Value { + case "0": + want = map[string][]string{ + "Cookie1": {"OldValue1a", "OldValue1b"}, + "Cookie2": {"OldValue2"}, + "Cookie3": {"OldValue3a", "OldValue3b"}, + "Cookie4": {"OldValue4"}, + "Cycle": {"0"}, + } + SetCookie(w, &Cookie{Name: "Cycle", Value: "1", Path: "/"}) + SetCookie(w, &Cookie{Name: "Cookie2", Path: "/", MaxAge: -1}) // Delete cookie from Header + Redirect(w, r, "/", StatusFound) + case "1": + want = map[string][]string{ + "Cookie1": {"OldValue1a", "OldValue1b"}, + "Cookie3": {"OldValue3a", "OldValue3b"}, + "Cookie4": {"OldValue4"}, + "Cycle": {"1"}, + } + SetCookie(w, &Cookie{Name: "Cycle", Value: "2", Path: "/"}) + SetCookie(w, &Cookie{Name: "Cookie3", Value: "NewValue3", Path: "/"}) // Modify cookie in Header + SetCookie(w, &Cookie{Name: "Cookie4", Value: "NewValue4", Path: "/"}) // Modify cookie in Jar + Redirect(w, r, "/", StatusFound) + case "2": + want = map[string][]string{ + "Cookie1": {"OldValue1a", "OldValue1b"}, + "Cookie3": {"NewValue3"}, + "Cookie4": {"NewValue4"}, + "Cycle": {"2"}, + } + SetCookie(w, &Cookie{Name: "Cycle", Value: "3", Path: "/"}) + SetCookie(w, &Cookie{Name: "Cookie5", Value: "NewValue5", Path: "/"}) // Insert cookie into Jar + Redirect(w, r, "/", StatusFound) + case "3": + want = map[string][]string{ + "Cookie1": {"OldValue1a", "OldValue1b"}, + "Cookie3": {"NewValue3"}, + "Cookie4": {"NewValue4"}, + "Cookie5": {"NewValue5"}, + "Cycle": {"3"}, + } + // Don't redirect to ensure the loop ends. + default: + t.Errorf("unexpected redirect cycle") + return + } + + if !reflect.DeepEqual(got, want) { + t.Errorf("redirect %s, Cookie = %v, want %v", c.Value, got, want) + } + })) + defer ts.Close() + + tr := &Transport{} + defer tr.CloseIdleConnections() + jar, _ := cookiejar.New(nil) + c := &Client{ + Transport: tr, + Jar: jar, + } + + u, _ := url.Parse(ts.URL) + req, _ := NewRequest("GET", ts.URL, nil) + req.AddCookie(&Cookie{Name: "Cookie1", Value: "OldValue1a"}) + req.AddCookie(&Cookie{Name: "Cookie1", Value: "OldValue1b"}) + req.AddCookie(&Cookie{Name: "Cookie2", Value: "OldValue2"}) + req.AddCookie(&Cookie{Name: "Cookie3", Value: "OldValue3a"}) + req.AddCookie(&Cookie{Name: "Cookie3", Value: "OldValue3b"}) + jar.SetCookies(u, []*Cookie{{Name: "Cookie4", Value: "OldValue4", Path: "/"}}) + jar.SetCookies(u, []*Cookie{{Name: "Cycle", Value: "0", Path: "/"}}) + res, err := c.Do(req) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + if res.StatusCode != 200 { + t.Fatal(res.Status) + } +} + +// Part of Issue 4800 +func TestShouldCopyHeaderOnRedirect(t *testing.T) { + tests := []struct { + header string + initialURL string + destURL string + want bool + }{ + {"User-Agent", "http://foo.com/", "http://bar.com/", true}, + {"X-Foo", "http://foo.com/", "http://bar.com/", true}, + + // Sensitive headers: + {"cookie", "http://foo.com/", "http://bar.com/", false}, + {"cookie2", "http://foo.com/", "http://bar.com/", false}, + {"authorization", "http://foo.com/", "http://bar.com/", false}, + {"www-authenticate", "http://foo.com/", "http://bar.com/", false}, + + // But subdomains should work: + {"www-authenticate", "http://foo.com/", "http://foo.com/", true}, + {"www-authenticate", "http://foo.com/", "http://sub.foo.com/", true}, + {"www-authenticate", "http://foo.com/", "http://notfoo.com/", false}, + // TODO(bradfitz): make this test work, once issue 16142 is fixed: + // {"www-authenticate", "http://foo.com:80/", "http://foo.com/", true}, + } + for i, tt := range tests { + u0, err := url.Parse(tt.initialURL) + if err != nil { + t.Errorf("%d. initial URL %q parse error: %v", i, tt.initialURL, err) + continue + } + u1, err := url.Parse(tt.destURL) + if err != nil { + t.Errorf("%d. dest URL %q parse error: %v", i, tt.destURL, err) + continue + } + got := Export_shouldCopyHeaderOnRedirect(tt.header, u0, u1) + if got != tt.want { + t.Errorf("%d. shouldCopyHeaderOnRedirect(%q, %q => %q) = %v; want %v", + i, tt.header, tt.initialURL, tt.destURL, got, tt.want) + } + } +} + +func TestClientRedirectTypes(t *testing.T) { + setParallel(t) + defer afterTest(t) + + tests := [...]struct { + method string + serverStatus int + wantMethod string // desired subsequent client method + }{ + 0: {method: "POST", serverStatus: 301, wantMethod: "GET"}, + 1: {method: "POST", serverStatus: 302, wantMethod: "GET"}, + 2: {method: "POST", serverStatus: 303, wantMethod: "GET"}, + 3: {method: "POST", serverStatus: 307, wantMethod: "POST"}, + 4: {method: "POST", serverStatus: 308, wantMethod: "POST"}, + + 5: {method: "HEAD", serverStatus: 301, wantMethod: "HEAD"}, + 6: {method: "HEAD", serverStatus: 302, wantMethod: "HEAD"}, + 7: {method: "HEAD", serverStatus: 303, wantMethod: "HEAD"}, + 8: {method: "HEAD", serverStatus: 307, wantMethod: "HEAD"}, + 9: {method: "HEAD", serverStatus: 308, wantMethod: "HEAD"}, + + 10: {method: "GET", serverStatus: 301, wantMethod: "GET"}, + 11: {method: "GET", serverStatus: 302, wantMethod: "GET"}, + 12: {method: "GET", serverStatus: 303, wantMethod: "GET"}, + 13: {method: "GET", serverStatus: 307, wantMethod: "GET"}, + 14: {method: "GET", serverStatus: 308, wantMethod: "GET"}, + + 15: {method: "DELETE", serverStatus: 301, wantMethod: "GET"}, + 16: {method: "DELETE", serverStatus: 302, wantMethod: "GET"}, + 17: {method: "DELETE", serverStatus: 303, wantMethod: "GET"}, + 18: {method: "DELETE", serverStatus: 307, wantMethod: "DELETE"}, + 19: {method: "DELETE", serverStatus: 308, wantMethod: "DELETE"}, + + 20: {method: "PUT", serverStatus: 301, wantMethod: "GET"}, + 21: {method: "PUT", serverStatus: 302, wantMethod: "GET"}, + 22: {method: "PUT", serverStatus: 303, wantMethod: "GET"}, + 23: {method: "PUT", serverStatus: 307, wantMethod: "PUT"}, + 24: {method: "PUT", serverStatus: 308, wantMethod: "PUT"}, + + 25: {method: "MADEUPMETHOD", serverStatus: 301, wantMethod: "GET"}, + 26: {method: "MADEUPMETHOD", serverStatus: 302, wantMethod: "GET"}, + 27: {method: "MADEUPMETHOD", serverStatus: 303, wantMethod: "GET"}, + 28: {method: "MADEUPMETHOD", serverStatus: 307, wantMethod: "MADEUPMETHOD"}, + 29: {method: "MADEUPMETHOD", serverStatus: 308, wantMethod: "MADEUPMETHOD"}, + } + + handlerc := make(chan HandlerFunc, 1) + + ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + h := <-handlerc + h(rw, req) + })) + defer ts.Close() + + tr := &Transport{} + defer tr.CloseIdleConnections() + + for i, tt := range tests { + handlerc <- func(w ResponseWriter, r *Request) { + w.Header().Set("Location", ts.URL) + w.WriteHeader(tt.serverStatus) + } + + req, err := NewRequest(tt.method, ts.URL, nil) + if err != nil { + t.Errorf("#%d: NewRequest: %v", i, err) + continue + } + + c := &Client{Transport: tr} + c.CheckRedirect = func(req *Request, via []*Request) error { + if got, want := req.Method, tt.wantMethod; got != want { + return fmt.Errorf("#%d: got next method %q; want %q", i, got, want) + } + handlerc <- func(rw ResponseWriter, req *Request) { + // TODO: Check that the body is valid when we do 307 and 308 support + } + return nil + } + + res, err := c.Do(req) + if err != nil { + t.Errorf("#%d: Response: %v", i, err) + continue + } + + res.Body.Close() + } +} + +// issue18239Body is an io.ReadCloser for TestTransportBodyReadError. +// Its Read returns readErr and increments *readCalls atomically. +// Its Close returns nil and increments *closeCalls atomically. +type issue18239Body struct { + readCalls *int32 + closeCalls *int32 + readErr error +} + +func (b issue18239Body) Read([]byte) (int, error) { + atomic.AddInt32(b.readCalls, 1) + return 0, b.readErr +} + +func (b issue18239Body) Close() error { + atomic.AddInt32(b.closeCalls, 1) + return nil +} + +// Issue 18239: make sure the Transport doesn't retry requests with bodies. +// (Especially if Request.GetBody is not defined.) +func TestTransportBodyReadError(t *testing.T) { + setParallel(t) + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if r.URL.Path == "/ping" { + return + } + buf := make([]byte, 1) + n, err := r.Body.Read(buf) + w.Header().Set("X-Body-Read", fmt.Sprintf("%v, %v", n, err)) + })) + defer ts.Close() + tr := &Transport{} + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + + // Do one initial successful request to create an idle TCP connection + // for the subsequent request to reuse. (The Transport only retries + // requests on reused connections.) + res, err := c.Get(ts.URL + "/ping") + if err != nil { + t.Fatal(err) + } + res.Body.Close() + + var readCallsAtomic int32 + var closeCallsAtomic int32 // atomic + someErr := errors.New("some body read error") + body := issue18239Body{&readCallsAtomic, &closeCallsAtomic, someErr} + + req, err := NewRequest("POST", ts.URL, body) + if err != nil { + t.Fatal(err) + } + _, err = tr.RoundTrip(req) + if err != someErr { + t.Errorf("Got error: %v; want Request.Body read error: %v", err, someErr) + } + + // And verify that our Body wasn't used multiple times, which + // would indicate retries. (as it buggily was during part of + // Go 1.8's dev cycle) + readCalls := atomic.LoadInt32(&readCallsAtomic) + closeCalls := atomic.LoadInt32(&closeCallsAtomic) + if readCalls != 1 { + t.Errorf("read calls = %d; want 1", readCalls) + } + if closeCalls != 1 { + t.Errorf("close calls = %d; want 1", closeCalls) + } +} diff --git a/libgo/go/net/http/clientserver_test.go b/libgo/go/net/http/clientserver_test.go index 3d1f09cae83..580115ca9c0 100644 --- a/libgo/go/net/http/clientserver_test.go +++ b/libgo/go/net/http/clientserver_test.go @@ -44,6 +44,19 @@ func (t *clientServerTest) close() { t.ts.Close() } +func (t *clientServerTest) getURL(u string) string { + res, err := t.c.Get(u) + if err != nil { + t.t.Fatal(err) + } + defer res.Body.Close() + slurp, err := ioutil.ReadAll(res.Body) + if err != nil { + t.t.Fatal(err) + } + return string(slurp) +} + func (t *clientServerTest) scheme() string { if t.h2 { return "https" @@ -56,6 +69,10 @@ const ( h2Mode = true ) +var optQuietLog = func(ts *httptest.Server) { + ts.Config.ErrorLog = quietLog +} + func newClientServerTest(t *testing.T, h2 bool, h Handler, opts ...interface{}) *clientServerTest { cst := &clientServerTest{ t: t, @@ -64,21 +81,23 @@ func newClientServerTest(t *testing.T, h2 bool, h Handler, opts ...interface{}) tr: &Transport{}, } cst.c = &Client{Transport: cst.tr} + cst.ts = httptest.NewUnstartedServer(h) for _, opt := range opts { switch opt := opt.(type) { case func(*Transport): opt(cst.tr) + case func(*httptest.Server): + opt(cst.ts) default: t.Fatalf("unhandled option type %T", opt) } } if !h2 { - cst.ts = httptest.NewServer(h) + cst.ts.Start() return cst } - cst.ts = httptest.NewUnstartedServer(h) ExportHttp2ConfigureServer(cst.ts.Config, nil) cst.ts.TLS = cst.ts.Config.TLSConfig cst.ts.StartTLS() @@ -170,6 +189,7 @@ func (tt h12Compare) reqFunc() reqFunc { } func (tt h12Compare) run(t *testing.T) { + setParallel(t) cst1 := newClientServerTest(t, false, HandlerFunc(tt.Handler), tt.Opts...) defer cst1.close() cst2 := newClientServerTest(t, true, HandlerFunc(tt.Handler), tt.Opts...) @@ -468,7 +488,7 @@ func TestH12_RequestContentLength_Known_NonZero(t *testing.T) { } func TestH12_RequestContentLength_Known_Zero(t *testing.T) { - h12requestContentLength(t, func() io.Reader { return strings.NewReader("") }, 0) + h12requestContentLength(t, func() io.Reader { return nil }, 0) } func TestH12_RequestContentLength_Unknown(t *testing.T) { @@ -938,6 +958,7 @@ func testStarRequest(t *testing.T, method string, h2 bool) { // Issue 13957 func TestTransportDiscardsUnneededConns(t *testing.T) { + setParallel(t) defer afterTest(t) cst := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "Hello, %v", r.RemoteAddr) @@ -1026,6 +1047,7 @@ func testTransportGCRequest(t *testing.T, h2, body bool) { t.Skip("skipping on gccgo because conservative GC means that finalizer may never run") } + setParallel(t) defer afterTest(t) cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { ioutil.ReadAll(r.Body) @@ -1072,10 +1094,11 @@ func TestTransportRejectsInvalidHeaders_h2(t *testing.T) { testTransportRejectsInvalidHeaders(t, h2Mode) } func testTransportRejectsInvalidHeaders(t *testing.T, h2 bool) { + setParallel(t) defer afterTest(t) cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "Handler saw headers: %q", r.Header) - })) + }), optQuietLog) defer cst.close() cst.tr.DisableKeepAlives = true @@ -1143,24 +1166,44 @@ func testBogusStatusWorks(t *testing.T, h2 bool) { } } -func TestInterruptWithPanic_h1(t *testing.T) { testInterruptWithPanic(t, h1Mode) } -func TestInterruptWithPanic_h2(t *testing.T) { testInterruptWithPanic(t, h2Mode) } -func testInterruptWithPanic(t *testing.T, h2 bool) { - log.SetOutput(ioutil.Discard) // is noisy otherwise - defer log.SetOutput(os.Stderr) - +func TestInterruptWithPanic_h1(t *testing.T) { testInterruptWithPanic(t, h1Mode, "boom") } +func TestInterruptWithPanic_h2(t *testing.T) { testInterruptWithPanic(t, h2Mode, "boom") } +func TestInterruptWithPanic_nil_h1(t *testing.T) { testInterruptWithPanic(t, h1Mode, nil) } +func TestInterruptWithPanic_nil_h2(t *testing.T) { testInterruptWithPanic(t, h2Mode, nil) } +func TestInterruptWithPanic_ErrAbortHandler_h1(t *testing.T) { + testInterruptWithPanic(t, h1Mode, ErrAbortHandler) +} +func TestInterruptWithPanic_ErrAbortHandler_h2(t *testing.T) { + testInterruptWithPanic(t, h2Mode, ErrAbortHandler) +} +func testInterruptWithPanic(t *testing.T, h2 bool, panicValue interface{}) { + setParallel(t) const msg = "hello" defer afterTest(t) + + testDone := make(chan struct{}) + defer close(testDone) + + var errorLog lockedBytesBuffer + gotHeaders := make(chan bool, 1) cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { io.WriteString(w, msg) w.(Flusher).Flush() - panic("no more") - })) + + select { + case <-gotHeaders: + case <-testDone: + } + panic(panicValue) + }), func(ts *httptest.Server) { + ts.Config.ErrorLog = log.New(&errorLog, "", 0) + }) defer cst.close() res, err := cst.c.Get(cst.ts.URL) if err != nil { t.Fatal(err) } + gotHeaders <- true defer res.Body.Close() slurp, err := ioutil.ReadAll(res.Body) if string(slurp) != msg { @@ -1169,6 +1212,42 @@ func testInterruptWithPanic(t *testing.T, h2 bool) { if err == nil { t.Errorf("client read all successfully; want some error") } + logOutput := func() string { + errorLog.Lock() + defer errorLog.Unlock() + return errorLog.String() + } + wantStackLogged := panicValue != nil && panicValue != ErrAbortHandler + + if err := waitErrCondition(5*time.Second, 10*time.Millisecond, func() error { + gotLog := logOutput() + if !wantStackLogged { + if gotLog == "" { + return nil + } + return fmt.Errorf("want no log output; got: %s", gotLog) + } + if gotLog == "" { + return fmt.Errorf("wanted a stack trace logged; got nothing") + } + if !strings.Contains(gotLog, "created by ") && strings.Count(gotLog, "\n") < 6 { + return fmt.Errorf("output doesn't look like a panic stack trace. Got: %s", gotLog) + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +type lockedBytesBuffer struct { + sync.Mutex + bytes.Buffer +} + +func (b *lockedBytesBuffer) Write(p []byte) (int, error) { + b.Lock() + defer b.Unlock() + return b.Buffer.Write(p) } // Issue 15366 @@ -1204,6 +1283,7 @@ func TestH12_AutoGzipWithDumpResponse(t *testing.T) { func TestCloseIdleConnections_h1(t *testing.T) { testCloseIdleConnections(t, h1Mode) } func TestCloseIdleConnections_h2(t *testing.T) { testCloseIdleConnections(t, h2Mode) } func testCloseIdleConnections(t *testing.T, h2 bool) { + setParallel(t) defer afterTest(t) cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("X-Addr", r.RemoteAddr) @@ -1238,3 +1318,70 @@ func (x noteCloseConn) Close() error { x.closeFunc() return x.Conn.Close() } + +type testErrorReader struct{ t *testing.T } + +func (r testErrorReader) Read(p []byte) (n int, err error) { + r.t.Error("unexpected Read call") + return 0, io.EOF +} + +func TestNoSniffExpectRequestBody_h1(t *testing.T) { testNoSniffExpectRequestBody(t, h1Mode) } +func TestNoSniffExpectRequestBody_h2(t *testing.T) { testNoSniffExpectRequestBody(t, h2Mode) } + +func testNoSniffExpectRequestBody(t *testing.T, h2 bool) { + defer afterTest(t) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + w.WriteHeader(StatusUnauthorized) + })) + defer cst.close() + + // Set ExpectContinueTimeout non-zero so RoundTrip won't try to write it. + cst.tr.ExpectContinueTimeout = 10 * time.Second + + req, err := NewRequest("POST", cst.ts.URL, testErrorReader{t}) + if err != nil { + t.Fatal(err) + } + req.ContentLength = 0 // so transport is tempted to sniff it + req.Header.Set("Expect", "100-continue") + res, err := cst.tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + if res.StatusCode != StatusUnauthorized { + t.Errorf("status code = %v; want %v", res.StatusCode, StatusUnauthorized) + } +} + +func TestServerUndeclaredTrailers_h1(t *testing.T) { testServerUndeclaredTrailers(t, h1Mode) } +func TestServerUndeclaredTrailers_h2(t *testing.T) { testServerUndeclaredTrailers(t, h2Mode) } +func testServerUndeclaredTrailers(t *testing.T, h2 bool) { + defer afterTest(t) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Foo", "Bar") + w.Header().Set("Trailer:Foo", "Baz") + w.(Flusher).Flush() + w.Header().Add("Trailer:Foo", "Baz2") + w.Header().Set("Trailer:Bar", "Quux") + })) + defer cst.close() + res, err := cst.c.Get(cst.ts.URL) + if err != nil { + t.Fatal(err) + } + if _, err := io.Copy(ioutil.Discard, res.Body); err != nil { + t.Fatal(err) + } + res.Body.Close() + delete(res.Header, "Date") + delete(res.Header, "Content-Type") + + if want := (Header{"Foo": {"Bar"}}); !reflect.DeepEqual(res.Header, want) { + t.Errorf("Header = %#v; want %#v", res.Header, want) + } + if want := (Header{"Foo": {"Baz", "Baz2"}, "Bar": {"Quux"}}); !reflect.DeepEqual(res.Trailer, want) { + t.Errorf("Trailer = %#v; want %#v", res.Trailer, want) + } +} diff --git a/libgo/go/net/http/cookie.go b/libgo/go/net/http/cookie.go index 1ea0e9397a3..5a67476cd42 100644 --- a/libgo/go/net/http/cookie.go +++ b/libgo/go/net/http/cookie.go @@ -6,7 +6,6 @@ package http import ( "bytes" - "fmt" "log" "net" "strconv" @@ -40,7 +39,11 @@ type Cookie struct { // readSetCookies parses all "Set-Cookie" values from // the header h and returns the successfully parsed Cookies. func readSetCookies(h Header) []*Cookie { - cookies := []*Cookie{} + cookieCount := len(h["Set-Cookie"]) + if cookieCount == 0 { + return []*Cookie{} + } + cookies := make([]*Cookie, 0, cookieCount) for _, line := range h["Set-Cookie"] { parts := strings.Split(strings.TrimSpace(line), ";") if len(parts) == 1 && parts[0] == "" { @@ -55,8 +58,8 @@ func readSetCookies(h Header) []*Cookie { if !isCookieNameValid(name) { continue } - value, success := parseCookieValue(value, true) - if !success { + value, ok := parseCookieValue(value, true) + if !ok { continue } c := &Cookie{ @@ -75,8 +78,8 @@ func readSetCookies(h Header) []*Cookie { attr, val = attr[:j], attr[j+1:] } lowerAttr := strings.ToLower(attr) - val, success = parseCookieValue(val, false) - if !success { + val, ok = parseCookieValue(val, false) + if !ok { c.Unparsed = append(c.Unparsed, parts[i]) continue } @@ -96,10 +99,9 @@ func readSetCookies(h Header) []*Cookie { break } if secs <= 0 { - c.MaxAge = -1 - } else { - c.MaxAge = secs + secs = -1 } + c.MaxAge = secs continue case "expires": c.RawExpires = val @@ -142,9 +144,13 @@ func (c *Cookie) String() string { return "" } var b bytes.Buffer - fmt.Fprintf(&b, "%s=%s", sanitizeCookieName(c.Name), sanitizeCookieValue(c.Value)) + b.WriteString(sanitizeCookieName(c.Name)) + b.WriteRune('=') + b.WriteString(sanitizeCookieValue(c.Value)) + if len(c.Path) > 0 { - fmt.Fprintf(&b, "; Path=%s", sanitizeCookiePath(c.Path)) + b.WriteString("; Path=") + b.WriteString(sanitizeCookiePath(c.Path)) } if len(c.Domain) > 0 { if validCookieDomain(c.Domain) { @@ -156,25 +162,31 @@ func (c *Cookie) String() string { if d[0] == '.' { d = d[1:] } - fmt.Fprintf(&b, "; Domain=%s", d) + b.WriteString("; Domain=") + b.WriteString(d) } else { - log.Printf("net/http: invalid Cookie.Domain %q; dropping domain attribute", - c.Domain) + log.Printf("net/http: invalid Cookie.Domain %q; dropping domain attribute", c.Domain) } } - if c.Expires.Unix() > 0 { - fmt.Fprintf(&b, "; Expires=%s", c.Expires.UTC().Format(TimeFormat)) + if validCookieExpires(c.Expires) { + b.WriteString("; Expires=") + b2 := b.Bytes() + b.Reset() + b.Write(c.Expires.UTC().AppendFormat(b2, TimeFormat)) } if c.MaxAge > 0 { - fmt.Fprintf(&b, "; Max-Age=%d", c.MaxAge) + b.WriteString("; Max-Age=") + b2 := b.Bytes() + b.Reset() + b.Write(strconv.AppendInt(b2, int64(c.MaxAge), 10)) } else if c.MaxAge < 0 { - fmt.Fprintf(&b, "; Max-Age=0") + b.WriteString("; Max-Age=0") } if c.HttpOnly { - fmt.Fprintf(&b, "; HttpOnly") + b.WriteString("; HttpOnly") } if c.Secure { - fmt.Fprintf(&b, "; Secure") + b.WriteString("; Secure") } return b.String() } @@ -184,12 +196,12 @@ func (c *Cookie) String() string { // // if filter isn't empty, only cookies of that name are returned func readCookies(h Header, filter string) []*Cookie { - cookies := []*Cookie{} lines, ok := h["Cookie"] if !ok { - return cookies + return []*Cookie{} } + cookies := []*Cookie{} for _, line := range lines { parts := strings.Split(strings.TrimSpace(line), ";") if len(parts) == 1 && parts[0] == "" { @@ -212,8 +224,8 @@ func readCookies(h Header, filter string) []*Cookie { if filter != "" && filter != name { continue } - val, success := parseCookieValue(val, true) - if !success { + val, ok := parseCookieValue(val, true) + if !ok { continue } cookies = append(cookies, &Cookie{Name: name, Value: val}) @@ -234,6 +246,12 @@ func validCookieDomain(v string) bool { return false } +// validCookieExpires returns whether v is a valid cookie expires-value. +func validCookieExpires(t time.Time) bool { + // IETF RFC 6265 Section 5.1.1.5, the year must not be less than 1601 + return t.Year() >= 1601 +} + // isCookieDomainName returns whether s is a valid domain name or a valid // domain name with a leading dot '.'. It is almost a direct copy of // package net's isDomainName. diff --git a/libgo/go/net/http/cookie_test.go b/libgo/go/net/http/cookie_test.go index 95e61479a15..b3e54f8db32 100644 --- a/libgo/go/net/http/cookie_test.go +++ b/libgo/go/net/http/cookie_test.go @@ -56,6 +56,15 @@ var writeSetCookiesTests = []struct { &Cookie{Name: "cookie-9", Value: "expiring", Expires: time.Unix(1257894000, 0)}, "cookie-9=expiring; Expires=Tue, 10 Nov 2009 23:00:00 GMT", }, + // According to IETF 6265 Section 5.1.1.5, the year cannot be less than 1601 + { + &Cookie{Name: "cookie-10", Value: "expiring-1601", Expires: time.Date(1601, 1, 1, 1, 1, 1, 1, time.UTC)}, + "cookie-10=expiring-1601; Expires=Mon, 01 Jan 1601 01:01:01 GMT", + }, + { + &Cookie{Name: "cookie-11", Value: "invalid-expiry", Expires: time.Date(1600, 1, 1, 1, 1, 1, 1, time.UTC)}, + "cookie-11=invalid-expiry", + }, // The "special" cookies have values containing commas or spaces which // are disallowed by RFC 6265 but are common in the wild. { @@ -426,3 +435,92 @@ func TestCookieSanitizePath(t *testing.T) { t.Errorf("Expected substring %q in log output. Got:\n%s", sub, got) } } + +func BenchmarkCookieString(b *testing.B) { + const wantCookieString = `cookie-9=i3e01nf61b6t23bvfmplnanol3; Path=/restricted/; Domain=example.com; Expires=Tue, 10 Nov 2009 23:00:00 GMT; Max-Age=3600` + c := &Cookie{ + Name: "cookie-9", + Value: "i3e01nf61b6t23bvfmplnanol3", + Expires: time.Unix(1257894000, 0), + Path: "/restricted/", + Domain: ".example.com", + MaxAge: 3600, + } + var benchmarkCookieString string + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + benchmarkCookieString = c.String() + } + if have, want := benchmarkCookieString, wantCookieString; have != want { + b.Fatalf("Have: %v Want: %v", have, want) + } +} + +func BenchmarkReadSetCookies(b *testing.B) { + header := Header{ + "Set-Cookie": { + "NID=99=YsDT5i3E-CXax-; expires=Wed, 23-Nov-2011 01:05:03 GMT; path=/; domain=.google.ch; HttpOnly", + ".ASPXAUTH=7E3AA; expires=Wed, 07-Mar-2012 14:25:06 GMT; path=/; HttpOnly", + }, + } + wantCookies := []*Cookie{ + { + Name: "NID", + Value: "99=YsDT5i3E-CXax-", + Path: "/", + Domain: ".google.ch", + HttpOnly: true, + Expires: time.Date(2011, 11, 23, 1, 5, 3, 0, time.UTC), + RawExpires: "Wed, 23-Nov-2011 01:05:03 GMT", + Raw: "NID=99=YsDT5i3E-CXax-; expires=Wed, 23-Nov-2011 01:05:03 GMT; path=/; domain=.google.ch; HttpOnly", + }, + { + Name: ".ASPXAUTH", + Value: "7E3AA", + Path: "/", + Expires: time.Date(2012, 3, 7, 14, 25, 6, 0, time.UTC), + RawExpires: "Wed, 07-Mar-2012 14:25:06 GMT", + HttpOnly: true, + Raw: ".ASPXAUTH=7E3AA; expires=Wed, 07-Mar-2012 14:25:06 GMT; path=/; HttpOnly", + }, + } + var c []*Cookie + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + c = readSetCookies(header) + } + if !reflect.DeepEqual(c, wantCookies) { + b.Fatalf("readSetCookies:\nhave: %s\nwant: %s\n", toJSON(c), toJSON(wantCookies)) + } +} + +func BenchmarkReadCookies(b *testing.B) { + header := Header{ + "Cookie": { + `de=; client_region=0; rpld1=0:hispeed.ch|20:che|21:zh|22:zurich|23:47.36|24:8.53|; rpld0=1:08|; backplane-channel=newspaper.com:1471; devicetype=0; osfam=0; rplmct=2; s_pers=%20s_vmonthnum%3D1472680800496%2526vn%253D1%7C1472680800496%3B%20s_nr%3D1471686767664-New%7C1474278767664%3B%20s_lv%3D1471686767669%7C1566294767669%3B%20s_lv_s%3DFirst%2520Visit%7C1471688567669%3B%20s_monthinvisit%3Dtrue%7C1471688567677%3B%20gvp_p5%3Dsports%253Ablog%253Aearly-lead%2520-%2520184693%2520-%252020160820%2520-%2520u-s%7C1471688567681%3B%20gvp_p51%3Dwp%2520-%2520sports%7C1471688567684%3B; s_sess=%20s_wp_ep%3Dhomepage%3B%20s._ref%3Dhttps%253A%252F%252Fwww.google.ch%252F%3B%20s_cc%3Dtrue%3B%20s_ppvl%3Dsports%25253Ablog%25253Aearly-lead%252520-%252520184693%252520-%25252020160820%252520-%252520u-lawyer%252C12%252C12%252C502%252C1231%252C502%252C1680%252C1050%252C2%252CP%3B%20s_ppv%3Dsports%25253Ablog%25253Aearly-lead%252520-%252520184693%252520-%25252020160820%252520-%252520u-s-lawyer%252C12%252C12%252C502%252C1231%252C502%252C1680%252C1050%252C2%252CP%3B%20s_dslv%3DFirst%2520Visit%3B%20s_sq%3Dwpninewspapercom%253D%252526pid%25253Dsports%2525253Ablog%2525253Aearly-lead%25252520-%25252520184693%25252520-%2525252020160820%25252520-%25252520u-s%252526pidt%25253D1%252526oid%25253Dhttps%2525253A%2525252F%2525252Fwww.newspaper.com%2525252F%2525253Fnid%2525253Dmenu_nav_homepage%252526ot%25253DA%3B`, + }, + } + wantCookies := []*Cookie{ + {Name: "de", Value: ""}, + {Name: "client_region", Value: "0"}, + {Name: "rpld1", Value: "0:hispeed.ch|20:che|21:zh|22:zurich|23:47.36|24:8.53|"}, + {Name: "rpld0", Value: "1:08|"}, + {Name: "backplane-channel", Value: "newspaper.com:1471"}, + {Name: "devicetype", Value: "0"}, + {Name: "osfam", Value: "0"}, + {Name: "rplmct", Value: "2"}, + {Name: "s_pers", Value: "%20s_vmonthnum%3D1472680800496%2526vn%253D1%7C1472680800496%3B%20s_nr%3D1471686767664-New%7C1474278767664%3B%20s_lv%3D1471686767669%7C1566294767669%3B%20s_lv_s%3DFirst%2520Visit%7C1471688567669%3B%20s_monthinvisit%3Dtrue%7C1471688567677%3B%20gvp_p5%3Dsports%253Ablog%253Aearly-lead%2520-%2520184693%2520-%252020160820%2520-%2520u-s%7C1471688567681%3B%20gvp_p51%3Dwp%2520-%2520sports%7C1471688567684%3B"}, + {Name: "s_sess", Value: "%20s_wp_ep%3Dhomepage%3B%20s._ref%3Dhttps%253A%252F%252Fwww.google.ch%252F%3B%20s_cc%3Dtrue%3B%20s_ppvl%3Dsports%25253Ablog%25253Aearly-lead%252520-%252520184693%252520-%25252020160820%252520-%252520u-lawyer%252C12%252C12%252C502%252C1231%252C502%252C1680%252C1050%252C2%252CP%3B%20s_ppv%3Dsports%25253Ablog%25253Aearly-lead%252520-%252520184693%252520-%25252020160820%252520-%252520u-s-lawyer%252C12%252C12%252C502%252C1231%252C502%252C1680%252C1050%252C2%252CP%3B%20s_dslv%3DFirst%2520Visit%3B%20s_sq%3Dwpninewspapercom%253D%252526pid%25253Dsports%2525253Ablog%2525253Aearly-lead%25252520-%25252520184693%25252520-%2525252020160820%25252520-%25252520u-s%252526pidt%25253D1%252526oid%25253Dhttps%2525253A%2525252F%2525252Fwww.newspaper.com%2525252F%2525253Fnid%2525253Dmenu_nav_homepage%252526ot%25253DA%3B"}, + } + var c []*Cookie + b.ReportAllocs() + b.ResetTimer() + for i := 0; i < b.N; i++ { + c = readCookies(header, "") + } + if !reflect.DeepEqual(c, wantCookies) { + b.Fatalf("readCookies:\nhave: %s\nwant: %s\n", toJSON(c), toJSON(wantCookies)) + } +} diff --git a/libgo/go/net/http/cookiejar/dummy_publicsuffix_test.go b/libgo/go/net/http/cookiejar/dummy_publicsuffix_test.go new file mode 100644 index 00000000000..748ec5cc431 --- /dev/null +++ b/libgo/go/net/http/cookiejar/dummy_publicsuffix_test.go @@ -0,0 +1,23 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build ignore + +package cookiejar_test + +import "net/http/cookiejar" + +type dummypsl struct { + List cookiejar.PublicSuffixList +} + +func (dummypsl) PublicSuffix(domain string) string { + return domain +} + +func (dummypsl) String() string { + return "dummy" +} + +var publicsuffix = dummypsl{} diff --git a/libgo/go/net/http/cookiejar/example_test.go b/libgo/go/net/http/cookiejar/example_test.go new file mode 100644 index 00000000000..19a57465ff6 --- /dev/null +++ b/libgo/go/net/http/cookiejar/example_test.go @@ -0,0 +1,67 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build ignore + +package cookiejar_test + +import ( + "fmt" + "log" + "net/http" + "net/http/cookiejar" + "net/http/httptest" + "net/url" +) + +func ExampleNew() { + // Start a server to give us cookies. + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if cookie, err := r.Cookie("Flavor"); err != nil { + http.SetCookie(w, &http.Cookie{Name: "Flavor", Value: "Chocolate Chip"}) + } else { + cookie.Value = "Oatmeal Raisin" + http.SetCookie(w, cookie) + } + })) + defer ts.Close() + + u, err := url.Parse(ts.URL) + if err != nil { + log.Fatal(err) + } + + // All users of cookiejar should import "golang.org/x/net/publicsuffix" + jar, err := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List}) + if err != nil { + log.Fatal(err) + } + + client := &http.Client{ + Jar: jar, + } + + if _, err = client.Get(u.String()); err != nil { + log.Fatal(err) + } + + fmt.Println("After 1st request:") + for _, cookie := range jar.Cookies(u) { + fmt.Printf(" %s: %s\n", cookie.Name, cookie.Value) + } + + if _, err = client.Get(u.String()); err != nil { + log.Fatal(err) + } + + fmt.Println("After 2nd request:") + for _, cookie := range jar.Cookies(u) { + fmt.Printf(" %s: %s\n", cookie.Name, cookie.Value) + } + // Output: + // After 1st request: + // Flavor: Chocolate Chip + // After 2nd request: + // Flavor: Oatmeal Raisin +} diff --git a/libgo/go/net/http/cookiejar/jar.go b/libgo/go/net/http/cookiejar/jar.go index 0e0fac9286e..f89abbcd186 100644 --- a/libgo/go/net/http/cookiejar/jar.go +++ b/libgo/go/net/http/cookiejar/jar.go @@ -107,7 +107,7 @@ type entry struct { seqNum uint64 } -// Id returns the domain;path;name triple of e as an id. +// id returns the domain;path;name triple of e as an id. func (e *entry) id() string { return fmt.Sprintf("%s;%s;%s", e.Domain, e.Path, e.Name) } @@ -147,24 +147,6 @@ func hasDotSuffix(s, suffix string) bool { return len(s) > len(suffix) && s[len(s)-len(suffix)-1] == '.' && s[len(s)-len(suffix):] == suffix } -// byPathLength is a []entry sort.Interface that sorts according to RFC 6265 -// section 5.4 point 2: by longest path and then by earliest creation time. -type byPathLength []entry - -func (s byPathLength) Len() int { return len(s) } - -func (s byPathLength) Less(i, j int) bool { - if len(s[i].Path) != len(s[j].Path) { - return len(s[i].Path) > len(s[j].Path) - } - if !s[i].Creation.Equal(s[j].Creation) { - return s[i].Creation.Before(s[j].Creation) - } - return s[i].seqNum < s[j].seqNum -} - -func (s byPathLength) Swap(i, j int) { s[i], s[j] = s[j], s[i] } - // Cookies implements the Cookies method of the http.CookieJar interface. // // It returns an empty slice if the URL's scheme is not HTTP or HTTPS. @@ -221,7 +203,18 @@ func (j *Jar) cookies(u *url.URL, now time.Time) (cookies []*http.Cookie) { } } - sort.Sort(byPathLength(selected)) + // sort according to RFC 6265 section 5.4 point 2: by longest + // path and then by earliest creation time. + sort.Slice(selected, func(i, j int) bool { + s := selected + if len(s[i].Path) != len(s[j].Path) { + return len(s[i].Path) > len(s[j].Path) + } + if !s[i].Creation.Equal(s[j].Creation) { + return s[i].Creation.Before(s[j].Creation) + } + return s[i].seqNum < s[j].seqNum + }) for _, e := range selected { cookies = append(cookies, &http.Cookie{Name: e.Name, Value: e.Value}) } diff --git a/libgo/go/net/http/doc.go b/libgo/go/net/http/doc.go index 4ec8272f628..7855feaaa99 100644 --- a/libgo/go/net/http/doc.go +++ b/libgo/go/net/http/doc.go @@ -44,7 +44,8 @@ For control over proxies, TLS configuration, keep-alives, compression, and other settings, create a Transport: tr := &http.Transport{ - TLSClientConfig: &tls.Config{RootCAs: pool}, + MaxIdleConns: 10, + IdleConnTimeout: 30 * time.Second, DisableCompression: true, } client := &http.Client{Transport: tr} @@ -77,19 +78,30 @@ custom Server: } log.Fatal(s.ListenAndServe()) -The http package has transparent support for the HTTP/2 protocol when -using HTTPS. Programs that must disable HTTP/2 can do so by setting -Transport.TLSNextProto (for clients) or Server.TLSNextProto (for -servers) to a non-nil, empty map. Alternatively, the following GODEBUG -environment variables are currently supported: +Starting with Go 1.6, the http package has transparent support for the +HTTP/2 protocol when using HTTPS. Programs that must disable HTTP/2 +can do so by setting Transport.TLSNextProto (for clients) or +Server.TLSNextProto (for servers) to a non-nil, empty +map. Alternatively, the following GODEBUG environment variables are +currently supported: GODEBUG=http2client=0 # disable HTTP/2 client support GODEBUG=http2server=0 # disable HTTP/2 server support GODEBUG=http2debug=1 # enable verbose HTTP/2 debug logs GODEBUG=http2debug=2 # ... even more verbose, with frame dumps -The GODEBUG variables are not covered by Go's API compatibility promise. -HTTP/2 support was added in Go 1.6. Please report any issues instead of -disabling HTTP/2 support: https://golang.org/s/http2bug +The GODEBUG variables are not covered by Go's API compatibility +promise. Please report any issues before disabling HTTP/2 +support: https://golang.org/s/http2bug + +The http package's Transport and Server both automatically enable +HTTP/2 support for simple configurations. To enable HTTP/2 for more +complex configurations, to use lower-level HTTP/2 features, or to use +a newer version of Go's http2 package, import "golang.org/x/net/http2" +directly and use its ConfigureTransport and/or ConfigureServer +functions. Manually configuring HTTP/2 via the golang.org/x/net/http2 +package takes precedence over the net/http package's built-in HTTP/2 +support. + */ package http diff --git a/libgo/go/net/http/export_test.go b/libgo/go/net/http/export_test.go index 9c5ba0809ad..b61f58b2db4 100644 --- a/libgo/go/net/http/export_test.go +++ b/libgo/go/net/http/export_test.go @@ -24,6 +24,7 @@ var ( ExportErrRequestCanceled = errRequestCanceled ExportErrRequestCanceledConn = errRequestCanceledConn ExportServeFile = serveFile + ExportScanETag = scanETag ExportHttp2ConfigureServer = http2ConfigureServer ) @@ -87,6 +88,12 @@ func (t *Transport) IdleConnKeysForTesting() (keys []string) { return } +func (t *Transport) IdleConnKeyCountForTesting() int { + t.idleMu.Lock() + defer t.idleMu.Unlock() + return len(t.idleConn) +} + func (t *Transport) IdleConnStrsForTesting() []string { var ret []string t.idleMu.Lock() @@ -100,6 +107,24 @@ func (t *Transport) IdleConnStrsForTesting() []string { return ret } +func (t *Transport) IdleConnStrsForTesting_h2() []string { + var ret []string + noDialPool := t.h2transport.ConnPool.(http2noDialClientConnPool) + pool := noDialPool.http2clientConnPool + + pool.mu.Lock() + defer pool.mu.Unlock() + + for k, cc := range pool.conns { + for range cc { + ret = append(ret, k) + } + } + + sort.Strings(ret) + return ret +} + func (t *Transport) IdleConnCountForTesting(cacheKey string) int { t.idleMu.Lock() defer t.idleMu.Unlock() @@ -160,3 +185,17 @@ func ExportHttp2ConfigureTransport(t *Transport) error { t.h2transport = t2 return nil } + +var Export_shouldCopyHeaderOnRedirect = shouldCopyHeaderOnRedirect + +func (s *Server) ExportAllConnsIdle() bool { + s.mu.Lock() + defer s.mu.Unlock() + for c := range s.activeConn { + st, ok := c.curState.Load().(ConnState) + if !ok || st != StateIdle { + return false + } + } + return true +} diff --git a/libgo/go/net/http/fcgi/fcgi.go b/libgo/go/net/http/fcgi/fcgi.go index 337484139d3..5057d700981 100644 --- a/libgo/go/net/http/fcgi/fcgi.go +++ b/libgo/go/net/http/fcgi/fcgi.go @@ -3,8 +3,12 @@ // license that can be found in the LICENSE file. // Package fcgi implements the FastCGI protocol. +// +// The protocol is not an official standard and the original +// documentation is no longer online. See the Internet Archive's +// mirror at: https://web.archive.org/web/20150420080736/http://www.fastcgi.com/drupal/node/6?q=node/22 +// // Currently only the responder role is supported. -// The protocol is defined at http://www.fastcgi.com/drupal/node/6?q=node/22 package fcgi // This file defines the raw protocol and some utilities used by the child and diff --git a/libgo/go/net/http/fs.go b/libgo/go/net/http/fs.go index c7a58a61dff..bf63bb5441f 100644 --- a/libgo/go/net/http/fs.go +++ b/libgo/go/net/http/fs.go @@ -77,7 +77,7 @@ func dirList(w ResponseWriter, f File) { Error(w, "Error reading directory", StatusInternalServerError) return } - sort.Sort(byName(dirs)) + sort.Slice(dirs, func(i, j int) bool { return dirs[i].Name() < dirs[j].Name() }) w.Header().Set("Content-Type", "text/html; charset=utf-8") fmt.Fprintf(w, "<pre>\n") @@ -98,7 +98,8 @@ func dirList(w ResponseWriter, f File) { // ServeContent replies to the request using the content in the // provided ReadSeeker. The main benefit of ServeContent over io.Copy // is that it handles Range requests properly, sets the MIME type, and -// handles If-Modified-Since requests. +// handles If-Match, If-Unmodified-Since, If-None-Match, If-Modified-Since, +// and If-Range requests. // // If the response's Content-Type header is not set, ServeContent // first tries to deduce the type from name's file extension and, @@ -115,8 +116,8 @@ func dirList(w ResponseWriter, f File) { // The content's Seek method must work: ServeContent uses // a seek to the end of the content to determine its size. // -// If the caller has set w's ETag header, ServeContent uses it to -// handle requests using If-Range and If-None-Match. +// If the caller has set w's ETag header formatted per RFC 7232, section 2.3, +// ServeContent uses it to handle requests using If-Match, If-None-Match, or If-Range. // // Note that *os.File implements the io.ReadSeeker interface. func ServeContent(w ResponseWriter, req *Request, name string, modtime time.Time, content io.ReadSeeker) { @@ -140,15 +141,17 @@ func ServeContent(w ResponseWriter, req *Request, name string, modtime time.Time // users. var errSeeker = errors.New("seeker can't seek") +// errNoOverlap is returned by serveContent's parseRange if first-byte-pos of +// all of the byte-range-spec values is greater than the content size. +var errNoOverlap = errors.New("invalid range: failed to overlap") + // if name is empty, filename is unknown. (used for mime type, before sniffing) // if modtime.IsZero(), modtime is unknown. // content must be seeked to the beginning of the file. // The sizeFunc is called at most once. Its error, if any, is sent in the HTTP response. func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time, sizeFunc func() (int64, error), content io.ReadSeeker) { - if checkLastModified(w, r, modtime) { - return - } - rangeReq, done := checkETag(w, r, modtime) + setLastModified(w, modtime) + done, rangeReq := checkPreconditions(w, r, modtime) if done { return } @@ -189,6 +192,9 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time, if size >= 0 { ranges, err := parseRange(rangeReq, size) if err != nil { + if err == errNoOverlap { + w.Header().Set("Content-Range", fmt.Sprintf("bytes */%d", size)) + } Error(w, err.Error(), StatusRequestedRangeNotSatisfiable) return } @@ -263,90 +269,245 @@ func serveContent(w ResponseWriter, r *Request, name string, modtime time.Time, } } -var unixEpochTime = time.Unix(0, 0) - -// modtime is the modification time of the resource to be served, or IsZero(). -// return value is whether this request is now complete. -func checkLastModified(w ResponseWriter, r *Request, modtime time.Time) bool { - if modtime.IsZero() || modtime.Equal(unixEpochTime) { - // If the file doesn't have a modtime (IsZero), or the modtime - // is obviously garbage (Unix time == 0), then ignore modtimes - // and don't process the If-Modified-Since header. - return false +// scanETag determines if a syntactically valid ETag is present at s. If so, +// the ETag and remaining text after consuming ETag is returned. Otherwise, +// it returns "", "". +func scanETag(s string) (etag string, remain string) { + s = textproto.TrimString(s) + start := 0 + if strings.HasPrefix(s, "W/") { + start = 2 + } + if len(s[start:]) < 2 || s[start] != '"' { + return "", "" + } + // ETag is either W/"text" or "text". + // See RFC 7232 2.3. + for i := start + 1; i < len(s); i++ { + c := s[i] + switch { + // Character values allowed in ETags. + case c == 0x21 || c >= 0x23 && c <= 0x7E || c >= 0x80: + case c == '"': + return string(s[:i+1]), s[i+1:] + default: + break + } } + return "", "" +} - // The Date-Modified header truncates sub-second precision, so - // use mtime < t+1s instead of mtime <= t to check for unmodified. - if t, err := time.Parse(TimeFormat, r.Header.Get("If-Modified-Since")); err == nil && modtime.Before(t.Add(1*time.Second)) { - h := w.Header() - delete(h, "Content-Type") - delete(h, "Content-Length") - w.WriteHeader(StatusNotModified) - return true - } - w.Header().Set("Last-Modified", modtime.UTC().Format(TimeFormat)) - return false +// etagStrongMatch reports whether a and b match using strong ETag comparison. +// Assumes a and b are valid ETags. +func etagStrongMatch(a, b string) bool { + return a == b && a != "" && a[0] == '"' } -// checkETag implements If-None-Match and If-Range checks. -// -// The ETag or modtime must have been previously set in the -// ResponseWriter's headers. The modtime is only compared at second -// granularity and may be the zero value to mean unknown. -// -// The return value is the effective request "Range" header to use and -// whether this request is now considered done. -func checkETag(w ResponseWriter, r *Request, modtime time.Time) (rangeReq string, done bool) { - etag := w.Header().get("Etag") - rangeReq = r.Header.get("Range") - - // Invalidate the range request if the entity doesn't match the one - // the client was expecting. - // "If-Range: version" means "ignore the Range: header unless version matches the - // current file." - // We only support ETag versions. - // The caller must have set the ETag on the response already. - if ir := r.Header.get("If-Range"); ir != "" && ir != etag { - // The If-Range value is typically the ETag value, but it may also be - // the modtime date. See golang.org/issue/8367. - timeMatches := false - if !modtime.IsZero() { - if t, err := ParseTime(ir); err == nil && t.Unix() == modtime.Unix() { - timeMatches = true - } +// etagWeakMatch reports whether a and b match using weak ETag comparison. +// Assumes a and b are valid ETags. +func etagWeakMatch(a, b string) bool { + return strings.TrimPrefix(a, "W/") == strings.TrimPrefix(b, "W/") +} + +// condResult is the result of an HTTP request precondition check. +// See https://tools.ietf.org/html/rfc7232 section 3. +type condResult int + +const ( + condNone condResult = iota + condTrue + condFalse +) + +func checkIfMatch(w ResponseWriter, r *Request) condResult { + im := r.Header.Get("If-Match") + if im == "" { + return condNone + } + for { + im = textproto.TrimString(im) + if len(im) == 0 { + break + } + if im[0] == ',' { + im = im[1:] + continue + } + if im[0] == '*' { + return condTrue } - if !timeMatches { - rangeReq = "" + etag, remain := scanETag(im) + if etag == "" { + break + } + if etagStrongMatch(etag, w.Header().get("Etag")) { + return condTrue } + im = remain } - if inm := r.Header.get("If-None-Match"); inm != "" { - // Must know ETag. + return condFalse +} + +func checkIfUnmodifiedSince(w ResponseWriter, r *Request, modtime time.Time) condResult { + ius := r.Header.Get("If-Unmodified-Since") + if ius == "" || isZeroTime(modtime) { + return condNone + } + if t, err := ParseTime(ius); err == nil { + // The Date-Modified header truncates sub-second precision, so + // use mtime < t+1s instead of mtime <= t to check for unmodified. + if modtime.Before(t.Add(1 * time.Second)) { + return condTrue + } + return condFalse + } + return condNone +} + +func checkIfNoneMatch(w ResponseWriter, r *Request) condResult { + inm := r.Header.get("If-None-Match") + if inm == "" { + return condNone + } + buf := inm + for { + buf = textproto.TrimString(buf) + if len(buf) == 0 { + break + } + if buf[0] == ',' { + buf = buf[1:] + } + if buf[0] == '*' { + return condFalse + } + etag, remain := scanETag(buf) if etag == "" { - return rangeReq, false + break + } + if etagWeakMatch(etag, w.Header().get("Etag")) { + return condFalse } + buf = remain + } + return condTrue +} + +func checkIfModifiedSince(w ResponseWriter, r *Request, modtime time.Time) condResult { + if r.Method != "GET" && r.Method != "HEAD" { + return condNone + } + ims := r.Header.Get("If-Modified-Since") + if ims == "" || isZeroTime(modtime) { + return condNone + } + t, err := ParseTime(ims) + if err != nil { + return condNone + } + // The Date-Modified header truncates sub-second precision, so + // use mtime < t+1s instead of mtime <= t to check for unmodified. + if modtime.Before(t.Add(1 * time.Second)) { + return condFalse + } + return condTrue +} + +func checkIfRange(w ResponseWriter, r *Request, modtime time.Time) condResult { + if r.Method != "GET" { + return condNone + } + ir := r.Header.get("If-Range") + if ir == "" { + return condNone + } + etag, _ := scanETag(ir) + if etag != "" { + if etagStrongMatch(etag, w.Header().Get("Etag")) { + return condTrue + } else { + return condFalse + } + } + // The If-Range value is typically the ETag value, but it may also be + // the modtime date. See golang.org/issue/8367. + if modtime.IsZero() { + return condFalse + } + t, err := ParseTime(ir) + if err != nil { + return condFalse + } + if t.Unix() == modtime.Unix() { + return condTrue + } + return condFalse +} + +var unixEpochTime = time.Unix(0, 0) + +// isZeroTime reports whether t is obviously unspecified (either zero or Unix()=0). +func isZeroTime(t time.Time) bool { + return t.IsZero() || t.Equal(unixEpochTime) +} + +func setLastModified(w ResponseWriter, modtime time.Time) { + if !isZeroTime(modtime) { + w.Header().Set("Last-Modified", modtime.UTC().Format(TimeFormat)) + } +} - // TODO(bradfitz): non-GET/HEAD requests require more work: - // sending a different status code on matches, and - // also can't use weak cache validators (those with a "W/ - // prefix). But most users of ServeContent will be using - // it on GET or HEAD, so only support those for now. - if r.Method != "GET" && r.Method != "HEAD" { - return rangeReq, false +func writeNotModified(w ResponseWriter) { + // RFC 7232 section 4.1: + // a sender SHOULD NOT generate representation metadata other than the + // above listed fields unless said metadata exists for the purpose of + // guiding cache updates (e.g., Last-Modified might be useful if the + // response does not have an ETag field). + h := w.Header() + delete(h, "Content-Type") + delete(h, "Content-Length") + if h.Get("Etag") != "" { + delete(h, "Last-Modified") + } + w.WriteHeader(StatusNotModified) +} + +// checkPreconditions evaluates request preconditions and reports whether a precondition +// resulted in sending StatusNotModified or StatusPreconditionFailed. +func checkPreconditions(w ResponseWriter, r *Request, modtime time.Time) (done bool, rangeHeader string) { + // This function carefully follows RFC 7232 section 6. + ch := checkIfMatch(w, r) + if ch == condNone { + ch = checkIfUnmodifiedSince(w, r, modtime) + } + if ch == condFalse { + w.WriteHeader(StatusPreconditionFailed) + return true, "" + } + switch checkIfNoneMatch(w, r) { + case condFalse: + if r.Method == "GET" || r.Method == "HEAD" { + writeNotModified(w) + return true, "" + } else { + w.WriteHeader(StatusPreconditionFailed) + return true, "" } + case condNone: + if checkIfModifiedSince(w, r, modtime) == condFalse { + writeNotModified(w) + return true, "" + } + } - // TODO(bradfitz): deal with comma-separated or multiple-valued - // list of If-None-match values. For now just handle the common - // case of a single item. - if inm == etag || inm == "*" { - h := w.Header() - delete(h, "Content-Type") - delete(h, "Content-Length") - w.WriteHeader(StatusNotModified) - return "", true + rangeHeader = r.Header.get("Range") + if rangeHeader != "" { + if checkIfRange(w, r, modtime) == condFalse { + rangeHeader = "" } } - return rangeReq, false + return false, rangeHeader } // name is '/'-separated, not filepath.Separator. @@ -419,9 +580,11 @@ func serveFile(w ResponseWriter, r *Request, fs FileSystem, name string, redirec // Still a directory? (we didn't find an index.html file) if d.IsDir() { - if checkLastModified(w, r, d.ModTime()) { + if checkIfModifiedSince(w, r, d.ModTime()) == condFalse { + writeNotModified(w) return } + w.Header().Set("Last-Modified", d.ModTime().UTC().Format(TimeFormat)) dirList(w, f) return } @@ -543,6 +706,7 @@ func (r httpRange) mimeHeader(contentType string, size int64) textproto.MIMEHead } // parseRange parses a Range header string as per RFC 2616. +// errNoOverlap is returned if none of the ranges overlap. func parseRange(s string, size int64) ([]httpRange, error) { if s == "" { return nil, nil // header not present @@ -552,6 +716,7 @@ func parseRange(s string, size int64) ([]httpRange, error) { return nil, errors.New("invalid range") } var ranges []httpRange + noOverlap := false for _, ra := range strings.Split(s[len(b):], ",") { ra = strings.TrimSpace(ra) if ra == "" { @@ -577,9 +742,15 @@ func parseRange(s string, size int64) ([]httpRange, error) { r.length = size - r.start } else { i, err := strconv.ParseInt(start, 10, 64) - if err != nil || i >= size || i < 0 { + if err != nil || i < 0 { return nil, errors.New("invalid range") } + if i >= size { + // If the range begins after the size of the content, + // then it does not overlap. + noOverlap = true + continue + } r.start = i if end == "" { // If no end is specified, range extends to end of the file. @@ -597,6 +768,10 @@ func parseRange(s string, size int64) ([]httpRange, error) { } ranges = append(ranges, r) } + if noOverlap && len(ranges) == 0 { + // The specified ranges did not overlap with the content. + return nil, errNoOverlap + } return ranges, nil } @@ -628,9 +803,3 @@ func sumRangesSize(ranges []httpRange) (size int64) { } return } - -type byName []os.FileInfo - -func (s byName) Len() int { return len(s) } -func (s byName) Less(i, j int) bool { return s[i].Name() < s[j].Name() } -func (s byName) Swap(i, j int) { s[i], s[j] = s[j], s[i] } diff --git a/libgo/go/net/http/fs_test.go b/libgo/go/net/http/fs_test.go index 22be3899223..bba56821156 100644 --- a/libgo/go/net/http/fs_test.go +++ b/libgo/go/net/http/fs_test.go @@ -68,6 +68,7 @@ var ServeFileRangeTests = []struct { } func TestServeFile(t *testing.T) { + setParallel(t) defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { ServeFile(w, r, "testdata/file") @@ -274,6 +275,7 @@ func TestFileServerEscapesNames(t *testing.T) { {`"'<>&`, `<a href="%22%27%3C%3E&">"'<>&</a>`}, {`?foo=bar#baz`, `<a href="%3Ffoo=bar%23baz">?foo=bar#baz</a>`}, {`<combo>?foo`, `<a href="%3Ccombo%3E%3Ffoo"><combo>?foo</a>`}, + {`foo:bar`, `<a href="./foo:bar">foo:bar</a>`}, } // We put each test file in its own directory in the fakeFS so we can look at it in isolation. @@ -765,6 +767,7 @@ func TestServeContent(t *testing.T) { reqHeader map[string]string wantLastMod string wantContentType string + wantContentRange string wantStatus int } htmlModTime := mustStat(t, "testdata/index.html").ModTime() @@ -782,8 +785,9 @@ func TestServeContent(t *testing.T) { wantStatus: 200, }, "not_modified_modtime": { - file: "testdata/style.css", - modtime: htmlModTime, + file: "testdata/style.css", + serveETag: `"foo"`, // Last-Modified sent only when no ETag + modtime: htmlModTime, reqHeader: map[string]string{ "If-Modified-Since": htmlModTime.UTC().Format(TimeFormat), }, @@ -792,6 +796,7 @@ func TestServeContent(t *testing.T) { "not_modified_modtime_with_contenttype": { file: "testdata/style.css", serveContentType: "text/css", // explicit content type + serveETag: `"foo"`, // Last-Modified sent only when no ETag modtime: htmlModTime, reqHeader: map[string]string{ "If-Modified-Since": htmlModTime.UTC().Format(TimeFormat), @@ -808,21 +813,62 @@ func TestServeContent(t *testing.T) { }, "not_modified_etag_no_seek": { content: panicOnSeek{nil}, // should never be called - serveETag: `"foo"`, + serveETag: `W/"foo"`, // If-None-Match uses weak ETag comparison reqHeader: map[string]string{ - "If-None-Match": `"foo"`, + "If-None-Match": `"baz", W/"foo"`, }, wantStatus: 304, }, + "if_none_match_mismatch": { + file: "testdata/style.css", + serveETag: `"foo"`, + reqHeader: map[string]string{ + "If-None-Match": `"Foo"`, + }, + wantStatus: 200, + wantContentType: "text/css; charset=utf-8", + }, "range_good": { file: "testdata/style.css", serveETag: `"A"`, reqHeader: map[string]string{ "Range": "bytes=0-4", }, - wantStatus: StatusPartialContent, + wantStatus: StatusPartialContent, + wantContentType: "text/css; charset=utf-8", + wantContentRange: "bytes 0-4/8", + }, + "range_match": { + file: "testdata/style.css", + serveETag: `"A"`, + reqHeader: map[string]string{ + "Range": "bytes=0-4", + "If-Range": `"A"`, + }, + wantStatus: StatusPartialContent, + wantContentType: "text/css; charset=utf-8", + wantContentRange: "bytes 0-4/8", + }, + "range_match_weak_etag": { + file: "testdata/style.css", + serveETag: `W/"A"`, + reqHeader: map[string]string{ + "Range": "bytes=0-4", + "If-Range": `W/"A"`, + }, + wantStatus: 200, wantContentType: "text/css; charset=utf-8", }, + "range_no_overlap": { + file: "testdata/style.css", + serveETag: `"A"`, + reqHeader: map[string]string{ + "Range": "bytes=10-20", + }, + wantStatus: StatusRequestedRangeNotSatisfiable, + wantContentType: "text/plain; charset=utf-8", + wantContentRange: "bytes */8", + }, // An If-Range resource for entity "A", but entity "B" is now current. // The Range request should be ignored. "range_no_match": { @@ -842,9 +888,10 @@ func TestServeContent(t *testing.T) { "Range": "bytes=0-4", "If-Range": "Wed, 25 Jun 2014 17:12:18 GMT", }, - wantStatus: StatusPartialContent, - wantContentType: "text/css; charset=utf-8", - wantLastMod: "Wed, 25 Jun 2014 17:12:18 GMT", + wantStatus: StatusPartialContent, + wantContentType: "text/css; charset=utf-8", + wantContentRange: "bytes 0-4/8", + wantLastMod: "Wed, 25 Jun 2014 17:12:18 GMT", }, "range_with_modtime_nanos": { file: "testdata/style.css", @@ -853,9 +900,10 @@ func TestServeContent(t *testing.T) { "Range": "bytes=0-4", "If-Range": "Wed, 25 Jun 2014 17:12:18 GMT", }, - wantStatus: StatusPartialContent, - wantContentType: "text/css; charset=utf-8", - wantLastMod: "Wed, 25 Jun 2014 17:12:18 GMT", + wantStatus: StatusPartialContent, + wantContentType: "text/css; charset=utf-8", + wantContentRange: "bytes 0-4/8", + wantLastMod: "Wed, 25 Jun 2014 17:12:18 GMT", }, "unix_zero_modtime": { content: strings.NewReader("<html>foo"), @@ -863,6 +911,62 @@ func TestServeContent(t *testing.T) { wantStatus: StatusOK, wantContentType: "text/html; charset=utf-8", }, + "ifmatch_matches": { + file: "testdata/style.css", + serveETag: `"A"`, + reqHeader: map[string]string{ + "If-Match": `"Z", "A"`, + }, + wantStatus: 200, + wantContentType: "text/css; charset=utf-8", + }, + "ifmatch_star": { + file: "testdata/style.css", + serveETag: `"A"`, + reqHeader: map[string]string{ + "If-Match": `*`, + }, + wantStatus: 200, + wantContentType: "text/css; charset=utf-8", + }, + "ifmatch_failed": { + file: "testdata/style.css", + serveETag: `"A"`, + reqHeader: map[string]string{ + "If-Match": `"B"`, + }, + wantStatus: 412, + wantContentType: "text/plain; charset=utf-8", + }, + "ifmatch_fails_on_weak_etag": { + file: "testdata/style.css", + serveETag: `W/"A"`, + reqHeader: map[string]string{ + "If-Match": `W/"A"`, + }, + wantStatus: 412, + wantContentType: "text/plain; charset=utf-8", + }, + "if_unmodified_since_true": { + file: "testdata/style.css", + modtime: htmlModTime, + reqHeader: map[string]string{ + "If-Unmodified-Since": htmlModTime.UTC().Format(TimeFormat), + }, + wantStatus: 200, + wantContentType: "text/css; charset=utf-8", + wantLastMod: htmlModTime.UTC().Format(TimeFormat), + }, + "if_unmodified_since_false": { + file: "testdata/style.css", + modtime: htmlModTime, + reqHeader: map[string]string{ + "If-Unmodified-Since": htmlModTime.Add(-2 * time.Second).UTC().Format(TimeFormat), + }, + wantStatus: 412, + wantContentType: "text/plain; charset=utf-8", + wantLastMod: htmlModTime.UTC().Format(TimeFormat), + }, } for testName, tt := range tests { var content io.ReadSeeker @@ -903,6 +1007,9 @@ func TestServeContent(t *testing.T) { if g, e := res.Header.Get("Content-Type"), tt.wantContentType; g != e { t.Errorf("test %q: content-type = %q, want %q", testName, g, e) } + if g, e := res.Header.Get("Content-Range"), tt.wantContentRange; g != e { + t.Errorf("test %q: content-range = %q, want %q", testName, g, e) + } if g, e := res.Header.Get("Last-Modified"), tt.wantLastMod; g != e { t.Errorf("test %q: last-modified = %q, want %q", testName, g, e) } @@ -958,6 +1065,7 @@ func TestServeContentErrorMessages(t *testing.T) { // verifies that sendfile is being used on Linux func TestLinuxSendfile(t *testing.T) { + setParallel(t) defer afterTest(t) if runtime.GOOS != "linux" { t.Skip("skipping; linux-only test") @@ -982,6 +1090,8 @@ func TestLinuxSendfile(t *testing.T) { // strace on the above platforms doesn't support sendfile64 // and will error out if we specify that with `-e trace='. syscalls = "sendfile" + case "mips64": + t.Skip("TODO: update this test to be robust against various versions of strace on mips64. See golang.org/issue/33430") } var buf bytes.Buffer @@ -1008,10 +1118,9 @@ func TestLinuxSendfile(t *testing.T) { Post(fmt.Sprintf("http://%s/quit", ln.Addr()), "", nil) child.Wait() - rx := regexp.MustCompile(`sendfile(64)?\(\d+,\s*\d+,\s*NULL,\s*\d+\)\s*=\s*\d+\s*\n`) - rxResume := regexp.MustCompile(`<\.\.\. sendfile(64)? resumed> \)\s*=\s*\d+\s*\n`) + rx := regexp.MustCompile(`sendfile(64)?\(\d+,\s*\d+,\s*NULL,\s*\d+`) out := buf.String() - if !rx.MatchString(out) && !rxResume.MatchString(out) { + if !rx.MatchString(out) { t.Errorf("no sendfile system call found in:\n%s", out) } } @@ -1090,3 +1199,26 @@ func (d fileServerCleanPathDir) Open(path string) (File, error) { } type panicOnSeek struct{ io.ReadSeeker } + +func Test_scanETag(t *testing.T) { + tests := []struct { + in string + wantETag string + wantRemain string + }{ + {`W/"etag-1"`, `W/"etag-1"`, ""}, + {`"etag-2"`, `"etag-2"`, ""}, + {`"etag-1", "etag-2"`, `"etag-1"`, `, "etag-2"`}, + {"", "", ""}, + {"", "", ""}, + {"W/", "", ""}, + {`W/"truc`, "", ""}, + {`w/"case-sensitive"`, "", ""}, + } + for _, test := range tests { + etag, remain := ExportScanETag(test.in) + if etag != test.wantETag || remain != test.wantRemain { + t.Errorf("scanETag(%q)=%q %q, want %q %q", test.in, etag, remain, test.wantETag, test.wantRemain) + } + } +} diff --git a/libgo/go/net/http/h2_bundle.go b/libgo/go/net/http/h2_bundle.go index 5826bb7d858..25fdf09d92b 100644 --- a/libgo/go/net/http/h2_bundle.go +++ b/libgo/go/net/http/h2_bundle.go @@ -1,5 +1,5 @@ // Code generated by golang.org/x/tools/cmd/bundle. -//go:generate bundle -o h2_bundle.go -prefix http2 golang.org/x/net/http2 +//go:generate bundle -o h2_bundle.go -prefix http2 -underscore golang.org/x/net/http2 // Package http2 implements the HTTP/2 protocol. // @@ -21,6 +21,7 @@ import ( "bytes" "compress/gzip" "context" + "crypto/rand" "crypto/tls" "encoding/binary" "errors" @@ -43,6 +44,7 @@ import ( "time" "golang_org/x/net/http2/hpack" + "golang_org/x/net/idna" "golang_org/x/net/lex/httplex" ) @@ -853,10 +855,12 @@ type http2Framer struct { // If the limit is hit, MetaHeadersFrame.Truncated is set true. MaxHeaderListSize uint32 - logReads bool + logReads, logWrites bool - debugFramer *http2Framer // only use for logging written writes - debugFramerBuf *bytes.Buffer + debugFramer *http2Framer // only use for logging written writes + debugFramerBuf *bytes.Buffer + debugReadLoggerf func(string, ...interface{}) + debugWriteLoggerf func(string, ...interface{}) } func (fr *http2Framer) maxHeaderListSize() uint32 { @@ -890,7 +894,7 @@ func (f *http2Framer) endWrite() error { byte(length>>16), byte(length>>8), byte(length)) - if http2logFrameWrites { + if f.logWrites { f.logWrite() } @@ -912,10 +916,10 @@ func (f *http2Framer) logWrite() { f.debugFramerBuf.Write(f.wbuf) fr, err := f.debugFramer.ReadFrame() if err != nil { - log.Printf("http2: Framer %p: failed to decode just-written frame", f) + f.debugWriteLoggerf("http2: Framer %p: failed to decode just-written frame", f) return } - log.Printf("http2: Framer %p: wrote %v", f, http2summarizeFrame(fr)) + f.debugWriteLoggerf("http2: Framer %p: wrote %v", f, http2summarizeFrame(fr)) } func (f *http2Framer) writeByte(v byte) { f.wbuf = append(f.wbuf, v) } @@ -936,9 +940,12 @@ const ( // NewFramer returns a Framer that writes frames to w and reads them from r. func http2NewFramer(w io.Writer, r io.Reader) *http2Framer { fr := &http2Framer{ - w: w, - r: r, - logReads: http2logFrameReads, + w: w, + r: r, + logReads: http2logFrameReads, + logWrites: http2logFrameWrites, + debugReadLoggerf: log.Printf, + debugWriteLoggerf: log.Printf, } fr.getReadBuf = func(size uint32) []byte { if cap(fr.readBuf) >= int(size) { @@ -1020,7 +1027,7 @@ func (fr *http2Framer) ReadFrame() (http2Frame, error) { return nil, err } if fr.logReads { - log.Printf("http2: Framer %p: read %v", fr, http2summarizeFrame(f)) + fr.debugReadLoggerf("http2: Framer %p: read %v", fr, http2summarizeFrame(f)) } if fh.Type == http2FrameHeaders && fr.ReadMetaHeaders != nil { return fr.readMetaFrame(f.(*http2HeadersFrame)) @@ -1254,7 +1261,7 @@ func (f *http2Framer) WriteSettings(settings ...http2Setting) error { return f.endWrite() } -// WriteSettings writes an empty SETTINGS frame with the ACK bit set. +// WriteSettingsAck writes an empty SETTINGS frame with the ACK bit set. // // It will perform exactly one Write to the underlying Writer. // It is the caller's responsibility to not call other Write methods concurrently. @@ -1920,8 +1927,8 @@ func (fr *http2Framer) readMetaFrame(hf *http2HeadersFrame) (*http2MetaHeadersFr hdec.SetEmitEnabled(true) hdec.SetMaxStringLength(fr.maxHeaderStringLen()) hdec.SetEmitFunc(func(hf hpack.HeaderField) { - if http2VerboseLogs && http2logFrameReads { - log.Printf("http2: decoded hpack field %+v", hf) + if http2VerboseLogs && fr.logReads { + fr.debugReadLoggerf("http2: decoded hpack field %+v", hf) } if !httplex.ValidHeaderFieldValue(hf.Value) { invalid = http2headerFieldValueError(hf.Value) @@ -2091,6 +2098,13 @@ type http2clientTrace httptrace.ClientTrace func http2reqContext(r *Request) context.Context { return r.Context() } +func (t *http2Transport) idleConnTimeout() time.Duration { + if t.t1 != nil { + return t.t1.IdleConnTimeout + } + return 0 +} + func http2setResponseUncompressed(res *Response) { res.Uncompressed = true } func http2traceGotConn(req *Request, cc *http2ClientConn) { @@ -2145,6 +2159,48 @@ func http2requestTrace(req *Request) *http2clientTrace { return (*http2clientTrace)(trace) } +// Ping sends a PING frame to the server and waits for the ack. +func (cc *http2ClientConn) Ping(ctx context.Context) error { + return cc.ping(ctx) +} + +func http2cloneTLSConfig(c *tls.Config) *tls.Config { return c.Clone() } + +var _ Pusher = (*http2responseWriter)(nil) + +// Push implements http.Pusher. +func (w *http2responseWriter) Push(target string, opts *PushOptions) error { + internalOpts := http2pushOptions{} + if opts != nil { + internalOpts.Method = opts.Method + internalOpts.Header = opts.Header + } + return w.push(target, internalOpts) +} + +func http2configureServer18(h1 *Server, h2 *http2Server) error { + if h2.IdleTimeout == 0 { + if h1.IdleTimeout != 0 { + h2.IdleTimeout = h1.IdleTimeout + } else { + h2.IdleTimeout = h1.ReadTimeout + } + } + return nil +} + +func http2shouldLogPanic(panicValue interface{}) bool { + return panicValue != nil && panicValue != ErrAbortHandler +} + +func http2reqGetBody(req *Request) func() (io.ReadCloser, error) { + return req.GetBody +} + +func http2reqBodyIsNoBody(body io.ReadCloser) bool { + return body == NoBody +} + var http2DebugGoroutines = os.Getenv("DEBUG_HTTP2_GOROUTINES") == "1" type http2goroutineLock uint64 @@ -2368,6 +2424,7 @@ var ( http2VerboseLogs bool http2logFrameWrites bool http2logFrameReads bool + http2inTests bool ) func init() { @@ -2409,13 +2466,23 @@ var ( type http2streamState int +// HTTP/2 stream states. +// +// See http://tools.ietf.org/html/rfc7540#section-5.1. +// +// For simplicity, the server code merges "reserved (local)" into +// "half-closed (remote)". This is one less state transition to track. +// The only downside is that we send PUSH_PROMISEs slightly less +// liberally than allowable. More discussion here: +// https://lists.w3.org/Archives/Public/ietf-http-wg/2016JulSep/0599.html +// +// "reserved (remote)" is omitted since the client code does not +// support server push. const ( http2stateIdle http2streamState = iota http2stateOpen http2stateHalfClosedLocal http2stateHalfClosedRemote - http2stateResvLocal - http2stateResvRemote http2stateClosed ) @@ -2424,8 +2491,6 @@ var http2stateName = [...]string{ http2stateOpen: "Open", http2stateHalfClosedLocal: "HalfClosedLocal", http2stateHalfClosedRemote: "HalfClosedRemote", - http2stateResvLocal: "ResvLocal", - http2stateResvRemote: "ResvRemote", http2stateClosed: "Closed", } @@ -2586,13 +2651,27 @@ func http2newBufferedWriter(w io.Writer) *http2bufferedWriter { return &http2bufferedWriter{w: w} } +// bufWriterPoolBufferSize is the size of bufio.Writer's +// buffers created using bufWriterPool. +// +// TODO: pick a less arbitrary value? this is a bit under +// (3 x typical 1500 byte MTU) at least. Other than that, +// not much thought went into it. +const http2bufWriterPoolBufferSize = 4 << 10 + var http2bufWriterPool = sync.Pool{ New: func() interface{} { - - return bufio.NewWriterSize(nil, 4<<10) + return bufio.NewWriterSize(nil, http2bufWriterPoolBufferSize) }, } +func (w *http2bufferedWriter) Available() int { + if w.bw == nil { + return http2bufWriterPoolBufferSize + } + return w.bw.Available() +} + func (w *http2bufferedWriter) Write(p []byte) (n int, err error) { if w.bw == nil { bw := http2bufWriterPool.Get().(*bufio.Writer) @@ -2686,6 +2765,19 @@ func (s *http2sorter) SortStrings(ss []string) { s.v = save } +// validPseudoPath reports whether v is a valid :path pseudo-header +// value. It must be either: +// +// *) a non-empty string starting with '/', but not with with "//", +// *) the string '*', for OPTIONS requests. +// +// For now this is only used a quick check for deciding when to clean +// up Opaque URLs before sending requests from the Transport. +// See golang.org/issue/16847 +func http2validPseudoPath(v string) bool { + return (len(v) > 0 && v[0] == '/' && (len(v) == 1 || v[1] != '/')) || v == "*" +} + // pipe is a goroutine-safe io.Reader/io.Writer pair. It's like // io.Pipe except there are no PipeReader/PipeWriter halves, and the // underlying buffer is an interface. (io.Pipe is always unbuffered) @@ -2882,6 +2974,15 @@ type http2Server struct { // PermitProhibitedCipherSuites, if true, permits the use of // cipher suites prohibited by the HTTP/2 spec. PermitProhibitedCipherSuites bool + + // IdleTimeout specifies how long until idle clients should be + // closed with a GOAWAY frame. PING frames are not considered + // activity for the purposes of IdleTimeout. + IdleTimeout time.Duration + + // NewWriteScheduler constructs a write scheduler for a connection. + // If nil, a default scheduler is chosen. + NewWriteScheduler func() http2WriteScheduler } func (s *http2Server) maxReadFrameSize() uint32 { @@ -2904,9 +3005,15 @@ func (s *http2Server) maxConcurrentStreams() uint32 { // // ConfigureServer must be called before s begins serving. func http2ConfigureServer(s *Server, conf *http2Server) error { + if s == nil { + panic("nil *http.Server") + } if conf == nil { conf = new(http2Server) } + if err := http2configureServer18(s, conf); err != nil { + return err + } if s.TLSConfig == nil { s.TLSConfig = new(tls.Config) @@ -2945,8 +3052,6 @@ func http2ConfigureServer(s *Server, conf *http2Server) error { s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, http2NextProtoTLS) } - s.TLSConfig.NextProtos = append(s.TLSConfig.NextProtos, "h2-14") - if s.TLSNextProto == nil { s.TLSNextProto = map[string]func(*Server, *tls.Conn, Handler){} } @@ -2960,7 +3065,6 @@ func http2ConfigureServer(s *Server, conf *http2Server) error { }) } s.TLSNextProto[http2NextProtoTLS] = protoHandler - s.TLSNextProto["h2-14"] = protoHandler return nil } @@ -3014,29 +3118,39 @@ func (s *http2Server) ServeConn(c net.Conn, opts *http2ServeConnOpts) { defer cancel() sc := &http2serverConn{ - srv: s, - hs: opts.baseConfig(), - conn: c, - baseCtx: baseCtx, - remoteAddrStr: c.RemoteAddr().String(), - bw: http2newBufferedWriter(c), - handler: opts.handler(), - streams: make(map[uint32]*http2stream), - readFrameCh: make(chan http2readFrameResult), - wantWriteFrameCh: make(chan http2frameWriteMsg, 8), - wroteFrameCh: make(chan http2frameWriteResult, 1), - bodyReadCh: make(chan http2bodyReadMsg), - doneServing: make(chan struct{}), - advMaxStreams: s.maxConcurrentStreams(), - writeSched: http2writeScheduler{ - maxFrameSize: http2initialMaxFrameSize, - }, + srv: s, + hs: opts.baseConfig(), + conn: c, + baseCtx: baseCtx, + remoteAddrStr: c.RemoteAddr().String(), + bw: http2newBufferedWriter(c), + handler: opts.handler(), + streams: make(map[uint32]*http2stream), + readFrameCh: make(chan http2readFrameResult), + wantWriteFrameCh: make(chan http2FrameWriteRequest, 8), + wantStartPushCh: make(chan http2startPushRequest, 8), + wroteFrameCh: make(chan http2frameWriteResult, 1), + bodyReadCh: make(chan http2bodyReadMsg), + doneServing: make(chan struct{}), + clientMaxStreams: math.MaxUint32, + advMaxStreams: s.maxConcurrentStreams(), initialWindowSize: http2initialWindowSize, + maxFrameSize: http2initialMaxFrameSize, headerTableSize: http2initialHeaderTableSize, serveG: http2newGoroutineLock(), pushEnabled: true, } + if sc.hs.WriteTimeout != 0 { + sc.conn.SetWriteDeadline(time.Time{}) + } + + if s.NewWriteScheduler != nil { + sc.writeSched = s.NewWriteScheduler() + } else { + sc.writeSched = http2NewRandomWriteScheduler() + } + sc.flow.add(http2initialWindowSize) sc.inflow.add(http2initialWindowSize) sc.hpackEncoder = hpack.NewEncoder(&sc.headerWriteBuf) @@ -3090,16 +3204,18 @@ type http2serverConn struct { handler Handler baseCtx http2contextContext framer *http2Framer - doneServing chan struct{} // closed when serverConn.serve ends - readFrameCh chan http2readFrameResult // written by serverConn.readFrames - wantWriteFrameCh chan http2frameWriteMsg // from handlers -> serve - wroteFrameCh chan http2frameWriteResult // from writeFrameAsync -> serve, tickles more frame writes - bodyReadCh chan http2bodyReadMsg // from handlers -> serve - testHookCh chan func(int) // code to run on the serve loop - flow http2flow // conn-wide (not stream-specific) outbound flow control - inflow http2flow // conn-wide inbound flow control - tlsState *tls.ConnectionState // shared by all handlers, like net/http + doneServing chan struct{} // closed when serverConn.serve ends + readFrameCh chan http2readFrameResult // written by serverConn.readFrames + wantWriteFrameCh chan http2FrameWriteRequest // from handlers -> serve + wantStartPushCh chan http2startPushRequest // from handlers -> serve + wroteFrameCh chan http2frameWriteResult // from writeFrameAsync -> serve, tickles more frame writes + bodyReadCh chan http2bodyReadMsg // from handlers -> serve + testHookCh chan func(int) // code to run on the serve loop + flow http2flow // conn-wide (not stream-specific) outbound flow control + inflow http2flow // conn-wide inbound flow control + tlsState *tls.ConnectionState // shared by all handlers, like net/http remoteAddrStr string + writeSched http2WriteScheduler // Everything following is owned by the serve loop; use serveG.check(): serveG http2goroutineLock // used to verify funcs are on serve() @@ -3109,22 +3225,27 @@ type http2serverConn struct { unackedSettings int // how many SETTINGS have we sent without ACKs? clientMaxStreams uint32 // SETTINGS_MAX_CONCURRENT_STREAMS from client (our PUSH_PROMISE limit) advMaxStreams uint32 // our SETTINGS_MAX_CONCURRENT_STREAMS advertised the client - curOpenStreams uint32 // client's number of open streams - maxStreamID uint32 // max ever seen + curClientStreams uint32 // number of open streams initiated by the client + curPushedStreams uint32 // number of open streams initiated by server push + maxClientStreamID uint32 // max ever seen from client (odd), or 0 if there have been no client requests + maxPushPromiseID uint32 // ID of the last push promise (even), or 0 if there have been no pushes streams map[uint32]*http2stream initialWindowSize int32 + maxFrameSize int32 headerTableSize uint32 peerMaxHeaderListSize uint32 // zero means unknown (default) canonHeader map[string]string // http2-lower-case -> Go-Canonical-Case - writingFrame bool // started write goroutine but haven't heard back on wroteFrameCh + writingFrame bool // started writing a frame (on serve goroutine or separate) + writingFrameAsync bool // started a frame on its own goroutine but haven't heard back on wroteFrameCh needsFrameFlush bool // last frame write wasn't a flush - writeSched http2writeScheduler - inGoAway bool // we've started to or sent GOAWAY - needToSendGoAway bool // we need to schedule a GOAWAY frame write + inGoAway bool // we've started to or sent GOAWAY + inFrameScheduleLoop bool // whether we're in the scheduleFrameWrite loop + needToSendGoAway bool // we need to schedule a GOAWAY frame write goAwayCode http2ErrCode shutdownTimerCh <-chan time.Time // nil until used shutdownTimer *time.Timer // nil until used - freeRequestBodyBuf []byte // if non-nil, a free initialWindowSize buffer for getRequestBodyBuf + idleTimer *time.Timer // nil if unused + idleTimerCh <-chan time.Time // nil if unused // Owned by the writeFrameAsync goroutine: headerWriteBuf bytes.Buffer @@ -3143,6 +3264,11 @@ func (sc *http2serverConn) maxHeaderListSize() uint32 { return uint32(n + typicalHeaders*perFieldOverhead) } +func (sc *http2serverConn) curOpenStreams() uint32 { + sc.serveG.check() + return sc.curClientStreams + sc.curPushedStreams +} + // stream represents a stream. This is the minimal metadata needed by // the serve goroutine. Most of the actual stream state is owned by // the http.Handler's goroutine in the responseWriter. Because the @@ -3168,11 +3294,10 @@ type http2stream struct { numTrailerValues int64 weight uint8 state http2streamState - sentReset bool // only true once detached from streams map - gotReset bool // only true once detacted from streams map - gotTrailerHeader bool // HEADER frame for trailers was seen - wroteHeaders bool // whether we wrote headers (not status 100) - reqBuf []byte + resetQueued bool // RST_STREAM queued for write; set by sc.resetStream + gotTrailerHeader bool // HEADER frame for trailers was seen + wroteHeaders bool // whether we wrote headers (not status 100) + reqBuf []byte // if non-nil, body pipe buffer to return later at EOF trailer Header // accumulated trailers reqTrailer Header // handler's Request.Trailer @@ -3195,8 +3320,14 @@ func (sc *http2serverConn) state(streamID uint32) (http2streamState, *http2strea return st.state, st } - if streamID <= sc.maxStreamID { - return http2stateClosed, nil + if streamID%2 == 1 { + if streamID <= sc.maxClientStreamID { + return http2stateClosed, nil + } + } else { + if streamID <= sc.maxPushPromiseID { + return http2stateClosed, nil + } } return http2stateIdle, nil } @@ -3328,17 +3459,17 @@ func (sc *http2serverConn) readFrames() { // frameWriteResult is the message passed from writeFrameAsync to the serve goroutine. type http2frameWriteResult struct { - wm http2frameWriteMsg // what was written (or attempted) - err error // result of the writeFrame call + wr http2FrameWriteRequest // what was written (or attempted) + err error // result of the writeFrame call } // writeFrameAsync runs in its own goroutine and writes a single frame // and then reports when it's done. // At most one goroutine can be running writeFrameAsync at a time per // serverConn. -func (sc *http2serverConn) writeFrameAsync(wm http2frameWriteMsg) { - err := wm.write.writeFrame(sc) - sc.wroteFrameCh <- http2frameWriteResult{wm, err} +func (sc *http2serverConn) writeFrameAsync(wr http2FrameWriteRequest) { + err := wr.write.writeFrame(sc) + sc.wroteFrameCh <- http2frameWriteResult{wr, err} } func (sc *http2serverConn) closeAllStreamsOnConnClose() { @@ -3382,7 +3513,7 @@ func (sc *http2serverConn) serve() { sc.vlogf("http2: server connection from %v on %p", sc.conn.RemoteAddr(), sc.hs) } - sc.writeFrame(http2frameWriteMsg{ + sc.writeFrame(http2FrameWriteRequest{ write: http2writeSettings{ {http2SettingMaxFrameSize, sc.srv.maxReadFrameSize()}, {http2SettingMaxConcurrentStreams, sc.advMaxStreams}, @@ -3399,6 +3530,17 @@ func (sc *http2serverConn) serve() { sc.setConnState(StateActive) sc.setConnState(StateIdle) + if sc.srv.IdleTimeout != 0 { + sc.idleTimer = time.NewTimer(sc.srv.IdleTimeout) + defer sc.idleTimer.Stop() + sc.idleTimerCh = sc.idleTimer.C + } + + var gracefulShutdownCh <-chan struct{} + if sc.hs != nil { + gracefulShutdownCh = http2h1ServerShutdownChan(sc.hs) + } + go sc.readFrames() settingsTimer := time.NewTimer(http2firstSettingsTimeout) @@ -3406,8 +3548,10 @@ func (sc *http2serverConn) serve() { for { loopNum++ select { - case wm := <-sc.wantWriteFrameCh: - sc.writeFrame(wm) + case wr := <-sc.wantWriteFrameCh: + sc.writeFrame(wr) + case spr := <-sc.wantStartPushCh: + sc.startPush(spr) case res := <-sc.wroteFrameCh: sc.wroteFrame(res) case res := <-sc.readFrameCh: @@ -3424,12 +3568,22 @@ func (sc *http2serverConn) serve() { case <-settingsTimer.C: sc.logf("timeout waiting for SETTINGS frames from %v", sc.conn.RemoteAddr()) return + case <-gracefulShutdownCh: + gracefulShutdownCh = nil + sc.startGracefulShutdown() case <-sc.shutdownTimerCh: sc.vlogf("GOAWAY close timer fired; closing conn from %v", sc.conn.RemoteAddr()) return + case <-sc.idleTimerCh: + sc.vlogf("connection is idle") + sc.goAway(http2ErrCodeNo) case fn := <-sc.testHookCh: fn(loopNum) } + + if sc.inGoAway && sc.curOpenStreams() == 0 && !sc.needToSendGoAway && !sc.writingFrame { + return + } } } @@ -3477,7 +3631,7 @@ func (sc *http2serverConn) writeDataFromHandler(stream *http2stream, data []byte ch := http2errChanPool.Get().(chan error) writeArg := http2writeDataPool.Get().(*http2writeData) *writeArg = http2writeData{stream.id, data, endStream} - err := sc.writeFrameFromHandler(http2frameWriteMsg{ + err := sc.writeFrameFromHandler(http2FrameWriteRequest{ write: writeArg, stream: stream, done: ch, @@ -3507,17 +3661,17 @@ func (sc *http2serverConn) writeDataFromHandler(stream *http2stream, data []byte return err } -// writeFrameFromHandler sends wm to sc.wantWriteFrameCh, but aborts +// writeFrameFromHandler sends wr to sc.wantWriteFrameCh, but aborts // if the connection has gone away. // // This must not be run from the serve goroutine itself, else it might // deadlock writing to sc.wantWriteFrameCh (which is only mildly // buffered and is read by serve itself). If you're on the serve // goroutine, call writeFrame instead. -func (sc *http2serverConn) writeFrameFromHandler(wm http2frameWriteMsg) error { +func (sc *http2serverConn) writeFrameFromHandler(wr http2FrameWriteRequest) error { sc.serveG.checkNotOn() select { - case sc.wantWriteFrameCh <- wm: + case sc.wantWriteFrameCh <- wr: return nil case <-sc.doneServing: @@ -3533,53 +3687,81 @@ func (sc *http2serverConn) writeFrameFromHandler(wm http2frameWriteMsg) error { // make it onto the wire // // If you're not on the serve goroutine, use writeFrameFromHandler instead. -func (sc *http2serverConn) writeFrame(wm http2frameWriteMsg) { +func (sc *http2serverConn) writeFrame(wr http2FrameWriteRequest) { sc.serveG.check() + // If true, wr will not be written and wr.done will not be signaled. var ignoreWrite bool - switch wm.write.(type) { + if wr.StreamID() != 0 { + _, isReset := wr.write.(http2StreamError) + if state, _ := sc.state(wr.StreamID()); state == http2stateClosed && !isReset { + ignoreWrite = true + } + } + + switch wr.write.(type) { case *http2writeResHeaders: - wm.stream.wroteHeaders = true + wr.stream.wroteHeaders = true case http2write100ContinueHeadersFrame: - if wm.stream.wroteHeaders { + if wr.stream.wroteHeaders { + + if wr.done != nil { + panic("wr.done != nil for write100ContinueHeadersFrame") + } ignoreWrite = true } } if !ignoreWrite { - sc.writeSched.add(wm) + sc.writeSched.Push(wr) } sc.scheduleFrameWrite() } -// startFrameWrite starts a goroutine to write wm (in a separate +// startFrameWrite starts a goroutine to write wr (in a separate // goroutine since that might block on the network), and updates the -// serve goroutine's state about the world, updated from info in wm. -func (sc *http2serverConn) startFrameWrite(wm http2frameWriteMsg) { +// serve goroutine's state about the world, updated from info in wr. +func (sc *http2serverConn) startFrameWrite(wr http2FrameWriteRequest) { sc.serveG.check() if sc.writingFrame { panic("internal error: can only be writing one frame at a time") } - st := wm.stream + st := wr.stream if st != nil { switch st.state { case http2stateHalfClosedLocal: - panic("internal error: attempt to send frame on half-closed-local stream") - case http2stateClosed: - if st.sentReset || st.gotReset { + switch wr.write.(type) { + case http2StreamError, http2handlerPanicRST, http2writeWindowUpdate: - sc.scheduleFrameWrite() - return + default: + panic(fmt.Sprintf("internal error: attempt to send frame on a half-closed-local stream: %v", wr)) } - panic(fmt.Sprintf("internal error: attempt to send a write %v on a closed stream", wm)) + case http2stateClosed: + panic(fmt.Sprintf("internal error: attempt to send frame on a closed stream: %v", wr)) + } + } + if wpp, ok := wr.write.(*http2writePushPromise); ok { + var err error + wpp.promisedID, err = wpp.allocatePromisedID() + if err != nil { + sc.writingFrameAsync = false + wr.replyToWriter(err) + return } } sc.writingFrame = true sc.needsFrameFlush = true - go sc.writeFrameAsync(wm) + if wr.write.staysWithinBuffer(sc.bw.Available()) { + sc.writingFrameAsync = false + err := wr.write.writeFrame(sc) + sc.wroteFrame(http2frameWriteResult{wr, err}) + } else { + sc.writingFrameAsync = true + go sc.writeFrameAsync(wr) + } } // errHandlerPanicked is the error given to any callers blocked in a read from @@ -3595,26 +3777,12 @@ func (sc *http2serverConn) wroteFrame(res http2frameWriteResult) { panic("internal error: expected to be already writing a frame") } sc.writingFrame = false + sc.writingFrameAsync = false - wm := res.wm - st := wm.stream - - closeStream := http2endsStream(wm.write) - - if _, ok := wm.write.(http2handlerPanicRST); ok { - sc.closeStream(st, http2errHandlerPanicked) - } + wr := res.wr - if ch := wm.done; ch != nil { - select { - case ch <- res.err: - default: - panic(fmt.Sprintf("unbuffered done channel passed in for type %T", wm.write)) - } - } - wm.write = nil - - if closeStream { + if http2writeEndsStream(wr.write) { + st := wr.stream if st == nil { panic("internal error: expecting non-nil stream") } @@ -3622,13 +3790,24 @@ func (sc *http2serverConn) wroteFrame(res http2frameWriteResult) { case http2stateOpen: st.state = http2stateHalfClosedLocal - errCancel := http2streamError(st.id, http2ErrCodeCancel) - sc.resetStream(errCancel) + sc.resetStream(http2streamError(st.id, http2ErrCodeCancel)) case http2stateHalfClosedRemote: sc.closeStream(st, http2errHandlerComplete) } + } else { + switch v := wr.write.(type) { + case http2StreamError: + + if st, ok := sc.streams[v.StreamID]; ok { + sc.closeStream(st, v) + } + case http2handlerPanicRST: + sc.closeStream(wr.stream, http2errHandlerPanicked) + } } + wr.replyToWriter(res.err) + sc.scheduleFrameWrite() } @@ -3646,47 +3825,68 @@ func (sc *http2serverConn) wroteFrame(res http2frameWriteResult) { // flush the write buffer. func (sc *http2serverConn) scheduleFrameWrite() { sc.serveG.check() - if sc.writingFrame { - return - } - if sc.needToSendGoAway { - sc.needToSendGoAway = false - sc.startFrameWrite(http2frameWriteMsg{ - write: &http2writeGoAway{ - maxStreamID: sc.maxStreamID, - code: sc.goAwayCode, - }, - }) - return - } - if sc.needToSendSettingsAck { - sc.needToSendSettingsAck = false - sc.startFrameWrite(http2frameWriteMsg{write: http2writeSettingsAck{}}) + if sc.writingFrame || sc.inFrameScheduleLoop { return } - if !sc.inGoAway { - if wm, ok := sc.writeSched.take(); ok { - sc.startFrameWrite(wm) - return + sc.inFrameScheduleLoop = true + for !sc.writingFrameAsync { + if sc.needToSendGoAway { + sc.needToSendGoAway = false + sc.startFrameWrite(http2FrameWriteRequest{ + write: &http2writeGoAway{ + maxStreamID: sc.maxClientStreamID, + code: sc.goAwayCode, + }, + }) + continue } + if sc.needToSendSettingsAck { + sc.needToSendSettingsAck = false + sc.startFrameWrite(http2FrameWriteRequest{write: http2writeSettingsAck{}}) + continue + } + if !sc.inGoAway || sc.goAwayCode == http2ErrCodeNo { + if wr, ok := sc.writeSched.Pop(); ok { + sc.startFrameWrite(wr) + continue + } + } + if sc.needsFrameFlush { + sc.startFrameWrite(http2FrameWriteRequest{write: http2flushFrameWriter{}}) + sc.needsFrameFlush = false + continue + } + break } - if sc.needsFrameFlush { - sc.startFrameWrite(http2frameWriteMsg{write: http2flushFrameWriter{}}) - sc.needsFrameFlush = false - return - } + sc.inFrameScheduleLoop = false +} + +// startGracefulShutdown sends a GOAWAY with ErrCodeNo to tell the +// client we're gracefully shutting down. The connection isn't closed +// until all current streams are done. +func (sc *http2serverConn) startGracefulShutdown() { + sc.goAwayIn(http2ErrCodeNo, 0) } func (sc *http2serverConn) goAway(code http2ErrCode) { sc.serveG.check() - if sc.inGoAway { - return - } + var forceCloseIn time.Duration if code != http2ErrCodeNo { - sc.shutDownIn(250 * time.Millisecond) + forceCloseIn = 250 * time.Millisecond } else { - sc.shutDownIn(1 * time.Second) + forceCloseIn = 1 * time.Second + } + sc.goAwayIn(code, forceCloseIn) +} + +func (sc *http2serverConn) goAwayIn(code http2ErrCode, forceCloseIn time.Duration) { + sc.serveG.check() + if sc.inGoAway { + return + } + if forceCloseIn != 0 { + sc.shutDownIn(forceCloseIn) } sc.inGoAway = true sc.needToSendGoAway = true @@ -3702,10 +3902,9 @@ func (sc *http2serverConn) shutDownIn(d time.Duration) { func (sc *http2serverConn) resetStream(se http2StreamError) { sc.serveG.check() - sc.writeFrame(http2frameWriteMsg{write: se}) + sc.writeFrame(http2FrameWriteRequest{write: se}) if st, ok := sc.streams[se.StreamID]; ok { - st.sentReset = true - sc.closeStream(st, se) + st.resetQueued = true } } @@ -3782,6 +3981,8 @@ func (sc *http2serverConn) processFrame(f http2Frame) error { return sc.processResetStream(f) case *http2PriorityFrame: return sc.processPriority(f) + case *http2GoAwayFrame: + return sc.processGoAway(f) case *http2PushPromiseFrame: return http2ConnectionError(http2ErrCodeProtocol) @@ -3801,7 +4002,10 @@ func (sc *http2serverConn) processPing(f *http2PingFrame) error { return http2ConnectionError(http2ErrCodeProtocol) } - sc.writeFrame(http2frameWriteMsg{write: http2writePingAck{f}}) + if sc.inGoAway && sc.goAwayCode != http2ErrCodeNo { + return nil + } + sc.writeFrame(http2FrameWriteRequest{write: http2writePingAck{f}}) return nil } @@ -3809,7 +4013,11 @@ func (sc *http2serverConn) processWindowUpdate(f *http2WindowUpdateFrame) error sc.serveG.check() switch { case f.StreamID != 0: - st := sc.streams[f.StreamID] + state, st := sc.state(f.StreamID) + if state == http2stateIdle { + + return http2ConnectionError(http2ErrCodeProtocol) + } if st == nil { return nil @@ -3835,7 +4043,6 @@ func (sc *http2serverConn) processResetStream(f *http2RSTStreamFrame) error { return http2ConnectionError(http2ErrCodeProtocol) } if st != nil { - st.gotReset = true st.cancelCtx() sc.closeStream(st, http2streamError(f.StreamID, f.ErrCode)) } @@ -3848,11 +4055,21 @@ func (sc *http2serverConn) closeStream(st *http2stream, err error) { panic(fmt.Sprintf("invariant; can't close stream in state %v", st.state)) } st.state = http2stateClosed - sc.curOpenStreams-- - if sc.curOpenStreams == 0 { - sc.setConnState(StateIdle) + if st.isPushed() { + sc.curPushedStreams-- + } else { + sc.curClientStreams-- } delete(sc.streams, st.id) + if len(sc.streams) == 0 { + sc.setConnState(StateIdle) + if sc.srv.IdleTimeout != 0 { + sc.idleTimer.Reset(sc.srv.IdleTimeout) + } + if http2h1ServerKeepAlivesDisabled(sc.hs) { + sc.startGracefulShutdown() + } + } if p := st.body; p != nil { sc.sendWindowUpdate(nil, p.Len()) @@ -3860,11 +4077,7 @@ func (sc *http2serverConn) closeStream(st *http2stream, err error) { p.CloseWithError(err) } st.cw.Close() - sc.writeSched.forgetStream(st.id) - if st.reqBuf != nil { - - sc.freeRequestBodyBuf = st.reqBuf - } + sc.writeSched.CloseStream(st.id) } func (sc *http2serverConn) processSettings(f *http2SettingsFrame) error { @@ -3904,7 +4117,7 @@ func (sc *http2serverConn) processSetting(s http2Setting) error { case http2SettingInitialWindowSize: return sc.processSettingInitialWindowSize(s.Val) case http2SettingMaxFrameSize: - sc.writeSched.maxFrameSize = s.Val + sc.maxFrameSize = int32(s.Val) case http2SettingMaxHeaderListSize: sc.peerMaxHeaderListSize = s.Val default: @@ -3933,11 +4146,18 @@ func (sc *http2serverConn) processSettingInitialWindowSize(val uint32) error { func (sc *http2serverConn) processData(f *http2DataFrame) error { sc.serveG.check() + if sc.inGoAway && sc.goAwayCode != http2ErrCodeNo { + return nil + } data := f.Data() id := f.Header().StreamID - st, ok := sc.streams[id] - if !ok || st.state != http2stateOpen || st.gotTrailerHeader { + state, st := sc.state(id) + if id == 0 || state == http2stateIdle { + + return http2ConnectionError(http2ErrCodeProtocol) + } + if st == nil || state != http2stateOpen || st.gotTrailerHeader || st.resetQueued { if sc.inflow.available() < int32(f.Length) { return http2streamError(id, http2ErrCodeFlowControl) @@ -3946,6 +4166,10 @@ func (sc *http2serverConn) processData(f *http2DataFrame) error { sc.inflow.take(int32(f.Length)) sc.sendWindowUpdate(nil, int(f.Length)) + if st != nil && st.resetQueued { + + return nil + } return http2streamError(id, http2ErrCodeStreamClosed) } if st.body == nil { @@ -3985,6 +4209,24 @@ func (sc *http2serverConn) processData(f *http2DataFrame) error { return nil } +func (sc *http2serverConn) processGoAway(f *http2GoAwayFrame) error { + sc.serveG.check() + if f.ErrCode != http2ErrCodeNo { + sc.logf("http2: received GOAWAY %+v, starting graceful shutdown", f) + } else { + sc.vlogf("http2: received GOAWAY %+v, starting graceful shutdown", f) + } + sc.startGracefulShutdown() + + sc.pushEnabled = false + return nil +} + +// isPushed reports whether the stream is server-initiated. +func (st *http2stream) isPushed() bool { + return st.id%2 == 0 +} + // endStream closes a Request.Body's pipe. It is called when a DATA // frame says a request body is over (or after trailers). func (st *http2stream) endStream() { @@ -4014,7 +4256,7 @@ func (st *http2stream) copyTrailersToHandlerRequest() { func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error { sc.serveG.check() - id := f.Header().StreamID + id := f.StreamID if sc.inGoAway { return nil @@ -4024,50 +4266,43 @@ func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error { return http2ConnectionError(http2ErrCodeProtocol) } - st := sc.streams[f.Header().StreamID] - if st != nil { + if st := sc.streams[f.StreamID]; st != nil { + if st.resetQueued { + + return nil + } return st.processTrailerHeaders(f) } - if id <= sc.maxStreamID { + if id <= sc.maxClientStreamID { return http2ConnectionError(http2ErrCodeProtocol) } - sc.maxStreamID = id + sc.maxClientStreamID = id - ctx, cancelCtx := http2contextWithCancel(sc.baseCtx) - st = &http2stream{ - sc: sc, - id: id, - state: http2stateOpen, - ctx: ctx, - cancelCtx: cancelCtx, - } - if f.StreamEnded() { - st.state = http2stateHalfClosedRemote + if sc.idleTimer != nil { + sc.idleTimer.Stop() } - st.cw.Init() - st.flow.conn = &sc.flow - st.flow.add(sc.initialWindowSize) - st.inflow.conn = &sc.inflow - st.inflow.add(http2initialWindowSize) + if sc.curClientStreams+1 > sc.advMaxStreams { + if sc.unackedSettings == 0 { - sc.streams[id] = st - if f.HasPriority() { - http2adjustStreamPriority(sc.streams, st.id, f.Priority) - } - sc.curOpenStreams++ - if sc.curOpenStreams == 1 { - sc.setConnState(StateActive) + return http2streamError(id, http2ErrCodeProtocol) + } + + return http2streamError(id, http2ErrCodeRefusedStream) } - if sc.curOpenStreams > sc.advMaxStreams { - if sc.unackedSettings == 0 { + initialState := http2stateOpen + if f.StreamEnded() { + initialState = http2stateHalfClosedRemote + } + st := sc.newStream(id, 0, initialState) - return http2streamError(st.id, http2ErrCodeProtocol) + if f.HasPriority() { + if err := http2checkPriority(f.StreamID, f.Priority); err != nil { + return err } - - return http2streamError(st.id, http2ErrCodeRefusedStream) + sc.writeSched.AdjustStream(st.id, f.Priority) } rw, req, err := sc.newWriterAndRequest(st, f) @@ -4085,10 +4320,14 @@ func (sc *http2serverConn) processHeaders(f *http2MetaHeadersFrame) error { if f.Truncated { handler = http2handleHeaderListTooLong - } else if err := http2checkValidHTTP2Request(req); err != nil { + } else if err := http2checkValidHTTP2RequestHeaders(req.Header); err != nil { handler = http2new400Handler(err) } + if sc.hs.ReadTimeout != 0 { + sc.conn.SetReadDeadline(time.Time{}) + } + go sc.runHandler(rw, req, handler) return nil } @@ -4121,90 +4360,138 @@ func (st *http2stream) processTrailerHeaders(f *http2MetaHeadersFrame) error { return nil } -func (sc *http2serverConn) processPriority(f *http2PriorityFrame) error { - http2adjustStreamPriority(sc.streams, f.StreamID, f.http2PriorityParam) +func http2checkPriority(streamID uint32, p http2PriorityParam) error { + if streamID == p.StreamDep { + + return http2streamError(streamID, http2ErrCodeProtocol) + } return nil } -func http2adjustStreamPriority(streams map[uint32]*http2stream, streamID uint32, priority http2PriorityParam) { - st, ok := streams[streamID] - if !ok { +func (sc *http2serverConn) processPriority(f *http2PriorityFrame) error { + if sc.inGoAway { + return nil + } + if err := http2checkPriority(f.StreamID, f.http2PriorityParam); err != nil { + return err + } + sc.writeSched.AdjustStream(f.StreamID, f.http2PriorityParam) + return nil +} - return +func (sc *http2serverConn) newStream(id, pusherID uint32, state http2streamState) *http2stream { + sc.serveG.check() + if id == 0 { + panic("internal error: cannot create stream with id 0") } - st.weight = priority.Weight - parent := streams[priority.StreamDep] - if parent == st { - return + ctx, cancelCtx := http2contextWithCancel(sc.baseCtx) + st := &http2stream{ + sc: sc, + id: id, + state: state, + ctx: ctx, + cancelCtx: cancelCtx, } + st.cw.Init() + st.flow.conn = &sc.flow + st.flow.add(sc.initialWindowSize) + st.inflow.conn = &sc.inflow + st.inflow.add(http2initialWindowSize) - for piter := parent; piter != nil; piter = piter.parent { - if piter == st { - parent.parent = st.parent - break - } + sc.streams[id] = st + sc.writeSched.OpenStream(st.id, http2OpenStreamOptions{PusherID: pusherID}) + if st.isPushed() { + sc.curPushedStreams++ + } else { + sc.curClientStreams++ } - st.parent = parent - if priority.Exclusive && (st.parent != nil || priority.StreamDep == 0) { - for _, openStream := range streams { - if openStream != st && openStream.parent == st.parent { - openStream.parent = st - } - } + if sc.curOpenStreams() == 1 { + sc.setConnState(StateActive) } + + return st } func (sc *http2serverConn) newWriterAndRequest(st *http2stream, f *http2MetaHeadersFrame) (*http2responseWriter, *Request, error) { sc.serveG.check() - method := f.PseudoValue("method") - path := f.PseudoValue("path") - scheme := f.PseudoValue("scheme") - authority := f.PseudoValue("authority") + rp := http2requestParam{ + method: f.PseudoValue("method"), + scheme: f.PseudoValue("scheme"), + authority: f.PseudoValue("authority"), + path: f.PseudoValue("path"), + } - isConnect := method == "CONNECT" + isConnect := rp.method == "CONNECT" if isConnect { - if path != "" || scheme != "" || authority == "" { + if rp.path != "" || rp.scheme != "" || rp.authority == "" { return nil, nil, http2streamError(f.StreamID, http2ErrCodeProtocol) } - } else if method == "" || path == "" || - (scheme != "https" && scheme != "http") { + } else if rp.method == "" || rp.path == "" || (rp.scheme != "https" && rp.scheme != "http") { return nil, nil, http2streamError(f.StreamID, http2ErrCodeProtocol) } bodyOpen := !f.StreamEnded() - if method == "HEAD" && bodyOpen { + if rp.method == "HEAD" && bodyOpen { return nil, nil, http2streamError(f.StreamID, http2ErrCodeProtocol) } - var tlsState *tls.ConnectionState // nil if not scheme https - if scheme == "https" { - tlsState = sc.tlsState + rp.header = make(Header) + for _, hf := range f.RegularFields() { + rp.header.Add(sc.canonicalHeader(hf.Name), hf.Value) + } + if rp.authority == "" { + rp.authority = rp.header.Get("Host") } - header := make(Header) - for _, hf := range f.RegularFields() { - header.Add(sc.canonicalHeader(hf.Name), hf.Value) + rw, req, err := sc.newWriterAndRequestNoBody(st, rp) + if err != nil { + return nil, nil, err } + if bodyOpen { + st.reqBuf = http2getRequestBodyBuf() + req.Body.(*http2requestBody).pipe = &http2pipe{ + b: &http2fixedBuffer{buf: st.reqBuf}, + } - if authority == "" { - authority = header.Get("Host") + if vv, ok := rp.header["Content-Length"]; ok { + req.ContentLength, _ = strconv.ParseInt(vv[0], 10, 64) + } else { + req.ContentLength = -1 + } } - needsContinue := header.Get("Expect") == "100-continue" + return rw, req, nil +} + +type http2requestParam struct { + method string + scheme, authority, path string + header Header +} + +func (sc *http2serverConn) newWriterAndRequestNoBody(st *http2stream, rp http2requestParam) (*http2responseWriter, *Request, error) { + sc.serveG.check() + + var tlsState *tls.ConnectionState // nil if not scheme https + if rp.scheme == "https" { + tlsState = sc.tlsState + } + + needsContinue := rp.header.Get("Expect") == "100-continue" if needsContinue { - header.Del("Expect") + rp.header.Del("Expect") } - if cookies := header["Cookie"]; len(cookies) > 1 { - header.Set("Cookie", strings.Join(cookies, "; ")) + if cookies := rp.header["Cookie"]; len(cookies) > 1 { + rp.header.Set("Cookie", strings.Join(cookies, "; ")) } // Setup Trailers var trailer Header - for _, v := range header["Trailer"] { + for _, v := range rp.header["Trailer"] { for _, key := range strings.Split(v, ",") { key = CanonicalHeaderKey(strings.TrimSpace(key)) switch key { @@ -4218,55 +4505,42 @@ func (sc *http2serverConn) newWriterAndRequest(st *http2stream, f *http2MetaHead } } } - delete(header, "Trailer") + delete(rp.header, "Trailer") - body := &http2requestBody{ - conn: sc, - stream: st, - needsContinue: needsContinue, - } var url_ *url.URL var requestURI string - if isConnect { - url_ = &url.URL{Host: authority} - requestURI = authority + if rp.method == "CONNECT" { + url_ = &url.URL{Host: rp.authority} + requestURI = rp.authority } else { var err error - url_, err = url.ParseRequestURI(path) + url_, err = url.ParseRequestURI(rp.path) if err != nil { - return nil, nil, http2streamError(f.StreamID, http2ErrCodeProtocol) + return nil, nil, http2streamError(st.id, http2ErrCodeProtocol) } - requestURI = path + requestURI = rp.path + } + + body := &http2requestBody{ + conn: sc, + stream: st, + needsContinue: needsContinue, } req := &Request{ - Method: method, + Method: rp.method, URL: url_, RemoteAddr: sc.remoteAddrStr, - Header: header, + Header: rp.header, RequestURI: requestURI, Proto: "HTTP/2.0", ProtoMajor: 2, ProtoMinor: 0, TLS: tlsState, - Host: authority, + Host: rp.authority, Body: body, Trailer: trailer, } req = http2requestWithContext(req, st.ctx) - if bodyOpen { - - buf := make([]byte, http2initialWindowSize) - - body.pipe = &http2pipe{ - b: &http2fixedBuffer{buf: buf}, - } - - if vv, ok := header["Content-Length"]; ok { - req.ContentLength, _ = strconv.ParseInt(vv[0], 10, 64) - } else { - req.ContentLength = -1 - } - } rws := http2responseWriterStatePool.Get().(*http2responseWriterState) bwSave := rws.bw @@ -4282,13 +4556,22 @@ func (sc *http2serverConn) newWriterAndRequest(st *http2stream, f *http2MetaHead return rw, req, nil } -func (sc *http2serverConn) getRequestBodyBuf() []byte { - sc.serveG.check() - if buf := sc.freeRequestBodyBuf; buf != nil { - sc.freeRequestBodyBuf = nil - return buf +var http2reqBodyCache = make(chan []byte, 8) + +func http2getRequestBodyBuf() []byte { + select { + case b := <-http2reqBodyCache: + return b + default: + return make([]byte, http2initialWindowSize) + } +} + +func http2putRequestBodyBuf(b []byte) { + select { + case http2reqBodyCache <- b: + default: } - return make([]byte, http2initialWindowSize) } // Run on its own goroutine. @@ -4298,15 +4581,17 @@ func (sc *http2serverConn) runHandler(rw *http2responseWriter, req *Request, han rw.rws.stream.cancelCtx() if didPanic { e := recover() - // Same as net/http: - const size = 64 << 10 - buf := make([]byte, size) - buf = buf[:runtime.Stack(buf, false)] - sc.writeFrameFromHandler(http2frameWriteMsg{ + sc.writeFrameFromHandler(http2FrameWriteRequest{ write: http2handlerPanicRST{rw.rws.stream.id}, stream: rw.rws.stream, }) - sc.logf("http2: panic serving %v: %v\n%s", sc.conn.RemoteAddr(), e, buf) + + if http2shouldLogPanic(e) { + const size = 64 << 10 + buf := make([]byte, size) + buf = buf[:runtime.Stack(buf, false)] + sc.logf("http2: panic serving %v: %v\n%s", sc.conn.RemoteAddr(), e, buf) + } return } rw.handlerDone() @@ -4334,7 +4619,7 @@ func (sc *http2serverConn) writeHeaders(st *http2stream, headerData *http2writeR errc = http2errChanPool.Get().(chan error) } - if err := sc.writeFrameFromHandler(http2frameWriteMsg{ + if err := sc.writeFrameFromHandler(http2FrameWriteRequest{ write: headerData, stream: st, done: errc, @@ -4357,7 +4642,7 @@ func (sc *http2serverConn) writeHeaders(st *http2stream, headerData *http2writeR // called from handler goroutines. func (sc *http2serverConn) write100ContinueHeaders(st *http2stream) { - sc.writeFrameFromHandler(http2frameWriteMsg{ + sc.writeFrameFromHandler(http2FrameWriteRequest{ write: http2write100ContinueHeadersFrame{st.id}, stream: st, }) @@ -4373,11 +4658,19 @@ type http2bodyReadMsg struct { // called from handler goroutines. // Notes that the handler for the given stream ID read n bytes of its body // and schedules flow control tokens to be sent. -func (sc *http2serverConn) noteBodyReadFromHandler(st *http2stream, n int) { +func (sc *http2serverConn) noteBodyReadFromHandler(st *http2stream, n int, err error) { sc.serveG.checkNotOn() - select { - case sc.bodyReadCh <- http2bodyReadMsg{st, n}: - case <-sc.doneServing: + if n > 0 { + select { + case sc.bodyReadCh <- http2bodyReadMsg{st, n}: + case <-sc.doneServing: + } + } + if err == io.EOF { + if buf := st.reqBuf; buf != nil { + st.reqBuf = nil + http2putRequestBodyBuf(buf) + } } } @@ -4419,7 +4712,7 @@ func (sc *http2serverConn) sendWindowUpdate32(st *http2stream, n int32) { if st != nil { streamID = st.id } - sc.writeFrame(http2frameWriteMsg{ + sc.writeFrame(http2FrameWriteRequest{ write: http2writeWindowUpdate{streamID: streamID, n: uint32(n)}, stream: st, }) @@ -4434,16 +4727,19 @@ func (sc *http2serverConn) sendWindowUpdate32(st *http2stream, n int32) { } } +// requestBody is the Handler's Request.Body type. +// Read and Close may be called concurrently. type http2requestBody struct { stream *http2stream conn *http2serverConn - closed bool + closed bool // for use by Close only + sawEOF bool // for use by Read only pipe *http2pipe // non-nil if we have a HTTP entity message body needsContinue bool // need to send a 100-continue } func (b *http2requestBody) Close() error { - if b.pipe != nil { + if b.pipe != nil && !b.closed { b.pipe.BreakWithError(http2errClosedBody) } b.closed = true @@ -4455,13 +4751,17 @@ func (b *http2requestBody) Read(p []byte) (n int, err error) { b.needsContinue = false b.conn.write100ContinueHeaders(b.stream) } - if b.pipe == nil { + if b.pipe == nil || b.sawEOF { return 0, io.EOF } n, err = b.pipe.Read(p) - if n > 0 { - b.conn.noteBodyReadFromHandler(b.stream, n) + if err == io.EOF { + b.sawEOF = true } + if b.conn == nil && http2inTests { + return + } + b.conn.noteBodyReadFromHandler(b.stream, n, err) return } @@ -4696,8 +4996,9 @@ func (w *http2responseWriter) CloseNotify() <-chan bool { if ch == nil { ch = make(chan bool, 1) rws.closeNotifierCh = ch + cw := rws.stream.cw go func() { - rws.stream.cw.Wait() + cw.Wait() ch <- true }() } @@ -4793,6 +5094,172 @@ func (w *http2responseWriter) handlerDone() { http2responseWriterStatePool.Put(rws) } +// Push errors. +var ( + http2ErrRecursivePush = errors.New("http2: recursive push not allowed") + http2ErrPushLimitReached = errors.New("http2: push would exceed peer's SETTINGS_MAX_CONCURRENT_STREAMS") +) + +// pushOptions is the internal version of http.PushOptions, which we +// cannot include here because it's only defined in Go 1.8 and later. +type http2pushOptions struct { + Method string + Header Header +} + +func (w *http2responseWriter) push(target string, opts http2pushOptions) error { + st := w.rws.stream + sc := st.sc + sc.serveG.checkNotOn() + + if st.isPushed() { + return http2ErrRecursivePush + } + + if opts.Method == "" { + opts.Method = "GET" + } + if opts.Header == nil { + opts.Header = Header{} + } + wantScheme := "http" + if w.rws.req.TLS != nil { + wantScheme = "https" + } + + u, err := url.Parse(target) + if err != nil { + return err + } + if u.Scheme == "" { + if !strings.HasPrefix(target, "/") { + return fmt.Errorf("target must be an absolute URL or an absolute path: %q", target) + } + u.Scheme = wantScheme + u.Host = w.rws.req.Host + } else { + if u.Scheme != wantScheme { + return fmt.Errorf("cannot push URL with scheme %q from request with scheme %q", u.Scheme, wantScheme) + } + if u.Host == "" { + return errors.New("URL must have a host") + } + } + for k := range opts.Header { + if strings.HasPrefix(k, ":") { + return fmt.Errorf("promised request headers cannot include pseudo header %q", k) + } + + switch strings.ToLower(k) { + case "content-length", "content-encoding", "trailer", "te", "expect", "host": + return fmt.Errorf("promised request headers cannot include %q", k) + } + } + if err := http2checkValidHTTP2RequestHeaders(opts.Header); err != nil { + return err + } + + if opts.Method != "GET" && opts.Method != "HEAD" { + return fmt.Errorf("method %q must be GET or HEAD", opts.Method) + } + + msg := http2startPushRequest{ + parent: st, + method: opts.Method, + url: u, + header: http2cloneHeader(opts.Header), + done: http2errChanPool.Get().(chan error), + } + + select { + case <-sc.doneServing: + return http2errClientDisconnected + case <-st.cw: + return http2errStreamClosed + case sc.wantStartPushCh <- msg: + } + + select { + case <-sc.doneServing: + return http2errClientDisconnected + case <-st.cw: + return http2errStreamClosed + case err := <-msg.done: + http2errChanPool.Put(msg.done) + return err + } +} + +type http2startPushRequest struct { + parent *http2stream + method string + url *url.URL + header Header + done chan error +} + +func (sc *http2serverConn) startPush(msg http2startPushRequest) { + sc.serveG.check() + + if msg.parent.state != http2stateOpen && msg.parent.state != http2stateHalfClosedRemote { + + msg.done <- http2errStreamClosed + return + } + + if !sc.pushEnabled { + msg.done <- ErrNotSupported + return + } + + allocatePromisedID := func() (uint32, error) { + sc.serveG.check() + + if !sc.pushEnabled { + return 0, ErrNotSupported + } + + if sc.curPushedStreams+1 > sc.clientMaxStreams { + return 0, http2ErrPushLimitReached + } + + if sc.maxPushPromiseID+2 >= 1<<31 { + sc.startGracefulShutdown() + return 0, http2ErrPushLimitReached + } + sc.maxPushPromiseID += 2 + promisedID := sc.maxPushPromiseID + + promised := sc.newStream(promisedID, msg.parent.id, http2stateHalfClosedRemote) + rw, req, err := sc.newWriterAndRequestNoBody(promised, http2requestParam{ + method: msg.method, + scheme: msg.url.Scheme, + authority: msg.url.Host, + path: msg.url.RequestURI(), + header: http2cloneHeader(msg.header), + }) + if err != nil { + + panic(fmt.Sprintf("newWriterAndRequestNoBody(%+v): %v", msg.url, err)) + } + + go sc.runHandler(rw, req, sc.handler.ServeHTTP) + return promisedID, nil + } + + sc.writeFrame(http2FrameWriteRequest{ + write: &http2writePushPromise{ + streamID: msg.parent.id, + method: msg.method, + url: msg.url, + h: msg.header, + allocatePromisedID: allocatePromisedID, + }, + stream: msg.parent, + done: msg.done, + }) +} + // foreachHeaderElement splits v according to the "#rule" construction // in RFC 2616 section 2.1 and calls fn for each non-empty element. func http2foreachHeaderElement(v string, fn func(string)) { @@ -4820,16 +5287,16 @@ var http2connHeaders = []string{ "Upgrade", } -// checkValidHTTP2Request checks whether req is a valid HTTP/2 request, +// checkValidHTTP2RequestHeaders checks whether h is a valid HTTP/2 request, // per RFC 7540 Section 8.1.2.2. // The returned error is reported to users. -func http2checkValidHTTP2Request(req *Request) error { - for _, h := range http2connHeaders { - if _, ok := req.Header[h]; ok { - return fmt.Errorf("request header %q is not valid in HTTP/2", h) +func http2checkValidHTTP2RequestHeaders(h Header) error { + for _, k := range http2connHeaders { + if _, ok := h[k]; ok { + return fmt.Errorf("request header %q is not valid in HTTP/2", k) } } - te := req.Header["Te"] + te := h["Te"] if len(te) > 0 && (len(te) > 1 || (te[0] != "trailers" && te[0] != "")) { return errors.New(`request header "TE" may only be "trailers" in HTTP/2`) } @@ -4877,6 +5344,45 @@ var http2badTrailer = map[string]bool{ "Www-Authenticate": true, } +// h1ServerShutdownChan returns a channel that will be closed when the +// provided *http.Server wants to shut down. +// +// This is a somewhat hacky way to get at http1 innards. It works +// when the http2 code is bundled into the net/http package in the +// standard library. The alternatives ended up making the cmd/go tool +// depend on http Servers. This is the lightest option for now. +// This is tested via the TestServeShutdown* tests in net/http. +func http2h1ServerShutdownChan(hs *Server) <-chan struct{} { + if fn := http2testh1ServerShutdownChan; fn != nil { + return fn(hs) + } + var x interface{} = hs + type I interface { + getDoneChan() <-chan struct{} + } + if hs, ok := x.(I); ok { + return hs.getDoneChan() + } + return nil +} + +// optional test hook for h1ServerShutdownChan. +var http2testh1ServerShutdownChan func(hs *Server) <-chan struct{} + +// h1ServerKeepAlivesDisabled reports whether hs has its keep-alives +// disabled. See comments on h1ServerShutdownChan above for why +// the code is written this way. +func http2h1ServerKeepAlivesDisabled(hs *Server) bool { + var x interface{} = hs + type I interface { + doKeepAlives() bool + } + if hs, ok := x.(I); ok { + return !hs.doKeepAlives() + } + return false +} + const ( // transportDefaultConnFlow is how many connection-level flow control // tokens we give the server at start-up, past the default 64k. @@ -4997,6 +5503,9 @@ type http2ClientConn struct { readerDone chan struct{} // closed on error readerErr error // set before readerDone is closed + idleTimeout time.Duration // or 0 for never + idleTimer *time.Timer + mu sync.Mutex // guards following cond *sync.Cond // hold mu; broadcast on flow/closed changes flow http2flow // our conn-level flow control quota (cs.flow is per stream) @@ -5007,6 +5516,7 @@ type http2ClientConn struct { goAwayDebug string // goAway frame's debug data, retained as a string streams map[uint32]*http2clientStream // client-initiated nextStreamID uint32 + pings map[[8]byte]chan struct{} // in flight ping data to notification channel bw *bufio.Writer br *bufio.Reader fr *http2Framer @@ -5033,6 +5543,7 @@ type http2clientStream struct { ID uint32 resc chan http2resAndError bufPipe http2pipe // buffered pipe with the flow-controlled response payload + startedWrite bool // started request body write; guarded by cc.mu requestedGzip bool on100 func() // optional code to run if get a 100 continue response @@ -5041,6 +5552,7 @@ type http2clientStream struct { bytesRemain int64 // -1 means unknown; owned by transportResponseBody.Read readErr error // sticky read error; owned by transportResponseBody.Read stopReqBody error // if non-nil, stop writing req body; guarded by cc.mu + didReset bool // whether we sent a RST_STREAM to the server; guarded by cc.mu peerReset chan struct{} // closed on peer reset resetErr error // populated before peerReset is closed @@ -5068,15 +5580,26 @@ func (cs *http2clientStream) awaitRequestCancel(req *Request) { } select { case <-req.Cancel: + cs.cancelStream() cs.bufPipe.CloseWithError(http2errRequestCanceled) - cs.cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) case <-ctx.Done(): + cs.cancelStream() cs.bufPipe.CloseWithError(ctx.Err()) - cs.cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) case <-cs.done: } } +func (cs *http2clientStream) cancelStream() { + cs.cc.mu.Lock() + didReset := cs.didReset + cs.didReset = true + cs.cc.mu.Unlock() + + if !didReset { + cs.cc.writeStreamReset(cs.ID, http2ErrCodeCancel, nil) + } +} + // checkResetOrDone reports any error sent in a RST_STREAM frame by the // server, or errStreamClosed if the stream is complete. func (cs *http2clientStream) checkResetOrDone() error { @@ -5133,14 +5656,22 @@ func (t *http2Transport) RoundTrip(req *Request) (*Response, error) { // authorityAddr returns a given authority (a host/IP, or host:port / ip:port) // and returns a host:port. The port 443 is added if needed. func http2authorityAddr(scheme string, authority string) (addr string) { - if _, _, err := net.SplitHostPort(authority); err == nil { - return authority + host, port, err := net.SplitHostPort(authority) + if err != nil { + port = "443" + if scheme == "http" { + port = "80" + } + host = authority + } + if a, err := idna.ToASCII(host); err == nil { + host = a } - port := "443" - if scheme == "http" { - port = "80" + + if strings.HasPrefix(host, "[") && strings.HasSuffix(host, "]") { + return host + ":" + port } - return net.JoinHostPort(authority, port) + return net.JoinHostPort(host, port) } // RoundTripOpt is like RoundTrip, but takes options. @@ -5158,8 +5689,10 @@ func (t *http2Transport) RoundTripOpt(req *Request, opt http2RoundTripOpt) (*Res } http2traceGotConn(req, cc) res, err := cc.RoundTrip(req) - if http2shouldRetryRequest(req, err) { - continue + if err != nil { + if req, err = http2shouldRetryRequest(req, err); err == nil { + continue + } } if err != nil { t.vlogf("RoundTrip failure: %v", err) @@ -5181,11 +5714,39 @@ func (t *http2Transport) CloseIdleConnections() { var ( http2errClientConnClosed = errors.New("http2: client conn is closed") http2errClientConnUnusable = errors.New("http2: client conn not usable") + + http2errClientConnGotGoAway = errors.New("http2: Transport received Server's graceful shutdown GOAWAY") + http2errClientConnGotGoAwayAfterSomeReqBody = errors.New("http2: Transport received Server's graceful shutdown GOAWAY; some request body already written") ) -func http2shouldRetryRequest(req *Request, err error) bool { +// shouldRetryRequest is called by RoundTrip when a request fails to get +// response headers. It is always called with a non-nil error. +// It returns either a request to retry (either the same request, or a +// modified clone), or an error if the request can't be replayed. +func http2shouldRetryRequest(req *Request, err error) (*Request, error) { + switch err { + default: + return nil, err + case http2errClientConnUnusable, http2errClientConnGotGoAway: + return req, nil + case http2errClientConnGotGoAwayAfterSomeReqBody: + + if req.Body == nil || http2reqBodyIsNoBody(req.Body) { + return req, nil + } - return err == http2errClientConnUnusable + getBody := http2reqGetBody(req) + if getBody == nil { + return nil, errors.New("http2: Transport: peer server initiated graceful shutdown after some of Request.Body was written; define Request.GetBody to avoid this error") + } + body, err := getBody() + if err != nil { + return nil, err + } + newReq := *req + newReq.Body = body + return &newReq, nil + } } func (t *http2Transport) dialClientConn(addr string, singleUse bool) (*http2ClientConn, error) { @@ -5203,7 +5764,7 @@ func (t *http2Transport) dialClientConn(addr string, singleUse bool) (*http2Clie func (t *http2Transport) newTLSConfig(host string) *tls.Config { cfg := new(tls.Config) if t.TLSClientConfig != nil { - *cfg = *t.TLSClientConfig + *cfg = *http2cloneTLSConfig(t.TLSClientConfig) } if !http2strSliceContains(cfg.NextProtos, http2NextProtoTLS) { cfg.NextProtos = append([]string{http2NextProtoTLS}, cfg.NextProtos...) @@ -5273,6 +5834,11 @@ func (t *http2Transport) newClientConn(c net.Conn, singleUse bool) (*http2Client streams: make(map[uint32]*http2clientStream), singleUse: singleUse, wantSettingsAck: true, + pings: make(map[[8]byte]chan struct{}), + } + if d := t.idleConnTimeout(); d != 0 { + cc.idleTimeout = d + cc.idleTimer = time.AfterFunc(d, cc.onIdleTimeout) } if http2VerboseLogs { t.vlogf("http2: Transport creating client conn %p to %v", cc, c.RemoteAddr()) @@ -5328,6 +5894,15 @@ func (cc *http2ClientConn) setGoAway(f *http2GoAwayFrame) { if old != nil && old.ErrCode != http2ErrCodeNo { cc.goAway.ErrCode = old.ErrCode } + last := f.LastStreamID + for streamID, cs := range cc.streams { + if streamID > last { + select { + case cs.resc <- http2resAndError{err: http2errClientConnGotGoAway}: + default: + } + } + } } func (cc *http2ClientConn) CanTakeNewRequest() bool { @@ -5345,6 +5920,16 @@ func (cc *http2ClientConn) canTakeNewRequestLocked() bool { cc.nextStreamID < math.MaxInt32 } +// onIdleTimeout is called from a time.AfterFunc goroutine. It will +// only be called when we're idle, but because we're coming from a new +// goroutine, there could be a new request coming in at the same time, +// so this simply calls the synchronized closeIfIdle to shut down this +// connection. The timer could just call closeIfIdle, but this is more +// clear. +func (cc *http2ClientConn) onIdleTimeout() { + cc.closeIfIdle() +} + func (cc *http2ClientConn) closeIfIdle() { cc.mu.Lock() if len(cc.streams) > 0 { @@ -5437,48 +6022,37 @@ func (cc *http2ClientConn) responseHeaderTimeout() time.Duration { // Certain headers are special-cased as okay but not transmitted later. func http2checkConnHeaders(req *Request) error { if v := req.Header.Get("Upgrade"); v != "" { - return errors.New("http2: invalid Upgrade request header") + return fmt.Errorf("http2: invalid Upgrade request header: %q", req.Header["Upgrade"]) } - if v := req.Header.Get("Transfer-Encoding"); (v != "" && v != "chunked") || len(req.Header["Transfer-Encoding"]) > 1 { - return errors.New("http2: invalid Transfer-Encoding request header") + if vv := req.Header["Transfer-Encoding"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && vv[0] != "chunked") { + return fmt.Errorf("http2: invalid Transfer-Encoding request header: %q", vv) } - if v := req.Header.Get("Connection"); (v != "" && v != "close" && v != "keep-alive") || len(req.Header["Connection"]) > 1 { - return errors.New("http2: invalid Connection request header") + if vv := req.Header["Connection"]; len(vv) > 0 && (len(vv) > 1 || vv[0] != "" && vv[0] != "close" && vv[0] != "keep-alive") { + return fmt.Errorf("http2: invalid Connection request header: %q", vv) } return nil } -func http2bodyAndLength(req *Request) (body io.Reader, contentLen int64) { - body = req.Body - if body == nil { - return nil, 0 +// actualContentLength returns a sanitized version of +// req.ContentLength, where 0 actually means zero (not unknown) and -1 +// means unknown. +func http2actualContentLength(req *Request) int64 { + if req.Body == nil { + return 0 } if req.ContentLength != 0 { - return req.Body, req.ContentLength - } - - // We have a body but a zero content length. Test to see if - // it's actually zero or just unset. - var buf [1]byte - n, rerr := body.Read(buf[:]) - if rerr != nil && rerr != io.EOF { - return http2errorReader{rerr}, -1 - } - if n == 1 { - - if rerr == io.EOF { - return bytes.NewReader(buf[:]), 1 - } - return io.MultiReader(bytes.NewReader(buf[:]), body), -1 + return req.ContentLength } - - return nil, 0 + return -1 } func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { if err := http2checkConnHeaders(req); err != nil { return nil, err } + if cc.idleTimer != nil { + cc.idleTimer.Stop() + } trailers, err := http2commaSeparatedTrailers(req) if err != nil { @@ -5486,9 +6060,6 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { } hasTrailers := trailers != "" - body, contentLen := http2bodyAndLength(req) - hasBody := body != nil - cc.mu.Lock() cc.lastActive = time.Now() if cc.closed || !cc.canTakeNewRequestLocked() { @@ -5496,6 +6067,10 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { return nil, http2errClientConnUnusable } + body := req.Body + hasBody := body != nil + contentLen := http2actualContentLength(req) + // TODO(bradfitz): this is a copy of the logic in net/http. Unify somewhere? var requestedGzip bool if !cc.t.disableCompression() && @@ -5561,6 +6136,13 @@ func (cc *http2ClientConn) RoundTrip(req *Request) (*Response, error) { cs.abortRequestBodyWrite(http2errStopReqBodyWrite) } if re.err != nil { + if re.err == http2errClientConnGotGoAway { + cc.mu.Lock() + if cs.startedWrite { + re.err = http2errClientConnGotGoAwayAfterSomeReqBody + } + cc.mu.Unlock() + } cc.forgetStreamID(cs.ID) return nil, re.err } @@ -5806,6 +6388,26 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail if host == "" { host = req.URL.Host } + host, err := httplex.PunycodeHostPort(host) + if err != nil { + return nil, err + } + + var path string + if req.Method != "CONNECT" { + path = req.URL.RequestURI() + if !http2validPseudoPath(path) { + orig := path + path = strings.TrimPrefix(path, req.URL.Scheme+"://"+host) + if !http2validPseudoPath(path) { + if req.URL.Opaque != "" { + return nil, fmt.Errorf("invalid request :path %q from URL.Opaque = %q", orig, req.URL.Opaque) + } else { + return nil, fmt.Errorf("invalid request :path %q", orig) + } + } + } + } for k, vv := range req.Header { if !httplex.ValidHeaderFieldName(k) { @@ -5821,8 +6423,8 @@ func (cc *http2ClientConn) encodeHeaders(req *Request, addGzipHeader bool, trail cc.writeHeader(":authority", host) cc.writeHeader(":method", req.Method) if req.Method != "CONNECT" { - cc.writeHeader(":path", req.URL.RequestURI()) - cc.writeHeader(":scheme", "https") + cc.writeHeader(":path", path) + cc.writeHeader(":scheme", req.URL.Scheme) } if trailers != "" { cc.writeHeader("trailer", trailers) @@ -5940,6 +6542,9 @@ func (cc *http2ClientConn) streamByID(id uint32, andRemove bool) *http2clientStr if andRemove && cs != nil && !cc.closed { cc.lastActive = time.Now() delete(cc.streams, id) + if len(cc.streams) == 0 && cc.idleTimer != nil { + cc.idleTimer.Reset(cc.idleTimeout) + } close(cs.done) cc.cond.Broadcast() } @@ -5996,6 +6601,10 @@ func (rl *http2clientConnReadLoop) cleanup() { defer cc.t.connPool().MarkDead(cc) defer close(cc.readerDone) + if cc.idleTimer != nil { + cc.idleTimer.Stop() + } + err := cc.readerErr cc.mu.Lock() if cc.goAway != nil && http2isEOFOrNetReadError(err) { @@ -6398,9 +7007,10 @@ func (rl *http2clientConnReadLoop) processData(f *http2DataFrame) error { cc.bw.Flush() cc.wmu.Unlock() } + didReset := cs.didReset cc.mu.Unlock() - if len(data) > 0 { + if len(data) > 0 && !didReset { if _, err := cs.bufPipe.Write(data); err != nil { rl.endStreamError(cs, err) return err @@ -6551,9 +7161,56 @@ func (rl *http2clientConnReadLoop) processResetStream(f *http2RSTStreamFrame) er return nil } +// Ping sends a PING frame to the server and waits for the ack. +// Public implementation is in go17.go and not_go17.go +func (cc *http2ClientConn) ping(ctx http2contextContext) error { + c := make(chan struct{}) + // Generate a random payload + var p [8]byte + for { + if _, err := rand.Read(p[:]); err != nil { + return err + } + cc.mu.Lock() + + if _, found := cc.pings[p]; !found { + cc.pings[p] = c + cc.mu.Unlock() + break + } + cc.mu.Unlock() + } + cc.wmu.Lock() + if err := cc.fr.WritePing(false, p); err != nil { + cc.wmu.Unlock() + return err + } + if err := cc.bw.Flush(); err != nil { + cc.wmu.Unlock() + return err + } + cc.wmu.Unlock() + select { + case <-c: + return nil + case <-ctx.Done(): + return ctx.Err() + case <-cc.readerDone: + + return cc.readerErr + } +} + func (rl *http2clientConnReadLoop) processPing(f *http2PingFrame) error { if f.IsAck() { + cc := rl.cc + cc.mu.Lock() + defer cc.mu.Unlock() + if c, ok := cc.pings[f.Data]; ok { + close(c) + delete(cc.pings, f.Data) + } return nil } cc := rl.cc @@ -6666,6 +7323,9 @@ func (t *http2Transport) getBodyWriterState(cs *http2clientStream, body io.Reade resc := make(chan error, 1) s.resc = resc s.fn = func() { + cs.cc.mu.Lock() + cs.startedWrite = true + cs.cc.mu.Unlock() resc <- cs.writeRequestBody(body, cs.req.Body) } s.delay = t.expectContinueTimeout() @@ -6728,6 +7388,11 @@ func http2isConnectionCloseRequest(req *Request) bool { // writeFramer is implemented by any type that is used to write frames. type http2writeFramer interface { writeFrame(http2writeContext) error + + // staysWithinBuffer reports whether this writer promises that + // it will only write less than or equal to size bytes, and it + // won't Flush the write context. + staysWithinBuffer(size int) bool } // writeContext is the interface needed by the various frame writer @@ -6749,9 +7414,10 @@ type http2writeContext interface { HeaderEncoder() (*hpack.Encoder, *bytes.Buffer) } -// endsStream reports whether the given frame writer w will locally -// close the stream. -func http2endsStream(w http2writeFramer) bool { +// writeEndsStream reports whether w writes a frame that will transition +// the stream to a half-closed local state. This returns false for RST_STREAM, +// which closes the entire stream (not just the local half). +func http2writeEndsStream(w http2writeFramer) bool { switch v := w.(type) { case *http2writeData: return v.endStream @@ -6759,7 +7425,7 @@ func http2endsStream(w http2writeFramer) bool { return v.endStream case nil: - panic("endsStream called on nil writeFramer") + panic("writeEndsStream called on nil writeFramer") } return false } @@ -6770,8 +7436,16 @@ func (http2flushFrameWriter) writeFrame(ctx http2writeContext) error { return ctx.Flush() } +func (http2flushFrameWriter) staysWithinBuffer(max int) bool { return false } + type http2writeSettings []http2Setting +func (s http2writeSettings) staysWithinBuffer(max int) bool { + const settingSize = 6 // uint16 + uint32 + return http2frameHeaderLen+settingSize*len(s) <= max + +} + func (s http2writeSettings) writeFrame(ctx http2writeContext) error { return ctx.Framer().WriteSettings([]http2Setting(s)...) } @@ -6791,6 +7465,8 @@ func (p *http2writeGoAway) writeFrame(ctx http2writeContext) error { return err } +func (*http2writeGoAway) staysWithinBuffer(max int) bool { return false } + type http2writeData struct { streamID uint32 p []byte @@ -6805,6 +7481,10 @@ func (w *http2writeData) writeFrame(ctx http2writeContext) error { return ctx.Framer().WriteData(w.streamID, w.endStream, w.p) } +func (w *http2writeData) staysWithinBuffer(max int) bool { + return http2frameHeaderLen+len(w.p) <= max +} + // handlerPanicRST is the message sent from handler goroutines when // the handler panics. type http2handlerPanicRST struct { @@ -6815,22 +7495,59 @@ func (hp http2handlerPanicRST) writeFrame(ctx http2writeContext) error { return ctx.Framer().WriteRSTStream(hp.StreamID, http2ErrCodeInternal) } +func (hp http2handlerPanicRST) staysWithinBuffer(max int) bool { return http2frameHeaderLen+4 <= max } + func (se http2StreamError) writeFrame(ctx http2writeContext) error { return ctx.Framer().WriteRSTStream(se.StreamID, se.Code) } +func (se http2StreamError) staysWithinBuffer(max int) bool { return http2frameHeaderLen+4 <= max } + type http2writePingAck struct{ pf *http2PingFrame } func (w http2writePingAck) writeFrame(ctx http2writeContext) error { return ctx.Framer().WritePing(true, w.pf.Data) } +func (w http2writePingAck) staysWithinBuffer(max int) bool { + return http2frameHeaderLen+len(w.pf.Data) <= max +} + type http2writeSettingsAck struct{} func (http2writeSettingsAck) writeFrame(ctx http2writeContext) error { return ctx.Framer().WriteSettingsAck() } +func (http2writeSettingsAck) staysWithinBuffer(max int) bool { return http2frameHeaderLen <= max } + +// splitHeaderBlock splits headerBlock into fragments so that each fragment fits +// in a single frame, then calls fn for each fragment. firstFrag/lastFrag are true +// for the first/last fragment, respectively. +func http2splitHeaderBlock(ctx http2writeContext, headerBlock []byte, fn func(ctx http2writeContext, frag []byte, firstFrag, lastFrag bool) error) error { + // For now we're lazy and just pick the minimum MAX_FRAME_SIZE + // that all peers must support (16KB). Later we could care + // more and send larger frames if the peer advertised it, but + // there's little point. Most headers are small anyway (so we + // generally won't have CONTINUATION frames), and extra frames + // only waste 9 bytes anyway. + const maxFrameSize = 16384 + + first := true + for len(headerBlock) > 0 { + frag := headerBlock + if len(frag) > maxFrameSize { + frag = frag[:maxFrameSize] + } + headerBlock = headerBlock[len(frag):] + if err := fn(ctx, frag, first, len(headerBlock) == 0); err != nil { + return err + } + first = false + } + return nil +} + // writeResHeaders is a request to write a HEADERS and 0+ CONTINUATION frames // for HTTP response headers or trailers from a server handler. type http2writeResHeaders struct { @@ -6852,6 +7569,11 @@ func http2encKV(enc *hpack.Encoder, k, v string) { enc.WriteField(hpack.HeaderField{Name: k, Value: v}) } +func (w *http2writeResHeaders) staysWithinBuffer(max int) bool { + + return false +} + func (w *http2writeResHeaders) writeFrame(ctx http2writeContext) error { enc, buf := ctx.HeaderEncoder() buf.Reset() @@ -6877,39 +7599,69 @@ func (w *http2writeResHeaders) writeFrame(ctx http2writeContext) error { panic("unexpected empty hpack") } - // For now we're lazy and just pick the minimum MAX_FRAME_SIZE - // that all peers must support (16KB). Later we could care - // more and send larger frames if the peer advertised it, but - // there's little point. Most headers are small anyway (so we - // generally won't have CONTINUATION frames), and extra frames - // only waste 9 bytes anyway. - const maxFrameSize = 16384 + return http2splitHeaderBlock(ctx, headerBlock, w.writeHeaderBlock) +} - first := true - for len(headerBlock) > 0 { - frag := headerBlock - if len(frag) > maxFrameSize { - frag = frag[:maxFrameSize] - } - headerBlock = headerBlock[len(frag):] - endHeaders := len(headerBlock) == 0 - var err error - if first { - first = false - err = ctx.Framer().WriteHeaders(http2HeadersFrameParam{ - StreamID: w.streamID, - BlockFragment: frag, - EndStream: w.endStream, - EndHeaders: endHeaders, - }) - } else { - err = ctx.Framer().WriteContinuation(w.streamID, endHeaders, frag) - } - if err != nil { - return err - } +func (w *http2writeResHeaders) writeHeaderBlock(ctx http2writeContext, frag []byte, firstFrag, lastFrag bool) error { + if firstFrag { + return ctx.Framer().WriteHeaders(http2HeadersFrameParam{ + StreamID: w.streamID, + BlockFragment: frag, + EndStream: w.endStream, + EndHeaders: lastFrag, + }) + } else { + return ctx.Framer().WriteContinuation(w.streamID, lastFrag, frag) + } +} + +// writePushPromise is a request to write a PUSH_PROMISE and 0+ CONTINUATION frames. +type http2writePushPromise struct { + streamID uint32 // pusher stream + method string // for :method + url *url.URL // for :scheme, :authority, :path + h Header + + // Creates an ID for a pushed stream. This runs on serveG just before + // the frame is written. The returned ID is copied to promisedID. + allocatePromisedID func() (uint32, error) + promisedID uint32 +} + +func (w *http2writePushPromise) staysWithinBuffer(max int) bool { + + return false +} + +func (w *http2writePushPromise) writeFrame(ctx http2writeContext) error { + enc, buf := ctx.HeaderEncoder() + buf.Reset() + + http2encKV(enc, ":method", w.method) + http2encKV(enc, ":scheme", w.url.Scheme) + http2encKV(enc, ":authority", w.url.Host) + http2encKV(enc, ":path", w.url.RequestURI()) + http2encodeHeaders(enc, w.h, nil) + + headerBlock := buf.Bytes() + if len(headerBlock) == 0 { + panic("unexpected empty hpack") + } + + return http2splitHeaderBlock(ctx, headerBlock, w.writeHeaderBlock) +} + +func (w *http2writePushPromise) writeHeaderBlock(ctx http2writeContext, frag []byte, firstFrag, lastFrag bool) error { + if firstFrag { + return ctx.Framer().WritePushPromise(http2PushPromiseParam{ + StreamID: w.streamID, + PromiseID: w.promisedID, + BlockFragment: frag, + EndHeaders: lastFrag, + }) + } else { + return ctx.Framer().WriteContinuation(w.streamID, lastFrag, frag) } - return nil } type http2write100ContinueHeadersFrame struct { @@ -6928,15 +7680,24 @@ func (w http2write100ContinueHeadersFrame) writeFrame(ctx http2writeContext) err }) } +func (w http2write100ContinueHeadersFrame) staysWithinBuffer(max int) bool { + + return 9+2*(len(":status")+len("100")) <= max +} + type http2writeWindowUpdate struct { streamID uint32 // or 0 for conn-level n uint32 } +func (wu http2writeWindowUpdate) staysWithinBuffer(max int) bool { return http2frameHeaderLen+4 <= max } + func (wu http2writeWindowUpdate) writeFrame(ctx http2writeContext) error { return ctx.Framer().WriteWindowUpdate(wu.streamID, wu.n) } +// encodeHeaders encodes an http.Header. If keys is not nil, then (k, h[k]) +// is encoded only only if k is in keys. func http2encodeHeaders(enc *hpack.Encoder, h Header, keys []string) { if keys == nil { sorter := http2sorterPool.Get().(*http2sorter) @@ -6966,14 +7727,53 @@ func http2encodeHeaders(enc *hpack.Encoder, h Header, keys []string) { } } -// frameWriteMsg is a request to write a frame. -type http2frameWriteMsg struct { +// WriteScheduler is the interface implemented by HTTP/2 write schedulers. +// Methods are never called concurrently. +type http2WriteScheduler interface { + // OpenStream opens a new stream in the write scheduler. + // It is illegal to call this with streamID=0 or with a streamID that is + // already open -- the call may panic. + OpenStream(streamID uint32, options http2OpenStreamOptions) + + // CloseStream closes a stream in the write scheduler. Any frames queued on + // this stream should be discarded. It is illegal to call this on a stream + // that is not open -- the call may panic. + CloseStream(streamID uint32) + + // AdjustStream adjusts the priority of the given stream. This may be called + // on a stream that has not yet been opened or has been closed. Note that + // RFC 7540 allows PRIORITY frames to be sent on streams in any state. See: + // https://tools.ietf.org/html/rfc7540#section-5.1 + AdjustStream(streamID uint32, priority http2PriorityParam) + + // Push queues a frame in the scheduler. In most cases, this will not be + // called with wr.StreamID()!=0 unless that stream is currently open. The one + // exception is RST_STREAM frames, which may be sent on idle or closed streams. + Push(wr http2FrameWriteRequest) + + // Pop dequeues the next frame to write. Returns false if no frames can + // be written. Frames with a given wr.StreamID() are Pop'd in the same + // order they are Push'd. + Pop() (wr http2FrameWriteRequest, ok bool) +} + +// OpenStreamOptions specifies extra options for WriteScheduler.OpenStream. +type http2OpenStreamOptions struct { + // PusherID is zero if the stream was initiated by the client. Otherwise, + // PusherID names the stream that pushed the newly opened stream. + PusherID uint32 +} + +// FrameWriteRequest is a request to write a frame. +type http2FrameWriteRequest struct { // write is the interface value that does the writing, once the - // writeScheduler (below) has decided to select this frame - // to write. The write functions are all defined in write.go. + // WriteScheduler has selected this frame to write. The write + // functions are all defined in write.go. write http2writeFramer - stream *http2stream // used for prioritization. nil for non-stream frames. + // stream is the stream on which this frame will be written. + // nil for non-stream frames like PING and SETTINGS. + stream *http2stream // done, if non-nil, must be a buffered channel with space for // 1 message and is sent the return value from write (or an @@ -6981,247 +7781,644 @@ type http2frameWriteMsg struct { done chan error } -// for debugging only: -func (wm http2frameWriteMsg) String() string { - var streamID uint32 - if wm.stream != nil { - streamID = wm.stream.id +// StreamID returns the id of the stream this frame will be written to. +// 0 is used for non-stream frames such as PING and SETTINGS. +func (wr http2FrameWriteRequest) StreamID() uint32 { + if wr.stream == nil { + if se, ok := wr.write.(http2StreamError); ok { + + return se.StreamID + } + return 0 + } + return wr.stream.id +} + +// DataSize returns the number of flow control bytes that must be consumed +// to write this entire frame. This is 0 for non-DATA frames. +func (wr http2FrameWriteRequest) DataSize() int { + if wd, ok := wr.write.(*http2writeData); ok { + return len(wd.p) + } + return 0 +} + +// Consume consumes min(n, available) bytes from this frame, where available +// is the number of flow control bytes available on the stream. Consume returns +// 0, 1, or 2 frames, where the integer return value gives the number of frames +// returned. +// +// If flow control prevents consuming any bytes, this returns (_, _, 0). If +// the entire frame was consumed, this returns (wr, _, 1). Otherwise, this +// returns (consumed, rest, 2), where 'consumed' contains the consumed bytes and +// 'rest' contains the remaining bytes. The consumed bytes are deducted from the +// underlying stream's flow control budget. +func (wr http2FrameWriteRequest) Consume(n int32) (http2FrameWriteRequest, http2FrameWriteRequest, int) { + var empty http2FrameWriteRequest + + wd, ok := wr.write.(*http2writeData) + if !ok || len(wd.p) == 0 { + return wr, empty, 1 + } + + allowed := wr.stream.flow.available() + if n < allowed { + allowed = n + } + if wr.stream.sc.maxFrameSize < allowed { + allowed = wr.stream.sc.maxFrameSize + } + if allowed <= 0 { + return empty, empty, 0 + } + if len(wd.p) > int(allowed) { + wr.stream.flow.take(allowed) + consumed := http2FrameWriteRequest{ + stream: wr.stream, + write: &http2writeData{ + streamID: wd.streamID, + p: wd.p[:allowed], + + endStream: false, + }, + + done: nil, + } + rest := http2FrameWriteRequest{ + stream: wr.stream, + write: &http2writeData{ + streamID: wd.streamID, + p: wd.p[allowed:], + endStream: wd.endStream, + }, + done: wr.done, + } + return consumed, rest, 2 } + + wr.stream.flow.take(int32(len(wd.p))) + return wr, empty, 1 +} + +// String is for debugging only. +func (wr http2FrameWriteRequest) String() string { var des string - if s, ok := wm.write.(fmt.Stringer); ok { + if s, ok := wr.write.(fmt.Stringer); ok { des = s.String() } else { - des = fmt.Sprintf("%T", wm.write) + des = fmt.Sprintf("%T", wr.write) } - return fmt.Sprintf("[frameWriteMsg stream=%d, ch=%v, type: %v]", streamID, wm.done != nil, des) + return fmt.Sprintf("[FrameWriteRequest stream=%d, ch=%v, writer=%v]", wr.StreamID(), wr.done != nil, des) } -// writeScheduler tracks pending frames to write, priorities, and decides -// the next one to use. It is not thread-safe. -type http2writeScheduler struct { - // zero are frames not associated with a specific stream. - // They're sent before any stream-specific freams. - zero http2writeQueue +// replyToWriter sends err to wr.done and panics if the send must block +// This does nothing if wr.done is nil. +func (wr *http2FrameWriteRequest) replyToWriter(err error) { + if wr.done == nil { + return + } + select { + case wr.done <- err: + default: + panic(fmt.Sprintf("unbuffered done channel passed in for type %T", wr.write)) + } + wr.write = nil +} - // maxFrameSize is the maximum size of a DATA frame - // we'll write. Must be non-zero and between 16K-16M. - maxFrameSize uint32 +// writeQueue is used by implementations of WriteScheduler. +type http2writeQueue struct { + s []http2FrameWriteRequest +} - // sq contains the stream-specific queues, keyed by stream ID. - // when a stream is idle, it's deleted from the map. - sq map[uint32]*http2writeQueue +func (q *http2writeQueue) empty() bool { return len(q.s) == 0 } - // canSend is a slice of memory that's reused between frame - // scheduling decisions to hold the list of writeQueues (from sq) - // which have enough flow control data to send. After canSend is - // built, the best is selected. - canSend []*http2writeQueue +func (q *http2writeQueue) push(wr http2FrameWriteRequest) { + q.s = append(q.s, wr) +} - // pool of empty queues for reuse. - queuePool []*http2writeQueue +func (q *http2writeQueue) shift() http2FrameWriteRequest { + if len(q.s) == 0 { + panic("invalid use of queue") + } + wr := q.s[0] + + copy(q.s, q.s[1:]) + q.s[len(q.s)-1] = http2FrameWriteRequest{} + q.s = q.s[:len(q.s)-1] + return wr } -func (ws *http2writeScheduler) putEmptyQueue(q *http2writeQueue) { - if len(q.s) != 0 { - panic("queue must be empty") +// consume consumes up to n bytes from q.s[0]. If the frame is +// entirely consumed, it is removed from the queue. If the frame +// is partially consumed, the frame is kept with the consumed +// bytes removed. Returns true iff any bytes were consumed. +func (q *http2writeQueue) consume(n int32) (http2FrameWriteRequest, bool) { + if len(q.s) == 0 { + return http2FrameWriteRequest{}, false } - ws.queuePool = append(ws.queuePool, q) + consumed, rest, numresult := q.s[0].Consume(n) + switch numresult { + case 0: + return http2FrameWriteRequest{}, false + case 1: + q.shift() + case 2: + q.s[0] = rest + } + return consumed, true } -func (ws *http2writeScheduler) getEmptyQueue() *http2writeQueue { - ln := len(ws.queuePool) +type http2writeQueuePool []*http2writeQueue + +// put inserts an unused writeQueue into the pool. +func (p *http2writeQueuePool) put(q *http2writeQueue) { + for i := range q.s { + q.s[i] = http2FrameWriteRequest{} + } + q.s = q.s[:0] + *p = append(*p, q) +} + +// get returns an empty writeQueue. +func (p *http2writeQueuePool) get() *http2writeQueue { + ln := len(*p) if ln == 0 { return new(http2writeQueue) } - q := ws.queuePool[ln-1] - ws.queuePool = ws.queuePool[:ln-1] + x := ln - 1 + q := (*p)[x] + (*p)[x] = nil + *p = (*p)[:x] return q } -func (ws *http2writeScheduler) empty() bool { return ws.zero.empty() && len(ws.sq) == 0 } +// RFC 7540, Section 5.3.5: the default weight is 16. +const http2priorityDefaultWeight = 15 // 16 = 15 + 1 -func (ws *http2writeScheduler) add(wm http2frameWriteMsg) { - st := wm.stream - if st == nil { - ws.zero.push(wm) +// PriorityWriteSchedulerConfig configures a priorityWriteScheduler. +type http2PriorityWriteSchedulerConfig struct { + // MaxClosedNodesInTree controls the maximum number of closed streams to + // retain in the priority tree. Setting this to zero saves a small amount + // of memory at the cost of performance. + // + // See RFC 7540, Section 5.3.4: + // "It is possible for a stream to become closed while prioritization + // information ... is in transit. ... This potentially creates suboptimal + // prioritization, since the stream could be given a priority that is + // different from what is intended. To avoid these problems, an endpoint + // SHOULD retain stream prioritization state for a period after streams + // become closed. The longer state is retained, the lower the chance that + // streams are assigned incorrect or default priority values." + MaxClosedNodesInTree int + + // MaxIdleNodesInTree controls the maximum number of idle streams to + // retain in the priority tree. Setting this to zero saves a small amount + // of memory at the cost of performance. + // + // See RFC 7540, Section 5.3.4: + // Similarly, streams that are in the "idle" state can be assigned + // priority or become a parent of other streams. This allows for the + // creation of a grouping node in the dependency tree, which enables + // more flexible expressions of priority. Idle streams begin with a + // default priority (Section 5.3.5). + MaxIdleNodesInTree int + + // ThrottleOutOfOrderWrites enables write throttling to help ensure that + // data is delivered in priority order. This works around a race where + // stream B depends on stream A and both streams are about to call Write + // to queue DATA frames. If B wins the race, a naive scheduler would eagerly + // write as much data from B as possible, but this is suboptimal because A + // is a higher-priority stream. With throttling enabled, we write a small + // amount of data from B to minimize the amount of bandwidth that B can + // steal from A. + ThrottleOutOfOrderWrites bool +} + +// NewPriorityWriteScheduler constructs a WriteScheduler that schedules +// frames by following HTTP/2 priorities as described in RFC 7340 Section 5.3. +// If cfg is nil, default options are used. +func http2NewPriorityWriteScheduler(cfg *http2PriorityWriteSchedulerConfig) http2WriteScheduler { + if cfg == nil { + + cfg = &http2PriorityWriteSchedulerConfig{ + MaxClosedNodesInTree: 10, + MaxIdleNodesInTree: 10, + ThrottleOutOfOrderWrites: false, + } + } + + ws := &http2priorityWriteScheduler{ + nodes: make(map[uint32]*http2priorityNode), + maxClosedNodesInTree: cfg.MaxClosedNodesInTree, + maxIdleNodesInTree: cfg.MaxIdleNodesInTree, + enableWriteThrottle: cfg.ThrottleOutOfOrderWrites, + } + ws.nodes[0] = &ws.root + if cfg.ThrottleOutOfOrderWrites { + ws.writeThrottleLimit = 1024 } else { - ws.streamQueue(st.id).push(wm) + ws.writeThrottleLimit = math.MaxInt32 } + return ws +} + +type http2priorityNodeState int + +const ( + http2priorityNodeOpen http2priorityNodeState = iota + http2priorityNodeClosed + http2priorityNodeIdle +) + +// priorityNode is a node in an HTTP/2 priority tree. +// Each node is associated with a single stream ID. +// See RFC 7540, Section 5.3. +type http2priorityNode struct { + q http2writeQueue // queue of pending frames to write + id uint32 // id of the stream, or 0 for the root of the tree + weight uint8 // the actual weight is weight+1, so the value is in [1,256] + state http2priorityNodeState // open | closed | idle + bytes int64 // number of bytes written by this node, or 0 if closed + subtreeBytes int64 // sum(node.bytes) of all nodes in this subtree + + // These links form the priority tree. + parent *http2priorityNode + kids *http2priorityNode // start of the kids list + prev, next *http2priorityNode // doubly-linked list of siblings } -func (ws *http2writeScheduler) streamQueue(streamID uint32) *http2writeQueue { - if q, ok := ws.sq[streamID]; ok { - return q +func (n *http2priorityNode) setParent(parent *http2priorityNode) { + if n == parent { + panic("setParent to self") } - if ws.sq == nil { - ws.sq = make(map[uint32]*http2writeQueue) + if n.parent == parent { + return + } + + if parent := n.parent; parent != nil { + if n.prev == nil { + parent.kids = n.next + } else { + n.prev.next = n.next + } + if n.next != nil { + n.next.prev = n.prev + } + } + + n.parent = parent + if parent == nil { + n.next = nil + n.prev = nil + } else { + n.next = parent.kids + n.prev = nil + if n.next != nil { + n.next.prev = n + } + parent.kids = n } - q := ws.getEmptyQueue() - ws.sq[streamID] = q - return q } -// take returns the most important frame to write and removes it from the scheduler. -// It is illegal to call this if the scheduler is empty or if there are no connection-level -// flow control bytes available. -func (ws *http2writeScheduler) take() (wm http2frameWriteMsg, ok bool) { - if ws.maxFrameSize == 0 { - panic("internal error: ws.maxFrameSize not initialized or invalid") +func (n *http2priorityNode) addBytes(b int64) { + n.bytes += b + for ; n != nil; n = n.parent { + n.subtreeBytes += b } +} - if !ws.zero.empty() { - return ws.zero.shift(), true +// walkReadyInOrder iterates over the tree in priority order, calling f for each node +// with a non-empty write queue. When f returns true, this funcion returns true and the +// walk halts. tmp is used as scratch space for sorting. +// +// f(n, openParent) takes two arguments: the node to visit, n, and a bool that is true +// if any ancestor p of n is still open (ignoring the root node). +func (n *http2priorityNode) walkReadyInOrder(openParent bool, tmp *[]*http2priorityNode, f func(*http2priorityNode, bool) bool) bool { + if !n.q.empty() && f(n, openParent) { + return true } - if len(ws.sq) == 0 { - return + if n.kids == nil { + return false + } + + if n.id != 0 { + openParent = openParent || (n.state == http2priorityNodeOpen) } - for id, q := range ws.sq { - if q.firstIsNoCost() { - return ws.takeFrom(id, q) + w := n.kids.weight + needSort := false + for k := n.kids.next; k != nil; k = k.next { + if k.weight != w { + needSort = true + break } } + if !needSort { + for k := n.kids; k != nil; k = k.next { + if k.walkReadyInOrder(openParent, tmp, f) { + return true + } + } + return false + } - if len(ws.canSend) != 0 { - panic("should be empty") + *tmp = (*tmp)[:0] + for n.kids != nil { + *tmp = append(*tmp, n.kids) + n.kids.setParent(nil) } - for _, q := range ws.sq { - if n := ws.streamWritableBytes(q); n > 0 { - ws.canSend = append(ws.canSend, q) + sort.Sort(http2sortPriorityNodeSiblings(*tmp)) + for i := len(*tmp) - 1; i >= 0; i-- { + (*tmp)[i].setParent(n) + } + for k := n.kids; k != nil; k = k.next { + if k.walkReadyInOrder(openParent, tmp, f) { + return true } } - if len(ws.canSend) == 0 { - return + return false +} + +type http2sortPriorityNodeSiblings []*http2priorityNode + +func (z http2sortPriorityNodeSiblings) Len() int { return len(z) } + +func (z http2sortPriorityNodeSiblings) Swap(i, k int) { z[i], z[k] = z[k], z[i] } + +func (z http2sortPriorityNodeSiblings) Less(i, k int) bool { + + wi, bi := float64(z[i].weight+1), float64(z[i].subtreeBytes) + wk, bk := float64(z[k].weight+1), float64(z[k].subtreeBytes) + if bi == 0 && bk == 0 { + return wi >= wk } - defer ws.zeroCanSend() + if bk == 0 { + return false + } + return bi/bk <= wi/wk +} - q := ws.canSend[0] +type http2priorityWriteScheduler struct { + // root is the root of the priority tree, where root.id = 0. + // The root queues control frames that are not associated with any stream. + root http2priorityNode - return ws.takeFrom(q.streamID(), q) + // nodes maps stream ids to priority tree nodes. + nodes map[uint32]*http2priorityNode + + // maxID is the maximum stream id in nodes. + maxID uint32 + + // lists of nodes that have been closed or are idle, but are kept in + // the tree for improved prioritization. When the lengths exceed either + // maxClosedNodesInTree or maxIdleNodesInTree, old nodes are discarded. + closedNodes, idleNodes []*http2priorityNode + + // From the config. + maxClosedNodesInTree int + maxIdleNodesInTree int + writeThrottleLimit int32 + enableWriteThrottle bool + + // tmp is scratch space for priorityNode.walkReadyInOrder to reduce allocations. + tmp []*http2priorityNode + + // pool of empty queues for reuse. + queuePool http2writeQueuePool } -// zeroCanSend is defered from take. -func (ws *http2writeScheduler) zeroCanSend() { - for i := range ws.canSend { - ws.canSend[i] = nil +func (ws *http2priorityWriteScheduler) OpenStream(streamID uint32, options http2OpenStreamOptions) { + + if curr := ws.nodes[streamID]; curr != nil { + if curr.state != http2priorityNodeIdle { + panic(fmt.Sprintf("stream %d already opened", streamID)) + } + curr.state = http2priorityNodeOpen + return + } + + parent := ws.nodes[options.PusherID] + if parent == nil { + parent = &ws.root + } + n := &http2priorityNode{ + q: *ws.queuePool.get(), + id: streamID, + weight: http2priorityDefaultWeight, + state: http2priorityNodeOpen, + } + n.setParent(parent) + ws.nodes[streamID] = n + if streamID > ws.maxID { + ws.maxID = streamID } - ws.canSend = ws.canSend[:0] } -// streamWritableBytes returns the number of DATA bytes we could write -// from the given queue's stream, if this stream/queue were -// selected. It is an error to call this if q's head isn't a -// *writeData. -func (ws *http2writeScheduler) streamWritableBytes(q *http2writeQueue) int32 { - wm := q.head() - ret := wm.stream.flow.available() - if ret == 0 { - return 0 +func (ws *http2priorityWriteScheduler) CloseStream(streamID uint32) { + if streamID == 0 { + panic("violation of WriteScheduler interface: cannot close stream 0") } - if int32(ws.maxFrameSize) < ret { - ret = int32(ws.maxFrameSize) + if ws.nodes[streamID] == nil { + panic(fmt.Sprintf("violation of WriteScheduler interface: unknown stream %d", streamID)) } - if ret == 0 { - panic("internal error: ws.maxFrameSize not initialized or invalid") + if ws.nodes[streamID].state != http2priorityNodeOpen { + panic(fmt.Sprintf("violation of WriteScheduler interface: stream %d already closed", streamID)) } - wd := wm.write.(*http2writeData) - if len(wd.p) < int(ret) { - ret = int32(len(wd.p)) + + n := ws.nodes[streamID] + n.state = http2priorityNodeClosed + n.addBytes(-n.bytes) + + q := n.q + ws.queuePool.put(&q) + n.q.s = nil + if ws.maxClosedNodesInTree > 0 { + ws.addClosedOrIdleNode(&ws.closedNodes, ws.maxClosedNodesInTree, n) + } else { + ws.removeNode(n) } - return ret } -func (ws *http2writeScheduler) takeFrom(id uint32, q *http2writeQueue) (wm http2frameWriteMsg, ok bool) { - wm = q.head() - - if wd, ok := wm.write.(*http2writeData); ok && len(wd.p) > 0 { - allowed := wm.stream.flow.available() - if allowed == 0 { +func (ws *http2priorityWriteScheduler) AdjustStream(streamID uint32, priority http2PriorityParam) { + if streamID == 0 { + panic("adjustPriority on root") + } - return http2frameWriteMsg{}, false + n := ws.nodes[streamID] + if n == nil { + if streamID <= ws.maxID || ws.maxIdleNodesInTree == 0 { + return } - if int32(ws.maxFrameSize) < allowed { - allowed = int32(ws.maxFrameSize) + ws.maxID = streamID + n = &http2priorityNode{ + q: *ws.queuePool.get(), + id: streamID, + weight: http2priorityDefaultWeight, + state: http2priorityNodeIdle, } + n.setParent(&ws.root) + ws.nodes[streamID] = n + ws.addClosedOrIdleNode(&ws.idleNodes, ws.maxIdleNodesInTree, n) + } - if len(wd.p) > int(allowed) { - wm.stream.flow.take(allowed) - chunk := wd.p[:allowed] - wd.p = wd.p[allowed:] + parent := ws.nodes[priority.StreamDep] + if parent == nil { + n.setParent(&ws.root) + n.weight = http2priorityDefaultWeight + return + } - return http2frameWriteMsg{ - stream: wm.stream, - write: &http2writeData{ - streamID: wd.streamID, - p: chunk, + if n == parent { + return + } - endStream: false, - }, + for x := parent.parent; x != nil; x = x.parent { + if x == n { + parent.setParent(n.parent) + break + } + } - done: nil, - }, true + if priority.Exclusive { + k := parent.kids + for k != nil { + next := k.next + if k != n { + k.setParent(n) + } + k = next } - wm.stream.flow.take(int32(len(wd.p))) } - q.shift() - if q.empty() { - ws.putEmptyQueue(q) - delete(ws.sq, id) + n.setParent(parent) + n.weight = priority.Weight +} + +func (ws *http2priorityWriteScheduler) Push(wr http2FrameWriteRequest) { + var n *http2priorityNode + if id := wr.StreamID(); id == 0 { + n = &ws.root + } else { + n = ws.nodes[id] + if n == nil { + + if wr.DataSize() > 0 { + panic("add DATA on non-open stream") + } + n = &ws.root + } } - return wm, true + n.q.push(wr) } -func (ws *http2writeScheduler) forgetStream(id uint32) { - q, ok := ws.sq[id] - if !ok { +func (ws *http2priorityWriteScheduler) Pop() (wr http2FrameWriteRequest, ok bool) { + ws.root.walkReadyInOrder(false, &ws.tmp, func(n *http2priorityNode, openParent bool) bool { + limit := int32(math.MaxInt32) + if openParent { + limit = ws.writeThrottleLimit + } + wr, ok = n.q.consume(limit) + if !ok { + return false + } + n.addBytes(int64(wr.DataSize())) + + if openParent { + ws.writeThrottleLimit += 1024 + if ws.writeThrottleLimit < 0 { + ws.writeThrottleLimit = math.MaxInt32 + } + } else if ws.enableWriteThrottle { + ws.writeThrottleLimit = 1024 + } + return true + }) + return wr, ok +} + +func (ws *http2priorityWriteScheduler) addClosedOrIdleNode(list *[]*http2priorityNode, maxSize int, n *http2priorityNode) { + if maxSize == 0 { return } - delete(ws.sq, id) + if len(*list) == maxSize { - for i := range q.s { - q.s[i] = http2frameWriteMsg{} + ws.removeNode((*list)[0]) + x := (*list)[1:] + copy(*list, x) + *list = (*list)[:len(x)] } - q.s = q.s[:0] - ws.putEmptyQueue(q) + *list = append(*list, n) } -type http2writeQueue struct { - s []http2frameWriteMsg +func (ws *http2priorityWriteScheduler) removeNode(n *http2priorityNode) { + for k := n.kids; k != nil; k = k.next { + k.setParent(n.parent) + } + n.setParent(nil) + delete(ws.nodes, n.id) } -// streamID returns the stream ID for a non-empty stream-specific queue. -func (q *http2writeQueue) streamID() uint32 { return q.s[0].stream.id } +// NewRandomWriteScheduler constructs a WriteScheduler that ignores HTTP/2 +// priorities. Control frames like SETTINGS and PING are written before DATA +// frames, but if no control frames are queued and multiple streams have queued +// HEADERS or DATA frames, Pop selects a ready stream arbitrarily. +func http2NewRandomWriteScheduler() http2WriteScheduler { + return &http2randomWriteScheduler{sq: make(map[uint32]*http2writeQueue)} +} -func (q *http2writeQueue) empty() bool { return len(q.s) == 0 } +type http2randomWriteScheduler struct { + // zero are frames not associated with a specific stream. + zero http2writeQueue + + // sq contains the stream-specific queues, keyed by stream ID. + // When a stream is idle or closed, it's deleted from the map. + sq map[uint32]*http2writeQueue -func (q *http2writeQueue) push(wm http2frameWriteMsg) { - q.s = append(q.s, wm) + // pool of empty queues for reuse. + queuePool http2writeQueuePool } -// head returns the next item that would be removed by shift. -func (q *http2writeQueue) head() http2frameWriteMsg { - if len(q.s) == 0 { - panic("invalid use of queue") - } - return q.s[0] +func (ws *http2randomWriteScheduler) OpenStream(streamID uint32, options http2OpenStreamOptions) { + } -func (q *http2writeQueue) shift() http2frameWriteMsg { - if len(q.s) == 0 { - panic("invalid use of queue") +func (ws *http2randomWriteScheduler) CloseStream(streamID uint32) { + q, ok := ws.sq[streamID] + if !ok { + return } - wm := q.s[0] + delete(ws.sq, streamID) + ws.queuePool.put(q) +} + +func (ws *http2randomWriteScheduler) AdjustStream(streamID uint32, priority http2PriorityParam) { - copy(q.s, q.s[1:]) - q.s[len(q.s)-1] = http2frameWriteMsg{} - q.s = q.s[:len(q.s)-1] - return wm } -func (q *http2writeQueue) firstIsNoCost() bool { - if df, ok := q.s[0].write.(*http2writeData); ok { - return len(df.p) == 0 +func (ws *http2randomWriteScheduler) Push(wr http2FrameWriteRequest) { + id := wr.StreamID() + if id == 0 { + ws.zero.push(wr) + return } - return true + q, ok := ws.sq[id] + if !ok { + q = ws.queuePool.get() + ws.sq[id] = q + } + q.push(wr) +} + +func (ws *http2randomWriteScheduler) Pop() (http2FrameWriteRequest, bool) { + + if !ws.zero.empty() { + return ws.zero.shift(), true + } + + for _, q := range ws.sq { + if wr, ok := q.consume(math.MaxInt32); ok { + return wr, true + } + } + return http2FrameWriteRequest{}, false } diff --git a/libgo/go/net/http/header.go b/libgo/go/net/http/header.go index 6343165a840..832169247fe 100644 --- a/libgo/go/net/http/header.go +++ b/libgo/go/net/http/header.go @@ -32,9 +32,11 @@ func (h Header) Set(key, value string) { } // Get gets the first value associated with the given key. +// It is case insensitive; textproto.CanonicalMIMEHeaderKey is used +// to canonicalize the provided key. // If there are no values associated with the key, Get returns "". -// To access multiple values of a key, access the map directly -// with CanonicalHeaderKey. +// To access multiple values of a key, or to use non-canonical keys, +// access the map directly. func (h Header) Get(key string) string { return textproto.MIMEHeader(h).Get(key) } diff --git a/libgo/go/net/http/http.go b/libgo/go/net/http/http.go index b34ae41ec51..826f7ff3da5 100644 --- a/libgo/go/net/http/http.go +++ b/libgo/go/net/http/http.go @@ -5,7 +5,11 @@ package http import ( + "io" + "strconv" "strings" + "time" + "unicode/utf8" "golang_org/x/net/lex/httplex" ) @@ -14,6 +18,10 @@ import ( // Transport's byte-limiting readers. const maxInt64 = 1<<63 - 1 +// aLongTimeAgo is a non-zero time, far in the past, used for +// immediate cancelation of network operations. +var aLongTimeAgo = time.Unix(233431200, 0) + // TODO(bradfitz): move common stuff here. The other files have accumulated // generic http stuff in random places. @@ -41,3 +49,93 @@ func removeEmptyPort(host string) string { func isNotToken(r rune) bool { return !httplex.IsTokenRune(r) } + +func isASCII(s string) bool { + for i := 0; i < len(s); i++ { + if s[i] >= utf8.RuneSelf { + return false + } + } + return true +} + +func hexEscapeNonASCII(s string) string { + newLen := 0 + for i := 0; i < len(s); i++ { + if s[i] >= utf8.RuneSelf { + newLen += 3 + } else { + newLen++ + } + } + if newLen == len(s) { + return s + } + b := make([]byte, 0, newLen) + for i := 0; i < len(s); i++ { + if s[i] >= utf8.RuneSelf { + b = append(b, '%') + b = strconv.AppendInt(b, int64(s[i]), 16) + } else { + b = append(b, s[i]) + } + } + return string(b) +} + +// NoBody is an io.ReadCloser with no bytes. Read always returns EOF +// and Close always returns nil. It can be used in an outgoing client +// request to explicitly signal that a request has zero bytes. +// An alternative, however, is to simply set Request.Body to nil. +var NoBody = noBody{} + +type noBody struct{} + +func (noBody) Read([]byte) (int, error) { return 0, io.EOF } +func (noBody) Close() error { return nil } +func (noBody) WriteTo(io.Writer) (int64, error) { return 0, nil } + +var ( + // verify that an io.Copy from NoBody won't require a buffer: + _ io.WriterTo = NoBody + _ io.ReadCloser = NoBody +) + +// PushOptions describes options for Pusher.Push. +type PushOptions struct { + // Method specifies the HTTP method for the promised request. + // If set, it must be "GET" or "HEAD". Empty means "GET". + Method string + + // Header specifies additional promised request headers. This cannot + // include HTTP/2 pseudo header fields like ":path" and ":scheme", + // which will be added automatically. + Header Header +} + +// Pusher is the interface implemented by ResponseWriters that support +// HTTP/2 server push. For more background, see +// https://tools.ietf.org/html/rfc7540#section-8.2. +type Pusher interface { + // Push initiates an HTTP/2 server push. This constructs a synthetic + // request using the given target and options, serializes that request + // into a PUSH_PROMISE frame, then dispatches that request using the + // server's request handler. If opts is nil, default options are used. + // + // The target must either be an absolute path (like "/path") or an absolute + // URL that contains a valid host and the same scheme as the parent request. + // If the target is a path, it will inherit the scheme and host of the + // parent request. + // + // The HTTP/2 spec disallows recursive pushes and cross-authority pushes. + // Push may or may not detect these invalid pushes; however, invalid + // pushes will be detected and canceled by conforming clients. + // + // Handlers that wish to push URL X should call Push before sending any + // data that may trigger a request for URL X. This avoids a race where the + // client issues requests for X before receiving the PUSH_PROMISE for X. + // + // Push returns ErrNotSupported if the client has disabled push or if push + // is not supported on the underlying connection. + Push(target string, opts *PushOptions) error +} diff --git a/libgo/go/net/http/http_test.go b/libgo/go/net/http/http_test.go index 34da4bbb59e..8f466bb3668 100644 --- a/libgo/go/net/http/http_test.go +++ b/libgo/go/net/http/http_test.go @@ -12,8 +12,13 @@ import ( "os/exec" "reflect" "testing" + "time" ) +func init() { + shutdownPollInterval = 5 * time.Millisecond +} + func TestForeachHeaderElement(t *testing.T) { tests := []struct { in string @@ -51,6 +56,18 @@ func TestCleanHost(t *testing.T) { {"www.google.com foo", "www.google.com"}, {"www.google.com/foo", "www.google.com"}, {" first character is a space", ""}, + {"[1::6]:8080", "[1::6]:8080"}, + + // Punycode: + {"гофер.рф/foo", "xn--c1ae0ajs.xn--p1ai"}, + {"bücher.de", "xn--bcher-kva.de"}, + {"bücher.de:8080", "xn--bcher-kva.de:8080"}, + // Verify we convert to lowercase before punycode: + {"BÜCHER.de", "xn--bcher-kva.de"}, + {"BÜCHER.de:8080", "xn--bcher-kva.de:8080"}, + // Verify we normalize to NFC before punycode: + {"gophér.nfc", "xn--gophr-esa.nfc"}, // NFC input; no work needed + {"goph\u0065\u0301r.nfd", "xn--gophr-esa.nfd"}, // NFD input } for _, tt := range tests { got := cleanHost(tt.in) @@ -65,8 +82,9 @@ func TestCleanHost(t *testing.T) { // This catches accidental dependencies between the HTTP transport and // server code. func TestCmdGoNoHTTPServer(t *testing.T) { + t.Parallel() goBin := testenv.GoToolPath(t) - out, err := exec.Command("go", "tool", "nm", goBin).CombinedOutput() + out, err := exec.Command(goBin, "tool", "nm", goBin).CombinedOutput() if err != nil { t.Fatalf("go tool nm: %v: %s", err, out) } diff --git a/libgo/go/net/http/httptest/httptest.go b/libgo/go/net/http/httptest/httptest.go index e2148a659c1..f7202da92ff 100644 --- a/libgo/go/net/http/httptest/httptest.go +++ b/libgo/go/net/http/httptest/httptest.go @@ -35,6 +35,9 @@ import ( // // NewRequest panics on error for ease of use in testing, where a // panic is acceptable. +// +// To generate a client HTTP request instead of a server request, see +// the NewRequest function in the net/http package. func NewRequest(method, target string, body io.Reader) *http.Request { if method == "" { method = "GET" diff --git a/libgo/go/net/http/httptest/recorder.go b/libgo/go/net/http/httptest/recorder.go index 0ad26a3d418..5f1aa6af479 100644 --- a/libgo/go/net/http/httptest/recorder.go +++ b/libgo/go/net/http/httptest/recorder.go @@ -8,15 +8,33 @@ import ( "bytes" "io/ioutil" "net/http" + "strconv" + "strings" ) // ResponseRecorder is an implementation of http.ResponseWriter that // records its mutations for later inspection in tests. type ResponseRecorder struct { - Code int // the HTTP response code from WriteHeader - HeaderMap http.Header // the HTTP response headers - Body *bytes.Buffer // if non-nil, the bytes.Buffer to append written data to - Flushed bool + // Code is the HTTP response code set by WriteHeader. + // + // Note that if a Handler never calls WriteHeader or Write, + // this might end up being 0, rather than the implicit + // http.StatusOK. To get the implicit value, use the Result + // method. + Code int + + // HeaderMap contains the headers explicitly set by the Handler. + // + // To get the implicit headers set by the server (such as + // automatic Content-Type), use the Result method. + HeaderMap http.Header + + // Body is the buffer to which the Handler's Write calls are sent. + // If nil, the Writes are silently discarded. + Body *bytes.Buffer + + // Flushed is whether the Handler called Flush. + Flushed bool result *http.Response // cache of Result's return value snapHeader http.Header // snapshot of HeaderMap at first Write @@ -136,6 +154,9 @@ func (rw *ResponseRecorder) Flush() { // first write call, or at the time of this call, if the handler never // did a write. // +// The Response.Body is guaranteed to be non-nil and Body.Read call is +// guaranteed to not return any error other than io.EOF. +// // Result must only be called after the handler has finished running. func (rw *ResponseRecorder) Result() *http.Response { if rw.result != nil { @@ -159,6 +180,7 @@ func (rw *ResponseRecorder) Result() *http.Response { if rw.Body != nil { res.Body = ioutil.NopCloser(bytes.NewReader(rw.Body.Bytes())) } + res.ContentLength = parseContentLength(res.Header.Get("Content-Length")) if trailers, ok := rw.snapHeader["Trailer"]; ok { res.Trailer = make(http.Header, len(trailers)) @@ -181,5 +203,33 @@ func (rw *ResponseRecorder) Result() *http.Response { res.Trailer[k] = vv2 } } + for k, vv := range rw.HeaderMap { + if !strings.HasPrefix(k, http.TrailerPrefix) { + continue + } + if res.Trailer == nil { + res.Trailer = make(http.Header) + } + for _, v := range vv { + res.Trailer.Add(strings.TrimPrefix(k, http.TrailerPrefix), v) + } + } return res } + +// parseContentLength trims whitespace from s and returns -1 if no value +// is set, or the value if it's >= 0. +// +// This a modified version of same function found in net/http/transfer.go. This +// one just ignores an invalid header. +func parseContentLength(cl string) int64 { + cl = strings.TrimSpace(cl) + if cl == "" { + return -1 + } + n, err := strconv.ParseInt(cl, 10, 64) + if err != nil { + return -1 + } + return n +} diff --git a/libgo/go/net/http/httptest/recorder_test.go b/libgo/go/net/http/httptest/recorder_test.go index d4e7137913e..9afba4e556a 100644 --- a/libgo/go/net/http/httptest/recorder_test.go +++ b/libgo/go/net/http/httptest/recorder_test.go @@ -94,6 +94,14 @@ func TestRecorder(t *testing.T) { return nil } } + hasContentLength := func(length int64) checkFunc { + return func(rec *ResponseRecorder) error { + if got := rec.Result().ContentLength; got != length { + return fmt.Errorf("ContentLength = %d; want %d", got, length) + } + return nil + } + } tests := []struct { name string @@ -141,7 +149,7 @@ func TestRecorder(t *testing.T) { w.(http.Flusher).Flush() // also sends a 200 w.WriteHeader(201) }, - check(hasStatus(200), hasFlush(true)), + check(hasStatus(200), hasFlush(true), hasContentLength(-1)), }, { "Content-Type detection", @@ -199,6 +207,7 @@ func TestRecorder(t *testing.T) { w.Header().Set("Trailer-A", "valuea") w.Header().Set("Trailer-C", "valuec") w.Header().Set("Trailer-NotDeclared", "should be omitted") + w.Header().Set("Trailer:Trailer-D", "with prefix") }, check( hasStatus(200), @@ -208,6 +217,7 @@ func TestRecorder(t *testing.T) { hasTrailer("Trailer-A", "valuea"), hasTrailer("Trailer-C", "valuec"), hasNotTrailers("Non-Trailer", "Trailer-B", "Trailer-NotDeclared"), + hasTrailer("Trailer-D", "with prefix"), ), }, { @@ -244,6 +254,16 @@ func TestRecorder(t *testing.T) { hasNotHeaders("X-Bar"), ), }, + { + "setting Content-Length header", + func(w http.ResponseWriter, r *http.Request) { + body := "Some body" + contentLength := fmt.Sprintf("%d", len(body)) + w.Header().Set("Content-Length", contentLength) + io.WriteString(w, body) + }, + check(hasStatus(200), hasContents("Some body"), hasContentLength(9)), + }, } r, _ := http.NewRequest("GET", "http://foo.com/", nil) for _, tt := range tests { diff --git a/libgo/go/net/http/httptest/server.go b/libgo/go/net/http/httptest/server.go index e27526a937a..5e9ace591f3 100644 --- a/libgo/go/net/http/httptest/server.go +++ b/libgo/go/net/http/httptest/server.go @@ -16,7 +16,6 @@ import ( "net/http" "net/http/internal" "os" - "runtime" "sync" "time" ) @@ -114,9 +113,10 @@ func (s *Server) StartTLS() { } existingConfig := s.TLS - s.TLS = new(tls.Config) if existingConfig != nil { - *s.TLS = *existingConfig + s.TLS = existingConfig.Clone() + } else { + s.TLS = new(tls.Config) } if s.TLS.NextProtos == nil { s.TLS.NextProtos = []string{"http/1.1"} @@ -293,15 +293,6 @@ func (s *Server) closeConn(c net.Conn) { s.closeConnChan(c, nil) } // closeConnChan is like closeConn, but takes an optional channel to receive a value // when the goroutine closing c is done. func (s *Server) closeConnChan(c net.Conn, done chan<- struct{}) { - if runtime.GOOS == "plan9" { - // Go's Plan 9 net package isn't great at unblocking reads when - // their underlying TCP connections are closed. Don't trust - // that that the ConnState state machine will get to - // StateClosed. Instead, just go there directly. Plan 9 may leak - // resources if the syscall doesn't end up returning. Oh well. - s.forgetConn(c) - } - c.Close() if done != nil { done <- struct{}{} diff --git a/libgo/go/net/http/httptrace/example_test.go b/libgo/go/net/http/httptrace/example_test.go new file mode 100644 index 00000000000..27cdcdec31b --- /dev/null +++ b/libgo/go/net/http/httptrace/example_test.go @@ -0,0 +1,31 @@ +// Copyright 2016 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build ignore + +package httptrace_test + +import ( + "fmt" + "log" + "net/http" + "net/http/httptrace" +) + +func Example() { + req, _ := http.NewRequest("GET", "http://example.com", nil) + trace := &httptrace.ClientTrace{ + GotConn: func(connInfo httptrace.GotConnInfo) { + fmt.Printf("Got Conn: %+v\n", connInfo) + }, + DNSDone: func(dnsInfo httptrace.DNSDoneInfo) { + fmt.Printf("DNS Info: %+v\n", dnsInfo) + }, + } + req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) + _, err := http.DefaultTransport.RoundTrip(req) + if err != nil { + log.Fatal(err) + } +} diff --git a/libgo/go/net/http/httptrace/trace.go b/libgo/go/net/http/httptrace/trace.go index 6f187a7b694..ea7b38c8fc6 100644 --- a/libgo/go/net/http/httptrace/trace.go +++ b/libgo/go/net/http/httptrace/trace.go @@ -1,6 +1,6 @@ // Copyright 2016 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file.h +// license that can be found in the LICENSE file. // Package httptrace provides mechanisms to trace the events within // HTTP client requests. @@ -8,6 +8,7 @@ package httptrace import ( "context" + "crypto/tls" "internal/nettrace" "net" "reflect" @@ -65,11 +66,16 @@ func WithClientTrace(ctx context.Context, trace *ClientTrace) context.Context { return ctx } -// ClientTrace is a set of hooks to run at various stages of an HTTP -// client request. Any particular hook may be nil. Functions may be -// called concurrently from different goroutines, starting after the -// call to Transport.RoundTrip and ending either when RoundTrip -// returns an error, or when the Response.Body is closed. +// ClientTrace is a set of hooks to run at various stages of an outgoing +// HTTP request. Any particular hook may be nil. Functions may be +// called concurrently from different goroutines and some may be called +// after the request has completed or failed. +// +// ClientTrace currently traces a single HTTP request & response +// during a single round trip and has no hooks that span a series +// of redirected requests. +// +// See https://blog.golang.org/http-tracing for more. type ClientTrace struct { // GetConn is called before a connection is created or // retrieved from an idle pool. The hostPort is the @@ -119,6 +125,16 @@ type ClientTrace struct { // enabled, this may be called multiple times. ConnectDone func(network, addr string, err error) + // TLSHandshakeStart is called when the TLS handshake is started. When + // connecting to a HTTPS site via a HTTP proxy, the handshake happens after + // the CONNECT request is processed by the proxy. + TLSHandshakeStart func() + + // TLSHandshakeDone is called after the TLS handshake with either the + // successful handshake's connection state, or a non-nil error on handshake + // failure. + TLSHandshakeDone func(tls.ConnectionState, error) + // WroteHeaders is called after the Transport has written // the request headers. WroteHeaders func() @@ -130,7 +146,8 @@ type ClientTrace struct { Wait100Continue func() // WroteRequest is called with the result of writing the - // request and any body. + // request and any body. It may be called multiple times + // in the case of retried requests. WroteRequest func(WroteRequestInfo) } diff --git a/libgo/go/net/http/httptrace/trace_test.go b/libgo/go/net/http/httptrace/trace_test.go index c7eaed83d47..bb57ada8531 100644 --- a/libgo/go/net/http/httptrace/trace_test.go +++ b/libgo/go/net/http/httptrace/trace_test.go @@ -1,14 +1,41 @@ // Copyright 2016 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style -// license that can be found in the LICENSE file.h +// license that can be found in the LICENSE file. package httptrace import ( "bytes" + "context" "testing" ) +func TestWithClientTrace(t *testing.T) { + var buf bytes.Buffer + connectStart := func(b byte) func(network, addr string) { + return func(network, addr string) { + buf.WriteByte(b) + } + } + + ctx := context.Background() + oldtrace := &ClientTrace{ + ConnectStart: connectStart('O'), + } + ctx = WithClientTrace(ctx, oldtrace) + newtrace := &ClientTrace{ + ConnectStart: connectStart('N'), + } + ctx = WithClientTrace(ctx, newtrace) + trace := ContextClientTrace(ctx) + + buf.Reset() + trace.ConnectStart("net", "addr") + if got, want := buf.String(), "NO"; got != want { + t.Errorf("got %q; want %q", got, want) + } +} + func TestCompose(t *testing.T) { var buf bytes.Buffer var testNum int diff --git a/libgo/go/net/http/httputil/dump.go b/libgo/go/net/http/httputil/dump.go index 15116816328..7104c374545 100644 --- a/libgo/go/net/http/httputil/dump.go +++ b/libgo/go/net/http/httputil/dump.go @@ -18,11 +18,16 @@ import ( "time" ) -// One of the copies, say from b to r2, could be avoided by using a more -// elaborate trick where the other copy is made during Request/Response.Write. -// This would complicate things too much, given that these functions are for -// debugging only. +// drainBody reads all of b to memory and then returns two equivalent +// ReadClosers yielding the same bytes. +// +// It returns an error if the initial slurp of all bytes fails. It does not attempt +// to make the returned ReadClosers have identical error-matching behavior. func drainBody(b io.ReadCloser) (r1, r2 io.ReadCloser, err error) { + if b == http.NoBody { + // No copying needed. Preserve the magic sentinel meaning of NoBody. + return http.NoBody, http.NoBody, nil + } var buf bytes.Buffer if _, err = buf.ReadFrom(b); err != nil { return nil, b, err diff --git a/libgo/go/net/http/httputil/dump_test.go b/libgo/go/net/http/httputil/dump_test.go index 2e980d39f8a..f881020fef7 100644 --- a/libgo/go/net/http/httputil/dump_test.go +++ b/libgo/go/net/http/httputil/dump_test.go @@ -184,6 +184,18 @@ var dumpTests = []dumpTest{ WantDump: "POST /v2/api/?login HTTP/1.1\r\n" + "Host: passport.myhost.com\r\n\r\n", }, + + // Issue 18506: make drainBody recognize NoBody. Otherwise + // this was turning into a chunked request. + { + Req: *mustNewRequest("POST", "http://example.com/foo", http.NoBody), + + WantDumpOut: "POST /foo HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "User-Agent: Go-http-client/1.1\r\n" + + "Content-Length: 0\r\n" + + "Accept-Encoding: gzip\r\n\r\n", + }, } func TestDumpRequest(t *testing.T) { diff --git a/libgo/go/net/http/httputil/persist.go b/libgo/go/net/http/httputil/persist.go index 87ddd52cd96..cbedf25ad1b 100644 --- a/libgo/go/net/http/httputil/persist.go +++ b/libgo/go/net/http/httputil/persist.go @@ -15,9 +15,14 @@ import ( ) var ( + // Deprecated: No longer used. ErrPersistEOF = &http.ProtocolError{ErrorString: "persistent connection closed"} - ErrClosed = &http.ProtocolError{ErrorString: "connection closed by user"} - ErrPipeline = &http.ProtocolError{ErrorString: "pipeline error"} + + // Deprecated: No longer used. + ErrClosed = &http.ProtocolError{ErrorString: "connection closed by user"} + + // Deprecated: No longer used. + ErrPipeline = &http.ProtocolError{ErrorString: "pipeline error"} ) // This is an API usage error - the local side is closed. diff --git a/libgo/go/net/http/httputil/reverseproxy.go b/libgo/go/net/http/httputil/reverseproxy.go index 49c120afde1..79c8fe27702 100644 --- a/libgo/go/net/http/httputil/reverseproxy.go +++ b/libgo/go/net/http/httputil/reverseproxy.go @@ -7,6 +7,7 @@ package httputil import ( + "context" "io" "log" "net" @@ -29,6 +30,8 @@ type ReverseProxy struct { // the request into a new request to be sent // using Transport. Its response is then copied // back to the original client unmodified. + // Director must not access the provided Request + // after returning. Director func(*http.Request) // The transport used to perform proxy requests. @@ -51,6 +54,11 @@ type ReverseProxy struct { // get byte slices for use by io.CopyBuffer when // copying HTTP response bodies. BufferPool BufferPool + + // ModifyResponse is an optional function that + // modifies the Response from the backend. + // If it returns an error, the proxy returns a StatusBadGateway error. + ModifyResponse func(*http.Response) error } // A BufferPool is an interface for getting and returning temporary @@ -120,76 +128,59 @@ var hopHeaders = []string{ "Upgrade", } -type requestCanceler interface { - CancelRequest(*http.Request) -} - -type runOnFirstRead struct { - io.Reader // optional; nil means empty body - - fn func() // Run before first Read, then set to nil -} - -func (c *runOnFirstRead) Read(bs []byte) (int, error) { - if c.fn != nil { - c.fn() - c.fn = nil - } - if c.Reader == nil { - return 0, io.EOF - } - return c.Reader.Read(bs) -} - func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { transport := p.Transport if transport == nil { transport = http.DefaultTransport } + ctx := req.Context() + if cn, ok := rw.(http.CloseNotifier); ok { + var cancel context.CancelFunc + ctx, cancel = context.WithCancel(ctx) + defer cancel() + notifyChan := cn.CloseNotify() + go func() { + select { + case <-notifyChan: + cancel() + case <-ctx.Done(): + } + }() + } + outreq := new(http.Request) *outreq = *req // includes shallow copies of maps, but okay - - if closeNotifier, ok := rw.(http.CloseNotifier); ok { - if requestCanceler, ok := transport.(requestCanceler); ok { - reqDone := make(chan struct{}) - defer close(reqDone) - - clientGone := closeNotifier.CloseNotify() - - outreq.Body = struct { - io.Reader - io.Closer - }{ - Reader: &runOnFirstRead{ - Reader: outreq.Body, - fn: func() { - go func() { - select { - case <-clientGone: - requestCanceler.CancelRequest(outreq) - case <-reqDone: - } - }() - }, - }, - Closer: outreq.Body, - } - } + if req.ContentLength == 0 { + outreq.Body = nil // Issue 16036: nil Body for http.Transport retries } + outreq = outreq.WithContext(ctx) p.Director(outreq) - outreq.Proto = "HTTP/1.1" - outreq.ProtoMajor = 1 - outreq.ProtoMinor = 1 outreq.Close = false - // Remove hop-by-hop headers to the backend. Especially - // important is "Connection" because we want a persistent - // connection, regardless of what the client sent to us. This - // is modifying the same underlying map from req (shallow + // We are modifying the same underlying map from req (shallow // copied above) so we only copy it if necessary. copiedHeaders := false + + // Remove hop-by-hop headers listed in the "Connection" header. + // See RFC 2616, section 14.10. + if c := outreq.Header.Get("Connection"); c != "" { + for _, f := range strings.Split(c, ",") { + if f = strings.TrimSpace(f); f != "" { + if !copiedHeaders { + outreq.Header = make(http.Header) + copyHeader(outreq.Header, req.Header) + copiedHeaders = true + } + outreq.Header.Del(f) + } + } + } + + // Remove hop-by-hop headers to the backend. Especially + // important is "Connection" because we want a persistent + // connection, regardless of what the client sent to us. for _, h := range hopHeaders { if outreq.Header.Get(h) != "" { if !copiedHeaders { @@ -218,16 +209,34 @@ func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) { return } + // Remove hop-by-hop headers listed in the + // "Connection" header of the response. + if c := res.Header.Get("Connection"); c != "" { + for _, f := range strings.Split(c, ",") { + if f = strings.TrimSpace(f); f != "" { + res.Header.Del(f) + } + } + } + for _, h := range hopHeaders { res.Header.Del(h) } + if p.ModifyResponse != nil { + if err := p.ModifyResponse(res); err != nil { + p.logf("http: proxy error: %v", err) + rw.WriteHeader(http.StatusBadGateway) + return + } + } + copyHeader(rw.Header(), res.Header) // The "Trailer" header isn't included in the Transport's response, // at least for *http.Transport. Build it up from Trailer. if len(res.Trailer) > 0 { - var trailerKeys []string + trailerKeys := make([]string, 0, len(res.Trailer)) for k := range res.Trailer { trailerKeys = append(trailerKeys, k) } @@ -266,12 +275,40 @@ func (p *ReverseProxy) copyResponse(dst io.Writer, src io.Reader) { if p.BufferPool != nil { buf = p.BufferPool.Get() } - io.CopyBuffer(dst, src, buf) + p.copyBuffer(dst, src, buf) if p.BufferPool != nil { p.BufferPool.Put(buf) } } +func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader, buf []byte) (int64, error) { + if len(buf) == 0 { + buf = make([]byte, 32*1024) + } + var written int64 + for { + nr, rerr := src.Read(buf) + if rerr != nil && rerr != io.EOF { + p.logf("httputil: ReverseProxy read error during body copy: %v", rerr) + } + if nr > 0 { + nw, werr := dst.Write(buf[:nr]) + if nw > 0 { + written += int64(nw) + } + if werr != nil { + return written, werr + } + if nr != nw { + return written, io.ErrShortWrite + } + } + if rerr != nil { + return written, rerr + } + } +} + func (p *ReverseProxy) logf(format string, args ...interface{}) { if p.ErrorLog != nil { p.ErrorLog.Printf(format, args...) diff --git a/libgo/go/net/http/httputil/reverseproxy_test.go b/libgo/go/net/http/httputil/reverseproxy_test.go index fe7cdb888f5..20c4e16bcb8 100644 --- a/libgo/go/net/http/httputil/reverseproxy_test.go +++ b/libgo/go/net/http/httputil/reverseproxy_test.go @@ -9,6 +9,8 @@ package httputil import ( "bufio" "bytes" + "errors" + "fmt" "io" "io/ioutil" "log" @@ -135,6 +137,61 @@ func TestReverseProxy(t *testing.T) { } +// Issue 16875: remove any proxied headers mentioned in the "Connection" +// header value. +func TestReverseProxyStripHeadersPresentInConnection(t *testing.T) { + const fakeConnectionToken = "X-Fake-Connection-Token" + const backendResponse = "I am the backend" + backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if c := r.Header.Get(fakeConnectionToken); c != "" { + t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c) + } + if c := r.Header.Get("Upgrade"); c != "" { + t.Errorf("handler got header %q = %q; want empty", "Upgrade", c) + } + w.Header().Set("Connection", "Upgrade, "+fakeConnectionToken) + w.Header().Set("Upgrade", "should be deleted") + w.Header().Set(fakeConnectionToken, "should be deleted") + io.WriteString(w, backendResponse) + })) + defer backend.Close() + backendURL, err := url.Parse(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + frontend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + proxyHandler.ServeHTTP(w, r) + if c := r.Header.Get("Upgrade"); c != "original value" { + t.Errorf("handler modified header %q = %q; want %q", "Upgrade", c, "original value") + } + })) + defer frontend.Close() + + getReq, _ := http.NewRequest("GET", frontend.URL, nil) + getReq.Header.Set("Connection", "Upgrade, "+fakeConnectionToken) + getReq.Header.Set("Upgrade", "original value") + getReq.Header.Set(fakeConnectionToken, "should be deleted") + res, err := http.DefaultClient.Do(getReq) + if err != nil { + t.Fatalf("Get: %v", err) + } + defer res.Body.Close() + bodyBytes, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatalf("reading body: %v", err) + } + if got, want := string(bodyBytes), backendResponse; got != want { + t.Errorf("got body %q; want %q", got, want) + } + if c := res.Header.Get("Upgrade"); c != "" { + t.Errorf("handler got header %q = %q; want empty", "Upgrade", c) + } + if c := res.Header.Get(fakeConnectionToken); c != "" { + t.Errorf("handler got header %q = %q; want empty", fakeConnectionToken, c) + } +} + func TestXForwardedFor(t *testing.T) { const prevForwardedFor = "client ip" const backendResponse = "I am the backend" @@ -260,14 +317,14 @@ func TestReverseProxyCancelation(t *testing.T) { reqInFlight := make(chan struct{}) backend := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - close(reqInFlight) + close(reqInFlight) // cause the client to cancel its request select { case <-time.After(10 * time.Second): // Note: this should only happen in broken implementations, and the // closenotify case should be instantaneous. - t.Log("Failed to close backend connection") - t.Fail() + t.Error("Handler never saw CloseNotify") + return case <-w.(http.CloseNotifier).CloseNotify(): } @@ -300,13 +357,13 @@ func TestReverseProxyCancelation(t *testing.T) { }() res, err := http.DefaultClient.Do(getReq) if res != nil { - t.Fatal("Non-nil response") + t.Errorf("got response %v; want nil", res.Status) } if err == nil { // This should be an error like: // Get http://127.0.0.1:58079: read tcp 127.0.0.1:58079: // use of closed network connection - t.Fatal("DefaultClient.Do() returned nil error") + t.Error("DefaultClient.Do() returned nil error; want non-nil error") } } @@ -495,3 +552,115 @@ func TestReverseProxy_Post(t *testing.T) { t.Errorf("got body %q; expected %q", g, e) } } + +type RoundTripperFunc func(*http.Request) (*http.Response, error) + +func (fn RoundTripperFunc) RoundTrip(req *http.Request) (*http.Response, error) { + return fn(req) +} + +// Issue 16036: send a Request with a nil Body when possible +func TestReverseProxy_NilBody(t *testing.T) { + backendURL, _ := url.Parse("http://fake.tld/") + proxyHandler := NewSingleHostReverseProxy(backendURL) + proxyHandler.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests + proxyHandler.Transport = RoundTripperFunc(func(req *http.Request) (*http.Response, error) { + if req.Body != nil { + t.Error("Body != nil; want a nil Body") + } + return nil, errors.New("done testing the interesting part; so force a 502 Gateway error") + }) + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + res, err := http.DefaultClient.Get(frontend.URL) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + if res.StatusCode != 502 { + t.Errorf("status code = %v; want 502 (Gateway Error)", res.Status) + } +} + +// Issue 14237. Test ModifyResponse and that an error from it +// causes the proxy to return StatusBadGateway, or StatusOK otherwise. +func TestReverseProxyModifyResponse(t *testing.T) { + backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Add("X-Hit-Mod", fmt.Sprintf("%v", r.URL.Path == "/mod")) + })) + defer backendServer.Close() + + rpURL, _ := url.Parse(backendServer.URL) + rproxy := NewSingleHostReverseProxy(rpURL) + rproxy.ErrorLog = log.New(ioutil.Discard, "", 0) // quiet for tests + rproxy.ModifyResponse = func(resp *http.Response) error { + if resp.Header.Get("X-Hit-Mod") != "true" { + return fmt.Errorf("tried to by-pass proxy") + } + return nil + } + + frontendProxy := httptest.NewServer(rproxy) + defer frontendProxy.Close() + + tests := []struct { + url string + wantCode int + }{ + {frontendProxy.URL + "/mod", http.StatusOK}, + {frontendProxy.URL + "/schedule", http.StatusBadGateway}, + } + + for i, tt := range tests { + resp, err := http.Get(tt.url) + if err != nil { + t.Fatalf("failed to reach proxy: %v", err) + } + if g, e := resp.StatusCode, tt.wantCode; g != e { + t.Errorf("#%d: got res.StatusCode %d; expected %d", i, g, e) + } + resp.Body.Close() + } +} + +// Issue 16659: log errors from short read +func TestReverseProxy_CopyBuffer(t *testing.T) { + backendServer := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + out := "this call was relayed by the reverse proxy" + // Coerce a wrong content length to induce io.UnexpectedEOF + w.Header().Set("Content-Length", fmt.Sprintf("%d", len(out)*2)) + fmt.Fprintln(w, out) + })) + defer backendServer.Close() + + rpURL, err := url.Parse(backendServer.URL) + if err != nil { + t.Fatal(err) + } + + var proxyLog bytes.Buffer + rproxy := NewSingleHostReverseProxy(rpURL) + rproxy.ErrorLog = log.New(&proxyLog, "", log.Lshortfile) + frontendProxy := httptest.NewServer(rproxy) + defer frontendProxy.Close() + + resp, err := http.Get(frontendProxy.URL) + if err != nil { + t.Fatalf("failed to reach proxy: %v", err) + } + defer resp.Body.Close() + + if _, err := ioutil.ReadAll(resp.Body); err == nil { + t.Fatalf("want non-nil error") + } + expected := []string{ + "EOF", + "read", + } + for _, phrase := range expected { + if !bytes.Contains(proxyLog.Bytes(), []byte(phrase)) { + t.Errorf("expected log to contain phrase %q", phrase) + } + } +} diff --git a/libgo/go/net/http/internal/chunked.go b/libgo/go/net/http/internal/chunked.go index 2e62c00d5db..63f321d03b9 100644 --- a/libgo/go/net/http/internal/chunked.go +++ b/libgo/go/net/http/internal/chunked.go @@ -35,10 +35,11 @@ func NewChunkedReader(r io.Reader) io.Reader { } type chunkedReader struct { - r *bufio.Reader - n uint64 // unread bytes in chunk - err error - buf [2]byte + r *bufio.Reader + n uint64 // unread bytes in chunk + err error + buf [2]byte + checkEnd bool // whether need to check for \r\n chunk footer } func (cr *chunkedReader) beginChunk() { @@ -68,6 +69,21 @@ func (cr *chunkedReader) chunkHeaderAvailable() bool { func (cr *chunkedReader) Read(b []uint8) (n int, err error) { for cr.err == nil { + if cr.checkEnd { + if n > 0 && cr.r.Buffered() < 2 { + // We have some data. Return early (per the io.Reader + // contract) instead of potentially blocking while + // reading more. + break + } + if _, cr.err = io.ReadFull(cr.r, cr.buf[:2]); cr.err == nil { + if string(cr.buf[:]) != "\r\n" { + cr.err = errors.New("malformed chunked encoding") + break + } + } + cr.checkEnd = false + } if cr.n == 0 { if n > 0 && !cr.chunkHeaderAvailable() { // We've read enough. Don't potentially block @@ -92,11 +108,7 @@ func (cr *chunkedReader) Read(b []uint8) (n int, err error) { // If we're at the end of a chunk, read the next two // bytes to verify they are "\r\n". if cr.n == 0 && cr.err == nil { - if _, cr.err = io.ReadFull(cr.r, cr.buf[:2]); cr.err == nil { - if cr.buf[0] != '\r' || cr.buf[1] != '\n' { - cr.err = errors.New("malformed chunked encoding") - } - } + cr.checkEnd = true } } return n, cr.err diff --git a/libgo/go/net/http/internal/chunked_test.go b/libgo/go/net/http/internal/chunked_test.go index 9abe1ab6d9d..d06716591ab 100644 --- a/libgo/go/net/http/internal/chunked_test.go +++ b/libgo/go/net/http/internal/chunked_test.go @@ -185,3 +185,30 @@ func TestChunkReadingIgnoresExtensions(t *testing.T) { t.Errorf("read %q; want %q", g, e) } } + +// Issue 17355: ChunkedReader shouldn't block waiting for more data +// if it can return something. +func TestChunkReadPartial(t *testing.T) { + pr, pw := io.Pipe() + go func() { + pw.Write([]byte("7\r\n1234567")) + }() + cr := NewChunkedReader(pr) + readBuf := make([]byte, 7) + n, err := cr.Read(readBuf) + if err != nil { + t.Fatal(err) + } + want := "1234567" + if n != 7 || string(readBuf) != want { + t.Fatalf("Read: %v %q; want %d, %q", n, readBuf[:n], len(want), want) + } + go func() { + pw.Write([]byte("xx")) + }() + _, err = cr.Read(readBuf) + if got := fmt.Sprint(err); !strings.Contains(got, "malformed") { + t.Fatalf("second read = %v; want malformed error", err) + } + +} diff --git a/libgo/go/net/http/main_test.go b/libgo/go/net/http/main_test.go index aea6e12744b..438bd2e58fd 100644 --- a/libgo/go/net/http/main_test.go +++ b/libgo/go/net/http/main_test.go @@ -6,6 +6,8 @@ package http_test import ( "fmt" + "io/ioutil" + "log" "net/http" "os" "runtime" @@ -15,6 +17,8 @@ import ( "time" ) +var quietLog = log.New(ioutil.Discard, "", 0) + func TestMain(m *testing.M) { v := m.Run() if v == 0 && goroutineLeaked() { @@ -134,3 +138,20 @@ func waitCondition(waitFor, checkEvery time.Duration, fn func() bool) bool { } return false } + +// waitErrCondition is like waitCondition but with errors instead of bools. +func waitErrCondition(waitFor, checkEvery time.Duration, fn func() error) error { + deadline := time.Now().Add(waitFor) + var err error + for time.Now().Before(deadline) { + if err = fn(); err == nil { + return nil + } + time.Sleep(checkEvery) + } + return err +} + +func closeClient(c *http.Client) { + c.Transport.(*http.Transport).CloseIdleConnections() +} diff --git a/libgo/go/net/http/npn_test.go b/libgo/go/net/http/npn_test.go index e2e911d3dd1..4c1f6b573df 100644 --- a/libgo/go/net/http/npn_test.go +++ b/libgo/go/net/http/npn_test.go @@ -18,6 +18,7 @@ import ( ) func TestNextProtoUpgrade(t *testing.T) { + setParallel(t) defer afterTest(t) ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "path=%s,proto=", r.URL.Path) diff --git a/libgo/go/net/http/range_test.go b/libgo/go/net/http/range_test.go index ef911af7b08..114987ed2c6 100644 --- a/libgo/go/net/http/range_test.go +++ b/libgo/go/net/http/range_test.go @@ -38,7 +38,7 @@ var ParseRangeTests = []struct { {"bytes=0-", 10, []httpRange{{0, 10}}}, {"bytes=5-", 10, []httpRange{{5, 5}}}, {"bytes=0-20", 10, []httpRange{{0, 10}}}, - {"bytes=15-,0-5", 10, nil}, + {"bytes=15-,0-5", 10, []httpRange{{0, 6}}}, {"bytes=1-2,5-", 10, []httpRange{{1, 2}, {5, 5}}}, {"bytes=-2 , 7-", 11, []httpRange{{9, 2}, {7, 4}}}, {"bytes=0-0 ,2-2, 7-", 11, []httpRange{{0, 1}, {2, 1}, {7, 4}}}, diff --git a/libgo/go/net/http/readrequest_test.go b/libgo/go/net/http/readrequest_test.go index 4bf646b0a63..28a148b9acb 100644 --- a/libgo/go/net/http/readrequest_test.go +++ b/libgo/go/net/http/readrequest_test.go @@ -25,7 +25,7 @@ type reqTest struct { } var noError = "" -var noBody = "" +var noBodyStr = "" var noTrailer Header = nil var reqTests = []reqTest{ @@ -95,7 +95,7 @@ var reqTests = []reqTest{ RequestURI: "/", }, - noBody, + noBodyStr, noTrailer, noError, }, @@ -121,7 +121,7 @@ var reqTests = []reqTest{ RequestURI: "//user@host/is/actually/a/path/", }, - noBody, + noBodyStr, noTrailer, noError, }, @@ -131,7 +131,7 @@ var reqTests = []reqTest{ "GET ../../../../etc/passwd HTTP/1.1\r\n" + "Host: test\r\n\r\n", nil, - noBody, + noBodyStr, noTrailer, "parse ../../../../etc/passwd: invalid URI for request", }, @@ -141,7 +141,7 @@ var reqTests = []reqTest{ "GET HTTP/1.1\r\n" + "Host: test\r\n\r\n", nil, - noBody, + noBodyStr, noTrailer, "parse : empty url", }, @@ -227,7 +227,7 @@ var reqTests = []reqTest{ RequestURI: "www.google.com:443", }, - noBody, + noBodyStr, noTrailer, noError, }, @@ -251,7 +251,7 @@ var reqTests = []reqTest{ RequestURI: "127.0.0.1:6060", }, - noBody, + noBodyStr, noTrailer, noError, }, @@ -275,7 +275,7 @@ var reqTests = []reqTest{ RequestURI: "/_goRPC_", }, - noBody, + noBodyStr, noTrailer, noError, }, @@ -299,7 +299,7 @@ var reqTests = []reqTest{ RequestURI: "*", }, - noBody, + noBodyStr, noTrailer, noError, }, @@ -323,7 +323,7 @@ var reqTests = []reqTest{ RequestURI: "*", }, - noBody, + noBodyStr, noTrailer, noError, }, @@ -350,7 +350,7 @@ var reqTests = []reqTest{ RequestURI: "/", }, - noBody, + noBodyStr, noTrailer, noError, }, @@ -376,7 +376,7 @@ var reqTests = []reqTest{ RequestURI: "/", }, - noBody, + noBodyStr, noTrailer, noError, }, @@ -397,7 +397,7 @@ var reqTests = []reqTest{ ContentLength: -1, Close: true, }, - noBody, + noBodyStr, noTrailer, noError, }, diff --git a/libgo/go/net/http/request.go b/libgo/go/net/http/request.go index dc5559282d0..fb6bb0aab58 100644 --- a/libgo/go/net/http/request.go +++ b/libgo/go/net/http/request.go @@ -18,12 +18,17 @@ import ( "io/ioutil" "mime" "mime/multipart" + "net" "net/http/httptrace" "net/textproto" "net/url" "strconv" "strings" "sync" + + "golang_org/x/net/idna" + "golang_org/x/text/unicode/norm" + "golang_org/x/text/width" ) const ( @@ -34,21 +39,40 @@ const ( // is either not present in the request or not a file field. var ErrMissingFile = errors.New("http: no such file") -// HTTP request parsing errors. +// ProtocolError represents an HTTP protocol error. +// +// Deprecated: Not all errors in the http package related to protocol errors +// are of type ProtocolError. type ProtocolError struct { ErrorString string } -func (err *ProtocolError) Error() string { return err.ErrorString } +func (pe *ProtocolError) Error() string { return pe.ErrorString } var ( - ErrHeaderTooLong = &ProtocolError{"header too long"} - ErrShortBody = &ProtocolError{"entity body too short"} - ErrNotSupported = &ProtocolError{"feature not supported"} - ErrUnexpectedTrailer = &ProtocolError{"trailer header without chunked transfer encoding"} + // ErrNotSupported is returned by the Push method of Pusher + // implementations to indicate that HTTP/2 Push support is not + // available. + ErrNotSupported = &ProtocolError{"feature not supported"} + + // ErrUnexpectedTrailer is returned by the Transport when a server + // replies with a Trailer header, but without a chunked reply. + ErrUnexpectedTrailer = &ProtocolError{"trailer header without chunked transfer encoding"} + + // ErrMissingBoundary is returned by Request.MultipartReader when the + // request's Content-Type does not include a "boundary" parameter. + ErrMissingBoundary = &ProtocolError{"no multipart boundary param in Content-Type"} + + // ErrNotMultipart is returned by Request.MultipartReader when the + // request's Content-Type is not multipart/form-data. + ErrNotMultipart = &ProtocolError{"request Content-Type isn't multipart/form-data"} + + // Deprecated: ErrHeaderTooLong is not used. + ErrHeaderTooLong = &ProtocolError{"header too long"} + // Deprecated: ErrShortBody is not used. + ErrShortBody = &ProtocolError{"entity body too short"} + // Deprecated: ErrMissingContentLength is not used. ErrMissingContentLength = &ProtocolError{"missing ContentLength in HEAD response"} - ErrNotMultipart = &ProtocolError{"request Content-Type isn't multipart/form-data"} - ErrMissingBoundary = &ProtocolError{"no multipart boundary param in Content-Type"} ) type badStringError struct { @@ -146,11 +170,20 @@ type Request struct { // Handler does not need to. Body io.ReadCloser + // GetBody defines an optional func to return a new copy of + // Body. It is used for client requests when a redirect requires + // reading the body more than once. Use of GetBody still + // requires setting Body. + // + // For server requests it is unused. + GetBody func() (io.ReadCloser, error) + // ContentLength records the length of the associated content. // The value -1 indicates that the length is unknown. // Values >= 0 indicate that the given number of bytes may // be read from Body. - // For client requests, a value of 0 means unknown if Body is not nil. + // For client requests, a value of 0 with a non-nil Body is + // also treated as unknown. ContentLength int64 // TransferEncoding lists the transfer encodings from outermost to @@ -175,11 +208,15 @@ type Request struct { // For server requests Host specifies the host on which the // URL is sought. Per RFC 2616, this is either the value of // the "Host" header or the host name given in the URL itself. - // It may be of the form "host:port". + // It may be of the form "host:port". For international domain + // names, Host may be in Punycode or Unicode form. Use + // golang.org/x/net/idna to convert it to either format if + // needed. // // For client requests Host optionally overrides the Host // header to send. If empty, the Request.Write method uses - // the value of URL.Host. + // the value of URL.Host. Host may contain an international + // domain name. Host string // Form contains the parsed form data, including both the URL @@ -276,8 +313,8 @@ type Request struct { // For outgoing client requests, the context controls cancelation. // // For incoming server requests, the context is canceled when the -// ServeHTTP method returns. For its associated values, see -// ServerContextKey and LocalAddrContextKey. +// client's connection closes, the request is canceled (with HTTP/2), +// or when the ServeHTTP method returns. func (r *Request) Context() context.Context { if r.ctx != nil { return r.ctx @@ -304,6 +341,18 @@ func (r *Request) ProtoAtLeast(major, minor int) bool { r.ProtoMajor == major && r.ProtoMinor >= minor } +// protoAtLeastOutgoing is like ProtoAtLeast, but is for outgoing +// requests (see issue 18407) where these fields aren't supposed to +// matter. As a minor fix for Go 1.8, at least treat (0, 0) as +// matching HTTP/1.1 or HTTP/1.0. Only HTTP/1.1 is used. +// TODO(bradfitz): ideally remove this whole method. It shouldn't be used. +func (r *Request) protoAtLeastOutgoing(major, minor int) bool { + if r.ProtoMajor == 0 && r.ProtoMinor == 0 && major == 1 && minor <= 1 { + return true + } + return r.ProtoAtLeast(major, minor) +} + // UserAgent returns the client's User-Agent, if sent in the request. func (r *Request) UserAgent() string { return r.Header.Get("User-Agent") @@ -319,6 +368,8 @@ var ErrNoCookie = errors.New("http: named cookie not present") // Cookie returns the named cookie provided in the request or // ErrNoCookie if not found. +// If multiple cookies match the given name, only one cookie will +// be returned. func (r *Request) Cookie(name string) (*Cookie, error) { for _, c := range readCookies(r.Header, name) { return c, nil @@ -561,6 +612,12 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, wai } } + if bw, ok := w.(*bufio.Writer); ok && tw.FlushHeaders { + if err := bw.Flush(); err != nil { + return err + } + } + // Write body and trailer err = tw.WriteBody(w) if err != nil { @@ -573,7 +630,24 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, wai return nil } -// cleanHost strips anything after '/' or ' '. +func idnaASCII(v string) (string, error) { + if isASCII(v) { + return v, nil + } + // The idna package doesn't do everything from + // https://tools.ietf.org/html/rfc5895 so we do it here. + // TODO(bradfitz): should the idna package do this instead? + v = strings.ToLower(v) + v = width.Fold.String(v) + v = norm.NFC.String(v) + return idna.ToASCII(v) +} + +// cleanHost cleans up the host sent in request's Host header. +// +// It both strips anything after '/' or ' ', and puts the value +// into Punycode form, if necessary. +// // Ideally we'd clean the Host header according to the spec: // https://tools.ietf.org/html/rfc7230#section-5.4 (Host = uri-host [ ":" port ]") // https://tools.ietf.org/html/rfc7230#section-2.7 (uri-host -> rfc3986's host) @@ -584,9 +658,21 @@ func (req *Request) write(w io.Writer, usingProxy bool, extraHeaders Header, wai // first offending character. func cleanHost(in string) string { if i := strings.IndexAny(in, " /"); i != -1 { - return in[:i] + in = in[:i] + } + host, port, err := net.SplitHostPort(in) + if err != nil { // input was just a host + a, err := idnaASCII(in) + if err != nil { + return in // garbage in, garbage out + } + return a } - return in + a, err := idnaASCII(host) + if err != nil { + return in // garbage in, garbage out + } + return net.JoinHostPort(a, port) } // removeZone removes IPv6 zone identifier from host. @@ -658,11 +744,17 @@ func validMethod(method string) bool { // methods Do, Post, and PostForm, and Transport.RoundTrip. // // NewRequest returns a Request suitable for use with Client.Do or -// Transport.RoundTrip. -// To create a request for use with testing a Server Handler use either -// ReadRequest or manually update the Request fields. See the Request -// type's documentation for the difference between inbound and outbound -// request fields. +// Transport.RoundTrip. To create a request for use with testing a +// Server Handler, either use the NewRequest function in the +// net/http/httptest package, use ReadRequest, or manually update the +// Request fields. See the Request type's documentation for the +// difference between inbound and outbound request fields. +// +// If body is of type *bytes.Buffer, *bytes.Reader, or +// *strings.Reader, the returned request's ContentLength is set to its +// exact value (instead of -1), GetBody is populated (so 307 and 308 +// redirects can replay the body), and Body is set to NoBody if the +// ContentLength is 0. func NewRequest(method, urlStr string, body io.Reader) (*Request, error) { if method == "" { // We document that "" means "GET" for Request.Method, and people have @@ -697,10 +789,43 @@ func NewRequest(method, urlStr string, body io.Reader) (*Request, error) { switch v := body.(type) { case *bytes.Buffer: req.ContentLength = int64(v.Len()) + buf := v.Bytes() + req.GetBody = func() (io.ReadCloser, error) { + r := bytes.NewReader(buf) + return ioutil.NopCloser(r), nil + } case *bytes.Reader: req.ContentLength = int64(v.Len()) + snapshot := *v + req.GetBody = func() (io.ReadCloser, error) { + r := snapshot + return ioutil.NopCloser(&r), nil + } case *strings.Reader: req.ContentLength = int64(v.Len()) + snapshot := *v + req.GetBody = func() (io.ReadCloser, error) { + r := snapshot + return ioutil.NopCloser(&r), nil + } + default: + // This is where we'd set it to -1 (at least + // if body != NoBody) to mean unknown, but + // that broke people during the Go 1.8 testing + // period. People depend on it being 0 I + // guess. Maybe retry later. See Issue 18117. + } + // For client requests, Request.ContentLength of 0 + // means either actually 0, or unknown. The only way + // to explicitly say that the ContentLength is zero is + // to set the Body to nil. But turns out too much code + // depends on NewRequest returning a non-nil Body, + // so we use a well-known ReadCloser variable instead + // and have the http package also treat that sentinel + // variable to mean explicitly zero. + if req.GetBody != nil && req.ContentLength == 0 { + req.Body = NoBody + req.GetBody = func() (io.ReadCloser, error) { return NoBody, nil } } } @@ -1000,18 +1125,24 @@ func parsePostForm(r *Request) (vs url.Values, err error) { return } -// ParseForm parses the raw query from the URL and updates r.Form. +// ParseForm populates r.Form and r.PostForm. +// +// For all requests, ParseForm parses the raw query from the URL and updates +// r.Form. +// +// For POST, PUT, and PATCH requests, it also parses the request body as a form +// and puts the results into both r.PostForm and r.Form. Request body parameters +// take precedence over URL query string values in r.Form. // -// For POST or PUT requests, it also parses the request body as a form and -// put the results into both r.PostForm and r.Form. -// POST and PUT body parameters take precedence over URL query string values -// in r.Form. +// For other HTTP methods, or when the Content-Type is not +// application/x-www-form-urlencoded, the request Body is not read, and +// r.PostForm is initialized to a non-nil, empty value. // // If the request Body's size has not already been limited by MaxBytesReader, // the size is capped at 10MB. // // ParseMultipartForm calls ParseForm automatically. -// It is idempotent. +// ParseForm is idempotent. func (r *Request) ParseForm() error { var err error if r.PostForm == nil { @@ -1174,3 +1305,30 @@ func (r *Request) isReplayable() bool { } return false } + +// outgoingLength reports the Content-Length of this outgoing (Client) request. +// It maps 0 into -1 (unknown) when the Body is non-nil. +func (r *Request) outgoingLength() int64 { + if r.Body == nil || r.Body == NoBody { + return 0 + } + if r.ContentLength != 0 { + return r.ContentLength + } + return -1 +} + +// requestMethodUsuallyLacksBody reports whether the given request +// method is one that typically does not involve a request body. +// This is used by the Transport (via +// transferWriter.shouldSendChunkedRequestBody) to determine whether +// we try to test-read a byte from a non-nil Request.Body when +// Request.outgoingLength() returns -1. See the comments in +// shouldSendChunkedRequestBody. +func requestMethodUsuallyLacksBody(method string) bool { + switch method { + case "GET", "HEAD", "DELETE", "OPTIONS", "PROPFIND", "SEARCH": + return true + } + return false +} diff --git a/libgo/go/net/http/request_test.go b/libgo/go/net/http/request_test.go index a4c88c02915..e6748375b58 100644 --- a/libgo/go/net/http/request_test.go +++ b/libgo/go/net/http/request_test.go @@ -29,9 +29,9 @@ func TestQuery(t *testing.T) { } } -func TestPostQuery(t *testing.T) { - req, _ := NewRequest("POST", "http://www.google.com/search?q=foo&q=bar&both=x&prio=1&empty=not", - strings.NewReader("z=post&both=y&prio=2&empty=")) +func TestParseFormQuery(t *testing.T) { + req, _ := NewRequest("POST", "http://www.google.com/search?q=foo&q=bar&both=x&prio=1&orphan=nope&empty=not", + strings.NewReader("z=post&both=y&prio=2&=nokey&orphan;empty=&")) req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") if q := req.FormValue("q"); q != "foo" { @@ -55,39 +55,30 @@ func TestPostQuery(t *testing.T) { if prio := req.FormValue("prio"); prio != "2" { t.Errorf(`req.FormValue("prio") = %q, want "2" (from body)`, prio) } - if empty := req.FormValue("empty"); empty != "" { + if orphan := req.Form["orphan"]; !reflect.DeepEqual(orphan, []string{"", "nope"}) { + t.Errorf(`req.FormValue("orphan") = %q, want "" (from body)`, orphan) + } + if empty := req.Form["empty"]; !reflect.DeepEqual(empty, []string{"", "not"}) { t.Errorf(`req.FormValue("empty") = %q, want "" (from body)`, empty) } + if nokey := req.Form[""]; !reflect.DeepEqual(nokey, []string{"nokey"}) { + t.Errorf(`req.FormValue("nokey") = %q, want "nokey" (from body)`, nokey) + } } -func TestPatchQuery(t *testing.T) { - req, _ := NewRequest("PATCH", "http://www.google.com/search?q=foo&q=bar&both=x&prio=1&empty=not", - strings.NewReader("z=post&both=y&prio=2&empty=")) - req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") - - if q := req.FormValue("q"); q != "foo" { - t.Errorf(`req.FormValue("q") = %q, want "foo"`, q) - } - if z := req.FormValue("z"); z != "post" { - t.Errorf(`req.FormValue("z") = %q, want "post"`, z) - } - if bq, found := req.PostForm["q"]; found { - t.Errorf(`req.PostForm["q"] = %q, want no entry in map`, bq) - } - if bz := req.PostFormValue("z"); bz != "post" { - t.Errorf(`req.PostFormValue("z") = %q, want "post"`, bz) - } - if qs := req.Form["q"]; !reflect.DeepEqual(qs, []string{"foo", "bar"}) { - t.Errorf(`req.Form["q"] = %q, want ["foo", "bar"]`, qs) - } - if both := req.Form["both"]; !reflect.DeepEqual(both, []string{"y", "x"}) { - t.Errorf(`req.Form["both"] = %q, want ["y", "x"]`, both) - } - if prio := req.FormValue("prio"); prio != "2" { - t.Errorf(`req.FormValue("prio") = %q, want "2" (from body)`, prio) - } - if empty := req.FormValue("empty"); empty != "" { - t.Errorf(`req.FormValue("empty") = %q, want "" (from body)`, empty) +// Tests that we only parse the form automatically for certain methods. +func TestParseFormQueryMethods(t *testing.T) { + for _, method := range []string{"POST", "PATCH", "PUT", "FOO"} { + req, _ := NewRequest(method, "http://www.google.com/search", + strings.NewReader("foo=bar")) + req.Header.Set("Content-Type", "application/x-www-form-urlencoded; param=value") + want := "bar" + if method == "FOO" { + want = "" + } + if got := req.FormValue("foo"); got != want { + t.Errorf(`for method %s, FormValue("foo") = %q; want %q`, method, got, want) + } } } @@ -374,18 +365,68 @@ func TestFormFileOrder(t *testing.T) { var readRequestErrorTests = []struct { in string - err error + err string + + header Header }{ - {"GET / HTTP/1.1\r\nheader:foo\r\n\r\n", nil}, - {"GET / HTTP/1.1\r\nheader:foo\r\n", io.ErrUnexpectedEOF}, - {"", io.EOF}, + 0: {"GET / HTTP/1.1\r\nheader:foo\r\n\r\n", "", Header{"Header": {"foo"}}}, + 1: {"GET / HTTP/1.1\r\nheader:foo\r\n", io.ErrUnexpectedEOF.Error(), nil}, + 2: {"", io.EOF.Error(), nil}, + 3: { + in: "HEAD / HTTP/1.1\r\nContent-Length:4\r\n\r\n", + err: "http: method cannot contain a Content-Length", + }, + 4: { + in: "HEAD / HTTP/1.1\r\n\r\n", + header: Header{}, + }, + + // Multiple Content-Length values should either be + // deduplicated if same or reject otherwise + // See Issue 16490. + 5: { + in: "POST / HTTP/1.1\r\nContent-Length: 10\r\nContent-Length: 0\r\n\r\nGopher hey\r\n", + err: "cannot contain multiple Content-Length headers", + }, + 6: { + in: "POST / HTTP/1.1\r\nContent-Length: 10\r\nContent-Length: 6\r\n\r\nGopher\r\n", + err: "cannot contain multiple Content-Length headers", + }, + 7: { + in: "PUT / HTTP/1.1\r\nContent-Length: 6 \r\nContent-Length: 6\r\nContent-Length:6\r\n\r\nGopher\r\n", + err: "", + header: Header{"Content-Length": {"6"}}, + }, + 8: { + in: "PUT / HTTP/1.1\r\nContent-Length: 1\r\nContent-Length: 6 \r\n\r\n", + err: "cannot contain multiple Content-Length headers", + }, + 9: { + in: "POST / HTTP/1.1\r\nContent-Length:\r\nContent-Length: 3\r\n\r\n", + err: "cannot contain multiple Content-Length headers", + }, + 10: { + in: "HEAD / HTTP/1.1\r\nContent-Length:0\r\nContent-Length: 0\r\n\r\n", + header: Header{"Content-Length": {"0"}}, + }, } func TestReadRequestErrors(t *testing.T) { for i, tt := range readRequestErrorTests { - _, err := ReadRequest(bufio.NewReader(strings.NewReader(tt.in))) - if err != tt.err { - t.Errorf("%d. got error = %v; want %v", i, err, tt.err) + req, err := ReadRequest(bufio.NewReader(strings.NewReader(tt.in))) + if err == nil { + if tt.err != "" { + t.Errorf("#%d: got nil err; want %q", i, tt.err) + } + + if !reflect.DeepEqual(tt.header, req.Header) { + t.Errorf("#%d: gotHeader: %q wantHeader: %q", i, req.Header, tt.header) + } + continue + } + + if tt.err == "" || !strings.Contains(err.Error(), tt.err) { + t.Errorf("%d: got error = %v; want %v", i, err, tt.err) } } } @@ -456,18 +497,23 @@ func TestNewRequestContentLength(t *testing.T) { {bytes.NewReader([]byte("123")), 3}, {bytes.NewBuffer([]byte("1234")), 4}, {strings.NewReader("12345"), 5}, - // Not detected: + {strings.NewReader(""), 0}, + {NoBody, 0}, + + // Not detected. During Go 1.8 we tried to make these set to -1, but + // due to Issue 18117, we keep these returning 0, even though they're + // unknown. {struct{ io.Reader }{strings.NewReader("xyz")}, 0}, {io.NewSectionReader(strings.NewReader("x"), 0, 6), 0}, {readByte(io.NewSectionReader(strings.NewReader("xy"), 0, 6)), 0}, } - for _, tt := range tests { + for i, tt := range tests { req, err := NewRequest("POST", "http://localhost/", tt.r) if err != nil { t.Fatal(err) } if req.ContentLength != tt.want { - t.Errorf("ContentLength(%T) = %d; want %d", tt.r, req.ContentLength, tt.want) + t.Errorf("test[%d]: ContentLength(%T) = %d; want %d", i, tt.r, req.ContentLength, tt.want) } } } @@ -626,11 +672,31 @@ func TestStarRequest(t *testing.T) { if err != nil { return } + if req.ContentLength != 0 { + t.Errorf("ContentLength = %d; want 0", req.ContentLength) + } + if req.Body == nil { + t.Errorf("Body = nil; want non-nil") + } + + // Request.Write has Client semantics for Body/ContentLength, + // where ContentLength 0 means unknown if Body is non-nil, and + // thus chunking will happen unless we change semantics and + // signal that we want to serialize it as exactly zero. The + // only way to do that for outbound requests is with a nil + // Body: + clientReq := *req + clientReq.Body = nil + var out bytes.Buffer - if err := req.Write(&out); err != nil { + if err := clientReq.Write(&out); err != nil { t.Fatal(err) } - back, err := ReadRequest(bufio.NewReader(&out)) + + if strings.Contains(out.String(), "chunked") { + t.Error("wrote chunked request; want no body") + } + back, err := ReadRequest(bufio.NewReader(bytes.NewReader(out.Bytes()))) if err != nil { t.Fatal(err) } @@ -719,6 +785,47 @@ func TestMaxBytesReaderStickyError(t *testing.T) { } } +// verify that NewRequest sets Request.GetBody and that it works +func TestNewRequestGetBody(t *testing.T) { + tests := []struct { + r io.Reader + }{ + {r: strings.NewReader("hello")}, + {r: bytes.NewReader([]byte("hello"))}, + {r: bytes.NewBuffer([]byte("hello"))}, + } + for i, tt := range tests { + req, err := NewRequest("POST", "http://foo.tld/", tt.r) + if err != nil { + t.Errorf("test[%d]: %v", i, err) + continue + } + if req.Body == nil { + t.Errorf("test[%d]: Body = nil", i) + continue + } + if req.GetBody == nil { + t.Errorf("test[%d]: GetBody = nil", i) + continue + } + slurp1, err := ioutil.ReadAll(req.Body) + if err != nil { + t.Errorf("test[%d]: ReadAll(Body) = %v", i, err) + } + newBody, err := req.GetBody() + if err != nil { + t.Errorf("test[%d]: GetBody = %v", i, err) + } + slurp2, err := ioutil.ReadAll(newBody) + if err != nil { + t.Errorf("test[%d]: ReadAll(GetBody()) = %v", i, err) + } + if string(slurp1) != string(slurp2) { + t.Errorf("test[%d]: Body %q != GetBody %q", i, slurp1, slurp2) + } + } +} + func testMissingFile(t *testing.T, req *Request) { f, fh, err := req.FormFile("missing") if f != nil { diff --git a/libgo/go/net/http/requestwrite_test.go b/libgo/go/net/http/requestwrite_test.go index 2545f6f4c22..eb65b9f736f 100644 --- a/libgo/go/net/http/requestwrite_test.go +++ b/libgo/go/net/http/requestwrite_test.go @@ -5,14 +5,17 @@ package http import ( + "bufio" "bytes" "errors" "fmt" "io" "io/ioutil" + "net" "net/url" "strings" "testing" + "time" ) type reqWriteTest struct { @@ -28,7 +31,7 @@ type reqWriteTest struct { var reqWriteTests = []reqWriteTest{ // HTTP/1.1 => chunked coding; no body; no trailer - { + 0: { Req: Request{ Method: "GET", URL: &url.URL{ @@ -75,7 +78,7 @@ var reqWriteTests = []reqWriteTest{ "Proxy-Connection: keep-alive\r\n\r\n", }, // HTTP/1.1 => chunked coding; body; empty trailer - { + 1: { Req: Request{ Method: "GET", URL: &url.URL{ @@ -104,7 +107,7 @@ var reqWriteTests = []reqWriteTest{ chunk("abcdef") + chunk(""), }, // HTTP/1.1 POST => chunked coding; body; empty trailer - { + 2: { Req: Request{ Method: "POST", URL: &url.URL{ @@ -137,7 +140,7 @@ var reqWriteTests = []reqWriteTest{ }, // HTTP/1.1 POST with Content-Length, no chunking - { + 3: { Req: Request{ Method: "POST", URL: &url.URL{ @@ -172,7 +175,7 @@ var reqWriteTests = []reqWriteTest{ }, // HTTP/1.1 POST with Content-Length in headers - { + 4: { Req: Request{ Method: "POST", URL: mustParseURL("http://example.com/"), @@ -201,7 +204,7 @@ var reqWriteTests = []reqWriteTest{ }, // default to HTTP/1.1 - { + 5: { Req: Request{ Method: "GET", URL: mustParseURL("/search"), @@ -215,7 +218,7 @@ var reqWriteTests = []reqWriteTest{ }, // Request with a 0 ContentLength and a 0 byte body. - { + 6: { Req: Request{ Method: "POST", URL: mustParseURL("/"), @@ -227,9 +230,32 @@ var reqWriteTests = []reqWriteTest{ Body: func() io.ReadCloser { return ioutil.NopCloser(io.LimitReader(strings.NewReader("xx"), 0)) }, - // RFC 2616 Section 14.13 says Content-Length should be specified - // unless body is prohibited by the request method. - // Also, nginx expects it for POST and PUT. + WantWrite: "POST / HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "User-Agent: Go-http-client/1.1\r\n" + + "Transfer-Encoding: chunked\r\n" + + "\r\n0\r\n\r\n", + + WantProxy: "POST / HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "User-Agent: Go-http-client/1.1\r\n" + + "Transfer-Encoding: chunked\r\n" + + "\r\n0\r\n\r\n", + }, + + // Request with a 0 ContentLength and a nil body. + 7: { + Req: Request{ + Method: "POST", + URL: mustParseURL("/"), + Host: "example.com", + ProtoMajor: 1, + ProtoMinor: 1, + ContentLength: 0, // as if unset by user + }, + + Body: func() io.ReadCloser { return nil }, + WantWrite: "POST / HTTP/1.1\r\n" + "Host: example.com\r\n" + "User-Agent: Go-http-client/1.1\r\n" + @@ -244,7 +270,7 @@ var reqWriteTests = []reqWriteTest{ }, // Request with a 0 ContentLength and a 1 byte body. - { + 8: { Req: Request{ Method: "POST", URL: mustParseURL("/"), @@ -270,7 +296,7 @@ var reqWriteTests = []reqWriteTest{ }, // Request with a ContentLength of 10 but a 5 byte body. - { + 9: { Req: Request{ Method: "POST", URL: mustParseURL("/"), @@ -284,7 +310,7 @@ var reqWriteTests = []reqWriteTest{ }, // Request with a ContentLength of 4 but an 8 byte body. - { + 10: { Req: Request{ Method: "POST", URL: mustParseURL("/"), @@ -298,7 +324,7 @@ var reqWriteTests = []reqWriteTest{ }, // Request with a 5 ContentLength and nil body. - { + 11: { Req: Request{ Method: "POST", URL: mustParseURL("/"), @@ -311,7 +337,7 @@ var reqWriteTests = []reqWriteTest{ }, // Request with a 0 ContentLength and a body with 1 byte content and an error. - { + 12: { Req: Request{ Method: "POST", URL: mustParseURL("/"), @@ -331,7 +357,7 @@ var reqWriteTests = []reqWriteTest{ }, // Request with a 0 ContentLength and a body without content and an error. - { + 13: { Req: Request{ Method: "POST", URL: mustParseURL("/"), @@ -352,7 +378,7 @@ var reqWriteTests = []reqWriteTest{ // Verify that DumpRequest preserves the HTTP version number, doesn't add a Host, // and doesn't add a User-Agent. - { + 14: { Req: Request{ Method: "GET", URL: mustParseURL("/foo"), @@ -373,7 +399,7 @@ var reqWriteTests = []reqWriteTest{ // an empty Host header, and don't use // Request.Header["Host"]. This is just testing that // we don't change Go 1.0 behavior. - { + 15: { Req: Request{ Method: "GET", Host: "", @@ -395,7 +421,7 @@ var reqWriteTests = []reqWriteTest{ }, // Opaque test #1 from golang.org/issue/4860 - { + 16: { Req: Request{ Method: "GET", URL: &url.URL{ @@ -414,7 +440,7 @@ var reqWriteTests = []reqWriteTest{ }, // Opaque test #2 from golang.org/issue/4860 - { + 17: { Req: Request{ Method: "GET", URL: &url.URL{ @@ -433,7 +459,7 @@ var reqWriteTests = []reqWriteTest{ }, // Testing custom case in header keys. Issue 5022. - { + 18: { Req: Request{ Method: "GET", URL: &url.URL{ @@ -457,7 +483,7 @@ var reqWriteTests = []reqWriteTest{ }, // Request with host header field; IPv6 address with zone identifier - { + 19: { Req: Request{ Method: "GET", URL: &url.URL{ @@ -472,7 +498,7 @@ var reqWriteTests = []reqWriteTest{ }, // Request with optional host header field; IPv6 address with zone identifier - { + 20: { Req: Request{ Method: "GET", URL: &url.URL{ @@ -543,6 +569,138 @@ func TestRequestWrite(t *testing.T) { } } +func TestRequestWriteTransport(t *testing.T) { + t.Parallel() + + matchSubstr := func(substr string) func(string) error { + return func(written string) error { + if !strings.Contains(written, substr) { + return fmt.Errorf("expected substring %q in request: %s", substr, written) + } + return nil + } + } + + noContentLengthOrTransferEncoding := func(req string) error { + if strings.Contains(req, "Content-Length: ") { + return fmt.Errorf("unexpected Content-Length in request: %s", req) + } + if strings.Contains(req, "Transfer-Encoding: ") { + return fmt.Errorf("unexpected Transfer-Encoding in request: %s", req) + } + return nil + } + + all := func(checks ...func(string) error) func(string) error { + return func(req string) error { + for _, c := range checks { + if err := c(req); err != nil { + return err + } + } + return nil + } + } + + type testCase struct { + method string + clen int64 // ContentLength + body io.ReadCloser + want func(string) error + + // optional: + init func(*testCase) + afterReqRead func() + } + + tests := []testCase{ + { + method: "GET", + want: noContentLengthOrTransferEncoding, + }, + { + method: "GET", + body: ioutil.NopCloser(strings.NewReader("")), + want: noContentLengthOrTransferEncoding, + }, + { + method: "GET", + clen: -1, + body: ioutil.NopCloser(strings.NewReader("")), + want: noContentLengthOrTransferEncoding, + }, + // A GET with a body, with explicit content length: + { + method: "GET", + clen: 7, + body: ioutil.NopCloser(strings.NewReader("foobody")), + want: all(matchSubstr("Content-Length: 7"), + matchSubstr("foobody")), + }, + // A GET with a body, sniffing the leading "f" from "foobody". + { + method: "GET", + clen: -1, + body: ioutil.NopCloser(strings.NewReader("foobody")), + want: all(matchSubstr("Transfer-Encoding: chunked"), + matchSubstr("\r\n1\r\nf\r\n"), + matchSubstr("oobody")), + }, + // But a POST request is expected to have a body, so + // no sniffing happens: + { + method: "POST", + clen: -1, + body: ioutil.NopCloser(strings.NewReader("foobody")), + want: all(matchSubstr("Transfer-Encoding: chunked"), + matchSubstr("foobody")), + }, + { + method: "POST", + clen: -1, + body: ioutil.NopCloser(strings.NewReader("")), + want: all(matchSubstr("Transfer-Encoding: chunked")), + }, + // Verify that a blocking Request.Body doesn't block forever. + { + method: "GET", + clen: -1, + init: func(tt *testCase) { + pr, pw := io.Pipe() + tt.afterReqRead = func() { + pw.Close() + } + tt.body = ioutil.NopCloser(pr) + }, + want: matchSubstr("Transfer-Encoding: chunked"), + }, + } + + for i, tt := range tests { + if tt.init != nil { + tt.init(&tt) + } + req := &Request{ + Method: tt.method, + URL: &url.URL{ + Scheme: "http", + Host: "example.com", + }, + Header: make(Header), + ContentLength: tt.clen, + Body: tt.body, + } + got, err := dumpRequestOut(req, tt.afterReqRead) + if err != nil { + t.Errorf("test[%d]: %v", i, err) + continue + } + if err := tt.want(string(got)); err != nil { + t.Errorf("test[%d]: %v", i, err) + } + } +} + type closeChecker struct { io.Reader closed bool @@ -553,17 +711,19 @@ func (rc *closeChecker) Close() error { return nil } -// TestRequestWriteClosesBody tests that Request.Write does close its request.Body. +// TestRequestWriteClosesBody tests that Request.Write closes its request.Body. // It also indirectly tests NewRequest and that it doesn't wrap an existing Closer // inside a NopCloser, and that it serializes it correctly. func TestRequestWriteClosesBody(t *testing.T) { rc := &closeChecker{Reader: strings.NewReader("my body")} - req, _ := NewRequest("POST", "http://foo.com/", rc) - if req.ContentLength != 0 { - t.Errorf("got req.ContentLength %d, want 0", req.ContentLength) + req, err := NewRequest("POST", "http://foo.com/", rc) + if err != nil { + t.Fatal(err) } buf := new(bytes.Buffer) - req.Write(buf) + if err := req.Write(buf); err != nil { + t.Error(err) + } if !rc.closed { t.Error("body not closed after write") } @@ -571,12 +731,7 @@ func TestRequestWriteClosesBody(t *testing.T) { "Host: foo.com\r\n" + "User-Agent: Go-http-client/1.1\r\n" + "Transfer-Encoding: chunked\r\n\r\n" + - // TODO: currently we don't buffer before chunking, so we get a - // single "m" chunk before the other chunks, as this was the 1-byte - // read from our MultiReader where we stitched the Body back together - // after sniffing whether the Body was 0 bytes or not. - chunk("m") + - chunk("y body") + + chunk("my body") + chunk("") if buf.String() != expected { t.Errorf("write:\n got: %s\nwant: %s", buf.String(), expected) @@ -652,3 +807,76 @@ func TestRequestWriteError(t *testing.T) { t.Fatalf("writeCalls constant is outdated in test") } } + +// dumpRequestOut is a modified copy of net/http/httputil.DumpRequestOut. +// Unlike the original, this version doesn't mutate the req.Body and +// try to restore it. It always dumps the whole body. +// And it doesn't support https. +func dumpRequestOut(req *Request, onReadHeaders func()) ([]byte, error) { + + // Use the actual Transport code to record what we would send + // on the wire, but not using TCP. Use a Transport with a + // custom dialer that returns a fake net.Conn that waits + // for the full input (and recording it), and then responds + // with a dummy response. + var buf bytes.Buffer // records the output + pr, pw := io.Pipe() + defer pr.Close() + defer pw.Close() + dr := &delegateReader{c: make(chan io.Reader)} + + t := &Transport{ + Dial: func(net, addr string) (net.Conn, error) { + return &dumpConn{io.MultiWriter(&buf, pw), dr}, nil + }, + } + defer t.CloseIdleConnections() + + // Wait for the request before replying with a dummy response: + go func() { + req, err := ReadRequest(bufio.NewReader(pr)) + if err == nil { + if onReadHeaders != nil { + onReadHeaders() + } + // Ensure all the body is read; otherwise + // we'll get a partial dump. + io.Copy(ioutil.Discard, req.Body) + req.Body.Close() + } + dr.c <- strings.NewReader("HTTP/1.1 204 No Content\r\nConnection: close\r\n\r\n") + }() + + _, err := t.RoundTrip(req) + if err != nil { + return nil, err + } + return buf.Bytes(), nil +} + +// delegateReader is a reader that delegates to another reader, +// once it arrives on a channel. +type delegateReader struct { + c chan io.Reader + r io.Reader // nil until received from c +} + +func (r *delegateReader) Read(p []byte) (int, error) { + if r.r == nil { + r.r = <-r.c + } + return r.r.Read(p) +} + +// dumpConn is a net.Conn that writes to Writer and reads from Reader. +type dumpConn struct { + io.Writer + io.Reader +} + +func (c *dumpConn) Close() error { return nil } +func (c *dumpConn) LocalAddr() net.Addr { return nil } +func (c *dumpConn) RemoteAddr() net.Addr { return nil } +func (c *dumpConn) SetDeadline(t time.Time) error { return nil } +func (c *dumpConn) SetReadDeadline(t time.Time) error { return nil } +func (c *dumpConn) SetWriteDeadline(t time.Time) error { return nil } diff --git a/libgo/go/net/http/response.go b/libgo/go/net/http/response.go index 5450d50c3ce..ae118fb386d 100644 --- a/libgo/go/net/http/response.go +++ b/libgo/go/net/http/response.go @@ -261,7 +261,7 @@ func (r *Response) Write(w io.Writer) error { if n == 0 { // Reset it to a known zero reader, in case underlying one // is unhappy being read repeatedly. - r1.Body = eofReader + r1.Body = NoBody } else { r1.ContentLength = -1 r1.Body = struct { @@ -300,7 +300,7 @@ func (r *Response) Write(w io.Writer) error { // contentLengthAlreadySent may have been already sent for // POST/PUT requests, even if zero length. See Issue 8180. contentLengthAlreadySent := tw.shouldSendContentLength() - if r1.ContentLength == 0 && !chunked(r1.TransferEncoding) && !contentLengthAlreadySent { + if r1.ContentLength == 0 && !chunked(r1.TransferEncoding) && !contentLengthAlreadySent && bodyAllowedForStatus(r.StatusCode) { if _, err := io.WriteString(w, "Content-Length: 0\r\n"); err != nil { return err } diff --git a/libgo/go/net/http/response_test.go b/libgo/go/net/http/response_test.go index 126da927355..660d51791b7 100644 --- a/libgo/go/net/http/response_test.go +++ b/libgo/go/net/http/response_test.go @@ -589,6 +589,7 @@ var readResponseCloseInMiddleTests = []struct { // reading only part of its contents advances the read to the end of // the request, right up until the next request. func TestReadResponseCloseInMiddle(t *testing.T) { + t.Parallel() for _, test := range readResponseCloseInMiddleTests { fatalf := func(format string, args ...interface{}) { args = append([]interface{}{test.chunked, test.compressed}, args...) @@ -792,6 +793,7 @@ func TestReadResponseErrors(t *testing.T) { type testCase struct { name string // optional, defaults to in in string + header Header wantErr interface{} // nil, err value, or string substring } @@ -817,11 +819,22 @@ func TestReadResponseErrors(t *testing.T) { } } + contentLength := func(status, body string, wantErr interface{}, header Header) testCase { + return testCase{ + name: fmt.Sprintf("status %q %q", status, body), + in: fmt.Sprintf("HTTP/1.1 %s\r\n%s", status, body), + wantErr: wantErr, + header: header, + } + } + + errMultiCL := "message cannot contain multiple Content-Length headers" + tests := []testCase{ - {"", "", io.ErrUnexpectedEOF}, - {"", "HTTP/1.1 301 Moved Permanently\r\nFoo: bar", io.ErrUnexpectedEOF}, - {"", "HTTP/1.1", "malformed HTTP response"}, - {"", "HTTP/2.0", "malformed HTTP response"}, + {"", "", nil, io.ErrUnexpectedEOF}, + {"", "HTTP/1.1 301 Moved Permanently\r\nFoo: bar", nil, io.ErrUnexpectedEOF}, + {"", "HTTP/1.1", nil, "malformed HTTP response"}, + {"", "HTTP/2.0", nil, "malformed HTTP response"}, status("20X Unknown", true), status("abcd Unknown", true), status("二百/两百 OK", true), @@ -846,7 +859,21 @@ func TestReadResponseErrors(t *testing.T) { version("HTTP/A.B", true), version("HTTP/1", true), version("http/1.1", true), + + contentLength("200 OK", "Content-Length: 10\r\nContent-Length: 7\r\n\r\nGopher hey\r\n", errMultiCL, nil), + contentLength("200 OK", "Content-Length: 7\r\nContent-Length: 7\r\n\r\nGophers\r\n", nil, Header{"Content-Length": {"7"}}), + contentLength("201 OK", "Content-Length: 0\r\nContent-Length: 7\r\n\r\nGophers\r\n", errMultiCL, nil), + contentLength("300 OK", "Content-Length: 0\r\nContent-Length: 0 \r\n\r\nGophers\r\n", nil, Header{"Content-Length": {"0"}}), + contentLength("200 OK", "Content-Length:\r\nContent-Length:\r\n\r\nGophers\r\n", nil, nil), + contentLength("206 OK", "Content-Length:\r\nContent-Length: 0 \r\nConnection: close\r\n\r\nGophers\r\n", errMultiCL, nil), + + // multiple content-length headers for 204 and 304 should still be checked + contentLength("204 OK", "Content-Length: 7\r\nContent-Length: 8\r\n\r\n", errMultiCL, nil), + contentLength("204 OK", "Content-Length: 3\r\nContent-Length: 3\r\n\r\n", nil, nil), + contentLength("304 OK", "Content-Length: 880\r\nContent-Length: 1\r\n\r\n", errMultiCL, nil), + contentLength("304 OK", "Content-Length: 961\r\nContent-Length: 961\r\n\r\n", nil, nil), } + for i, tt := range tests { br := bufio.NewReader(strings.NewReader(tt.in)) _, rerr := ReadResponse(br, nil) diff --git a/libgo/go/net/http/responsewrite_test.go b/libgo/go/net/http/responsewrite_test.go index 90f6767d96b..d41d89896ef 100644 --- a/libgo/go/net/http/responsewrite_test.go +++ b/libgo/go/net/http/responsewrite_test.go @@ -241,7 +241,8 @@ func TestResponseWrite(t *testing.T) { "HTTP/1.0 007 license to violate specs\r\nContent-Length: 0\r\n\r\n", }, - // No stutter. + // No stutter. Status code in 1xx range response should + // not include a Content-Length header. See issue #16942. { Response{ StatusCode: 123, @@ -253,7 +254,23 @@ func TestResponseWrite(t *testing.T) { Body: nil, }, - "HTTP/1.0 123 Sesame Street\r\nContent-Length: 0\r\n\r\n", + "HTTP/1.0 123 Sesame Street\r\n\r\n", + }, + + // Status code 204 (No content) response should not include a + // Content-Length header. See issue #16942. + { + Response{ + StatusCode: 204, + Status: "No Content", + ProtoMajor: 1, + ProtoMinor: 0, + Request: dummyReq("GET"), + Header: Header{}, + Body: nil, + }, + + "HTTP/1.0 204 No Content\r\n\r\n", }, } diff --git a/libgo/go/net/http/serve_test.go b/libgo/go/net/http/serve_test.go index 13e5f283e4c..072da2552bc 100644 --- a/libgo/go/net/http/serve_test.go +++ b/libgo/go/net/http/serve_test.go @@ -156,6 +156,7 @@ func (ht handlerTest) rawResponse(req string) string { } func TestConsumingBodyOnNextConn(t *testing.T) { + t.Parallel() defer afterTest(t) conn := new(testConn) for i := 0; i < 2; i++ { @@ -237,6 +238,7 @@ var vtests = []struct { } func TestHostHandlers(t *testing.T) { + setParallel(t) defer afterTest(t) mux := NewServeMux() for _, h := range handlers { @@ -353,6 +355,7 @@ var serveMuxTests = []struct { } func TestServeMuxHandler(t *testing.T) { + setParallel(t) mux := NewServeMux() for _, e := range serveMuxRegister { mux.Handle(e.pattern, e.h) @@ -390,15 +393,16 @@ var serveMuxTests2 = []struct { // TestServeMuxHandlerRedirects tests that automatic redirects generated by // mux.Handler() shouldn't clear the request's query string. func TestServeMuxHandlerRedirects(t *testing.T) { + setParallel(t) mux := NewServeMux() for _, e := range serveMuxRegister { mux.Handle(e.pattern, e.h) } for _, tt := range serveMuxTests2 { - tries := 1 + tries := 1 // expect at most 1 redirection if redirOk is true. turl := tt.url - for tries > 0 { + for { u, e := url.Parse(turl) if e != nil { t.Fatal(e) @@ -432,6 +436,7 @@ func TestServeMuxHandlerRedirects(t *testing.T) { // Tests for https://golang.org/issue/900 func TestMuxRedirectLeadingSlashes(t *testing.T) { + setParallel(t) paths := []string{"//foo.txt", "///foo.txt", "/../../foo.txt"} for _, path := range paths { req, err := ReadRequest(bufio.NewReader(strings.NewReader("GET " + path + " HTTP/1.1\r\nHost: test\r\n\r\n"))) @@ -456,9 +461,6 @@ func TestMuxRedirectLeadingSlashes(t *testing.T) { } func TestServerTimeouts(t *testing.T) { - if runtime.GOOS == "plan9" { - t.Skip("skipping test; see https://golang.org/issue/7237") - } setParallel(t) defer afterTest(t) reqNum := 0 @@ -479,11 +481,11 @@ func TestServerTimeouts(t *testing.T) { if err != nil { t.Fatalf("http Get #1: %v", err) } - got, _ := ioutil.ReadAll(r.Body) + got, err := ioutil.ReadAll(r.Body) expected := "req=1" - if string(got) != expected { - t.Errorf("Unexpected response for request #1; got %q; expected %q", - string(got), expected) + if string(got) != expected || err != nil { + t.Errorf("Unexpected response for request #1; got %q ,%v; expected %q, nil", + string(got), err, expected) } // Slow client that should timeout. @@ -494,6 +496,7 @@ func TestServerTimeouts(t *testing.T) { } buf := make([]byte, 1) n, err := conn.Read(buf) + conn.Close() latency := time.Since(t1) if n != 0 || err != io.EOF { t.Errorf("Read = %v, %v, wanted %v, %v", n, err, 0, io.EOF) @@ -505,14 +508,14 @@ func TestServerTimeouts(t *testing.T) { // Hit the HTTP server successfully again, verifying that the // previous slow connection didn't run our handler. (that we // get "req=2", not "req=3") - r, err = Get(ts.URL) + r, err = c.Get(ts.URL) if err != nil { t.Fatalf("http Get #2: %v", err) } - got, _ = ioutil.ReadAll(r.Body) + got, err = ioutil.ReadAll(r.Body) expected = "req=2" - if string(got) != expected { - t.Errorf("Get #2 got %q, want %q", string(got), expected) + if string(got) != expected || err != nil { + t.Errorf("Get #2 got %q, %v, want %q, nil", string(got), err, expected) } if !testing.Short() { @@ -532,13 +535,61 @@ func TestServerTimeouts(t *testing.T) { } } +// Test that the HTTP/2 server handles Server.WriteTimeout (Issue 18437) +func TestHTTP2WriteDeadlineExtendedOnNewRequest(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + setParallel(t) + defer afterTest(t) + ts := httptest.NewUnstartedServer(HandlerFunc(func(res ResponseWriter, req *Request) {})) + ts.Config.WriteTimeout = 250 * time.Millisecond + ts.TLS = &tls.Config{NextProtos: []string{"h2"}} + ts.StartTLS() + defer ts.Close() + + tr := newTLSTransport(t, ts) + defer tr.CloseIdleConnections() + if err := ExportHttp2ConfigureTransport(tr); err != nil { + t.Fatal(err) + } + c := &Client{Transport: tr} + + for i := 1; i <= 3; i++ { + req, err := NewRequest("GET", ts.URL, nil) + if err != nil { + t.Fatal(err) + } + + // fail test if no response after 1 second + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + req = req.WithContext(ctx) + + r, err := c.Do(req) + select { + case <-ctx.Done(): + if ctx.Err() == context.DeadlineExceeded { + t.Fatalf("http2 Get #%d response timed out", i) + } + default: + } + if err != nil { + t.Fatalf("http2 Get #%d: %v", i, err) + } + r.Body.Close() + if r.ProtoMajor != 2 { + t.Fatalf("http2 Get expected HTTP/2.0, got %q", r.Proto) + } + time.Sleep(ts.Config.WriteTimeout / 2) + } +} + // golang.org/issue/4741 -- setting only a write timeout that triggers // shouldn't cause a handler to block forever on reads (next HTTP // request) that will never happen. func TestOnlyWriteTimeout(t *testing.T) { - if runtime.GOOS == "plan9" { - t.Skip("skipping test; see https://golang.org/issue/7237") - } + setParallel(t) defer afterTest(t) var conn net.Conn var afterTimeoutErrc = make(chan error, 1) @@ -598,6 +649,7 @@ func (l trackLastConnListener) Accept() (c net.Conn, err error) { // TestIdentityResponse verifies that a handler can unset func TestIdentityResponse(t *testing.T) { + setParallel(t) defer afterTest(t) handler := HandlerFunc(func(rw ResponseWriter, req *Request) { rw.Header().Set("Content-Length", "3") @@ -619,13 +671,16 @@ func TestIdentityResponse(t *testing.T) { ts := httptest.NewServer(handler) defer ts.Close() + c := &Client{Transport: new(Transport)} + defer closeClient(c) + // Note: this relies on the assumption (which is true) that // Get sends HTTP/1.1 or greater requests. Otherwise the // server wouldn't have the choice to send back chunked // responses. for _, te := range []string{"", "identity"} { url := ts.URL + "/?te=" + te - res, err := Get(url) + res, err := c.Get(url) if err != nil { t.Fatalf("error with Get of %s: %v", url, err) } @@ -644,7 +699,7 @@ func TestIdentityResponse(t *testing.T) { // Verify that ErrContentLength is returned url := ts.URL + "/?overwrite=1" - res, err := Get(url) + res, err := c.Get(url) if err != nil { t.Fatalf("error with Get of %s: %v", url, err) } @@ -674,6 +729,7 @@ func TestIdentityResponse(t *testing.T) { } func testTCPConnectionCloses(t *testing.T, req string, h Handler) { + setParallel(t) defer afterTest(t) s := httptest.NewServer(h) defer s.Close() @@ -717,6 +773,7 @@ func testTCPConnectionCloses(t *testing.T, req string, h Handler) { } func testTCPConnectionStaysOpen(t *testing.T, req string, handler Handler) { + setParallel(t) defer afterTest(t) ts := httptest.NewServer(handler) defer ts.Close() @@ -750,7 +807,7 @@ func TestServeHTTP10Close(t *testing.T) { // TestClientCanClose verifies that clients can also force a connection to close. func TestClientCanClose(t *testing.T) { - testTCPConnectionCloses(t, "GET / HTTP/1.1\r\nConnection: close\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) { + testTCPConnectionCloses(t, "GET / HTTP/1.1\r\nHost: foo\r\nConnection: close\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) { // Nothing. })) } @@ -758,7 +815,7 @@ func TestClientCanClose(t *testing.T) { // TestHandlersCanSetConnectionClose verifies that handlers can force a connection to close, // even for HTTP/1.1 requests. func TestHandlersCanSetConnectionClose11(t *testing.T) { - testTCPConnectionCloses(t, "GET / HTTP/1.1\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) { + testTCPConnectionCloses(t, "GET / HTTP/1.1\r\nHost: foo\r\n\r\n\r\n", HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Connection", "close") })) } @@ -796,6 +853,7 @@ func TestHTTP10KeepAlive304Response(t *testing.T) { // Issue 15703 func TestKeepAliveFinalChunkWithEOF(t *testing.T) { + setParallel(t) defer afterTest(t) cst := newClientServerTest(t, false /* h1 */, HandlerFunc(func(w ResponseWriter, r *Request) { w.(Flusher).Flush() // force chunked encoding @@ -828,6 +886,7 @@ func TestSetsRemoteAddr_h1(t *testing.T) { testSetsRemoteAddr(t, h1Mode) } func TestSetsRemoteAddr_h2(t *testing.T) { testSetsRemoteAddr(t, h2Mode) } func testSetsRemoteAddr(t *testing.T, h2 bool) { + setParallel(t) defer afterTest(t) cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "%s", r.RemoteAddr) @@ -877,6 +936,7 @@ func (c *blockingRemoteAddrConn) RemoteAddr() net.Addr { // Issue 12943 func TestServerAllowsBlockingRemoteAddr(t *testing.T) { + setParallel(t) defer afterTest(t) ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { fmt.Fprintf(w, "RA:%s", r.RemoteAddr) @@ -948,7 +1008,9 @@ func TestServerAllowsBlockingRemoteAddr(t *testing.T) { t.Fatalf("response 1 addr = %q; want %q", g, e) } } + func TestIdentityResponseHeaders(t *testing.T) { + // Not parallel; changes log output. defer afterTest(t) log.SetOutput(ioutil.Discard) // is noisy otherwise defer log.SetOutput(os.Stderr) @@ -960,7 +1022,10 @@ func TestIdentityResponseHeaders(t *testing.T) { })) defer ts.Close() - res, err := Get(ts.URL) + c := &Client{Transport: new(Transport)} + defer closeClient(c) + + res, err := c.Get(ts.URL) if err != nil { t.Fatalf("Get error: %v", err) } @@ -983,6 +1048,7 @@ func TestHeadResponses_h1(t *testing.T) { testHeadResponses(t, h1Mode) } func TestHeadResponses_h2(t *testing.T) { testHeadResponses(t, h2Mode) } func testHeadResponses(t *testing.T, h2 bool) { + setParallel(t) defer afterTest(t) cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { _, err := w.Write([]byte("<html>")) @@ -1020,9 +1086,6 @@ func testHeadResponses(t *testing.T, h2 bool) { } func TestTLSHandshakeTimeout(t *testing.T) { - if runtime.GOOS == "plan9" { - t.Skip("skipping test; see https://golang.org/issue/7237") - } setParallel(t) defer afterTest(t) ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) @@ -1054,6 +1117,7 @@ func TestTLSHandshakeTimeout(t *testing.T) { } func TestTLSServer(t *testing.T) { + setParallel(t) defer afterTest(t) ts := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) { if r.TLS != nil { @@ -1121,6 +1185,7 @@ func TestAutomaticHTTP2_Serve_H2TLSConfig(t *testing.T) { } func testAutomaticHTTP2_Serve(t *testing.T, tlsConf *tls.Config, wantH2 bool) { + setParallel(t) defer afterTest(t) ln := newLocalListener(t) ln.Close() // immediately (not a defer!) @@ -1136,6 +1201,7 @@ func testAutomaticHTTP2_Serve(t *testing.T, tlsConf *tls.Config, wantH2 bool) { } func TestAutomaticHTTP2_Serve_WithTLSConfig(t *testing.T) { + setParallel(t) defer afterTest(t) ln := newLocalListener(t) ln.Close() // immediately (not a defer!) @@ -1177,6 +1243,7 @@ func TestAutomaticHTTP2_ListenAndServe_GetCertificate(t *testing.T) { } func testAutomaticHTTP2_ListenAndServe(t *testing.T, tlsConf *tls.Config) { + // Not parallel: uses global test hooks. defer afterTest(t) defer SetTestHookServerServe(nil) var ok bool @@ -1280,6 +1347,7 @@ var serverExpectTests = []serverExpectTest{ // correctly. // http2 test: TestServer_Response_Automatic100Continue func TestServerExpect(t *testing.T) { + setParallel(t) defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { // Note using r.FormValue("readbody") because for POST @@ -1373,6 +1441,7 @@ func TestServerExpect(t *testing.T) { // Under a ~256KB (maxPostHandlerReadBytes) threshold, the server // should consume client request bodies that a handler didn't read. func TestServerUnreadRequestBodyLittle(t *testing.T) { + setParallel(t) defer afterTest(t) conn := new(testConn) body := strings.Repeat("x", 100<<10) @@ -1413,6 +1482,7 @@ func TestServerUnreadRequestBodyLittle(t *testing.T) { // should ignore client request bodies that a handler didn't read // and close the connection. func TestServerUnreadRequestBodyLarge(t *testing.T) { + setParallel(t) if testing.Short() && testenv.Builder() == "" { t.Log("skipping in short mode") } @@ -1546,6 +1616,7 @@ var handlerBodyCloseTests = [...]handlerBodyCloseTest{ } func TestHandlerBodyClose(t *testing.T) { + setParallel(t) if testing.Short() && testenv.Builder() == "" { t.Skip("skipping in -short mode") } @@ -1625,6 +1696,7 @@ var testHandlerBodyConsumers = []testHandlerBodyConsumer{ } func TestRequestBodyReadErrorClosesConnection(t *testing.T) { + setParallel(t) defer afterTest(t) for _, handler := range testHandlerBodyConsumers { conn := new(testConn) @@ -1655,6 +1727,7 @@ func TestRequestBodyReadErrorClosesConnection(t *testing.T) { } func TestInvalidTrailerClosesConnection(t *testing.T) { + setParallel(t) defer afterTest(t) for _, handler := range testHandlerBodyConsumers { conn := new(testConn) @@ -1737,7 +1810,7 @@ restart: if !c.rd.IsZero() { // If the deadline falls in the middle of our sleep window, deduct // part of the sleep, then return a timeout. - if remaining := c.rd.Sub(time.Now()); remaining < cue { + if remaining := time.Until(c.rd); remaining < cue { c.script[0] = cue - remaining time.Sleep(remaining) return 0, syscall.ETIMEDOUT @@ -1823,6 +1896,7 @@ func TestRequestBodyTimeoutClosesConnection(t *testing.T) { func TestTimeoutHandler_h1(t *testing.T) { testTimeoutHandler(t, h1Mode) } func TestTimeoutHandler_h2(t *testing.T) { testTimeoutHandler(t, h2Mode) } func testTimeoutHandler(t *testing.T, h2 bool) { + setParallel(t) defer afterTest(t) sendHi := make(chan bool, 1) writeErrors := make(chan error, 1) @@ -1876,6 +1950,7 @@ func testTimeoutHandler(t *testing.T, h2 bool) { // See issues 8209 and 8414. func TestTimeoutHandlerRace(t *testing.T) { + setParallel(t) defer afterTest(t) delayHi := HandlerFunc(func(w ResponseWriter, r *Request) { @@ -1892,6 +1967,9 @@ func TestTimeoutHandlerRace(t *testing.T) { ts := httptest.NewServer(TimeoutHandler(delayHi, 20*time.Millisecond, "")) defer ts.Close() + c := &Client{Transport: new(Transport)} + defer closeClient(c) + var wg sync.WaitGroup gate := make(chan bool, 10) n := 50 @@ -1905,7 +1983,7 @@ func TestTimeoutHandlerRace(t *testing.T) { go func() { defer wg.Done() defer func() { <-gate }() - res, err := Get(fmt.Sprintf("%s/%d", ts.URL, rand.Intn(50))) + res, err := c.Get(fmt.Sprintf("%s/%d", ts.URL, rand.Intn(50))) if err == nil { io.Copy(ioutil.Discard, res.Body) res.Body.Close() @@ -1917,6 +1995,7 @@ func TestTimeoutHandlerRace(t *testing.T) { // See issues 8209 and 8414. func TestTimeoutHandlerRaceHeader(t *testing.T) { + setParallel(t) defer afterTest(t) delay204 := HandlerFunc(func(w ResponseWriter, r *Request) { @@ -1932,13 +2011,15 @@ func TestTimeoutHandlerRaceHeader(t *testing.T) { if testing.Short() { n = 10 } + c := &Client{Transport: new(Transport)} + defer closeClient(c) for i := 0; i < n; i++ { gate <- true wg.Add(1) go func() { defer wg.Done() defer func() { <-gate }() - res, err := Get(ts.URL) + res, err := c.Get(ts.URL) if err != nil { t.Error(err) return @@ -1952,6 +2033,7 @@ func TestTimeoutHandlerRaceHeader(t *testing.T) { // Issue 9162 func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) { + setParallel(t) defer afterTest(t) sendHi := make(chan bool, 1) writeErrors := make(chan error, 1) @@ -2016,11 +2098,15 @@ func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) { timeout := 300 * time.Millisecond ts := httptest.NewServer(TimeoutHandler(handler, timeout, "")) defer ts.Close() + + c := &Client{Transport: new(Transport)} + defer closeClient(c) + // Issue was caused by the timeout handler starting the timer when // was created, not when the request. So wait for more than the timeout // to ensure that's not the case. time.Sleep(2 * timeout) - res, err := Get(ts.URL) + res, err := c.Get(ts.URL) if err != nil { t.Fatal(err) } @@ -2032,6 +2118,7 @@ func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) { // https://golang.org/issue/15948 func TestTimeoutHandlerEmptyResponse(t *testing.T) { + setParallel(t) defer afterTest(t) var handler HandlerFunc = func(w ResponseWriter, _ *Request) { // No response. @@ -2040,7 +2127,10 @@ func TestTimeoutHandlerEmptyResponse(t *testing.T) { ts := httptest.NewServer(TimeoutHandler(handler, timeout, "")) defer ts.Close() - res, err := Get(ts.URL) + c := &Client{Transport: new(Transport)} + defer closeClient(c) + + res, err := c.Get(ts.URL) if err != nil { t.Fatal(err) } @@ -2050,23 +2140,6 @@ func TestTimeoutHandlerEmptyResponse(t *testing.T) { } } -// Verifies we don't path.Clean() on the wrong parts in redirects. -func TestRedirectMunging(t *testing.T) { - req, _ := NewRequest("GET", "http://example.com/", nil) - - resp := httptest.NewRecorder() - Redirect(resp, req, "/foo?next=http://bar.com/", 302) - if g, e := resp.Header().Get("Location"), "/foo?next=http://bar.com/"; g != e { - t.Errorf("Location header was %q; want %q", g, e) - } - - resp = httptest.NewRecorder() - Redirect(resp, req, "http://localhost:8080/_ah/login?continue=http://localhost:8080/", 302) - if g, e := resp.Header().Get("Location"), "http://localhost:8080/_ah/login?continue=http://localhost:8080/"; g != e { - t.Errorf("Location header was %q; want %q", g, e) - } -} - func TestRedirectBadPath(t *testing.T) { // This used to crash. It's not valid input (bad path), but it // shouldn't crash. @@ -2085,7 +2158,7 @@ func TestRedirectBadPath(t *testing.T) { } // Test different URL formats and schemes -func TestRedirectURLFormat(t *testing.T) { +func TestRedirect(t *testing.T) { req, _ := NewRequest("GET", "http://example.com/qux/", nil) var tests = []struct { @@ -2108,6 +2181,14 @@ func TestRedirectURLFormat(t *testing.T) { {"../quux/foobar.com/baz", "/quux/foobar.com/baz"}, // incorrect number of slashes {"///foobar.com/baz", "/foobar.com/baz"}, + + // Verifies we don't path.Clean() on the wrong parts in redirects: + {"/foo?next=http://bar.com/", "/foo?next=http://bar.com/"}, + {"http://localhost:8080/_ah/login?continue=http://localhost:8080/", + "http://localhost:8080/_ah/login?continue=http://localhost:8080/"}, + + {"/фубар", "/%d1%84%d1%83%d0%b1%d0%b0%d1%80"}, + {"http://foo.com/фубар", "http://foo.com/%d1%84%d1%83%d0%b1%d0%b0%d1%80"}, } for _, tt := range tests { @@ -2133,6 +2214,7 @@ func TestZeroLengthPostAndResponse_h2(t *testing.T) { } func testZeroLengthPostAndResponse(t *testing.T, h2 bool) { + setParallel(t) defer afterTest(t) cst := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, r *Request) { all, err := ioutil.ReadAll(r.Body) @@ -2252,12 +2334,58 @@ func testHandlerPanic(t *testing.T, withHijack, h2 bool, panicValue interface{}) } } +type terrorWriter struct{ t *testing.T } + +func (w terrorWriter) Write(p []byte) (int, error) { + w.t.Errorf("%s", p) + return len(p), nil +} + +// Issue 16456: allow writing 0 bytes on hijacked conn to test hijack +// without any log spam. +func TestServerWriteHijackZeroBytes(t *testing.T) { + defer afterTest(t) + done := make(chan struct{}) + ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + defer close(done) + w.(Flusher).Flush() + conn, _, err := w.(Hijacker).Hijack() + if err != nil { + t.Errorf("Hijack: %v", err) + return + } + defer conn.Close() + _, err = w.Write(nil) + if err != ErrHijacked { + t.Errorf("Write error = %v; want ErrHijacked", err) + } + })) + ts.Config.ErrorLog = log.New(terrorWriter{t}, "Unexpected write: ", 0) + ts.Start() + defer ts.Close() + + tr := &Transport{} + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + res, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + select { + case <-done: + case <-time.After(5 * time.Second): + t.Fatal("timeout") + } +} + func TestServerNoDate_h1(t *testing.T) { testServerNoHeader(t, h1Mode, "Date") } func TestServerNoDate_h2(t *testing.T) { testServerNoHeader(t, h2Mode, "Date") } func TestServerNoContentType_h1(t *testing.T) { testServerNoHeader(t, h1Mode, "Content-Type") } func TestServerNoContentType_h2(t *testing.T) { testServerNoHeader(t, h2Mode, "Content-Type") } func testServerNoHeader(t *testing.T, h2 bool, header string) { + setParallel(t) defer afterTest(t) cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { w.Header()[header] = nil @@ -2275,6 +2403,7 @@ func testServerNoHeader(t *testing.T, h2 bool, header string) { } func TestStripPrefix(t *testing.T) { + setParallel(t) defer afterTest(t) h := HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("X-Path", r.URL.Path) @@ -2282,7 +2411,10 @@ func TestStripPrefix(t *testing.T) { ts := httptest.NewServer(StripPrefix("/foo", h)) defer ts.Close() - res, err := Get(ts.URL + "/foo/bar") + c := &Client{Transport: new(Transport)} + defer closeClient(c) + + res, err := c.Get(ts.URL + "/foo/bar") if err != nil { t.Fatal(err) } @@ -2304,10 +2436,11 @@ func TestStripPrefix(t *testing.T) { func TestRequestLimit_h1(t *testing.T) { testRequestLimit(t, h1Mode) } func TestRequestLimit_h2(t *testing.T) { testRequestLimit(t, h2Mode) } func testRequestLimit(t *testing.T, h2 bool) { + setParallel(t) defer afterTest(t) cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { t.Fatalf("didn't expect to get request in Handler") - })) + }), optQuietLog) defer cst.close() req, _ := NewRequest("GET", cst.ts.URL, nil) var bytesPerHeader = len("header12345: val12345\r\n") @@ -2350,6 +2483,7 @@ func (cr countReader) Read(p []byte) (n int, err error) { func TestRequestBodyLimit_h1(t *testing.T) { testRequestBodyLimit(t, h1Mode) } func TestRequestBodyLimit_h2(t *testing.T) { testRequestBodyLimit(t, h2Mode) } func testRequestBodyLimit(t *testing.T, h2 bool) { + setParallel(t) defer afterTest(t) const limit = 1 << 20 cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { @@ -2399,14 +2533,14 @@ func TestClientWriteShutdown(t *testing.T) { } err = conn.(*net.TCPConn).CloseWrite() if err != nil { - t.Fatalf("Dial: %v", err) + t.Fatalf("CloseWrite: %v", err) } donec := make(chan bool) go func() { defer close(donec) bs, err := ioutil.ReadAll(conn) if err != nil { - t.Fatalf("ReadAll: %v", err) + t.Errorf("ReadAll: %v", err) } got := string(bs) if got != "" { @@ -2445,6 +2579,7 @@ func TestServerBufferedChunking(t *testing.T) { // closing the TCP connection, causing the client to get a RST. // See https://golang.org/issue/3595 func TestServerGracefulClose(t *testing.T) { + setParallel(t) defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { Error(w, "bye", StatusUnauthorized) @@ -2557,7 +2692,8 @@ func TestCloseNotifier(t *testing.T) { go func() { _, err = fmt.Fprintf(conn, "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n") if err != nil { - t.Fatal(err) + t.Error(err) + return } <-diec conn.Close() @@ -2599,7 +2735,8 @@ func TestCloseNotifierPipelined(t *testing.T) { const req = "GET / HTTP/1.1\r\nConnection: keep-alive\r\nHost: foo\r\n\r\n" _, err = io.WriteString(conn, req+req) // two requests if err != nil { - t.Fatal(err) + t.Error(err) + return } <-diec conn.Close() @@ -2707,6 +2844,7 @@ func TestHijackAfterCloseNotifier(t *testing.T) { } func TestHijackBeforeRequestBodyRead(t *testing.T) { + setParallel(t) defer afterTest(t) var requestBody = bytes.Repeat([]byte("a"), 1<<20) bodyOkay := make(chan bool, 1) @@ -3028,15 +3166,18 @@ func (l *errorListener) Addr() net.Addr { } func TestAcceptMaxFds(t *testing.T) { - log.SetOutput(ioutil.Discard) // is noisy otherwise - defer log.SetOutput(os.Stderr) + setParallel(t) ln := &errorListener{[]error{ &net.OpError{ Op: "accept", Err: syscall.EMFILE, }}} - err := Serve(ln, HandlerFunc(HandlerFunc(func(ResponseWriter, *Request) {}))) + server := &Server{ + Handler: HandlerFunc(HandlerFunc(func(ResponseWriter, *Request) {})), + ErrorLog: log.New(ioutil.Discard, "", 0), // noisy otherwise + } + err := server.Serve(ln) if err != io.EOF { t.Errorf("got error %v, want EOF", err) } @@ -3161,6 +3302,7 @@ func TestHTTP10ConnectionHeader(t *testing.T) { func TestServerReaderFromOrder_h1(t *testing.T) { testServerReaderFromOrder(t, h1Mode) } func TestServerReaderFromOrder_h2(t *testing.T) { testServerReaderFromOrder(t, h2Mode) } func testServerReaderFromOrder(t *testing.T, h2 bool) { + setParallel(t) defer afterTest(t) pr, pw := io.Pipe() const size = 3 << 20 @@ -3265,6 +3407,7 @@ func TestTransportAndServerSharedBodyRace_h2(t *testing.T) { testTransportAndServerSharedBodyRace(t, h2Mode) } func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) { + setParallel(t) defer afterTest(t) const bodySize = 1 << 20 @@ -3453,6 +3596,7 @@ func TestAppendTime(t *testing.T) { } func TestServerConnState(t *testing.T) { + setParallel(t) defer afterTest(t) handler := map[string]func(w ResponseWriter, r *Request){ "/": func(w ResponseWriter, r *Request) { @@ -3500,14 +3644,39 @@ func TestServerConnState(t *testing.T) { } ts.Start() - mustGet(t, ts.URL+"/") - mustGet(t, ts.URL+"/close") + tr := &Transport{} + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + + mustGet := func(url string, headers ...string) { + req, err := NewRequest("GET", url, nil) + if err != nil { + t.Fatal(err) + } + for len(headers) > 0 { + req.Header.Add(headers[0], headers[1]) + headers = headers[2:] + } + res, err := c.Do(req) + if err != nil { + t.Errorf("Error fetching %s: %v", url, err) + return + } + _, err = ioutil.ReadAll(res.Body) + defer res.Body.Close() + if err != nil { + t.Errorf("Error reading %s: %v", url, err) + } + } + + mustGet(ts.URL + "/") + mustGet(ts.URL + "/close") - mustGet(t, ts.URL+"/") - mustGet(t, ts.URL+"/", "Connection", "close") + mustGet(ts.URL + "/") + mustGet(ts.URL+"/", "Connection", "close") - mustGet(t, ts.URL+"/hijack") - mustGet(t, ts.URL+"/hijack-panic") + mustGet(ts.URL + "/hijack") + mustGet(ts.URL + "/hijack-panic") // New->Closed { @@ -3587,31 +3756,10 @@ func TestServerConnState(t *testing.T) { } mu.Lock() - t.Errorf("Unexpected events.\nGot log: %s\n Want: %s\n", logString(stateLog), logString(want)) + t.Errorf("Unexpected events.\nGot log:\n%s\n Want:\n%s\n", logString(stateLog), logString(want)) mu.Unlock() } -func mustGet(t *testing.T, url string, headers ...string) { - req, err := NewRequest("GET", url, nil) - if err != nil { - t.Fatal(err) - } - for len(headers) > 0 { - req.Header.Add(headers[0], headers[1]) - headers = headers[2:] - } - res, err := DefaultClient.Do(req) - if err != nil { - t.Errorf("Error fetching %s: %v", url, err) - return - } - _, err = ioutil.ReadAll(res.Body) - defer res.Body.Close() - if err != nil { - t.Errorf("Error reading %s: %v", url, err) - } -} - func TestServerKeepAlivesEnabled(t *testing.T) { defer afterTest(t) ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) @@ -3632,6 +3780,7 @@ func TestServerKeepAlivesEnabled(t *testing.T) { func TestServerEmptyBodyRace_h1(t *testing.T) { testServerEmptyBodyRace(t, h1Mode) } func TestServerEmptyBodyRace_h2(t *testing.T) { testServerEmptyBodyRace(t, h2Mode) } func testServerEmptyBodyRace(t *testing.T, h2 bool) { + setParallel(t) defer afterTest(t) var n int32 cst := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, req *Request) { @@ -3695,6 +3844,7 @@ func (c *closeWriteTestConn) CloseWrite() error { } func TestCloseWrite(t *testing.T) { + setParallel(t) var srv Server var testConn closeWriteTestConn c := ExportServerNewConn(&srv, &testConn) @@ -3935,6 +4085,7 @@ Host: foo // If a Handler finishes and there's an unread request body, // verify the server try to do implicit read on it before replying. func TestHandlerFinishSkipBigContentLengthRead(t *testing.T) { + setParallel(t) conn := &testConn{closec: make(chan bool)} conn.readBuf.Write([]byte(fmt.Sprintf( "POST / HTTP/1.1\r\n" + @@ -4033,7 +4184,11 @@ func TestServerValidatesHostHeader(t *testing.T) { io.WriteString(&conn.readBuf, methodTarget+tt.proto+"\r\n"+tt.host+"\r\n") ln := &oneConnListener{conn} - go Serve(ln, HandlerFunc(func(ResponseWriter, *Request) {})) + srv := Server{ + ErrorLog: quietLog, + Handler: HandlerFunc(func(ResponseWriter, *Request) {}), + } + go srv.Serve(ln) <-conn.closec res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil) if err != nil { @@ -4088,6 +4243,7 @@ func TestServerHandlersCanHandleH2PRI(t *testing.T) { // Test that we validate the valid bytes in HTTP/1 headers. // Issue 11207. func TestServerValidatesHeaders(t *testing.T) { + setParallel(t) tests := []struct { header string want int @@ -4097,9 +4253,10 @@ func TestServerValidatesHeaders(t *testing.T) { {"X-Foo: bar\r\n", 200}, {"Foo: a space\r\n", 200}, - {"A space: foo\r\n", 400}, // space in header - {"foo\xffbar: foo\r\n", 400}, // binary in header - {"foo\x00bar: foo\r\n", 400}, // binary in header + {"A space: foo\r\n", 400}, // space in header + {"foo\xffbar: foo\r\n", 400}, // binary in header + {"foo\x00bar: foo\r\n", 400}, // binary in header + {"Foo: " + strings.Repeat("x", 1<<21) + "\r\n", 431}, // header too large {"foo: foo foo\r\n", 200}, // LWS space is okay {"foo: foo\tfoo\r\n", 200}, // LWS tab is okay @@ -4112,7 +4269,11 @@ func TestServerValidatesHeaders(t *testing.T) { io.WriteString(&conn.readBuf, "GET / HTTP/1.1\r\nHost: foo\r\n"+tt.header+"\r\n") ln := &oneConnListener{conn} - go Serve(ln, HandlerFunc(func(ResponseWriter, *Request) {})) + srv := Server{ + ErrorLog: quietLog, + Handler: HandlerFunc(func(ResponseWriter, *Request) {}), + } + go srv.Serve(ln) <-conn.closec res, err := ReadResponse(bufio.NewReader(&conn.writeBuf), nil) if err != nil { @@ -4132,6 +4293,7 @@ func TestServerRequestContextCancel_ServeHTTPDone_h2(t *testing.T) { testServerRequestContextCancel_ServeHTTPDone(t, h2Mode) } func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, h2 bool) { + setParallel(t) defer afterTest(t) ctxc := make(chan context.Context, 1) cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { @@ -4157,13 +4319,12 @@ func testServerRequestContextCancel_ServeHTTPDone(t *testing.T, h2 bool) { } } +// Tests that the Request.Context available to the Handler is canceled +// if the peer closes their TCP connection. This requires that the server +// is always blocked in a Read call so it notices the EOF from the client. +// See issues 15927 and 15224. func TestServerRequestContextCancel_ConnClose(t *testing.T) { - // Currently the context is not canceled when the connection - // is closed because we're not reading from the connection - // until after ServeHTTP for the previous handler is done. - // Until the server code is modified to always be in a read - // (Issue 15224), this test doesn't work yet. - t.Skip("TODO(bradfitz): this test doesn't yet work; golang.org/issue/15224") + setParallel(t) defer afterTest(t) inHandler := make(chan struct{}) handlerDone := make(chan struct{}) @@ -4192,7 +4353,7 @@ func TestServerRequestContextCancel_ConnClose(t *testing.T) { select { case <-handlerDone: - case <-time.After(3 * time.Second): + case <-time.After(4 * time.Second): t.Fatalf("timeout waiting to see ServeHTTP exit") } } @@ -4204,6 +4365,7 @@ func TestServerContext_ServerContextKey_h2(t *testing.T) { testServerContext_ServerContextKey(t, h2Mode) } func testServerContext_ServerContextKey(t *testing.T, h2 bool) { + setParallel(t) defer afterTest(t) cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { ctx := r.Context() @@ -4229,6 +4391,7 @@ func testServerContext_ServerContextKey(t *testing.T, h2 bool) { // https://golang.org/issue/15960 func TestHandlerSetTransferEncodingChunked(t *testing.T) { + setParallel(t) defer afterTest(t) ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Transfer-Encoding", "chunked") @@ -4243,6 +4406,7 @@ func TestHandlerSetTransferEncodingChunked(t *testing.T) { // https://golang.org/issue/16063 func TestHandlerSetTransferEncodingGzip(t *testing.T) { + setParallel(t) defer afterTest(t) ht := newHandlerTest(HandlerFunc(func(w ResponseWriter, r *Request) { w.Header().Set("Transfer-Encoding", "gzip") @@ -4416,13 +4580,19 @@ func BenchmarkClient(b *testing.B) { b.StopTimer() defer afterTest(b) - port := os.Getenv("TEST_BENCH_SERVER_PORT") // can be set by user - if port == "" { - port = "39207" - } var data = []byte("Hello world.\n") if server := os.Getenv("TEST_BENCH_SERVER"); server != "" { // Server process mode. + port := os.Getenv("TEST_BENCH_SERVER_PORT") // can be set by user + if port == "" { + port = "0" + } + ln, err := net.Listen("tcp", "localhost:"+port) + if err != nil { + fmt.Fprintln(os.Stderr, err.Error()) + os.Exit(1) + } + fmt.Println(ln.Addr().String()) HandleFunc("/", func(w ResponseWriter, r *Request) { r.ParseForm() if r.Form.Get("stop") != "" { @@ -4431,33 +4601,44 @@ func BenchmarkClient(b *testing.B) { w.Header().Set("Content-Type", "text/html; charset=utf-8") w.Write(data) }) - log.Fatal(ListenAndServe("localhost:"+port, nil)) + var srv Server + log.Fatal(srv.Serve(ln)) } // Start server process. cmd := exec.Command(os.Args[0], "-test.run=XXXX", "-test.bench=BenchmarkClient$") cmd.Env = append(os.Environ(), "TEST_BENCH_SERVER=yes") + cmd.Stderr = os.Stderr + stdout, err := cmd.StdoutPipe() + if err != nil { + b.Fatal(err) + } if err := cmd.Start(); err != nil { b.Fatalf("subprocess failed to start: %v", err) } defer cmd.Process.Kill() + + // Wait for the server in the child process to respond and tell us + // its listening address, once it's started listening: + timer := time.AfterFunc(10*time.Second, func() { + cmd.Process.Kill() + }) + defer timer.Stop() + bs := bufio.NewScanner(stdout) + if !bs.Scan() { + b.Fatalf("failed to read listening URL from child: %v", bs.Err()) + } + url := "http://" + strings.TrimSpace(bs.Text()) + "/" + timer.Stop() + if _, err := getNoBody(url); err != nil { + b.Fatalf("initial probe of child process failed: %v", err) + } + done := make(chan error) go func() { done <- cmd.Wait() }() - // Wait for the server process to respond. - url := "http://localhost:" + port + "/" - for i := 0; i < 100; i++ { - time.Sleep(100 * time.Millisecond) - if _, err := getNoBody(url); err == nil { - break - } - if i == 99 { - b.Fatalf("subprocess does not respond") - } - } - // Do b.N requests to the server. b.StartTimer() for i := 0; i < b.N; i++ { @@ -4719,6 +4900,7 @@ func BenchmarkCloseNotifier(b *testing.B) { // Verify this doesn't race (Issue 16505) func TestConcurrentServerServe(t *testing.T) { + setParallel(t) for i := 0; i < 100; i++ { ln1 := &oneConnListener{conn: nil} ln2 := &oneConnListener{conn: nil} @@ -4727,3 +4909,267 @@ func TestConcurrentServerServe(t *testing.T) { go func() { srv.Serve(ln2) }() } } + +func TestServerIdleTimeout(t *testing.T) { + if testing.Short() { + t.Skip("skipping in short mode") + } + setParallel(t) + defer afterTest(t) + ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + io.Copy(ioutil.Discard, r.Body) + io.WriteString(w, r.RemoteAddr) + })) + ts.Config.ReadHeaderTimeout = 1 * time.Second + ts.Config.IdleTimeout = 2 * time.Second + ts.Start() + defer ts.Close() + + tr := &Transport{} + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + + get := func() string { + res, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + slurp, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + return string(slurp) + } + + a1, a2 := get(), get() + if a1 != a2 { + t.Fatalf("did requests on different connections") + } + time.Sleep(3 * time.Second) + a3 := get() + if a2 == a3 { + t.Fatal("request three unexpectedly on same connection") + } + + // And test that ReadHeaderTimeout still works: + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + defer conn.Close() + conn.Write([]byte("GET / HTTP/1.1\r\nHost: foo.com\r\n")) + time.Sleep(2 * time.Second) + if _, err := io.CopyN(ioutil.Discard, conn, 1); err == nil { + t.Fatal("copy byte succeeded; want err") + } +} + +func get(t *testing.T, c *Client, url string) string { + res, err := c.Get(url) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + slurp, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + return string(slurp) +} + +// Tests that calls to Server.SetKeepAlivesEnabled(false) closes any +// currently-open connections. +func TestServerSetKeepAlivesEnabledClosesConns(t *testing.T) { + setParallel(t) + defer afterTest(t) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + io.WriteString(w, r.RemoteAddr) + })) + defer ts.Close() + + tr := &Transport{} + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + + get := func() string { return get(t, c, ts.URL) } + + a1, a2 := get(), get() + if a1 != a2 { + t.Fatal("expected first two requests on same connection") + } + var idle0 int + if !waitCondition(2*time.Second, 10*time.Millisecond, func() bool { + idle0 = tr.IdleConnKeyCountForTesting() + return idle0 == 1 + }) { + t.Fatalf("idle count before SetKeepAlivesEnabled called = %v; want 1", idle0) + } + + ts.Config.SetKeepAlivesEnabled(false) + + var idle1 int + if !waitCondition(2*time.Second, 10*time.Millisecond, func() bool { + idle1 = tr.IdleConnKeyCountForTesting() + return idle1 == 0 + }) { + t.Fatalf("idle count after SetKeepAlivesEnabled called = %v; want 0", idle1) + } + + a3 := get() + if a3 == a2 { + t.Fatal("expected third request on new connection") + } +} + +func TestServerShutdown_h1(t *testing.T) { testServerShutdown(t, h1Mode) } +func TestServerShutdown_h2(t *testing.T) { testServerShutdown(t, h2Mode) } + +func testServerShutdown(t *testing.T, h2 bool) { + setParallel(t) + defer afterTest(t) + var doShutdown func() // set later + var shutdownRes = make(chan error, 1) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + go doShutdown() + // Shutdown is graceful, so it should not interrupt + // this in-flight response. Add a tiny sleep here to + // increase the odds of a failure if shutdown has + // bugs. + time.Sleep(20 * time.Millisecond) + io.WriteString(w, r.RemoteAddr) + })) + defer cst.close() + + doShutdown = func() { + shutdownRes <- cst.ts.Config.Shutdown(context.Background()) + } + get(t, cst.c, cst.ts.URL) // calls t.Fail on failure + + if err := <-shutdownRes; err != nil { + t.Fatalf("Shutdown: %v", err) + } + + res, err := cst.c.Get(cst.ts.URL) + if err == nil { + res.Body.Close() + t.Fatal("second request should fail. server should be shut down") + } +} + +// Issue 17878: tests that we can call Close twice. +func TestServerCloseDeadlock(t *testing.T) { + var s Server + s.Close() + s.Close() +} + +// Issue 17717: tests that Server.SetKeepAlivesEnabled is respected by +// both HTTP/1 and HTTP/2. +func TestServerKeepAlivesEnabled_h1(t *testing.T) { testServerKeepAlivesEnabled(t, h1Mode) } +func TestServerKeepAlivesEnabled_h2(t *testing.T) { testServerKeepAlivesEnabled(t, h2Mode) } +func testServerKeepAlivesEnabled(t *testing.T, h2 bool) { + setParallel(t) + defer afterTest(t) + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + fmt.Fprintf(w, "%v", r.RemoteAddr) + })) + defer cst.close() + srv := cst.ts.Config + srv.SetKeepAlivesEnabled(false) + a := cst.getURL(cst.ts.URL) + if !waitCondition(2*time.Second, 10*time.Millisecond, srv.ExportAllConnsIdle) { + t.Fatalf("test server has active conns") + } + b := cst.getURL(cst.ts.URL) + if a == b { + t.Errorf("got same connection between first and second requests") + } + if !waitCondition(2*time.Second, 10*time.Millisecond, srv.ExportAllConnsIdle) { + t.Fatalf("test server has active conns") + } +} + +// Issue 18447: test that the Server's ReadTimeout is stopped while +// the server's doing its 1-byte background read between requests, +// waiting for the connection to maybe close. +func TestServerCancelsReadTimeoutWhenIdle(t *testing.T) { + setParallel(t) + defer afterTest(t) + const timeout = 250 * time.Millisecond + ts := httptest.NewUnstartedServer(HandlerFunc(func(w ResponseWriter, r *Request) { + select { + case <-time.After(2 * timeout): + fmt.Fprint(w, "ok") + case <-r.Context().Done(): + fmt.Fprint(w, r.Context().Err()) + } + })) + ts.Config.ReadTimeout = timeout + ts.Start() + defer ts.Close() + + tr := &Transport{} + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + + res, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + slurp, err := ioutil.ReadAll(res.Body) + res.Body.Close() + if err != nil { + t.Fatal(err) + } + if string(slurp) != "ok" { + t.Fatalf("Got: %q, want ok", slurp) + } +} + +// Issue 18535: test that the Server doesn't try to do a background +// read if it's already done one. +func TestServerDuplicateBackgroundRead(t *testing.T) { + setParallel(t) + defer afterTest(t) + + const goroutines = 5 + const requests = 2000 + + hts := httptest.NewServer(HandlerFunc(NotFound)) + defer hts.Close() + + reqBytes := []byte("GET / HTTP/1.1\r\nHost: e.com\r\n\r\n") + + var wg sync.WaitGroup + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + cn, err := net.Dial("tcp", hts.Listener.Addr().String()) + if err != nil { + t.Error(err) + return + } + defer cn.Close() + + wg.Add(1) + go func() { + defer wg.Done() + io.Copy(ioutil.Discard, cn) + }() + + for j := 0; j < requests; j++ { + if t.Failed() { + return + } + _, err := cn.Write(reqBytes) + if err != nil { + t.Error(err) + return + } + } + }() + } + wg.Wait() +} diff --git a/libgo/go/net/http/server.go b/libgo/go/net/http/server.go index 89574a8b36e..96236489bd9 100644 --- a/libgo/go/net/http/server.go +++ b/libgo/go/net/http/server.go @@ -40,7 +40,9 @@ var ( // ErrHijacked is returned by ResponseWriter.Write calls when // the underlying connection has been hijacked using the - // Hijacker interfaced. + // Hijacker interface. A zero-byte write on a hijacked + // connection will return ErrHijacked without any other side + // effects. ErrHijacked = errors.New("http: connection has been hijacked") // ErrContentLength is returned by ResponseWriter.Write calls @@ -73,7 +75,9 @@ var ( // If ServeHTTP panics, the server (the caller of ServeHTTP) assumes // that the effect of the panic was isolated to the active request. // It recovers the panic, logs a stack trace to the server error log, -// and hangs up the connection. +// and hangs up the connection. To abort a handler so the client sees +// an interrupted response but the server doesn't log an error, panic +// with the value ErrAbortHandler. type Handler interface { ServeHTTP(ResponseWriter, *Request) } @@ -85,11 +89,25 @@ type Handler interface { // has returned. type ResponseWriter interface { // Header returns the header map that will be sent by - // WriteHeader. Changing the header after a call to - // WriteHeader (or Write) has no effect unless the modified - // headers were declared as trailers by setting the - // "Trailer" header before the call to WriteHeader (see example). - // To suppress implicit response headers, set their value to nil. + // WriteHeader. The Header map also is the mechanism with which + // Handlers can set HTTP trailers. + // + // Changing the header map after a call to WriteHeader (or + // Write) has no effect unless the modified headers are + // trailers. + // + // There are two ways to set Trailers. The preferred way is to + // predeclare in the headers which trailers you will later + // send by setting the "Trailer" header to the names of the + // trailer keys which will come later. In this case, those + // keys of the Header map are treated as if they were + // trailers. See the example. The second way, for trailer + // keys not known to the Handler until after the first Write, + // is to prefix the Header map keys with the TrailerPrefix + // constant value. See TrailerPrefix. + // + // To suppress implicit response headers (such as "Date"), set + // their value to nil. Header() Header // Write writes the data to the connection as part of an HTTP reply. @@ -206,6 +224,9 @@ type conn struct { // Immutable; never nil. server *Server + // cancelCtx cancels the connection-level context. + cancelCtx context.CancelFunc + // rwc is the underlying network connection. // This is never wrapped by other types and is the value given out // to CloseNotifier callers. It is usually of type *net.TCPConn or @@ -232,7 +253,6 @@ type conn struct { r *connReader // bufr reads from r. - // Users of bufr must hold mu. bufr *bufio.Reader // bufw writes to checkConnErrorWriter{c}, which populates werr on error. @@ -242,7 +262,11 @@ type conn struct { // on this connection, if any. lastMethod string - // mu guards hijackedv, use of bufr, (*response).closeNotifyCh. + curReq atomic.Value // of *response (which has a Request in it) + + curState atomic.Value // of ConnState + + // mu guards hijackedv mu sync.Mutex // hijackedv is whether this connection has been hijacked @@ -262,8 +286,12 @@ func (c *conn) hijackLocked() (rwc net.Conn, buf *bufio.ReadWriter, err error) { if c.hijackedv { return nil, nil, ErrHijacked } + c.r.abortPendingRead() + c.hijackedv = true rwc = c.rwc + rwc.SetDeadline(time.Time{}) + buf = bufio.NewReadWriter(c.bufr, bufio.NewWriter(rwc)) c.setState(rwc, StateHijacked) return @@ -346,13 +374,7 @@ func (cw *chunkWriter) close() { bw := cw.res.conn.bufw // conn's bufio writer // zero chunk to mark EOF bw.WriteString("0\r\n") - if len(cw.res.trailers) > 0 { - trailers := make(Header) - for _, h := range cw.res.trailers { - if vv := cw.res.handlerHeader[h]; len(vv) > 0 { - trailers[h] = vv - } - } + if trailers := cw.res.finalTrailers(); trailers != nil { trailers.Write(bw) // the writer handles noting errors } // final blank line after the trailers (whether @@ -413,9 +435,48 @@ type response struct { dateBuf [len(TimeFormat)]byte clenBuf [10]byte - // closeNotifyCh is non-nil once CloseNotify is called. - // Guarded by conn.mu - closeNotifyCh <-chan bool + // closeNotifyCh is the channel returned by CloseNotify. + // TODO(bradfitz): this is currently (for Go 1.8) always + // non-nil. Make this lazily-created again as it used to be? + closeNotifyCh chan bool + didCloseNotify int32 // atomic (only 0->1 winner should send) +} + +// TrailerPrefix is a magic prefix for ResponseWriter.Header map keys +// that, if present, signals that the map entry is actually for +// the response trailers, and not the response headers. The prefix +// is stripped after the ServeHTTP call finishes and the values are +// sent in the trailers. +// +// This mechanism is intended only for trailers that are not known +// prior to the headers being written. If the set of trailers is fixed +// or known before the header is written, the normal Go trailers mechanism +// is preferred: +// https://golang.org/pkg/net/http/#ResponseWriter +// https://golang.org/pkg/net/http/#example_ResponseWriter_trailers +const TrailerPrefix = "Trailer:" + +// finalTrailers is called after the Handler exits and returns a non-nil +// value if the Handler set any trailers. +func (w *response) finalTrailers() Header { + var t Header + for k, vv := range w.handlerHeader { + if strings.HasPrefix(k, TrailerPrefix) { + if t == nil { + t = make(Header) + } + t[strings.TrimPrefix(k, TrailerPrefix)] = vv + } + } + for _, k := range w.trailers { + if t == nil { + t = make(Header) + } + for _, v := range w.handlerHeader[k] { + t.Add(k, v) + } + } + return t } type atomicBool int32 @@ -548,60 +609,152 @@ type readResult struct { // call blocked in a background goroutine to wait for activity and // trigger a CloseNotifier channel. type connReader struct { - r io.Reader - remain int64 // bytes remaining + conn *conn - // ch is non-nil if a background read is in progress. - // It is guarded by conn.mu. - ch chan readResult + mu sync.Mutex // guards following + hasByte bool + byteBuf [1]byte + bgErr error // non-nil means error happened on background read + cond *sync.Cond + inRead bool + aborted bool // set true before conn.rwc deadline is set to past + remain int64 // bytes remaining +} + +func (cr *connReader) lock() { + cr.mu.Lock() + if cr.cond == nil { + cr.cond = sync.NewCond(&cr.mu) + } +} + +func (cr *connReader) unlock() { cr.mu.Unlock() } + +func (cr *connReader) startBackgroundRead() { + cr.lock() + defer cr.unlock() + if cr.inRead { + panic("invalid concurrent Body.Read call") + } + if cr.hasByte { + return + } + cr.inRead = true + cr.conn.rwc.SetReadDeadline(time.Time{}) + go cr.backgroundRead() +} + +func (cr *connReader) backgroundRead() { + n, err := cr.conn.rwc.Read(cr.byteBuf[:]) + cr.lock() + if n == 1 { + cr.hasByte = true + // We were at EOF already (since we wouldn't be in a + // background read otherwise), so this is a pipelined + // HTTP request. + cr.closeNotifyFromPipelinedRequest() + } + if ne, ok := err.(net.Error); ok && cr.aborted && ne.Timeout() { + // Ignore this error. It's the expected error from + // another goroutine calling abortPendingRead. + } else if err != nil { + cr.handleReadError(err) + } + cr.aborted = false + cr.inRead = false + cr.unlock() + cr.cond.Broadcast() +} + +func (cr *connReader) abortPendingRead() { + cr.lock() + defer cr.unlock() + if !cr.inRead { + return + } + cr.aborted = true + cr.conn.rwc.SetReadDeadline(aLongTimeAgo) + for cr.inRead { + cr.cond.Wait() + } + cr.conn.rwc.SetReadDeadline(time.Time{}) } func (cr *connReader) setReadLimit(remain int64) { cr.remain = remain } func (cr *connReader) setInfiniteReadLimit() { cr.remain = maxInt64 } func (cr *connReader) hitReadLimit() bool { return cr.remain <= 0 } +// may be called from multiple goroutines. +func (cr *connReader) handleReadError(err error) { + cr.conn.cancelCtx() + cr.closeNotify() +} + +// closeNotifyFromPipelinedRequest simply calls closeNotify. +// +// This method wrapper is here for documentation. The callers are the +// cases where we send on the closenotify channel because of a +// pipelined HTTP request, per the previous Go behavior and +// documentation (that this "MAY" happen). +// +// TODO: consider changing this behavior and making context +// cancelation and closenotify work the same. +func (cr *connReader) closeNotifyFromPipelinedRequest() { + cr.closeNotify() +} + +// may be called from multiple goroutines. +func (cr *connReader) closeNotify() { + res, _ := cr.conn.curReq.Load().(*response) + if res != nil { + if atomic.CompareAndSwapInt32(&res.didCloseNotify, 0, 1) { + res.closeNotifyCh <- true + } + } +} + func (cr *connReader) Read(p []byte) (n int, err error) { + cr.lock() + if cr.inRead { + cr.unlock() + panic("invalid concurrent Body.Read call") + } if cr.hitReadLimit() { + cr.unlock() return 0, io.EOF } + if cr.bgErr != nil { + err = cr.bgErr + cr.unlock() + return 0, err + } if len(p) == 0 { - return + cr.unlock() + return 0, nil } if int64(len(p)) > cr.remain { p = p[:cr.remain] } - - // Is a background read (started by CloseNotifier) already in - // flight? If so, wait for it and use its result. - ch := cr.ch - if ch != nil { - cr.ch = nil - res := <-ch - if res.n == 1 { - p[0] = res.b - cr.remain -= 1 - } - return res.n, res.err + if cr.hasByte { + p[0] = cr.byteBuf[0] + cr.hasByte = false + cr.unlock() + return 1, nil } - n, err = cr.r.Read(p) - cr.remain -= int64(n) - return -} + cr.inRead = true + cr.unlock() + n, err = cr.conn.rwc.Read(p) -func (cr *connReader) startBackgroundRead(onReadComplete func()) { - if cr.ch != nil { - // Background read already started. - return + cr.lock() + cr.inRead = false + if err != nil { + cr.handleReadError(err) } - cr.ch = make(chan readResult, 1) - go cr.closeNotifyAwaitActivityRead(cr.ch, onReadComplete) -} + cr.remain -= int64(n) + cr.unlock() -func (cr *connReader) closeNotifyAwaitActivityRead(ch chan<- readResult, onReadComplete func()) { - var buf [1]byte - n, err := cr.r.Read(buf[:1]) - onReadComplete() - ch <- readResult{n, err, buf[0]} + cr.cond.Broadcast() + return n, err } var ( @@ -633,7 +786,7 @@ func newBufioReader(r io.Reader) *bufio.Reader { br.Reset(r) return br } - // Note: if this reader size is every changed, update + // Note: if this reader size is ever changed, update // TestHandlerBodyClose's assumptions. return bufio.NewReader(r) } @@ -746,9 +899,18 @@ func (c *conn) readRequest(ctx context.Context) (w *response, err error) { return nil, ErrHijacked } + var ( + wholeReqDeadline time.Time // or zero if none + hdrDeadline time.Time // or zero if none + ) + t0 := time.Now() + if d := c.server.readHeaderTimeout(); d != 0 { + hdrDeadline = t0.Add(d) + } if d := c.server.ReadTimeout; d != 0 { - c.rwc.SetReadDeadline(time.Now().Add(d)) + wholeReqDeadline = t0.Add(d) } + c.rwc.SetReadDeadline(hdrDeadline) if d := c.server.WriteTimeout; d != 0 { defer func() { c.rwc.SetWriteDeadline(time.Now().Add(d)) @@ -756,14 +918,12 @@ func (c *conn) readRequest(ctx context.Context) (w *response, err error) { } c.r.setReadLimit(c.server.initialReadLimitSize()) - c.mu.Lock() // while using bufr if c.lastMethod == "POST" { // RFC 2616 section 4.1 tolerance for old buggy clients. peek, _ := c.bufr.Peek(4) // ReadRequest will get err below c.bufr.Discard(numLeadingCRorLF(peek)) } req, err := readRequest(c.bufr, keepHostHeader) - c.mu.Unlock() if err != nil { if c.r.hitReadLimit() { return nil, errTooLarge @@ -809,6 +969,11 @@ func (c *conn) readRequest(ctx context.Context) (w *response, err error) { body.doEarlyClose = true } + // Adjust the read deadline if necessary. + if !hdrDeadline.Equal(wholeReqDeadline) { + c.rwc.SetReadDeadline(wholeReqDeadline) + } + w = &response{ conn: c, cancelCtx: cancelCtx, @@ -816,6 +981,7 @@ func (c *conn) readRequest(ctx context.Context) (w *response, err error) { reqBody: req.Body, handlerHeader: make(Header), contentLength: -1, + closeNotifyCh: make(chan bool, 1), // We populate these ahead of time so we're not // reading from req.Header after their Handler starts @@ -990,7 +1156,17 @@ func (cw *chunkWriter) writeHeader(p []byte) { } var setHeader extraHeader + // Don't write out the fake "Trailer:foo" keys. See TrailerPrefix. trailers := false + for k := range cw.header { + if strings.HasPrefix(k, TrailerPrefix) { + if excludeHeader == nil { + excludeHeader = make(map[string]bool) + } + excludeHeader[k] = true + trailers = true + } + } for _, v := range cw.header["Trailer"] { trailers = true foreachHeaderElement(v, cw.res.declareTrailer) @@ -1318,7 +1494,9 @@ func (w *response) WriteString(data string) (n int, err error) { // either dataB or dataS is non-zero. func (w *response) write(lenData int, dataB []byte, dataS string) (n int, err error) { if w.conn.hijacked() { - w.conn.server.logf("http: response.Write on hijacked connection") + if lenData > 0 { + w.conn.server.logf("http: response.Write on hijacked connection") + } return 0, ErrHijacked } if !w.wroteHeader { @@ -1354,6 +1532,8 @@ func (w *response) finishRequest() { w.cw.close() w.conn.bufw.Flush() + w.conn.r.abortPendingRead() + // Close the body (regardless of w.closeAfterReply) so we can // re-use its bufio.Reader later safely. w.reqBody.Close() @@ -1469,11 +1649,30 @@ func validNPN(proto string) bool { } func (c *conn) setState(nc net.Conn, state ConnState) { - if hook := c.server.ConnState; hook != nil { + srv := c.server + switch state { + case StateNew: + srv.trackConn(c, true) + case StateHijacked, StateClosed: + srv.trackConn(c, false) + } + c.curState.Store(connStateInterface[state]) + if hook := srv.ConnState; hook != nil { hook(nc, state) } } +// connStateInterface is an array of the interface{} versions of +// ConnState values, so we can use them in atomic.Values later without +// paying the cost of shoving their integers in an interface{}. +var connStateInterface = [...]interface{}{ + StateNew: StateNew, + StateActive: StateActive, + StateIdle: StateIdle, + StateHijacked: StateHijacked, + StateClosed: StateClosed, +} + // badRequestError is a literal string (used by in the server in HTML, // unescaped) to tell the user why their request was bad. It should // be plain text without user info or other embedded errors. @@ -1481,11 +1680,34 @@ type badRequestError string func (e badRequestError) Error() string { return "Bad Request: " + string(e) } +// ErrAbortHandler is a sentinel panic value to abort a handler. +// While any panic from ServeHTTP aborts the response to the client, +// panicking with ErrAbortHandler also suppresses logging of a stack +// trace to the server's error log. +var ErrAbortHandler = errors.New("net/http: abort Handler") + +// isCommonNetReadError reports whether err is a common error +// encountered during reading a request off the network when the +// client has gone away or had its read fail somehow. This is used to +// determine which logs are interesting enough to log about. +func isCommonNetReadError(err error) bool { + if err == io.EOF { + return true + } + if neterr, ok := err.(net.Error); ok && neterr.Timeout() { + return true + } + if oe, ok := err.(*net.OpError); ok && oe.Op == "read" { + return true + } + return false +} + // Serve a new connection. func (c *conn) serve(ctx context.Context) { c.remoteAddr = c.rwc.RemoteAddr().String() defer func() { - if err := recover(); err != nil { + if err := recover(); err != nil && err != ErrAbortHandler { const size = 64 << 10 buf := make([]byte, size) buf = buf[:runtime.Stack(buf, false)] @@ -1521,13 +1743,14 @@ func (c *conn) serve(ctx context.Context) { // HTTP/1.x from here on. - c.r = &connReader{r: c.rwc} - c.bufr = newBufioReader(c.r) - c.bufw = newBufioWriterSize(checkConnErrorWriter{c}, 4<<10) - ctx, cancelCtx := context.WithCancel(ctx) + c.cancelCtx = cancelCtx defer cancelCtx() + c.r = &connReader{conn: c} + c.bufr = newBufioReader(c.r) + c.bufw = newBufioWriterSize(checkConnErrorWriter{c}, 4<<10) + for { w, err := c.readRequest(ctx) if c.r.remain != c.server.initialReadLimitSize() { @@ -1535,27 +1758,29 @@ func (c *conn) serve(ctx context.Context) { c.setState(c.rwc, StateActive) } if err != nil { + const errorHeaders = "\r\nContent-Type: text/plain; charset=utf-8\r\nConnection: close\r\n\r\n" + if err == errTooLarge { // Their HTTP client may or may not be // able to read this if we're // responding to them and hanging up // while they're still writing their // request. Undefined behavior. - io.WriteString(c.rwc, "HTTP/1.1 431 Request Header Fields Too Large\r\nContent-Type: text/plain\r\nConnection: close\r\n\r\n431 Request Header Fields Too Large") + const publicErr = "431 Request Header Fields Too Large" + fmt.Fprintf(c.rwc, "HTTP/1.1 "+publicErr+errorHeaders+publicErr) c.closeWriteAndWait() return } - if err == io.EOF { - return // don't reply - } - if neterr, ok := err.(net.Error); ok && neterr.Timeout() { + if isCommonNetReadError(err) { return // don't reply } - var publicErr string + + publicErr := "400 Bad Request" if v, ok := err.(badRequestError); ok { - publicErr = ": " + string(v) + publicErr = publicErr + ": " + string(v) } - io.WriteString(c.rwc, "HTTP/1.1 400 Bad Request\r\nContent-Type: text/plain\r\nConnection: close\r\n\r\n400 Bad Request"+publicErr) + + fmt.Fprintf(c.rwc, "HTTP/1.1 "+publicErr+errorHeaders+publicErr) return } @@ -1571,11 +1796,24 @@ func (c *conn) serve(ctx context.Context) { return } + c.curReq.Store(w) + + if requestBodyRemains(req.Body) { + registerOnHitEOF(req.Body, w.conn.r.startBackgroundRead) + } else { + if w.conn.bufr.Buffered() > 0 { + w.conn.r.closeNotifyFromPipelinedRequest() + } + w.conn.r.startBackgroundRead() + } + // HTTP cannot have multiple simultaneous active requests.[*] // Until the server replies to this request, it can't read another, // so we might as well run the handler in this goroutine. // [*] Not strictly true: HTTP pipelining. We could let them all process // in parallel even if their responses need to be serialized. + // But we're not going to implement HTTP pipelining because it + // was never deployed in the wild and the answer is HTTP/2. serverHandler{c.server}.ServeHTTP(w, w.req) w.cancelCtx() if c.hijacked() { @@ -1589,6 +1827,23 @@ func (c *conn) serve(ctx context.Context) { return } c.setState(c.rwc, StateIdle) + c.curReq.Store((*response)(nil)) + + if !w.conn.server.doKeepAlives() { + // We're in shutdown mode. We might've replied + // to the user without "Connection: close" and + // they might think they can send another + // request, but such is life with HTTP/1.1. + return + } + + if d := c.server.idleTimeout(); d != 0 { + c.rwc.SetReadDeadline(time.Now().Add(d)) + if _, err := c.bufr.Peek(4); err != nil { + return + } + } + c.rwc.SetReadDeadline(time.Time{}) } } @@ -1624,10 +1879,6 @@ func (w *response) Hijack() (rwc net.Conn, buf *bufio.ReadWriter, err error) { c.mu.Lock() defer c.mu.Unlock() - if w.closeNotifyCh != nil { - return nil, nil, errors.New("http: Hijack is incompatible with use of CloseNotifier in same ServeHTTP call") - } - // Release the bufioWriter that writes to the chunk writer, it is not // used after a connection has been hijacked. rwc, buf, err = c.hijackLocked() @@ -1642,50 +1893,7 @@ func (w *response) CloseNotify() <-chan bool { if w.handlerDone.isSet() { panic("net/http: CloseNotify called after ServeHTTP finished") } - c := w.conn - c.mu.Lock() - defer c.mu.Unlock() - - if w.closeNotifyCh != nil { - return w.closeNotifyCh - } - ch := make(chan bool, 1) - w.closeNotifyCh = ch - - if w.conn.hijackedv { - // CloseNotify is undefined after a hijack, but we have - // no place to return an error, so just return a channel, - // even though it'll never receive a value. - return ch - } - - var once sync.Once - notify := func() { once.Do(func() { ch <- true }) } - - if requestBodyRemains(w.reqBody) { - // They're still consuming the request body, so we - // shouldn't notify yet. - registerOnHitEOF(w.reqBody, func() { - c.mu.Lock() - defer c.mu.Unlock() - startCloseNotifyBackgroundRead(c, notify) - }) - } else { - startCloseNotifyBackgroundRead(c, notify) - } - return ch -} - -// c.mu must be held. -func startCloseNotifyBackgroundRead(c *conn, notify func()) { - if c.bufr.Buffered() > 0 { - // They've consumed the request body, so anything - // remaining is a pipelined request, which we - // document as firing on. - notify() - } else { - c.r.startBackgroundRead(notify) - } + return w.closeNotifyCh } func registerOnHitEOF(rc io.ReadCloser, fn func()) { @@ -1702,7 +1910,7 @@ func registerOnHitEOF(rc io.ReadCloser, fn func()) { // requestBodyRemains reports whether future calls to Read // on rc might yield more data. func requestBodyRemains(rc io.ReadCloser) bool { - if rc == eofReader { + if rc == NoBody { return false } switch v := rc.(type) { @@ -1816,7 +2024,7 @@ func Redirect(w ResponseWriter, r *Request, urlStr string, code int) { } } - w.Header().Set("Location", urlStr) + w.Header().Set("Location", hexEscapeNonASCII(urlStr)) w.WriteHeader(code) // RFC 2616 recommends that a short note "SHOULD" be included in the @@ -2094,11 +2302,36 @@ func Serve(l net.Listener, handler Handler) error { // A Server defines parameters for running an HTTP server. // The zero value for Server is a valid configuration. type Server struct { - Addr string // TCP address to listen on, ":http" if empty - Handler Handler // handler to invoke, http.DefaultServeMux if nil - ReadTimeout time.Duration // maximum duration before timing out read of the request - WriteTimeout time.Duration // maximum duration before timing out write of the response - TLSConfig *tls.Config // optional TLS config, used by ListenAndServeTLS + Addr string // TCP address to listen on, ":http" if empty + Handler Handler // handler to invoke, http.DefaultServeMux if nil + TLSConfig *tls.Config // optional TLS config, used by ListenAndServeTLS + + // ReadTimeout is the maximum duration for reading the entire + // request, including the body. + // + // Because ReadTimeout does not let Handlers make per-request + // decisions on each request body's acceptable deadline or + // upload rate, most users will prefer to use + // ReadHeaderTimeout. It is valid to use them both. + ReadTimeout time.Duration + + // ReadHeaderTimeout is the amount of time allowed to read + // request headers. The connection's read deadline is reset + // after reading the headers and the Handler can decide what + // is considered too slow for the body. + ReadHeaderTimeout time.Duration + + // WriteTimeout is the maximum duration before timing out + // writes of the response. It is reset whenever a new + // request's header is read. Like ReadTimeout, it does not + // let Handlers make decisions on a per-request basis. + WriteTimeout time.Duration + + // IdleTimeout is the maximum amount of time to wait for the + // next request when keep-alives are enabled. If IdleTimeout + // is zero, the value of ReadTimeout is used. If both are + // zero, there is no timeout. + IdleTimeout time.Duration // MaxHeaderBytes controls the maximum number of bytes the // server will read parsing the request header's keys and @@ -2114,7 +2347,8 @@ type Server struct { // handle HTTP requests and will initialize the Request's TLS // and RemoteAddr if not already set. The connection is // automatically closed when the function returns. - // If TLSNextProto is nil, HTTP/2 support is enabled automatically. + // If TLSNextProto is not nil, HTTP/2 support is not enabled + // automatically. TLSNextProto map[string]func(*Server, *tls.Conn, Handler) // ConnState specifies an optional callback function that is @@ -2129,8 +2363,132 @@ type Server struct { ErrorLog *log.Logger disableKeepAlives int32 // accessed atomically. + inShutdown int32 // accessed atomically (non-zero means we're in Shutdown) nextProtoOnce sync.Once // guards setupHTTP2_* init nextProtoErr error // result of http2.ConfigureServer if used + + mu sync.Mutex + listeners map[net.Listener]struct{} + activeConn map[*conn]struct{} + doneChan chan struct{} +} + +func (s *Server) getDoneChan() <-chan struct{} { + s.mu.Lock() + defer s.mu.Unlock() + return s.getDoneChanLocked() +} + +func (s *Server) getDoneChanLocked() chan struct{} { + if s.doneChan == nil { + s.doneChan = make(chan struct{}) + } + return s.doneChan +} + +func (s *Server) closeDoneChanLocked() { + ch := s.getDoneChanLocked() + select { + case <-ch: + // Already closed. Don't close again. + default: + // Safe to close here. We're the only closer, guarded + // by s.mu. + close(ch) + } +} + +// Close immediately closes all active net.Listeners and any +// connections in state StateNew, StateActive, or StateIdle. For a +// graceful shutdown, use Shutdown. +// +// Close does not attempt to close (and does not even know about) +// any hijacked connections, such as WebSockets. +// +// Close returns any error returned from closing the Server's +// underlying Listener(s). +func (srv *Server) Close() error { + srv.mu.Lock() + defer srv.mu.Unlock() + srv.closeDoneChanLocked() + err := srv.closeListenersLocked() + for c := range srv.activeConn { + c.rwc.Close() + delete(srv.activeConn, c) + } + return err +} + +// shutdownPollInterval is how often we poll for quiescence +// during Server.Shutdown. This is lower during tests, to +// speed up tests. +// Ideally we could find a solution that doesn't involve polling, +// but which also doesn't have a high runtime cost (and doesn't +// involve any contentious mutexes), but that is left as an +// exercise for the reader. +var shutdownPollInterval = 500 * time.Millisecond + +// Shutdown gracefully shuts down the server without interrupting any +// active connections. Shutdown works by first closing all open +// listeners, then closing all idle connections, and then waiting +// indefinitely for connections to return to idle and then shut down. +// If the provided context expires before the shutdown is complete, +// then the context's error is returned. +// +// Shutdown does not attempt to close nor wait for hijacked +// connections such as WebSockets. The caller of Shutdown should +// separately notify such long-lived connections of shutdown and wait +// for them to close, if desired. +func (srv *Server) Shutdown(ctx context.Context) error { + atomic.AddInt32(&srv.inShutdown, 1) + defer atomic.AddInt32(&srv.inShutdown, -1) + + srv.mu.Lock() + lnerr := srv.closeListenersLocked() + srv.closeDoneChanLocked() + srv.mu.Unlock() + + ticker := time.NewTicker(shutdownPollInterval) + defer ticker.Stop() + for { + if srv.closeIdleConns() { + return lnerr + } + select { + case <-ctx.Done(): + return ctx.Err() + case <-ticker.C: + } + } +} + +// closeIdleConns closes all idle connections and reports whether the +// server is quiescent. +func (s *Server) closeIdleConns() bool { + s.mu.Lock() + defer s.mu.Unlock() + quiescent := true + for c := range s.activeConn { + st, ok := c.curState.Load().(ConnState) + if !ok || st != StateIdle { + quiescent = false + continue + } + c.rwc.Close() + delete(s.activeConn, c) + } + return quiescent +} + +func (s *Server) closeListenersLocked() error { + var err error + for ln := range s.listeners { + if cerr := ln.Close(); cerr != nil && err == nil { + err = cerr + } + delete(s.listeners, ln) + } + return err } // A ConnState represents the state of a client connection to a server. @@ -2243,6 +2601,8 @@ func (srv *Server) shouldConfigureHTTP2ForServe() bool { return strSliceContains(srv.TLSConfig.NextProtos, http2NextProtoTLS) } +var ErrServerClosed = errors.New("http: Server closed") + // Serve accepts incoming connections on the Listener l, creating a // new service goroutine for each. The service goroutines read requests and // then call srv.Handler to reply to them. @@ -2252,7 +2612,8 @@ func (srv *Server) shouldConfigureHTTP2ForServe() bool { // srv.TLSConfig is non-nil and doesn't include the string "h2" in // Config.NextProtos, HTTP/2 support is not enabled. // -// Serve always returns a non-nil error. +// Serve always returns a non-nil error. After Shutdown or Close, the +// returned error is ErrServerClosed. func (srv *Server) Serve(l net.Listener) error { defer l.Close() if fn := testHookServerServe; fn != nil { @@ -2264,14 +2625,20 @@ func (srv *Server) Serve(l net.Listener) error { return err } - // TODO: allow changing base context? can't imagine concrete - // use cases yet. - baseCtx := context.Background() + srv.trackListener(l, true) + defer srv.trackListener(l, false) + + baseCtx := context.Background() // base is always background, per Issue 16220 ctx := context.WithValue(baseCtx, ServerContextKey, srv) ctx = context.WithValue(ctx, LocalAddrContextKey, l.Addr()) for { rw, e := l.Accept() if e != nil { + select { + case <-srv.getDoneChan(): + return ErrServerClosed + default: + } if ne, ok := e.(net.Error); ok && ne.Temporary() { if tempDelay == 0 { tempDelay = 5 * time.Millisecond @@ -2294,8 +2661,57 @@ func (srv *Server) Serve(l net.Listener) error { } } +func (s *Server) trackListener(ln net.Listener, add bool) { + s.mu.Lock() + defer s.mu.Unlock() + if s.listeners == nil { + s.listeners = make(map[net.Listener]struct{}) + } + if add { + // If the *Server is being reused after a previous + // Close or Shutdown, reset its doneChan: + if len(s.listeners) == 0 && len(s.activeConn) == 0 { + s.doneChan = nil + } + s.listeners[ln] = struct{}{} + } else { + delete(s.listeners, ln) + } +} + +func (s *Server) trackConn(c *conn, add bool) { + s.mu.Lock() + defer s.mu.Unlock() + if s.activeConn == nil { + s.activeConn = make(map[*conn]struct{}) + } + if add { + s.activeConn[c] = struct{}{} + } else { + delete(s.activeConn, c) + } +} + +func (s *Server) idleTimeout() time.Duration { + if s.IdleTimeout != 0 { + return s.IdleTimeout + } + return s.ReadTimeout +} + +func (s *Server) readHeaderTimeout() time.Duration { + if s.ReadHeaderTimeout != 0 { + return s.ReadHeaderTimeout + } + return s.ReadTimeout +} + func (s *Server) doKeepAlives() bool { - return atomic.LoadInt32(&s.disableKeepAlives) == 0 + return atomic.LoadInt32(&s.disableKeepAlives) == 0 && !s.shuttingDown() +} + +func (s *Server) shuttingDown() bool { + return atomic.LoadInt32(&s.inShutdown) != 0 } // SetKeepAlivesEnabled controls whether HTTP keep-alives are enabled. @@ -2305,9 +2721,21 @@ func (s *Server) doKeepAlives() bool { func (srv *Server) SetKeepAlivesEnabled(v bool) { if v { atomic.StoreInt32(&srv.disableKeepAlives, 0) - } else { - atomic.StoreInt32(&srv.disableKeepAlives, 1) + return } + atomic.StoreInt32(&srv.disableKeepAlives, 1) + + // Close idle HTTP/1 conns: + srv.closeIdleConns() + + // Close HTTP/2 conns, as soon as they become idle, but reset + // the chan so future conns (if the listener is still active) + // still work and don't get a GOAWAY immediately, before their + // first request: + srv.mu.Lock() + defer srv.mu.Unlock() + srv.closeDoneChanLocked() // closes http2 conns + srv.doneChan = nil } func (s *Server) logf(format string, args ...interface{}) { @@ -2630,24 +3058,6 @@ func (globalOptionsHandler) ServeHTTP(w ResponseWriter, r *Request) { } } -type eofReaderWithWriteTo struct{} - -func (eofReaderWithWriteTo) WriteTo(io.Writer) (int64, error) { return 0, nil } -func (eofReaderWithWriteTo) Read([]byte) (int, error) { return 0, io.EOF } - -// eofReader is a non-nil io.ReadCloser that always returns EOF. -// It has a WriteTo method so io.Copy won't need a buffer. -var eofReader = &struct { - eofReaderWithWriteTo - io.Closer -}{ - eofReaderWithWriteTo{}, - ioutil.NopCloser(nil), -} - -// Verify that an io.Copy from an eofReader won't require a buffer. -var _ io.WriterTo = eofReader - // initNPNRequest is an HTTP handler that initializes certain // uninitialized fields in its *Request. Such partially-initialized // Requests come from NPN protocol handlers. @@ -2662,7 +3072,7 @@ func (h initNPNRequest) ServeHTTP(rw ResponseWriter, req *Request) { *req.TLS = h.c.ConnectionState() } if req.Body == nil { - req.Body = eofReader + req.Body = NoBody } if req.RemoteAddr == "" { req.RemoteAddr = h.c.RemoteAddr().String() @@ -2723,6 +3133,7 @@ func (w checkConnErrorWriter) Write(p []byte) (n int, err error) { n, err = w.c.rwc.Write(p) if err != nil && w.c.werr == nil { w.c.werr = err + w.c.cancelCtx() } return } diff --git a/libgo/go/net/http/sniff_test.go b/libgo/go/net/http/sniff_test.go index ac404bfa723..38f3f8197e9 100644 --- a/libgo/go/net/http/sniff_test.go +++ b/libgo/go/net/http/sniff_test.go @@ -66,6 +66,7 @@ func TestServerContentType_h1(t *testing.T) { testServerContentType(t, h1Mode) } func TestServerContentType_h2(t *testing.T) { testServerContentType(t, h2Mode) } func testServerContentType(t *testing.T, h2 bool) { + setParallel(t) defer afterTest(t) cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { i, _ := strconv.Atoi(r.FormValue("i")) @@ -160,6 +161,7 @@ func testContentTypeWithCopy(t *testing.T, h2 bool) { func TestSniffWriteSize_h1(t *testing.T) { testSniffWriteSize(t, h1Mode) } func TestSniffWriteSize_h2(t *testing.T) { testSniffWriteSize(t, h2Mode) } func testSniffWriteSize(t *testing.T, h2 bool) { + setParallel(t) defer afterTest(t) cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { size, _ := strconv.Atoi(r.FormValue("size")) diff --git a/libgo/go/net/http/transfer.go b/libgo/go/net/http/transfer.go index c653467098c..4f47637aa76 100644 --- a/libgo/go/net/http/transfer.go +++ b/libgo/go/net/http/transfer.go @@ -17,6 +17,7 @@ import ( "strconv" "strings" "sync" + "time" "golang_org/x/net/lex/httplex" ) @@ -33,6 +34,23 @@ func (r errorReader) Read(p []byte) (n int, err error) { return 0, r.err } +type byteReader struct { + b byte + done bool +} + +func (br *byteReader) Read(p []byte) (n int, err error) { + if br.done { + return 0, io.EOF + } + if len(p) == 0 { + return 0, nil + } + br.done = true + p[0] = br.b + return 1, io.EOF +} + // transferWriter inspects the fields of a user-supplied Request or Response, // sanitizes them without changing the user object and provides methods for // writing the respective header, body and trailer in wire format. @@ -46,6 +64,9 @@ type transferWriter struct { TransferEncoding []string Trailer Header IsResponse bool + + FlushHeaders bool // flush headers to network before body + ByteReadCh chan readResult // non-nil if probeRequestBody called } func newTransferWriter(r interface{}) (t *transferWriter, err error) { @@ -59,37 +80,15 @@ func newTransferWriter(r interface{}) (t *transferWriter, err error) { return nil, fmt.Errorf("http: Request.ContentLength=%d with nil Body", rr.ContentLength) } t.Method = valueOrDefault(rr.Method, "GET") - t.Body = rr.Body - t.BodyCloser = rr.Body - t.ContentLength = rr.ContentLength t.Close = rr.Close t.TransferEncoding = rr.TransferEncoding t.Trailer = rr.Trailer - atLeastHTTP11 = rr.ProtoAtLeast(1, 1) - if t.Body != nil && len(t.TransferEncoding) == 0 && atLeastHTTP11 { - if t.ContentLength == 0 { - // Test to see if it's actually zero or just unset. - var buf [1]byte - n, rerr := io.ReadFull(t.Body, buf[:]) - if rerr != nil && rerr != io.EOF { - t.ContentLength = -1 - t.Body = errorReader{rerr} - } else if n == 1 { - // Oh, guess there is data in this Body Reader after all. - // The ContentLength field just wasn't set. - // Stich the Body back together again, re-attaching our - // consumed byte. - t.ContentLength = -1 - t.Body = io.MultiReader(bytes.NewReader(buf[:]), t.Body) - } else { - // Body is actually empty. - t.Body = nil - t.BodyCloser = nil - } - } - if t.ContentLength < 0 { - t.TransferEncoding = []string{"chunked"} - } + atLeastHTTP11 = rr.protoAtLeastOutgoing(1, 1) + t.Body = rr.Body + t.BodyCloser = rr.Body + t.ContentLength = rr.outgoingLength() + if t.ContentLength < 0 && len(t.TransferEncoding) == 0 && atLeastHTTP11 && t.shouldSendChunkedRequestBody() { + t.TransferEncoding = []string{"chunked"} } case *Response: t.IsResponse = true @@ -103,7 +102,7 @@ func newTransferWriter(r interface{}) (t *transferWriter, err error) { t.TransferEncoding = rr.TransferEncoding t.Trailer = rr.Trailer atLeastHTTP11 = rr.ProtoAtLeast(1, 1) - t.ResponseToHEAD = noBodyExpected(t.Method) + t.ResponseToHEAD = noResponseBodyExpected(t.Method) } // Sanitize Body,ContentLength,TransferEncoding @@ -131,7 +130,100 @@ func newTransferWriter(r interface{}) (t *transferWriter, err error) { return t, nil } -func noBodyExpected(requestMethod string) bool { +// shouldSendChunkedRequestBody reports whether we should try to send a +// chunked request body to the server. In particular, the case we really +// want to prevent is sending a GET or other typically-bodyless request to a +// server with a chunked body when the body has zero bytes, since GETs with +// bodies (while acceptable according to specs), even zero-byte chunked +// bodies, are approximately never seen in the wild and confuse most +// servers. See Issue 18257, as one example. +// +// The only reason we'd send such a request is if the user set the Body to a +// non-nil value (say, ioutil.NopCloser(bytes.NewReader(nil))) and didn't +// set ContentLength, or NewRequest set it to -1 (unknown), so then we assume +// there's bytes to send. +// +// This code tries to read a byte from the Request.Body in such cases to see +// whether the body actually has content (super rare) or is actually just +// a non-nil content-less ReadCloser (the more common case). In that more +// common case, we act as if their Body were nil instead, and don't send +// a body. +func (t *transferWriter) shouldSendChunkedRequestBody() bool { + // Note that t.ContentLength is the corrected content length + // from rr.outgoingLength, so 0 actually means zero, not unknown. + if t.ContentLength >= 0 || t.Body == nil { // redundant checks; caller did them + return false + } + if requestMethodUsuallyLacksBody(t.Method) { + // Only probe the Request.Body for GET/HEAD/DELETE/etc + // requests, because it's only those types of requests + // that confuse servers. + t.probeRequestBody() // adjusts t.Body, t.ContentLength + return t.Body != nil + } + // For all other request types (PUT, POST, PATCH, or anything + // made-up we've never heard of), assume it's normal and the server + // can deal with a chunked request body. Maybe we'll adjust this + // later. + return true +} + +// probeRequestBody reads a byte from t.Body to see whether it's empty +// (returns io.EOF right away). +// +// But because we've had problems with this blocking users in the past +// (issue 17480) when the body is a pipe (perhaps waiting on the response +// headers before the pipe is fed data), we need to be careful and bound how +// long we wait for it. This delay will only affect users if all the following +// are true: +// * the request body blocks +// * the content length is not set (or set to -1) +// * the method doesn't usually have a body (GET, HEAD, DELETE, ...) +// * there is no transfer-encoding=chunked already set. +// In other words, this delay will not normally affect anybody, and there +// are workarounds if it does. +func (t *transferWriter) probeRequestBody() { + t.ByteReadCh = make(chan readResult, 1) + go func(body io.Reader) { + var buf [1]byte + var rres readResult + rres.n, rres.err = body.Read(buf[:]) + if rres.n == 1 { + rres.b = buf[0] + } + t.ByteReadCh <- rres + }(t.Body) + timer := time.NewTimer(200 * time.Millisecond) + select { + case rres := <-t.ByteReadCh: + timer.Stop() + if rres.n == 0 && rres.err == io.EOF { + // It was empty. + t.Body = nil + t.ContentLength = 0 + } else if rres.n == 1 { + if rres.err != nil { + t.Body = io.MultiReader(&byteReader{b: rres.b}, errorReader{rres.err}) + } else { + t.Body = io.MultiReader(&byteReader{b: rres.b}, t.Body) + } + } else if rres.err != nil { + t.Body = errorReader{rres.err} + } + case <-timer.C: + // Too slow. Don't wait. Read it later, and keep + // assuming that this is ContentLength == -1 + // (unknown), which means we'll send a + // "Transfer-Encoding: chunked" header. + t.Body = io.MultiReader(finishAsyncByteRead{t}, t.Body) + // Request that Request.Write flush the headers to the + // network before writing the body, since our body may not + // become readable until it's seen the response headers. + t.FlushHeaders = true + } +} + +func noResponseBodyExpected(requestMethod string) bool { return requestMethod == "HEAD" } @@ -214,7 +306,7 @@ func (t *transferWriter) WriteBody(w io.Writer) error { if t.Body != nil { if chunked(t.TransferEncoding) { if bw, ok := w.(*bufio.Writer); ok && !t.IsResponse { - w = &internal.FlushAfterChunkWriter{bw} + w = &internal.FlushAfterChunkWriter{Writer: bw} } cw := internal.NewChunkedWriter(w) _, err = io.Copy(cw, t.Body) @@ -235,7 +327,9 @@ func (t *transferWriter) WriteBody(w io.Writer) error { if err != nil { return err } - if err = t.BodyCloser.Close(); err != nil { + } + if t.BodyCloser != nil { + if err := t.BodyCloser.Close(); err != nil { return err } } @@ -385,13 +479,13 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) { // or close connection when finished, since multipart is not supported yet switch { case chunked(t.TransferEncoding): - if noBodyExpected(t.RequestMethod) { - t.Body = eofReader + if noResponseBodyExpected(t.RequestMethod) { + t.Body = NoBody } else { t.Body = &body{src: internal.NewChunkedReader(r), hdr: msg, r: r, closing: t.Close} } case realLength == 0: - t.Body = eofReader + t.Body = NoBody case realLength > 0: t.Body = &body{src: io.LimitReader(r, realLength), closing: t.Close} default: @@ -401,7 +495,7 @@ func readTransfer(msg interface{}, r *bufio.Reader) (err error) { t.Body = &body{src: r, closing: t.Close} } else { // Persistent connection (i.e. HTTP/1.1) - t.Body = eofReader + t.Body = NoBody } } @@ -493,10 +587,31 @@ func (t *transferReader) fixTransferEncoding() error { // function is not a method, because ultimately it should be shared by // ReadResponse and ReadRequest. func fixLength(isResponse bool, status int, requestMethod string, header Header, te []string) (int64, error) { - contentLens := header["Content-Length"] isRequest := !isResponse + contentLens := header["Content-Length"] + + // Hardening against HTTP request smuggling + if len(contentLens) > 1 { + // Per RFC 7230 Section 3.3.2, prevent multiple + // Content-Length headers if they differ in value. + // If there are dups of the value, remove the dups. + // See Issue 16490. + first := strings.TrimSpace(contentLens[0]) + for _, ct := range contentLens[1:] { + if first != strings.TrimSpace(ct) { + return 0, fmt.Errorf("http: message cannot contain multiple Content-Length headers; got %q", contentLens) + } + } + + // deduplicate Content-Length + header.Del("Content-Length") + header.Add("Content-Length", first) + + contentLens = header["Content-Length"] + } + // Logic based on response type or status - if noBodyExpected(requestMethod) { + if noResponseBodyExpected(requestMethod) { // For HTTP requests, as part of hardening against request // smuggling (RFC 7230), don't allow a Content-Length header for // methods which don't permit bodies. As an exception, allow @@ -514,11 +629,6 @@ func fixLength(isResponse bool, status int, requestMethod string, header Header, return 0, nil } - if len(contentLens) > 1 { - // harden against HTTP request smuggling. See RFC 7230. - return 0, errors.New("http: message cannot contain multiple Content-Length headers") - } - // Logic based on Transfer-Encoding if chunked(te) { return -1, nil @@ -539,7 +649,7 @@ func fixLength(isResponse bool, status int, requestMethod string, header Header, header.Del("Content-Length") } - if !isResponse { + if isRequest { // RFC 2616 neither explicitly permits nor forbids an // entity-body on a GET request so we permit one if // declared, but we default to 0 here (not -1 below) @@ -864,3 +974,21 @@ func parseContentLength(cl string) (int64, error) { return n, nil } + +// finishAsyncByteRead finishes reading the 1-byte sniff +// from the ContentLength==0, Body!=nil case. +type finishAsyncByteRead struct { + tw *transferWriter +} + +func (fr finishAsyncByteRead) Read(p []byte) (n int, err error) { + if len(p) == 0 { + return + } + rres := <-fr.tw.ByteReadCh + n, err = rres.n, rres.err + if n == 1 { + p[0] = rres.b + } + return +} diff --git a/libgo/go/net/http/transport.go b/libgo/go/net/http/transport.go index 1f0763471b8..571943d6e5c 100644 --- a/libgo/go/net/http/transport.go +++ b/libgo/go/net/http/transport.go @@ -25,6 +25,7 @@ import ( "os" "strings" "sync" + "sync/atomic" "time" "golang_org/x/net/lex/httplex" @@ -40,6 +41,7 @@ var DefaultTransport RoundTripper = &Transport{ DialContext: (&net.Dialer{ Timeout: 30 * time.Second, KeepAlive: 30 * time.Second, + DualStack: true, }).DialContext, MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, @@ -66,8 +68,10 @@ const DefaultMaxIdleConnsPerHost = 2 // For high-level functionality, such as cookies and redirects, see Client. // // Transport uses HTTP/1.1 for HTTP URLs and either HTTP/1.1 or HTTP/2 -// for HTTPS URLs, depending on whether the server supports HTTP/2. -// See the package docs for more about HTTP/2. +// for HTTPS URLs, depending on whether the server supports HTTP/2, +// and how the Transport is configured. The DefaultTransport supports HTTP/2. +// To explicitly enable HTTP/2 on a transport, use golang.org/x/net/http2 +// and call ConfigureTransport. See the package docs for more about HTTP/2. type Transport struct { idleMu sync.Mutex wantIdle bool // user has requested to close all idle conns @@ -76,10 +80,10 @@ type Transport struct { idleLRU connLRU reqMu sync.Mutex - reqCanceler map[*Request]func() + reqCanceler map[*Request]func(error) - altMu sync.RWMutex - altProto map[string]RoundTripper // nil or map of URI scheme => RoundTripper + altMu sync.Mutex // guards changing altProto only + altProto atomic.Value // of nil or map[string]RoundTripper, key is URI scheme // Proxy specifies a function to return a proxy for a given // Request. If the function returns a non-nil error, the @@ -111,7 +115,9 @@ type Transport struct { DialTLS func(network, addr string) (net.Conn, error) // TLSClientConfig specifies the TLS configuration to use with - // tls.Client. If nil, the default configuration is used. + // tls.Client. + // If nil, the default configuration is used. + // If non-nil, HTTP/2 support may not be enabled by default. TLSClientConfig *tls.Config // TLSHandshakeTimeout specifies the maximum amount of time waiting to @@ -156,7 +162,9 @@ type Transport struct { // ExpectContinueTimeout, if non-zero, specifies the amount of // time to wait for a server's first response headers after fully // writing the request headers if the request has an - // "Expect: 100-continue" header. Zero means no timeout. + // "Expect: 100-continue" header. Zero means no timeout and + // causes the body to be sent immediately, without + // waiting for the server to approve. // This time does not include the time to send the request header. ExpectContinueTimeout time.Duration @@ -168,9 +176,14 @@ type Transport struct { // called with the request's authority (such as "example.com" // or "example.com:1234") and the TLS connection. The function // must return a RoundTripper that then handles the request. - // If TLSNextProto is nil, HTTP/2 support is enabled automatically. + // If TLSNextProto is not nil, HTTP/2 support is not enabled + // automatically. TLSNextProto map[string]func(authority string, c *tls.Conn) RoundTripper + // ProxyConnectHeader optionally specifies headers to send to + // proxies during CONNECT requests. + ProxyConnectHeader Header + // MaxResponseHeaderBytes specifies a limit on how many // response bytes are allowed in the server's response // header. @@ -330,11 +343,9 @@ func (t *Transport) RoundTrip(req *Request) (*Response, error) { } } } - // TODO(bradfitz): switch to atomic.Value for this map instead of RWMutex - t.altMu.RLock() - altRT := t.altProto[scheme] - t.altMu.RUnlock() - if altRT != nil { + + altProto, _ := t.altProto.Load().(map[string]RoundTripper) + if altRT := altProto[scheme]; altRT != nil { if resp, err := altRT.RoundTrip(req); err != ErrSkipAltProtocol { return resp, err } @@ -421,19 +432,15 @@ func (pc *persistConn) shouldRetryRequest(req *Request, err error) bool { // our request (as opposed to sending an error). return false } + if _, ok := err.(nothingWrittenError); ok { + // We never wrote anything, so it's safe to retry. + return true + } if !req.isReplayable() { // Don't retry non-idempotent requests. - - // TODO: swap the nothingWrittenError and isReplayable checks, - // putting the "if nothingWrittenError => return true" case - // first, per golang.org/issue/15723 return false } - switch err.(type) { - case nothingWrittenError: - // We never wrote anything, so it's safe to retry. - return true - case transportReadFromServerError: + if _, ok := err.(transportReadFromServerError); ok { // We got some non-EOF net.Conn.Read failure reading // the 1st response byte from the server. return true @@ -463,13 +470,16 @@ var ErrSkipAltProtocol = errors.New("net/http: skip alternate protocol") func (t *Transport) RegisterProtocol(scheme string, rt RoundTripper) { t.altMu.Lock() defer t.altMu.Unlock() - if t.altProto == nil { - t.altProto = make(map[string]RoundTripper) - } - if _, exists := t.altProto[scheme]; exists { + oldMap, _ := t.altProto.Load().(map[string]RoundTripper) + if _, exists := oldMap[scheme]; exists { panic("protocol " + scheme + " already registered") } - t.altProto[scheme] = rt + newMap := make(map[string]RoundTripper) + for k, v := range oldMap { + newMap[k] = v + } + newMap[scheme] = rt + t.altProto.Store(newMap) } // CloseIdleConnections closes any connections which were previously @@ -502,12 +512,17 @@ func (t *Transport) CloseIdleConnections() { // cancelable context instead. CancelRequest cannot cancel HTTP/2 // requests. func (t *Transport) CancelRequest(req *Request) { + t.cancelRequest(req, errRequestCanceled) +} + +// Cancel an in-flight request, recording the error value. +func (t *Transport) cancelRequest(req *Request, err error) { t.reqMu.Lock() cancel := t.reqCanceler[req] delete(t.reqCanceler, req) t.reqMu.Unlock() if cancel != nil { - cancel() + cancel(err) } } @@ -557,10 +572,18 @@ func (e *envOnce) reset() { } func (t *Transport) connectMethodForRequest(treq *transportRequest) (cm connectMethod, err error) { + if port := treq.URL.Port(); !validPort(port) { + return cm, fmt.Errorf("invalid URL port %q", port) + } cm.targetScheme = treq.URL.Scheme cm.targetAddr = canonicalAddr(treq.URL) if t.Proxy != nil { cm.proxyURL, err = t.Proxy(treq.Request) + if err == nil && cm.proxyURL != nil { + if port := cm.proxyURL.Port(); !validPort(port) { + return cm, fmt.Errorf("invalid proxy URL port %q", port) + } + } } return cm, err } @@ -787,11 +810,11 @@ func (t *Transport) removeIdleConnLocked(pconn *persistConn) { } } -func (t *Transport) setReqCanceler(r *Request, fn func()) { +func (t *Transport) setReqCanceler(r *Request, fn func(error)) { t.reqMu.Lock() defer t.reqMu.Unlock() if t.reqCanceler == nil { - t.reqCanceler = make(map[*Request]func()) + t.reqCanceler = make(map[*Request]func(error)) } if fn != nil { t.reqCanceler[r] = fn @@ -804,7 +827,7 @@ func (t *Transport) setReqCanceler(r *Request, fn func()) { // for the request, we don't set the function and return false. // Since CancelRequest will clear the canceler, we can use the return value to detect if // the request was canceled since the last setReqCancel call. -func (t *Transport) replaceReqCanceler(r *Request, fn func()) bool { +func (t *Transport) replaceReqCanceler(r *Request, fn func(error)) bool { t.reqMu.Lock() defer t.reqMu.Unlock() _, ok := t.reqCanceler[r] @@ -853,7 +876,7 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC // set request canceler to some non-nil function so we // can detect whether it was cleared between now and when // we enter roundTrip - t.setReqCanceler(req, func() {}) + t.setReqCanceler(req, func(error) {}) return pc, nil } @@ -878,8 +901,8 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC }() } - cancelc := make(chan struct{}) - t.setReqCanceler(req, func() { close(cancelc) }) + cancelc := make(chan error, 1) + t.setReqCanceler(req, func(err error) { cancelc <- err }) go func() { pc, err := t.dialConn(ctx, cm) @@ -900,16 +923,21 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC // value. select { case <-req.Cancel: + // It was an error due to cancelation, so prioritize that + // error value. (Issue 16049) + return nil, errRequestCanceledConn case <-req.Context().Done(): - case <-cancelc: + return nil, req.Context().Err() + case err := <-cancelc: + if err == errRequestCanceled { + err = errRequestCanceledConn + } + return nil, err default: // It wasn't an error due to cancelation, so // return the original error message: return nil, v.err } - // It was an error due to cancelation, so prioritize that - // error value. (Issue 16049) - return nil, errRequestCanceledConn case pc := <-idleConnCh: // Another request finished first and its net.Conn // became available before our dial. Or somebody @@ -926,10 +954,13 @@ func (t *Transport) getConn(treq *transportRequest, cm connectMethod) (*persistC return nil, errRequestCanceledConn case <-req.Context().Done(): handlePendingDial() - return nil, errRequestCanceledConn - case <-cancelc: + return nil, req.Context().Err() + case err := <-cancelc: handlePendingDial() - return nil, errRequestCanceledConn + if err == errRequestCanceled { + err = errRequestCanceledConn + } + return nil, err } } @@ -943,6 +974,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon writeErrCh: make(chan error, 1), writeLoopDone: make(chan struct{}), } + trace := httptrace.ContextClientTrace(ctx) tlsDial := t.DialTLS != nil && cm.targetScheme == "https" && cm.proxyURL == nil if tlsDial { var err error @@ -956,18 +988,28 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon if tc, ok := pconn.conn.(*tls.Conn); ok { // Handshake here, in case DialTLS didn't. TLSNextProto below // depends on it for knowing the connection state. + if trace != nil && trace.TLSHandshakeStart != nil { + trace.TLSHandshakeStart() + } if err := tc.Handshake(); err != nil { go pconn.conn.Close() + if trace != nil && trace.TLSHandshakeDone != nil { + trace.TLSHandshakeDone(tls.ConnectionState{}, err) + } return nil, err } cs := tc.ConnectionState() + if trace != nil && trace.TLSHandshakeDone != nil { + trace.TLSHandshakeDone(cs, nil) + } pconn.tlsState = &cs } } else { conn, err := t.dial(ctx, "tcp", cm.addr()) if err != nil { if cm.proxyURL != nil { - err = fmt.Errorf("http: error connecting to proxy %s: %v", cm.proxyURL, err) + // Return a typed error, per Issue 16997: + err = &net.OpError{Op: "proxyconnect", Net: "tcp", Err: err} } return nil, err } @@ -987,11 +1029,15 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon } case cm.targetScheme == "https": conn := pconn.conn + hdr := t.ProxyConnectHeader + if hdr == nil { + hdr = make(Header) + } connectReq := &Request{ Method: "CONNECT", URL: &url.URL{Opaque: cm.targetAddr}, Host: cm.targetAddr, - Header: make(Header), + Header: hdr, } if pa := cm.proxyAuth(); pa != "" { connectReq.Header.Set("Proxy-Authorization", pa) @@ -1016,7 +1062,7 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon if cm.targetScheme == "https" && !tlsDial { // Initiate TLS and check remote host name against certificate. - cfg := cloneTLSClientConfig(t.TLSClientConfig) + cfg := cloneTLSConfig(t.TLSClientConfig) if cfg.ServerName == "" { cfg.ServerName = cm.tlsHost() } @@ -1030,6 +1076,9 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon }) } go func() { + if trace != nil && trace.TLSHandshakeStart != nil { + trace.TLSHandshakeStart() + } err := tlsConn.Handshake() if timer != nil { timer.Stop() @@ -1038,6 +1087,9 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon }() if err := <-errc; err != nil { plainConn.Close() + if trace != nil && trace.TLSHandshakeDone != nil { + trace.TLSHandshakeDone(tls.ConnectionState{}, err) + } return nil, err } if !cfg.InsecureSkipVerify { @@ -1047,6 +1099,9 @@ func (t *Transport) dialConn(ctx context.Context, cm connectMethod) (*persistCon } } cs := tlsConn.ConnectionState() + if trace != nil && trace.TLSHandshakeDone != nil { + trace.TLSHandshakeDone(cs, nil) + } pconn.tlsState = &cs pconn.conn = tlsConn } @@ -1235,8 +1290,8 @@ type persistConn struct { mu sync.Mutex // guards following fields numExpectedResponses int closed error // set non-nil when conn is closed, before closech is closed + canceledErr error // set non-nil if conn is canceled broken bool // an error has happened on this connection; marked broken so it's not reused. - canceled bool // whether this conn was broken due a CancelRequest reused bool // whether conn has had successful request/response and is being reused. // mutateHeaderFunc is an optional func to modify extra // headers on each outbound request before it's written. (the @@ -1274,11 +1329,12 @@ func (pc *persistConn) isBroken() bool { return b } -// isCanceled reports whether this connection was closed due to CancelRequest. -func (pc *persistConn) isCanceled() bool { +// canceled returns non-nil if the connection was closed due to +// CancelRequest or due to context cancelation. +func (pc *persistConn) canceled() error { pc.mu.Lock() defer pc.mu.Unlock() - return pc.canceled + return pc.canceledErr } // isReused reports whether this connection is in a known broken state. @@ -1301,10 +1357,10 @@ func (pc *persistConn) gotIdleConnTrace(idleAt time.Time) (t httptrace.GotConnIn return } -func (pc *persistConn) cancelRequest() { +func (pc *persistConn) cancelRequest(err error) { pc.mu.Lock() defer pc.mu.Unlock() - pc.canceled = true + pc.canceledErr = err pc.closeLocked(errRequestCanceled) } @@ -1328,12 +1384,12 @@ func (pc *persistConn) closeConnIfStillIdle() { // // The startBytesWritten value should be the value of pc.nwrite before the roundTrip // started writing the request. -func (pc *persistConn) mapRoundTripErrorFromReadLoop(startBytesWritten int64, err error) (out error) { +func (pc *persistConn) mapRoundTripErrorFromReadLoop(req *Request, startBytesWritten int64, err error) (out error) { if err == nil { return nil } - if pc.isCanceled() { - return errRequestCanceled + if err := pc.canceled(); err != nil { + return err } if err == errServerClosedIdle { return err @@ -1343,7 +1399,7 @@ func (pc *persistConn) mapRoundTripErrorFromReadLoop(startBytesWritten int64, er } if pc.isBroken() { <-pc.writeLoopDone - if pc.nwrite == startBytesWritten { + if pc.nwrite == startBytesWritten && req.outgoingLength() == 0 { return nothingWrittenError{err} } } @@ -1354,9 +1410,9 @@ func (pc *persistConn) mapRoundTripErrorFromReadLoop(startBytesWritten int64, er // up to Transport.RoundTrip method when persistConn.roundTrip sees // its pc.closech channel close, indicating the persistConn is dead. // (after closech is closed, pc.closed is valid). -func (pc *persistConn) mapRoundTripErrorAfterClosed(startBytesWritten int64) error { - if pc.isCanceled() { - return errRequestCanceled +func (pc *persistConn) mapRoundTripErrorAfterClosed(req *Request, startBytesWritten int64) error { + if err := pc.canceled(); err != nil { + return err } err := pc.closed if err == errServerClosedIdle { @@ -1372,7 +1428,7 @@ func (pc *persistConn) mapRoundTripErrorAfterClosed(startBytesWritten int64) err // see if we actually managed to write anything. If not, we // can retry the request. <-pc.writeLoopDone - if pc.nwrite == startBytesWritten { + if pc.nwrite == startBytesWritten && req.outgoingLength() == 0 { return nothingWrittenError{err} } @@ -1513,8 +1569,10 @@ func (pc *persistConn) readLoop() { waitForBodyRead <- isEOF if isEOF { <-eofc // see comment above eofc declaration - } else if err != nil && pc.isCanceled() { - return errRequestCanceled + } else if err != nil { + if cerr := pc.canceled(); cerr != nil { + return cerr + } } return err }, @@ -1554,7 +1612,7 @@ func (pc *persistConn) readLoop() { pc.t.CancelRequest(rc.req) case <-rc.req.Context().Done(): alive = false - pc.t.CancelRequest(rc.req) + pc.t.cancelRequest(rc.req, rc.req.Context().Err()) case <-pc.closech: alive = false } @@ -1652,7 +1710,7 @@ func (pc *persistConn) writeLoop() { } if err != nil { wr.req.Request.closeBody() - if pc.nwrite == startBytesWritten { + if pc.nwrite == startBytesWritten && wr.req.outgoingLength() == 0 { err = nothingWrittenError{err} } } @@ -1840,8 +1898,8 @@ WaitResponse: select { case err := <-writeErrCh: if err != nil { - if pc.isCanceled() { - err = errRequestCanceled + if cerr := pc.canceled(); cerr != nil { + err = cerr } re = responseAndError{err: err} pc.close(fmt.Errorf("write error: %v", err)) @@ -1853,21 +1911,20 @@ WaitResponse: respHeaderTimer = timer.C } case <-pc.closech: - re = responseAndError{err: pc.mapRoundTripErrorAfterClosed(startBytesWritten)} + re = responseAndError{err: pc.mapRoundTripErrorAfterClosed(req.Request, startBytesWritten)} break WaitResponse case <-respHeaderTimer: pc.close(errTimeout) re = responseAndError{err: errTimeout} break WaitResponse case re = <-resc: - re.err = pc.mapRoundTripErrorFromReadLoop(startBytesWritten, re.err) + re.err = pc.mapRoundTripErrorFromReadLoop(req.Request, startBytesWritten, re.err) break WaitResponse case <-cancelChan: pc.t.CancelRequest(req.Request) cancelChan = nil - ctxDoneChan = nil case <-ctxDoneChan: - pc.t.CancelRequest(req.Request) + pc.t.cancelRequest(req.Request, req.Context().Err()) cancelChan = nil ctxDoneChan = nil } @@ -1931,11 +1988,15 @@ var portMap = map[string]string{ // canonicalAddr returns url.Host but always with a ":port" suffix func canonicalAddr(url *url.URL) string { - addr := url.Host - if !hasPort(addr) { - return addr + ":" + portMap[url.Scheme] + addr := url.Hostname() + if v, err := idnaASCII(addr); err == nil { + addr = v + } + port := url.Port() + if port == "" { + port = portMap[url.Scheme] } - return addr + return net.JoinHostPort(addr, port) } // bodyEOFSignal is used by the HTTP/1 transport when reading response @@ -2060,75 +2121,14 @@ type fakeLocker struct{} func (fakeLocker) Lock() {} func (fakeLocker) Unlock() {} -// cloneTLSConfig returns a shallow clone of the exported -// fields of cfg, ignoring the unexported sync.Once, which -// contains a mutex and must not be copied. -// -// The cfg must not be in active use by tls.Server, or else -// there can still be a race with tls.Server updating SessionTicketKey -// and our copying it, and also a race with the server setting -// SessionTicketsDisabled=false on failure to set the random -// ticket key. -// -// If cfg is nil, a new zero tls.Config is returned. +// clneTLSConfig returns a shallow clone of cfg, or a new zero tls.Config if +// cfg is nil. This is safe to call even if cfg is in active use by a TLS +// client or server. func cloneTLSConfig(cfg *tls.Config) *tls.Config { if cfg == nil { return &tls.Config{} } - return &tls.Config{ - Rand: cfg.Rand, - Time: cfg.Time, - Certificates: cfg.Certificates, - NameToCertificate: cfg.NameToCertificate, - GetCertificate: cfg.GetCertificate, - RootCAs: cfg.RootCAs, - NextProtos: cfg.NextProtos, - ServerName: cfg.ServerName, - ClientAuth: cfg.ClientAuth, - ClientCAs: cfg.ClientCAs, - InsecureSkipVerify: cfg.InsecureSkipVerify, - CipherSuites: cfg.CipherSuites, - PreferServerCipherSuites: cfg.PreferServerCipherSuites, - SessionTicketsDisabled: cfg.SessionTicketsDisabled, - SessionTicketKey: cfg.SessionTicketKey, - ClientSessionCache: cfg.ClientSessionCache, - MinVersion: cfg.MinVersion, - MaxVersion: cfg.MaxVersion, - CurvePreferences: cfg.CurvePreferences, - DynamicRecordSizingDisabled: cfg.DynamicRecordSizingDisabled, - Renegotiation: cfg.Renegotiation, - } -} - -// cloneTLSClientConfig is like cloneTLSConfig but omits -// the fields SessionTicketsDisabled and SessionTicketKey. -// This makes it safe to call cloneTLSClientConfig on a config -// in active use by a server. -func cloneTLSClientConfig(cfg *tls.Config) *tls.Config { - if cfg == nil { - return &tls.Config{} - } - return &tls.Config{ - Rand: cfg.Rand, - Time: cfg.Time, - Certificates: cfg.Certificates, - NameToCertificate: cfg.NameToCertificate, - GetCertificate: cfg.GetCertificate, - RootCAs: cfg.RootCAs, - NextProtos: cfg.NextProtos, - ServerName: cfg.ServerName, - ClientAuth: cfg.ClientAuth, - ClientCAs: cfg.ClientCAs, - InsecureSkipVerify: cfg.InsecureSkipVerify, - CipherSuites: cfg.CipherSuites, - PreferServerCipherSuites: cfg.PreferServerCipherSuites, - ClientSessionCache: cfg.ClientSessionCache, - MinVersion: cfg.MinVersion, - MaxVersion: cfg.MaxVersion, - CurvePreferences: cfg.CurvePreferences, - DynamicRecordSizingDisabled: cfg.DynamicRecordSizingDisabled, - Renegotiation: cfg.Renegotiation, - } + return cfg.Clone() } type connLRU struct { @@ -2169,3 +2169,15 @@ func (cl *connLRU) remove(pc *persistConn) { func (cl *connLRU) len() int { return len(cl.m) } + +// validPort reports whether p (without the colon) is a valid port in +// a URL, per RFC 3986 Section 3.2.3, which says the port may be +// empty, or only contain digits. +func validPort(p string) bool { + for _, r := range []byte(p) { + if r < '0' || r > '9' { + return false + } + } + return true +} diff --git a/libgo/go/net/http/transport_internal_test.go b/libgo/go/net/http/transport_internal_test.go index a05ca6ed0d8..3d24fc127d4 100644 --- a/libgo/go/net/http/transport_internal_test.go +++ b/libgo/go/net/http/transport_internal_test.go @@ -72,3 +72,70 @@ func newLocalListener(t *testing.T) net.Listener { } return ln } + +func dummyRequest(method string) *Request { + req, err := NewRequest(method, "http://fake.tld/", nil) + if err != nil { + panic(err) + } + return req +} + +func TestTransportShouldRetryRequest(t *testing.T) { + tests := []struct { + pc *persistConn + req *Request + + err error + want bool + }{ + 0: { + pc: &persistConn{reused: false}, + req: dummyRequest("POST"), + err: nothingWrittenError{}, + want: false, + }, + 1: { + pc: &persistConn{reused: true}, + req: dummyRequest("POST"), + err: nothingWrittenError{}, + want: true, + }, + 2: { + pc: &persistConn{reused: true}, + req: dummyRequest("POST"), + err: http2ErrNoCachedConn, + want: true, + }, + 3: { + pc: &persistConn{reused: true}, + req: dummyRequest("POST"), + err: errMissingHost, + want: false, + }, + 4: { + pc: &persistConn{reused: true}, + req: dummyRequest("POST"), + err: transportReadFromServerError{}, + want: false, + }, + 5: { + pc: &persistConn{reused: true}, + req: dummyRequest("GET"), + err: transportReadFromServerError{}, + want: true, + }, + 6: { + pc: &persistConn{reused: true}, + req: dummyRequest("GET"), + err: errServerClosedIdle, + want: true, + }, + } + for i, tt := range tests { + got := tt.pc.shouldRetryRequest(tt.req, tt.err) + if got != tt.want { + t.Errorf("%d. shouldRetryRequest = %v; want %v", i, got, tt.want) + } + } +} diff --git a/libgo/go/net/http/transport_test.go b/libgo/go/net/http/transport_test.go index 298682d04de..d5ddf6a1232 100644 --- a/libgo/go/net/http/transport_test.go +++ b/libgo/go/net/http/transport_test.go @@ -441,9 +441,7 @@ func TestTransportMaxPerHostIdleConns(t *testing.T) { } func TestTransportRemovesDeadIdleConnections(t *testing.T) { - if runtime.GOOS == "plan9" { - t.Skip("skipping test; see https://golang.org/issue/15464") - } + setParallel(t) defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { io.WriteString(w, r.RemoteAddr) @@ -700,6 +698,7 @@ var roundTripTests = []struct { // Test that the modification made to the Request by the RoundTripper is cleaned up func TestRoundTripGzip(t *testing.T) { + setParallel(t) defer afterTest(t) const responseBody = "test response body" ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { @@ -758,6 +757,7 @@ func TestRoundTripGzip(t *testing.T) { } func TestTransportGzip(t *testing.T) { + setParallel(t) defer afterTest(t) const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" const nRandBytes = 1024 * 1024 @@ -856,6 +856,7 @@ func TestTransportGzip(t *testing.T) { // If a request has Expect:100-continue header, the request blocks sending body until the first response. // Premature consumption of the request body should not be occurred. func TestTransportExpect100Continue(t *testing.T) { + setParallel(t) defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { @@ -966,6 +967,48 @@ func TestTransportProxy(t *testing.T) { } } +// Issue 16997: test transport dial preserves typed errors +func TestTransportDialPreservesNetOpProxyError(t *testing.T) { + defer afterTest(t) + + var errDial = errors.New("some dial error") + + tr := &Transport{ + Proxy: func(*Request) (*url.URL, error) { + return url.Parse("http://proxy.fake.tld/") + }, + Dial: func(string, string) (net.Conn, error) { + return nil, errDial + }, + } + defer tr.CloseIdleConnections() + + c := &Client{Transport: tr} + req, _ := NewRequest("GET", "http://fake.tld", nil) + res, err := c.Do(req) + if err == nil { + res.Body.Close() + t.Fatal("wanted a non-nil error") + } + + uerr, ok := err.(*url.Error) + if !ok { + t.Fatalf("got %T, want *url.Error", err) + } + oe, ok := uerr.Err.(*net.OpError) + if !ok { + t.Fatalf("url.Error.Err = %T; want *net.OpError", uerr.Err) + } + want := &net.OpError{ + Op: "proxyconnect", + Net: "tcp", + Err: errDial, // original error, unwrapped. + } + if !reflect.DeepEqual(oe, want) { + t.Errorf("Got error %#v; want %#v", oe, want) + } +} + // TestTransportGzipRecursive sends a gzip quine and checks that the // client gets the same value back. This is more cute than anything, // but checks that we don't recurse forever, and checks that @@ -1038,10 +1081,12 @@ func waitNumGoroutine(nmax int) int { // tests that persistent goroutine connections shut down when no longer desired. func TestTransportPersistConnLeak(t *testing.T) { - setParallel(t) + // Not parallel: counts goroutines defer afterTest(t) - gotReqCh := make(chan bool) - unblockCh := make(chan bool) + + const numReq = 25 + gotReqCh := make(chan bool, numReq) + unblockCh := make(chan bool, numReq) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { gotReqCh <- true <-unblockCh @@ -1055,14 +1100,15 @@ func TestTransportPersistConnLeak(t *testing.T) { n0 := runtime.NumGoroutine() - const numReq = 25 - didReqCh := make(chan bool) + didReqCh := make(chan bool, numReq) + failed := make(chan bool, numReq) for i := 0; i < numReq; i++ { go func() { res, err := c.Get(ts.URL) didReqCh <- true if err != nil { t.Errorf("client fetch error: %v", err) + failed <- true return } res.Body.Close() @@ -1071,7 +1117,13 @@ func TestTransportPersistConnLeak(t *testing.T) { // Wait for all goroutines to be stuck in the Handler. for i := 0; i < numReq; i++ { - <-gotReqCh + select { + case <-gotReqCh: + // ok + case <-failed: + close(unblockCh) + return + } } nhigh := runtime.NumGoroutine() @@ -1102,7 +1154,7 @@ func TestTransportPersistConnLeak(t *testing.T) { // golang.org/issue/4531: Transport leaks goroutines when // request.ContentLength is explicitly short func TestTransportPersistConnLeakShortBody(t *testing.T) { - setParallel(t) + // Not parallel: measures goroutines. defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { })) @@ -1198,6 +1250,7 @@ func TestIssue3644(t *testing.T) { // Test that a client receives a server's reply, even if the server doesn't read // the entire request body. func TestIssue3595(t *testing.T) { + setParallel(t) defer afterTest(t) const deniedMsg = "sorry, denied." ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { @@ -1246,6 +1299,7 @@ func TestChunkedNoContent(t *testing.T) { } func TestTransportConcurrency(t *testing.T) { + // Not parallel: uses global test hooks. defer afterTest(t) maxProcs, numReqs := 16, 500 if testing.Short() { @@ -1306,9 +1360,7 @@ func TestTransportConcurrency(t *testing.T) { } func TestIssue4191_InfiniteGetTimeout(t *testing.T) { - if runtime.GOOS == "plan9" { - t.Skip("skipping test; see https://golang.org/issue/7237") - } + setParallel(t) defer afterTest(t) const debug = false mux := NewServeMux() @@ -1370,9 +1422,7 @@ func TestIssue4191_InfiniteGetTimeout(t *testing.T) { } func TestIssue4191_InfiniteGetToPutTimeout(t *testing.T) { - if runtime.GOOS == "plan9" { - t.Skip("skipping test; see https://golang.org/issue/7237") - } + setParallel(t) defer afterTest(t) const debug = false mux := NewServeMux() @@ -1696,12 +1746,6 @@ func testCancelRequestWithChannelBeforeDo(t *testing.T, withCtx bool) { defer ts.Close() defer close(unblockc) - // Don't interfere with the next test on plan9. - // Cf. https://golang.org/issues/11476 - if runtime.GOOS == "plan9" { - defer time.Sleep(500 * time.Millisecond) - } - tr := &Transport{} defer tr.CloseIdleConnections() c := &Client{Transport: tr} @@ -1718,8 +1762,17 @@ func testCancelRequestWithChannelBeforeDo(t *testing.T, withCtx bool) { } _, err := c.Do(req) - if err == nil || !strings.Contains(err.Error(), "canceled") { - t.Errorf("Do error = %v; want cancelation", err) + if ue, ok := err.(*url.Error); ok { + err = ue.Err + } + if withCtx { + if err != context.Canceled { + t.Errorf("Do error = %v; want %v", err, context.Canceled) + } + } else { + if err == nil || !strings.Contains(err.Error(), "canceled") { + t.Errorf("Do error = %v; want cancelation", err) + } } } @@ -1888,6 +1941,7 @@ func TestTransportEmptyMethod(t *testing.T) { } func TestTransportSocketLateBinding(t *testing.T) { + setParallel(t) defer afterTest(t) mux := NewServeMux() @@ -2152,6 +2206,7 @@ func TestProxyFromEnvironment(t *testing.T) { } func TestIdleConnChannelLeak(t *testing.T) { + // Not parallel: uses global test hooks. var mu sync.Mutex var n int @@ -2383,6 +2438,7 @@ func (c byteFromChanReader) Read(p []byte) (n int, err error) { // questionable state. // golang.org/issue/7569 func TestTransportNoReuseAfterEarlyResponse(t *testing.T) { + setParallel(t) defer afterTest(t) var sconn struct { sync.Mutex @@ -2485,22 +2541,6 @@ type errorReader struct { func (e errorReader) Read(p []byte) (int, error) { return 0, e.err } -type plan9SleepReader struct{} - -func (plan9SleepReader) Read(p []byte) (int, error) { - if runtime.GOOS == "plan9" { - // After the fix to unblock TCP Reads in - // https://golang.org/cl/15941, this sleep is required - // on plan9 to make sure TCP Writes before an - // immediate TCP close go out on the wire. On Plan 9, - // it seems that a hangup of a TCP connection with - // queued data doesn't send the queued data first. - // https://golang.org/issue/9554 - time.Sleep(50 * time.Millisecond) - } - return 0, io.EOF -} - type closerFunc func() error func (f closerFunc) Close() error { return f() } @@ -2595,7 +2635,7 @@ func TestTransportClosesBodyOnError(t *testing.T) { io.Reader io.Closer }{ - io.MultiReader(io.LimitReader(neverEnding('x'), 1<<20), plan9SleepReader{}, errorReader{fakeErr}), + io.MultiReader(io.LimitReader(neverEnding('x'), 1<<20), errorReader{fakeErr}), closerFunc(func() error { select { case didClose <- true: @@ -2627,6 +2667,8 @@ func TestTransportClosesBodyOnError(t *testing.T) { } func TestTransportDialTLS(t *testing.T) { + setParallel(t) + defer afterTest(t) var mu sync.Mutex // guards following var gotReq, didDial bool @@ -2904,14 +2946,8 @@ func TestTransportFlushesBodyChunks(t *testing.T) { defer res.Body.Close() want := []string{ - // Because Request.ContentLength = 0, the body is sniffed for 1 byte to determine whether there's content. - // That explains the initial "num0" being split into "n" and "um0". - // The first byte is included with the request headers Write. Perhaps in the future - // we will want to flush the headers out early if the first byte of the request body is - // taking a long time to arrive. But not yet. "POST / HTTP/1.1\r\nHost: localhost:8080\r\nUser-Agent: x\r\nTransfer-Encoding: chunked\r\nAccept-Encoding: gzip\r\n\r\n" + - "1\r\nn\r\n", - "4\r\num0\n\r\n", + "5\r\nnum0\n\r\n", "5\r\nnum1\n\r\n", "5\r\nnum2\n\r\n", "0\r\n\r\n", @@ -3150,6 +3186,7 @@ func TestTransportReuseConnection_Gzip_ContentLength(t *testing.T) { // Make sure we re-use underlying TCP connection for gzipped responses too. func testTransportReuseConnection_Gzip(t *testing.T, chunked bool) { + setParallel(t) defer afterTest(t) addr := make(chan string, 2) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { @@ -3185,6 +3222,7 @@ func testTransportReuseConnection_Gzip(t *testing.T, chunked bool) { } func TestTransportResponseHeaderLength(t *testing.T) { + setParallel(t) defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { if r.URL.Path == "/long" { @@ -3248,7 +3286,7 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) { cst.tr.ExpectContinueTimeout = 1 * time.Second - var mu sync.Mutex + var mu sync.Mutex // guards buf var buf bytes.Buffer logf := func(format string, args ...interface{}) { mu.Lock() @@ -3290,10 +3328,16 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) { Wait100Continue: func() { logf("Wait100Continue") }, Got100Continue: func() { logf("Got100Continue") }, WroteRequest: func(e httptrace.WroteRequestInfo) { - close(gotWroteReqEvent) logf("WroteRequest: %+v", e) + close(gotWroteReqEvent) }, } + if h2 { + trace.TLSHandshakeStart = func() { logf("tls handshake start") } + trace.TLSHandshakeDone = func(s tls.ConnectionState, err error) { + logf("tls handshake done. ConnectionState = %v \n err = %v", s, err) + } + } if noHooks { // zero out all func pointers, trying to get some path to crash *trace = httptrace.ClientTrace{} @@ -3323,7 +3367,10 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) { return } + mu.Lock() got := buf.String() + mu.Unlock() + wantOnce := func(sub string) { if strings.Count(got, sub) != 1 { t.Errorf("expected substring %q exactly once in output.", sub) @@ -3342,7 +3389,10 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) { wantOnceOrMore("connected to tcp " + addrStr + " = <nil>") wantOnce("Reused:false WasIdle:false IdleTime:0s") wantOnce("first response byte") - if !h2 { + if h2 { + wantOnce("tls handshake start") + wantOnce("tls handshake done") + } else { wantOnce("PutIdleConn = <nil>") } wantOnce("Wait100Continue") @@ -3357,12 +3407,21 @@ func testTransportEventTrace(t *testing.T, h2 bool, noHooks bool) { } func TestTransportEventTraceRealDNS(t *testing.T) { + if testing.Short() && testenv.Builder() == "" { + // Skip this test in short mode (the default for + // all.bash), in case the user is using a shady/ISP + // DNS server hijacking queries. + // See issues 16732, 16716. + // Our builders use 8.8.8.8, though, which correctly + // returns NXDOMAIN, so still run this test there. + t.Skip("skipping in short mode") + } defer afterTest(t) tr := &Transport{} defer tr.CloseIdleConnections() c := &Client{Transport: tr} - var mu sync.Mutex + var mu sync.Mutex // guards buf var buf bytes.Buffer logf := func(format string, args ...interface{}) { mu.Lock() @@ -3386,7 +3445,10 @@ func TestTransportEventTraceRealDNS(t *testing.T) { t.Fatal("expected error during DNS lookup") } + mu.Lock() got := buf.String() + mu.Unlock() + wantSub := func(sub string) { if !strings.Contains(got, sub) { t.Errorf("expected substring %q in output.", sub) @@ -3402,6 +3464,73 @@ func TestTransportEventTraceRealDNS(t *testing.T) { } } +// Issue 14353: port can only contain digits. +func TestTransportRejectsAlphaPort(t *testing.T) { + res, err := Get("http://dummy.tld:123foo/bar") + if err == nil { + res.Body.Close() + t.Fatal("unexpected success") + } + ue, ok := err.(*url.Error) + if !ok { + t.Fatalf("got %#v; want *url.Error", err) + } + got := ue.Err.Error() + want := `invalid URL port "123foo"` + if got != want { + t.Errorf("got error %q; want %q", got, want) + } +} + +// Test the httptrace.TLSHandshake{Start,Done} hooks with a https http1 +// connections. The http2 test is done in TestTransportEventTrace_h2 +func TestTLSHandshakeTrace(t *testing.T) { + defer afterTest(t) + s := httptest.NewTLSServer(HandlerFunc(func(w ResponseWriter, r *Request) {})) + defer s.Close() + + var mu sync.Mutex + var start, done bool + trace := &httptrace.ClientTrace{ + TLSHandshakeStart: func() { + mu.Lock() + defer mu.Unlock() + start = true + }, + TLSHandshakeDone: func(s tls.ConnectionState, err error) { + mu.Lock() + defer mu.Unlock() + done = true + if err != nil { + t.Fatal("Expected error to be nil but was:", err) + } + }, + } + + tr := &Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + req, err := NewRequest("GET", s.URL, nil) + if err != nil { + t.Fatal("Unable to construct test request:", err) + } + req = req.WithContext(httptrace.WithClientTrace(req.Context(), trace)) + + r, err := c.Do(req) + if err != nil { + t.Fatal("Unexpected error making request:", err) + } + r.Body.Close() + mu.Lock() + defer mu.Unlock() + if !start { + t.Fatal("Expected TLSHandshakeStart to be called, but wasn't") + } + if !done { + t.Fatal("Expected TLSHandshakeDone to be called, but wasnt't") + } +} + func TestTransportMaxIdleConns(t *testing.T) { defer afterTest(t) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { @@ -3457,27 +3586,36 @@ func TestTransportMaxIdleConns(t *testing.T) { } } -func TestTransportIdleConnTimeout(t *testing.T) { +func TestTransportIdleConnTimeout_h1(t *testing.T) { testTransportIdleConnTimeout(t, h1Mode) } +func TestTransportIdleConnTimeout_h2(t *testing.T) { testTransportIdleConnTimeout(t, h2Mode) } +func testTransportIdleConnTimeout(t *testing.T, h2 bool) { if testing.Short() { t.Skip("skipping in short mode") } defer afterTest(t) - ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + const timeout = 1 * time.Second + + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { // No body for convenience. })) - defer ts.Close() - - const timeout = 1 * time.Second - tr := &Transport{ - IdleConnTimeout: timeout, - } + defer cst.close() + tr := cst.tr + tr.IdleConnTimeout = timeout defer tr.CloseIdleConnections() c := &Client{Transport: tr} + idleConns := func() []string { + if h2 { + return tr.IdleConnStrsForTesting_h2() + } else { + return tr.IdleConnStrsForTesting() + } + } + var conn string doReq := func(n int) { - req, _ := NewRequest("GET", ts.URL, nil) + req, _ := NewRequest("GET", cst.ts.URL, nil) req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ PutIdleConn: func(err error) { if err != nil { @@ -3490,7 +3628,7 @@ func TestTransportIdleConnTimeout(t *testing.T) { t.Fatal(err) } res.Body.Close() - conns := tr.IdleConnStrsForTesting() + conns := idleConns() if len(conns) != 1 { t.Fatalf("req %v: unexpected number of idle conns: %q", n, conns) } @@ -3506,7 +3644,7 @@ func TestTransportIdleConnTimeout(t *testing.T) { time.Sleep(timeout / 2) } time.Sleep(timeout * 3 / 2) - if got := tr.IdleConnStrsForTesting(); len(got) != 0 { + if got := idleConns(); len(got) != 0 { t.Errorf("idle conns = %q; want none", got) } } @@ -3523,6 +3661,7 @@ func TestTransportIdleConnTimeout(t *testing.T) { // know the successful tls.Dial from DialTLS will need to go into the // idle pool. Then we give it a of time to explode. func TestIdleConnH2Crash(t *testing.T) { + setParallel(t) cst := newClientServerTest(t, h2Mode, HandlerFunc(func(w ResponseWriter, r *Request) { // nothing })) @@ -3531,12 +3670,12 @@ func TestIdleConnH2Crash(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() - gotErr := make(chan bool, 1) + sawDoErr := make(chan bool, 1) + testDone := make(chan struct{}) + defer close(testDone) cst.tr.IdleConnTimeout = 5 * time.Millisecond cst.tr.DialTLS = func(network, addr string) (net.Conn, error) { - cancel() - <-gotErr c, err := tls.Dial(network, addr, &tls.Config{ InsecureSkipVerify: true, NextProtos: []string{"h2"}, @@ -3550,6 +3689,17 @@ func TestIdleConnH2Crash(t *testing.T) { c.Close() return nil, errors.New("bogus") } + + cancel() + + failTimer := time.NewTimer(5 * time.Second) + defer failTimer.Stop() + select { + case <-sawDoErr: + case <-testDone: + case <-failTimer.C: + t.Error("timeout in DialTLS, waiting too long for cst.c.Do to fail") + } return c, nil } @@ -3560,7 +3710,7 @@ func TestIdleConnH2Crash(t *testing.T) { res.Body.Close() t.Fatal("unexpected success") } - gotErr <- true + sawDoErr <- true // Wait for the explosion. time.Sleep(cst.tr.IdleConnTimeout * 10) @@ -3605,6 +3755,122 @@ func TestTransportReturnsPeekError(t *testing.T) { } } +// Issue 13835: international domain names should work +func TestTransportIDNA_h1(t *testing.T) { testTransportIDNA(t, h1Mode) } +func TestTransportIDNA_h2(t *testing.T) { testTransportIDNA(t, h2Mode) } +func testTransportIDNA(t *testing.T, h2 bool) { + defer afterTest(t) + + const uniDomain = "гофер.го" + const punyDomain = "xn--c1ae0ajs.xn--c1aw" + + var port string + cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) { + want := punyDomain + ":" + port + if r.Host != want { + t.Errorf("Host header = %q; want %q", r.Host, want) + } + if h2 { + if r.TLS == nil { + t.Errorf("r.TLS == nil") + } else if r.TLS.ServerName != punyDomain { + t.Errorf("TLS.ServerName = %q; want %q", r.TLS.ServerName, punyDomain) + } + } + w.Header().Set("Hit-Handler", "1") + })) + defer cst.close() + + ip, port, err := net.SplitHostPort(cst.ts.Listener.Addr().String()) + if err != nil { + t.Fatal(err) + } + + // Install a fake DNS server. + ctx := context.WithValue(context.Background(), nettrace.LookupIPAltResolverKey{}, func(ctx context.Context, host string) ([]net.IPAddr, error) { + if host != punyDomain { + t.Errorf("got DNS host lookup for %q; want %q", host, punyDomain) + return nil, nil + } + return []net.IPAddr{{IP: net.ParseIP(ip)}}, nil + }) + + req, _ := NewRequest("GET", cst.scheme()+"://"+uniDomain+":"+port, nil) + trace := &httptrace.ClientTrace{ + GetConn: func(hostPort string) { + want := net.JoinHostPort(punyDomain, port) + if hostPort != want { + t.Errorf("getting conn for %q; want %q", hostPort, want) + } + }, + DNSStart: func(e httptrace.DNSStartInfo) { + if e.Host != punyDomain { + t.Errorf("DNSStart Host = %q; want %q", e.Host, punyDomain) + } + }, + } + req = req.WithContext(httptrace.WithClientTrace(ctx, trace)) + + res, err := cst.tr.RoundTrip(req) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + if res.Header.Get("Hit-Handler") != "1" { + out, err := httputil.DumpResponse(res, true) + if err != nil { + t.Fatal(err) + } + t.Errorf("Response body wasn't from Handler. Got:\n%s\n", out) + } +} + +// Issue 13290: send User-Agent in proxy CONNECT +func TestTransportProxyConnectHeader(t *testing.T) { + defer afterTest(t) + reqc := make(chan *Request, 1) + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if r.Method != "CONNECT" { + t.Errorf("method = %q; want CONNECT", r.Method) + } + reqc <- r + c, _, err := w.(Hijacker).Hijack() + if err != nil { + t.Errorf("Hijack: %v", err) + return + } + c.Close() + })) + defer ts.Close() + tr := &Transport{ + ProxyConnectHeader: Header{ + "User-Agent": {"foo"}, + "Other": {"bar"}, + }, + Proxy: func(r *Request) (*url.URL, error) { + return url.Parse(ts.URL) + }, + } + defer tr.CloseIdleConnections() + c := &Client{Transport: tr} + res, err := c.Get("https://dummy.tld/") // https to force a CONNECT + if err == nil { + res.Body.Close() + t.Errorf("unexpected success") + } + select { + case <-time.After(3 * time.Second): + t.Fatal("timeout") + case r := <-reqc: + if got, want := r.Header.Get("User-Agent"), "foo"; got != want { + t.Errorf("CONNECT request User-Agent = %q; want %q", got, want) + } + if got, want := r.Header.Get("Other"), "bar"; got != want { + t.Errorf("CONNECT request Other = %q; want %q", got, want) + } + } +} + var errFakeRoundTrip = errors.New("fake roundtrip") type funcRoundTripper func() |