summaryrefslogtreecommitdiff
path: root/paste/proxy.py
blob: aff824d029415313c03c3130081bc67eb5fed28e (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
# (c) 2005 Ian Bicking and contributors; written for Paste (http://pythonpaste.org)
# Licensed under the MIT license: http://www.opensource.org/licenses/mit-license.php
"""
An application that proxies WSGI requests to a remote server.

TODO:

* Send ``Via`` header?  It's not clear to me this is a Via in the
  style of a typical proxy.

* Other headers or metadata?  I put in X-Forwarded-For, but that's it.

* Signed data of non-HTTP keys?  This would be for things like
  REMOTE_USER.

* Something to indicate what the original URL was?  The original host,
  scheme, and base path.

* Rewriting ``Location`` headers?  mod_proxy does this.

* Rewriting body?  (Probably not on this one -- that can be done with
  a different middleware that wraps this middleware)

* Example::  
    
    use = egg:Paste#proxy
    address = http://server3:8680/exist/rest/db/orgs/sch/config/
    allowed_request_methods = GET
  
"""

import httplib
import urlparse

from paste import httpexceptions

# Remove these headers from response (specify lower case header
# names):
filtered_headers = (     
    'transfer-encoding',    
)

class Proxy(object):

    def __init__(self, address, allowed_request_methods=()):
        self.address = address
        self.parsed = urlparse.urlsplit(address)
        self.scheme = self.parsed[0].lower()
        self.host = self.parsed[1]
        self.path = self.parsed[2]
        self.allowed_request_methods = [
            x.lower() for x in allowed_request_methods if x]
        

    def __call__(self, environ, start_response):
        if (self.allowed_request_methods and 
            environ['REQUEST_METHOD'].lower() not in self.allowed_request_methods):
            return httpexceptions.HTTPBadRequest("Disallowed")(environ, start_response)

        if self.scheme == 'http':
            ConnClass = httplib.HTTPConnection
        elif self.scheme == 'https':
            ConnClass = httplib.HTTPSConnection
        else:
            raise ValueError(
                "Unknown scheme for %r: %r" % (self.address, self.scheme))
        conn = ConnClass(self.host)
        headers = {}
        for key, value in environ.items():
            if key.startswith('HTTP_'):
                key = key[5:].lower().replace('_', '-')
                if key == 'host':
                    continue
                headers[key] = value
        headers['host'] = self.host
        if 'REMOTE_ADDR' in environ:
            headers['x-forwarded-for'] = environ['REMOTE_ADDR']
        if environ.get('CONTENT_TYPE'):
            headers['content-type'] = environ['CONTENT_TYPE']
        if environ.get('CONTENT_LENGTH'):
            length = int(environ['CONTENT_LENGTH'])
            body = environ['wsgi.input'].read(length)
        else:
            body = ''
            
        if self.path:            
            request_path = environ['PATH_INFO']
            if request_path[0] == '/':
                request_path = request_path[1:]
                
            path = urlparse.urljoin(self.path, request_path)
        else:
            path = environ['PATH_INFO']
            
        conn.request(environ['REQUEST_METHOD'],
                     path,
                     body, headers)
        res = conn.getresponse()
        headers_out = []
        for header, value in res.getheaders():
            if header.lower() not in filtered_headers:
                headers_out.append((header, value))
                
        status = '%s %s' % (res.status, res.reason)
        start_response(status, headers_out)
        # @@: Default?
        length = res.getheader('content-length')
        if length is not None:
            body = res.read(int(length))
        else:
            body = res.read()
        conn.close()
        return [body]

def make_proxy(global_conf, address, allowed_request_methods=""):
    """
    Make a WSGI application that proxies to another address --
    'address' should be the full URL ending with a trailing /
    'allowed_request_methods' is a space seperated list of request methods
    """
    from paste.deploy.converters import aslist
    allowed_request_methods = aslist(allowed_request_methods)
    return Proxy(address, allowed_request_methods=allowed_request_methods)


class TransparentProxy(object):

    """
    A proxy that sends the request just as it was given, including
    respecting HTTP_HOST, wsgi.url_scheme, etc.

    This is a way of translating WSGI requests directly to real HTTP
    requests.  All information goes in the environment; modify it to
    modify the way the request is made.
    """

    def __init__(self):
        pass

    def __call__(self, environ, start_response):
        scheme = environ['wsgi.url_scheme']
        if scheme == 'http':
            ConnClass = httplib.HTTPConnection
        elif scheme == 'https':
            ConnClass = httplib.HTTPSConnection
        else:
            raise ValueError(
                "Unknown scheme %r" % scheme)
        if 'HTTP_HOST' not in environ:
            raise ValueError(
                "WSGI environ must contain an HTTP_HOST key")
        host = environ['HTTP_HOST']
        conn = ConnClass(host)
        headers = {}
        for key, value in environ.items():
            if key.startswith('HTTP_'):
                key = key[5:].lower().replace('_', '-')
                headers[key] = value
        headers['host'] = host
        if 'REMOTE_ADDR' in environ:
            headers['x-forwarded-for'] = environ['REMOTE_ADDR']
        if environ.get('CONTENT_TYPE'):
            headers['content-type'] = environ['CONTENT_TYPE']
        if environ.get('CONTENT_LENGTH'):
            length = int(environ['CONTENT_LENGTH'])
            body = environ['wsgi.input'].read(length)
        else:
            body = environ['wsgi.input'].read(length)
            length = len(body)
            
        path = (environ.get('SCRIPT_NAME', '')
                + environ.get('PATH_INFO', ''))
        if 'QUERY_STRING' in environ:
            path += '?' + environ['QUERY_STRING']
        conn.request(environ['REQUEST_METHOD'],
                     path, body, headers)
        res = conn.getresponse()
        headers_out = []
        for header, value in res.getheaders():
            if header.lower() not in filtered_headers:
                headers_out.append((header, value))
                
        status = '%s %s' % (res.status, res.reason)
        start_response(status, headers_out)
        # @@: Default?
        length = res.getheader('content-length')
        if length is not None:
            body = res.read(int(length))
        else:
            body = res.read()
        conn.close()
        return [body]