summaryrefslogtreecommitdiff
path: root/urllib3/connectionpool.py
blob: 35cc6c4d828f04106c7675131cb433cf59041851 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
import logging
log = logging.getLogger(__name__)

from Queue import Queue, Empty, Full
try:
    from cStringIO import StringIO
except ImportError, e:
    from StringIO import StringIO

try:
    import ssl
except ImportError, e:
    class ssl(object):
        SSLError = None

from urllib import urlencode
from httplib import HTTPConnection, HTTPSConnection, HTTPException
import socket
from socket import error as SocketError, timeout as SocketTimeout

import gzip
import zlib

from filepost import encode_multipart_formdata


## Exceptions

class HTTPError(Exception):
    "Base exception used by this module."
    pass

class SSLError(Exception):
    "Raised when SSL certificate fails in an HTTPS connection."
    pass

class MaxRetryError(HTTPError):
    "Raised when the maximum number of retries is exceeded."
    pass

class TimeoutError(HTTPError):
    "Raised when a socket timeout occurs."
    pass

class HostChangedError(HTTPError):
    "Raised when an existing pool gets a request for a foreign host."
    pass

## Response objects

class HTTPResponse(object):
    """
    HTTP Response container.

    Similar to httplib's HTTPResponse but the data is pre-loaded.
    """
    def __init__(self, data='', headers=None, status=0, version=0, reason=None, strict=0):
        self.data = data
        self.headers = headers or {}
        self.status = status
        self.version = version
        self.reason = reason
        self.strict = strict

    @staticmethod
    def from_httplib(r):
        """
        Given an httplib.HTTPResponse instance, return a corresponding
        urllib3.HTTPResponse object.

        NOTE: This method will perform r.read() which will have side effects
        on the original http.HTTPResponse object.
        """
        tmp_data = StringIO(r.read())
        try:
            if r.getheader('content-encoding') == 'gzip':
                log.debug("Received response with content-encoding: gzip, decompressing with gzip.")

                gzipper = gzip.GzipFile(fileobj=tmp_data)
                data = gzipper.read()
            elif r.getheader('content-encoding') == 'deflate':
                log.debug("Received response with content-encoding: deflate, decompressing with zlib.")
                try:
                    data = zlib.decompress(tmp_data)
                except zlib.error, e:
                    data = zlib.decompress(tmp_data, -zlib.MAX_WBITS)
            else:
                data = tmp_data.read()

        except IOError:
            raise HTTPError("Received response with content-encoding: %s, but failed to decompress it." % (r.getheader('content-encoding')))

        return HTTPResponse(data=data,
                    headers=dict(r.getheaders()),
                    status=r.status,
                    version=r.version,
                    reason=r.reason,
                    strict=r.strict)

    # Backwards-compatibility methods for httplib.HTTPResponse
    def getheaders(self):
        return self.headers

    def getheader(self, name, default=None):
        return self.headers.get(name, default)


## Connection objects

class StreamableMixin(object):
    def send(self, data):
        if isinstance(data, str):
            HTTPConnection.send(self, data)
        elif hasattr(data, '__iter__'):
            for chunk in data:
                HTTPConnection.send(self, chunk)
        else:
            raise TypeError("data object is not an iterable", data)

class StreamableHTTPConnection(HTTPConnection, StreamableMixin):
    # FIXME: Hack for old-style Python classes with broken inheritance 
    def send(self, data):
        StreamableMixin.send(self, data)

class StreamableHTTPSConnection(HTTPSConnection, StreamableMixin):
    # FIXME: Hack for old-style Python classes with broken inheritance 
    def send(self, data):
        StreamableMixin.send(self, data)


class VerifiedHTTPSConnection(StreamableHTTPSConnection):
    """
    Based on httplib.HTTPSConnection but wraps the socket with SSL certification.
    """

    def set_cert(self, key_file=None, cert_file=None, cert_reqs='CERT_NONE', ca_certs=None):
        ssl_req_scheme = {
            'CERT_NONE' : ssl.CERT_NONE,
            'CERT_OPTIONAL' : ssl.CERT_OPTIONAL,
            'CERT_REQUIRED' : ssl.CERT_REQUIRED
        }

        self.key_file = key_file
        self.cert_file = cert_file
        self.cert_reqs = ssl_req_scheme.get(cert_reqs) or ssl.CERT_NONE
        self.ca_certs = ca_certs

    def connect(self):
        # Add certificate verification
        sock = socket.create_connection((self.host, self.port), self.timeout)

        # Wrap socket using verification with the root certs in trusted_root_certs
        self.sock = ssl.wrap_socket(sock, self.key_file, self.cert_file, cert_reqs=self.cert_reqs, ca_certs=self.ca_certs)


## Pool objects

class HTTPConnectionPool(object):
    """
    Thread-safe connection pool for one host.

    host
        Host used for this HTTP Connection (e.g. "localhost"), passed into
        httplib.HTTPConnection()

    port
        Port used for this HTTP Connection (None is equivalent to 80), passed
        into httplib.HTTPConnection()

    timeout
        Socket timeout for each individual connection, can be a float. None
        disables timeout.

    maxsize
        Number of connections to save that can be reused. More than 1 is useful
        in multithreaded situations. If ``block`` is set to false, more
        connections will be created but they will not be saved once they've
        been used.

    block
        If set to True, no more than ``maxsize`` connections will be used at
        a time. When no free connections are available, the call will block
        until a connection has been released. This is a useful side effect for
        particular multithreaded situations where one does not want to use more
        than maxsize connections per host to prevent flooding.

    headers
        Headers to include with all requests, unless other headers are given
        explicitly.
    """

    scheme = 'http'

    def __init__(self, host, port=None, timeout=None, maxsize=1, block=False, headers=None):
        self.host = host
        self.port = port
        self.timeout = timeout
        self.pool = Queue(maxsize)
        self.block = block
        self.headers = headers or {}

        # Fill the queue up so that doing get() on it will block properly
        [self.pool.put(None) for i in xrange(maxsize)]

        self.num_connections = 0
        self.num_requests = 0

    def _new_conn(self):
        """
        Return a fresh HTTPConnection.
        """
        self.num_connections += 1
        log.info("Starting new HTTP connection (%d): %s" % (self.num_connections, self.host))
        return StreamableHTTPConnection(host=self.host, port=self.port)

    def _get_conn(self, timeout=None):
        """
        Get a connection. Will return a pooled connection if one is available.
        Otherwise, a fresh connection is returned.
        """
        conn = None
        try:
            conn = self.pool.get(block=self.block, timeout=timeout)
        except Empty, e:
            pass # Oh well, we'll create a new connection then

        return conn or self._new_conn()

    def _put_conn(self, conn):
        """
        Put a connection back into the pool.
        If the pool is already full, the connection is discarded because we
        exceeded maxsize. If connections are discarded frequently, then maxsize
        should be increased.
        """
        try:
            self.pool.put(conn, block=False)
        except Full, e:
            # This should never happen if self.block == True
            log.warning("HttpConnectionPool is full, discarding connection: %s" % self.host)

    def is_same_host(self, url):
        return url.startswith('/') or get_host(url) == (self.scheme, self.host, self.port)

    def urlopen(self, method, url, body=None, headers=None, retries=3, redirect=True, assert_same_host=True):
        """
        Get a connection from the pool and perform an HTTP request.

        method
            HTTP request method (such as GET, POST, PUT, etc.)

        body
            Data to send in the request body (useful for creating POST requests,
            see HTTPConnectionPool.post_url for more convenience).

        headers
            Dictionary of custom headers to send, such as User-Agent, If-None-Match,
            etc. If None, pool headers are used. If provided, these headers completely
            replace any pool-specific headers.

        retries
            Number of retries to allow before raising a MaxRetryError exception.

        redirect
            Automatically handle redirects (status codes 301, 302, 303, 307),
            each redirect counts as a retry.

        assert_same_host
            If True, will make sure that the host of the pool requests is consistent
            else will raise HostChangedError. When False, you can use the pool on an
            HTTP proxy and request foreign hosts.
        """
        if headers == None:
            headers = self.headers

        if retries < 0:
            raise MaxRetryError("Max retries exceeded for url: %s" % url)

        # Check host
        if assert_same_host and not self.is_same_host(url):
            host = "%s://%s" % (self.scheme, self.host)
            if self.port:
                host = "%s:%d" % (host, self.port)

            raise HostChangedError("Connection pool with host '%s' tried to open a foreign host: %s" % (host, url))

        try:
            # Request a connection from the queue
            conn = self._get_conn()

            # Make the request
            self.num_requests += 1
            conn.request(method, url, body=body, headers=headers)
            conn.sock.settimeout(self.timeout)
            httplib_response = conn.getresponse()
            log.debug("\"%s %s %s\" %s %s" % (method, url, conn._http_vsn_str, httplib_response.status, httplib_response.length))

            # from_httplib will perform httplib_response.read() which will have
            # the side effect of letting us use this connection for another
            # request.
            response = HTTPResponse.from_httplib(httplib_response)

            # Put the connection back to be reused
            self._put_conn(conn)

        except (SocketTimeout, Empty), e:
            # Timed out either by socket or queue
            raise TimeoutError("Request timed out after %f seconds" % self.timeout)

        except (ssl.SSLError), e:
            # SSL certificate error
            raise SSLError(e)

        except (HTTPException, SocketError), e:
            log.warn("Retrying (%d attempts remain) after connection broken by '%r': %s" % (retries, e, url))
            return self.urlopen(method, url, body, headers, retries-1, redirect) # Try again

        # Handle redirection
        if redirect and response.status in [301, 302, 303, 307] and 'location' in response.headers: # Redirect, retry
            log.info("Redirecting %s -> %s" % (url, response.headers.get('location')))
            return self.urlopen(method, response.headers.get('location'), body, headers, retries-1, redirect)

        return response

    def get_url(self, url, fields=None, headers=None, retries=3, redirect=True):
        """
        Wrapper for performing GET with urlopen (see urlopen for more details).

        Supports an optional ``fields`` dictionary parameter key/value strings.
        If provided, they will be added to the url.
        """
        if fields:
            url += '?' + urlencode(fields)
        return self.urlopen('GET', url, headers=headers, retries=retries, redirect=redirect)

    def post_url(self, url, fields=None, headers=None, retries=3, redirect=True):
        """
        Wrapper for performing POST with urlopen (see urlopen for more details).

        Supports an optional ``fields`` parameter of key/value strings AND
        key/filetuple. A filetuple is a (filename, data) tuple. For example:

        fields = {
            'foo': 'bar',
            'foofile': ('foofile.txt', 'contents of foofile'),
        }

        NOTE: If ``headers`` are supplied, the 'Content-Type' value will be
        overwritten because it depends on the dynamic random boundary string
        which is used to compose the body of the request.
        """
        body, content_type = encode_multipart_formdata(fields or {})

        headers = headers or {}
        headers.update({'Content-Type': content_type})

        return self.urlopen('POST', url, body, headers=headers, retries=retries, redirect=redirect)


class HTTPSConnectionPool(HTTPConnectionPool):
    """
    Same as HTTPConnectionPool, but HTTPS.
    """

    scheme = 'https'

    def __init__(self, host, port=None, timeout=None, maxsize=1, block=False, headers=None, key_file=None, cert_file=None, cert_reqs='CERT_NONE', ca_certs=None):
        self.host = host
        self.port = port
        self.timeout = timeout
        self.pool = Queue(maxsize)
        self.block = block
        self.headers = headers or {}

        self.key_file = key_file
        self.cert_file = cert_file
        self.cert_reqs = cert_reqs
        self.ca_certs = ca_certs

        # Fill the queue up so that doing get() on it will block properly
        [self.pool.put(None) for i in xrange(maxsize)]

        self.num_connections = 0
        self.num_requests = 0

    def _new_conn(self):
        """
        Return a fresh HTTPSConnection.
        """
        self.num_connections += 1
        log.info("Starting new HTTPS connection (%d): %s" % (self.num_connections, self.host))

        if not ssl:
            return StreamableHTTPSConnection(host=self.host, port=self.port)

        connection = VerifiedHTTPSConnection(host=self.host, port=self.port)
        connection.set_cert(key_file=self.key_file, cert_file=self.cert_file, cert_reqs=self.cert_reqs, ca_certs=self.ca_certs)
        return connection


## Helpers


def make_headers(keep_alive=None, accept_encoding=None, user_agent=None, basic_auth=None):
    """
    Shortcuts for generating request headers.

    keep_alive
        If true, adds 'connection: keep-alive' header.

    accept_encoding
        Can be a boolean, list, or string.
        True translates to 'gzip,deflate'.
        List will get joined by comma.
        String will be used as provided.

    user_agent
        String representing the user-agent you want, such as "python-urllib3/0.6"

    basic_auth
        Colon-separated username:password string for 'authorization: basic ...'
        auth header.
    """
    headers = {}
    if accept_encoding:
        if isinstance(accept_encoding, str):
            pass
        elif isinstance(accept_encoding, list):
            accept_encoding = ','.join(accept_encoding)
        else:
            accept_encoding = 'gzip,deflate'
        headers['accept-encoding'] = accept_encoding

    if user_agent:
        headers['user-agent'] = user_agent

    if keep_alive:
        headers['connection'] = 'keep-alive'

    if basic_auth:
        headers['authorization'] = 'Basic ' + basic_auth.encode('base64').strip()

    return headers


def get_host(url):
    """
    Given a url, return its scheme, host and port (None if it's not there).

    For example:
    >>> get_host('http://google.com/mail/')
    http, google.com, None
    >>> get_host('google.com:80')
    http, google.com, 80
    """
    # This code is actually similar to urlparse.urlsplit, but much
    # simplified for our needs.
    port = None
    scheme = 'http'
    if '//' in url:
        scheme, url = url.split('://', 1)
    if '/' in url:
        url, path = url.split('/', 1)
    if ':' in url:
        url, port = url.split(':', 1)
        port = int(port)
    return scheme, url, port

def connection_from_url(url, **kw):
    """
    Given a url, return an HTTP(S)ConnectionPool instance of its host.

    This is a shortcut for not having to determine the host of the url
    before creating an HTTP(S)ConnectionPool instance.

    Passes on whatever kw arguments to the constructor of
    HTTP(S)ConnectionPool. (e.g. timeout, maxsize, block)
    """
    scheme, host, port = get_host(url)
    if scheme == 'https':
        return HTTPSConnectionPool(host, port=port, **kw)
    else:
        return HTTPConnectionPool(host, port=port, **kw)