summaryrefslogtreecommitdiff
path: root/pint
diff options
context:
space:
mode:
authorRyan May <rmay@ucar.edu>2022-10-11 15:19:37 -0600
committerRyan May <rmay@ucar.edu>2022-10-14 17:05:48 -0600
commitc659d9ed8dda8b4223f157addc7ee6435566cb94 (patch)
tree4b3c3827d13112fd0a2771d310486d9f6a9f477c /pint
parent052a92041912e02abee29df48e95541c8448f78b (diff)
downloadpint-c659d9ed8dda8b4223f157addc7ee6435566cb94.tar.gz
Fix setitem with a masked array with multiple items (Fixes #1584)
This was incorrectly passing through some non-masked values.
Diffstat (limited to 'pint')
-rw-r--r--pint/facets/numpy/quantity.py7
-rw-r--r--pint/testsuite/test_numpy.py18
2 files changed, 23 insertions, 2 deletions
diff --git a/pint/facets/numpy/quantity.py b/pint/facets/numpy/quantity.py
index 2436100..40a97a4 100644
--- a/pint/facets/numpy/quantity.py
+++ b/pint/facets/numpy/quantity.py
@@ -245,7 +245,12 @@ class NumpyQuantity:
def __setitem__(self, key, value):
try:
- if np.ma.is_masked(value) or math.isnan(value):
+ # If we're dealing with a masked single value or a nan, set it
+ if (
+ isinstance(self._magnitude, np.ma.MaskedArray)
+ and np.ma.is_masked(value)
+ and getattr(value, "size", 0) == 1
+ ) or math.isnan(value):
self._magnitude[key] = value
return
except TypeError:
diff --git a/pint/testsuite/test_numpy.py b/pint/testsuite/test_numpy.py
index 77d18e3..4e178c6 100644
--- a/pint/testsuite/test_numpy.py
+++ b/pint/testsuite/test_numpy.py
@@ -912,7 +912,7 @@ class TestNumpyUnclassified(TestNumpyMethods):
q[:] = 1 * self.ureg.m
helpers.assert_quantity_equal(q, [[1, 1], [1, 1]] * self.ureg.m)
- # check and see that dimensionless num bers work correctly
+ # check and see that dimensionless numbers work correctly
q = [0, 1, 2, 3] * self.ureg.dimensionless
q[0] = 1
helpers.assert_quantity_equal(q, np.asarray([1, 1, 2, 3]))
@@ -933,6 +933,22 @@ class TestNumpyUnclassified(TestNumpyMethods):
assert not w
assert q.mask[0]
+ def test_setitem_mixed_masked(self):
+ masked = np.ma.array(
+ [
+ 1,
+ 2,
+ ],
+ mask=[True, False],
+ )
+ q = self.Q_(np.ones(shape=(2,)), "m")
+ with pytest.raises(DimensionalityError):
+ q[:] = masked
+
+ masked_q = self.Q_(masked, "mm")
+ q[:] = masked_q
+ helpers.assert_quantity_equal(q, [1.0, 0.002] * self.ureg.m)
+
def test_iterator(self):
for q, v in zip(self.q.flatten(), [1, 2, 3, 4]):
assert q == v * self.ureg.m