diff options
author | Cory Benfield <lukasaoz@gmail.com> | 2017-06-14 21:30:09 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2017-06-14 21:30:09 +0100 |
commit | 98f9b0c105b2bc5c0d00a8631db0a064918e8948 (patch) | |
tree | 912fb71bd66392d66b00d35043a396064e6643bd | |
parent | 39315237c009d6c18601449a48d5e096869f2a31 (diff) | |
parent | d74df4af794b28db61edbbacec688463bbabdd43 (diff) | |
download | urllib3-gae-to-pytest.tar.gz |
Merge branch 'master' into gae-to-pytestgae-to-pytest
-rw-r--r-- | test/contrib/test_socks.py | 3 | ||||
-rw-r--r-- | test/test_response.py | 311 | ||||
-rw-r--r-- | test/test_retry.py | 171 | ||||
-rw-r--r-- | test/test_util.py | 578 | ||||
-rw-r--r-- | test/with_dummyserver/test_https.py | 1 | ||||
-rw-r--r-- | test/with_dummyserver/test_poolmanager.py | 1 |
6 files changed, 510 insertions, 555 deletions
diff --git a/test/contrib/test_socks.py b/test/contrib/test_socks.py index da553437..d5c82744 100644 --- a/test/contrib/test_socks.py +++ b/test/contrib/test_socks.py @@ -255,6 +255,7 @@ class TestSocks5Proxy(IPV4SocketDummyServerTestCase): self._start_server(request_handler) proxy_url = "socks5://%s:%s" % (self.host, self.port) pm = socks.SOCKSProxyManager(proxy_url) + self.addCleanup(pm.clear) response = pm.request('GET', 'http://localhost') self.assertEqual(response.status, 200) @@ -511,6 +512,7 @@ class TestSOCKS4Proxy(IPV4SocketDummyServerTestCase): self._start_server(request_handler) proxy_url = "socks4://%s:%s" % (self.host, self.port) pm = socks.SOCKSProxyManager(proxy_url) + self.addCleanup(pm.clear) response = pm.request('GET', 'http://localhost') self.assertEqual(response.status, 200) @@ -664,6 +666,7 @@ class TestSOCKSWithTLS(IPV4SocketDummyServerTestCase): b'Content-Length: 0\r\n' b'\r\n') tls.close() + sock.close() self._start_server(request_handler) proxy_url = "socks5h://%s:%s" % (self.host, self.port) diff --git a/test/test_response.py b/test/test_response.py index 5146b1f0..9ec029f7 100644 --- a/test/test_response.py +++ b/test/test_response.py @@ -1,8 +1,9 @@ import socket -import sys from io import BytesIO, BufferedReader +import pytest + from urllib3.response import HTTPResponse from urllib3.exceptions import ( DecodeError, ResponseNotChunked, ProtocolError, InvalidHeader @@ -13,11 +14,6 @@ from urllib3.util.response import is_fp_closed from base64 import b64decode -if sys.version_info >= (2, 7): - import unittest -else: - import unittest2 as unittest - # A known random (i.e, not-too-compressible) payload generated with: # "".join(random.choice(string.printable) for i in xrange(512)) # .encode("zlib").encode("base64") @@ -34,63 +30,69 @@ S5moAj5HexY/g/F8TctpxwsvyZp38dXeLDjSQvEQIkF7XR3YXbeZgKk3V34KGCPOAeeuQDIgyVhV nP4HF2uWHA==""") -class TestLegacyResponse(unittest.TestCase): +@pytest.fixture +def sock(): + s = socket.socket() + yield s + s.close() + + +class TestLegacyResponse(object): def test_getheaders(self): headers = {'host': 'example.com'} r = HTTPResponse(headers=headers) - self.assertEqual(r.getheaders(), headers) + assert r.getheaders() == headers def test_getheader(self): headers = {'host': 'example.com'} r = HTTPResponse(headers=headers) - self.assertEqual(r.getheader('host'), 'example.com') + assert r.getheader('host') == 'example.com' -class TestResponse(unittest.TestCase): +class TestResponse(object): def test_cache_content(self): r = HTTPResponse('foo') - self.assertEqual(r.data, 'foo') - self.assertEqual(r._body, 'foo') + assert r.data == 'foo' + assert r._body == 'foo' def test_default(self): r = HTTPResponse() - self.assertEqual(r.data, None) + assert r.data is None def test_none(self): r = HTTPResponse(None) - self.assertEqual(r.data, None) + assert r.data is None def test_preload(self): fp = BytesIO(b'foo') r = HTTPResponse(fp, preload_content=True) - self.assertEqual(fp.tell(), len(b'foo')) - self.assertEqual(r.data, b'foo') + assert fp.tell() == len(b'foo') + assert r.data == b'foo' def test_no_preload(self): fp = BytesIO(b'foo') r = HTTPResponse(fp, preload_content=False) - self.assertEqual(fp.tell(), 0) - self.assertEqual(r.data, b'foo') - self.assertEqual(fp.tell(), len(b'foo')) + assert fp.tell() == 0 + assert r.data == b'foo' + assert fp.tell() == len(b'foo') def test_decode_bad_data(self): fp = BytesIO(b'\x00' * 10) - self.assertRaises(DecodeError, HTTPResponse, fp, headers={ - 'content-encoding': 'deflate' - }) + with pytest.raises(DecodeError): + HTTPResponse(fp, headers={'content-encoding': 'deflate'}) def test_reference_read(self): fp = BytesIO(b'foo') r = HTTPResponse(fp, preload_content=False) - self.assertEqual(r.read(1), b'f') - self.assertEqual(r.read(2), b'oo') - self.assertEqual(r.read(), b'') - self.assertEqual(r.read(), b'') + assert r.read(1) == b'f' + assert r.read(2) == b'oo' + assert r.read() == b'' + assert r.read() == b'' def test_decode_deflate(self): import zlib @@ -99,7 +101,7 @@ class TestResponse(unittest.TestCase): fp = BytesIO(data) r = HTTPResponse(fp, headers={'content-encoding': 'deflate'}) - self.assertEqual(r.data, b'foo') + assert r.data == b'foo' def test_decode_deflate_case_insensitve(self): import zlib @@ -108,7 +110,7 @@ class TestResponse(unittest.TestCase): fp = BytesIO(data) r = HTTPResponse(fp, headers={'content-encoding': 'DeFlAtE'}) - self.assertEqual(r.data, b'foo') + assert r.data == b'foo' def test_chunked_decoding_deflate(self): import zlib @@ -118,15 +120,15 @@ class TestResponse(unittest.TestCase): r = HTTPResponse(fp, headers={'content-encoding': 'deflate'}, preload_content=False) - self.assertEqual(r.read(3), b'') + assert r.read(3) == b'' # Buffer in case we need to switch to the raw stream - self.assertIsNotNone(r._decoder._data) - self.assertEqual(r.read(1), b'f') + assert r._decoder._data is not None + assert r.read(1) == b'f' # Now that we've decoded data, we just stream through the decoder - self.assertIsNone(r._decoder._data) - self.assertEqual(r.read(2), b'oo') - self.assertEqual(r.read(), b'') - self.assertEqual(r.read(), b'') + assert r._decoder._data is None + assert r.read(2) == b'oo' + assert r.read() == b'' + assert r.read() == b'' def test_chunked_decoding_deflate2(self): import zlib @@ -138,13 +140,13 @@ class TestResponse(unittest.TestCase): r = HTTPResponse(fp, headers={'content-encoding': 'deflate'}, preload_content=False) - self.assertEqual(r.read(1), b'') - self.assertEqual(r.read(1), b'f') + assert r.read(1) == b'' + assert r.read(1) == b'f' # Once we've decoded data, we just stream to the decoder; no buffering - self.assertIsNone(r._decoder._data) - self.assertEqual(r.read(2), b'oo') - self.assertEqual(r.read(), b'') - self.assertEqual(r.read(), b'') + assert r._decoder._data is None + assert r.read(2) == b'oo' + assert r.read() == b'' + assert r.read() == b'' def test_chunked_decoding_gzip(self): import zlib @@ -156,71 +158,79 @@ class TestResponse(unittest.TestCase): r = HTTPResponse(fp, headers={'content-encoding': 'gzip'}, preload_content=False) - self.assertEqual(r.read(11), b'') - self.assertEqual(r.read(1), b'f') - self.assertEqual(r.read(2), b'oo') - self.assertEqual(r.read(), b'') - self.assertEqual(r.read(), b'') + assert r.read(11) == b'' + assert r.read(1) == b'f' + assert r.read(2) == b'oo' + assert r.read() == b'' + assert r.read() == b'' def test_body_blob(self): resp = HTTPResponse(b'foo') - self.assertEqual(resp.data, b'foo') - self.assertTrue(resp.closed) + assert resp.data == b'foo' + assert resp.closed - def test_io(self): + def test_io(self, sock): fp = BytesIO(b'foo') resp = HTTPResponse(fp, preload_content=False) - self.assertEqual(resp.closed, False) - self.assertEqual(resp.readable(), True) - self.assertEqual(resp.writable(), False) - self.assertRaises(IOError, resp.fileno) + assert not resp.closed + assert resp.readable() + assert not resp.writable() + with pytest.raises(IOError): + resp.fileno() resp.close() - self.assertEqual(resp.closed, True) + assert resp.closed # Try closing with an `httplib.HTTPResponse`, because it has an # `isclosed` method. - hlr = httplib.HTTPResponse(socket.socket()) - resp2 = HTTPResponse(hlr, preload_content=False) - self.assertEqual(resp2.closed, False) - resp2.close() - self.assertEqual(resp2.closed, True) + try: + hlr = httplib.HTTPResponse(sock) + resp2 = HTTPResponse(hlr, preload_content=False) + assert not resp2.closed + resp2.close() + assert resp2.closed + finally: + hlr.close() # also try when only data is present. resp3 = HTTPResponse('foodata') - self.assertRaises(IOError, resp3.fileno) + with pytest.raises(IOError): + resp3.fileno() resp3._fp = 2 # A corner case where _fp is present but doesn't have `closed`, # `isclosed`, or `fileno`. Unlikely, but possible. - self.assertEqual(resp3.closed, True) - self.assertRaises(IOError, resp3.fileno) - - def test_io_closed_consistently(self): - hlr = httplib.HTTPResponse(socket.socket()) - hlr.fp = BytesIO(b'foo') - hlr.chunked = 0 - hlr.length = 3 - resp = HTTPResponse(hlr, preload_content=False) - - self.assertEqual(resp.closed, False) - self.assertEqual(resp._fp.isclosed(), False) - self.assertEqual(is_fp_closed(resp._fp), False) - resp.read() - self.assertEqual(resp.closed, True) - self.assertEqual(resp._fp.isclosed(), True) - self.assertEqual(is_fp_closed(resp._fp), True) + assert resp3.closed + with pytest.raises(IOError): + resp3.fileno() + + def test_io_closed_consistently(self, sock): + try: + hlr = httplib.HTTPResponse(sock) + hlr.fp = BytesIO(b'foo') + hlr.chunked = 0 + hlr.length = 3 + with HTTPResponse(hlr, preload_content=False) as resp: + assert not resp.closed + assert not resp._fp.isclosed() + assert not is_fp_closed(resp._fp) + resp.read() + assert resp.closed + assert resp._fp.isclosed() + assert is_fp_closed(resp._fp) + finally: + hlr.close() def test_io_bufferedreader(self): fp = BytesIO(b'foo') resp = HTTPResponse(fp, preload_content=False) br = BufferedReader(resp) - self.assertEqual(br.read(), b'foo') + assert br.read() == b'foo' br.close() - self.assertEqual(resp.closed, True) + assert resp.closed b = b'fooandahalf' fp = BytesIO(b) @@ -228,7 +238,7 @@ class TestResponse(unittest.TestCase): br = BufferedReader(resp, 5) br.read(1) # sets up the buffer, reading 5 - self.assertEqual(len(fp.read()), len(b) - 5) + assert len(fp.read()) == (len(b) - 5) # This is necessary to make sure the "no bytes left" part of `readinto` # gets tested. @@ -257,9 +267,10 @@ class TestResponse(unittest.TestCase): resp = HTTPResponse(fp, preload_content=False) stream = resp.stream(2, decode_content=False) - self.assertEqual(next(stream), b'fo') - self.assertEqual(next(stream), b'o') - self.assertRaises(StopIteration, next, stream) + assert next(stream) == b'fo' + assert next(stream) == b'o' + with pytest.raises(StopIteration): + next(stream) def test_streaming_tell(self): fp = BytesIO(b'foo') @@ -269,14 +280,15 @@ class TestResponse(unittest.TestCase): position = 0 position += len(next(stream)) - self.assertEqual(2, position) - self.assertEqual(position, resp.tell()) + assert 2 == position + assert position == resp.tell() position += len(next(stream)) - self.assertEqual(3, position) - self.assertEqual(position, resp.tell()) + assert 3 == position + assert position == resp.tell() - self.assertRaises(StopIteration, next, stream) + with pytest.raises(StopIteration): + next(stream) def test_gzipped_streaming(self): import zlib @@ -289,9 +301,10 @@ class TestResponse(unittest.TestCase): preload_content=False) stream = resp.stream(2) - self.assertEqual(next(stream), b'f') - self.assertEqual(next(stream), b'oo') - self.assertRaises(StopIteration, next, stream) + assert next(stream) == b'f' + assert next(stream) == b'oo' + with pytest.raises(StopIteration): + next(stream) def test_gzipped_streaming_tell(self): import zlib @@ -307,11 +320,12 @@ class TestResponse(unittest.TestCase): # Read everything payload = next(stream) - self.assertEqual(payload, uncompressed_data) + assert payload == uncompressed_data - self.assertEqual(len(data), resp.tell()) + assert len(data) == resp.tell() - self.assertRaises(StopIteration, next, stream) + with pytest.raises(StopIteration): + next(stream) def test_deflate_streaming_tell_intermediate_point(self): # Ensure that ``tell()`` returns the correct number of bytes when @@ -350,20 +364,21 @@ class TestResponse(unittest.TestCase): parts_positions = [(part, resp.tell()) for part in stream] end_of_stream = resp.tell() - self.assertRaises(StopIteration, next, stream) + with pytest.raises(StopIteration): + next(stream) parts, positions = zip(*parts_positions) # Check that the payload is equal to the uncompressed data payload = b"".join(parts) - self.assertEqual(uncompressed_data, payload) + assert uncompressed_data == payload # Check that the positions in the stream are correct expected = [(i+1)*payload_part_size for i in range(NUMBER_OF_READS)] - self.assertEqual(expected, list(positions)) + assert expected == list(positions) # Check that the end of the stream is in the correct place - self.assertEqual(len(ZLIB_PAYLOAD), end_of_stream) + assert len(ZLIB_PAYLOAD) == end_of_stream def test_deflate_streaming(self): import zlib @@ -374,9 +389,10 @@ class TestResponse(unittest.TestCase): preload_content=False) stream = resp.stream(2) - self.assertEqual(next(stream), b'f') - self.assertEqual(next(stream), b'oo') - self.assertRaises(StopIteration, next, stream) + assert next(stream) == b'f' + assert next(stream) == b'oo' + with pytest.raises(StopIteration): + next(stream) def test_deflate2_streaming(self): import zlib @@ -389,39 +405,41 @@ class TestResponse(unittest.TestCase): preload_content=False) stream = resp.stream(2) - self.assertEqual(next(stream), b'f') - self.assertEqual(next(stream), b'oo') - self.assertRaises(StopIteration, next, stream) + assert next(stream) == b'f' + assert next(stream) == b'oo' + with pytest.raises(StopIteration): + next(stream) def test_empty_stream(self): fp = BytesIO(b'') resp = HTTPResponse(fp, preload_content=False) stream = resp.stream(2, decode_content=False) - self.assertRaises(StopIteration, next, stream) + with pytest.raises(StopIteration): + next(stream) def test_length_no_header(self): fp = BytesIO(b'12345') resp = HTTPResponse(fp, preload_content=False) - self.assertEqual(resp.length_remaining, None) + assert resp.length_remaining is None def test_length_w_valid_header(self): headers = {"content-length": "5"} fp = BytesIO(b'12345') resp = HTTPResponse(fp, headers=headers, preload_content=False) - self.assertEqual(resp.length_remaining, 5) + assert resp.length_remaining == 5 def test_length_w_bad_header(self): garbage = {'content-length': 'foo'} fp = BytesIO(b'12345') resp = HTTPResponse(fp, headers=garbage, preload_content=False) - self.assertEqual(resp.length_remaining, None) + assert resp.length_remaining is None garbage['content-length'] = "-10" resp = HTTPResponse(fp, headers=garbage, preload_content=False) - self.assertEqual(resp.length_remaining, None) + assert resp.length_remaining is None def test_length_when_chunked(self): # This is expressly forbidden in RFC 7230 sec 3.3.2 @@ -432,7 +450,7 @@ class TestResponse(unittest.TestCase): fp = BytesIO(b'12345') resp = HTTPResponse(fp, headers=headers, preload_content=False) - self.assertEqual(resp.length_remaining, None) + assert resp.length_remaining is None def test_length_with_multiple_content_lengths(self): headers = {'content-length': '5, 5, 5'} @@ -440,10 +458,10 @@ class TestResponse(unittest.TestCase): fp = BytesIO(b'abcde') resp = HTTPResponse(fp, headers=headers, preload_content=False) - self.assertEqual(resp.length_remaining, 5) + assert resp.length_remaining == 5 - self.assertRaises(InvalidHeader, HTTPResponse, fp, - headers=garbage, preload_content=False) + with pytest.raises(InvalidHeader): + HTTPResponse(fp, headers=garbage, preload_content=False) def test_length_after_read(self): headers = {"content-length": "5"} @@ -452,20 +470,20 @@ class TestResponse(unittest.TestCase): fp = BytesIO(b'12345') resp = HTTPResponse(fp, preload_content=False) resp.read() - self.assertEqual(resp.length_remaining, None) + assert resp.length_remaining is None # Test our update from content-length fp = BytesIO(b'12345') resp = HTTPResponse(fp, headers=headers, preload_content=False) resp.read() - self.assertEqual(resp.length_remaining, 0) + assert resp.length_remaining == 0 # Test partial read fp = BytesIO(b'12345') resp = HTTPResponse(fp, headers=headers, preload_content=False) data = resp.stream(2) next(data) - self.assertEqual(resp.length_remaining, 3) + assert resp.length_remaining == 3 def test_mock_httpresponse_stream(self): # Mock out a HTTP Request that does enough to make it through urllib3's @@ -490,9 +508,10 @@ class TestResponse(unittest.TestCase): resp = HTTPResponse(fp, preload_content=False) stream = resp.stream(2) - self.assertEqual(next(stream), b'fo') - self.assertEqual(next(stream), b'o') - self.assertRaises(StopIteration, next, stream) + assert next(stream) == b'fo' + assert next(stream) == b'o' + with pytest.raises(StopIteration): + next(stream) def test_mock_transfer_encoding_chunked(self): stream = [b"fo", b"o", b"bar"] @@ -501,10 +520,8 @@ class TestResponse(unittest.TestCase): r.fp = fp resp = HTTPResponse(r, preload_content=False, headers={'transfer-encoding': 'chunked'}) - i = 0 - for c in resp.stream(): - self.assertEqual(c, stream[i]) - i += 1 + for i, c in enumerate(resp.stream()): + assert c == stream[i] def test_mock_gzipped_transfer_encoding_chunked_decoded(self): """Show that we can decode the gizpped and chunked body.""" @@ -527,7 +544,7 @@ class TestResponse(unittest.TestCase): for c in resp.stream(decode_content=True): data += c - self.assertEqual(b'foobar', data) + assert b'foobar' == data def test_mock_transfer_encoding_chunked_custom_read(self): stream = [b"foooo", b"bbbbaaaaar"] @@ -539,12 +556,7 @@ class TestResponse(unittest.TestCase): resp = HTTPResponse(r, preload_content=False, headers={'transfer-encoding': 'chunked'}) expected_response = [b'fo', b'oo', b'o', b'bb', b'bb', b'aa', b'aa', b'ar'] response = list(resp.read_chunked(2)) - if getattr(self, "assertListEqual", False): - self.assertListEqual(expected_response, response) - else: - for index, item in enumerate(response): - v = expected_response[index] - self.assertEqual(item, v) + assert expected_response == response def test_mock_transfer_encoding_chunked_unlmtd_read(self): stream = [b"foooo", b"bbbbaaaaar"] @@ -554,18 +566,14 @@ class TestResponse(unittest.TestCase): r.chunked = True r.chunk_left = None resp = HTTPResponse(r, preload_content=False, headers={'transfer-encoding': 'chunked'}) - if getattr(self, "assertListEqual", False): - self.assertListEqual(stream, list(resp.read_chunked())) - else: - for index, item in enumerate(resp.read_chunked()): - v = stream[index] - self.assertEqual(item, v) + assert stream == list(resp.read_chunked()) def test_read_not_chunked_response_as_chunks(self): fp = BytesIO(b'foo') resp = HTTPResponse(fp, preload_content=False) r = resp.read_chunked() - self.assertRaises(ResponseNotChunked, next, r) + with pytest.raises(ResponseNotChunked): + next(r) def test_invalid_chunks(self): stream = [b"foooo", b"bbbbaaaaar"] @@ -575,7 +583,8 @@ class TestResponse(unittest.TestCase): r.chunked = True r.chunk_left = None resp = HTTPResponse(r, preload_content=False, headers={'transfer-encoding': 'chunked'}) - self.assertRaises(ProtocolError, next, resp.read_chunked()) + with pytest.raises(ProtocolError): + next(resp.read_chunked()) def test_chunked_response_without_crlf_on_end(self): stream = [b"foo", b"bar", b"baz"] @@ -585,12 +594,7 @@ class TestResponse(unittest.TestCase): r.chunked = True r.chunk_left = None resp = HTTPResponse(r, preload_content=False, headers={'transfer-encoding': 'chunked'}) - if getattr(self, "assertListEqual", False): - self.assertListEqual(stream, list(resp.stream())) - else: - for index, item in enumerate(resp.stream()): - v = stream[index] - self.assertEqual(item, v) + assert stream == list(resp.stream()) def test_chunked_response_with_extensions(self): stream = [b"foo", b"bar"] @@ -600,26 +604,21 @@ class TestResponse(unittest.TestCase): r.chunked = True r.chunk_left = None resp = HTTPResponse(r, preload_content=False, headers={'transfer-encoding': 'chunked'}) - if getattr(self, "assertListEqual", False): - self.assertListEqual(stream, list(resp.stream())) - else: - for index, item in enumerate(resp.stream()): - v = stream[index] - self.assertEqual(item, v) + assert stream == list(resp.stream()) def test_get_case_insensitive_headers(self): headers = {'host': 'example.com'} r = HTTPResponse(headers=headers) - self.assertEqual(r.headers.get('host'), 'example.com') - self.assertEqual(r.headers.get('Host'), 'example.com') + assert r.headers.get('host') == 'example.com' + assert r.headers.get('Host') == 'example.com' def test_retries(self): fp = BytesIO(b'') resp = HTTPResponse(fp) - self.assertEqual(resp.retries, None) + assert resp.retries is None retry = Retry() resp = HTTPResponse(fp, retries=retry) - self.assertEqual(resp.retries, retry) + assert resp.retries == retry class MockChunkedEncodingResponse(object): @@ -722,7 +721,3 @@ class MockSock(object): @classmethod def makefile(cls, *args, **kwargs): return - - -if __name__ == '__main__': - unittest.main() diff --git a/test/test_retry.py b/test/test_retry.py index dbe4dc0d..9181f7ca 100644 --- a/test/test_retry.py +++ b/test/test_retry.py @@ -1,4 +1,4 @@ -import unittest +import pytest from urllib3.response import HTTPResponse from urllib3.packages.six.moves import xrange @@ -11,17 +11,15 @@ from urllib3.exceptions import ( ) -class RetryTest(unittest.TestCase): +class TestRetry(object): def test_string(self): """ Retry string representation looks the way we expect """ retry = Retry() - self.assertEqual(str(retry), - 'Retry(total=10, connect=None, read=None, redirect=None, status=None)') + assert str(retry) == 'Retry(total=10, connect=None, read=None, redirect=None, status=None)' for _ in range(3): retry = retry.increment(method='GET') - self.assertEqual(str(retry), - 'Retry(total=7, connect=None, read=None, redirect=None, status=None)') + assert str(retry) == 'Retry(total=7, connect=None, read=None, redirect=None, status=None)' def test_retry_both_specified(self): """Total can win if it's lower than the connect value""" @@ -29,11 +27,9 @@ class RetryTest(unittest.TestCase): retry = Retry(connect=3, total=2) retry = retry.increment(error=error) retry = retry.increment(error=error) - try: + with pytest.raises(MaxRetryError) as e: retry.increment(error=error) - self.fail("Failed to raise error.") - except MaxRetryError as e: - self.assertEqual(e.reason, error) + assert e.value.reason == error def test_retry_higher_total_loses(self): """ A lower connect timeout than the total is honored """ @@ -41,7 +37,8 @@ class RetryTest(unittest.TestCase): retry = Retry(connect=2, total=3) retry = retry.increment(error=error) retry = retry.increment(error=error) - self.assertRaises(MaxRetryError, retry.increment, error=error) + with pytest.raises(MaxRetryError): + retry.increment(error=error) def test_retry_higher_total_loses_vs_read(self): """ A lower read timeout than the total is honored """ @@ -49,7 +46,8 @@ class RetryTest(unittest.TestCase): retry = Retry(read=2, total=3) retry = retry.increment(method='GET', error=error) retry = retry.increment(method='GET', error=error) - self.assertRaises(MaxRetryError, retry.increment, method='GET', error=error) + with pytest.raises(MaxRetryError): + retry.increment(method='GET', error=error) def test_retry_total_none(self): """ if Total is none, connect error should take precedence """ @@ -57,105 +55,99 @@ class RetryTest(unittest.TestCase): retry = Retry(connect=2, total=None) retry = retry.increment(error=error) retry = retry.increment(error=error) - try: + with pytest.raises(MaxRetryError) as e: retry.increment(error=error) - self.fail("Failed to raise error.") - except MaxRetryError as e: - self.assertEqual(e.reason, error) + assert e.value.reason == error error = ReadTimeoutError(None, "/", "read timed out") retry = Retry(connect=2, total=None) retry = retry.increment(method='GET', error=error) retry = retry.increment(method='GET', error=error) retry = retry.increment(method='GET', error=error) - self.assertFalse(retry.is_exhausted()) + assert not retry.is_exhausted() def test_retry_default(self): """ If no value is specified, should retry connects 3 times """ retry = Retry() - self.assertEqual(retry.total, 10) - self.assertEqual(retry.connect, None) - self.assertEqual(retry.read, None) - self.assertEqual(retry.redirect, None) + assert retry.total == 10 + assert retry.connect is None + assert retry.read is None + assert retry.redirect is None error = ConnectTimeoutError() retry = Retry(connect=1) retry = retry.increment(error=error) - self.assertRaises(MaxRetryError, retry.increment, error=error) + with pytest.raises(MaxRetryError): + retry.increment(error=error) retry = Retry(connect=1) retry = retry.increment(error=error) - self.assertFalse(retry.is_exhausted()) + assert not retry.is_exhausted() - self.assertTrue(Retry(0).raise_on_redirect) - self.assertFalse(Retry(False).raise_on_redirect) + assert Retry(0).raise_on_redirect + assert not Retry(False).raise_on_redirect def test_retry_read_zero(self): """ No second chances on read timeouts, by default """ error = ReadTimeoutError(None, "/", "read timed out") retry = Retry(read=0) - try: + with pytest.raises(MaxRetryError) as e: retry.increment(method='GET', error=error) - self.fail("Failed to raise error.") - except MaxRetryError as e: - self.assertEqual(e.reason, error) + assert e.value.reason == error def test_status_counter(self): resp = HTTPResponse(status=400) retry = Retry(status=2) retry = retry.increment(response=resp) retry = retry.increment(response=resp) - try: + with pytest.raises(MaxRetryError) as e: retry.increment(response=resp) - self.fail("Failed to raise error.") - except MaxRetryError as e: - self.assertEqual(str(e.reason), - ResponseError.SPECIFIC_ERROR.format(status_code=400)) + assert str(e.value.reason) == ResponseError.SPECIFIC_ERROR.format(status_code=400) def test_backoff(self): """ Backoff is computed correctly """ max_backoff = Retry.BACKOFF_MAX retry = Retry(total=100, backoff_factor=0.2) - self.assertEqual(retry.get_backoff_time(), 0) # First request + assert retry.get_backoff_time() == 0 # First request retry = retry.increment(method='GET') - self.assertEqual(retry.get_backoff_time(), 0) # First retry + assert retry.get_backoff_time() == 0 # First retry retry = retry.increment(method='GET') - self.assertEqual(retry.backoff_factor, 0.2) - self.assertEqual(retry.total, 98) - self.assertEqual(retry.get_backoff_time(), 0.4) # Start backoff + assert retry.backoff_factor == 0.2 + assert retry.total == 98 + assert retry.get_backoff_time() == 0.4 # Start backoff retry = retry.increment(method='GET') - self.assertEqual(retry.get_backoff_time(), 0.8) + assert retry.get_backoff_time() == 0.8 retry = retry.increment(method='GET') - self.assertEqual(retry.get_backoff_time(), 1.6) + assert retry.get_backoff_time() == 1.6 - for i in xrange(10): + for _ in xrange(10): retry = retry.increment(method='GET') - self.assertEqual(retry.get_backoff_time(), max_backoff) + assert retry.get_backoff_time() == max_backoff def test_zero_backoff(self): retry = Retry() - self.assertEqual(retry.get_backoff_time(), 0) + assert retry.get_backoff_time() == 0 retry = retry.increment(method='GET') retry = retry.increment(method='GET') - self.assertEqual(retry.get_backoff_time(), 0) + assert retry.get_backoff_time() == 0 def test_backoff_reset_after_redirect(self): retry = Retry(total=100, redirect=5, backoff_factor=0.2) retry = retry.increment(method='GET') retry = retry.increment(method='GET') - self.assertEqual(retry.get_backoff_time(), 0.4) + assert retry.get_backoff_time() == 0.4 redirect_response = HTTPResponse(status=302, headers={'location': 'test'}) retry = retry.increment(method='GET', response=redirect_response) - self.assertEqual(retry.get_backoff_time(), 0) + assert retry.get_backoff_time() == 0 retry = retry.increment(method='GET') retry = retry.increment(method='GET') - self.assertEqual(retry.get_backoff_time(), 0.4) + assert retry.get_backoff_time() == 0.4 def test_sleep(self): # sleep a very small amount of time so our code coverage is happy @@ -166,101 +158,94 @@ class RetryTest(unittest.TestCase): def test_status_forcelist(self): retry = Retry(status_forcelist=xrange(500, 600)) - self.assertFalse(retry.is_retry('GET', status_code=200)) - self.assertFalse(retry.is_retry('GET', status_code=400)) - self.assertTrue(retry.is_retry('GET', status_code=500)) + assert not retry.is_retry('GET', status_code=200) + assert not retry.is_retry('GET', status_code=400) + assert retry.is_retry('GET', status_code=500) retry = Retry(total=1, status_forcelist=[418]) - self.assertFalse(retry.is_retry('GET', status_code=400)) - self.assertTrue(retry.is_retry('GET', status_code=418)) + assert not retry.is_retry('GET', status_code=400) + assert retry.is_retry('GET', status_code=418) # String status codes are not matched. retry = Retry(total=1, status_forcelist=['418']) - self.assertFalse(retry.is_retry('GET', status_code=418)) + assert not retry.is_retry('GET', status_code=418) def test_method_whitelist_with_status_forcelist(self): # Falsey method_whitelist means to retry on any method. retry = Retry(status_forcelist=[500], method_whitelist=None) - self.assertTrue(retry.is_retry('GET', status_code=500)) - self.assertTrue(retry.is_retry('POST', status_code=500)) + assert retry.is_retry('GET', status_code=500) + assert retry.is_retry('POST', status_code=500) # Criteria of method_whitelist and status_forcelist are ANDed. retry = Retry(status_forcelist=[500], method_whitelist=['POST']) - self.assertFalse(retry.is_retry('GET', status_code=500)) - self.assertTrue(retry.is_retry('POST', status_code=500)) + assert not retry.is_retry('GET', status_code=500) + assert retry.is_retry('POST', status_code=500) def test_exhausted(self): - self.assertFalse(Retry(0).is_exhausted()) - self.assertTrue(Retry(-1).is_exhausted()) - self.assertEqual(Retry(1).increment(method='GET').total, 0) + assert not Retry(0).is_exhausted() + assert Retry(-1).is_exhausted() + assert Retry(1).increment(method='GET').total == 0 - def test_disabled(self): - self.assertRaises(MaxRetryError, Retry(-1).increment, method='GET') - self.assertRaises(MaxRetryError, Retry(0).increment, method='GET') + @pytest.mark.parametrize('total', [-1, 0]) + def test_disabled(self, total): + with pytest.raises(MaxRetryError): + Retry(total).increment(method='GET') def test_error_message(self): retry = Retry(total=0) - try: + with pytest.raises(MaxRetryError) as e: retry = retry.increment(method='GET', error=ReadTimeoutError(None, "/", "read timed out")) - raise AssertionError("Should have raised a MaxRetryError") - except MaxRetryError as e: - assert 'Caused by redirect' not in str(e) - self.assertEqual(str(e.reason), 'None: read timed out') + assert 'Caused by redirect' not in str(e.value) + assert str(e.value.reason) == 'None: read timed out' retry = Retry(total=1) - try: + with pytest.raises(MaxRetryError) as e: retry = retry.increment('POST', '/') retry = retry.increment('POST', '/') - raise AssertionError("Should have raised a MaxRetryError") - except MaxRetryError as e: - assert 'Caused by redirect' not in str(e) - self.assertTrue(isinstance(e.reason, ResponseError), - "%s should be a ResponseError" % e.reason) - self.assertEqual(str(e.reason), ResponseError.GENERIC_ERROR) + assert 'Caused by redirect' not in str(e.value) + assert isinstance(e.value.reason, ResponseError) + assert str(e.value.reason) == ResponseError.GENERIC_ERROR retry = Retry(total=1) - try: - response = HTTPResponse(status=500) + response = HTTPResponse(status=500) + with pytest.raises(MaxRetryError) as e: retry = retry.increment('POST', '/', response=response) retry = retry.increment('POST', '/', response=response) - raise AssertionError("Should have raised a MaxRetryError") - except MaxRetryError as e: - assert 'Caused by redirect' not in str(e) - msg = ResponseError.SPECIFIC_ERROR.format(status_code=500) - self.assertEqual(str(e.reason), msg) + assert 'Caused by redirect' not in str(e.value) + msg = ResponseError.SPECIFIC_ERROR.format(status_code=500) + assert str(e.value.reason) == msg retry = Retry(connect=1) - try: + with pytest.raises(MaxRetryError) as e: retry = retry.increment(error=ConnectTimeoutError('conntimeout')) retry = retry.increment(error=ConnectTimeoutError('conntimeout')) - raise AssertionError("Should have raised a MaxRetryError") - except MaxRetryError as e: - assert 'Caused by redirect' not in str(e) - self.assertEqual(str(e.reason), 'conntimeout') + assert 'Caused by redirect' not in str(e.value) + assert str(e.value.reason) == 'conntimeout' def test_history(self): retry = Retry(total=10, method_whitelist=frozenset(['GET', 'POST'])) - self.assertEqual(retry.history, tuple()) + assert retry.history == tuple() connection_error = ConnectTimeoutError('conntimeout') retry = retry.increment('GET', '/test1', None, connection_error) history = (RequestHistory('GET', '/test1', connection_error, None, None),) - self.assertEqual(retry.history, history) + assert retry.history == history read_error = ReadTimeoutError(None, "/test2", "read timed out") retry = retry.increment('POST', '/test2', None, read_error) history = (RequestHistory('GET', '/test1', connection_error, None, None), RequestHistory('POST', '/test2', read_error, None, None)) - self.assertEqual(retry.history, history) + assert retry.history == history response = HTTPResponse(status=500) retry = retry.increment('GET', '/test3', response, None) history = (RequestHistory('GET', '/test1', connection_error, None, None), RequestHistory('POST', '/test2', read_error, None, None), RequestHistory('GET', '/test3', None, 500, None)) - self.assertEqual(retry.history, history) + assert retry.history == history def test_retry_method_not_in_whitelist(self): error = ReadTimeoutError(None, "/", "read timed out") retry = Retry() - self.assertRaises(ReadTimeoutError, retry.increment, method='POST', error=error) + with pytest.raises(ReadTimeoutError): + retry.increment(method='POST', error=error) diff --git a/test/test_util.py b/test/test_util.py index 1885f164..8cbb2a54 100644 --- a/test/test_util.py +++ b/test/test_util.py @@ -2,12 +2,12 @@ import hashlib import warnings import logging import io -import unittest import ssl import socket from itertools import chain from mock import patch, Mock +import pytest from urllib3 import add_stderr_logger, disable_warnings from urllib3.util.request import make_headers, rewind_body, _FAILEDTELL @@ -48,106 +48,96 @@ from . import clear_warnings TIMEOUT_EPOCH = 1000 -class TestUtil(unittest.TestCase): - def test_get_host(self): - url_host_map = { - # Hosts - 'http://google.com/mail': ('http', 'google.com', None), - 'http://google.com/mail/': ('http', 'google.com', None), - 'google.com/mail': ('http', 'google.com', None), - 'http://google.com/': ('http', 'google.com', None), - 'http://google.com': ('http', 'google.com', None), - 'http://www.google.com': ('http', 'www.google.com', None), - 'http://mail.google.com': ('http', 'mail.google.com', None), - 'http://google.com:8000/mail/': ('http', 'google.com', 8000), - 'http://google.com:8000': ('http', 'google.com', 8000), - 'https://google.com': ('https', 'google.com', None), - 'https://google.com:8000': ('https', 'google.com', 8000), - 'http://user:password@127.0.0.1:1234': ('http', '127.0.0.1', 1234), - 'http://google.com/foo=http://bar:42/baz': ('http', 'google.com', None), - 'http://google.com?foo=http://bar:42/baz': ('http', 'google.com', None), - 'http://google.com#foo=http://bar:42/baz': ('http', 'google.com', None), - - # IPv4 - '173.194.35.7': ('http', '173.194.35.7', None), - 'http://173.194.35.7': ('http', '173.194.35.7', None), - 'http://173.194.35.7/test': ('http', '173.194.35.7', None), - 'http://173.194.35.7:80': ('http', '173.194.35.7', 80), - 'http://173.194.35.7:80/test': ('http', '173.194.35.7', 80), - - # IPv6 - '[2a00:1450:4001:c01::67]': ('http', '[2a00:1450:4001:c01::67]', None), - 'http://[2a00:1450:4001:c01::67]': ('http', '[2a00:1450:4001:c01::67]', None), - 'http://[2a00:1450:4001:c01::67]/test': ('http', '[2a00:1450:4001:c01::67]', None), - 'http://[2a00:1450:4001:c01::67]:80': ('http', '[2a00:1450:4001:c01::67]', 80), - 'http://[2a00:1450:4001:c01::67]:80/test': ('http', '[2a00:1450:4001:c01::67]', 80), - - # More IPv6 from http://www.ietf.org/rfc/rfc2732.txt - 'http://[fedc:ba98:7654:3210:fedc:ba98:7654:3210]:8000/index.html': ( - 'http', '[fedc:ba98:7654:3210:fedc:ba98:7654:3210]', 8000), - 'http://[1080:0:0:0:8:800:200c:417a]/index.html': ( - 'http', '[1080:0:0:0:8:800:200c:417a]', None), - 'http://[3ffe:2a00:100:7031::1]': ('http', '[3ffe:2a00:100:7031::1]', None), - 'http://[1080::8:800:200c:417a]/foo': ('http', '[1080::8:800:200c:417a]', None), - 'http://[::192.9.5.5]/ipng': ('http', '[::192.9.5.5]', None), - 'http://[::ffff:129.144.52.38]:42/index.html': ('http', '[::ffff:129.144.52.38]', 42), - 'http://[2010:836b:4179::836b:4179]': ('http', '[2010:836b:4179::836b:4179]', None), - } - for url, expected_host in url_host_map.items(): - returned_host = get_host(url) - self.assertEqual(returned_host, expected_host) - - def test_invalid_host(self): - # TODO: Add more tests - invalid_host = [ - 'http://google.com:foo', - 'http://::1/', - 'http://::1:80/', - 'http://google.com:-80', - six.u('http://google.com:\xb2\xb2'), # \xb2 = ^2 - ] - - for location in invalid_host: - self.assertRaises(LocationParseError, get_host, location) - - def test_host_normalization(self): - """ - Asserts the scheme and hosts with a normalizable scheme are - converted to lower-case. - """ - url_host_map = { - # Hosts - 'HTTP://GOOGLE.COM/mail/': ('http', 'google.com', None), - 'GOogle.COM/mail': ('http', 'google.com', None), - 'HTTP://GoOgLe.CoM:8000/mail/': ('http', 'google.com', 8000), - 'HTTP://user:password@EXAMPLE.COM:1234': ('http', 'example.com', 1234), - '173.194.35.7': ('http', '173.194.35.7', None), - 'HTTP://173.194.35.7': ('http', '173.194.35.7', None), - 'HTTP://[2a00:1450:4001:c01::67]:80/test': ('http', '[2a00:1450:4001:c01::67]', 80), - 'HTTP://[FEDC:BA98:7654:3210:FEDC:BA98:7654:3210]:8000/index.html': ( - 'http', '[fedc:ba98:7654:3210:fedc:ba98:7654:3210]', 8000), - 'HTTPS://[1080:0:0:0:8:800:200c:417A]/index.html': ( - 'https', '[1080:0:0:0:8:800:200c:417a]', None), - 'abOut://eXamPlE.com?info=1': ('about', 'eXamPlE.com', None), - 'http+UNIX://%2fvar%2frun%2fSOCKET/path': ( - 'http+unix', '%2fvar%2frun%2fSOCKET', None), - } - for url, expected_host in url_host_map.items(): - returned_host = get_host(url) - self.assertEqual(returned_host, expected_host) - - def test_parse_url_normalization(self): +class TestUtil(object): + + url_host_map = [ + # Hosts + ('http://google.com/mail', ('http', 'google.com', None)), + ('http://google.com/mail/', ('http', 'google.com', None)), + ('google.com/mail', ('http', 'google.com', None)), + ('http://google.com/', ('http', 'google.com', None)), + ('http://google.com', ('http', 'google.com', None)), + ('http://www.google.com', ('http', 'www.google.com', None)), + ('http://mail.google.com', ('http', 'mail.google.com', None)), + ('http://google.com:8000/mail/', ('http', 'google.com', 8000)), + ('http://google.com:8000', ('http', 'google.com', 8000)), + ('https://google.com', ('https', 'google.com', None)), + ('https://google.com:8000', ('https', 'google.com', 8000)), + ('http://user:password@127.0.0.1:1234', ('http', '127.0.0.1', 1234)), + ('http://google.com/foo=http://bar:42/baz', ('http', 'google.com', None)), + ('http://google.com?foo=http://bar:42/baz', ('http', 'google.com', None)), + ('http://google.com#foo=http://bar:42/baz', ('http', 'google.com', None)), + + # IPv4 + ('173.194.35.7', ('http', '173.194.35.7', None)), + ('http://173.194.35.7', ('http', '173.194.35.7', None)), + ('http://173.194.35.7/test', ('http', '173.194.35.7', None)), + ('http://173.194.35.7:80', ('http', '173.194.35.7', 80)), + ('http://173.194.35.7:80/test', ('http', '173.194.35.7', 80)), + + # IPv6 + ('[2a00:1450:4001:c01::67]', ('http', '[2a00:1450:4001:c01::67]', None)), + ('http://[2a00:1450:4001:c01::67]', ('http', '[2a00:1450:4001:c01::67]', None)), + ('http://[2a00:1450:4001:c01::67]/test', ('http', '[2a00:1450:4001:c01::67]', None)), + ('http://[2a00:1450:4001:c01::67]:80', ('http', '[2a00:1450:4001:c01::67]', 80)), + ('http://[2a00:1450:4001:c01::67]:80/test', ('http', '[2a00:1450:4001:c01::67]', 80)), + + # More IPv6 from http://www.ietf.org/rfc/rfc2732.txt + ('http://[fedc:ba98:7654:3210:fedc:ba98:7654:3210]:8000/index.html', ( + 'http', '[fedc:ba98:7654:3210:fedc:ba98:7654:3210]', 8000)), + ('http://[1080:0:0:0:8:800:200c:417a]/index.html', ( + 'http', '[1080:0:0:0:8:800:200c:417a]', None)), + ('http://[3ffe:2a00:100:7031::1]', ('http', '[3ffe:2a00:100:7031::1]', None)), + ('http://[1080::8:800:200c:417a]/foo', ('http', '[1080::8:800:200c:417a]', None)), + ('http://[::192.9.5.5]/ipng', ('http', '[::192.9.5.5]', None)), + ('http://[::ffff:129.144.52.38]:42/index.html', ('http', '[::ffff:129.144.52.38]', 42)), + ('http://[2010:836b:4179::836b:4179]', ('http', '[2010:836b:4179::836b:4179]', None)), + + # Hosts + ('HTTP://GOOGLE.COM/mail/', ('http', 'google.com', None)), + ('GOogle.COM/mail', ('http', 'google.com', None)), + ('HTTP://GoOgLe.CoM:8000/mail/', ('http', 'google.com', 8000)), + ('HTTP://user:password@EXAMPLE.COM:1234', ('http', 'example.com', 1234)), + ('173.194.35.7', ('http', '173.194.35.7', None)), + ('HTTP://173.194.35.7', ('http', '173.194.35.7', None)), + ('HTTP://[2a00:1450:4001:c01::67]:80/test', ('http', '[2a00:1450:4001:c01::67]', 80)), + ('HTTP://[FEDC:BA98:7654:3210:FEDC:BA98:7654:3210]:8000/index.html', ( + 'http', '[fedc:ba98:7654:3210:fedc:ba98:7654:3210]', 8000)), + ('HTTPS://[1080:0:0:0:8:800:200c:417A]/index.html', ( + 'https', '[1080:0:0:0:8:800:200c:417a]', None)), + ('abOut://eXamPlE.com?info=1', ('about', 'eXamPlE.com', None)), + ('http+UNIX://%2fvar%2frun%2fSOCKET/path', ( + 'http+unix', '%2fvar%2frun%2fSOCKET', None)), + ] + + @pytest.mark.parametrize('url, expected_host', url_host_map) + def test_get_host(self, url, expected_host): + returned_host = get_host(url) + assert returned_host == expected_host + + # TODO: Add more tests + @pytest.mark.parametrize('location', [ + 'http://google.com:foo', + 'http://::1/', + 'http://::1:80/', + 'http://google.com:-80', + six.u('http://google.com:\xb2\xb2'), # \xb2 = ^2 + ]) + def test_invalid_host(self, location): + with pytest.raises(LocationParseError): + get_host(location) + + @pytest.mark.parametrize('url, expected_normalized_url', [ + ('HTTP://GOOGLE.COM/MAIL/', 'http://google.com/MAIL/'), + ('HTTP://JeremyCline:Hunter2@Example.com:8080/', + 'http://JeremyCline:Hunter2@example.com:8080/'), + ('HTTPS://Example.Com/?Key=Value', 'https://example.com/?Key=Value'), + ('Https://Example.Com/#Fragment', 'https://example.com/#Fragment'), + ]) + def test_parse_url_normalization(self, url, expected_normalized_url): """Assert parse_url normalizes the scheme/host, and only the scheme/host""" - test_urls = [ - ('HTTP://GOOGLE.COM/MAIL/', 'http://google.com/MAIL/'), - ('HTTP://JeremyCline:Hunter2@Example.com:8080/', - 'http://JeremyCline:Hunter2@example.com:8080/'), - ('HTTPS://Example.Com/?Key=Value', 'https://example.com/?Key=Value'), - ('Https://Example.Com/#Fragment', 'https://example.com/#Fragment'), - ] - for url, expected_normalized_url in test_urls: - actual_normalized_url = parse_url(url).url - self.assertEqual(actual_normalized_url, expected_normalized_url) + actual_normalized_url = parse_url(url).url + assert actual_normalized_url == expected_normalized_url parse_url_host_map = [ ('http://google.com/mail', Url('http', host='google.com', path='/mail')), @@ -183,108 +173,97 @@ class TestUtil(unittest.TestCase): ('http://@', Url('http', host=None, auth='')) ] - non_round_tripping_parse_url_host_map = { + non_round_tripping_parse_url_host_map = [ # Path/query/fragment - '?': Url(path='', query=''), - '#': Url(path='', fragment=''), + ('?', Url(path='', query='')), + ('#', Url(path='', fragment='')), # Empty Port - 'http://google.com:': Url('http', host='google.com'), - 'http://google.com:/': Url('http', host='google.com', path='/'), - - } + ('http://google.com:', Url('http', host='google.com')), + ('http://google.com:/', Url('http', host='google.com', path='/')), + ] - def test_parse_url(self): - for url, expected_Url in chain(self.parse_url_host_map, - self.non_round_tripping_parse_url_host_map.items()): - returned_Url = parse_url(url) - self.assertEqual(returned_Url, expected_Url) + @pytest.mark.parametrize( + 'url, expected_url', + chain(parse_url_host_map, non_round_tripping_parse_url_host_map) + ) + def test_parse_url(self, url, expected_url): + returned_url = parse_url(url) + assert returned_url == expected_url - def test_unparse_url(self): - for url, expected_Url in self.parse_url_host_map: - self.assertEqual(url, expected_Url.url) + @pytest.mark.parametrize('url, expected_url', parse_url_host_map) + def test_unparse_url(self, url, expected_url): + assert url == expected_url.url def test_parse_url_invalid_IPv6(self): - self.assertRaises(ValueError, parse_url, '[::1') + with pytest.raises(ValueError): + parse_url('[::1') def test_Url_str(self): U = Url('http', host='google.com') - self.assertEqual(str(U), U.url) - - def test_request_uri(self): - url_host_map = { - 'http://google.com/mail': '/mail', - 'http://google.com/mail/': '/mail/', - 'http://google.com/': '/', - 'http://google.com': '/', - '': '/', - '/': '/', - '?': '/?', - '#': '/', - '/foo?bar=baz': '/foo?bar=baz', - } - for url, expected_request_uri in url_host_map.items(): - returned_url = parse_url(url) - self.assertEqual(returned_url.request_uri, expected_request_uri) - - def test_netloc(self): - url_netloc_map = { - 'http://google.com/mail': 'google.com', - 'http://google.com:80/mail': 'google.com:80', - 'google.com/foobar': 'google.com', - 'google.com:12345': 'google.com:12345', - } - - for url, expected_netloc in url_netloc_map.items(): - self.assertEqual(parse_url(url).netloc, expected_netloc) - - def test_make_headers(self): - self.assertEqual( - make_headers(accept_encoding=True), - {'accept-encoding': 'gzip,deflate'}) - - self.assertEqual( - make_headers(accept_encoding='foo,bar'), - {'accept-encoding': 'foo,bar'}) - - self.assertEqual( - make_headers(accept_encoding=['foo', 'bar']), - {'accept-encoding': 'foo,bar'}) - - self.assertEqual( - make_headers(accept_encoding=True, user_agent='banana'), - {'accept-encoding': 'gzip,deflate', 'user-agent': 'banana'}) - - self.assertEqual( - make_headers(user_agent='banana'), - {'user-agent': 'banana'}) - - self.assertEqual( - make_headers(keep_alive=True), - {'connection': 'keep-alive'}) - - self.assertEqual( - make_headers(basic_auth='foo:bar'), - {'authorization': 'Basic Zm9vOmJhcg=='}) - - self.assertEqual( - make_headers(proxy_basic_auth='foo:bar'), - {'proxy-authorization': 'Basic Zm9vOmJhcg=='}) - - self.assertEqual( - make_headers(disable_cache=True), - {'cache-control': 'no-cache'}) + assert str(U) == U.url + + request_uri_map = [ + ('http://google.com/mail', '/mail'), + ('http://google.com/mail/', '/mail/'), + ('http://google.com/', '/'), + ('http://google.com', '/'), + ('', '/'), + ('/', '/'), + ('?', '/?'), + ('#', '/'), + ('/foo?bar=baz', '/foo?bar=baz'), + ] + + @pytest.mark.parametrize('url, expected_request_uri', request_uri_map) + def test_request_uri(self, url, expected_request_uri): + returned_url = parse_url(url) + assert returned_url.request_uri == expected_request_uri + + url_netloc_map = [ + ('http://google.com/mail', 'google.com'), + ('http://google.com:80/mail', 'google.com:80'), + ('google.com/foobar', 'google.com'), + ('google.com:12345', 'google.com:12345'), + ] + + @pytest.mark.parametrize('url, expected_netloc', url_netloc_map) + def test_netloc(self, url, expected_netloc): + assert parse_url(url).netloc == expected_netloc + + @pytest.mark.parametrize('kwargs, expected', [ + ({'accept_encoding': True}, + {'accept-encoding': 'gzip,deflate'}), + ({'accept_encoding': 'foo,bar'}, + {'accept-encoding': 'foo,bar'}), + ({'accept_encoding': ['foo', 'bar']}, + {'accept-encoding': 'foo,bar'}), + ({'accept_encoding': True, 'user_agent': 'banana'}, + {'accept-encoding': 'gzip,deflate', 'user-agent': 'banana'}), + ({'user_agent': 'banana'}, + {'user-agent': 'banana'}), + ({'keep_alive': True}, + {'connection': 'keep-alive'}), + ({'basic_auth': 'foo:bar'}, + {'authorization': 'Basic Zm9vOmJhcg=='}), + ({'proxy_basic_auth': 'foo:bar'}, + {'proxy-authorization': 'Basic Zm9vOmJhcg=='}), + ({'disable_cache': True}, + {'cache-control': 'no-cache'}), + ]) + def test_make_headers(self, kwargs, expected): + assert make_headers(**kwargs) == expected def test_rewind_body(self): body = io.BytesIO(b'test data') - self.assertEqual(body.read(), b'test data') + assert body.read() == b'test data' # Assert the file object has been consumed - self.assertEqual(body.read(), b'') + assert body.read() == b'' # Rewind it back to just be b'data' rewind_body(body, 5) - self.assertEqual(body.read(), b'data') + assert body.read() == b'data' def test_rewind_body_failed_tell(self): body = io.BytesIO(b'test data') @@ -292,15 +271,18 @@ class TestUtil(unittest.TestCase): # Simulate failed tell() body_pos = _FAILEDTELL - self.assertRaises(UnrewindableBodyError, rewind_body, body, body_pos) + with pytest.raises(UnrewindableBodyError): + rewind_body(body, body_pos) def test_rewind_body_bad_position(self): body = io.BytesIO(b'test data') body.read() # Consume body # Pass non-integer position - self.assertRaises(ValueError, rewind_body, body, None) - self.assertRaises(ValueError, rewind_body, body, object()) + with pytest.raises(ValueError): + rewind_body(body, body_pos=None) + with pytest.raises(ValueError): + rewind_body(body, body_pos=object()) def test_rewind_body_failed_seek(self): class BadSeek(): @@ -308,24 +290,24 @@ class TestUtil(unittest.TestCase): def seek(self, pos, offset=0): raise IOError - self.assertRaises(UnrewindableBodyError, rewind_body, BadSeek(), 2) + with pytest.raises(UnrewindableBodyError): + rewind_body(BadSeek(), body_pos=2) - def test_split_first(self): - test_cases = { - ('abcd', 'b'): ('a', 'cd', 'b'), - ('abcd', 'cb'): ('a', 'cd', 'b'), - ('abcd', ''): ('abcd', '', None), - ('abcd', 'a'): ('', 'bcd', 'a'), - ('abcd', 'ab'): ('', 'bcd', 'a'), - } - for input, expected in test_cases.items(): - output = split_first(*input) - self.assertEqual(output, expected) + @pytest.mark.parametrize('input, expected', [ + (('abcd', 'b'), ('a', 'cd', 'b')), + (('abcd', 'cb'), ('a', 'cd', 'b')), + (('abcd', ''), ('abcd', '', None)), + (('abcd', 'a'), ('', 'bcd', 'a')), + (('abcd', 'ab'), ('', 'bcd', 'a')), + ]) + def test_split_first(self, input, expected): + output = split_first(*input) + assert output == expected def test_add_stderr_logger(self): handler = add_stderr_logger(level=logging.INFO) # Don't actually print debug logger = logging.getLogger('urllib3') - self.assertTrue(handler in logger.handlers) + assert handler in logger.handlers logger.debug('Testing add_stderr_logger') logger.removeHandler(handler) @@ -334,10 +316,10 @@ class TestUtil(unittest.TestCase): with warnings.catch_warnings(record=True) as w: clear_warnings() warnings.warn('This is a test.', InsecureRequestWarning) - self.assertEqual(len(w), 1) + assert len(w) == 1 disable_warnings() warnings.warn('This is a test.', InsecureRequestWarning) - self.assertEqual(len(w), 1) + assert len(w) == 1 def _make_time_pass(self, seconds, timeout, time_mock): """ Make some time pass for the timeout object """ @@ -346,47 +328,19 @@ class TestUtil(unittest.TestCase): time_mock.return_value = TIMEOUT_EPOCH + seconds return timeout - def test_invalid_timeouts(self): - try: - Timeout(total=-1) - self.fail("negative value should throw exception") - except ValueError as e: - self.assertTrue('less than' in str(e)) - try: - Timeout(connect=2, total=-1) - self.fail("negative value should throw exception") - except ValueError as e: - self.assertTrue('less than' in str(e)) - - try: - Timeout(read=-1) - self.fail("negative value should throw exception") - except ValueError as e: - self.assertTrue('less than' in str(e)) - - try: - Timeout(connect=False) - self.fail("boolean values should throw exception") - except ValueError as e: - self.assertTrue('cannot be a boolean' in str(e)) - - try: - Timeout(read=True) - self.fail("boolean values should throw exception") - except ValueError as e: - self.assertTrue('cannot be a boolean' in str(e)) - - try: - Timeout(connect=0) - self.fail("value <= 0 should throw exception") - except ValueError as e: - self.assertTrue('less than or equal' in str(e)) - - try: - Timeout(read="foo") - self.fail("string value should not be allowed") - except ValueError as e: - self.assertTrue('int, float or None' in str(e)) + @pytest.mark.parametrize('kwargs, message', [ + ({'total': -1}, 'less than'), + ({'connect': 2, 'total': -1}, 'less than'), + ({'read': -1}, 'less than'), + ({'connect': False}, 'cannot be a boolean'), + ({'read': True}, 'cannot be a boolean'), + ({'connect': 0}, 'less than or equal'), + ({'read': 'foo'}, 'int, float or None') + ]) + def test_invalid_timeouts(self, kwargs, message): + with pytest.raises(ValueError) as e: + Timeout(**kwargs) + assert message in str(e.value) @patch('urllib3.util.timeout.current_time') def test_timeout(self, current_time): @@ -395,70 +349,78 @@ class TestUtil(unittest.TestCase): # make 'no time' elapse timeout = self._make_time_pass(seconds=0, timeout=timeout, time_mock=current_time) - self.assertEqual(timeout.read_timeout, 3) - self.assertEqual(timeout.connect_timeout, 3) + assert timeout.read_timeout == 3 + assert timeout.connect_timeout == 3 timeout = Timeout(total=3, connect=2) - self.assertEqual(timeout.connect_timeout, 2) + assert timeout.connect_timeout == 2 timeout = Timeout() - self.assertEqual(timeout.connect_timeout, Timeout.DEFAULT_TIMEOUT) + assert timeout.connect_timeout == Timeout.DEFAULT_TIMEOUT # Connect takes 5 seconds, leaving 5 seconds for read timeout = Timeout(total=10, read=7) timeout = self._make_time_pass(seconds=5, timeout=timeout, time_mock=current_time) - self.assertEqual(timeout.read_timeout, 5) + assert timeout.read_timeout == 5 # Connect takes 2 seconds, read timeout still 7 seconds timeout = Timeout(total=10, read=7) timeout = self._make_time_pass(seconds=2, timeout=timeout, time_mock=current_time) - self.assertEqual(timeout.read_timeout, 7) + assert timeout.read_timeout == 7 timeout = Timeout(total=10, read=7) - self.assertEqual(timeout.read_timeout, 7) + assert timeout.read_timeout == 7 timeout = Timeout(total=None, read=None, connect=None) - self.assertEqual(timeout.connect_timeout, None) - self.assertEqual(timeout.read_timeout, None) - self.assertEqual(timeout.total, None) + assert timeout.connect_timeout is None + assert timeout.read_timeout is None + assert timeout.total is None timeout = Timeout(5) - self.assertEqual(timeout.total, 5) + assert timeout.total == 5 def test_timeout_str(self): timeout = Timeout(connect=1, read=2, total=3) - self.assertEqual(str(timeout), "Timeout(connect=1, read=2, total=3)") + assert str(timeout) == "Timeout(connect=1, read=2, total=3)" timeout = Timeout(connect=1, read=None, total=3) - self.assertEqual(str(timeout), "Timeout(connect=1, read=None, total=3)") + assert str(timeout) == "Timeout(connect=1, read=None, total=3)" @patch('urllib3.util.timeout.current_time') def test_timeout_elapsed(self, current_time): current_time.return_value = TIMEOUT_EPOCH timeout = Timeout(total=3) - self.assertRaises(TimeoutStateError, timeout.get_connect_duration) + with pytest.raises(TimeoutStateError): + timeout.get_connect_duration() timeout.start_connect() - self.assertRaises(TimeoutStateError, timeout.start_connect) + with pytest.raises(TimeoutStateError): + timeout.start_connect() current_time.return_value = TIMEOUT_EPOCH + 2 - self.assertEqual(timeout.get_connect_duration(), 2) + assert timeout.get_connect_duration() == 2 current_time.return_value = TIMEOUT_EPOCH + 37 - self.assertEqual(timeout.get_connect_duration(), 37) - - def test_resolve_cert_reqs(self): - self.assertEqual(resolve_cert_reqs(None), ssl.CERT_NONE) - self.assertEqual(resolve_cert_reqs(ssl.CERT_NONE), ssl.CERT_NONE) - self.assertEqual(resolve_cert_reqs(ssl.CERT_REQUIRED), ssl.CERT_REQUIRED) - self.assertEqual(resolve_cert_reqs('REQUIRED'), ssl.CERT_REQUIRED) - self.assertEqual(resolve_cert_reqs('CERT_REQUIRED'), ssl.CERT_REQUIRED) - - def test_resolve_ssl_version(self): - self.assertEqual(resolve_ssl_version(ssl.PROTOCOL_TLSv1), ssl.PROTOCOL_TLSv1) - self.assertEqual(resolve_ssl_version("PROTOCOL_TLSv1"), ssl.PROTOCOL_TLSv1) - self.assertEqual(resolve_ssl_version("TLSv1"), ssl.PROTOCOL_TLSv1) - self.assertEqual(resolve_ssl_version(ssl.PROTOCOL_SSLv23), ssl.PROTOCOL_SSLv23) + assert timeout.get_connect_duration() == 37 + + @pytest.mark.parametrize('candidate, requirements', [ + (None, ssl.CERT_NONE), + (ssl.CERT_NONE, ssl.CERT_NONE), + (ssl.CERT_REQUIRED, ssl.CERT_REQUIRED), + ('REQUIRED', ssl.CERT_REQUIRED), + ('CERT_REQUIRED', ssl.CERT_REQUIRED), + ]) + def test_resolve_cert_reqs(self, candidate, requirements): + assert resolve_cert_reqs(candidate) == requirements + + @pytest.mark.parametrize('candidate, version', [ + (ssl.PROTOCOL_TLSv1, ssl.PROTOCOL_TLSv1), + ("PROTOCOL_TLSv1", ssl.PROTOCOL_TLSv1), + ("TLSv1", ssl.PROTOCOL_TLSv1), + (ssl.PROTOCOL_SSLv23, ssl.PROTOCOL_SSLv23), + ]) + def test_resolve_ssl_version(self, candidate, version): + assert resolve_ssl_version(candidate) == version def test_is_fp_closed_object_supports_closed(self): class ClosedFile(object): @@ -466,7 +428,7 @@ class TestUtil(unittest.TestCase): def closed(self): return True - self.assertTrue(is_fp_closed(ClosedFile())) + assert is_fp_closed(ClosedFile()) def test_is_fp_closed_object_has_none_fp(self): class NoneFpFile(object): @@ -474,7 +436,7 @@ class TestUtil(unittest.TestCase): def fp(self): return None - self.assertTrue(is_fp_closed(NoneFpFile())) + assert is_fp_closed(NoneFpFile()) def test_is_fp_closed_object_has_fp(self): class FpFile(object): @@ -482,13 +444,14 @@ class TestUtil(unittest.TestCase): def fp(self): return True - self.assertTrue(not is_fp_closed(FpFile())) + assert not is_fp_closed(FpFile()) def test_is_fp_closed_object_has_neither_fp_nor_closed(self): class NotReallyAFile(object): pass - self.assertRaises(ValueError, is_fp_closed, NotReallyAFile()) + with pytest.raises(ValueError): + is_fp_closed(NotReallyAFile()) def test_ssl_wrap_socket_loads_the_cert_chain(self): socket = object() @@ -497,7 +460,8 @@ class TestUtil(unittest.TestCase): certfile='/path/to/certfile') mock_context.load_cert_chain.assert_called_once_with( - '/path/to/certfile', None) + '/path/to/certfile', None + ) @patch('urllib3.util.ssl_.create_urllib3_context') def test_ssl_wrap_socket_creates_new_context(self, @@ -515,7 +479,8 @@ class TestUtil(unittest.TestCase): ssl_wrap_socket(ssl_context=mock_context, ca_certs='/path/to/pem', sock=socket) mock_context.load_verify_locations.assert_called_once_with( - '/path/to/pem', None) + '/path/to/pem', None + ) def test_ssl_wrap_socket_loads_certificate_directories(self): socket = object() @@ -523,7 +488,8 @@ class TestUtil(unittest.TestCase): ssl_wrap_socket(ssl_context=mock_context, ca_cert_dir='/path/to/pems', sock=socket) mock_context.load_verify_locations.assert_called_once_with( - None, '/path/to/pems') + None, '/path/to/pems' + ) def test_ssl_wrap_socket_with_no_sni_warns(self): socket = object() @@ -543,55 +509,59 @@ class TestUtil(unittest.TestCase): def test_const_compare_digest_fallback(self): target = hashlib.sha256(b'abcdef').digest() - self.assertTrue(_const_compare_digest_backport(target, target)) + assert _const_compare_digest_backport(target, target) prefix = target[:-1] - self.assertFalse(_const_compare_digest_backport(target, prefix)) + assert not _const_compare_digest_backport(target, prefix) suffix = target + b'0' - self.assertFalse(_const_compare_digest_backport(target, suffix)) + assert not _const_compare_digest_backport(target, suffix) incorrect = hashlib.sha256(b'xyz').digest() - self.assertFalse(_const_compare_digest_backport(target, incorrect)) + assert not _const_compare_digest_backport(target, incorrect) def test_has_ipv6_disabled_on_compile(self): with patch('socket.has_ipv6', False): - self.assertFalse(_has_ipv6('::1')) + assert not _has_ipv6('::1') def test_has_ipv6_enabled_but_fails(self): with patch('socket.has_ipv6', True): with patch('socket.socket') as mock: instance = mock.return_value instance.bind = Mock(side_effect=Exception('No IPv6 here!')) - self.assertFalse(_has_ipv6('::1')) + assert not _has_ipv6('::1') def test_has_ipv6_enabled_and_working(self): with patch('socket.has_ipv6', True): with patch('socket.socket') as mock: instance = mock.return_value instance.bind.return_value = True - self.assertTrue(_has_ipv6('::1')) + assert _has_ipv6('::1') def test_ip_family_ipv6_enabled(self): with patch('urllib3.util.connection.HAS_IPV6', True): - self.assertEqual(allowed_gai_family(), socket.AF_UNSPEC) + assert allowed_gai_family() == socket.AF_UNSPEC def test_ip_family_ipv6_disabled(self): with patch('urllib3.util.connection.HAS_IPV6', False): - self.assertEqual(allowed_gai_family(), socket.AF_INET) - - def test_parse_retry_after(self): - invalid = [ - "-1", - "+1", - "1.0", - six.u("\xb2"), # \xb2 = ^2 - ] + assert allowed_gai_family() == socket.AF_INET + + @pytest.mark.parametrize('value', [ + "-1", + "+1", + "1.0", + six.u("\xb2"), # \xb2 = ^2 + ]) + def test_parse_retry_after_invalid(self, value): retry = Retry() - - for value in invalid: - self.assertRaises(InvalidHeader, retry.parse_retry_after, value) - - self.assertEqual(retry.parse_retry_after("0"), 0) - self.assertEqual(retry.parse_retry_after("1000"), 1000) - self.assertEqual(retry.parse_retry_after("\t42 "), 42) + with pytest.raises(InvalidHeader): + retry.parse_retry_after(value) + + @pytest.mark.parametrize('value, expected', [ + ("0", 0), + ("1000", 1000), + ("\t42 ", 42), + ]) + def test_parse_retry_after(self, value, expected): + retry = Retry() + assert retry.parse_retry_after(value) == expected diff --git a/test/with_dummyserver/test_https.py b/test/with_dummyserver/test_https.py index 837751be..dbebbea4 100644 --- a/test/with_dummyserver/test_https.py +++ b/test/with_dummyserver/test_https.py @@ -472,6 +472,7 @@ class TestHTTPS(HTTPSDummyServerTestCase): fingerprint = '92:81:FE:85:F7:0C:26:60:EC:D6:B3:BF:93:CF:F9:71:CC:07:7D:0A' conn = VerifiedHTTPSConnection(self.host, self.port) + self.addCleanup(conn.close) https_pool = HTTPSConnectionPool(self.host, self.port, cert_reqs='CERT_REQUIRED', ca_certs=DEFAULT_CA, diff --git a/test/with_dummyserver/test_poolmanager.py b/test/with_dummyserver/test_poolmanager.py index 4ae3d62d..ebfd3c56 100644 --- a/test/with_dummyserver/test_poolmanager.py +++ b/test/with_dummyserver/test_poolmanager.py @@ -210,6 +210,7 @@ class TestPoolManager(HTTPDummyServerTestCase): def test_http_with_ca_cert_dir(self): http = PoolManager(ca_certs='REQUIRED', ca_cert_dir='/nosuchdir') + self.addCleanup(http.clear) r = http.request('GET', 'http://%s:%s/' % (self.host, self.port)) self.assertEqual(r.status, 200) |