summaryrefslogtreecommitdiff
path: root/src/go/types/call.go
diff options
context:
space:
mode:
authorRobert Griesemer <gri@golang.org>2023-03-13 16:38:14 -0700
committerGopher Robot <gobot@golang.org>2023-03-29 20:53:08 +0000
commitcc048b32f3de4168de6b0207fd01c65e51d37ac0 (patch)
treea8dc26316b4550692afb9acb97c88485e5c85451 /src/go/types/call.go
parent93b3035dbbcd21c1d0538142cba4e7f79631e7a2 (diff)
downloadgo-git-cc048b32f3de4168de6b0207fd01c65e51d37ac0.tar.gz
go/types, types2: reverse inference of function type arguments
This CL implements type inference for generic functions used in assignments: variable init expressions, regular assignments, and return statements, but (not yet) function arguments passed to functions. For instance, given a generic function func f[P any](x P) and a variable of function type var v func(x int) the assignment v = f is valid w/o explicit instantiation of f, and the missing type argument for f is inferred from the type of v. More generally, the function f may have multiple type arguments, and it may be partially instantiated. This new form of inference is not enabled by default (it needs to go through the proposal process first). It can be enabled by setting Config.EnableReverseTypeInference. The mechanism is implemented as follows: - The various expression evaluation functions take an additional (first) argument T, which is the target type for the expression. If not nil, it is the type of the LHS in an assignment. - The method Checker.funcInst is changed such that it uses both, provided type arguments (if any), and a target type (if any) to augment type inference. Change-Id: Idfde61078e1ee4f22abcca894a4c84d681734ff6 Reviewed-on: https://go-review.googlesource.com/c/go/+/476075 TryBot-Result: Gopher Robot <gobot@golang.org> Auto-Submit: Robert Griesemer <gri@google.com> Reviewed-by: Robert Findley <rfindley@google.com> Reviewed-by: Robert Griesemer <gri@google.com> Run-TryBot: Robert Griesemer <gri@google.com>
Diffstat (limited to 'src/go/types/call.go')
-rw-r--r--src/go/types/call.go119
1 files changed, 99 insertions, 20 deletions
diff --git a/src/go/types/call.go b/src/go/types/call.go
index e5968c7cfc..f75043d5dc 100644
--- a/src/go/types/call.go
+++ b/src/go/types/call.go
@@ -7,6 +7,7 @@
package types
import (
+ "fmt"
"go/ast"
"go/internal/typeparams"
"go/token"
@@ -15,25 +16,48 @@ import (
"unicode"
)
-// funcInst type-checks a function instantiation inst and returns the result in x.
-// The operand x must be the evaluation of inst.X and its type must be a signature.
-func (check *Checker) funcInst(x *operand, ix *typeparams.IndexExpr) {
+// funcInst type-checks a function instantiation and returns the result in x.
+// The incoming x must be an uninstantiated generic function. If ix != 0,
+// it provides (some or all of) the type arguments (ix.Indices) for the
+// instantiation. If the target type T != nil and is a (non-generic) function
+// signature, the signature's parameter types are used to infer additional
+// missing type arguments of x, if any.
+// At least one of inst or T must be provided.
+func (check *Checker) funcInst(T Type, pos token.Pos, x *operand, ix *typeparams.IndexExpr) {
if !check.allowVersion(check.pkg, 1, 18) {
check.softErrorf(inNode(ix.Orig, ix.Lbrack), UnsupportedFeature, "function instantiation requires go1.18 or later")
}
- targs := check.typeList(ix.Indices)
- if targs == nil {
- x.mode = invalid
- x.expr = ix.Orig
- return
+ // tsig is the (assignment) target function signature, or nil.
+ // TODO(gri) refactor and pass in tsig to funcInst instead
+ var tsig *Signature
+ if check.conf._EnableReverseTypeInference && T != nil {
+ tsig, _ = under(T).(*Signature)
+ }
+
+ // targs and xlist are the type arguments and corresponding type expressions, or nil.
+ var targs []Type
+ var xlist []ast.Expr
+ if ix != nil {
+ xlist = ix.Indices
+ targs = check.typeList(xlist)
+ if targs == nil {
+ x.mode = invalid
+ x.expr = ix
+ return
+ }
+ assert(len(targs) == len(xlist))
}
- assert(len(targs) == len(ix.Indices))
- // check number of type arguments (got) vs number of type parameters (want)
+ assert(tsig != nil || targs != nil)
+
+ // Check the number of type arguments (got) vs number of type parameters (want).
+ // Note that x is a function value, not a type expression, so we don't need to
+ // call under below.
sig := x.typ.(*Signature)
got, want := len(targs), sig.TypeParams().Len()
if got > want {
+ // Providing too many type arguments is always an error.
check.errorf(ix.Indices[got-1], WrongTypeArgCount, "got %d type arguments but want %d", got, want)
x.mode = invalid
x.expr = ix.Orig
@@ -41,11 +65,43 @@ func (check *Checker) funcInst(x *operand, ix *typeparams.IndexExpr) {
}
if got < want {
- targs = check.infer(ix.Orig, sig.TypeParams().list(), targs, nil, nil)
+ // If the uninstantiated or partially instantiated function x is used in an
+ // assignment (tsig != nil), use the respective function parameter and result
+ // types to infer additional type arguments.
+ var args []*operand
+ var params []*Var
+ if tsig != nil && sig.tparams != nil && tsig.params.Len() == sig.params.Len() && tsig.results.Len() == sig.results.Len() {
+ // x is a generic function and the signature arity matches the target function.
+ // To infer x's missing type arguments, treat the function assignment as a call
+ // of a synthetic function f where f's parameters are the parameters and results
+ // of x and where the arguments to the call of f are values of the parameter and
+ // result types of x.
+ n := tsig.params.Len()
+ m := tsig.results.Len()
+ args = make([]*operand, n+m)
+ params = make([]*Var, n+m)
+ for i := 0; i < n; i++ {
+ lvar := tsig.params.At(i)
+ lname := ast.NewIdent(paramName(lvar.name, i, "parameter"))
+ lname.NamePos = x.Pos() // correct position
+ args[i] = &operand{mode: value, expr: lname, typ: lvar.typ}
+ params[i] = sig.params.At(i)
+ }
+ for i := 0; i < m; i++ {
+ lvar := tsig.results.At(i)
+ lname := ast.NewIdent(paramName(lvar.name, i, "result parameter"))
+ lname.NamePos = x.Pos() // correct position
+ args[n+i] = &operand{mode: value, expr: lname, typ: lvar.typ}
+ params[n+i] = sig.results.At(i)
+ }
+ }
+
+ // Note that NewTuple(params...) below is nil if len(params) == 0, as desired.
+ targs = check.infer(atPos(pos), sig.TypeParams().list(), targs, NewTuple(params...), args)
if targs == nil {
// error was already reported
x.mode = invalid
- x.expr = ix.Orig
+ x.expr = ix // TODO(gri) is this correct?
return
}
got = len(targs)
@@ -53,12 +109,35 @@ func (check *Checker) funcInst(x *operand, ix *typeparams.IndexExpr) {
assert(got == want)
// instantiate function signature
- sig = check.instantiateSignature(x.Pos(), sig, targs, ix.Indices)
+ sig = check.instantiateSignature(x.Pos(), sig, targs, xlist)
assert(sig.TypeParams().Len() == 0) // signature is not generic anymore
- check.recordInstance(ix.Orig, targs, sig)
+
x.typ = sig
x.mode = value
- x.expr = ix.Orig
+ // If we don't have an index expression, keep the existing expression of x.
+ if ix != nil {
+ x.expr = ix.Orig
+ }
+ check.recordInstance(x.expr, targs, sig)
+}
+
+func paramName(name string, i int, kind string) string {
+ if name != "" {
+ return name
+ }
+ return nth(i+1) + " " + kind
+}
+
+func nth(n int) string {
+ switch n {
+ case 1:
+ return "1st"
+ case 2:
+ return "2nd"
+ case 3:
+ return "3rd"
+ }
+ return fmt.Sprintf("%dth", n)
}
func (check *Checker) instantiateSignature(pos token.Pos, typ *Signature, targs []Type, xlist []ast.Expr) (res *Signature) {
@@ -121,7 +200,7 @@ func (check *Checker) callExpr(x *operand, call *ast.CallExpr) exprKind {
case typexpr:
// conversion
- check.nonGeneric(x)
+ check.nonGeneric(nil, x)
if x.mode == invalid {
return conversion
}
@@ -131,7 +210,7 @@ func (check *Checker) callExpr(x *operand, call *ast.CallExpr) exprKind {
case 0:
check.errorf(inNode(call, call.Rparen), WrongArgCount, "missing argument in conversion to %s", T)
case 1:
- check.expr(x, call.Args[0])
+ check.expr(nil, x, call.Args[0])
if x.mode != invalid {
if call.Ellipsis.IsValid() {
check.errorf(call.Args[0], BadDotDotDotSyntax, "invalid use of ... in conversion to %s", T)
@@ -274,7 +353,7 @@ func (check *Checker) exprList(elist []ast.Expr) (xlist []*operand) {
xlist = make([]*operand, len(elist))
for i, e := range elist {
var x operand
- check.expr(&x, e)
+ check.expr(nil, &x, e)
xlist[i] = &x
}
}
@@ -791,12 +870,12 @@ func (check *Checker) use1(e ast.Expr, lhs bool) bool {
}
}
}
- check.rawExpr(&x, n, nil, true)
+ check.rawExpr(nil, &x, n, nil, true)
if v != nil {
v.used = v_used // restore v.used
}
default:
- check.rawExpr(&x, e, nil, true)
+ check.rawExpr(nil, &x, e, nil, true)
}
return x.mode != invalid
}