diff options
Diffstat (limited to 'requests_cache')
| -rw-r--r-- | requests_cache/backends/base.py | 3 | ||||
| -rw-r--r-- | requests_cache/cache_keys.py | 51 |
2 files changed, 33 insertions, 21 deletions
diff --git a/requests_cache/backends/base.py b/requests_cache/backends/base.py index 0045f7f..d5f16e8 100644 --- a/requests_cache/backends/base.py +++ b/requests_cache/backends/base.py @@ -8,7 +8,7 @@ from logging import getLogger from typing import Iterable, Iterator, Tuple, Union from ..cache_control import ExpirationTime -from ..cache_keys import create_key, remove_ignored_params, url_to_key +from ..cache_keys import create_key, remove_ignored_params, remove_ignored_url_params, url_to_key from ..models import AnyRequest, AnyResponse, CachedResponse from ..serializers import init_serializer @@ -55,6 +55,7 @@ class BaseCache: cache_key = cache_key or self.create_key(response.request) cached_response = CachedResponse.from_response(response, cache_key=cache_key, expires=expires) cached_response.request = remove_ignored_params(cached_response.request, self.ignored_parameters) + cached_response.url = remove_ignored_url_params(cached_response.url, self.ignored_parameters) self.responses[cache_key] = cached_response def save_redirect(self, request: AnyRequest, response_key: str): diff --git a/requests_cache/cache_keys.py b/requests_cache/cache_keys.py index d1ad8dc..748c511 100644 --- a/requests_cache/cache_keys.py +++ b/requests_cache/cache_keys.py @@ -22,7 +22,7 @@ def create_key( """Create a normalized cache key from a request object""" key = hashlib.sha256() key.update(encode((request.method or '').upper())) - url = remove_ignored_url_params(request, ignored_params) + url = remove_ignored_url_params(request.url, ignored_params) url = url_normalize(url) key.update(encode(url)) key.update(encode(kwargs.get('verify', True))) @@ -42,47 +42,58 @@ def remove_ignored_params( ) -> PreparedRequest: if not ignored_params: return request - request.headers = remove_ignored_headers(request, ignored_params) - request.url = remove_ignored_url_params(request, ignored_params) + request.headers = remove_ignored_headers(request.headers, ignored_params) + request.url = remove_ignored_url_params(request.url, ignored_params) request.body = remove_ignored_body_params(request, ignored_params) return request def remove_ignored_headers( - request: PreparedRequest, ignored_params: Optional[Iterable[str]] + headers: Mapping, ignored_parameters: Optional[Iterable[str]] ) -> CaseInsensitiveDict: - if not ignored_params: - return request.headers - headers = CaseInsensitiveDict(request.headers.copy()) - for k in ignored_params: + """Remove any ignored request headers""" + if not ignored_parameters: + return CaseInsensitiveDict(headers) + + headers = CaseInsensitiveDict(headers) + for k in ignored_parameters: headers.pop(k, None) return headers -def remove_ignored_url_params(request: PreparedRequest, ignored_params: Optional[Iterable[str]]) -> str: - url_str = str(request.url) - if not ignored_params: - return url_str +def remove_ignored_url_params(url: Optional[str], ignored_parameters: Optional[Iterable[str]]) -> str: + """Remove any ignored request parameters from the URL""" + if not ignored_parameters or not url: + return url or '' - url = urlparse(url_str) - query = filter_params(parse_qsl(url.query), ignored_params) - return urlunparse((url.scheme, url.netloc, url.path, url.params, urlencode(query), url.fragment)) + url_tokens = urlparse(url) + query = _filter_params(parse_qsl(url_tokens.query), ignored_parameters) + return urlunparse( + ( + url_tokens.scheme, + url_tokens.netloc, + url_tokens.path, + url_tokens.params, + urlencode(query), + url_tokens.fragment, + ) + ) def remove_ignored_body_params( - request: PreparedRequest, ignored_params: Optional[Iterable[str]] + request: PreparedRequest, ignored_parameters: Optional[Iterable[str]] ) -> bytes: original_body = request.body content_type = request.headers.get('content-type') - if not ignored_params or not original_body or not content_type: + if not ignored_parameters or not original_body or not content_type: return encode(original_body) if content_type == 'application/x-www-form-urlencoded': - body = filter_params(parse_qsl(decode(original_body)), ignored_params) + body = _filter_params(parse_qsl(decode(original_body)), ignored_parameters) filtered_body = urlencode(body) elif content_type == 'application/json': body = json.loads(decode(original_body)).items() - body = filter_params(sorted(body), ignored_params) + body = _filter_params(sorted(body), ignored_parameters) filtered_body = json.dumps(body) else: filtered_body = original_body # type: ignore @@ -90,7 +101,7 @@ def remove_ignored_body_params( return encode(filtered_body) -def filter_params(data: List[Tuple[str, str]], ignored_params: Iterable[str]) -> List[Tuple[str, str]]: +def _filter_params(data: List[Tuple[str, str]], ignored_params: Iterable[str]) -> List[Tuple[str, str]]: return [(k, v) for k, v in data if k not in set(ignored_params)] |
