diff options
Diffstat (limited to 'libgo/go/http')
36 files changed, 2985 insertions, 218 deletions
diff --git a/libgo/go/http/cgi/child.go b/libgo/go/http/cgi/child.go index c7d48b9eb3f..e1ad7ad3221 100644 --- a/libgo/go/http/cgi/child.go +++ b/libgo/go/http/cgi/child.go @@ -9,10 +9,12 @@ package cgi import ( "bufio" + "crypto/tls" "fmt" "http" "io" "io/ioutil" + "net" "os" "strconv" "strings" @@ -21,8 +23,16 @@ import ( // Request returns the HTTP request as represented in the current // environment. This assumes the current program is being run // by a web server in a CGI environment. +// The returned Request's Body is populated, if applicable. func Request() (*http.Request, os.Error) { - return requestFromEnvironment(envMap(os.Environ())) + r, err := RequestFromMap(envMap(os.Environ())) + if err != nil { + return nil, err + } + if r.ContentLength > 0 { + r.Body = ioutil.NopCloser(io.LimitReader(os.Stdin, r.ContentLength)) + } + return r, nil } func envMap(env []string) map[string]string { @@ -42,37 +52,44 @@ var skipHeader = map[string]bool{ "HTTP_USER_AGENT": true, } -func requestFromEnvironment(env map[string]string) (*http.Request, os.Error) { +// RequestFromMap creates an http.Request from CGI variables. +// The returned Request's Body field is not populated. +func RequestFromMap(params map[string]string) (*http.Request, os.Error) { r := new(http.Request) - r.Method = env["REQUEST_METHOD"] + r.Method = params["REQUEST_METHOD"] if r.Method == "" { return nil, os.NewError("cgi: no REQUEST_METHOD in environment") } + + r.Proto = params["SERVER_PROTOCOL"] + var ok bool + r.ProtoMajor, r.ProtoMinor, ok = http.ParseHTTPVersion(r.Proto) + if !ok { + return nil, os.NewError("cgi: invalid SERVER_PROTOCOL version") + } + r.Close = true r.Trailer = http.Header{} r.Header = http.Header{} - r.Host = env["HTTP_HOST"] - r.Referer = env["HTTP_REFERER"] - r.UserAgent = env["HTTP_USER_AGENT"] + r.Host = params["HTTP_HOST"] + r.Referer = params["HTTP_REFERER"] + r.UserAgent = params["HTTP_USER_AGENT"] - // CGI doesn't allow chunked requests, so these should all be accurate: - r.Proto = "HTTP/1.0" - r.ProtoMajor = 1 - r.ProtoMinor = 0 - r.TransferEncoding = nil - - if lenstr := env["CONTENT_LENGTH"]; lenstr != "" { + if lenstr := params["CONTENT_LENGTH"]; lenstr != "" { clen, err := strconv.Atoi64(lenstr) if err != nil { return nil, os.NewError("cgi: bad CONTENT_LENGTH in environment: " + lenstr) } r.ContentLength = clen - r.Body = ioutil.NopCloser(io.LimitReader(os.Stdin, clen)) + } + + if ct := params["CONTENT_TYPE"]; ct != "" { + r.Header.Set("Content-Type", ct) } // Copy "HTTP_FOO_BAR" variables to "Foo-Bar" Headers - for k, v := range env { + for k, v := range params { if !strings.HasPrefix(k, "HTTP_") || skipHeader[k] { continue } @@ -84,7 +101,7 @@ func requestFromEnvironment(env map[string]string) (*http.Request, os.Error) { if r.Host != "" { // Hostname is provided, so we can reasonably construct a URL, // even if we have to assume 'http' for the scheme. - r.RawURL = "http://" + r.Host + env["REQUEST_URI"] + r.RawURL = "http://" + r.Host + params["REQUEST_URI"] url, err := http.ParseURL(r.RawURL) if err != nil { return nil, os.NewError("cgi: failed to parse host and REQUEST_URI into a URL: " + r.RawURL) @@ -94,13 +111,25 @@ func requestFromEnvironment(env map[string]string) (*http.Request, os.Error) { // Fallback logic if we don't have a Host header or the URL // failed to parse if r.URL == nil { - r.RawURL = env["REQUEST_URI"] + r.RawURL = params["REQUEST_URI"] url, err := http.ParseURL(r.RawURL) if err != nil { return nil, os.NewError("cgi: failed to parse REQUEST_URI into a URL: " + r.RawURL) } r.URL = url } + + // There's apparently a de-facto standard for this. + // http://docstore.mik.ua/orelly/linux/cgi/ch03_02.htm#ch03-35636 + if s := params["HTTPS"]; s == "on" || s == "ON" || s == "1" { + r.TLS = &tls.ConnectionState{HandshakeComplete: true} + } + + // Request.RemoteAddr has its port set by Go's standard http + // server, so we do here too. We don't have one, though, so we + // use a dummy one. + r.RemoteAddr = net.JoinHostPort(params["REMOTE_ADDR"], "0") + return r, nil } @@ -139,10 +168,6 @@ func (r *response) Flush() { r.bufw.Flush() } -func (r *response) RemoteAddr() string { - return os.Getenv("REMOTE_ADDR") -} - func (r *response) Header() http.Header { return r.header } @@ -168,25 +193,7 @@ func (r *response) WriteHeader(code int) { r.header.Add("Content-Type", "text/html; charset=utf-8") } - // TODO: add a method on http.Header to write itself to an io.Writer? - // This is duplicated code. - for k, vv := range r.header { - for _, v := range vv { - v = strings.Replace(v, "\n", "", -1) - v = strings.Replace(v, "\r", "", -1) - v = strings.TrimSpace(v) - fmt.Fprintf(r.bufw, "%s: %s\r\n", k, v) - } - } - r.bufw.Write([]byte("\r\n")) + r.header.Write(r.bufw) + r.bufw.WriteString("\r\n") r.bufw.Flush() } - -func (r *response) UsingTLS() bool { - // There's apparently a de-facto standard for this. - // http://docstore.mik.ua/orelly/linux/cgi/ch03_02.htm#ch03-35636 - if s := os.Getenv("HTTPS"); s == "on" || s == "ON" || s == "1" { - return true - } - return false -} diff --git a/libgo/go/http/cgi/child_test.go b/libgo/go/http/cgi/child_test.go index db0e09cf66a..d12947814e1 100644 --- a/libgo/go/http/cgi/child_test.go +++ b/libgo/go/http/cgi/child_test.go @@ -12,6 +12,7 @@ import ( func TestRequest(t *testing.T) { env := map[string]string{ + "SERVER_PROTOCOL": "HTTP/1.1", "REQUEST_METHOD": "GET", "HTTP_HOST": "example.com", "HTTP_REFERER": "elsewhere", @@ -19,10 +20,13 @@ func TestRequest(t *testing.T) { "HTTP_FOO_BAR": "baz", "REQUEST_URI": "/path?a=b", "CONTENT_LENGTH": "123", + "CONTENT_TYPE": "text/xml", + "HTTPS": "1", + "REMOTE_ADDR": "5.6.7.8", } - req, err := requestFromEnvironment(env) + req, err := RequestFromMap(env) if err != nil { - t.Fatalf("requestFromEnvironment: %v", err) + t.Fatalf("RequestFromMap: %v", err) } if g, e := req.UserAgent, "goclient"; e != g { t.Errorf("expected UserAgent %q; got %q", e, g) @@ -34,6 +38,9 @@ func TestRequest(t *testing.T) { // Tests that we don't put recognized headers in the map t.Errorf("expected User-Agent %q; got %q", e, g) } + if g, e := req.Header.Get("Content-Type"), "text/xml"; e != g { + t.Errorf("expected Content-Type %q; got %q", e, g) + } if g, e := req.ContentLength, int64(123); e != g { t.Errorf("expected ContentLength %d; got %d", e, g) } @@ -58,18 +65,25 @@ func TestRequest(t *testing.T) { if req.Trailer == nil { t.Errorf("unexpected nil Trailer") } + if req.TLS == nil { + t.Errorf("expected non-nil TLS") + } + if e, g := "5.6.7.8:0", req.RemoteAddr; e != g { + t.Errorf("RemoteAddr: got %q; want %q", g, e) + } } func TestRequestWithoutHost(t *testing.T) { env := map[string]string{ - "HTTP_HOST": "", - "REQUEST_METHOD": "GET", - "REQUEST_URI": "/path?a=b", - "CONTENT_LENGTH": "123", + "SERVER_PROTOCOL": "HTTP/1.1", + "HTTP_HOST": "", + "REQUEST_METHOD": "GET", + "REQUEST_URI": "/path?a=b", + "CONTENT_LENGTH": "123", } - req, err := requestFromEnvironment(env) + req, err := RequestFromMap(env) if err != nil { - t.Fatalf("requestFromEnvironment: %v", err) + t.Fatalf("RequestFromMap: %v", err) } if g, e := req.RawURL, "/path?a=b"; e != g { t.Errorf("expected RawURL %q; got %q", e, g) diff --git a/libgo/go/http/cgi/host.go b/libgo/go/http/cgi/host.go index 862acb6000e..7e4ccf881d9 100644 --- a/libgo/go/http/cgi/host.go +++ b/libgo/go/http/cgi/host.go @@ -15,8 +15,8 @@ package cgi import ( + "bufio" "bytes" - "encoding/line" "exec" "fmt" "http" @@ -51,6 +51,16 @@ type Handler struct { InheritEnv []string // environment variables to inherit from host, as "key" Logger *log.Logger // optional log for errors or nil to use log.Print Args []string // optional arguments to pass to child process + + // PathLocationHandler specifies the root http Handler that + // should handle internal redirects when the CGI process + // returns a Location header value starting with a "/", as + // specified in RFC 3875 ยง 6.3.2. This will likely be + // http.DefaultServeMux. + // + // If nil, a CGI response with a local URI path is instead sent + // back to the client and not redirected internally. + PathLocationHandler http.Handler } func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { @@ -78,6 +88,7 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { env := []string{ "SERVER_SOFTWARE=go", "SERVER_NAME=" + req.Host, + "SERVER_PROTOCOL=HTTP/1.1", "HTTP_HOST=" + req.Host, "GATEWAY_INTERFACE=CGI/1.1", "REQUEST_METHOD=" + req.Method, @@ -172,14 +183,14 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { go io.Copy(cmd.Stdin, req.Body) } - linebody := line.NewReader(cmd.Stdout, 1024) - headers := rw.Header() - statusCode := http.StatusOK + linebody, _ := bufio.NewReaderSize(cmd.Stdout, 1024) + headers := make(http.Header) + statusCode := 0 for { line, isPrefix, err := linebody.ReadLine() if isPrefix { rw.WriteHeader(http.StatusInternalServerError) - h.printf("CGI: long header line from subprocess.") + h.printf("cgi: long header line from subprocess.") return } if err == os.EOF { @@ -187,7 +198,7 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } if err != nil { rw.WriteHeader(http.StatusInternalServerError) - h.printf("CGI: error reading headers: %v", err) + h.printf("cgi: error reading headers: %v", err) return } if len(line) == 0 { @@ -195,7 +206,7 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { } parts := strings.Split(string(line), ":", 2) if len(parts) < 2 { - h.printf("CGI: bogus header line: %s", string(line)) + h.printf("cgi: bogus header line: %s", string(line)) continue } header, val := parts[0], parts[1] @@ -204,13 +215,13 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { switch { case header == "Status": if len(val) < 3 { - h.printf("CGI: bogus status (short): %q", val) + h.printf("cgi: bogus status (short): %q", val) return } code, err := strconv.Atoi(val[0:3]) if err != nil { - h.printf("CGI: bogus status: %q", val) - h.printf("CGI: line was %q", line) + h.printf("cgi: bogus status: %q", val) + h.printf("cgi: line was %q", line) return } statusCode = code @@ -218,11 +229,35 @@ func (h *Handler) ServeHTTP(rw http.ResponseWriter, req *http.Request) { headers.Add(header, val) } } + + if loc := headers.Get("Location"); loc != "" { + if strings.HasPrefix(loc, "/") && h.PathLocationHandler != nil { + h.handleInternalRedirect(rw, req, loc) + return + } + if statusCode == 0 { + statusCode = http.StatusFound + } + } + + if statusCode == 0 { + statusCode = http.StatusOK + } + + // Copy headers to rw's headers, after we've decided not to + // go into handleInternalRedirect, which won't want its rw + // headers to have been touched. + for k, vv := range headers { + for _, v := range vv { + rw.Header().Add(k, v) + } + } + rw.WriteHeader(statusCode) _, err = io.Copy(rw, linebody) if err != nil { - h.printf("CGI: copy error: %v", err) + h.printf("cgi: copy error: %v", err) } } @@ -234,6 +269,37 @@ func (h *Handler) printf(format string, v ...interface{}) { } } +func (h *Handler) handleInternalRedirect(rw http.ResponseWriter, req *http.Request, path string) { + url, err := req.URL.ParseURL(path) + if err != nil { + rw.WriteHeader(http.StatusInternalServerError) + h.printf("cgi: error resolving local URI path %q: %v", path, err) + return + } + // TODO: RFC 3875 isn't clear if only GET is supported, but it + // suggests so: "Note that any message-body attached to the + // request (such as for a POST request) may not be available + // to the resource that is the target of the redirect." We + // should do some tests against Apache to see how it handles + // POST, HEAD, etc. Does the internal redirect get the same + // method or just GET? What about incoming headers? + // (e.g. Cookies) Which headers, if any, are copied into the + // second request? + newReq := &http.Request{ + Method: "GET", + URL: url, + RawURL: path, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(http.Header), + Host: url.Host, + RemoteAddr: req.RemoteAddr, + TLS: req.TLS, + } + h.PathLocationHandler.ServeHTTP(rw, newReq) +} + func upperCaseAndUnderscore(rune int) int { switch { case rune >= 'a' && rune <= 'z': diff --git a/libgo/go/http/cgi/host_test.go b/libgo/go/http/cgi/host_test.go index e8084b1134e..9ac085f2f3a 100644 --- a/libgo/go/http/cgi/host_test.go +++ b/libgo/go/http/cgi/host_test.go @@ -271,3 +271,40 @@ Transfer-Encoding: chunked expected, got) } } + +func TestRedirect(t *testing.T) { + if skipTest(t) { + return + } + h := &Handler{ + Path: "testdata/test.cgi", + Root: "/test.cgi", + } + rec := runCgiTest(t, h, "GET /test.cgi?loc=http://foo.com/ HTTP/1.0\nHost: example.com\n\n", nil) + if e, g := 302, rec.Code; e != g { + t.Errorf("expected status code %d; got %d", e, g) + } + if e, g := "http://foo.com/", rec.Header().Get("Location"); e != g { + t.Errorf("expected Location header of %q; got %q", e, g) + } +} + +func TestInternalRedirect(t *testing.T) { + if skipTest(t) { + return + } + baseHandler := http.HandlerFunc(func(rw http.ResponseWriter, req *http.Request) { + fmt.Fprintf(rw, "basepath=%s\n", req.URL.Path) + fmt.Fprintf(rw, "remoteaddr=%s\n", req.RemoteAddr) + }) + h := &Handler{ + Path: "testdata/test.cgi", + Root: "/test.cgi", + PathLocationHandler: baseHandler, + } + expectedMap := map[string]string{ + "basepath": "/foo", + "remoteaddr": "1.2.3.4", + } + runCgiTest(t, h, "GET /test.cgi?loc=/foo HTTP/1.0\nHost: example.com\n\n", expectedMap) +} diff --git a/libgo/go/http/client.go b/libgo/go/http/client.go index daba3a89b0c..d73cbc8550c 100644 --- a/libgo/go/http/client.go +++ b/libgo/go/http/client.go @@ -22,6 +22,16 @@ import ( // Client is not yet very configurable. type Client struct { Transport RoundTripper // if nil, DefaultTransport is used + + // If CheckRedirect is not nil, the client calls it before + // following an HTTP redirect. The arguments req and via + // are the upcoming request and the requests made already, + // oldest first. If CheckRedirect returns an error, the client + // returns that error instead of issue the Request req. + // + // If CheckRedirect is nil, the Client uses its default policy, + // which is to stop after 10 consecutive requests. + CheckRedirect func(req *Request, via []*Request) os.Error } // DefaultClient is the default Client and is used by Get, Head, and Post. @@ -109,7 +119,7 @@ func shouldRedirect(statusCode int) bool { } // Get issues a GET to the specified URL. If the response is one of the following -// redirect codes, it follows the redirect, up to a maximum of 10 redirects: +// redirect codes, Get follows the redirect, up to a maximum of 10 redirects: // // 301 (Moved Permanently) // 302 (Found) @@ -126,35 +136,33 @@ func Get(url string) (r *Response, finalURL string, err os.Error) { return DefaultClient.Get(url) } -// Get issues a GET to the specified URL. If the response is one of the following -// redirect codes, it follows the redirect, up to a maximum of 10 redirects: +// Get issues a GET to the specified URL. If the response is one of the +// following redirect codes, Get follows the redirect after calling the +// Client's CheckRedirect function. // // 301 (Moved Permanently) // 302 (Found) // 303 (See Other) // 307 (Temporary Redirect) // -// finalURL is the URL from which the response was fetched -- identical to the -// input URL unless redirects were followed. +// finalURL is the URL from which the response was fetched -- identical +// to the input URL unless redirects were followed. // // Caller should close r.Body when done reading from it. func (c *Client) Get(url string) (r *Response, finalURL string, err os.Error) { // TODO: if/when we add cookie support, the redirected request shouldn't // necessarily supply the same cookies as the original. - // TODO: set referrer header on redirects. var base *URL - // TODO: remove this hard-coded 10 and use the Client's policy - // (ClientConfig) instead. - for redirect := 0; ; redirect++ { - if redirect >= 10 { - err = os.ErrorString("stopped after 10 redirects") - break - } + redirectChecker := c.CheckRedirect + if redirectChecker == nil { + redirectChecker = defaultCheckRedirect + } + var via []*Request + for redirect := 0; ; redirect++ { var req Request req.Method = "GET" - req.ProtoMajor = 1 - req.ProtoMinor = 1 + req.Header = make(Header) if base == nil { req.URL, err = ParseURL(url) } else { @@ -163,6 +171,19 @@ func (c *Client) Get(url string) (r *Response, finalURL string, err os.Error) { if err != nil { break } + if len(via) > 0 { + // Add the Referer header. + lastReq := via[len(via)-1] + if lastReq.URL.Scheme != "https" { + req.Referer = lastReq.URL.String() + } + + err = redirectChecker(&req, via) + if err != nil { + break + } + } + url = req.URL.String() if r, err = send(&req, c.Transport); err != nil { break @@ -174,6 +195,7 @@ func (c *Client) Get(url string) (r *Response, finalURL string, err os.Error) { break } base = req.URL + via = append(via, &req) continue } finalURL = url @@ -184,6 +206,13 @@ func (c *Client) Get(url string) (r *Response, finalURL string, err os.Error) { return } +func defaultCheckRedirect(req *Request, via []*Request) os.Error { + if len(via) >= 10 { + return os.ErrorString("stopped after 10 redirects") + } + return nil +} + // Post issues a POST to the specified URL. // // Caller should close r.Body when done reading from it. diff --git a/libgo/go/http/client_test.go b/libgo/go/http/client_test.go index 3a6f834253b..59d62c1c9d4 100644 --- a/libgo/go/http/client_test.go +++ b/libgo/go/http/client_test.go @@ -12,6 +12,7 @@ import ( "http/httptest" "io/ioutil" "os" + "strconv" "strings" "testing" ) @@ -75,3 +76,51 @@ func TestGetRequestFormat(t *testing.T) { t.Errorf("expected non-nil request Header") } } + +func TestRedirects(t *testing.T) { + var ts *httptest.Server + ts = httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + n, _ := strconv.Atoi(r.FormValue("n")) + // Test Referer header. (7 is arbitrary position to test at) + if n == 7 { + if g, e := r.Referer, ts.URL+"/?n=6"; e != g { + t.Errorf("on request ?n=7, expected referer of %q; got %q", e, g) + } + } + if n < 15 { + Redirect(w, r, fmt.Sprintf("/?n=%d", n+1), StatusFound) + return + } + fmt.Fprintf(w, "n=%d", n) + })) + defer ts.Close() + + c := &Client{} + _, _, err := c.Get(ts.URL) + if e, g := "Get /?n=10: stopped after 10 redirects", fmt.Sprintf("%v", err); e != g { + t.Errorf("with default client, expected error %q, got %q", e, g) + } + + var checkErr os.Error + var lastVia []*Request + c = &Client{CheckRedirect: func(_ *Request, via []*Request) os.Error { + lastVia = via + return checkErr + }} + _, finalUrl, err := c.Get(ts.URL) + if e, g := "<nil>", fmt.Sprintf("%v", err); e != g { + t.Errorf("with custom client, expected error %q, got %q", e, g) + } + if !strings.HasSuffix(finalUrl, "/?n=15") { + t.Errorf("expected final url to end in /?n=15; got url %q", finalUrl) + } + if e, g := 15, len(lastVia); e != g { + t.Errorf("expected lastVia to have contained %d elements; got %d", e, g) + } + + checkErr = os.NewError("no redirects allowed") + _, finalUrl, err = c.Get(ts.URL) + if e, g := "Get /?n=1: no redirects allowed", fmt.Sprintf("%v", err); e != g { + t.Errorf("with redirects forbidden, expected error %q, got %q", e, g) + } +} diff --git a/libgo/go/http/cookie.go b/libgo/go/http/cookie.go index 2bb66e58e5c..cc51316438a 100644 --- a/libgo/go/http/cookie.go +++ b/libgo/go/http/cookie.go @@ -15,9 +15,9 @@ import ( "time" ) -// This implementation is done according to IETF draft-ietf-httpstate-cookie-23, found at +// This implementation is done according to RFC 6265: // -// http://tools.ietf.org/html/draft-ietf-httpstate-cookie-23 +// http://tools.ietf.org/html/rfc6265 // A Cookie represents an HTTP cookie as sent in the Set-Cookie header of an // HTTP response or the Cookie header of an HTTP request. @@ -142,12 +142,12 @@ func writeSetCookies(w io.Writer, kk []*Cookie) os.Error { var b bytes.Buffer for _, c := range kk { b.Reset() - fmt.Fprintf(&b, "%s=%s", c.Name, c.Value) + fmt.Fprintf(&b, "%s=%s", sanitizeName(c.Name), sanitizeValue(c.Value)) if len(c.Path) > 0 { - fmt.Fprintf(&b, "; Path=%s", URLEscape(c.Path)) + fmt.Fprintf(&b, "; Path=%s", sanitizeValue(c.Path)) } if len(c.Domain) > 0 { - fmt.Fprintf(&b, "; Domain=%s", URLEscape(c.Domain)) + fmt.Fprintf(&b, "; Domain=%s", sanitizeValue(c.Domain)) } if len(c.Expires.Zone) > 0 { fmt.Fprintf(&b, "; Expires=%s", c.Expires.Format(time.RFC1123)) @@ -225,7 +225,7 @@ func readCookies(h Header) []*Cookie { func writeCookies(w io.Writer, kk []*Cookie) os.Error { lines := make([]string, 0, len(kk)) for _, c := range kk { - lines = append(lines, fmt.Sprintf("Cookie: %s=%s\r\n", c.Name, c.Value)) + lines = append(lines, fmt.Sprintf("Cookie: %s=%s\r\n", sanitizeName(c.Name), sanitizeValue(c.Value))) } sort.SortStrings(lines) for _, l := range lines { @@ -236,6 +236,19 @@ func writeCookies(w io.Writer, kk []*Cookie) os.Error { return nil } +func sanitizeName(n string) string { + n = strings.Replace(n, "\n", "-", -1) + n = strings.Replace(n, "\r", "-", -1) + return n +} + +func sanitizeValue(v string) string { + v = strings.Replace(v, "\n", " ", -1) + v = strings.Replace(v, "\r", " ", -1) + v = strings.Replace(v, ";", " ", -1) + return v +} + func unquoteCookieValue(v string) string { if len(v) > 1 && v[0] == '"' && v[len(v)-1] == '"' { return v[1 : len(v)-1] diff --git a/libgo/go/http/cookie_test.go b/libgo/go/http/cookie_test.go index db09970406b..a3ae85cd6c9 100644 --- a/libgo/go/http/cookie_test.go +++ b/libgo/go/http/cookie_test.go @@ -21,9 +21,13 @@ var writeSetCookiesTests = []struct { []*Cookie{ &Cookie{Name: "cookie-1", Value: "v$1"}, &Cookie{Name: "cookie-2", Value: "two", MaxAge: 3600}, + &Cookie{Name: "cookie-3", Value: "three", Domain: ".example.com"}, + &Cookie{Name: "cookie-4", Value: "four", Path: "/restricted/"}, }, "Set-Cookie: cookie-1=v$1\r\n" + - "Set-Cookie: cookie-2=two; Max-Age=3600\r\n", + "Set-Cookie: cookie-2=two; Max-Age=3600\r\n" + + "Set-Cookie: cookie-3=three; Domain=.example.com\r\n" + + "Set-Cookie: cookie-4=four; Path=/restricted/\r\n", }, } diff --git a/libgo/go/http/dump.go b/libgo/go/http/dump.go index 306c45bc2c9..358980f7cae 100644 --- a/libgo/go/http/dump.go +++ b/libgo/go/http/dump.go @@ -31,6 +31,8 @@ func drainBody(b io.ReadCloser) (r1, r2 io.ReadCloser, err os.Error) { // DumpRequest is semantically a no-op, but in order to // dump the body, it reads the body data into memory and // changes req.Body to refer to the in-memory copy. +// The documentation for Request.Write details which fields +// of req are used. func DumpRequest(req *Request, body bool) (dump []byte, err os.Error) { var b bytes.Buffer save := req.Body diff --git a/libgo/go/http/export_test.go b/libgo/go/http/export_test.go index a76b70760df..3fe658641f8 100644 --- a/libgo/go/http/export_test.go +++ b/libgo/go/http/export_test.go @@ -14,7 +14,7 @@ func (t *Transport) IdleConnKeysForTesting() (keys []string) { if t.idleConn == nil { return } - for key, _ := range t.idleConn { + for key := range t.idleConn { keys = append(keys, key) } return @@ -32,3 +32,10 @@ func (t *Transport) IdleConnCountForTesting(cacheKey string) int { } return len(conns) } + +func NewTestTimeoutHandler(handler Handler, ch <-chan int64) Handler { + f := func() <-chan int64 { + return ch + } + return &timeoutHandler{handler, f, ""} +} diff --git a/libgo/go/http/fcgi/child.go b/libgo/go/http/fcgi/child.go new file mode 100644 index 00000000000..19718824c96 --- /dev/null +++ b/libgo/go/http/fcgi/child.go @@ -0,0 +1,258 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fcgi + +// This file implements FastCGI from the perspective of a child process. + +import ( + "fmt" + "http" + "http/cgi" + "io" + "net" + "os" + "time" +) + +// request holds the state for an in-progress request. As soon as it's complete, +// it's converted to an http.Request. +type request struct { + pw *io.PipeWriter + reqId uint16 + params map[string]string + buf [1024]byte + rawParams []byte + keepConn bool +} + +func newRequest(reqId uint16, flags uint8) *request { + r := &request{ + reqId: reqId, + params: map[string]string{}, + keepConn: flags&flagKeepConn != 0, + } + r.rawParams = r.buf[:0] + return r +} + +// parseParams reads an encoded []byte into Params. +func (r *request) parseParams() { + text := r.rawParams + r.rawParams = nil + for len(text) > 0 { + keyLen, n := readSize(text) + if n == 0 { + return + } + text = text[n:] + valLen, n := readSize(text) + if n == 0 { + return + } + text = text[n:] + key := readString(text, keyLen) + text = text[keyLen:] + val := readString(text, valLen) + text = text[valLen:] + r.params[key] = val + } +} + +// response implements http.ResponseWriter. +type response struct { + req *request + header http.Header + w *bufWriter + wroteHeader bool +} + +func newResponse(c *child, req *request) *response { + return &response{ + req: req, + header: http.Header{}, + w: newWriter(c.conn, typeStdout, req.reqId), + } +} + +func (r *response) Header() http.Header { + return r.header +} + +func (r *response) Write(data []byte) (int, os.Error) { + if !r.wroteHeader { + r.WriteHeader(http.StatusOK) + } + return r.w.Write(data) +} + +func (r *response) WriteHeader(code int) { + if r.wroteHeader { + return + } + r.wroteHeader = true + if code == http.StatusNotModified { + // Must not have body. + r.header.Del("Content-Type") + r.header.Del("Content-Length") + r.header.Del("Transfer-Encoding") + } else if r.header.Get("Content-Type") == "" { + r.header.Set("Content-Type", "text/html; charset=utf-8") + } + + if r.header.Get("Date") == "" { + r.header.Set("Date", time.UTC().Format(http.TimeFormat)) + } + + fmt.Fprintf(r.w, "Status: %d %s\r\n", code, http.StatusText(code)) + r.header.Write(r.w) + r.w.WriteString("\r\n") +} + +func (r *response) Flush() { + if !r.wroteHeader { + r.WriteHeader(http.StatusOK) + } + r.w.Flush() +} + +func (r *response) Close() os.Error { + r.Flush() + return r.w.Close() +} + +type child struct { + conn *conn + handler http.Handler +} + +func newChild(rwc net.Conn, handler http.Handler) *child { + return &child{newConn(rwc), handler} +} + +func (c *child) serve() { + requests := map[uint16]*request{} + defer c.conn.Close() + var rec record + var br beginRequest + for { + if err := rec.read(c.conn.rwc); err != nil { + return + } + + req, ok := requests[rec.h.Id] + if !ok && rec.h.Type != typeBeginRequest && rec.h.Type != typeGetValues { + // The spec says to ignore unknown request IDs. + continue + } + if ok && rec.h.Type == typeBeginRequest { + // The server is trying to begin a request with the same ID + // as an in-progress request. This is an error. + return + } + + switch rec.h.Type { + case typeBeginRequest: + if err := br.read(rec.content()); err != nil { + return + } + if br.role != roleResponder { + c.conn.writeEndRequest(rec.h.Id, 0, statusUnknownRole) + break + } + requests[rec.h.Id] = newRequest(rec.h.Id, br.flags) + case typeParams: + // NOTE(eds): Technically a key-value pair can straddle the boundary + // between two packets. We buffer until we've received all parameters. + if len(rec.content()) > 0 { + req.rawParams = append(req.rawParams, rec.content()...) + break + } + req.parseParams() + case typeStdin: + content := rec.content() + if req.pw == nil { + var body io.ReadCloser + if len(content) > 0 { + // body could be an io.LimitReader, but it shouldn't matter + // as long as both sides are behaving. + body, req.pw = io.Pipe() + } + go c.serveRequest(req, body) + } + if len(content) > 0 { + // TODO(eds): This blocks until the handler reads from the pipe. + // If the handler takes a long time, it might be a problem. + req.pw.Write(content) + } else if req.pw != nil { + req.pw.Close() + } + case typeGetValues: + values := map[string]string{"FCGI_MPXS_CONNS": "1"} + c.conn.writePairs(0, typeGetValuesResult, values) + case typeData: + // If the filter role is implemented, read the data stream here. + case typeAbortRequest: + requests[rec.h.Id] = nil, false + c.conn.writeEndRequest(rec.h.Id, 0, statusRequestComplete) + if !req.keepConn { + // connection will close upon return + return + } + default: + b := make([]byte, 8) + b[0] = rec.h.Type + c.conn.writeRecord(typeUnknownType, 0, b) + } + } +} + +func (c *child) serveRequest(req *request, body io.ReadCloser) { + r := newResponse(c, req) + httpReq, err := cgi.RequestFromMap(req.params) + if err != nil { + // there was an error reading the request + r.WriteHeader(http.StatusInternalServerError) + c.conn.writeRecord(typeStderr, req.reqId, []byte(err.String())) + } else { + httpReq.Body = body + c.handler.ServeHTTP(r, httpReq) + } + if body != nil { + body.Close() + } + r.Close() + c.conn.writeEndRequest(req.reqId, 0, statusRequestComplete) + if !req.keepConn { + c.conn.Close() + } +} + +// Serve accepts incoming FastCGI connections on the listener l, creating a new +// service thread for each. The service threads read requests and then call handler +// to reply to them. +// If l is nil, Serve accepts connections on stdin. +// If handler is nil, http.DefaultServeMux is used. +func Serve(l net.Listener, handler http.Handler) os.Error { + if l == nil { + var err os.Error + l, err = net.FileListener(os.Stdin) + if err != nil { + return err + } + defer l.Close() + } + if handler == nil { + handler = http.DefaultServeMux + } + for { + rw, err := l.Accept() + if err != nil { + return err + } + c := newChild(rw, handler) + go c.serve() + } + panic("unreachable") +} diff --git a/libgo/go/http/fcgi/fcgi.go b/libgo/go/http/fcgi/fcgi.go new file mode 100644 index 00000000000..8e2e1cd3cb3 --- /dev/null +++ b/libgo/go/http/fcgi/fcgi.go @@ -0,0 +1,271 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package fcgi implements the FastCGI protocol. +// Currently only the responder role is supported. +// The protocol is defined at http://www.fastcgi.com/drupal/node/6?q=node/22 +package fcgi + +// This file defines the raw protocol and some utilities used by the child and +// the host. + +import ( + "bufio" + "bytes" + "encoding/binary" + "io" + "os" + "sync" +) + +const ( + // Packet Types + typeBeginRequest = iota + 1 + typeAbortRequest + typeEndRequest + typeParams + typeStdin + typeStdout + typeStderr + typeData + typeGetValues + typeGetValuesResult + typeUnknownType +) + +// keep the connection between web-server and responder open after request +const flagKeepConn = 1 + +const ( + maxWrite = 65535 // maximum record body + maxPad = 255 +) + +const ( + roleResponder = iota + 1 // only Responders are implemented. + roleAuthorizer + roleFilter +) + +const ( + statusRequestComplete = iota + statusCantMultiplex + statusOverloaded + statusUnknownRole +) + +const headerLen = 8 + +type header struct { + Version uint8 + Type uint8 + Id uint16 + ContentLength uint16 + PaddingLength uint8 + Reserved uint8 +} + +type beginRequest struct { + role uint16 + flags uint8 + reserved [5]uint8 +} + +func (br *beginRequest) read(content []byte) os.Error { + if len(content) != 8 { + return os.NewError("fcgi: invalid begin request record") + } + br.role = binary.BigEndian.Uint16(content) + br.flags = content[2] + return nil +} + +// for padding so we don't have to allocate all the time +// not synchronized because we don't care what the contents are +var pad [maxPad]byte + +func (h *header) init(recType uint8, reqId uint16, contentLength int) { + h.Version = 1 + h.Type = recType + h.Id = reqId + h.ContentLength = uint16(contentLength) + h.PaddingLength = uint8(-contentLength & 7) +} + +// conn sends records over rwc +type conn struct { + mutex sync.Mutex + rwc io.ReadWriteCloser + + // to avoid allocations + buf bytes.Buffer + h header +} + +func newConn(rwc io.ReadWriteCloser) *conn { + return &conn{rwc: rwc} +} + +func (c *conn) Close() os.Error { + c.mutex.Lock() + defer c.mutex.Unlock() + return c.rwc.Close() +} + +type record struct { + h header + buf [maxWrite + maxPad]byte +} + +func (rec *record) read(r io.Reader) (err os.Error) { + if err = binary.Read(r, binary.BigEndian, &rec.h); err != nil { + return err + } + if rec.h.Version != 1 { + return os.NewError("fcgi: invalid header version") + } + n := int(rec.h.ContentLength) + int(rec.h.PaddingLength) + if _, err = io.ReadFull(r, rec.buf[:n]); err != nil { + return err + } + return nil +} + +func (r *record) content() []byte { + return r.buf[:r.h.ContentLength] +} + +// writeRecord writes and sends a single record. +func (c *conn) writeRecord(recType uint8, reqId uint16, b []byte) os.Error { + c.mutex.Lock() + defer c.mutex.Unlock() + c.buf.Reset() + c.h.init(recType, reqId, len(b)) + if err := binary.Write(&c.buf, binary.BigEndian, c.h); err != nil { + return err + } + if _, err := c.buf.Write(b); err != nil { + return err + } + if _, err := c.buf.Write(pad[:c.h.PaddingLength]); err != nil { + return err + } + _, err := c.rwc.Write(c.buf.Bytes()) + return err +} + +func (c *conn) writeBeginRequest(reqId uint16, role uint16, flags uint8) os.Error { + b := [8]byte{byte(role >> 8), byte(role), flags} + return c.writeRecord(typeBeginRequest, reqId, b[:]) +} + +func (c *conn) writeEndRequest(reqId uint16, appStatus int, protocolStatus uint8) os.Error { + b := make([]byte, 8) + binary.BigEndian.PutUint32(b, uint32(appStatus)) + b[4] = protocolStatus + return c.writeRecord(typeEndRequest, reqId, b) +} + +func (c *conn) writePairs(recType uint8, reqId uint16, pairs map[string]string) os.Error { + w := newWriter(c, recType, reqId) + b := make([]byte, 8) + for k, v := range pairs { + n := encodeSize(b, uint32(len(k))) + n += encodeSize(b[n:], uint32(len(k))) + if _, err := w.Write(b[:n]); err != nil { + return err + } + if _, err := w.WriteString(k); err != nil { + return err + } + if _, err := w.WriteString(v); err != nil { + return err + } + } + w.Close() + return nil +} + +func readSize(s []byte) (uint32, int) { + if len(s) == 0 { + return 0, 0 + } + size, n := uint32(s[0]), 1 + if size&(1<<7) != 0 { + if len(s) < 4 { + return 0, 0 + } + n = 4 + size = binary.BigEndian.Uint32(s) + size &^= 1 << 31 + } + return size, n +} + +func readString(s []byte, size uint32) string { + if size > uint32(len(s)) { + return "" + } + return string(s[:size]) +} + +func encodeSize(b []byte, size uint32) int { + if size > 127 { + size |= 1 << 31 + binary.BigEndian.PutUint32(b, size) + return 4 + } + b[0] = byte(size) + return 1 +} + +// bufWriter encapsulates bufio.Writer but also closes the underlying stream when +// Closed. +type bufWriter struct { + closer io.Closer + *bufio.Writer +} + +func (w *bufWriter) Close() os.Error { + if err := w.Writer.Flush(); err != nil { + w.closer.Close() + return err + } + return w.closer.Close() +} + +func newWriter(c *conn, recType uint8, reqId uint16) *bufWriter { + s := &streamWriter{c: c, recType: recType, reqId: reqId} + w, _ := bufio.NewWriterSize(s, maxWrite) + return &bufWriter{s, w} +} + +// streamWriter abstracts out the separation of a stream into discrete records. +// It only writes maxWrite bytes at a time. +type streamWriter struct { + c *conn + recType uint8 + reqId uint16 +} + +func (w *streamWriter) Write(p []byte) (int, os.Error) { + nn := 0 + for len(p) > 0 { + n := len(p) + if n > maxWrite { + n = maxWrite + } + if err := w.c.writeRecord(w.recType, w.reqId, p[:n]); err != nil { + return nn, err + } + nn += n + p = p[n:] + } + return nn, nil +} + +func (w *streamWriter) Close() os.Error { + // send empty record to close the stream + return w.c.writeRecord(w.recType, w.reqId, nil) +} diff --git a/libgo/go/http/fcgi/fcgi_test.go b/libgo/go/http/fcgi/fcgi_test.go new file mode 100644 index 00000000000..16a6243295e --- /dev/null +++ b/libgo/go/http/fcgi/fcgi_test.go @@ -0,0 +1,114 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package fcgi + +import ( + "bytes" + "io" + "os" + "testing" +) + +var sizeTests = []struct { + size uint32 + bytes []byte +}{ + {0, []byte{0x00}}, + {127, []byte{0x7F}}, + {128, []byte{0x80, 0x00, 0x00, 0x80}}, + {1000, []byte{0x80, 0x00, 0x03, 0xE8}}, + {33554431, []byte{0x81, 0xFF, 0xFF, 0xFF}}, +} + +func TestSize(t *testing.T) { + b := make([]byte, 4) + for i, test := range sizeTests { + n := encodeSize(b, test.size) + if !bytes.Equal(b[:n], test.bytes) { + t.Errorf("%d expected %x, encoded %x", i, test.bytes, b) + } + size, n := readSize(test.bytes) + if size != test.size { + t.Errorf("%d expected %d, read %d", i, test.size, size) + } + if len(test.bytes) != n { + t.Errorf("%d did not consume all the bytes", i) + } + } +} + +var streamTests = []struct { + desc string + recType uint8 + reqId uint16 + content []byte + raw []byte +}{ + {"single record", typeStdout, 1, nil, + []byte{1, typeStdout, 0, 1, 0, 0, 0, 0}, + }, + // this data will have to be split into two records + {"two records", typeStdin, 300, make([]byte, 66000), + bytes.Join([][]byte{ + // header for the first record + []byte{1, typeStdin, 0x01, 0x2C, 0xFF, 0xFF, 1, 0}, + make([]byte, 65536), + // header for the second + []byte{1, typeStdin, 0x01, 0x2C, 0x01, 0xD1, 7, 0}, + make([]byte, 472), + // header for the empty record + []byte{1, typeStdin, 0x01, 0x2C, 0, 0, 0, 0}, + }, + nil), + }, +} + +type nilCloser struct { + io.ReadWriter +} + +func (c *nilCloser) Close() os.Error { return nil } + +func TestStreams(t *testing.T) { + var rec record +outer: + for _, test := range streamTests { + buf := bytes.NewBuffer(test.raw) + var content []byte + for buf.Len() > 0 { + if err := rec.read(buf); err != nil { + t.Errorf("%s: error reading record: %v", test.desc, err) + continue outer + } + content = append(content, rec.content()...) + } + if rec.h.Type != test.recType { + t.Errorf("%s: got type %d expected %d", test.desc, rec.h.Type, test.recType) + continue + } + if rec.h.Id != test.reqId { + t.Errorf("%s: got request ID %d expected %d", test.desc, rec.h.Id, test.reqId) + continue + } + if !bytes.Equal(content, test.content) { + t.Errorf("%s: read wrong content", test.desc) + continue + } + buf.Reset() + c := newConn(&nilCloser{buf}) + w := newWriter(c, test.recType, test.reqId) + if _, err := w.Write(test.content); err != nil { + t.Errorf("%s: error writing record: %v", test.desc, err) + continue + } + if err := w.Close(); err != nil { + t.Errorf("%s: error closing stream: %v", test.desc, err) + continue + } + if !bytes.Equal(buf.Bytes(), test.raw) { + t.Errorf("%s: wrote wrong content", test.desc) + } + } +} diff --git a/libgo/go/http/fs.go b/libgo/go/http/fs.go index c5efffca9cd..17d5297b82c 100644 --- a/libgo/go/http/fs.go +++ b/libgo/go/http/fs.go @@ -143,7 +143,7 @@ func serveFile(w ResponseWriter, r *Request, name string, redirect bool) { n, _ := io.ReadFull(f, buf[:]) b := buf[:n] if isText(b) { - ctype = "text-plain; charset=utf-8" + ctype = "text/plain; charset=utf-8" } else { // generic binary ctype = "application/octet-stream" diff --git a/libgo/go/http/fs_test.go b/libgo/go/http/fs_test.go index 692b9863e82..09d0981f26e 100644 --- a/libgo/go/http/fs_test.go +++ b/libgo/go/http/fs_test.go @@ -104,7 +104,7 @@ func TestServeFileContentType(t *testing.T) { t.Errorf("Content-Type mismatch: got %q, want %q", h, want) } } - get("text-plain; charset=utf-8") + get("text/plain; charset=utf-8") override = true get(ctype) } diff --git a/libgo/go/http/header.go b/libgo/go/http/header.go index 95b0f3db6bb..95140b01f2a 100644 --- a/libgo/go/http/header.go +++ b/libgo/go/http/header.go @@ -4,7 +4,14 @@ package http -import "net/textproto" +import ( + "fmt" + "io" + "net/textproto" + "os" + "sort" + "strings" +) // A Header represents the key-value pairs in an HTTP header. type Header map[string][]string @@ -35,6 +42,37 @@ func (h Header) Del(key string) { textproto.MIMEHeader(h).Del(key) } +// Write writes a header in wire format. +func (h Header) Write(w io.Writer) os.Error { + return h.WriteSubset(w, nil) +} + +// WriteSubset writes a header in wire format. +// If exclude is not nil, keys where exclude[key] == true are not written. +func (h Header) WriteSubset(w io.Writer, exclude map[string]bool) os.Error { + keys := make([]string, 0, len(h)) + for k := range h { + if exclude == nil || !exclude[k] { + keys = append(keys, k) + } + } + sort.SortStrings(keys) + for _, k := range keys { + for _, v := range h[k] { + v = strings.Replace(v, "\n", " ", -1) + v = strings.Replace(v, "\r", " ", -1) + v = strings.TrimSpace(v) + if v == "" { + continue + } + if _, err := fmt.Fprintf(w, "%s: %s\r\n", k, v); err != nil { + return err + } + } + } + return nil +} + // CanonicalHeaderKey returns the canonical format of the // header key s. The canonicalization converts the first // letter and any letter following a hyphen to upper case; diff --git a/libgo/go/http/header_test.go b/libgo/go/http/header_test.go new file mode 100644 index 00000000000..7e24cb069c6 --- /dev/null +++ b/libgo/go/http/header_test.go @@ -0,0 +1,71 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package http + +import ( + "bytes" + "testing" +) + +var headerWriteTests = []struct { + h Header + exclude map[string]bool + expected string +}{ + {Header{}, nil, ""}, + { + Header{ + "Content-Type": {"text/html; charset=UTF-8"}, + "Content-Length": {"0"}, + }, + nil, + "Content-Length: 0\r\nContent-Type: text/html; charset=UTF-8\r\n", + }, + { + Header{ + "Content-Length": {"0", "1", "2"}, + }, + nil, + "Content-Length: 0\r\nContent-Length: 1\r\nContent-Length: 2\r\n", + }, + { + Header{ + "Expires": {"-1"}, + "Content-Length": {"0"}, + "Content-Encoding": {"gzip"}, + }, + map[string]bool{"Content-Length": true}, + "Content-Encoding: gzip\r\nExpires: -1\r\n", + }, + { + Header{ + "Expires": {"-1"}, + "Content-Length": {"0", "1", "2"}, + "Content-Encoding": {"gzip"}, + }, + map[string]bool{"Content-Length": true}, + "Content-Encoding: gzip\r\nExpires: -1\r\n", + }, + { + Header{ + "Expires": {"-1"}, + "Content-Length": {"0"}, + "Content-Encoding": {"gzip"}, + }, + map[string]bool{"Content-Length": true, "Expires": true, "Content-Encoding": true}, + "", + }, +} + +func TestHeaderWrite(t *testing.T) { + var buf bytes.Buffer + for i, test := range headerWriteTests { + test.h.WriteSubset(&buf, test.exclude) + if buf.String() != test.expected { + t.Errorf("#%d:\n got: %q\nwant: %q", i, buf.String(), test.expected) + } + buf.Reset() + } +} diff --git a/libgo/go/http/httptest/recorder.go b/libgo/go/http/httptest/recorder.go index 0dd19a617cc..f2fedefcfd1 100644 --- a/libgo/go/http/httptest/recorder.go +++ b/libgo/go/http/httptest/recorder.go @@ -2,7 +2,7 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -// The httptest package provides utilities for HTTP testing. +// Package httptest provides utilities for HTTP testing. package httptest import ( diff --git a/libgo/go/http/persist.go b/libgo/go/http/persist.go index b93c5fe4855..e4eea6815d0 100644 --- a/libgo/go/http/persist.go +++ b/libgo/go/http/persist.go @@ -20,8 +20,8 @@ var ( // A ServerConn reads requests and sends responses over an underlying // connection, until the HTTP keepalive logic commands an end. ServerConn -// does not close the underlying connection. Instead, the user calls Close -// and regains control over the connection. ServerConn supports pipe-lining, +// also allows hijacking the underlying connection by calling Hijack +// to regain control over the connection. ServerConn supports pipe-lining, // i.e. requests can be read out of sync (but in the same order) while the // respective responses are sent. type ServerConn struct { @@ -45,11 +45,11 @@ func NewServerConn(c net.Conn, r *bufio.Reader) *ServerConn { return &ServerConn{c: c, r: r, pipereq: make(map[*Request]uint)} } -// Close detaches the ServerConn and returns the underlying connection as well -// as the read-side bufio which may have some left over data. Close may be +// Hijack detaches the ServerConn and returns the underlying connection as well +// as the read-side bufio which may have some left over data. Hijack may be // called before Read has signaled the end of the keep-alive logic. The user -// should not call Close while Read or Write is in progress. -func (sc *ServerConn) Close() (c net.Conn, r *bufio.Reader) { +// should not call Hijack while Read or Write is in progress. +func (sc *ServerConn) Hijack() (c net.Conn, r *bufio.Reader) { sc.lk.Lock() defer sc.lk.Unlock() c = sc.c @@ -59,6 +59,15 @@ func (sc *ServerConn) Close() (c net.Conn, r *bufio.Reader) { return } +// Close calls Hijack and then also closes the underlying connection +func (sc *ServerConn) Close() os.Error { + c, _ := sc.Hijack() + if c != nil { + return c.Close() + } + return nil +} + // Read returns the next request on the wire. An ErrPersistEOF is returned if // it is gracefully determined that there are no more requests (e.g. after the // first request on an HTTP/1.0 connection, or after a Connection:close on a @@ -199,9 +208,9 @@ func (sc *ServerConn) Write(req *Request, resp *Response) os.Error { } // A ClientConn sends request and receives headers over an underlying -// connection, while respecting the HTTP keepalive logic. ClientConn is not -// responsible for closing the underlying connection. One must call Close to -// regain control of that connection and deal with it as desired. +// connection, while respecting the HTTP keepalive logic. ClientConn +// supports hijacking the connection calling Hijack to +// regain control of the underlying net.Conn and deal with it as desired. type ClientConn struct { lk sync.Mutex // read-write protects the following fields c net.Conn @@ -239,11 +248,11 @@ func NewProxyClientConn(c net.Conn, r *bufio.Reader) *ClientConn { return cc } -// Close detaches the ClientConn and returns the underlying connection as well -// as the read-side bufio which may have some left over data. Close may be +// Hijack detaches the ClientConn and returns the underlying connection as well +// as the read-side bufio which may have some left over data. Hijack may be // called before the user or Read have signaled the end of the keep-alive -// logic. The user should not call Close while Read or Write is in progress. -func (cc *ClientConn) Close() (c net.Conn, r *bufio.Reader) { +// logic. The user should not call Hijack while Read or Write is in progress. +func (cc *ClientConn) Hijack() (c net.Conn, r *bufio.Reader) { cc.lk.Lock() defer cc.lk.Unlock() c = cc.c @@ -253,6 +262,15 @@ func (cc *ClientConn) Close() (c net.Conn, r *bufio.Reader) { return } +// Close calls Hijack and then also closes the underlying connection +func (cc *ClientConn) Close() os.Error { + c, _ := cc.Hijack() + if c != nil { + return c.Close() + } + return nil +} + // Write writes a request. An ErrPersistEOF error is returned if the connection // has been closed in an HTTP keepalive sense. If req.Close equals true, the // keepalive connection is logically closed after this request and the opposing diff --git a/libgo/go/http/pprof/pprof.go b/libgo/go/http/pprof/pprof.go index bc79e218320..917c7f877a3 100644 --- a/libgo/go/http/pprof/pprof.go +++ b/libgo/go/http/pprof/pprof.go @@ -26,6 +26,7 @@ package pprof import ( "bufio" + "bytes" "fmt" "http" "os" @@ -88,10 +89,14 @@ func Profile(w http.ResponseWriter, r *http.Request) { func Symbol(w http.ResponseWriter, r *http.Request) { w.Header().Set("Content-Type", "text/plain; charset=utf-8") + // We have to read the whole POST body before + // writing any output. Buffer the output here. + var buf bytes.Buffer + // We don't know how many symbols we have, but we // do have symbol information. Pprof only cares whether // this number is 0 (no symbols available) or > 0. - fmt.Fprintf(w, "num_symbols: 1\n") + fmt.Fprintf(&buf, "num_symbols: 1\n") var b *bufio.Reader if r.Method == "POST" { @@ -109,14 +114,19 @@ func Symbol(w http.ResponseWriter, r *http.Request) { if pc != 0 { f := runtime.FuncForPC(uintptr(pc)) if f != nil { - fmt.Fprintf(w, "%#x %s\n", pc, f.Name()) + fmt.Fprintf(&buf, "%#x %s\n", pc, f.Name()) } } // Wait until here to check for err; the last // symbol will have an err because it doesn't end in +. if err != nil { + if err != os.EOF { + fmt.Fprintf(&buf, "reading request: %v\n", err) + } break } } + + w.Write(buf.Bytes()) } diff --git a/libgo/go/http/proxy_test.go b/libgo/go/http/proxy_test.go index 7050ef5ed06..308bf44b48a 100644 --- a/libgo/go/http/proxy_test.go +++ b/libgo/go/http/proxy_test.go @@ -16,9 +16,15 @@ var UseProxyTests = []struct { host string match bool }{ - {"localhost", false}, // match completely + // Never proxy localhost: + {"localhost:80", false}, + {"127.0.0.1", false}, + {"127.0.0.2", false}, + {"[::1]", false}, + {"[::2]", true}, // not a loopback address + {"barbaz.net", false}, // match as .barbaz.net - {"foobar.com:443", false}, // have a port but match + {"foobar.com", false}, // have a port but match {"foofoobar.com", true}, // not match as a part of foobar.com {"baz.com", true}, // not match as a part of barbaz.com {"localhost.net", true}, // not match as suffix of address @@ -29,19 +35,16 @@ var UseProxyTests = []struct { func TestUseProxy(t *testing.T) { oldenv := os.Getenv("NO_PROXY") - no_proxy := "foobar.com, .barbaz.net , localhost" - os.Setenv("NO_PROXY", no_proxy) defer os.Setenv("NO_PROXY", oldenv) + no_proxy := "foobar.com, .barbaz.net" + os.Setenv("NO_PROXY", no_proxy) + tr := &Transport{} for _, test := range UseProxyTests { - if tr.useProxy(test.host) != test.match { - if test.match { - t.Errorf("useProxy(%v) = %v, want %v", test.host, !test.match, test.match) - } else { - t.Errorf("not expected: '%s' shouldn't match as '%s'", test.host, no_proxy) - } + if tr.useProxy(test.host+":80") != test.match { + t.Errorf("useProxy(%v) = %v, want %v", test.host, !test.match, test.match) } } } diff --git a/libgo/go/http/request.go b/libgo/go/http/request.go index d82894fab08..8545d75660a 100644 --- a/libgo/go/http/request.go +++ b/libgo/go/http/request.go @@ -4,9 +4,8 @@ // HTTP Request reading and parsing. -// The http package implements parsing of HTTP requests, replies, -// and URLs and provides an extensible HTTP server and a basic -// HTTP client. +// Package http implements parsing of HTTP requests, replies, and URLs and +// provides an extensible HTTP server and a basic HTTP client. package http import ( @@ -25,12 +24,17 @@ import ( ) const ( - maxLineLength = 4096 // assumed <= bufio.defaultBufSize - maxValueLength = 4096 - maxHeaderLines = 1024 - chunkSize = 4 << 10 // 4 KB chunks + maxLineLength = 4096 // assumed <= bufio.defaultBufSize + maxValueLength = 4096 + maxHeaderLines = 1024 + chunkSize = 4 << 10 // 4 KB chunks + defaultMaxMemory = 32 << 20 // 32 MB ) +// ErrMissingFile is returned by FormFile when the provided file field name +// is either not present in the request or not a file field. +var ErrMissingFile = os.ErrorString("http: no such file") + // HTTP request parsing errors. type ProtocolError struct { os.ErrorString @@ -65,9 +69,12 @@ var reqExcludeHeader = map[string]bool{ // A Request represents a parsed HTTP request header. type Request struct { - Method string // GET, POST, PUT, etc. - RawURL string // The raw URL given in the request. - URL *URL // Parsed URL. + Method string // GET, POST, PUT, etc. + RawURL string // The raw URL given in the request. + URL *URL // Parsed URL. + + // The protocol version for incoming requests. + // Outgoing requests always use HTTP/1.1. Proto string // "HTTP/1.0" ProtoMajor int // 1 ProtoMinor int // 0 @@ -134,6 +141,10 @@ type Request struct { // The parsed form. Only available after ParseForm is called. Form map[string][]string + // The parsed multipart form, including file uploads. + // Only available after ParseMultipartForm is called. + MultipartForm *multipart.Form + // Trailer maps trailer keys to values. Like for Header, if the // response has multiple trailer lines with the same key, they will be // concatenated, delimited by commas. @@ -163,9 +174,30 @@ func (r *Request) ProtoAtLeast(major, minor int) bool { r.ProtoMajor == major && r.ProtoMinor >= minor } +// multipartByReader is a sentinel value. +// Its presence in Request.MultipartForm indicates that parsing of the request +// body has been handed off to a MultipartReader instead of ParseMultipartFrom. +var multipartByReader = &multipart.Form{ + Value: make(map[string][]string), + File: make(map[string][]*multipart.FileHeader), +} + // MultipartReader returns a MIME multipart reader if this is a // multipart/form-data POST request, else returns nil and an error. +// Use this function instead of ParseMultipartForm to +// process the request body as a stream. func (r *Request) MultipartReader() (multipart.Reader, os.Error) { + if r.MultipartForm == multipartByReader { + return nil, os.NewError("http: MultipartReader called twice") + } + if r.MultipartForm != nil { + return nil, os.NewError("http: multipart handled by ParseMultipartForm") + } + r.MultipartForm = multipartByReader + return r.multipartReader() +} + +func (r *Request) multipartReader() (multipart.Reader, os.Error) { v := r.Header.Get("Content-Type") if v == "" { return nil, ErrNotMultipart @@ -199,10 +231,14 @@ const defaultUserAgent = "Go http package" // UserAgent (defaults to defaultUserAgent) // Referer // Header +// Cookie +// ContentLength +// TransferEncoding // Body // -// If Body is present, Write forces "Transfer-Encoding: chunked" as a header -// and then closes Body when finished sending it. +// If Body is present but Content-Length is <= 0, Write adds +// "Transfer-Encoding: chunked" to the header. Body is closed after +// it is sent. func (req *Request) Write(w io.Writer) os.Error { return req.write(w, false) } @@ -264,7 +300,7 @@ func (req *Request) write(w io.Writer, usingProxy bool) os.Error { // from Request, and introduce Request methods along the lines of // Response.{GetHeader,AddHeader} and string constants for "Host", // "User-Agent" and "Referer". - err = writeSortedHeader(w, req.Header, reqExcludeHeader) + err = req.Header.WriteSubset(w, reqExcludeHeader) if err != nil { return err } @@ -420,6 +456,29 @@ func (cr *chunkedReader) Read(b []uint8) (n int, err os.Error) { return n, cr.err } +// NewRequest returns a new Request given a method, URL, and optional body. +func NewRequest(method, url string, body io.Reader) (*Request, os.Error) { + u, err := ParseURL(url) + if err != nil { + return nil, err + } + rc, ok := body.(io.ReadCloser) + if !ok && body != nil { + rc = ioutil.NopCloser(body) + } + req := &Request{ + Method: method, + URL: u, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(Header), + Body: rc, + Host: u.Host, + } + return req, nil +} + // ReadRequest reads and parses a request from b. func ReadRequest(b *bufio.Reader) (req *Request, err os.Error) { @@ -549,7 +608,9 @@ func parseQuery(m map[string][]string, query string) (err os.Error) { return err } -// ParseForm parses the request body as a form for POST requests, or the raw query for GET requests. +// ParseForm parses the raw query. +// For POST requests, it also parses the request body as a form. +// ParseMultipartForm calls ParseForm automatically. // It is idempotent. func (r *Request) ParseForm() (err os.Error) { if r.Form != nil { @@ -567,18 +628,23 @@ func (r *Request) ParseForm() (err os.Error) { ct := r.Header.Get("Content-Type") switch strings.Split(ct, ";", 2)[0] { case "text/plain", "application/x-www-form-urlencoded", "": - b, e := ioutil.ReadAll(r.Body) + const maxFormSize = int64(10 << 20) // 10 MB is a lot of text. + b, e := ioutil.ReadAll(io.LimitReader(r.Body, maxFormSize+1)) if e != nil { if err == nil { err = e } break } + if int64(len(b)) > maxFormSize { + return os.NewError("http: POST too large") + } e = parseQuery(r.Form, string(b)) if err == nil { err = e } - // TODO(dsymonds): Handle multipart/form-data + case "multipart/form-data": + // handled by ParseMultipartForm default: return &badStringError{"unknown Content-Type", ct} } @@ -586,11 +652,50 @@ func (r *Request) ParseForm() (err os.Error) { return err } +// ParseMultipartForm parses a request body as multipart/form-data. +// The whole request body is parsed and up to a total of maxMemory bytes of +// its file parts are stored in memory, with the remainder stored on +// disk in temporary files. +// ParseMultipartForm calls ParseForm if necessary. +// After one call to ParseMultipartForm, subsequent calls have no effect. +func (r *Request) ParseMultipartForm(maxMemory int64) os.Error { + if r.Form == nil { + err := r.ParseForm() + if err != nil { + return err + } + } + if r.MultipartForm != nil { + return nil + } + if r.MultipartForm == multipartByReader { + return os.NewError("http: multipart handled by MultipartReader") + } + + mr, err := r.multipartReader() + if err == ErrNotMultipart { + return nil + } else if err != nil { + return err + } + + f, err := mr.ReadForm(maxMemory) + if err != nil { + return err + } + for k, v := range f.Value { + r.Form[k] = append(r.Form[k], v...) + } + r.MultipartForm = f + + return nil +} + // FormValue returns the first value for the named component of the query. -// FormValue calls ParseForm if necessary. +// FormValue calls ParseMultipartForm and ParseForm if necessary. func (r *Request) FormValue(key string) string { if r.Form == nil { - r.ParseForm() + r.ParseMultipartForm(defaultMaxMemory) } if vs := r.Form[key]; len(vs) > 0 { return vs[0] @@ -598,6 +703,27 @@ func (r *Request) FormValue(key string) string { return "" } +// FormFile returns the first file for the provided form key. +// FormFile calls ParseMultipartForm and ParseForm if necessary. +func (r *Request) FormFile(key string) (multipart.File, *multipart.FileHeader, os.Error) { + if r.MultipartForm == multipartByReader { + return nil, nil, os.NewError("http: multipart handled by MultipartReader") + } + if r.MultipartForm == nil { + err := r.ParseMultipartForm(defaultMaxMemory) + if err != nil { + return nil, nil, err + } + } + if r.MultipartForm != nil && r.MultipartForm.File != nil { + if fhs := r.MultipartForm.File[key]; len(fhs) > 0 { + f, err := fhs[0].Open() + return f, fhs[0], err + } + } + return nil, nil, ErrMissingFile +} + func (r *Request) expectsContinue() bool { return strings.ToLower(r.Header.Get("Expect")) == "100-continue" } diff --git a/libgo/go/http/request_test.go b/libgo/go/http/request_test.go index 19083adf624..f79d3a24240 100644 --- a/libgo/go/http/request_test.go +++ b/libgo/go/http/request_test.go @@ -10,6 +10,8 @@ import ( . "http" "http/httptest" "io" + "io/ioutil" + "mime/multipart" "os" "reflect" "regexp" @@ -82,7 +84,7 @@ func TestPostQuery(t *testing.T) { req.Header = Header{ "Content-Type": {"application/x-www-form-urlencoded; boo!"}, } - req.Body = nopCloser{strings.NewReader("z=post&both=y")} + req.Body = ioutil.NopCloser(strings.NewReader("z=post&both=y")) if q := req.FormValue("q"); q != "foo" { t.Errorf(`req.FormValue("q") = %q, want "foo"`, q) } @@ -115,7 +117,7 @@ func TestPostContentTypeParsing(t *testing.T) { req := &Request{ Method: "POST", Header: Header(test.contentType), - Body: nopCloser{bytes.NewBufferString("body")}, + Body: ioutil.NopCloser(bytes.NewBufferString("body")), } err := req.ParseForm() if !test.error && err != nil { @@ -131,7 +133,7 @@ func TestMultipartReader(t *testing.T) { req := &Request{ Method: "POST", Header: Header{"Content-Type": {`multipart/form-data; boundary="foo123"`}}, - Body: nopCloser{new(bytes.Buffer)}, + Body: ioutil.NopCloser(new(bytes.Buffer)), } multipart, err := req.MultipartReader() if multipart == nil { @@ -170,9 +172,143 @@ func TestRedirect(t *testing.T) { } } -// TODO: stop copy/pasting this around. move to io/ioutil? -type nopCloser struct { - io.Reader +func TestMultipartRequest(t *testing.T) { + // Test that we can read the values and files of a + // multipart request with FormValue and FormFile, + // and that ParseMultipartForm can be called multiple times. + req := newTestMultipartRequest(t) + if err := req.ParseMultipartForm(25); err != nil { + t.Fatal("ParseMultipartForm first call:", err) + } + defer req.MultipartForm.RemoveAll() + validateTestMultipartContents(t, req, false) + if err := req.ParseMultipartForm(25); err != nil { + t.Fatal("ParseMultipartForm second call:", err) + } + validateTestMultipartContents(t, req, false) +} + +func TestMultipartRequestAuto(t *testing.T) { + // Test that FormValue and FormFile automatically invoke + // ParseMultipartForm and return the right values. + req := newTestMultipartRequest(t) + defer func() { + if req.MultipartForm != nil { + req.MultipartForm.RemoveAll() + } + }() + validateTestMultipartContents(t, req, true) +} + +func TestEmptyMultipartRequest(t *testing.T) { + // Test that FormValue and FormFile automatically invoke + // ParseMultipartForm and return the right values. + req, err := NewRequest("GET", "/", nil) + if err != nil { + t.Errorf("NewRequest err = %q", err) + } + testMissingFile(t, req) +} + +func testMissingFile(t *testing.T, req *Request) { + f, fh, err := req.FormFile("missing") + if f != nil { + t.Errorf("FormFile file = %q, want nil", f, nil) + } + if fh != nil { + t.Errorf("FormFile file header = %q, want nil", fh, nil) + } + if err != ErrMissingFile { + t.Errorf("FormFile err = %q, want nil", err, ErrMissingFile) + } } -func (nopCloser) Close() os.Error { return nil } +func newTestMultipartRequest(t *testing.T) *Request { + b := bytes.NewBufferString(strings.Replace(message, "\n", "\r\n", -1)) + req, err := NewRequest("POST", "/", b) + if err != nil { + t.Fatalf("NewRequest:", err) + } + ctype := fmt.Sprintf(`multipart/form-data; boundary="%s"`, boundary) + req.Header.Set("Content-type", ctype) + return req +} + +func validateTestMultipartContents(t *testing.T, req *Request, allMem bool) { + if g, e := req.FormValue("texta"), textaValue; g != e { + t.Errorf("texta value = %q, want %q", g, e) + } + if g, e := req.FormValue("texta"), textaValue; g != e { + t.Errorf("texta value = %q, want %q", g, e) + } + if g := req.FormValue("missing"); g != "" { + t.Errorf("missing value = %q, want empty string", g) + } + + assertMem := func(n string, fd multipart.File) { + if _, ok := fd.(*os.File); ok { + t.Error(n, " is *os.File, should not be") + } + } + fd := testMultipartFile(t, req, "filea", "filea.txt", fileaContents) + assertMem("filea", fd) + fd = testMultipartFile(t, req, "fileb", "fileb.txt", filebContents) + if allMem { + assertMem("fileb", fd) + } else { + if _, ok := fd.(*os.File); !ok { + t.Errorf("fileb has unexpected underlying type %T", fd) + } + } + + testMissingFile(t, req) +} + +func testMultipartFile(t *testing.T, req *Request, key, expectFilename, expectContent string) multipart.File { + f, fh, err := req.FormFile(key) + if err != nil { + t.Fatalf("FormFile(%q):", key, err) + } + if fh.Filename != expectFilename { + t.Errorf("filename = %q, want %q", fh.Filename, expectFilename) + } + var b bytes.Buffer + _, err = io.Copy(&b, f) + if err != nil { + t.Fatal("copying contents:", err) + } + if g := b.String(); g != expectContent { + t.Errorf("contents = %q, want %q", g, expectContent) + } + return f +} + +const ( + fileaContents = "This is a test file." + filebContents = "Another test file." + textaValue = "foo" + textbValue = "bar" + boundary = `MyBoundary` +) + +const message = ` +--MyBoundary +Content-Disposition: form-data; name="filea"; filename="filea.txt" +Content-Type: text/plain + +` + fileaContents + ` +--MyBoundary +Content-Disposition: form-data; name="fileb"; filename="fileb.txt" +Content-Type: text/plain + +` + filebContents + ` +--MyBoundary +Content-Disposition: form-data; name="texta" + +` + textaValue + ` +--MyBoundary +Content-Disposition: form-data; name="textb" + +` + textbValue + ` +--MyBoundary-- +` diff --git a/libgo/go/http/requestwrite_test.go b/libgo/go/http/requestwrite_test.go index 726baa26686..bb000c701ff 100644 --- a/libgo/go/http/requestwrite_test.go +++ b/libgo/go/http/requestwrite_test.go @@ -6,7 +6,10 @@ package http import ( "bytes" + "io" "io/ioutil" + "os" + "strings" "testing" ) @@ -133,6 +136,41 @@ var reqWriteTests = []reqWriteTest{ "Transfer-Encoding: chunked\r\n\r\n" + "6\r\nabcdef\r\n0\r\n\r\n", }, + + // HTTP/1.1 POST with Content-Length, no chunking + { + Request{ + Method: "POST", + URL: &URL{ + Scheme: "http", + Host: "www.google.com", + Path: "/search", + }, + ProtoMajor: 1, + ProtoMinor: 1, + Header: Header{}, + Close: true, + ContentLength: 6, + }, + + []byte("abcdef"), + + "POST /search HTTP/1.1\r\n" + + "Host: www.google.com\r\n" + + "User-Agent: Go http package\r\n" + + "Connection: close\r\n" + + "Content-Length: 6\r\n" + + "\r\n" + + "abcdef", + + "POST http://www.google.com/search HTTP/1.1\r\n" + + "User-Agent: Go http package\r\n" + + "Connection: close\r\n" + + "Content-Length: 6\r\n" + + "\r\n" + + "abcdef", + }, + // default to HTTP/1.1 { Request{ @@ -189,3 +227,26 @@ func TestRequestWrite(t *testing.T) { } } } + +type closeChecker struct { + io.Reader + closed bool +} + +func (rc *closeChecker) Close() os.Error { + rc.closed = true + return nil +} + +// TestRequestWriteClosesBody tests that Request.Write does close its request.Body. +// It also indirectly tests NewRequest and that it doesn't wrap an existing Closer +// inside a NopCloser. +func TestRequestWriteClosesBody(t *testing.T) { + rc := &closeChecker{Reader: strings.NewReader("my body")} + req, _ := NewRequest("GET", "http://foo.com/", rc) + buf := new(bytes.Buffer) + req.Write(buf) + if !rc.closed { + t.Error("body not closed after write") + } +} diff --git a/libgo/go/http/response.go b/libgo/go/http/response.go index 1f725ecdddd..a65c2b14df6 100644 --- a/libgo/go/http/response.go +++ b/libgo/go/http/response.go @@ -8,11 +8,9 @@ package http import ( "bufio" - "fmt" "io" "net/textproto" "os" - "sort" "strconv" "strings" ) @@ -192,7 +190,7 @@ func (resp *Response) Write(w io.Writer) os.Error { } // Rest of header - err = writeSortedHeader(w, resp.Header, respExcludeHeader) + err = resp.Header.WriteSubset(w, respExcludeHeader) if err != nil { return err } @@ -213,27 +211,3 @@ func (resp *Response) Write(w io.Writer) os.Error { // Success return nil } - -func writeSortedHeader(w io.Writer, h Header, exclude map[string]bool) os.Error { - keys := make([]string, 0, len(h)) - for k := range h { - if exclude == nil || !exclude[k] { - keys = append(keys, k) - } - } - sort.SortStrings(keys) - for _, k := range keys { - for _, v := range h[k] { - v = strings.Replace(v, "\n", " ", -1) - v = strings.Replace(v, "\r", " ", -1) - v = strings.TrimSpace(v) - if v == "" { - continue - } - if _, err := fmt.Fprintf(w, "%s: %s\r\n", k, v); err != nil { - return err - } - } - } - return nil -} diff --git a/libgo/go/http/response_test.go b/libgo/go/http/response_test.go index ef67fdd2dc3..9e77c20c40b 100644 --- a/libgo/go/http/response_test.go +++ b/libgo/go/http/response_test.go @@ -7,8 +7,12 @@ package http import ( "bufio" "bytes" + "compress/gzip" + "crypto/rand" "fmt" + "os" "io" + "io/ioutil" "reflect" "testing" ) @@ -117,7 +121,9 @@ var respTests = []respTest{ "Transfer-Encoding: chunked\r\n" + "\r\n" + "0a\r\n" + - "Body here\n" + + "Body here\n\r\n" + + "09\r\n" + + "continued\r\n" + "0\r\n" + "\r\n", @@ -134,7 +140,7 @@ var respTests = []respTest{ TransferEncoding: []string{"chunked"}, }, - "Body here\n", + "Body here\ncontinued", }, // Chunked response with Content-Length. @@ -186,6 +192,29 @@ var respTests = []respTest{ "", }, + // explicit Content-Length of 0. + { + "HTTP/1.1 200 OK\r\n" + + "Content-Length: 0\r\n" + + "\r\n", + + Response{ + Status: "200 OK", + StatusCode: 200, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + RequestMethod: "GET", + Header: Header{ + "Content-Length": {"0"}, + }, + Close: false, + ContentLength: 0, + }, + + "", + }, + // Status line without a Reason-Phrase, but trailing space. // (permitted by RFC 2616) { @@ -250,9 +279,107 @@ func TestReadResponse(t *testing.T) { } } +var readResponseCloseInMiddleTests = []struct { + chunked, compressed bool +}{ + {false, false}, + {true, false}, + {true, true}, +} + +// TestReadResponseCloseInMiddle tests that closing a body after +// reading only part of its contents advances the read to the end of +// the request, right up until the next request. +func TestReadResponseCloseInMiddle(t *testing.T) { + for _, test := range readResponseCloseInMiddleTests { + fatalf := func(format string, args ...interface{}) { + args = append([]interface{}{test.chunked, test.compressed}, args...) + t.Fatalf("on test chunked=%v, compressed=%v: "+format, args...) + } + checkErr := func(err os.Error, msg string) { + if err == nil { + return + } + fatalf(msg+": %v", err) + } + var buf bytes.Buffer + buf.WriteString("HTTP/1.1 200 OK\r\n") + if test.chunked { + buf.WriteString("Transfer-Encoding: chunked\r\n") + } else { + buf.WriteString("Content-Length: 1000000\r\n") + } + var wr io.Writer = &buf + if test.chunked { + wr = &chunkedWriter{wr} + } + if test.compressed { + buf.WriteString("Content-Encoding: gzip\r\n") + var err os.Error + wr, err = gzip.NewWriter(wr) + checkErr(err, "gzip.NewWriter") + } + buf.WriteString("\r\n") + + chunk := bytes.Repeat([]byte{'x'}, 1000) + for i := 0; i < 1000; i++ { + if test.compressed { + // Otherwise this compresses too well. + _, err := io.ReadFull(rand.Reader, chunk) + checkErr(err, "rand.Reader ReadFull") + } + wr.Write(chunk) + } + if test.compressed { + err := wr.(*gzip.Compressor).Close() + checkErr(err, "compressor close") + } + if test.chunked { + buf.WriteString("0\r\n\r\n") + } + buf.WriteString("Next Request Here") + + bufr := bufio.NewReader(&buf) + resp, err := ReadResponse(bufr, "GET") + checkErr(err, "ReadResponse") + expectedLength := int64(-1) + if !test.chunked { + expectedLength = 1000000 + } + if resp.ContentLength != expectedLength { + fatalf("expected response length %d, got %d", expectedLength, resp.ContentLength) + } + if resp.Body == nil { + fatalf("nil body") + } + if test.compressed { + gzReader, err := gzip.NewReader(resp.Body) + checkErr(err, "gzip.NewReader") + resp.Body = &readFirstCloseBoth{gzReader, resp.Body} + } + + rbuf := make([]byte, 2500) + n, err := io.ReadFull(resp.Body, rbuf) + checkErr(err, "2500 byte ReadFull") + if n != 2500 { + fatalf("ReadFull only read %d bytes", n) + } + if test.compressed == false && !bytes.Equal(bytes.Repeat([]byte{'x'}, 2500), rbuf) { + fatalf("ReadFull didn't read 2500 'x'; got %q", string(rbuf)) + } + resp.Body.Close() + + rest, err := ioutil.ReadAll(bufr) + checkErr(err, "ReadAll on remainder") + if e, g := "Next Request Here", string(rest); e != g { + fatalf("for chunked=%v remainder = %q, expected %q", g, e) + } + } +} + func diff(t *testing.T, prefix string, have, want interface{}) { - hv := reflect.NewValue(have).(*reflect.PtrValue).Elem().(*reflect.StructValue) - wv := reflect.NewValue(want).(*reflect.PtrValue).Elem().(*reflect.StructValue) + hv := reflect.ValueOf(have).Elem() + wv := reflect.ValueOf(want).Elem() if hv.Type() != wv.Type() { t.Errorf("%s: type mismatch %v vs %v", prefix, hv.Type(), wv.Type()) } @@ -260,7 +387,7 @@ func diff(t *testing.T, prefix string, have, want interface{}) { hf := hv.Field(i).Interface() wf := wv.Field(i).Interface() if !reflect.DeepEqual(hf, wf) { - t.Errorf("%s: %s = %v want %v", prefix, hv.Type().(*reflect.StructType).Field(i).Name, hf, wf) + t.Errorf("%s: %s = %v want %v", prefix, hv.Type().Field(i).Name, hf, wf) } } } diff --git a/libgo/go/http/reverseproxy.go b/libgo/go/http/reverseproxy.go new file mode 100644 index 00000000000..e4ce1e34c79 --- /dev/null +++ b/libgo/go/http/reverseproxy.go @@ -0,0 +1,100 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// HTTP reverse proxy handler + +package http + +import ( + "io" + "log" + "net" + "strings" +) + +// ReverseProxy is an HTTP Handler that takes an incoming request and +// sends it to another server, proxying the response back to the +// client. +type ReverseProxy struct { + // Director must be a function which modifies + // the request into a new request to be sent + // using Transport. Its response is then copied + // back to the original client unmodified. + Director func(*Request) + + // The Transport used to perform proxy requests. + // If nil, DefaultTransport is used. + Transport RoundTripper +} + +func singleJoiningSlash(a, b string) string { + aslash := strings.HasSuffix(a, "/") + bslash := strings.HasPrefix(b, "/") + switch { + case aslash && bslash: + return a + b[1:] + case !aslash && !bslash: + return a + "/" + b + } + return a + b +} + +// NewSingleHostReverseProxy returns a new ReverseProxy that rewrites +// URLs to the scheme, host, and base path provided in target. If the +// target's path is "/base" and the incoming request was for "/dir", +// the target request will be for /base/dir. +func NewSingleHostReverseProxy(target *URL) *ReverseProxy { + director := func(req *Request) { + req.URL.Scheme = target.Scheme + req.URL.Host = target.Host + req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path) + if q := req.URL.RawQuery; q != "" { + req.URL.RawPath = req.URL.Path + "?" + q + } else { + req.URL.RawPath = req.URL.Path + } + req.URL.RawQuery = target.RawQuery + } + return &ReverseProxy{Director: director} +} + +func (p *ReverseProxy) ServeHTTP(rw ResponseWriter, req *Request) { + transport := p.Transport + if transport == nil { + transport = DefaultTransport + } + + outreq := new(Request) + *outreq = *req // includes shallow copies of maps, but okay + + p.Director(outreq) + outreq.Proto = "HTTP/1.1" + outreq.ProtoMajor = 1 + outreq.ProtoMinor = 1 + outreq.Close = false + + if clientIp, _, err := net.SplitHostPort(req.RemoteAddr); err == nil { + outreq.Header.Set("X-Forwarded-For", clientIp) + } + + res, err := transport.RoundTrip(outreq) + if err != nil { + log.Printf("http: proxy error: %v", err) + rw.WriteHeader(StatusInternalServerError) + return + } + + hdr := rw.Header() + for k, vv := range res.Header { + for _, v := range vv { + hdr.Add(k, v) + } + } + + rw.WriteHeader(res.StatusCode) + + if res.Body != nil { + io.Copy(rw, res.Body) + } +} diff --git a/libgo/go/http/reverseproxy_test.go b/libgo/go/http/reverseproxy_test.go new file mode 100644 index 00000000000..8cf7705d745 --- /dev/null +++ b/libgo/go/http/reverseproxy_test.go @@ -0,0 +1,50 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Reverse proxy tests. + +package http_test + +import ( + . "http" + "http/httptest" + "io/ioutil" + "testing" +) + +func TestReverseProxy(t *testing.T) { + const backendResponse = "I am the backend" + const backendStatus = 404 + backend := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + if r.Header.Get("X-Forwarded-For") == "" { + t.Errorf("didn't get X-Forwarded-For header") + } + w.Header().Set("X-Foo", "bar") + w.WriteHeader(backendStatus) + w.Write([]byte(backendResponse)) + })) + defer backend.Close() + backendURL, err := ParseURL(backend.URL) + if err != nil { + t.Fatal(err) + } + proxyHandler := NewSingleHostReverseProxy(backendURL) + frontend := httptest.NewServer(proxyHandler) + defer frontend.Close() + + res, _, err := Get(frontend.URL) + if err != nil { + t.Fatalf("Get: %v", err) + } + if g, e := res.StatusCode, backendStatus; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + if g, e := res.Header.Get("X-Foo"), "bar"; g != e { + t.Errorf("got X-Foo %q; expected %q", g, e) + } + bodyBytes, _ := ioutil.ReadAll(res.Body) + if g, e := string(bodyBytes), backendResponse; g != e { + t.Errorf("got body %q; expected %q", g, e) + } +} diff --git a/libgo/go/http/serve_test.go b/libgo/go/http/serve_test.go index cf889553fb7..7ff6ef04b1a 100644 --- a/libgo/go/http/serve_test.go +++ b/libgo/go/http/serve_test.go @@ -231,7 +231,7 @@ func TestMuxRedirectLeadingSlashes(t *testing.T) { func TestServerTimeouts(t *testing.T) { // TODO(bradfitz): convert this to use httptest.Server - l, err := net.ListenTCP("tcp", &net.TCPAddr{Port: 0}) + l, err := net.Listen("tcp", "127.0.0.1:0") if err != nil { t.Fatalf("listen error: %v", err) } @@ -247,7 +247,7 @@ func TestServerTimeouts(t *testing.T) { server := &Server{Handler: handler, ReadTimeout: 0.25 * second, WriteTimeout: 0.25 * second} go server.Serve(l) - url := fmt.Sprintf("http://localhost:%d/", addr.Port) + url := fmt.Sprintf("http://%s/", addr) // Hit the HTTP server successfully. tr := &Transport{DisableKeepAlives: true} // they interfere with this test @@ -265,7 +265,7 @@ func TestServerTimeouts(t *testing.T) { // Slow client that should timeout. t1 := time.Nanoseconds() - conn, err := net.Dial("tcp", fmt.Sprintf("localhost:%d", addr.Port)) + conn, err := net.Dial("tcp", addr.String()) if err != nil { t.Fatalf("Dial: %v", err) } @@ -534,3 +534,162 @@ func TestTLSServer(t *testing.T) { t.Errorf("expected body %q; got %q", e, g) } } + +type serverExpectTest struct { + contentLength int // of request body + expectation string // e.g. "100-continue" + readBody bool // whether handler should read the body (if false, sends StatusUnauthorized) + expectedResponse string // expected substring in first line of http response +} + +var serverExpectTests = []serverExpectTest{ + // Normal 100-continues, case-insensitive. + {100, "100-continue", true, "100 Continue"}, + {100, "100-cOntInUE", true, "100 Continue"}, + + // No 100-continue. + {100, "", true, "200 OK"}, + + // 100-continue but requesting client to deny us, + // so it never eads the body. + {100, "100-continue", false, "401 Unauthorized"}, + // Likewise without 100-continue: + {100, "", false, "401 Unauthorized"}, + + // Non-standard expectations are failures + {0, "a-pony", false, "417 Expectation Failed"}, + + // Expect-100 requested but no body + {0, "100-continue", true, "400 Bad Request"}, +} + +// Tests that the server responds to the "Expect" request header +// correctly. +func TestServerExpect(t *testing.T) { + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + // Note using r.FormValue("readbody") because for POST + // requests that would read from r.Body, which we only + // conditionally want to do. + if strings.Contains(r.URL.RawPath, "readbody=true") { + ioutil.ReadAll(r.Body) + w.Write([]byte("Hi")) + } else { + w.WriteHeader(StatusUnauthorized) + } + })) + defer ts.Close() + + runTest := func(test serverExpectTest) { + conn, err := net.Dial("tcp", ts.Listener.Addr().String()) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer conn.Close() + sendf := func(format string, args ...interface{}) { + _, err := fmt.Fprintf(conn, format, args...) + if err != nil { + t.Fatalf("On test %#v, error writing %q: %v", test, format, err) + } + } + go func() { + sendf("POST /?readbody=%v HTTP/1.1\r\n"+ + "Connection: close\r\n"+ + "Content-Length: %d\r\n"+ + "Expect: %s\r\nHost: foo\r\n\r\n", + test.readBody, test.contentLength, test.expectation) + if test.contentLength > 0 && strings.ToLower(test.expectation) != "100-continue" { + body := strings.Repeat("A", test.contentLength) + sendf(body) + } + }() + bufr := bufio.NewReader(conn) + line, err := bufr.ReadString('\n') + if err != nil { + t.Fatalf("ReadString: %v", err) + } + if !strings.Contains(line, test.expectedResponse) { + t.Errorf("for test %#v got first line=%q", test, line) + } + } + + for _, test := range serverExpectTests { + runTest(test) + } +} + +func TestServerConsumesRequestBody(t *testing.T) { + conn := new(testConn) + body := strings.Repeat("x", 1<<20) + conn.readBuf.Write([]byte(fmt.Sprintf( + "POST / HTTP/1.1\r\n"+ + "Host: test\r\n"+ + "Content-Length: %d\r\n"+ + "\r\n",len(body)))) + conn.readBuf.Write([]byte(body)) + + done := make(chan bool) + + ls := &oneConnListener{conn} + go Serve(ls, HandlerFunc(func(rw ResponseWriter, req *Request) { + if conn.readBuf.Len() < len(body)/2 { + t.Errorf("on request, read buffer length is %d; expected about 1MB", conn.readBuf.Len()) + } + rw.WriteHeader(200) + if g, e := conn.readBuf.Len(), 0; g != e { + t.Errorf("after WriteHeader, read buffer length is %d; want %d", g, e) + } + done <- true + })) + <-done +} + +func TestTimeoutHandler(t *testing.T) { + sendHi := make(chan bool, 1) + writeErrors := make(chan os.Error, 1) + sayHi := HandlerFunc(func(w ResponseWriter, r *Request) { + <-sendHi + _, werr := w.Write([]byte("hi")) + writeErrors <- werr + }) + timeout := make(chan int64, 1) // write to this to force timeouts + ts := httptest.NewServer(NewTestTimeoutHandler(sayHi, timeout)) + defer ts.Close() + + // Succeed without timing out: + sendHi <- true + res, _, err := Get(ts.URL) + if err != nil { + t.Error(err) + } + if g, e := res.StatusCode, StatusOK; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + body, _ := ioutil.ReadAll(res.Body) + if g, e := string(body), "hi"; g != e { + t.Errorf("got body %q; expected %q", g, e) + } + if g := <-writeErrors; g != nil { + t.Errorf("got unexpected Write error on first request: %v", g) + } + + // Times out: + timeout <- 1 + res, _, err = Get(ts.URL) + if err != nil { + t.Error(err) + } + if g, e := res.StatusCode, StatusServiceUnavailable; g != e { + t.Errorf("got res.StatusCode %d; expected %d", g, e) + } + body, _ = ioutil.ReadAll(res.Body) + if !strings.Contains(string(body), "<title>Timeout</title>") { + t.Errorf("expected timeout body; got %q", string(body)) + } + + // Now make the previously-timed out handler speak again, + // which verifies the panic is handled: + sendHi <- true + if g, e := <-writeErrors, ErrHandlerTimeout; g != e { + t.Errorf("expected Write error of %v; got %v", e, g) + } +} diff --git a/libgo/go/http/server.go b/libgo/go/http/server.go index 8e7039371ae..d155f06a2d2 100644 --- a/libgo/go/http/server.go +++ b/libgo/go/http/server.go @@ -22,6 +22,7 @@ import ( "path" "strconv" "strings" + "sync" "time" ) @@ -141,9 +142,13 @@ func newConn(rwc net.Conn, handler Handler) (c *conn, err os.Error) { type expectContinueReader struct { resp *response readCloser io.ReadCloser + closed bool } func (ecr *expectContinueReader) Read(p []byte) (n int, err os.Error) { + if ecr.closed { + return 0, os.NewError("http: Read after Close on request Body") + } if !ecr.resp.wroteContinue && !ecr.resp.conn.hijacked { ecr.resp.wroteContinue = true io.WriteString(ecr.resp.conn.buf, "HTTP/1.1 100 Continue\r\n\r\n") @@ -153,6 +158,7 @@ func (ecr *expectContinueReader) Read(p []byte) (n int, err os.Error) { } func (ecr *expectContinueReader) Close() os.Error { + ecr.closed = true return ecr.readCloser.Close() } @@ -180,12 +186,6 @@ func (c *conn) readRequest() (w *response, err os.Error) { w.req = req w.header = make(Header) w.contentLength = -1 - - // Expect 100 Continue support - if req.expectsContinue() && req.ProtoAtLeast(1, 1) { - // Wrap the Body reader with one that replies on the connection - req.Body = &expectContinueReader{readCloser: req.Body, resp: w} - } return w, nil } @@ -202,6 +202,16 @@ func (w *response) WriteHeader(code int) { log.Print("http: multiple response.WriteHeader calls") return } + + // Per RFC 2616, we should consume the request body before + // replying, if the handler hasn't already done so. + if w.req.ContentLength != 0 { + ecr, isExpecter := w.req.Body.(*expectContinueReader) + if !isExpecter || ecr.resp.wroteContinue { + w.req.Body.Close() + } + } + w.wroteHeader = true w.status = code if code == StatusNotModified { @@ -299,7 +309,7 @@ func (w *response) WriteHeader(code int) { text = "status code " + codestring } io.WriteString(w.conn.buf, proto+" "+codestring+" "+text+"\r\n") - writeSortedHeader(w.conn.buf, w.header, nil) + w.header.Write(w.conn.buf) io.WriteString(w.conn.buf, "\r\n") } @@ -413,6 +423,9 @@ func (w *response) finishRequest() { } w.conn.buf.Flush() w.req.Body.Close() + if w.req.MultipartForm != nil { + w.req.MultipartForm.RemoveAll() + } if w.contentLength != -1 && w.contentLength != w.written { // Did not write enough. Avoid getting out of sync. @@ -446,6 +459,38 @@ func (c *conn) serve() { if err != nil { break } + + // Expect 100 Continue support + req := w.req + if req.expectsContinue() { + if req.ProtoAtLeast(1, 1) { + // Wrap the Body reader with one that replies on the connection + req.Body = &expectContinueReader{readCloser: req.Body, resp: w} + } + if req.ContentLength == 0 { + w.Header().Set("Connection", "close") + w.WriteHeader(StatusBadRequest) + break + } + req.Header.Del("Expect") + } else if req.Header.Get("Expect") != "" { + // TODO(bradfitz): let ServeHTTP handlers handle + // requests with non-standard expectation[s]? Seems + // theoretical at best, and doesn't fit into the + // current ServeHTTP model anyway. We'd need to + // make the ResponseWriter an optional + // "ExpectReplier" interface or something. + // + // For now we'll just obey RFC 2616 14.20 which says + // "If a server receives a request containing an + // Expect field that includes an expectation- + // extension that it does not support, it MUST + // respond with a 417 (Expectation Failed) status." + w.Header().Set("Connection", "close") + w.WriteHeader(StatusExpectationFailed) + break + } + // HTTP cannot have multiple simultaneous active requests.[*] // Until the server replies to this request, it can't read another, // so we might as well run the handler in this goroutine. @@ -857,3 +902,89 @@ func ListenAndServeTLS(addr string, certFile string, keyFile string, handler Han tlsListener := tls.NewListener(conn, config) return Serve(tlsListener, handler) } + +// TimeoutHandler returns a Handler that runs h with the given time limit. +// +// The new Handler calls h.ServeHTTP to handle each request, but if a +// call runs for more than ns nanoseconds, the handler responds with +// a 503 Service Unavailable error and the given message in its body. +// (If msg is empty, a suitable default message will be sent.) +// After such a timeout, writes by h to its ResponseWriter will return +// ErrHandlerTimeout. +func TimeoutHandler(h Handler, ns int64, msg string) Handler { + f := func() <-chan int64 { + return time.After(ns) + } + return &timeoutHandler{h, f, msg} +} + +// ErrHandlerTimeout is returned on ResponseWriter Write calls +// in handlers which have timed out. +var ErrHandlerTimeout = os.NewError("http: Handler timeout") + +type timeoutHandler struct { + handler Handler + timeout func() <-chan int64 // returns channel producing a timeout + body string +} + +func (h *timeoutHandler) errorBody() string { + if h.body != "" { + return h.body + } + return "<html><head><title>Timeout</title></head><body><h1>Timeout</h1></body></html>" +} + +func (h *timeoutHandler) ServeHTTP(w ResponseWriter, r *Request) { + done := make(chan bool) + tw := &timeoutWriter{w: w} + go func() { + h.handler.ServeHTTP(tw, r) + done <- true + }() + select { + case <-done: + return + case <-h.timeout(): + tw.mu.Lock() + defer tw.mu.Unlock() + if !tw.wroteHeader { + tw.w.WriteHeader(StatusServiceUnavailable) + tw.w.Write([]byte(h.errorBody())) + } + tw.timedOut = true + } +} + +type timeoutWriter struct { + w ResponseWriter + + mu sync.Mutex + timedOut bool + wroteHeader bool +} + +func (tw *timeoutWriter) Header() Header { + return tw.w.Header() +} + +func (tw *timeoutWriter) Write(p []byte) (int, os.Error) { + tw.mu.Lock() + timedOut := tw.timedOut + tw.mu.Unlock() + if timedOut { + return 0, ErrHandlerTimeout + } + return tw.w.Write(p) +} + +func (tw *timeoutWriter) WriteHeader(code int) { + tw.mu.Lock() + if tw.timedOut || tw.wroteHeader { + tw.mu.Unlock() + return + } + tw.wroteHeader = true + tw.mu.Unlock() + tw.w.WriteHeader(code) +} diff --git a/libgo/go/http/spdy/protocol.go b/libgo/go/http/spdy/protocol.go new file mode 100644 index 00000000000..d584ea232ea --- /dev/null +++ b/libgo/go/http/spdy/protocol.go @@ -0,0 +1,367 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package spdy is an incomplete implementation of the SPDY protocol. +// +// The implementation follows draft 2 of the spec: +// https://sites.google.com/a/chromium.org/dev/spdy/spdy-protocol/spdy-protocol-draft2 +package spdy + +import ( + "bytes" + "compress/zlib" + "encoding/binary" + "http" + "io" + "os" + "strconv" + "strings" + "sync" +) + +// Version is the protocol version number that this package implements. +const Version = 2 + +// ControlFrameType stores the type field in a control frame header. +type ControlFrameType uint16 + +// Control frame type constants +const ( + TypeSynStream ControlFrameType = 0x0001 + TypeSynReply = 0x0002 + TypeRstStream = 0x0003 + TypeSettings = 0x0004 + TypeNoop = 0x0005 + TypePing = 0x0006 + TypeGoaway = 0x0007 + TypeHeaders = 0x0008 + TypeWindowUpdate = 0x0009 +) + +func (t ControlFrameType) String() string { + switch t { + case TypeSynStream: + return "SYN_STREAM" + case TypeSynReply: + return "SYN_REPLY" + case TypeRstStream: + return "RST_STREAM" + case TypeSettings: + return "SETTINGS" + case TypeNoop: + return "NOOP" + case TypePing: + return "PING" + case TypeGoaway: + return "GOAWAY" + case TypeHeaders: + return "HEADERS" + case TypeWindowUpdate: + return "WINDOW_UPDATE" + } + return "Type(" + strconv.Itoa(int(t)) + ")" +} + +type FrameFlags uint8 + +// Stream frame flags +const ( + FlagFin FrameFlags = 0x01 + FlagUnidirectional = 0x02 +) + +// SETTINGS frame flags +const ( + FlagClearPreviouslyPersistedSettings FrameFlags = 0x01 +) + +// MaxDataLength is the maximum number of bytes that can be stored in one frame. +const MaxDataLength = 1<<24 - 1 + +// A Frame is a framed message as sent between clients and servers. +// There are two types of frames: control frames and data frames. +type Frame struct { + Header [4]byte + Flags FrameFlags + Data []byte +} + +// ControlFrame creates a control frame with the given information. +func ControlFrame(t ControlFrameType, f FrameFlags, data []byte) Frame { + return Frame{ + Header: [4]byte{ + (Version&0xff00)>>8 | 0x80, + (Version & 0x00ff), + byte((t & 0xff00) >> 8), + byte((t & 0x00ff) >> 0), + }, + Flags: f, + Data: data, + } +} + +// DataFrame creates a data frame with the given information. +func DataFrame(streamId uint32, f FrameFlags, data []byte) Frame { + return Frame{ + Header: [4]byte{ + byte(streamId & 0x7f000000 >> 24), + byte(streamId & 0x00ff0000 >> 16), + byte(streamId & 0x0000ff00 >> 8), + byte(streamId & 0x000000ff >> 0), + }, + Flags: f, + Data: data, + } +} + +// ReadFrame reads an entire frame into memory. +func ReadFrame(r io.Reader) (f Frame, err os.Error) { + _, err = io.ReadFull(r, f.Header[:]) + if err != nil { + return + } + err = binary.Read(r, binary.BigEndian, &f.Flags) + if err != nil { + return + } + var lengthField [3]byte + _, err = io.ReadFull(r, lengthField[:]) + if err != nil { + if err == os.EOF { + err = io.ErrUnexpectedEOF + } + return + } + var length uint32 + length |= uint32(lengthField[0]) << 16 + length |= uint32(lengthField[1]) << 8 + length |= uint32(lengthField[2]) << 0 + if length > 0 { + f.Data = make([]byte, int(length)) + _, err = io.ReadFull(r, f.Data) + if err == os.EOF { + err = io.ErrUnexpectedEOF + } + } else { + f.Data = []byte{} + } + return +} + +// IsControl returns whether the frame holds a control frame. +func (f Frame) IsControl() bool { + return f.Header[0]&0x80 != 0 +} + +// Type obtains the type field if the frame is a control frame, otherwise it returns zero. +func (f Frame) Type() ControlFrameType { + if !f.IsControl() { + return 0 + } + return (ControlFrameType(f.Header[2])<<8 | ControlFrameType(f.Header[3])) +} + +// StreamId returns the stream ID field if the frame is a data frame, otherwise it returns zero. +func (f Frame) StreamId() (id uint32) { + if f.IsControl() { + return 0 + } + id |= uint32(f.Header[0]) << 24 + id |= uint32(f.Header[1]) << 16 + id |= uint32(f.Header[2]) << 8 + id |= uint32(f.Header[3]) << 0 + return +} + +// WriteTo writes the frame in the SPDY format. +func (f Frame) WriteTo(w io.Writer) (n int64, err os.Error) { + var nn int + // Header + nn, err = w.Write(f.Header[:]) + n += int64(nn) + if err != nil { + return + } + // Flags + nn, err = w.Write([]byte{byte(f.Flags)}) + n += int64(nn) + if err != nil { + return + } + // Length + nn, err = w.Write([]byte{ + byte(len(f.Data) & 0x00ff0000 >> 16), + byte(len(f.Data) & 0x0000ff00 >> 8), + byte(len(f.Data) & 0x000000ff), + }) + n += int64(nn) + if err != nil { + return + } + // Data + if len(f.Data) > 0 { + nn, err = w.Write(f.Data) + n += int64(nn) + } + return +} + +// headerDictionary is the dictionary sent to the zlib compressor/decompressor. +// Even though the specification states there is no null byte at the end, Chrome sends it. +const headerDictionary = "optionsgetheadpostputdeletetrace" + + "acceptaccept-charsetaccept-encodingaccept-languageauthorizationexpectfromhost" + + "if-modified-sinceif-matchif-none-matchif-rangeif-unmodifiedsince" + + "max-forwardsproxy-authorizationrangerefererteuser-agent" + + "100101200201202203204205206300301302303304305306307400401402403404405406407408409410411412413414415416417500501502503504505" + + "accept-rangesageetaglocationproxy-authenticatepublicretry-after" + + "servervarywarningwww-authenticateallowcontent-basecontent-encodingcache-control" + + "connectiondatetrailertransfer-encodingupgradeviawarning" + + "content-languagecontent-lengthcontent-locationcontent-md5content-rangecontent-typeetagexpireslast-modifiedset-cookie" + + "MondayTuesdayWednesdayThursdayFridaySaturdaySunday" + + "JanFebMarAprMayJunJulAugSepOctNovDec" + + "chunkedtext/htmlimage/pngimage/jpgimage/gifapplication/xmlapplication/xhtmltext/plainpublicmax-age" + + "charset=iso-8859-1utf-8gzipdeflateHTTP/1.1statusversionurl\x00" + +// hrSource is a reader that passes through reads from another reader. +// When the underlying reader reaches EOF, Read will block until another reader is added via change. +type hrSource struct { + r io.Reader + m sync.RWMutex + c *sync.Cond +} + +func (src *hrSource) Read(p []byte) (n int, err os.Error) { + src.m.RLock() + for src.r == nil { + src.c.Wait() + } + n, err = src.r.Read(p) + src.m.RUnlock() + if err == os.EOF { + src.change(nil) + err = nil + } + return +} + +func (src *hrSource) change(r io.Reader) { + src.m.Lock() + defer src.m.Unlock() + src.r = r + src.c.Broadcast() +} + +// A HeaderReader reads zlib-compressed headers. +type HeaderReader struct { + source hrSource + decompressor io.ReadCloser +} + +// NewHeaderReader creates a HeaderReader with the initial dictionary. +func NewHeaderReader() (hr *HeaderReader) { + hr = new(HeaderReader) + hr.source.c = sync.NewCond(hr.source.m.RLocker()) + return +} + +// ReadHeader reads a set of headers from a reader. +func (hr *HeaderReader) ReadHeader(r io.Reader) (h http.Header, err os.Error) { + hr.source.change(r) + h, err = hr.read() + return +} + +// Decode reads a set of headers from a block of bytes. +func (hr *HeaderReader) Decode(data []byte) (h http.Header, err os.Error) { + hr.source.change(bytes.NewBuffer(data)) + h, err = hr.read() + return +} + +func (hr *HeaderReader) read() (h http.Header, err os.Error) { + var count uint16 + if hr.decompressor == nil { + hr.decompressor, err = zlib.NewReaderDict(&hr.source, []byte(headerDictionary)) + if err != nil { + return + } + } + err = binary.Read(hr.decompressor, binary.BigEndian, &count) + if err != nil { + return + } + h = make(http.Header, int(count)) + for i := 0; i < int(count); i++ { + var name, value string + name, err = readHeaderString(hr.decompressor) + if err != nil { + return + } + value, err = readHeaderString(hr.decompressor) + if err != nil { + return + } + valueList := strings.Split(string(value), "\x00", -1) + for _, v := range valueList { + h.Add(name, v) + } + } + return +} + +func readHeaderString(r io.Reader) (s string, err os.Error) { + var length uint16 + err = binary.Read(r, binary.BigEndian, &length) + if err != nil { + return + } + data := make([]byte, int(length)) + _, err = io.ReadFull(r, data) + if err != nil { + return + } + return string(data), nil +} + +// HeaderWriter will write zlib-compressed headers on different streams. +type HeaderWriter struct { + compressor *zlib.Writer + buffer *bytes.Buffer +} + +// NewHeaderWriter creates a HeaderWriter ready to compress headers. +func NewHeaderWriter(level int) (hw *HeaderWriter) { + hw = &HeaderWriter{buffer: new(bytes.Buffer)} + hw.compressor, _ = zlib.NewWriterDict(hw.buffer, level, []byte(headerDictionary)) + return +} + +// WriteHeader writes a header block directly to an output. +func (hw *HeaderWriter) WriteHeader(w io.Writer, h http.Header) (err os.Error) { + hw.write(h) + _, err = io.Copy(w, hw.buffer) + hw.buffer.Reset() + return +} + +// Encode returns a compressed header block. +func (hw *HeaderWriter) Encode(h http.Header) (data []byte) { + hw.write(h) + data = make([]byte, hw.buffer.Len()) + hw.buffer.Read(data) + return +} + +func (hw *HeaderWriter) write(h http.Header) { + binary.Write(hw.compressor, binary.BigEndian, uint16(len(h))) + for k, vals := range h { + k = strings.ToLower(k) + binary.Write(hw.compressor, binary.BigEndian, uint16(len(k))) + binary.Write(hw.compressor, binary.BigEndian, []byte(k)) + v := strings.Join(vals, "\x00") + binary.Write(hw.compressor, binary.BigEndian, uint16(len(v))) + binary.Write(hw.compressor, binary.BigEndian, []byte(v)) + } + hw.compressor.Flush() +} diff --git a/libgo/go/http/spdy/protocol_test.go b/libgo/go/http/spdy/protocol_test.go new file mode 100644 index 00000000000..998ff998bc7 --- /dev/null +++ b/libgo/go/http/spdy/protocol_test.go @@ -0,0 +1,259 @@ +// Copyright 2011 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package spdy + +import ( + "bytes" + "compress/zlib" + "http" + "os" + "testing" +) + +type frameIoTest struct { + desc string + data []byte + frame Frame + readError os.Error + readOnly bool +} + +var frameIoTests = []frameIoTest{ + { + "noop frame", + []byte{ + 0x80, 0x02, 0x00, 0x05, + 0x00, 0x00, 0x00, 0x00, + }, + ControlFrame( + TypeNoop, + 0x00, + []byte{}, + ), + nil, + false, + }, + { + "ping frame", + []byte{ + 0x80, 0x02, 0x00, 0x06, + 0x00, 0x00, 0x00, 0x04, + 0x00, 0x00, 0x00, 0x01, + }, + ControlFrame( + TypePing, + 0x00, + []byte{0x00, 0x00, 0x00, 0x01}, + ), + nil, + false, + }, + { + "syn_stream frame", + []byte{ + 0x80, 0x02, 0x00, 0x01, + 0x01, 0x00, 0x00, 0x53, + 0x00, 0x00, 0x00, 0x01, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x78, 0xbb, + 0xdf, 0xa2, 0x51, 0xb2, + 0x62, 0x60, 0x66, 0x60, + 0xcb, 0x4d, 0x2d, 0xc9, + 0xc8, 0x4f, 0x61, 0x60, + 0x4e, 0x4f, 0x2d, 0x61, + 0x60, 0x2e, 0x2d, 0xca, + 0x61, 0x10, 0xcb, 0x28, + 0x29, 0x29, 0xb0, 0xd2, + 0xd7, 0x2f, 0x2f, 0x2f, + 0xd7, 0x4b, 0xcf, 0xcf, + 0x4f, 0xcf, 0x49, 0xd5, + 0x4b, 0xce, 0xcf, 0xd5, + 0x67, 0x60, 0x2f, 0x4b, + 0x2d, 0x2a, 0xce, 0xcc, + 0xcf, 0x63, 0xe0, 0x00, + 0x29, 0xd0, 0x37, 0xd4, + 0x33, 0x04, 0x00, 0x00, + 0x00, 0xff, 0xff, + }, + ControlFrame( + TypeSynStream, + 0x01, + []byte{ + 0x00, 0x00, 0x00, 0x01, + 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x78, 0xbb, + 0xdf, 0xa2, 0x51, 0xb2, + 0x62, 0x60, 0x66, 0x60, + 0xcb, 0x4d, 0x2d, 0xc9, + 0xc8, 0x4f, 0x61, 0x60, + 0x4e, 0x4f, 0x2d, 0x61, + 0x60, 0x2e, 0x2d, 0xca, + 0x61, 0x10, 0xcb, 0x28, + 0x29, 0x29, 0xb0, 0xd2, + 0xd7, 0x2f, 0x2f, 0x2f, + 0xd7, 0x4b, 0xcf, 0xcf, + 0x4f, 0xcf, 0x49, 0xd5, + 0x4b, 0xce, 0xcf, 0xd5, + 0x67, 0x60, 0x2f, 0x4b, + 0x2d, 0x2a, 0xce, 0xcc, + 0xcf, 0x63, 0xe0, 0x00, + 0x29, 0xd0, 0x37, 0xd4, + 0x33, 0x04, 0x00, 0x00, + 0x00, 0xff, 0xff, + }, + ), + nil, + false, + }, + { + "data frame", + []byte{ + 0x00, 0x00, 0x00, 0x05, + 0x01, 0x00, 0x00, 0x04, + 0x01, 0x02, 0x03, 0x04, + }, + DataFrame( + 5, + 0x01, + []byte{0x01, 0x02, 0x03, 0x04}, + ), + nil, + false, + }, + { + "too much data", + []byte{ + 0x00, 0x00, 0x00, 0x05, + 0x01, 0x00, 0x00, 0x04, + 0x01, 0x02, 0x03, 0x04, + 0x05, 0x06, 0x07, 0x08, + }, + DataFrame( + 5, + 0x01, + []byte{0x01, 0x02, 0x03, 0x04}, + ), + nil, + true, + }, + { + "not enough data", + []byte{ + 0x00, 0x00, 0x00, 0x05, + }, + Frame{}, + os.EOF, + true, + }, +} + +func TestReadFrame(t *testing.T) { + for _, tt := range frameIoTests { + f, err := ReadFrame(bytes.NewBuffer(tt.data)) + if err != tt.readError { + t.Errorf("%s: ReadFrame: %s", tt.desc, err) + continue + } + if err == nil { + if !bytes.Equal(f.Header[:], tt.frame.Header[:]) { + t.Errorf("%s: header %q != %q", tt.desc, string(f.Header[:]), string(tt.frame.Header[:])) + } + if f.Flags != tt.frame.Flags { + t.Errorf("%s: flags %#02x != %#02x", tt.desc, f.Flags, tt.frame.Flags) + } + if !bytes.Equal(f.Data, tt.frame.Data) { + t.Errorf("%s: data %q != %q", tt.desc, string(f.Data), string(tt.frame.Data)) + } + } + } +} + +func TestWriteTo(t *testing.T) { + for _, tt := range frameIoTests { + if tt.readOnly { + continue + } + b := new(bytes.Buffer) + _, err := tt.frame.WriteTo(b) + if err != nil { + t.Errorf("%s: WriteTo: %s", tt.desc, err) + } + if !bytes.Equal(b.Bytes(), tt.data) { + t.Errorf("%s: data %q != %q", tt.desc, string(b.Bytes()), string(tt.data)) + } + } +} + +var headerDataTest = []byte{ + 0x78, 0xbb, 0xdf, 0xa2, + 0x51, 0xb2, 0x62, 0x60, + 0x66, 0x60, 0xcb, 0x4d, + 0x2d, 0xc9, 0xc8, 0x4f, + 0x61, 0x60, 0x4e, 0x4f, + 0x2d, 0x61, 0x60, 0x2e, + 0x2d, 0xca, 0x61, 0x10, + 0xcb, 0x28, 0x29, 0x29, + 0xb0, 0xd2, 0xd7, 0x2f, + 0x2f, 0x2f, 0xd7, 0x4b, + 0xcf, 0xcf, 0x4f, 0xcf, + 0x49, 0xd5, 0x4b, 0xce, + 0xcf, 0xd5, 0x67, 0x60, + 0x2f, 0x4b, 0x2d, 0x2a, + 0xce, 0xcc, 0xcf, 0x63, + 0xe0, 0x00, 0x29, 0xd0, + 0x37, 0xd4, 0x33, 0x04, + 0x00, 0x00, 0x00, 0xff, + 0xff, +} + +func TestReadHeader(t *testing.T) { + r := NewHeaderReader() + h, err := r.Decode(headerDataTest) + if err != nil { + t.Fatalf("Error: %v", err) + return + } + if len(h) != 3 { + t.Errorf("Header count = %d (expected 3)", len(h)) + } + if h.Get("Url") != "http://www.google.com/" { + t.Errorf("Url: %q != %q", h.Get("Url"), "http://www.google.com/") + } + if h.Get("Method") != "get" { + t.Errorf("Method: %q != %q", h.Get("Method"), "get") + } + if h.Get("Version") != "http/1.1" { + t.Errorf("Version: %q != %q", h.Get("Version"), "http/1.1") + } +} + +func TestWriteHeader(t *testing.T) { + for level := zlib.NoCompression; level <= zlib.BestCompression; level++ { + r := NewHeaderReader() + w := NewHeaderWriter(level) + for i := 0; i < 100; i++ { + b := new(bytes.Buffer) + gold := http.Header{ + "Url": []string{"http://www.google.com/"}, + "Method": []string{"get"}, + "Version": []string{"http/1.1"}, + } + w.WriteHeader(b, gold) + h, err := r.Decode(b.Bytes()) + if err != nil { + t.Errorf("(level=%d i=%d) Error: %v", level, i, err) + return + } + if len(h) != len(gold) { + t.Errorf("(level=%d i=%d) Header count = %d (expected %d)", level, i, len(h), len(gold)) + } + for k, _ := range h { + if h.Get(k) != gold.Get(k) { + t.Errorf("(level=%d i=%d) %s: %q != %q", level, i, k, h.Get(k), gold.Get(k)) + } + } + } + } +} diff --git a/libgo/go/http/transfer.go b/libgo/go/http/transfer.go index 41614f144fe..0fa8bed43aa 100644 --- a/libgo/go/http/transfer.go +++ b/libgo/go/http/transfer.go @@ -7,6 +7,7 @@ package http import ( "bufio" "io" + "io/ioutil" "os" "strconv" "strings" @@ -438,26 +439,39 @@ type body struct { hdr interface{} // non-nil (Response or Request) value means read trailer r *bufio.Reader // underlying wire-format reader for the trailer closing bool // is the connection to be closed after reading body? + closed bool +} + +// ErrBodyReadAfterClose is returned when reading a Request Body after +// the body has been closed. This typically happens when the body is +// read after an HTTP Handler calls WriteHeader or Write on its +// ResponseWriter. +var ErrBodyReadAfterClose = os.NewError("http: invalid Read on closed request Body") + +func (b *body) Read(p []byte) (n int, err os.Error) { + if b.closed { + return 0, ErrBodyReadAfterClose + } + return b.Reader.Read(p) } func (b *body) Close() os.Error { + if b.closed { + return nil + } + defer func() { + b.closed = true + }() if b.hdr == nil && b.closing { // no trailer and closing the connection next. // no point in reading to EOF. return nil } - trashBuf := make([]byte, 1024) // local for thread safety - for { - _, err := b.Read(trashBuf) - if err == nil { - continue - } - if err == os.EOF { - break - } + if _, err := io.Copy(ioutil.Discard, b); err != nil { return err } + if b.hdr == nil { // not reading trailer return nil } diff --git a/libgo/go/http/transport.go b/libgo/go/http/transport.go index 797d134aa85..73a2c2191ea 100644 --- a/libgo/go/http/transport.go +++ b/libgo/go/http/transport.go @@ -6,6 +6,8 @@ package http import ( "bufio" + "bytes" + "compress/gzip" "crypto/tls" "encoding/base64" "fmt" @@ -39,8 +41,9 @@ type Transport struct { // TODO: tunable on timeout on cached connections // TODO: optional pipelining - IgnoreEnvironment bool // don't look at environment variables for proxy configuration - DisableKeepAlives bool + IgnoreEnvironment bool // don't look at environment variables for proxy configuration + DisableKeepAlives bool + DisableCompression bool // MaxIdleConnsPerHost, if non-zero, controls the maximum idle // (keep-alive) to keep to keep per-host. If zero, @@ -215,6 +218,9 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, os.Error) { conn, err := net.Dial("tcp", cm.addr()) if err != nil { + if cm.proxyURL != nil { + err = fmt.Errorf("http: error connecting to proxy %s: %v", cm.proxyURL, err) + } return nil, err } @@ -286,10 +292,28 @@ func (t *Transport) getConn(cm *connectMethod) (*persistConn, os.Error) { // useProxy returns true if requests to addr should use a proxy, // according to the NO_PROXY or no_proxy environment variable. +// addr is always a canonicalAddr with a host and port. func (t *Transport) useProxy(addr string) bool { if len(addr) == 0 { return true } + host, _, err := net.SplitHostPort(addr) + if err != nil { + return false + } + if host == "localhost" { + return false + } + if ip := net.ParseIP(host); ip != nil { + if ip4 := ip.To4(); ip4 != nil && ip4[0] == 127 { + // 127.0.0.0/8 loopback isn't proxied. + return false + } + if bytes.Equal(ip, net.IPv6loopback) { + return false + } + } + no_proxy := t.getenvEitherCase("NO_PROXY") if no_proxy == "*" { return false @@ -474,6 +498,19 @@ func (pc *persistConn) roundTrip(req *Request) (resp *Response, err os.Error) { pc.mutateRequestFunc(req) } + // Ask for a compressed version if the caller didn't set their + // own value for Accept-Encoding. We only attempted to + // uncompress the gzip stream if we were the layer that + // requested it. + requestedGzip := false + if !pc.t.DisableCompression && req.Header.Get("Accept-Encoding") == "" { + // Request gzip only, not deflate. Deflate is ambiguous and + // as universally supported anyway. + // See: http://www.gzip.org/zlib/zlib_faq.html#faq38 + requestedGzip = true + req.Header.Set("Accept-Encoding", "gzip") + } + pc.lk.Lock() pc.numExpectedResponses++ pc.lk.Unlock() @@ -490,6 +527,20 @@ func (pc *persistConn) roundTrip(req *Request) (resp *Response, err os.Error) { pc.lk.Lock() pc.numExpectedResponses-- pc.lk.Unlock() + + if re.err == nil && requestedGzip && re.res.Header.Get("Content-Encoding") == "gzip" { + re.res.Header.Del("Content-Encoding") + re.res.Header.Del("Content-Length") + re.res.ContentLength = -1 + esb := re.res.Body.(*bodyEOFSignal) + gzReader, err := gzip.NewReader(esb.body) + if err != nil { + pc.close() + return nil, err + } + esb.body = &readFirstCloseBoth{gzReader, esb.body} + } + return re.res, re.err } @@ -526,7 +577,7 @@ func responseIsKeepAlive(res *Response) bool { func readResponseWithEOFSignal(r *bufio.Reader, requestMethod string) (resp *Response, err os.Error) { resp, err = ReadResponse(r, requestMethod) if err == nil && resp.ContentLength != 0 { - resp.Body = &bodyEOFSignal{resp.Body, nil} + resp.Body = &bodyEOFSignal{body: resp.Body} } return } @@ -535,12 +586,16 @@ func readResponseWithEOFSignal(r *bufio.Reader, requestMethod string) (resp *Res // once, right before the final Read() or Close() call returns, but after // EOF has been seen. type bodyEOFSignal struct { - body io.ReadCloser - fn func() + body io.ReadCloser + fn func() + isClosed bool } func (es *bodyEOFSignal) Read(p []byte) (n int, err os.Error) { n, err = es.body.Read(p) + if es.isClosed && n > 0 { + panic("http: unexpected bodyEOFSignal Read after Close; see issue 1725") + } if err == os.EOF && es.fn != nil { es.fn() es.fn = nil @@ -549,6 +604,7 @@ func (es *bodyEOFSignal) Read(p []byte) (n int, err os.Error) { } func (es *bodyEOFSignal) Close() (err os.Error) { + es.isClosed = true err = es.body.Close() if err == nil && es.fn != nil { es.fn() @@ -556,3 +612,19 @@ func (es *bodyEOFSignal) Close() (err os.Error) { } return } + +type readFirstCloseBoth struct { + io.ReadCloser + io.Closer +} + +func (r *readFirstCloseBoth) Close() os.Error { + if err := r.ReadCloser.Close(); err != nil { + r.Closer.Close() + return err + } + if err := r.Closer.Close(); err != nil { + return err + } + return nil +} diff --git a/libgo/go/http/transport_test.go b/libgo/go/http/transport_test.go index e46f830c828..7610856738d 100644 --- a/libgo/go/http/transport_test.go +++ b/libgo/go/http/transport_test.go @@ -7,11 +7,16 @@ package http_test import ( + "bytes" + "compress/gzip" + "crypto/rand" "fmt" . "http" "http/httptest" + "io" "io/ioutil" "os" + "strconv" "testing" "time" ) @@ -24,7 +29,7 @@ var hostPortHandler = HandlerFunc(func(w ResponseWriter, r *Request) { if r.FormValue("close") == "true" { w.Header().Set("Connection", "close") } - fmt.Fprintf(w, "%s", r.RemoteAddr) + w.Write([]byte(r.RemoteAddr)) }) // Two subsequent requests and verify their response is the same. @@ -177,35 +182,47 @@ func TestTransportIdleCacheKeys(t *testing.T) { } func TestTransportMaxPerHostIdleConns(t *testing.T) { - ch := make(chan string) + resch := make(chan string) + gotReq := make(chan bool) ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { - fmt.Fprintf(w, "%s", <-ch) + gotReq <- true + msg := <-resch + _, err := w.Write([]byte(msg)) + if err != nil { + t.Fatalf("Write: %v", err) + } })) defer ts.Close() maxIdleConns := 2 tr := &Transport{DisableKeepAlives: false, MaxIdleConnsPerHost: maxIdleConns} c := &Client{Transport: tr} - // Start 3 outstanding requests (will hang until we write to - // ch) + // Start 3 outstanding requests and wait for the server to get them. + // Their responses will hang until we we write to resch, though. donech := make(chan bool) doReq := func() { resp, _, err := c.Get(ts.URL) if err != nil { t.Error(err) } - ioutil.ReadAll(resp.Body) + _, err = ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("ReadAll: %v", err) + } donech <- true } go doReq() + <-gotReq go doReq() + <-gotReq go doReq() + <-gotReq if e, g := 0, len(tr.IdleConnKeysForTesting()); e != g { t.Fatalf("Before writes, expected %d idle conn cache keys; got %d", e, g) } - ch <- "res1" + resch <- "res1" <-donech keys := tr.IdleConnKeysForTesting() if e, g := 1, len(keys); e != g { @@ -219,13 +236,13 @@ func TestTransportMaxPerHostIdleConns(t *testing.T) { t.Errorf("after first response, expected %d idle conns; got %d", e, g) } - ch <- "res2" + resch <- "res2" <-donech if e, g := 2, tr.IdleConnCountForTesting(cacheKey); e != g { t.Errorf("after second response, expected %d idle conns; got %d", e, g) } - ch <- "res3" + resch <- "res3" <-donech if e, g := maxIdleConns, tr.IdleConnCountForTesting(cacheKey); e != g { t.Errorf("after third response, still expected %d idle conns; got %d", e, g) @@ -239,26 +256,44 @@ func TestTransportServerClosingUnexpectedly(t *testing.T) { tr := &Transport{} c := &Client{Transport: tr} - fetch := func(n int) string { - res, _, err := c.Get(ts.URL) - if err != nil { - t.Fatalf("error in req #%d, GET: %v", n, err) + fetch := func(n, retries int) string { + condFatalf := func(format string, arg ...interface{}) { + if retries <= 0 { + t.Fatalf(format, arg...) + } + t.Logf("retrying shortly after expected error: "+format, arg...) + time.Sleep(1e9 / int64(retries)) } - body, err := ioutil.ReadAll(res.Body) - if err != nil { - t.Fatalf("error in req #%d, ReadAll: %v", n, err) + for retries >= 0 { + retries-- + res, _, err := c.Get(ts.URL) + if err != nil { + condFatalf("error in req #%d, GET: %v", n, err) + continue + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + condFatalf("error in req #%d, ReadAll: %v", n, err) + continue + } + res.Body.Close() + return string(body) } - res.Body.Close() - return string(body) + panic("unreachable") } - body1 := fetch(1) - body2 := fetch(2) + body1 := fetch(1, 0) + body2 := fetch(2, 0) ts.CloseClientConnections() // surprise! - time.Sleep(25e6) // idle for a bit (test is inherently racey, but expectedly) - body3 := fetch(3) + // This test has an expected race. Sleeping for 25 ms prevents + // it on most fast machines, causing the next fetch() call to + // succeed quickly. But if we do get errors, fetch() will retry 5 + // times with some delays between. + time.Sleep(25e6) + + body3 := fetch(3, 5) if body1 != body2 { t.Errorf("expected body1 and body2 to be equal") @@ -288,10 +323,10 @@ func TestTransportHeadResponses(t *testing.T) { t.Errorf("error on loop %d: %v", i, err) } if e, g := "123", res.Header.Get("Content-Length"); e != g { - t.Errorf("loop %d: expected Content-Length header of %q, got %q", e, g) + t.Errorf("loop %d: expected Content-Length header of %q, got %q", i, e, g) } if e, g := int64(0), res.ContentLength; e != g { - t.Errorf("loop %d: expected res.ContentLength of %v, got %v", e, g) + t.Errorf("loop %d: expected res.ContentLength of %v, got %v", i, e, g) } } } @@ -338,6 +373,7 @@ func TestTransportNilURL(t *testing.T) { req.Proto = "HTTP/1.1" req.ProtoMajor = 1 req.ProtoMinor = 1 + req.Header = make(Header) tr := &Transport{} res, err := tr.RoundTrip(req) @@ -349,3 +385,147 @@ func TestTransportNilURL(t *testing.T) { t.Fatalf("Expected response body of %q; got %q", e, g) } } + +func TestTransportGzip(t *testing.T) { + const testString = "The test string aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaa" + const nRandBytes = 1024 * 1024 + ts := httptest.NewServer(HandlerFunc(func(rw ResponseWriter, req *Request) { + if g, e := req.Header.Get("Accept-Encoding"), "gzip"; g != e { + t.Errorf("Accept-Encoding = %q, want %q", g, e) + } + rw.Header().Set("Content-Encoding", "gzip") + + var w io.Writer = rw + var buf bytes.Buffer + if req.FormValue("chunked") == "0" { + w = &buf + defer io.Copy(rw, &buf) + defer func() { + rw.Header().Set("Content-Length", strconv.Itoa(buf.Len())) + }() + } + gz, _ := gzip.NewWriter(w) + gz.Write([]byte(testString)) + if req.FormValue("body") == "large" { + io.Copyn(gz, rand.Reader, nRandBytes) + } + gz.Close() + })) + defer ts.Close() + + for _, chunked := range []string{"1", "0"} { + c := &Client{Transport: &Transport{}} + + // First fetch something large, but only read some of it. + res, _, err := c.Get(ts.URL + "?body=large&chunked=" + chunked) + if err != nil { + t.Fatalf("large get: %v", err) + } + buf := make([]byte, len(testString)) + n, err := io.ReadFull(res.Body, buf) + if err != nil { + t.Fatalf("partial read of large response: size=%d, %v", n, err) + } + if e, g := testString, string(buf); e != g { + t.Errorf("partial read got %q, expected %q", g, e) + } + res.Body.Close() + // Read on the body, even though it's closed + n, err = res.Body.Read(buf) + if n != 0 || err == nil { + t.Errorf("expected error post-closed large Read; got = %d, %v", n, err) + } + + // Then something small. + res, _, err = c.Get(ts.URL + "?chunked=" + chunked) + if err != nil { + t.Fatal(err) + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if g, e := string(body), testString; g != e { + t.Fatalf("body = %q; want %q", g, e) + } + if g, e := res.Header.Get("Content-Encoding"), ""; g != e { + t.Fatalf("Content-Encoding = %q; want %q", g, e) + } + + // Read on the body after it's been fully read: + n, err = res.Body.Read(buf) + if n != 0 || err == nil { + t.Errorf("expected Read error after exhausted reads; got %d, %v", n, err) + } + res.Body.Close() + n, err = res.Body.Read(buf) + if n != 0 || err == nil { + t.Errorf("expected Read error after Close; got %d, %v", n, err) + } + } +} + +// TestTransportGzipRecursive sends a gzip quine and checks that the +// client gets the same value back. This is more cute than anything, +// but checks that we don't recurse forever, and checks that +// Content-Encoding is removed. +func TestTransportGzipRecursive(t *testing.T) { + ts := httptest.NewServer(HandlerFunc(func(w ResponseWriter, r *Request) { + w.Header().Set("Content-Encoding", "gzip") + w.Write(rgz) + })) + defer ts.Close() + + c := &Client{Transport: &Transport{}} + res, _, err := c.Get(ts.URL) + if err != nil { + t.Fatal(err) + } + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Fatal(err) + } + if !bytes.Equal(body, rgz) { + t.Fatalf("Incorrect result from recursive gz:\nhave=%x\nwant=%x", + body, rgz) + } + if g, e := res.Header.Get("Content-Encoding"), ""; g != e { + t.Fatalf("Content-Encoding = %q; want %q", g, e) + } +} + +// rgz is a gzip quine that uncompresses to itself. +var rgz = []byte{ + 0x1f, 0x8b, 0x08, 0x08, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x72, 0x65, 0x63, 0x75, 0x72, 0x73, + 0x69, 0x76, 0x65, 0x00, 0x92, 0xef, 0xe6, 0xe0, + 0x60, 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2, + 0xe2, 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17, + 0x00, 0xe8, 0xff, 0x92, 0xef, 0xe6, 0xe0, 0x60, + 0x00, 0x83, 0xa2, 0xd4, 0xe4, 0xd2, 0xa2, 0xe2, + 0xcc, 0xb2, 0x54, 0x06, 0x00, 0x00, 0x17, 0x00, + 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16, 0x06, 0x00, + 0x05, 0x00, 0xfa, 0xff, 0x42, 0x12, 0x46, 0x16, + 0x06, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, 0x05, + 0x00, 0xfa, 0xff, 0x00, 0x14, 0x00, 0xeb, 0xff, + 0x42, 0x12, 0x46, 0x16, 0x06, 0x00, 0x05, 0x00, + 0xfa, 0xff, 0x00, 0x05, 0x00, 0xfa, 0xff, 0x00, + 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, + 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, + 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, 0xeb, 0xff, + 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x14, 0x00, + 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, + 0x14, 0x00, 0xeb, 0xff, 0x42, 0x88, 0x21, 0xc4, + 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, 0x00, + 0x00, 0xff, 0xff, 0x00, 0x17, 0x00, 0xe8, 0xff, + 0x42, 0x88, 0x21, 0xc4, 0x00, 0x00, 0x00, 0x00, + 0xff, 0xff, 0x00, 0x00, 0x00, 0xff, 0xff, 0x00, + 0x17, 0x00, 0xe8, 0xff, 0x42, 0x12, 0x46, 0x16, + 0x06, 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08, + 0x00, 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa, + 0x00, 0x00, 0x00, 0x42, 0x12, 0x46, 0x16, 0x06, + 0x00, 0x00, 0x00, 0xff, 0xff, 0x01, 0x08, 0x00, + 0xf7, 0xff, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00, + 0x00, 0x00, 0x3d, 0xb1, 0x20, 0x85, 0xfa, 0x00, + 0x00, 0x00, +} diff --git a/libgo/go/http/url.go b/libgo/go/http/url.go index 0fc0cb2d76e..d7ee14ee84a 100644 --- a/libgo/go/http/url.go +++ b/libgo/go/http/url.go @@ -449,7 +449,7 @@ func ParseURLReference(rawurlref string) (url *URL, err os.Error) { // // There are redundant fields stored in the URL structure: // the String method consults Scheme, Path, Host, RawUserinfo, -// RawQuery, and Fragment, but not Raw, RawPath or Authority. +// RawQuery, and Fragment, but not Raw, RawPath or RawAuthority. func (url *URL) String() string { result := "" if url.Scheme != "" { |