summaryrefslogtreecommitdiff
path: root/pint/facets/context/objects.py
diff options
context:
space:
mode:
Diffstat (limited to 'pint/facets/context/objects.py')
-rw-r--r--pint/facets/context/objects.py97
1 files changed, 72 insertions, 25 deletions
diff --git a/pint/facets/context/objects.py b/pint/facets/context/objects.py
index 38d8805..9517821 100644
--- a/pint/facets/context/objects.py
+++ b/pint/facets/context/objects.py
@@ -10,12 +10,32 @@ from __future__ import annotations
import weakref
from collections import ChainMap, defaultdict
-from typing import Any
+from typing import Any, Callable, Protocol, Generic
from collections.abc import Iterable
-from ...facets.plain import UnitDefinition
+from ...facets.plain import UnitDefinition, PlainQuantity, PlainUnit, MagnitudeT
from ...util import UnitsContainer, to_units_container
from .definitions import ContextDefinition
+from ..._typing import Magnitude
+
+
+class Transformation(Protocol):
+ def __call__(self, value: Magnitude, **kwargs: Any) -> Magnitude:
+ ...
+
+
+from ..._typing import UnitLike
+
+ToBaseFunc = Callable[[UnitsContainer], UnitsContainer]
+SrcDst = tuple[UnitsContainer, UnitsContainer]
+
+
+class ContextQuantity(Generic[MagnitudeT], PlainQuantity[MagnitudeT]):
+ pass
+
+
+class ContextUnit(PlainUnit):
+ pass
class Context:
@@ -75,24 +95,27 @@ class Context:
aliases: tuple[str] = tuple(),
defaults: dict[str, Any] | None = None,
) -> None:
- self.name = name
- self.aliases = aliases
+ self.name: str | None = name
+ self.aliases: tuple[str] = aliases
#: Maps (src, dst) -> transformation function
- self.funcs = {}
+ self.funcs: dict[SrcDst, Transformation] = {}
#: Maps defaults variable names to values
- self.defaults = defaults or {}
+ self.defaults: dict[str, Any] = defaults or {}
# Store Definition objects that are context-specific
- self.redefinitions = []
+ # TODO: narrow type this if possible.
+ self.redefinitions: list[Any] = []
# Flag set to True by the Registry the first time the context is enabled
self.checked = False
#: Maps (src, dst) -> self
#: Used as a convenience dictionary to be composed by ContextChain
- self.relation_to_context = weakref.WeakValueDictionary()
+ self.relation_to_context: weakref.WeakValueDictionary[
+ SrcDst, Context
+ ] = weakref.WeakValueDictionary()
@classmethod
def from_context(cls, context: Context, **defaults: Any) -> Context:
@@ -125,13 +148,22 @@ class Context:
@classmethod
def from_lines(
- cls, lines: Iterable[str], to_base_func=None, non_int_type: type = float
+ cls,
+ lines: Iterable[str],
+ to_base_func: ToBaseFunc | None = None,
+ non_int_type: type = float,
) -> Context:
- cd = ContextDefinition.from_lines(lines, non_int_type)
- return cls.from_definition(cd, to_base_func)
+ context_definition = ContextDefinition.from_lines(lines, non_int_type)
+
+ if context_definition is None:
+ raise ValueError(f"Could not define Context from from {lines}")
+
+ return cls.from_definition(context_definition, to_base_func)
@classmethod
- def from_definition(cls, cd: ContextDefinition, to_base_func=None) -> Context:
+ def from_definition(
+ cls, cd: ContextDefinition, to_base_func: ToBaseFunc | None = None
+ ) -> Context:
ctx = cls(cd.name, cd.aliases, cd.defaults)
for definition in cd.redefinitions:
@@ -139,6 +171,7 @@ class Context:
for relation in cd.relations:
try:
+ # TODO: check to_base_func. Is it a good API idea?
if to_base_func:
src = to_base_func(relation.src)
dst = to_base_func(relation.dst)
@@ -154,14 +187,16 @@ class Context:
return ctx
- def add_transformation(self, src, dst, func) -> None:
+ def add_transformation(
+ self, src: UnitLike, dst: UnitLike, func: Transformation
+ ) -> None:
"""Add a transformation function to the context."""
_key = self.__keytransform__(src, dst)
self.funcs[_key] = func
self.relation_to_context[_key] = self
- def remove_transformation(self, src, dst) -> None:
+ def remove_transformation(self, src: UnitLike, dst: UnitLike) -> None:
"""Add a transformation function to the context."""
_key = self.__keytransform__(src, dst)
@@ -169,14 +204,17 @@ class Context:
del self.relation_to_context[_key]
@staticmethod
- def __keytransform__(src, dst) -> tuple[UnitsContainer, UnitsContainer]:
+ def __keytransform__(src: UnitLike, dst: UnitLike) -> SrcDst:
return to_units_container(src), to_units_container(dst)
- def transform(self, src, dst, registry, value):
+ def transform(
+ self, src: UnitLike, dst: UnitLike, registry: Any, value: Magnitude
+ ) -> Magnitude:
"""Transform a value."""
_key = self.__keytransform__(src, dst)
- return self.funcs[_key](registry, value, **self.defaults)
+ func = self.funcs[_key]
+ return func(registry, value, **self.defaults)
def redefine(self, definition: str) -> None:
"""Override the definition of a unit in the registry.
@@ -202,7 +240,13 @@ class Context:
def hashable(
self,
- ) -> tuple[str | None, tuple[str, ...], frozenset, frozenset, tuple]:
+ ) -> tuple[
+ str | None,
+ tuple[str],
+ frozenset[tuple[SrcDst, int]],
+ frozenset[tuple[str, Any]],
+ tuple[Any],
+ ]:
"""Generate a unique hashable and comparable representation of self, which can
be used as a key in a dict. This class cannot define ``__hash__`` because it is
mutable, and the Python interpreter does cache the output of ``__hash__``.
@@ -220,18 +264,18 @@ class Context:
)
-class ContextChain(ChainMap):
+class ContextChain(ChainMap[SrcDst, Context]):
"""A specialized ChainMap for contexts that simplifies finding rules
to transform from one dimension to another.
"""
def __init__(self):
super().__init__()
- self.contexts = []
+ self.contexts: list[Context] = []
self.maps.clear() # Remove default empty map
- self._graph = None
+ self._graph: dict[SrcDst, set[UnitsContainer]] | None = None
- def insert_contexts(self, *contexts):
+ def insert_contexts(self, *contexts: Context):
"""Insert one or more contexts in reversed order the chained map.
(A rule in last context will take precedence)
@@ -243,7 +287,7 @@ class ContextChain(ChainMap):
self.maps = [ctx.relation_to_context for ctx in reversed(contexts)] + self.maps
self._graph = None
- def remove_contexts(self, n: int = None):
+ def remove_contexts(self, n: int | None = None):
"""Remove the last n inserted contexts from the chain.
Parameters
@@ -257,7 +301,7 @@ class ContextChain(ChainMap):
self._graph = None
@property
- def defaults(self):
+ def defaults(self) -> dict[str, Any]:
for ctx in self.values():
return ctx.defaults
return {}
@@ -271,7 +315,10 @@ class ContextChain(ChainMap):
self._graph[fr_].add(to_)
return self._graph
- def transform(self, src, dst, registry, value):
+ # TODO: type registry
+ def transform(
+ self, src: UnitsContainer, dst: UnitsContainer, registry: Any, value: Magnitude
+ ):
"""Transform the value, finding the rule in the chained context.
(A rule in last context will take precedence)
"""