From ea9097c9f75cf7149ebbcc3edc1161122bb15e5a Mon Sep 17 00:00:00 2001 From: Robert Griesemer Date: Tue, 11 Apr 2023 18:02:26 -0700 Subject: go/types, types2: implement reverse type inference for function arguments Allow function-typed function arguments to be generic and collect their type parameters together with the callee's type parameters (if any). Use a single inference step to infer the type arguments for all type parameters simultaneously. Requires Go 1.21 and that Config.EnableReverseTypeInference is set. Does not yet support partially instantiated generic function arguments. Not yet enabled in the compiler. Known bug: inference may produce an incorrect result is the same generic function is passed twice in the same function call. For #59338. Change-Id: Ia1faa27a28c6353f0bbfd7f81feafc21bd36652c Reviewed-on: https://go-review.googlesource.com/c/go/+/483935 Auto-Submit: Robert Griesemer Reviewed-by: Robert Findley TryBot-Result: Gopher Robot Reviewed-by: Robert Griesemer Run-TryBot: Robert Griesemer --- src/go/types/api.go | 2 +- src/go/types/api_test.go | 43 +++++++++++++++- src/go/types/call.go | 125 ++++++++++++++++++++++++++++++++++++--------- src/go/types/check_test.go | 5 +- src/go/types/expr.go | 21 ++++++++ src/go/types/infer.go | 23 ++++----- 6 files changed, 178 insertions(+), 41 deletions(-) (limited to 'src/go') diff --git a/src/go/types/api.go b/src/go/types/api.go index 7af84fd244..e202d6dea8 100644 --- a/src/go/types/api.go +++ b/src/go/types/api.go @@ -175,7 +175,7 @@ type Config struct { // partially instantiated generic functions may be assigned // (incl. returned) to variables of function type and type // inference will attempt to infer the missing type arguments. - // Experimental. Needs a proposal. + // See proposal go.dev/issue/59338. _EnableReverseTypeInference bool } diff --git a/src/go/types/api_test.go b/src/go/types/api_test.go index ae1a7e50a7..02e26c3f02 100644 --- a/src/go/types/api_test.go +++ b/src/go/types/api_test.go @@ -550,18 +550,57 @@ type T[P any] []P {`package issue51803; func foo[T any](T) {}; func _() { foo[int]( /* leave arg away on purpose */ ) }`, []testInst{{`foo`, []string{`int`}, `func(int)`}}, }, + + // reverse type parameter inference + {`package reverse1a; var f func(int) = g; func g[P any](P) {}`, + []testInst{{`g`, []string{`int`}, `func(int)`}}, + }, + {`package reverse1b; func f(func(int)) {}; func g[P any](P) {}; func _() { f(g) }`, + []testInst{{`g`, []string{`int`}, `func(int)`}}, + }, + {`package reverse2a; var f func(int) string = g; func g[P, Q any](P) Q { var q Q; return q }`, + []testInst{{`g`, []string{`int`, `string`}, `func(int) string`}}, + }, + {`package reverse2b; func f(func(int) string) {}; func g[P, Q any](P) Q { var q Q; return q }; func _() { f(g) }`, + []testInst{{`g`, []string{`int`, `string`}, `func(int) string`}}, + }, + // reverse3a not possible (cannot assign to generic function outside of argument passing) + {`package reverse3b; func f[R any](func(int) R) {}; func g[P any](P) string { return "" }; func _() { f(g) }`, + []testInst{ + {`f`, []string{`string`}, `func(func(int) string)`}, + {`g`, []string{`int`}, `func(int) string`}, + }, + }, + {`package reverse4a; var _, _ func([]int, *float32) = g, h; func g[P, Q any]([]P, *Q) {}; func h[R any]([]R, *float32) {}`, + []testInst{ + {`g`, []string{`int`, `float32`}, `func([]int, *float32)`}, + {`h`, []string{`int`}, `func([]int, *float32)`}, + }, + }, + {`package reverse4b; func f(_, _ func([]int, *float32)) {}; func g[P, Q any]([]P, *Q) {}; func h[R any]([]R, *float32) {}; func _() { f(g, h) }`, + []testInst{ + {`g`, []string{`int`, `float32`}, `func([]int, *float32)`}, + {`h`, []string{`int`}, `func([]int, *float32)`}, + }, + }, } for _, test := range tests { imports := make(testImporter) conf := Config{ Importer: imports, - Error: func(error) {}, // ignore errors + // Unexported field: set below with boolFieldAddr + // _EnableReverseTypeInference: true, } + *boolFieldAddr(&conf, "_EnableReverseTypeInference") = true instMap := make(map[*ast.Ident]Instance) useMap := make(map[*ast.Ident]Object) makePkg := func(src string) *Package { - pkg, _ := typecheck(src, &conf, &Info{Instances: instMap, Uses: useMap}) + pkg, err := typecheck(src, &conf, &Info{Instances: instMap, Uses: useMap}) + // allow error for issue51803 + if err != nil && (pkg == nil || pkg.Name() != "issue51803") { + t.Fatal(err) + } imports[pkg.Name()] = pkg return pkg } diff --git a/src/go/types/call.go b/src/go/types/call.go index 979de2338f..02b6038ccc 100644 --- a/src/go/types/call.go +++ b/src/go/types/call.go @@ -25,14 +25,16 @@ import ( func (check *Checker) funcInst(tsig *Signature, pos token.Pos, x *operand, ix *typeparams.IndexExpr) { assert(tsig != nil || ix != nil) + var versionErr bool // set if version error was reported + var instErrPos positioner // position for instantion error + if ix != nil { + instErrPos = inNode(ix.Orig, ix.Lbrack) + } else { + instErrPos = atPos(pos) + } if !check.allowVersion(check.pkg, pos, 1, 18) { - var posn positioner - if ix != nil { - posn = inNode(ix.Orig, ix.Lbrack) - } else { - posn = atPos(pos) - } - check.softErrorf(posn, UnsupportedFeature, "function instantiation requires go1.18 or later") + check.softErrorf(instErrPos, UnsupportedFeature, "function instantiation requires go1.18 or later") + versionErr = true } // targs and xlist are the type arguments and corresponding type expressions, or nil. @@ -74,6 +76,13 @@ func (check *Checker) funcInst(tsig *Signature, pos token.Pos, x *operand, ix *t // of a synthetic function f where f's parameters are the parameters and results // of x and where the arguments to the call of f are values of the parameter and // result types of x. + if !versionErr && !check.allowVersion(check.pkg, pos, 1, 21) { + if ix != nil { + check.softErrorf(instErrPos, UnsupportedFeature, "partially instantiated function in assignment requires go1.21 or later") + } else { + check.softErrorf(instErrPos, UnsupportedFeature, "implicitly instantiated function in assignment requires go1.21 or later") + } + } n := tsig.params.Len() m := tsig.results.Len() args = make([]*operand, n+m) @@ -308,7 +317,7 @@ func (check *Checker) callExpr(x *operand, call *ast.CallExpr) exprKind { } // evaluate arguments - args := check.exprList(call.Args) + args := check.genericExprList(call.Args) sig = check.arguments(call, sig, targs, args, xlist) if wasGeneric && sig.TypeParams().Len() == 0 { @@ -343,6 +352,8 @@ func (check *Checker) callExpr(x *operand, call *ast.CallExpr) exprKind { return statement } +// exprList evaluates a list of expressions and returns the corresponding operands. +// A single-element expression list may evaluate to multiple operands. func (check *Checker) exprList(elist []ast.Expr) (xlist []*operand) { switch len(elist) { case 0: @@ -361,6 +372,25 @@ func (check *Checker) exprList(elist []ast.Expr) (xlist []*operand) { return } +// genericExprList is like exprList but result operands may be generic (not fully instantiated). +func (check *Checker) genericExprList(elist []ast.Expr) (xlist []*operand) { + switch len(elist) { + case 0: + // nothing to do + case 1: + xlist = check.genericMultiExpr(elist[0]) + default: + // multiple (possibly invalid) values + xlist = make([]*operand, len(elist)) + for i, e := range elist { + var x operand + check.genericExpr(&x, e) + xlist[i] = &x + } + } + return +} + // xlist is the list of type argument expressions supplied in the source code. func (check *Checker) arguments(call *ast.CallExpr, sig *Signature, targs []Type, args []*operand, xlist []ast.Expr) (rsig *Signature) { rsig = sig @@ -391,7 +421,7 @@ func (check *Checker) arguments(call *ast.CallExpr, sig *Signature, targs []Type // set up parameters sigParams := sig.params // adjusted for variadic functions (may be nil for empty parameter lists!) - adjusted := false // indicates if sigParams is different from t.params + adjusted := false // indicates if sigParams is different from sig.params if sig.variadic { if ddd { // variadic_func(a, b, c...) @@ -452,8 +482,12 @@ func (check *Checker) arguments(call *ast.CallExpr, sig *Signature, targs []Type return } - // infer type arguments and instantiate signature if necessary - if sig.TypeParams().Len() > 0 { + // collect type parameters of callee and generic function arguments + var tparams []*TypeParam + + // collect type parameters of callee + n := sig.TypeParams().Len() + if n > 0 { if !check.allowVersion(check.pkg, call.Pos(), 1, 18) { switch call.Fun.(type) { case *ast.IndexExpr, *ast.IndexListExpr: @@ -463,29 +497,72 @@ func (check *Checker) arguments(call *ast.CallExpr, sig *Signature, targs []Type check.softErrorf(inNode(call, call.Lparen), UnsupportedFeature, "implicit function instantiation requires go1.18 or later") } } - - // Rename type parameters to avoid problems with recursive calls. - var tparams []*TypeParam + // rename type parameters to avoid problems with recursive calls tparams, sigParams = check.renameTParams(call.Pos(), sig.TypeParams().list(), sigParams) + } - targs := check.infer(call, tparams, targs, sigParams, args) + // collect type parameters from generic function arguments + var genericArgs []int // indices of generic function arguments + if check.conf._EnableReverseTypeInference { + for i, arg := range args { + // generic arguments cannot have a defined (*Named) type - no need for underlying type below + if asig, _ := arg.typ.(*Signature); asig != nil && asig.TypeParams().Len() > 0 { + // TODO(gri) need to also rename type parameters for cases like f(g, g) + tparams = append(tparams, asig.TypeParams().list()...) + genericArgs = append(genericArgs, i) + } + } + } + if len(genericArgs) > 0 && !check.allowVersion(check.pkg, call.Pos(), 1, 21) { + // at the moment we only support implicit instantiations of argument functions + check.softErrorf(inNode(call, call.Lparen), UnsupportedFeature, "implicitly instantiated function as argument requires go1.21 or later") + } + + // tparams holds the type parameters of the callee and generic function arguments, if any: + // the first n type parameters belong to the callee, followed by mi type parameters for each + // of the generic function arguments, where mi = args[i].typ.(*Signature).TypeParams().Len(). + + // infer missing type arguments of callee and function arguments + if len(tparams) > 0 { + targs = check.infer(call, tparams, targs, sigParams, args) if targs == nil { + // TODO(gri) If infer inferred the first targs[:n], consider instantiating + // the call signature for better error messages/gopls behavior. + // Perhaps instantiate as much as we can, also for arguments. + // This will require changes to how infer returns its results. return // error already reported } - // compute result signature - rsig = check.instantiateSignature(call.Pos(), sig, targs, xlist) - assert(rsig.TypeParams().Len() == 0) // signature is not generic anymore - check.recordInstance(call.Fun, targs, rsig) + // compute result signature: instantiate if needed + rsig = sig + if n > 0 { + rsig = check.instantiateSignature(call.Pos(), sig, targs[:n], xlist) + assert(rsig.TypeParams().Len() == 0) // signature is not generic anymore + check.recordInstance(call.Fun, targs[:n], rsig) + } - // Optimization: Only if the parameter list was adjusted do we - // need to compute it from the adjusted list; otherwise we can - // simply use the result signature's parameter list. - if adjusted { - sigParams = check.subst(call.Pos(), sigParams, makeSubstMap(tparams, targs), nil, check.context()).(*Tuple) + // Optimization: Only if the callee's parameter list was adjusted do we need to + // compute it from the adjusted list; otherwise we can simply use the result + // signature's parameter list. We only need the n type parameters and arguments + // of the callee. + if n > 0 && adjusted { + sigParams = check.subst(call.Pos(), sigParams, makeSubstMap(tparams[:n], targs[:n]), nil, check.context()).(*Tuple) } else { sigParams = rsig.params } + + // compute argument signatures: instantiate if needed + j := n + for _, i := range genericArgs { + asig := args[i].typ.(*Signature) + k := j + asig.TypeParams().Len() + // targs[j:k] are the inferred type arguments for asig + asig = check.instantiateSignature(call.Pos(), asig, targs[j:k], nil) // TODO(gri) provide xlist if possible (partial instantiations) + assert(asig.TypeParams().Len() == 0) // signature is not generic anymore + args[i].typ = asig + check.recordInstance(args[i].expr, targs[j:k], asig) + j = k + } } // check arguments diff --git a/src/go/types/check_test.go b/src/go/types/check_test.go index 0f4c320a47..cda052f4d3 100644 --- a/src/go/types/check_test.go +++ b/src/go/types/check_test.go @@ -39,6 +39,7 @@ import ( "go/scanner" "go/token" "internal/testenv" + "internal/types/errors" "os" "path/filepath" "reflect" @@ -295,9 +296,9 @@ func testFiles(t *testing.T, sizes Sizes, filenames []string, srcs [][]byte, man } } -func readCode(err Error) int { +func readCode(err Error) errors.Code { v := reflect.ValueOf(err) - return int(v.FieldByName("go116code").Int()) + return errors.Code(v.FieldByName("go116code").Int()) } // boolFieldAddr(conf, name) returns the address of the boolean field conf.. diff --git a/src/go/types/expr.go b/src/go/types/expr.go index 0db80ca44b..891153ba8d 100644 --- a/src/go/types/expr.go +++ b/src/go/types/expr.go @@ -1829,6 +1829,27 @@ func (check *Checker) multiExpr(e ast.Expr, allowCommaOk bool) (list []*operand, return } +// genericMultiExpr is like multiExpr but a one-element result may also be generic +// and potential comma-ok expressions are returned as single values. +func (check *Checker) genericMultiExpr(e ast.Expr) (list []*operand) { + var x operand + check.rawExpr(nil, &x, e, nil, true) + check.exclude(&x, 1<= 0 default: -- cgit v1.2.1