From 90008e8402bde93b9c3fbb7d884d39f98ee6415c Mon Sep 17 00:00:00 2001 From: Ryan May Date: Wed, 8 Mar 2023 16:41:17 -0700 Subject: Fix __dask_postcompute__() to better preserve type In Unidata/MetPy#2945, a call to dask's .compute() was causing the resulting type to be a different Quantity() variant (from pint.util rather than the parent registry), which resulted in isinstance() failing. This changes things to use the appropriate type from `self` rather than hard-coded class names. --- pint/facets/dask/__init__.py | 14 +++----------- pint/testsuite/test_dask.py | 2 ++ 2 files changed, 5 insertions(+), 11 deletions(-) (limited to 'pint') diff --git a/pint/facets/dask/__init__.py b/pint/facets/dask/__init__.py index f99e8a2..5276d3c 100644 --- a/pint/facets/dask/__init__.py +++ b/pint/facets/dask/__init__.py @@ -46,10 +46,7 @@ class DaskQuantity: def __dask_tokenize__(self): from dask.base import tokenize - from pint import UnitRegistry - - # TODO: Check if this is the right class as first argument - return (UnitRegistry.Quantity, tokenize(self._magnitude), self.units) + return (type(self), tokenize(self._magnitude), self.units) @property def __dask_optimize__(self): @@ -67,14 +64,9 @@ class DaskQuantity: func, args = self._magnitude.__dask_postpersist__() return self._dask_finalize, (func, args, self.units) - @staticmethod - def _dask_finalize(results, func, args, units): + def _dask_finalize(self, results, func, args, units): values = func(results, *args) - - from pint import Quantity - - # TODO: Check if this is the right class as first argument - return Quantity(values, units) + return type(self)(values, units) @check_dask_array def compute(self, **kwargs): diff --git a/pint/testsuite/test_dask.py b/pint/testsuite/test_dask.py index 69c80fe..f4dee6a 100644 --- a/pint/testsuite/test_dask.py +++ b/pint/testsuite/test_dask.py @@ -149,6 +149,8 @@ def test_compute_persist_equivalent(local_registry, dask_array, numpy_array): assert np.all(res_compute == res_persist) assert res_compute.units == res_persist.units == units_ + assert type(res_compute) == local_registry.Quantity + assert type(res_persist) == local_registry.Quantity @pytest.mark.parametrize("method", ["compute", "persist", "visualize"]) -- cgit v1.2.1