diff options
author | Ryan May <rmay@ucar.edu> | 2023-04-27 17:39:33 -0600 |
---|---|---|
committer | Ryan May <rmay@ucar.edu> | 2023-04-27 18:09:15 -0600 |
commit | 027c15c810a01568fe7864a810110d79061346ba (patch) | |
tree | 5ba98bfe3956202e242420b93889a1e18d2e39db /pint | |
parent | 61571a77e1a765b36ce1a26951975a1332ed3cf6 (diff) | |
download | pint-027c15c810a01568fe7864a810110d79061346ba.tar.gz |
Fix up dot/cross wrapper for non-multiplicative units
Diffstat (limited to 'pint')
-rw-r--r-- | pint/facets/numpy/numpy_func.py | 27 | ||||
-rw-r--r-- | pint/testsuite/test_numpy_func.py | 42 |
2 files changed, 67 insertions, 2 deletions
diff --git a/pint/facets/numpy/numpy_func.py b/pint/facets/numpy/numpy_func.py index 2a004a8..f25f4a4 100644 --- a/pint/facets/numpy/numpy_func.py +++ b/pint/facets/numpy/numpy_func.py @@ -729,6 +729,7 @@ for name in ["prod", "nanprod"]: implement_prod_func(name) +# Handle mutliplicative functions separately to deal with non-multiplicative units def _base_unit_if_needed(a): if a._is_multiplicative: return a @@ -759,6 +760,30 @@ def _trapz(a, x=None, dx=1.0, **kwargs): return a.units._REGISTRY.Quantity(ret, units) +def implement_mul_func(func): + # If NumPy is not available, do not attempt implement that which does not exist + if np is None: + return + + func = getattr(np, func_str) + + @implements(func_str, "function") + def implementation(a, b, **kwargs): + a = _base_unit_if_needed(a) + units = a.units + if hasattr(b, "units"): + b = _base_unit_if_needed(b) + units *= b.units + b = b._magnitude + + mag = func(a._magnitude, b, **kwargs) + return a.units._REGISTRY.Quantity(mag, units) + + +for func_str in ["cross", "dot"]: + implement_mul_func(func_str) + + # Implement simple matching-unit or stripped-unit functions based on signature @@ -950,8 +975,6 @@ for func_str in [ # Handle functions with output unit defined by operation for func_str in ["std", "nanstd", "sum", "nansum", "cumsum", "nancumsum"]: implement_func("function", func_str, input_units=None, output_unit="sum") -for func_str in ["cross", "dot"]: - implement_func("function", func_str, input_units=None, output_unit="mul") for func_str in ["diff", "ediff1d"]: implement_func("function", func_str, input_units=None, output_unit="delta") for func_str in ["gradient"]: diff --git a/pint/testsuite/test_numpy_func.py b/pint/testsuite/test_numpy_func.py index 4f1488c..7a0cdb7 100644 --- a/pint/testsuite/test_numpy_func.py +++ b/pint/testsuite/test_numpy_func.py @@ -213,3 +213,45 @@ class TestNumPyFuncUtils(TestNumpyMethods): z = self.Q_(np.array([0.0, 2.0, 4.0]), "m") with pytest.raises(OffsetUnitCalculusError): np.trapz(t, x=z) + + def test_dot(self): + with ExitStack() as stack: + stack.callback( + setattr, + self.ureg, + "autoconvert_offset_to_baseunit", + self.ureg.autoconvert_offset_to_baseunit, + ) + self.ureg.autoconvert_offset_to_baseunit = True + t = self.Q_(np.array([0.0, 5.0, 10.0]), "degC") + z = self.Q_(np.array([1.0, 2.0, 3.0]), "m") + helpers.assert_quantity_almost_equal( + np.dot(t, z), self.Q_(1678.9, "kelvin meter") + ) + + def test_dot_no_autoconvert(self): + t = self.Q_(np.array([0.0, 5.0, 10.0]), "degC") + z = self.Q_(np.array([1.0, 2.0, 3.0]), "m") + with pytest.raises(OffsetUnitCalculusError): + np.dot(t, z) + + def test_cross(self): + with ExitStack() as stack: + stack.callback( + setattr, + self.ureg, + "autoconvert_offset_to_baseunit", + self.ureg.autoconvert_offset_to_baseunit, + ) + self.ureg.autoconvert_offset_to_baseunit = True + t = self.Q_(np.array([0.0, 5.0, 10.0]), "degC") + z = self.Q_(np.array([1.0, 2.0, 3.0]), "m") + helpers.assert_quantity_almost_equal( + np.cross(t, z), self.Q_([268.15, -536.3, 268.15], "kelvin meter") + ) + + def test_cross_no_autoconvert(self): + t = self.Q_(np.array([0.0, 5.0, 10.0]), "degC") + z = self.Q_(np.array([1.0, 2.0, 3.0]), "m") + with pytest.raises(OffsetUnitCalculusError): + np.cross(t, z) |