summaryrefslogtreecommitdiff
path: root/pint
diff options
context:
space:
mode:
authorKeewis <keewis@posteo.de>2020-06-21 13:54:11 +0200
committerKeewis <keewis@posteo.de>2020-06-21 14:07:50 +0200
commitfe1629cbfc00fbfe8315108ef2caf1d6dac16b3f (patch)
tree99b136c99cf6402950dccf2b7dab6e940dcefa4a /pint
parent6d9f82a7dfe667ecba9a6d19865fcb08e7a80eb6 (diff)
downloadpint-fe1629cbfc00fbfe8315108ef2caf1d6dac16b3f.tar.gz
potentially make this work with duck arrays
Diffstat (limited to 'pint')
-rw-r--r--pint/compat.py2
-rw-r--r--pint/quantity.py10
2 files changed, 9 insertions, 3 deletions
diff --git a/pint/compat.py b/pint/compat.py
index a671e19..cbd897d 100644
--- a/pint/compat.py
+++ b/pint/compat.py
@@ -212,7 +212,7 @@ def eq(lhs, rhs, check_all: bool):
bool or array_like of bool
"""
out = lhs == rhs
- if check_all and isinstance(out, ndarray):
+ if check_all and is_duck_array_type(type(out)):
return out.all()
return out
diff --git a/pint/quantity.py b/pint/quantity.py
index 1c17ed4..33a9f47 100644
--- a/pint/quantity.py
+++ b/pint/quantity.py
@@ -1470,11 +1470,16 @@ class Quantity(PrettyIPython, SharedRegistryObject):
@check_implemented
def __eq__(self, other):
def bool_result(value):
+ nonlocal other
+
if not is_duck_array_type(type(self._magnitude)):
return value
- shape = np.broadcast(self._magnitude, other).shape
- return np.full(shape=shape, fill_value=False)
+ if isinstance(other, Quantity):
+ other = other._magnitude
+
+ template, _ = np.broadcast_arrays(self._magnitude, other)
+ return np.full_like(template, fill_value=False)
# We compare to the base class of Quantity because
# each Quantity class is unique.
@@ -1498,6 +1503,7 @@ class Quantity(PrettyIPython, SharedRegistryObject):
return bool_result(False)
+ # TODO: this might be expensive. Do we even need it?
if eq(self._magnitude, 0, True) and eq(other._magnitude, 0, True):
return bool_result(self.dimensionality == other.dimensionality)