diff options
author | Jordan Cook <jordan.cook@pioneer.com> | 2021-09-19 21:06:03 -0500 |
---|---|---|
committer | Jordan Cook <jordan.cook@pioneer.com> | 2021-09-20 23:33:30 -0500 |
commit | b3fc1f042e2deebcc62c1a074a02c0f90ab406ca (patch) | |
tree | 505b517e76d06b736966213f7f6e52075d9f463a | |
parent | 5605db4a84c4ee15f4406ea872aa650e31f3348f (diff) | |
download | requests-cache-b3fc1f042e2deebcc62c1a074a02c0f90ab406ca.tar.gz |
Reorganize & improve request normalization functions:
* Handle all normalization in `cache_keys` module, get rid of `normalize_dict()` function used in `CachedSession`
* Reorganize `cache_keys` helper functions into the following:
* `normalize_request()`
* `normalize_url()`
* `normalize_headers()`
* `normalize_params()`
* `normalize_body()`
* `normalize_json_body()`
* `redact_response()`
-rw-r--r-- | requests_cache/backends/base.py | 5 | ||||
-rw-r--r-- | requests_cache/cache_control.py | 3 | ||||
-rw-r--r-- | requests_cache/cache_keys.py | 243 | ||||
-rw-r--r-- | requests_cache/models/__init__.py | 5 | ||||
-rw-r--r-- | requests_cache/models/request.py | 6 | ||||
-rw-r--r-- | requests_cache/session.py | 8 | ||||
-rw-r--r-- | tests/unit/test_cache_keys.py | 68 | ||||
-rw-r--r-- | tests/unit/test_session.py | 71 |
8 files changed, 233 insertions, 176 deletions
diff --git a/requests_cache/backends/base.py b/requests_cache/backends/base.py index 5c3d385..d57e02d 100644 --- a/requests_cache/backends/base.py +++ b/requests_cache/backends/base.py @@ -13,7 +13,7 @@ from logging import getLogger from typing import Callable, Iterable, Iterator, Tuple, Union from ..cache_control import ExpirationTime -from ..cache_keys import create_key, remove_ignored_params, remove_ignored_url_params +from ..cache_keys import create_key, redact_response from ..models import AnyRequest, AnyResponse, CachedResponse from ..serializers import init_serializer @@ -94,8 +94,7 @@ class BaseCache: """ cache_key = cache_key or self.create_key(response.request) cached_response = CachedResponse.from_response(response, expires=expires) - cached_response.url = remove_ignored_url_params(response.url, self.ignored_parameters) - cached_response.request = remove_ignored_params(cached_response.request, self.ignored_parameters) + cached_response = redact_response(cached_response, self.ignored_parameters) self.responses[cache_key] = cached_response for r in response.history: self.redirects[self.create_key(r.request)] = cache_key diff --git a/requests_cache/cache_control.py b/requests_cache/cache_control.py index 418ac29..5a5c383 100644 --- a/requests_cache/cache_control.py +++ b/requests_cache/cache_control.py @@ -23,7 +23,8 @@ from requests import PreparedRequest, Response if TYPE_CHECKING: from .models import CachedResponse -# Value that may be set by either Cache-Control headers or CachedSession params to disable caching +__all__ = ['DO_NOT_CACHE', 'CacheActions'] +# May be set by either headers or expire_after param to disable caching DO_NOT_CACHE = 0 # Supported Cache-Control directives diff --git a/requests_cache/cache_keys.py b/requests_cache/cache_keys.py index d8f787c..953d5c0 100644 --- a/requests_cache/cache_keys.py +++ b/requests_cache/cache_keys.py @@ -8,178 +8,184 @@ from __future__ import annotations import json from hashlib import blake2b -from operator import itemgetter -from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Tuple, Union +from logging import getLogger +from typing import TYPE_CHECKING, Dict, Iterable, List, Mapping, Optional, Union from urllib.parse import parse_qsl, urlencode, urlparse, urlunparse from requests import Request, Session from requests.models import CaseInsensitiveDict -from requests.utils import default_headers from url_normalize import url_normalize from . import get_valid_kwargs if TYPE_CHECKING: - from .models import AnyRequest + from .models import AnyPreparedRequest, AnyRequest, CachedResponse -DEFAULT_REQUEST_HEADERS = default_headers() +__all__ = ['create_key', 'normalize_request'] +# Request headers that are always excluded from cache keys, but not redacted from cached responses DEFAULT_EXCLUDE_HEADERS = {'Cache-Control', 'If-None-Match', 'If-Modified-Since'} + +ParamList = Optional[Iterable[str]] RequestContent = Union[Mapping, str, bytes] +logger = getLogger(__name__) + def create_key( request: AnyRequest = None, - ignored_parameters: Iterable[str] = None, - match_headers: Union[Iterable[str], bool] = False, - **kwargs, + ignored_parameters: ParamList = None, + match_headers: Union[ParamList, bool] = False, + **request_kwargs, ) -> str: """Create a normalized cache key from either a request object or :py:class:`~requests.Request` arguments + + Args: + request: Request object to generate a cache key from + ignored_parameters: Request parames, headers, and/or body params to not match against + match_headers: Match only the specified headers, or ``True`` to match all headers + request_kwargs: Request arguments to generate a cache key from """ - # Create a PreparedRequest, if needed + # Convert raw request arguments into a request object, if needed if not request: - request_kwargs = get_valid_kwargs(Request.__init__, kwargs) - request = Session().prepare_request(Request(**request_kwargs)) - if TYPE_CHECKING: - assert request is not None - - # Add method and relevant request settings + request = Request(**get_valid_kwargs(Request.__init__, request_kwargs)) + + # Normalize and gather all relevant request info to match against + request = normalize_request(request, ignored_parameters) + key_parts = [ + request.method or '', + request.url, + request.body or '', + request_kwargs.get('verify', True), + *get_matched_headers(request.headers, match_headers), + ] + + # Generate a hash based on this info key = blake2b(digest_size=8) - key.update(encode((request.method or '').upper())) - key.update(encode(kwargs.get('verify', True))) - - # Add filtered/normalized URL + request params - url = remove_ignored_url_params(request.url, ignored_parameters) - key.update(encode(url_normalize(url))) - - # Add filtered request body - body = remove_ignored_body_params(request, ignored_parameters) - if body: - key.update(body) - - # Add filtered/normalized headers - headers = get_matched_headers(request.headers, ignored_parameters, match_headers) - for k, v in headers.items(): - key.update(encode(f'{k}={v}')) - + for part in key_parts: + key.update(encode(part)) return key.hexdigest() def get_matched_headers( - headers: CaseInsensitiveDict, ignored_parameters: Optional[Iterable[str]], match_headers -) -> Dict: - """Get only the headers we should match against, given an optional include list and/or exclude - list. Also normalizes headers (sorted/lowercased keys). + headers: CaseInsensitiveDict, match_headers: Union[ParamList, bool] +) -> List[str]: + """Get only the headers we should match against as a list of ``k=v`` strings, given an optional + include list. """ if not match_headers: - return {} + return [] - included = set(match_headers if isinstance(match_headers, Iterable) else headers.keys()) - included -= set(ignored_parameters or []) - included -= DEFAULT_EXCLUDE_HEADERS - return {k.lower(): headers[k] for k in sorted(included) if k in headers} + if isinstance(match_headers, Iterable): + included = set(match_headers) - DEFAULT_EXCLUDE_HEADERS + else: + included = set(headers) - DEFAULT_EXCLUDE_HEADERS + return [f'{k.lower()}={headers[k]}' for k in included if k in headers] -def remove_ignored_headers( - headers: Mapping, ignored_parameters: Optional[Iterable[str]] -) -> CaseInsensitiveDict: - """Remove any ignored request headers""" - if not ignored_parameters: - return CaseInsensitiveDict(headers) - headers = CaseInsensitiveDict(headers) - for k in ignored_parameters: - headers.pop(k, None) - return headers +def normalize_request(request: AnyRequest, ignored_parameters: ParamList) -> AnyPreparedRequest: + """Normalize and remove ignored parameters from request URL, body, and headers. + This is used for both: + * Increasing cache hits by generating more precise cache keys + * Redacting potentially sensitive info from cached requests + + Args: + request: Request object to normalize + ignored_parameters: Request parames, headers, and/or body params to not match against and + to remove from the request + """ + if isinstance(request, Request): + norm_request = Session().prepare_request(request) + else: + norm_request = request.copy() + + norm_request.method = (norm_request.method or '').upper() + norm_request.url = normalize_url(norm_request.url, ignored_parameters) + norm_request.headers = normalize_headers(norm_request.headers, ignored_parameters) + norm_request.body = normalize_body(norm_request, ignored_parameters) + return norm_request -def remove_ignored_params( - request: AnyRequest, ignored_parameters: Optional[Iterable[str]] -) -> AnyRequest: - """Remove ignored parameters from request URL, body, and headers""" - if not ignored_parameters: - return request - request.headers = remove_ignored_headers(request.headers, ignored_parameters) - request.url = remove_ignored_url_params(request.url, ignored_parameters) - request.body = remove_ignored_body_params(request, ignored_parameters) - return request +def normalize_headers(headers: Mapping[str, str], ignored_parameters: ParamList) -> CaseInsensitiveDict: + """Sort and filter request headers""" + if ignored_parameters: + headers = filter_sort_dict(headers, ignored_parameters) + return CaseInsensitiveDict(headers) -def remove_ignored_url_params(url: Optional[str], ignored_parameters: Optional[Iterable[str]]) -> str: - """Remove any ignored request parameters from the URL""" - if not ignored_parameters or not url: - return url or '' +def normalize_url(url: str, ignored_parameters: ParamList) -> str: + """Normalize and filter a URL. This includes request parameters, IDN domains, scheme, host, + port, etc. + """ + # Strip query params from URL, sort and filter, and reassemble into a complete URL url_tokens = urlparse(url) - query = _filter_params(parse_qsl(url_tokens.query), ignored_parameters) - return urlunparse( + url = urlunparse( ( url_tokens.scheme, url_tokens.netloc, url_tokens.path, url_tokens.params, - urlencode(query), + normalize_params(url_tokens.query, ignored_parameters), url_tokens.fragment, ) ) + return url_normalize(url) -def remove_ignored_body_params( - request: AnyRequest, ignored_parameters: Optional[Iterable[str]] -) -> bytes: - """Remove any ignored parameters from the request body""" - original_body = request.body - filtered_body: Union[str, bytes] = b'' - content_type = request.headers.get('content-type') - if not ignored_parameters or not original_body or not content_type: - return encode(original_body) - - if content_type == 'application/x-www-form-urlencoded': - body = _filter_params(parse_qsl(decode(original_body)), ignored_parameters) - filtered_body = urlencode(body) - elif content_type == 'application/json': - body = json.loads(decode(original_body)).items() - body = _filter_params(sorted(body), ignored_parameters) - filtered_body = json.dumps(body) - else: - filtered_body = original_body + +def normalize_body(request: AnyPreparedRequest, ignored_parameters: ParamList) -> bytes: + """Normalize and filter a request body if possible, depending on Content-Type""" + original_body = request.body or b'' + content_type = request.headers.get('Content-Type') + + # Filter and sort params if possible + filtered_body: Union[str, bytes] = original_body + if content_type == 'application/json': + filtered_body = normalize_json_body(original_body, ignored_parameters) + elif content_type == 'application/x-www-form-urlencoded': + filtered_body = normalize_params(original_body, ignored_parameters) return encode(filtered_body) -def _filter_params( - data: List[Tuple[str, str]], ignored_parameters: Iterable[str] -) -> List[Tuple[str, str]]: - return [(k, v) for k, v in data if k not in set(ignored_parameters)] +# TODO: Skip this for a very large response body? +def normalize_json_body( + original_body: Union[str, bytes], ignored_parameters: ParamList +) -> Union[str, bytes]: + """Normalize and filter a request body with serialized JSON data""" + try: + body = json.loads(decode(original_body)) + body = filter_sort_dict(body, ignored_parameters) + return json.dumps(body) + # If it's invalid JSON, then don't mess with it + except (AttributeError, TypeError, ValueError): + logger.warning('Invalid JSON body:', exc_info=True) + return original_body -def normalize_dict( - items: Optional[RequestContent], normalize_data: bool = True -) -> Optional[RequestContent]: - """Sort items in a dict +# TODO: More thorough tests +def normalize_params(value: Union[str, bytes], ignored_parameters: ParamList) -> str: + """Normalize and filter urlencoded params from either a URL or request body with form data""" + params = dict(parse_qsl(decode(value))) + params = filter_sort_dict(params, ignored_parameters) + return urlencode(params) - Args: - items: Request params, data, or json - normalize_data: Also normalize stringified JSON - """ - if not items: - return None - if isinstance(items, Mapping): - return sort_dict(items) - if normalize_data and isinstance(items, (bytes, str)): - # Attempt to load body as JSON; not doing this by default as it could impact performance - try: - dict_items = json.loads(decode(items)) - dict_items = json.dumps(sort_dict(dict_items)) - return dict_items.encode('utf-8') if isinstance(items, bytes) else dict_items - except Exception: - pass - return items +def redact_response(response: CachedResponse, ignored_parameters: ParamList) -> CachedResponse: + """Redact any ignored parameters (potentially containing sensitive info) from a cached request""" + if ignored_parameters: + response.url = normalize_url(response.url, ignored_parameters) + response.request = normalize_request(response.request, ignored_parameters) # type: ignore + return response -def sort_dict(d: Mapping) -> Dict: - return dict(sorted(d.items(), key=itemgetter(0))) +def decode(value, encoding='utf-8') -> str: + """Decode a value from bytes, if hasn't already been. + Note: ``PreparedRequest.body`` is always encoded in utf-8. + """ + return value.decode(encoding) if isinstance(value, bytes) else value def encode(value, encoding='utf-8') -> bytes: @@ -187,8 +193,7 @@ def encode(value, encoding='utf-8') -> bytes: return value if isinstance(value, bytes) else str(value).encode(encoding) -def decode(value, encoding='utf-8') -> str: - """Decode a value from bytes, if hasn't already been. - Note: ``PreparedRequest.body`` is always encoded in utf-8. - """ - return value.decode(encoding) if isinstance(value, bytes) else value +def filter_sort_dict(data: Mapping[str, str], ignored_parameters: ParamList) -> Dict[str, str]: + if not ignored_parameters: + return dict(sorted(data.items())) + return {k: v for k, v in sorted(data.items()) if k not in set(ignored_parameters)} diff --git a/requests_cache/models/__init__.py b/requests_cache/models/__init__.py index dec305a..6ffc7ad 100644 --- a/requests_cache/models/__init__.py +++ b/requests_cache/models/__init__.py @@ -2,11 +2,12 @@ # flake8: noqa: F401 from typing import Union -from requests import PreparedRequest, Response +from requests import PreparedRequest, Request, Response from .raw_response import CachedHTTPResponse from .request import CachedRequest from .response import CachedResponse, set_response_defaults AnyResponse = Union[Response, CachedResponse] -AnyRequest = Union[PreparedRequest, CachedRequest] +AnyRequest = Union[Request, PreparedRequest, CachedRequest] +AnyPreparedRequest = Union[PreparedRequest, CachedRequest] diff --git a/requests_cache/models/request.py b/requests_cache/models/request.py index 46951fc..dbeddd0 100644 --- a/requests_cache/models/request.py +++ b/requests_cache/models/request.py @@ -1,6 +1,6 @@ from logging import getLogger -from attr import define, field, fields_dict +from attr import asdict, define, field, fields_dict from requests import PreparedRequest from requests.cookies import RequestsCookieJar from requests.structures import CaseInsensitiveDict @@ -27,6 +27,10 @@ class CachedRequest: kwargs['cookies'] = getattr(original_request, '_cookies', None) return cls(**kwargs) + def copy(self) -> 'CachedRequest': + """Return a copy of the CachedRequest""" + return self.__class__(**asdict(self)) + def prepare(self) -> PreparedRequest: """Convert the CachedRequest back into a PreparedRequest""" prepared_request = PreparedRequest() diff --git a/requests_cache/session.py b/requests_cache/session.py index a60c93e..bc6ffd5 100644 --- a/requests_cache/session.py +++ b/requests_cache/session.py @@ -26,9 +26,9 @@ from urllib3 import filepost from . import get_valid_kwargs from .backends import BackendSpecifier, init_backend from .cache_control import CacheActions, ExpirationTime, get_expiration_seconds -from .cache_keys import normalize_dict from .models import AnyResponse, CachedResponse, set_response_defaults +__all__ = ['ALL_METHODS', 'CachedSession', 'CacheMixin'] ALL_METHODS = ['GET', 'HEAD', 'OPTIONS', 'POST', 'PUT', 'PATCH', 'DELETE'] FILTER_FN = Callable[[AnyResponse], bool] @@ -119,9 +119,9 @@ class CacheMixin(MIXIN_BASE): return super().request( method, url, - params=normalize_dict(params), - data=normalize_dict(data), - json=normalize_dict(json), + params=params, + data=data, + json=json, headers=headers, **kwargs, ) diff --git a/tests/unit/test_cache_keys.py b/tests/unit/test_cache_keys.py index 0add627..c685d2b 100644 --- a/tests/unit/test_cache_keys.py +++ b/tests/unit/test_cache_keys.py @@ -2,24 +2,13 @@ This just contains tests for some extra edge cases not covered elsewhere. """ import pytest -from requests import PreparedRequest +from requests import PreparedRequest, Request -from requests_cache.cache_keys import ( - create_key, - normalize_dict, - remove_ignored_body_params, - remove_ignored_headers, -) - - -def test_normalize_dict__skip_body(): - assert normalize_dict(b'some bytes', normalize_data=False) == b'some bytes' +from requests_cache.cache_keys import create_key, normalize_request +CACHE_KEY = 'e8cb526891875e37' -CACHE_KEY = 'f8cd92cfe57ddbf9' - -# All of the following variations should produce the same cache key @pytest.mark.parametrize( 'url, params', [ @@ -35,9 +24,9 @@ CACHE_KEY = 'f8cd92cfe57ddbf9' ('https://example.com?', {'foo': 'bar', 'param': '1'}), ], ) -def test_normalize_url_params(url, params): - request = PreparedRequest() - request.prepare( +def test_create_key__normalize_url_params(url, params): + """All of the above variations should produce the same cache key""" + request = Request( method='GET', url=url, params=params, @@ -45,18 +34,41 @@ def test_normalize_url_params(url, params): assert create_key(request) == CACHE_KEY -def test_remove_ignored_body_params__binary(): - request = PreparedRequest() - request.method = 'GET' - request.url = 'https://img.site.com/base/img.jpg' - request.body = b'some bytes' - request.headers = {'Content-Type': 'application/octet-stream'} - assert remove_ignored_body_params(request, ignored_parameters=None) == request.body +def test_normalize_request__json_body(): + request = Request( + method='GET', + url='https://img.site.com/base/img.jpg', + data=b'{"param_1": "value_1", "param_2": "value_2"}', + headers={'Content-Type': 'application/json'}, + ) + assert normalize_request(request, ignored_parameters=['param_2']).body == b'{"param_1": "value_1"}' + + +def test_normalize_request__invalid_json_body(): + request = Request( + method='GET', + url='https://img.site.com/base/img.jpg', + data=b'invalid JSON!', + headers={'Content-Type': 'application/json'}, + ) + assert normalize_request(request, ignored_parameters=['param_2']).body == b'invalid JSON!' + + +def test_normalize_request__binary_body(): + request = Request( + method='GET', + url='https://img.site.com/base/img.jpg', + data=b'some bytes', + headers={'Content-Type': 'application/octet-stream'}, + ) + assert normalize_request(request, ignored_parameters=['param']).body == request.data def test_remove_ignored_headers__empty(): request = PreparedRequest() - request.method = 'GET' - request.url = 'https://img.site.com/base/img.jpg' - request.headers = {'foo': 'bar'} - assert remove_ignored_headers(request.headers, ignored_parameters=None) == request.headers + request.prepare( + method='GET', + url='https://img.site.com/base/img.jpg', + headers={'foo': 'bar'}, + ) + assert normalize_request(request, ignored_parameters=None).headers == request.headers diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 150a650..e95c829 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -6,6 +6,7 @@ from collections import UserDict, defaultdict from datetime import datetime, timedelta from pickle import PickleError from unittest.mock import patch +from urllib.parse import urlencode import pytest import requests @@ -279,15 +280,14 @@ def test_raw_data(method, mock_session): """POST and PUT requests with different data (raw) should be cached under different keys""" assert mock_session.request(method, MOCKED_URL, data='raw data').from_cache is False assert mock_session.request(method, MOCKED_URL, data='raw data').from_cache is True - assert mock_session.request(method, MOCKED_URL, data='new raw data').from_cache is False + assert mock_session.request(method, MOCKED_URL, data='{"data": "new raw data"}').from_cache is False -@pytest.mark.parametrize('mapping_class', [dict, UserDict, CaseInsensitiveDict]) @pytest.mark.parametrize('field', ['params', 'data', 'json']) -def test_normalize_params(field, mapping_class, mock_session): - """Test normalization with different combinations of data fields and dict-like classes""" +def test_normalize_params(field, mock_session): + """Test normalization with different combinations of data fields""" params = {"a": "a", "b": ["1", "2", "3"], "c": "4"} - reversed_params = mapping_class(sorted(params.items(), reverse=True)) + reversed_params = dict(sorted(params.items(), reverse=True)) assert mock_session.get(MOCKED_URL, **{field: params}).from_cache is False assert mock_session.get(MOCKED_URL, **{field: params}).from_cache is True @@ -297,25 +297,47 @@ def test_normalize_params(field, mapping_class, mock_session): assert mock_session.post(MOCKED_URL, **{field: {"a": "b"}}).from_cache is False -@pytest.mark.parametrize('field', ['data', 'json']) -def test_normalize_serialized_body(field, mock_session): +@pytest.mark.parametrize('mapping_class', [dict, UserDict, CaseInsensitiveDict]) +def test_normalize_params__custom_dicts(mapping_class, mock_session): + """Test normalization with different dict-like classes""" + params = {"a": "a", "b": ["1", "2", "3"], "c": "4"} + params = mapping_class(params.items()) + + assert mock_session.get(MOCKED_URL, params=params).from_cache is False + assert mock_session.get(MOCKED_URL, params=params).from_cache is True + assert mock_session.post(MOCKED_URL, params=params).from_cache is False + assert mock_session.post(MOCKED_URL, params=params).from_cache is True + + +def test_normalize_params__serialized_body(mock_session): """Test normalization for serialized request body content""" + headers = {'Content-Type': 'application/json'} params = {"a": "a", "b": ["1", "2", "3"], "c": "4"} - reversed_params = dict(sorted(params.items(), reverse=True)) + sorted_params = json.dumps(params) + reversed_params = json.dumps(dict(sorted(params.items(), reverse=True))) + + assert mock_session.post(MOCKED_URL, headers=headers, data=sorted_params).from_cache is False + assert mock_session.post(MOCKED_URL, headers=headers, data=sorted_params).from_cache is True + assert mock_session.post(MOCKED_URL, headers=headers, data=reversed_params).from_cache is True + + +def test_normalize_params__urlencoded_body(mock_session): + headers = {'Content-Type': 'application/x-www-form-urlencoded'} + params = urlencode({"a": "a", "b": "!@#$%^&*()[]", "c": "4"}) - assert mock_session.post(MOCKED_URL, **{field: json.dumps(params)}).from_cache is False - assert mock_session.post(MOCKED_URL, **{field: json.dumps(params)}).from_cache is True - assert mock_session.post(MOCKED_URL, **{field: json.dumps(reversed_params)}).from_cache is True + assert mock_session.post(MOCKED_URL, headers=headers, data=params).from_cache is False + assert mock_session.post(MOCKED_URL, headers=headers, data=params).from_cache is True + assert mock_session.post(MOCKED_URL, headers=headers, data=params).from_cache is True -def test_normalize_non_json_body(mock_session): +def test_normalize_params__non_json_body(mock_session): """For serialized request body content that isn't in JSON format, no normalization is expected""" assert mock_session.post(MOCKED_URL, data=b'key_1=value_1,key_2=value_2').from_cache is False assert mock_session.post(MOCKED_URL, data=b'key_1=value_1,key_2=value_2').from_cache is True assert mock_session.post(MOCKED_URL, data=b'key_2=value_2,key_1=value_1').from_cache is False -def test_normalize_url(mock_session): +def test_normalize_params__url(mock_session): """Test URL variations that should all result in the same key""" urls = [ 'https://site.com?param_1=value_1¶m_2=value_2', @@ -403,7 +425,7 @@ def test_response_defaults(mock_session): response_1 = mock_session.get(MOCKED_URL) response_2 = mock_session.get(MOCKED_URL) response_3 = mock_session.get(MOCKED_URL) - cache_key = '71c046cdb0afaa62' + cache_key = 'd7fa9fb7317b7412' assert response_1.cache_key == cache_key assert response_1.created_at is None @@ -423,21 +445,34 @@ def test_response_defaults(mock_session): def test_match_headers(mock_session): """With match_headers, requests with different headers should have different cache keys""" mock_session.cache.match_headers = True - headers_list = [{'Accept': 'text/json'}, {'Accept': 'text/xml'}, {'Accept': 'custom'}, None] + headers_list = [{'Accept': 'application/json'}, {'Accept': 'text/xml'}, {'Accept': 'custom'}, None] for headers in headers_list: assert mock_session.get(MOCKED_URL, headers=headers).from_cache is False assert mock_session.get(MOCKED_URL, headers=headers).from_cache is True -def test_match_headers_normalize(mock_session): +def test_match_headers__normalize(mock_session): """With match_headers, the same headers (in any order) should have the same cache key""" mock_session.cache.match_headers = True - headers = {'Accept': 'text/json', 'Custom': 'abc'} - reversed_headers = {'Custom': 'abc', 'Accept': 'text/json'} + headers = {'Accept': 'application/json', 'Custom': 'abc'} + reversed_headers = {'Custom': 'abc', 'Accept': 'application/json'} assert mock_session.get(MOCKED_URL, headers=headers).from_cache is False assert mock_session.get(MOCKED_URL, headers=reversed_headers).from_cache is True +def test_match_headers__list(mock_session): + """match_headers can optionally be a list of specific headers to include""" + mock_session.cache.match_headers = ['Accept'] + headers_1 = {'Accept': 'application/json', 'User-Agent': 'qutebrowser'} + headers_2 = {'Accept': 'application/json', 'User-Agent': 'Firefox'} + headers_3 = {'Accept': 'text/plain', 'User-Agent': 'qutebrowser'} + + assert mock_session.get(MOCKED_URL, headers=headers_1).from_cache is False + assert mock_session.get(MOCKED_URL, headers=headers_1).from_cache is True + assert mock_session.get(MOCKED_URL, headers=headers_2).from_cache is True + assert mock_session.get(MOCKED_URL, headers=headers_3).from_cache is False + + def test_include_get_headers(): """include_get_headers is aliased to match_headers for backwards-compatibility""" session = CachedSession(include_get_headers=True, backend='memory') |