diff options
Diffstat (limited to 'libgo/go/net/http/server.go')
-rw-r--r-- | libgo/go/net/http/server.go | 101 |
1 files changed, 77 insertions, 24 deletions
diff --git a/libgo/go/net/http/server.go b/libgo/go/net/http/server.go index ce3993341d3..f5cdc3a4d34 100644 --- a/libgo/go/net/http/server.go +++ b/libgo/go/net/http/server.go @@ -13,6 +13,7 @@ import ( "crypto/tls" "errors" "fmt" + "internal/godebug" "io" "log" "math/rand" @@ -20,7 +21,6 @@ import ( "net/textproto" "net/url" urlpkg "net/url" - "os" "path" "runtime" "sort" @@ -494,8 +494,8 @@ type response struct { // prior to the headers being written. If the set of trailers is fixed // or known before the header is written, the normal Go trailers mechanism // is preferred: -// https://golang.org/pkg/net/http/#ResponseWriter -// https://golang.org/pkg/net/http/#example_ResponseWriter_trailers +// https://pkg.go.dev/net/http#ResponseWriter +// https://pkg.go.dev/net/http#example-ResponseWriter-Trailers const TrailerPrefix = "Trailer:" // finalTrailers is called after the Handler exits and returns a non-nil @@ -798,7 +798,7 @@ var ( ) var copyBufPool = sync.Pool{ - New: func() interface{} { + New: func() any { b := make([]byte, 32*1024) return &b }, @@ -865,6 +865,28 @@ func (srv *Server) initialReadLimitSize() int64 { return int64(srv.maxHeaderBytes()) + 4096 // bufio slop } +// tlsHandshakeTimeout returns the time limit permitted for the TLS +// handshake, or zero for unlimited. +// +// It returns the minimum of any positive ReadHeaderTimeout, +// ReadTimeout, or WriteTimeout. +func (srv *Server) tlsHandshakeTimeout() time.Duration { + var ret time.Duration + for _, v := range [...]time.Duration{ + srv.ReadHeaderTimeout, + srv.ReadTimeout, + srv.WriteTimeout, + } { + if v <= 0 { + continue + } + if ret == 0 || v < ret { + ret = v + } + } + return ret +} + // wrapper around io.ReadCloser which on first read, sends an // HTTP/1.1 100 Continue header type expectContinueReader struct { @@ -1409,11 +1431,11 @@ func (cw *chunkWriter) writeHeader(p []byte) { hasCL = false } - if w.req.Method == "HEAD" || !bodyAllowedForStatus(code) { - // do nothing - } else if code == StatusNoContent { + if w.req.Method == "HEAD" || !bodyAllowedForStatus(code) || code == StatusNoContent { + // Response has no body. delHeader("Transfer-Encoding") } else if hasCL { + // Content-Length has been provided, so no chunking is to be done. delHeader("Transfer-Encoding") } else if w.req.ProtoAtLeast(1, 1) { // HTTP/1.1 or greater: Transfer-Encoding has been set to identity, and no @@ -1424,6 +1446,7 @@ func (cw *chunkWriter) writeHeader(p []byte) { if hasTE && te == "identity" { cw.chunking = false w.closeAfterReply = true + delHeader("Transfer-Encoding") } else { // HTTP/1.1 or greater: use chunked transfer encoding // to avoid closing the connection at EOF. @@ -1799,6 +1822,7 @@ func isCommonNetReadError(err error) bool { func (c *conn) serve(ctx context.Context) { c.remoteAddr = c.rwc.RemoteAddr().String() ctx = context.WithValue(ctx, LocalAddrContextKey, c.rwc.LocalAddr()) + var inFlightResponse *response defer func() { if err := recover(); err != nil && err != ErrAbortHandler { const size = 64 << 10 @@ -1806,18 +1830,25 @@ func (c *conn) serve(ctx context.Context) { buf = buf[:runtime.Stack(buf, false)] c.server.logf("http: panic serving %v: %v\n%s", c.remoteAddr, err, buf) } + if inFlightResponse != nil { + inFlightResponse.cancelCtx() + } if !c.hijacked() { + if inFlightResponse != nil { + inFlightResponse.conn.r.abortPendingRead() + inFlightResponse.reqBody.Close() + } c.close() c.setState(c.rwc, StateClosed, runHooks) } }() if tlsConn, ok := c.rwc.(*tls.Conn); ok { - if d := c.server.ReadTimeout; d > 0 { - c.rwc.SetReadDeadline(time.Now().Add(d)) - } - if d := c.server.WriteTimeout; d > 0 { - c.rwc.SetWriteDeadline(time.Now().Add(d)) + tlsTO := c.server.tlsHandshakeTimeout() + if tlsTO > 0 { + dl := time.Now().Add(tlsTO) + c.rwc.SetReadDeadline(dl) + c.rwc.SetWriteDeadline(dl) } if err := tlsConn.HandshakeContext(ctx); err != nil { // If the handshake failed due to the client not speaking @@ -1831,6 +1862,11 @@ func (c *conn) serve(ctx context.Context) { c.server.logf("http: TLS handshake error from %s: %v", c.rwc.RemoteAddr(), err) return } + // Restore Conn-level deadlines. + if tlsTO > 0 { + c.rwc.SetReadDeadline(time.Time{}) + c.rwc.SetWriteDeadline(time.Time{}) + } c.tlsState = new(tls.ConnectionState) *c.tlsState = tlsConn.ConnectionState() if proto := c.tlsState.NegotiatedProtocol; validNextProto(proto) { @@ -1931,7 +1967,9 @@ func (c *conn) serve(ctx context.Context) { // in parallel even if their responses need to be serialized. // But we're not going to implement HTTP pipelining because it // was never deployed in the wild and the answer is HTTP/2. + inFlightResponse = w serverHandler{c.server}.ServeHTTP(w, w.req) + inFlightResponse = nil w.cancelCtx() if c.hijacked() { return @@ -2277,7 +2315,7 @@ func cleanPath(p string) string { // stripHostPort returns h without any trailing ":<port>". func stripHostPort(h string) string { // If no port on host, return unchanged - if strings.IndexByte(h, ':') == -1 { + if !strings.Contains(h, ":") { return h } host, _, err := net.SplitHostPort(h) @@ -3157,7 +3195,7 @@ func (srv *Server) SetKeepAlivesEnabled(v bool) { // TODO: Issue 26303: close HTTP/2 conns as soon as they become idle. } -func (s *Server) logf(format string, args ...interface{}) { +func (s *Server) logf(format string, args ...any) { if s.ErrorLog != nil { s.ErrorLog.Printf(format, args...) } else { @@ -3168,7 +3206,7 @@ func (s *Server) logf(format string, args ...interface{}) { // logf prints to the ErrorLog of the *Server associated with request r // via ServerContextKey. If there's no associated server, or if ErrorLog // is nil, logging is done via the log package's standard logger. -func logf(r *Request, format string, args ...interface{}) { +func logf(r *Request, format string, args ...any) { s, _ := r.Context().Value(ServerContextKey).(*Server) if s != nil && s.ErrorLog != nil { s.ErrorLog.Printf(format, args...) @@ -3264,7 +3302,7 @@ func (srv *Server) onceSetNextProtoDefaults_Serve() { // configured otherwise. (by setting srv.TLSNextProto non-nil) // It must only be called via srv.nextProtoOnce (use srv.setupHTTP2_*). func (srv *Server) onceSetNextProtoDefaults() { - if omitBundledHTTP2 || strings.Contains(os.Getenv("GODEBUG"), "http2server=0") { + if omitBundledHTTP2 || godebug.Get("http2server") == "0" { return } // Enable HTTP/2 by default if the user hasn't otherwise @@ -3331,7 +3369,7 @@ func (h *timeoutHandler) ServeHTTP(w ResponseWriter, r *Request) { h: make(Header), req: r, } - panicChan := make(chan interface{}, 1) + panicChan := make(chan any, 1) go func() { defer func() { if p := recover(); p != nil { @@ -3359,9 +3397,15 @@ func (h *timeoutHandler) ServeHTTP(w ResponseWriter, r *Request) { case <-ctx.Done(): tw.mu.Lock() defer tw.mu.Unlock() - w.WriteHeader(StatusServiceUnavailable) - io.WriteString(w, h.errorBody()) - tw.timedOut = true + switch err := ctx.Err(); err { + case context.DeadlineExceeded: + w.WriteHeader(StatusServiceUnavailable) + io.WriteString(w, h.errorBody()) + tw.err = ErrHandlerTimeout + default: + w.WriteHeader(StatusServiceUnavailable) + tw.err = err + } } } @@ -3372,7 +3416,7 @@ type timeoutWriter struct { req *Request mu sync.Mutex - timedOut bool + err error wroteHeader bool code int } @@ -3392,8 +3436,8 @@ func (tw *timeoutWriter) Header() Header { return tw.h } func (tw *timeoutWriter) Write(p []byte) (int, error) { tw.mu.Lock() defer tw.mu.Unlock() - if tw.timedOut { - return 0, ErrHandlerTimeout + if tw.err != nil { + return 0, tw.err } if !tw.wroteHeader { tw.writeHeaderLocked(StatusOK) @@ -3405,7 +3449,7 @@ func (tw *timeoutWriter) writeHeaderLocked(code int) { checkWriteHeaderCode(code) switch { - case tw.timedOut: + case tw.err != nil: return case tw.wroteHeader: if tw.req != nil { @@ -3572,3 +3616,12 @@ func tlsRecordHeaderLooksLikeHTTP(hdr [5]byte) bool { } return false } + +// MaxBytesHandler returns a Handler that runs h with its ResponseWriter and Request.Body wrapped by a MaxBytesReader. +func MaxBytesHandler(h Handler, n int64) Handler { + return HandlerFunc(func(w ResponseWriter, r *Request) { + r2 := *r + r2.Body = MaxBytesReader(w, r.Body, n) + h.ServeHTTP(w, &r2) + }) +} |