1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
|
// Copyright 2011 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
// HTTP reverse proxy handler
package httputil
import (
"io"
"log"
"net"
"net/http"
"net/url"
"strings"
"sync"
"time"
)
// ReverseProxy is an HTTP Handler that takes an incoming request and
// sends it to another server, proxying the response back to the
// client.
type ReverseProxy struct {
// Director must be a function which modifies
// the request into a new request to be sent
// using Transport. Its response is then copied
// back to the original client unmodified.
Director func(*http.Request)
// The transport used to perform proxy requests.
// If nil, http.DefaultTransport is used.
Transport http.RoundTripper
// FlushInterval specifies the flush interval, in
// nanoseconds, to flush to the client while
// coping the response body.
// If zero, no periodic flushing is done.
FlushInterval int64
}
func singleJoiningSlash(a, b string) string {
aslash := strings.HasSuffix(a, "/")
bslash := strings.HasPrefix(b, "/")
switch {
case aslash && bslash:
return a + b[1:]
case !aslash && !bslash:
return a + "/" + b
}
return a + b
}
// NewSingleHostReverseProxy returns a new ReverseProxy that rewrites
// URLs to the scheme, host, and base path provided in target. If the
// target's path is "/base" and the incoming request was for "/dir",
// the target request will be for /base/dir.
func NewSingleHostReverseProxy(target *url.URL) *ReverseProxy {
director := func(req *http.Request) {
req.URL.Scheme = target.Scheme
req.URL.Host = target.Host
req.URL.Path = singleJoiningSlash(target.Path, req.URL.Path)
if q := req.URL.RawQuery; q != "" {
req.URL.RawPath = req.URL.Path + "?" + q
} else {
req.URL.RawPath = req.URL.Path
}
req.URL.RawQuery = target.RawQuery
}
return &ReverseProxy{Director: director}
}
func copyHeader(dst, src http.Header) {
for k, vv := range src {
for _, v := range vv {
dst.Add(k, v)
}
}
}
func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) {
transport := p.Transport
if transport == nil {
transport = http.DefaultTransport
}
outreq := new(http.Request)
*outreq = *req // includes shallow copies of maps, but okay
p.Director(outreq)
outreq.Proto = "HTTP/1.1"
outreq.ProtoMajor = 1
outreq.ProtoMinor = 1
outreq.Close = false
// Remove the connection header to the backend. We want a
// persistent connection, regardless of what the client sent
// to us. This is modifying the same underlying map from req
// (shallow copied above) so we only copy it if necessary.
if outreq.Header.Get("Connection") != "" {
outreq.Header = make(http.Header)
copyHeader(outreq.Header, req.Header)
outreq.Header.Del("Connection")
}
if clientIp, _, err := net.SplitHostPort(req.RemoteAddr); err == nil {
outreq.Header.Set("X-Forwarded-For", clientIp)
}
res, err := transport.RoundTrip(outreq)
if err != nil {
log.Printf("http: proxy error: %v", err)
rw.WriteHeader(http.StatusInternalServerError)
return
}
copyHeader(rw.Header(), res.Header)
rw.WriteHeader(res.StatusCode)
if res.Body != nil {
var dst io.Writer = rw
if p.FlushInterval != 0 {
if wf, ok := rw.(writeFlusher); ok {
dst = &maxLatencyWriter{dst: wf, latency: p.FlushInterval}
}
}
io.Copy(dst, res.Body)
}
}
type writeFlusher interface {
io.Writer
http.Flusher
}
type maxLatencyWriter struct {
dst writeFlusher
latency int64 // nanos
lk sync.Mutex // protects init of done, as well Write + Flush
done chan bool
}
func (m *maxLatencyWriter) Write(p []byte) (n int, err error) {
m.lk.Lock()
defer m.lk.Unlock()
if m.done == nil {
m.done = make(chan bool)
go m.flushLoop()
}
n, err = m.dst.Write(p)
if err != nil {
m.done <- true
}
return
}
func (m *maxLatencyWriter) flushLoop() {
t := time.NewTicker(m.latency)
defer t.Stop()
for {
select {
case <-t.C:
m.lk.Lock()
m.dst.Flush()
m.lk.Unlock()
case <-m.done:
return
}
}
panic("unreached")
}
|