summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--pint/numpy_func.py2
-rw-r--r--pint/quantity.py9
-rw-r--r--pint/testsuite/test_quantity.py9
3 files changed, 19 insertions, 1 deletions
diff --git a/pint/numpy_func.py b/pint/numpy_func.py
index eb13daa..866d06f 100644
--- a/pint/numpy_func.py
+++ b/pint/numpy_func.py
@@ -228,7 +228,7 @@ copy_units_output_ufuncs = ['ldexp', 'fmod', 'mod', 'remainder']
op_units_output_ufuncs = {'var': 'square', 'prod': 'size', 'multiply': 'mul',
'true_divide': 'div', 'divide': 'div', 'floor_divide': 'div',
'sqrt': 'sqrt', 'square': 'square', 'reciprocal': 'reciprocal',
- 'std': 'sum', 'sum': 'sum', 'cumsum': 'sum'}
+ 'std': 'sum', 'sum': 'sum', 'cumsum': 'sum', 'matmul': 'mul'}
# Perform the standard ufunc implementations based on behavior collections
diff --git a/pint/quantity.py b/pint/quantity.py
index 1af4ae5..24ca26e 100644
--- a/pint/quantity.py
+++ b/pint/quantity.py
@@ -976,6 +976,15 @@ class Quantity(PrettyIPython, SharedRegistryObject):
__rmul__ = __mul__
+ def __matmul__(self, other):
+ # Use NumPy ufunc for matrix multiplication
+ try:
+ return np.matmul(self, other)
+ except AttributeError:
+ return NotImplemented
+
+ __rmatmul__ = __matmul__
+
def __itruediv__(self, other):
if not isinstance(self._magnitude, ndarray):
return self._mul_div(other, operator.truediv)
diff --git a/pint/testsuite/test_quantity.py b/pint/testsuite/test_quantity.py
index e15923d..7030e14 100644
--- a/pint/testsuite/test_quantity.py
+++ b/pint/testsuite/test_quantity.py
@@ -1303,6 +1303,15 @@ class TestOffsetUnitMath(QuantityTestCase, ParameterizedTestCase):
in1_cp = copy.copy(in1)
self.assertQuantityAlmostEqual(op.ipow(in1_cp, in2), expected)
+ @helpers.requires_numpy()
+ def test_matmul_with_numpy(self):
+ A = [[1, 2], [3, 4]] * self.ureg.m
+ B = np.array([[0, -1], [-1, 0]])
+ b = [[1], [0]] * self.ureg.m
+ self.assertQuantityEqual(A @ B, [[-2, -1], [-4, -3]] * self.ureg.m)
+ self.assertQuantityEqual(A @ b, [[1], [3]] * self.ureg.m**2)
+ self.assertQuantityEqual(B @ b, [[0], [-1]] * self.ureg.m)
+
class TestDimensionReduction(QuantityTestCase):
def _calc_mass(self, ureg):