diff options
| -rw-r--r-- | src/go/build/deps_test.go | 2 | ||||
| -rw-r--r-- | src/net/http/httptrace/trace.go | 7 | ||||
| -rw-r--r-- | src/net/http/transport.go | 49 | ||||
| -rw-r--r-- | src/net/http/transport_test.go | 68 |
4 files changed, 103 insertions, 23 deletions
diff --git a/src/go/build/deps_test.go b/src/go/build/deps_test.go index 67d1115017..9d667b6107 100644 --- a/src/go/build/deps_test.go +++ b/src/go/build/deps_test.go @@ -416,7 +416,7 @@ var pkgDeps = map[string][]string{ "syscall/js", }, "net/http/internal": {"L4"}, - "net/http/httptrace": {"context", "crypto/tls", "internal/nettrace", "net", "reflect", "time"}, + "net/http/httptrace": {"context", "crypto/tls", "internal/nettrace", "net", "net/textproto", "reflect", "time"}, // HTTP-using packages. "expvar": {"L4", "OS", "encoding/json", "net/http"}, diff --git a/src/net/http/httptrace/trace.go b/src/net/http/httptrace/trace.go index ea7b38c8fc..8033535670 100644 --- a/src/net/http/httptrace/trace.go +++ b/src/net/http/httptrace/trace.go @@ -11,6 +11,7 @@ import ( "crypto/tls" "internal/nettrace" "net" + "net/textproto" "reflect" "time" ) @@ -107,6 +108,12 @@ type ClientTrace struct { // Continue" response. Got100Continue func() + // Got1xxResponse is called for each 1xx informational response header + // returned before the final non-1xx response. Got1xxResponse is called + // for "100 Continue" responses, even if Got100Continue is also defined. + // If it returns an error, the client request is aborted with that error value. + Got1xxResponse func(code int, header textproto.MIMEHeader) error + // DNSStart is called when a DNS lookup begins. DNSStart func(DNSStartInfo) diff --git a/src/net/http/transport.go b/src/net/http/transport.go index 3890f19af3..9b5ea52c9b 100644 --- a/src/net/http/transport.go +++ b/src/net/http/transport.go @@ -21,6 +21,7 @@ import ( "log" "net" "net/http/httptrace" + "net/textproto" "net/url" "os" "strings" @@ -1641,26 +1642,42 @@ func (pc *persistConn) readResponse(rc requestAndChan, trace *httptrace.ClientTr trace.GotFirstResponseByte() } } - resp, err = ReadResponse(pc.br, rc.req) - if err != nil { - return - } - if rc.continueCh != nil { - if resp.StatusCode == 100 { - if trace != nil && trace.Got100Continue != nil { - trace.Got100Continue() - } - rc.continueCh <- struct{}{} - } else { - close(rc.continueCh) - } - } - if resp.StatusCode == 100 { - pc.readLimit = pc.maxHeaderResponseSize() // reset the limit + num1xx := 0 // number of informational 1xx headers received + const max1xxResponses = 5 // arbitrary bound on number of informational responses + + continueCh := rc.continueCh + for { resp, err = ReadResponse(pc.br, rc.req) if err != nil { return } + resCode := resp.StatusCode + if continueCh != nil { + if resCode == 100 { + if trace != nil && trace.Got100Continue != nil { + trace.Got100Continue() + } + continueCh <- struct{}{} + continueCh = nil + } else if resCode >= 200 { + close(continueCh) + continueCh = nil + } + } + if 100 <= resCode && resCode <= 199 { + num1xx++ + if num1xx > max1xxResponses { + return nil, errors.New("net/http: too many 1xx informational responses") + } + pc.readLimit = pc.maxHeaderResponseSize() // reset the limit + if trace != nil && trace.Got1xxResponse != nil { + if err := trace.Got1xxResponse(resCode, textproto.MIMEHeader(resp.Header)); err != nil { + return nil, err + } + } + continue + } + break } resp.TLS = pc.tlsState return diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go index 57309bbac1..01a209c633 100644 --- a/src/net/http/transport_test.go +++ b/src/net/http/transport_test.go @@ -31,6 +31,7 @@ import ( "net/http/httptrace" "net/http/httputil" "net/http/internal" + "net/textproto" "net/url" "os" "reflect" @@ -2287,6 +2288,7 @@ Content-Length: %d c := &Client{Transport: tr} testResponse := func(req *Request, name string, wantCode int) { + t.Helper() res, err := c.Do(req) if err != nil { t.Fatalf("%s: Do: %v", name, err) @@ -2309,13 +2311,67 @@ Content-Length: %d req.Header.Set("Request-Id", reqID(i)) testResponse(req, fmt.Sprintf("100, %d/%d", i, numReqs), 200) } +} - // And some other informational 1xx but non-100 responses, to test - // we return them but don't re-use the connection. - for i := 1; i <= numReqs; i++ { - req, _ := NewRequest("POST", "http://other.tld/", strings.NewReader(reqBody(i))) - req.Header.Set("X-Want-Response-Code", "123 Sesame Street") - testResponse(req, fmt.Sprintf("123, %d/%d", i, numReqs), 123) +// Issue 17739: the HTTP client must ignore any unknown 1xx +// informational responses before the actual response. +func TestTransportIgnore1xxResponses(t *testing.T) { + setParallel(t) + defer afterTest(t) + cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + conn, buf, _ := w.(Hijacker).Hijack() + buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\nFoo: bar\r\n\r\nHTTP/1.1 200 OK\r\nBar: baz\r\nContent-Length: 5\r\n\r\nHello")) + buf.Flush() + conn.Close() + })) + defer cst.close() + cst.tr.DisableKeepAlives = true // prevent log spam; our test server is hanging up anyway + + var got bytes.Buffer + + req, _ := NewRequest("GET", cst.ts.URL, nil) + req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{ + Got1xxResponse: func(code int, header textproto.MIMEHeader) error { + fmt.Fprintf(&got, "1xx: code=%v, header=%v\n", code, header) + return nil + }, + })) + res, err := cst.c.Do(req) + if err != nil { + t.Fatal(err) + } + defer res.Body.Close() + + res.Write(&got) + want := "1xx: code=123, header=map[Foo:[bar]]\nHTTP/1.1 200 OK\r\nContent-Length: 5\r\nBar: baz\r\n\r\nHello" + if got.String() != want { + t.Errorf(" got: %q\nwant: %q\n", got.Bytes(), want) + } +} + +func TestTransportLimits1xxResponses(t *testing.T) { + setParallel(t) + defer afterTest(t) + cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) { + conn, buf, _ := w.(Hijacker).Hijack() + for i := 0; i < 10; i++ { + buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\n\r\n")) + } + buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n")) + buf.Flush() + conn.Close() + })) + defer cst.close() + cst.tr.DisableKeepAlives = true // prevent log spam; our test server is hanging up anyway + + res, err := cst.c.Get(cst.ts.URL) + if res != nil { + defer res.Body.Close() + } + got := fmt.Sprint(err) + wantSub := "too many 1xx informational responses" + if !strings.Contains(got, wantSub) { + t.Errorf("Get error = %v; want substring %q", err, wantSub) } } |
