diff options
Diffstat (limited to 'pint/compat.py')
-rw-r--r-- | pint/compat.py | 75 |
1 files changed, 40 insertions, 35 deletions
diff --git a/pint/compat.py b/pint/compat.py index a76e15c..de149ac 100644 --- a/pint/compat.py +++ b/pint/compat.py @@ -13,6 +13,7 @@ from __future__ import annotations import math import tokenize from decimal import Decimal +from importlib import import_module from io import BytesIO from numbers import Number @@ -80,7 +81,6 @@ try: NP_NO_VALUE = np._NoValue except ImportError: - np = None class ndarray: @@ -166,53 +166,58 @@ if not HAS_MIP: # Define location of pint.Quantity in NEP-13 type cast hierarchy by defining upcast # types using guarded imports -upcast_types = [] -# pint-pandas (PintArray) try: - from pint_pandas import PintArray - - upcast_types.append(PintArray) + from dask import array as dask_array + from dask.base import compute, persist, visualize except ImportError: - pass + compute, persist, visualize = None, None, None + dask_array = None -# Pandas (Series) -try: - from pandas import Series - upcast_types.append(Series) -except ImportError: - pass +upcast_type_names = ( + "pint_pandas.PintArray", + "pandas.Series", + "xarray.core.dataarray.DataArray", + "xarray.core.dataset.Dataset", + "xarray.core.variable.Variable", + "pandas.core.series.Series", + "xarray.core.dataarray.DataArray", +) -# xarray (DataArray, Dataset, Variable) -try: - from xarray import DataArray, Dataset, Variable +upcast_type_map = {k: None for k in upcast_type_names} - upcast_types += [DataArray, Dataset, Variable] -except ImportError: - pass -try: - from dask import array as dask_array - from dask.base import compute, persist, visualize +def fully_qualified_name(obj): + t = type(obj) + module = t.__module__ + name = t.__qualname__ -except ImportError: - compute, persist, visualize = None, None, None - dask_array = None + if module is None or module == "__builtin__": + return name + return f"{module}.{name}" -def is_upcast_type(other) -> bool: - """Check if the type object is a upcast type using preset list. - Parameters - ---------- - other : object +def check_upcast_type(obj): + fqn = fully_qualified_name(obj) + if fqn not in upcast_type_map: + return False + else: + module_name, class_name = fqn.rsplit(".", 1) + cls = getattr(import_module(module_name), class_name) - Returns - ------- - bool - """ - return other in upcast_types + upcast_type_map[fqn] = cls + # This is to check we are importing the same thing. + # and avoid weird problems. Maybe instead of return + # we should raise an error if false. + return isinstance(obj, cls) + + +def is_upcast_type(other): + if other in upcast_type_map.values(): + return True + return check_upcast_type(other) def is_duck_array_type(cls) -> bool: |