summaryrefslogtreecommitdiff
path: root/src/go
diff options
context:
space:
mode:
authorRobert Griesemer <gri@golang.org>2023-04-11 18:02:26 -0700
committerGopher Robot <gobot@golang.org>2023-05-03 18:47:48 +0000
commitea9097c9f75cf7149ebbcc3edc1161122bb15e5a (patch)
tree9374f08c69decf60db9a50767f3821fa11b265c5 /src/go
parent57cb47209c655bdd7fb6d7effd5375e9b0fe90cf (diff)
downloadgo-git-ea9097c9f75cf7149ebbcc3edc1161122bb15e5a.tar.gz
go/types, types2: implement reverse type inference for function arguments
Allow function-typed function arguments to be generic and collect their type parameters together with the callee's type parameters (if any). Use a single inference step to infer the type arguments for all type parameters simultaneously. Requires Go 1.21 and that Config.EnableReverseTypeInference is set. Does not yet support partially instantiated generic function arguments. Not yet enabled in the compiler. Known bug: inference may produce an incorrect result is the same generic function is passed twice in the same function call. For #59338. Change-Id: Ia1faa27a28c6353f0bbfd7f81feafc21bd36652c Reviewed-on: https://go-review.googlesource.com/c/go/+/483935 Auto-Submit: Robert Griesemer <gri@google.com> Reviewed-by: Robert Findley <rfindley@google.com> TryBot-Result: Gopher Robot <gobot@golang.org> Reviewed-by: Robert Griesemer <gri@google.com> Run-TryBot: Robert Griesemer <gri@google.com>
Diffstat (limited to 'src/go')
-rw-r--r--src/go/types/api.go2
-rw-r--r--src/go/types/api_test.go43
-rw-r--r--src/go/types/call.go125
-rw-r--r--src/go/types/check_test.go5
-rw-r--r--src/go/types/expr.go21
-rw-r--r--src/go/types/infer.go23
6 files changed, 178 insertions, 41 deletions
diff --git a/src/go/types/api.go b/src/go/types/api.go
index 7af84fd244..e202d6dea8 100644
--- a/src/go/types/api.go
+++ b/src/go/types/api.go
@@ -175,7 +175,7 @@ type Config struct {
// partially instantiated generic functions may be assigned
// (incl. returned) to variables of function type and type
// inference will attempt to infer the missing type arguments.
- // Experimental. Needs a proposal.
+ // See proposal go.dev/issue/59338.
_EnableReverseTypeInference bool
}
diff --git a/src/go/types/api_test.go b/src/go/types/api_test.go
index ae1a7e50a7..02e26c3f02 100644
--- a/src/go/types/api_test.go
+++ b/src/go/types/api_test.go
@@ -550,18 +550,57 @@ type T[P any] []P
{`package issue51803; func foo[T any](T) {}; func _() { foo[int]( /* leave arg away on purpose */ ) }`,
[]testInst{{`foo`, []string{`int`}, `func(int)`}},
},
+
+ // reverse type parameter inference
+ {`package reverse1a; var f func(int) = g; func g[P any](P) {}`,
+ []testInst{{`g`, []string{`int`}, `func(int)`}},
+ },
+ {`package reverse1b; func f(func(int)) {}; func g[P any](P) {}; func _() { f(g) }`,
+ []testInst{{`g`, []string{`int`}, `func(int)`}},
+ },
+ {`package reverse2a; var f func(int) string = g; func g[P, Q any](P) Q { var q Q; return q }`,
+ []testInst{{`g`, []string{`int`, `string`}, `func(int) string`}},
+ },
+ {`package reverse2b; func f(func(int) string) {}; func g[P, Q any](P) Q { var q Q; return q }; func _() { f(g) }`,
+ []testInst{{`g`, []string{`int`, `string`}, `func(int) string`}},
+ },
+ // reverse3a not possible (cannot assign to generic function outside of argument passing)
+ {`package reverse3b; func f[R any](func(int) R) {}; func g[P any](P) string { return "" }; func _() { f(g) }`,
+ []testInst{
+ {`f`, []string{`string`}, `func(func(int) string)`},
+ {`g`, []string{`int`}, `func(int) string`},
+ },
+ },
+ {`package reverse4a; var _, _ func([]int, *float32) = g, h; func g[P, Q any]([]P, *Q) {}; func h[R any]([]R, *float32) {}`,
+ []testInst{
+ {`g`, []string{`int`, `float32`}, `func([]int, *float32)`},
+ {`h`, []string{`int`}, `func([]int, *float32)`},
+ },
+ },
+ {`package reverse4b; func f(_, _ func([]int, *float32)) {}; func g[P, Q any]([]P, *Q) {}; func h[R any]([]R, *float32) {}; func _() { f(g, h) }`,
+ []testInst{
+ {`g`, []string{`int`, `float32`}, `func([]int, *float32)`},
+ {`h`, []string{`int`}, `func([]int, *float32)`},
+ },
+ },
}
for _, test := range tests {
imports := make(testImporter)
conf := Config{
Importer: imports,
- Error: func(error) {}, // ignore errors
+ // Unexported field: set below with boolFieldAddr
+ // _EnableReverseTypeInference: true,
}
+ *boolFieldAddr(&conf, "_EnableReverseTypeInference") = true
instMap := make(map[*ast.Ident]Instance)
useMap := make(map[*ast.Ident]Object)
makePkg := func(src string) *Package {
- pkg, _ := typecheck(src, &conf, &Info{Instances: instMap, Uses: useMap})
+ pkg, err := typecheck(src, &conf, &Info{Instances: instMap, Uses: useMap})
+ // allow error for issue51803
+ if err != nil && (pkg == nil || pkg.Name() != "issue51803") {
+ t.Fatal(err)
+ }
imports[pkg.Name()] = pkg
return pkg
}
diff --git a/src/go/types/call.go b/src/go/types/call.go
index 979de2338f..02b6038ccc 100644
--- a/src/go/types/call.go
+++ b/src/go/types/call.go
@@ -25,14 +25,16 @@ import (
func (check *Checker) funcInst(tsig *Signature, pos token.Pos, x *operand, ix *typeparams.IndexExpr) {
assert(tsig != nil || ix != nil)
+ var versionErr bool // set if version error was reported
+ var instErrPos positioner // position for instantion error
+ if ix != nil {
+ instErrPos = inNode(ix.Orig, ix.Lbrack)
+ } else {
+ instErrPos = atPos(pos)
+ }
if !check.allowVersion(check.pkg, pos, 1, 18) {
- var posn positioner
- if ix != nil {
- posn = inNode(ix.Orig, ix.Lbrack)
- } else {
- posn = atPos(pos)
- }
- check.softErrorf(posn, UnsupportedFeature, "function instantiation requires go1.18 or later")
+ check.softErrorf(instErrPos, UnsupportedFeature, "function instantiation requires go1.18 or later")
+ versionErr = true
}
// targs and xlist are the type arguments and corresponding type expressions, or nil.
@@ -74,6 +76,13 @@ func (check *Checker) funcInst(tsig *Signature, pos token.Pos, x *operand, ix *t
// of a synthetic function f where f's parameters are the parameters and results
// of x and where the arguments to the call of f are values of the parameter and
// result types of x.
+ if !versionErr && !check.allowVersion(check.pkg, pos, 1, 21) {
+ if ix != nil {
+ check.softErrorf(instErrPos, UnsupportedFeature, "partially instantiated function in assignment requires go1.21 or later")
+ } else {
+ check.softErrorf(instErrPos, UnsupportedFeature, "implicitly instantiated function in assignment requires go1.21 or later")
+ }
+ }
n := tsig.params.Len()
m := tsig.results.Len()
args = make([]*operand, n+m)
@@ -308,7 +317,7 @@ func (check *Checker) callExpr(x *operand, call *ast.CallExpr) exprKind {
}
// evaluate arguments
- args := check.exprList(call.Args)
+ args := check.genericExprList(call.Args)
sig = check.arguments(call, sig, targs, args, xlist)
if wasGeneric && sig.TypeParams().Len() == 0 {
@@ -343,6 +352,8 @@ func (check *Checker) callExpr(x *operand, call *ast.CallExpr) exprKind {
return statement
}
+// exprList evaluates a list of expressions and returns the corresponding operands.
+// A single-element expression list may evaluate to multiple operands.
func (check *Checker) exprList(elist []ast.Expr) (xlist []*operand) {
switch len(elist) {
case 0:
@@ -361,6 +372,25 @@ func (check *Checker) exprList(elist []ast.Expr) (xlist []*operand) {
return
}
+// genericExprList is like exprList but result operands may be generic (not fully instantiated).
+func (check *Checker) genericExprList(elist []ast.Expr) (xlist []*operand) {
+ switch len(elist) {
+ case 0:
+ // nothing to do
+ case 1:
+ xlist = check.genericMultiExpr(elist[0])
+ default:
+ // multiple (possibly invalid) values
+ xlist = make([]*operand, len(elist))
+ for i, e := range elist {
+ var x operand
+ check.genericExpr(&x, e)
+ xlist[i] = &x
+ }
+ }
+ return
+}
+
// xlist is the list of type argument expressions supplied in the source code.
func (check *Checker) arguments(call *ast.CallExpr, sig *Signature, targs []Type, args []*operand, xlist []ast.Expr) (rsig *Signature) {
rsig = sig
@@ -391,7 +421,7 @@ func (check *Checker) arguments(call *ast.CallExpr, sig *Signature, targs []Type
// set up parameters
sigParams := sig.params // adjusted for variadic functions (may be nil for empty parameter lists!)
- adjusted := false // indicates if sigParams is different from t.params
+ adjusted := false // indicates if sigParams is different from sig.params
if sig.variadic {
if ddd {
// variadic_func(a, b, c...)
@@ -452,8 +482,12 @@ func (check *Checker) arguments(call *ast.CallExpr, sig *Signature, targs []Type
return
}
- // infer type arguments and instantiate signature if necessary
- if sig.TypeParams().Len() > 0 {
+ // collect type parameters of callee and generic function arguments
+ var tparams []*TypeParam
+
+ // collect type parameters of callee
+ n := sig.TypeParams().Len()
+ if n > 0 {
if !check.allowVersion(check.pkg, call.Pos(), 1, 18) {
switch call.Fun.(type) {
case *ast.IndexExpr, *ast.IndexListExpr:
@@ -463,29 +497,72 @@ func (check *Checker) arguments(call *ast.CallExpr, sig *Signature, targs []Type
check.softErrorf(inNode(call, call.Lparen), UnsupportedFeature, "implicit function instantiation requires go1.18 or later")
}
}
-
- // Rename type parameters to avoid problems with recursive calls.
- var tparams []*TypeParam
+ // rename type parameters to avoid problems with recursive calls
tparams, sigParams = check.renameTParams(call.Pos(), sig.TypeParams().list(), sigParams)
+ }
- targs := check.infer(call, tparams, targs, sigParams, args)
+ // collect type parameters from generic function arguments
+ var genericArgs []int // indices of generic function arguments
+ if check.conf._EnableReverseTypeInference {
+ for i, arg := range args {
+ // generic arguments cannot have a defined (*Named) type - no need for underlying type below
+ if asig, _ := arg.typ.(*Signature); asig != nil && asig.TypeParams().Len() > 0 {
+ // TODO(gri) need to also rename type parameters for cases like f(g, g)
+ tparams = append(tparams, asig.TypeParams().list()...)
+ genericArgs = append(genericArgs, i)
+ }
+ }
+ }
+ if len(genericArgs) > 0 && !check.allowVersion(check.pkg, call.Pos(), 1, 21) {
+ // at the moment we only support implicit instantiations of argument functions
+ check.softErrorf(inNode(call, call.Lparen), UnsupportedFeature, "implicitly instantiated function as argument requires go1.21 or later")
+ }
+
+ // tparams holds the type parameters of the callee and generic function arguments, if any:
+ // the first n type parameters belong to the callee, followed by mi type parameters for each
+ // of the generic function arguments, where mi = args[i].typ.(*Signature).TypeParams().Len().
+
+ // infer missing type arguments of callee and function arguments
+ if len(tparams) > 0 {
+ targs = check.infer(call, tparams, targs, sigParams, args)
if targs == nil {
+ // TODO(gri) If infer inferred the first targs[:n], consider instantiating
+ // the call signature for better error messages/gopls behavior.
+ // Perhaps instantiate as much as we can, also for arguments.
+ // This will require changes to how infer returns its results.
return // error already reported
}
- // compute result signature
- rsig = check.instantiateSignature(call.Pos(), sig, targs, xlist)
- assert(rsig.TypeParams().Len() == 0) // signature is not generic anymore
- check.recordInstance(call.Fun, targs, rsig)
+ // compute result signature: instantiate if needed
+ rsig = sig
+ if n > 0 {
+ rsig = check.instantiateSignature(call.Pos(), sig, targs[:n], xlist)
+ assert(rsig.TypeParams().Len() == 0) // signature is not generic anymore
+ check.recordInstance(call.Fun, targs[:n], rsig)
+ }
- // Optimization: Only if the parameter list was adjusted do we
- // need to compute it from the adjusted list; otherwise we can
- // simply use the result signature's parameter list.
- if adjusted {
- sigParams = check.subst(call.Pos(), sigParams, makeSubstMap(tparams, targs), nil, check.context()).(*Tuple)
+ // Optimization: Only if the callee's parameter list was adjusted do we need to
+ // compute it from the adjusted list; otherwise we can simply use the result
+ // signature's parameter list. We only need the n type parameters and arguments
+ // of the callee.
+ if n > 0 && adjusted {
+ sigParams = check.subst(call.Pos(), sigParams, makeSubstMap(tparams[:n], targs[:n]), nil, check.context()).(*Tuple)
} else {
sigParams = rsig.params
}
+
+ // compute argument signatures: instantiate if needed
+ j := n
+ for _, i := range genericArgs {
+ asig := args[i].typ.(*Signature)
+ k := j + asig.TypeParams().Len()
+ // targs[j:k] are the inferred type arguments for asig
+ asig = check.instantiateSignature(call.Pos(), asig, targs[j:k], nil) // TODO(gri) provide xlist if possible (partial instantiations)
+ assert(asig.TypeParams().Len() == 0) // signature is not generic anymore
+ args[i].typ = asig
+ check.recordInstance(args[i].expr, targs[j:k], asig)
+ j = k
+ }
}
// check arguments
diff --git a/src/go/types/check_test.go b/src/go/types/check_test.go
index 0f4c320a47..cda052f4d3 100644
--- a/src/go/types/check_test.go
+++ b/src/go/types/check_test.go
@@ -39,6 +39,7 @@ import (
"go/scanner"
"go/token"
"internal/testenv"
+ "internal/types/errors"
"os"
"path/filepath"
"reflect"
@@ -295,9 +296,9 @@ func testFiles(t *testing.T, sizes Sizes, filenames []string, srcs [][]byte, man
}
}
-func readCode(err Error) int {
+func readCode(err Error) errors.Code {
v := reflect.ValueOf(err)
- return int(v.FieldByName("go116code").Int())
+ return errors.Code(v.FieldByName("go116code").Int())
}
// boolFieldAddr(conf, name) returns the address of the boolean field conf.<name>.
diff --git a/src/go/types/expr.go b/src/go/types/expr.go
index 0db80ca44b..891153ba8d 100644
--- a/src/go/types/expr.go
+++ b/src/go/types/expr.go
@@ -1829,6 +1829,27 @@ func (check *Checker) multiExpr(e ast.Expr, allowCommaOk bool) (list []*operand,
return
}
+// genericMultiExpr is like multiExpr but a one-element result may also be generic
+// and potential comma-ok expressions are returned as single values.
+func (check *Checker) genericMultiExpr(e ast.Expr) (list []*operand) {
+ var x operand
+ check.rawExpr(nil, &x, e, nil, true)
+ check.exclude(&x, 1<<novalue|1<<builtin|1<<typexpr)
+
+ if t, ok := x.typ.(*Tuple); ok && x.mode != invalid {
+ // multiple values - cannot be generic
+ list = make([]*operand, t.Len())
+ for i, v := range t.vars {
+ list[i] = &operand{mode: value, expr: e, typ: v.typ}
+ }
+ return
+ }
+
+ // exactly one (possible invalid or generic) value
+ list = []*operand{&x}
+ return
+}
+
// exprWithHint typechecks expression e and initializes x with the expression value;
// hint is the type of a composite literal element.
// If an error occurred, x.mode is set to invalid.
diff --git a/src/go/types/infer.go b/src/go/types/infer.go
index 3aa66105c4..39bf4a14f7 100644
--- a/src/go/types/infer.go
+++ b/src/go/types/infer.go
@@ -142,17 +142,17 @@ func (check *Checker) infer(posn positioner, tparams []*TypeParam, targs []Type,
}
for i, arg := range args {
+ if arg.mode == invalid {
+ // An error was reported earlier. Ignore this arg
+ // and continue, we may still be able to infer all
+ // targs resulting in fewer follow-on errors.
+ // TODO(gri) determine if we still need this check
+ continue
+ }
par := params.At(i)
- // If we permit bidirectional unification, this conditional code needs to be
- // executed even if par.typ is not parameterized since the argument may be a
- // generic function (for which we want to infer its type arguments).
- if isParameterized(tparams, par.typ) {
- if arg.mode == invalid {
- // An error was reported earlier. Ignore this targ
- // and continue, we may still be able to infer all
- // targs resulting in fewer follow-on errors.
- continue
- }
+ if isParameterized(tparams, par.typ) || isParameterized(tparams, arg.typ) {
+ // Function parameters are always typed. Arguments may be untyped.
+ // Collect the indices of untyped arguments and handle them later.
if isTyped(arg.typ) {
if !u.unify(par.typ, arg.typ) {
errorf("type", par.typ, arg.typ, arg)
@@ -265,7 +265,7 @@ func (check *Checker) infer(posn positioner, tparams []*TypeParam, targs []Type,
}
// --- 3 ---
- // use information from untyped contants
+ // use information from untyped constants
if traceInference {
u.tracef("== untyped arguments: %v", untyped)
@@ -543,7 +543,6 @@ func (w *tpWalker) isParameterized(typ Type) (res bool) {
}
case *TypeParam:
- // t must be one of w.tparams
return tparamIndex(w.tparams, t) >= 0
default: