From 6c2abe7474b4947ef2f4e4e271398b21e9072673 Mon Sep 17 00:00:00 2001 From: Kevin Wan Date: Sun, 30 Jan 2022 13:09:21 +0800 Subject: [PATCH] fix: goroutine stuck on edge case (#1495) * fix: goroutine stuck on edge case * refactor: simplify mapreduce implementation --- core/mr/mapreduce.go | 19 ++---- core/mr/mapreduce_fuzz_test.go | 8 +-- core/mr/mapreduce_rand_test.go | 107 +++++++++++++++++++++++++++++++++ 3 files changed, 115 insertions(+), 19 deletions(-) create mode 100644 core/mr/mapreduce_rand_test.go diff --git a/core/mr/mapreduce.go b/core/mr/mapreduce.go index d73ff6b8..e7763f0e 100644 --- a/core/mr/mapreduce.go +++ b/core/mr/mapreduce.go @@ -289,33 +289,21 @@ func drain(channel <-chan interface{}) { func executeMappers(mCtx mapperContext) { var wg sync.WaitGroup - pc := &onceChan{channel: make(chan interface{})} defer func() { - // in case panic happens when processing last item, for loop not handling it. - select { - case r := <-pc.channel: - mCtx.panicChan.write(r) - default: - } - wg.Wait() close(mCtx.collector) drain(mCtx.source) }() + var failed int32 pool := make(chan lang.PlaceholderType, mCtx.workers) writer := newGuardedWriter(mCtx.ctx, mCtx.collector, mCtx.doneChan) - for { + for atomic.LoadInt32(&failed) == 0 { select { case <-mCtx.ctx.Done(): return case <-mCtx.doneChan: return - case r := <-pc.channel: - // make sure this method quit ASAP, - // without this case branch, all the items from source will be consumed. - mCtx.panicChan.write(r) - return case pool <- lang.Placeholder: item, ok := <-mCtx.source if !ok { @@ -327,7 +315,8 @@ func executeMappers(mCtx mapperContext) { go func() { defer func() { if r := recover(); r != nil { - pc.write(r) + atomic.AddInt32(&failed, 1) + mCtx.panicChan.write(r) } wg.Done() <-pool diff --git a/core/mr/mapreduce_fuzz_test.go b/core/mr/mapreduce_fuzz_test.go index 770315ae..fa930a50 100644 --- a/core/mr/mapreduce_fuzz_test.go +++ b/core/mr/mapreduce_fuzz_test.go @@ -18,9 +18,9 @@ import ( func FuzzMapReduce(f *testing.F) { rand.Seed(time.Now().UnixNano()) - f.Add(int64(10), runtime.NumCPU()) - f.Fuzz(func(t *testing.T, n int64, workers int) { - n = n%5000 + 5000 + f.Add(uint(10), uint(runtime.NumCPU())) + f.Fuzz(func(t *testing.T, num uint, workers uint) { + n := int64(num)%5000 + 5000 genPanic := rand.Intn(100) == 0 mapperPanic := rand.Intn(100) == 0 reducerPanic := rand.Intn(100) == 0 @@ -56,7 +56,7 @@ func FuzzMapReduce(f *testing.F) { idx++ } writer.Write(total) - }, WithWorkers(workers%50+runtime.NumCPU())) + }, WithWorkers(int(workers)%50+runtime.NumCPU()/2)) } if genPanic || mapperPanic || reducerPanic { diff --git a/core/mr/mapreduce_rand_test.go b/core/mr/mapreduce_rand_test.go new file mode 100644 index 00000000..cbc8fc29 --- /dev/null +++ b/core/mr/mapreduce_rand_test.go @@ -0,0 +1,107 @@ +//go:build fuzz +// +build fuzz + +package mr + +import ( + "fmt" + "math/rand" + "runtime" + "strconv" + "strings" + "sync" + "testing" + "time" + + "github.com/stretchr/testify/assert" + "github.com/zeromicro/go-zero/core/threading" + "gopkg.in/cheggaaa/pb.v1" +) + +// If Fuzz stuck, we don't know why, because it only returns hung or unexpected, +// so we need to simulate the fuzz test in test mode. +func TestMapReduceRandom(t *testing.T) { + rand.Seed(time.Now().UnixNano()) + + const ( + times = 10000 + nRange = 500 + mega = 1024 * 1024 + ) + + bar := pb.New(times).Start() + runner := threading.NewTaskRunner(runtime.NumCPU()) + var wg sync.WaitGroup + wg.Add(times) + for i := 0; i < times; i++ { + runner.Schedule(func() { + start := time.Now() + defer func() { + if time.Since(start) > time.Minute { + t.Fatal("timeout") + } + wg.Done() + }() + + t.Run(strconv.Itoa(i), func(t *testing.T) { + n := rand.Int63n(nRange)%nRange + nRange + workers := rand.Int()%50 + runtime.NumCPU()/2 + genPanic := rand.Intn(100) == 0 + mapperPanic := rand.Intn(100) == 0 + reducerPanic := rand.Intn(100) == 0 + genIdx := rand.Int63n(n) + mapperIdx := rand.Int63n(n) + reducerIdx := rand.Int63n(n) + squareSum := (n - 1) * n * (2*n - 1) / 6 + + fn := func() (interface{}, error) { + return MapReduce(func(source chan<- interface{}) { + for i := int64(0); i < n; i++ { + source <- i + if genPanic && i == genIdx { + panic("foo") + } + } + }, func(item interface{}, writer Writer, cancel func(error)) { + v := item.(int64) + if mapperPanic && v == mapperIdx { + panic("bar") + } + writer.Write(v * v) + }, func(pipe <-chan interface{}, writer Writer, cancel func(error)) { + var idx int64 + var total int64 + for v := range pipe { + if reducerPanic && idx == reducerIdx { + panic("baz") + } + total += v.(int64) + idx++ + } + writer.Write(total) + }, WithWorkers(int(workers)%50+runtime.NumCPU()/2)) + } + + if genPanic || mapperPanic || reducerPanic { + var buf strings.Builder + buf.WriteString(fmt.Sprintf("n: %d", n)) + buf.WriteString(fmt.Sprintf(", genPanic: %t", genPanic)) + buf.WriteString(fmt.Sprintf(", mapperPanic: %t", mapperPanic)) + buf.WriteString(fmt.Sprintf(", reducerPanic: %t", reducerPanic)) + buf.WriteString(fmt.Sprintf(", genIdx: %d", genIdx)) + buf.WriteString(fmt.Sprintf(", mapperIdx: %d", mapperIdx)) + buf.WriteString(fmt.Sprintf(", reducerIdx: %d", reducerIdx)) + assert.Panicsf(t, func() { fn() }, buf.String()) + } else { + val, err := fn() + assert.Nil(t, err) + assert.Equal(t, squareSum, val.(int64)) + } + bar.Increment() + }) + }) + } + + wg.Wait() + bar.Finish() +}