summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--.gitignore1
-rw-r--r--requests_cache/models/raw_response.py23
-rwxr-xr-xrequests_cache/models/response.py8
-rw-r--r--requests_cache/serializers/base.py42
-rw-r--r--requests_cache/serializers/pickle.py16
-rw-r--r--tests/integration/base_cache_test.py3
-rw-r--r--tests/integration/base_storage_test.py2
-rw-r--r--tests/unit/models/test_raw_response.py58
-rw-r--r--tests/unit/models/test_request.py16
-rw-r--r--tests/unit/models/test_response.py (renamed from tests/unit/test_response.py)75
-rw-r--r--tests/unit/test_session.py22
11 files changed, 170 insertions, 96 deletions
diff --git a/.gitignore b/.gitignore
index 04cbefb..cc234f2 100644
--- a/.gitignore
+++ b/.gitignore
@@ -8,6 +8,7 @@ build/
dist/
http_cache/
venv/
+.venv
# Editors
.~*
diff --git a/requests_cache/models/raw_response.py b/requests_cache/models/raw_response.py
index fd1875f..63c6d8b 100644
--- a/requests_cache/models/raw_response.py
+++ b/requests_cache/models/raw_response.py
@@ -3,15 +3,14 @@ from logging import getLogger
import attr
from requests import Response
-from requests.structures import CaseInsensitiveDict
-from urllib3.response import HTTPResponse, is_fp_closed
+from urllib3.response import HTTPHeaderDict, HTTPResponse, is_fp_closed
logger = getLogger(__name__)
-@attr.s(auto_attribs=False, auto_detect=True, init=False, kw_only=True)
+@attr.s(auto_attribs=False, auto_detect=True, kw_only=True)
class CachedHTTPResponse(HTTPResponse):
- """A serializable dataclass that emulates :py:class:`~urllib3.response.HTTPResponse`.
+ """A serializable dataclass that extends/emulates :py:class:`~urllib3.response.HTTPResponse`.
Supports streaming requests and generator usage.
The only action this doesn't support is explicitly calling :py:meth:`.read` with
@@ -19,9 +18,9 @@ class CachedHTTPResponse(HTTPResponse):
"""
decode_content: bool = attr.ib(default=None)
- headers: CaseInsensitiveDict = attr.ib(factory=dict)
+ headers: HTTPHeaderDict = attr.ib(factory=dict)
reason: str = attr.ib(default=None)
- request_url: str = attr.ib(default=None)
+ request_url: str = attr.ib(default=None) # TODO: Not available in urllib <=1.21. Is this needed?
status: int = attr.ib(default=0)
strict: int = attr.ib(default=0)
version: int = attr.ib(default=0)
@@ -71,9 +70,15 @@ class CachedHTTPResponse(HTTPResponse):
self._fp.close()
return data
- def reset(self):
- """Reset raw response file pointer"""
- self._fp = BytesIO(self._body)
+ def reset(self, body: bytes = None):
+ """Reset raw response file pointer, and optionally update content"""
+ if body is not None:
+ self._body = body
+ self._fp = BytesIO(self._body or b'')
+
+ def set_content(self, body: bytes):
+ self._body = body
+ self.reset()
def stream(self, amt=None, **kwargs):
"""Simplified generator over cached content that emulates
diff --git a/requests_cache/models/response.py b/requests_cache/models/response.py
index eddffeb..86b21b9 100755
--- a/requests_cache/models/response.py
+++ b/requests_cache/models/response.py
@@ -34,7 +34,8 @@ class CachedResponse(Response):
saves a bit of memory and deserialization steps when those objects aren't accessed.
"""
- _content: bytes = attr.ib(default=b'', repr=False, converter=lambda x: x or b'')
+ # _content: bytes = attr.ib(default=b'', repr=False, converter=lambda x: x or b'')
+ _content: bytes = attr.ib(default=None)
url: str = attr.ib(default=None)
status_code: int = attr.ib(default=0)
cookies: RequestsCookieJar = attr.ib(factory=dict)
@@ -48,6 +49,11 @@ class CachedResponse(Response):
request: CachedRequest = attr.ib(factory=CachedRequest)
raw: CachedHTTPResponse = attr.ib(factory=CachedHTTPResponse, repr=False)
+ def __attrs_post_init__(self):
+ """Re-initialize raw response body after deserialization"""
+ if self.raw._body is None and self._content is not None:
+ self.raw.reset(self._content)
+
@classmethod
def from_response(cls, original_response: Response, **kwargs):
"""Create a CachedResponse based on an original response object"""
diff --git a/requests_cache/serializers/base.py b/requests_cache/serializers/base.py
index bc2b732..144f075 100644
--- a/requests_cache/serializers/base.py
+++ b/requests_cache/serializers/base.py
@@ -1,10 +1,11 @@
from abc import abstractmethod
from datetime import datetime, timedelta
-from typing import Dict
+from typing import Any
import cattr
from requests.cookies import RequestsCookieJar, cookiejar_from_dict
from requests.structures import CaseInsensitiveDict
+from urllib3.response import HTTPHeaderDict
from ..models import CachedResponse
@@ -20,25 +21,28 @@ class BaseSerializer:
def __init__(self, *args, **kwargs):
"""Make a converter to structure and unstructure some of the nested objects within a response"""
super().__init__(*args, **kwargs)
- converter = cattr.Converter()
+ try:
+ # raise AttributeError
+ converter = cattr.GenConverter(omit_if_default=True)
+ # Python 3.6 compatibility
+ except AttributeError:
+ converter = cattr.Converter()
# Convert datetimes to and from iso-formatted strings
converter.register_unstructure_hook(datetime, lambda obj: obj.isoformat() if obj else None)
- converter.register_structure_hook(
- datetime, lambda obj, cls: datetime.fromisoformat(obj) if obj else None
- )
+ converter.register_structure_hook(datetime, to_datetime)
# Convert timedeltas to and from float values in seconds
converter.register_unstructure_hook(timedelta, lambda obj: obj.total_seconds() if obj else None)
- converter.register_structure_hook(
- timedelta, lambda obj, cls: timedelta(seconds=obj) if obj else None
- )
+ converter.register_structure_hook(timedelta, to_timedelta)
# Convert dict-like objects to and from plain dicts
converter.register_unstructure_hook(RequestsCookieJar, lambda obj: dict(obj.items()))
converter.register_structure_hook(RequestsCookieJar, lambda obj, cls: cookiejar_from_dict(obj))
converter.register_unstructure_hook(CaseInsensitiveDict, dict)
converter.register_structure_hook(CaseInsensitiveDict, lambda obj, cls: CaseInsensitiveDict(obj))
+ converter.register_unstructure_hook(HTTPHeaderDict, dict)
+ converter.register_structure_hook(HTTPHeaderDict, lambda obj, cls: HTTPHeaderDict(obj))
# Not sure yet if this will be needed
# converter.register_unstructure_hook(PreparedRequest, CachedRequest.from_request)
@@ -49,10 +53,14 @@ class BaseSerializer:
self.converter = converter
- def unstructure(self, response: CachedResponse) -> Dict:
- return self.converter.unstructure(response)
+ def unstructure(self, obj: Any) -> Any:
+ if not isinstance(obj, CachedResponse):
+ return obj
+ return self.converter.unstructure(obj)
- def structure(self, obj: Dict) -> CachedResponse:
+ def structure(self, obj: Any) -> Any:
+ if not isinstance(obj, dict):
+ return obj
return self.converter.structure(obj, CachedResponse)
@abstractmethod
@@ -62,3 +70,15 @@ class BaseSerializer:
@abstractmethod
def loads(self, obj) -> CachedResponse:
pass
+
+
+def to_datetime(obj, cls) -> datetime:
+ if isinstance(obj, str):
+ obj = datetime.fromisoformat(obj)
+ return obj
+
+
+def to_timedelta(obj, cls) -> timedelta:
+ if isinstance(obj, (int, float)):
+ obj = timedelta(seconds=obj)
+ return obj
diff --git a/requests_cache/serializers/pickle.py b/requests_cache/serializers/pickle.py
index c1a5726..86a4e59 100644
--- a/requests_cache/serializers/pickle.py
+++ b/requests_cache/serializers/pickle.py
@@ -16,19 +16,15 @@ class PickleSerializer(BaseSerializer):
return super().structure(pickle.loads(obj))
-class SafePickleSerializer(BaseSerializer, SafeSerializer):
+class SafePickleSerializer(SafeSerializer, BaseSerializer):
"""Wrapper for itsdangerous + pickle that pre/post-processes with cattrs"""
def __init__(self, *args, **kwargs):
+ # super().__init__(*args, **kwargs, serializer=pickle)
super().__init__(*args, **kwargs, serializer=PickleSerializer())
- def dumps(self, response: CachedResponse) -> bytes:
- x = super().unstructure(response)
- # breakpoint()
- return SafeSerializer.dumps(self, x)
+ # def dumps(self, response: CachedResponse) -> bytes:
+ # return SafeSerializer.dumps(self, super().unstructure(response))
- # TODO: Something weird is going on here
- def loads(self, obj: bytes) -> CachedResponse:
- return SafeSerializer.loads(self, obj)
- # breakpoint()
- return super().structure(SafeSerializer.loads(self, obj))
+ # def loads(self, obj: bytes) -> CachedResponse:
+ # return super().structure(SafeSerializer.loads(self, obj))
diff --git a/tests/integration/base_cache_test.py b/tests/integration/base_cache_test.py
index 67d512c..d7267a1 100644
--- a/tests/integration/base_cache_test.py
+++ b/tests/integration/base_cache_test.py
@@ -124,8 +124,9 @@ class BaseCacheTest:
assert b'gzipped' in response.content
if stream is True:
assert b'gzipped' in response.raw.read(None, decode_content=True)
+ response.raw._fp = BytesIO(response.content)
- cached_response = CachedResponse(response)
+ cached_response = CachedResponse.from_response(response)
assert b'gzipped' in cached_response.content
assert b'gzipped' in cached_response.raw.read(None, decode_content=True)
diff --git a/tests/integration/base_storage_test.py b/tests/integration/base_storage_test.py
index 7c83d6f..c0f9ce1 100644
--- a/tests/integration/base_storage_test.py
+++ b/tests/integration/base_storage_test.py
@@ -2,7 +2,7 @@
import pytest
from typing import Dict, Type
-from requests_cache.backends.base import BaseStorage
+from requests_cache.backends import BaseStorage
from tests.conftest import CACHE_NAME
diff --git a/tests/unit/models/test_raw_response.py b/tests/unit/models/test_raw_response.py
new file mode 100644
index 0000000..bdcd373
--- /dev/null
+++ b/tests/unit/models/test_raw_response.py
@@ -0,0 +1,58 @@
+from io import BytesIO
+
+from requests_cache.models import CachedHTTPResponse
+from tests.conftest import MOCKED_URL
+
+
+def test_from_response(mock_session):
+ response = mock_session.get(MOCKED_URL)
+ response.raw._fp = BytesIO(b'mock response')
+ raw = CachedHTTPResponse.from_response(response)
+
+ assert dict(response.raw.headers) == dict(raw.headers) == {'Content-Type': 'text/plain'}
+ assert raw.read(None) == b'mock response'
+ assert response.raw.decode_content is raw.decode_content is False
+ assert response.raw.reason is raw.reason is None
+ assert response.raw._request_url is raw.request_url is None
+ assert response.raw.status == raw.status == 200
+ assert response.raw.strict == raw.strict == 0
+ assert response.raw.version == raw.version == 0
+
+
+def test_read():
+ raw = CachedHTTPResponse(body=b'mock response')
+ assert raw.read(10) == b'mock respo'
+ assert raw.read(None) == b'nse'
+ assert raw.read(1) == b''
+ assert raw._fp.closed is True
+
+
+def test_close():
+ raw = CachedHTTPResponse(body=b'mock response')
+ raw.close()
+ assert raw._fp.closed is True
+
+
+def test_reset():
+ raw = CachedHTTPResponse(body=b'mock response')
+ raw.read(None)
+ assert raw.read(1) == b''
+ assert raw._fp.closed is True
+
+ raw.reset()
+ assert raw.read(None) == b'mock response'
+
+
+def test_set_content():
+ raw = CachedHTTPResponse(body=None)
+ raw.set_content(b'mock response')
+ assert raw.read() == b'mock response'
+
+
+def test_stream():
+ raw = CachedHTTPResponse(body=b'mock response')
+ data = b''
+ for chunk in raw.stream(1):
+ data += chunk
+ assert data == b'mock response'
+ assert raw._fp.closed
diff --git a/tests/unit/models/test_request.py b/tests/unit/models/test_request.py
new file mode 100644
index 0000000..8e917ff
--- /dev/null
+++ b/tests/unit/models/test_request.py
@@ -0,0 +1,16 @@
+from requests.utils import default_headers
+
+from requests_cache.models.response import CachedRequest
+from tests.conftest import MOCKED_URL
+
+
+def test_from_request(mock_session):
+ response = mock_session.get(MOCKED_URL, data=b'mock request', headers={'foo': 'bar'})
+ request = CachedRequest.from_request(response.request)
+ expected_headers = {**default_headers(), 'Content-Length': '12', 'foo': 'bar'}
+
+ assert response.request.body == request.body == b'mock request'
+ assert response.request._cookies == request.cookies == {}
+ assert response.request.headers == request.headers == expected_headers
+ assert response.request.method == request.method == 'GET'
+ assert response.request.url == request.url == MOCKED_URL
diff --git a/tests/unit/test_response.py b/tests/unit/models/test_response.py
index 1eb5e64..67c1006 100644
--- a/tests/unit/test_response.py
+++ b/tests/unit/models/test_response.py
@@ -5,13 +5,12 @@ from time import sleep
from urllib3.response import HTTPResponse
-from requests_cache import CachedHTTPResponse, CachedResponse
-from requests_cache.response import format_file_size
+from requests_cache.models.response import CachedResponse, format_file_size
from tests.conftest import MOCKED_URL
def test_basic_attrs(mock_session):
- response = CachedResponse(mock_session.get(MOCKED_URL))
+ response = CachedResponse.from_response(mock_session.get(MOCKED_URL))
assert response.from_cache is True
assert response.url == MOCKED_URL
@@ -25,62 +24,28 @@ def test_basic_attrs(mock_session):
assert response.is_expired is False
+def test_history(mock_session):
+ original_response = mock_session.get(MOCKED_URL)
+ original_response.history = [mock_session.get(MOCKED_URL)] * 3
+ response = CachedResponse.from_response(original_response)
+ assert len(response.history) == 3
+ assert all([isinstance(r, CachedResponse) for r in response.history])
+
+
@pytest.mark.parametrize(
- 'expire_after, is_expired',
+ 'expires, is_expired',
[
(datetime.utcnow() + timedelta(days=1), False),
(datetime.utcnow() - timedelta(days=1), True),
],
)
-def test_is_expired(expire_after, is_expired, mock_session):
- response = CachedResponse(mock_session.get(MOCKED_URL), expire_after)
+def test_is_expired(expires, is_expired, mock_session):
+ response = CachedResponse.from_response(mock_session.get(MOCKED_URL), expires=expires)
assert response.from_cache is True
assert response.is_expired == is_expired
-def test_history(mock_session):
- original_response = mock_session.get(MOCKED_URL)
- original_response.history = [mock_session.get(MOCKED_URL)] * 3
- response = CachedResponse(original_response)
- assert len(response.history) == 3
- assert all([isinstance(r, CachedResponse) for r in response.history])
-
-
-def test_raw_response__read(mock_session):
- response = CachedResponse(mock_session.get(MOCKED_URL))
- assert isinstance(response.raw, CachedHTTPResponse)
- assert response.raw.read(10) == b'mock respo'
- assert response.raw.read(None) == b'nse'
- assert response.raw.read(1) == b''
- assert response.raw._fp.closed is True
-
-
-def test_raw_response__close(mock_session):
- response = CachedResponse(mock_session.get(MOCKED_URL))
- response.close()
- assert response.raw._fp.closed is True
-
-
-def test_raw_response__reset(mock_session):
- response = CachedResponse(mock_session.get(MOCKED_URL))
- response.raw.read(None)
- assert response.raw.read(1) == b''
- assert response.raw._fp.closed is True
-
- response.reset()
- assert response.raw.read(None) == b'mock response'
-
-
-def test_raw_response__stream(mock_session):
- response = CachedResponse(mock_session.get(MOCKED_URL))
- data = b''
- for chunk in response.raw.stream(1):
- data += chunk
- assert data == b'mock response'
- assert response.raw._fp.closed
-
-
-def test_raw_response__iterator(mock_session):
+def test_iterator(mock_session):
# Set up mock response with streamed content
url = f'{MOCKED_URL}/stream'
mock_raw_response = HTTPResponse(
@@ -112,7 +77,7 @@ def test_raw_response__iterator(mock_session):
def test_revalidate__extend_expiration(mock_session):
# Start with an expired response
- response = CachedResponse(
+ response = CachedResponse.from_response(
mock_session.get(MOCKED_URL),
expires=datetime.utcnow() - timedelta(seconds=0.01),
)
@@ -127,7 +92,7 @@ def test_revalidate__extend_expiration(mock_session):
def test_revalidate__shorten_expiration(mock_session):
# Start with a non-expired response
- response = CachedResponse(
+ response = CachedResponse.from_response(
mock_session.get(MOCKED_URL),
expires=datetime.utcnow() + timedelta(seconds=1),
)
@@ -139,7 +104,7 @@ def test_revalidate__shorten_expiration(mock_session):
def test_size(mock_session):
- response = CachedResponse(mock_session.get(MOCKED_URL))
+ response = CachedResponse.from_response(mock_session.get(MOCKED_URL))
response._content = None
assert response.size == 0
response._content = b'1' * 1024
@@ -150,7 +115,7 @@ def test_str(mock_session):
"""Just ensure that a subset of relevant attrs get included in the response str; the format
may change without breaking the test.
"""
- response = CachedResponse(mock_session.get(MOCKED_URL))
+ response = CachedResponse.from_response(mock_session.get(MOCKED_URL))
response._content = b'1010'
expected_values = ['GET', MOCKED_URL, 200, '4 bytes', 'created', 'expires', 'fresh']
assert all([str(v) in str(response) for v in expected_values])
@@ -158,10 +123,10 @@ def test_str(mock_session):
def test_repr(mock_session):
"""Just ensure that a subset of relevant attrs get included in the response repr"""
- response = CachedResponse(mock_session.get(MOCKED_URL))
+ response = CachedResponse.from_response(mock_session.get(MOCKED_URL))
expected_values = ['GET', MOCKED_URL, 200, 'ISO-8859-1', response.headers]
print(repr(response))
- assert repr(response).startswith('<CachedResponse(') and repr(response).endswith(')>')
+ assert repr(response).startswith('CachedResponse(') and repr(response).endswith(')')
assert all([str(v) in repr(response) for v in expected_values])
diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py
index 7ec4265..e958212 100644
--- a/tests/unit/test_session.py
+++ b/tests/unit/test_session.py
@@ -12,13 +12,18 @@ from uuid import uuid4
import requests
from itsdangerous.exc import BadSignature
-from itsdangerous.serializer import Serializer
from requests.structures import CaseInsensitiveDict
from requests_cache import ALL_METHODS, CachedSession
-from requests_cache.backends import BACKEND_CLASSES, BaseCache, get_placeholder_backend
-from requests_cache.backends.sqlite import DbDict, DbPickleDict
-from requests_cache.response import CachedResponse
+from requests_cache.backends import (
+ BACKEND_CLASSES,
+ BaseCache,
+ DbDict,
+ DbPickleDict,
+ get_placeholder_backend,
+)
+from requests_cache.models import CachedResponse
+from requests_cache.serializers import PickleSerializer, SafePickleSerializer
from tests.conftest import (
MOCKED_URL,
MOCKED_URL_404,
@@ -559,16 +564,17 @@ def test_unpickle_errors(mock_session):
def test_cache_signing(tempfile_path):
session = CachedSession(tempfile_path)
- assert session.cache.responses._serializer == pickle
+ assert isinstance(session.cache.responses.serializer, PickleSerializer)
# With a secret key, itsdangerous should be used
secret_key = str(uuid4())
session = CachedSession(tempfile_path, secret_key=secret_key)
- assert isinstance(session.cache.responses._serializer, Serializer)
+ assert isinstance(session.cache.responses.serializer, SafePickleSerializer)
# Simple serialize/deserialize round trip
- session.cache.responses['key'] = 'value'
- assert session.cache.responses['key'] == 'value'
+ response = CachedResponse()
+ session.cache.responses['key'] = response
+ assert session.cache.responses['key'] == response
# Without the same signing key, the item shouldn't be considered safe to deserialize
session = CachedSession(tempfile_path, secret_key='a different key')