summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorHernan Grecco <hgrecco@gmail.com>2023-05-04 17:21:35 -0300
committerHernan Grecco <hgrecco@gmail.com>2023-05-05 03:27:28 -0300
commit2f4125d0be4caa21a4ce2726bbc266cf265d822d (patch)
tree01d34e44e1252d953236a56bf1f03b9b4d314202
parent5643c32f7f2c886015df459f26b4d72adaee1207 (diff)
downloadpint-2f4125d0be4caa21a4ce2726bbc266cf265d822d.tar.gz
Large commit to make Pint more typing friendly
In this very large commit we tackle a few aspects of Pint that makes it difficult to do static typing. 1. Dynamic classes became static: Quantity and Unit are now (for the most part) static classes with a static inheritance. This allows mypy/pylance and other type checker to properly inspect them. 2. Added types through out all the code. (WIP) 3. Refactor minor parts of the code to make it more typing homogeneous. Catch a few potential bugs in the way. 4. Add several TODOs that need to be addressed in 0.23 5. Moved some group and system and context code out of the PlainRegistry 6. Moved certain specialized methods out of the PlainRegistry.
-rw-r--r--pint/_typing.py53
-rw-r--r--pint/compat.py14
-rw-r--r--pint/converters.py14
-rw-r--r--pint/delegates/txt_defparser/defparser.py2
-rw-r--r--pint/facets/__init__.py33
-rw-r--r--pint/facets/context/__init__.py4
-rw-r--r--pint/facets/context/definitions.py8
-rw-r--r--pint/facets/context/objects.py97
-rw-r--r--pint/facets/context/registry.py48
-rw-r--r--pint/facets/dask/__init__.py29
-rw-r--r--pint/facets/formatting/__init__.py9
-rw-r--r--pint/facets/formatting/objects.py6
-rw-r--r--pint/facets/formatting/registry.py21
-rw-r--r--pint/facets/group/__init__.py13
-rw-r--r--pint/facets/group/definitions.py2
-rw-r--r--pint/facets/group/objects.py37
-rw-r--r--pint/facets/group/registry.py50
-rw-r--r--pint/facets/measurement/__init__.py9
-rw-r--r--pint/facets/measurement/objects.py9
-rw-r--r--pint/facets/measurement/registry.py21
-rw-r--r--pint/facets/nonmultiplicative/__init__.py6
-rw-r--r--pint/facets/nonmultiplicative/objects.py10
-rw-r--r--pint/facets/nonmultiplicative/registry.py90
-rw-r--r--pint/facets/numpy/__init__.py4
-rw-r--r--pint/facets/numpy/quantity.py16
-rw-r--r--pint/facets/numpy/registry.py17
-rw-r--r--pint/facets/plain/__init__.py7
-rw-r--r--pint/facets/plain/definitions.py44
-rw-r--r--pint/facets/plain/qto.py386
-rw-r--r--pint/facets/plain/quantity.py493
-rw-r--r--pint/facets/plain/registry.py329
-rw-r--r--pint/facets/system/__init__.py4
-rw-r--r--pint/facets/system/definitions.py2
-rw-r--r--pint/facets/system/objects.py43
-rw-r--r--pint/facets/system/registry.py80
-rw-r--r--pint/formatting.py63
-rw-r--r--pint/registry.py59
-rw-r--r--pint/registry_helpers.py8
-rw-r--r--pint/testing.py6
-rw-r--r--pint/util.py55
40 files changed, 1368 insertions, 833 deletions
diff --git a/pint/_typing.py b/pint/_typing.py
index 65e355c..5177e78 100644
--- a/pint/_typing.py
+++ b/pint/_typing.py
@@ -1,9 +1,9 @@
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Callable, TypeVar, Union, Protocol
+from decimal import Decimal
+from fractions import Fraction
-# TODO: Remove when 3.11 becomes minimal version.
-Self = TypeVar("Self")
if TYPE_CHECKING:
from .facets.plain import PlainQuantity as Quantity
@@ -11,7 +11,7 @@ if TYPE_CHECKING:
from .util import UnitsContainer
-class PintScalar(Protocol):
+class ScalarProtocol(Protocol):
def __add__(self, other: Any) -> Any:
...
@@ -36,8 +36,20 @@ class PintScalar(Protocol):
def __pow__(self, other: Any, modulo: Any) -> Any:
...
+ def __gt__(self, other: Any) -> bool:
+ ...
+
+ def __lt__(self, other: Any) -> bool:
+ ...
+
+ def __ge__(self, other: Any) -> bool:
+ ...
-class PintArray(Protocol):
+ def __le__(self, other: Any) -> bool:
+ ...
+
+
+class ArrayProtocol(Protocol):
def __len__(self) -> int:
...
@@ -48,18 +60,41 @@ class PintArray(Protocol):
...
+HAS_NUMPY = False
+if TYPE_CHECKING:
+ from .compat import HAS_NUMPY
+
+if HAS_NUMPY:
+ from .compat import np
+
+ Scalar = Union[ScalarProtocol, float, int, Decimal, Fraction, np.number[Any]]
+ Array = Union[np.ndarray[Any, Any]]
+else:
+ Scalar = Union[ScalarProtocol, float, int, Decimal, Fraction]
+ Array = ArrayProtocol
+
+
# TODO: Change when Python 3.10 becomes minimal version.
-# Magnitude = PintScalar | PintArray
-Magnitude = Union[PintScalar, PintArray]
+Magnitude = Union[ScalarProtocol, ArrayProtocol]
-UnitLike = Union[str, "UnitsContainer", "Unit"]
+UnitLike = Union[str, dict[str, Scalar], "UnitsContainer", "Unit"]
QuantityOrUnitLike = Union["Quantity", UnitLike]
-Shape = tuple[int, ...]
+Shape = tuple[int]
-_MagnitudeType = TypeVar("_MagnitudeType")
S = TypeVar("S")
FuncType = Callable[..., Any]
F = TypeVar("F", bound=FuncType)
+
+
+# TODO: Improve or delete types
+QuantityArgument = Any
+
+T = TypeVar("T")
+
+
+class Handler(Protocol):
+ def __getitem__(self, item: type[T]) -> Callable[[T], None]:
+ ...
diff --git a/pint/compat.py b/pint/compat.py
index 7b48efa..727ff99 100644
--- a/pint/compat.py
+++ b/pint/compat.py
@@ -20,6 +20,16 @@ from collections.abc import Mapping
from typing import Any, NoReturn, Callable
from collections.abc import Generator, Iterable
+try:
+ from typing import TypeAlias # noqa
+except ImportError:
+ from typing_extensions import TypeAlias # noqa
+
+try:
+ from typing import Self # noqa
+except ImportError:
+ from typing_extensions import Self # noqa
+
def missing_dependency(
package: str, display_name: str | None = None
@@ -137,10 +147,10 @@ except ImportError:
HAS_UNCERTAINTIES = False
try:
- from babel import Locale as Loc
+ from babel import Locale
from babel import units as babel_units
- babel_parse = Loc.parse
+ babel_parse = Locale.parse
HAS_BABEL = hasattr(babel_units, "format_unit")
except ImportError:
diff --git a/pint/converters.py b/pint/converters.py
index 9494ad1..822b8a0 100644
--- a/pint/converters.py
+++ b/pint/converters.py
@@ -15,16 +15,18 @@ from dataclasses import fields as dc_fields
from typing import Any
-from ._typing import Self, Magnitude
+from ._typing import Magnitude
-from .compat import HAS_NUMPY, exp, log # noqa: F401
+from .compat import HAS_NUMPY, exp, log, Self # noqa: F401
@dataclass(frozen=True)
class Converter:
"""Base class for value converters."""
+ # list[type[Converter]]
_subclasses = []
+ # dict[frozenset[str], type[Converter]]
_param_names_to_subclass = {}
@property
@@ -41,21 +43,21 @@ class Converter:
def from_reference(self, value: Magnitude, inplace: bool = False) -> Magnitude:
return value
- def __init_subclass__(cls, **kwargs):
+ def __init_subclass__(cls, **kwargs: Any):
# Get constructor parameters
super().__init_subclass__(**kwargs)
cls._subclasses.append(cls)
@classmethod
- def get_field_names(cls, new_cls) -> frozenset[str]:
+ def get_field_names(cls, new_cls: type) -> frozenset[str]:
return frozenset(p.name for p in dc_fields(new_cls))
@classmethod
- def preprocess_kwargs(cls, **kwargs):
+ def preprocess_kwargs(cls, **kwargs: Any) -> dict[str, Any] | None:
return None
@classmethod
- def from_arguments(cls: type[Self], **kwargs: Any) -> Self:
+ def from_arguments(cls, **kwargs: Any) -> Converter:
kwk = frozenset(kwargs.keys())
try:
new_cls = cls._param_names_to_subclass[kwk]
diff --git a/pint/delegates/txt_defparser/defparser.py b/pint/delegates/txt_defparser/defparser.py
index f1b8e45..4acea2f 100644
--- a/pint/delegates/txt_defparser/defparser.py
+++ b/pint/delegates/txt_defparser/defparser.py
@@ -130,7 +130,7 @@ class DefParser:
else:
yield stmt
- def parse_file(self, filename: pathlib.Path, cfg: ParserConfig | None = None):
+ def parse_file(self, filename: pathlib.Path | str, cfg: ParserConfig | None = None):
return fp.parse(
filename,
_PintParser,
diff --git a/pint/facets/__init__.py b/pint/facets/__init__.py
index 750f729..4fd1597 100644
--- a/pint/facets/__init__.py
+++ b/pint/facets/__init__.py
@@ -71,15 +71,18 @@
from __future__ import annotations
-from .context import ContextRegistry
-from .dask import DaskRegistry
-from .formatting import FormattingRegistry
-from .group import GroupRegistry
-from .measurement import MeasurementRegistry
-from .nonmultiplicative import NonMultiplicativeRegistry
-from .numpy import NumpyRegistry
-from .plain import PlainRegistry
-from .system import SystemRegistry
+from .context import ContextRegistry, GenericContextRegistry
+from .dask import DaskRegistry, GenericDaskRegistry
+from .formatting import FormattingRegistry, GenericFormattingRegistry
+from .group import GroupRegistry, GenericGroupRegistry
+from .measurement import MeasurementRegistry, GenericMeasurementRegistry
+from .nonmultiplicative import (
+ NonMultiplicativeRegistry,
+ GenericNonMultiplicativeRegistry,
+)
+from .numpy import NumpyRegistry, GenericNumpyRegistry
+from .plain import PlainRegistry, GenericPlainRegistry, QuantityT, UnitT, MagnitudeT
+from .system import SystemRegistry, GenericSystemRegistry
__all__ = [
"ContextRegistry",
@@ -91,4 +94,16 @@ __all__ = [
"NumpyRegistry",
"PlainRegistry",
"SystemRegistry",
+ "GenericContextRegistry",
+ "GenericDaskRegistry",
+ "GenericFormattingRegistry",
+ "GenericGroupRegistry",
+ "GenericMeasurementRegistry",
+ "GenericNonMultiplicativeRegistry",
+ "GenericNumpyRegistry",
+ "GenericPlainRegistry",
+ "GenericSystemRegistry",
+ "QuantityT",
+ "UnitT",
+ "MagnitudeT",
]
diff --git a/pint/facets/context/__init__.py b/pint/facets/context/__init__.py
index db28436..28c7b5c 100644
--- a/pint/facets/context/__init__.py
+++ b/pint/facets/context/__init__.py
@@ -13,6 +13,6 @@ from __future__ import annotations
from .definitions import ContextDefinition
from .objects import Context
-from .registry import ContextRegistry
+from .registry import ContextRegistry, GenericContextRegistry
-__all__ = ["ContextDefinition", "Context", "ContextRegistry"]
+__all__ = ["ContextDefinition", "Context", "ContextRegistry", "GenericContextRegistry"]
diff --git a/pint/facets/context/definitions.py b/pint/facets/context/definitions.py
index 833857e..d2581d5 100644
--- a/pint/facets/context/definitions.py
+++ b/pint/facets/context/definitions.py
@@ -12,7 +12,7 @@ import itertools
import numbers
import re
from dataclasses import dataclass
-from typing import TYPE_CHECKING, Any, Callable
+from typing import TYPE_CHECKING, Callable
from collections.abc import Iterable
from ... import errors
@@ -47,7 +47,7 @@ class Relation:
return set(self._varname_re.findall(self.equation))
@property
- def transformation(self) -> Callable[..., Quantity[Any]]:
+ def transformation(self) -> Callable[..., Quantity]:
"""Return a transformation callable that uses the registry
to parse the transformation equation.
"""
@@ -68,7 +68,7 @@ class ForwardRelation(Relation):
"""
@property
- def bidirectional(self):
+ def bidirectional(self) -> bool:
return False
@@ -82,7 +82,7 @@ class BidirectionalRelation(Relation):
"""
@property
- def bidirectional(self):
+ def bidirectional(self) -> bool:
return True
diff --git a/pint/facets/context/objects.py b/pint/facets/context/objects.py
index 38d8805..9517821 100644
--- a/pint/facets/context/objects.py
+++ b/pint/facets/context/objects.py
@@ -10,12 +10,32 @@ from __future__ import annotations
import weakref
from collections import ChainMap, defaultdict
-from typing import Any
+from typing import Any, Callable, Protocol, Generic
from collections.abc import Iterable
-from ...facets.plain import UnitDefinition
+from ...facets.plain import UnitDefinition, PlainQuantity, PlainUnit, MagnitudeT
from ...util import UnitsContainer, to_units_container
from .definitions import ContextDefinition
+from ..._typing import Magnitude
+
+
+class Transformation(Protocol):
+ def __call__(self, value: Magnitude, **kwargs: Any) -> Magnitude:
+ ...
+
+
+from ..._typing import UnitLike
+
+ToBaseFunc = Callable[[UnitsContainer], UnitsContainer]
+SrcDst = tuple[UnitsContainer, UnitsContainer]
+
+
+class ContextQuantity(Generic[MagnitudeT], PlainQuantity[MagnitudeT]):
+ pass
+
+
+class ContextUnit(PlainUnit):
+ pass
class Context:
@@ -75,24 +95,27 @@ class Context:
aliases: tuple[str] = tuple(),
defaults: dict[str, Any] | None = None,
) -> None:
- self.name = name
- self.aliases = aliases
+ self.name: str | None = name
+ self.aliases: tuple[str] = aliases
#: Maps (src, dst) -> transformation function
- self.funcs = {}
+ self.funcs: dict[SrcDst, Transformation] = {}
#: Maps defaults variable names to values
- self.defaults = defaults or {}
+ self.defaults: dict[str, Any] = defaults or {}
# Store Definition objects that are context-specific
- self.redefinitions = []
+ # TODO: narrow type this if possible.
+ self.redefinitions: list[Any] = []
# Flag set to True by the Registry the first time the context is enabled
self.checked = False
#: Maps (src, dst) -> self
#: Used as a convenience dictionary to be composed by ContextChain
- self.relation_to_context = weakref.WeakValueDictionary()
+ self.relation_to_context: weakref.WeakValueDictionary[
+ SrcDst, Context
+ ] = weakref.WeakValueDictionary()
@classmethod
def from_context(cls, context: Context, **defaults: Any) -> Context:
@@ -125,13 +148,22 @@ class Context:
@classmethod
def from_lines(
- cls, lines: Iterable[str], to_base_func=None, non_int_type: type = float
+ cls,
+ lines: Iterable[str],
+ to_base_func: ToBaseFunc | None = None,
+ non_int_type: type = float,
) -> Context:
- cd = ContextDefinition.from_lines(lines, non_int_type)
- return cls.from_definition(cd, to_base_func)
+ context_definition = ContextDefinition.from_lines(lines, non_int_type)
+
+ if context_definition is None:
+ raise ValueError(f"Could not define Context from from {lines}")
+
+ return cls.from_definition(context_definition, to_base_func)
@classmethod
- def from_definition(cls, cd: ContextDefinition, to_base_func=None) -> Context:
+ def from_definition(
+ cls, cd: ContextDefinition, to_base_func: ToBaseFunc | None = None
+ ) -> Context:
ctx = cls(cd.name, cd.aliases, cd.defaults)
for definition in cd.redefinitions:
@@ -139,6 +171,7 @@ class Context:
for relation in cd.relations:
try:
+ # TODO: check to_base_func. Is it a good API idea?
if to_base_func:
src = to_base_func(relation.src)
dst = to_base_func(relation.dst)
@@ -154,14 +187,16 @@ class Context:
return ctx
- def add_transformation(self, src, dst, func) -> None:
+ def add_transformation(
+ self, src: UnitLike, dst: UnitLike, func: Transformation
+ ) -> None:
"""Add a transformation function to the context."""
_key = self.__keytransform__(src, dst)
self.funcs[_key] = func
self.relation_to_context[_key] = self
- def remove_transformation(self, src, dst) -> None:
+ def remove_transformation(self, src: UnitLike, dst: UnitLike) -> None:
"""Add a transformation function to the context."""
_key = self.__keytransform__(src, dst)
@@ -169,14 +204,17 @@ class Context:
del self.relation_to_context[_key]
@staticmethod
- def __keytransform__(src, dst) -> tuple[UnitsContainer, UnitsContainer]:
+ def __keytransform__(src: UnitLike, dst: UnitLike) -> SrcDst:
return to_units_container(src), to_units_container(dst)
- def transform(self, src, dst, registry, value):
+ def transform(
+ self, src: UnitLike, dst: UnitLike, registry: Any, value: Magnitude
+ ) -> Magnitude:
"""Transform a value."""
_key = self.__keytransform__(src, dst)
- return self.funcs[_key](registry, value, **self.defaults)
+ func = self.funcs[_key]
+ return func(registry, value, **self.defaults)
def redefine(self, definition: str) -> None:
"""Override the definition of a unit in the registry.
@@ -202,7 +240,13 @@ class Context:
def hashable(
self,
- ) -> tuple[str | None, tuple[str, ...], frozenset, frozenset, tuple]:
+ ) -> tuple[
+ str | None,
+ tuple[str],
+ frozenset[tuple[SrcDst, int]],
+ frozenset[tuple[str, Any]],
+ tuple[Any],
+ ]:
"""Generate a unique hashable and comparable representation of self, which can
be used as a key in a dict. This class cannot define ``__hash__`` because it is
mutable, and the Python interpreter does cache the output of ``__hash__``.
@@ -220,18 +264,18 @@ class Context:
)
-class ContextChain(ChainMap):
+class ContextChain(ChainMap[SrcDst, Context]):
"""A specialized ChainMap for contexts that simplifies finding rules
to transform from one dimension to another.
"""
def __init__(self):
super().__init__()
- self.contexts = []
+ self.contexts: list[Context] = []
self.maps.clear() # Remove default empty map
- self._graph = None
+ self._graph: dict[SrcDst, set[UnitsContainer]] | None = None
- def insert_contexts(self, *contexts):
+ def insert_contexts(self, *contexts: Context):
"""Insert one or more contexts in reversed order the chained map.
(A rule in last context will take precedence)
@@ -243,7 +287,7 @@ class ContextChain(ChainMap):
self.maps = [ctx.relation_to_context for ctx in reversed(contexts)] + self.maps
self._graph = None
- def remove_contexts(self, n: int = None):
+ def remove_contexts(self, n: int | None = None):
"""Remove the last n inserted contexts from the chain.
Parameters
@@ -257,7 +301,7 @@ class ContextChain(ChainMap):
self._graph = None
@property
- def defaults(self):
+ def defaults(self) -> dict[str, Any]:
for ctx in self.values():
return ctx.defaults
return {}
@@ -271,7 +315,10 @@ class ContextChain(ChainMap):
self._graph[fr_].add(to_)
return self._graph
- def transform(self, src, dst, registry, value):
+ # TODO: type registry
+ def transform(
+ self, src: UnitsContainer, dst: UnitsContainer, registry: Any, value: Magnitude
+ ):
"""Transform the value, finding the rule in the chained context.
(A rule in last context will take precedence)
"""
diff --git a/pint/facets/context/registry.py b/pint/facets/context/registry.py
index a36d82d..746e79c 100644
--- a/pint/facets/context/registry.py
+++ b/pint/facets/context/registry.py
@@ -11,12 +11,13 @@ from __future__ import annotations
import functools
from collections import ChainMap
from contextlib import contextmanager
-from typing import Any, Callable, ContextManager
+from typing import Any, Callable, Generator, Generic
-from ..._typing import F
+from ...compat import TypeAlias
+from ..._typing import F, Magnitude
from ...errors import UndefinedUnitError
-from ...util import find_connected_nodes, find_shortest_path, logger
-from ..plain import PlainRegistry, UnitDefinition
+from ...util import find_connected_nodes, find_shortest_path, logger, UnitsContainer
+from ..plain import GenericPlainRegistry, UnitDefinition, QuantityT, UnitT
from .definitions import ContextDefinition
from . import objects
@@ -36,7 +37,9 @@ class ContextCacheOverlay:
self.parse_unit = registry_cache.parse_unit
-class ContextRegistry(PlainRegistry):
+class GenericContextRegistry(
+ Generic[QuantityT, UnitT], GenericPlainRegistry[QuantityT, UnitT]
+):
"""Handle of Contexts.
Conversion between units with different dimensions according
@@ -50,7 +53,7 @@ class ContextRegistry(PlainRegistry):
- Parse @context directive.
"""
- Context = objects.Context
+ Context: type[objects.Context] = objects.Context
def __init__(self, **kwargs: Any) -> None:
# Map context name (string) or abbreviation to context.
@@ -65,13 +68,13 @@ class ContextRegistry(PlainRegistry):
super().__init__(**kwargs)
# Allow contexts to add override layers to the units
- self._units = ChainMap(self._units)
+ self._units: ChainMap[str, UnitDefinition] = ChainMap(self._units)
def _register_definition_adders(self) -> None:
super()._register_definition_adders()
self._register_adder(ContextDefinition, self.add_context)
- def add_context(self, context: Context | ContextDefinition) -> None:
+ def add_context(self, context: objects.Context | ContextDefinition) -> None:
"""Add a context object to the registry.
The context will be accessible by its name and aliases.
@@ -194,7 +197,7 @@ class ContextRegistry(PlainRegistry):
self.define(definition)
def enable_contexts(
- self, *names_or_contexts: str | objects.Context, **kwargs
+ self, *names_or_contexts: str | objects.Context, **kwargs: Any
) -> None:
"""Enable contexts provided by name or by object.
@@ -241,7 +244,7 @@ class ContextRegistry(PlainRegistry):
self._active_ctx.insert_contexts(*contexts)
self._switch_context_cache_and_units()
- def disable_contexts(self, n: int = None) -> None:
+ def disable_contexts(self, n: int | None = None) -> None:
"""Disable the last n enabled contexts.
Parameters
@@ -253,7 +256,9 @@ class ContextRegistry(PlainRegistry):
self._switch_context_cache_and_units()
@contextmanager
- def context(self, *names, **kwargs) -> ContextManager[objects.Context]:
+ def context(
+ self: GenericContextRegistry[QuantityT, UnitT], *names: str, **kwargs: Any
+ ) -> Generator[GenericContextRegistry[QuantityT, UnitT], None, None]:
"""Used as a context manager, this function enables to activate a context
which is removed after usage.
@@ -309,7 +314,7 @@ class ContextRegistry(PlainRegistry):
# the added contexts are removed from the active one.
self.disable_contexts(len(names))
- def with_context(self, name, **kwargs) -> Callable[[F], F]:
+ def with_context(self, name: str, **kwargs: Any) -> Callable[[F], F]:
"""Decorator to wrap a function call in a Pint context.
Use it to ensure that a certain context is active when
@@ -351,7 +356,13 @@ class ContextRegistry(PlainRegistry):
return decorator
- def _convert(self, value, src, dst, inplace=False):
+ def _convert(
+ self,
+ value: Magnitude,
+ src: UnitsContainer,
+ dst: UnitsContainer,
+ inplace: bool = False,
+ ) -> Magnitude:
"""Convert value from some source to destination units.
In addition to what is done by the PlainRegistry,
@@ -391,7 +402,9 @@ class ContextRegistry(PlainRegistry):
return super()._convert(value, src, dst, inplace)
- def _get_compatible_units(self, input_units, group_or_system):
+ def _get_compatible_units(
+ self, input_units: UnitsContainer, group_or_system: str | None = None
+ ):
src_dim = self._get_dimensionality(input_units)
ret = super()._get_compatible_units(input_units, group_or_system)
@@ -404,3 +417,10 @@ class ContextRegistry(PlainRegistry):
ret |= self._cache.dimensional_equivalents[node]
return ret
+
+
+class ContextRegistry(
+ GenericContextRegistry[objects.ContextQuantity[Any], objects.ContextUnit]
+):
+ Quantity: TypeAlias = objects.ContextQuantity[Any]
+ Unit: TypeAlias = objects.ContextUnit
diff --git a/pint/facets/dask/__init__.py b/pint/facets/dask/__init__.py
index 90c8972..8d62f55 100644
--- a/pint/facets/dask/__init__.py
+++ b/pint/facets/dask/__init__.py
@@ -11,10 +11,18 @@
from __future__ import annotations
+from typing import Generic, Any
import functools
-from ...compat import compute, dask_array, persist, visualize
-from ..plain import PlainRegistry, PlainQuantity
+from ...compat import compute, dask_array, persist, visualize, TypeAlias
+from ..plain import (
+ GenericPlainRegistry,
+ PlainQuantity,
+ QuantityT,
+ UnitT,
+ PlainUnit,
+ MagnitudeT,
+)
def check_dask_array(f):
@@ -31,7 +39,7 @@ def check_dask_array(f):
return wrapper
-class DaskQuantity(PlainQuantity):
+class DaskQuantity(Generic[MagnitudeT], PlainQuantity[MagnitudeT]):
# Dask.array.Array ducking
def __dask_graph__(self):
if isinstance(self._magnitude, dask_array.Array):
@@ -119,5 +127,16 @@ class DaskQuantity(PlainQuantity):
visualize(self, **kwargs)
-class DaskRegistry(PlainRegistry):
- Quantity = DaskQuantity
+class DaskUnit(PlainUnit):
+ pass
+
+
+class GenericDaskRegistry(
+ Generic[QuantityT, UnitT], GenericPlainRegistry[QuantityT, UnitT]
+):
+ pass
+
+
+class DaskRegistry(GenericDaskRegistry[DaskQuantity[Any], DaskUnit]):
+ Quantity: TypeAlias = DaskQuantity[Any]
+ Unit: TypeAlias = DaskUnit
diff --git a/pint/facets/formatting/__init__.py b/pint/facets/formatting/__init__.py
index e3f4381..799fa31 100644
--- a/pint/facets/formatting/__init__.py
+++ b/pint/facets/formatting/__init__.py
@@ -11,6 +11,11 @@
from __future__ import annotations
from .objects import FormattingQuantity, FormattingUnit
-from .registry import FormattingRegistry
+from .registry import FormattingRegistry, GenericFormattingRegistry
-__all__ = ["FormattingQuantity", "FormattingUnit", "FormattingRegistry"]
+__all__ = [
+ "FormattingQuantity",
+ "FormattingUnit",
+ "FormattingRegistry",
+ "GenericFormattingRegistry",
+]
diff --git a/pint/facets/formatting/objects.py b/pint/facets/formatting/objects.py
index 5df937c..7d39e91 100644
--- a/pint/facets/formatting/objects.py
+++ b/pint/facets/formatting/objects.py
@@ -9,7 +9,7 @@
from __future__ import annotations
import re
-from typing import Any
+from typing import Any, Generic
from ...compat import babel_parse, ndarray, np
from ...formatting import (
@@ -23,10 +23,10 @@ from ...formatting import (
)
from ...util import UnitsContainer, iterable
-from ..plain import PlainQuantity, PlainUnit
+from ..plain import PlainQuantity, PlainUnit, MagnitudeT
-class FormattingQuantity(PlainQuantity):
+class FormattingQuantity(Generic[MagnitudeT], PlainQuantity[MagnitudeT]):
_exp_pattern = re.compile(r"([0-9]\.?[0-9]*)e(-?)\+?0*([0-9]+)")
def __format__(self, spec: str) -> str:
diff --git a/pint/facets/formatting/registry.py b/pint/facets/formatting/registry.py
index c4dc373..7684597 100644
--- a/pint/facets/formatting/registry.py
+++ b/pint/facets/formatting/registry.py
@@ -8,10 +8,21 @@
from __future__ import annotations
-from ..plain import PlainRegistry
-from .objects import FormattingQuantity, FormattingUnit
+from typing import Generic, Any
+from ...compat import TypeAlias
+from ..plain import GenericPlainRegistry, QuantityT, UnitT
+from . import objects
-class FormattingRegistry(PlainRegistry):
- Quantity = FormattingQuantity
- Unit = FormattingUnit
+
+class GenericFormattingRegistry(
+ Generic[QuantityT, UnitT], GenericPlainRegistry[QuantityT, UnitT]
+):
+ pass
+
+
+class FormattingRegistry(
+ GenericFormattingRegistry[objects.FormattingQuantity[Any], objects.FormattingUnit]
+):
+ Quantity: TypeAlias = objects.FormattingQuantity[Any]
+ Unit: TypeAlias = objects.FormattingUnit
diff --git a/pint/facets/group/__init__.py b/pint/facets/group/__init__.py
index e1fad04..b25ea85 100644
--- a/pint/facets/group/__init__.py
+++ b/pint/facets/group/__init__.py
@@ -11,7 +11,14 @@
from __future__ import annotations
from .definitions import GroupDefinition
-from .objects import Group
-from .registry import GroupRegistry
+from .objects import Group, GroupQuantity, GroupUnit
+from .registry import GroupRegistry, GenericGroupRegistry
-__all__ = ["GroupDefinition", "Group", "GroupRegistry"]
+__all__ = [
+ "GroupDefinition",
+ "Group",
+ "GroupRegistry",
+ "GenericGroupRegistry",
+ "GroupQuantity",
+ "GroupUnit",
+]
diff --git a/pint/facets/group/definitions.py b/pint/facets/group/definitions.py
index 554a63b..2f34750 100644
--- a/pint/facets/group/definitions.py
+++ b/pint/facets/group/definitions.py
@@ -11,7 +11,7 @@ from __future__ import annotations
from collections.abc import Iterable
from dataclasses import dataclass
-from ..._typing import Self
+from ...compat import Self
from ... import errors
from .. import plain
diff --git a/pint/facets/group/objects.py b/pint/facets/group/objects.py
index 200a323..64d91c1 100644
--- a/pint/facets/group/objects.py
+++ b/pint/facets/group/objects.py
@@ -8,9 +8,36 @@
from __future__ import annotations
+from typing import Callable, Any, TYPE_CHECKING, Generic
+
from collections.abc import Generator, Iterable
from ...util import SharedRegistryObject, getattr_maybe_raise
from .definitions import GroupDefinition
+from ..plain import PlainQuantity, PlainUnit, MagnitudeT
+
+if TYPE_CHECKING:
+ from ..plain import UnitDefinition
+
+ DefineFunc = Callable[
+ [
+ Any,
+ ],
+ None,
+ ]
+ AddUnitFunc = Callable[
+ [
+ UnitDefinition,
+ ],
+ None,
+ ]
+
+
+class GroupQuantity(Generic[MagnitudeT], PlainQuantity[MagnitudeT]):
+ pass
+
+
+class GroupUnit(PlainUnit):
+ pass
class Group(SharedRegistryObject):
@@ -57,7 +84,7 @@ class Group(SharedRegistryObject):
self._computed_members: frozenset[str] | None = None
@property
- def members(self):
+ def members(self) -> frozenset[str]:
"""Names of the units that are members of the group.
Calculated to include to all units in all included _used_groups.
@@ -143,7 +170,7 @@ class Group(SharedRegistryObject):
@classmethod
def from_lines(
- cls, lines: Iterable[str], define_func, non_int_type: type = float
+ cls, lines: Iterable[str], define_func: DefineFunc, non_int_type: type = float
) -> Group:
"""Return a Group object parsing an iterable of lines.
@@ -160,11 +187,15 @@ class Group(SharedRegistryObject):
"""
group_definition = GroupDefinition.from_lines(lines, non_int_type)
+
+ if group_definition is None:
+ raise ValueError(f"Could not define group from {lines}")
+
return cls.from_definition(group_definition, define_func)
@classmethod
def from_definition(
- cls, group_definition: GroupDefinition, add_unit_func=None
+ cls, group_definition: GroupDefinition, add_unit_func: AddUnitFunc | None = None
) -> Group:
grp = cls(group_definition.name)
diff --git a/pint/facets/group/registry.py b/pint/facets/group/registry.py
index 0d35ae0..f130e61 100644
--- a/pint/facets/group/registry.py
+++ b/pint/facets/group/registry.py
@@ -8,20 +8,28 @@
from __future__ import annotations
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Generic, Any
+from ...compat import TypeAlias
from ... import errors
if TYPE_CHECKING:
- from ..._typing import Unit
-
-from ...util import create_class_with_registry
-from ..plain import PlainRegistry, UnitDefinition
+ from ..._typing import Unit, UnitsContainer
+
+from ...util import create_class_with_registry, to_units_container
+from ..plain import (
+ GenericPlainRegistry,
+ UnitDefinition,
+ QuantityT,
+ UnitT,
+)
from .definitions import GroupDefinition
from . import objects
-class GroupRegistry(PlainRegistry):
+class GenericGroupRegistry(
+ Generic[QuantityT, UnitT], GenericPlainRegistry[QuantityT, UnitT]
+):
"""Handle of Groups.
Group units
@@ -34,7 +42,7 @@ class GroupRegistry(PlainRegistry):
# TODO: Change this to Group: Group to specify class
# and use introspection to get system class as a way
# to enjoy typing goodies
- Group = objects.Group
+ Group = type[objects.Group]
def __init__(self, **kwargs):
super().__init__(**kwargs)
@@ -46,7 +54,7 @@ class GroupRegistry(PlainRegistry):
def _init_dynamic_classes(self) -> None:
"""Generate subclasses on the fly and attach them to self"""
super()._init_dynamic_classes()
- self.Group = create_class_with_registry(self, self.Group)
+ self.Group = create_class_with_registry(self, objects.Group)
def _after_init(self) -> None:
"""Invoked at the end of ``__init__``.
@@ -113,8 +121,23 @@ class GroupRegistry(PlainRegistry):
return self.Group(name)
- def _get_compatible_units(self, input_units, group) -> frozenset[Unit]:
- ret = super()._get_compatible_units(input_units, group)
+ def get_compatible_units(
+ self, input_units: UnitsContainer, group: str | None = None
+ ) -> frozenset[Unit]:
+ """ """
+ if group is None:
+ return super().get_compatible_units(input_units)
+
+ input_units = to_units_container(input_units)
+
+ equiv = self._get_compatible_units(input_units, group)
+
+ return frozenset(self.Unit(eq) for eq in equiv)
+
+ def _get_compatible_units(
+ self, input_units: UnitsContainer, group: str | None = None
+ ) -> frozenset[str]:
+ ret = super()._get_compatible_units(input_units)
if not group:
return ret
@@ -124,3 +147,10 @@ class GroupRegistry(PlainRegistry):
else:
raise ValueError("Unknown Group with name '%s'" % group)
return frozenset(ret & members)
+
+
+class GroupRegistry(
+ GenericGroupRegistry[objects.GroupQuantity[Any], objects.GroupUnit]
+):
+ Quantity: TypeAlias = objects.GroupQuantity[Any]
+ Unit: TypeAlias = objects.GroupUnit
diff --git a/pint/facets/measurement/__init__.py b/pint/facets/measurement/__init__.py
index 21539dc..d36a5c3 100644
--- a/pint/facets/measurement/__init__.py
+++ b/pint/facets/measurement/__init__.py
@@ -11,6 +11,11 @@
from __future__ import annotations
from .objects import Measurement, MeasurementQuantity
-from .registry import MeasurementRegistry
+from .registry import MeasurementRegistry, GenericMeasurementRegistry
-__all__ = ["Measurement", "MeasurementQuantity", "MeasurementRegistry"]
+__all__ = [
+ "Measurement",
+ "MeasurementQuantity",
+ "MeasurementRegistry",
+ "GenericMeasurementRegistry",
+]
diff --git a/pint/facets/measurement/objects.py b/pint/facets/measurement/objects.py
index 5f3ba7a..b9cacda 100644
--- a/pint/facets/measurement/objects.py
+++ b/pint/facets/measurement/objects.py
@@ -10,15 +10,16 @@ from __future__ import annotations
import copy
import re
+from typing import Generic
from ...compat import ufloat
from ...formatting import _FORMATS, extract_custom_flags, siunitx_format_unit
-from ..plain import PlainQuantity
+from ..plain import PlainQuantity, PlainUnit, MagnitudeT
MISSING = object()
-class MeasurementQuantity(PlainQuantity):
+class MeasurementQuantity(Generic[MagnitudeT], PlainQuantity[MagnitudeT]):
# Measurement support
def plus_minus(self, error, relative=False):
if isinstance(error, self.__class__):
@@ -32,6 +33,10 @@ class MeasurementQuantity(PlainQuantity):
return self._REGISTRY.Measurement(copy.copy(self.magnitude), error, self._units)
+class MeasurementUnit(PlainUnit):
+ pass
+
+
class Measurement(PlainQuantity):
"""Implements a class to describe a quantity with uncertainty.
diff --git a/pint/facets/measurement/registry.py b/pint/facets/measurement/registry.py
index 0fc4391..4a3e878 100644
--- a/pint/facets/measurement/registry.py
+++ b/pint/facets/measurement/registry.py
@@ -9,15 +9,17 @@
from __future__ import annotations
-from ...compat import ufloat
+from typing import Generic, Any
+
+from ...compat import ufloat, TypeAlias
from ...util import create_class_with_registry
-from ..plain import PlainRegistry
-from .objects import MeasurementQuantity
+from ..plain import GenericPlainRegistry, QuantityT, UnitT
from . import objects
-class MeasurementRegistry(PlainRegistry):
- Quantity = MeasurementQuantity
+class GenericMeasurementRegistry(
+ Generic[QuantityT, UnitT], GenericPlainRegistry[QuantityT, UnitT]
+):
Measurement = objects.Measurement
def _init_dynamic_classes(self) -> None:
@@ -34,3 +36,12 @@ class MeasurementRegistry(PlainRegistry):
)
self.Measurement = no_uncertainties
+
+
+class MeasurementRegistry(
+ GenericMeasurementRegistry[
+ objects.MeasurementQuantity[Any], objects.MeasurementUnit
+ ]
+):
+ Quantity: TypeAlias = objects.MeasurementQuantity[Any]
+ Unit: TypeAlias = objects.MeasurementUnit
diff --git a/pint/facets/nonmultiplicative/__init__.py b/pint/facets/nonmultiplicative/__init__.py
index cbba410..eb3292b 100644
--- a/pint/facets/nonmultiplicative/__init__.py
+++ b/pint/facets/nonmultiplicative/__init__.py
@@ -15,8 +15,6 @@ from __future__ import annotations
# This import register LogarithmicConverter and OffsetConverter to be usable
# (via subclassing)
from .definitions import LogarithmicConverter, OffsetConverter # noqa: F401
-from .registry import NonMultiplicativeRegistry
+from .registry import NonMultiplicativeRegistry, GenericNonMultiplicativeRegistry
-__all__ = [
- "NonMultiplicativeRegistry",
-]
+__all__ = ["NonMultiplicativeRegistry", "GenericNonMultiplicativeRegistry"]
diff --git a/pint/facets/nonmultiplicative/objects.py b/pint/facets/nonmultiplicative/objects.py
index 0ab743e..8b944b1 100644
--- a/pint/facets/nonmultiplicative/objects.py
+++ b/pint/facets/nonmultiplicative/objects.py
@@ -8,10 +8,12 @@
from __future__ import annotations
-from ..plain import PlainQuantity
+from typing import Generic
+from ..plain import PlainQuantity, PlainUnit, MagnitudeT
-class NonMultiplicativeQuantity(PlainQuantity):
+
+class NonMultiplicativeQuantity(Generic[MagnitudeT], PlainQuantity[MagnitudeT]):
@property
def _is_multiplicative(self) -> bool:
"""Check if the PlainQuantity object has only multiplicative units."""
@@ -59,3 +61,7 @@ class NonMultiplicativeQuantity(PlainQuantity):
if next(iter(self._units.values())) != 1:
is_ok = False
return is_ok
+
+
+class NonMultiplicativeUnit(PlainUnit):
+ pass
diff --git a/pint/facets/nonmultiplicative/registry.py b/pint/facets/nonmultiplicative/registry.py
index 8bc04db..505406c 100644
--- a/pint/facets/nonmultiplicative/registry.py
+++ b/pint/facets/nonmultiplicative/registry.py
@@ -8,16 +8,22 @@
from __future__ import annotations
-from typing import Any
+from typing import Any, TypeVar, Generic
+from ...compat import TypeAlias
from ...errors import DimensionalityError, UndefinedUnitError
from ...util import UnitsContainer, logger
-from ..plain import PlainRegistry, UnitDefinition
+from ..plain import GenericPlainRegistry, UnitDefinition, QuantityT, UnitT
from .definitions import OffsetConverter, ScaleConverter
-from .objects import NonMultiplicativeQuantity
+from . import objects
-class NonMultiplicativeRegistry(PlainRegistry):
+T = TypeVar("T")
+
+
+class GenericNonMultiplicativeRegistry(
+ Generic[QuantityT, UnitT], GenericPlainRegistry[QuantityT, UnitT]
+):
"""Handle of non multiplicative units (e.g. Temperature).
Capabilities:
@@ -35,8 +41,6 @@ class NonMultiplicativeRegistry(PlainRegistry):
"""
- Quantity = NonMultiplicativeQuantity
-
def __init__(
self,
default_as_delta: bool = True,
@@ -58,14 +62,14 @@ class NonMultiplicativeRegistry(PlainRegistry):
input_string: str,
as_delta: bool | None = None,
case_sensitive: bool | None = None,
- ):
+ ) -> UnitsContainer:
""" """
if as_delta is None:
as_delta = self.default_as_delta
return super()._parse_units(input_string, as_delta, case_sensitive)
- def _add_unit(self, definition: UnitDefinition):
+ def _add_unit(self, definition: UnitDefinition) -> None:
super()._add_unit(definition)
if definition.is_multiplicative:
@@ -104,22 +108,60 @@ class NonMultiplicativeRegistry(PlainRegistry):
)
super()._add_unit(delta_def)
- def _is_multiplicative(self, u) -> bool:
- if u in self._units:
- return self._units[u].is_multiplicative
+ def _is_multiplicative(self, unit_name: str) -> bool:
+ """True if the unit is multiplicative.
+
+ Parameters
+ ----------
+ unit_name
+ Name of the unit to check.
+ Can be prefixed, pluralized or even an alias
+
+ Raises
+ ------
+ UndefinedUnitError
+ If the unit is not in the registyr.
+ """
+ if unit_name in self._units:
+ return self._units[unit_name].is_multiplicative
# If the unit is not in the registry might be because it is not
# registered with its prefixed version.
# TODO: Might be better to register them.
- names = self.parse_unit_name(u)
+ names = self.parse_unit_name(unit_name)
assert len(names) == 1
_, base_name, _ = names[0]
try:
return self._units[base_name].is_multiplicative
except KeyError:
- raise UndefinedUnitError(u)
+ raise UndefinedUnitError(unit_name)
+
+ def _validate_and_extract(self, units: UnitsContainer) -> str | None:
+ """Used to check if a given units is suitable for a simple
+ conversion.
+
+ Return None if all units are non-multiplicative
+ Return the unit name if a single non-multiplicative unit is found
+ and is raised to a power equals to 1.
+
+ Otherwise, raise an Exception.
+
+ Parameters
+ ----------
+ units
+ Compound dictionary.
+
+ Raises
+ ------
+ ValueError
+ If the more than a single non-multiplicative unit is present,
+ or a single one is present but raised to a power different from 1.
+
+ """
+
+ # TODO: document what happens if autoconvert_offset_to_baseunit
+ # TODO: Clarify docs
- def _validate_and_extract(self, units):
# u is for unit, e is for exponent
nonmult_units = [
(u, e) for u, e in units.items() if not self._is_multiplicative(u)
@@ -147,11 +189,16 @@ class NonMultiplicativeRegistry(PlainRegistry):
return None
- def _add_ref_of_log_or_offset_unit(self, offset_unit, all_units):
+ def _add_ref_of_log_or_offset_unit(
+ self, offset_unit: str, all_units: UnitsContainer
+ ) -> UnitsContainer:
slct_unit = self._units[offset_unit]
if slct_unit.is_logarithmic or (not slct_unit.is_multiplicative):
# Extract reference unit
slct_ref = slct_unit.reference
+
+ # TODO: Check that reference is None
+
# If reference unit is not dimensionless
if slct_ref != UnitsContainer():
# Extract reference unit
@@ -161,7 +208,9 @@ class NonMultiplicativeRegistry(PlainRegistry):
# Otherwise, return the units unmodified
return all_units
- def _convert(self, value, src, dst, inplace=False):
+ def _convert(
+ self, value: T, src: UnitsContainer, dst: UnitsContainer, inplace: bool = False
+ ) -> T:
"""Convert value from some source to destination units.
In addition to what is done by the PlainRegistry,
@@ -235,3 +284,12 @@ class NonMultiplicativeRegistry(PlainRegistry):
)
return value
+
+
+class NonMultiplicativeRegistry(
+ GenericNonMultiplicativeRegistry[
+ objects.NonMultiplicativeQuantity[Any], objects.NonMultiplicativeUnit
+ ]
+):
+ Quantity: TypeAlias = objects.NonMultiplicativeQuantity[Any]
+ Unit: TypeAlias = objects.NonMultiplicativeUnit
diff --git a/pint/facets/numpy/__init__.py b/pint/facets/numpy/__init__.py
index aad9508..2e38dc1 100644
--- a/pint/facets/numpy/__init__.py
+++ b/pint/facets/numpy/__init__.py
@@ -10,6 +10,6 @@
from __future__ import annotations
-from .registry import NumpyRegistry
+from .registry import NumpyRegistry, GenericNumpyRegistry
-__all__ = ["NumpyRegistry"]
+__all__ = ["NumpyRegistry", "GenericNumpyRegistry"]
diff --git a/pint/facets/numpy/quantity.py b/pint/facets/numpy/quantity.py
index 131983c..880f860 100644
--- a/pint/facets/numpy/quantity.py
+++ b/pint/facets/numpy/quantity.py
@@ -11,11 +11,11 @@ from __future__ import annotations
import functools
import math
import warnings
-from typing import Any
+from typing import Any, Generic
-from ..plain import PlainQuantity
+from ..plain import PlainQuantity, MagnitudeT
-from ..._typing import Shape, _MagnitudeType
+from ..._typing import Shape
from ...compat import _to_magnitude, np
from ...errors import DimensionalityError, PintTypeError, UnitStrippedWarning
from .numpy_func import (
@@ -42,7 +42,7 @@ def method_wraps(numpy_func):
return wrapper
-class NumpyQuantity(PlainQuantity):
+class NumpyQuantity(Generic[MagnitudeT], PlainQuantity[MagnitudeT]):
""" """
# NumPy function/ufunc support
@@ -130,11 +130,11 @@ class NumpyQuantity(PlainQuantity):
raise DimensionalityError("dimensionless", self._units)
return self.__class__(self.magnitude.clip(min, max, out, **kwargs), self._units)
- def fill(self: NumpyQuantity[np.ndarray], value) -> None:
+ def fill(self: NumpyQuantity, value) -> None:
self._units = value._units
return self.magnitude.fill(value.magnitude)
- def put(self: NumpyQuantity[np.ndarray], indices, values, mode="raise") -> None:
+ def put(self: NumpyQuantity, indices, values, mode="raise") -> None:
if isinstance(values, self.__class__):
values = values.to(self).magnitude
elif self.dimensionless:
@@ -144,11 +144,11 @@ class NumpyQuantity(PlainQuantity):
self.magnitude.put(indices, values, mode)
@property
- def real(self) -> NumpyQuantity[_MagnitudeType]:
+ def real(self) -> NumpyQuantity:
return self.__class__(self._magnitude.real, self._units)
@property
- def imag(self) -> NumpyQuantity[_MagnitudeType]:
+ def imag(self) -> NumpyQuantity:
return self.__class__(self._magnitude.imag, self._units)
@property
diff --git a/pint/facets/numpy/registry.py b/pint/facets/numpy/registry.py
index 11d57f3..e93de44 100644
--- a/pint/facets/numpy/registry.py
+++ b/pint/facets/numpy/registry.py
@@ -9,11 +9,20 @@
from __future__ import annotations
-from ..plain import PlainRegistry
+from typing import Generic, Any
+
+from ...compat import TypeAlias
+from ..plain import GenericPlainRegistry, QuantityT, UnitT
from .quantity import NumpyQuantity
from .unit import NumpyUnit
-class NumpyRegistry(PlainRegistry):
- Quantity = NumpyQuantity
- Unit = NumpyUnit
+class GenericNumpyRegistry(
+ Generic[QuantityT, UnitT], GenericPlainRegistry[QuantityT, UnitT]
+):
+ pass
+
+
+class NumpyRegistry(GenericPlainRegistry[NumpyQuantity[Any], NumpyUnit]):
+ Quantity: TypeAlias = NumpyQuantity[Any]
+ Unit: TypeAlias = NumpyUnit
diff --git a/pint/facets/plain/__init__.py b/pint/facets/plain/__init__.py
index 211d017..90bf2e3 100644
--- a/pint/facets/plain/__init__.py
+++ b/pint/facets/plain/__init__.py
@@ -19,9 +19,11 @@ from .definitions import (
UnitDefinition,
)
from .objects import PlainQuantity, PlainUnit
-from .registry import PlainRegistry
+from .registry import PlainRegistry, GenericPlainRegistry, QuantityT, UnitT
+from .quantity import MagnitudeT
__all__ = [
+ "GenericPlainRegistry",
"PlainUnit",
"PlainQuantity",
"PlainRegistry",
@@ -31,4 +33,7 @@ __all__ = [
"PrefixDefinition",
"ScaleConverter",
"UnitDefinition",
+ "QuantityT",
+ "UnitT",
+ "MagnitudeT",
]
diff --git a/pint/facets/plain/definitions.py b/pint/facets/plain/definitions.py
index 79a44f1..4b352e7 100644
--- a/pint/facets/plain/definitions.py
+++ b/pint/facets/plain/definitions.py
@@ -13,7 +13,7 @@ import numbers
import typing as ty
from dataclasses import dataclass
from functools import cached_property
-from typing import Callable, Any
+from typing import Any
from ..._typing import Magnitude
from ... import errors
@@ -69,11 +69,15 @@ class DefaultsDefinition:
@dataclass(frozen=True)
-class PrefixDefinition(errors.WithDefErr):
- """Definition of a prefix."""
-
+class NamedDefinition:
#: name of the prefix
name: str
+
+
+@dataclass(frozen=True)
+class PrefixDefinition(NamedDefinition, errors.WithDefErr):
+ """Definition of a prefix."""
+
#: scaling value for this prefix
value: numbers.Number
#: canonical symbol
@@ -90,8 +94,8 @@ class PrefixDefinition(errors.WithDefErr):
return bool(self.defined_symbol)
@cached_property
- def converter(self):
- return Converter.from_arguments(scale=self.value)
+ def converter(self) -> ScaleConverter:
+ return ScaleConverter(self.value)
def __post_init__(self):
if not errors.is_valid_prefix_name(self.name):
@@ -110,22 +114,19 @@ class PrefixDefinition(errors.WithDefErr):
@dataclass(frozen=True)
-class UnitDefinition(errors.WithDefErr):
+class UnitDefinition(NamedDefinition, errors.WithDefErr):
"""Definition of a unit."""
- #: canonical name of the unit
- name: str
#: canonical symbol
defined_symbol: str | None
#: additional names for the same unit
aliases: tuple[str]
#: A functiont that converts a value in these units into the reference units
- converter: Callable[
- [
- Magnitude,
- ],
- Magnitude,
- ] | Converter | None
+ # TODO: this has changed as converter is now annotated as converter.
+ # Briefly, in several places converter attributes like as_multiplicative were
+ # accesed. So having a generic function is a no go.
+ # I guess this was never used as errors where not raised.
+ converter: Converter | None
#: Reference units.
reference: UnitsContainer | None
@@ -190,7 +191,7 @@ class UnitDefinition(errors.WithDefErr):
def is_base(self) -> bool:
"""Indicates if it is a base unit."""
- # TODO: why is this here
+ # TODO: This is set in __post_init__
return self._is_base
@property
@@ -215,17 +216,14 @@ class UnitDefinition(errors.WithDefErr):
@dataclass(frozen=True)
-class DimensionDefinition(errors.WithDefErr):
+class DimensionDefinition(NamedDefinition, errors.WithDefErr):
"""Definition of a root dimension"""
- #: name of the dimension
- name: str
-
@property
- def is_base(self):
+ def is_base(self) -> bool:
return True
- def __post_init__(self):
+ def __post_init__(self) -> None:
if not errors.is_valid_dimension_name(self.name):
raise self.def_err(errors.MSG_INVALID_DIMENSION_NAME)
@@ -238,7 +236,7 @@ class DerivedDimensionDefinition(DimensionDefinition):
reference: UnitsContainer
@property
- def is_base(self):
+ def is_base(self) -> bool:
return False
def __post_init__(self):
diff --git a/pint/facets/plain/qto.py b/pint/facets/plain/qto.py
new file mode 100644
index 0000000..72b8157
--- /dev/null
+++ b/pint/facets/plain/qto.py
@@ -0,0 +1,386 @@
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+import bisect
+import math
+import numbers
+from ...util import infer_base_unit
+import warnings
+from ...compat import (
+ mip_INF,
+ mip_INTEGER,
+ mip_model,
+ mip_Model,
+ mip_OptimizationStatus,
+ mip_xsum,
+)
+
+if TYPE_CHECKING:
+ from ..._typing import UnitLike
+ from ...util import UnitsContainer
+ from .quantity import PlainQuantity
+
+
+def _get_reduced_units(
+ quantity: PlainQuantity, units: UnitsContainer
+) -> UnitsContainer:
+ # loop through individual units and compare to each other unit
+ # can we do better than a nested loop here?
+ for unit1, exp in units.items():
+ # make sure it wasn't already reduced to zero exponent on prior pass
+ if unit1 not in units:
+ continue
+ for unit2 in units:
+ # get exponent after reduction
+ exp = units[unit1]
+ if unit1 != unit2:
+ power = quantity._REGISTRY._get_dimensionality_ratio(unit1, unit2)
+ if power:
+ units = units.add(unit2, exp / power).remove([unit1])
+ break
+ return units
+
+
+def ito_reduced_units(quantity: PlainQuantity) -> None:
+ """Return PlainQuantity scaled in place to reduced units, i.e. one unit per
+ dimension. This will not reduce compound units (e.g., 'J/kg' will not
+ be reduced to m**2/s**2), nor can it make use of contexts at this time.
+ """
+
+ # shortcuts in case we're dimensionless or only a single unit
+ if quantity.dimensionless:
+ return quantity.ito({})
+ if len(quantity._units) == 1:
+ return None
+
+ units = quantity._units.copy()
+ new_units = _get_reduced_units(quantity, units)
+
+ return quantity.ito(new_units)
+
+
+def to_reduced_units(
+ quantity: PlainQuantity,
+) -> PlainQuantity:
+ """Return PlainQuantity scaled in place to reduced units, i.e. one unit per
+ dimension. This will not reduce compound units (intentionally), nor
+ can it make use of contexts at this time.
+ """
+
+ # shortcuts in case we're dimensionless or only a single unit
+ if quantity.dimensionless:
+ return quantity.to({})
+ if len(quantity._units) == 1:
+ return quantity
+
+ units = quantity._units.copy()
+ new_units = _get_reduced_units(quantity, units)
+
+ return quantity.to(new_units)
+
+
+def to_compact(
+ quantity: PlainQuantity, unit: UnitsContainer | None = None
+) -> PlainQuantity:
+ """ "Return PlainQuantity rescaled to compact, human-readable units.
+
+ To get output in terms of a different unit, use the unit parameter.
+
+
+ Examples
+ --------
+
+ >>> import pint
+ >>> ureg = pint.UnitRegistry()
+ >>> (200e-9*ureg.s).to_compact()
+ <Quantity(200.0, 'nanosecond')>
+ >>> (1e-2*ureg('kg m/s^2')).to_compact('N')
+ <Quantity(10.0, 'millinewton')>
+ """
+
+ if not isinstance(quantity.magnitude, numbers.Number):
+ msg = "to_compact applied to non numerical types " "has an undefined behavior."
+ w = RuntimeWarning(msg)
+ warnings.warn(w, stacklevel=2)
+ return quantity
+
+ if (
+ quantity.unitless
+ or quantity.magnitude == 0
+ or math.isnan(quantity.magnitude)
+ or math.isinf(quantity.magnitude)
+ ):
+ return quantity
+
+ SI_prefixes: dict[int, str] = {}
+ for prefix in quantity._REGISTRY._prefixes.values():
+ try:
+ scale = prefix.converter.scale
+ # Kludgy way to check if this is an SI prefix
+ log10_scale = int(math.log10(scale))
+ if log10_scale == math.log10(scale):
+ SI_prefixes[log10_scale] = prefix.name
+ except Exception:
+ SI_prefixes[0] = ""
+
+ SI_prefixes_list = sorted(SI_prefixes.items())
+ SI_powers = [item[0] for item in SI_prefixes_list]
+ SI_bases = [item[1] for item in SI_prefixes_list]
+
+ if unit is None:
+ unit = infer_base_unit(quantity, registry=quantity._REGISTRY)
+ else:
+ unit = infer_base_unit(quantity.__class__(1, unit), registry=quantity._REGISTRY)
+
+ q_base = quantity.to(unit)
+
+ magnitude = q_base.magnitude
+
+ units = list(q_base._units.items())
+ units_numerator = [a for a in units if a[1] > 0]
+
+ if len(units_numerator) > 0:
+ unit_str, unit_power = units_numerator[0]
+ else:
+ unit_str, unit_power = units[0]
+
+ if unit_power > 0:
+ power = math.floor(math.log10(abs(magnitude)) / float(unit_power) / 3) * 3
+ else:
+ power = math.ceil(math.log10(abs(magnitude)) / float(unit_power) / 3) * 3
+
+ index = bisect.bisect_left(SI_powers, power)
+
+ if index >= len(SI_bases):
+ index = -1
+
+ prefix_str = SI_bases[index]
+
+ new_unit_str = prefix_str + unit_str
+ new_unit_container = q_base._units.rename(unit_str, new_unit_str)
+
+ return quantity.to(new_unit_container)
+
+
+def to_preferred(
+ quantity: PlainQuantity, preferred_units: list[UnitLike]
+) -> PlainQuantity:
+ """Return Quantity converted to a unit composed of the preferred units.
+
+ Examples
+ --------
+
+ >>> import pint
+ >>> ureg = pint.UnitRegistry()
+ >>> (1*ureg.acre).to_preferred([ureg.meters])
+ <Quantity(4046.87261, 'meter ** 2')>
+ >>> (1*(ureg.force_pound*ureg.m)).to_preferred([ureg.W])
+ <Quantity(4.44822162, 'second * watt')>
+ """
+
+ if not quantity.dimensionality:
+ return quantity
+
+ # The optimizer isn't perfect, and will sometimes miss obvious solutions.
+ # This sub-algorithm is less powerful, but always finds the very simple solutions.
+ def find_simple():
+ best_ratio = None
+ best_unit = None
+ self_dims = sorted(quantity.dimensionality)
+ self_exps = [quantity.dimensionality[d] for d in self_dims]
+ s_exps_head, *s_exps_tail = self_exps
+ n = len(s_exps_tail)
+ for preferred_unit in preferred_units:
+ dims = sorted(preferred_unit.dimensionality)
+ if dims == self_dims:
+ p_exps_head, *p_exps_tail = (
+ preferred_unit.dimensionality[d] for d in dims
+ )
+ if all(
+ s_exps_tail[i] * p_exps_head == p_exps_tail[i] ** s_exps_head
+ for i in range(n)
+ ):
+ ratio = p_exps_head / s_exps_head
+ ratio = max(ratio, 1 / ratio)
+ if best_ratio is None or ratio < best_ratio:
+ best_ratio = ratio
+ best_unit = preferred_unit ** (s_exps_head / p_exps_head)
+ return best_unit
+
+ simple = find_simple()
+ if simple is not None:
+ return quantity.to(simple)
+
+ # For each dimension (e.g. T(ime), L(ength), M(ass)), assign a default base unit from
+ # the collection of base units
+
+ unit_selections = {
+ base_unit.dimensionality: base_unit
+ for base_unit in map(quantity._REGISTRY.Unit, quantity._REGISTRY._base_units)
+ }
+
+ # Override the default unit of each dimension with the 1D-units used in this Quantity
+ unit_selections.update(
+ {
+ unit.dimensionality: unit
+ for unit in map(quantity._REGISTRY.Unit, quantity._units.keys())
+ }
+ )
+
+ # Determine the preferred unit for each dimensionality from the preferred_units
+ # (A prefered unit doesn't have to be only one dimensional, e.g. Watts)
+ preferred_dims = {
+ preferred_unit.dimensionality: preferred_unit
+ for preferred_unit in map(quantity._REGISTRY.Unit, preferred_units)
+ }
+
+ # Combine the defaults and preferred, favoring the preferred
+ unit_selections.update(preferred_dims)
+
+ # This algorithm has poor asymptotic time complexity, so first reduce the considered
+ # dimensions and units to only those that are useful to the problem
+
+ # The dimensions (without powers) of this Quantity
+ dimension_set = set(quantity.dimensionality)
+
+ # Getting zero exponents in dimensions not in dimension_set can be facilitated
+ # by units that interact with that dimension and one or more dimension_set members.
+ # For example MT^1 * LT^-1 lets you get MLT^0 when T is not in dimension_set.
+ # For each candidate unit that interacts with a dimension_set member, add the
+ # candidate unit's other dimensions to dimension_set, and repeat until no more
+ # dimensions are selected.
+
+ discovery_done = False
+ while not discovery_done:
+ discovery_done = True
+ for d in unit_selections:
+ unit_dimensions = set(d)
+ intersection = unit_dimensions.intersection(dimension_set)
+ if 0 < len(intersection) < len(unit_dimensions):
+ # there are dimensions in this unit that are in dimension set
+ # and others that are not in dimension set
+ dimension_set = dimension_set.union(unit_dimensions)
+ discovery_done = False
+ break
+
+ # filter out dimensions and their unit selections that don't interact with any
+ # dimension_set members
+ unit_selections = {
+ dimensionality: unit
+ for dimensionality, unit in unit_selections.items()
+ if set(dimensionality).intersection(dimension_set)
+ }
+
+ # update preferred_units with the selected units that were originally preferred
+ preferred_units = list(
+ {u for d, u in unit_selections.items() if d in preferred_dims}
+ )
+ preferred_units.sort(key=str) # for determinism
+
+ # and unpreferred_units are the selected units that weren't originally preferred
+ unpreferred_units = list(
+ {u for d, u in unit_selections.items() if d not in preferred_dims}
+ )
+ unpreferred_units.sort(key=str) # for determinism
+
+ # for indexability
+ dimensions = list(dimension_set)
+ dimensions.sort() # for determinism
+
+ # the powers for each elemet of dimensions (the list) for this Quantity
+ dimensionality = [quantity.dimensionality[dimension] for dimension in dimensions]
+
+ # Now that the input data is minimized, setup the optimization problem
+
+ # use mip to select units from preferred units
+
+ model = mip_Model()
+ model.verbose = 0
+
+ # Make one variable for each candidate unit
+
+ vars = [
+ model.add_var(str(unit), lb=-mip_INF, ub=mip_INF, var_type=mip_INTEGER)
+ for unit in (preferred_units + unpreferred_units)
+ ]
+
+ # where [u1 ... uN] are powers of N candidate units (vars)
+ # and [d1(uI) ... dK(uI)] are the K dimensional exponents of candidate unit I
+ # and [t1 ... tK] are the dimensional exponents of the quantity (quantity)
+ # create the following constraints
+ #
+ # ⎡ d1(u1) ⋯ dK(u1) ⎤
+ # [ u1 ⋯ uN ] * ⎢ ⋮ ⋱ ⎢ = [ t1 ⋯ tK ]
+ # ⎣ d1(uN) dK(uN) ⎦
+ #
+ # in English, the units we choose, and their exponents, when combined, must have the
+ # target dimensionality
+
+ matrix = [
+ [preferred_unit.dimensionality[dimension] for dimension in dimensions]
+ for preferred_unit in (preferred_units + unpreferred_units)
+ ]
+
+ # Do the matrix multiplication with mip_model.xsum for performance and create constraints
+ for i in range(len(dimensions)):
+ dot = mip_model.xsum([var * vector[i] for var, vector in zip(vars, matrix)])
+ # add constraint to the model
+ model += dot == dimensionality[i]
+
+ # where [c1 ... cN] are costs, 1 when a preferred variable, and a large value when not
+ # minimize sum(abs(u1) * c1 ... abs(uN) * cN)
+
+ # linearize the optimization variable via a proxy
+ objective = model.add_var("objective", lb=0, ub=mip_INF, var_type=mip_INTEGER)
+
+ # Constrain the objective to be equal to the sums of the absolute values of the preferred
+ # unit powers. Do this by making a separate constraint for each permutation of signedness.
+ # Also apply the cost coefficient, which causes the output to prefer the preferred units
+
+ # prefer units that interact with fewer dimensions
+ cost = [len(p.dimensionality) for p in preferred_units]
+
+ # set the cost for non preferred units to a higher number
+ bias = (
+ max(map(abs, dimensionality)) * max((1, *cost)) * 10
+ ) # arbitrary, just needs to be larger
+ cost.extend([bias] * len(unpreferred_units))
+
+ for i in range(1 << len(vars)):
+ sum = mip_xsum(
+ [
+ (-1 if i & 1 << (len(vars) - j - 1) else 1) * cost[j] * var
+ for j, var in enumerate(vars)
+ ]
+ )
+ model += objective >= sum
+
+ model.objective = objective
+
+ # run the mips minimizer and extract the result if successful
+ if model.optimize() == mip_OptimizationStatus.OPTIMAL:
+ optimal_units = []
+ min_objective = float("inf")
+ for i in range(model.num_solutions):
+ if model.objective_values[i] < min_objective:
+ min_objective = model.objective_values[i]
+ optimal_units.clear()
+ elif model.objective_values[i] > min_objective:
+ continue
+
+ temp_unit = quantity._REGISTRY.Unit("")
+ for var in vars:
+ if var.xi(i):
+ temp_unit *= quantity._REGISTRY.Unit(var.name) ** var.xi(i)
+ optimal_units.append(temp_unit)
+
+ sorting_keys = {tuple(sorted(unit._units)): unit for unit in optimal_units}
+ min_key = sorted(sorting_keys)[0]
+ result_unit = sorting_keys[min_key]
+
+ return quantity.to(result_unit)
+
+ # for whatever reason, a solution wasn't found
+ # return the original quantity
+ return quantity
diff --git a/pint/facets/plain/quantity.py b/pint/facets/plain/quantity.py
index 1eaaa3d..0058549 100644
--- a/pint/facets/plain/quantity.py
+++ b/pint/facets/plain/quantity.py
@@ -8,37 +8,22 @@
from __future__ import annotations
-import bisect
+
import copy
import datetime
import locale
-import math
import numbers
import operator
-import warnings
-from typing import (
- TYPE_CHECKING,
- Any,
- Callable,
- Generic,
- TypeVar,
- overload,
-)
-from collections.abc import Iterable, Iterator, Sequence
+from typing import TYPE_CHECKING, Any, Callable, overload, Generic, TypeVar
+from collections.abc import Iterator, Sequence
-from ..._typing import S, UnitLike, _MagnitudeType
+from ..._typing import UnitLike, QuantityOrUnitLike, Magnitude
from ...compat import (
HAS_NUMPY,
_to_magnitude,
eq,
is_duck_array_type,
is_upcast_type,
- mip_INF,
- mip_INTEGER,
- mip_model,
- mip_Model,
- mip_OptimizationStatus,
- mip_xsum,
np,
zero_or_nan,
)
@@ -47,11 +32,11 @@ from ...util import (
PrettyIPython,
SharedRegistryObject,
UnitsContainer,
- infer_base_unit,
logger,
to_units_container,
)
from .definitions import UnitDefinition
+from . import qto
if TYPE_CHECKING:
from ..context import Context
@@ -61,6 +46,10 @@ if TYPE_CHECKING:
if HAS_NUMPY:
import numpy as np # noqa
+MagnitudeT = TypeVar("MagnitudeT", bound=Magnitude)
+
+T = TypeVar("T", bound=Magnitude)
+
def reduce_dimensions(f):
def wrapped(self, *args, **kwargs):
@@ -115,14 +104,10 @@ def method_wraps(numpy_func):
return wrapper
-# Workaround to bypass dynamically generated PlainQuantity with overload method
-Magnitude = TypeVar("Magnitude")
-
-
# TODO: remove all nonmultiplicative remnants
-class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]):
+class PlainQuantity(Generic[MagnitudeT], PrettyIPython, SharedRegistryObject):
"""Implements a class to describe a physical quantity:
the product of a numerical value and a unit of measurement.
@@ -140,7 +125,7 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
#: Default formatting string.
default_format: str = ""
- _magnitude: _MagnitudeType
+ _magnitude: MagnitudeT
@property
def ndim(self) -> int:
@@ -156,11 +141,7 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
def force_ndarray_like(self) -> bool:
return self._REGISTRY.force_ndarray_like
- @property
- def UnitsContainer(self) -> Callable[..., UnitsContainerT]:
- return self._REGISTRY.UnitsContainer
-
- def __reduce__(self) -> tuple:
+ def __reduce__(self) -> tuple[type, Magnitude, UnitsContainer]:
"""Allow pickling quantities. Since UnitRegistries are not pickled, upon
unpickling the new object is always attached to the application registry.
"""
@@ -168,12 +149,17 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
# Note: type(self) would be a mistake as subclasses built by
# dinamically can't be pickled
+ # TODO: Check if this is still the case.
return _unpickle_quantity, (PlainQuantity, self.magnitude, self._units)
+ # @overload
+ # def __new__(
+ # cls, value: T, units: UnitLike | None = None
+ # ) -> PlainQuantity[T]:
+ # ...
+
@overload
- def __new__(
- cls, value: str, units: UnitLike | None = None
- ) -> PlainQuantity[Magnitude]:
+ def __new__(cls, value: str, units: UnitLike | None = None) -> PlainQuantity[int]:
...
@overload
@@ -182,17 +168,11 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
) -> PlainQuantity[np.ndarray]:
...
- @overload
- def __new__(
- cls, value: PlainQuantity[Magnitude], units: UnitLike | None = None
- ) -> PlainQuantity[Magnitude]:
- ...
-
- @overload
- def __new__(
- cls, value: Magnitude, units: UnitLike | None = None
- ) -> PlainQuantity[Magnitude]:
- ...
+ # @overload
+ # def __new__(
+ # cls, value: PlainQuantity[Any], units: UnitLike | None = None
+ # ) -> PlainQuantity[Any]:
+ # ...
def __new__(cls, value, units=None):
if is_upcast_type(type(value)):
@@ -243,7 +223,7 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
return inst
- def __iter__(self: PlainQuantity[Iterable[S]]) -> Iterator[S]:
+ def __iter__(self: PlainQuantity[MagnitudeT]) -> Iterator[Any]:
# Make sure that, if self.magnitude is not iterable, we raise TypeError as soon
# as one calls iter(self) without waiting for the first element to be drawn from
# the iterator
@@ -255,11 +235,11 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
return it_outer()
- def __copy__(self) -> PlainQuantity[_MagnitudeType]:
+ def __copy__(self) -> PlainQuantity[MagnitudeT]:
ret = self.__class__(copy.copy(self._magnitude), self._units)
return ret
- def __deepcopy__(self, memo) -> PlainQuantity[_MagnitudeType]:
+ def __deepcopy__(self, memo) -> PlainQuantity[MagnitudeT]:
ret = self.__class__(
copy.deepcopy(self._magnitude, memo), copy.deepcopy(self._units, memo)
)
@@ -285,16 +265,16 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
return hash((self_base.__class__, self_base.magnitude, self_base.units))
@property
- def magnitude(self) -> _MagnitudeType:
+ def magnitude(self) -> MagnitudeT:
"""PlainQuantity's magnitude. Long form for `m`"""
return self._magnitude
@property
- def m(self) -> _MagnitudeType:
+ def m(self) -> MagnitudeT:
"""PlainQuantity's magnitude. Short form for `magnitude`"""
return self._magnitude
- def m_as(self, units) -> _MagnitudeType:
+ def m_as(self, units) -> MagnitudeT:
"""PlainQuantity's magnitude expressed in particular units.
Parameters
@@ -351,8 +331,8 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
@classmethod
def from_list(
- cls, quant_list: list[PlainQuantity], units=None
- ) -> PlainQuantity[np.ndarray]:
+ cls, quant_list: list[PlainQuantity[MagnitudeT]], units=None
+ ) -> PlainQuantity[MagnitudeT]:
"""Transforms a list of Quantities into an numpy.array quantity.
If no units are specified, the unit of the first element will be used.
Same as from_sequence.
@@ -375,8 +355,8 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
@classmethod
def from_sequence(
- cls, seq: Sequence[PlainQuantity], units=None
- ) -> PlainQuantity[np.ndarray]:
+ cls, seq: Sequence[PlainQuantity[MagnitudeT]], units=None
+ ) -> PlainQuantity[MagnitudeT]:
"""Transforms a sequence of Quantities into an numpy.array quantity.
If no units are specified, the unit of the first element will be used.
@@ -414,7 +394,7 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
def from_tuple(cls, tup):
return cls(tup[0], cls._REGISTRY.UnitsContainer(tup[1]))
- def to_tuple(self) -> tuple[_MagnitudeType, tuple[tuple[str]]]:
+ def to_tuple(self) -> tuple[MagnitudeT, tuple[tuple[str]]]:
return self.m, tuple(self._units.items())
def compatible_units(self, *contexts):
@@ -452,7 +432,7 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
except DimensionalityError:
return False
- if isinstance(other, (PlainQuantity, PlainUnit)):
+ if isinstance(other, (PlainQuantity[MagnitudeT], PlainUnit)):
return self.dimensionality == other.dimensionality
if isinstance(other, str):
@@ -481,7 +461,9 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
inplace=is_duck_array_type(type(self._magnitude)),
)
- def ito(self, other=None, *contexts, **ctx_kwargs) -> None:
+ def ito(
+ self, other: QuantityOrUnitLike | None = None, *contexts, **ctx_kwargs
+ ) -> None:
"""Inplace rescale to different units.
Parameters
@@ -500,7 +482,9 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
return None
- def to(self, other=None, *contexts, **ctx_kwargs) -> PlainQuantity[_MagnitudeType]:
+ def to(
+ self, other: QuantityOrUnitLike | None = None, *contexts, **ctx_kwargs
+ ) -> PlainQuantity:
"""Return PlainQuantity rescaled to different units.
Parameters
@@ -532,7 +516,7 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
return None
- def to_root_units(self) -> PlainQuantity[_MagnitudeType]:
+ def to_root_units(self) -> PlainQuantity[MagnitudeT]:
"""Return PlainQuantity rescaled to root units."""
_, other = self._REGISTRY._get_root_units(self._units)
@@ -551,7 +535,7 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
return None
- def to_base_units(self) -> PlainQuantity[_MagnitudeType]:
+ def to_base_units(self) -> PlainQuantity[MagnitudeT]:
"""Return PlainQuantity rescaled to plain units."""
_, other = self._REGISTRY._get_base_units(self._units)
@@ -560,361 +544,13 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
return self.__class__(magnitude, other)
- def _get_reduced_units(self, units):
- # loop through individual units and compare to each other unit
- # can we do better than a nested loop here?
- for unit1, exp in units.items():
- # make sure it wasn't already reduced to zero exponent on prior pass
- if unit1 not in units:
- continue
- for unit2 in units:
- # get exponent after reduction
- exp = units[unit1]
- if unit1 != unit2:
- power = self._REGISTRY._get_dimensionality_ratio(unit1, unit2)
- if power:
- units = units.add(unit2, exp / power).remove([unit1])
- break
- return units
-
- def ito_reduced_units(self) -> None:
- """Return PlainQuantity scaled in place to reduced units, i.e. one unit per
- dimension. This will not reduce compound units (e.g., 'J/kg' will not
- be reduced to m**2/s**2), nor can it make use of contexts at this time.
- """
-
- # shortcuts in case we're dimensionless or only a single unit
- if self.dimensionless:
- return self.ito({})
- if len(self._units) == 1:
- return None
-
- units = self._units.copy()
- new_units = self._get_reduced_units(units)
-
- return self.ito(new_units)
-
- def to_reduced_units(self) -> PlainQuantity[_MagnitudeType]:
- """Return PlainQuantity scaled in place to reduced units, i.e. one unit per
- dimension. This will not reduce compound units (intentionally), nor
- can it make use of contexts at this time.
- """
-
- # shortcuts in case we're dimensionless or only a single unit
- if self.dimensionless:
- return self.to({})
- if len(self._units) == 1:
- return self
-
- units = self._units.copy()
- new_units = self._get_reduced_units(units)
-
- return self.to(new_units)
-
- def to_compact(self, unit=None) -> PlainQuantity[_MagnitudeType]:
- """ "Return PlainQuantity rescaled to compact, human-readable units.
-
- To get output in terms of a different unit, use the unit parameter.
-
-
- Examples
- --------
-
- >>> import pint
- >>> ureg = pint.UnitRegistry()
- >>> (200e-9*ureg.s).to_compact()
- <Quantity(200.0, 'nanosecond')>
- >>> (1e-2*ureg('kg m/s^2')).to_compact('N')
- <Quantity(10.0, 'millinewton')>
- """
-
- if not isinstance(self.magnitude, numbers.Number):
- msg = (
- "to_compact applied to non numerical types "
- "has an undefined behavior."
- )
- w = RuntimeWarning(msg)
- warnings.warn(w, stacklevel=2)
- return self
-
- if (
- self.unitless
- or self.magnitude == 0
- or math.isnan(self.magnitude)
- or math.isinf(self.magnitude)
- ):
- return self
-
- SI_prefixes: dict[int, str] = {}
- for prefix in self._REGISTRY._prefixes.values():
- try:
- scale = prefix.converter.scale
- # Kludgy way to check if this is an SI prefix
- log10_scale = int(math.log10(scale))
- if log10_scale == math.log10(scale):
- SI_prefixes[log10_scale] = prefix.name
- except Exception:
- SI_prefixes[0] = ""
-
- SI_prefixes_list = sorted(SI_prefixes.items())
- SI_powers = [item[0] for item in SI_prefixes_list]
- SI_bases = [item[1] for item in SI_prefixes_list]
-
- if unit is None:
- unit = infer_base_unit(self, registry=self._REGISTRY)
- else:
- unit = infer_base_unit(self.__class__(1, unit), registry=self._REGISTRY)
-
- q_base = self.to(unit)
-
- magnitude = q_base.magnitude
-
- units = list(q_base._units.items())
- units_numerator = [a for a in units if a[1] > 0]
-
- if len(units_numerator) > 0:
- unit_str, unit_power = units_numerator[0]
- else:
- unit_str, unit_power = units[0]
-
- if unit_power > 0:
- power = math.floor(math.log10(abs(magnitude)) / float(unit_power) / 3) * 3
- else:
- power = math.ceil(math.log10(abs(magnitude)) / float(unit_power) / 3) * 3
-
- index = bisect.bisect_left(SI_powers, power)
-
- if index >= len(SI_bases):
- index = -1
-
- prefix_str = SI_bases[index]
-
- new_unit_str = prefix_str + unit_str
- new_unit_container = q_base._units.rename(unit_str, new_unit_str)
-
- return self.to(new_unit_container)
-
- def to_preferred(
- self, preferred_units: list[UnitLike]
- ) -> PlainQuantity[_MagnitudeType]:
- """Return Quantity converted to a unit composed of the preferred units.
-
- Examples
- --------
-
- >>> import pint
- >>> ureg = pint.UnitRegistry()
- >>> (1*ureg.acre).to_preferred([ureg.meters])
- <Quantity(4046.87261, 'meter ** 2')>
- >>> (1*(ureg.force_pound*ureg.m)).to_preferred([ureg.W])
- <Quantity(4.44822162, 'second * watt')>
- """
-
- if not self.dimensionality:
- return self
-
- # The optimizer isn't perfect, and will sometimes miss obvious solutions.
- # This sub-algorithm is less powerful, but always finds the very simple solutions.
- def find_simple():
- best_ratio = None
- best_unit = None
- self_dims = sorted(self.dimensionality)
- self_exps = [self.dimensionality[d] for d in self_dims]
- s_exps_head, *s_exps_tail = self_exps
- n = len(s_exps_tail)
- for preferred_unit in preferred_units:
- dims = sorted(preferred_unit.dimensionality)
- if dims == self_dims:
- p_exps_head, *p_exps_tail = (
- preferred_unit.dimensionality[d] for d in dims
- )
- if all(
- s_exps_tail[i] * p_exps_head == p_exps_tail[i] ** s_exps_head
- for i in range(n)
- ):
- ratio = p_exps_head / s_exps_head
- ratio = max(ratio, 1 / ratio)
- if best_ratio is None or ratio < best_ratio:
- best_ratio = ratio
- best_unit = preferred_unit ** (s_exps_head / p_exps_head)
- return best_unit
-
- simple = find_simple()
- if simple is not None:
- return self.to(simple)
-
- # For each dimension (e.g. T(ime), L(ength), M(ass)), assign a default base unit from
- # the collection of base units
-
- unit_selections = {
- base_unit.dimensionality: base_unit
- for base_unit in map(self._REGISTRY.Unit, self._REGISTRY._base_units)
- }
-
- # Override the default unit of each dimension with the 1D-units used in this Quantity
- unit_selections.update(
- {
- unit.dimensionality: unit
- for unit in map(self._REGISTRY.Unit, self._units.keys())
- }
- )
-
- # Determine the preferred unit for each dimensionality from the preferred_units
- # (A prefered unit doesn't have to be only one dimensional, e.g. Watts)
- preferred_dims = {
- preferred_unit.dimensionality: preferred_unit
- for preferred_unit in map(self._REGISTRY.Unit, preferred_units)
- }
-
- # Combine the defaults and preferred, favoring the preferred
- unit_selections.update(preferred_dims)
-
- # This algorithm has poor asymptotic time complexity, so first reduce the considered
- # dimensions and units to only those that are useful to the problem
-
- # The dimensions (without powers) of this Quantity
- dimension_set = set(self.dimensionality)
-
- # Getting zero exponents in dimensions not in dimension_set can be facilitated
- # by units that interact with that dimension and one or more dimension_set members.
- # For example MT^1 * LT^-1 lets you get MLT^0 when T is not in dimension_set.
- # For each candidate unit that interacts with a dimension_set member, add the
- # candidate unit's other dimensions to dimension_set, and repeat until no more
- # dimensions are selected.
-
- discovery_done = False
- while not discovery_done:
- discovery_done = True
- for d in unit_selections:
- unit_dimensions = set(d)
- intersection = unit_dimensions.intersection(dimension_set)
- if 0 < len(intersection) < len(unit_dimensions):
- # there are dimensions in this unit that are in dimension set
- # and others that are not in dimension set
- dimension_set = dimension_set.union(unit_dimensions)
- discovery_done = False
- break
-
- # filter out dimensions and their unit selections that don't interact with any
- # dimension_set members
- unit_selections = {
- dimensionality: unit
- for dimensionality, unit in unit_selections.items()
- if set(dimensionality).intersection(dimension_set)
- }
-
- # update preferred_units with the selected units that were originally preferred
- preferred_units = list(
- {u for d, u in unit_selections.items() if d in preferred_dims}
- )
- preferred_units.sort(key=str) # for determinism
-
- # and unpreferred_units are the selected units that weren't originally preferred
- unpreferred_units = list(
- {u for d, u in unit_selections.items() if d not in preferred_dims}
- )
- unpreferred_units.sort(key=str) # for determinism
-
- # for indexability
- dimensions = list(dimension_set)
- dimensions.sort() # for determinism
-
- # the powers for each elemet of dimensions (the list) for this Quantity
- dimensionality = [self.dimensionality[dimension] for dimension in dimensions]
-
- # Now that the input data is minimized, setup the optimization problem
-
- # use mip to select units from preferred units
-
- model = mip_Model()
- model.verbose = 0
-
- # Make one variable for each candidate unit
-
- vars = [
- model.add_var(str(unit), lb=-mip_INF, ub=mip_INF, var_type=mip_INTEGER)
- for unit in (preferred_units + unpreferred_units)
- ]
-
- # where [u1 ... uN] are powers of N candidate units (vars)
- # and [d1(uI) ... dK(uI)] are the K dimensional exponents of candidate unit I
- # and [t1 ... tK] are the dimensional exponents of the quantity (self)
- # create the following constraints
- #
- # ⎡ d1(u1) ⋯ dK(u1) ⎤
- # [ u1 ⋯ uN ] * ⎢ ⋮ ⋱ ⎢ = [ t1 ⋯ tK ]
- # ⎣ d1(uN) dK(uN) ⎦
- #
- # in English, the units we choose, and their exponents, when combined, must have the
- # target dimensionality
-
- matrix = [
- [preferred_unit.dimensionality[dimension] for dimension in dimensions]
- for preferred_unit in (preferred_units + unpreferred_units)
- ]
-
- # Do the matrix multiplication with mip_model.xsum for performance and create constraints
- for i in range(len(dimensions)):
- dot = mip_model.xsum([var * vector[i] for var, vector in zip(vars, matrix)])
- # add constraint to the model
- model += dot == dimensionality[i]
-
- # where [c1 ... cN] are costs, 1 when a preferred variable, and a large value when not
- # minimize sum(abs(u1) * c1 ... abs(uN) * cN)
-
- # linearize the optimization variable via a proxy
- objective = model.add_var("objective", lb=0, ub=mip_INF, var_type=mip_INTEGER)
-
- # Constrain the objective to be equal to the sums of the absolute values of the preferred
- # unit powers. Do this by making a separate constraint for each permutation of signedness.
- # Also apply the cost coefficient, which causes the output to prefer the preferred units
-
- # prefer units that interact with fewer dimensions
- cost = [len(p.dimensionality) for p in preferred_units]
-
- # set the cost for non preferred units to a higher number
- bias = (
- max(map(abs, dimensionality)) * max((1, *cost)) * 10
- ) # arbitrary, just needs to be larger
- cost.extend([bias] * len(unpreferred_units))
-
- for i in range(1 << len(vars)):
- sum = mip_xsum(
- [
- (-1 if i & 1 << (len(vars) - j - 1) else 1) * cost[j] * var
- for j, var in enumerate(vars)
- ]
- )
- model += objective >= sum
-
- model.objective = objective
-
- # run the mips minimizer and extract the result if successful
- if model.optimize() == mip_OptimizationStatus.OPTIMAL:
- optimal_units = []
- min_objective = float("inf")
- for i in range(model.num_solutions):
- if model.objective_values[i] < min_objective:
- min_objective = model.objective_values[i]
- optimal_units.clear()
- elif model.objective_values[i] > min_objective:
- continue
-
- temp_unit = self._REGISTRY.Unit("")
- for var in vars:
- if var.xi(i):
- temp_unit *= self._REGISTRY.Unit(var.name) ** var.xi(i)
- optimal_units.append(temp_unit)
-
- sorting_keys = {tuple(sorted(unit._units)): unit for unit in optimal_units}
- min_key = sorted(sorting_keys)[0]
- result_unit = sorting_keys[min_key]
-
- return self.to(result_unit)
-
- # for whatever reason, a solution wasn't found
- # return the original quantity
- return self
+ # Functions not essential to a Quantity but it is
+ # convenient that they live in PlainQuantity.
+ # They are implemented elsewhere to keep Quantity class clean.
+ to_compact = qto.to_compact
+ to_preferred = qto.to_preferred
+ to_reduced_units = qto.to_reduced_units
+ ito_reduced_units = qto.ito_reduced_units
# Mathematical operations
def __int__(self) -> int:
@@ -1163,7 +799,7 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
...
@overload
- def __iadd__(self, other) -> PlainQuantity[_MagnitudeType]:
+ def __iadd__(self, other) -> PlainQuantity[MagnitudeT]:
...
def __iadd__(self, other):
@@ -1539,7 +1175,7 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
return self
@check_implemented
- def __pow__(self, other) -> PlainQuantity[_MagnitudeType]:
+ def __pow__(self, other) -> PlainQuantity[MagnitudeT]:
try:
_to_magnitude(other, self.force_ndarray, self.force_ndarray_like)
except PintTypeError:
@@ -1604,7 +1240,7 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
return self.__class__(magnitude, units)
@check_implemented
- def __rpow__(self, other) -> PlainQuantity[_MagnitudeType]:
+ def __rpow__(self, other) -> PlainQuantity[MagnitudeT]:
try:
_to_magnitude(other, self.force_ndarray, self.force_ndarray_like)
except PintTypeError:
@@ -1617,16 +1253,16 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
new_self = self.to_root_units()
return other**new_self._magnitude
- def __abs__(self) -> PlainQuantity[_MagnitudeType]:
+ def __abs__(self) -> PlainQuantity[MagnitudeT]:
return self.__class__(abs(self._magnitude), self._units)
- def __round__(self, ndigits: int | None = 0) -> PlainQuantity[int]:
+ def __round__(self, ndigits: int | None = 0) -> PlainQuantity[MagnitudeT]:
return self.__class__(round(self._magnitude, ndigits=ndigits), self._units)
- def __pos__(self) -> PlainQuantity[_MagnitudeType]:
+ def __pos__(self) -> PlainQuantity[MagnitudeT]:
return self.__class__(operator.pos(self._magnitude), self._units)
- def __neg__(self) -> PlainQuantity[_MagnitudeType]:
+ def __neg__(self) -> PlainQuantity[MagnitudeT]:
return self.__class__(operator.neg(self._magnitude), self._units)
@check_implemented
@@ -1797,5 +1433,14 @@ class PlainQuantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]
def _ok_for_muldiv(self, no_offset_units=None) -> bool:
return True
- def to_timedelta(self: PlainQuantity[float]) -> datetime.timedelta:
+ def to_timedelta(self: PlainQuantity[MagnitudeT]) -> datetime.timedelta:
return datetime.timedelta(microseconds=self.to("microseconds").magnitude)
+
+ # We put this last to avoid overriding UnitsContainer
+ # and I do not want to rename it.
+ # TODO: Maybe in the future we need to change it to a more meaningful
+ # non-colliding name.
+
+ @property
+ def UnitsContainer(self) -> Callable[..., UnitsContainerT]:
+ return self._REGISTRY.UnitsContainer
diff --git a/pint/facets/plain/registry.py b/pint/facets/plain/registry.py
index d3baff4..ed46608 100644
--- a/pint/facets/plain/registry.py
+++ b/pint/facets/plain/registry.py
@@ -20,27 +20,38 @@ from decimal import Decimal
from fractions import Fraction
from numbers import Number
from token import NAME, NUMBER
+from tokenize import TokenInfo
+
from typing import (
TYPE_CHECKING,
Any,
Callable,
TypeVar,
Union,
+ Generic,
)
from collections.abc import Iterable, Iterator
if TYPE_CHECKING:
from ..context import Context
- from ..._typing import Quantity, Unit
+ from ...compat import Locale
+
+ # from ..._typing import Quantity, Unit
+
+from ..._typing import (
+ QuantityOrUnitLike,
+ UnitLike,
+ QuantityArgument,
+ Scalar,
+ Handler,
+)
-from ..._typing import QuantityOrUnitLike, UnitLike
from ..._vendor import appdirs
-from ...compat import HAS_BABEL, babel_parse, tokenizer
+from ...compat import babel_parse, tokenizer, TypeAlias, Self
from ...errors import DimensionalityError, RedefinitionError, UndefinedUnitError
from ...pint_eval import build_eval_tree
from ...util import ParserHelper
-from ...util import UnitsContainer
-from ...util import UnitsContainer as UnitsContainerT
+from ...util import UnitsContainer as UnitsContainer
from ...util import (
_is_dim,
create_class_with_registry,
@@ -58,25 +69,20 @@ from .definitions import (
DimensionDefinition,
PrefixDefinition,
UnitDefinition,
+ NamedDefinition,
)
from .objects import PlainQuantity, PlainUnit
-if TYPE_CHECKING:
- if HAS_BABEL:
- import babel
-
- Locale = babel.Locale
- else:
- Locale = None
-
T = TypeVar("T")
_BLOCK_RE = re.compile(r"[ (]")
@functools.lru_cache
-def pattern_to_regex(pattern):
- if hasattr(pattern, "finditer"):
+def pattern_to_regex(pattern: str | re.Pattern[str]) -> re.Pattern[str]:
+ # TODO: This has been changed during typing improvements.
+ # if hasattr(pattern, "finditer"):
+ if not isinstance(pattern, str):
pattern = pattern.pattern
# Replace "{unit_name}" match string with float regex with unit_name as group
@@ -96,15 +102,19 @@ class RegistryCache:
def __init__(self) -> None:
#: Maps dimensionality (UnitsContainer) to Units (str)
- self.dimensional_equivalents: dict[UnitsContainer, set[str]] = {}
+ self.dimensional_equivalents: dict[UnitsContainer, frozenset[str]] = {}
+
#: Maps dimensionality (UnitsContainer) to Dimensionality (UnitsContainer)
- self.root_units = {}
+ # TODO: this description is not right.
+ self.root_units: dict[UnitsContainer, tuple[Scalar, UnitsContainer]] = {}
+
#: Maps dimensionality (UnitsContainer) to Units (UnitsContainer)
self.dimensionality: dict[UnitsContainer, UnitsContainer] = {}
+
#: Cache the unit name associated to user input. ('mV' -> 'millivolt')
self.parse_unit: dict[str, UnitsContainer] = {}
- def __eq__(self, other):
+ def __eq__(self, other: Any):
if not isinstance(other, self.__class__):
return False
attrs = (
@@ -127,7 +137,12 @@ class RegistryMeta(type):
return obj
-class PlainRegistry(metaclass=RegistryMeta):
+# Generic types used to mark types associated to Registries.
+QuantityT = TypeVar("QuantityT", bound=PlainQuantity)
+UnitT = TypeVar("UnitT", bound=PlainUnit)
+
+
+class GenericPlainRegistry(Generic[QuantityT, UnitT], metaclass=RegistryMeta):
"""Base class for all registries.
Capabilities:
@@ -174,11 +189,10 @@ class PlainRegistry(metaclass=RegistryMeta):
#: Babel.Locale instance or None
fmt_locale: Locale | None = None
- _diskcache = None
-
- Quantity = PlainQuantity
- Unit = PlainUnit
+ Quantity: type[QuantityT]
+ Unit: type[UnitT]
+ _diskcache = None
_def_parser = None
def __init__(
@@ -197,7 +211,7 @@ class PlainRegistry(metaclass=RegistryMeta):
mpl_formatter: str = "{:P}",
):
#: Map a definition class to a adder methods.
- self._adders = {}
+ self._adders: Handler = {}
self._register_definition_adders()
self._init_dynamic_classes()
@@ -280,8 +294,8 @@ class PlainRegistry(metaclass=RegistryMeta):
def _init_dynamic_classes(self) -> None:
"""Generate subclasses on the fly and attach them to self"""
- self.Unit: Unit = create_class_with_registry(self, self.Unit)
- self.Quantity: Quantity = create_class_with_registry(self, self.Quantity)
+ self.Unit = create_class_with_registry(self, self.Unit)
+ self.Quantity = create_class_with_registry(self, self.Quantity)
def _after_init(self) -> None:
"""This should be called after all __init__"""
@@ -297,7 +311,16 @@ class PlainRegistry(metaclass=RegistryMeta):
self._build_cache(loaded_files)
self._initialized = True
- def _register_adder(self, definition_class, adder_func):
+ def _register_adder(
+ self,
+ definition_class: type[T],
+ adder_func: Callable[
+ [
+ T,
+ ],
+ None,
+ ],
+ ) -> None:
"""Register a block definition."""
self._adders[definition_class] = adder_func
@@ -310,24 +333,25 @@ class PlainRegistry(metaclass=RegistryMeta):
self._register_adder(DimensionDefinition, self._add_dimension)
self._register_adder(DerivedDimensionDefinition, self._add_derived_dimension)
- def __deepcopy__(self, memo) -> PlainRegistry:
+ def __deepcopy__(self: Self, memo) -> type[Self]:
new = object.__new__(type(self))
new.__dict__ = copy.deepcopy(self.__dict__, memo)
new._init_dynamic_classes()
return new
- def __getattr__(self, item):
+ def __getattr__(self, item: str) -> QuantityT:
getattr_maybe_raise(self, item)
return self.Unit(item)
- def __getitem__(self, item):
+ def __getitem__(self, item: str) -> UnitT:
logger.warning(
"Calling the getitem method from a UnitRegistry is deprecated. "
"use `parse_expression` method or use the registry as a callable."
)
- return self.parse_expression(item)
+ return self.Quantity()
+ # return self.parse_expression(item)
- def __contains__(self, item) -> bool:
+ def __contains__(self, item: str) -> bool:
"""Support checking prefixed units with the `in` operator"""
try:
self.__getattr__(item)
@@ -366,16 +390,13 @@ class PlainRegistry(metaclass=RegistryMeta):
self.fmt_locale = loc
- def UnitsContainer(self, *args, **kwargs) -> UnitsContainerT:
- return UnitsContainer(*args, non_int_type=self.non_int_type, **kwargs)
-
@property
def default_format(self) -> str:
"""Default formatting string for quantities."""
return self.Quantity.default_format
@default_format.setter
- def default_format(self, value: str):
+ def default_format(self, value: str) -> None:
self.Unit.default_format = value
self.Quantity.default_format = value
self.Measurement.default_format = value
@@ -390,7 +411,7 @@ class PlainRegistry(metaclass=RegistryMeta):
def non_int_type(self):
return self._non_int_type
- def define(self, definition):
+ def define(self, definition: str | type) -> None:
"""Add unit to the registry.
Parameters
@@ -413,7 +434,7 @@ class PlainRegistry(metaclass=RegistryMeta):
# - then we define specific adder for each definition class. :-D
############
- def _helper_dispatch_adder(self, definition):
+ def _helper_dispatch_adder(self, definition: Any) -> None:
"""Helper function to add a single definition,
choosing the appropiate method by class.
"""
@@ -428,7 +449,12 @@ class PlainRegistry(metaclass=RegistryMeta):
adder_func(definition)
- def _helper_adder(self, definition, target_dict, casei_target_dict):
+ def _helper_adder(
+ self,
+ definition: NamedDefinition,
+ target_dict: dict[str, Any],
+ casei_target_dict: dict[str, Any] | None,
+ ) -> None:
"""Helper function to store a definition in the internal dictionaries.
It stores the definition under its name, symbol and aliases.
"""
@@ -436,6 +462,7 @@ class PlainRegistry(metaclass=RegistryMeta):
definition.name, definition, target_dict, casei_target_dict
)
+ # TODO: Not sure why but using hasattr does not work here.
if getattr(definition, "has_symbol", ""):
self._helper_single_adder(
definition.symbol, definition, target_dict, casei_target_dict
@@ -447,7 +474,13 @@ class PlainRegistry(metaclass=RegistryMeta):
self._helper_single_adder(alias, definition, target_dict, casei_target_dict)
- def _helper_single_adder(self, key, value, target_dict, casei_target_dict):
+ def _helper_single_adder(
+ self,
+ key: str,
+ value: NamedDefinition,
+ target_dict: dict[str, Any],
+ casei_target_dict: dict[str, Any] | None,
+ ) -> None:
"""Helper function to store a definition in the internal dictionaries.
It warns or raise error on redefinition.
@@ -462,11 +495,11 @@ class PlainRegistry(metaclass=RegistryMeta):
if casei_target_dict is not None:
casei_target_dict[key.lower()].add(key)
- def _add_defaults(self, defaults_definition: DefaultsDefinition):
+ def _add_defaults(self, defaults_definition: DefaultsDefinition) -> None:
for k, v in defaults_definition.items():
self._defaults[k] = v
- def _add_alias(self, definition: AliasDefinition):
+ def _add_alias(self, definition: AliasDefinition) -> None:
unit_dict = self._units
unit = unit_dict[definition.name]
while not isinstance(unit, UnitDefinition):
@@ -474,19 +507,19 @@ class PlainRegistry(metaclass=RegistryMeta):
for alias in definition.aliases:
self._helper_single_adder(alias, unit, self._units, self._units_casei)
- def _add_dimension(self, definition: DimensionDefinition):
+ def _add_dimension(self, definition: DimensionDefinition) -> None:
self._helper_adder(definition, self._dimensions, None)
- def _add_derived_dimension(self, definition: DerivedDimensionDefinition):
+ def _add_derived_dimension(self, definition: DerivedDimensionDefinition) -> None:
for dim_name in definition.reference.keys():
if dim_name not in self._dimensions:
self._add_dimension(DimensionDefinition(dim_name))
self._helper_adder(definition, self._dimensions, None)
- def _add_prefix(self, definition: PrefixDefinition):
+ def _add_prefix(self, definition: PrefixDefinition) -> None:
self._helper_adder(definition, self._prefixes, None)
- def _add_unit(self, definition: UnitDefinition):
+ def _add_unit(self, definition: UnitDefinition) -> None:
if definition.is_base:
self._base_units.append(definition.name)
for dim_name in definition.reference.keys():
@@ -495,7 +528,9 @@ class PlainRegistry(metaclass=RegistryMeta):
self._helper_adder(definition, self._units, self._units_casei)
- def load_definitions(self, file, is_resource: bool = False):
+ def load_definitions(
+ self, file: Iterable[str] | str | pathlib.Path, is_resource: bool = False
+ ):
"""Add units and prefixes defined in a definition text file.
Parameters
@@ -531,8 +566,8 @@ class PlainRegistry(metaclass=RegistryMeta):
self._cache = RegistryCache()
- deps = {
- name: definition.reference.keys() if definition.reference else set()
+ deps: dict[str, set[str]] = {
+ name: set(definition.reference.keys()) if definition.reference else set()
for name, definition in self._units.items()
}
@@ -579,14 +614,13 @@ class PlainRegistry(metaclass=RegistryMeta):
candidates = self.parse_unit_name(name_or_alias, case_sensitive)
if not candidates:
raise UndefinedUnitError(name_or_alias)
- elif len(candidates) == 1:
- prefix, unit_name, _ = candidates[0]
- else:
+
+ prefix, unit_name, _ = candidates[0]
+ if len(candidates) > 1:
logger.warning(
"Parsing {} yield multiple results. "
- "Options are: {}".format(name_or_alias, candidates)
+ "Options are: {!r}".format(name_or_alias, candidates)
)
- prefix, unit_name, _ = candidates[0]
if prefix:
name = prefix + unit_name
@@ -595,7 +629,7 @@ class PlainRegistry(metaclass=RegistryMeta):
self._units[name] = UnitDefinition(
name,
symbol,
- (),
+ tuple(),
prefix_def.converter,
self.UnitsContainer({unit_name: 1}),
)
@@ -608,21 +642,20 @@ class PlainRegistry(metaclass=RegistryMeta):
candidates = self.parse_unit_name(name_or_alias, case_sensitive)
if not candidates:
raise UndefinedUnitError(name_or_alias)
- elif len(candidates) == 1:
- prefix, unit_name, _ = candidates[0]
- else:
+
+ prefix, unit_name, _ = candidates[0]
+ if len(candidates) > 1:
logger.warning(
"Parsing {} yield multiple results. "
"Options are: {!r}".format(name_or_alias, candidates)
)
- prefix, unit_name, _ = candidates[0]
return self._prefixes[prefix].symbol + self._units[unit_name].symbol
def _get_symbol(self, name: str) -> str:
return self._units[name].symbol
- def get_dimensionality(self, input_units) -> UnitsContainerT:
+ def get_dimensionality(self, input_units: UnitLike) -> UnitsContainer:
"""Convert unit or dict of units or dimensions to a dict of plain dimensions
dimensions
"""
@@ -633,9 +666,7 @@ class PlainRegistry(metaclass=RegistryMeta):
return self._get_dimensionality(input_units)
- def _get_dimensionality(
- self, input_units: UnitsContainerT | None
- ) -> UnitsContainerT:
+ def _get_dimensionality(self, input_units: UnitsContainer | None) -> UnitsContainer:
"""Convert a UnitsContainer to plain dimensions."""
if not input_units:
return self.UnitsContainer()
@@ -647,7 +678,7 @@ class PlainRegistry(metaclass=RegistryMeta):
except KeyError:
pass
- accumulator = defaultdict(int)
+ accumulator: dict[str, int] = defaultdict(int)
self._get_dimensionality_recurse(input_units, 1, accumulator)
if "[]" in accumulator:
@@ -659,21 +690,25 @@ class PlainRegistry(metaclass=RegistryMeta):
return dims
- def _get_dimensionality_recurse(self, ref, exp, accumulator):
+ def _get_dimensionality_recurse(
+ self, ref: UnitsContainer, exp: Scalar, accumulator: dict[str, int]
+ ) -> None:
for key in ref:
exp2 = exp * ref[key]
if _is_dim(key):
reg = self._dimensions[key]
- if reg.is_base:
- accumulator[key] += exp2
- elif reg.reference is not None:
+ if isinstance(reg, DerivedDimensionDefinition):
self._get_dimensionality_recurse(reg.reference, exp2, accumulator)
+ else:
+ # DimensionDefinition.
+ accumulator[key] += exp2
+
else:
reg = self._units[self.get_name(key)]
if reg.reference is not None:
self._get_dimensionality_recurse(reg.reference, exp2, accumulator)
- def _get_dimensionality_ratio(self, unit1, unit2):
+ def _get_dimensionality_ratio(self, unit1: UnitLike, unit2: UnitLike):
"""Get the exponential ratio between two units, i.e. solve unit2 = unit1**x for x.
Parameters
@@ -707,7 +742,7 @@ class PlainRegistry(metaclass=RegistryMeta):
def get_root_units(
self, input_units: UnitLike, check_nonmult: bool = True
- ) -> tuple[Number, PlainUnit]:
+ ) -> tuple[Number, UnitT]:
"""Convert unit or dict of units to the root units.
If any unit is non multiplicative and check_converter is True,
@@ -734,7 +769,9 @@ class PlainRegistry(metaclass=RegistryMeta):
return f, self.Unit(units)
- def _get_root_units(self, input_units, check_nonmult=True):
+ def _get_root_units(
+ self, input_units: UnitsContainer, check_nonmult: bool = True
+ ) -> tuple[Scalar, UnitsContainer]:
"""Convert unit or dict of units to the root units.
If any unit is non multiplicative and check_converter is True,
@@ -764,12 +801,13 @@ class PlainRegistry(metaclass=RegistryMeta):
except KeyError:
pass
- accumulators = [1, defaultdict(int)]
+ accumulators: dict[str | None, int] = defaultdict(int)
+ accumulators[None] = 1
self._get_root_units_recurse(input_units, 1, accumulators)
- factor = accumulators[0]
+ factor = accumulators[None]
units = self.UnitsContainer(
- {k: v for k, v in accumulators[1].items() if v != 0}
+ {k: v for k, v in accumulators.items() if k is not None and v != 0}
)
# Check if any of the final units is non multiplicative and return None instead.
@@ -780,7 +818,9 @@ class PlainRegistry(metaclass=RegistryMeta):
cache[input_units] = factor, units
return factor, units
- def get_base_units(self, input_units, check_nonmult=True, system=None):
+ def get_base_units(
+ self, input_units: UnitsContainer | str, check_nonmult: bool = True, system=None
+ ) -> tuple[Number, UnitT]:
"""Convert unit or dict of units to the plain units.
If any unit is non multiplicative and check_converter is True,
@@ -806,35 +846,44 @@ class PlainRegistry(metaclass=RegistryMeta):
return self.get_root_units(input_units, check_nonmult)
- def _get_root_units_recurse(self, ref, exp, accumulators):
+ # TODO: accumulators breaks typing list[int, dict[str, int]]
+ # So we have changed the behavior here
+ def _get_root_units_recurse(
+ self, ref: UnitsContainer, exp: Scalar, accumulators: dict[str | None, int]
+ ) -> None:
+ """
+
+ accumulators None keeps the scalar prefactor not associated with a specific unit.
+
+ """
for key in ref:
exp2 = exp * ref[key]
key = self.get_name(key)
reg = self._units[key]
if reg.is_base:
- accumulators[1][key] += exp2
+ accumulators[key] += exp2
else:
- accumulators[0] *= reg.converter.scale**exp2
+ accumulators[None] *= reg.converter.scale**exp2
if reg.reference is not None:
self._get_root_units_recurse(reg.reference, exp2, accumulators)
- def get_compatible_units(
- self, input_units, group_or_system=None
- ) -> frozenset[Unit]:
+ def get_compatible_units(self, input_units: QuantityOrUnitLike) -> frozenset[UnitT]:
""" """
input_units = to_units_container(input_units)
- equiv = self._get_compatible_units(input_units, group_or_system)
+ equiv = self._get_compatible_units(input_units)
return frozenset(self.Unit(eq) for eq in equiv)
- def _get_compatible_units(self, input_units, group_or_system):
+ def _get_compatible_units(
+ self, input_units: UnitsContainer, *args, **kwargs
+ ) -> frozenset[str]:
""" """
if not input_units:
return frozenset()
src_dim = self._get_dimensionality(input_units)
- return self._cache.dimensional_equivalents.setdefault(src_dim, set())
+ return self._cache.dimensional_equivalents.setdefault(src_dim, frozenset())
# TODO: remove context from here
def is_compatible_with(
@@ -901,7 +950,14 @@ class PlainRegistry(metaclass=RegistryMeta):
return self._convert(value, src, dst, inplace)
- def _convert(self, value, src, dst, inplace=False, check_dimensionality=True):
+ def _convert(
+ self,
+ value: T,
+ src: UnitsContainer,
+ dst: UnitsContainer,
+ inplace: bool = False,
+ check_dimensionality: bool = True,
+ ) -> T:
"""Convert value from some source to destination units.
Parameters
@@ -931,7 +987,7 @@ class PlainRegistry(metaclass=RegistryMeta):
# If the source and destination dimensionality are different,
# then the conversion cannot be performed.
if src_dim != dst_dim:
- raise DimensionalityError(src, dst, src_dim, dst_dim)
+ raise DimensionalityError(src, dst, str(src_dim), str(dst_dim))
# Here src and dst have only multiplicative units left. Thus we can
# convert with a factor.
@@ -953,7 +1009,7 @@ class PlainRegistry(metaclass=RegistryMeta):
def parse_unit_name(
self, unit_name: str, case_sensitive: bool | None = None
- ) -> tuple[tuple[str, str, str], ...]:
+ ) -> tuple[tuple[str, str, str]]:
"""Parse a unit to identify prefix, unit name and suffix
by walking the list of prefix and suffix.
In case of equivalent combinations (e.g. ('kilo', 'gram', '') and
@@ -1033,7 +1089,7 @@ class PlainRegistry(metaclass=RegistryMeta):
input_string: str,
as_delta: bool | None = None,
case_sensitive: bool | None = None,
- ) -> Unit:
+ ) -> UnitT:
"""Parse a units expression and returns a UnitContainer with
the canonical names.
@@ -1054,6 +1110,8 @@ class PlainRegistry(metaclass=RegistryMeta):
pint.Unit
"""
+
+ # TODO: deal or remove with as_delta = None
for p in self.preprocessors:
input_string = p(input_string)
units = self._parse_units(input_string, as_delta, case_sensitive)
@@ -1064,7 +1122,7 @@ class PlainRegistry(metaclass=RegistryMeta):
input_string: str,
as_delta: bool = True,
case_sensitive: bool | None = None,
- ) -> UnitsContainerT:
+ ) -> UnitsContainer:
"""Parse a units expression and returns a UnitContainer with
the canonical names.
"""
@@ -1104,12 +1162,37 @@ class PlainRegistry(metaclass=RegistryMeta):
return ret
- def _eval_token(self, token, case_sensitive=None, **values):
+ def _eval_token(
+ self,
+ token: TokenInfo,
+ case_sensitive: bool | None = None,
+ **values: QuantityArgument,
+ ):
+ """Evaluate a single token using the following rules:
+
+ 1. numerical values as strings are replaced by their numeric counterparts
+ - integers are parsed as integers
+ - other numeric values are parses of non_int_type
+ 2. strings in (inf, infinity, nan, dimensionless) with their numerical value.
+ 3. strings in values.keys() are replaced by Quantity(values[key])
+ 4. in other cases, the values are parsed as units and replaced by their canonical name.
+
+ Parameters
+ ----------
+ token
+ Token to evaluate.
+ case_sensitive, optional
+ If true, a case sensitive matching of the unit name will be done in the registry.
+ If false, a case INsensitive matching of the unit name will be done in the registry.
+ (Default value = None, which uses registry setting)
+ **values
+ Other string that will be parsed using the Quantity constructor on their corresponding value.
+ """
token_type = token[0]
token_text = token[1]
if token_type == NAME:
if token_text == "dimensionless":
- return self.Quantity(1, self.dimensionless)
+ return self.Quantity(1)
elif token_text.lower() in ("inf", "infinity"):
return self.non_int_type("inf")
elif token_text.lower() == "nan":
@@ -1139,28 +1222,25 @@ class PlainRegistry(metaclass=RegistryMeta):
Parameters
----------
- input_string :
+ input_string
pattern_string:
- The regex parse string
- case_sensitive :
- (Default value = None, which uses registry setting)
- many :
+ The regex parse string
+ case_sensitive, optional
+ If true, a case sensitive matching of the unit name will be done in the registry.
+ If false, a case INsensitive matching of the unit name will be done in the registry.
+ (Default value = None, which uses registry setting)
+ many, optional
Match many results
(Default value = False)
-
-
- Returns
- -------
-
"""
if not input_string:
return [] if many else None
# Parse string
- pattern = pattern_to_regex(pattern)
- matched = re.finditer(pattern, input_string)
+ regex = pattern_to_regex(pattern)
+ matched = re.finditer(regex, input_string)
# Extract result(s)
results = []
@@ -1184,11 +1264,11 @@ class PlainRegistry(metaclass=RegistryMeta):
return results
def parse_expression(
- self,
+ self: Self,
input_string: str,
case_sensitive: bool | None = None,
- **values,
- ) -> Quantity:
+ **values: QuantityArgument,
+ ) -> QuantityT:
"""Parse a mathematical expression including units and return a quantity object.
Numerical constants can be specified as keyword arguments and will take precedence
@@ -1196,16 +1276,14 @@ class PlainRegistry(metaclass=RegistryMeta):
Parameters
----------
- input_string :
-
- case_sensitive :
- (Default value = None, which uses registry setting)
- **values :
-
-
- Returns
- -------
-
+ input_string
+
+ case_sensitive, optional
+ If true, a case sensitive matching of the unit name will be done in the registry.
+ If false, a case INsensitive matching of the unit name will be done in the registry.
+ (Default value = None, which uses registry setting)
+ **values
+ Other string that will be parsed using the Quantity constructor on their corresponding value.
"""
if not input_string:
return self.Quantity(1)
@@ -1215,8 +1293,21 @@ class PlainRegistry(metaclass=RegistryMeta):
input_string = string_preprocessor(input_string)
gen = tokenizer(input_string)
- return build_eval_tree(gen).evaluate(
- lambda x: self._eval_token(x, case_sensitive=case_sensitive, **values)
- )
+ def _define_op(s: str):
+ return self._eval_token(s, case_sensitive=case_sensitive, **values)
+
+ return build_eval_tree(gen).evaluate(_define_op)
+
+ # We put this last to avoid overriding UnitsContainer
+ # and I do not want to rename it.
+ # TODO: Maybe in the future we need to change it to a more meaningful
+ # non-colliding name.
+ def UnitsContainer(self, *args: Any, **kwargs: Any) -> UnitsContainer:
+ return UnitsContainer(*args, non_int_type=self.non_int_type, **kwargs)
__call__ = parse_expression
+
+
+class PlainRegistry(GenericPlainRegistry[PlainQuantity[Any], PlainUnit]):
+ Quantity: TypeAlias = PlainQuantity[Any]
+ Unit: TypeAlias = PlainUnit
diff --git a/pint/facets/system/__init__.py b/pint/facets/system/__init__.py
index e95098b..24e68b7 100644
--- a/pint/facets/system/__init__.py
+++ b/pint/facets/system/__init__.py
@@ -12,6 +12,6 @@ from __future__ import annotations
from .definitions import SystemDefinition
from .objects import System
-from .registry import SystemRegistry
+from .registry import SystemRegistry, GenericSystemRegistry
-__all__ = ["SystemDefinition", "System", "SystemRegistry"]
+__all__ = ["SystemDefinition", "System", "SystemRegistry", "GenericSystemRegistry"]
diff --git a/pint/facets/system/definitions.py b/pint/facets/system/definitions.py
index 1ce8269..eb582f3 100644
--- a/pint/facets/system/definitions.py
+++ b/pint/facets/system/definitions.py
@@ -11,7 +11,7 @@ from __future__ import annotations
from collections.abc import Iterable
from dataclasses import dataclass
-from ..._typing import Self
+from ...compat import Self
from ... import errors
diff --git a/pint/facets/system/objects.py b/pint/facets/system/objects.py
index 69b1c84..cf6a24f 100644
--- a/pint/facets/system/objects.py
+++ b/pint/facets/system/objects.py
@@ -14,7 +14,9 @@ import numbers
from typing import Any
from collections.abc import Iterable
-from ..._typing import Self
+
+from typing import Callable, Generic
+from numbers import Number
from ...babel_names import _babel_systems
from ...compat import babel_parse
@@ -25,6 +27,20 @@ from ...util import (
to_units_container,
)
from .definitions import SystemDefinition
+from .. import group
+from ..plain import MagnitudeT
+
+from ..._typing import UnitLike
+
+GetRootUnits = Callable[[UnitLike, bool], tuple[Number, UnitLike]]
+
+
+class SystemQuantity(Generic[MagnitudeT], group.GroupQuantity[MagnitudeT]):
+ pass
+
+
+class SystemUnit(group.GroupUnit):
+ pass
class System(SharedRegistryObject):
@@ -76,11 +92,11 @@ class System(SharedRegistryObject):
def members(self):
d = self._REGISTRY._groups
if self._computed_members is None:
- self._computed_members = set()
+ tmp: set[str] = set()
for group_name in self._used_groups:
try:
- self._computed_members |= d[group_name].members
+ tmp |= d[group_name].members
except KeyError:
logger.warning(
"Could not resolve {} in System {}".format(
@@ -88,7 +104,7 @@ class System(SharedRegistryObject):
)
)
- self._computed_members = frozenset(self._computed_members)
+ self._computed_members = frozenset(tmp)
return self._computed_members
@@ -116,17 +132,30 @@ class System(SharedRegistryObject):
return locale.measurement_systems[name]
return self.name
+ # TODO: When 3.11 is minimal version, use Self
+
@classmethod
def from_lines(
- cls: type[Self], lines: Iterable[str], get_root_func, non_int_type: type = float
- ) -> Self:
+ cls: type[System],
+ lines: Iterable[str],
+ get_root_func: GetRootUnits,
+ non_int_type: type = float,
+ ) -> System:
# TODO: we changed something here it used to be
# system_definition = SystemDefinition.from_lines(lines, get_root_func)
system_definition = SystemDefinition.from_lines(lines, non_int_type)
+
+ if system_definition is None:
+ raise ValueError(f"Could not define System from from {lines}")
+
return cls.from_definition(system_definition, get_root_func)
@classmethod
- def from_definition(cls, system_definition: SystemDefinition, get_root_func=None):
+ def from_definition(
+ cls: type[System],
+ system_definition: SystemDefinition,
+ get_root_func: GetRootUnits | None = None,
+ ) -> System:
if get_root_func is None:
# TODO: kept for backwards compatibility
get_root_func = cls._REGISTRY.get_root_units
diff --git a/pint/facets/system/registry.py b/pint/facets/system/registry.py
index 6e0878e..30921bd 100644
--- a/pint/facets/system/registry.py
+++ b/pint/facets/system/registry.py
@@ -9,10 +9,14 @@
from __future__ import annotations
from numbers import Number
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Generic, Any
from ... import errors
+from ...compat import TypeAlias
+
+from ..plain import QuantityT, UnitT
+
if TYPE_CHECKING:
from ..._typing import Quantity, Unit
@@ -22,13 +26,14 @@ from ...util import (
create_class_with_registry,
to_units_container,
)
-from ..group import GroupRegistry
+from ..group import GenericGroupRegistry
from .definitions import SystemDefinition
-from .objects import Lister, System
from . import objects
-class SystemRegistry(GroupRegistry):
+class GenericSystemRegistry(
+ Generic[QuantityT, UnitT], GenericGroupRegistry[QuantityT, UnitT]
+):
"""Handle of Systems.
Conversion between units with different dimensions according
@@ -46,24 +51,24 @@ class SystemRegistry(GroupRegistry):
# TODO: Change this to System: System to specify class
# and use introspection to get system class as a way
# to enjoy typing goodies
- System = objects.System
+ System: type[objects.System]
- def __init__(self, system=None, **kwargs):
+ def __init__(self, system: str | None = None, **kwargs):
super().__init__(**kwargs)
#: Map system name to system.
#: :type: dict[ str | System]
- self._systems: dict[str, System] = {}
+ self._systems: dict[str, objects.System] = {}
#: Maps dimensionality (UnitsContainer) to Dimensionality (UnitsContainer)
- self._base_units_cache = {}
+ self._base_units_cache: dict[UnitsContainerT, UnitsContainerT] = {}
- self._default_system = system
+ self._default_system_name: str | None = system
def _init_dynamic_classes(self) -> None:
"""Generate subclasses on the fly and attach them to self"""
super()._init_dynamic_classes()
- self.System = create_class_with_registry(self, self.System)
+ self.System = create_class_with_registry(self, objects.System)
def _after_init(self) -> None:
"""Invoked at the end of ``__init__``.
@@ -74,7 +79,7 @@ class SystemRegistry(GroupRegistry):
super()._after_init()
#: System name to be used by default.
- self._default_system = self._default_system or self._defaults.get(
+ self._default_system_name = self._default_system_name or self._defaults.get(
"system", None
)
@@ -82,7 +87,7 @@ class SystemRegistry(GroupRegistry):
super()._register_definition_adders()
self._register_adder(SystemDefinition, self._add_system)
- def _add_system(self, sd: SystemDefinition):
+ def _add_system(self, sd: SystemDefinition) -> None:
if sd.name in self._systems:
raise ValueError(f"System {sd.name} already present in registry")
@@ -96,29 +101,29 @@ class SystemRegistry(GroupRegistry):
@property
def sys(self):
- return Lister(self._systems)
+ return objects.Lister(self._systems)
@property
- def default_system(self) -> System:
- return self._default_system
+ def default_system(self) -> str | None:
+ return self._default_system_name
@default_system.setter
- def default_system(self, name):
+ def default_system(self, name: str) -> None:
if name:
if name not in self._systems:
raise ValueError("Unknown system %s" % name)
self._base_units_cache = {}
- self._default_system = name
+ self._default_system_name = name
- def get_system(self, name: str, create_if_needed: bool = True) -> System:
+ def get_system(self, name: str, create_if_needed: bool = True) -> objects.System:
"""Return a Group.
Parameters
----------
name : str
- Name of the group to be
+ Name of the group to be.
create_if_needed : bool
If True, create a group if not found. If False, raise an Exception.
(Default value = True)
@@ -141,7 +146,7 @@ class SystemRegistry(GroupRegistry):
self,
input_units: UnitLike | Quantity,
check_nonmult: bool = True,
- system: str | System | None = None,
+ system: str | objects.System | None = None,
) -> tuple[Number, Unit]:
"""Convert unit or dict of units to the plain units.
@@ -179,15 +184,15 @@ class SystemRegistry(GroupRegistry):
self,
input_units: UnitsContainerT,
check_nonmult: bool = True,
- system: str | System | None = None,
+ system: str | objects.System | None = None,
):
if system is None:
- system = self._default_system
+ system = self._default_system_name
# The cache is only done for check_nonmult=True and the current system.
if (
check_nonmult
- and system == self._default_system
+ and system == self._default_system_name
and input_units in self._base_units_cache
):
return self._base_units_cache[input_units]
@@ -220,16 +225,32 @@ class SystemRegistry(GroupRegistry):
return base_factor, destination_units
- def _get_compatible_units(self, input_units, group_or_system) -> frozenset[Unit]:
+ def get_compatible_units(
+ self, input_units: UnitsContainerT, group_or_system: str | None = None
+ ) -> frozenset[Unit]:
+ """ """
+
+ group_or_system = group_or_system or self._default_system_name
+
if group_or_system is None:
- group_or_system = self._default_system
+ return super().get_compatible_units(input_units)
+
+ input_units = to_units_container(input_units)
+
+ equiv = self._get_compatible_units(input_units, group_or_system)
+
+ return frozenset(self.Unit(eq) for eq in equiv)
+ def _get_compatible_units(
+ self, input_units: UnitsContainerT, group_or_system: str | None = None
+ ) -> frozenset[Unit]:
if group_or_system and group_or_system in self._systems:
members = self._systems[group_or_system].members
# group_or_system has been handled by System
- return frozenset(members & super()._get_compatible_units(input_units, None))
+ return frozenset(members & super()._get_compatible_units(input_units))
try:
+ # This will be handled by groups
return super()._get_compatible_units(input_units, group_or_system)
except ValueError as ex:
# It might be also a system
@@ -238,3 +259,10 @@ class SystemRegistry(GroupRegistry):
"Unknown Group o System with name '%s'" % group_or_system
) from ex
raise ex
+
+
+class SystemRegistry(
+ GenericSystemRegistry[objects.SystemQuantity[Any], objects.SystemUnit]
+):
+ Quantity: TypeAlias = objects.SystemQuantity[Any]
+ Unit: TypeAlias = objects.SystemUnit
diff --git a/pint/formatting.py b/pint/formatting.py
index 880f55b..28adf25 100644
--- a/pint/formatting.py
+++ b/pint/formatting.py
@@ -13,17 +13,27 @@ from __future__ import annotations
import functools
import re
import warnings
-from typing import Callable, Any
+from typing import Callable, Any, TYPE_CHECKING, TypeVar
from collections.abc import Iterable
from numbers import Number
from .babel_names import _babel_lengths, _babel_units
-from .compat import babel_parse
+from .compat import babel_parse, HAS_BABEL
+
+if TYPE_CHECKING:
+ from .util import ItMatrix, UnitsContainer
+
+ if HAS_BABEL:
+ import babel
+
+ Locale = babel.Locale
+ else:
+ Locale = TypeVar("Locale")
__JOIN_REG_EXP = re.compile(r"{\d*}")
-def _join(fmt: str, iterable: Iterable[Any]):
+def _join(fmt: str, iterable: Iterable[Any]) -> str:
"""Join an iterable with the format specified in fmt.
The format can be specified in two ways:
@@ -124,6 +134,7 @@ _FORMATS: dict[str, dict[str, Any]] = {
}
#: _FORMATTERS maps format names to callables doing the formatting
+# TODO fix Callable typing
_FORMATTERS: dict[str, Callable] = {}
@@ -167,7 +178,7 @@ def register_unit_format(name: str):
@register_unit_format("P")
-def format_pretty(unit, registry, **options):
+def format_pretty(unit: UnitsContainer, registry, **options) -> str:
return formatter(
unit.items(),
as_ratio=True,
@@ -181,7 +192,7 @@ def format_pretty(unit, registry, **options):
)
-def latex_escape(string):
+def latex_escape(string: str) -> str:
"""
Prepend characters that have a special meaning in LaTeX with a backslash.
"""
@@ -198,7 +209,7 @@ def latex_escape(string):
@register_unit_format("L")
-def format_latex(unit, registry, **options):
+def format_latex(unit: UnitsContainer, registry, **options) -> str:
preprocessed = {rf"\mathrm{{{latex_escape(u)}}}": p for u, p in unit.items()}
formatted = formatter(
preprocessed.items(),
@@ -214,7 +225,7 @@ def format_latex(unit, registry, **options):
@register_unit_format("Lx")
-def format_latex_siunitx(unit, registry, **options):
+def format_latex_siunitx(unit: UnitsContainer, registry, **options) -> str:
if registry is None:
raise ValueError(
"Can't format as siunitx without a registry."
@@ -228,7 +239,7 @@ def format_latex_siunitx(unit, registry, **options):
@register_unit_format("H")
-def format_html(unit, registry, **options):
+def format_html(unit: UnitsContainer, registry, **options) -> str:
return formatter(
unit.items(),
as_ratio=True,
@@ -242,7 +253,7 @@ def format_html(unit, registry, **options):
@register_unit_format("D")
-def format_default(unit, registry, **options):
+def format_default(unit: UnitsContainer, registry, **options) -> str:
return formatter(
unit.items(),
as_ratio=True,
@@ -256,7 +267,7 @@ def format_default(unit, registry, **options):
@register_unit_format("C")
-def format_compact(unit, registry, **options):
+def format_compact(unit: UnitsContainer, registry, **options) -> str:
return formatter(
unit.items(),
as_ratio=True,
@@ -270,7 +281,7 @@ def format_compact(unit, registry, **options):
def formatter(
- items: list[tuple[str, Number]],
+ items: Iterable[tuple[str, Number]],
as_ratio: bool = True,
single_denominator: bool = False,
product_fmt: str = " * ",
@@ -282,7 +293,7 @@ def formatter(
babel_length: str = "long",
babel_plural_form: str = "one",
sort: bool = True,
-):
+) -> str:
"""Format a list of (name, exponent) pairs.
Parameters
@@ -393,7 +404,7 @@ def formatter(
_BASIC_TYPES = frozenset("bcdeEfFgGnosxX%uS")
-def _parse_spec(spec):
+def _parse_spec(spec: str) -> str:
result = ""
for ch in reversed(spec):
if ch == "~" or ch in _BASIC_TYPES:
@@ -410,7 +421,7 @@ def _parse_spec(spec):
return result
-def format_unit(unit, spec, registry=None, **options):
+def format_unit(unit, spec: str, registry=None, **options):
# registry may be None to allow formatting `UnitsContainer` objects
# in that case, the spec may not be "Lx"
@@ -430,10 +441,10 @@ def format_unit(unit, spec, registry=None, **options):
return fmt(unit, registry=registry, **options)
-def siunitx_format_unit(units, registry):
+def siunitx_format_unit(units: UnitsContainer, registry) -> str:
"""Returns LaTeX code for the unit that can be put into an siunitx command."""
- def _tothe(power):
+ def _tothe(power: int | float) -> str:
if isinstance(power, int) or (isinstance(power, float) and power.is_integer()):
if power == 1:
return ""
@@ -473,7 +484,7 @@ def siunitx_format_unit(units, registry):
return "".join(lpos) + "".join(lneg)
-def extract_custom_flags(spec):
+def extract_custom_flags(spec: str) -> str:
import re
if not spec:
@@ -488,14 +499,16 @@ def extract_custom_flags(spec):
return "".join(custom_flags)
-def remove_custom_flags(spec):
+def remove_custom_flags(spec: str) -> str:
for flag in sorted(_FORMATTERS.keys(), key=len, reverse=True) + ["~"]:
if flag:
spec = spec.replace(flag, "")
return spec
-def split_format(spec, default, separate_format_defaults=True):
+def split_format(
+ spec: str, default: str, separate_format_defaults: bool = True
+) -> tuple[str, str]:
mspec = remove_custom_flags(spec)
uspec = extract_custom_flags(spec)
@@ -535,11 +548,11 @@ def split_format(spec, default, separate_format_defaults=True):
return mspec, uspec
-def vector_to_latex(vec, fmtfun=lambda x: format(x, ".2f")):
+def vector_to_latex(vec: Iterable[Any], fmtfun=lambda x: format(x, ".2f")) -> str:
return matrix_to_latex([vec], fmtfun)
-def matrix_to_latex(matrix, fmtfun=lambda x: format(x, ".2f")):
+def matrix_to_latex(matrix: ItMatrix, fmtfun=lambda x: format(x, ".2f")) -> str:
ret = []
for row in matrix:
@@ -548,7 +561,9 @@ def matrix_to_latex(matrix, fmtfun=lambda x: format(x, ".2f")):
return r"\begin{pmatrix}%s\end{pmatrix}" % "\\\\ \n".join(ret)
-def ndarray_to_latex_parts(ndarr, fmtfun=lambda x: format(x, ".2f"), dim=()):
+def ndarray_to_latex_parts(
+ ndarr, fmtfun=lambda x: format(x, ".2f"), dim: tuple[int] = tuple()
+):
if isinstance(fmtfun, str):
fmt = fmtfun
fmtfun = lambda x: format(x, fmt)
@@ -573,5 +588,7 @@ def ndarray_to_latex_parts(ndarr, fmtfun=lambda x: format(x, ".2f"), dim=()):
return ret
-def ndarray_to_latex(ndarr, fmtfun=lambda x: format(x, ".2f"), dim=()):
+def ndarray_to_latex(
+ ndarr, fmtfun=lambda x: format(x, ".2f"), dim: tuple[int] = tuple()
+) -> str:
return "\n".join(ndarray_to_latex_parts(ndarr, fmtfun, dim))
diff --git a/pint/registry.py b/pint/registry.py
index 474eb77..964d8a5 100644
--- a/pint/registry.py
+++ b/pint/registry.py
@@ -14,16 +14,10 @@
from __future__ import annotations
+from typing import Generic
+
from . import registry_helpers
-from .facets import (
- ContextRegistry,
- DaskRegistry,
- FormattingRegistry,
- MeasurementRegistry,
- NonMultiplicativeRegistry,
- NumpyRegistry,
- SystemRegistry,
-)
+from . import facets
from .util import logger, pi_theorem
@@ -33,37 +27,40 @@ from .util import logger, pi_theorem
class Quantity(
- # SystemRegistry.Quantity,
- # ContextRegistry.Quantity,
- DaskRegistry.Quantity,
- NumpyRegistry.Quantity,
- MeasurementRegistry.Quantity,
- FormattingRegistry.Quantity,
- NonMultiplicativeRegistry.Quantity,
+ facets.SystemRegistry.Quantity,
+ facets.ContextRegistry.Quantity,
+ facets.DaskRegistry.Quantity,
+ facets.NumpyRegistry.Quantity,
+ facets.MeasurementRegistry.Quantity,
+ facets.FormattingRegistry.Quantity,
+ facets.NonMultiplicativeRegistry.Quantity,
+ facets.PlainRegistry.Quantity,
):
pass
class Unit(
- # SystemRegistry.Unit,
- # ContextRegistry.Unit,
- # DaskRegistry.Unit,
- NumpyRegistry.Unit,
- # MeasurementRegistry.Unit,
- FormattingRegistry.Unit,
- NonMultiplicativeRegistry.Unit,
+ facets.SystemRegistry.Unit,
+ facets.ContextRegistry.Unit,
+ facets.DaskRegistry.Unit,
+ facets.NumpyRegistry.Unit,
+ facets.MeasurementRegistry.Unit,
+ facets.FormattingRegistry.Unit,
+ facets.NonMultiplicativeRegistry.Unit,
+ facets.PlainRegistry.Unit,
):
pass
class UnitRegistry(
- SystemRegistry,
- ContextRegistry,
- DaskRegistry,
- NumpyRegistry,
- MeasurementRegistry,
- FormattingRegistry,
- NonMultiplicativeRegistry,
+ facets.GenericSystemRegistry[Quantity, Unit],
+ facets.GenericContextRegistry[Quantity, Unit],
+ facets.GenericDaskRegistry[Quantity, Unit],
+ facets.GenericNumpyRegistry[Quantity, Unit],
+ facets.GenericMeasurementRegistry[Quantity, Unit],
+ facets.GenericFormattingRegistry[Quantity, Unit],
+ facets.GenericNonMultiplicativeRegistry[Quantity, Unit],
+ facets.GenericPlainRegistry[Quantity, Unit],
):
"""The unit registry stores the definitions and relationships between units.
@@ -171,7 +168,7 @@ class UnitRegistry(
check = registry_helpers.check
-class LazyRegistry:
+class LazyRegistry(Generic[facets.QuantityT, facets.UnitT]):
def __init__(self, args=None, kwargs=None):
self.__dict__["params"] = args or (), kwargs or {}
diff --git a/pint/registry_helpers.py b/pint/registry_helpers.py
index 1f28036..7eee694 100644
--- a/pint/registry_helpers.py
+++ b/pint/registry_helpers.py
@@ -13,7 +13,7 @@ from __future__ import annotations
import functools
from inspect import signature
from itertools import zip_longest
-from typing import TYPE_CHECKING, Callable, TypeVar
+from typing import TYPE_CHECKING, Callable, TypeVar, Any
from collections.abc import Iterable
from ._typing import F
@@ -189,7 +189,7 @@ def wraps(
ret: str | Unit | Iterable[str | Unit | None] | None,
args: str | Unit | Iterable[str | Unit | None] | None,
strict: bool = True,
-) -> Callable[[Callable[..., T]], Callable[..., Quantity[T]]]:
+) -> Callable[[Callable[..., Any]], Callable[..., Quantity]]:
"""Wraps a function to become pint-aware.
Use it when a function requires a numerical value but in some specific
@@ -253,7 +253,7 @@ def wraps(
)
ret = _to_units_container(ret, ureg)
- def decorator(func: Callable[..., T]) -> Callable[..., Quantity[T]]:
+ def decorator(func: Callable[..., Any]) -> Callable[..., Quantity]:
count_params = len(signature(func).parameters)
if len(args) != count_params:
raise TypeError(
@@ -269,7 +269,7 @@ def wraps(
)
@functools.wraps(func, assigned=assigned, updated=updated)
- def wrapper(*values, **kw) -> Quantity[T]:
+ def wrapper(*values, **kw) -> Quantity:
values, kw = _apply_defaults(func, values, kw)
# In principle, the values are used as is
diff --git a/pint/testing.py b/pint/testing.py
index 8e4f15f..d99df0b 100644
--- a/pint/testing.py
+++ b/pint/testing.py
@@ -34,7 +34,7 @@ def _get_comparable_magnitudes(first, second, msg):
return m1, m2
-def assert_equal(first, second, msg=None):
+def assert_equal(first, second, msg: str | None = None) -> None:
if msg is None:
msg = f"Comparing {first!r} and {second!r}. "
@@ -57,7 +57,9 @@ def assert_equal(first, second, msg=None):
assert m1 == m2, msg
-def assert_allclose(first, second, rtol=1e-07, atol=0, msg=None):
+def assert_allclose(
+ first, second, rtol: float = 1e-07, atol: float = 0, msg: str | None = None
+) -> None:
if msg is None:
try:
msg = f"Comparing {first!r} and {second!r}. "
diff --git a/pint/util.py b/pint/util.py
index d75d1b5..40ea39e 100644
--- a/pint/util.py
+++ b/pint/util.py
@@ -30,15 +30,15 @@ from typing import (
)
from collections.abc import Hashable, Generator
-from .compat import NUMERIC_TYPES, tokenizer
+from .compat import NUMERIC_TYPES, tokenizer, Self
from .errors import DefinitionSyntaxError
from .formatting import format_unit
from .pint_eval import build_eval_tree
-from ._typing import PintScalar
+from ._typing import Scalar
if TYPE_CHECKING:
- from ._typing import Quantity, UnitLike, Self
+ from ._typing import Quantity, UnitLike, QuantityOrUnitLike
from .registry import UnitRegistry
@@ -47,12 +47,13 @@ logger.addHandler(NullHandler())
T = TypeVar("T")
TH = TypeVar("TH", bound=Hashable)
+TT = TypeVar("TT", bound=type)
# TODO: Change when Python 3.10 becomes minimal version.
# ItMatrix: TypeAlias = Iterable[Iterable[PintScalar]]
# Matrix: TypeAlias = list[list[PintScalar]]
-ItMatrix = Iterable[Iterable[PintScalar]]
-Matrix = list[list[PintScalar]]
+ItMatrix = Iterable[Iterable[Scalar]]
+Matrix = list[list[Scalar]]
def _noop(x: T) -> T:
@@ -65,7 +66,7 @@ def matrix_to_string(
col_headers: Iterable[str] | None = None,
fmtfun: Callable[
[
- PintScalar,
+ Scalar,
],
str,
] = "{:0.0f}".format,
@@ -125,9 +126,9 @@ def matrix_apply(
matrix: ItMatrix,
func: Callable[
[
- PintScalar,
+ Scalar,
],
- PintScalar,
+ Scalar,
],
) -> Matrix:
"""Apply a function to individual elements within a matrix.
@@ -172,7 +173,14 @@ def column_echelon_form(
Swapped rows.
"""
- _transpose = transpose if transpose_result else _noop
+ _transpose: Callable[
+ [
+ ItMatrix,
+ ],
+ Matrix,
+ ] = (
+ transpose if transpose_result else _noop
+ )
ech_matrix = matrix_apply(
transpose(matrix),
@@ -181,7 +189,7 @@ def column_echelon_form(
rows, cols = len(ech_matrix), len(ech_matrix[0])
# M = [[ntype(x) for x in row] for row in M]
- id_matrix: list[list[PintScalar]] = [ # noqa: E741
+ id_matrix: list[list[Scalar]] = [ # noqa: E741
[ntype(1) if n == nc else ntype(0) for nc in range(rows)] for n in range(rows)
]
@@ -415,7 +423,7 @@ def find_connected_nodes(
return visited
-class udict(dict[str, PintScalar]):
+class udict(dict[str, Scalar]):
"""Custom dict implementing __missing__."""
def __missing__(self, key: str):
@@ -425,7 +433,7 @@ class udict(dict[str, PintScalar]):
return udict(self)
-class UnitsContainer(Mapping[str, PintScalar]):
+class UnitsContainer(Mapping[str, Scalar]):
"""The UnitsContainer stores the product of units and their respective
exponent and implements the corresponding operations.
@@ -441,10 +449,12 @@ class UnitsContainer(Mapping[str, PintScalar]):
_d: udict
_hash: int | None
- _one: PintScalar
+ _one: Scalar
_non_int_type: type
- def __init__(self, *args, non_int_type: type | None = None, **kwargs) -> None:
+ def __init__(
+ self, *args: Any, non_int_type: type | None = None, **kwargs: Any
+ ) -> None:
if args and isinstance(args[0], UnitsContainer):
default_non_int_type = args[0]._non_int_type
else:
@@ -542,7 +552,7 @@ class UnitsContainer(Mapping[str, PintScalar]):
def __len__(self) -> int:
return len(self._d)
- def __getitem__(self, key: str) -> PintScalar:
+ def __getitem__(self, key: str) -> Scalar:
return self._d[key]
def __contains__(self, key: str) -> bool:
@@ -554,10 +564,10 @@ class UnitsContainer(Mapping[str, PintScalar]):
return self._hash
# Only needed by pickle protocol 0 and 1 (used by pytables)
- def __getstate__(self) -> tuple[udict, PintScalar, type]:
+ def __getstate__(self) -> tuple[udict, Scalar, type]:
return self._d, self._one, self._non_int_type
- def __setstate__(self, state: tuple[udict, PintScalar, type]):
+ def __setstate__(self, state: tuple[udict, Scalar, type]):
self._d, self._one, self._non_int_type = state
self._hash = None
@@ -682,9 +692,9 @@ class ParserHelper(UnitsContainer):
__slots__ = ("scale",)
- scale: PintScalar
+ scale: Scalar
- def __init__(self, scale: PintScalar = 1, *args, **kwargs):
+ def __init__(self, scale: Scalar = 1, *args, **kwargs):
super().__init__(*args, **kwargs)
self.scale = scale
@@ -1002,7 +1012,7 @@ class PrettyIPython:
def to_units_container(
- unit_like: UnitLike | Quantity, registry: UnitRegistry | None = None
+ unit_like: QuantityOrUnitLike, registry: UnitRegistry | None = None
) -> UnitsContainer:
"""Convert a unit compatible type to a UnitsContainer.
@@ -1025,6 +1035,7 @@ def to_units_container(
return unit_like._units
elif str in mro:
if registry:
+ # TODO: Why not parse.units here?
return registry._parse_units(unit_like)
else:
return ParserHelper.from_string(unit_like)
@@ -1124,7 +1135,9 @@ def sized(y: Any) -> bool:
return True
-def create_class_with_registry(registry: UnitRegistry, base_class: type) -> type:
+def create_class_with_registry(
+ registry: UnitRegistry, base_class: type[TT]
+) -> type[TT]:
"""Create new class inheriting from base_class and
filling _REGISTRY class attribute with an actual instanced registry.
"""