diff options
author | shiftinv <me@shiftinv.cc> | 2021-04-18 18:35:01 +0200 |
---|---|---|
committer | shiftinv <me@shiftinv.cc> | 2021-04-18 18:59:05 +0200 |
commit | e3deb4b2d7a388550d004d0109fac805dd91db41 (patch) | |
tree | 19dab7e17d6e2d69e948c9a577754a7526043c75 | |
parent | 2a4681d6342f1635da6c3ea96838956c228ba006 (diff) | |
download | requests-cache-e3deb4b2d7a388550d004d0109fac805dd91db41.tar.gz |
Improve raw response reset, update tests
-rwxr-xr-x | requests_cache/response.py | 46 | ||||
-rw-r--r-- | tests/unit/test_response.py | 34 |
2 files changed, 35 insertions, 45 deletions
diff --git a/requests_cache/response.py b/requests_cache/response.py index 299196c..a9d733b 100755 --- a/requests_cache/response.py +++ b/requests_cache/response.py @@ -1,7 +1,6 @@ """Classes to wrap cached response objects""" from copy import copy from datetime import datetime, timedelta -from functools import wraps from io import BytesIO from logging import getLogger from typing import Any, Dict, Optional, Union @@ -51,26 +50,20 @@ class CachedResponse(Response): self.request = copy(original_response.request) self.request.hooks = [] - # Read content to support streaming requests, reset file pointer on original request, - # and patch `decode_content` parameter to avoid decoding twice - self._content = original_response.content - if hasattr(original_response.raw, '_fp'): - data = self._content or b'' - original_response.raw._fp = BytesIO(data) + # Read content to support streaming requests, and reset file pointer on original request + if hasattr(original_response.raw, '_fp') and not original_response.raw.isclosed(): + # Cache raw data in `_body` + original_response.raw.read(decode_content=False, cache_content=True) + # Reset `_fp` + original_response.raw._fp = BytesIO(original_response.raw._body) + # Read and store (decoded) data + self._content = original_response.content + # Reset `_fp` again + original_response.raw._fp = BytesIO(original_response.raw._body) original_response.raw._fp_bytes_read = 0 - original_response.raw.length_remaining = len(data) - - # Only need to patch if response is encoded and `read` is bound (i.e. not patched yet) - if 'content-encoding' in original_response.headers and hasattr(original_response.raw.read, '__self__'): - orig = original_response.raw.read - - @wraps(orig) - def patched_read(amt=None, decode_content=None, *args, **kwargs): - _check_response_read_decode(decode_content, original_response.raw) - # Force decode_content to be False, as _fp already contains decoded data - return orig(amt, False, *args, **kwargs) - - original_response.raw.read = patched_read # noqa + original_response.raw.length_remaining = len(original_response.raw._body) + else: + self._content = original_response.content # Copy raw response self._raw_response = None @@ -144,8 +137,11 @@ class CachedHTTPResponse(HTTPResponse): """Simplified reader for cached content that emulates :py:meth:`urllib3.response.HTTPResponse.read()` """ - if 'content-encoding' in self.headers: - _check_response_read_decode(decode_content, self) + if 'content-encoding' in self.headers and ( + decode_content is False or (decode_content is None and not self.decode_content) + ): + # Warn if content was encoded and decode_content is set to False + logger.warning('read() returns decoded data for cached responses, even with decode_content=False set') data = self._fp.read(amt) # "close" the file to inform consumers to stop reading from it @@ -174,9 +170,3 @@ def set_response_defaults(response: AnyResponse) -> AnyResponse: response.from_cache = False response.is_expired = False return response - - -def _check_response_read_decode(decode_content: Optional[bool], response: HTTPResponse): - if decode_content is False or (decode_content is None and not response.decode_content): - # Warn if decode_content is set to False - logger.warning('read() returns decoded data, even with decode_content=False set') diff --git a/tests/unit/test_response.py b/tests/unit/test_response.py index 4f67280..f92391e 100644 --- a/tests/unit/test_response.py +++ b/tests/unit/test_response.py @@ -1,4 +1,3 @@ -import gzip import pytest from datetime import datetime, timedelta from io import BytesIO @@ -7,7 +6,7 @@ from time import sleep from urllib3.response import HTTPResponse from requests_cache import CachedHTTPResponse, CachedResponse -from tests.conftest import MOCKED_URL +from tests.conftest import MOCKED_URL, httpbin def test_basic_attrs(mock_session): @@ -72,23 +71,24 @@ def test_raw_response__reset(mock_session): assert response.raw.read(None) == b'mock response' -def test_raw_response__decode(mock_session): - """Test that a gzip-compressed raw response does not get decoded twice with decode_content""" - url = f'{MOCKED_URL}/utf-8' - mock_session.mock_adapter.register_uri( - 'GET', - url, - status_code=200, - body=BytesIO(gzip.compress(b'compressed response')), - headers={'content-encoding': 'gzip'}, - ) - response = mock_session.get(url) +def test_raw_response__decode(tempfile_session): + """Test that a gzip-compressed raw response can be manually uncompressed with decode_content""" + response = tempfile_session.get(httpbin('gzip')) + assert b'gzipped' in response.content + cached = CachedResponse(response) + assert b'gzipped' in cached.content + assert b'gzipped' in cached.raw.read(None, decode_content=True) + + +def test_raw_response__decode_stream(tempfile_session): + """Test that streamed gzip-compressed responses can be uncompressed with decode_content""" + response_uncached = tempfile_session.get(httpbin('gzip'), stream=True) + response_cached = tempfile_session.get(httpbin('gzip'), stream=True) - # Test the original response after creating a CachedResponse based on it, - # as well as the CachedResponse itself - for res in (response, cached): - assert res.raw.read(None, decode_content=True) == b'compressed response' + for res in (response_uncached, response_cached): + assert b'gzipped' in res.content + assert b'gzipped' in res.raw.read(None, decode_content=True) def test_raw_response__stream(mock_session): |