summaryrefslogtreecommitdiff
path: root/pint/compat.py
diff options
context:
space:
mode:
Diffstat (limited to 'pint/compat.py')
-rw-r--r--pint/compat.py75
1 files changed, 40 insertions, 35 deletions
diff --git a/pint/compat.py b/pint/compat.py
index a76e15c..de149ac 100644
--- a/pint/compat.py
+++ b/pint/compat.py
@@ -13,6 +13,7 @@ from __future__ import annotations
import math
import tokenize
from decimal import Decimal
+from importlib import import_module
from io import BytesIO
from numbers import Number
@@ -80,7 +81,6 @@ try:
NP_NO_VALUE = np._NoValue
except ImportError:
-
np = None
class ndarray:
@@ -166,53 +166,58 @@ if not HAS_MIP:
# Define location of pint.Quantity in NEP-13 type cast hierarchy by defining upcast
# types using guarded imports
-upcast_types = []
-# pint-pandas (PintArray)
try:
- from pint_pandas import PintArray
-
- upcast_types.append(PintArray)
+ from dask import array as dask_array
+ from dask.base import compute, persist, visualize
except ImportError:
- pass
+ compute, persist, visualize = None, None, None
+ dask_array = None
-# Pandas (Series)
-try:
- from pandas import Series
- upcast_types.append(Series)
-except ImportError:
- pass
+upcast_type_names = (
+ "pint_pandas.PintArray",
+ "pandas.Series",
+ "xarray.core.dataarray.DataArray",
+ "xarray.core.dataset.Dataset",
+ "xarray.core.variable.Variable",
+ "pandas.core.series.Series",
+ "xarray.core.dataarray.DataArray",
+)
-# xarray (DataArray, Dataset, Variable)
-try:
- from xarray import DataArray, Dataset, Variable
+upcast_type_map = {k: None for k in upcast_type_names}
- upcast_types += [DataArray, Dataset, Variable]
-except ImportError:
- pass
-try:
- from dask import array as dask_array
- from dask.base import compute, persist, visualize
+def fully_qualified_name(obj):
+ t = type(obj)
+ module = t.__module__
+ name = t.__qualname__
-except ImportError:
- compute, persist, visualize = None, None, None
- dask_array = None
+ if module is None or module == "__builtin__":
+ return name
+ return f"{module}.{name}"
-def is_upcast_type(other) -> bool:
- """Check if the type object is a upcast type using preset list.
- Parameters
- ----------
- other : object
+def check_upcast_type(obj):
+ fqn = fully_qualified_name(obj)
+ if fqn not in upcast_type_map:
+ return False
+ else:
+ module_name, class_name = fqn.rsplit(".", 1)
+ cls = getattr(import_module(module_name), class_name)
- Returns
- -------
- bool
- """
- return other in upcast_types
+ upcast_type_map[fqn] = cls
+ # This is to check we are importing the same thing.
+ # and avoid weird problems. Maybe instead of return
+ # we should raise an error if false.
+ return isinstance(obj, cls)
+
+
+def is_upcast_type(other):
+ if other in upcast_type_map.values():
+ return True
+ return check_upcast_type(other)
def is_duck_array_type(cls) -> bool: