1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
|
import pytest
from pint import UnitRegistry
# Conditionally import NumPy and Dask
np = pytest.importorskip("numpy", reason="NumPy is not available")
dask = pytest.importorskip("dask", reason="Dask is not available")
ureg = UnitRegistry(force_ndarray_like=True)
units_ = "kilogram"
def add_five(q):
return q + 5 * ureg(units_)
@pytest.fixture
def dask_array():
return dask.array.arange(0, 25, chunks=5, dtype=float).reshape((5, 5))
@pytest.fixture
def numpy_array():
return np.arange(0, 25, dtype=float).reshape((5, 5)) + 5
def test_has_dask_graph(dask_array):
"""Test that a pint.Quantity wrapped Dask array has a Dask graph."""
q = ureg.Quantity(dask_array, units_)
assert q.__dask_graph__() == dask_array.__dask_graph__()
def test_has_no_dask_graph(numpy_array):
"""Test that a pint.Quantity wrapped NumPy array does not have a Dask graph,
and that attempting to access it returns None.
"""
q = ureg.Quantity(numpy_array, units_)
assert q.__dask_graph__() is None
def test_has_dask_keys(dask_array):
"""Test that a pint.Quantity wrapped Dask array has Dask keys."""
q = ureg.Quantity(dask_array, units_)
assert q.__dask_keys__() == dask_array.__dask_keys__()
def test_dask_scheduler(dask_array):
"""Test that a pint.Quantity wrapped Dask array has the correct default scheduler."""
q = ureg.Quantity(dask_array, units_)
scheduler = q.__dask_scheduler__
scheduler_name = f'{scheduler.__module__}.{scheduler.__name__}'
true_name = 'dask.threaded.get'
assert scheduler == dask.array.Array.__dask_scheduler__
assert scheduler_name == true_name
def test_dask_optimize(dask_array):
"""Test that a pint.Quantity wrapped Dask array can be optimized."""
q = ureg.Quantity(dask_array, units_)
assert q.__dask_optimize__ == dask.array.Array.__dask_optimize__
def test_compute(dask_array, numpy_array):
"""Test the compute() method on a pint.Quantity wrapped Dask array."""
q = ureg.Quantity(dask_array, units_)
comps = add_five(q)
res = comps.compute()
assert np.all(res.m == numpy_array)
assert res.units == units_
assert q.magnitude is dask_array
def test_persist(dask_array, numpy_array):
"""Test the persist() method on a pint.Quantity wrapped Dask array.
For single machines, persist() is expected to return the computed result(s).
"""
q = ureg.Quantity(dask_array, units_)
comps = add_five(q)
res = comps.persist()
assert np.all(res.m == numpy_array)
assert res.units == units_
assert q.magnitude is dask_array
def test_visualize():
pass
def test_compute_exception(numpy_array):
"""Test exception handling for calling compute() on a pint.Quantity wrapped object
that is not a dask.array.core.Array object.
"""
q = ureg.Quantity(numpy_array, units_)
comps = add_five(q)
with pytest.raises(AttributeError) as excinfo:
comps.compute()
exctruth = "Method compute only implemented for objects of <class 'dask.array.core.Array'>, not <class 'numpy.ndarray'>"
assert str(excinfo.value) == exctruth
def test_persist_exception(numpy_array):
"""Test exception handling for calling persist() on a pint.Quantity wrapped object
that is not a dask.array.core.Array object.
"""
q = ureg.Quantity(numpy_array, units_)
comps = add_five(q)
with pytest.raises(AttributeError) as excinfo:
comps.persist()
exctruth = "Method persist only implemented for objects of <class 'dask.array.core.Array'>, not <class 'numpy.ndarray'>"
assert str(excinfo.value) == exctruth
def test_visualize_exception(numpy_array):
"""Test exception handling for calling visualize() on a pint.Quantity wrapped object
that is not a dask.array.core.Array object.
"""
q = ureg.Quantity(numpy_array, units_)
comps = add_five(q)
with pytest.raises(AttributeError) as excinfo:
comps.visualize()
exctruth = "Method visualize only implemented for objects of <class 'dask.array.core.Array'>, not <class 'numpy.ndarray'>"
assert str(excinfo.value) == exctruth
|