summaryrefslogtreecommitdiff
path: root/requests_cache/models/response.py
blob: e704f0319c3bbed9b70d973eb96a922b2056cfb9 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
from __future__ import annotations

from datetime import datetime, timedelta, timezone
from logging import getLogger
from time import time
from typing import TYPE_CHECKING, Dict, List, Optional, Union

import attr
from attr import define, field
from requests import PreparedRequest, Response
from requests.cookies import RequestsCookieJar
from requests.structures import CaseInsensitiveDict

from ..policy.expiration import ExpirationTime, get_expiration_datetime
from . import CachedHTTPResponse, CachedRequest, RichMixin

if TYPE_CHECKING:
    from ..policy.actions import CacheActions

DATETIME_FORMAT = '%Y-%m-%d %H:%M:%S %Z'  # Format used for __str__ only
DecodedContent = Union[Dict, str, None]
logger = getLogger(__name__)


@define(auto_attribs=False, repr=False, slots=False)
class BaseResponse(Response):
    """Wrapper class for responses returned by :py:class:`.CachedSession`. This mainly exists to
    provide type hints for extra cache-related attributes that are added to non-cached responses.
    """

    created_at: datetime = field(factory=datetime.utcnow)
    expires: Optional[datetime] = field(default=None)
    cache_key: str = ''  # Not serialized; set by BaseCache.get_response()
    revalidated: bool = False  # Not serialized; set by CacheActions.update_revalidated_response()

    @property
    def from_cache(self) -> bool:
        return False

    @property
    def is_expired(self) -> bool:
        return False


@define(auto_attribs=False, repr=False, slots=False)
class OriginalResponse(BaseResponse):
    """Wrapper class for non-cached responses returned by :py:class:`.CachedSession`"""

    @classmethod
    def wrap_response(cls, response: Response, actions: 'CacheActions'):
        """Modify a response object in-place and add extra cache-related attributes"""
        if not isinstance(response, cls):
            response.__class__ = cls
            # Add expires and cache_key only if the response was written to the cache
            response.expires = None if actions.skip_write else actions.expires  # type: ignore
            response.cache_key = None if actions.skip_write else actions.cache_key  # type: ignore
            response.created_at = datetime.utcnow()  # type: ignore
        return response


@define(auto_attribs=False, repr=False, slots=False)
class CachedResponse(RichMixin, BaseResponse):
    """A class that emulates :py:class:`requests.Response`, optimized for serialization"""

    _content: bytes = field(default=None)
    _decoded_content: DecodedContent = field(default=None)
    _next: Optional[CachedRequest] = field(default=None)
    cookies: RequestsCookieJar = field(factory=RequestsCookieJar)
    created_at: datetime = field(default=None)
    elapsed: timedelta = field(factory=timedelta)
    encoding: str = field(default=None)
    expires: Optional[datetime] = field(default=None)
    headers: CaseInsensitiveDict = field(factory=CaseInsensitiveDict)
    history: List['CachedResponse'] = field(factory=list)  # type: ignore
    raw: CachedHTTPResponse = None  # type: ignore  # Not serialized; populated from CachedResponse attrs
    reason: str = field(default=None)
    request: CachedRequest = field(factory=CachedRequest)  # type: ignore
    status_code: int = field(default=0)
    url: str = field(default=None)

    def __attrs_post_init__(self):
        # Not using created_at field default due to possible bug on Windows with omit_if_default
        self.created_at = self.created_at or datetime.utcnow()
        # Re-initialize raw (urllib3) response after deserialization
        self.raw = self.raw or CachedHTTPResponse.from_cached_response(self)

    @classmethod
    def from_response(cls, response: Response, **kwargs):
        """Create a CachedResponse based on an original Response or another CachedResponse object"""
        if isinstance(response, CachedResponse):
            obj = attr.evolve(response, **kwargs)
            obj._convert_redirects()
            return obj

        obj = cls(**kwargs)

        # Copy basic attributes
        for k in Response.__attrs__:
            setattr(obj, k, getattr(response, k, None))

        # Store request, raw response, and next response (if it's a redirect response)
        obj.raw = CachedHTTPResponse.from_response(response)
        obj.request = CachedRequest.from_request(response.request)
        obj._next = CachedRequest.from_request(response.next) if response.next else None

        # Store response body, which will have been read & decoded by requests.Response by now
        obj._content = response.content

        obj._convert_redirects()
        return obj

    def _convert_redirects(self):
        """Convert redirect history, if any; avoid recursion by not copying redirects of redirects"""
        if self.is_redirect:
            self.history = []
            return
        self.history = [self.from_response(redirect) for redirect in self.history]

    @property
    def _content_consumed(self) -> bool:
        """For compatibility with requests.Response; will always be True for a cached response"""
        return True

    @_content_consumed.setter
    def _content_consumed(self, value: bool):
        pass

    @property
    def expires_delta(self) -> Optional[int]:
        """Get time to expiration in seconds (rounded to the nearest second)"""
        if self.expires is None:
            return None
        delta = self.expires - datetime.utcnow()
        return round(delta.total_seconds())

    @property
    def expires_unix(self) -> Optional[int]:
        """Get expiration time as a Unix timestamp"""
        seconds = self.expires_delta
        return round(time() + seconds) if seconds is not None else None

    @property
    def from_cache(self) -> bool:
        return True

    @property
    def is_expired(self) -> bool:
        """Determine if this cached response is expired"""
        return self.expires is not None and datetime.utcnow() >= self.expires

    def is_older_than(self, older_than: ExpirationTime) -> bool:
        """Determine if this cached response is older than the given time"""
        older_than = get_expiration_datetime(older_than, negative_delta=True)
        return older_than is not None and self.created_at < older_than

    @property
    def next(self) -> Optional[PreparedRequest]:
        """Returns a PreparedRequest for the next request in a redirect chain, if there is one."""
        return self._next.prepare() if self._next else None

    def reset_expiration(self, expire_after: ExpirationTime):
        """Set a new expiration for this response"""
        self.expires = get_expiration_datetime(expire_after)
        return self.is_expired

    @property
    def size(self) -> int:
        """Get the size of the response body in bytes"""
        return len(self.content) if self.content else 0

    def __getstate__(self):
        """Override pickling behavior from ``requests.Response.__getstate__``"""
        return self.__dict__

    def __setstate__(self, state):
        """Override pickling behavior from ``requests.Response.__setstate__``"""
        for name, value in state.items():
            setattr(self, name, value)

    def __str__(self):
        return (
            f'<CachedResponse [{self.status_code}]: '
            f'created: {format_datetime(self.created_at)}, '
            f'expires: {format_datetime(self.expires)} ({"stale" if self.is_expired else "fresh"}), '
            f'size: {format_file_size(self.size)}, request: {self.request}>'
        )


def format_datetime(value: Optional[datetime]) -> str:
    """Get a formatted datetime string in the local time zone"""
    if not value:
        return "N/A"
    if value.tzinfo is None:
        value = value.replace(tzinfo=timezone.utc)
    return value.astimezone().strftime(DATETIME_FORMAT)


def format_file_size(n_bytes: int) -> str:
    """Convert a file size in bytes into a human-readable format"""
    filesize = float(n_bytes or 0)

    def _format(unit):
        return f'{int(filesize)} {unit}' if unit == 'bytes' else f'{filesize:.2f} {unit}'

    for unit in ['bytes', 'KiB', 'MiB', 'GiB']:
        if filesize < 1024 or unit == 'GiB':
            return _format(unit)
        filesize /= 1024

    if TYPE_CHECKING:
        return _format(unit)