summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorJules Chéron <jules.cheron@gmail.com>2021-05-22 23:03:34 +0200
committerJules Chéron <jules.cheron@gmail.com>2021-08-01 20:22:11 +0200
commitc60e1af833391a60a82f3c47460fca899e220de5 (patch)
treec9fbae3c8633e8d3cf6d36a7e1e6d91285394c14
parent78769ec967a24cd57c9fb6474aff2aa3708ca483 (diff)
downloadpint-c60e1af833391a60a82f3c47460fca899e220de5.tar.gz
Add pint typing module
- Quantity as Generic class - Add overloaded signature for __new__ Quantity - Add typing module as private - Add py.typed for PEP561 supports - Add overloaded signature for __new__ Quantity - Quantity as Generic class - Add type hints throughout the project - Add py.typed in package data in setup.cfg - Add type hints for decorators - Add type hints for public API of registry.py - Add type hints for units.py
-rw-r--r--.coveragerc16
-rw-r--r--CHANGES2
-rw-r--r--pint/_typing.py18
-rw-r--r--pint/context.py33
-rw-r--r--pint/definitions.py79
-rw-r--r--pint/formatting.py3
-rw-r--r--pint/py.typed0
-rw-r--r--pint/quantity.py199
-rw-r--r--pint/registry.py282
-rw-r--r--pint/registry_helpers.py24
-rw-r--r--pint/unit.py53
-rw-r--r--pint/util.py48
-rw-r--r--setup.cfg2
13 files changed, 529 insertions, 230 deletions
diff --git a/.coveragerc b/.coveragerc
index 73fab9e..fbb079e 100644
--- a/.coveragerc
+++ b/.coveragerc
@@ -1,3 +1,19 @@
[run]
omit = pint/testsuite/*
+[report]
+# Regexes for lines to exclude from consideration
+exclude_lines =
+ # Have to re-enable the standard pragma
+ pragma: no cover
+
+ # Don't complain about missing debug-only code:
+ def __repr__
+
+ # Don't complain if tests don't hit defensive assertion code:
+ raise AssertionError
+ raise NotImplementedError
+ AbstractMethodError
+
+ # Don't complain if non-runnable code isn't run:
+ if TYPE_CHECKING: \ No newline at end of file
diff --git a/CHANGES b/CHANGES
index 5dc0e84..d624141 100644
--- a/CHANGES
+++ b/CHANGES
@@ -19,6 +19,8 @@ Pint Changelog
- pint no longer supports Python 3.6
- Minimum Numpy version supported is 1.17+
+- Add supports for type hints for Quantity class. Quantity is now a Generic (PEP560).
+- Add support for [PEP561](https://www.python.org/dev/peps/pep-0561/) (Package Type information)
0.17 (2021-03-22)
diff --git a/pint/_typing.py b/pint/_typing.py
new file mode 100644
index 0000000..dbeca15
--- /dev/null
+++ b/pint/_typing.py
@@ -0,0 +1,18 @@
+from typing import TYPE_CHECKING, Any, Callable, Tuple, TypeVar, Union
+
+if TYPE_CHECKING:
+ from .quantity import Quantity
+ from .unit import Unit
+ from .util import UnitsContainer
+
+UnitLike = Union[str, "UnitsContainer", "Unit"]
+
+QuantityOrUnitLike = Union["Quantity", UnitLike]
+
+Shape = Tuple[int, ...]
+
+_MagnitudeType = TypeVar("_MagnitudeType")
+S = TypeVar("S")
+
+FuncType = Callable[..., Any]
+F = TypeVar("F", bound=FuncType)
diff --git a/pint/context.py b/pint/context.py
index 6cd440e..59ea9cf 100644
--- a/pint/context.py
+++ b/pint/context.py
@@ -8,14 +8,22 @@
:license: BSD, see LICENSE for more details.
"""
+from __future__ import annotations
+
import re
import weakref
from collections import ChainMap, defaultdict
+from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple
from .definitions import Definition, UnitDefinition
from .errors import DefinitionSyntaxError
from .util import ParserHelper, SourceIterator, to_units_container
+if TYPE_CHECKING:
+ from .quantity import Quantity
+ from .registry import UnitRegistry
+ from .util import UnitsContainer
+
#: Regex to match the header parts of a context.
_header_re = re.compile(
r"@context\s*(?P<defaults>\(.*\))?\s+(?P<name>\w+)\s*(=(?P<aliases>.*))*"
@@ -25,8 +33,8 @@ _header_re = re.compile(
_varname_re = re.compile(r"[A-Za-z_][A-Za-z0-9_]*")
-def _expression_to_function(eq):
- def func(ureg, value, **kwargs):
+def _expression_to_function(eq: str) -> Callable[..., Quantity[Any]]:
+ def func(ureg: UnitRegistry, value: Any, **kwargs: Any) -> Quantity[Any]:
return ureg.parse_expression(eq, value=value, **kwargs)
return func
@@ -84,7 +92,12 @@ class Context:
>>> c.redefine("pound = 0.5 kg")
"""
- def __init__(self, name=None, aliases=(), defaults=None):
+ def __init__(
+ self,
+ name: Optional[str] = None,
+ aliases: Tuple[str, ...] = (),
+ defaults: Optional[dict] = None,
+ ) -> None:
self.name = name
self.aliases = aliases
@@ -106,7 +119,7 @@ class Context:
self.relation_to_context = weakref.WeakValueDictionary()
@classmethod
- def from_context(cls, context, **defaults):
+ def from_context(cls, context: Context, **defaults) -> Context:
"""Creates a new context that shares the funcs dictionary with the
original context. The default values are copied from the original
context and updated with the new defaults.
@@ -135,7 +148,7 @@ class Context:
return context
@classmethod
- def from_lines(cls, lines, to_base_func=None, non_int_type=float):
+ def from_lines(cls, lines, to_base_func=None, non_int_type=float) -> Context:
lines = SourceIterator(lines)
lineno, header = next(lines)
@@ -223,14 +236,14 @@ class Context:
return ctx
- def add_transformation(self, src, dst, func):
+ def add_transformation(self, src, dst, func) -> None:
"""Add a transformation function to the context."""
_key = self.__keytransform__(src, dst)
self.funcs[_key] = func
self.relation_to_context[_key] = self
- def remove_transformation(self, src, dst):
+ def remove_transformation(self, src, dst) -> None:
"""Add a transformation function to the context."""
_key = self.__keytransform__(src, dst)
@@ -238,7 +251,7 @@ class Context:
del self.relation_to_context[_key]
@staticmethod
- def __keytransform__(src, dst):
+ def __keytransform__(src, dst) -> Tuple[UnitsContainer, UnitsContainer]:
return to_units_container(src), to_units_container(dst)
def transform(self, src, dst, registry, value):
@@ -270,7 +283,9 @@ class Context:
raise DefinitionSyntaxError("Can't define base units within a context")
self.redefinitions.append(d)
- def hashable(self):
+ def hashable(
+ self,
+ ) -> Tuple[Optional[str], Tuple[str, ...], frozenset, frozenset, tuple]:
"""Generate a unique hashable and comparable representation of self, which can
be used as a key in a dict. This class cannot define ``__hash__`` because it is
mutable, and the Python interpreter does cache the output of ``__hash__``.
diff --git a/pint/definitions.py b/pint/definitions.py
index 7e30c89..f02157b 100644
--- a/pint/definitions.py
+++ b/pint/definitions.py
@@ -8,9 +8,12 @@
:license: BSD, see LICENSE for more details.
"""
+from __future__ import annotations
+
from collections import namedtuple
+from typing import Callable, Iterable, Optional, Union
-from .converters import LogarithmicConverter, OffsetConverter, ScaleConverter
+from .converters import Converter, LogarithmicConverter, OffsetConverter, ScaleConverter
from .errors import DefinitionSyntaxError
from .util import ParserHelper, UnitsContainer, _is_dim
@@ -42,7 +45,7 @@ class PreprocessedDefinition(
"""
@classmethod
- def from_string(cls, definition):
+ def from_string(cls, definition: str) -> PreprocessedDefinition:
name, definition = definition.split("=", 1)
name = name.strip()
@@ -64,7 +67,7 @@ class _NotNumeric(Exception):
self.value = value
-def numeric_parse(s, non_int_type=float):
+def numeric_parse(s: str, non_int_type: type = float):
"""Try parse a string into a number (without using eval).
Parameters
@@ -103,7 +106,13 @@ class Definition:
converter : callable or Converter or None
"""
- def __init__(self, name, symbol, aliases, converter):
+ def __init__(
+ self,
+ name: str,
+ symbol: Optional[str],
+ aliases: Iterable[str],
+ converter: Optional[Union[Callable, Converter]],
+ ):
if isinstance(converter, str):
raise TypeError(
@@ -112,19 +121,21 @@ class Definition:
self._name = name
self._symbol = symbol
- self._aliases = aliases
+ self._aliases = tuple(aliases)
self._converter = converter
@property
- def is_multiplicative(self):
+ def is_multiplicative(self) -> bool:
return self._converter.is_multiplicative
@property
- def is_logarithmic(self):
+ def is_logarithmic(self) -> bool:
return self._converter.is_logarithmic
@classmethod
- def from_string(cls, definition, non_int_type=float):
+ def from_string(
+ cls, definition: Union[str, PreprocessedDefinition], non_int_type: type = float
+ ) -> "Definition":
"""Parse a definition.
Parameters
@@ -150,30 +161,30 @@ class Definition:
return UnitDefinition.from_string(definition, non_int_type)
@property
- def name(self):
+ def name(self) -> str:
return self._name
@property
- def symbol(self):
+ def symbol(self) -> str:
return self._symbol or self._name
@property
- def has_symbol(self):
+ def has_symbol(self) -> bool:
return bool(self._symbol)
@property
- def aliases(self):
+ def aliases(self) -> Iterable[str]:
return self._aliases
- def add_aliases(self, *alias):
+ def add_aliases(self, *alias: str) -> None:
alias = tuple(a for a in alias if a not in self._aliases)
self._aliases = self._aliases + alias
@property
- def converter(self):
+ def converter(self) -> Converter:
return self._converter
- def __str__(self):
+ def __str__(self) -> str:
return self.name
@@ -188,7 +199,9 @@ class PrefixDefinition(Definition):
"""
@classmethod
- def from_string(cls, definition, non_int_type=float):
+ def from_string(
+ cls, definition: Union[str, PreprocessedDefinition], non_int_type: type = float
+ ) -> "PrefixDefinition":
if isinstance(definition, str):
definition = PreprocessedDefinition.from_string(definition)
@@ -226,14 +239,24 @@ class UnitDefinition(Definition):
"""
- def __init__(self, name, symbol, aliases, converter, reference=None, is_base=False):
+ def __init__(
+ self,
+ name: str,
+ symbol: Optional[str],
+ aliases: Iterable[str],
+ converter: Converter,
+ reference: Optional[UnitsContainer] = None,
+ is_base: bool = False,
+ ) -> None:
self.reference = reference
self.is_base = is_base
super().__init__(name, symbol, aliases, converter)
@classmethod
- def from_string(cls, definition, non_int_type=float):
+ def from_string(
+ cls, definition: Union[str, PreprocessedDefinition], non_int_type: type = float
+ ) -> "UnitDefinition":
if isinstance(definition, str):
definition = PreprocessedDefinition.from_string(definition)
@@ -305,14 +328,24 @@ class DimensionDefinition(Definition):
[density] = [mass] / [volume]
"""
- def __init__(self, name, symbol, aliases, converter, reference=None, is_base=False):
+ def __init__(
+ self,
+ name: str,
+ symbol: Optional[str],
+ aliases: Iterable[str],
+ converter: Optional[Converter],
+ reference: Optional[UnitsContainer] = None,
+ is_base: bool = False,
+ ) -> None:
self.reference = reference
self.is_base = is_base
super().__init__(name, symbol, aliases, converter=None)
@classmethod
- def from_string(cls, definition, non_int_type=float):
+ def from_string(
+ cls, definition: Union[str, PreprocessedDefinition], non_int_type: type = float
+ ) -> "DimensionDefinition":
if isinstance(definition, str):
definition = PreprocessedDefinition.from_string(definition)
@@ -350,11 +383,13 @@ class AliasDefinition(Definition):
@alias meter = my_meter
"""
- def __init__(self, name, aliases):
+ def __init__(self, name: str, aliases: Iterable[str]) -> None:
super().__init__(name=name, symbol=None, aliases=aliases, converter=None)
@classmethod
- def from_string(cls, definition, non_int_type=float):
+ def from_string(
+ cls, definition: Union[str, PreprocessedDefinition], non_int_type: type = float
+ ) -> AliasDefinition:
if isinstance(definition, str):
definition = PreprocessedDefinition.from_string(definition)
diff --git a/pint/formatting.py b/pint/formatting.py
index afc51fe..e9f2090 100644
--- a/pint/formatting.py
+++ b/pint/formatting.py
@@ -9,6 +9,7 @@
"""
import re
+from typing import Dict
from .babel_names import _babel_lengths, _babel_units
from .compat import babel_parse
@@ -71,7 +72,7 @@ def _pretty_fmt_exponent(num):
#: _FORMATS maps format specifications to the corresponding argument set to
#: formatter().
-_FORMATS = {
+_FORMATS: Dict[str, dict] = {
"P": { # Pretty format.
"as_ratio": True,
"single_denominator": False,
diff --git a/pint/py.typed b/pint/py.typed
new file mode 100644
index 0000000..e69de29
--- /dev/null
+++ b/pint/py.typed
diff --git a/pint/quantity.py b/pint/quantity.py
index fb0396e..34361bb 100644
--- a/pint/quantity.py
+++ b/pint/quantity.py
@@ -6,6 +6,8 @@
:license: BSD, see LICENSE for more details.
"""
+from __future__ import annotations
+
import bisect
import contextlib
import copy
@@ -17,9 +19,27 @@ import numbers
import operator
import re
import warnings
-from typing import List
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ Dict,
+ Generic,
+ Iterable,
+ Iterator,
+ List,
+ Optional,
+ Sequence,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+ overload,
+)
+from ._typing import S, Shape, UnitLike, _MagnitudeType
from .compat import (
+ HAS_NUMPY,
_to_magnitude,
babel_parse,
compute,
@@ -66,6 +86,14 @@ from .util import (
to_units_container,
)
+if TYPE_CHECKING:
+ from . import Context, Unit
+ from .registry import BaseRegistry
+ from .unit import UnitsContainer as UnitsContainerT
+
+ if HAS_NUMPY:
+ import numpy as np # noqa
+
class _Exception(Exception): # pragma: no cover
def __init__(self, internal):
@@ -153,7 +181,11 @@ def printoptions(*args, **kwargs):
np.set_printoptions(**opts)
-class Quantity(PrettyIPython, SharedRegistryObject):
+# Workaround to bypass dynamically generated Quantity with overload method
+Magnitude = TypeVar("Magnitude")
+
+
+class Quantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]):
"""Implements a class to describe a physical quantity:
the product of a numerical value and a unit of measurement.
@@ -170,21 +202,22 @@ class Quantity(PrettyIPython, SharedRegistryObject):
"""
#: Default formatting string.
- default_format = ""
+ default_format: str = ""
+ _magnitude: _MagnitudeType
@property
- def force_ndarray(self):
+ def force_ndarray(self) -> bool:
return self._REGISTRY.force_ndarray
@property
- def force_ndarray_like(self):
+ def force_ndarray_like(self) -> bool:
return self._REGISTRY.force_ndarray_like
@property
- def UnitsContainer(self):
+ def UnitsContainer(self) -> Callable[..., UnitsContainerT]:
return self._REGISTRY.UnitsContainer
- def __reduce__(self):
+ def __reduce__(self) -> tuple:
"""Allow pickling quantities. Since UnitRegistries are not pickled, upon
unpickling the new object is always attached to the application registry.
"""
@@ -194,6 +227,30 @@ class Quantity(PrettyIPython, SharedRegistryObject):
# build_quantity_class can't be pickled
return _unpickle_quantity, (Quantity, self.magnitude, self._units)
+ @overload
+ def __new__(
+ cls, value: str, units: Optional[UnitLike] = None
+ ) -> Quantity[Magnitude]:
+ ...
+
+ @overload
+ def __new__( # type: ignore[misc]
+ cls, value: Sequence, units: Optional[UnitLike] = None
+ ) -> Quantity[np.ndarray]:
+ ...
+
+ @overload
+ def __new__(
+ cls, value: Quantity[Magnitude], units: Optional[UnitLike] = None
+ ) -> Quantity[Magnitude]:
+ ...
+
+ @overload
+ def __new__(
+ cls, value: Magnitude, units: Optional[UnitLike] = None
+ ) -> Quantity[Magnitude]:
+ ...
+
def __new__(cls, value, units=None):
if is_upcast_type(type(value)):
raise TypeError(f"Quantity cannot wrap upcast type {type(value)}")
@@ -250,7 +307,7 @@ class Quantity(PrettyIPython, SharedRegistryObject):
def debug_used(self):
return self.__used
- def __iter__(self):
+ def __iter__(self: Quantity[Iterable[S]]) -> Iterator[S]:
# Make sure that, if self.magnitude is not iterable, we raise TypeError as soon
# as one calls iter(self) without waiting for the first element to be drawn from
# the iterator
@@ -262,34 +319,34 @@ class Quantity(PrettyIPython, SharedRegistryObject):
return it_outer()
- def __copy__(self):
+ def __copy__(self) -> Quantity[_MagnitudeType]:
ret = self.__class__(copy.copy(self._magnitude), self._units)
ret.__used = self.__used
return ret
- def __deepcopy__(self, memo):
+ def __deepcopy__(self, memo) -> Quantity[_MagnitudeType]:
ret = self.__class__(
copy.deepcopy(self._magnitude, memo), copy.deepcopy(self._units, memo)
)
ret.__used = self.__used
return ret
- def __str__(self):
+ def __str__(self) -> str:
if self._REGISTRY.fmt_locale is not None:
return self.format_babel()
return format(self)
- def __bytes__(self):
+ def __bytes__(self) -> bytes:
return str(self).encode(locale.getpreferredencoding())
- def __repr__(self):
+ def __repr__(self) -> str:
if isinstance(self._magnitude, float):
return f"<Quantity({self._magnitude:.9}, '{self._units}')>"
else:
return f"<Quantity({self._magnitude}, '{self._units}')>"
- def __hash__(self):
+ def __hash__(self) -> int:
self_base = self.to_base_units()
if self_base.dimensionless:
return hash(self_base.magnitude)
@@ -298,7 +355,7 @@ class Quantity(PrettyIPython, SharedRegistryObject):
_exp_pattern = re.compile(r"([0-9]\.?[0-9]*)e(-?)\+?0*([0-9]+)")
- def __format__(self, spec):
+ def __format__(self, spec: str) -> str:
if self._REGISTRY.fmt_locale is not None:
return self.format_babel(spec)
@@ -410,7 +467,7 @@ class Quantity(PrettyIPython, SharedRegistryObject):
p.text(" ")
p.pretty(self.units)
- def format_babel(self, spec="", **kwspec):
+ def format_babel(self, spec: str = "", **kwspec: Any) -> str:
spec = spec or self.default_format
# standard cases
@@ -435,16 +492,16 @@ class Quantity(PrettyIPython, SharedRegistryObject):
).replace("\n", "")
@property
- def magnitude(self):
+ def magnitude(self) -> _MagnitudeType:
"""Quantity's magnitude. Long form for `m`"""
return self._magnitude
@property
- def m(self):
+ def m(self) -> _MagnitudeType:
"""Quantity's magnitude. Short form for `magnitude`"""
return self._magnitude
- def m_as(self, units):
+ def m_as(self, units) -> _MagnitudeType:
"""Quantity's magnitude expressed in particular units.
Parameters
@@ -459,31 +516,31 @@ class Quantity(PrettyIPython, SharedRegistryObject):
return self.to(units).magnitude
@property
- def units(self):
+ def units(self) -> "Unit":
"""Quantity's units. Long form for `u`"""
return self._REGISTRY.Unit(self._units)
@property
- def u(self):
+ def u(self) -> "Unit":
"""Quantity's units. Short form for `units`"""
return self._REGISTRY.Unit(self._units)
@property
- def unitless(self):
+ def unitless(self) -> bool:
""" """
return not bool(self.to_root_units()._units)
@property
- def dimensionless(self):
+ def dimensionless(self) -> bool:
""" """
tmp = self.to_root_units()
return not bool(tmp.dimensionality)
- _dimensionality = None
+ _dimensionality: Optional[UnitsContainerT] = None
@property
- def dimensionality(self):
+ def dimensionality(self) -> UnitsContainerT:
"""
Returns
-------
@@ -495,12 +552,12 @@ class Quantity(PrettyIPython, SharedRegistryObject):
return self._dimensionality
- def check(self, dimension):
+ def check(self, dimension: UnitLike) -> bool:
"""Return true if the quantity's dimension matches passed dimension."""
return self.dimensionality == self._REGISTRY.get_dimensionality(dimension)
@classmethod
- def from_list(cls, quant_list, units=None):
+ def from_list(cls, quant_list: List[Quantity], units=None) -> Quantity[np.ndarray]:
"""Transforms a list of Quantities into an numpy.array quantity.
If no units are specified, the unit of the first element will be used.
Same as from_sequence.
@@ -522,7 +579,7 @@ class Quantity(PrettyIPython, SharedRegistryObject):
return cls.from_sequence(quant_list, units=units)
@classmethod
- def from_sequence(cls, seq, units=None):
+ def from_sequence(cls, seq: Sequence[Quantity], units=None) -> Quantity[np.ndarray]:
"""Transforms a sequence of Quantities into an numpy.array quantity.
If no units are specified, the unit of the first element will be used.
@@ -560,7 +617,7 @@ class Quantity(PrettyIPython, SharedRegistryObject):
def from_tuple(cls, tup):
return cls(tup[0], cls._REGISTRY.UnitsContainer(tup[1]))
- def to_tuple(self):
+ def to_tuple(self) -> Tuple[_MagnitudeType, Tuple[Tuple[str]]]:
return self.m, tuple(self._units.items())
def compatible_units(self, *contexts):
@@ -570,7 +627,9 @@ class Quantity(PrettyIPython, SharedRegistryObject):
return self._REGISTRY.get_compatible_units(self._units)
- def is_compatible_with(self, other, *contexts, **ctx_kwargs):
+ def is_compatible_with(
+ self, other: Any, *contexts: Union[str, Context], **ctx_kwargs: Any
+ ) -> bool:
"""check if the other object is compatible
Parameters
@@ -623,7 +682,7 @@ class Quantity(PrettyIPython, SharedRegistryObject):
inplace=is_duck_array_type(type(self._magnitude)),
)
- def ito(self, other=None, *contexts, **ctx_kwargs):
+ def ito(self, other=None, *contexts, **ctx_kwargs) -> None:
"""Inplace rescale to different units.
Parameters
@@ -642,7 +701,7 @@ class Quantity(PrettyIPython, SharedRegistryObject):
return None
- def to(self, other=None, *contexts, **ctx_kwargs):
+ def to(self, other=None, *contexts, **ctx_kwargs) -> Quantity[_MagnitudeType]:
"""Return Quantity rescaled to different units.
Parameters
@@ -664,7 +723,7 @@ class Quantity(PrettyIPython, SharedRegistryObject):
return self.__class__(magnitude, other)
- def ito_root_units(self):
+ def ito_root_units(self) -> None:
"""Return Quantity rescaled to root units."""
_, other = self._REGISTRY._get_root_units(self._units)
@@ -674,7 +733,7 @@ class Quantity(PrettyIPython, SharedRegistryObject):
return None
- def to_root_units(self):
+ def to_root_units(self) -> Quantity[_MagnitudeType]:
"""Return Quantity rescaled to root units."""
_, other = self._REGISTRY._get_root_units(self._units)
@@ -683,7 +742,7 @@ class Quantity(PrettyIPython, SharedRegistryObject):
return self.__class__(magnitude, other)
- def ito_base_units(self):
+ def ito_base_units(self) -> None:
"""Return Quantity rescaled to base units."""
_, other = self._REGISTRY._get_base_units(self._units)
@@ -693,7 +752,7 @@ class Quantity(PrettyIPython, SharedRegistryObject):
return None
- def to_base_units(self):
+ def to_base_units(self) -> Quantity[_MagnitudeType]:
"""Return Quantity rescaled to base units."""
_, other = self._REGISTRY._get_base_units(self._units)
@@ -702,7 +761,7 @@ class Quantity(PrettyIPython, SharedRegistryObject):
return self.__class__(magnitude, other)
- def ito_reduced_units(self):
+ def ito_reduced_units(self) -> None:
"""Return Quantity scaled in place to reduced units, i.e. one unit per
dimension. This will not reduce compound units (e.g., 'J/kg' will not
be reduced to m**2/s**2), nor can it make use of contexts at this time.
@@ -730,7 +789,7 @@ class Quantity(PrettyIPython, SharedRegistryObject):
return self.ito(newunits)
- def to_reduced_units(self):
+ def to_reduced_units(self) -> Quantity[_MagnitudeType]:
"""Return Quantity scaled in place to reduced units, i.e. one unit per
dimension. This will not reduce compound units (intentionally), nor
can it make use of contexts at this time.
@@ -741,7 +800,7 @@ class Quantity(PrettyIPython, SharedRegistryObject):
newq.ito_reduced_units()
return newq
- def to_compact(self, unit=None):
+ def to_compact(self, unit=None) -> Quantity[_MagnitudeType]:
""" "Return Quantity rescaled to compact, human-readable units.
To get output in terms of a different unit, use the unit parameter.
@@ -775,7 +834,7 @@ class Quantity(PrettyIPython, SharedRegistryObject):
):
return self
- SI_prefixes = {}
+ SI_prefixes: Dict[int, str] = {}
for prefix in self._REGISTRY._prefixes.values():
try:
scale = prefix.converter.scale
@@ -786,9 +845,9 @@ class Quantity(PrettyIPython, SharedRegistryObject):
except Exception:
SI_prefixes[0] = ""
- SI_prefixes = sorted(SI_prefixes.items())
- SI_powers = [item[0] for item in SI_prefixes]
- SI_bases = [item[1] for item in SI_prefixes]
+ SI_prefixes_list = sorted(SI_prefixes.items())
+ SI_powers = [item[0] for item in SI_prefixes_list]
+ SI_bases = [item[1] for item in SI_prefixes_list]
if unit is None:
unit = infer_base_unit(self)
@@ -817,25 +876,25 @@ class Quantity(PrettyIPython, SharedRegistryObject):
if index >= len(SI_bases):
index = -1
- prefix = SI_bases[index]
+ prefix_str = SI_bases[index]
- new_unit_str = prefix + unit_str
+ new_unit_str = prefix_str + unit_str
new_unit_container = q_base._units.rename(unit_str, new_unit_str)
return self.to(new_unit_container)
# Mathematical operations
- def __int__(self):
+ def __int__(self) -> int:
if self.dimensionless:
return int(self._convert_magnitude_not_inplace(UnitsContainer()))
raise DimensionalityError(self._units, "dimensionless")
- def __float__(self):
+ def __float__(self) -> float:
if self.dimensionless:
return float(self._convert_magnitude_not_inplace(UnitsContainer()))
raise DimensionalityError(self._units, "dimensionless")
- def __complex__(self):
+ def __complex__(self) -> complex:
if self.dimensionless:
return complex(self._convert_magnitude_not_inplace(UnitsContainer()))
raise DimensionalityError(self._units, "dimensionless")
@@ -1066,6 +1125,14 @@ class Quantity(PrettyIPython, SharedRegistryObject):
return self.__class__(magnitude, units)
+ @overload
+ def __iadd__(self, other: datetime.datetime) -> datetime.timedelta: # type: ignore[misc]
+ ...
+
+ @overload
+ def __iadd__(self, other) -> Quantity[_MagnitudeType]:
+ ...
+
def __iadd__(self, other):
if isinstance(other, datetime.datetime):
return self.to_timedelta() + other
@@ -1431,7 +1498,7 @@ class Quantity(PrettyIPython, SharedRegistryObject):
return self
@check_implemented
- def __pow__(self, other):
+ def __pow__(self, other) -> Quantity[_MagnitudeType]:
try:
_to_magnitude(other, self.force_ndarray, self.force_ndarray_like)
except PintTypeError:
@@ -1496,7 +1563,7 @@ class Quantity(PrettyIPython, SharedRegistryObject):
return self.__class__(magnitude, units)
@check_implemented
- def __rpow__(self, other):
+ def __rpow__(self, other) -> Quantity[_MagnitudeType]:
try:
_to_magnitude(other, self.force_ndarray, self.force_ndarray_like)
except PintTypeError:
@@ -1509,16 +1576,16 @@ class Quantity(PrettyIPython, SharedRegistryObject):
new_self = self.to_root_units()
return other ** new_self._magnitude
- def __abs__(self):
+ def __abs__(self) -> Quantity[_MagnitudeType]:
return self.__class__(abs(self._magnitude), self._units)
- def __round__(self, ndigits=0):
+ def __round__(self, ndigits: Optional[int] = 0) -> Quantity[int]:
return self.__class__(round(self._magnitude, ndigits=ndigits), self._units)
- def __pos__(self):
+ def __pos__(self) -> Quantity[_MagnitudeType]:
return self.__class__(operator.pos(self._magnitude), self._units)
- def __neg__(self):
+ def __neg__(self) -> Quantity[_MagnitudeType]:
return self.__class__(operator.neg(self._magnitude), self._units)
@check_implemented
@@ -1627,7 +1694,7 @@ class Quantity(PrettyIPython, SharedRegistryObject):
__ge__ = lambda self, other: self.compare(other, op=operator.ge)
__gt__ = lambda self, other: self.compare(other, op=operator.gt)
- def __bool__(self):
+ def __bool__(self) -> bool:
# Only cast when non-ambiguous (when multiplicative unit)
if self._is_multiplicative:
return bool(self._magnitude)
@@ -1695,7 +1762,7 @@ class Quantity(PrettyIPython, SharedRegistryObject):
else:
return value
- def __array__(self, t=None):
+ def __array__(self, t=None) -> np.ndarray:
warnings.warn(
"The unit of the quantity is stripped when downcasting to ndarray.",
UnitStrippedWarning,
@@ -1722,11 +1789,11 @@ class Quantity(PrettyIPython, SharedRegistryObject):
raise DimensionalityError("dimensionless", self._units)
return self.__class__(self.magnitude.clip(min, max, out, **kwargs), self._units)
- def fill(self, value):
+ def fill(self: Quantity[np.ndarray], value) -> None:
self._units = value._units
return self.magnitude.fill(value.magnitude)
- def put(self, indices, values, mode="raise"):
+ def put(self: Quantity[np.ndarray], indices, values, mode="raise") -> None:
if isinstance(values, self.__class__):
values = values.to(self).magnitude
elif self.dimensionless:
@@ -1736,11 +1803,11 @@ class Quantity(PrettyIPython, SharedRegistryObject):
self.magnitude.put(indices, values, mode)
@property
- def real(self):
+ def real(self) -> Quantity[_MagnitudeType]:
return self.__class__(self._magnitude.real, self._units)
@property
- def imag(self):
+ def imag(self) -> Quantity[_MagnitudeType]:
return self.__class__(self._magnitude.imag, self._units)
@property
@@ -1753,7 +1820,7 @@ class Quantity(PrettyIPython, SharedRegistryObject):
yield self.__class__(v, self._units)
@property
- def shape(self):
+ def shape(self) -> Shape:
return self._magnitude.shape
@shape.setter
@@ -1791,10 +1858,10 @@ class Quantity(PrettyIPython, SharedRegistryObject):
self.ito(to_units)
- def __len__(self):
+ def __len__(self) -> int:
return len(self._magnitude)
- def __getattr__(self, item):
+ def __getattr__(self, item) -> Any:
if item.startswith("__array_"):
# Handle array protocol attributes other than `__array__`
raise AttributeError(f"Array protocol attribute {item} not available.")
@@ -1944,7 +2011,7 @@ class Quantity(PrettyIPython, SharedRegistryObject):
self._get_unit_definition(d).reference == offset_unit_dim for d in deltas
)
- def _ok_for_muldiv(self, no_offset_units=None):
+ def _ok_for_muldiv(self, no_offset_units=None) -> bool:
"""Checks if Quantity object can be multiplied or divided"""
is_ok = True
@@ -1964,7 +2031,7 @@ class Quantity(PrettyIPython, SharedRegistryObject):
is_ok = False
return is_ok
- def to_timedelta(self):
+ def to_timedelta(self: Quantity[float]) -> datetime.timedelta:
return datetime.timedelta(microseconds=self.to("microseconds").magnitude)
# Dask.array.Array ducking
@@ -2058,7 +2125,7 @@ class Quantity(PrettyIPython, SharedRegistryObject):
_Quantity = Quantity
-def build_quantity_class(registry):
+def build_quantity_class(registry: BaseRegistry) -> Type[Quantity]:
class Quantity(_Quantity):
_REGISTRY = registry
diff --git a/pint/registry.py b/pint/registry.py
index 42c75c6..599a049 100644
--- a/pint/registry.py
+++ b/pint/registry.py
@@ -33,6 +33,8 @@ The module actually defines 5 registries with different capabilities:
:license: BSD, see LICENSE for more details.
"""
+from __future__ import annotations
+
import copy
import functools
import importlib.resources
@@ -45,10 +47,29 @@ from contextlib import contextmanager
from decimal import Decimal
from fractions import Fraction
from io import StringIO
+from numbers import Number
from tokenize import NAME, NUMBER
+from typing import (
+ TYPE_CHECKING,
+ Any,
+ Callable,
+ ContextManager,
+ Dict,
+ FrozenSet,
+ Iterable,
+ Iterator,
+ List,
+ Optional,
+ Set,
+ Tuple,
+ Type,
+ TypeVar,
+ Union,
+)
from . import registry_helpers, systems
-from .compat import babel_parse, tokenizer
+from ._typing import F, QuantityOrUnitLike
+from .compat import HAS_BABEL, babel_parse, tokenizer
from .context import Context, ContextChain
from .converters import LogarithmicConverter, ScaleConverter
from .definitions import (
@@ -65,6 +86,7 @@ from .errors import (
UndefinedUnitError,
)
from .pint_eval import build_eval_tree
+from .systems import Group, System
from .util import (
ParserHelper,
SourceIterator,
@@ -80,6 +102,21 @@ from .util import (
to_units_container,
)
+if TYPE_CHECKING:
+ from ._typing import UnitLike
+ from .quantity import Quantity
+ from .unit import Unit
+ from .unit import UnitsContainer as UnitsContainerT
+
+ if HAS_BABEL:
+ import babel
+
+ Locale = babel.Locale
+ else:
+ Locale = None
+
+T = TypeVar("T")
+
_BLOCK_RE = re.compile(r"[ (]")
@@ -110,15 +147,15 @@ class RegistryMeta(type):
class RegistryCache:
"""Cache to speed up unit registries"""
- def __init__(self):
+ def __init__(self) -> None:
#: Maps dimensionality (UnitsContainer) to Units (str)
- self.dimensional_equivalents = {}
+ self.dimensional_equivalents: Dict[UnitsContainer, Set[str]] = {}
#: Maps dimensionality (UnitsContainer) to Dimensionality (UnitsContainer)
self.root_units = {}
#: Maps dimensionality (UnitsContainer) to Units (UnitsContainer)
- self.dimensionality = {}
+ self.dimensionality: Dict[UnitsContainer, UnitsContainer] = {}
#: Cache the unit name associated to user input. ('mV' -> 'millivolt')
- self.parse_unit = {}
+ self.parse_unit: Dict[str, UnitsContainer] = {}
class ContextCacheOverlay:
@@ -126,13 +163,17 @@ class ContextCacheOverlay:
active contexts which contain unit redefinitions.
"""
- def __init__(self, registry_cache: RegistryCache):
+ def __init__(self, registry_cache: RegistryCache) -> None:
self.dimensional_equivalents = registry_cache.dimensional_equivalents
self.root_units = {}
self.dimensionality = registry_cache.dimensionality
self.parse_unit = registry_cache.parse_unit
+NON_INT_TYPE = Type[Union[float, Decimal, Fraction]]
+PreprocessorType = Callable[[str], str]
+
+
class BaseRegistry(metaclass=RegistryMeta):
"""Base class for all registries.
@@ -173,22 +214,22 @@ class BaseRegistry(metaclass=RegistryMeta):
#: Map context prefix to function
#: type: Dict[str, (SourceIterator -> None)]
- _parsers = None
+ _parsers: Dict[str, Callable[[SourceIterator], None]] = None
#: Babel.Locale instance or None
- fmt_locale = None
+ fmt_locale: Optional[Locale] = None
def __init__(
self,
filename="",
- force_ndarray=False,
- force_ndarray_like=False,
- on_redefinition="warn",
- auto_reduce_dimensions=False,
- preprocessors=None,
- fmt_locale=None,
- non_int_type=float,
- case_sensitive=True,
+ force_ndarray: bool = False,
+ force_ndarray_like: bool = False,
+ on_redefinition: str = "warn",
+ auto_reduce_dimensions: bool = False,
+ preprocessors: Optional[List[PreprocessorType]] = None,
+ fmt_locale: Optional[str] = None,
+ non_int_type: NON_INT_TYPE = float,
+ case_sensitive: bool = True,
):
self._register_parsers()
self._init_dynamic_classes()
@@ -215,33 +256,35 @@ class BaseRegistry(metaclass=RegistryMeta):
#: Map between name (string) and value (string) of defaults stored in the
#: definitions file.
- self._defaults = {}
+ self._defaults: Dict[str, str] = {}
#: Map dimension name (string) to its definition (DimensionDefinition).
- self._dimensions = {}
+ self._dimensions: Dict[str, DimensionDefinition] = {}
#: Map unit name (string) to its definition (UnitDefinition).
#: Might contain prefixed units.
- self._units = {}
+ self._units: Dict[str, UnitDefinition] = {}
#: Map unit name in lower case (string) to a set of unit names with the right
#: case.
#: Does not contain prefixed units.
#: e.g: 'hz' - > set('Hz', )
- self._units_casei = defaultdict(set)
+ self._units_casei: Dict[str, Set[str]] = defaultdict(set)
#: Map prefix name (string) to its definition (PrefixDefinition).
- self._prefixes = {"": PrefixDefinition("", "", (), 1)}
+ self._prefixes: Dict[str, PrefixDefinition] = {
+ "": PrefixDefinition("", "", (), 1)
+ }
#: Map suffix name (string) to canonical , and unit alias to canonical unit name
- self._suffixes = {"": "", "s": ""}
+ self._suffixes: Dict[str, str] = {"": "", "s": ""}
#: Map contexts to RegistryCache
self._cache = RegistryCache()
self._initialized = False
- def _init_dynamic_classes(self):
+ def _init_dynamic_classes(self) -> None:
"""Generate subclasses on the fly and attach them to self"""
from .unit import build_unit_class
@@ -249,13 +292,13 @@ class BaseRegistry(metaclass=RegistryMeta):
from .quantity import build_quantity_class
- self.Quantity = build_quantity_class(self)
+ self.Quantity: Type["Quantity"] = build_quantity_class(self)
from .measurement import build_measurement_class
self.Measurement = build_measurement_class(self)
- def _after_init(self):
+ def _after_init(self) -> None:
"""This should be called after all __init__"""
if self._filename == "":
@@ -266,17 +309,17 @@ class BaseRegistry(metaclass=RegistryMeta):
self._build_cache()
self._initialized = True
- def _register_parsers(self):
+ def _register_parsers(self) -> None:
self._register_parser("@defaults", self._parse_defaults)
- def _parse_defaults(self, ifile):
+ def _parse_defaults(self, ifile) -> None:
"""Loader for a @default section."""
next(ifile)
for lineno, part in ifile.block_iter():
k, v = part.split("=")
self._defaults[k.strip()] = v.strip()
- def __deepcopy__(self, memo):
+ def __deepcopy__(self, memo) -> "BaseRegistry":
new = object.__new__(type(self))
new.__dict__ = copy.deepcopy(self.__dict__, memo)
new._init_dynamic_classes()
@@ -293,7 +336,7 @@ class BaseRegistry(metaclass=RegistryMeta):
)
return self.parse_expression(item)
- def __contains__(self, item):
+ def __contains__(self, item) -> bool:
"""Support checking prefixed units with the `in` operator"""
try:
self.__getattr__(item)
@@ -301,12 +344,12 @@ class BaseRegistry(metaclass=RegistryMeta):
except UndefinedUnitError:
return False
- def __dir__(self):
+ def __dir__(self) -> List[str]:
#: Calling dir(registry) gives all units, methods, and attributes.
#: Also used for autocompletion in IPython.
return list(self._units.keys()) + list(object.__dir__(self))
- def __iter__(self):
+ def __iter__(self) -> Iterator[str]:
"""Allows for listing all units in registry with `list(ureg)`.
Returns
@@ -315,7 +358,7 @@ class BaseRegistry(metaclass=RegistryMeta):
"""
return iter(sorted(self._units.keys()))
- def set_fmt_locale(self, loc):
+ def set_fmt_locale(self, loc: Optional[str]) -> None:
"""Change the locale used by default by `format_babel`.
Parameters
@@ -332,20 +375,20 @@ class BaseRegistry(metaclass=RegistryMeta):
self.fmt_locale = loc
- def UnitsContainer(self, *args, **kwargs):
+ def UnitsContainer(self, *args, **kwargs) -> UnitsContainerT:
return UnitsContainer(*args, non_int_type=self.non_int_type, **kwargs)
@property
- def default_format(self):
+ def default_format(self) -> str:
"""Default formatting string for quantities."""
return self.Quantity.default_format
@default_format.setter
- def default_format(self, value):
+ def default_format(self, value: str):
self.Unit.default_format = value
self.Quantity.default_format = value
- def define(self, definition):
+ def define(self, definition: Union[str, Definition]) -> None:
"""Add unit to the registry.
Parameters
@@ -360,7 +403,7 @@ class BaseRegistry(metaclass=RegistryMeta):
else:
self._define(definition)
- def _define(self, definition):
+ def _define(self, definition: Definition) -> Tuple[Definition, dict, dict]:
"""Add unit to the registry.
This method defines only multiplicative units, converting any other type
@@ -509,7 +552,7 @@ class BaseRegistry(metaclass=RegistryMeta):
else:
raise ValueError("Prefix directives must start with '@'")
- def load_definitions(self, file, is_resource=False):
+ def load_definitions(self, file, is_resource: bool = False) -> None:
"""Add units and prefixes defined in a definition text file.
Parameters
@@ -584,7 +627,7 @@ class BaseRegistry(metaclass=RegistryMeta):
except Exception as ex:
logger.error("In line {}, cannot add '{}' {}".format(no, line, ex))
- def _build_cache(self):
+ def _build_cache(self) -> None:
"""Build a cache of dimensionality and base units."""
self._cache = RegistryCache()
@@ -621,7 +664,9 @@ class BaseRegistry(metaclass=RegistryMeta):
except Exception as exc:
logger.warning(f"Could not resolve {unit_name}: {exc!r}")
- def get_name(self, name_or_alias, case_sensitive=None):
+ def get_name(
+ self, name_or_alias: str, case_sensitive: Optional[bool] = None
+ ) -> str:
"""Return the canonical name of a unit."""
if name_or_alias == "dimensionless":
@@ -659,7 +704,9 @@ class BaseRegistry(metaclass=RegistryMeta):
return unit_name
- def get_symbol(self, name_or_alias, case_sensitive=None):
+ def get_symbol(
+ self, name_or_alias: str, case_sensitive: Optional[bool] = None
+ ) -> str:
"""Return the preferred alias for a unit."""
candidates = self.parse_unit_name(name_or_alias, case_sensitive)
if not candidates:
@@ -675,10 +722,10 @@ class BaseRegistry(metaclass=RegistryMeta):
return self._prefixes[prefix].symbol + self._units[unit_name].symbol
- def _get_symbol(self, name):
+ def _get_symbol(self, name: str) -> str:
return self._units[name].symbol
- def get_dimensionality(self, input_units):
+ def get_dimensionality(self, input_units) -> UnitsContainerT:
"""Convert unit or dict of units or dimensions to a dict of base dimensions
dimensions
"""
@@ -689,7 +736,9 @@ class BaseRegistry(metaclass=RegistryMeta):
return self._get_dimensionality(input_units)
- def _get_dimensionality(self, input_units):
+ def _get_dimensionality(
+ self, input_units: Optional[UnitsContainerT]
+ ) -> UnitsContainerT:
"""Convert a UnitsContainer to base dimensions."""
if not input_units:
return self.UnitsContainer()
@@ -757,7 +806,9 @@ class BaseRegistry(metaclass=RegistryMeta):
return first
return None
- def get_root_units(self, input_units, check_nonmult=True):
+ def get_root_units(
+ self, input_units: UnitLike, check_nonmult: bool = True
+ ) -> Tuple[Number, Unit]:
"""Convert unit or dict of units to the root units.
If any unit is non multiplicative and check_converter is True,
@@ -868,7 +919,9 @@ class BaseRegistry(metaclass=RegistryMeta):
if reg.reference is not None:
self._get_root_units_recurse(reg.reference, exp2, accumulators)
- def get_compatible_units(self, input_units, group_or_system=None):
+ def get_compatible_units(
+ self, input_units, group_or_system=None
+ ) -> FrozenSet["Unit"]:
""" """
input_units = to_units_container(input_units)
@@ -884,7 +937,9 @@ class BaseRegistry(metaclass=RegistryMeta):
src_dim = self._get_dimensionality(input_units)
return self._cache.dimensional_equivalents[src_dim]
- def is_compatible_with(self, obj1, obj2, *contexts, **ctx_kwargs):
+ def is_compatible_with(
+ self, obj1: Any, obj2: Any, *contexts: Union[str, Context], **ctx_kwargs
+ ) -> bool:
"""check if the other object is compatible
Parameters
@@ -911,7 +966,13 @@ class BaseRegistry(metaclass=RegistryMeta):
return not isinstance(obj2, (self.Quantity, self.Unit))
- def convert(self, value, src, dst, inplace=False):
+ def convert(
+ self,
+ value: T,
+ src: QuantityOrUnitLike,
+ dst: QuantityOrUnitLike,
+ inplace: bool = False,
+ ) -> T:
"""Convert value from some source to destination units.
Parameters
@@ -991,7 +1052,9 @@ class BaseRegistry(metaclass=RegistryMeta):
return value
- def parse_unit_name(self, unit_name, case_sensitive=None):
+ def parse_unit_name(
+ self, unit_name: str, case_sensitive: Optional[bool] = None
+ ) -> Tuple[Tuple[str, str, str], ...]:
"""Parse a unit to identify prefix, unit name and suffix
by walking the list of prefix and suffix.
In case of equivalent combinations (e.g. ('kilo', 'gram', '') and
@@ -1014,7 +1077,9 @@ class BaseRegistry(metaclass=RegistryMeta):
self._parse_unit_name(unit_name, case_sensitive=case_sensitive)
)
- def _parse_unit_name(self, unit_name, case_sensitive=None):
+ def _parse_unit_name(
+ self, unit_name: str, case_sensitive: Optional[bool] = None
+ ) -> Iterator[Tuple[str, str, str]]:
"""Helper of parse_unit_name."""
case_sensitive = (
self.case_sensitive if case_sensitive is None else case_sensitive
@@ -1044,7 +1109,9 @@ class BaseRegistry(metaclass=RegistryMeta):
)
@staticmethod
- def _dedup_candidates(candidates):
+ def _dedup_candidates(
+ candidates: Iterable[Tuple[str, str, str]]
+ ) -> Tuple[Tuple[str, str, str], ...]:
"""Helper of parse_unit_name.
Given an iterable of unit triplets (prefix, name, suffix), remove those with
@@ -1062,7 +1129,12 @@ class BaseRegistry(metaclass=RegistryMeta):
candidates.pop(("", cp + cu, ""), None)
return tuple(candidates)
- def parse_units(self, input_string, as_delta=None, case_sensitive=None):
+ def parse_units(
+ self,
+ input_string: str,
+ as_delta: Optional[bool] = None,
+ case_sensitive: Optional[bool] = None,
+ ) -> Unit:
"""Parse a units expression and returns a UnitContainer with
the canonical names.
@@ -1080,6 +1152,7 @@ class BaseRegistry(metaclass=RegistryMeta):
Returns
-------
+ pint.Unit
"""
for p in self.preprocessors:
@@ -1159,8 +1232,13 @@ class BaseRegistry(metaclass=RegistryMeta):
raise Exception("unknown token type")
def parse_pattern(
- self, input_string, pattern, case_sensitive=None, use_decimal=False, many=False
- ):
+ self,
+ input_string: str,
+ pattern: str,
+ case_sensitive: Optional[bool] = None,
+ use_decimal: bool = False,
+ many: bool = False,
+ ) -> Union[List[str], str, None]:
"""Parse a string with a given regex pattern and returns result.
Parameters
@@ -1215,8 +1293,12 @@ class BaseRegistry(metaclass=RegistryMeta):
return results
def parse_expression(
- self, input_string, case_sensitive=None, use_decimal=False, **values
- ):
+ self,
+ input_string: str,
+ case_sensitive: Optional[bool] = None,
+ use_decimal: bool = False,
+ **values,
+ ) -> Quantity:
"""Parse a mathematical expression including units and return a quantity object.
Numerical constants can be specified as keyword arguments and will take precedence
@@ -1280,8 +1362,11 @@ class NonMultiplicativeRegistry(BaseRegistry):
"""
def __init__(
- self, default_as_delta=True, autoconvert_offset_to_baseunit=False, **kwargs
- ):
+ self,
+ default_as_delta: bool = True,
+ autoconvert_offset_to_baseunit: bool = False,
+ **kwargs: Any,
+ ) -> None:
super().__init__(**kwargs)
#: When performing a multiplication of units, interpret
@@ -1292,14 +1377,19 @@ class NonMultiplicativeRegistry(BaseRegistry):
# base units on multiplication and division.
self.autoconvert_offset_to_baseunit = autoconvert_offset_to_baseunit
- def _parse_units(self, input_string, as_delta=None, case_sensitive=None):
+ def _parse_units(
+ self,
+ input_string: str,
+ as_delta: Optional[bool] = None,
+ case_sensitive: Optional[bool] = None,
+ ):
""" """
if as_delta is None:
as_delta = self.default_as_delta
return super()._parse_units(input_string, as_delta, case_sensitive)
- def _define(self, definition):
+ def _define(self, definition: Union[str, Definition]):
"""Add unit to the registry.
In addition to what is done by the BaseRegistry,
@@ -1325,7 +1415,7 @@ class NonMultiplicativeRegistry(BaseRegistry):
return definition, d, di
- def _is_multiplicative(self, u):
+ def _is_multiplicative(self, u) -> bool:
if u in self._units:
return self._units[u].is_multiplicative
@@ -1473,9 +1563,9 @@ class ContextRegistry(BaseRegistry):
- Parse @context directive.
"""
- def __init__(self, **kwargs):
+ def __init__(self, **kwargs: Any) -> None:
# Map context name (string) or abbreviation to context.
- self._contexts = {}
+ self._contexts: Dict[str, Context] = {}
# Stores active contexts.
self._active_ctx = ContextChain()
# Map context chain to cache
@@ -1488,11 +1578,11 @@ class ContextRegistry(BaseRegistry):
# Allow contexts to add override layers to the units
self._units = ChainMap(self._units)
- def _register_parsers(self):
+ def _register_parsers(self) -> None:
super()._register_parsers()
self._register_parser("@context", self._parse_context)
- def _parse_context(self, ifile):
+ def _parse_context(self, ifile) -> None:
try:
self.add_context(
Context.from_lines(
@@ -1624,7 +1714,9 @@ class ContextRegistry(BaseRegistry):
# Write into the context-specific self._units.maps[0] and self._cache.root_units
self.define(definition)
- def enable_contexts(self, *names_or_contexts, **kwargs) -> None:
+ def enable_contexts(
+ self, *names_or_contexts: Union[str, Context], **kwargs
+ ) -> None:
"""Enable contexts provided by name or by object.
Parameters
@@ -1664,10 +1756,10 @@ class ContextRegistry(BaseRegistry):
ctx.checked = True
# and create a new one with the new defaults.
- ctxs = tuple(Context.from_context(ctx, **kwargs) for ctx in ctxs)
+ contexts = tuple(Context.from_context(ctx, **kwargs) for ctx in ctxs)
# Finally we add them to the active context.
- self._active_ctx.insert_contexts(*ctxs)
+ self._active_ctx.insert_contexts(*contexts)
self._switch_context_cache_and_units()
def disable_contexts(self, n: int = None) -> None:
@@ -1682,7 +1774,7 @@ class ContextRegistry(BaseRegistry):
self._switch_context_cache_and_units()
@contextmanager
- def context(self, *names, **kwargs):
+ def context(self, *names, **kwargs) -> ContextManager[Context]:
"""Used as a context manager, this function enables to activate a context
which is removed after usage.
@@ -1739,7 +1831,7 @@ class ContextRegistry(BaseRegistry):
# the added contexts are removed from the active one.
self.disable_contexts(len(names))
- def with_context(self, name, **kwargs):
+ def with_context(self, name, **kwargs) -> Callable[[F], F]:
"""Decorator to wrap a function call in a Pint context.
Use it to ensure that a certain context is active when
@@ -1858,23 +1950,23 @@ class SystemRegistry(BaseRegistry):
#: Map system name to system.
#: :type: dict[ str | System]
- self._systems = {}
+ self._systems: Dict[str, System] = {}
#: Maps dimensionality (UnitsContainer) to Dimensionality (UnitsContainer)
self._base_units_cache = dict()
#: Map group name to group.
#: :type: dict[ str | Group]
- self._groups = {}
+ self._groups: Dict[str, Group] = {}
self._groups["root"] = self.Group("root")
self._default_system = system
- def _init_dynamic_classes(self):
+ def _init_dynamic_classes(self) -> None:
super()._init_dynamic_classes()
self.Group = systems.build_group_class(self)
self.System = systems.build_system_class(self)
- def _after_init(self):
+ def _after_init(self) -> None:
"""Invoked at the end of ``__init__``.
- Create default group and add all orphan units to it
@@ -1901,20 +1993,20 @@ class SystemRegistry(BaseRegistry):
"system", None
)
- def _register_parsers(self):
+ def _register_parsers(self) -> None:
super()._register_parsers()
self._register_parser("@group", self._parse_group)
self._register_parser("@system", self._parse_system)
- def _parse_group(self, ifile):
+ def _parse_group(self, ifile) -> None:
self.Group.from_lines(ifile.block_iter(), self.define, self.non_int_type)
- def _parse_system(self, ifile):
+ def _parse_system(self, ifile) -> None:
self.System.from_lines(
ifile.block_iter(), self.get_root_units, self.non_int_type
)
- def get_group(self, name, create_if_needed=True):
+ def get_group(self, name: str, create_if_needed: bool = True) -> Group:
"""Return a Group.
Parameters
@@ -1943,7 +2035,7 @@ class SystemRegistry(BaseRegistry):
return systems.Lister(self._systems)
@property
- def default_system(self):
+ def default_system(self) -> System:
return self._default_system
@default_system.setter
@@ -1956,7 +2048,7 @@ class SystemRegistry(BaseRegistry):
self._default_system = name
- def get_system(self, name, create_if_needed=True):
+ def get_system(self, name: str, create_if_needed: bool = True) -> System:
"""Return a Group.
Parameters
@@ -1994,7 +2086,12 @@ class SystemRegistry(BaseRegistry):
return definition, d, di
- def get_base_units(self, input_units, check_nonmult=True, system=None):
+ def get_base_units(
+ self,
+ input_units: Union[UnitLike, Quantity],
+ check_nonmult: bool = True,
+ system: Union[str, System, None] = None,
+ ) -> Tuple[Number, Unit]:
"""Convert unit or dict of units to the base units.
If any unit is non multiplicative and check_converter is True,
@@ -2027,7 +2124,12 @@ class SystemRegistry(BaseRegistry):
return f, self.Unit(units)
- def _get_base_units(self, input_units, check_nonmult=True, system=None):
+ def _get_base_units(
+ self,
+ input_units: UnitsContainerT,
+ check_nonmult: bool = True,
+ system: Union[str, System, None] = None,
+ ):
if system is None:
system = self._default_system
@@ -2068,7 +2170,7 @@ class SystemRegistry(BaseRegistry):
return base_factor, destination_units
- def _get_compatible_units(self, input_units, group_or_system):
+ def _get_compatible_units(self, input_units, group_or_system) -> FrozenSet[Unit]:
if group_or_system is None:
group_or_system = self._default_system
@@ -2126,17 +2228,17 @@ class UnitRegistry(SystemRegistry, ContextRegistry, NonMultiplicativeRegistry):
def __init__(
self,
filename="",
- force_ndarray=False,
- force_ndarray_like=False,
- default_as_delta=True,
- autoconvert_offset_to_baseunit=False,
- on_redefinition="warn",
+ force_ndarray: bool = False,
+ force_ndarray_like: bool = False,
+ default_as_delta: bool = True,
+ autoconvert_offset_to_baseunit: bool = False,
+ on_redefinition: str = "warn",
system=None,
auto_reduce_dimensions=False,
preprocessors=None,
fmt_locale=None,
non_int_type=float,
- case_sensitive=True,
+ case_sensitive: bool = True,
):
super().__init__(
@@ -2170,7 +2272,7 @@ class UnitRegistry(SystemRegistry, ContextRegistry, NonMultiplicativeRegistry):
"""
return pi_theorem(quantities, self)
- def setup_matplotlib(self, enable=True):
+ def setup_matplotlib(self, enable: bool = True) -> None:
"""Set up handlers for matplotlib's unit support.
Parameters
diff --git a/pint/registry_helpers.py b/pint/registry_helpers.py
index b8888c8..7f6ee7f 100644
--- a/pint/registry_helpers.py
+++ b/pint/registry_helpers.py
@@ -11,10 +11,19 @@
import functools
from inspect import signature
from itertools import zip_longest
+from typing import TYPE_CHECKING, Callable, Iterable, TypeVar, Union
+from ._typing import F
from .errors import DimensionalityError
+from .quantity import Quantity
from .util import UnitsContainer, to_units_container
+if TYPE_CHECKING:
+ from .registry import UnitRegistry
+ from .unit import Unit
+
+T = TypeVar("T")
+
def _replace_units(original_units, values_by_name):
"""Convert a unit compatible type to a UnitsContainer.
@@ -175,7 +184,12 @@ def _apply_defaults(func, args, kwargs):
return args, {}
-def wraps(ureg, ret, args, strict=True):
+def wraps(
+ ureg: "UnitRegistry",
+ ret: Union[str, "Unit", Iterable[str], Iterable["Unit"], None],
+ args: Union[str, "Unit", Iterable[str], Iterable["Unit"], None],
+ strict: bool = True,
+) -> Callable[[Callable[..., T]], Callable[..., Quantity[T]]]:
"""Wraps a function to become pint-aware.
Use it when a function requires a numerical value but in some specific
@@ -239,7 +253,7 @@ def wraps(ureg, ret, args, strict=True):
)
ret = _to_units_container(ret, ureg)
- def decorator(func):
+ def decorator(func: Callable[..., T]) -> Callable[..., Quantity[T]]:
count_params = len(signature(func).parameters)
if len(args) != count_params:
@@ -256,7 +270,7 @@ def wraps(ureg, ret, args, strict=True):
)
@functools.wraps(func, assigned=assigned, updated=updated)
- def wrapper(*values, **kw):
+ def wrapper(*values, **kw) -> Quantity[T]:
values, kw = _apply_defaults(func, values, kw)
@@ -288,7 +302,9 @@ def wraps(ureg, ret, args, strict=True):
return decorator
-def check(ureg, *args):
+def check(
+ ureg: "UnitRegistry", *args: Union[str, UnitsContainer, "Unit", None]
+) -> Callable[[F], F]:
"""Decorator to for quantity type checking for function inputs.
Use it to ensure that the decorated function input parameters match
diff --git a/pint/unit.py b/pint/unit.py
index 5208eab..f91b6c1 100644
--- a/pint/unit.py
+++ b/pint/unit.py
@@ -8,23 +8,30 @@
:license: BSD, see LICENSE for more details.
"""
+from __future__ import annotations
+
import copy
import locale
import operator
from numbers import Number
+from typing import TYPE_CHECKING, Any, Type, Union
+from ._typing import UnitLike
from .compat import NUMERIC_TYPES, babel_parse, is_upcast_type
from .definitions import UnitDefinition
from .errors import DimensionalityError
from .formatting import siunitx_format_unit
from .util import PrettyIPython, SharedRegistryObject, UnitsContainer
+if TYPE_CHECKING:
+ from .context import Context
+
class Unit(PrettyIPython, SharedRegistryObject):
"""Implements a class to describe a unit supporting math operations."""
#: Default formatting string.
- default_format = ""
+ default_format: str = ""
def __reduce__(self):
# See notes in Quantity.__reduce__
@@ -32,7 +39,7 @@ class Unit(PrettyIPython, SharedRegistryObject):
return _unpickle_unit, (Unit, self._units)
- def __init__(self, units):
+ def __init__(self, units: UnitLike) -> None:
super().__init__()
if isinstance(units, (UnitsContainer, UnitDefinition)):
self._units = units
@@ -50,29 +57,29 @@ class Unit(PrettyIPython, SharedRegistryObject):
self.__handling = None
@property
- def debug_used(self):
+ def debug_used(self) -> Any:
return self.__used
- def __copy__(self):
+ def __copy__(self) -> Unit:
ret = self.__class__(self._units)
ret.__used = self.__used
return ret
- def __deepcopy__(self, memo):
+ def __deepcopy__(self, memo) -> Unit:
ret = self.__class__(copy.deepcopy(self._units, memo))
ret.__used = self.__used
return ret
- def __str__(self):
+ def __str__(self) -> str:
return format(self)
- def __bytes__(self):
+ def __bytes__(self) -> bytes:
return str(self).encode(locale.getpreferredencoding())
- def __repr__(self):
+ def __repr__(self) -> str:
return "<Unit('{}')>".format(self._units)
- def __format__(self, spec):
+ def __format__(self, spec) -> str:
spec = spec or self.default_format
# special cases
if "Lx" in spec: # the LaTeX siunitx code
@@ -93,7 +100,7 @@ class Unit(PrettyIPython, SharedRegistryObject):
return format(units, spec)
- def format_babel(self, spec="", locale=None, **kwspec):
+ def format_babel(self, spec="", locale=None, **kwspec: Any) -> str:
spec = spec or self.default_format
if "~" in spec:
@@ -119,12 +126,12 @@ class Unit(PrettyIPython, SharedRegistryObject):
return units.format_babel(spec, **kwspec)
@property
- def dimensionless(self):
+ def dimensionless(self) -> bool:
"""Return True if the Unit is dimensionless; False otherwise."""
return not bool(self.dimensionality)
@property
- def dimensionality(self):
+ def dimensionality(self) -> UnitsContainer:
"""
Returns
-------
@@ -146,7 +153,9 @@ class Unit(PrettyIPython, SharedRegistryObject):
return self._REGISTRY.get_compatible_units(self)
- def is_compatible_with(self, other, *contexts, **ctx_kwargs):
+ def is_compatible_with(
+ self, other: Any, *contexts: Union[str, Context], **ctx_kwargs: Any
+ ) -> bool:
"""check if the other object is compatible
Parameters
@@ -218,7 +227,7 @@ class Unit(PrettyIPython, SharedRegistryObject):
__div__ = __truediv__
__rdiv__ = __rtruediv__
- def __pow__(self, other):
+ def __pow__(self, other) -> "Unit":
if isinstance(other, NUMERIC_TYPES):
return self.__class__(self._units ** other)
@@ -226,10 +235,10 @@ class Unit(PrettyIPython, SharedRegistryObject):
mess = "Cannot power Unit by {}".format(type(other))
raise TypeError(mess)
- def __hash__(self):
+ def __hash__(self) -> int:
return self._units.__hash__()
- def __eq__(self, other):
+ def __eq__(self, other) -> bool:
# We compare to the base class of Unit because each Unit class is
# unique.
if self._check(other):
@@ -244,10 +253,10 @@ class Unit(PrettyIPython, SharedRegistryObject):
else:
return self._units == other
- def __ne__(self, other):
+ def __ne__(self, other) -> bool:
return not (self == other)
- def compare(self, other, op):
+ def compare(self, other, op) -> bool:
self_q = self._REGISTRY.Quantity(1, self)
if isinstance(other, NUMERIC_TYPES):
@@ -262,13 +271,13 @@ class Unit(PrettyIPython, SharedRegistryObject):
__ge__ = lambda self, other: self.compare(other, op=operator.ge)
__gt__ = lambda self, other: self.compare(other, op=operator.gt)
- def __int__(self):
+ def __int__(self) -> int:
return int(self._REGISTRY.Quantity(1, self._units))
- def __float__(self):
+ def __float__(self) -> float:
return float(self._REGISTRY.Quantity(1, self._units))
- def __complex__(self):
+ def __complex__(self) -> complex:
return complex(self._REGISTRY.Quantity(1, self._units))
__array_priority__ = 17
@@ -361,7 +370,7 @@ class Unit(PrettyIPython, SharedRegistryObject):
_Unit = Unit
-def build_unit_class(registry):
+def build_unit_class(registry) -> Type[Unit]:
class Unit(_Unit):
_REGISTRY = registry
diff --git a/pint/util.py b/pint/util.py
index f2162f4..18474ec 100644
--- a/pint/util.py
+++ b/pint/util.py
@@ -8,6 +8,8 @@
:license: BSD, see LICENSE for more details.
"""
+from __future__ import annotations
+
import logging
import math
import operator
@@ -18,12 +20,19 @@ from functools import lru_cache, partial
from logging import NullHandler
from numbers import Number
from token import NAME, NUMBER
+from typing import TYPE_CHECKING, ClassVar, Optional, Union
from .compat import NUMERIC_TYPES, tokenizer
from .errors import DefinitionSyntaxError
from .formatting import format_unit
from .pint_eval import build_eval_tree
+if TYPE_CHECKING:
+ from ._typing import UnitLike
+ from .quantity import Quantity
+ from .registry import BaseRegistry
+
+
logger = logging.getLogger(__name__)
logger.addHandler(NullHandler())
@@ -321,7 +330,7 @@ class UnitsContainer(Mapping):
__slots__ = ("_d", "_hash", "_one", "_non_int_type")
- def __init__(self, *args, **kwargs):
+ def __init__(self, *args, **kwargs) -> None:
if args and isinstance(args[0], UnitsContainer):
default_non_int_type = args[0]._non_int_type
else:
@@ -398,7 +407,7 @@ class UnitsContainer(Mapping):
def __iter__(self):
return iter(self._d)
- def __len__(self):
+ def __len__(self) -> int:
return len(self._d)
def __getitem__(self, key):
@@ -419,7 +428,7 @@ class UnitsContainer(Mapping):
def __setstate__(self, state):
self._d, self._hash, self._one, self._non_int_type = state
- def __eq__(self, other):
+ def __eq__(self, other) -> bool:
if isinstance(other, UnitsContainer):
# UnitsContainer.__hash__(self) is not the same as hash(self); see
# ParserHelper.__hash__ and __eq__.
@@ -440,19 +449,19 @@ class UnitsContainer(Mapping):
return dict.__eq__(self._d, other)
- def __str__(self):
+ def __str__(self) -> str:
return self.__format__("")
- def __repr__(self):
+ def __repr__(self) -> str:
tmp = "{%s}" % ", ".join(
["'{}': {}".format(key, value) for key, value in sorted(self._d.items())]
)
return "<UnitsContainer({})>".format(tmp)
- def __format__(self, spec):
+ def __format__(self, spec: str) -> str:
return format_unit(self, spec)
- def format_babel(self, spec, **kwspec):
+ def format_babel(self, spec: str, **kwspec) -> str:
return format_unit(self, spec, **kwspec)
def __copy__(self):
@@ -742,7 +751,7 @@ class ParserHelper(UnitsContainer):
#: List of regex substitution pairs.
-_subs_re = [
+_subs_re_list = [
("\N{DEGREE SIGN}", " degree"),
(r"([\w\.\-\+\*\\\^])\s+", r"\1 "), # merge multiple spaces
(r"({}) squared", r"\1**2"), # Handle square and cube
@@ -758,12 +767,14 @@ _subs_re = [
]
#: Compiles the regex and replace {} by a regex that matches an identifier.
-_subs_re = [(re.compile(a.format(r"[_a-zA-Z][_a-zA-Z0-9]*")), b) for a, b in _subs_re]
+_subs_re = [
+ (re.compile(a.format(r"[_a-zA-Z][_a-zA-Z0-9]*")), b) for a, b in _subs_re_list
+]
_pretty_table = str.maketrans("⁰¹²³⁴⁵⁶⁷⁸⁹·⁻", "0123456789*-")
_pretty_exp_re = re.compile(r"⁻?[⁰¹²³⁴⁵⁶⁷⁸⁹]+(?:\.[⁰¹²³⁴⁵⁶⁷⁸⁹]*)?")
-def string_preprocessor(input_string):
+def string_preprocessor(input_string: str) -> str:
input_string = input_string.replace(",", "")
input_string = input_string.replace(" per ", "/")
@@ -781,7 +792,7 @@ def string_preprocessor(input_string):
return input_string
-def _is_dim(name):
+def _is_dim(name: str) -> bool:
return name[0] == "[" and name[-1] == "]"
@@ -799,6 +810,9 @@ class SharedRegistryObject:
"""
+ _REGISTRY: ClassVar[BaseRegistry]
+ _units: UnitsContainer
+
def __new__(cls, *args, **kwargs):
inst = object.__new__(cls)
if not hasattr(cls, "_REGISTRY"):
@@ -809,7 +823,7 @@ class SharedRegistryObject:
inst._REGISTRY = _APP_REGISTRY
return inst
- def _check(self, other):
+ def _check(self, other) -> bool:
"""Check if the other object use a registry and if so that it is the
same registry.
@@ -840,6 +854,8 @@ class SharedRegistryObject:
class PrettyIPython:
"""Mixin to add pretty-printers for IPython"""
+ default_format: str
+
def _repr_html_(self):
if "~" in self.default_format:
return "{:~H}".format(self)
@@ -859,7 +875,9 @@ class PrettyIPython:
p.text("{:P}".format(self))
-def to_units_container(unit_like, registry=None):
+def to_units_container(
+ unit_like: Union[UnitLike, Quantity], registry: Optional[BaseRegistry] = None
+) -> UnitsContainer:
"""Convert a unit compatible type to a UnitsContainer.
Parameters
@@ -1014,7 +1032,7 @@ class BlockIterator(SourceIterator):
next = __next__
-def iterable(y):
+def iterable(y) -> bool:
"""Check whether or not an object can be iterated over.
Vendored from numpy under the terms of the BSD 3-Clause License. (Copyright
@@ -1036,7 +1054,7 @@ def iterable(y):
return True
-def sized(y):
+def sized(y) -> bool:
"""Check whether or not an object has a defined length.
Parameters
diff --git a/setup.cfg b/setup.cfg
index 6f10386..fc811ed 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -44,7 +44,7 @@ test =
pytest-subtests
[options.package_data]
-pint = default_en.txt; constants_en.txt
+pint = default_en.txt; constants_en.txt; py.typed
[check-manifest]
ignore =