summaryrefslogtreecommitdiff
path: root/src/net/http/httptest
diff options
context:
space:
mode:
Diffstat (limited to 'src/net/http/httptest')
-rw-r--r--src/net/http/httptest/example_test.go50
-rw-r--r--src/net/http/httptest/recorder.go72
-rw-r--r--src/net/http/httptest/recorder_test.go90
-rw-r--r--src/net/http/httptest/server.go228
-rw-r--r--src/net/http/httptest/server_test.go29
5 files changed, 469 insertions, 0 deletions
diff --git a/src/net/http/httptest/example_test.go b/src/net/http/httptest/example_test.go
new file mode 100644
index 000000000..42a0ec953
--- /dev/null
+++ b/src/net/http/httptest/example_test.go
@@ -0,0 +1,50 @@
+// Copyright 2013 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package httptest_test
+
+import (
+ "fmt"
+ "io/ioutil"
+ "log"
+ "net/http"
+ "net/http/httptest"
+)
+
+func ExampleResponseRecorder() {
+ handler := func(w http.ResponseWriter, r *http.Request) {
+ http.Error(w, "something failed", http.StatusInternalServerError)
+ }
+
+ req, err := http.NewRequest("GET", "http://example.com/foo", nil)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ w := httptest.NewRecorder()
+ handler(w, req)
+
+ fmt.Printf("%d - %s", w.Code, w.Body.String())
+ // Output: 500 - something failed
+}
+
+func ExampleServer() {
+ ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprintln(w, "Hello, client")
+ }))
+ defer ts.Close()
+
+ res, err := http.Get(ts.URL)
+ if err != nil {
+ log.Fatal(err)
+ }
+ greeting, err := ioutil.ReadAll(res.Body)
+ res.Body.Close()
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ fmt.Printf("%s", greeting)
+ // Output: Hello, client
+}
diff --git a/src/net/http/httptest/recorder.go b/src/net/http/httptest/recorder.go
new file mode 100644
index 000000000..5451f5423
--- /dev/null
+++ b/src/net/http/httptest/recorder.go
@@ -0,0 +1,72 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Package httptest provides utilities for HTTP testing.
+package httptest
+
+import (
+ "bytes"
+ "net/http"
+)
+
+// ResponseRecorder is an implementation of http.ResponseWriter that
+// records its mutations for later inspection in tests.
+type ResponseRecorder struct {
+ Code int // the HTTP response code from WriteHeader
+ HeaderMap http.Header // the HTTP response headers
+ Body *bytes.Buffer // if non-nil, the bytes.Buffer to append written data to
+ Flushed bool
+
+ wroteHeader bool
+}
+
+// NewRecorder returns an initialized ResponseRecorder.
+func NewRecorder() *ResponseRecorder {
+ return &ResponseRecorder{
+ HeaderMap: make(http.Header),
+ Body: new(bytes.Buffer),
+ Code: 200,
+ }
+}
+
+// DefaultRemoteAddr is the default remote address to return in RemoteAddr if
+// an explicit DefaultRemoteAddr isn't set on ResponseRecorder.
+const DefaultRemoteAddr = "1.2.3.4"
+
+// Header returns the response headers.
+func (rw *ResponseRecorder) Header() http.Header {
+ m := rw.HeaderMap
+ if m == nil {
+ m = make(http.Header)
+ rw.HeaderMap = m
+ }
+ return m
+}
+
+// Write always succeeds and writes to rw.Body, if not nil.
+func (rw *ResponseRecorder) Write(buf []byte) (int, error) {
+ if !rw.wroteHeader {
+ rw.WriteHeader(200)
+ }
+ if rw.Body != nil {
+ rw.Body.Write(buf)
+ }
+ return len(buf), nil
+}
+
+// WriteHeader sets rw.Code.
+func (rw *ResponseRecorder) WriteHeader(code int) {
+ if !rw.wroteHeader {
+ rw.Code = code
+ }
+ rw.wroteHeader = true
+}
+
+// Flush sets rw.Flushed to true.
+func (rw *ResponseRecorder) Flush() {
+ if !rw.wroteHeader {
+ rw.WriteHeader(200)
+ }
+ rw.Flushed = true
+}
diff --git a/src/net/http/httptest/recorder_test.go b/src/net/http/httptest/recorder_test.go
new file mode 100644
index 000000000..2b563260c
--- /dev/null
+++ b/src/net/http/httptest/recorder_test.go
@@ -0,0 +1,90 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package httptest
+
+import (
+ "fmt"
+ "net/http"
+ "testing"
+)
+
+func TestRecorder(t *testing.T) {
+ type checkFunc func(*ResponseRecorder) error
+ check := func(fns ...checkFunc) []checkFunc { return fns }
+
+ hasStatus := func(wantCode int) checkFunc {
+ return func(rec *ResponseRecorder) error {
+ if rec.Code != wantCode {
+ return fmt.Errorf("Status = %d; want %d", rec.Code, wantCode)
+ }
+ return nil
+ }
+ }
+ hasContents := func(want string) checkFunc {
+ return func(rec *ResponseRecorder) error {
+ if rec.Body.String() != want {
+ return fmt.Errorf("wrote = %q; want %q", rec.Body.String(), want)
+ }
+ return nil
+ }
+ }
+ hasFlush := func(want bool) checkFunc {
+ return func(rec *ResponseRecorder) error {
+ if rec.Flushed != want {
+ return fmt.Errorf("Flushed = %v; want %v", rec.Flushed, want)
+ }
+ return nil
+ }
+ }
+
+ tests := []struct {
+ name string
+ h func(w http.ResponseWriter, r *http.Request)
+ checks []checkFunc
+ }{
+ {
+ "200 default",
+ func(w http.ResponseWriter, r *http.Request) {},
+ check(hasStatus(200), hasContents("")),
+ },
+ {
+ "first code only",
+ func(w http.ResponseWriter, r *http.Request) {
+ w.WriteHeader(201)
+ w.WriteHeader(202)
+ w.Write([]byte("hi"))
+ },
+ check(hasStatus(201), hasContents("hi")),
+ },
+ {
+ "write sends 200",
+ func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte("hi first"))
+ w.WriteHeader(201)
+ w.WriteHeader(202)
+ },
+ check(hasStatus(200), hasContents("hi first"), hasFlush(false)),
+ },
+ {
+ "flush",
+ func(w http.ResponseWriter, r *http.Request) {
+ w.(http.Flusher).Flush() // also sends a 200
+ w.WriteHeader(201)
+ },
+ check(hasStatus(200), hasFlush(true)),
+ },
+ }
+ r, _ := http.NewRequest("GET", "http://foo.com/", nil)
+ for _, tt := range tests {
+ h := http.HandlerFunc(tt.h)
+ rec := NewRecorder()
+ h.ServeHTTP(rec, r)
+ for _, check := range tt.checks {
+ if err := check(rec); err != nil {
+ t.Errorf("%s: %v", tt.name, err)
+ }
+ }
+ }
+}
diff --git a/src/net/http/httptest/server.go b/src/net/http/httptest/server.go
new file mode 100644
index 000000000..789e7bf41
--- /dev/null
+++ b/src/net/http/httptest/server.go
@@ -0,0 +1,228 @@
+// Copyright 2011 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+// Implementation of Server
+
+package httptest
+
+import (
+ "crypto/tls"
+ "flag"
+ "fmt"
+ "net"
+ "net/http"
+ "os"
+ "sync"
+)
+
+// A Server is an HTTP server listening on a system-chosen port on the
+// local loopback interface, for use in end-to-end HTTP tests.
+type Server struct {
+ URL string // base URL of form http://ipaddr:port with no trailing slash
+ Listener net.Listener
+
+ // TLS is the optional TLS configuration, populated with a new config
+ // after TLS is started. If set on an unstarted server before StartTLS
+ // is called, existing fields are copied into the new config.
+ TLS *tls.Config
+
+ // Config may be changed after calling NewUnstartedServer and
+ // before Start or StartTLS.
+ Config *http.Server
+
+ // wg counts the number of outstanding HTTP requests on this server.
+ // Close blocks until all requests are finished.
+ wg sync.WaitGroup
+}
+
+// historyListener keeps track of all connections that it's ever
+// accepted.
+type historyListener struct {
+ net.Listener
+ sync.Mutex // protects history
+ history []net.Conn
+}
+
+func (hs *historyListener) Accept() (c net.Conn, err error) {
+ c, err = hs.Listener.Accept()
+ if err == nil {
+ hs.Lock()
+ hs.history = append(hs.history, c)
+ hs.Unlock()
+ }
+ return
+}
+
+func newLocalListener() net.Listener {
+ if *serve != "" {
+ l, err := net.Listen("tcp", *serve)
+ if err != nil {
+ panic(fmt.Sprintf("httptest: failed to listen on %v: %v", *serve, err))
+ }
+ return l
+ }
+ l, err := net.Listen("tcp", "127.0.0.1:0")
+ if err != nil {
+ if l, err = net.Listen("tcp6", "[::1]:0"); err != nil {
+ panic(fmt.Sprintf("httptest: failed to listen on a port: %v", err))
+ }
+ }
+ return l
+}
+
+// When debugging a particular http server-based test,
+// this flag lets you run
+// go test -run=BrokenTest -httptest.serve=127.0.0.1:8000
+// to start the broken server so you can interact with it manually.
+var serve = flag.String("httptest.serve", "", "if non-empty, httptest.NewServer serves on this address and blocks")
+
+// NewServer starts and returns a new Server.
+// The caller should call Close when finished, to shut it down.
+func NewServer(handler http.Handler) *Server {
+ ts := NewUnstartedServer(handler)
+ ts.Start()
+ return ts
+}
+
+// NewUnstartedServer returns a new Server but doesn't start it.
+//
+// After changing its configuration, the caller should call Start or
+// StartTLS.
+//
+// The caller should call Close when finished, to shut it down.
+func NewUnstartedServer(handler http.Handler) *Server {
+ return &Server{
+ Listener: newLocalListener(),
+ Config: &http.Server{Handler: handler},
+ }
+}
+
+// Start starts a server from NewUnstartedServer.
+func (s *Server) Start() {
+ if s.URL != "" {
+ panic("Server already started")
+ }
+ s.Listener = &historyListener{Listener: s.Listener}
+ s.URL = "http://" + s.Listener.Addr().String()
+ s.wrapHandler()
+ go s.Config.Serve(s.Listener)
+ if *serve != "" {
+ fmt.Fprintln(os.Stderr, "httptest: serving on", s.URL)
+ select {}
+ }
+}
+
+// StartTLS starts TLS on a server from NewUnstartedServer.
+func (s *Server) StartTLS() {
+ if s.URL != "" {
+ panic("Server already started")
+ }
+ cert, err := tls.X509KeyPair(localhostCert, localhostKey)
+ if err != nil {
+ panic(fmt.Sprintf("httptest: NewTLSServer: %v", err))
+ }
+
+ existingConfig := s.TLS
+ s.TLS = new(tls.Config)
+ if existingConfig != nil {
+ *s.TLS = *existingConfig
+ }
+ if s.TLS.NextProtos == nil {
+ s.TLS.NextProtos = []string{"http/1.1"}
+ }
+ if len(s.TLS.Certificates) == 0 {
+ s.TLS.Certificates = []tls.Certificate{cert}
+ }
+ tlsListener := tls.NewListener(s.Listener, s.TLS)
+
+ s.Listener = &historyListener{Listener: tlsListener}
+ s.URL = "https://" + s.Listener.Addr().String()
+ s.wrapHandler()
+ go s.Config.Serve(s.Listener)
+}
+
+func (s *Server) wrapHandler() {
+ h := s.Config.Handler
+ if h == nil {
+ h = http.DefaultServeMux
+ }
+ s.Config.Handler = &waitGroupHandler{
+ s: s,
+ h: h,
+ }
+}
+
+// NewTLSServer starts and returns a new Server using TLS.
+// The caller should call Close when finished, to shut it down.
+func NewTLSServer(handler http.Handler) *Server {
+ ts := NewUnstartedServer(handler)
+ ts.StartTLS()
+ return ts
+}
+
+// Close shuts down the server and blocks until all outstanding
+// requests on this server have completed.
+func (s *Server) Close() {
+ s.Listener.Close()
+ s.wg.Wait()
+ s.CloseClientConnections()
+ if t, ok := http.DefaultTransport.(*http.Transport); ok {
+ t.CloseIdleConnections()
+ }
+}
+
+// CloseClientConnections closes any currently open HTTP connections
+// to the test Server.
+func (s *Server) CloseClientConnections() {
+ hl, ok := s.Listener.(*historyListener)
+ if !ok {
+ return
+ }
+ hl.Lock()
+ for _, conn := range hl.history {
+ conn.Close()
+ }
+ hl.Unlock()
+}
+
+// waitGroupHandler wraps a handler, incrementing and decrementing a
+// sync.WaitGroup on each request, to enable Server.Close to block
+// until outstanding requests are finished.
+type waitGroupHandler struct {
+ s *Server
+ h http.Handler // non-nil
+}
+
+func (h *waitGroupHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
+ h.s.wg.Add(1)
+ defer h.s.wg.Done() // a defer, in case ServeHTTP below panics
+ h.h.ServeHTTP(w, r)
+}
+
+// localhostCert is a PEM-encoded TLS cert with SAN IPs
+// "127.0.0.1" and "[::1]", expiring at the last second of 2049 (the end
+// of ASN.1 time).
+// generated from src/crypto/tls:
+// go run generate_cert.go --rsa-bits 512 --host 127.0.0.1,::1,example.com --ca --start-date "Jan 1 00:00:00 1970" --duration=1000000h
+var localhostCert = []byte(`-----BEGIN CERTIFICATE-----
+MIIBdzCCASOgAwIBAgIBADALBgkqhkiG9w0BAQUwEjEQMA4GA1UEChMHQWNtZSBD
+bzAeFw03MDAxMDEwMDAwMDBaFw00OTEyMzEyMzU5NTlaMBIxEDAOBgNVBAoTB0Fj
+bWUgQ28wWjALBgkqhkiG9w0BAQEDSwAwSAJBAN55NcYKZeInyTuhcCwFMhDHCmwa
+IUSdtXdcbItRB/yfXGBhiex00IaLXQnSU+QZPRZWYqeTEbFSgihqi1PUDy8CAwEA
+AaNoMGYwDgYDVR0PAQH/BAQDAgCkMBMGA1UdJQQMMAoGCCsGAQUFBwMBMA8GA1Ud
+EwEB/wQFMAMBAf8wLgYDVR0RBCcwJYILZXhhbXBsZS5jb22HBH8AAAGHEAAAAAAA
+AAAAAAAAAAAAAAEwCwYJKoZIhvcNAQEFA0EAAoQn/ytgqpiLcZu9XKbCJsJcvkgk
+Se6AbGXgSlq+ZCEVo0qIwSgeBqmsJxUu7NCSOwVJLYNEBO2DtIxoYVk+MA==
+-----END CERTIFICATE-----`)
+
+// localhostKey is the private key for localhostCert.
+var localhostKey = []byte(`-----BEGIN RSA PRIVATE KEY-----
+MIIBPAIBAAJBAN55NcYKZeInyTuhcCwFMhDHCmwaIUSdtXdcbItRB/yfXGBhiex0
+0IaLXQnSU+QZPRZWYqeTEbFSgihqi1PUDy8CAwEAAQJBAQdUx66rfh8sYsgfdcvV
+NoafYpnEcB5s4m/vSVe6SU7dCK6eYec9f9wpT353ljhDUHq3EbmE4foNzJngh35d
+AekCIQDhRQG5Li0Wj8TM4obOnnXUXf1jRv0UkzE9AHWLG5q3AwIhAPzSjpYUDjVW
+MCUXgckTpKCuGwbJk7424Nb8bLzf3kllAiA5mUBgjfr/WtFSJdWcPQ4Zt9KTMNKD
+EUO0ukpTwEIl6wIhAMbGqZK3zAAFdq8DD2jPx+UJXnh0rnOkZBzDtJ6/iN69AiEA
+1Aq8MJgTaYsDQWyU/hDq5YkDJc9e9DSCvUIzqxQWMQE=
+-----END RSA PRIVATE KEY-----`)
diff --git a/src/net/http/httptest/server_test.go b/src/net/http/httptest/server_test.go
new file mode 100644
index 000000000..500a9f0b8
--- /dev/null
+++ b/src/net/http/httptest/server_test.go
@@ -0,0 +1,29 @@
+// Copyright 2012 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package httptest
+
+import (
+ "io/ioutil"
+ "net/http"
+ "testing"
+)
+
+func TestServer(t *testing.T) {
+ ts := NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte("hello"))
+ }))
+ defer ts.Close()
+ res, err := http.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ got, err := ioutil.ReadAll(res.Body)
+ if err != nil {
+ t.Fatal(err)
+ }
+ if string(got) != "hello" {
+ t.Errorf("got %q, want hello", string(got))
+ }
+}