summaryrefslogtreecommitdiff
path: root/pint
diff options
context:
space:
mode:
authorRyan May <rmay@ucar.edu>2023-04-27 17:39:33 -0600
committerRyan May <rmay@ucar.edu>2023-04-27 18:09:15 -0600
commit027c15c810a01568fe7864a810110d79061346ba (patch)
tree5ba98bfe3956202e242420b93889a1e18d2e39db /pint
parent61571a77e1a765b36ce1a26951975a1332ed3cf6 (diff)
downloadpint-027c15c810a01568fe7864a810110d79061346ba.tar.gz
Fix up dot/cross wrapper for non-multiplicative units
Diffstat (limited to 'pint')
-rw-r--r--pint/facets/numpy/numpy_func.py27
-rw-r--r--pint/testsuite/test_numpy_func.py42
2 files changed, 67 insertions, 2 deletions
diff --git a/pint/facets/numpy/numpy_func.py b/pint/facets/numpy/numpy_func.py
index 2a004a8..f25f4a4 100644
--- a/pint/facets/numpy/numpy_func.py
+++ b/pint/facets/numpy/numpy_func.py
@@ -729,6 +729,7 @@ for name in ["prod", "nanprod"]:
implement_prod_func(name)
+# Handle mutliplicative functions separately to deal with non-multiplicative units
def _base_unit_if_needed(a):
if a._is_multiplicative:
return a
@@ -759,6 +760,30 @@ def _trapz(a, x=None, dx=1.0, **kwargs):
return a.units._REGISTRY.Quantity(ret, units)
+def implement_mul_func(func):
+ # If NumPy is not available, do not attempt implement that which does not exist
+ if np is None:
+ return
+
+ func = getattr(np, func_str)
+
+ @implements(func_str, "function")
+ def implementation(a, b, **kwargs):
+ a = _base_unit_if_needed(a)
+ units = a.units
+ if hasattr(b, "units"):
+ b = _base_unit_if_needed(b)
+ units *= b.units
+ b = b._magnitude
+
+ mag = func(a._magnitude, b, **kwargs)
+ return a.units._REGISTRY.Quantity(mag, units)
+
+
+for func_str in ["cross", "dot"]:
+ implement_mul_func(func_str)
+
+
# Implement simple matching-unit or stripped-unit functions based on signature
@@ -950,8 +975,6 @@ for func_str in [
# Handle functions with output unit defined by operation
for func_str in ["std", "nanstd", "sum", "nansum", "cumsum", "nancumsum"]:
implement_func("function", func_str, input_units=None, output_unit="sum")
-for func_str in ["cross", "dot"]:
- implement_func("function", func_str, input_units=None, output_unit="mul")
for func_str in ["diff", "ediff1d"]:
implement_func("function", func_str, input_units=None, output_unit="delta")
for func_str in ["gradient"]:
diff --git a/pint/testsuite/test_numpy_func.py b/pint/testsuite/test_numpy_func.py
index 4f1488c..7a0cdb7 100644
--- a/pint/testsuite/test_numpy_func.py
+++ b/pint/testsuite/test_numpy_func.py
@@ -213,3 +213,45 @@ class TestNumPyFuncUtils(TestNumpyMethods):
z = self.Q_(np.array([0.0, 2.0, 4.0]), "m")
with pytest.raises(OffsetUnitCalculusError):
np.trapz(t, x=z)
+
+ def test_dot(self):
+ with ExitStack() as stack:
+ stack.callback(
+ setattr,
+ self.ureg,
+ "autoconvert_offset_to_baseunit",
+ self.ureg.autoconvert_offset_to_baseunit,
+ )
+ self.ureg.autoconvert_offset_to_baseunit = True
+ t = self.Q_(np.array([0.0, 5.0, 10.0]), "degC")
+ z = self.Q_(np.array([1.0, 2.0, 3.0]), "m")
+ helpers.assert_quantity_almost_equal(
+ np.dot(t, z), self.Q_(1678.9, "kelvin meter")
+ )
+
+ def test_dot_no_autoconvert(self):
+ t = self.Q_(np.array([0.0, 5.0, 10.0]), "degC")
+ z = self.Q_(np.array([1.0, 2.0, 3.0]), "m")
+ with pytest.raises(OffsetUnitCalculusError):
+ np.dot(t, z)
+
+ def test_cross(self):
+ with ExitStack() as stack:
+ stack.callback(
+ setattr,
+ self.ureg,
+ "autoconvert_offset_to_baseunit",
+ self.ureg.autoconvert_offset_to_baseunit,
+ )
+ self.ureg.autoconvert_offset_to_baseunit = True
+ t = self.Q_(np.array([0.0, 5.0, 10.0]), "degC")
+ z = self.Q_(np.array([1.0, 2.0, 3.0]), "m")
+ helpers.assert_quantity_almost_equal(
+ np.cross(t, z), self.Q_([268.15, -536.3, 268.15], "kelvin meter")
+ )
+
+ def test_cross_no_autoconvert(self):
+ t = self.Q_(np.array([0.0, 5.0, 10.0]), "degC")
+ z = self.Q_(np.array([1.0, 2.0, 3.0]), "m")
+ with pytest.raises(OffsetUnitCalculusError):
+ np.cross(t, z)