summaryrefslogtreecommitdiff
path: root/requests_cache/cache_keys.py
diff options
context:
space:
mode:
authorJordan Cook <jordan.cook@pioneer.com>2021-08-10 16:56:13 -0500
committerJordan Cook <jordan.cook@pioneer.com>2021-08-14 21:58:01 -0500
commit85ccdfd12928964e2861768ce70301232f50c769 (patch)
tree6862f34ee3ec4e42b8f9e979d32bb99baa032615 /requests_cache/cache_keys.py
parentd17ad7f422ad01bc7ddcd707396f763ec0b236d3 (diff)
downloadrequests-cache-85ccdfd12928964e2861768ce70301232f50c769.tar.gz
Replace some 'type: ignore' statements with better type hinting
Diffstat (limited to 'requests_cache/cache_keys.py')
-rw-r--r--requests_cache/cache_keys.py31
1 files changed, 18 insertions, 13 deletions
diff --git a/requests_cache/cache_keys.py b/requests_cache/cache_keys.py
index d1ad8dc..3ded784 100644
--- a/requests_cache/cache_keys.py
+++ b/requests_cache/cache_keys.py
@@ -1,20 +1,25 @@
+from __future__ import annotations
+
import hashlib
import json
from operator import itemgetter
-from typing import Iterable, List, Mapping, Optional, Tuple, Union
+from typing import TYPE_CHECKING, Iterable, List, Mapping, Optional, Tuple, Union
from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse
-from requests import PreparedRequest, Request, Session
+from requests import Request, Session
from requests.models import CaseInsensitiveDict
from requests.utils import default_headers
from url_normalize import url_normalize
+if TYPE_CHECKING:
+ from .models import AnyRequest
+
DEFAULT_HEADERS = default_headers()
RequestContent = Union[Mapping, str, bytes]
def create_key(
- request: PreparedRequest,
+ request: AnyRequest,
ignored_params: Iterable[str] = None,
include_get_headers: bool = False,
**kwargs,
@@ -31,15 +36,16 @@ def create_key(
if body:
key.update(body)
if include_get_headers and request.headers != DEFAULT_HEADERS:
- for name, value in normalize_dict(request.headers).items(): # type: ignore
+ headers = normalize_dict(request.headers)
+ if TYPE_CHECKING:
+ assert isinstance(headers, dict)
+ for name, value in headers.items():
key.update(encode(f'{name}={value}'))
return key.hexdigest()
-def remove_ignored_params(
- request: PreparedRequest, ignored_params: Optional[Iterable[str]]
-) -> PreparedRequest:
+def remove_ignored_params(request: AnyRequest, ignored_params: Optional[Iterable[str]]) -> AnyRequest:
if not ignored_params:
return request
request.headers = remove_ignored_headers(request, ignored_params)
@@ -49,7 +55,7 @@ def remove_ignored_params(
def remove_ignored_headers(
- request: PreparedRequest, ignored_params: Optional[Iterable[str]]
+ request: AnyRequest, ignored_params: Optional[Iterable[str]]
) -> CaseInsensitiveDict:
if not ignored_params:
return request.headers
@@ -59,7 +65,7 @@ def remove_ignored_headers(
return headers
-def remove_ignored_url_params(request: PreparedRequest, ignored_params: Optional[Iterable[str]]) -> str:
+def remove_ignored_url_params(request: AnyRequest, ignored_params: Optional[Iterable[str]]) -> str:
url_str = str(request.url)
if not ignored_params:
return url_str
@@ -69,10 +75,9 @@ def remove_ignored_url_params(request: PreparedRequest, ignored_params: Optional
return urlunparse((url.scheme, url.netloc, url.path, url.params, urlencode(query), url.fragment))
-def remove_ignored_body_params(
- request: PreparedRequest, ignored_params: Optional[Iterable[str]]
-) -> bytes:
+def remove_ignored_body_params(request: AnyRequest, ignored_params: Optional[Iterable[str]]) -> bytes:
original_body = request.body
+ filtered_body: Union[str, bytes] = b''
content_type = request.headers.get('content-type')
if not ignored_params or not original_body or not content_type:
return encode(original_body)
@@ -85,7 +90,7 @@ def remove_ignored_body_params(
body = filter_params(sorted(body), ignored_params)
filtered_body = json.dumps(body)
else:
- filtered_body = original_body # type: ignore
+ filtered_body = original_body
return encode(filtered_body)