summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/go/build/deps_test.go2
-rw-r--r--src/net/http/httptrace/trace.go7
-rw-r--r--src/net/http/transport.go49
-rw-r--r--src/net/http/transport_test.go68
4 files changed, 103 insertions, 23 deletions
diff --git a/src/go/build/deps_test.go b/src/go/build/deps_test.go
index 67d1115017..9d667b6107 100644
--- a/src/go/build/deps_test.go
+++ b/src/go/build/deps_test.go
@@ -416,7 +416,7 @@ var pkgDeps = map[string][]string{
"syscall/js",
},
"net/http/internal": {"L4"},
- "net/http/httptrace": {"context", "crypto/tls", "internal/nettrace", "net", "reflect", "time"},
+ "net/http/httptrace": {"context", "crypto/tls", "internal/nettrace", "net", "net/textproto", "reflect", "time"},
// HTTP-using packages.
"expvar": {"L4", "OS", "encoding/json", "net/http"},
diff --git a/src/net/http/httptrace/trace.go b/src/net/http/httptrace/trace.go
index ea7b38c8fc..8033535670 100644
--- a/src/net/http/httptrace/trace.go
+++ b/src/net/http/httptrace/trace.go
@@ -11,6 +11,7 @@ import (
"crypto/tls"
"internal/nettrace"
"net"
+ "net/textproto"
"reflect"
"time"
)
@@ -107,6 +108,12 @@ type ClientTrace struct {
// Continue" response.
Got100Continue func()
+ // Got1xxResponse is called for each 1xx informational response header
+ // returned before the final non-1xx response. Got1xxResponse is called
+ // for "100 Continue" responses, even if Got100Continue is also defined.
+ // If it returns an error, the client request is aborted with that error value.
+ Got1xxResponse func(code int, header textproto.MIMEHeader) error
+
// DNSStart is called when a DNS lookup begins.
DNSStart func(DNSStartInfo)
diff --git a/src/net/http/transport.go b/src/net/http/transport.go
index 3890f19af3..9b5ea52c9b 100644
--- a/src/net/http/transport.go
+++ b/src/net/http/transport.go
@@ -21,6 +21,7 @@ import (
"log"
"net"
"net/http/httptrace"
+ "net/textproto"
"net/url"
"os"
"strings"
@@ -1641,26 +1642,42 @@ func (pc *persistConn) readResponse(rc requestAndChan, trace *httptrace.ClientTr
trace.GotFirstResponseByte()
}
}
- resp, err = ReadResponse(pc.br, rc.req)
- if err != nil {
- return
- }
- if rc.continueCh != nil {
- if resp.StatusCode == 100 {
- if trace != nil && trace.Got100Continue != nil {
- trace.Got100Continue()
- }
- rc.continueCh <- struct{}{}
- } else {
- close(rc.continueCh)
- }
- }
- if resp.StatusCode == 100 {
- pc.readLimit = pc.maxHeaderResponseSize() // reset the limit
+ num1xx := 0 // number of informational 1xx headers received
+ const max1xxResponses = 5 // arbitrary bound on number of informational responses
+
+ continueCh := rc.continueCh
+ for {
resp, err = ReadResponse(pc.br, rc.req)
if err != nil {
return
}
+ resCode := resp.StatusCode
+ if continueCh != nil {
+ if resCode == 100 {
+ if trace != nil && trace.Got100Continue != nil {
+ trace.Got100Continue()
+ }
+ continueCh <- struct{}{}
+ continueCh = nil
+ } else if resCode >= 200 {
+ close(continueCh)
+ continueCh = nil
+ }
+ }
+ if 100 <= resCode && resCode <= 199 {
+ num1xx++
+ if num1xx > max1xxResponses {
+ return nil, errors.New("net/http: too many 1xx informational responses")
+ }
+ pc.readLimit = pc.maxHeaderResponseSize() // reset the limit
+ if trace != nil && trace.Got1xxResponse != nil {
+ if err := trace.Got1xxResponse(resCode, textproto.MIMEHeader(resp.Header)); err != nil {
+ return nil, err
+ }
+ }
+ continue
+ }
+ break
}
resp.TLS = pc.tlsState
return
diff --git a/src/net/http/transport_test.go b/src/net/http/transport_test.go
index 57309bbac1..01a209c633 100644
--- a/src/net/http/transport_test.go
+++ b/src/net/http/transport_test.go
@@ -31,6 +31,7 @@ import (
"net/http/httptrace"
"net/http/httputil"
"net/http/internal"
+ "net/textproto"
"net/url"
"os"
"reflect"
@@ -2287,6 +2288,7 @@ Content-Length: %d
c := &Client{Transport: tr}
testResponse := func(req *Request, name string, wantCode int) {
+ t.Helper()
res, err := c.Do(req)
if err != nil {
t.Fatalf("%s: Do: %v", name, err)
@@ -2309,13 +2311,67 @@ Content-Length: %d
req.Header.Set("Request-Id", reqID(i))
testResponse(req, fmt.Sprintf("100, %d/%d", i, numReqs), 200)
}
+}
- // And some other informational 1xx but non-100 responses, to test
- // we return them but don't re-use the connection.
- for i := 1; i <= numReqs; i++ {
- req, _ := NewRequest("POST", "http://other.tld/", strings.NewReader(reqBody(i)))
- req.Header.Set("X-Want-Response-Code", "123 Sesame Street")
- testResponse(req, fmt.Sprintf("123, %d/%d", i, numReqs), 123)
+// Issue 17739: the HTTP client must ignore any unknown 1xx
+// informational responses before the actual response.
+func TestTransportIgnore1xxResponses(t *testing.T) {
+ setParallel(t)
+ defer afterTest(t)
+ cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ conn, buf, _ := w.(Hijacker).Hijack()
+ buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\nFoo: bar\r\n\r\nHTTP/1.1 200 OK\r\nBar: baz\r\nContent-Length: 5\r\n\r\nHello"))
+ buf.Flush()
+ conn.Close()
+ }))
+ defer cst.close()
+ cst.tr.DisableKeepAlives = true // prevent log spam; our test server is hanging up anyway
+
+ var got bytes.Buffer
+
+ req, _ := NewRequest("GET", cst.ts.URL, nil)
+ req = req.WithContext(httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
+ Got1xxResponse: func(code int, header textproto.MIMEHeader) error {
+ fmt.Fprintf(&got, "1xx: code=%v, header=%v\n", code, header)
+ return nil
+ },
+ }))
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ defer res.Body.Close()
+
+ res.Write(&got)
+ want := "1xx: code=123, header=map[Foo:[bar]]\nHTTP/1.1 200 OK\r\nContent-Length: 5\r\nBar: baz\r\n\r\nHello"
+ if got.String() != want {
+ t.Errorf(" got: %q\nwant: %q\n", got.Bytes(), want)
+ }
+}
+
+func TestTransportLimits1xxResponses(t *testing.T) {
+ setParallel(t)
+ defer afterTest(t)
+ cst := newClientServerTest(t, h1Mode, HandlerFunc(func(w ResponseWriter, r *Request) {
+ conn, buf, _ := w.(Hijacker).Hijack()
+ for i := 0; i < 10; i++ {
+ buf.Write([]byte("HTTP/1.1 123 OneTwoThree\r\n\r\n"))
+ }
+ buf.Write([]byte("HTTP/1.1 204 No Content\r\n\r\n"))
+ buf.Flush()
+ conn.Close()
+ }))
+ defer cst.close()
+ cst.tr.DisableKeepAlives = true // prevent log spam; our test server is hanging up anyway
+
+ res, err := cst.c.Get(cst.ts.URL)
+ if res != nil {
+ defer res.Body.Close()
+ }
+ got := fmt.Sprint(err)
+ wantSub := "too many 1xx informational responses"
+ if !strings.Contains(got, wantSub) {
+ t.Errorf("Get error = %v; want substring %q", err, wantSub)
}
}