diff options
author | Jordan Cook <JWCook@users.noreply.github.com> | 2021-11-16 10:11:19 -0600 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-11-16 10:11:19 -0600 |
commit | 40d109aef985366f845414df683c2b54af097251 (patch) | |
tree | 2b9195f35f48617c6268643d2c530ce975c3ac77 | |
parent | da847a2fc0dbf40fbb8596a74450f84b3c2db114 (diff) | |
parent | 8a6f8691bf3cadd98bf60365880ba04c7a708765 (diff) | |
download | requests-cache-40d109aef985366f845414df683c2b54af097251.tar.gz |
Merge pull request #447 from JWCook/backend-instance-kwargs
Add support for BaseCache keyword arguments passed along with a backend instance
-rw-r--r-- | HISTORY.md | 1 | ||||
-rw-r--r-- | requests_cache/backends/__init__.py | 18 | ||||
-rw-r--r-- | requests_cache/backends/base.py | 20 | ||||
-rw-r--r-- | requests_cache/session.py | 1 | ||||
-rw-r--r-- | tests/unit/test_session.py | 19 |
5 files changed, 49 insertions, 10 deletions
@@ -22,6 +22,7 @@ **Bugfixes:** * Handle some additional corner cases when normalizing request data +* Add support for `BaseCache` keyword arguments passed along with a backend instance * Fix issue with cache headers not being used correctly if `cache_control=True` is used with an `expire_after` value * Fix license metadata as shown on PyPI diff --git a/requests_cache/backends/__init__.py b/requests_cache/backends/__init__.py index bc921f7..0c525fe 100644 --- a/requests_cache/backends/__init__.py +++ b/requests_cache/backends/__init__.py @@ -3,7 +3,7 @@ from logging import getLogger from typing import Callable, Dict, Iterable, Optional, Type, Union -from .._utils import get_placeholder_class +from .._utils import get_placeholder_class, get_valid_kwargs from .base import BaseCache, BaseStorage, DictStorage # Backend-specific keyword arguments equivalent to 'cache_name' @@ -83,14 +83,26 @@ def init_backend(cache_name: str, backend: Optional[BackendSpecifier], **kwargs) # Determine backend class if isinstance(backend, BaseCache): - return backend + return _set_backend_kwargs(cache_name, backend, **kwargs) elif isinstance(backend, type): return backend(cache_name, **kwargs) elif not backend: - backend = 'sqlite' if BACKEND_CLASSES['sqlite'] else 'memory' + sqlite_supported = issubclass(BACKEND_CLASSES['sqlite'], BaseCache) + backend = 'sqlite' if sqlite_supported else 'memory' backend = str(backend).lower() if backend not in BACKEND_CLASSES: raise ValueError(f'Invalid backend: {backend}. Choose from: {BACKEND_CLASSES.keys()}') return BACKEND_CLASSES[backend](cache_name, **kwargs) + + +def _set_backend_kwargs(cache_name, backend, **kwargs): + """Set any backend arguments if they are passed along with a backend instance""" + backend_kwargs = get_valid_kwargs(BaseCache.__init__, kwargs) + backend_kwargs.setdefault('match_headers', kwargs.pop('include_get_headers', False)) + for k, v in backend_kwargs.items(): + setattr(backend, k, v) + if cache_name: + backend.cache_name = cache_name + return backend diff --git a/requests_cache/backends/base.py b/requests_cache/backends/base.py index 7d46a85..ddd47e7 100644 --- a/requests_cache/backends/base.py +++ b/requests_cache/backends/base.py @@ -44,7 +44,7 @@ class BaseCache: def __init__( self, - *args, + cache_name: str = 'http_cache', match_headers: Union[Iterable[str], bool] = False, ignored_parameters: Iterable[str] = None, key_fn: KEY_FN = None, @@ -52,9 +52,9 @@ class BaseCache: ): self.responses: BaseStorage = DictStorage() self.redirects: BaseStorage = DictStorage() + self.cache_name = cache_name self.ignored_parameters = ignored_parameters self.key_fn = key_fn or create_key - self.name: str = kwargs.get('cache_name', '') self.match_headers = match_headers or kwargs.pop('include_get_headers', False) @property @@ -235,7 +235,7 @@ class BaseCache: return f'Total rows: {len(self.responses)} responses, {len(self.redirects)} redirects' def __repr__(self): - return f'<{self.__class__.__name__}(name={self.name})>' + return f'<{self.__class__.__name__}(name={self.cache_name})>' class BaseStorage(MutableMapping, ABC): @@ -261,9 +261,17 @@ class BaseStorage(MutableMapping, ABC): serializer=None, **kwargs, ): - self.serializer = init_serializer(serializer, **kwargs) + self._serializer = init_serializer(serializer, **kwargs) logger.debug(f'Initializing {type(self).__name__} with serializer: {self.serializer}') + @property + def serializer(self): + return self._serializer + + @serializer.setter + def serializer(self, value): + self._serializer = init_serializer(value) + def bulk_delete(self, keys: Iterable[str]): """Delete multiple keys from the cache, without raising errors for missing keys. This is a naive implementation that subclasses should override with a more efficient backend-specific @@ -289,6 +297,10 @@ class DictStorage(UserDict, BaseStorage): """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self._serializer = None + def __getitem__(self, key): """An additional step is needed here for response data. Since the original response object is still in memory, its content has already been read and needs to be reset. diff --git a/requests_cache/session.py b/requests_cache/session.py index 3b26b4a..9b2531d 100644 --- a/requests_cache/session.py +++ b/requests_cache/session.py @@ -66,7 +66,6 @@ class CacheMixin(MIXIN_BASE): self.filter_fn = filter_fn or (lambda r: True) self.stale_if_error = stale_if_error or kwargs.pop('old_data_on_error', False) - self.cache.name = cache_name # Set to handle backend=<instance> self._disabled = False self._lock = RLock() diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index d3509de..d225e57 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -52,9 +52,24 @@ def test_init_backend_instance(): assert session.cache is backend +def test_init_backend_instance__kwargs(): + backend = MyCache() + session = CachedSession( + 'test_cache', + backend=backend, + ignored_parameters=['foo'], + include_get_headers=True, + ) + + assert session.cache.cache_name == 'test_cache' + assert session.cache.ignored_parameters == ['foo'] + assert session.cache.match_headers is True + + def test_init_backend_class(): - session = CachedSession(backend=MyCache) + session = CachedSession('test_cache', backend=MyCache) assert isinstance(session.cache, MyCache) + assert session.cache.cache_name == 'test_cache' @pytest.mark.parametrize('method', ALL_METHODS) @@ -138,7 +153,7 @@ def test_repr(mock_session): mock_session.cache.redirects['key'] = 'value' mock_session.cache.redirects['key_2'] = 'value' - assert mock_session.cache.name in repr(mock_session) and '10.5' in repr(mock_session) + assert mock_session.cache.cache_name in repr(mock_session) and '10.5' in repr(mock_session) assert '2 redirects' in str(mock_session.cache) and '1 responses' in str(mock_session.cache) |