diff options
-rw-r--r-- | .gitignore | 1 | ||||
-rw-r--r-- | requests_cache/models/raw_response.py | 23 | ||||
-rwxr-xr-x | requests_cache/models/response.py | 8 | ||||
-rw-r--r-- | requests_cache/serializers/base.py | 42 | ||||
-rw-r--r-- | requests_cache/serializers/pickle.py | 16 | ||||
-rw-r--r-- | tests/integration/base_cache_test.py | 3 | ||||
-rw-r--r-- | tests/integration/base_storage_test.py | 2 | ||||
-rw-r--r-- | tests/unit/models/test_raw_response.py | 58 | ||||
-rw-r--r-- | tests/unit/models/test_request.py | 16 | ||||
-rw-r--r-- | tests/unit/models/test_response.py (renamed from tests/unit/test_response.py) | 75 | ||||
-rw-r--r-- | tests/unit/test_session.py | 22 |
11 files changed, 170 insertions, 96 deletions
@@ -8,6 +8,7 @@ build/ dist/ http_cache/ venv/ +.venv # Editors .~* diff --git a/requests_cache/models/raw_response.py b/requests_cache/models/raw_response.py index fd1875f..63c6d8b 100644 --- a/requests_cache/models/raw_response.py +++ b/requests_cache/models/raw_response.py @@ -3,15 +3,14 @@ from logging import getLogger import attr from requests import Response -from requests.structures import CaseInsensitiveDict -from urllib3.response import HTTPResponse, is_fp_closed +from urllib3.response import HTTPHeaderDict, HTTPResponse, is_fp_closed logger = getLogger(__name__) -@attr.s(auto_attribs=False, auto_detect=True, init=False, kw_only=True) +@attr.s(auto_attribs=False, auto_detect=True, kw_only=True) class CachedHTTPResponse(HTTPResponse): - """A serializable dataclass that emulates :py:class:`~urllib3.response.HTTPResponse`. + """A serializable dataclass that extends/emulates :py:class:`~urllib3.response.HTTPResponse`. Supports streaming requests and generator usage. The only action this doesn't support is explicitly calling :py:meth:`.read` with @@ -19,9 +18,9 @@ class CachedHTTPResponse(HTTPResponse): """ decode_content: bool = attr.ib(default=None) - headers: CaseInsensitiveDict = attr.ib(factory=dict) + headers: HTTPHeaderDict = attr.ib(factory=dict) reason: str = attr.ib(default=None) - request_url: str = attr.ib(default=None) + request_url: str = attr.ib(default=None) # TODO: Not available in urllib <=1.21. Is this needed? status: int = attr.ib(default=0) strict: int = attr.ib(default=0) version: int = attr.ib(default=0) @@ -71,9 +70,15 @@ class CachedHTTPResponse(HTTPResponse): self._fp.close() return data - def reset(self): - """Reset raw response file pointer""" - self._fp = BytesIO(self._body) + def reset(self, body: bytes = None): + """Reset raw response file pointer, and optionally update content""" + if body is not None: + self._body = body + self._fp = BytesIO(self._body or b'') + + def set_content(self, body: bytes): + self._body = body + self.reset() def stream(self, amt=None, **kwargs): """Simplified generator over cached content that emulates diff --git a/requests_cache/models/response.py b/requests_cache/models/response.py index eddffeb..86b21b9 100755 --- a/requests_cache/models/response.py +++ b/requests_cache/models/response.py @@ -34,7 +34,8 @@ class CachedResponse(Response): saves a bit of memory and deserialization steps when those objects aren't accessed. """ - _content: bytes = attr.ib(default=b'', repr=False, converter=lambda x: x or b'') + # _content: bytes = attr.ib(default=b'', repr=False, converter=lambda x: x or b'') + _content: bytes = attr.ib(default=None) url: str = attr.ib(default=None) status_code: int = attr.ib(default=0) cookies: RequestsCookieJar = attr.ib(factory=dict) @@ -48,6 +49,11 @@ class CachedResponse(Response): request: CachedRequest = attr.ib(factory=CachedRequest) raw: CachedHTTPResponse = attr.ib(factory=CachedHTTPResponse, repr=False) + def __attrs_post_init__(self): + """Re-initialize raw response body after deserialization""" + if self.raw._body is None and self._content is not None: + self.raw.reset(self._content) + @classmethod def from_response(cls, original_response: Response, **kwargs): """Create a CachedResponse based on an original response object""" diff --git a/requests_cache/serializers/base.py b/requests_cache/serializers/base.py index bc2b732..144f075 100644 --- a/requests_cache/serializers/base.py +++ b/requests_cache/serializers/base.py @@ -1,10 +1,11 @@ from abc import abstractmethod from datetime import datetime, timedelta -from typing import Dict +from typing import Any import cattr from requests.cookies import RequestsCookieJar, cookiejar_from_dict from requests.structures import CaseInsensitiveDict +from urllib3.response import HTTPHeaderDict from ..models import CachedResponse @@ -20,25 +21,28 @@ class BaseSerializer: def __init__(self, *args, **kwargs): """Make a converter to structure and unstructure some of the nested objects within a response""" super().__init__(*args, **kwargs) - converter = cattr.Converter() + try: + # raise AttributeError + converter = cattr.GenConverter(omit_if_default=True) + # Python 3.6 compatibility + except AttributeError: + converter = cattr.Converter() # Convert datetimes to and from iso-formatted strings converter.register_unstructure_hook(datetime, lambda obj: obj.isoformat() if obj else None) - converter.register_structure_hook( - datetime, lambda obj, cls: datetime.fromisoformat(obj) if obj else None - ) + converter.register_structure_hook(datetime, to_datetime) # Convert timedeltas to and from float values in seconds converter.register_unstructure_hook(timedelta, lambda obj: obj.total_seconds() if obj else None) - converter.register_structure_hook( - timedelta, lambda obj, cls: timedelta(seconds=obj) if obj else None - ) + converter.register_structure_hook(timedelta, to_timedelta) # Convert dict-like objects to and from plain dicts converter.register_unstructure_hook(RequestsCookieJar, lambda obj: dict(obj.items())) converter.register_structure_hook(RequestsCookieJar, lambda obj, cls: cookiejar_from_dict(obj)) converter.register_unstructure_hook(CaseInsensitiveDict, dict) converter.register_structure_hook(CaseInsensitiveDict, lambda obj, cls: CaseInsensitiveDict(obj)) + converter.register_unstructure_hook(HTTPHeaderDict, dict) + converter.register_structure_hook(HTTPHeaderDict, lambda obj, cls: HTTPHeaderDict(obj)) # Not sure yet if this will be needed # converter.register_unstructure_hook(PreparedRequest, CachedRequest.from_request) @@ -49,10 +53,14 @@ class BaseSerializer: self.converter = converter - def unstructure(self, response: CachedResponse) -> Dict: - return self.converter.unstructure(response) + def unstructure(self, obj: Any) -> Any: + if not isinstance(obj, CachedResponse): + return obj + return self.converter.unstructure(obj) - def structure(self, obj: Dict) -> CachedResponse: + def structure(self, obj: Any) -> Any: + if not isinstance(obj, dict): + return obj return self.converter.structure(obj, CachedResponse) @abstractmethod @@ -62,3 +70,15 @@ class BaseSerializer: @abstractmethod def loads(self, obj) -> CachedResponse: pass + + +def to_datetime(obj, cls) -> datetime: + if isinstance(obj, str): + obj = datetime.fromisoformat(obj) + return obj + + +def to_timedelta(obj, cls) -> timedelta: + if isinstance(obj, (int, float)): + obj = timedelta(seconds=obj) + return obj diff --git a/requests_cache/serializers/pickle.py b/requests_cache/serializers/pickle.py index c1a5726..86a4e59 100644 --- a/requests_cache/serializers/pickle.py +++ b/requests_cache/serializers/pickle.py @@ -16,19 +16,15 @@ class PickleSerializer(BaseSerializer): return super().structure(pickle.loads(obj)) -class SafePickleSerializer(BaseSerializer, SafeSerializer): +class SafePickleSerializer(SafeSerializer, BaseSerializer): """Wrapper for itsdangerous + pickle that pre/post-processes with cattrs""" def __init__(self, *args, **kwargs): + # super().__init__(*args, **kwargs, serializer=pickle) super().__init__(*args, **kwargs, serializer=PickleSerializer()) - def dumps(self, response: CachedResponse) -> bytes: - x = super().unstructure(response) - # breakpoint() - return SafeSerializer.dumps(self, x) + # def dumps(self, response: CachedResponse) -> bytes: + # return SafeSerializer.dumps(self, super().unstructure(response)) - # TODO: Something weird is going on here - def loads(self, obj: bytes) -> CachedResponse: - return SafeSerializer.loads(self, obj) - # breakpoint() - return super().structure(SafeSerializer.loads(self, obj)) + # def loads(self, obj: bytes) -> CachedResponse: + # return super().structure(SafeSerializer.loads(self, obj)) diff --git a/tests/integration/base_cache_test.py b/tests/integration/base_cache_test.py index 67d512c..d7267a1 100644 --- a/tests/integration/base_cache_test.py +++ b/tests/integration/base_cache_test.py @@ -124,8 +124,9 @@ class BaseCacheTest: assert b'gzipped' in response.content if stream is True: assert b'gzipped' in response.raw.read(None, decode_content=True) + response.raw._fp = BytesIO(response.content) - cached_response = CachedResponse(response) + cached_response = CachedResponse.from_response(response) assert b'gzipped' in cached_response.content assert b'gzipped' in cached_response.raw.read(None, decode_content=True) diff --git a/tests/integration/base_storage_test.py b/tests/integration/base_storage_test.py index 7c83d6f..c0f9ce1 100644 --- a/tests/integration/base_storage_test.py +++ b/tests/integration/base_storage_test.py @@ -2,7 +2,7 @@ import pytest from typing import Dict, Type -from requests_cache.backends.base import BaseStorage +from requests_cache.backends import BaseStorage from tests.conftest import CACHE_NAME diff --git a/tests/unit/models/test_raw_response.py b/tests/unit/models/test_raw_response.py new file mode 100644 index 0000000..bdcd373 --- /dev/null +++ b/tests/unit/models/test_raw_response.py @@ -0,0 +1,58 @@ +from io import BytesIO + +from requests_cache.models import CachedHTTPResponse +from tests.conftest import MOCKED_URL + + +def test_from_response(mock_session): + response = mock_session.get(MOCKED_URL) + response.raw._fp = BytesIO(b'mock response') + raw = CachedHTTPResponse.from_response(response) + + assert dict(response.raw.headers) == dict(raw.headers) == {'Content-Type': 'text/plain'} + assert raw.read(None) == b'mock response' + assert response.raw.decode_content is raw.decode_content is False + assert response.raw.reason is raw.reason is None + assert response.raw._request_url is raw.request_url is None + assert response.raw.status == raw.status == 200 + assert response.raw.strict == raw.strict == 0 + assert response.raw.version == raw.version == 0 + + +def test_read(): + raw = CachedHTTPResponse(body=b'mock response') + assert raw.read(10) == b'mock respo' + assert raw.read(None) == b'nse' + assert raw.read(1) == b'' + assert raw._fp.closed is True + + +def test_close(): + raw = CachedHTTPResponse(body=b'mock response') + raw.close() + assert raw._fp.closed is True + + +def test_reset(): + raw = CachedHTTPResponse(body=b'mock response') + raw.read(None) + assert raw.read(1) == b'' + assert raw._fp.closed is True + + raw.reset() + assert raw.read(None) == b'mock response' + + +def test_set_content(): + raw = CachedHTTPResponse(body=None) + raw.set_content(b'mock response') + assert raw.read() == b'mock response' + + +def test_stream(): + raw = CachedHTTPResponse(body=b'mock response') + data = b'' + for chunk in raw.stream(1): + data += chunk + assert data == b'mock response' + assert raw._fp.closed diff --git a/tests/unit/models/test_request.py b/tests/unit/models/test_request.py new file mode 100644 index 0000000..8e917ff --- /dev/null +++ b/tests/unit/models/test_request.py @@ -0,0 +1,16 @@ +from requests.utils import default_headers + +from requests_cache.models.response import CachedRequest +from tests.conftest import MOCKED_URL + + +def test_from_request(mock_session): + response = mock_session.get(MOCKED_URL, data=b'mock request', headers={'foo': 'bar'}) + request = CachedRequest.from_request(response.request) + expected_headers = {**default_headers(), 'Content-Length': '12', 'foo': 'bar'} + + assert response.request.body == request.body == b'mock request' + assert response.request._cookies == request.cookies == {} + assert response.request.headers == request.headers == expected_headers + assert response.request.method == request.method == 'GET' + assert response.request.url == request.url == MOCKED_URL diff --git a/tests/unit/test_response.py b/tests/unit/models/test_response.py index 1eb5e64..67c1006 100644 --- a/tests/unit/test_response.py +++ b/tests/unit/models/test_response.py @@ -5,13 +5,12 @@ from time import sleep from urllib3.response import HTTPResponse -from requests_cache import CachedHTTPResponse, CachedResponse -from requests_cache.response import format_file_size +from requests_cache.models.response import CachedResponse, format_file_size from tests.conftest import MOCKED_URL def test_basic_attrs(mock_session): - response = CachedResponse(mock_session.get(MOCKED_URL)) + response = CachedResponse.from_response(mock_session.get(MOCKED_URL)) assert response.from_cache is True assert response.url == MOCKED_URL @@ -25,62 +24,28 @@ def test_basic_attrs(mock_session): assert response.is_expired is False +def test_history(mock_session): + original_response = mock_session.get(MOCKED_URL) + original_response.history = [mock_session.get(MOCKED_URL)] * 3 + response = CachedResponse.from_response(original_response) + assert len(response.history) == 3 + assert all([isinstance(r, CachedResponse) for r in response.history]) + + @pytest.mark.parametrize( - 'expire_after, is_expired', + 'expires, is_expired', [ (datetime.utcnow() + timedelta(days=1), False), (datetime.utcnow() - timedelta(days=1), True), ], ) -def test_is_expired(expire_after, is_expired, mock_session): - response = CachedResponse(mock_session.get(MOCKED_URL), expire_after) +def test_is_expired(expires, is_expired, mock_session): + response = CachedResponse.from_response(mock_session.get(MOCKED_URL), expires=expires) assert response.from_cache is True assert response.is_expired == is_expired -def test_history(mock_session): - original_response = mock_session.get(MOCKED_URL) - original_response.history = [mock_session.get(MOCKED_URL)] * 3 - response = CachedResponse(original_response) - assert len(response.history) == 3 - assert all([isinstance(r, CachedResponse) for r in response.history]) - - -def test_raw_response__read(mock_session): - response = CachedResponse(mock_session.get(MOCKED_URL)) - assert isinstance(response.raw, CachedHTTPResponse) - assert response.raw.read(10) == b'mock respo' - assert response.raw.read(None) == b'nse' - assert response.raw.read(1) == b'' - assert response.raw._fp.closed is True - - -def test_raw_response__close(mock_session): - response = CachedResponse(mock_session.get(MOCKED_URL)) - response.close() - assert response.raw._fp.closed is True - - -def test_raw_response__reset(mock_session): - response = CachedResponse(mock_session.get(MOCKED_URL)) - response.raw.read(None) - assert response.raw.read(1) == b'' - assert response.raw._fp.closed is True - - response.reset() - assert response.raw.read(None) == b'mock response' - - -def test_raw_response__stream(mock_session): - response = CachedResponse(mock_session.get(MOCKED_URL)) - data = b'' - for chunk in response.raw.stream(1): - data += chunk - assert data == b'mock response' - assert response.raw._fp.closed - - -def test_raw_response__iterator(mock_session): +def test_iterator(mock_session): # Set up mock response with streamed content url = f'{MOCKED_URL}/stream' mock_raw_response = HTTPResponse( @@ -112,7 +77,7 @@ def test_raw_response__iterator(mock_session): def test_revalidate__extend_expiration(mock_session): # Start with an expired response - response = CachedResponse( + response = CachedResponse.from_response( mock_session.get(MOCKED_URL), expires=datetime.utcnow() - timedelta(seconds=0.01), ) @@ -127,7 +92,7 @@ def test_revalidate__extend_expiration(mock_session): def test_revalidate__shorten_expiration(mock_session): # Start with a non-expired response - response = CachedResponse( + response = CachedResponse.from_response( mock_session.get(MOCKED_URL), expires=datetime.utcnow() + timedelta(seconds=1), ) @@ -139,7 +104,7 @@ def test_revalidate__shorten_expiration(mock_session): def test_size(mock_session): - response = CachedResponse(mock_session.get(MOCKED_URL)) + response = CachedResponse.from_response(mock_session.get(MOCKED_URL)) response._content = None assert response.size == 0 response._content = b'1' * 1024 @@ -150,7 +115,7 @@ def test_str(mock_session): """Just ensure that a subset of relevant attrs get included in the response str; the format may change without breaking the test. """ - response = CachedResponse(mock_session.get(MOCKED_URL)) + response = CachedResponse.from_response(mock_session.get(MOCKED_URL)) response._content = b'1010' expected_values = ['GET', MOCKED_URL, 200, '4 bytes', 'created', 'expires', 'fresh'] assert all([str(v) in str(response) for v in expected_values]) @@ -158,10 +123,10 @@ def test_str(mock_session): def test_repr(mock_session): """Just ensure that a subset of relevant attrs get included in the response repr""" - response = CachedResponse(mock_session.get(MOCKED_URL)) + response = CachedResponse.from_response(mock_session.get(MOCKED_URL)) expected_values = ['GET', MOCKED_URL, 200, 'ISO-8859-1', response.headers] print(repr(response)) - assert repr(response).startswith('<CachedResponse(') and repr(response).endswith(')>') + assert repr(response).startswith('CachedResponse(') and repr(response).endswith(')') assert all([str(v) in repr(response) for v in expected_values]) diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 7ec4265..e958212 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -12,13 +12,18 @@ from uuid import uuid4 import requests from itsdangerous.exc import BadSignature -from itsdangerous.serializer import Serializer from requests.structures import CaseInsensitiveDict from requests_cache import ALL_METHODS, CachedSession -from requests_cache.backends import BACKEND_CLASSES, BaseCache, get_placeholder_backend -from requests_cache.backends.sqlite import DbDict, DbPickleDict -from requests_cache.response import CachedResponse +from requests_cache.backends import ( + BACKEND_CLASSES, + BaseCache, + DbDict, + DbPickleDict, + get_placeholder_backend, +) +from requests_cache.models import CachedResponse +from requests_cache.serializers import PickleSerializer, SafePickleSerializer from tests.conftest import ( MOCKED_URL, MOCKED_URL_404, @@ -559,16 +564,17 @@ def test_unpickle_errors(mock_session): def test_cache_signing(tempfile_path): session = CachedSession(tempfile_path) - assert session.cache.responses._serializer == pickle + assert isinstance(session.cache.responses.serializer, PickleSerializer) # With a secret key, itsdangerous should be used secret_key = str(uuid4()) session = CachedSession(tempfile_path, secret_key=secret_key) - assert isinstance(session.cache.responses._serializer, Serializer) + assert isinstance(session.cache.responses.serializer, SafePickleSerializer) # Simple serialize/deserialize round trip - session.cache.responses['key'] = 'value' - assert session.cache.responses['key'] == 'value' + response = CachedResponse() + session.cache.responses['key'] = response + assert session.cache.responses['key'] == response # Without the same signing key, the item shouldn't be considered safe to deserialize session = CachedSession(tempfile_path, secret_key='a different key') |