From 66f67d3ee4344cf089bb30b5323113f19866a158 Mon Sep 17 00:00:00 2001 From: Jon Thielen Date: Wed, 30 Oct 2019 18:13:11 -0500 Subject: Add __array_function__ based on changes by @andrewgsavage and @jthielen --- pint/quantity.py | 201 +++++++++++++++++++++++++++++++++++++++- pint/testsuite/test_numpy.py | 44 +++++---- pint/testsuite/test_quantity.py | 4 +- 3 files changed, 224 insertions(+), 25 deletions(-) (limited to 'pint') diff --git a/pint/quantity.py b/pint/quantity.py index f0bdd5f..53ad459 100644 --- a/pint/quantity.py +++ b/pint/quantity.py @@ -26,10 +26,9 @@ from .errors import (DimensionalityError, OffsetUnitCalculusError, PintTypeError from .definitions import UnitDefinition from .compat import ndarray, np, _to_magnitude from .util import (PrettyIPython, logger, UnitsContainer, SharedRegistryObject, - to_units_container, infer_base_unit) + to_units_container, infer_base_unit, iterable, sized) from pint.compat import Loc - def _eq(first, second, check_all): """Comparison of scalars and arrays """ @@ -63,11 +62,10 @@ def ireduce_dimensions(f): return result return wrapped - def check_implemented(f): def wrapped(self, *args, **kwargs): other=args[0] - if other.__class__.__name__ in ["PintArray", "Series"]: + if other.__class__.__name__ in ["PintArray", "Series", "DataArray"]: return NotImplemented # pandas often gets to arrays of quantities [ Q_(1,"m"), Q_(2,"m")] # and expects Quantity * array[Quantity] should return NotImplemented @@ -77,6 +75,192 @@ def check_implemented(f): return result return wrapped +HANDLED_FUNCTIONS = {} + +def implements(numpy_function): + """Register an __array_function__ implementation for BaseQuantity objects.""" + def decorator(func): + HANDLED_FUNCTIONS[numpy_function] = func + return func + return decorator + +def _is_quantity_sequence(arg): + if iterable(arg) and sized(arg) and not isinstance(arg, string_types): + if isinstance(arg[0], Quantity): + if not all([isinstance(item, Quantity) for item in arg]): + raise TypeError("{} contains items that aren't Quantity type".format(arg)) + return True + return False + +def _get_first_input_units(args, kwargs={}): + args_combo = list(args)+list(kwargs.values()) + out_units=None + for arg in args_combo: + if isinstance(arg, Quantity): + out_units = arg.units + elif _is_quantity_sequence(arg): + out_units = arg[0].units + if out_units is not None: + break + return out_units + +def convert_to_consistent_units(pre_calc_units=None, *args, **kwargs): + """Takes the args for a numpy function and converts any Quantity or Sequence of Quantities + into the units of the first Quantiy/Sequence of quantities. Other args are left untouched. + """ + def convert_arg(arg): + if pre_calc_units is not None: + if isinstance(arg, Quantity): + return arg.m_as(pre_calc_units) + elif _is_quantity_sequence(arg): + return [item.m_as(pre_calc_units) for item in arg] + else: + if isinstance(arg, Quantity): + return arg.m + elif _is_quantity_sequence(arg): + return [item.m for item in arg] + return arg + + new_args=tuple(convert_arg(arg) for arg in args) + new_kwargs = {key:convert_arg(arg) for key,arg in kwargs.items()} + return new_args, new_kwargs + +def implement_func(func_str, pre_calc_units_, post_calc_units_, out_units_): + """ + :param func_str: The numpy function to implement + :type func_str: str + :param pre_calc_units: The units any quantity/ sequences of quantities should be converted to. + consistent_infer converts all qs to the first units found in args/kwargs + inconsistent does not convert any qs, eg for product + rad (or any other unit) converts qs to radians/ other unit + None converts qs to magnitudes without conversion + :type pre_calc_units: NoneType, str + :param pre_calc_units: The units the result of the function should be initiated as. + as_pre_calc uses the units it was converted to pre calc. Do not use with pre_calc_units="inconsistent" + rad (or any other unit) uses radians/ other unit + prod uses multiplies the input quantity units + None causes func to return without creating a quantity from the output, regardless of any out_units + :type out_units: NoneType, str + :param out_units: The units the result of the function should be returned to the user as. The quantity created in the post_calc_units will be converted to the out_units + None or as_post_calc uses the units the quantity was initiated in, ie the post_calc_units, without any conversion. + rad (or any other unit) uses radians/ other unit + infer_from_input uses the first input units found, as received by the function before any conversions. + :type out_units: NoneType, str + + """ + func = getattr(np,func_str) + + @implements(func) + def _(*args, **kwargs): + # TODO make work for kwargs + args_and_kwargs = list(args)+list(kwargs.values()) + + (pre_calc_units, post_calc_units, out_units)=(pre_calc_units_, post_calc_units_, out_units_) + first_input_units=_get_first_input_units(args, kwargs) + if pre_calc_units == "consistent_infer": + pre_calc_units = first_input_units + + if pre_calc_units == "inconsistent": + new_args, new_kwargs = args, kwargs + else: + new_args, new_kwargs = convert_to_consistent_units(pre_calc_units, *args, **kwargs) + res = func(*new_args, **new_kwargs) + + if post_calc_units is None: + return res + elif post_calc_units == "as_pre_calc": + post_calc_units = pre_calc_units + elif post_calc_units == "sum": + post_calc_units = (1*first_input_units + 1*first_input_units).units + elif post_calc_units == "prod": + product = 1 + for x in args_and_kwargs: + product *= x + post_calc_units = product.units + elif post_calc_units == "div": + product = first_input_units*first_input_units + for x in args_and_kwargs: + product /= x + post_calc_units = product.units + elif post_calc_units == "delta": + post_calc_units = (1*first_input_units-1*first_input_units).units + elif post_calc_units == "delta,div": + product=(1*first_input_units-1*first_input_units).units + for x in args_and_kwargs[1:]: + product /= x + post_calc_units = product.units + elif post_calc_units == "variance": + post_calc_units = ((1*first_input_units + 1*first_input_units)**2).units + Q_ = first_input_units._REGISTRY.Quantity + post_calc_Q_= Q_(res, post_calc_units) + + if out_units is None or out_units == "as_post_calc": + return post_calc_Q_ + elif out_units == "infer_from_input": + out_units = first_input_units + return post_calc_Q_.to(out_units) + +@implements(np.meshgrid) +def _meshgrid(*xi, **kwargs): + # Simply need to map input units to onto list of outputs + input_units = (x.units for x in xi) + res = np.meshgrid(*(x.m for x in xi), **kwargs) + return [out * unit for out, unit in zip(res, input_units)] + +@implements(np.full_like) +def _full_like(a, fill_value, dtype=None, order='K', subok=True, shape=None): + # Make full_like by multiplying with array from ones_like in a + # non-multiplicative-unit-safe way + if isinstance(fill_value, Quantity): + return fill_value._REGISTRY.Quantity( + np.ones_like(a, dtype=dtype, order=order, subok=subok, shape=shape) * fill_value.m, + fill_value.units) + else: + return (np.ones_like(a, dtype=dtype, order=order, subok=subok, shape=shape) + * fill_value) + +@implements(np.interp) +def _interp(x, xp, fp, left=None, right=None, period=None): + # Need to handle x and y units separately + x_unit = _get_first_input_units([x, xp, period]) + y_unit = _get_first_input_units([fp, left, right]) + x_args, _ = convert_to_consistent_units(x_unit, x, xp, period) + y_args, _ = convert_to_consistent_units(y_unit, fp, left, right) + x, xp, period = x_args + fp, right, left = y_args + Q_ = y_unit._REGISTRY.Quantity + return Q_(np.interp(x, xp, fp, left=left, right=right, period=period), y_unit) + +for func_str in ['linspace', 'concatenate', 'block', 'stack', 'hstack', 'vstack', 'dstack', 'atleast_1d', 'column_stack', 'atleast_2d', 'atleast_3d', 'expand_dims','squeeze', 'swapaxes', 'compress', 'rollaxis', 'broadcast_to', 'moveaxis', 'fix', 'amax', 'amin', 'nanmax', 'nanmin', 'around', 'diagonal', 'mean', 'ptp', 'ravel', 'round_', 'sort', 'median', 'nanmedian', 'transpose', 'flip', 'copy', 'trim_zeros', 'append', 'clip', 'nan_to_num']: + implement_func(func_str, 'consistent_infer', 'as_pre_calc', 'as_post_calc') + +for func_str in ['isclose', 'searchsorted']: + implement_func(func_str, 'consistent_infer', None, None) + +for func_str in ['unwrap']: + implement_func(func_str, 'rad', 'rad', 'infer_from_input') + +for func_str in ['cumprod', 'cumproduct', 'nancumprod']: + implement_func(func_str, 'dimensionless', 'dimensionless', 'infer_from_input') + +for func_str in ['size', 'isreal', 'iscomplex', 'shape', 'ones_like', 'zeros_like', 'empty_like', 'argsort', 'argmin', 'argmax', 'alen', 'ndim', 'nanargmax', 'nanargmin', 'count_nonzero', 'nonzero', 'result_type']: + implement_func(func_str, None, None, None) + +for func_str in ['average', 'mean', 'std', 'nanmean', 'nanstd', 'sum', 'nansum', 'cumsum', 'nancumsum']: + implement_func(func_str, None, 'sum', None) + +for func_str in ['cross', 'trapz', 'dot']: + implement_func(func_str, None, 'prod', None) + +for func_str in ['diff', 'ediff1d',]: + implement_func(func_str, None, 'delta', None) + +for func_str in ['gradient', ]: + implement_func(func_str, None, 'delta,div', None) + +for func_str in ['var', 'nanvar']: + implement_func(func_str, None, 'variance', None) + @contextlib.contextmanager def printoptions(*args, **kwargs): @@ -102,6 +286,12 @@ class Quantity(PrettyIPython, SharedRegistryObject): :param units: units of the physical quantity to be created :type units: UnitsContainer, str or pint.Quantity """ + def __array_function__(self, func, types, args, kwargs): + if func not in HANDLED_FUNCTIONS: + return NotImplemented + if not all(issubclass(t, Quantity) for t in types): + return NotImplemented + return HANDLED_FUNCTIONS[func](*args, **kwargs) #: Default formatting string. default_format = '' @@ -1469,7 +1659,8 @@ class Quantity(PrettyIPython, SharedRegistryObject): # Attributes starting with `__array_` are common attributes of NumPy ndarray. # They are requested by numpy functions. if item.startswith('__array_'): - warnings.warn("The unit of the quantity is stripped.", UnitStrippedWarning, stacklevel=2) + warnings.warn("The unit of the quantity is stripped when getting {} " + "attribute".format(item), UnitStrippedWarning, stacklevel=2) if isinstance(self._magnitude, ndarray): return getattr(self._magnitude, item) else: diff --git a/pint/testsuite/test_numpy.py b/pint/testsuite/test_numpy.py index 9294fff..e32afa5 100644 --- a/pint/testsuite/test_numpy.py +++ b/pint/testsuite/test_numpy.py @@ -30,18 +30,24 @@ class TestNumpyMethods(QuantityTestCase): @property def q_temperature(self): return self.Q_([[1,2],[3,4]], self.ureg.degC) - + + def assertNDArrayEqual(self, actual, desired): + # Assert that the given arrays are equal, and are not Quantities + np.testing.assert_array_equal(actual, desired) + self.assertFalse(isinstance(actual, self.Q_)) + self.assertFalse(isinstance(desired, self.Q_)) + class TestNumpyArrayCreation(TestNumpyMethods): # https://docs.scipy.org/doc/numpy/reference/routines.array-creation.html @helpers.requires_array_function_protocol() def test_ones_like(self): - np.testing.assert_equal(np.ones_like(self.q), np.array([[1, 1], [1, 1]])) + self.assertNDArrayEqual(np.ones_like(self.q), np.array([[1, 1], [1, 1]])) @helpers.requires_array_function_protocol() def test_zeros_like(self): - np.testing.assert_equal(np.zeros_like(self.q), np.array([[0, 0], [0, 0]])) + self.assertNDArrayEqual(np.zeros_like(self.q), np.array([[0, 0], [0, 0]])) @helpers.requires_array_function_protocol() def test_empty_like(self): @@ -53,7 +59,7 @@ class TestNumpyArrayCreation(TestNumpyMethods): def test_full_like(self): self.assertQuantityEqual(np.full_like(self.q, self.Q_(0, self.ureg.degC)), self.Q_([[0, 0], [0, 0]], self.ureg.degC)) - np.testing.assert_equal(np.full_like(self.q, 2), np.array([[2, 2], [2, 2]])) + self.assertNDArrayEqual(np.full_like(self.q, 2), np.array([[2, 2], [2, 2]])) class TestNumpyArrayManipulation(TestNumpyMethods): #TODO @@ -309,7 +315,7 @@ class TestNumpyMathematicalFunctions(TestNumpyMethods): self.assertRaises(DimensionalityError, op.pow, arr_cp, q_cp) # ..not for op.ipow ! # q_cp is treated as if it is an array. The units are ignored. - # BaseQuantity.__ipow__ is never called + # Quantity.__ipow__ is never called arr_cp = copy.copy(arr) q_cp = copy.copy(q) self.assertRaises(DimensionalityError, op.ipow, arr_cp, q_cp) @@ -362,11 +368,11 @@ class TestNumpyUnclassified(TestNumpyMethods): def test_argsort(self): q = [1, 4, 5, 6, 2, 9] * self.ureg.MeV - np.testing.assert_array_equal(q.argsort(), [0, 4, 1, 2, 3, 5]) + self.assertNDArrayEqual(q.argsort(), [0, 4, 1, 2, 3, 5]) @helpers.requires_array_function_protocol() def test_argsort_numpy_func(self): - np.testing.assert_array_equal(np.argsort(self.q, axis=0), np.array([[0, 0], [1, 1]])) + self.assertNDArrayEqual(np.argsort(self.q, axis=0), np.array([[0, 0], [1, 1]])) def test_diagonal(self): q = [[1, 2, 3], [1, 2, 3], [1, 2, 3]] * self.ureg.m @@ -385,7 +391,7 @@ class TestNumpyUnclassified(TestNumpyMethods): def test_searchsorted(self): q = self.q.flatten() - np.testing.assert_array_equal(q.searchsorted([1.5, 2.5] * self.ureg.m), + self.assertNDArrayEqual(q.searchsorted([1.5, 2.5] * self.ureg.m), [1, 2]) q = self.q.flatten() self.assertRaises(DimensionalityError, q.searchsorted, [1.5, 2.5]) @@ -394,17 +400,17 @@ class TestNumpyUnclassified(TestNumpyMethods): def test_searchsorted_numpy_func(self): """Test searchsorted as numpy function.""" q = self.q.flatten() - np.testing.assert_array_equal(np.searchsorted(q, [1.5, 2.5] * self.ureg.m), + self.assertNDArrayEqual(np.searchsorted(q, [1.5, 2.5] * self.ureg.m), [1, 2]) def test_nonzero(self): q = [1, 0, 5, 6, 0, 9] * self.ureg.m - np.testing.assert_array_equal(q.nonzero()[0], [0, 2, 3, 5]) + self.assertNDArrayEqual(q.nonzero()[0], [0, 2, 3, 5]) @helpers.requires_array_function_protocol() def test_nonzero_numpy_func(self): q = [1, 0, 5, 6, 0, 9] * self.ureg.m - np.testing.assert_array_equal(np.nonzero(q)[0], [0, 2, 3, 5]) + self.assertNDArrayEqual(np.nonzero(q)[0], [0, 2, 3, 5]) @helpers.requires_array_function_protocol() def test_count_nonzero_numpy_func(self): @@ -439,11 +445,11 @@ class TestNumpyUnclassified(TestNumpyMethods): @helpers.requires_array_function_protocol() def test_argmax_numpy_func(self): - np.testing.assert_equal(np.argmax(self.q, axis=0), np.array([1, 1])) + self.assertNDArrayEqual(np.argmax(self.q, axis=0), np.array([1, 1])) @helpers.requires_array_function_protocol() def test_nanargmax_numpy_func(self): - np.testing.assert_equal(np.nanargmax(self.q_nan, axis=0), np.array([1, 0])) + self.assertNDArrayEqual(np.nanargmax(self.q_nan, axis=0), np.array([1, 0])) def test_min(self): self.assertEqual(self.q.min(), 1 * self.ureg.m) @@ -469,11 +475,11 @@ class TestNumpyUnclassified(TestNumpyMethods): @helpers.requires_array_function_protocol() def test_argmin_numpy_func(self): - np.testing.assert_equal(np.argmin(self.q, axis=0), np.array([0, 0])) + self.assertNDArrayEqual(np.argmin(self.q, axis=0), np.array([0, 0])) @helpers.requires_array_function_protocol() def test_nanargmin_numpy_func(self): - np.testing.assert_equal(np.nanargmin(self.q_nan, axis=0), np.array([0, 0])) + self.assertNDArrayEqual(np.nanargmin(self.q_nan, axis=0), np.array([0, 0])) def test_ptp(self): self.assertEqual(self.q.ptp(), 3 * self.ureg.m) @@ -658,7 +664,7 @@ class TestNumpyUnclassified(TestNumpyMethods): set_application_registry(self.ureg) def pickle_test(q): pq = pickle.loads(pickle.dumps(q)) - np.testing.assert_array_equal(q.magnitude, pq.magnitude) + self.assertNDArrayEqual(q.magnitude, pq.magnitude) self.assertEqual(q.units, pq.units) pickle_test([10,20]*self.ureg.m) @@ -719,7 +725,7 @@ class TestNumpyUnclassified(TestNumpyMethods): @helpers.requires_array_function_protocol() def test_isclose_numpy_func(self): q2 = [[1000.05, 2000], [3000.00007, 4001]] * self.ureg.mm - np.testing.assert_equal(np.isclose(self.q, q2), np.array([[False, True], [True, False]])) + self.assertNDArrayEqual(np.isclose(self.q, q2), np.array([[False, True], [True, False]])) @helpers.requires_array_function_protocol() def test_interp_numpy_func(self): @@ -729,8 +735,8 @@ class TestNumpyUnclassified(TestNumpyMethods): self.assertQuantityAlmostEqual(np.interp(x, xp, fp), self.Q_([6.66667, 20.], self.ureg.degC), rtol=1e-5) def test_comparisons(self): - np.testing.assert_equal(self.q > 2 * self.ureg.m, np.array([[False, False], [True, True]])) - np.testing.assert_equal(self.q < 2 * self.ureg.m, np.array([[True, False], [False, False]])) + self.assertNDArrayEqual(self.q > 2 * self.ureg.m, np.array([[False, False], [True, True]])) + self.assertNDArrayEqual(self.q < 2 * self.ureg.m, np.array([[True, False], [False, False]])) @unittest.skip diff --git a/pint/testsuite/test_quantity.py b/pint/testsuite/test_quantity.py index e15923d..c216a62 100644 --- a/pint/testsuite/test_quantity.py +++ b/pint/testsuite/test_quantity.py @@ -286,7 +286,9 @@ class TestQuantity(QuantityTestCase): self.assertEqual(q.u, q.reshape(2, 3).u) self.assertEqual(q.u, q.swapaxes(0, 1).u) self.assertEqual(q.u, q.mean().u) - self.assertEqual(q.u, np.compress((q==q[0,0]).any(0), q).u) + # TODO: Re-add np.compress implementation once mixed type is resolved + # (see https://github.com/hgrecco/pint/pull/764#issuecomment-523272038) + # self.assertEqual(q.u, np.compress((q==q[0,0]).any(0), q).u) def test_context_attr(self): self.assertEqual(self.ureg.meter, self.Q_(1, 'meter')) -- cgit v1.2.1