diff options
author | Ryan May <rmay@ucar.edu> | 2022-10-11 15:19:37 -0600 |
---|---|---|
committer | Ryan May <rmay@ucar.edu> | 2022-10-14 17:05:48 -0600 |
commit | c659d9ed8dda8b4223f157addc7ee6435566cb94 (patch) | |
tree | 4b3c3827d13112fd0a2771d310486d9f6a9f477c /pint | |
parent | 052a92041912e02abee29df48e95541c8448f78b (diff) | |
download | pint-c659d9ed8dda8b4223f157addc7ee6435566cb94.tar.gz |
Fix setitem with a masked array with multiple items (Fixes #1584)
This was incorrectly passing through some non-masked values.
Diffstat (limited to 'pint')
-rw-r--r-- | pint/facets/numpy/quantity.py | 7 | ||||
-rw-r--r-- | pint/testsuite/test_numpy.py | 18 |
2 files changed, 23 insertions, 2 deletions
diff --git a/pint/facets/numpy/quantity.py b/pint/facets/numpy/quantity.py index 2436100..40a97a4 100644 --- a/pint/facets/numpy/quantity.py +++ b/pint/facets/numpy/quantity.py @@ -245,7 +245,12 @@ class NumpyQuantity: def __setitem__(self, key, value): try: - if np.ma.is_masked(value) or math.isnan(value): + # If we're dealing with a masked single value or a nan, set it + if ( + isinstance(self._magnitude, np.ma.MaskedArray) + and np.ma.is_masked(value) + and getattr(value, "size", 0) == 1 + ) or math.isnan(value): self._magnitude[key] = value return except TypeError: diff --git a/pint/testsuite/test_numpy.py b/pint/testsuite/test_numpy.py index 77d18e3..4e178c6 100644 --- a/pint/testsuite/test_numpy.py +++ b/pint/testsuite/test_numpy.py @@ -912,7 +912,7 @@ class TestNumpyUnclassified(TestNumpyMethods): q[:] = 1 * self.ureg.m helpers.assert_quantity_equal(q, [[1, 1], [1, 1]] * self.ureg.m) - # check and see that dimensionless num bers work correctly + # check and see that dimensionless numbers work correctly q = [0, 1, 2, 3] * self.ureg.dimensionless q[0] = 1 helpers.assert_quantity_equal(q, np.asarray([1, 1, 2, 3])) @@ -933,6 +933,22 @@ class TestNumpyUnclassified(TestNumpyMethods): assert not w assert q.mask[0] + def test_setitem_mixed_masked(self): + masked = np.ma.array( + [ + 1, + 2, + ], + mask=[True, False], + ) + q = self.Q_(np.ones(shape=(2,)), "m") + with pytest.raises(DimensionalityError): + q[:] = masked + + masked_q = self.Q_(masked, "mm") + q[:] = masked_q + helpers.assert_quantity_equal(q, [1.0, 0.002] * self.ureg.m) + def test_iterator(self): for q, v in zip(self.q.flatten(), [1, 2, 3, 4]): assert q == v * self.ureg.m |