From 14a902c1a793af09ecac86dfce5f9aed277d7760 Mon Sep 17 00:00:00 2001 From: Kevin Wan Date: Fri, 28 Jan 2022 10:59:41 +0800 Subject: [PATCH] feat: handling panic in mapreduce, panic in calling goroutine, not inside goroutines (#1490) * feat: handle panic * chore: update fuzz test * chore: optimize square sum algorithm --- core/mr/mapreduce.go | 155 ++++++++++++----- core/mr/mapreduce_fuzz_test.go | 78 +++++++++ core/mr/mapreduce_test.go | 306 +++++++++++++++++++-------------- 3 files changed, 372 insertions(+), 167 deletions(-) create mode 100644 core/mr/mapreduce_fuzz_test.go diff --git a/core/mr/mapreduce.go b/core/mr/mapreduce.go index 80c54aa9..d73ff6b8 100644 --- a/core/mr/mapreduce.go +++ b/core/mr/mapreduce.go @@ -3,12 +3,11 @@ package mr import ( "context" "errors" - "fmt" "sync" + "sync/atomic" "github.com/zeromicro/go-zero/core/errorx" "github.com/zeromicro/go-zero/core/lang" - "github.com/zeromicro/go-zero/core/threading" ) const ( @@ -42,6 +41,16 @@ type ( // Option defines the method to customize the mapreduce. Option func(opts *mapReduceOptions) + mapperContext struct { + ctx context.Context + mapper MapFunc + source <-chan interface{} + panicChan *onceChan + collector chan<- interface{} + doneChan <-chan lang.PlaceholderType + workers int + } + mapReduceOptions struct { ctx context.Context workers int @@ -90,46 +99,72 @@ func FinishVoid(fns ...func()) { // ForEach maps all elements from given generate but no output. func ForEach(generate GenerateFunc, mapper ForEachFunc, opts ...Option) { - drain(Map(generate, func(item interface{}, writer Writer) { - mapper(item) - }, opts...)) -} - -// Map maps all elements generated from given generate func, and returns an output channel. -func Map(generate GenerateFunc, mapper MapFunc, opts ...Option) chan interface{} { options := buildOptions(opts...) - source := buildSource(generate) + panicChan := &onceChan{channel: make(chan interface{})} + source := buildSource(generate, panicChan) collector := make(chan interface{}, options.workers) done := make(chan lang.PlaceholderType) - go executeMappers(options.ctx, mapper, source, collector, done, options.workers) + go executeMappers(mapperContext{ + ctx: options.ctx, + mapper: func(item interface{}, writer Writer) { + mapper(item) + }, + source: source, + panicChan: panicChan, + collector: collector, + doneChan: done, + workers: options.workers, + }) - return collector + for { + select { + case v := <-panicChan.channel: + panic(v) + case _, ok := <-collector: + if !ok { + return + } + } + } } // MapReduce maps all elements generated from given generate func, // and reduces the output elements with given reducer. func MapReduce(generate GenerateFunc, mapper MapperFunc, reducer ReducerFunc, opts ...Option) (interface{}, error) { - source := buildSource(generate) - return MapReduceChan(source, mapper, reducer, opts...) + panicChan := &onceChan{channel: make(chan interface{})} + source := buildSource(generate, panicChan) + return mapReduceWithPanicChan(source, panicChan, mapper, reducer, opts...) } // MapReduceChan maps all elements from source, and reduce the output elements with given reducer. func MapReduceChan(source <-chan interface{}, mapper MapperFunc, reducer ReducerFunc, opts ...Option) (interface{}, error) { + panicChan := &onceChan{channel: make(chan interface{})} + return mapReduceWithPanicChan(source, panicChan, mapper, reducer, opts...) +} + +// MapReduceChan maps all elements from source, and reduce the output elements with given reducer. +func mapReduceWithPanicChan(source <-chan interface{}, panicChan *onceChan, mapper MapperFunc, + reducer ReducerFunc, opts ...Option) (interface{}, error) { options := buildOptions(opts...) + // output is used to write the final result output := make(chan interface{}) defer func() { + // reducer can only write once, if more, panic for range output { panic("more than one element written in reducer") } }() + // collector is used to collect data from mapper, and consume in reducer collector := make(chan interface{}, options.workers) + // if done is closed, all mappers and reducer should stop processing done := make(chan lang.PlaceholderType) writer := newGuardedWriter(options.ctx, output, done) var closeOnce sync.Once + // use atomic.Value to avoid data race var retErr errorx.AtomicError finish := func() { closeOnce.Do(func() { @@ -151,30 +186,38 @@ func MapReduceChan(source <-chan interface{}, mapper MapperFunc, reducer Reducer go func() { defer func() { drain(collector) - if r := recover(); r != nil { - cancel(fmt.Errorf("%v", r)) - } else { - finish() + panicChan.write(r) } + finish() }() reducer(collector, writer, cancel) }() - go executeMappers(options.ctx, func(item interface{}, w Writer) { - mapper(item, w, cancel) - }, source, collector, done, options.workers) + go executeMappers(mapperContext{ + ctx: options.ctx, + mapper: func(item interface{}, w Writer) { + mapper(item, w, cancel) + }, + source: source, + panicChan: panicChan, + collector: collector, + doneChan: done, + workers: options.workers, + }) select { case <-options.ctx.Done(): cancel(context.DeadlineExceeded) return nil, context.DeadlineExceeded - case value, ok := <-output: + case v := <-panicChan.channel: + panic(v) + case v, ok := <-output: if err := retErr.Load(); err != nil { return nil, err } else if ok { - return value, nil + return v, nil } else { return nil, ErrReduceNoOutput } @@ -221,12 +264,18 @@ func buildOptions(opts ...Option) *mapReduceOptions { return options } -func buildSource(generate GenerateFunc) chan interface{} { +func buildSource(generate GenerateFunc, panicChan *onceChan) chan interface{} { source := make(chan interface{}) - threading.GoSafe(func() { - defer close(source) + go func() { + defer func() { + if r := recover(); r != nil { + panicChan.write(r) + } + close(source) + }() + generate(source) - }) + }() return source } @@ -238,39 +287,54 @@ func drain(channel <-chan interface{}) { } } -func executeMappers(ctx context.Context, mapper MapFunc, input <-chan interface{}, - collector chan<- interface{}, done <-chan lang.PlaceholderType, workers int) { +func executeMappers(mCtx mapperContext) { var wg sync.WaitGroup + pc := &onceChan{channel: make(chan interface{})} defer func() { + // in case panic happens when processing last item, for loop not handling it. + select { + case r := <-pc.channel: + mCtx.panicChan.write(r) + default: + } + wg.Wait() - close(collector) + close(mCtx.collector) + drain(mCtx.source) }() - pool := make(chan lang.PlaceholderType, workers) - writer := newGuardedWriter(ctx, collector, done) + pool := make(chan lang.PlaceholderType, mCtx.workers) + writer := newGuardedWriter(mCtx.ctx, mCtx.collector, mCtx.doneChan) for { select { - case <-ctx.Done(): + case <-mCtx.ctx.Done(): + return + case <-mCtx.doneChan: return - case <-done: + case r := <-pc.channel: + // make sure this method quit ASAP, + // without this case branch, all the items from source will be consumed. + mCtx.panicChan.write(r) return case pool <- lang.Placeholder: - item, ok := <-input + item, ok := <-mCtx.source if !ok { <-pool return } wg.Add(1) - // better to safely run caller defined method - threading.GoSafe(func() { + go func() { defer func() { + if r := recover(); r != nil { + pc.write(r) + } wg.Done() <-pool }() - mapper(item, writer) - }) + mCtx.mapper(item, writer) + }() } } } @@ -316,3 +380,16 @@ func (gw guardedWriter) Write(v interface{}) { gw.channel <- v } } + +type onceChan struct { + channel chan interface{} + wrote int32 +} + +func (oc *onceChan) write(val interface{}) { + if atomic.AddInt32(&oc.wrote, 1) > 1 { + return + } + + oc.channel <- val +} diff --git a/core/mr/mapreduce_fuzz_test.go b/core/mr/mapreduce_fuzz_test.go new file mode 100644 index 00000000..770315ae --- /dev/null +++ b/core/mr/mapreduce_fuzz_test.go @@ -0,0 +1,78 @@ +//go:build go1.18 +// +build go1.18 + +package mr + +import ( + "fmt" + "math/rand" + "runtime" + "strings" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "go.uber.org/goleak" +) + +func FuzzMapReduce(f *testing.F) { + rand.Seed(time.Now().UnixNano()) + + f.Add(int64(10), runtime.NumCPU()) + f.Fuzz(func(t *testing.T, n int64, workers int) { + n = n%5000 + 5000 + genPanic := rand.Intn(100) == 0 + mapperPanic := rand.Intn(100) == 0 + reducerPanic := rand.Intn(100) == 0 + genIdx := rand.Int63n(n) + mapperIdx := rand.Int63n(n) + reducerIdx := rand.Int63n(n) + squareSum := (n - 1) * n * (2*n - 1) / 6 + + fn := func() (interface{}, error) { + defer goleak.VerifyNone(t, goleak.IgnoreCurrent()) + + return MapReduce(func(source chan<- interface{}) { + for i := int64(0); i < n; i++ { + source <- i + if genPanic && i == genIdx { + panic("foo") + } + } + }, func(item interface{}, writer Writer, cancel func(error)) { + v := item.(int64) + if mapperPanic && v == mapperIdx { + panic("bar") + } + writer.Write(v * v) + }, func(pipe <-chan interface{}, writer Writer, cancel func(error)) { + var idx int64 + var total int64 + for v := range pipe { + if reducerPanic && idx == reducerIdx { + panic("baz") + } + total += v.(int64) + idx++ + } + writer.Write(total) + }, WithWorkers(workers%50+runtime.NumCPU())) + } + + if genPanic || mapperPanic || reducerPanic { + var buf strings.Builder + buf.WriteString(fmt.Sprintf("n: %d", n)) + buf.WriteString(fmt.Sprintf(", genPanic: %t", genPanic)) + buf.WriteString(fmt.Sprintf(", mapperPanic: %t", mapperPanic)) + buf.WriteString(fmt.Sprintf(", reducerPanic: %t", reducerPanic)) + buf.WriteString(fmt.Sprintf(", genIdx: %d", genIdx)) + buf.WriteString(fmt.Sprintf(", mapperIdx: %d", mapperIdx)) + buf.WriteString(fmt.Sprintf(", reducerIdx: %d", reducerIdx)) + assert.Panicsf(t, func() { fn() }, buf.String()) + } else { + val, err := fn() + assert.Nil(t, err) + assert.Equal(t, squareSum, val.(int64)) + } + }) +} diff --git a/core/mr/mapreduce_test.go b/core/mr/mapreduce_test.go index 6ae2060d..b43f3f00 100644 --- a/core/mr/mapreduce_test.go +++ b/core/mr/mapreduce_test.go @@ -11,8 +11,6 @@ import ( "time" "github.com/stretchr/testify/assert" - "github.com/zeromicro/go-zero/core/stringx" - "github.com/zeromicro/go-zero/core/syncx" "go.uber.org/goleak" ) @@ -124,84 +122,69 @@ func TestForEach(t *testing.T) { t.Run("all", func(t *testing.T) { defer goleak.VerifyNone(t) - ForEach(func(source chan<- interface{}) { - for i := 0; i < tasks; i++ { - source <- i - } - }, func(item interface{}) { - panic("foo") + assert.PanicsWithValue(t, "foo", func() { + ForEach(func(source chan<- interface{}) { + for i := 0; i < tasks; i++ { + source <- i + } + }, func(item interface{}) { + panic("foo") + }) }) }) } -func TestMap(t *testing.T) { +func TestGeneratePanic(t *testing.T) { defer goleak.VerifyNone(t) - tests := []struct { - mapper MapFunc - expect int - }{ - { - mapper: func(item interface{}, writer Writer) { - v := item.(int) - writer.Write(v * v) - }, - expect: 30, - }, - { - mapper: func(item interface{}, writer Writer) { - v := item.(int) - if v%2 == 0 { - return - } - writer.Write(v * v) - }, - expect: 10, - }, - { - mapper: func(item interface{}, writer Writer) { - v := item.(int) - if v%2 == 0 { - panic(v) - } - writer.Write(v * v) - }, - expect: 10, - }, - } + t.Run("all", func(t *testing.T) { + assert.PanicsWithValue(t, "foo", func() { + ForEach(func(source chan<- interface{}) { + panic("foo") + }, func(item interface{}) { + }) + }) + }) +} - for _, test := range tests { - t.Run(stringx.Rand(), func(t *testing.T) { - channel := Map(func(source chan<- interface{}) { - for i := 1; i < 5; i++ { +func TestMapperPanic(t *testing.T) { + defer goleak.VerifyNone(t) + + const tasks = 1000 + var run int32 + t.Run("all", func(t *testing.T) { + assert.PanicsWithValue(t, "foo", func() { + _, _ = MapReduce(func(source chan<- interface{}) { + for i := 0; i < tasks; i++ { source <- i } - }, test.mapper, WithWorkers(-1)) - - var result int - for v := range channel { - result += v.(int) - } - - assert.Equal(t, test.expect, result) + }, func(item interface{}, writer Writer, cancel func(error)) { + atomic.AddInt32(&run, 1) + panic("foo") + }, func(pipe <-chan interface{}, writer Writer, cancel func(error)) { + }) }) - } + assert.True(t, atomic.LoadInt32(&run) < tasks/2) + }) } func TestMapReduce(t *testing.T) { defer goleak.VerifyNone(t) tests := []struct { + name string mapper MapperFunc reducer ReducerFunc expectErr error expectValue interface{} }{ { + name: "simple", expectErr: nil, expectValue: 30, }, { + name: "cancel with error", mapper: func(item interface{}, writer Writer, cancel func(error)) { v := item.(int) if v%3 == 0 { @@ -212,6 +195,7 @@ func TestMapReduce(t *testing.T) { expectErr: errDummy, }, { + name: "cancel with nil", mapper: func(item interface{}, writer Writer, cancel func(error)) { v := item.(int) if v%3 == 0 { @@ -223,6 +207,7 @@ func TestMapReduce(t *testing.T) { expectValue: nil, }, { + name: "cancel with more", reducer: func(pipe <-chan interface{}, writer Writer, cancel func(error)) { var result int for item := range pipe { @@ -237,45 +222,68 @@ func TestMapReduce(t *testing.T) { }, } - for _, test := range tests { - t.Run(stringx.Rand(), func(t *testing.T) { - if test.mapper == nil { - test.mapper = func(item interface{}, writer Writer, cancel func(error)) { - v := item.(int) - writer.Write(v * v) - } - } - if test.reducer == nil { - test.reducer = func(pipe <-chan interface{}, writer Writer, cancel func(error)) { - var result int - for item := range pipe { - result += item.(int) + t.Run("MapReduce", func(t *testing.T) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.mapper == nil { + test.mapper = func(item interface{}, writer Writer, cancel func(error)) { + v := item.(int) + writer.Write(v * v) } - writer.Write(result) } - } - value, err := MapReduce(func(source chan<- interface{}) { - for i := 1; i < 5; i++ { - source <- i + if test.reducer == nil { + test.reducer = func(pipe <-chan interface{}, writer Writer, cancel func(error)) { + var result int + for item := range pipe { + result += item.(int) + } + writer.Write(result) + } } - }, test.mapper, test.reducer, WithWorkers(runtime.NumCPU())) + value, err := MapReduce(func(source chan<- interface{}) { + for i := 1; i < 5; i++ { + source <- i + } + }, test.mapper, test.reducer, WithWorkers(runtime.NumCPU())) - assert.Equal(t, test.expectErr, err) - assert.Equal(t, test.expectValue, value) - }) - } -} + assert.Equal(t, test.expectErr, err) + assert.Equal(t, test.expectValue, value) + }) + } + }) -func TestMapReducePanicBothMapperAndReducer(t *testing.T) { - defer goleak.VerifyNone(t) + t.Run("MapReduce", func(t *testing.T) { + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if test.mapper == nil { + test.mapper = func(item interface{}, writer Writer, cancel func(error)) { + v := item.(int) + writer.Write(v * v) + } + } + if test.reducer == nil { + test.reducer = func(pipe <-chan interface{}, writer Writer, cancel func(error)) { + var result int + for item := range pipe { + result += item.(int) + } + writer.Write(result) + } + } - _, _ = MapReduce(func(source chan<- interface{}) { - source <- 0 - source <- 1 - }, func(item interface{}, writer Writer, cancel func(error)) { - panic("foo") - }, func(pipe <-chan interface{}, writer Writer, cancel func(error)) { - panic("bar") + source := make(chan interface{}) + go func() { + for i := 1; i < 5; i++ { + source <- i + } + close(source) + }() + + value, err := MapReduceChan(source, test.mapper, test.reducer, WithWorkers(-1)) + assert.Equal(t, test.expectErr, err) + assert.Equal(t, test.expectValue, value) + }) + } }) } @@ -302,16 +310,19 @@ func TestMapReduceVoid(t *testing.T) { var value uint32 tests := []struct { + name string mapper MapperFunc reducer VoidReducerFunc expectValue uint32 expectErr error }{ { + name: "simple", expectValue: 30, expectErr: nil, }, { + name: "cancel with error", mapper: func(item interface{}, writer Writer, cancel func(error)) { v := item.(int) if v%3 == 0 { @@ -322,6 +333,7 @@ func TestMapReduceVoid(t *testing.T) { expectErr: errDummy, }, { + name: "cancel with nil", mapper: func(item interface{}, writer Writer, cancel func(error)) { v := item.(int) if v%3 == 0 { @@ -332,6 +344,7 @@ func TestMapReduceVoid(t *testing.T) { expectErr: ErrCancelWithNil, }, { + name: "cancel with more", reducer: func(pipe <-chan interface{}, cancel func(error)) { for item := range pipe { result := atomic.AddUint32(&value, uint32(item.(int))) @@ -345,7 +358,7 @@ func TestMapReduceVoid(t *testing.T) { } for _, test := range tests { - t.Run(stringx.Rand(), func(t *testing.T) { + t.Run(test.name, func(t *testing.T) { atomic.StoreUint32(&value, 0) if test.mapper == nil { @@ -400,39 +413,59 @@ func TestMapReduceVoidWithDelay(t *testing.T) { assert.Equal(t, 0, result[1]) } -func TestMapVoid(t *testing.T) { +func TestMapReducePanic(t *testing.T) { defer goleak.VerifyNone(t) - const tasks = 1000 - var count uint32 - ForEach(func(source chan<- interface{}) { - for i := 0; i < tasks; i++ { - source <- i - } - }, func(item interface{}) { - atomic.AddUint32(&count, 1) + assert.Panics(t, func() { + _, _ = MapReduce(func(source chan<- interface{}) { + source <- 0 + source <- 1 + }, func(item interface{}, writer Writer, cancel func(error)) { + i := item.(int) + writer.Write(i) + }, func(pipe <-chan interface{}, writer Writer, cancel func(error)) { + for range pipe { + panic("panic") + } + }) }) +} - assert.Equal(t, tasks, int(count)) +func TestMapReducePanicOnce(t *testing.T) { + defer goleak.VerifyNone(t) + + assert.Panics(t, func() { + _, _ = MapReduce(func(source chan<- interface{}) { + for i := 0; i < 100; i++ { + source <- i + } + }, func(item interface{}, writer Writer, cancel func(error)) { + i := item.(int) + if i == 0 { + panic("foo") + } + writer.Write(i) + }, func(pipe <-chan interface{}, writer Writer, cancel func(error)) { + for range pipe { + panic("bar") + } + }) + }) } -func TestMapReducePanic(t *testing.T) { +func TestMapReducePanicBothMapperAndReducer(t *testing.T) { defer goleak.VerifyNone(t) - v, err := MapReduce(func(source chan<- interface{}) { - source <- 0 - source <- 1 - }, func(item interface{}, writer Writer, cancel func(error)) { - i := item.(int) - writer.Write(i) - }, func(pipe <-chan interface{}, writer Writer, cancel func(error)) { - for range pipe { - panic("panic") - } + assert.Panics(t, func() { + _, _ = MapReduce(func(source chan<- interface{}) { + source <- 0 + source <- 1 + }, func(item interface{}, writer Writer, cancel func(error)) { + panic("foo") + }, func(pipe <-chan interface{}, writer Writer, cancel func(error)) { + panic("bar") + }) }) - assert.Nil(t, v) - assert.NotNil(t, err) - assert.Equal(t, "panic", err.Error()) } func TestMapReduceVoidCancel(t *testing.T) { @@ -461,13 +494,13 @@ func TestMapReduceVoidCancel(t *testing.T) { func TestMapReduceVoidCancelWithRemains(t *testing.T) { defer goleak.VerifyNone(t) - var done syncx.AtomicBool + var done int32 var result []int err := MapReduceVoid(func(source chan<- interface{}) { for i := 0; i < defaultWorkers*2; i++ { source <- i } - done.Set(true) + atomic.AddInt32(&done, 1) }, func(item interface{}, writer Writer, cancel func(error)) { i := item.(int) if i == defaultWorkers/2 { @@ -482,7 +515,7 @@ func TestMapReduceVoidCancelWithRemains(t *testing.T) { }) assert.NotNil(t, err) assert.Equal(t, "anything", err.Error()) - assert.True(t, done.True()) + assert.Equal(t, int32(1), done) } func TestMapReduceWithoutReducerWrite(t *testing.T) { @@ -507,34 +540,51 @@ func TestMapReduceVoidPanicInReducer(t *testing.T) { defer goleak.VerifyNone(t) const message = "foo" - var done syncx.AtomicBool - err := MapReduceVoid(func(source chan<- interface{}) { + assert.Panics(t, func() { + var done int32 + _ = MapReduceVoid(func(source chan<- interface{}) { + for i := 0; i < defaultWorkers*2; i++ { + source <- i + } + atomic.AddInt32(&done, 1) + }, func(item interface{}, writer Writer, cancel func(error)) { + i := item.(int) + writer.Write(i) + }, func(pipe <-chan interface{}, cancel func(error)) { + panic(message) + }, WithWorkers(1)) + }) +} + +func TestForEachWithContext(t *testing.T) { + defer goleak.VerifyNone(t) + + var done int32 + ctx, cancel := context.WithCancel(context.Background()) + ForEach(func(source chan<- interface{}) { for i := 0; i < defaultWorkers*2; i++ { source <- i } - done.Set(true) - }, func(item interface{}, writer Writer, cancel func(error)) { + atomic.AddInt32(&done, 1) + }, func(item interface{}) { i := item.(int) - writer.Write(i) - }, func(pipe <-chan interface{}, cancel func(error)) { - panic(message) - }, WithWorkers(1)) - assert.NotNil(t, err) - assert.Equal(t, message, err.Error()) - assert.True(t, done.True()) + if i == defaultWorkers/2 { + cancel() + } + }, WithContext(ctx)) } func TestMapReduceWithContext(t *testing.T) { defer goleak.VerifyNone(t) - var done syncx.AtomicBool + var done int32 var result []int ctx, cancel := context.WithCancel(context.Background()) err := MapReduceVoid(func(source chan<- interface{}) { for i := 0; i < defaultWorkers*2; i++ { source <- i } - done.Set(true) + atomic.AddInt32(&done, 1) }, func(item interface{}, writer Writer, c func(error)) { i := item.(int) if i == defaultWorkers/2 {