diff options
author | Guido Imperiale <crusaderky@gmail.com> | 2020-03-12 23:32:10 +0000 |
---|---|---|
committer | Guido Imperiale <crusaderky@gmail.com> | 2020-03-12 23:32:10 +0000 |
commit | 5191978e9a546c6b0caf47dc743b0eab54d0eb2d (patch) | |
tree | cc6c49fe153af609e2eef668cab85404096bf1a7 /pint/compat.py | |
parent | 7206b5caac5b9228f44f54c7500bb24e831027cf (diff) | |
download | pint-5191978e9a546c6b0caf47dc743b0eab54d0eb2d.tar.gz |
add, iadd, eq, gt, etc. to treat bare NaN like 0
Trivial
Bugfix on legacy numpy
black
Diffstat (limited to 'pint/compat.py')
-rw-r--r-- | pint/compat.py | 86 |
1 files changed, 73 insertions, 13 deletions
diff --git a/pint/compat.py b/pint/compat.py index e8e1a1b..a671e19 100644 --- a/pint/compat.py +++ b/pint/compat.py @@ -7,6 +7,7 @@ :copyright: 2013 by Pint Authors, see AUTHORS for more details. :license: BSD, see LICENSE for more details. """ +import math import tokenize from decimal import Decimal from io import BytesIO @@ -38,7 +39,7 @@ class BehaviorChangeWarning(UserWarning): try: import numpy as np - from numpy import ndarray + from numpy import ndarray, datetime64 as np_datetime64 HAS_NUMPY = True NUMPY_VER = np.__version__ @@ -81,6 +82,9 @@ except ImportError: class ndarray: pass + class np_datetime64: + pass + HAS_NUMPY = False NUMPY_VER = "0" NUMERIC_TYPES = (Number, Decimal) @@ -154,7 +158,7 @@ except ImportError: pass -def is_upcast_type(other): +def is_upcast_type(other) -> bool: """Check if the type object is a upcast type using preset list. Parameters @@ -168,29 +172,29 @@ def is_upcast_type(other): return other in upcast_types -def is_duck_array_type(other): +def is_duck_array_type(cls) -> bool: """Check if the type object represents a (non-Quantity) duck array type. Parameters ---------- - other : object + cls : class Returns ------- bool """ # TODO (NEP 30): replace duck array check with hasattr(other, "__duckarray__") - return other is ndarray or ( - not hasattr(other, "_magnitude") - and not hasattr(other, "_units") + return issubclass(cls, ndarray) or ( + not hasattr(cls, "_magnitude") + and not hasattr(cls, "_units") and HAS_NUMPY_ARRAY_FUNCTION - and hasattr(other, "__array_function__") - and hasattr(other, "ndim") - and hasattr(other, "dtype") + and hasattr(cls, "__array_function__") + and hasattr(cls, "ndim") + and hasattr(cls, "dtype") ) -def eq(lhs, rhs, check_all): +def eq(lhs, rhs, check_all: bool): """Comparison of scalars and arrays. Parameters @@ -200,7 +204,8 @@ def eq(lhs, rhs, check_all): rhs : object right-hand side check_all : bool - if True, reduce sequence to single bool. + if True, reduce sequence to single bool; + return True if all the elements are equal. Returns ------- @@ -208,5 +213,60 @@ def eq(lhs, rhs, check_all): """ out = lhs == rhs if check_all and isinstance(out, ndarray): - return np.all(out) + return out.all() + return out + + +def isnan(obj, check_all: bool): + """Test for NaN or NaT + + Parameters + ---------- + obj : object + scalar or vector + check_all : bool + if True, reduce sequence to single bool; + return True if any of the elements are NaN. + + Returns + ------- + bool or array_like of bool. + Always return False for non-numeric types. + """ + if is_duck_array_type(type(obj)): + if obj.dtype.kind in "if": + out = np.isnan(obj) + elif obj.dtype.kind in "Mm": + out = np.isnat(obj) + else: + # Not a numeric or datetime type + out = np.full(obj.shape, False) + return out.any() if check_all else out + if isinstance(obj, np_datetime64): + return np.isnat(obj) + try: + return math.isnan(obj) + except TypeError: + return False + + +def zero_or_nan(obj, check_all: bool): + """Test if obj is zero, NaN, or NaT + + Parameters + ---------- + obj : object + scalar or vector + check_all : bool + if True, reduce sequence to single bool; + return True if all the elements are zero, NaN, or NaT. + + Returns + ------- + bool or array_like of bool. + Always return False for non-numeric types. + """ + out = eq(obj, 0, False) + isnan(obj, False) + if check_all and is_duck_array_type(type(out)): + return out.all() return out |