summaryrefslogtreecommitdiff
path: root/pint
diff options
context:
space:
mode:
authorJules Chéron <43635101+jules-ch@users.noreply.github.com>2023-03-22 21:31:15 +0100
committerGitHub <noreply@github.com>2023-03-22 21:31:15 +0100
commit3c6f5ffcb9304089caf4be07ded2c3644c2c0e13 (patch)
treea4b286d6053aa9e83ac1f81666886dac7a7c58dd /pint
parentd0106d78df7e39bfbc0ad1cd89f9e28d48e8af26 (diff)
parent90008e8402bde93b9c3fbb7d884d39f98ee6415c (diff)
downloadpint-3c6f5ffcb9304089caf4be07ded2c3644c2c0e13.tar.gz
Merge pull request #1722 from dopplershift/fix-dask
Fix __dask_postcompute__() to better preserve type
Diffstat (limited to 'pint')
-rw-r--r--pint/facets/dask/__init__.py14
-rw-r--r--pint/testsuite/test_dask.py2
2 files changed, 5 insertions, 11 deletions
diff --git a/pint/facets/dask/__init__.py b/pint/facets/dask/__init__.py
index f99e8a2..5276d3c 100644
--- a/pint/facets/dask/__init__.py
+++ b/pint/facets/dask/__init__.py
@@ -46,10 +46,7 @@ class DaskQuantity:
def __dask_tokenize__(self):
from dask.base import tokenize
- from pint import UnitRegistry
-
- # TODO: Check if this is the right class as first argument
- return (UnitRegistry.Quantity, tokenize(self._magnitude), self.units)
+ return (type(self), tokenize(self._magnitude), self.units)
@property
def __dask_optimize__(self):
@@ -67,14 +64,9 @@ class DaskQuantity:
func, args = self._magnitude.__dask_postpersist__()
return self._dask_finalize, (func, args, self.units)
- @staticmethod
- def _dask_finalize(results, func, args, units):
+ def _dask_finalize(self, results, func, args, units):
values = func(results, *args)
-
- from pint import Quantity
-
- # TODO: Check if this is the right class as first argument
- return Quantity(values, units)
+ return type(self)(values, units)
@check_dask_array
def compute(self, **kwargs):
diff --git a/pint/testsuite/test_dask.py b/pint/testsuite/test_dask.py
index 69c80fe..f4dee6a 100644
--- a/pint/testsuite/test_dask.py
+++ b/pint/testsuite/test_dask.py
@@ -149,6 +149,8 @@ def test_compute_persist_equivalent(local_registry, dask_array, numpy_array):
assert np.all(res_compute == res_persist)
assert res_compute.units == res_persist.units == units_
+ assert type(res_compute) == local_registry.Quantity
+ assert type(res_persist) == local_registry.Quantity
@pytest.mark.parametrize("method", ["compute", "persist", "visualize"])