summaryrefslogtreecommitdiff
path: root/pint
diff options
context:
space:
mode:
authorJon Thielen <github@jont.cc>2019-12-10 22:40:42 -0600
committerJon Thielen <github@jont.cc>2019-12-10 22:40:42 -0600
commite0d6686d08aa9447270d0ba4139e7d81824482a3 (patch)
tree109317a7b64dcf21fa5d761ef933d772f5216281 /pint
parent8707693c0418fb7c49e85d872f02b3021f057cae (diff)
downloadpint-e0d6686d08aa9447270d0ba4139e7d81824482a3.tar.gz
NumPy function util cleanup based on feedback
Diffstat (limited to 'pint')
-rw-r--r--pint/numpy_func.py63
1 files changed, 32 insertions, 31 deletions
diff --git a/pint/numpy_func.py b/pint/numpy_func.py
index beafd4f..f7ecf7d 100644
--- a/pint/numpy_func.py
+++ b/pint/numpy_func.py
@@ -8,6 +8,7 @@
"""
from inspect import signature
+from itertools import chain
import warnings
from .compat import NP_NO_VALUE, is_upcast_type, np, eq
@@ -36,16 +37,34 @@ def _is_quantity_sequence(arg):
def _get_first_input_units(args, kwargs={}):
"""Obtain the first valid unit from a collection of args and kwargs."""
- args_combo = list(args) + list(kwargs.values())
- out_units=None
- for arg in args_combo:
+ for arg in chain(args, kwargs.values()):
if _is_quantity(arg):
- out_units = arg.units
+ return arg.units
elif _is_quantity_sequence(arg):
- out_units = arg[0].units
- if out_units is not None:
- break
- return out_units
+ return arg[0].units
+
+
+def convert_arg(arg, pre_calc_units):
+ """Convert quantities and sequences of quantities to pre_calc_units and strip units.
+
+ Helper function for convert_to_consistent_units.
+ """
+ if pre_calc_units is not None:
+ if _is_quantity(arg):
+ return arg.m_as(pre_calc_units)
+ elif _is_quantity_sequence(arg):
+ return [item.m_as(pre_calc_units) for item in arg]
+ elif arg is not None:
+ if pre_calc_units.dimensionless:
+ return pre_calc_units._REGISTRY.Quantity(arg).m_as(pre_calc_units)
+ else:
+ raise DimensionalityError('dimensionless', pre_calc_units)
+ else:
+ if _is_quantity(arg):
+ return arg.m
+ elif _is_quantity_sequence(arg):
+ return [item.m for item in arg]
+ return arg
def convert_to_consistent_units(*args, pre_calc_units=None, **kwargs):
@@ -56,27 +75,9 @@ def convert_to_consistent_units(*args, pre_calc_units=None, **kwargs):
Quantities and returns the magnitudes. Other args/kwargs are treated as dimensionless
Quantities. If pre_calc_units is None, units are simply stripped.
"""
- def convert_arg(arg):
- if pre_calc_units is not None:
- if _is_quantity(arg):
- return arg.m_as(pre_calc_units)
- elif _is_quantity_sequence(arg):
- return [item.m_as(pre_calc_units) for item in arg]
- elif arg is not None:
- if pre_calc_units.dimensionless:
- return pre_calc_units._REGISTRY.Quantity(arg).m_as(pre_calc_units)
- else:
- raise DimensionalityError('dimensionless', pre_calc_units)
- else:
- if _is_quantity(arg):
- return arg.m
- elif _is_quantity_sequence(arg):
- return [item.m for item in arg]
- return arg
-
- new_args = tuple(convert_arg(arg) for arg in args)
- new_kwargs = {key: convert_arg(arg) for key, arg in kwargs.items()}
- return new_args, new_kwargs
+ return (tuple(convert_arg(arg, pre_calc_units=pre_calc_units) for arg in args),
+ {key: convert_arg(arg, pre_calc_units=pre_calc_units)
+ for key, arg in kwargs.items()})
def unwrap_and_wrap_consistent_units(*args):
@@ -174,7 +175,6 @@ def implement_func(func_type, func_str, input_units=None, output_unit=None):
@implements(func_str, func_type)
def implementation(*args, **kwargs):
- args_and_kwargs = list(args) + list(kwargs.values())
first_input_units = _get_first_input_units(args, kwargs)
if input_units == "all_consistent":
# Match all input args/kwargs to same units
@@ -196,7 +196,8 @@ def implement_func(func_type, func_str, input_units=None, output_unit=None):
result_unit = first_input_units
elif output_unit in ['sum', 'mul', 'delta', 'delta,div', 'div', 'variance', 'square',
'sqrt', 'reciprocal', 'size']:
- result_unit = get_op_output_unit(output_unit, first_input_units, args_and_kwargs)
+ result_unit = get_op_output_unit(output_unit, first_input_units,
+ tuple(chain(args, kwargs.values())))
else:
result_unit = output_unit