summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJohan Brandhorst <johan.brandhorst@gmail.com>2016-12-21 13:49:04 +0000
committerBrad Fitzpatrick <bradfitz@golang.org>2017-03-03 21:02:17 +0000
commitfbf4dd91b9762eeb038c141b21712534310795f1 (patch)
tree0fc405660a54afc7ee7f20e3cce85f6f978c0c19
parent02e36f8c87901fde841cb5e82e0409b7914572b8 (diff)
downloadgo-git-fbf4dd91b9762eeb038c141b21712534310795f1.tar.gz
net/http/httptest: add Client and Certificate methods to Server
Adds a function for easily accessing the x509.Certificate of a Server, if there is one. Also adds a helper function for getting a http.Client suitable for use with the server. This makes the steps required to test a httptest TLS server simpler. Fixes #18411 Change-Id: I2e78fe1e54e31bed9c641be2d9a099f698c7bbde Reviewed-on: https://go-review.googlesource.com/34639 Reviewed-by: Brad Fitzpatrick <bradfitz@golang.org>
-rw-r--r--src/go/build/deps_test.go2
-rw-r--r--src/net/http/httptest/example_test.go22
-rw-r--r--src/net/http/httptest/server.go38
-rw-r--r--src/net/http/httptest/server_test.go23
4 files changed, 84 insertions, 1 deletions
diff --git a/src/go/build/deps_test.go b/src/go/build/deps_test.go
index f8ba53288e..2adc06f39b 100644
--- a/src/go/build/deps_test.go
+++ b/src/go/build/deps_test.go
@@ -411,7 +411,7 @@ var pkgDeps = map[string][]string{
"net/http/cgi": {"L4", "NET", "OS", "crypto/tls", "net/http", "regexp"},
"net/http/cookiejar": {"L4", "NET", "net/http"},
"net/http/fcgi": {"L4", "NET", "OS", "net/http", "net/http/cgi"},
- "net/http/httptest": {"L4", "NET", "OS", "crypto/tls", "flag", "net/http", "net/http/internal"},
+ "net/http/httptest": {"L4", "NET", "OS", "crypto/tls", "flag", "net/http", "net/http/internal", "crypto/x509"},
"net/http/httputil": {"L4", "NET", "OS", "context", "net/http", "net/http/internal"},
"net/http/pprof": {"L4", "OS", "html/template", "net/http", "runtime/pprof", "runtime/trace"},
"net/rpc": {"L4", "NET", "encoding/gob", "html/template", "net/http"},
diff --git a/src/net/http/httptest/example_test.go b/src/net/http/httptest/example_test.go
index bd2c49642b..e3d392130e 100644
--- a/src/net/http/httptest/example_test.go
+++ b/src/net/http/httptest/example_test.go
@@ -54,3 +54,25 @@ func ExampleServer() {
fmt.Printf("%s", greeting)
// Output: Hello, client
}
+
+func ExampleNewTLSServer() {
+ ts := httptest.NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ fmt.Fprintln(w, "Hello, client")
+ }))
+ defer ts.Close()
+
+ client := ts.Client()
+ res, err := client.Get(ts.URL)
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ greeting, err := ioutil.ReadAll(res.Body)
+ res.Body.Close()
+ if err != nil {
+ log.Fatal(err)
+ }
+
+ fmt.Printf("%s", greeting)
+ // Output: Hello, client
+}
diff --git a/src/net/http/httptest/server.go b/src/net/http/httptest/server.go
index 711821433b..56ad18ee9b 100644
--- a/src/net/http/httptest/server.go
+++ b/src/net/http/httptest/server.go
@@ -9,6 +9,7 @@ package httptest
import (
"bytes"
"crypto/tls"
+ "crypto/x509"
"flag"
"fmt"
"log"
@@ -35,6 +36,9 @@ type Server struct {
// before Start or StartTLS.
Config *http.Server
+ // certificate is a parsed version of the TLS config certificate, if present.
+ certificate *x509.Certificate
+
// wg counts the number of outstanding HTTP requests on this server.
// Close blocks until all requests are finished.
wg sync.WaitGroup
@@ -42,6 +46,10 @@ type Server struct {
mu sync.Mutex // guards closed and conns
closed bool
conns map[net.Conn]http.ConnState // except terminal states
+
+ // client is configured for use with the server.
+ // Its transport is automatically closed when Close is called.
+ client *http.Client
}
func newLocalListener() net.Listener {
@@ -85,6 +93,7 @@ func NewUnstartedServer(handler http.Handler) *Server {
return &Server{
Listener: newLocalListener(),
Config: &http.Server{Handler: handler},
+ client: &http.Client{},
}
}
@@ -124,6 +133,17 @@ func (s *Server) StartTLS() {
if len(s.TLS.Certificates) == 0 {
s.TLS.Certificates = []tls.Certificate{cert}
}
+ s.certificate, err = x509.ParseCertificate(s.TLS.Certificates[0].Certificate[0])
+ if err != nil {
+ panic(fmt.Sprintf("httptest: NewTLSServer: %v", err))
+ }
+ certpool := x509.NewCertPool()
+ certpool.AddCert(s.certificate)
+ s.client.Transport = &http.Transport{
+ TLSClientConfig: &tls.Config{
+ RootCAs: certpool,
+ },
+ }
s.Listener = tls.NewListener(s.Listener, s.TLS)
s.URL = "https://" + s.Listener.Addr().String()
s.wrap()
@@ -186,6 +206,11 @@ func (s *Server) Close() {
t.CloseIdleConnections()
}
+ // Also close the client idle connections.
+ if t, ok := s.client.Transport.(closeIdleTransport); ok {
+ t.CloseIdleConnections()
+ }
+
s.wg.Wait()
}
@@ -228,6 +253,19 @@ func (s *Server) CloseClientConnections() {
}
}
+// Certificate returns the certificate used by the server, or nil if
+// the server doesn't use TLS.
+func (s *Server) Certificate() *x509.Certificate {
+ return s.certificate
+}
+
+// Client returns an HTTP client configured for making requests to the server.
+// It is configured to trust the server's TLS test certificate and will
+// close its idle connections on Server.Close.
+func (s *Server) Client() *http.Client {
+ return s.client
+}
+
func (s *Server) goServe() {
s.wg.Add(1)
go func() {
diff --git a/src/net/http/httptest/server_test.go b/src/net/http/httptest/server_test.go
index d032c5983b..7d80fa15dd 100644
--- a/src/net/http/httptest/server_test.go
+++ b/src/net/http/httptest/server_test.go
@@ -22,6 +22,7 @@ func TestServer(t *testing.T) {
t.Fatal(err)
}
got, err := ioutil.ReadAll(res.Body)
+ res.Body.Close()
if err != nil {
t.Fatal(err)
}
@@ -98,3 +99,25 @@ func TestServerCloseClientConnections(t *testing.T) {
t.Fatalf("Unexpected response: %#v", res)
}
}
+
+// Tests that the Server.Client method works and returns an http.Client that can hit
+// NewTLSServer without cert warnings.
+func TestServerClient(t *testing.T) {
+ ts := NewTLSServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
+ w.Write([]byte("hello"))
+ }))
+ defer ts.Close()
+ client := ts.Client()
+ res, err := client.Get(ts.URL)
+ if err != nil {
+ t.Fatal(err)
+ }
+ got, err := ioutil.ReadAll(res.Body)
+ res.Body.Close()
+ if err != nil {
+ t.Fatal(err)
+ }
+ if string(got) != "hello" {
+ t.Errorf("got %q, want hello", string(got))
+ }
+}