diff options
Diffstat (limited to 'pint')
-rw-r--r-- | pint/facets/dask/__init__.py | 14 | ||||
-rw-r--r-- | pint/testsuite/test_dask.py | 2 |
2 files changed, 5 insertions, 11 deletions
diff --git a/pint/facets/dask/__init__.py b/pint/facets/dask/__init__.py index f99e8a2..5276d3c 100644 --- a/pint/facets/dask/__init__.py +++ b/pint/facets/dask/__init__.py @@ -46,10 +46,7 @@ class DaskQuantity: def __dask_tokenize__(self): from dask.base import tokenize - from pint import UnitRegistry - - # TODO: Check if this is the right class as first argument - return (UnitRegistry.Quantity, tokenize(self._magnitude), self.units) + return (type(self), tokenize(self._magnitude), self.units) @property def __dask_optimize__(self): @@ -67,14 +64,9 @@ class DaskQuantity: func, args = self._magnitude.__dask_postpersist__() return self._dask_finalize, (func, args, self.units) - @staticmethod - def _dask_finalize(results, func, args, units): + def _dask_finalize(self, results, func, args, units): values = func(results, *args) - - from pint import Quantity - - # TODO: Check if this is the right class as first argument - return Quantity(values, units) + return type(self)(values, units) @check_dask_array def compute(self, **kwargs): diff --git a/pint/testsuite/test_dask.py b/pint/testsuite/test_dask.py index 69c80fe..f4dee6a 100644 --- a/pint/testsuite/test_dask.py +++ b/pint/testsuite/test_dask.py @@ -149,6 +149,8 @@ def test_compute_persist_equivalent(local_registry, dask_array, numpy_array): assert np.all(res_compute == res_persist) assert res_compute.units == res_persist.units == units_ + assert type(res_compute) == local_registry.Quantity + assert type(res_persist) == local_registry.Quantity @pytest.mark.parametrize("method", ["compute", "persist", "visualize"]) |