summaryrefslogtreecommitdiff
path: root/libgo/go/net/http/serve_test.go
diff options
context:
space:
mode:
Diffstat (limited to 'libgo/go/net/http/serve_test.go')
-rw-r--r--libgo/go/net/http/serve_test.go219
1 files changed, 170 insertions, 49 deletions
diff --git a/libgo/go/net/http/serve_test.go b/libgo/go/net/http/serve_test.go
index 6394da3bb7c..fb18cb2c6f5 100644
--- a/libgo/go/net/http/serve_test.go
+++ b/libgo/go/net/http/serve_test.go
@@ -23,6 +23,7 @@ import (
"net"
. "net/http"
"net/http/httptest"
+ "net/http/httptrace"
"net/http/httputil"
"net/http/internal"
"net/http/internal/testcert"
@@ -2146,7 +2147,7 @@ func TestInvalidTrailerClosesConnection(t *testing.T) {
// Read and Write.
type slowTestConn struct {
// over multiple calls to Read, time.Durations are slept, strings are read.
- script []interface{}
+ script []any
closec chan bool
mu sync.Mutex // guards rd/wd
@@ -2238,7 +2239,7 @@ func TestRequestBodyTimeoutClosesConnection(t *testing.T) {
defer afterTest(t)
for _, handler := range testHandlerBodyConsumers {
conn := &slowTestConn{
- script: []interface{}{
+ script: []any{
"POST /public HTTP/1.1\r\n" +
"Host: test\r\n" +
"Content-Length: 10000\r\n" +
@@ -2273,6 +2274,18 @@ func TestRequestBodyTimeoutClosesConnection(t *testing.T) {
}
}
+// cancelableTimeoutContext overwrites the error message to DeadlineExceeded
+type cancelableTimeoutContext struct {
+ context.Context
+}
+
+func (c cancelableTimeoutContext) Err() error {
+ if c.Context.Err() != nil {
+ return context.DeadlineExceeded
+ }
+ return nil
+}
+
func TestTimeoutHandler_h1(t *testing.T) { testTimeoutHandler(t, h1Mode) }
func TestTimeoutHandler_h2(t *testing.T) { testTimeoutHandler(t, h2Mode) }
func testTimeoutHandler(t *testing.T, h2 bool) {
@@ -2285,8 +2298,9 @@ func testTimeoutHandler(t *testing.T, h2 bool) {
_, werr := w.Write([]byte("hi"))
writeErrors <- werr
})
- timeout := make(chan time.Time, 1) // write to this to force timeouts
- cst := newClientServerTest(t, h2, NewTestTimeoutHandler(sayHi, timeout))
+ ctx, cancel := context.WithCancel(context.Background())
+ h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx})
+ cst := newClientServerTest(t, h2, h)
defer cst.close()
// Succeed without timing out:
@@ -2307,7 +2321,8 @@ func testTimeoutHandler(t *testing.T, h2 bool) {
}
// Times out:
- timeout <- time.Time{}
+ cancel()
+
res, err = cst.c.Get(cst.ts.URL)
if err != nil {
t.Error(err)
@@ -2428,8 +2443,9 @@ func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) {
_, werr := w.Write([]byte("hi"))
writeErrors <- werr
})
- timeout := make(chan time.Time, 1) // write to this to force timeouts
- cst := newClientServerTest(t, h1Mode, NewTestTimeoutHandler(sayHi, timeout))
+ ctx, cancel := context.WithCancel(context.Background())
+ h := NewTestTimeoutHandler(sayHi, cancelableTimeoutContext{ctx})
+ cst := newClientServerTest(t, h1Mode, h)
defer cst.close()
// Succeed without timing out:
@@ -2450,7 +2466,8 @@ func TestTimeoutHandlerRaceHeaderTimeout(t *testing.T) {
}
// Times out:
- timeout <- time.Time{}
+ cancel()
+
res, err = cst.c.Get(cst.ts.URL)
if err != nil {
t.Error(err)
@@ -2500,6 +2517,47 @@ func TestTimeoutHandlerStartTimerWhenServing(t *testing.T) {
}
}
+func TestTimeoutHandlerContextCanceled(t *testing.T) {
+ setParallel(t)
+ defer afterTest(t)
+ writeErrors := make(chan error, 1)
+ sayHi := HandlerFunc(func(w ResponseWriter, r *Request) {
+ w.Header().Set("Content-Type", "text/plain")
+ var err error
+ // The request context has already been canceled, but
+ // retry the write for a while to give the timeout handler
+ // a chance to notice.
+ for i := 0; i < 100; i++ {
+ _, err = w.Write([]byte("a"))
+ if err != nil {
+ break
+ }
+ time.Sleep(1 * time.Millisecond)
+ }
+ writeErrors <- err
+ })
+ ctx, cancel := context.WithCancel(context.Background())
+ cancel()
+ h := NewTestTimeoutHandler(sayHi, ctx)
+ cst := newClientServerTest(t, h1Mode, h)
+ defer cst.close()
+
+ res, err := cst.c.Get(cst.ts.URL)
+ if err != nil {
+ t.Error(err)
+ }
+ if g, e := res.StatusCode, StatusServiceUnavailable; g != e {
+ t.Errorf("got res.StatusCode %d; expected %d", g, e)
+ }
+ body, _ := io.ReadAll(res.Body)
+ if g, e := string(body), ""; g != e {
+ t.Errorf("got body %q; expected %q", g, e)
+ }
+ if g, e := <-writeErrors, context.Canceled; g != e {
+ t.Errorf("got unexpected Write in handler: %v, want %g", g, e)
+ }
+}
+
// https://golang.org/issue/15948
func TestTimeoutHandlerEmptyResponse(t *testing.T) {
setParallel(t)
@@ -2708,7 +2766,7 @@ func TestHandlerPanicWithHijack(t *testing.T) {
testHandlerPanic(t, true, h1Mode, nil, "intentional death for testing")
}
-func testHandlerPanic(t *testing.T, withHijack, h2 bool, wrapper func(Handler) Handler, panicValue interface{}) {
+func testHandlerPanic(t *testing.T, withHijack, h2 bool, wrapper func(Handler) Handler, panicValue any) {
defer afterTest(t)
// Unlike the other tests that set the log output to io.Discard
// to quiet the output, this test uses a pipe. The pipe serves three
@@ -3017,22 +3075,14 @@ func TestClientWriteShutdown(t *testing.T) {
if err != nil {
t.Fatalf("CloseWrite: %v", err)
}
- donec := make(chan bool)
- go func() {
- defer close(donec)
- bs, err := io.ReadAll(conn)
- if err != nil {
- t.Errorf("ReadAll: %v", err)
- }
- got := string(bs)
- if got != "" {
- t.Errorf("read %q from server; want nothing", got)
- }
- }()
- select {
- case <-donec:
- case <-time.After(10 * time.Second):
- t.Fatalf("timeout")
+
+ bs, err := io.ReadAll(conn)
+ if err != nil {
+ t.Errorf("ReadAll: %v", err)
+ }
+ got := string(bs)
+ if got != "" {
+ t.Errorf("read %q from server; want nothing", got)
}
}
@@ -3884,7 +3934,7 @@ func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) {
// this test fails, it hangs. This helps debugging and I've
// added this enough times "temporarily". It now gets added
// full time.
- errorf := func(format string, args ...interface{}) {
+ errorf := func(format string, args ...any) {
v := fmt.Sprintf(format, args...)
println(v)
t.Error(v)
@@ -3893,10 +3943,10 @@ func testTransportAndServerSharedBodyRace(t *testing.T, h2 bool) {
unblockBackend := make(chan bool)
backend := newClientServerTest(t, h2, HandlerFunc(func(rw ResponseWriter, req *Request) {
gone := rw.(CloseNotifier).CloseNotify()
- didCopy := make(chan interface{})
+ didCopy := make(chan any)
go func() {
n, err := io.CopyN(rw, req.Body, bodySize)
- didCopy <- []interface{}{n, err}
+ didCopy <- []any{n, err}
}()
isGone := false
Loop:
@@ -4888,7 +4938,7 @@ func TestServerContext_LocalAddrContextKey_h2(t *testing.T) {
func testServerContext_LocalAddrContextKey(t *testing.T, h2 bool) {
setParallel(t)
defer afterTest(t)
- ch := make(chan interface{}, 1)
+ ch := make(chan any, 1)
cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
ch <- r.Context().Value(LocalAddrContextKey)
}))
@@ -5689,22 +5739,37 @@ func testServerKeepAlivesEnabled(t *testing.T, h2 bool) {
}
// Not parallel: messes with global variable. (http2goAwayTimeout)
defer afterTest(t)
- cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {
- fmt.Fprintf(w, "%v", r.RemoteAddr)
- }))
+ cst := newClientServerTest(t, h2, HandlerFunc(func(w ResponseWriter, r *Request) {}))
defer cst.close()
srv := cst.ts.Config
srv.SetKeepAlivesEnabled(false)
- a := cst.getURL(cst.ts.URL)
- if !waitCondition(2*time.Second, 10*time.Millisecond, srv.ExportAllConnsIdle) {
- t.Fatalf("test server has active conns")
- }
- b := cst.getURL(cst.ts.URL)
- if a == b {
- t.Errorf("got same connection between first and second requests")
- }
- if !waitCondition(2*time.Second, 10*time.Millisecond, srv.ExportAllConnsIdle) {
- t.Fatalf("test server has active conns")
+ for try := 0; try < 2; try++ {
+ if !waitCondition(2*time.Second, 10*time.Millisecond, srv.ExportAllConnsIdle) {
+ t.Fatalf("request %v: test server has active conns", try)
+ }
+ conns := 0
+ var info httptrace.GotConnInfo
+ ctx := httptrace.WithClientTrace(context.Background(), &httptrace.ClientTrace{
+ GotConn: func(v httptrace.GotConnInfo) {
+ conns++
+ info = v
+ },
+ })
+ req, err := NewRequestWithContext(ctx, "GET", cst.ts.URL, nil)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res, err := cst.c.Do(req)
+ if err != nil {
+ t.Fatal(err)
+ }
+ res.Body.Close()
+ if conns != 1 {
+ t.Fatalf("request %v: got %v conns, want 1", try, conns)
+ }
+ if info.Reused || info.WasIdle {
+ t.Fatalf("request %v: Reused=%v (want false), WasIdle=%v (want false)", try, info.Reused, info.WasIdle)
+ }
}
}
@@ -5933,11 +5998,7 @@ func TestServerHijackGetsBackgroundByte_big(t *testing.T) {
t.Fatal(err)
}
- select {
- case <-done:
- case <-time.After(2 * time.Second):
- t.Error("timeout")
- }
+ <-done
}
// Issue 18319: test that the Server validates the request method.
@@ -6232,7 +6293,7 @@ func testContentEncodingNoSniffing(t *testing.T, h2 bool) {
// setting contentEncoding as an interface instead of a string
// directly, so as to differentiate between 3 states:
// unset, empty string "" and set string "foo/bar".
- contentEncoding interface{}
+ contentEncoding any
wantContentType string
}
@@ -6490,7 +6551,7 @@ func TestDisableKeepAliveUpgrade(t *testing.T) {
rwc, ok := resp.Body.(io.ReadWriteCloser)
if !ok {
- t.Fatalf("Response.Body is not a io.ReadWriteCloser: %T", resp.Body)
+ t.Fatalf("Response.Body is not an io.ReadWriteCloser: %T", resp.Body)
}
_, err = rwc.Write([]byte("hello"))
@@ -6609,3 +6670,63 @@ func testQuerySemicolon(t *testing.T, query string, wantX string, allowSemicolon
}
}
}
+
+func TestMaxBytesHandler(t *testing.T) {
+ setParallel(t)
+ defer afterTest(t)
+
+ for _, maxSize := range []int64{100, 1_000, 1_000_000} {
+ for _, requestSize := range []int64{100, 1_000, 1_000_000} {
+ t.Run(fmt.Sprintf("max size %d request size %d", maxSize, requestSize),
+ func(t *testing.T) {
+ testMaxBytesHandler(t, maxSize, requestSize)
+ })
+ }
+ }
+}
+
+func testMaxBytesHandler(t *testing.T, maxSize, requestSize int64) {
+ var (
+ handlerN int64
+ handlerErr error
+ )
+ echo := HandlerFunc(func(w ResponseWriter, r *Request) {
+ var buf bytes.Buffer
+ handlerN, handlerErr = io.Copy(&buf, r.Body)
+ io.Copy(w, &buf)
+ })
+
+ ts := httptest.NewServer(MaxBytesHandler(echo, maxSize))
+ defer ts.Close()
+
+ c := ts.Client()
+ var buf strings.Builder
+ body := strings.NewReader(strings.Repeat("a", int(requestSize)))
+ res, err := c.Post(ts.URL, "text/plain", body)
+ if err != nil {
+ t.Errorf("unexpected connection error: %v", err)
+ } else {
+ _, err = io.Copy(&buf, res.Body)
+ res.Body.Close()
+ if err != nil {
+ t.Errorf("unexpected read error: %v", err)
+ }
+ }
+ if handlerN > maxSize {
+ t.Errorf("expected max request body %d; got %d", maxSize, handlerN)
+ }
+ if requestSize > maxSize && handlerErr == nil {
+ t.Error("expected error on handler side; got nil")
+ }
+ if requestSize <= maxSize {
+ if handlerErr != nil {
+ t.Errorf("%d expected nil error on handler side; got %v", requestSize, handlerErr)
+ }
+ if handlerN != requestSize {
+ t.Errorf("expected request of size %d; got %d", requestSize, handlerN)
+ }
+ }
+ if buf.Len() != int(handlerN) {
+ t.Errorf("expected echo of size %d; got %d", handlerN, buf.Len())
+ }
+}