summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--src/cmd/compile/internal/types2/api.go2
-rw-r--r--src/cmd/compile/internal/types2/api_test.go43
-rw-r--r--src/cmd/compile/internal/types2/call.go125
-rw-r--r--src/cmd/compile/internal/types2/expr.go21
-rw-r--r--src/cmd/compile/internal/types2/infer.go23
-rw-r--r--src/go/types/api.go2
-rw-r--r--src/go/types/api_test.go43
-rw-r--r--src/go/types/call.go125
-rw-r--r--src/go/types/check_test.go5
-rw-r--r--src/go/types/expr.go21
-rw-r--r--src/go/types/infer.go23
-rw-r--r--src/internal/types/testdata/examples/inference2.go31
-rw-r--r--src/internal/types/testdata/fixedbugs/issue59338a.go21
-rw-r--r--src/internal/types/testdata/fixedbugs/issue59338b.go21
14 files changed, 420 insertions, 86 deletions
diff --git a/src/cmd/compile/internal/types2/api.go b/src/cmd/compile/internal/types2/api.go
index bd87945295..0ee9a4bd06 100644
--- a/src/cmd/compile/internal/types2/api.go
+++ b/src/cmd/compile/internal/types2/api.go
@@ -174,7 +174,7 @@ type Config struct {
// partially instantiated generic functions may be assigned
// (incl. returned) to variables of function type and type
// inference will attempt to infer the missing type arguments.
- // Experimental. Needs a proposal.
+ // See proposal go.dev/issue/59338.
EnableReverseTypeInference bool
}
diff --git a/src/cmd/compile/internal/types2/api_test.go b/src/cmd/compile/internal/types2/api_test.go
index a13f43111c..ae253623e6 100644
--- a/src/cmd/compile/internal/types2/api_test.go
+++ b/src/cmd/compile/internal/types2/api_test.go
@@ -550,18 +550,55 @@ type T[P any] []P
{`package issue51803; func foo[T any](T) {}; func _() { foo[int]( /* leave arg away on purpose */ ) }`,
[]testInst{{`foo`, []string{`int`}, `func(int)`}},
},
+
+ // reverse type parameter inference
+ {`package reverse1a; var f func(int) = g; func g[P any](P) {}`,
+ []testInst{{`g`, []string{`int`}, `func(int)`}},
+ },
+ {`package reverse1b; func f(func(int)) {}; func g[P any](P) {}; func _() { f(g) }`,
+ []testInst{{`g`, []string{`int`}, `func(int)`}},
+ },
+ {`package reverse2a; var f func(int) string = g; func g[P, Q any](P) Q { var q Q; return q }`,
+ []testInst{{`g`, []string{`int`, `string`}, `func(int) string`}},
+ },
+ {`package reverse2b; func f(func(int) string) {}; func g[P, Q any](P) Q { var q Q; return q }; func _() { f(g) }`,
+ []testInst{{`g`, []string{`int`, `string`}, `func(int) string`}},
+ },
+ // reverse3a not possible (cannot assign to generic function outside of argument passing)
+ {`package reverse3b; func f[R any](func(int) R) {}; func g[P any](P) string { return "" }; func _() { f(g) }`,
+ []testInst{
+ {`f`, []string{`string`}, `func(func(int) string)`},
+ {`g`, []string{`int`}, `func(int) string`},
+ },
+ },
+ {`package reverse4a; var _, _ func([]int, *float32) = g, h; func g[P, Q any]([]P, *Q) {}; func h[R any]([]R, *float32) {}`,
+ []testInst{
+ {`g`, []string{`int`, `float32`}, `func([]int, *float32)`},
+ {`h`, []string{`int`}, `func([]int, *float32)`},
+ },
+ },
+ {`package reverse4b; func f(_, _ func([]int, *float32)) {}; func g[P, Q any]([]P, *Q) {}; func h[R any]([]R, *float32) {}; func _() { f(g, h) }`,
+ []testInst{
+ {`g`, []string{`int`, `float32`}, `func([]int, *float32)`},
+ {`h`, []string{`int`}, `func([]int, *float32)`},
+ },
+ },
}
for _, test := range tests {
imports := make(testImporter)
conf := Config{
- Importer: imports,
- Error: func(error) {}, // ignore errors
+ Importer: imports,
+ EnableReverseTypeInference: true,
}
instMap := make(map[*syntax.Name]Instance)
useMap := make(map[*syntax.Name]Object)
makePkg := func(src string) *Package {
- pkg, _ := typecheck(src, &conf, &Info{Instances: instMap, Uses: useMap})
+ pkg, err := typecheck(src, &conf, &Info{Instances: instMap, Uses: useMap})
+ // allow error for issue51803
+ if err != nil && (pkg == nil || pkg.Name() != "issue51803") {
+ t.Fatal(err)
+ }
imports[pkg.Name()] = pkg
return pkg
}
diff --git a/src/cmd/compile/internal/types2/call.go b/src/cmd/compile/internal/types2/call.go
index 20cde9f44e..8ad7744ab4 100644
--- a/src/cmd/compile/internal/types2/call.go
+++ b/src/cmd/compile/internal/types2/call.go
@@ -23,14 +23,16 @@ import (
func (check *Checker) funcInst(tsig *Signature, pos syntax.Pos, x *operand, inst *syntax.IndexExpr) {
assert(tsig != nil || inst != nil)
+ var versionErr bool // set if version error was reported
+ var instErrPos poser // position for instantion error
+ if inst != nil {
+ instErrPos = inst.Pos()
+ } else {
+ instErrPos = pos
+ }
if !check.allowVersion(check.pkg, pos, 1, 18) {
- var posn poser
- if inst != nil {
- posn = inst.Pos()
- } else {
- posn = pos
- }
- check.versionErrorf(posn, "go1.18", "function instantiation")
+ check.versionErrorf(instErrPos, "go1.18", "function instantiation")
+ versionErr = true
}
// targs and xlist are the type arguments and corresponding type expressions, or nil.
@@ -72,6 +74,13 @@ func (check *Checker) funcInst(tsig *Signature, pos syntax.Pos, x *operand, inst
// of a synthetic function f where f's parameters are the parameters and results
// of x and where the arguments to the call of f are values of the parameter and
// result types of x.
+ if !versionErr && !check.allowVersion(check.pkg, pos, 1, 21) {
+ if inst != nil {
+ check.versionErrorf(instErrPos, "go1.21", "partially instantiated function in assignment")
+ } else {
+ check.versionErrorf(instErrPos, "go1.21", "implicitly instantiated function in assignment")
+ }
+ }
n := tsig.params.Len()
m := tsig.results.Len()
args = make([]*operand, n+m)
@@ -303,7 +312,7 @@ func (check *Checker) callExpr(x *operand, call *syntax.CallExpr) exprKind {
}
// evaluate arguments
- args := check.exprList(call.ArgList)
+ args := check.genericExprList(call.ArgList)
sig = check.arguments(call, sig, targs, args, xlist)
if wasGeneric && sig.TypeParams().Len() == 0 {
@@ -338,6 +347,8 @@ func (check *Checker) callExpr(x *operand, call *syntax.CallExpr) exprKind {
return statement
}
+// exprList evaluates a list of expressions and returns the corresponding operands.
+// A single-element expression list may evaluate to multiple operands.
func (check *Checker) exprList(elist []syntax.Expr) (xlist []*operand) {
switch len(elist) {
case 0:
@@ -356,6 +367,25 @@ func (check *Checker) exprList(elist []syntax.Expr) (xlist []*operand) {
return
}
+// genericExprList is like exprList but result operands may be generic (not fully instantiated).
+func (check *Checker) genericExprList(elist []syntax.Expr) (xlist []*operand) {
+ switch len(elist) {
+ case 0:
+ // nothing to do
+ case 1:
+ xlist = check.genericMultiExpr(elist[0])
+ default:
+ // multiple (possibly invalid) values
+ xlist = make([]*operand, len(elist))
+ for i, e := range elist {
+ var x operand
+ check.genericExpr(&x, e)
+ xlist[i] = &x
+ }
+ }
+ return
+}
+
// xlist is the list of type argument expressions supplied in the source code.
func (check *Checker) arguments(call *syntax.CallExpr, sig *Signature, targs []Type, args []*operand, xlist []syntax.Expr) (rsig *Signature) {
rsig = sig
@@ -386,7 +416,7 @@ func (check *Checker) arguments(call *syntax.CallExpr, sig *Signature, targs []T
// set up parameters
sigParams := sig.params // adjusted for variadic functions (may be nil for empty parameter lists!)
- adjusted := false // indicates if sigParams is different from t.params
+ adjusted := false // indicates if sigParams is different from sig.params
if sig.variadic {
if ddd {
// variadic_func(a, b, c...)
@@ -451,8 +481,12 @@ func (check *Checker) arguments(call *syntax.CallExpr, sig *Signature, targs []T
return
}
- // infer type arguments and instantiate signature if necessary
- if sig.TypeParams().Len() > 0 {
+ // collect type parameters of callee and generic function arguments
+ var tparams []*TypeParam
+
+ // collect type parameters of callee
+ n := sig.TypeParams().Len()
+ if n > 0 {
if !check.allowVersion(check.pkg, call.Pos(), 1, 18) {
if iexpr, _ := call.Fun.(*syntax.IndexExpr); iexpr != nil {
check.versionErrorf(iexpr.Pos(), "go1.18", "function instantiation")
@@ -460,29 +494,72 @@ func (check *Checker) arguments(call *syntax.CallExpr, sig *Signature, targs []T
check.versionErrorf(call.Pos(), "go1.18", "implicit function instantiation")
}
}
-
- // Rename type parameters to avoid problems with recursive calls.
- var tparams []*TypeParam
+ // rename type parameters to avoid problems with recursive calls
tparams, sigParams = check.renameTParams(call.Pos(), sig.TypeParams().list(), sigParams)
+ }
- targs := check.infer(call.Pos(), tparams, targs, sigParams, args)
+ // collect type parameters from generic function arguments
+ var genericArgs []int // indices of generic function arguments
+ if check.conf.EnableReverseTypeInference {
+ for i, arg := range args {
+ // generic arguments cannot have a defined (*Named) type - no need for underlying type below
+ if asig, _ := arg.typ.(*Signature); asig != nil && asig.TypeParams().Len() > 0 {
+ // TODO(gri) need to also rename type parameters for cases like f(g, g)
+ tparams = append(tparams, asig.TypeParams().list()...)
+ genericArgs = append(genericArgs, i)
+ }
+ }
+ }
+ if len(genericArgs) > 0 && !check.allowVersion(check.pkg, call.Pos(), 1, 21) {
+ // at the moment we only support implicit instantiations of argument functions
+ check.versionErrorf(args[genericArgs[0]].Pos(), "go1.21", "implicitly instantiated function as argument")
+ }
+
+ // tparams holds the type parameters of the callee and generic function arguments, if any:
+ // the first n type parameters belong to the callee, followed by mi type parameters for each
+ // of the generic function arguments, where mi = args[i].typ.(*Signature).TypeParams().Len().
+
+ // infer missing type arguments of callee and function arguments
+ if len(tparams) > 0 {
+ targs = check.infer(call.Pos(), tparams, targs, sigParams, args)
if targs == nil {
+ // TODO(gri) If infer inferred the first targs[:n], consider instantiating
+ // the call signature for better error messages/gopls behavior.
+ // Perhaps instantiate as much as we can, also for arguments.
+ // This will require changes to how infer returns its results.
return // error already reported
}
- // compute result signature
- rsig = check.instantiateSignature(call.Pos(), sig, targs, xlist)
- assert(rsig.TypeParams().Len() == 0) // signature is not generic anymore
- check.recordInstance(call.Fun, targs, rsig)
+ // compute result signature: instantiate if needed
+ rsig = sig
+ if n > 0 {
+ rsig = check.instantiateSignature(call.Pos(), sig, targs[:n], xlist)
+ assert(rsig.TypeParams().Len() == 0) // signature is not generic anymore
+ check.recordInstance(call.Fun, targs[:n], rsig)
+ }
- // Optimization: Only if the parameter list was adjusted do we
- // need to compute it from the adjusted list; otherwise we can
- // simply use the result signature's parameter list.
- if adjusted {
- sigParams = check.subst(call.Pos(), sigParams, makeSubstMap(tparams, targs), nil, check.context()).(*Tuple)
+ // Optimization: Only if the callee's parameter list was adjusted do we need to
+ // compute it from the adjusted list; otherwise we can simply use the result
+ // signature's parameter list. We only need the n type parameters and arguments
+ // of the callee.
+ if n > 0 && adjusted {
+ sigParams = check.subst(call.Pos(), sigParams, makeSubstMap(tparams[:n], targs[:n]), nil, check.context()).(*Tuple)
} else {
sigParams = rsig.params
}
+
+ // compute argument signatures: instantiate if needed
+ j := n
+ for _, i := range genericArgs {
+ asig := args[i].typ.(*Signature)
+ k := j + asig.TypeParams().Len()
+ // targs[j:k] are the inferred type arguments for asig
+ asig = check.instantiateSignature(call.Pos(), asig, targs[j:k], nil) // TODO(gri) provide xlist if possible (partial instantiations)
+ assert(asig.TypeParams().Len() == 0) // signature is not generic anymore
+ args[i].typ = asig
+ check.recordInstance(args[i].expr, targs[j:k], asig)
+ j = k
+ }
}
// check arguments
diff --git a/src/cmd/compile/internal/types2/expr.go b/src/cmd/compile/internal/types2/expr.go
index 7c3b40f086..51b944eead 100644
--- a/src/cmd/compile/internal/types2/expr.go
+++ b/src/cmd/compile/internal/types2/expr.go
@@ -1882,6 +1882,27 @@ func (check *Checker) multiExpr(e syntax.Expr, allowCommaOk bool) (list []*opera
return
}
+// genericMultiExpr is like multiExpr but a one-element result may also be generic
+// and potential comma-ok expressions are returned as single values.
+func (check *Checker) genericMultiExpr(e syntax.Expr) (list []*operand) {
+ var x operand
+ check.rawExpr(nil, &x, e, nil, true)
+ check.exclude(&x, 1<<novalue|1<<builtin|1<<typexpr)
+
+ if t, ok := x.typ.(*Tuple); ok && x.mode != invalid {
+ // multiple values - cannot be generic
+ list = make([]*operand, t.Len())
+ for i, v := range t.vars {
+ list[i] = &operand{mode: value, expr: e, typ: v.typ}
+ }
+ return
+ }
+
+ // exactly one (possible invalid or generic) value
+ list = []*operand{&x}
+ return
+}
+
// exprWithHint typechecks expression e and initializes x with the expression value;
// hint is the type of a composite literal element.
// If an error occurred, x.mode is set to invalid.
diff --git a/src/cmd/compile/internal/types2/infer.go b/src/cmd/compile/internal/types2/infer.go
index dbe621cded..9c1022c46f 100644
--- a/src/cmd/compile/internal/types2/infer.go
+++ b/src/cmd/compile/internal/types2/infer.go
@@ -140,17 +140,17 @@ func (check *Checker) infer(pos syntax.Pos, tparams []*TypeParam, targs []Type,
}
for i, arg := range args {
+ if arg.mode == invalid {
+ // An error was reported earlier. Ignore this arg
+ // and continue, we may still be able to infer all
+ // targs resulting in fewer follow-on errors.
+ // TODO(gri) determine if we still need this check
+ continue
+ }
par := params.At(i)
- // If we permit bidirectional unification, this conditional code needs to be
- // executed even if par.typ is not parameterized since the argument may be a
- // generic function (for which we want to infer its type arguments).
- if isParameterized(tparams, par.typ) {
- if arg.mode == invalid {
- // An error was reported earlier. Ignore this targ
- // and continue, we may still be able to infer all
- // targs resulting in fewer follow-on errors.
- continue
- }
+ if isParameterized(tparams, par.typ) || isParameterized(tparams, arg.typ) {
+ // Function parameters are always typed. Arguments may be untyped.
+ // Collect the indices of untyped arguments and handle them later.
if isTyped(arg.typ) {
if !u.unify(par.typ, arg.typ) {
errorf("type", par.typ, arg.typ, arg)
@@ -263,7 +263,7 @@ func (check *Checker) infer(pos syntax.Pos, tparams []*TypeParam, targs []Type,
}
// --- 3 ---
- // use information from untyped contants
+ // use information from untyped constants
if traceInference {
u.tracef("== untyped arguments: %v", untyped)
@@ -541,7 +541,6 @@ func (w *tpWalker) isParameterized(typ Type) (res bool) {
}
case *TypeParam:
- // t must be one of w.tparams
return tparamIndex(w.tparams, t) >= 0
default:
diff --git a/src/go/types/api.go b/src/go/types/api.go
index 7af84fd244..e202d6dea8 100644
--- a/src/go/types/api.go
+++ b/src/go/types/api.go
@@ -175,7 +175,7 @@ type Config struct {
// partially instantiated generic functions may be assigned
// (incl. returned) to variables of function type and type
// inference will attempt to infer the missing type arguments.
- // Experimental. Needs a proposal.
+ // See proposal go.dev/issue/59338.
_EnableReverseTypeInference bool
}
diff --git a/src/go/types/api_test.go b/src/go/types/api_test.go
index ae1a7e50a7..02e26c3f02 100644
--- a/src/go/types/api_test.go
+++ b/src/go/types/api_test.go
@@ -550,18 +550,57 @@ type T[P any] []P
{`package issue51803; func foo[T any](T) {}; func _() { foo[int]( /* leave arg away on purpose */ ) }`,
[]testInst{{`foo`, []string{`int`}, `func(int)`}},
},
+
+ // reverse type parameter inference
+ {`package reverse1a; var f func(int) = g; func g[P any](P) {}`,
+ []testInst{{`g`, []string{`int`}, `func(int)`}},
+ },
+ {`package reverse1b; func f(func(int)) {}; func g[P any](P) {}; func _() { f(g) }`,
+ []testInst{{`g`, []string{`int`}, `func(int)`}},
+ },
+ {`package reverse2a; var f func(int) string = g; func g[P, Q any](P) Q { var q Q; return q }`,
+ []testInst{{`g`, []string{`int`, `string`}, `func(int) string`}},
+ },
+ {`package reverse2b; func f(func(int) string) {}; func g[P, Q any](P) Q { var q Q; return q }; func _() { f(g) }`,
+ []testInst{{`g`, []string{`int`, `string`}, `func(int) string`}},
+ },
+ // reverse3a not possible (cannot assign to generic function outside of argument passing)
+ {`package reverse3b; func f[R any](func(int) R) {}; func g[P any](P) string { return "" }; func _() { f(g) }`,
+ []testInst{
+ {`f`, []string{`string`}, `func(func(int) string)`},
+ {`g`, []string{`int`}, `func(int) string`},
+ },
+ },
+ {`package reverse4a; var _, _ func([]int, *float32) = g, h; func g[P, Q any]([]P, *Q) {}; func h[R any]([]R, *float32) {}`,
+ []testInst{
+ {`g`, []string{`int`, `float32`}, `func([]int, *float32)`},
+ {`h`, []string{`int`}, `func([]int, *float32)`},
+ },
+ },
+ {`package reverse4b; func f(_, _ func([]int, *float32)) {}; func g[P, Q any]([]P, *Q) {}; func h[R any]([]R, *float32) {}; func _() { f(g, h) }`,
+ []testInst{
+ {`g`, []string{`int`, `float32`}, `func([]int, *float32)`},
+ {`h`, []string{`int`}, `func([]int, *float32)`},
+ },
+ },
}
for _, test := range tests {
imports := make(testImporter)
conf := Config{
Importer: imports,
- Error: func(error) {}, // ignore errors
+ // Unexported field: set below with boolFieldAddr
+ // _EnableReverseTypeInference: true,
}
+ *boolFieldAddr(&conf, "_EnableReverseTypeInference") = true
instMap := make(map[*ast.Ident]Instance)
useMap := make(map[*ast.Ident]Object)
makePkg := func(src string) *Package {
- pkg, _ := typecheck(src, &conf, &Info{Instances: instMap, Uses: useMap})
+ pkg, err := typecheck(src, &conf, &Info{Instances: instMap, Uses: useMap})
+ // allow error for issue51803
+ if err != nil && (pkg == nil || pkg.Name() != "issue51803") {
+ t.Fatal(err)
+ }
imports[pkg.Name()] = pkg
return pkg
}
diff --git a/src/go/types/call.go b/src/go/types/call.go
index 979de2338f..02b6038ccc 100644
--- a/src/go/types/call.go
+++ b/src/go/types/call.go
@@ -25,14 +25,16 @@ import (
func (check *Checker) funcInst(tsig *Signature, pos token.Pos, x *operand, ix *typeparams.IndexExpr) {
assert(tsig != nil || ix != nil)
+ var versionErr bool // set if version error was reported
+ var instErrPos positioner // position for instantion error
+ if ix != nil {
+ instErrPos = inNode(ix.Orig, ix.Lbrack)
+ } else {
+ instErrPos = atPos(pos)
+ }
if !check.allowVersion(check.pkg, pos, 1, 18) {
- var posn positioner
- if ix != nil {
- posn = inNode(ix.Orig, ix.Lbrack)
- } else {
- posn = atPos(pos)
- }
- check.softErrorf(posn, UnsupportedFeature, "function instantiation requires go1.18 or later")
+ check.softErrorf(instErrPos, UnsupportedFeature, "function instantiation requires go1.18 or later")
+ versionErr = true
}
// targs and xlist are the type arguments and corresponding type expressions, or nil.
@@ -74,6 +76,13 @@ func (check *Checker) funcInst(tsig *Signature, pos token.Pos, x *operand, ix *t
// of a synthetic function f where f's parameters are the parameters and results
// of x and where the arguments to the call of f are values of the parameter and
// result types of x.
+ if !versionErr && !check.allowVersion(check.pkg, pos, 1, 21) {
+ if ix != nil {
+ check.softErrorf(instErrPos, UnsupportedFeature, "partially instantiated function in assignment requires go1.21 or later")
+ } else {
+ check.softErrorf(instErrPos, UnsupportedFeature, "implicitly instantiated function in assignment requires go1.21 or later")
+ }
+ }
n := tsig.params.Len()
m := tsig.results.Len()
args = make([]*operand, n+m)
@@ -308,7 +317,7 @@ func (check *Checker) callExpr(x *operand, call *ast.CallExpr) exprKind {
}
// evaluate arguments
- args := check.exprList(call.Args)
+ args := check.genericExprList(call.Args)
sig = check.arguments(call, sig, targs, args, xlist)
if wasGeneric && sig.TypeParams().Len() == 0 {
@@ -343,6 +352,8 @@ func (check *Checker) callExpr(x *operand, call *ast.CallExpr) exprKind {
return statement
}
+// exprList evaluates a list of expressions and returns the corresponding operands.
+// A single-element expression list may evaluate to multiple operands.
func (check *Checker) exprList(elist []ast.Expr) (xlist []*operand) {
switch len(elist) {
case 0:
@@ -361,6 +372,25 @@ func (check *Checker) exprList(elist []ast.Expr) (xlist []*operand) {
return
}
+// genericExprList is like exprList but result operands may be generic (not fully instantiated).
+func (check *Checker) genericExprList(elist []ast.Expr) (xlist []*operand) {
+ switch len(elist) {
+ case 0:
+ // nothing to do
+ case 1:
+ xlist = check.genericMultiExpr(elist[0])
+ default:
+ // multiple (possibly invalid) values
+ xlist = make([]*operand, len(elist))
+ for i, e := range elist {
+ var x operand
+ check.genericExpr(&x, e)
+ xlist[i] = &x
+ }
+ }
+ return
+}
+
// xlist is the list of type argument expressions supplied in the source code.
func (check *Checker) arguments(call *ast.CallExpr, sig *Signature, targs []Type, args []*operand, xlist []ast.Expr) (rsig *Signature) {
rsig = sig
@@ -391,7 +421,7 @@ func (check *Checker) arguments(call *ast.CallExpr, sig *Signature, targs []Type
// set up parameters
sigParams := sig.params // adjusted for variadic functions (may be nil for empty parameter lists!)
- adjusted := false // indicates if sigParams is different from t.params
+ adjusted := false // indicates if sigParams is different from sig.params
if sig.variadic {
if ddd {
// variadic_func(a, b, c...)
@@ -452,8 +482,12 @@ func (check *Checker) arguments(call *ast.CallExpr, sig *Signature, targs []Type
return
}
- // infer type arguments and instantiate signature if necessary
- if sig.TypeParams().Len() > 0 {
+ // collect type parameters of callee and generic function arguments
+ var tparams []*TypeParam
+
+ // collect type parameters of callee
+ n := sig.TypeParams().Len()
+ if n > 0 {
if !check.allowVersion(check.pkg, call.Pos(), 1, 18) {
switch call.Fun.(type) {
case *ast.IndexExpr, *ast.IndexListExpr:
@@ -463,29 +497,72 @@ func (check *Checker) arguments(call *ast.CallExpr, sig *Signature, targs []Type
check.softErrorf(inNode(call, call.Lparen), UnsupportedFeature, "implicit function instantiation requires go1.18 or later")
}
}
-
- // Rename type parameters to avoid problems with recursive calls.
- var tparams []*TypeParam
+ // rename type parameters to avoid problems with recursive calls
tparams, sigParams = check.renameTParams(call.Pos(), sig.TypeParams().list(), sigParams)
+ }
- targs := check.infer(call, tparams, targs, sigParams, args)
+ // collect type parameters from generic function arguments
+ var genericArgs []int // indices of generic function arguments
+ if check.conf._EnableReverseTypeInference {
+ for i, arg := range args {
+ // generic arguments cannot have a defined (*Named) type - no need for underlying type below
+ if asig, _ := arg.typ.(*Signature); asig != nil && asig.TypeParams().Len() > 0 {
+ // TODO(gri) need to also rename type parameters for cases like f(g, g)
+ tparams = append(tparams, asig.TypeParams().list()...)
+ genericArgs = append(genericArgs, i)
+ }
+ }
+ }
+ if len(genericArgs) > 0 && !check.allowVersion(check.pkg, call.Pos(), 1, 21) {
+ // at the moment we only support implicit instantiations of argument functions
+ check.softErrorf(inNode(call, call.Lparen), UnsupportedFeature, "implicitly instantiated function as argument requires go1.21 or later")
+ }
+
+ // tparams holds the type parameters of the callee and generic function arguments, if any:
+ // the first n type parameters belong to the callee, followed by mi type parameters for each
+ // of the generic function arguments, where mi = args[i].typ.(*Signature).TypeParams().Len().
+
+ // infer missing type arguments of callee and function arguments
+ if len(tparams) > 0 {
+ targs = check.infer(call, tparams, targs, sigParams, args)
if targs == nil {
+ // TODO(gri) If infer inferred the first targs[:n], consider instantiating
+ // the call signature for better error messages/gopls behavior.
+ // Perhaps instantiate as much as we can, also for arguments.
+ // This will require changes to how infer returns its results.
return // error already reported
}
- // compute result signature
- rsig = check.instantiateSignature(call.Pos(), sig, targs, xlist)
- assert(rsig.TypeParams().Len() == 0) // signature is not generic anymore
- check.recordInstance(call.Fun, targs, rsig)
+ // compute result signature: instantiate if needed
+ rsig = sig
+ if n > 0 {
+ rsig = check.instantiateSignature(call.Pos(), sig, targs[:n], xlist)
+ assert(rsig.TypeParams().Len() == 0) // signature is not generic anymore
+ check.recordInstance(call.Fun, targs[:n], rsig)
+ }
- // Optimization: Only if the parameter list was adjusted do we
- // need to compute it from the adjusted list; otherwise we can
- // simply use the result signature's parameter list.
- if adjusted {
- sigParams = check.subst(call.Pos(), sigParams, makeSubstMap(tparams, targs), nil, check.context()).(*Tuple)
+ // Optimization: Only if the callee's parameter list was adjusted do we need to
+ // compute it from the adjusted list; otherwise we can simply use the result
+ // signature's parameter list. We only need the n type parameters and arguments
+ // of the callee.
+ if n > 0 && adjusted {
+ sigParams = check.subst(call.Pos(), sigParams, makeSubstMap(tparams[:n], targs[:n]), nil, check.context()).(*Tuple)
} else {
sigParams = rsig.params
}
+
+ // compute argument signatures: instantiate if needed
+ j := n
+ for _, i := range genericArgs {
+ asig := args[i].typ.(*Signature)
+ k := j + asig.TypeParams().Len()
+ // targs[j:k] are the inferred type arguments for asig
+ asig = check.instantiateSignature(call.Pos(), asig, targs[j:k], nil) // TODO(gri) provide xlist if possible (partial instantiations)
+ assert(asig.TypeParams().Len() == 0) // signature is not generic anymore
+ args[i].typ = asig
+ check.recordInstance(args[i].expr, targs[j:k], asig)
+ j = k
+ }
}
// check arguments
diff --git a/src/go/types/check_test.go b/src/go/types/check_test.go
index 0f4c320a47..cda052f4d3 100644
--- a/src/go/types/check_test.go
+++ b/src/go/types/check_test.go
@@ -39,6 +39,7 @@ import (
"go/scanner"
"go/token"
"internal/testenv"
+ "internal/types/errors"
"os"
"path/filepath"
"reflect"
@@ -295,9 +296,9 @@ func testFiles(t *testing.T, sizes Sizes, filenames []string, srcs [][]byte, man
}
}
-func readCode(err Error) int {
+func readCode(err Error) errors.Code {
v := reflect.ValueOf(err)
- return int(v.FieldByName("go116code").Int())
+ return errors.Code(v.FieldByName("go116code").Int())
}
// boolFieldAddr(conf, name) returns the address of the boolean field conf.<name>.
diff --git a/src/go/types/expr.go b/src/go/types/expr.go
index 0db80ca44b..891153ba8d 100644
--- a/src/go/types/expr.go
+++ b/src/go/types/expr.go
@@ -1829,6 +1829,27 @@ func (check *Checker) multiExpr(e ast.Expr, allowCommaOk bool) (list []*operand,
return
}
+// genericMultiExpr is like multiExpr but a one-element result may also be generic
+// and potential comma-ok expressions are returned as single values.
+func (check *Checker) genericMultiExpr(e ast.Expr) (list []*operand) {
+ var x operand
+ check.rawExpr(nil, &x, e, nil, true)
+ check.exclude(&x, 1<<novalue|1<<builtin|1<<typexpr)
+
+ if t, ok := x.typ.(*Tuple); ok && x.mode != invalid {
+ // multiple values - cannot be generic
+ list = make([]*operand, t.Len())
+ for i, v := range t.vars {
+ list[i] = &operand{mode: value, expr: e, typ: v.typ}
+ }
+ return
+ }
+
+ // exactly one (possible invalid or generic) value
+ list = []*operand{&x}
+ return
+}
+
// exprWithHint typechecks expression e and initializes x with the expression value;
// hint is the type of a composite literal element.
// If an error occurred, x.mode is set to invalid.
diff --git a/src/go/types/infer.go b/src/go/types/infer.go
index 3aa66105c4..39bf4a14f7 100644
--- a/src/go/types/infer.go
+++ b/src/go/types/infer.go
@@ -142,17 +142,17 @@ func (check *Checker) infer(posn positioner, tparams []*TypeParam, targs []Type,
}
for i, arg := range args {
+ if arg.mode == invalid {
+ // An error was reported earlier. Ignore this arg
+ // and continue, we may still be able to infer all
+ // targs resulting in fewer follow-on errors.
+ // TODO(gri) determine if we still need this check
+ continue
+ }
par := params.At(i)
- // If we permit bidirectional unification, this conditional code needs to be
- // executed even if par.typ is not parameterized since the argument may be a
- // generic function (for which we want to infer its type arguments).
- if isParameterized(tparams, par.typ) {
- if arg.mode == invalid {
- // An error was reported earlier. Ignore this targ
- // and continue, we may still be able to infer all
- // targs resulting in fewer follow-on errors.
- continue
- }
+ if isParameterized(tparams, par.typ) || isParameterized(tparams, arg.typ) {
+ // Function parameters are always typed. Arguments may be untyped.
+ // Collect the indices of untyped arguments and handle them later.
if isTyped(arg.typ) {
if !u.unify(par.typ, arg.typ) {
errorf("type", par.typ, arg.typ, arg)
@@ -265,7 +265,7 @@ func (check *Checker) infer(posn positioner, tparams []*TypeParam, targs []Type,
}
// --- 3 ---
- // use information from untyped contants
+ // use information from untyped constants
if traceInference {
u.tracef("== untyped arguments: %v", untyped)
@@ -543,7 +543,6 @@ func (w *tpWalker) isParameterized(typ Type) (res bool) {
}
case *TypeParam:
- // t must be one of w.tparams
return tparamIndex(w.tparams, t) >= 0
default:
diff --git a/src/internal/types/testdata/examples/inference2.go b/src/internal/types/testdata/examples/inference2.go
index d309a00c0c..80acc828dd 100644
--- a/src/internal/types/testdata/examples/inference2.go
+++ b/src/internal/types/testdata/examples/inference2.go
@@ -10,11 +10,13 @@
package p
-func f1[P any](P) {}
-func f2[P any]() P { var x P; return x }
-func f3[P, Q any](P) Q { var x Q; return x }
-func f4[P any](P, P) {}
-func f5[P any](P) []P { return nil }
+func f1[P any](P) {}
+func f2[P any]() P { var x P; return x }
+func f3[P, Q any](P) Q { var x Q; return x }
+func f4[P any](P, P) {}
+func f5[P any](P) []P { return nil }
+func f6[P any](int) P { var x P; return x }
+func f7[P any](P) string { return "" }
// initialization expressions
var (
@@ -71,3 +73,22 @@ func _() func(string) []int {
func _() (_, _ func(int)) { return f1, f1 }
func _() (_, _ func(int)) { return f1, f2 /* ERROR "cannot infer P" */ }
+
+// Argument passing
+func g1(func(int)) {}
+func g2(func(int, int)) {}
+func g3(func(int) string) {}
+func g4[P any](func(P) string) {}
+func g5[P, Q any](func(P) string, func(P) Q) {}
+func g6(func(int), func(string)) {}
+
+func _() {
+ g1(f1)
+ g1(f2 /* ERROR "cannot infer P" */)
+ g2(f4)
+ g4(f6)
+ g5(f6, f7)
+
+ // TODO(gri) this should work (requires type parameter renaming for f1)
+ g6(f1, f1 /* ERROR "type func[P any](P) of f1 does not match func(string)" */)
+}
diff --git a/src/internal/types/testdata/fixedbugs/issue59338a.go b/src/internal/types/testdata/fixedbugs/issue59338a.go
new file mode 100644
index 0000000000..fd37586cfb
--- /dev/null
+++ b/src/internal/types/testdata/fixedbugs/issue59338a.go
@@ -0,0 +1,21 @@
+// -reverseTypeInference -lang=go1.20
+
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package p
+
+func g[P any](P) {}
+func h[P, Q any](P) Q { panic(0) }
+
+var _ func(int) = g /* ERROR "implicitly instantiated function in assignment requires go1.21 or later" */
+var _ func(int) string = h[ /* ERROR "partially instantiated function in assignment requires go1.21 or later" */ int]
+
+func f1(func(int)) {}
+func f2(int, func(int)) {}
+
+func _() {
+ f1( /* ERROR "implicitly instantiated function as argument requires go1.21 or later" */ g)
+ f2( /* ERROR "implicitly instantiated function as argument requires go1.21 or later" */ 0, g)
+}
diff --git a/src/internal/types/testdata/fixedbugs/issue59338b.go b/src/internal/types/testdata/fixedbugs/issue59338b.go
new file mode 100644
index 0000000000..ea321bcd17
--- /dev/null
+++ b/src/internal/types/testdata/fixedbugs/issue59338b.go
@@ -0,0 +1,21 @@
+// -reverseTypeInference
+
+// Copyright 2023 The Go Authors. All rights reserved.
+// Use of this source code is governed by a BSD-style
+// license that can be found in the LICENSE file.
+
+package p
+
+func g[P any](P) {}
+func h[P, Q any](P) Q { panic(0) }
+
+var _ func(int) = g
+var _ func(int) string = h[int]
+
+func f1(func(int)) {}
+func f2(int, func(int)) {}
+
+func _() {
+ f1(g)
+ f2(0, g)
+}