diff options
Diffstat (limited to 'pint/numpy_func.py')
-rw-r--r-- | pint/numpy_func.py | 59 |
1 files changed, 31 insertions, 28 deletions
diff --git a/pint/numpy_func.py b/pint/numpy_func.py index 52d791d..9ebb4df 100644 --- a/pint/numpy_func.py +++ b/pint/numpy_func.py @@ -11,8 +11,8 @@ from inspect import signature from itertools import chain import warnings -from .compat import NP_NO_VALUE, is_upcast_type, np, eq -from .errors import DimensionalityError, UnitStrippedWarning +from .compat import is_upcast_type, np, eq +from .errors import DimensionalityError from .util import iterable, sized HANDLED_UFUNCS = {} @@ -92,8 +92,8 @@ def convert_to_consistent_units(*args, pre_calc_units=None, **kwargs): def unwrap_and_wrap_consistent_units(*args): """Strip units from args while providing a rewrapping function. - Returns the given args as parsed by convert_to_consistent_units assuming units of first - arg with units, along with a wrapper to restore that unit to the output. + Returns the given args as parsed by convert_to_consistent_units assuming units of + first arg with units, along with a wrapper to restore that unit to the output. """ first_input_units = _get_first_input_units(args) args, _ = convert_to_consistent_units(*args, pre_calc_units=first_input_units) @@ -167,7 +167,9 @@ def get_op_output_unit(unit_op, first_input_units, all_args=[], size=None): def implements(numpy_func_string, func_type): - """Register an __array_function__/__array_ufunc__ implementation for Quantity objects.""" + """Register an __array_function__/__array_ufunc__ implementation for Quantity + objects. + """ def decorator(func): if func_type == "function": @@ -198,8 +200,8 @@ def implement_func(func_type, func_str, input_units=None, output_unit=None): *args, pre_calc_units=first_input_units, **kwargs ) else: - # Match all input args/kwargs to input_units, or if input_units is None, simply - # strip units + # Match all input args/kwargs to input_units, or if input_units is None, + # simply strip units stripped_args, stripped_kwargs = convert_to_consistent_units( *args, pre_calc_units=input_units, **kwargs ) @@ -237,19 +239,19 @@ def implement_func(func_type, func_str, input_units=None, output_unit=None): Define ufunc behavior collections. - `strip_unit_input_output_ufuncs`: units should be ignored on both input and output -- `matching_input_bare_output_ufuncs`: inputs are converted to matching units, but outputs are - returned as-is -- `matching_input_set_units_output_ufuncs`: inputs are converted to matching units, and the - output units are as set by the dict value -- `set_units_ufuncs`: dict values are specified as (in_unit, out_unit), so that inputs are - converted to in_unit before having magnitude passed to NumPy ufunc, and outputs are set to - have out_unit +- `matching_input_bare_output_ufuncs`: inputs are converted to matching units, but + outputs are returned as-is +- `matching_input_set_units_output_ufuncs`: inputs are converted to matching units, and + the output units are as set by the dict value +- `set_units_ufuncs`: dict values are specified as (in_unit, out_unit), so that inputs + are converted to in_unit before having magnitude passed to NumPy ufunc, and outputs + are set to have out_unit - `matching_input_copy_units_output_ufuncs`: inputs are converted to matching units, and outputs are set to that unit -- `copy_units_output_ufuncs`: input units (except the first) are ignored, and output is set to - that of the first input unit -- `op_units_output_ufuncs`: determine output unit from input unit as determined by operation - (see `get_op_output_unit`) +- `copy_units_output_ufuncs`: input units (except the first) are ignored, and output is + set to that of the first input unit +- `op_units_output_ufuncs`: determine output unit from input unit as determined by + operation (see `get_op_output_unit`) """ strip_unit_input_output_ufuncs = ["isnan", "isinf", "isfinite", "signbit"] matching_input_bare_output_ufuncs = [ @@ -289,9 +291,10 @@ set_units_ufuncs = { "logaddexp": ("", ""), "logaddexp2": ("", ""), } -# TODO (#905 follow-up): while this matches previous behavior, some of these have optional -# arguments that should not be Quantities. This should be fixed, and tests using these -# optional arguments should be added. +# TODO (#905 follow-up): +# while this matches previous behavior, some of these have optional arguments that +# should not be Quantities. This should be fixed, and tests using these optional +# arguments should be added. matching_input_copy_units_output_ufuncs = [ "compress", "conj", @@ -398,9 +401,9 @@ def _power(x1, x2): def _add_subtract_handle_non_quantity_zero(x1, x2): - # As in #121/#122, if a value is 0 (but not Quantity 0) do the operation without checking - # units. We do the calculation instead of just returning the same value to enforce any - # shape checking and type casting due to the operation. + # As in #121/#122, if a value is 0 (but not Quantity 0) do the operation without + # checking units. We do the calculation instead of just returning the same value to + # enforce any shape checking and type casting due to the operation. if eq(x1, 0, True): (x2,), output_wrap = unwrap_and_wrap_consistent_units(x2) elif eq(x2, 0, True): @@ -546,8 +549,8 @@ def _isin(element, test_elements, assume_unique=False, invert=False): try: compatible_test_elements.append(test_element.m_as(element.units)) except DimensionalityError: - # Incompatible unit test elements cannot be in element, but others in sequence - # may + # Incompatible unit test elements cannot be in element, but others in + # sequence may pass test_elements = compatible_test_elements else: @@ -679,8 +682,8 @@ for func_str in ["block", "hstack", "vstack", "dstack", "column_stack"]: "function", func_str, input_units="all_consistent", output_unit="match_input" ) -# Handle cumulative products (which must be dimensionless for consistent units across output -# array) +# Handle cumulative products (which must be dimensionless for consistent units across +# output array) for func_str in ["cumprod", "cumproduct", "nancumprod"]: implement_func( "function", func_str, input_units="dimensionless", output_unit="match_input" |