summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--api/next/51365.txt3
-rw-r--r--src/context/context.go94
-rw-r--r--src/context/context_test.go153
-rw-r--r--src/context/x_test.go1
4 files changed, 231 insertions, 20 deletions
diff --git a/api/next/51365.txt b/api/next/51365.txt
new file mode 100644
index 0000000000..df629f1852
--- /dev/null
+++ b/api/next/51365.txt
@@ -0,0 +1,3 @@
+pkg context, func Cause(Context) error #51365
+pkg context, func WithCancelCause(Context) (Context, CancelCauseFunc) #51365
+pkg context, type CancelCauseFunc func(error) #51365
diff --git a/src/context/context.go b/src/context/context.go
index 7eace57893..a0b5edc524 100644
--- a/src/context/context.go
+++ b/src/context/context.go
@@ -22,6 +22,12 @@
// fires. The go vet tool checks that CancelFuncs are used on all
// control-flow paths.
//
+// The WithCancelCause function returns a CancelCauseFunc, which
+// takes an error and records it as the cancelation cause. Calling
+// Cause on the canceled context or any of its children retrieves
+// the cause. If no cause is specified, Cause(ctx) returns the same
+// value as ctx.Err().
+//
// Programs that use Contexts should follow these rules to keep interfaces
// consistent across packages and enable static analysis tools to check context
// propagation:
@@ -230,17 +236,63 @@ type CancelFunc func()
// Canceling this context releases resources associated with it, so code should
// call cancel as soon as the operations running in this Context complete.
func WithCancel(parent Context) (ctx Context, cancel CancelFunc) {
+ c := withCancel(parent)
+ return c, func() { c.cancel(true, Canceled, nil) }
+}
+
+// A CancelCauseFunc behaves like a CancelFunc but additionally sets the cancelation cause.
+// This cause can be retrieved by calling Cause on the canceled Context or on
+// any of its derived Contexts.
+//
+// If the context has already been canceled, CancelCauseFunc does not set the cause.
+// For example, if childContext is derived from parentContext:
+// - if parentContext is canceled with cause1 before childContext is canceled with cause2,
+// then Cause(parentContext) == Cause(childContext) == cause1
+// - if childContext is canceled with cause2 before parentContext is canceled with cause1,
+// then Cause(parentContext) == cause1 and Cause(childContext) == cause2
+type CancelCauseFunc func(cause error)
+
+// WithCancelCause behaves like WithCancel but returns a CancelCauseFunc instead of a CancelFunc.
+// Calling cancel with a non-nil error (the "cause") records that error in ctx;
+// it can then be retrieved using Cause(ctx).
+// Calling cancel with nil sets the cause to Canceled.
+//
+// Example use:
+//
+// ctx, cancel := context.WithCancelCause(parent)
+// cancel(myError)
+// ctx.Err() // returns context.Canceled
+// context.Cause(ctx) // returns myError
+func WithCancelCause(parent Context) (ctx Context, cancel CancelCauseFunc) {
+ c := withCancel(parent)
+ return c, func(cause error) { c.cancel(true, Canceled, cause) }
+}
+
+func withCancel(parent Context) *cancelCtx {
if parent == nil {
panic("cannot create context from nil parent")
}
c := newCancelCtx(parent)
- propagateCancel(parent, &c)
- return &c, func() { c.cancel(true, Canceled) }
+ propagateCancel(parent, c)
+ return c
+}
+
+// Cause returns a non-nil error explaining why c was canceled.
+// The first cancelation of c or one of its parents sets the cause.
+// If that cancelation happened via a call to CancelCauseFunc(err),
+// then Cause returns err.
+// Otherwise Cause(c) returns the same value as c.Err().
+// Cause returns nil if c has not been canceled yet.
+func Cause(c Context) error {
+ if cc, ok := c.Value(&cancelCtxKey).(*cancelCtx); ok {
+ return cc.cause
+ }
+ return nil
}
// newCancelCtx returns an initialized cancelCtx.
-func newCancelCtx(parent Context) cancelCtx {
- return cancelCtx{Context: parent}
+func newCancelCtx(parent Context) *cancelCtx {
+ return &cancelCtx{Context: parent}
}
// goroutines counts the number of goroutines ever created; for testing.
@@ -256,7 +308,7 @@ func propagateCancel(parent Context, child canceler) {
select {
case <-done:
// parent is already canceled
- child.cancel(false, parent.Err())
+ child.cancel(false, parent.Err(), Cause(parent))
return
default:
}
@@ -265,7 +317,7 @@ func propagateCancel(parent Context, child canceler) {
p.mu.Lock()
if p.err != nil {
// parent has already been canceled
- child.cancel(false, p.err)
+ child.cancel(false, p.err, p.cause)
} else {
if p.children == nil {
p.children = make(map[canceler]struct{})
@@ -278,7 +330,7 @@ func propagateCancel(parent Context, child canceler) {
go func() {
select {
case <-parent.Done():
- child.cancel(false, parent.Err())
+ child.cancel(false, parent.Err(), Cause(parent))
case <-child.Done():
}
}()
@@ -326,7 +378,7 @@ func removeChild(parent Context, child canceler) {
// A canceler is a context type that can be canceled directly. The
// implementations are *cancelCtx and *timerCtx.
type canceler interface {
- cancel(removeFromParent bool, err error)
+ cancel(removeFromParent bool, err, cause error)
Done() <-chan struct{}
}
@@ -346,6 +398,7 @@ type cancelCtx struct {
done atomic.Value // of chan struct{}, created lazily, closed by first cancel call
children map[canceler]struct{} // set to nil by the first cancel call
err error // set to non-nil by the first cancel call
+ cause error // set to non-nil by the first cancel call
}
func (c *cancelCtx) Value(key any) any {
@@ -394,16 +447,21 @@ func (c *cancelCtx) String() string {
// cancel closes c.done, cancels each of c's children, and, if
// removeFromParent is true, removes c from its parent's children.
-func (c *cancelCtx) cancel(removeFromParent bool, err error) {
+// cancel sets c.cause to cause if this is the first time c is canceled.
+func (c *cancelCtx) cancel(removeFromParent bool, err, cause error) {
if err == nil {
panic("context: internal error: missing cancel error")
}
+ if cause == nil {
+ cause = err
+ }
c.mu.Lock()
if c.err != nil {
c.mu.Unlock()
return // already canceled
}
c.err = err
+ c.cause = cause
d, _ := c.done.Load().(chan struct{})
if d == nil {
c.done.Store(closedchan)
@@ -412,7 +470,7 @@ func (c *cancelCtx) cancel(removeFromParent bool, err error) {
}
for child := range c.children {
// NOTE: acquiring the child's lock while holding parent's lock.
- child.cancel(false, err)
+ child.cancel(false, err, cause)
}
c.children = nil
c.mu.Unlock()
@@ -446,24 +504,24 @@ func WithDeadline(parent Context, d time.Time) (Context, CancelFunc) {
propagateCancel(parent, c)
dur := time.Until(d)
if dur <= 0 {
- c.cancel(true, DeadlineExceeded) // deadline has already passed
- return c, func() { c.cancel(false, Canceled) }
+ c.cancel(true, DeadlineExceeded, nil) // deadline has already passed
+ return c, func() { c.cancel(false, Canceled, nil) }
}
c.mu.Lock()
defer c.mu.Unlock()
if c.err == nil {
c.timer = time.AfterFunc(dur, func() {
- c.cancel(true, DeadlineExceeded)
+ c.cancel(true, DeadlineExceeded, nil)
})
}
- return c, func() { c.cancel(true, Canceled) }
+ return c, func() { c.cancel(true, Canceled, nil) }
}
// A timerCtx carries a timer and a deadline. It embeds a cancelCtx to
// implement Done and Err. It implements cancel by stopping its timer then
// delegating to cancelCtx.cancel.
type timerCtx struct {
- cancelCtx
+ *cancelCtx
timer *time.Timer // Under cancelCtx.mu.
deadline time.Time
@@ -479,8 +537,8 @@ func (c *timerCtx) String() string {
time.Until(c.deadline).String() + "])"
}
-func (c *timerCtx) cancel(removeFromParent bool, err error) {
- c.cancelCtx.cancel(false, err)
+func (c *timerCtx) cancel(removeFromParent bool, err, cause error) {
+ c.cancelCtx.cancel(false, err, cause)
if removeFromParent {
// Remove this timerCtx from its parent cancelCtx's children.
removeChild(c.cancelCtx.Context, c)
@@ -581,7 +639,7 @@ func value(c Context, key any) any {
c = ctx.Context
case *timerCtx:
if key == &cancelCtxKey {
- return &ctx.cancelCtx
+ return ctx.cancelCtx
}
c = ctx.Context
case *emptyCtx:
diff --git a/src/context/context_test.go b/src/context/context_test.go
index 0991880907..593a7b1521 100644
--- a/src/context/context_test.go
+++ b/src/context/context_test.go
@@ -650,8 +650,9 @@ func XTestCancelRemoves(t testingT) {
}
func XTestWithCancelCanceledParent(t testingT) {
- parent, pcancel := WithCancel(Background())
- pcancel()
+ parent, pcancel := WithCancelCause(Background())
+ cause := fmt.Errorf("Because!")
+ pcancel(cause)
c, _ := WithCancel(parent)
select {
@@ -662,6 +663,9 @@ func XTestWithCancelCanceledParent(t testingT) {
if got, want := c.Err(), Canceled; got != want {
t.Errorf("child not canceled; got = %v, want = %v", got, want)
}
+ if got, want := Cause(c), cause; got != want {
+ t.Errorf("child has wrong cause; got = %v, want = %v", got, want)
+ }
}
func XTestWithValueChecksKey(t testingT) {
@@ -785,3 +789,148 @@ func XTestCustomContextGoroutines(t testingT) {
defer cancel7()
checkNoGoroutine()
}
+
+func XTestCause(t testingT) {
+ var (
+ parentCause = fmt.Errorf("parentCause")
+ childCause = fmt.Errorf("childCause")
+ )
+ for _, test := range []struct {
+ name string
+ ctx Context
+ err error
+ cause error
+ }{
+ {
+ name: "Background",
+ ctx: Background(),
+ err: nil,
+ cause: nil,
+ },
+ {
+ name: "TODO",
+ ctx: TODO(),
+ err: nil,
+ cause: nil,
+ },
+ {
+ name: "WithCancel",
+ ctx: func() Context {
+ ctx, cancel := WithCancel(Background())
+ cancel()
+ return ctx
+ }(),
+ err: Canceled,
+ cause: Canceled,
+ },
+ {
+ name: "WithCancelCause",
+ ctx: func() Context {
+ ctx, cancel := WithCancelCause(Background())
+ cancel(parentCause)
+ return ctx
+ }(),
+ err: Canceled,
+ cause: parentCause,
+ },
+ {
+ name: "WithCancelCause nil",
+ ctx: func() Context {
+ ctx, cancel := WithCancelCause(Background())
+ cancel(nil)
+ return ctx
+ }(),
+ err: Canceled,
+ cause: Canceled,
+ },
+ {
+ name: "WithCancelCause: parent cause before child",
+ ctx: func() Context {
+ ctx, cancelParent := WithCancelCause(Background())
+ ctx, cancelChild := WithCancelCause(ctx)
+ cancelParent(parentCause)
+ cancelChild(childCause)
+ return ctx
+ }(),
+ err: Canceled,
+ cause: parentCause,
+ },
+ {
+ name: "WithCancelCause: parent cause after child",
+ ctx: func() Context {
+ ctx, cancelParent := WithCancelCause(Background())
+ ctx, cancelChild := WithCancelCause(ctx)
+ cancelChild(childCause)
+ cancelParent(parentCause)
+ return ctx
+ }(),
+ err: Canceled,
+ cause: childCause,
+ },
+ {
+ name: "WithCancelCause: parent cause before nil",
+ ctx: func() Context {
+ ctx, cancelParent := WithCancelCause(Background())
+ ctx, cancelChild := WithCancel(ctx)
+ cancelParent(parentCause)
+ cancelChild()
+ return ctx
+ }(),
+ err: Canceled,
+ cause: parentCause,
+ },
+ {
+ name: "WithCancelCause: parent cause after nil",
+ ctx: func() Context {
+ ctx, cancelParent := WithCancelCause(Background())
+ ctx, cancelChild := WithCancel(ctx)
+ cancelChild()
+ cancelParent(parentCause)
+ return ctx
+ }(),
+ err: Canceled,
+ cause: Canceled,
+ },
+ {
+ name: "WithCancelCause: child cause after nil",
+ ctx: func() Context {
+ ctx, cancelParent := WithCancel(Background())
+ ctx, cancelChild := WithCancelCause(ctx)
+ cancelParent()
+ cancelChild(childCause)
+ return ctx
+ }(),
+ err: Canceled,
+ cause: Canceled,
+ },
+ {
+ name: "WithCancelCause: child cause before nil",
+ ctx: func() Context {
+ ctx, cancelParent := WithCancel(Background())
+ ctx, cancelChild := WithCancelCause(ctx)
+ cancelChild(childCause)
+ cancelParent()
+ return ctx
+ }(),
+ err: Canceled,
+ cause: childCause,
+ },
+ {
+ name: "WithTimeout",
+ ctx: func() Context {
+ ctx, cancel := WithTimeout(Background(), 0)
+ cancel()
+ return ctx
+ }(),
+ err: DeadlineExceeded,
+ cause: DeadlineExceeded,
+ },
+ } {
+ if got, want := test.ctx.Err(), test.err; want != got {
+ t.Errorf("%s: ctx.Err() = %v want %v", test.name, got, want)
+ }
+ if got, want := Cause(test.ctx), test.cause; want != got {
+ t.Errorf("%s: Cause(ctx) = %v want %v", test.name, got, want)
+ }
+ }
+}
diff --git a/src/context/x_test.go b/src/context/x_test.go
index 00eca72d5a..d3adb381d6 100644
--- a/src/context/x_test.go
+++ b/src/context/x_test.go
@@ -29,3 +29,4 @@ func TestWithValueChecksKey(t *testing.T) { XTestWithValueChecksKey
func TestInvalidDerivedFail(t *testing.T) { XTestInvalidDerivedFail(t) }
func TestDeadlineExceededSupportsTimeout(t *testing.T) { XTestDeadlineExceededSupportsTimeout(t) }
func TestCustomContextGoroutines(t *testing.T) { XTestCustomContextGoroutines(t) }
+func TestCause(t *testing.T) { XTestCause(t) }