diff options
-rw-r--r-- | src/cmd/compile/internal/types2/api.go | 2 | ||||
-rw-r--r-- | src/cmd/compile/internal/types2/api_test.go | 43 | ||||
-rw-r--r-- | src/cmd/compile/internal/types2/call.go | 125 | ||||
-rw-r--r-- | src/cmd/compile/internal/types2/expr.go | 21 | ||||
-rw-r--r-- | src/cmd/compile/internal/types2/infer.go | 23 | ||||
-rw-r--r-- | src/go/types/api.go | 2 | ||||
-rw-r--r-- | src/go/types/api_test.go | 43 | ||||
-rw-r--r-- | src/go/types/call.go | 125 | ||||
-rw-r--r-- | src/go/types/check_test.go | 5 | ||||
-rw-r--r-- | src/go/types/expr.go | 21 | ||||
-rw-r--r-- | src/go/types/infer.go | 23 | ||||
-rw-r--r-- | src/internal/types/testdata/examples/inference2.go | 31 | ||||
-rw-r--r-- | src/internal/types/testdata/fixedbugs/issue59338a.go | 21 | ||||
-rw-r--r-- | src/internal/types/testdata/fixedbugs/issue59338b.go | 21 |
14 files changed, 420 insertions, 86 deletions
diff --git a/src/cmd/compile/internal/types2/api.go b/src/cmd/compile/internal/types2/api.go index bd87945295..0ee9a4bd06 100644 --- a/src/cmd/compile/internal/types2/api.go +++ b/src/cmd/compile/internal/types2/api.go @@ -174,7 +174,7 @@ type Config struct { // partially instantiated generic functions may be assigned // (incl. returned) to variables of function type and type // inference will attempt to infer the missing type arguments. - // Experimental. Needs a proposal. + // See proposal go.dev/issue/59338. EnableReverseTypeInference bool } diff --git a/src/cmd/compile/internal/types2/api_test.go b/src/cmd/compile/internal/types2/api_test.go index a13f43111c..ae253623e6 100644 --- a/src/cmd/compile/internal/types2/api_test.go +++ b/src/cmd/compile/internal/types2/api_test.go @@ -550,18 +550,55 @@ type T[P any] []P {`package issue51803; func foo[T any](T) {}; func _() { foo[int]( /* leave arg away on purpose */ ) }`, []testInst{{`foo`, []string{`int`}, `func(int)`}}, }, + + // reverse type parameter inference + {`package reverse1a; var f func(int) = g; func g[P any](P) {}`, + []testInst{{`g`, []string{`int`}, `func(int)`}}, + }, + {`package reverse1b; func f(func(int)) {}; func g[P any](P) {}; func _() { f(g) }`, + []testInst{{`g`, []string{`int`}, `func(int)`}}, + }, + {`package reverse2a; var f func(int) string = g; func g[P, Q any](P) Q { var q Q; return q }`, + []testInst{{`g`, []string{`int`, `string`}, `func(int) string`}}, + }, + {`package reverse2b; func f(func(int) string) {}; func g[P, Q any](P) Q { var q Q; return q }; func _() { f(g) }`, + []testInst{{`g`, []string{`int`, `string`}, `func(int) string`}}, + }, + // reverse3a not possible (cannot assign to generic function outside of argument passing) + {`package reverse3b; func f[R any](func(int) R) {}; func g[P any](P) string { return "" }; func _() { f(g) }`, + []testInst{ + {`f`, []string{`string`}, `func(func(int) string)`}, + {`g`, []string{`int`}, `func(int) string`}, + }, + }, + {`package reverse4a; var _, _ func([]int, *float32) = g, h; func g[P, Q any]([]P, *Q) {}; func h[R any]([]R, *float32) {}`, + []testInst{ + {`g`, []string{`int`, `float32`}, `func([]int, *float32)`}, + {`h`, []string{`int`}, `func([]int, *float32)`}, + }, + }, + {`package reverse4b; func f(_, _ func([]int, *float32)) {}; func g[P, Q any]([]P, *Q) {}; func h[R any]([]R, *float32) {}; func _() { f(g, h) }`, + []testInst{ + {`g`, []string{`int`, `float32`}, `func([]int, *float32)`}, + {`h`, []string{`int`}, `func([]int, *float32)`}, + }, + }, } for _, test := range tests { imports := make(testImporter) conf := Config{ - Importer: imports, - Error: func(error) {}, // ignore errors + Importer: imports, + EnableReverseTypeInference: true, } instMap := make(map[*syntax.Name]Instance) useMap := make(map[*syntax.Name]Object) makePkg := func(src string) *Package { - pkg, _ := typecheck(src, &conf, &Info{Instances: instMap, Uses: useMap}) + pkg, err := typecheck(src, &conf, &Info{Instances: instMap, Uses: useMap}) + // allow error for issue51803 + if err != nil && (pkg == nil || pkg.Name() != "issue51803") { + t.Fatal(err) + } imports[pkg.Name()] = pkg return pkg } diff --git a/src/cmd/compile/internal/types2/call.go b/src/cmd/compile/internal/types2/call.go index 20cde9f44e..8ad7744ab4 100644 --- a/src/cmd/compile/internal/types2/call.go +++ b/src/cmd/compile/internal/types2/call.go @@ -23,14 +23,16 @@ import ( func (check *Checker) funcInst(tsig *Signature, pos syntax.Pos, x *operand, inst *syntax.IndexExpr) { assert(tsig != nil || inst != nil) + var versionErr bool // set if version error was reported + var instErrPos poser // position for instantion error + if inst != nil { + instErrPos = inst.Pos() + } else { + instErrPos = pos + } if !check.allowVersion(check.pkg, pos, 1, 18) { - var posn poser - if inst != nil { - posn = inst.Pos() - } else { - posn = pos - } - check.versionErrorf(posn, "go1.18", "function instantiation") + check.versionErrorf(instErrPos, "go1.18", "function instantiation") + versionErr = true } // targs and xlist are the type arguments and corresponding type expressions, or nil. @@ -72,6 +74,13 @@ func (check *Checker) funcInst(tsig *Signature, pos syntax.Pos, x *operand, inst // of a synthetic function f where f's parameters are the parameters and results // of x and where the arguments to the call of f are values of the parameter and // result types of x. + if !versionErr && !check.allowVersion(check.pkg, pos, 1, 21) { + if inst != nil { + check.versionErrorf(instErrPos, "go1.21", "partially instantiated function in assignment") + } else { + check.versionErrorf(instErrPos, "go1.21", "implicitly instantiated function in assignment") + } + } n := tsig.params.Len() m := tsig.results.Len() args = make([]*operand, n+m) @@ -303,7 +312,7 @@ func (check *Checker) callExpr(x *operand, call *syntax.CallExpr) exprKind { } // evaluate arguments - args := check.exprList(call.ArgList) + args := check.genericExprList(call.ArgList) sig = check.arguments(call, sig, targs, args, xlist) if wasGeneric && sig.TypeParams().Len() == 0 { @@ -338,6 +347,8 @@ func (check *Checker) callExpr(x *operand, call *syntax.CallExpr) exprKind { return statement } +// exprList evaluates a list of expressions and returns the corresponding operands. +// A single-element expression list may evaluate to multiple operands. func (check *Checker) exprList(elist []syntax.Expr) (xlist []*operand) { switch len(elist) { case 0: @@ -356,6 +367,25 @@ func (check *Checker) exprList(elist []syntax.Expr) (xlist []*operand) { return } +// genericExprList is like exprList but result operands may be generic (not fully instantiated). +func (check *Checker) genericExprList(elist []syntax.Expr) (xlist []*operand) { + switch len(elist) { + case 0: + // nothing to do + case 1: + xlist = check.genericMultiExpr(elist[0]) + default: + // multiple (possibly invalid) values + xlist = make([]*operand, len(elist)) + for i, e := range elist { + var x operand + check.genericExpr(&x, e) + xlist[i] = &x + } + } + return +} + // xlist is the list of type argument expressions supplied in the source code. func (check *Checker) arguments(call *syntax.CallExpr, sig *Signature, targs []Type, args []*operand, xlist []syntax.Expr) (rsig *Signature) { rsig = sig @@ -386,7 +416,7 @@ func (check *Checker) arguments(call *syntax.CallExpr, sig *Signature, targs []T // set up parameters sigParams := sig.params // adjusted for variadic functions (may be nil for empty parameter lists!) - adjusted := false // indicates if sigParams is different from t.params + adjusted := false // indicates if sigParams is different from sig.params if sig.variadic { if ddd { // variadic_func(a, b, c...) @@ -451,8 +481,12 @@ func (check *Checker) arguments(call *syntax.CallExpr, sig *Signature, targs []T return } - // infer type arguments and instantiate signature if necessary - if sig.TypeParams().Len() > 0 { + // collect type parameters of callee and generic function arguments + var tparams []*TypeParam + + // collect type parameters of callee + n := sig.TypeParams().Len() + if n > 0 { if !check.allowVersion(check.pkg, call.Pos(), 1, 18) { if iexpr, _ := call.Fun.(*syntax.IndexExpr); iexpr != nil { check.versionErrorf(iexpr.Pos(), "go1.18", "function instantiation") @@ -460,29 +494,72 @@ func (check *Checker) arguments(call *syntax.CallExpr, sig *Signature, targs []T check.versionErrorf(call.Pos(), "go1.18", "implicit function instantiation") } } - - // Rename type parameters to avoid problems with recursive calls. - var tparams []*TypeParam + // rename type parameters to avoid problems with recursive calls tparams, sigParams = check.renameTParams(call.Pos(), sig.TypeParams().list(), sigParams) + } - targs := check.infer(call.Pos(), tparams, targs, sigParams, args) + // collect type parameters from generic function arguments + var genericArgs []int // indices of generic function arguments + if check.conf.EnableReverseTypeInference { + for i, arg := range args { + // generic arguments cannot have a defined (*Named) type - no need for underlying type below + if asig, _ := arg.typ.(*Signature); asig != nil && asig.TypeParams().Len() > 0 { + // TODO(gri) need to also rename type parameters for cases like f(g, g) + tparams = append(tparams, asig.TypeParams().list()...) + genericArgs = append(genericArgs, i) + } + } + } + if len(genericArgs) > 0 && !check.allowVersion(check.pkg, call.Pos(), 1, 21) { + // at the moment we only support implicit instantiations of argument functions + check.versionErrorf(args[genericArgs[0]].Pos(), "go1.21", "implicitly instantiated function as argument") + } + + // tparams holds the type parameters of the callee and generic function arguments, if any: + // the first n type parameters belong to the callee, followed by mi type parameters for each + // of the generic function arguments, where mi = args[i].typ.(*Signature).TypeParams().Len(). + + // infer missing type arguments of callee and function arguments + if len(tparams) > 0 { + targs = check.infer(call.Pos(), tparams, targs, sigParams, args) if targs == nil { + // TODO(gri) If infer inferred the first targs[:n], consider instantiating + // the call signature for better error messages/gopls behavior. + // Perhaps instantiate as much as we can, also for arguments. + // This will require changes to how infer returns its results. return // error already reported } - // compute result signature - rsig = check.instantiateSignature(call.Pos(), sig, targs, xlist) - assert(rsig.TypeParams().Len() == 0) // signature is not generic anymore - check.recordInstance(call.Fun, targs, rsig) + // compute result signature: instantiate if needed + rsig = sig + if n > 0 { + rsig = check.instantiateSignature(call.Pos(), sig, targs[:n], xlist) + assert(rsig.TypeParams().Len() == 0) // signature is not generic anymore + check.recordInstance(call.Fun, targs[:n], rsig) + } - // Optimization: Only if the parameter list was adjusted do we - // need to compute it from the adjusted list; otherwise we can - // simply use the result signature's parameter list. - if adjusted { - sigParams = check.subst(call.Pos(), sigParams, makeSubstMap(tparams, targs), nil, check.context()).(*Tuple) + // Optimization: Only if the callee's parameter list was adjusted do we need to + // compute it from the adjusted list; otherwise we can simply use the result + // signature's parameter list. We only need the n type parameters and arguments + // of the callee. + if n > 0 && adjusted { + sigParams = check.subst(call.Pos(), sigParams, makeSubstMap(tparams[:n], targs[:n]), nil, check.context()).(*Tuple) } else { sigParams = rsig.params } + + // compute argument signatures: instantiate if needed + j := n + for _, i := range genericArgs { + asig := args[i].typ.(*Signature) + k := j + asig.TypeParams().Len() + // targs[j:k] are the inferred type arguments for asig + asig = check.instantiateSignature(call.Pos(), asig, targs[j:k], nil) // TODO(gri) provide xlist if possible (partial instantiations) + assert(asig.TypeParams().Len() == 0) // signature is not generic anymore + args[i].typ = asig + check.recordInstance(args[i].expr, targs[j:k], asig) + j = k + } } // check arguments diff --git a/src/cmd/compile/internal/types2/expr.go b/src/cmd/compile/internal/types2/expr.go index 7c3b40f086..51b944eead 100644 --- a/src/cmd/compile/internal/types2/expr.go +++ b/src/cmd/compile/internal/types2/expr.go @@ -1882,6 +1882,27 @@ func (check *Checker) multiExpr(e syntax.Expr, allowCommaOk bool) (list []*opera return } +// genericMultiExpr is like multiExpr but a one-element result may also be generic +// and potential comma-ok expressions are returned as single values. +func (check *Checker) genericMultiExpr(e syntax.Expr) (list []*operand) { + var x operand + check.rawExpr(nil, &x, e, nil, true) + check.exclude(&x, 1<<novalue|1<<builtin|1<<typexpr) + + if t, ok := x.typ.(*Tuple); ok && x.mode != invalid { + // multiple values - cannot be generic + list = make([]*operand, t.Len()) + for i, v := range t.vars { + list[i] = &operand{mode: value, expr: e, typ: v.typ} + } + return + } + + // exactly one (possible invalid or generic) value + list = []*operand{&x} + return +} + // exprWithHint typechecks expression e and initializes x with the expression value; // hint is the type of a composite literal element. // If an error occurred, x.mode is set to invalid. diff --git a/src/cmd/compile/internal/types2/infer.go b/src/cmd/compile/internal/types2/infer.go index dbe621cded..9c1022c46f 100644 --- a/src/cmd/compile/internal/types2/infer.go +++ b/src/cmd/compile/internal/types2/infer.go @@ -140,17 +140,17 @@ func (check *Checker) infer(pos syntax.Pos, tparams []*TypeParam, targs []Type, } for i, arg := range args { + if arg.mode == invalid { + // An error was reported earlier. Ignore this arg + // and continue, we may still be able to infer all + // targs resulting in fewer follow-on errors. + // TODO(gri) determine if we still need this check + continue + } par := params.At(i) - // If we permit bidirectional unification, this conditional code needs to be - // executed even if par.typ is not parameterized since the argument may be a - // generic function (for which we want to infer its type arguments). - if isParameterized(tparams, par.typ) { - if arg.mode == invalid { - // An error was reported earlier. Ignore this targ - // and continue, we may still be able to infer all - // targs resulting in fewer follow-on errors. - continue - } + if isParameterized(tparams, par.typ) || isParameterized(tparams, arg.typ) { + // Function parameters are always typed. Arguments may be untyped. + // Collect the indices of untyped arguments and handle them later. if isTyped(arg.typ) { if !u.unify(par.typ, arg.typ) { errorf("type", par.typ, arg.typ, arg) @@ -263,7 +263,7 @@ func (check *Checker) infer(pos syntax.Pos, tparams []*TypeParam, targs []Type, } // --- 3 --- - // use information from untyped contants + // use information from untyped constants if traceInference { u.tracef("== untyped arguments: %v", untyped) @@ -541,7 +541,6 @@ func (w *tpWalker) isParameterized(typ Type) (res bool) { } case *TypeParam: - // t must be one of w.tparams return tparamIndex(w.tparams, t) >= 0 default: diff --git a/src/go/types/api.go b/src/go/types/api.go index 7af84fd244..e202d6dea8 100644 --- a/src/go/types/api.go +++ b/src/go/types/api.go @@ -175,7 +175,7 @@ type Config struct { // partially instantiated generic functions may be assigned // (incl. returned) to variables of function type and type // inference will attempt to infer the missing type arguments. - // Experimental. Needs a proposal. + // See proposal go.dev/issue/59338. _EnableReverseTypeInference bool } diff --git a/src/go/types/api_test.go b/src/go/types/api_test.go index ae1a7e50a7..02e26c3f02 100644 --- a/src/go/types/api_test.go +++ b/src/go/types/api_test.go @@ -550,18 +550,57 @@ type T[P any] []P {`package issue51803; func foo[T any](T) {}; func _() { foo[int]( /* leave arg away on purpose */ ) }`, []testInst{{`foo`, []string{`int`}, `func(int)`}}, }, + + // reverse type parameter inference + {`package reverse1a; var f func(int) = g; func g[P any](P) {}`, + []testInst{{`g`, []string{`int`}, `func(int)`}}, + }, + {`package reverse1b; func f(func(int)) {}; func g[P any](P) {}; func _() { f(g) }`, + []testInst{{`g`, []string{`int`}, `func(int)`}}, + }, + {`package reverse2a; var f func(int) string = g; func g[P, Q any](P) Q { var q Q; return q }`, + []testInst{{`g`, []string{`int`, `string`}, `func(int) string`}}, + }, + {`package reverse2b; func f(func(int) string) {}; func g[P, Q any](P) Q { var q Q; return q }; func _() { f(g) }`, + []testInst{{`g`, []string{`int`, `string`}, `func(int) string`}}, + }, + // reverse3a not possible (cannot assign to generic function outside of argument passing) + {`package reverse3b; func f[R any](func(int) R) {}; func g[P any](P) string { return "" }; func _() { f(g) }`, + []testInst{ + {`f`, []string{`string`}, `func(func(int) string)`}, + {`g`, []string{`int`}, `func(int) string`}, + }, + }, + {`package reverse4a; var _, _ func([]int, *float32) = g, h; func g[P, Q any]([]P, *Q) {}; func h[R any]([]R, *float32) {}`, + []testInst{ + {`g`, []string{`int`, `float32`}, `func([]int, *float32)`}, + {`h`, []string{`int`}, `func([]int, *float32)`}, + }, + }, + {`package reverse4b; func f(_, _ func([]int, *float32)) {}; func g[P, Q any]([]P, *Q) {}; func h[R any]([]R, *float32) {}; func _() { f(g, h) }`, + []testInst{ + {`g`, []string{`int`, `float32`}, `func([]int, *float32)`}, + {`h`, []string{`int`}, `func([]int, *float32)`}, + }, + }, } for _, test := range tests { imports := make(testImporter) conf := Config{ Importer: imports, - Error: func(error) {}, // ignore errors + // Unexported field: set below with boolFieldAddr + // _EnableReverseTypeInference: true, } + *boolFieldAddr(&conf, "_EnableReverseTypeInference") = true instMap := make(map[*ast.Ident]Instance) useMap := make(map[*ast.Ident]Object) makePkg := func(src string) *Package { - pkg, _ := typecheck(src, &conf, &Info{Instances: instMap, Uses: useMap}) + pkg, err := typecheck(src, &conf, &Info{Instances: instMap, Uses: useMap}) + // allow error for issue51803 + if err != nil && (pkg == nil || pkg.Name() != "issue51803") { + t.Fatal(err) + } imports[pkg.Name()] = pkg return pkg } diff --git a/src/go/types/call.go b/src/go/types/call.go index 979de2338f..02b6038ccc 100644 --- a/src/go/types/call.go +++ b/src/go/types/call.go @@ -25,14 +25,16 @@ import ( func (check *Checker) funcInst(tsig *Signature, pos token.Pos, x *operand, ix *typeparams.IndexExpr) { assert(tsig != nil || ix != nil) + var versionErr bool // set if version error was reported + var instErrPos positioner // position for instantion error + if ix != nil { + instErrPos = inNode(ix.Orig, ix.Lbrack) + } else { + instErrPos = atPos(pos) + } if !check.allowVersion(check.pkg, pos, 1, 18) { - var posn positioner - if ix != nil { - posn = inNode(ix.Orig, ix.Lbrack) - } else { - posn = atPos(pos) - } - check.softErrorf(posn, UnsupportedFeature, "function instantiation requires go1.18 or later") + check.softErrorf(instErrPos, UnsupportedFeature, "function instantiation requires go1.18 or later") + versionErr = true } // targs and xlist are the type arguments and corresponding type expressions, or nil. @@ -74,6 +76,13 @@ func (check *Checker) funcInst(tsig *Signature, pos token.Pos, x *operand, ix *t // of a synthetic function f where f's parameters are the parameters and results // of x and where the arguments to the call of f are values of the parameter and // result types of x. + if !versionErr && !check.allowVersion(check.pkg, pos, 1, 21) { + if ix != nil { + check.softErrorf(instErrPos, UnsupportedFeature, "partially instantiated function in assignment requires go1.21 or later") + } else { + check.softErrorf(instErrPos, UnsupportedFeature, "implicitly instantiated function in assignment requires go1.21 or later") + } + } n := tsig.params.Len() m := tsig.results.Len() args = make([]*operand, n+m) @@ -308,7 +317,7 @@ func (check *Checker) callExpr(x *operand, call *ast.CallExpr) exprKind { } // evaluate arguments - args := check.exprList(call.Args) + args := check.genericExprList(call.Args) sig = check.arguments(call, sig, targs, args, xlist) if wasGeneric && sig.TypeParams().Len() == 0 { @@ -343,6 +352,8 @@ func (check *Checker) callExpr(x *operand, call *ast.CallExpr) exprKind { return statement } +// exprList evaluates a list of expressions and returns the corresponding operands. +// A single-element expression list may evaluate to multiple operands. func (check *Checker) exprList(elist []ast.Expr) (xlist []*operand) { switch len(elist) { case 0: @@ -361,6 +372,25 @@ func (check *Checker) exprList(elist []ast.Expr) (xlist []*operand) { return } +// genericExprList is like exprList but result operands may be generic (not fully instantiated). +func (check *Checker) genericExprList(elist []ast.Expr) (xlist []*operand) { + switch len(elist) { + case 0: + // nothing to do + case 1: + xlist = check.genericMultiExpr(elist[0]) + default: + // multiple (possibly invalid) values + xlist = make([]*operand, len(elist)) + for i, e := range elist { + var x operand + check.genericExpr(&x, e) + xlist[i] = &x + } + } + return +} + // xlist is the list of type argument expressions supplied in the source code. func (check *Checker) arguments(call *ast.CallExpr, sig *Signature, targs []Type, args []*operand, xlist []ast.Expr) (rsig *Signature) { rsig = sig @@ -391,7 +421,7 @@ func (check *Checker) arguments(call *ast.CallExpr, sig *Signature, targs []Type // set up parameters sigParams := sig.params // adjusted for variadic functions (may be nil for empty parameter lists!) - adjusted := false // indicates if sigParams is different from t.params + adjusted := false // indicates if sigParams is different from sig.params if sig.variadic { if ddd { // variadic_func(a, b, c...) @@ -452,8 +482,12 @@ func (check *Checker) arguments(call *ast.CallExpr, sig *Signature, targs []Type return } - // infer type arguments and instantiate signature if necessary - if sig.TypeParams().Len() > 0 { + // collect type parameters of callee and generic function arguments + var tparams []*TypeParam + + // collect type parameters of callee + n := sig.TypeParams().Len() + if n > 0 { if !check.allowVersion(check.pkg, call.Pos(), 1, 18) { switch call.Fun.(type) { case *ast.IndexExpr, *ast.IndexListExpr: @@ -463,29 +497,72 @@ func (check *Checker) arguments(call *ast.CallExpr, sig *Signature, targs []Type check.softErrorf(inNode(call, call.Lparen), UnsupportedFeature, "implicit function instantiation requires go1.18 or later") } } - - // Rename type parameters to avoid problems with recursive calls. - var tparams []*TypeParam + // rename type parameters to avoid problems with recursive calls tparams, sigParams = check.renameTParams(call.Pos(), sig.TypeParams().list(), sigParams) + } - targs := check.infer(call, tparams, targs, sigParams, args) + // collect type parameters from generic function arguments + var genericArgs []int // indices of generic function arguments + if check.conf._EnableReverseTypeInference { + for i, arg := range args { + // generic arguments cannot have a defined (*Named) type - no need for underlying type below + if asig, _ := arg.typ.(*Signature); asig != nil && asig.TypeParams().Len() > 0 { + // TODO(gri) need to also rename type parameters for cases like f(g, g) + tparams = append(tparams, asig.TypeParams().list()...) + genericArgs = append(genericArgs, i) + } + } + } + if len(genericArgs) > 0 && !check.allowVersion(check.pkg, call.Pos(), 1, 21) { + // at the moment we only support implicit instantiations of argument functions + check.softErrorf(inNode(call, call.Lparen), UnsupportedFeature, "implicitly instantiated function as argument requires go1.21 or later") + } + + // tparams holds the type parameters of the callee and generic function arguments, if any: + // the first n type parameters belong to the callee, followed by mi type parameters for each + // of the generic function arguments, where mi = args[i].typ.(*Signature).TypeParams().Len(). + + // infer missing type arguments of callee and function arguments + if len(tparams) > 0 { + targs = check.infer(call, tparams, targs, sigParams, args) if targs == nil { + // TODO(gri) If infer inferred the first targs[:n], consider instantiating + // the call signature for better error messages/gopls behavior. + // Perhaps instantiate as much as we can, also for arguments. + // This will require changes to how infer returns its results. return // error already reported } - // compute result signature - rsig = check.instantiateSignature(call.Pos(), sig, targs, xlist) - assert(rsig.TypeParams().Len() == 0) // signature is not generic anymore - check.recordInstance(call.Fun, targs, rsig) + // compute result signature: instantiate if needed + rsig = sig + if n > 0 { + rsig = check.instantiateSignature(call.Pos(), sig, targs[:n], xlist) + assert(rsig.TypeParams().Len() == 0) // signature is not generic anymore + check.recordInstance(call.Fun, targs[:n], rsig) + } - // Optimization: Only if the parameter list was adjusted do we - // need to compute it from the adjusted list; otherwise we can - // simply use the result signature's parameter list. - if adjusted { - sigParams = check.subst(call.Pos(), sigParams, makeSubstMap(tparams, targs), nil, check.context()).(*Tuple) + // Optimization: Only if the callee's parameter list was adjusted do we need to + // compute it from the adjusted list; otherwise we can simply use the result + // signature's parameter list. We only need the n type parameters and arguments + // of the callee. + if n > 0 && adjusted { + sigParams = check.subst(call.Pos(), sigParams, makeSubstMap(tparams[:n], targs[:n]), nil, check.context()).(*Tuple) } else { sigParams = rsig.params } + + // compute argument signatures: instantiate if needed + j := n + for _, i := range genericArgs { + asig := args[i].typ.(*Signature) + k := j + asig.TypeParams().Len() + // targs[j:k] are the inferred type arguments for asig + asig = check.instantiateSignature(call.Pos(), asig, targs[j:k], nil) // TODO(gri) provide xlist if possible (partial instantiations) + assert(asig.TypeParams().Len() == 0) // signature is not generic anymore + args[i].typ = asig + check.recordInstance(args[i].expr, targs[j:k], asig) + j = k + } } // check arguments diff --git a/src/go/types/check_test.go b/src/go/types/check_test.go index 0f4c320a47..cda052f4d3 100644 --- a/src/go/types/check_test.go +++ b/src/go/types/check_test.go @@ -39,6 +39,7 @@ import ( "go/scanner" "go/token" "internal/testenv" + "internal/types/errors" "os" "path/filepath" "reflect" @@ -295,9 +296,9 @@ func testFiles(t *testing.T, sizes Sizes, filenames []string, srcs [][]byte, man } } -func readCode(err Error) int { +func readCode(err Error) errors.Code { v := reflect.ValueOf(err) - return int(v.FieldByName("go116code").Int()) + return errors.Code(v.FieldByName("go116code").Int()) } // boolFieldAddr(conf, name) returns the address of the boolean field conf.<name>. diff --git a/src/go/types/expr.go b/src/go/types/expr.go index 0db80ca44b..891153ba8d 100644 --- a/src/go/types/expr.go +++ b/src/go/types/expr.go @@ -1829,6 +1829,27 @@ func (check *Checker) multiExpr(e ast.Expr, allowCommaOk bool) (list []*operand, return } +// genericMultiExpr is like multiExpr but a one-element result may also be generic +// and potential comma-ok expressions are returned as single values. +func (check *Checker) genericMultiExpr(e ast.Expr) (list []*operand) { + var x operand + check.rawExpr(nil, &x, e, nil, true) + check.exclude(&x, 1<<novalue|1<<builtin|1<<typexpr) + + if t, ok := x.typ.(*Tuple); ok && x.mode != invalid { + // multiple values - cannot be generic + list = make([]*operand, t.Len()) + for i, v := range t.vars { + list[i] = &operand{mode: value, expr: e, typ: v.typ} + } + return + } + + // exactly one (possible invalid or generic) value + list = []*operand{&x} + return +} + // exprWithHint typechecks expression e and initializes x with the expression value; // hint is the type of a composite literal element. // If an error occurred, x.mode is set to invalid. diff --git a/src/go/types/infer.go b/src/go/types/infer.go index 3aa66105c4..39bf4a14f7 100644 --- a/src/go/types/infer.go +++ b/src/go/types/infer.go @@ -142,17 +142,17 @@ func (check *Checker) infer(posn positioner, tparams []*TypeParam, targs []Type, } for i, arg := range args { + if arg.mode == invalid { + // An error was reported earlier. Ignore this arg + // and continue, we may still be able to infer all + // targs resulting in fewer follow-on errors. + // TODO(gri) determine if we still need this check + continue + } par := params.At(i) - // If we permit bidirectional unification, this conditional code needs to be - // executed even if par.typ is not parameterized since the argument may be a - // generic function (for which we want to infer its type arguments). - if isParameterized(tparams, par.typ) { - if arg.mode == invalid { - // An error was reported earlier. Ignore this targ - // and continue, we may still be able to infer all - // targs resulting in fewer follow-on errors. - continue - } + if isParameterized(tparams, par.typ) || isParameterized(tparams, arg.typ) { + // Function parameters are always typed. Arguments may be untyped. + // Collect the indices of untyped arguments and handle them later. if isTyped(arg.typ) { if !u.unify(par.typ, arg.typ) { errorf("type", par.typ, arg.typ, arg) @@ -265,7 +265,7 @@ func (check *Checker) infer(posn positioner, tparams []*TypeParam, targs []Type, } // --- 3 --- - // use information from untyped contants + // use information from untyped constants if traceInference { u.tracef("== untyped arguments: %v", untyped) @@ -543,7 +543,6 @@ func (w *tpWalker) isParameterized(typ Type) (res bool) { } case *TypeParam: - // t must be one of w.tparams return tparamIndex(w.tparams, t) >= 0 default: diff --git a/src/internal/types/testdata/examples/inference2.go b/src/internal/types/testdata/examples/inference2.go index d309a00c0c..80acc828dd 100644 --- a/src/internal/types/testdata/examples/inference2.go +++ b/src/internal/types/testdata/examples/inference2.go @@ -10,11 +10,13 @@ package p -func f1[P any](P) {} -func f2[P any]() P { var x P; return x } -func f3[P, Q any](P) Q { var x Q; return x } -func f4[P any](P, P) {} -func f5[P any](P) []P { return nil } +func f1[P any](P) {} +func f2[P any]() P { var x P; return x } +func f3[P, Q any](P) Q { var x Q; return x } +func f4[P any](P, P) {} +func f5[P any](P) []P { return nil } +func f6[P any](int) P { var x P; return x } +func f7[P any](P) string { return "" } // initialization expressions var ( @@ -71,3 +73,22 @@ func _() func(string) []int { func _() (_, _ func(int)) { return f1, f1 } func _() (_, _ func(int)) { return f1, f2 /* ERROR "cannot infer P" */ } + +// Argument passing +func g1(func(int)) {} +func g2(func(int, int)) {} +func g3(func(int) string) {} +func g4[P any](func(P) string) {} +func g5[P, Q any](func(P) string, func(P) Q) {} +func g6(func(int), func(string)) {} + +func _() { + g1(f1) + g1(f2 /* ERROR "cannot infer P" */) + g2(f4) + g4(f6) + g5(f6, f7) + + // TODO(gri) this should work (requires type parameter renaming for f1) + g6(f1, f1 /* ERROR "type func[P any](P) of f1 does not match func(string)" */) +} diff --git a/src/internal/types/testdata/fixedbugs/issue59338a.go b/src/internal/types/testdata/fixedbugs/issue59338a.go new file mode 100644 index 0000000000..fd37586cfb --- /dev/null +++ b/src/internal/types/testdata/fixedbugs/issue59338a.go @@ -0,0 +1,21 @@ +// -reverseTypeInference -lang=go1.20 + +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package p + +func g[P any](P) {} +func h[P, Q any](P) Q { panic(0) } + +var _ func(int) = g /* ERROR "implicitly instantiated function in assignment requires go1.21 or later" */ +var _ func(int) string = h[ /* ERROR "partially instantiated function in assignment requires go1.21 or later" */ int] + +func f1(func(int)) {} +func f2(int, func(int)) {} + +func _() { + f1( /* ERROR "implicitly instantiated function as argument requires go1.21 or later" */ g) + f2( /* ERROR "implicitly instantiated function as argument requires go1.21 or later" */ 0, g) +} diff --git a/src/internal/types/testdata/fixedbugs/issue59338b.go b/src/internal/types/testdata/fixedbugs/issue59338b.go new file mode 100644 index 0000000000..ea321bcd17 --- /dev/null +++ b/src/internal/types/testdata/fixedbugs/issue59338b.go @@ -0,0 +1,21 @@ +// -reverseTypeInference + +// Copyright 2023 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package p + +func g[P any](P) {} +func h[P, Q any](P) Q { panic(0) } + +var _ func(int) = g +var _ func(int) string = h[int] + +func f1(func(int)) {} +func f2(int, func(int)) {} + +func _() { + f1(g) + f2(0, g) +} |