summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorshiftinv <me@shiftinv.cc>2021-04-18 18:35:01 +0200
committershiftinv <me@shiftinv.cc>2021-04-18 18:59:05 +0200
commite3deb4b2d7a388550d004d0109fac805dd91db41 (patch)
tree19dab7e17d6e2d69e948c9a577754a7526043c75
parent2a4681d6342f1635da6c3ea96838956c228ba006 (diff)
downloadrequests-cache-e3deb4b2d7a388550d004d0109fac805dd91db41.tar.gz
Improve raw response reset, update tests
-rwxr-xr-xrequests_cache/response.py46
-rw-r--r--tests/unit/test_response.py34
2 files changed, 35 insertions, 45 deletions
diff --git a/requests_cache/response.py b/requests_cache/response.py
index 299196c..a9d733b 100755
--- a/requests_cache/response.py
+++ b/requests_cache/response.py
@@ -1,7 +1,6 @@
"""Classes to wrap cached response objects"""
from copy import copy
from datetime import datetime, timedelta
-from functools import wraps
from io import BytesIO
from logging import getLogger
from typing import Any, Dict, Optional, Union
@@ -51,26 +50,20 @@ class CachedResponse(Response):
self.request = copy(original_response.request)
self.request.hooks = []
- # Read content to support streaming requests, reset file pointer on original request,
- # and patch `decode_content` parameter to avoid decoding twice
- self._content = original_response.content
- if hasattr(original_response.raw, '_fp'):
- data = self._content or b''
- original_response.raw._fp = BytesIO(data)
+ # Read content to support streaming requests, and reset file pointer on original request
+ if hasattr(original_response.raw, '_fp') and not original_response.raw.isclosed():
+ # Cache raw data in `_body`
+ original_response.raw.read(decode_content=False, cache_content=True)
+ # Reset `_fp`
+ original_response.raw._fp = BytesIO(original_response.raw._body)
+ # Read and store (decoded) data
+ self._content = original_response.content
+ # Reset `_fp` again
+ original_response.raw._fp = BytesIO(original_response.raw._body)
original_response.raw._fp_bytes_read = 0
- original_response.raw.length_remaining = len(data)
-
- # Only need to patch if response is encoded and `read` is bound (i.e. not patched yet)
- if 'content-encoding' in original_response.headers and hasattr(original_response.raw.read, '__self__'):
- orig = original_response.raw.read
-
- @wraps(orig)
- def patched_read(amt=None, decode_content=None, *args, **kwargs):
- _check_response_read_decode(decode_content, original_response.raw)
- # Force decode_content to be False, as _fp already contains decoded data
- return orig(amt, False, *args, **kwargs)
-
- original_response.raw.read = patched_read # noqa
+ original_response.raw.length_remaining = len(original_response.raw._body)
+ else:
+ self._content = original_response.content
# Copy raw response
self._raw_response = None
@@ -144,8 +137,11 @@ class CachedHTTPResponse(HTTPResponse):
"""Simplified reader for cached content that emulates
:py:meth:`urllib3.response.HTTPResponse.read()`
"""
- if 'content-encoding' in self.headers:
- _check_response_read_decode(decode_content, self)
+ if 'content-encoding' in self.headers and (
+ decode_content is False or (decode_content is None and not self.decode_content)
+ ):
+ # Warn if content was encoded and decode_content is set to False
+ logger.warning('read() returns decoded data for cached responses, even with decode_content=False set')
data = self._fp.read(amt)
# "close" the file to inform consumers to stop reading from it
@@ -174,9 +170,3 @@ def set_response_defaults(response: AnyResponse) -> AnyResponse:
response.from_cache = False
response.is_expired = False
return response
-
-
-def _check_response_read_decode(decode_content: Optional[bool], response: HTTPResponse):
- if decode_content is False or (decode_content is None and not response.decode_content):
- # Warn if decode_content is set to False
- logger.warning('read() returns decoded data, even with decode_content=False set')
diff --git a/tests/unit/test_response.py b/tests/unit/test_response.py
index 4f67280..f92391e 100644
--- a/tests/unit/test_response.py
+++ b/tests/unit/test_response.py
@@ -1,4 +1,3 @@
-import gzip
import pytest
from datetime import datetime, timedelta
from io import BytesIO
@@ -7,7 +6,7 @@ from time import sleep
from urllib3.response import HTTPResponse
from requests_cache import CachedHTTPResponse, CachedResponse
-from tests.conftest import MOCKED_URL
+from tests.conftest import MOCKED_URL, httpbin
def test_basic_attrs(mock_session):
@@ -72,23 +71,24 @@ def test_raw_response__reset(mock_session):
assert response.raw.read(None) == b'mock response'
-def test_raw_response__decode(mock_session):
- """Test that a gzip-compressed raw response does not get decoded twice with decode_content"""
- url = f'{MOCKED_URL}/utf-8'
- mock_session.mock_adapter.register_uri(
- 'GET',
- url,
- status_code=200,
- body=BytesIO(gzip.compress(b'compressed response')),
- headers={'content-encoding': 'gzip'},
- )
- response = mock_session.get(url)
+def test_raw_response__decode(tempfile_session):
+ """Test that a gzip-compressed raw response can be manually uncompressed with decode_content"""
+ response = tempfile_session.get(httpbin('gzip'))
+ assert b'gzipped' in response.content
+
cached = CachedResponse(response)
+ assert b'gzipped' in cached.content
+ assert b'gzipped' in cached.raw.read(None, decode_content=True)
+
+
+def test_raw_response__decode_stream(tempfile_session):
+ """Test that streamed gzip-compressed responses can be uncompressed with decode_content"""
+ response_uncached = tempfile_session.get(httpbin('gzip'), stream=True)
+ response_cached = tempfile_session.get(httpbin('gzip'), stream=True)
- # Test the original response after creating a CachedResponse based on it,
- # as well as the CachedResponse itself
- for res in (response, cached):
- assert res.raw.read(None, decode_content=True) == b'compressed response'
+ for res in (response_uncached, response_cached):
+ assert b'gzipped' in res.content
+ assert b'gzipped' in res.raw.read(None, decode_content=True)
def test_raw_response__stream(mock_session):