summaryrefslogtreecommitdiff
path: root/pint/numpy_func.py
diff options
context:
space:
mode:
authorJules Chéron <jules.cheron@gmail.com>2021-02-16 17:06:03 +0100
committerJules Chéron <jules.cheron@gmail.com>2021-02-16 17:14:27 +0100
commit146f4f16a6268860e0f27c1e129df0ac341eebb4 (patch)
tree10b67988ccf4e33f5116e98d54a68bb77a6e1cbb /pint/numpy_func.py
parent397022713184adb0acd9eedf15d1df236db39a7a (diff)
downloadpint-146f4f16a6268860e0f27c1e129df0ac341eebb4.tar.gz
Fix numpy.linalg.solve units output.
Update get_op_output_unit with new type invdiv. It outputs the product of the following units over the first one in the args list. Update tests with values & np.dot(A, x) == b. Where x = np.linalg.solve(A, b) Closes #1246
Diffstat (limited to 'pint/numpy_func.py')
-rw-r--r--pint/numpy_func.py14
1 files changed, 12 insertions, 2 deletions
diff --git a/pint/numpy_func.py b/pint/numpy_func.py
index 1dce044..c335f3d 100644
--- a/pint/numpy_func.py
+++ b/pint/numpy_func.py
@@ -148,6 +148,7 @@ def get_op_output_unit(unit_op, first_input_units, all_args=None, size=None):
- "sqrt": square root of `first_input_units`
- "reciprocal": reciprocal of `first_input_units`
- "size": `first_input_units` raised to the power of `size`
+ - "invdiv": inverse of `div`, product of all following units divided by first argument unit
Parameters
----------
@@ -205,7 +206,15 @@ def get_op_output_unit(unit_op, first_input_units, all_args=None, size=None):
if size is None:
raise ValueError('size argument must be given when unit_op=="size"')
result_unit = first_input_units ** size
-
+ elif unit_op == "invdiv":
+ # Start with first arg in numerator, all others in denominator
+ product = getattr(
+ all_args[0], "units", first_input_units._REGISTRY.parse_units("")
+ )
+ for x in all_args[1:]:
+ if hasattr(x, "units"):
+ product /= x.units
+ result_unit = product ** -1
else:
raise ValueError("Output unit method {} not understood".format(unit_op))
@@ -304,6 +313,7 @@ def implement_func(func_type, func_str, input_units=None, output_unit=None):
"delta",
"delta,div",
"div",
+ "invdiv",
"variance",
"square",
"sqrt",
@@ -878,7 +888,7 @@ for func_str in ["diff", "ediff1d"]:
for func_str in ["gradient"]:
implement_func("function", func_str, input_units=None, output_unit="delta,div")
for func_str in ["linalg.solve"]:
- implement_func("function", func_str, input_units=None, output_unit="div")
+ implement_func("function", func_str, input_units=None, output_unit="invdiv")
for func_str in ["var", "nanvar"]:
implement_func("function", func_str, input_units=None, output_unit="variance")