summaryrefslogtreecommitdiff
path: root/requests_cache
diff options
context:
space:
mode:
Diffstat (limited to 'requests_cache')
-rw-r--r--requests_cache/backends/base.py3
-rw-r--r--requests_cache/cache_keys.py51
2 files changed, 33 insertions, 21 deletions
diff --git a/requests_cache/backends/base.py b/requests_cache/backends/base.py
index 0045f7f..d5f16e8 100644
--- a/requests_cache/backends/base.py
+++ b/requests_cache/backends/base.py
@@ -8,7 +8,7 @@ from logging import getLogger
from typing import Iterable, Iterator, Tuple, Union
from ..cache_control import ExpirationTime
-from ..cache_keys import create_key, remove_ignored_params, url_to_key
+from ..cache_keys import create_key, remove_ignored_params, remove_ignored_url_params, url_to_key
from ..models import AnyRequest, AnyResponse, CachedResponse
from ..serializers import init_serializer
@@ -55,6 +55,7 @@ class BaseCache:
cache_key = cache_key or self.create_key(response.request)
cached_response = CachedResponse.from_response(response, cache_key=cache_key, expires=expires)
cached_response.request = remove_ignored_params(cached_response.request, self.ignored_parameters)
+ cached_response.url = remove_ignored_url_params(cached_response.url, self.ignored_parameters)
self.responses[cache_key] = cached_response
def save_redirect(self, request: AnyRequest, response_key: str):
diff --git a/requests_cache/cache_keys.py b/requests_cache/cache_keys.py
index d1ad8dc..748c511 100644
--- a/requests_cache/cache_keys.py
+++ b/requests_cache/cache_keys.py
@@ -22,7 +22,7 @@ def create_key(
"""Create a normalized cache key from a request object"""
key = hashlib.sha256()
key.update(encode((request.method or '').upper()))
- url = remove_ignored_url_params(request, ignored_params)
+ url = remove_ignored_url_params(request.url, ignored_params)
url = url_normalize(url)
key.update(encode(url))
key.update(encode(kwargs.get('verify', True)))
@@ -42,47 +42,58 @@ def remove_ignored_params(
) -> PreparedRequest:
if not ignored_params:
return request
- request.headers = remove_ignored_headers(request, ignored_params)
- request.url = remove_ignored_url_params(request, ignored_params)
+ request.headers = remove_ignored_headers(request.headers, ignored_params)
+ request.url = remove_ignored_url_params(request.url, ignored_params)
request.body = remove_ignored_body_params(request, ignored_params)
return request
def remove_ignored_headers(
- request: PreparedRequest, ignored_params: Optional[Iterable[str]]
+ headers: Mapping, ignored_parameters: Optional[Iterable[str]]
) -> CaseInsensitiveDict:
- if not ignored_params:
- return request.headers
- headers = CaseInsensitiveDict(request.headers.copy())
- for k in ignored_params:
+ """Remove any ignored request headers"""
+ if not ignored_parameters:
+ return CaseInsensitiveDict(headers)
+
+ headers = CaseInsensitiveDict(headers)
+ for k in ignored_parameters:
headers.pop(k, None)
return headers
-def remove_ignored_url_params(request: PreparedRequest, ignored_params: Optional[Iterable[str]]) -> str:
- url_str = str(request.url)
- if not ignored_params:
- return url_str
+def remove_ignored_url_params(url: Optional[str], ignored_parameters: Optional[Iterable[str]]) -> str:
+ """Remove any ignored request parameters from the URL"""
+ if not ignored_parameters or not url:
+ return url or ''
- url = urlparse(url_str)
- query = filter_params(parse_qsl(url.query), ignored_params)
- return urlunparse((url.scheme, url.netloc, url.path, url.params, urlencode(query), url.fragment))
+ url_tokens = urlparse(url)
+ query = _filter_params(parse_qsl(url_tokens.query), ignored_parameters)
+ return urlunparse(
+ (
+ url_tokens.scheme,
+ url_tokens.netloc,
+ url_tokens.path,
+ url_tokens.params,
+ urlencode(query),
+ url_tokens.fragment,
+ )
+ )
def remove_ignored_body_params(
- request: PreparedRequest, ignored_params: Optional[Iterable[str]]
+ request: PreparedRequest, ignored_parameters: Optional[Iterable[str]]
) -> bytes:
original_body = request.body
content_type = request.headers.get('content-type')
- if not ignored_params or not original_body or not content_type:
+ if not ignored_parameters or not original_body or not content_type:
return encode(original_body)
if content_type == 'application/x-www-form-urlencoded':
- body = filter_params(parse_qsl(decode(original_body)), ignored_params)
+ body = _filter_params(parse_qsl(decode(original_body)), ignored_parameters)
filtered_body = urlencode(body)
elif content_type == 'application/json':
body = json.loads(decode(original_body)).items()
- body = filter_params(sorted(body), ignored_params)
+ body = _filter_params(sorted(body), ignored_parameters)
filtered_body = json.dumps(body)
else:
filtered_body = original_body # type: ignore
@@ -90,7 +101,7 @@ def remove_ignored_body_params(
return encode(filtered_body)
-def filter_params(data: List[Tuple[str, str]], ignored_params: Iterable[str]) -> List[Tuple[str, str]]:
+def _filter_params(data: List[Tuple[str, str]], ignored_params: Iterable[str]) -> List[Tuple[str, str]]:
return [(k, v) for k, v in data if k not in set(ignored_params)]