diff options
author | Russell Manser <russell.p.manser@ttu.edu> | 2020-07-03 16:43:28 -0500 |
---|---|---|
committer | Russell Manser <russell.p.manser@ttu.edu> | 2020-07-03 16:43:28 -0500 |
commit | 11a351feefb5dd333a193c7e8c18f6d0a9a60360 (patch) | |
tree | 0e71c42ba3e1c34ecb878494656e4cd55f1976f0 /pint/testsuite/test_dask.py | |
parent | 230941df17b58c070c619ec3fc615c3ef7659dc2 (diff) | |
download | pint-11a351feefb5dd333a193c7e8c18f6d0a9a60360.tar.gz |
Change __dask_optimize__ and __dask_scheduler__ to class properties
Change these methods to `@property`s and have them redirect to the
appropriate Dask Array methods. Tests are included. Add a
test utility function for incrementing test quantities.
Diffstat (limited to 'pint/testsuite/test_dask.py')
-rw-r--r-- | pint/testsuite/test_dask.py | 36 |
1 files changed, 29 insertions, 7 deletions
diff --git a/pint/testsuite/test_dask.py b/pint/testsuite/test_dask.py index 8ca6afe..6d8140d 100644 --- a/pint/testsuite/test_dask.py +++ b/pint/testsuite/test_dask.py @@ -4,15 +4,19 @@ from pint import UnitRegistry # Conditionally import NumPy and Dask np = pytest.importorskip("numpy", reason="NumPy is not available") -da = pytest.importorskip("dask.array", reason="Dask is not available") +dask = pytest.importorskip("dask", reason="Dask is not available") ureg = UnitRegistry(force_ndarray_like=True) units_ = "kilogram" +def add_five(q): + return q + 5 * ureg(units_) + + @pytest.fixture def dask_array(): - return da.arange(0, 25, chunks=5, dtype=float).reshape((5, 5)) + return dask.array.arange(0, 25, chunks=5, dtype=float).reshape((5, 5)) @pytest.fixture @@ -40,10 +44,28 @@ def test_has_dask_keys(dask_array): assert q.__dask_keys__() == dask_array.__dask_keys__() +def test_dask_scheduler(dask_array): + """Test that a pint.Quantity wrapped Dask array has the correct default scheduler.""" + q = ureg.Quantity(dask_array, units_) + scheduler = q.__dask_scheduler__ + scheduler_name = f'{scheduler.__module__}.{scheduler.__name__}' + true_name = 'dask.threaded.get' + + assert scheduler == dask.array.Array.__dask_scheduler__ + assert scheduler_name == true_name + + +def test_dask_optimize(dask_array): + """Test that a pint.Quantity wrapped Dask array can be optimized.""" + q = ureg.Quantity(dask_array, units_) + + assert q.__dask_optimize__ == dask.array.Array.__dask_optimize__ + + def test_compute(dask_array, numpy_array): """Test the compute() method on a pint.Quantity wrapped Dask array.""" q = ureg.Quantity(dask_array, units_) - comps = q + 5 * ureg(units_) + comps = add_five(q) res = comps.compute() assert np.all(res.m == numpy_array) @@ -57,7 +79,7 @@ def test_persist(dask_array, numpy_array): For single machines, persist() is expected to return the computed result(s). """ q = ureg.Quantity(dask_array, units_) - comps = q + 5 * ureg(units_) + comps = add_five(q) res = comps.persist() assert np.all(res.m == numpy_array) @@ -74,7 +96,7 @@ def test_compute_exception(numpy_array): that is not a dask.array.core.Array object. """ q = ureg.Quantity(numpy_array, units_) - comps = q + 5 * ureg(units_) + comps = add_five(q) with pytest.raises(AttributeError) as excinfo: comps.compute() @@ -87,7 +109,7 @@ def test_persist_exception(numpy_array): that is not a dask.array.core.Array object. """ q = ureg.Quantity(numpy_array, units_) - comps = q + 5 * ureg(units_) + comps = add_five(q) with pytest.raises(AttributeError) as excinfo: comps.persist() @@ -100,7 +122,7 @@ def test_visualize_exception(numpy_array): that is not a dask.array.core.Array object. """ q = ureg.Quantity(numpy_array, units_) - comps = q + 5 * ureg(units_) + comps = add_five(q) with pytest.raises(AttributeError) as excinfo: comps.visualize() |