diff options
| author | Jules Chéron <jules.cheron@gmail.com> | 2021-02-16 17:06:03 +0100 |
|---|---|---|
| committer | Jules Chéron <jules.cheron@gmail.com> | 2021-02-16 17:14:27 +0100 |
| commit | 146f4f16a6268860e0f27c1e129df0ac341eebb4 (patch) | |
| tree | 10b67988ccf4e33f5116e98d54a68bb77a6e1cbb /pint/numpy_func.py | |
| parent | 397022713184adb0acd9eedf15d1df236db39a7a (diff) | |
| download | pint-146f4f16a6268860e0f27c1e129df0ac341eebb4.tar.gz | |
Fix numpy.linalg.solve units output.
Update get_op_output_unit with new type invdiv.
It outputs the product of the following units over the first one in the args list.
Update tests with values & np.dot(A, x) == b.
Where x = np.linalg.solve(A, b)
Closes #1246
Diffstat (limited to 'pint/numpy_func.py')
| -rw-r--r-- | pint/numpy_func.py | 14 |
1 files changed, 12 insertions, 2 deletions
diff --git a/pint/numpy_func.py b/pint/numpy_func.py index 1dce044..c335f3d 100644 --- a/pint/numpy_func.py +++ b/pint/numpy_func.py @@ -148,6 +148,7 @@ def get_op_output_unit(unit_op, first_input_units, all_args=None, size=None): - "sqrt": square root of `first_input_units` - "reciprocal": reciprocal of `first_input_units` - "size": `first_input_units` raised to the power of `size` + - "invdiv": inverse of `div`, product of all following units divided by first argument unit Parameters ---------- @@ -205,7 +206,15 @@ def get_op_output_unit(unit_op, first_input_units, all_args=None, size=None): if size is None: raise ValueError('size argument must be given when unit_op=="size"') result_unit = first_input_units ** size - + elif unit_op == "invdiv": + # Start with first arg in numerator, all others in denominator + product = getattr( + all_args[0], "units", first_input_units._REGISTRY.parse_units("") + ) + for x in all_args[1:]: + if hasattr(x, "units"): + product /= x.units + result_unit = product ** -1 else: raise ValueError("Output unit method {} not understood".format(unit_op)) @@ -304,6 +313,7 @@ def implement_func(func_type, func_str, input_units=None, output_unit=None): "delta", "delta,div", "div", + "invdiv", "variance", "square", "sqrt", @@ -878,7 +888,7 @@ for func_str in ["diff", "ediff1d"]: for func_str in ["gradient"]: implement_func("function", func_str, input_units=None, output_unit="delta,div") for func_str in ["linalg.solve"]: - implement_func("function", func_str, input_units=None, output_unit="div") + implement_func("function", func_str, input_units=None, output_unit="invdiv") for func_str in ["var", "nanvar"]: implement_func("function", func_str, input_units=None, output_unit="variance") |
