From 3fc052fc6e64c6306974d4606551cc9b4711039f Mon Sep 17 00:00:00 2001 From: "Michal W. Tarnowski" Date: Wed, 16 Dec 2020 09:39:31 +0100 Subject: DOC: Fix and extend the docstring for np.inner (#18002) * DOC: fix the docstring for np.inner * DOC: extend the docstring for np.inner and add an example * DOC: update numpy/core/multiarray.py Co-authored-by: Eric Wieser * DOC: apply suggestions from code review Co-authored-by: Matti Picus Co-authored-by: Eric Wieser Co-authored-by: Matti Picus --- numpy/core/multiarray.py | 27 +++++++++++++++++++++------ 1 file changed, 21 insertions(+), 6 deletions(-) (limited to 'numpy/core/multiarray.py') diff --git a/numpy/core/multiarray.py b/numpy/core/multiarray.py index f736973de..07179a627 100644 --- a/numpy/core/multiarray.py +++ b/numpy/core/multiarray.py @@ -259,12 +259,16 @@ def inner(a, b): Returns ------- out : ndarray - `out.shape = a.shape[:-1] + b.shape[:-1]` + If `a` and `b` are both + scalars or both 1-D arrays then a scalar is returned; otherwise + an array is returned. + ``out.shape = (*a.shape[:-1], *b.shape[:-1])`` Raises ------ ValueError - If the last dimension of `a` and `b` has different size. + If both `a` and `b` are nonscalar and their last dimensions have + different sizes. See Also -------- @@ -284,8 +288,8 @@ def inner(a, b): or explicitly:: - np.inner(a, b)[i0,...,ir-1,j0,...,js-1] - = sum(a[i0,...,ir-1,:]*b[j0,...,js-1,:]) + np.inner(a, b)[i0,...,ir-2,j0,...,js-2] + = sum(a[i0,...,ir-2,:]*b[j0,...,js-2,:]) In addition `a` or `b` may be scalars, in which case:: @@ -300,14 +304,25 @@ def inner(a, b): >>> np.inner(a, b) 2 - A multidimensional example: + Some multidimensional examples: >>> a = np.arange(24).reshape((2,3,4)) >>> b = np.arange(4) - >>> np.inner(a, b) + >>> c = np.inner(a, b) + >>> c.shape + (2, 3) + >>> c array([[ 14, 38, 62], [ 86, 110, 134]]) + >>> a = np.arange(2).reshape((1,1,2)) + >>> b = np.arange(6).reshape((3,2)) + >>> c = np.inner(a, b) + >>> c.shape + (1, 1, 3) + >>> c + array([[[1, 3, 5]]]) + An example where `b` is a scalar: >>> np.inner(np.eye(2), 7) -- cgit v1.2.1