diff options
Diffstat (limited to 'pint')
-rw-r--r-- | pint/numpy_func.py | 2 | ||||
-rw-r--r-- | pint/quantity.py | 9 | ||||
-rw-r--r-- | pint/testsuite/test_quantity.py | 9 |
3 files changed, 19 insertions, 1 deletions
diff --git a/pint/numpy_func.py b/pint/numpy_func.py index eb13daa..866d06f 100644 --- a/pint/numpy_func.py +++ b/pint/numpy_func.py @@ -228,7 +228,7 @@ copy_units_output_ufuncs = ['ldexp', 'fmod', 'mod', 'remainder'] op_units_output_ufuncs = {'var': 'square', 'prod': 'size', 'multiply': 'mul', 'true_divide': 'div', 'divide': 'div', 'floor_divide': 'div', 'sqrt': 'sqrt', 'square': 'square', 'reciprocal': 'reciprocal', - 'std': 'sum', 'sum': 'sum', 'cumsum': 'sum'} + 'std': 'sum', 'sum': 'sum', 'cumsum': 'sum', 'matmul': 'mul'} # Perform the standard ufunc implementations based on behavior collections diff --git a/pint/quantity.py b/pint/quantity.py index 1af4ae5..24ca26e 100644 --- a/pint/quantity.py +++ b/pint/quantity.py @@ -976,6 +976,15 @@ class Quantity(PrettyIPython, SharedRegistryObject): __rmul__ = __mul__ + def __matmul__(self, other): + # Use NumPy ufunc for matrix multiplication + try: + return np.matmul(self, other) + except AttributeError: + return NotImplemented + + __rmatmul__ = __matmul__ + def __itruediv__(self, other): if not isinstance(self._magnitude, ndarray): return self._mul_div(other, operator.truediv) diff --git a/pint/testsuite/test_quantity.py b/pint/testsuite/test_quantity.py index e15923d..7030e14 100644 --- a/pint/testsuite/test_quantity.py +++ b/pint/testsuite/test_quantity.py @@ -1303,6 +1303,15 @@ class TestOffsetUnitMath(QuantityTestCase, ParameterizedTestCase): in1_cp = copy.copy(in1) self.assertQuantityAlmostEqual(op.ipow(in1_cp, in2), expected) + @helpers.requires_numpy() + def test_matmul_with_numpy(self): + A = [[1, 2], [3, 4]] * self.ureg.m + B = np.array([[0, -1], [-1, 0]]) + b = [[1], [0]] * self.ureg.m + self.assertQuantityEqual(A @ B, [[-2, -1], [-4, -3]] * self.ureg.m) + self.assertQuantityEqual(A @ b, [[1], [3]] * self.ureg.m**2) + self.assertQuantityEqual(B @ b, [[0], [-1]] * self.ureg.m) + class TestDimensionReduction(QuantityTestCase): def _calc_mass(self, ureg): |