summaryrefslogtreecommitdiff
path: root/pint/compat.py
diff options
context:
space:
mode:
authorGuido Imperiale <crusaderky@gmail.com>2020-03-12 23:32:10 +0000
committerGuido Imperiale <crusaderky@gmail.com>2020-03-12 23:32:10 +0000
commit5191978e9a546c6b0caf47dc743b0eab54d0eb2d (patch)
treecc6c49fe153af609e2eef668cab85404096bf1a7 /pint/compat.py
parent7206b5caac5b9228f44f54c7500bb24e831027cf (diff)
downloadpint-5191978e9a546c6b0caf47dc743b0eab54d0eb2d.tar.gz
add, iadd, eq, gt, etc. to treat bare NaN like 0
Trivial Bugfix on legacy numpy black
Diffstat (limited to 'pint/compat.py')
-rw-r--r--pint/compat.py86
1 files changed, 73 insertions, 13 deletions
diff --git a/pint/compat.py b/pint/compat.py
index e8e1a1b..a671e19 100644
--- a/pint/compat.py
+++ b/pint/compat.py
@@ -7,6 +7,7 @@
:copyright: 2013 by Pint Authors, see AUTHORS for more details.
:license: BSD, see LICENSE for more details.
"""
+import math
import tokenize
from decimal import Decimal
from io import BytesIO
@@ -38,7 +39,7 @@ class BehaviorChangeWarning(UserWarning):
try:
import numpy as np
- from numpy import ndarray
+ from numpy import ndarray, datetime64 as np_datetime64
HAS_NUMPY = True
NUMPY_VER = np.__version__
@@ -81,6 +82,9 @@ except ImportError:
class ndarray:
pass
+ class np_datetime64:
+ pass
+
HAS_NUMPY = False
NUMPY_VER = "0"
NUMERIC_TYPES = (Number, Decimal)
@@ -154,7 +158,7 @@ except ImportError:
pass
-def is_upcast_type(other):
+def is_upcast_type(other) -> bool:
"""Check if the type object is a upcast type using preset list.
Parameters
@@ -168,29 +172,29 @@ def is_upcast_type(other):
return other in upcast_types
-def is_duck_array_type(other):
+def is_duck_array_type(cls) -> bool:
"""Check if the type object represents a (non-Quantity) duck array type.
Parameters
----------
- other : object
+ cls : class
Returns
-------
bool
"""
# TODO (NEP 30): replace duck array check with hasattr(other, "__duckarray__")
- return other is ndarray or (
- not hasattr(other, "_magnitude")
- and not hasattr(other, "_units")
+ return issubclass(cls, ndarray) or (
+ not hasattr(cls, "_magnitude")
+ and not hasattr(cls, "_units")
and HAS_NUMPY_ARRAY_FUNCTION
- and hasattr(other, "__array_function__")
- and hasattr(other, "ndim")
- and hasattr(other, "dtype")
+ and hasattr(cls, "__array_function__")
+ and hasattr(cls, "ndim")
+ and hasattr(cls, "dtype")
)
-def eq(lhs, rhs, check_all):
+def eq(lhs, rhs, check_all: bool):
"""Comparison of scalars and arrays.
Parameters
@@ -200,7 +204,8 @@ def eq(lhs, rhs, check_all):
rhs : object
right-hand side
check_all : bool
- if True, reduce sequence to single bool.
+ if True, reduce sequence to single bool;
+ return True if all the elements are equal.
Returns
-------
@@ -208,5 +213,60 @@ def eq(lhs, rhs, check_all):
"""
out = lhs == rhs
if check_all and isinstance(out, ndarray):
- return np.all(out)
+ return out.all()
+ return out
+
+
+def isnan(obj, check_all: bool):
+ """Test for NaN or NaT
+
+ Parameters
+ ----------
+ obj : object
+ scalar or vector
+ check_all : bool
+ if True, reduce sequence to single bool;
+ return True if any of the elements are NaN.
+
+ Returns
+ -------
+ bool or array_like of bool.
+ Always return False for non-numeric types.
+ """
+ if is_duck_array_type(type(obj)):
+ if obj.dtype.kind in "if":
+ out = np.isnan(obj)
+ elif obj.dtype.kind in "Mm":
+ out = np.isnat(obj)
+ else:
+ # Not a numeric or datetime type
+ out = np.full(obj.shape, False)
+ return out.any() if check_all else out
+ if isinstance(obj, np_datetime64):
+ return np.isnat(obj)
+ try:
+ return math.isnan(obj)
+ except TypeError:
+ return False
+
+
+def zero_or_nan(obj, check_all: bool):
+ """Test if obj is zero, NaN, or NaT
+
+ Parameters
+ ----------
+ obj : object
+ scalar or vector
+ check_all : bool
+ if True, reduce sequence to single bool;
+ return True if all the elements are zero, NaN, or NaT.
+
+ Returns
+ -------
+ bool or array_like of bool.
+ Always return False for non-numeric types.
+ """
+ out = eq(obj, 0, False) + isnan(obj, False)
+ if check_all and is_duck_array_type(type(out)):
+ return out.all()
return out