diff options
Diffstat (limited to 'pint/facets/context/objects.py')
-rw-r--r-- | pint/facets/context/objects.py | 97 |
1 files changed, 72 insertions, 25 deletions
diff --git a/pint/facets/context/objects.py b/pint/facets/context/objects.py index 38d8805..9517821 100644 --- a/pint/facets/context/objects.py +++ b/pint/facets/context/objects.py @@ -10,12 +10,32 @@ from __future__ import annotations import weakref from collections import ChainMap, defaultdict -from typing import Any +from typing import Any, Callable, Protocol, Generic from collections.abc import Iterable -from ...facets.plain import UnitDefinition +from ...facets.plain import UnitDefinition, PlainQuantity, PlainUnit, MagnitudeT from ...util import UnitsContainer, to_units_container from .definitions import ContextDefinition +from ..._typing import Magnitude + + +class Transformation(Protocol): + def __call__(self, value: Magnitude, **kwargs: Any) -> Magnitude: + ... + + +from ..._typing import UnitLike + +ToBaseFunc = Callable[[UnitsContainer], UnitsContainer] +SrcDst = tuple[UnitsContainer, UnitsContainer] + + +class ContextQuantity(Generic[MagnitudeT], PlainQuantity[MagnitudeT]): + pass + + +class ContextUnit(PlainUnit): + pass class Context: @@ -75,24 +95,27 @@ class Context: aliases: tuple[str] = tuple(), defaults: dict[str, Any] | None = None, ) -> None: - self.name = name - self.aliases = aliases + self.name: str | None = name + self.aliases: tuple[str] = aliases #: Maps (src, dst) -> transformation function - self.funcs = {} + self.funcs: dict[SrcDst, Transformation] = {} #: Maps defaults variable names to values - self.defaults = defaults or {} + self.defaults: dict[str, Any] = defaults or {} # Store Definition objects that are context-specific - self.redefinitions = [] + # TODO: narrow type this if possible. + self.redefinitions: list[Any] = [] # Flag set to True by the Registry the first time the context is enabled self.checked = False #: Maps (src, dst) -> self #: Used as a convenience dictionary to be composed by ContextChain - self.relation_to_context = weakref.WeakValueDictionary() + self.relation_to_context: weakref.WeakValueDictionary[ + SrcDst, Context + ] = weakref.WeakValueDictionary() @classmethod def from_context(cls, context: Context, **defaults: Any) -> Context: @@ -125,13 +148,22 @@ class Context: @classmethod def from_lines( - cls, lines: Iterable[str], to_base_func=None, non_int_type: type = float + cls, + lines: Iterable[str], + to_base_func: ToBaseFunc | None = None, + non_int_type: type = float, ) -> Context: - cd = ContextDefinition.from_lines(lines, non_int_type) - return cls.from_definition(cd, to_base_func) + context_definition = ContextDefinition.from_lines(lines, non_int_type) + + if context_definition is None: + raise ValueError(f"Could not define Context from from {lines}") + + return cls.from_definition(context_definition, to_base_func) @classmethod - def from_definition(cls, cd: ContextDefinition, to_base_func=None) -> Context: + def from_definition( + cls, cd: ContextDefinition, to_base_func: ToBaseFunc | None = None + ) -> Context: ctx = cls(cd.name, cd.aliases, cd.defaults) for definition in cd.redefinitions: @@ -139,6 +171,7 @@ class Context: for relation in cd.relations: try: + # TODO: check to_base_func. Is it a good API idea? if to_base_func: src = to_base_func(relation.src) dst = to_base_func(relation.dst) @@ -154,14 +187,16 @@ class Context: return ctx - def add_transformation(self, src, dst, func) -> None: + def add_transformation( + self, src: UnitLike, dst: UnitLike, func: Transformation + ) -> None: """Add a transformation function to the context.""" _key = self.__keytransform__(src, dst) self.funcs[_key] = func self.relation_to_context[_key] = self - def remove_transformation(self, src, dst) -> None: + def remove_transformation(self, src: UnitLike, dst: UnitLike) -> None: """Add a transformation function to the context.""" _key = self.__keytransform__(src, dst) @@ -169,14 +204,17 @@ class Context: del self.relation_to_context[_key] @staticmethod - def __keytransform__(src, dst) -> tuple[UnitsContainer, UnitsContainer]: + def __keytransform__(src: UnitLike, dst: UnitLike) -> SrcDst: return to_units_container(src), to_units_container(dst) - def transform(self, src, dst, registry, value): + def transform( + self, src: UnitLike, dst: UnitLike, registry: Any, value: Magnitude + ) -> Magnitude: """Transform a value.""" _key = self.__keytransform__(src, dst) - return self.funcs[_key](registry, value, **self.defaults) + func = self.funcs[_key] + return func(registry, value, **self.defaults) def redefine(self, definition: str) -> None: """Override the definition of a unit in the registry. @@ -202,7 +240,13 @@ class Context: def hashable( self, - ) -> tuple[str | None, tuple[str, ...], frozenset, frozenset, tuple]: + ) -> tuple[ + str | None, + tuple[str], + frozenset[tuple[SrcDst, int]], + frozenset[tuple[str, Any]], + tuple[Any], + ]: """Generate a unique hashable and comparable representation of self, which can be used as a key in a dict. This class cannot define ``__hash__`` because it is mutable, and the Python interpreter does cache the output of ``__hash__``. @@ -220,18 +264,18 @@ class Context: ) -class ContextChain(ChainMap): +class ContextChain(ChainMap[SrcDst, Context]): """A specialized ChainMap for contexts that simplifies finding rules to transform from one dimension to another. """ def __init__(self): super().__init__() - self.contexts = [] + self.contexts: list[Context] = [] self.maps.clear() # Remove default empty map - self._graph = None + self._graph: dict[SrcDst, set[UnitsContainer]] | None = None - def insert_contexts(self, *contexts): + def insert_contexts(self, *contexts: Context): """Insert one or more contexts in reversed order the chained map. (A rule in last context will take precedence) @@ -243,7 +287,7 @@ class ContextChain(ChainMap): self.maps = [ctx.relation_to_context for ctx in reversed(contexts)] + self.maps self._graph = None - def remove_contexts(self, n: int = None): + def remove_contexts(self, n: int | None = None): """Remove the last n inserted contexts from the chain. Parameters @@ -257,7 +301,7 @@ class ContextChain(ChainMap): self._graph = None @property - def defaults(self): + def defaults(self) -> dict[str, Any]: for ctx in self.values(): return ctx.defaults return {} @@ -271,7 +315,10 @@ class ContextChain(ChainMap): self._graph[fr_].add(to_) return self._graph - def transform(self, src, dst, registry, value): + # TODO: type registry + def transform( + self, src: UnitsContainer, dst: UnitsContainer, registry: Any, value: Magnitude + ): """Transform the value, finding the rule in the chained context. (A rule in last context will take precedence) """ |