summaryrefslogtreecommitdiff
path: root/pint
diff options
context:
space:
mode:
authorJon Thielen <github@jont.cc>2019-10-30 18:13:11 -0500
committerJon Thielen <github@jont.cc>2019-12-05 12:18:22 -0600
commit66f67d3ee4344cf089bb30b5323113f19866a158 (patch)
treea2390c028811f64c60bf2d48318c7d0554c2bc81 /pint
parent6734633ba3e7083b78fb3108a01974f0bca07892 (diff)
downloadpint-66f67d3ee4344cf089bb30b5323113f19866a158.tar.gz
Add __array_function__ based on changes by @andrewgsavage and @jthielen
Diffstat (limited to 'pint')
-rw-r--r--pint/quantity.py201
-rw-r--r--pint/testsuite/test_numpy.py44
-rw-r--r--pint/testsuite/test_quantity.py4
3 files changed, 224 insertions, 25 deletions
diff --git a/pint/quantity.py b/pint/quantity.py
index f0bdd5f..53ad459 100644
--- a/pint/quantity.py
+++ b/pint/quantity.py
@@ -26,10 +26,9 @@ from .errors import (DimensionalityError, OffsetUnitCalculusError, PintTypeError
from .definitions import UnitDefinition
from .compat import ndarray, np, _to_magnitude
from .util import (PrettyIPython, logger, UnitsContainer, SharedRegistryObject,
- to_units_container, infer_base_unit)
+ to_units_container, infer_base_unit, iterable, sized)
from pint.compat import Loc
-
def _eq(first, second, check_all):
"""Comparison of scalars and arrays
"""
@@ -63,11 +62,10 @@ def ireduce_dimensions(f):
return result
return wrapped
-
def check_implemented(f):
def wrapped(self, *args, **kwargs):
other=args[0]
- if other.__class__.__name__ in ["PintArray", "Series"]:
+ if other.__class__.__name__ in ["PintArray", "Series", "DataArray"]:
return NotImplemented
# pandas often gets to arrays of quantities [ Q_(1,"m"), Q_(2,"m")]
# and expects Quantity * array[Quantity] should return NotImplemented
@@ -77,6 +75,192 @@ def check_implemented(f):
return result
return wrapped
+HANDLED_FUNCTIONS = {}
+
+def implements(numpy_function):
+ """Register an __array_function__ implementation for BaseQuantity objects."""
+ def decorator(func):
+ HANDLED_FUNCTIONS[numpy_function] = func
+ return func
+ return decorator
+
+def _is_quantity_sequence(arg):
+ if iterable(arg) and sized(arg) and not isinstance(arg, string_types):
+ if isinstance(arg[0], Quantity):
+ if not all([isinstance(item, Quantity) for item in arg]):
+ raise TypeError("{} contains items that aren't Quantity type".format(arg))
+ return True
+ return False
+
+def _get_first_input_units(args, kwargs={}):
+ args_combo = list(args)+list(kwargs.values())
+ out_units=None
+ for arg in args_combo:
+ if isinstance(arg, Quantity):
+ out_units = arg.units
+ elif _is_quantity_sequence(arg):
+ out_units = arg[0].units
+ if out_units is not None:
+ break
+ return out_units
+
+def convert_to_consistent_units(pre_calc_units=None, *args, **kwargs):
+ """Takes the args for a numpy function and converts any Quantity or Sequence of Quantities
+ into the units of the first Quantiy/Sequence of quantities. Other args are left untouched.
+ """
+ def convert_arg(arg):
+ if pre_calc_units is not None:
+ if isinstance(arg, Quantity):
+ return arg.m_as(pre_calc_units)
+ elif _is_quantity_sequence(arg):
+ return [item.m_as(pre_calc_units) for item in arg]
+ else:
+ if isinstance(arg, Quantity):
+ return arg.m
+ elif _is_quantity_sequence(arg):
+ return [item.m for item in arg]
+ return arg
+
+ new_args=tuple(convert_arg(arg) for arg in args)
+ new_kwargs = {key:convert_arg(arg) for key,arg in kwargs.items()}
+ return new_args, new_kwargs
+
+def implement_func(func_str, pre_calc_units_, post_calc_units_, out_units_):
+ """
+ :param func_str: The numpy function to implement
+ :type func_str: str
+ :param pre_calc_units: The units any quantity/ sequences of quantities should be converted to.
+ consistent_infer converts all qs to the first units found in args/kwargs
+ inconsistent does not convert any qs, eg for product
+ rad (or any other unit) converts qs to radians/ other unit
+ None converts qs to magnitudes without conversion
+ :type pre_calc_units: NoneType, str
+ :param pre_calc_units: The units the result of the function should be initiated as.
+ as_pre_calc uses the units it was converted to pre calc. Do not use with pre_calc_units="inconsistent"
+ rad (or any other unit) uses radians/ other unit
+ prod uses multiplies the input quantity units
+ None causes func to return without creating a quantity from the output, regardless of any out_units
+ :type out_units: NoneType, str
+ :param out_units: The units the result of the function should be returned to the user as. The quantity created in the post_calc_units will be converted to the out_units
+ None or as_post_calc uses the units the quantity was initiated in, ie the post_calc_units, without any conversion.
+ rad (or any other unit) uses radians/ other unit
+ infer_from_input uses the first input units found, as received by the function before any conversions.
+ :type out_units: NoneType, str
+
+ """
+ func = getattr(np,func_str)
+
+ @implements(func)
+ def _(*args, **kwargs):
+ # TODO make work for kwargs
+ args_and_kwargs = list(args)+list(kwargs.values())
+
+ (pre_calc_units, post_calc_units, out_units)=(pre_calc_units_, post_calc_units_, out_units_)
+ first_input_units=_get_first_input_units(args, kwargs)
+ if pre_calc_units == "consistent_infer":
+ pre_calc_units = first_input_units
+
+ if pre_calc_units == "inconsistent":
+ new_args, new_kwargs = args, kwargs
+ else:
+ new_args, new_kwargs = convert_to_consistent_units(pre_calc_units, *args, **kwargs)
+ res = func(*new_args, **new_kwargs)
+
+ if post_calc_units is None:
+ return res
+ elif post_calc_units == "as_pre_calc":
+ post_calc_units = pre_calc_units
+ elif post_calc_units == "sum":
+ post_calc_units = (1*first_input_units + 1*first_input_units).units
+ elif post_calc_units == "prod":
+ product = 1
+ for x in args_and_kwargs:
+ product *= x
+ post_calc_units = product.units
+ elif post_calc_units == "div":
+ product = first_input_units*first_input_units
+ for x in args_and_kwargs:
+ product /= x
+ post_calc_units = product.units
+ elif post_calc_units == "delta":
+ post_calc_units = (1*first_input_units-1*first_input_units).units
+ elif post_calc_units == "delta,div":
+ product=(1*first_input_units-1*first_input_units).units
+ for x in args_and_kwargs[1:]:
+ product /= x
+ post_calc_units = product.units
+ elif post_calc_units == "variance":
+ post_calc_units = ((1*first_input_units + 1*first_input_units)**2).units
+ Q_ = first_input_units._REGISTRY.Quantity
+ post_calc_Q_= Q_(res, post_calc_units)
+
+ if out_units is None or out_units == "as_post_calc":
+ return post_calc_Q_
+ elif out_units == "infer_from_input":
+ out_units = first_input_units
+ return post_calc_Q_.to(out_units)
+
+@implements(np.meshgrid)
+def _meshgrid(*xi, **kwargs):
+ # Simply need to map input units to onto list of outputs
+ input_units = (x.units for x in xi)
+ res = np.meshgrid(*(x.m for x in xi), **kwargs)
+ return [out * unit for out, unit in zip(res, input_units)]
+
+@implements(np.full_like)
+def _full_like(a, fill_value, dtype=None, order='K', subok=True, shape=None):
+ # Make full_like by multiplying with array from ones_like in a
+ # non-multiplicative-unit-safe way
+ if isinstance(fill_value, Quantity):
+ return fill_value._REGISTRY.Quantity(
+ np.ones_like(a, dtype=dtype, order=order, subok=subok, shape=shape) * fill_value.m,
+ fill_value.units)
+ else:
+ return (np.ones_like(a, dtype=dtype, order=order, subok=subok, shape=shape)
+ * fill_value)
+
+@implements(np.interp)
+def _interp(x, xp, fp, left=None, right=None, period=None):
+ # Need to handle x and y units separately
+ x_unit = _get_first_input_units([x, xp, period])
+ y_unit = _get_first_input_units([fp, left, right])
+ x_args, _ = convert_to_consistent_units(x_unit, x, xp, period)
+ y_args, _ = convert_to_consistent_units(y_unit, fp, left, right)
+ x, xp, period = x_args
+ fp, right, left = y_args
+ Q_ = y_unit._REGISTRY.Quantity
+ return Q_(np.interp(x, xp, fp, left=left, right=right, period=period), y_unit)
+
+for func_str in ['linspace', 'concatenate', 'block', 'stack', 'hstack', 'vstack', 'dstack', 'atleast_1d', 'column_stack', 'atleast_2d', 'atleast_3d', 'expand_dims','squeeze', 'swapaxes', 'compress', 'rollaxis', 'broadcast_to', 'moveaxis', 'fix', 'amax', 'amin', 'nanmax', 'nanmin', 'around', 'diagonal', 'mean', 'ptp', 'ravel', 'round_', 'sort', 'median', 'nanmedian', 'transpose', 'flip', 'copy', 'trim_zeros', 'append', 'clip', 'nan_to_num']:
+ implement_func(func_str, 'consistent_infer', 'as_pre_calc', 'as_post_calc')
+
+for func_str in ['isclose', 'searchsorted']:
+ implement_func(func_str, 'consistent_infer', None, None)
+
+for func_str in ['unwrap']:
+ implement_func(func_str, 'rad', 'rad', 'infer_from_input')
+
+for func_str in ['cumprod', 'cumproduct', 'nancumprod']:
+ implement_func(func_str, 'dimensionless', 'dimensionless', 'infer_from_input')
+
+for func_str in ['size', 'isreal', 'iscomplex', 'shape', 'ones_like', 'zeros_like', 'empty_like', 'argsort', 'argmin', 'argmax', 'alen', 'ndim', 'nanargmax', 'nanargmin', 'count_nonzero', 'nonzero', 'result_type']:
+ implement_func(func_str, None, None, None)
+
+for func_str in ['average', 'mean', 'std', 'nanmean', 'nanstd', 'sum', 'nansum', 'cumsum', 'nancumsum']:
+ implement_func(func_str, None, 'sum', None)
+
+for func_str in ['cross', 'trapz', 'dot']:
+ implement_func(func_str, None, 'prod', None)
+
+for func_str in ['diff', 'ediff1d',]:
+ implement_func(func_str, None, 'delta', None)
+
+for func_str in ['gradient', ]:
+ implement_func(func_str, None, 'delta,div', None)
+
+for func_str in ['var', 'nanvar']:
+ implement_func(func_str, None, 'variance', None)
+
@contextlib.contextmanager
def printoptions(*args, **kwargs):
@@ -102,6 +286,12 @@ class Quantity(PrettyIPython, SharedRegistryObject):
:param units: units of the physical quantity to be created
:type units: UnitsContainer, str or pint.Quantity
"""
+ def __array_function__(self, func, types, args, kwargs):
+ if func not in HANDLED_FUNCTIONS:
+ return NotImplemented
+ if not all(issubclass(t, Quantity) for t in types):
+ return NotImplemented
+ return HANDLED_FUNCTIONS[func](*args, **kwargs)
#: Default formatting string.
default_format = ''
@@ -1469,7 +1659,8 @@ class Quantity(PrettyIPython, SharedRegistryObject):
# Attributes starting with `__array_` are common attributes of NumPy ndarray.
# They are requested by numpy functions.
if item.startswith('__array_'):
- warnings.warn("The unit of the quantity is stripped.", UnitStrippedWarning, stacklevel=2)
+ warnings.warn("The unit of the quantity is stripped when getting {} "
+ "attribute".format(item), UnitStrippedWarning, stacklevel=2)
if isinstance(self._magnitude, ndarray):
return getattr(self._magnitude, item)
else:
diff --git a/pint/testsuite/test_numpy.py b/pint/testsuite/test_numpy.py
index 9294fff..e32afa5 100644
--- a/pint/testsuite/test_numpy.py
+++ b/pint/testsuite/test_numpy.py
@@ -30,18 +30,24 @@ class TestNumpyMethods(QuantityTestCase):
@property
def q_temperature(self):
return self.Q_([[1,2],[3,4]], self.ureg.degC)
-
+
+ def assertNDArrayEqual(self, actual, desired):
+ # Assert that the given arrays are equal, and are not Quantities
+ np.testing.assert_array_equal(actual, desired)
+ self.assertFalse(isinstance(actual, self.Q_))
+ self.assertFalse(isinstance(desired, self.Q_))
+
class TestNumpyArrayCreation(TestNumpyMethods):
# https://docs.scipy.org/doc/numpy/reference/routines.array-creation.html
@helpers.requires_array_function_protocol()
def test_ones_like(self):
- np.testing.assert_equal(np.ones_like(self.q), np.array([[1, 1], [1, 1]]))
+ self.assertNDArrayEqual(np.ones_like(self.q), np.array([[1, 1], [1, 1]]))
@helpers.requires_array_function_protocol()
def test_zeros_like(self):
- np.testing.assert_equal(np.zeros_like(self.q), np.array([[0, 0], [0, 0]]))
+ self.assertNDArrayEqual(np.zeros_like(self.q), np.array([[0, 0], [0, 0]]))
@helpers.requires_array_function_protocol()
def test_empty_like(self):
@@ -53,7 +59,7 @@ class TestNumpyArrayCreation(TestNumpyMethods):
def test_full_like(self):
self.assertQuantityEqual(np.full_like(self.q, self.Q_(0, self.ureg.degC)),
self.Q_([[0, 0], [0, 0]], self.ureg.degC))
- np.testing.assert_equal(np.full_like(self.q, 2), np.array([[2, 2], [2, 2]]))
+ self.assertNDArrayEqual(np.full_like(self.q, 2), np.array([[2, 2], [2, 2]]))
class TestNumpyArrayManipulation(TestNumpyMethods):
#TODO
@@ -309,7 +315,7 @@ class TestNumpyMathematicalFunctions(TestNumpyMethods):
self.assertRaises(DimensionalityError, op.pow, arr_cp, q_cp)
# ..not for op.ipow !
# q_cp is treated as if it is an array. The units are ignored.
- # BaseQuantity.__ipow__ is never called
+ # Quantity.__ipow__ is never called
arr_cp = copy.copy(arr)
q_cp = copy.copy(q)
self.assertRaises(DimensionalityError, op.ipow, arr_cp, q_cp)
@@ -362,11 +368,11 @@ class TestNumpyUnclassified(TestNumpyMethods):
def test_argsort(self):
q = [1, 4, 5, 6, 2, 9] * self.ureg.MeV
- np.testing.assert_array_equal(q.argsort(), [0, 4, 1, 2, 3, 5])
+ self.assertNDArrayEqual(q.argsort(), [0, 4, 1, 2, 3, 5])
@helpers.requires_array_function_protocol()
def test_argsort_numpy_func(self):
- np.testing.assert_array_equal(np.argsort(self.q, axis=0), np.array([[0, 0], [1, 1]]))
+ self.assertNDArrayEqual(np.argsort(self.q, axis=0), np.array([[0, 0], [1, 1]]))
def test_diagonal(self):
q = [[1, 2, 3], [1, 2, 3], [1, 2, 3]] * self.ureg.m
@@ -385,7 +391,7 @@ class TestNumpyUnclassified(TestNumpyMethods):
def test_searchsorted(self):
q = self.q.flatten()
- np.testing.assert_array_equal(q.searchsorted([1.5, 2.5] * self.ureg.m),
+ self.assertNDArrayEqual(q.searchsorted([1.5, 2.5] * self.ureg.m),
[1, 2])
q = self.q.flatten()
self.assertRaises(DimensionalityError, q.searchsorted, [1.5, 2.5])
@@ -394,17 +400,17 @@ class TestNumpyUnclassified(TestNumpyMethods):
def test_searchsorted_numpy_func(self):
"""Test searchsorted as numpy function."""
q = self.q.flatten()
- np.testing.assert_array_equal(np.searchsorted(q, [1.5, 2.5] * self.ureg.m),
+ self.assertNDArrayEqual(np.searchsorted(q, [1.5, 2.5] * self.ureg.m),
[1, 2])
def test_nonzero(self):
q = [1, 0, 5, 6, 0, 9] * self.ureg.m
- np.testing.assert_array_equal(q.nonzero()[0], [0, 2, 3, 5])
+ self.assertNDArrayEqual(q.nonzero()[0], [0, 2, 3, 5])
@helpers.requires_array_function_protocol()
def test_nonzero_numpy_func(self):
q = [1, 0, 5, 6, 0, 9] * self.ureg.m
- np.testing.assert_array_equal(np.nonzero(q)[0], [0, 2, 3, 5])
+ self.assertNDArrayEqual(np.nonzero(q)[0], [0, 2, 3, 5])
@helpers.requires_array_function_protocol()
def test_count_nonzero_numpy_func(self):
@@ -439,11 +445,11 @@ class TestNumpyUnclassified(TestNumpyMethods):
@helpers.requires_array_function_protocol()
def test_argmax_numpy_func(self):
- np.testing.assert_equal(np.argmax(self.q, axis=0), np.array([1, 1]))
+ self.assertNDArrayEqual(np.argmax(self.q, axis=0), np.array([1, 1]))
@helpers.requires_array_function_protocol()
def test_nanargmax_numpy_func(self):
- np.testing.assert_equal(np.nanargmax(self.q_nan, axis=0), np.array([1, 0]))
+ self.assertNDArrayEqual(np.nanargmax(self.q_nan, axis=0), np.array([1, 0]))
def test_min(self):
self.assertEqual(self.q.min(), 1 * self.ureg.m)
@@ -469,11 +475,11 @@ class TestNumpyUnclassified(TestNumpyMethods):
@helpers.requires_array_function_protocol()
def test_argmin_numpy_func(self):
- np.testing.assert_equal(np.argmin(self.q, axis=0), np.array([0, 0]))
+ self.assertNDArrayEqual(np.argmin(self.q, axis=0), np.array([0, 0]))
@helpers.requires_array_function_protocol()
def test_nanargmin_numpy_func(self):
- np.testing.assert_equal(np.nanargmin(self.q_nan, axis=0), np.array([0, 0]))
+ self.assertNDArrayEqual(np.nanargmin(self.q_nan, axis=0), np.array([0, 0]))
def test_ptp(self):
self.assertEqual(self.q.ptp(), 3 * self.ureg.m)
@@ -658,7 +664,7 @@ class TestNumpyUnclassified(TestNumpyMethods):
set_application_registry(self.ureg)
def pickle_test(q):
pq = pickle.loads(pickle.dumps(q))
- np.testing.assert_array_equal(q.magnitude, pq.magnitude)
+ self.assertNDArrayEqual(q.magnitude, pq.magnitude)
self.assertEqual(q.units, pq.units)
pickle_test([10,20]*self.ureg.m)
@@ -719,7 +725,7 @@ class TestNumpyUnclassified(TestNumpyMethods):
@helpers.requires_array_function_protocol()
def test_isclose_numpy_func(self):
q2 = [[1000.05, 2000], [3000.00007, 4001]] * self.ureg.mm
- np.testing.assert_equal(np.isclose(self.q, q2), np.array([[False, True], [True, False]]))
+ self.assertNDArrayEqual(np.isclose(self.q, q2), np.array([[False, True], [True, False]]))
@helpers.requires_array_function_protocol()
def test_interp_numpy_func(self):
@@ -729,8 +735,8 @@ class TestNumpyUnclassified(TestNumpyMethods):
self.assertQuantityAlmostEqual(np.interp(x, xp, fp), self.Q_([6.66667, 20.], self.ureg.degC), rtol=1e-5)
def test_comparisons(self):
- np.testing.assert_equal(self.q > 2 * self.ureg.m, np.array([[False, False], [True, True]]))
- np.testing.assert_equal(self.q < 2 * self.ureg.m, np.array([[True, False], [False, False]]))
+ self.assertNDArrayEqual(self.q > 2 * self.ureg.m, np.array([[False, False], [True, True]]))
+ self.assertNDArrayEqual(self.q < 2 * self.ureg.m, np.array([[True, False], [False, False]]))
@unittest.skip
diff --git a/pint/testsuite/test_quantity.py b/pint/testsuite/test_quantity.py
index e15923d..c216a62 100644
--- a/pint/testsuite/test_quantity.py
+++ b/pint/testsuite/test_quantity.py
@@ -286,7 +286,9 @@ class TestQuantity(QuantityTestCase):
self.assertEqual(q.u, q.reshape(2, 3).u)
self.assertEqual(q.u, q.swapaxes(0, 1).u)
self.assertEqual(q.u, q.mean().u)
- self.assertEqual(q.u, np.compress((q==q[0,0]).any(0), q).u)
+ # TODO: Re-add np.compress implementation once mixed type is resolved
+ # (see https://github.com/hgrecco/pint/pull/764#issuecomment-523272038)
+ # self.assertEqual(q.u, np.compress((q==q[0,0]).any(0), q).u)
def test_context_attr(self):
self.assertEqual(self.ureg.meter, self.Q_(1, 'meter'))