summaryrefslogtreecommitdiff
path: root/pint/_vendor/flexcache.py
blob: 7b3969846fc372fd672ff3907b0330e78c18a345 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
"""
    flexcache.flexcache
    ~~~~~~~~~~~~~~~~~~~

    Classes for persistent caching and invalidating cached objects,
    which are built from a source object and a (potentially expensive)
    conversion function.

    Header
    ------
    Contains summary information about the source object that will
    be saved together with the cached file.

    It's capabilities are divided in three groups:
    - The Header itself which contains the information that will
      be saved alongside the cached file
    - The Naming logic which indicates how the cached filename is
      built.
    - The Invalidation logic which indicates whether a cached file
      is valid (i.e. truthful to the actual source file).

    DiskCache
    ---------
    Saves and loads to the cache a transformed versions of a source object.

    :copyright: 2022 by flexcache Authors, see AUTHORS for more details.
    :license: BSD, see LICENSE for more details.
"""

from __future__ import annotations

import abc
import hashlib
import json
import pathlib
import pickle
import platform
import typing
from dataclasses import asdict as dc_asdict
from dataclasses import dataclass
from dataclasses import fields as dc_fields
from typing import Any, Iterable

#########
# Header
#########


@dataclass(frozen=True)
class BaseHeader(abc.ABC):
    """Header with no information except the converter_id

    All header files must inherit from this.
    """

    # The actual source of the data (or a reference to it)
    # that is going to be converted.
    source: Any

    # An identification of the function that is used to
    # convert the source into the result object.
    converter_id: str

    _source_type = object

    def __post_init__(self):
        # TODO: In more modern python versions it would be
        # good to check for things like tuple[str].
        if not isinstance(self.source, self._source_type):
            raise TypeError(
                f"Source must be {self._source_type}, " f"not {type(self.source)}"
            )

    def for_cache_name(self) -> typing.Generator[bytes]:
        """The basename for the cache file is a hash hexdigest
        built by feeding this collection of values.

        A class can provide it's own set of values by rewriting
        `_for_cache_name`.
        """
        for el in self._for_cache_name():
            if isinstance(el, str):
                yield el.encode("utf-8")
            else:
                yield el

    def _for_cache_name(self) -> typing.Generator[bytes | str]:
        """The basename for the cache file is a hash hexdigest
        built by feeding this collection of values.

        Change the behavior by writing your own.
        """
        yield self.converter_id

    @abc.abstractmethod
    def is_valid(self, cache_path: pathlib.Path) -> bool:
        """Return True if the cache_path is an cached version
        of the source_object represented by this header.
        """


@dataclass(frozen=True)
class BasicPythonHeader(BaseHeader):
    """Header with basic Python information."""

    system: str = platform.system()
    python_implementation: str = platform.python_implementation()
    python_version: str = platform.python_version()


#####################
# Invalidation logic
#####################


class InvalidateByExist:
    """The cached file is valid if exists and is newer than the source file."""

    def is_valid(self, cache_path: pathlib.Path) -> bool:
        return cache_path.exists()


class InvalidateByPathMTime(abc.ABC):
    """The cached file is valid if exists and is newer than the source file."""

    @property
    @abc.abstractmethod
    def source_path(self) -> pathlib.Path:
        ...

    def is_valid(self, cache_path: pathlib.Path):
        return (
            cache_path.exists()
            and cache_path.stat().st_mtime > self.source_path.stat().st_mtime
        )


class InvalidateByMultiPathsMtime(abc.ABC):
    """The cached file is valid if exists and is newer than the newest source file."""

    @property
    @abc.abstractmethod
    def source_paths(self) -> pathlib.Path:
        ...

    @property
    def newest_date(self):
        return max((t.stat().st_mtime for t in self.source_paths), default=0)

    def is_valid(self, cache_path: pathlib.Path):
        return cache_path.exists() and cache_path.stat().st_mtime > self.newest_date


###############
# Naming logic
###############


class NameByFields:
    """Name is built taking into account all fields in the Header
    (except the source itself).
    """

    def _for_cache_name(self):
        yield from super()._for_cache_name()
        for field in dc_fields(self):
            if field.name not in ("source", "converter_id"):
                yield getattr(self, field.name)


class NameByFileContent:
    """Given a file source object, the name is built from its content."""

    _source_type = pathlib.Path

    @property
    def source_path(self) -> pathlib.Path:
        return self.source

    def _for_cache_name(self):
        yield from super()._for_cache_name()
        yield self.source_path.read_bytes()

    @classmethod
    def from_string(cls, s: str, converter_id: str):
        return cls(pathlib.Path(s), converter_id)


@dataclass(frozen=True)
class NameByObj:
    """Given a pickable source object, the name is built from its content."""

    pickle_protocol: int = pickle.HIGHEST_PROTOCOL

    def _for_cache_name(self):
        yield from super()._for_cache_name()
        yield pickle.dumps(self.source, protocol=self.pickle_protocol)


class NameByPath:
    """Given a file source object, the name is built from its resolved path."""

    _source_type = pathlib.Path

    @property
    def source_path(self) -> pathlib.Path:
        return self.source

    def _for_cache_name(self):
        yield from super()._for_cache_name()
        yield bytes(self.source_path.resolve())

    @classmethod
    def from_string(cls, s: str, converter_id: str):
        return cls(pathlib.Path(s), converter_id)


class NameByMultiPaths:
    """Given multiple file source object, the name is built from their resolved path
    in ascending order.
    """

    _source_type = tuple

    @property
    def source_paths(self) -> tuple[pathlib.Path]:
        return self.source

    def _for_cache_name(self):
        yield from super()._for_cache_name()
        yield from sorted(bytes(p.resolve()) for p in self.source_paths)

    @classmethod
    def from_strings(cls, ss: Iterable[str], converter_id: str):
        return cls(tuple(pathlib.Path(s) for s in ss), converter_id)


class NameByHashIter:
    """Given multiple hashes, the name is built from them in ascending order."""

    _source_type = tuple

    def _for_cache_name(self):
        yield from super()._for_cache_name()
        yield from sorted(h for h in self.source)


class DiskCache:
    """A class to store and load cached objects to disk, which
    are built from a source object and conversion function.

    The basename for the cache file is a hash hexdigest
    built by feeding a collection of values determined by
    the Header object.

    Parameters
    ----------
    cache_folder
        indicates where the cache files will be saved.
    """

    # Maps classes to header class
    _header_classes: dict[type, BaseHeader] = None

    # Hasher object constructor (e.g. a member of hashlib)
    # must implement update(b: bytes) and hexdigest() methods
    _hasher = hashlib.sha1

    # If True, for each cached file the header is also stored.
    _store_header: bool = True

    def __init__(self, cache_folder: str | pathlib.Path):
        self.cache_folder = pathlib.Path(cache_folder)
        self.cache_folder.mkdir(parents=True, exist_ok=True)
        self._header_classes = self._header_classes or {}

    def register_header_class(self, object_class: type, header_class: BaseHeader):
        self._header_classes[object_class] = header_class

    def cache_stem_for(self, header: BaseHeader) -> str:
        """Generate a hash representing the basename of a memoized file
        for a given header.

        The naming strategy is defined by the header class used.
        """
        hd = self._hasher()
        for value in header.for_cache_name():
            hd.update(value)
        return hd.hexdigest()

    def cache_path_for(self, header: BaseHeader) -> pathlib.Path:
        """Generate a Path representing the location of a memoized file
        for a given filepath or object.

        The naming strategy is defined by the header class used.
        """
        h = self.cache_stem_for(header)
        return self.cache_folder.joinpath(h).with_suffix(".pickle")

    def _get_header_class(self, source_object) -> BaseHeader:
        for k, v in self._header_classes.items():
            if isinstance(source_object, k):
                return v
        raise TypeError(f"Cannot find header class for {type(source_object)}")

    def load(self, source_object, converter=None, pass_hash=False) -> tuple[Any, str]:
        """Given a source_object, return the converted value stored
        in the cache together with the cached path stem

        When the cache is not found:
        - If a converter callable is given, use it on the source
          object, store the result in the cache and return it.
        - Return None, otherwise.

        Two signatures for the converter are valid:
        - source_object -> transformed object
        - (source_object, cached_path_stem) -> transformed_object

        To use the second one, use `pass_hash=True`.

        If you want to do the conversion yourself outside this class,
        use the converter argument to provide a name for it. This is
        important as the cached_path_stem depends on the converter name.
        """
        header_class = self._get_header_class(source_object)

        if isinstance(converter, str):
            converter_id = converter
            converter = None
        else:
            converter_id = getattr(converter, "__name__", "")

        header = header_class(source_object, converter_id)

        cache_path = self.cache_path_for(header)

        converted_object = self.rawload(header, cache_path)

        if converted_object:
            return converted_object, cache_path.stem
        if converter is None:
            return None, cache_path.stem

        if pass_hash:
            converted_object = converter(source_object, cache_path.stem)
        else:
            converted_object = converter(source_object)

        self.rawsave(header, converted_object, cache_path)

        return converted_object, cache_path.stem

    def save(self, converted_object, source_object, converter_id="") -> str:
        """Given a converted_object and its corresponding source_object,
        store it in the cache and return the cached_path_stem.
        """

        header_class = self._get_header_class(source_object)
        header = header_class(source_object, converter_id)
        return self.rawsave(header, converted_object, self.cache_path_for(header)).stem

    def rawload(
        self, header: BaseHeader, cache_path: pathlib.Path = None
    ) -> Any | None:
        """Load the converted_object from the cache if it is valid.

        The invalidating strategy is defined by the header class used.

        The cache_path is optional, it will be calculated from the header
        if not given.
        """
        if cache_path is None:
            cache_path = self.cache_path_for(header)

        if header.is_valid(cache_path):
            with cache_path.open(mode="rb") as fi:
                return pickle.load(fi)

    def rawsave(
        self, header: BaseHeader, converted, cache_path: pathlib.Path = None
    ) -> pathlib.Path:
        """Save the converted object (in pickle format) and
        its header (in json format) to the cache folder.

        The cache_path is optional, it will be calculated from the header
        if not given.
        """
        if cache_path is None:
            cache_path = self.cache_path_for(header)

        if self._store_header:
            with cache_path.with_suffix(".json").open("w", encoding="utf-8") as fo:
                json.dump({k: str(v) for k, v in dc_asdict(header).items()}, fo)
        with cache_path.open(mode="wb") as fo:
            pickle.dump(converted, fo)
        return cache_path


class DiskCacheByHash(DiskCache):
    """Convenience class used for caching conversions that take a path,
    naming by hashing its content.
    """

    @dataclass(frozen=True)
    class Header(NameByFileContent, InvalidateByExist, BaseHeader):
        pass

    _header_classes = {
        pathlib.Path: Header,
        str: Header.from_string,
    }


class DiskCacheByMTime(DiskCache):
    """Convenience class used for caching conversions that take a path,
    naming by hashing its full path and invalidating by the file
    modification time.
    """

    @dataclass(frozen=True)
    class Header(NameByPath, InvalidateByPathMTime, BaseHeader):
        pass

    _header_classes = {
        pathlib.Path: Header,
        str: Header.from_string,
    }