diff options
-rw-r--r-- | api/next/51365.txt | 3 | ||||
-rw-r--r-- | src/context/context.go | 94 | ||||
-rw-r--r-- | src/context/context_test.go | 153 | ||||
-rw-r--r-- | src/context/x_test.go | 1 |
4 files changed, 231 insertions, 20 deletions
diff --git a/api/next/51365.txt b/api/next/51365.txt new file mode 100644 index 0000000000..df629f1852 --- /dev/null +++ b/api/next/51365.txt @@ -0,0 +1,3 @@ +pkg context, func Cause(Context) error #51365 +pkg context, func WithCancelCause(Context) (Context, CancelCauseFunc) #51365 +pkg context, type CancelCauseFunc func(error) #51365 diff --git a/src/context/context.go b/src/context/context.go index 7eace57893..a0b5edc524 100644 --- a/src/context/context.go +++ b/src/context/context.go @@ -22,6 +22,12 @@ // fires. The go vet tool checks that CancelFuncs are used on all // control-flow paths. // +// The WithCancelCause function returns a CancelCauseFunc, which +// takes an error and records it as the cancelation cause. Calling +// Cause on the canceled context or any of its children retrieves +// the cause. If no cause is specified, Cause(ctx) returns the same +// value as ctx.Err(). +// // Programs that use Contexts should follow these rules to keep interfaces // consistent across packages and enable static analysis tools to check context // propagation: @@ -230,17 +236,63 @@ type CancelFunc func() // Canceling this context releases resources associated with it, so code should // call cancel as soon as the operations running in this Context complete. func WithCancel(parent Context) (ctx Context, cancel CancelFunc) { + c := withCancel(parent) + return c, func() { c.cancel(true, Canceled, nil) } +} + +// A CancelCauseFunc behaves like a CancelFunc but additionally sets the cancelation cause. +// This cause can be retrieved by calling Cause on the canceled Context or on +// any of its derived Contexts. +// +// If the context has already been canceled, CancelCauseFunc does not set the cause. +// For example, if childContext is derived from parentContext: +// - if parentContext is canceled with cause1 before childContext is canceled with cause2, +// then Cause(parentContext) == Cause(childContext) == cause1 +// - if childContext is canceled with cause2 before parentContext is canceled with cause1, +// then Cause(parentContext) == cause1 and Cause(childContext) == cause2 +type CancelCauseFunc func(cause error) + +// WithCancelCause behaves like WithCancel but returns a CancelCauseFunc instead of a CancelFunc. +// Calling cancel with a non-nil error (the "cause") records that error in ctx; +// it can then be retrieved using Cause(ctx). +// Calling cancel with nil sets the cause to Canceled. +// +// Example use: +// +// ctx, cancel := context.WithCancelCause(parent) +// cancel(myError) +// ctx.Err() // returns context.Canceled +// context.Cause(ctx) // returns myError +func WithCancelCause(parent Context) (ctx Context, cancel CancelCauseFunc) { + c := withCancel(parent) + return c, func(cause error) { c.cancel(true, Canceled, cause) } +} + +func withCancel(parent Context) *cancelCtx { if parent == nil { panic("cannot create context from nil parent") } c := newCancelCtx(parent) - propagateCancel(parent, &c) - return &c, func() { c.cancel(true, Canceled) } + propagateCancel(parent, c) + return c +} + +// Cause returns a non-nil error explaining why c was canceled. +// The first cancelation of c or one of its parents sets the cause. +// If that cancelation happened via a call to CancelCauseFunc(err), +// then Cause returns err. +// Otherwise Cause(c) returns the same value as c.Err(). +// Cause returns nil if c has not been canceled yet. +func Cause(c Context) error { + if cc, ok := c.Value(&cancelCtxKey).(*cancelCtx); ok { + return cc.cause + } + return nil } // newCancelCtx returns an initialized cancelCtx. -func newCancelCtx(parent Context) cancelCtx { - return cancelCtx{Context: parent} +func newCancelCtx(parent Context) *cancelCtx { + return &cancelCtx{Context: parent} } // goroutines counts the number of goroutines ever created; for testing. @@ -256,7 +308,7 @@ func propagateCancel(parent Context, child canceler) { select { case <-done: // parent is already canceled - child.cancel(false, parent.Err()) + child.cancel(false, parent.Err(), Cause(parent)) return default: } @@ -265,7 +317,7 @@ func propagateCancel(parent Context, child canceler) { p.mu.Lock() if p.err != nil { // parent has already been canceled - child.cancel(false, p.err) + child.cancel(false, p.err, p.cause) } else { if p.children == nil { p.children = make(map[canceler]struct{}) @@ -278,7 +330,7 @@ func propagateCancel(parent Context, child canceler) { go func() { select { case <-parent.Done(): - child.cancel(false, parent.Err()) + child.cancel(false, parent.Err(), Cause(parent)) case <-child.Done(): } }() @@ -326,7 +378,7 @@ func removeChild(parent Context, child canceler) { // A canceler is a context type that can be canceled directly. The // implementations are *cancelCtx and *timerCtx. type canceler interface { - cancel(removeFromParent bool, err error) + cancel(removeFromParent bool, err, cause error) Done() <-chan struct{} } @@ -346,6 +398,7 @@ type cancelCtx struct { done atomic.Value // of chan struct{}, created lazily, closed by first cancel call children map[canceler]struct{} // set to nil by the first cancel call err error // set to non-nil by the first cancel call + cause error // set to non-nil by the first cancel call } func (c *cancelCtx) Value(key any) any { @@ -394,16 +447,21 @@ func (c *cancelCtx) String() string { // cancel closes c.done, cancels each of c's children, and, if // removeFromParent is true, removes c from its parent's children. -func (c *cancelCtx) cancel(removeFromParent bool, err error) { +// cancel sets c.cause to cause if this is the first time c is canceled. +func (c *cancelCtx) cancel(removeFromParent bool, err, cause error) { if err == nil { panic("context: internal error: missing cancel error") } + if cause == nil { + cause = err + } c.mu.Lock() if c.err != nil { c.mu.Unlock() return // already canceled } c.err = err + c.cause = cause d, _ := c.done.Load().(chan struct{}) if d == nil { c.done.Store(closedchan) @@ -412,7 +470,7 @@ func (c *cancelCtx) cancel(removeFromParent bool, err error) { } for child := range c.children { // NOTE: acquiring the child's lock while holding parent's lock. - child.cancel(false, err) + child.cancel(false, err, cause) } c.children = nil c.mu.Unlock() @@ -446,24 +504,24 @@ func WithDeadline(parent Context, d time.Time) (Context, CancelFunc) { propagateCancel(parent, c) dur := time.Until(d) if dur <= 0 { - c.cancel(true, DeadlineExceeded) // deadline has already passed - return c, func() { c.cancel(false, Canceled) } + c.cancel(true, DeadlineExceeded, nil) // deadline has already passed + return c, func() { c.cancel(false, Canceled, nil) } } c.mu.Lock() defer c.mu.Unlock() if c.err == nil { c.timer = time.AfterFunc(dur, func() { - c.cancel(true, DeadlineExceeded) + c.cancel(true, DeadlineExceeded, nil) }) } - return c, func() { c.cancel(true, Canceled) } + return c, func() { c.cancel(true, Canceled, nil) } } // A timerCtx carries a timer and a deadline. It embeds a cancelCtx to // implement Done and Err. It implements cancel by stopping its timer then // delegating to cancelCtx.cancel. type timerCtx struct { - cancelCtx + *cancelCtx timer *time.Timer // Under cancelCtx.mu. deadline time.Time @@ -479,8 +537,8 @@ func (c *timerCtx) String() string { time.Until(c.deadline).String() + "])" } -func (c *timerCtx) cancel(removeFromParent bool, err error) { - c.cancelCtx.cancel(false, err) +func (c *timerCtx) cancel(removeFromParent bool, err, cause error) { + c.cancelCtx.cancel(false, err, cause) if removeFromParent { // Remove this timerCtx from its parent cancelCtx's children. removeChild(c.cancelCtx.Context, c) @@ -581,7 +639,7 @@ func value(c Context, key any) any { c = ctx.Context case *timerCtx: if key == &cancelCtxKey { - return &ctx.cancelCtx + return ctx.cancelCtx } c = ctx.Context case *emptyCtx: diff --git a/src/context/context_test.go b/src/context/context_test.go index 0991880907..593a7b1521 100644 --- a/src/context/context_test.go +++ b/src/context/context_test.go @@ -650,8 +650,9 @@ func XTestCancelRemoves(t testingT) { } func XTestWithCancelCanceledParent(t testingT) { - parent, pcancel := WithCancel(Background()) - pcancel() + parent, pcancel := WithCancelCause(Background()) + cause := fmt.Errorf("Because!") + pcancel(cause) c, _ := WithCancel(parent) select { @@ -662,6 +663,9 @@ func XTestWithCancelCanceledParent(t testingT) { if got, want := c.Err(), Canceled; got != want { t.Errorf("child not canceled; got = %v, want = %v", got, want) } + if got, want := Cause(c), cause; got != want { + t.Errorf("child has wrong cause; got = %v, want = %v", got, want) + } } func XTestWithValueChecksKey(t testingT) { @@ -785,3 +789,148 @@ func XTestCustomContextGoroutines(t testingT) { defer cancel7() checkNoGoroutine() } + +func XTestCause(t testingT) { + var ( + parentCause = fmt.Errorf("parentCause") + childCause = fmt.Errorf("childCause") + ) + for _, test := range []struct { + name string + ctx Context + err error + cause error + }{ + { + name: "Background", + ctx: Background(), + err: nil, + cause: nil, + }, + { + name: "TODO", + ctx: TODO(), + err: nil, + cause: nil, + }, + { + name: "WithCancel", + ctx: func() Context { + ctx, cancel := WithCancel(Background()) + cancel() + return ctx + }(), + err: Canceled, + cause: Canceled, + }, + { + name: "WithCancelCause", + ctx: func() Context { + ctx, cancel := WithCancelCause(Background()) + cancel(parentCause) + return ctx + }(), + err: Canceled, + cause: parentCause, + }, + { + name: "WithCancelCause nil", + ctx: func() Context { + ctx, cancel := WithCancelCause(Background()) + cancel(nil) + return ctx + }(), + err: Canceled, + cause: Canceled, + }, + { + name: "WithCancelCause: parent cause before child", + ctx: func() Context { + ctx, cancelParent := WithCancelCause(Background()) + ctx, cancelChild := WithCancelCause(ctx) + cancelParent(parentCause) + cancelChild(childCause) + return ctx + }(), + err: Canceled, + cause: parentCause, + }, + { + name: "WithCancelCause: parent cause after child", + ctx: func() Context { + ctx, cancelParent := WithCancelCause(Background()) + ctx, cancelChild := WithCancelCause(ctx) + cancelChild(childCause) + cancelParent(parentCause) + return ctx + }(), + err: Canceled, + cause: childCause, + }, + { + name: "WithCancelCause: parent cause before nil", + ctx: func() Context { + ctx, cancelParent := WithCancelCause(Background()) + ctx, cancelChild := WithCancel(ctx) + cancelParent(parentCause) + cancelChild() + return ctx + }(), + err: Canceled, + cause: parentCause, + }, + { + name: "WithCancelCause: parent cause after nil", + ctx: func() Context { + ctx, cancelParent := WithCancelCause(Background()) + ctx, cancelChild := WithCancel(ctx) + cancelChild() + cancelParent(parentCause) + return ctx + }(), + err: Canceled, + cause: Canceled, + }, + { + name: "WithCancelCause: child cause after nil", + ctx: func() Context { + ctx, cancelParent := WithCancel(Background()) + ctx, cancelChild := WithCancelCause(ctx) + cancelParent() + cancelChild(childCause) + return ctx + }(), + err: Canceled, + cause: Canceled, + }, + { + name: "WithCancelCause: child cause before nil", + ctx: func() Context { + ctx, cancelParent := WithCancel(Background()) + ctx, cancelChild := WithCancelCause(ctx) + cancelChild(childCause) + cancelParent() + return ctx + }(), + err: Canceled, + cause: childCause, + }, + { + name: "WithTimeout", + ctx: func() Context { + ctx, cancel := WithTimeout(Background(), 0) + cancel() + return ctx + }(), + err: DeadlineExceeded, + cause: DeadlineExceeded, + }, + } { + if got, want := test.ctx.Err(), test.err; want != got { + t.Errorf("%s: ctx.Err() = %v want %v", test.name, got, want) + } + if got, want := Cause(test.ctx), test.cause; want != got { + t.Errorf("%s: Cause(ctx) = %v want %v", test.name, got, want) + } + } +} diff --git a/src/context/x_test.go b/src/context/x_test.go index 00eca72d5a..d3adb381d6 100644 --- a/src/context/x_test.go +++ b/src/context/x_test.go @@ -29,3 +29,4 @@ func TestWithValueChecksKey(t *testing.T) { XTestWithValueChecksKey func TestInvalidDerivedFail(t *testing.T) { XTestInvalidDerivedFail(t) } func TestDeadlineExceededSupportsTimeout(t *testing.T) { XTestDeadlineExceededSupportsTimeout(t) } func TestCustomContextGoroutines(t *testing.T) { XTestCustomContextGoroutines(t) } +func TestCause(t *testing.T) { XTestCause(t) } |