summaryrefslogtreecommitdiff
path: root/pint/testsuite/test_dask.py
diff options
context:
space:
mode:
authorRussell Manser <russell.p.manser@ttu.edu>2020-07-03 16:43:28 -0500
committerRussell Manser <russell.p.manser@ttu.edu>2020-07-03 16:43:28 -0500
commit11a351feefb5dd333a193c7e8c18f6d0a9a60360 (patch)
tree0e71c42ba3e1c34ecb878494656e4cd55f1976f0 /pint/testsuite/test_dask.py
parent230941df17b58c070c619ec3fc615c3ef7659dc2 (diff)
downloadpint-11a351feefb5dd333a193c7e8c18f6d0a9a60360.tar.gz
Change __dask_optimize__ and __dask_scheduler__ to class properties
Change these methods to `@property`s and have them redirect to the appropriate Dask Array methods. Tests are included. Add a test utility function for incrementing test quantities.
Diffstat (limited to 'pint/testsuite/test_dask.py')
-rw-r--r--pint/testsuite/test_dask.py36
1 files changed, 29 insertions, 7 deletions
diff --git a/pint/testsuite/test_dask.py b/pint/testsuite/test_dask.py
index 8ca6afe..6d8140d 100644
--- a/pint/testsuite/test_dask.py
+++ b/pint/testsuite/test_dask.py
@@ -4,15 +4,19 @@ from pint import UnitRegistry
# Conditionally import NumPy and Dask
np = pytest.importorskip("numpy", reason="NumPy is not available")
-da = pytest.importorskip("dask.array", reason="Dask is not available")
+dask = pytest.importorskip("dask", reason="Dask is not available")
ureg = UnitRegistry(force_ndarray_like=True)
units_ = "kilogram"
+def add_five(q):
+ return q + 5 * ureg(units_)
+
+
@pytest.fixture
def dask_array():
- return da.arange(0, 25, chunks=5, dtype=float).reshape((5, 5))
+ return dask.array.arange(0, 25, chunks=5, dtype=float).reshape((5, 5))
@pytest.fixture
@@ -40,10 +44,28 @@ def test_has_dask_keys(dask_array):
assert q.__dask_keys__() == dask_array.__dask_keys__()
+def test_dask_scheduler(dask_array):
+ """Test that a pint.Quantity wrapped Dask array has the correct default scheduler."""
+ q = ureg.Quantity(dask_array, units_)
+ scheduler = q.__dask_scheduler__
+ scheduler_name = f'{scheduler.__module__}.{scheduler.__name__}'
+ true_name = 'dask.threaded.get'
+
+ assert scheduler == dask.array.Array.__dask_scheduler__
+ assert scheduler_name == true_name
+
+
+def test_dask_optimize(dask_array):
+ """Test that a pint.Quantity wrapped Dask array can be optimized."""
+ q = ureg.Quantity(dask_array, units_)
+
+ assert q.__dask_optimize__ == dask.array.Array.__dask_optimize__
+
+
def test_compute(dask_array, numpy_array):
"""Test the compute() method on a pint.Quantity wrapped Dask array."""
q = ureg.Quantity(dask_array, units_)
- comps = q + 5 * ureg(units_)
+ comps = add_five(q)
res = comps.compute()
assert np.all(res.m == numpy_array)
@@ -57,7 +79,7 @@ def test_persist(dask_array, numpy_array):
For single machines, persist() is expected to return the computed result(s).
"""
q = ureg.Quantity(dask_array, units_)
- comps = q + 5 * ureg(units_)
+ comps = add_five(q)
res = comps.persist()
assert np.all(res.m == numpy_array)
@@ -74,7 +96,7 @@ def test_compute_exception(numpy_array):
that is not a dask.array.core.Array object.
"""
q = ureg.Quantity(numpy_array, units_)
- comps = q + 5 * ureg(units_)
+ comps = add_five(q)
with pytest.raises(AttributeError) as excinfo:
comps.compute()
@@ -87,7 +109,7 @@ def test_persist_exception(numpy_array):
that is not a dask.array.core.Array object.
"""
q = ureg.Quantity(numpy_array, units_)
- comps = q + 5 * ureg(units_)
+ comps = add_five(q)
with pytest.raises(AttributeError) as excinfo:
comps.persist()
@@ -100,7 +122,7 @@ def test_visualize_exception(numpy_array):
that is not a dask.array.core.Array object.
"""
q = ureg.Quantity(numpy_array, units_)
- comps = q + 5 * ureg(units_)
+ comps = add_five(q)
with pytest.raises(AttributeError) as excinfo:
comps.visualize()