diff options
author | Jules Chéron <jules.cheron@gmail.com> | 2021-05-22 23:03:34 +0200 |
---|---|---|
committer | Jules Chéron <jules.cheron@gmail.com> | 2021-08-01 20:22:11 +0200 |
commit | c60e1af833391a60a82f3c47460fca899e220de5 (patch) | |
tree | c9fbae3c8633e8d3cf6d36a7e1e6d91285394c14 | |
parent | 78769ec967a24cd57c9fb6474aff2aa3708ca483 (diff) | |
download | pint-c60e1af833391a60a82f3c47460fca899e220de5.tar.gz |
Add pint typing module
- Quantity as Generic class
- Add overloaded signature for __new__ Quantity
- Add typing module as private
- Add py.typed for PEP561 supports
- Add overloaded signature for __new__ Quantity
- Quantity as Generic class
- Add type hints throughout the project
- Add py.typed in package data in setup.cfg
- Add type hints for decorators
- Add type hints for public API of registry.py
- Add type hints for units.py
-rw-r--r-- | .coveragerc | 16 | ||||
-rw-r--r-- | CHANGES | 2 | ||||
-rw-r--r-- | pint/_typing.py | 18 | ||||
-rw-r--r-- | pint/context.py | 33 | ||||
-rw-r--r-- | pint/definitions.py | 79 | ||||
-rw-r--r-- | pint/formatting.py | 3 | ||||
-rw-r--r-- | pint/py.typed | 0 | ||||
-rw-r--r-- | pint/quantity.py | 199 | ||||
-rw-r--r-- | pint/registry.py | 282 | ||||
-rw-r--r-- | pint/registry_helpers.py | 24 | ||||
-rw-r--r-- | pint/unit.py | 53 | ||||
-rw-r--r-- | pint/util.py | 48 | ||||
-rw-r--r-- | setup.cfg | 2 |
13 files changed, 529 insertions, 230 deletions
diff --git a/.coveragerc b/.coveragerc index 73fab9e..fbb079e 100644 --- a/.coveragerc +++ b/.coveragerc @@ -1,3 +1,19 @@ [run] omit = pint/testsuite/* +[report] +# Regexes for lines to exclude from consideration +exclude_lines = + # Have to re-enable the standard pragma + pragma: no cover + + # Don't complain about missing debug-only code: + def __repr__ + + # Don't complain if tests don't hit defensive assertion code: + raise AssertionError + raise NotImplementedError + AbstractMethodError + + # Don't complain if non-runnable code isn't run: + if TYPE_CHECKING:
\ No newline at end of file @@ -19,6 +19,8 @@ Pint Changelog - pint no longer supports Python 3.6 - Minimum Numpy version supported is 1.17+ +- Add supports for type hints for Quantity class. Quantity is now a Generic (PEP560). +- Add support for [PEP561](https://www.python.org/dev/peps/pep-0561/) (Package Type information) 0.17 (2021-03-22) diff --git a/pint/_typing.py b/pint/_typing.py new file mode 100644 index 0000000..dbeca15 --- /dev/null +++ b/pint/_typing.py @@ -0,0 +1,18 @@ +from typing import TYPE_CHECKING, Any, Callable, Tuple, TypeVar, Union + +if TYPE_CHECKING: + from .quantity import Quantity + from .unit import Unit + from .util import UnitsContainer + +UnitLike = Union[str, "UnitsContainer", "Unit"] + +QuantityOrUnitLike = Union["Quantity", UnitLike] + +Shape = Tuple[int, ...] + +_MagnitudeType = TypeVar("_MagnitudeType") +S = TypeVar("S") + +FuncType = Callable[..., Any] +F = TypeVar("F", bound=FuncType) diff --git a/pint/context.py b/pint/context.py index 6cd440e..59ea9cf 100644 --- a/pint/context.py +++ b/pint/context.py @@ -8,14 +8,22 @@ :license: BSD, see LICENSE for more details. """ +from __future__ import annotations + import re import weakref from collections import ChainMap, defaultdict +from typing import TYPE_CHECKING, Any, Callable, Optional, Tuple from .definitions import Definition, UnitDefinition from .errors import DefinitionSyntaxError from .util import ParserHelper, SourceIterator, to_units_container +if TYPE_CHECKING: + from .quantity import Quantity + from .registry import UnitRegistry + from .util import UnitsContainer + #: Regex to match the header parts of a context. _header_re = re.compile( r"@context\s*(?P<defaults>\(.*\))?\s+(?P<name>\w+)\s*(=(?P<aliases>.*))*" @@ -25,8 +33,8 @@ _header_re = re.compile( _varname_re = re.compile(r"[A-Za-z_][A-Za-z0-9_]*") -def _expression_to_function(eq): - def func(ureg, value, **kwargs): +def _expression_to_function(eq: str) -> Callable[..., Quantity[Any]]: + def func(ureg: UnitRegistry, value: Any, **kwargs: Any) -> Quantity[Any]: return ureg.parse_expression(eq, value=value, **kwargs) return func @@ -84,7 +92,12 @@ class Context: >>> c.redefine("pound = 0.5 kg") """ - def __init__(self, name=None, aliases=(), defaults=None): + def __init__( + self, + name: Optional[str] = None, + aliases: Tuple[str, ...] = (), + defaults: Optional[dict] = None, + ) -> None: self.name = name self.aliases = aliases @@ -106,7 +119,7 @@ class Context: self.relation_to_context = weakref.WeakValueDictionary() @classmethod - def from_context(cls, context, **defaults): + def from_context(cls, context: Context, **defaults) -> Context: """Creates a new context that shares the funcs dictionary with the original context. The default values are copied from the original context and updated with the new defaults. @@ -135,7 +148,7 @@ class Context: return context @classmethod - def from_lines(cls, lines, to_base_func=None, non_int_type=float): + def from_lines(cls, lines, to_base_func=None, non_int_type=float) -> Context: lines = SourceIterator(lines) lineno, header = next(lines) @@ -223,14 +236,14 @@ class Context: return ctx - def add_transformation(self, src, dst, func): + def add_transformation(self, src, dst, func) -> None: """Add a transformation function to the context.""" _key = self.__keytransform__(src, dst) self.funcs[_key] = func self.relation_to_context[_key] = self - def remove_transformation(self, src, dst): + def remove_transformation(self, src, dst) -> None: """Add a transformation function to the context.""" _key = self.__keytransform__(src, dst) @@ -238,7 +251,7 @@ class Context: del self.relation_to_context[_key] @staticmethod - def __keytransform__(src, dst): + def __keytransform__(src, dst) -> Tuple[UnitsContainer, UnitsContainer]: return to_units_container(src), to_units_container(dst) def transform(self, src, dst, registry, value): @@ -270,7 +283,9 @@ class Context: raise DefinitionSyntaxError("Can't define base units within a context") self.redefinitions.append(d) - def hashable(self): + def hashable( + self, + ) -> Tuple[Optional[str], Tuple[str, ...], frozenset, frozenset, tuple]: """Generate a unique hashable and comparable representation of self, which can be used as a key in a dict. This class cannot define ``__hash__`` because it is mutable, and the Python interpreter does cache the output of ``__hash__``. diff --git a/pint/definitions.py b/pint/definitions.py index 7e30c89..f02157b 100644 --- a/pint/definitions.py +++ b/pint/definitions.py @@ -8,9 +8,12 @@ :license: BSD, see LICENSE for more details. """ +from __future__ import annotations + from collections import namedtuple +from typing import Callable, Iterable, Optional, Union -from .converters import LogarithmicConverter, OffsetConverter, ScaleConverter +from .converters import Converter, LogarithmicConverter, OffsetConverter, ScaleConverter from .errors import DefinitionSyntaxError from .util import ParserHelper, UnitsContainer, _is_dim @@ -42,7 +45,7 @@ class PreprocessedDefinition( """ @classmethod - def from_string(cls, definition): + def from_string(cls, definition: str) -> PreprocessedDefinition: name, definition = definition.split("=", 1) name = name.strip() @@ -64,7 +67,7 @@ class _NotNumeric(Exception): self.value = value -def numeric_parse(s, non_int_type=float): +def numeric_parse(s: str, non_int_type: type = float): """Try parse a string into a number (without using eval). Parameters @@ -103,7 +106,13 @@ class Definition: converter : callable or Converter or None """ - def __init__(self, name, symbol, aliases, converter): + def __init__( + self, + name: str, + symbol: Optional[str], + aliases: Iterable[str], + converter: Optional[Union[Callable, Converter]], + ): if isinstance(converter, str): raise TypeError( @@ -112,19 +121,21 @@ class Definition: self._name = name self._symbol = symbol - self._aliases = aliases + self._aliases = tuple(aliases) self._converter = converter @property - def is_multiplicative(self): + def is_multiplicative(self) -> bool: return self._converter.is_multiplicative @property - def is_logarithmic(self): + def is_logarithmic(self) -> bool: return self._converter.is_logarithmic @classmethod - def from_string(cls, definition, non_int_type=float): + def from_string( + cls, definition: Union[str, PreprocessedDefinition], non_int_type: type = float + ) -> "Definition": """Parse a definition. Parameters @@ -150,30 +161,30 @@ class Definition: return UnitDefinition.from_string(definition, non_int_type) @property - def name(self): + def name(self) -> str: return self._name @property - def symbol(self): + def symbol(self) -> str: return self._symbol or self._name @property - def has_symbol(self): + def has_symbol(self) -> bool: return bool(self._symbol) @property - def aliases(self): + def aliases(self) -> Iterable[str]: return self._aliases - def add_aliases(self, *alias): + def add_aliases(self, *alias: str) -> None: alias = tuple(a for a in alias if a not in self._aliases) self._aliases = self._aliases + alias @property - def converter(self): + def converter(self) -> Converter: return self._converter - def __str__(self): + def __str__(self) -> str: return self.name @@ -188,7 +199,9 @@ class PrefixDefinition(Definition): """ @classmethod - def from_string(cls, definition, non_int_type=float): + def from_string( + cls, definition: Union[str, PreprocessedDefinition], non_int_type: type = float + ) -> "PrefixDefinition": if isinstance(definition, str): definition = PreprocessedDefinition.from_string(definition) @@ -226,14 +239,24 @@ class UnitDefinition(Definition): """ - def __init__(self, name, symbol, aliases, converter, reference=None, is_base=False): + def __init__( + self, + name: str, + symbol: Optional[str], + aliases: Iterable[str], + converter: Converter, + reference: Optional[UnitsContainer] = None, + is_base: bool = False, + ) -> None: self.reference = reference self.is_base = is_base super().__init__(name, symbol, aliases, converter) @classmethod - def from_string(cls, definition, non_int_type=float): + def from_string( + cls, definition: Union[str, PreprocessedDefinition], non_int_type: type = float + ) -> "UnitDefinition": if isinstance(definition, str): definition = PreprocessedDefinition.from_string(definition) @@ -305,14 +328,24 @@ class DimensionDefinition(Definition): [density] = [mass] / [volume] """ - def __init__(self, name, symbol, aliases, converter, reference=None, is_base=False): + def __init__( + self, + name: str, + symbol: Optional[str], + aliases: Iterable[str], + converter: Optional[Converter], + reference: Optional[UnitsContainer] = None, + is_base: bool = False, + ) -> None: self.reference = reference self.is_base = is_base super().__init__(name, symbol, aliases, converter=None) @classmethod - def from_string(cls, definition, non_int_type=float): + def from_string( + cls, definition: Union[str, PreprocessedDefinition], non_int_type: type = float + ) -> "DimensionDefinition": if isinstance(definition, str): definition = PreprocessedDefinition.from_string(definition) @@ -350,11 +383,13 @@ class AliasDefinition(Definition): @alias meter = my_meter """ - def __init__(self, name, aliases): + def __init__(self, name: str, aliases: Iterable[str]) -> None: super().__init__(name=name, symbol=None, aliases=aliases, converter=None) @classmethod - def from_string(cls, definition, non_int_type=float): + def from_string( + cls, definition: Union[str, PreprocessedDefinition], non_int_type: type = float + ) -> AliasDefinition: if isinstance(definition, str): definition = PreprocessedDefinition.from_string(definition) diff --git a/pint/formatting.py b/pint/formatting.py index afc51fe..e9f2090 100644 --- a/pint/formatting.py +++ b/pint/formatting.py @@ -9,6 +9,7 @@ """ import re +from typing import Dict from .babel_names import _babel_lengths, _babel_units from .compat import babel_parse @@ -71,7 +72,7 @@ def _pretty_fmt_exponent(num): #: _FORMATS maps format specifications to the corresponding argument set to #: formatter(). -_FORMATS = { +_FORMATS: Dict[str, dict] = { "P": { # Pretty format. "as_ratio": True, "single_denominator": False, diff --git a/pint/py.typed b/pint/py.typed new file mode 100644 index 0000000..e69de29 --- /dev/null +++ b/pint/py.typed diff --git a/pint/quantity.py b/pint/quantity.py index fb0396e..34361bb 100644 --- a/pint/quantity.py +++ b/pint/quantity.py @@ -6,6 +6,8 @@ :license: BSD, see LICENSE for more details. """ +from __future__ import annotations + import bisect import contextlib import copy @@ -17,9 +19,27 @@ import numbers import operator import re import warnings -from typing import List +from typing import ( + TYPE_CHECKING, + Any, + Callable, + Dict, + Generic, + Iterable, + Iterator, + List, + Optional, + Sequence, + Tuple, + Type, + TypeVar, + Union, + overload, +) +from ._typing import S, Shape, UnitLike, _MagnitudeType from .compat import ( + HAS_NUMPY, _to_magnitude, babel_parse, compute, @@ -66,6 +86,14 @@ from .util import ( to_units_container, ) +if TYPE_CHECKING: + from . import Context, Unit + from .registry import BaseRegistry + from .unit import UnitsContainer as UnitsContainerT + + if HAS_NUMPY: + import numpy as np # noqa + class _Exception(Exception): # pragma: no cover def __init__(self, internal): @@ -153,7 +181,11 @@ def printoptions(*args, **kwargs): np.set_printoptions(**opts) -class Quantity(PrettyIPython, SharedRegistryObject): +# Workaround to bypass dynamically generated Quantity with overload method +Magnitude = TypeVar("Magnitude") + + +class Quantity(PrettyIPython, SharedRegistryObject, Generic[_MagnitudeType]): """Implements a class to describe a physical quantity: the product of a numerical value and a unit of measurement. @@ -170,21 +202,22 @@ class Quantity(PrettyIPython, SharedRegistryObject): """ #: Default formatting string. - default_format = "" + default_format: str = "" + _magnitude: _MagnitudeType @property - def force_ndarray(self): + def force_ndarray(self) -> bool: return self._REGISTRY.force_ndarray @property - def force_ndarray_like(self): + def force_ndarray_like(self) -> bool: return self._REGISTRY.force_ndarray_like @property - def UnitsContainer(self): + def UnitsContainer(self) -> Callable[..., UnitsContainerT]: return self._REGISTRY.UnitsContainer - def __reduce__(self): + def __reduce__(self) -> tuple: """Allow pickling quantities. Since UnitRegistries are not pickled, upon unpickling the new object is always attached to the application registry. """ @@ -194,6 +227,30 @@ class Quantity(PrettyIPython, SharedRegistryObject): # build_quantity_class can't be pickled return _unpickle_quantity, (Quantity, self.magnitude, self._units) + @overload + def __new__( + cls, value: str, units: Optional[UnitLike] = None + ) -> Quantity[Magnitude]: + ... + + @overload + def __new__( # type: ignore[misc] + cls, value: Sequence, units: Optional[UnitLike] = None + ) -> Quantity[np.ndarray]: + ... + + @overload + def __new__( + cls, value: Quantity[Magnitude], units: Optional[UnitLike] = None + ) -> Quantity[Magnitude]: + ... + + @overload + def __new__( + cls, value: Magnitude, units: Optional[UnitLike] = None + ) -> Quantity[Magnitude]: + ... + def __new__(cls, value, units=None): if is_upcast_type(type(value)): raise TypeError(f"Quantity cannot wrap upcast type {type(value)}") @@ -250,7 +307,7 @@ class Quantity(PrettyIPython, SharedRegistryObject): def debug_used(self): return self.__used - def __iter__(self): + def __iter__(self: Quantity[Iterable[S]]) -> Iterator[S]: # Make sure that, if self.magnitude is not iterable, we raise TypeError as soon # as one calls iter(self) without waiting for the first element to be drawn from # the iterator @@ -262,34 +319,34 @@ class Quantity(PrettyIPython, SharedRegistryObject): return it_outer() - def __copy__(self): + def __copy__(self) -> Quantity[_MagnitudeType]: ret = self.__class__(copy.copy(self._magnitude), self._units) ret.__used = self.__used return ret - def __deepcopy__(self, memo): + def __deepcopy__(self, memo) -> Quantity[_MagnitudeType]: ret = self.__class__( copy.deepcopy(self._magnitude, memo), copy.deepcopy(self._units, memo) ) ret.__used = self.__used return ret - def __str__(self): + def __str__(self) -> str: if self._REGISTRY.fmt_locale is not None: return self.format_babel() return format(self) - def __bytes__(self): + def __bytes__(self) -> bytes: return str(self).encode(locale.getpreferredencoding()) - def __repr__(self): + def __repr__(self) -> str: if isinstance(self._magnitude, float): return f"<Quantity({self._magnitude:.9}, '{self._units}')>" else: return f"<Quantity({self._magnitude}, '{self._units}')>" - def __hash__(self): + def __hash__(self) -> int: self_base = self.to_base_units() if self_base.dimensionless: return hash(self_base.magnitude) @@ -298,7 +355,7 @@ class Quantity(PrettyIPython, SharedRegistryObject): _exp_pattern = re.compile(r"([0-9]\.?[0-9]*)e(-?)\+?0*([0-9]+)") - def __format__(self, spec): + def __format__(self, spec: str) -> str: if self._REGISTRY.fmt_locale is not None: return self.format_babel(spec) @@ -410,7 +467,7 @@ class Quantity(PrettyIPython, SharedRegistryObject): p.text(" ") p.pretty(self.units) - def format_babel(self, spec="", **kwspec): + def format_babel(self, spec: str = "", **kwspec: Any) -> str: spec = spec or self.default_format # standard cases @@ -435,16 +492,16 @@ class Quantity(PrettyIPython, SharedRegistryObject): ).replace("\n", "") @property - def magnitude(self): + def magnitude(self) -> _MagnitudeType: """Quantity's magnitude. Long form for `m`""" return self._magnitude @property - def m(self): + def m(self) -> _MagnitudeType: """Quantity's magnitude. Short form for `magnitude`""" return self._magnitude - def m_as(self, units): + def m_as(self, units) -> _MagnitudeType: """Quantity's magnitude expressed in particular units. Parameters @@ -459,31 +516,31 @@ class Quantity(PrettyIPython, SharedRegistryObject): return self.to(units).magnitude @property - def units(self): + def units(self) -> "Unit": """Quantity's units. Long form for `u`""" return self._REGISTRY.Unit(self._units) @property - def u(self): + def u(self) -> "Unit": """Quantity's units. Short form for `units`""" return self._REGISTRY.Unit(self._units) @property - def unitless(self): + def unitless(self) -> bool: """ """ return not bool(self.to_root_units()._units) @property - def dimensionless(self): + def dimensionless(self) -> bool: """ """ tmp = self.to_root_units() return not bool(tmp.dimensionality) - _dimensionality = None + _dimensionality: Optional[UnitsContainerT] = None @property - def dimensionality(self): + def dimensionality(self) -> UnitsContainerT: """ Returns ------- @@ -495,12 +552,12 @@ class Quantity(PrettyIPython, SharedRegistryObject): return self._dimensionality - def check(self, dimension): + def check(self, dimension: UnitLike) -> bool: """Return true if the quantity's dimension matches passed dimension.""" return self.dimensionality == self._REGISTRY.get_dimensionality(dimension) @classmethod - def from_list(cls, quant_list, units=None): + def from_list(cls, quant_list: List[Quantity], units=None) -> Quantity[np.ndarray]: """Transforms a list of Quantities into an numpy.array quantity. If no units are specified, the unit of the first element will be used. Same as from_sequence. @@ -522,7 +579,7 @@ class Quantity(PrettyIPython, SharedRegistryObject): return cls.from_sequence(quant_list, units=units) @classmethod - def from_sequence(cls, seq, units=None): + def from_sequence(cls, seq: Sequence[Quantity], units=None) -> Quantity[np.ndarray]: """Transforms a sequence of Quantities into an numpy.array quantity. If no units are specified, the unit of the first element will be used. @@ -560,7 +617,7 @@ class Quantity(PrettyIPython, SharedRegistryObject): def from_tuple(cls, tup): return cls(tup[0], cls._REGISTRY.UnitsContainer(tup[1])) - def to_tuple(self): + def to_tuple(self) -> Tuple[_MagnitudeType, Tuple[Tuple[str]]]: return self.m, tuple(self._units.items()) def compatible_units(self, *contexts): @@ -570,7 +627,9 @@ class Quantity(PrettyIPython, SharedRegistryObject): return self._REGISTRY.get_compatible_units(self._units) - def is_compatible_with(self, other, *contexts, **ctx_kwargs): + def is_compatible_with( + self, other: Any, *contexts: Union[str, Context], **ctx_kwargs: Any + ) -> bool: """check if the other object is compatible Parameters @@ -623,7 +682,7 @@ class Quantity(PrettyIPython, SharedRegistryObject): inplace=is_duck_array_type(type(self._magnitude)), ) - def ito(self, other=None, *contexts, **ctx_kwargs): + def ito(self, other=None, *contexts, **ctx_kwargs) -> None: """Inplace rescale to different units. Parameters @@ -642,7 +701,7 @@ class Quantity(PrettyIPython, SharedRegistryObject): return None - def to(self, other=None, *contexts, **ctx_kwargs): + def to(self, other=None, *contexts, **ctx_kwargs) -> Quantity[_MagnitudeType]: """Return Quantity rescaled to different units. Parameters @@ -664,7 +723,7 @@ class Quantity(PrettyIPython, SharedRegistryObject): return self.__class__(magnitude, other) - def ito_root_units(self): + def ito_root_units(self) -> None: """Return Quantity rescaled to root units.""" _, other = self._REGISTRY._get_root_units(self._units) @@ -674,7 +733,7 @@ class Quantity(PrettyIPython, SharedRegistryObject): return None - def to_root_units(self): + def to_root_units(self) -> Quantity[_MagnitudeType]: """Return Quantity rescaled to root units.""" _, other = self._REGISTRY._get_root_units(self._units) @@ -683,7 +742,7 @@ class Quantity(PrettyIPython, SharedRegistryObject): return self.__class__(magnitude, other) - def ito_base_units(self): + def ito_base_units(self) -> None: """Return Quantity rescaled to base units.""" _, other = self._REGISTRY._get_base_units(self._units) @@ -693,7 +752,7 @@ class Quantity(PrettyIPython, SharedRegistryObject): return None - def to_base_units(self): + def to_base_units(self) -> Quantity[_MagnitudeType]: """Return Quantity rescaled to base units.""" _, other = self._REGISTRY._get_base_units(self._units) @@ -702,7 +761,7 @@ class Quantity(PrettyIPython, SharedRegistryObject): return self.__class__(magnitude, other) - def ito_reduced_units(self): + def ito_reduced_units(self) -> None: """Return Quantity scaled in place to reduced units, i.e. one unit per dimension. This will not reduce compound units (e.g., 'J/kg' will not be reduced to m**2/s**2), nor can it make use of contexts at this time. @@ -730,7 +789,7 @@ class Quantity(PrettyIPython, SharedRegistryObject): return self.ito(newunits) - def to_reduced_units(self): + def to_reduced_units(self) -> Quantity[_MagnitudeType]: """Return Quantity scaled in place to reduced units, i.e. one unit per dimension. This will not reduce compound units (intentionally), nor can it make use of contexts at this time. @@ -741,7 +800,7 @@ class Quantity(PrettyIPython, SharedRegistryObject): newq.ito_reduced_units() return newq - def to_compact(self, unit=None): + def to_compact(self, unit=None) -> Quantity[_MagnitudeType]: """ "Return Quantity rescaled to compact, human-readable units. To get output in terms of a different unit, use the unit parameter. @@ -775,7 +834,7 @@ class Quantity(PrettyIPython, SharedRegistryObject): ): return self - SI_prefixes = {} + SI_prefixes: Dict[int, str] = {} for prefix in self._REGISTRY._prefixes.values(): try: scale = prefix.converter.scale @@ -786,9 +845,9 @@ class Quantity(PrettyIPython, SharedRegistryObject): except Exception: SI_prefixes[0] = "" - SI_prefixes = sorted(SI_prefixes.items()) - SI_powers = [item[0] for item in SI_prefixes] - SI_bases = [item[1] for item in SI_prefixes] + SI_prefixes_list = sorted(SI_prefixes.items()) + SI_powers = [item[0] for item in SI_prefixes_list] + SI_bases = [item[1] for item in SI_prefixes_list] if unit is None: unit = infer_base_unit(self) @@ -817,25 +876,25 @@ class Quantity(PrettyIPython, SharedRegistryObject): if index >= len(SI_bases): index = -1 - prefix = SI_bases[index] + prefix_str = SI_bases[index] - new_unit_str = prefix + unit_str + new_unit_str = prefix_str + unit_str new_unit_container = q_base._units.rename(unit_str, new_unit_str) return self.to(new_unit_container) # Mathematical operations - def __int__(self): + def __int__(self) -> int: if self.dimensionless: return int(self._convert_magnitude_not_inplace(UnitsContainer())) raise DimensionalityError(self._units, "dimensionless") - def __float__(self): + def __float__(self) -> float: if self.dimensionless: return float(self._convert_magnitude_not_inplace(UnitsContainer())) raise DimensionalityError(self._units, "dimensionless") - def __complex__(self): + def __complex__(self) -> complex: if self.dimensionless: return complex(self._convert_magnitude_not_inplace(UnitsContainer())) raise DimensionalityError(self._units, "dimensionless") @@ -1066,6 +1125,14 @@ class Quantity(PrettyIPython, SharedRegistryObject): return self.__class__(magnitude, units) + @overload + def __iadd__(self, other: datetime.datetime) -> datetime.timedelta: # type: ignore[misc] + ... + + @overload + def __iadd__(self, other) -> Quantity[_MagnitudeType]: + ... + def __iadd__(self, other): if isinstance(other, datetime.datetime): return self.to_timedelta() + other @@ -1431,7 +1498,7 @@ class Quantity(PrettyIPython, SharedRegistryObject): return self @check_implemented - def __pow__(self, other): + def __pow__(self, other) -> Quantity[_MagnitudeType]: try: _to_magnitude(other, self.force_ndarray, self.force_ndarray_like) except PintTypeError: @@ -1496,7 +1563,7 @@ class Quantity(PrettyIPython, SharedRegistryObject): return self.__class__(magnitude, units) @check_implemented - def __rpow__(self, other): + def __rpow__(self, other) -> Quantity[_MagnitudeType]: try: _to_magnitude(other, self.force_ndarray, self.force_ndarray_like) except PintTypeError: @@ -1509,16 +1576,16 @@ class Quantity(PrettyIPython, SharedRegistryObject): new_self = self.to_root_units() return other ** new_self._magnitude - def __abs__(self): + def __abs__(self) -> Quantity[_MagnitudeType]: return self.__class__(abs(self._magnitude), self._units) - def __round__(self, ndigits=0): + def __round__(self, ndigits: Optional[int] = 0) -> Quantity[int]: return self.__class__(round(self._magnitude, ndigits=ndigits), self._units) - def __pos__(self): + def __pos__(self) -> Quantity[_MagnitudeType]: return self.__class__(operator.pos(self._magnitude), self._units) - def __neg__(self): + def __neg__(self) -> Quantity[_MagnitudeType]: return self.__class__(operator.neg(self._magnitude), self._units) @check_implemented @@ -1627,7 +1694,7 @@ class Quantity(PrettyIPython, SharedRegistryObject): __ge__ = lambda self, other: self.compare(other, op=operator.ge) __gt__ = lambda self, other: self.compare(other, op=operator.gt) - def __bool__(self): + def __bool__(self) -> bool: # Only cast when non-ambiguous (when multiplicative unit) if self._is_multiplicative: return bool(self._magnitude) @@ -1695,7 +1762,7 @@ class Quantity(PrettyIPython, SharedRegistryObject): else: return value - def __array__(self, t=None): + def __array__(self, t=None) -> np.ndarray: warnings.warn( "The unit of the quantity is stripped when downcasting to ndarray.", UnitStrippedWarning, @@ -1722,11 +1789,11 @@ class Quantity(PrettyIPython, SharedRegistryObject): raise DimensionalityError("dimensionless", self._units) return self.__class__(self.magnitude.clip(min, max, out, **kwargs), self._units) - def fill(self, value): + def fill(self: Quantity[np.ndarray], value) -> None: self._units = value._units return self.magnitude.fill(value.magnitude) - def put(self, indices, values, mode="raise"): + def put(self: Quantity[np.ndarray], indices, values, mode="raise") -> None: if isinstance(values, self.__class__): values = values.to(self).magnitude elif self.dimensionless: @@ -1736,11 +1803,11 @@ class Quantity(PrettyIPython, SharedRegistryObject): self.magnitude.put(indices, values, mode) @property - def real(self): + def real(self) -> Quantity[_MagnitudeType]: return self.__class__(self._magnitude.real, self._units) @property - def imag(self): + def imag(self) -> Quantity[_MagnitudeType]: return self.__class__(self._magnitude.imag, self._units) @property @@ -1753,7 +1820,7 @@ class Quantity(PrettyIPython, SharedRegistryObject): yield self.__class__(v, self._units) @property - def shape(self): + def shape(self) -> Shape: return self._magnitude.shape @shape.setter @@ -1791,10 +1858,10 @@ class Quantity(PrettyIPython, SharedRegistryObject): self.ito(to_units) - def __len__(self): + def __len__(self) -> int: return len(self._magnitude) - def __getattr__(self, item): + def __getattr__(self, item) -> Any: if item.startswith("__array_"): # Handle array protocol attributes other than `__array__` raise AttributeError(f"Array protocol attribute {item} not available.") @@ -1944,7 +2011,7 @@ class Quantity(PrettyIPython, SharedRegistryObject): self._get_unit_definition(d).reference == offset_unit_dim for d in deltas ) - def _ok_for_muldiv(self, no_offset_units=None): + def _ok_for_muldiv(self, no_offset_units=None) -> bool: """Checks if Quantity object can be multiplied or divided""" is_ok = True @@ -1964,7 +2031,7 @@ class Quantity(PrettyIPython, SharedRegistryObject): is_ok = False return is_ok - def to_timedelta(self): + def to_timedelta(self: Quantity[float]) -> datetime.timedelta: return datetime.timedelta(microseconds=self.to("microseconds").magnitude) # Dask.array.Array ducking @@ -2058,7 +2125,7 @@ class Quantity(PrettyIPython, SharedRegistryObject): _Quantity = Quantity -def build_quantity_class(registry): +def build_quantity_class(registry: BaseRegistry) -> Type[Quantity]: class Quantity(_Quantity): _REGISTRY = registry diff --git a/pint/registry.py b/pint/registry.py index 42c75c6..599a049 100644 --- a/pint/registry.py +++ b/pint/registry.py @@ -33,6 +33,8 @@ The module actually defines 5 registries with different capabilities: :license: BSD, see LICENSE for more details. """ +from __future__ import annotations + import copy import functools import importlib.resources @@ -45,10 +47,29 @@ from contextlib import contextmanager from decimal import Decimal from fractions import Fraction from io import StringIO +from numbers import Number from tokenize import NAME, NUMBER +from typing import ( + TYPE_CHECKING, + Any, + Callable, + ContextManager, + Dict, + FrozenSet, + Iterable, + Iterator, + List, + Optional, + Set, + Tuple, + Type, + TypeVar, + Union, +) from . import registry_helpers, systems -from .compat import babel_parse, tokenizer +from ._typing import F, QuantityOrUnitLike +from .compat import HAS_BABEL, babel_parse, tokenizer from .context import Context, ContextChain from .converters import LogarithmicConverter, ScaleConverter from .definitions import ( @@ -65,6 +86,7 @@ from .errors import ( UndefinedUnitError, ) from .pint_eval import build_eval_tree +from .systems import Group, System from .util import ( ParserHelper, SourceIterator, @@ -80,6 +102,21 @@ from .util import ( to_units_container, ) +if TYPE_CHECKING: + from ._typing import UnitLike + from .quantity import Quantity + from .unit import Unit + from .unit import UnitsContainer as UnitsContainerT + + if HAS_BABEL: + import babel + + Locale = babel.Locale + else: + Locale = None + +T = TypeVar("T") + _BLOCK_RE = re.compile(r"[ (]") @@ -110,15 +147,15 @@ class RegistryMeta(type): class RegistryCache: """Cache to speed up unit registries""" - def __init__(self): + def __init__(self) -> None: #: Maps dimensionality (UnitsContainer) to Units (str) - self.dimensional_equivalents = {} + self.dimensional_equivalents: Dict[UnitsContainer, Set[str]] = {} #: Maps dimensionality (UnitsContainer) to Dimensionality (UnitsContainer) self.root_units = {} #: Maps dimensionality (UnitsContainer) to Units (UnitsContainer) - self.dimensionality = {} + self.dimensionality: Dict[UnitsContainer, UnitsContainer] = {} #: Cache the unit name associated to user input. ('mV' -> 'millivolt') - self.parse_unit = {} + self.parse_unit: Dict[str, UnitsContainer] = {} class ContextCacheOverlay: @@ -126,13 +163,17 @@ class ContextCacheOverlay: active contexts which contain unit redefinitions. """ - def __init__(self, registry_cache: RegistryCache): + def __init__(self, registry_cache: RegistryCache) -> None: self.dimensional_equivalents = registry_cache.dimensional_equivalents self.root_units = {} self.dimensionality = registry_cache.dimensionality self.parse_unit = registry_cache.parse_unit +NON_INT_TYPE = Type[Union[float, Decimal, Fraction]] +PreprocessorType = Callable[[str], str] + + class BaseRegistry(metaclass=RegistryMeta): """Base class for all registries. @@ -173,22 +214,22 @@ class BaseRegistry(metaclass=RegistryMeta): #: Map context prefix to function #: type: Dict[str, (SourceIterator -> None)] - _parsers = None + _parsers: Dict[str, Callable[[SourceIterator], None]] = None #: Babel.Locale instance or None - fmt_locale = None + fmt_locale: Optional[Locale] = None def __init__( self, filename="", - force_ndarray=False, - force_ndarray_like=False, - on_redefinition="warn", - auto_reduce_dimensions=False, - preprocessors=None, - fmt_locale=None, - non_int_type=float, - case_sensitive=True, + force_ndarray: bool = False, + force_ndarray_like: bool = False, + on_redefinition: str = "warn", + auto_reduce_dimensions: bool = False, + preprocessors: Optional[List[PreprocessorType]] = None, + fmt_locale: Optional[str] = None, + non_int_type: NON_INT_TYPE = float, + case_sensitive: bool = True, ): self._register_parsers() self._init_dynamic_classes() @@ -215,33 +256,35 @@ class BaseRegistry(metaclass=RegistryMeta): #: Map between name (string) and value (string) of defaults stored in the #: definitions file. - self._defaults = {} + self._defaults: Dict[str, str] = {} #: Map dimension name (string) to its definition (DimensionDefinition). - self._dimensions = {} + self._dimensions: Dict[str, DimensionDefinition] = {} #: Map unit name (string) to its definition (UnitDefinition). #: Might contain prefixed units. - self._units = {} + self._units: Dict[str, UnitDefinition] = {} #: Map unit name in lower case (string) to a set of unit names with the right #: case. #: Does not contain prefixed units. #: e.g: 'hz' - > set('Hz', ) - self._units_casei = defaultdict(set) + self._units_casei: Dict[str, Set[str]] = defaultdict(set) #: Map prefix name (string) to its definition (PrefixDefinition). - self._prefixes = {"": PrefixDefinition("", "", (), 1)} + self._prefixes: Dict[str, PrefixDefinition] = { + "": PrefixDefinition("", "", (), 1) + } #: Map suffix name (string) to canonical , and unit alias to canonical unit name - self._suffixes = {"": "", "s": ""} + self._suffixes: Dict[str, str] = {"": "", "s": ""} #: Map contexts to RegistryCache self._cache = RegistryCache() self._initialized = False - def _init_dynamic_classes(self): + def _init_dynamic_classes(self) -> None: """Generate subclasses on the fly and attach them to self""" from .unit import build_unit_class @@ -249,13 +292,13 @@ class BaseRegistry(metaclass=RegistryMeta): from .quantity import build_quantity_class - self.Quantity = build_quantity_class(self) + self.Quantity: Type["Quantity"] = build_quantity_class(self) from .measurement import build_measurement_class self.Measurement = build_measurement_class(self) - def _after_init(self): + def _after_init(self) -> None: """This should be called after all __init__""" if self._filename == "": @@ -266,17 +309,17 @@ class BaseRegistry(metaclass=RegistryMeta): self._build_cache() self._initialized = True - def _register_parsers(self): + def _register_parsers(self) -> None: self._register_parser("@defaults", self._parse_defaults) - def _parse_defaults(self, ifile): + def _parse_defaults(self, ifile) -> None: """Loader for a @default section.""" next(ifile) for lineno, part in ifile.block_iter(): k, v = part.split("=") self._defaults[k.strip()] = v.strip() - def __deepcopy__(self, memo): + def __deepcopy__(self, memo) -> "BaseRegistry": new = object.__new__(type(self)) new.__dict__ = copy.deepcopy(self.__dict__, memo) new._init_dynamic_classes() @@ -293,7 +336,7 @@ class BaseRegistry(metaclass=RegistryMeta): ) return self.parse_expression(item) - def __contains__(self, item): + def __contains__(self, item) -> bool: """Support checking prefixed units with the `in` operator""" try: self.__getattr__(item) @@ -301,12 +344,12 @@ class BaseRegistry(metaclass=RegistryMeta): except UndefinedUnitError: return False - def __dir__(self): + def __dir__(self) -> List[str]: #: Calling dir(registry) gives all units, methods, and attributes. #: Also used for autocompletion in IPython. return list(self._units.keys()) + list(object.__dir__(self)) - def __iter__(self): + def __iter__(self) -> Iterator[str]: """Allows for listing all units in registry with `list(ureg)`. Returns @@ -315,7 +358,7 @@ class BaseRegistry(metaclass=RegistryMeta): """ return iter(sorted(self._units.keys())) - def set_fmt_locale(self, loc): + def set_fmt_locale(self, loc: Optional[str]) -> None: """Change the locale used by default by `format_babel`. Parameters @@ -332,20 +375,20 @@ class BaseRegistry(metaclass=RegistryMeta): self.fmt_locale = loc - def UnitsContainer(self, *args, **kwargs): + def UnitsContainer(self, *args, **kwargs) -> UnitsContainerT: return UnitsContainer(*args, non_int_type=self.non_int_type, **kwargs) @property - def default_format(self): + def default_format(self) -> str: """Default formatting string for quantities.""" return self.Quantity.default_format @default_format.setter - def default_format(self, value): + def default_format(self, value: str): self.Unit.default_format = value self.Quantity.default_format = value - def define(self, definition): + def define(self, definition: Union[str, Definition]) -> None: """Add unit to the registry. Parameters @@ -360,7 +403,7 @@ class BaseRegistry(metaclass=RegistryMeta): else: self._define(definition) - def _define(self, definition): + def _define(self, definition: Definition) -> Tuple[Definition, dict, dict]: """Add unit to the registry. This method defines only multiplicative units, converting any other type @@ -509,7 +552,7 @@ class BaseRegistry(metaclass=RegistryMeta): else: raise ValueError("Prefix directives must start with '@'") - def load_definitions(self, file, is_resource=False): + def load_definitions(self, file, is_resource: bool = False) -> None: """Add units and prefixes defined in a definition text file. Parameters @@ -584,7 +627,7 @@ class BaseRegistry(metaclass=RegistryMeta): except Exception as ex: logger.error("In line {}, cannot add '{}' {}".format(no, line, ex)) - def _build_cache(self): + def _build_cache(self) -> None: """Build a cache of dimensionality and base units.""" self._cache = RegistryCache() @@ -621,7 +664,9 @@ class BaseRegistry(metaclass=RegistryMeta): except Exception as exc: logger.warning(f"Could not resolve {unit_name}: {exc!r}") - def get_name(self, name_or_alias, case_sensitive=None): + def get_name( + self, name_or_alias: str, case_sensitive: Optional[bool] = None + ) -> str: """Return the canonical name of a unit.""" if name_or_alias == "dimensionless": @@ -659,7 +704,9 @@ class BaseRegistry(metaclass=RegistryMeta): return unit_name - def get_symbol(self, name_or_alias, case_sensitive=None): + def get_symbol( + self, name_or_alias: str, case_sensitive: Optional[bool] = None + ) -> str: """Return the preferred alias for a unit.""" candidates = self.parse_unit_name(name_or_alias, case_sensitive) if not candidates: @@ -675,10 +722,10 @@ class BaseRegistry(metaclass=RegistryMeta): return self._prefixes[prefix].symbol + self._units[unit_name].symbol - def _get_symbol(self, name): + def _get_symbol(self, name: str) -> str: return self._units[name].symbol - def get_dimensionality(self, input_units): + def get_dimensionality(self, input_units) -> UnitsContainerT: """Convert unit or dict of units or dimensions to a dict of base dimensions dimensions """ @@ -689,7 +736,9 @@ class BaseRegistry(metaclass=RegistryMeta): return self._get_dimensionality(input_units) - def _get_dimensionality(self, input_units): + def _get_dimensionality( + self, input_units: Optional[UnitsContainerT] + ) -> UnitsContainerT: """Convert a UnitsContainer to base dimensions.""" if not input_units: return self.UnitsContainer() @@ -757,7 +806,9 @@ class BaseRegistry(metaclass=RegistryMeta): return first return None - def get_root_units(self, input_units, check_nonmult=True): + def get_root_units( + self, input_units: UnitLike, check_nonmult: bool = True + ) -> Tuple[Number, Unit]: """Convert unit or dict of units to the root units. If any unit is non multiplicative and check_converter is True, @@ -868,7 +919,9 @@ class BaseRegistry(metaclass=RegistryMeta): if reg.reference is not None: self._get_root_units_recurse(reg.reference, exp2, accumulators) - def get_compatible_units(self, input_units, group_or_system=None): + def get_compatible_units( + self, input_units, group_or_system=None + ) -> FrozenSet["Unit"]: """ """ input_units = to_units_container(input_units) @@ -884,7 +937,9 @@ class BaseRegistry(metaclass=RegistryMeta): src_dim = self._get_dimensionality(input_units) return self._cache.dimensional_equivalents[src_dim] - def is_compatible_with(self, obj1, obj2, *contexts, **ctx_kwargs): + def is_compatible_with( + self, obj1: Any, obj2: Any, *contexts: Union[str, Context], **ctx_kwargs + ) -> bool: """check if the other object is compatible Parameters @@ -911,7 +966,13 @@ class BaseRegistry(metaclass=RegistryMeta): return not isinstance(obj2, (self.Quantity, self.Unit)) - def convert(self, value, src, dst, inplace=False): + def convert( + self, + value: T, + src: QuantityOrUnitLike, + dst: QuantityOrUnitLike, + inplace: bool = False, + ) -> T: """Convert value from some source to destination units. Parameters @@ -991,7 +1052,9 @@ class BaseRegistry(metaclass=RegistryMeta): return value - def parse_unit_name(self, unit_name, case_sensitive=None): + def parse_unit_name( + self, unit_name: str, case_sensitive: Optional[bool] = None + ) -> Tuple[Tuple[str, str, str], ...]: """Parse a unit to identify prefix, unit name and suffix by walking the list of prefix and suffix. In case of equivalent combinations (e.g. ('kilo', 'gram', '') and @@ -1014,7 +1077,9 @@ class BaseRegistry(metaclass=RegistryMeta): self._parse_unit_name(unit_name, case_sensitive=case_sensitive) ) - def _parse_unit_name(self, unit_name, case_sensitive=None): + def _parse_unit_name( + self, unit_name: str, case_sensitive: Optional[bool] = None + ) -> Iterator[Tuple[str, str, str]]: """Helper of parse_unit_name.""" case_sensitive = ( self.case_sensitive if case_sensitive is None else case_sensitive @@ -1044,7 +1109,9 @@ class BaseRegistry(metaclass=RegistryMeta): ) @staticmethod - def _dedup_candidates(candidates): + def _dedup_candidates( + candidates: Iterable[Tuple[str, str, str]] + ) -> Tuple[Tuple[str, str, str], ...]: """Helper of parse_unit_name. Given an iterable of unit triplets (prefix, name, suffix), remove those with @@ -1062,7 +1129,12 @@ class BaseRegistry(metaclass=RegistryMeta): candidates.pop(("", cp + cu, ""), None) return tuple(candidates) - def parse_units(self, input_string, as_delta=None, case_sensitive=None): + def parse_units( + self, + input_string: str, + as_delta: Optional[bool] = None, + case_sensitive: Optional[bool] = None, + ) -> Unit: """Parse a units expression and returns a UnitContainer with the canonical names. @@ -1080,6 +1152,7 @@ class BaseRegistry(metaclass=RegistryMeta): Returns ------- + pint.Unit """ for p in self.preprocessors: @@ -1159,8 +1232,13 @@ class BaseRegistry(metaclass=RegistryMeta): raise Exception("unknown token type") def parse_pattern( - self, input_string, pattern, case_sensitive=None, use_decimal=False, many=False - ): + self, + input_string: str, + pattern: str, + case_sensitive: Optional[bool] = None, + use_decimal: bool = False, + many: bool = False, + ) -> Union[List[str], str, None]: """Parse a string with a given regex pattern and returns result. Parameters @@ -1215,8 +1293,12 @@ class BaseRegistry(metaclass=RegistryMeta): return results def parse_expression( - self, input_string, case_sensitive=None, use_decimal=False, **values - ): + self, + input_string: str, + case_sensitive: Optional[bool] = None, + use_decimal: bool = False, + **values, + ) -> Quantity: """Parse a mathematical expression including units and return a quantity object. Numerical constants can be specified as keyword arguments and will take precedence @@ -1280,8 +1362,11 @@ class NonMultiplicativeRegistry(BaseRegistry): """ def __init__( - self, default_as_delta=True, autoconvert_offset_to_baseunit=False, **kwargs - ): + self, + default_as_delta: bool = True, + autoconvert_offset_to_baseunit: bool = False, + **kwargs: Any, + ) -> None: super().__init__(**kwargs) #: When performing a multiplication of units, interpret @@ -1292,14 +1377,19 @@ class NonMultiplicativeRegistry(BaseRegistry): # base units on multiplication and division. self.autoconvert_offset_to_baseunit = autoconvert_offset_to_baseunit - def _parse_units(self, input_string, as_delta=None, case_sensitive=None): + def _parse_units( + self, + input_string: str, + as_delta: Optional[bool] = None, + case_sensitive: Optional[bool] = None, + ): """ """ if as_delta is None: as_delta = self.default_as_delta return super()._parse_units(input_string, as_delta, case_sensitive) - def _define(self, definition): + def _define(self, definition: Union[str, Definition]): """Add unit to the registry. In addition to what is done by the BaseRegistry, @@ -1325,7 +1415,7 @@ class NonMultiplicativeRegistry(BaseRegistry): return definition, d, di - def _is_multiplicative(self, u): + def _is_multiplicative(self, u) -> bool: if u in self._units: return self._units[u].is_multiplicative @@ -1473,9 +1563,9 @@ class ContextRegistry(BaseRegistry): - Parse @context directive. """ - def __init__(self, **kwargs): + def __init__(self, **kwargs: Any) -> None: # Map context name (string) or abbreviation to context. - self._contexts = {} + self._contexts: Dict[str, Context] = {} # Stores active contexts. self._active_ctx = ContextChain() # Map context chain to cache @@ -1488,11 +1578,11 @@ class ContextRegistry(BaseRegistry): # Allow contexts to add override layers to the units self._units = ChainMap(self._units) - def _register_parsers(self): + def _register_parsers(self) -> None: super()._register_parsers() self._register_parser("@context", self._parse_context) - def _parse_context(self, ifile): + def _parse_context(self, ifile) -> None: try: self.add_context( Context.from_lines( @@ -1624,7 +1714,9 @@ class ContextRegistry(BaseRegistry): # Write into the context-specific self._units.maps[0] and self._cache.root_units self.define(definition) - def enable_contexts(self, *names_or_contexts, **kwargs) -> None: + def enable_contexts( + self, *names_or_contexts: Union[str, Context], **kwargs + ) -> None: """Enable contexts provided by name or by object. Parameters @@ -1664,10 +1756,10 @@ class ContextRegistry(BaseRegistry): ctx.checked = True # and create a new one with the new defaults. - ctxs = tuple(Context.from_context(ctx, **kwargs) for ctx in ctxs) + contexts = tuple(Context.from_context(ctx, **kwargs) for ctx in ctxs) # Finally we add them to the active context. - self._active_ctx.insert_contexts(*ctxs) + self._active_ctx.insert_contexts(*contexts) self._switch_context_cache_and_units() def disable_contexts(self, n: int = None) -> None: @@ -1682,7 +1774,7 @@ class ContextRegistry(BaseRegistry): self._switch_context_cache_and_units() @contextmanager - def context(self, *names, **kwargs): + def context(self, *names, **kwargs) -> ContextManager[Context]: """Used as a context manager, this function enables to activate a context which is removed after usage. @@ -1739,7 +1831,7 @@ class ContextRegistry(BaseRegistry): # the added contexts are removed from the active one. self.disable_contexts(len(names)) - def with_context(self, name, **kwargs): + def with_context(self, name, **kwargs) -> Callable[[F], F]: """Decorator to wrap a function call in a Pint context. Use it to ensure that a certain context is active when @@ -1858,23 +1950,23 @@ class SystemRegistry(BaseRegistry): #: Map system name to system. #: :type: dict[ str | System] - self._systems = {} + self._systems: Dict[str, System] = {} #: Maps dimensionality (UnitsContainer) to Dimensionality (UnitsContainer) self._base_units_cache = dict() #: Map group name to group. #: :type: dict[ str | Group] - self._groups = {} + self._groups: Dict[str, Group] = {} self._groups["root"] = self.Group("root") self._default_system = system - def _init_dynamic_classes(self): + def _init_dynamic_classes(self) -> None: super()._init_dynamic_classes() self.Group = systems.build_group_class(self) self.System = systems.build_system_class(self) - def _after_init(self): + def _after_init(self) -> None: """Invoked at the end of ``__init__``. - Create default group and add all orphan units to it @@ -1901,20 +1993,20 @@ class SystemRegistry(BaseRegistry): "system", None ) - def _register_parsers(self): + def _register_parsers(self) -> None: super()._register_parsers() self._register_parser("@group", self._parse_group) self._register_parser("@system", self._parse_system) - def _parse_group(self, ifile): + def _parse_group(self, ifile) -> None: self.Group.from_lines(ifile.block_iter(), self.define, self.non_int_type) - def _parse_system(self, ifile): + def _parse_system(self, ifile) -> None: self.System.from_lines( ifile.block_iter(), self.get_root_units, self.non_int_type ) - def get_group(self, name, create_if_needed=True): + def get_group(self, name: str, create_if_needed: bool = True) -> Group: """Return a Group. Parameters @@ -1943,7 +2035,7 @@ class SystemRegistry(BaseRegistry): return systems.Lister(self._systems) @property - def default_system(self): + def default_system(self) -> System: return self._default_system @default_system.setter @@ -1956,7 +2048,7 @@ class SystemRegistry(BaseRegistry): self._default_system = name - def get_system(self, name, create_if_needed=True): + def get_system(self, name: str, create_if_needed: bool = True) -> System: """Return a Group. Parameters @@ -1994,7 +2086,12 @@ class SystemRegistry(BaseRegistry): return definition, d, di - def get_base_units(self, input_units, check_nonmult=True, system=None): + def get_base_units( + self, + input_units: Union[UnitLike, Quantity], + check_nonmult: bool = True, + system: Union[str, System, None] = None, + ) -> Tuple[Number, Unit]: """Convert unit or dict of units to the base units. If any unit is non multiplicative and check_converter is True, @@ -2027,7 +2124,12 @@ class SystemRegistry(BaseRegistry): return f, self.Unit(units) - def _get_base_units(self, input_units, check_nonmult=True, system=None): + def _get_base_units( + self, + input_units: UnitsContainerT, + check_nonmult: bool = True, + system: Union[str, System, None] = None, + ): if system is None: system = self._default_system @@ -2068,7 +2170,7 @@ class SystemRegistry(BaseRegistry): return base_factor, destination_units - def _get_compatible_units(self, input_units, group_or_system): + def _get_compatible_units(self, input_units, group_or_system) -> FrozenSet[Unit]: if group_or_system is None: group_or_system = self._default_system @@ -2126,17 +2228,17 @@ class UnitRegistry(SystemRegistry, ContextRegistry, NonMultiplicativeRegistry): def __init__( self, filename="", - force_ndarray=False, - force_ndarray_like=False, - default_as_delta=True, - autoconvert_offset_to_baseunit=False, - on_redefinition="warn", + force_ndarray: bool = False, + force_ndarray_like: bool = False, + default_as_delta: bool = True, + autoconvert_offset_to_baseunit: bool = False, + on_redefinition: str = "warn", system=None, auto_reduce_dimensions=False, preprocessors=None, fmt_locale=None, non_int_type=float, - case_sensitive=True, + case_sensitive: bool = True, ): super().__init__( @@ -2170,7 +2272,7 @@ class UnitRegistry(SystemRegistry, ContextRegistry, NonMultiplicativeRegistry): """ return pi_theorem(quantities, self) - def setup_matplotlib(self, enable=True): + def setup_matplotlib(self, enable: bool = True) -> None: """Set up handlers for matplotlib's unit support. Parameters diff --git a/pint/registry_helpers.py b/pint/registry_helpers.py index b8888c8..7f6ee7f 100644 --- a/pint/registry_helpers.py +++ b/pint/registry_helpers.py @@ -11,10 +11,19 @@ import functools from inspect import signature from itertools import zip_longest +from typing import TYPE_CHECKING, Callable, Iterable, TypeVar, Union +from ._typing import F from .errors import DimensionalityError +from .quantity import Quantity from .util import UnitsContainer, to_units_container +if TYPE_CHECKING: + from .registry import UnitRegistry + from .unit import Unit + +T = TypeVar("T") + def _replace_units(original_units, values_by_name): """Convert a unit compatible type to a UnitsContainer. @@ -175,7 +184,12 @@ def _apply_defaults(func, args, kwargs): return args, {} -def wraps(ureg, ret, args, strict=True): +def wraps( + ureg: "UnitRegistry", + ret: Union[str, "Unit", Iterable[str], Iterable["Unit"], None], + args: Union[str, "Unit", Iterable[str], Iterable["Unit"], None], + strict: bool = True, +) -> Callable[[Callable[..., T]], Callable[..., Quantity[T]]]: """Wraps a function to become pint-aware. Use it when a function requires a numerical value but in some specific @@ -239,7 +253,7 @@ def wraps(ureg, ret, args, strict=True): ) ret = _to_units_container(ret, ureg) - def decorator(func): + def decorator(func: Callable[..., T]) -> Callable[..., Quantity[T]]: count_params = len(signature(func).parameters) if len(args) != count_params: @@ -256,7 +270,7 @@ def wraps(ureg, ret, args, strict=True): ) @functools.wraps(func, assigned=assigned, updated=updated) - def wrapper(*values, **kw): + def wrapper(*values, **kw) -> Quantity[T]: values, kw = _apply_defaults(func, values, kw) @@ -288,7 +302,9 @@ def wraps(ureg, ret, args, strict=True): return decorator -def check(ureg, *args): +def check( + ureg: "UnitRegistry", *args: Union[str, UnitsContainer, "Unit", None] +) -> Callable[[F], F]: """Decorator to for quantity type checking for function inputs. Use it to ensure that the decorated function input parameters match diff --git a/pint/unit.py b/pint/unit.py index 5208eab..f91b6c1 100644 --- a/pint/unit.py +++ b/pint/unit.py @@ -8,23 +8,30 @@ :license: BSD, see LICENSE for more details. """ +from __future__ import annotations + import copy import locale import operator from numbers import Number +from typing import TYPE_CHECKING, Any, Type, Union +from ._typing import UnitLike from .compat import NUMERIC_TYPES, babel_parse, is_upcast_type from .definitions import UnitDefinition from .errors import DimensionalityError from .formatting import siunitx_format_unit from .util import PrettyIPython, SharedRegistryObject, UnitsContainer +if TYPE_CHECKING: + from .context import Context + class Unit(PrettyIPython, SharedRegistryObject): """Implements a class to describe a unit supporting math operations.""" #: Default formatting string. - default_format = "" + default_format: str = "" def __reduce__(self): # See notes in Quantity.__reduce__ @@ -32,7 +39,7 @@ class Unit(PrettyIPython, SharedRegistryObject): return _unpickle_unit, (Unit, self._units) - def __init__(self, units): + def __init__(self, units: UnitLike) -> None: super().__init__() if isinstance(units, (UnitsContainer, UnitDefinition)): self._units = units @@ -50,29 +57,29 @@ class Unit(PrettyIPython, SharedRegistryObject): self.__handling = None @property - def debug_used(self): + def debug_used(self) -> Any: return self.__used - def __copy__(self): + def __copy__(self) -> Unit: ret = self.__class__(self._units) ret.__used = self.__used return ret - def __deepcopy__(self, memo): + def __deepcopy__(self, memo) -> Unit: ret = self.__class__(copy.deepcopy(self._units, memo)) ret.__used = self.__used return ret - def __str__(self): + def __str__(self) -> str: return format(self) - def __bytes__(self): + def __bytes__(self) -> bytes: return str(self).encode(locale.getpreferredencoding()) - def __repr__(self): + def __repr__(self) -> str: return "<Unit('{}')>".format(self._units) - def __format__(self, spec): + def __format__(self, spec) -> str: spec = spec or self.default_format # special cases if "Lx" in spec: # the LaTeX siunitx code @@ -93,7 +100,7 @@ class Unit(PrettyIPython, SharedRegistryObject): return format(units, spec) - def format_babel(self, spec="", locale=None, **kwspec): + def format_babel(self, spec="", locale=None, **kwspec: Any) -> str: spec = spec or self.default_format if "~" in spec: @@ -119,12 +126,12 @@ class Unit(PrettyIPython, SharedRegistryObject): return units.format_babel(spec, **kwspec) @property - def dimensionless(self): + def dimensionless(self) -> bool: """Return True if the Unit is dimensionless; False otherwise.""" return not bool(self.dimensionality) @property - def dimensionality(self): + def dimensionality(self) -> UnitsContainer: """ Returns ------- @@ -146,7 +153,9 @@ class Unit(PrettyIPython, SharedRegistryObject): return self._REGISTRY.get_compatible_units(self) - def is_compatible_with(self, other, *contexts, **ctx_kwargs): + def is_compatible_with( + self, other: Any, *contexts: Union[str, Context], **ctx_kwargs: Any + ) -> bool: """check if the other object is compatible Parameters @@ -218,7 +227,7 @@ class Unit(PrettyIPython, SharedRegistryObject): __div__ = __truediv__ __rdiv__ = __rtruediv__ - def __pow__(self, other): + def __pow__(self, other) -> "Unit": if isinstance(other, NUMERIC_TYPES): return self.__class__(self._units ** other) @@ -226,10 +235,10 @@ class Unit(PrettyIPython, SharedRegistryObject): mess = "Cannot power Unit by {}".format(type(other)) raise TypeError(mess) - def __hash__(self): + def __hash__(self) -> int: return self._units.__hash__() - def __eq__(self, other): + def __eq__(self, other) -> bool: # We compare to the base class of Unit because each Unit class is # unique. if self._check(other): @@ -244,10 +253,10 @@ class Unit(PrettyIPython, SharedRegistryObject): else: return self._units == other - def __ne__(self, other): + def __ne__(self, other) -> bool: return not (self == other) - def compare(self, other, op): + def compare(self, other, op) -> bool: self_q = self._REGISTRY.Quantity(1, self) if isinstance(other, NUMERIC_TYPES): @@ -262,13 +271,13 @@ class Unit(PrettyIPython, SharedRegistryObject): __ge__ = lambda self, other: self.compare(other, op=operator.ge) __gt__ = lambda self, other: self.compare(other, op=operator.gt) - def __int__(self): + def __int__(self) -> int: return int(self._REGISTRY.Quantity(1, self._units)) - def __float__(self): + def __float__(self) -> float: return float(self._REGISTRY.Quantity(1, self._units)) - def __complex__(self): + def __complex__(self) -> complex: return complex(self._REGISTRY.Quantity(1, self._units)) __array_priority__ = 17 @@ -361,7 +370,7 @@ class Unit(PrettyIPython, SharedRegistryObject): _Unit = Unit -def build_unit_class(registry): +def build_unit_class(registry) -> Type[Unit]: class Unit(_Unit): _REGISTRY = registry diff --git a/pint/util.py b/pint/util.py index f2162f4..18474ec 100644 --- a/pint/util.py +++ b/pint/util.py @@ -8,6 +8,8 @@ :license: BSD, see LICENSE for more details. """ +from __future__ import annotations + import logging import math import operator @@ -18,12 +20,19 @@ from functools import lru_cache, partial from logging import NullHandler from numbers import Number from token import NAME, NUMBER +from typing import TYPE_CHECKING, ClassVar, Optional, Union from .compat import NUMERIC_TYPES, tokenizer from .errors import DefinitionSyntaxError from .formatting import format_unit from .pint_eval import build_eval_tree +if TYPE_CHECKING: + from ._typing import UnitLike + from .quantity import Quantity + from .registry import BaseRegistry + + logger = logging.getLogger(__name__) logger.addHandler(NullHandler()) @@ -321,7 +330,7 @@ class UnitsContainer(Mapping): __slots__ = ("_d", "_hash", "_one", "_non_int_type") - def __init__(self, *args, **kwargs): + def __init__(self, *args, **kwargs) -> None: if args and isinstance(args[0], UnitsContainer): default_non_int_type = args[0]._non_int_type else: @@ -398,7 +407,7 @@ class UnitsContainer(Mapping): def __iter__(self): return iter(self._d) - def __len__(self): + def __len__(self) -> int: return len(self._d) def __getitem__(self, key): @@ -419,7 +428,7 @@ class UnitsContainer(Mapping): def __setstate__(self, state): self._d, self._hash, self._one, self._non_int_type = state - def __eq__(self, other): + def __eq__(self, other) -> bool: if isinstance(other, UnitsContainer): # UnitsContainer.__hash__(self) is not the same as hash(self); see # ParserHelper.__hash__ and __eq__. @@ -440,19 +449,19 @@ class UnitsContainer(Mapping): return dict.__eq__(self._d, other) - def __str__(self): + def __str__(self) -> str: return self.__format__("") - def __repr__(self): + def __repr__(self) -> str: tmp = "{%s}" % ", ".join( ["'{}': {}".format(key, value) for key, value in sorted(self._d.items())] ) return "<UnitsContainer({})>".format(tmp) - def __format__(self, spec): + def __format__(self, spec: str) -> str: return format_unit(self, spec) - def format_babel(self, spec, **kwspec): + def format_babel(self, spec: str, **kwspec) -> str: return format_unit(self, spec, **kwspec) def __copy__(self): @@ -742,7 +751,7 @@ class ParserHelper(UnitsContainer): #: List of regex substitution pairs. -_subs_re = [ +_subs_re_list = [ ("\N{DEGREE SIGN}", " degree"), (r"([\w\.\-\+\*\\\^])\s+", r"\1 "), # merge multiple spaces (r"({}) squared", r"\1**2"), # Handle square and cube @@ -758,12 +767,14 @@ _subs_re = [ ] #: Compiles the regex and replace {} by a regex that matches an identifier. -_subs_re = [(re.compile(a.format(r"[_a-zA-Z][_a-zA-Z0-9]*")), b) for a, b in _subs_re] +_subs_re = [ + (re.compile(a.format(r"[_a-zA-Z][_a-zA-Z0-9]*")), b) for a, b in _subs_re_list +] _pretty_table = str.maketrans("⁰¹²³⁴⁵⁶⁷⁸⁹·⁻", "0123456789*-") _pretty_exp_re = re.compile(r"⁻?[⁰¹²³⁴⁵⁶⁷⁸⁹]+(?:\.[⁰¹²³⁴⁵⁶⁷⁸⁹]*)?") -def string_preprocessor(input_string): +def string_preprocessor(input_string: str) -> str: input_string = input_string.replace(",", "") input_string = input_string.replace(" per ", "/") @@ -781,7 +792,7 @@ def string_preprocessor(input_string): return input_string -def _is_dim(name): +def _is_dim(name: str) -> bool: return name[0] == "[" and name[-1] == "]" @@ -799,6 +810,9 @@ class SharedRegistryObject: """ + _REGISTRY: ClassVar[BaseRegistry] + _units: UnitsContainer + def __new__(cls, *args, **kwargs): inst = object.__new__(cls) if not hasattr(cls, "_REGISTRY"): @@ -809,7 +823,7 @@ class SharedRegistryObject: inst._REGISTRY = _APP_REGISTRY return inst - def _check(self, other): + def _check(self, other) -> bool: """Check if the other object use a registry and if so that it is the same registry. @@ -840,6 +854,8 @@ class SharedRegistryObject: class PrettyIPython: """Mixin to add pretty-printers for IPython""" + default_format: str + def _repr_html_(self): if "~" in self.default_format: return "{:~H}".format(self) @@ -859,7 +875,9 @@ class PrettyIPython: p.text("{:P}".format(self)) -def to_units_container(unit_like, registry=None): +def to_units_container( + unit_like: Union[UnitLike, Quantity], registry: Optional[BaseRegistry] = None +) -> UnitsContainer: """Convert a unit compatible type to a UnitsContainer. Parameters @@ -1014,7 +1032,7 @@ class BlockIterator(SourceIterator): next = __next__ -def iterable(y): +def iterable(y) -> bool: """Check whether or not an object can be iterated over. Vendored from numpy under the terms of the BSD 3-Clause License. (Copyright @@ -1036,7 +1054,7 @@ def iterable(y): return True -def sized(y): +def sized(y) -> bool: """Check whether or not an object has a defined length. Parameters @@ -44,7 +44,7 @@ test = pytest-subtests [options.package_data] -pint = default_en.txt; constants_en.txt +pint = default_en.txt; constants_en.txt; py.typed [check-manifest] ignore = |