fix: goroutine stuck on edge case (#1495)

* fix: goroutine stuck on edge case

* refactor: simplify mapreduce implementation
master
Kevin Wan 3 years ago committed by GitHub
parent 14a902c1a7
commit 6c2abe7474
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -289,33 +289,21 @@ func drain(channel <-chan interface{}) {
func executeMappers(mCtx mapperContext) { func executeMappers(mCtx mapperContext) {
var wg sync.WaitGroup var wg sync.WaitGroup
pc := &onceChan{channel: make(chan interface{})}
defer func() { defer func() {
// in case panic happens when processing last item, for loop not handling it.
select {
case r := <-pc.channel:
mCtx.panicChan.write(r)
default:
}
wg.Wait() wg.Wait()
close(mCtx.collector) close(mCtx.collector)
drain(mCtx.source) drain(mCtx.source)
}() }()
var failed int32
pool := make(chan lang.PlaceholderType, mCtx.workers) pool := make(chan lang.PlaceholderType, mCtx.workers)
writer := newGuardedWriter(mCtx.ctx, mCtx.collector, mCtx.doneChan) writer := newGuardedWriter(mCtx.ctx, mCtx.collector, mCtx.doneChan)
for { for atomic.LoadInt32(&failed) == 0 {
select { select {
case <-mCtx.ctx.Done(): case <-mCtx.ctx.Done():
return return
case <-mCtx.doneChan: case <-mCtx.doneChan:
return return
case r := <-pc.channel:
// make sure this method quit ASAP,
// without this case branch, all the items from source will be consumed.
mCtx.panicChan.write(r)
return
case pool <- lang.Placeholder: case pool <- lang.Placeholder:
item, ok := <-mCtx.source item, ok := <-mCtx.source
if !ok { if !ok {
@ -327,7 +315,8 @@ func executeMappers(mCtx mapperContext) {
go func() { go func() {
defer func() { defer func() {
if r := recover(); r != nil { if r := recover(); r != nil {
pc.write(r) atomic.AddInt32(&failed, 1)
mCtx.panicChan.write(r)
} }
wg.Done() wg.Done()
<-pool <-pool

@ -18,9 +18,9 @@ import (
func FuzzMapReduce(f *testing.F) { func FuzzMapReduce(f *testing.F) {
rand.Seed(time.Now().UnixNano()) rand.Seed(time.Now().UnixNano())
f.Add(int64(10), runtime.NumCPU()) f.Add(uint(10), uint(runtime.NumCPU()))
f.Fuzz(func(t *testing.T, n int64, workers int) { f.Fuzz(func(t *testing.T, num uint, workers uint) {
n = n%5000 + 5000 n := int64(num)%5000 + 5000
genPanic := rand.Intn(100) == 0 genPanic := rand.Intn(100) == 0
mapperPanic := rand.Intn(100) == 0 mapperPanic := rand.Intn(100) == 0
reducerPanic := rand.Intn(100) == 0 reducerPanic := rand.Intn(100) == 0
@ -56,7 +56,7 @@ func FuzzMapReduce(f *testing.F) {
idx++ idx++
} }
writer.Write(total) writer.Write(total)
}, WithWorkers(workers%50+runtime.NumCPU())) }, WithWorkers(int(workers)%50+runtime.NumCPU()/2))
} }
if genPanic || mapperPanic || reducerPanic { if genPanic || mapperPanic || reducerPanic {

@ -0,0 +1,107 @@
//go:build fuzz
// +build fuzz
package mr
import (
"fmt"
"math/rand"
"runtime"
"strconv"
"strings"
"sync"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/threading"
"gopkg.in/cheggaaa/pb.v1"
)
// If Fuzz stuck, we don't know why, because it only returns hung or unexpected,
// so we need to simulate the fuzz test in test mode.
func TestMapReduceRandom(t *testing.T) {
rand.Seed(time.Now().UnixNano())
const (
times = 10000
nRange = 500
mega = 1024 * 1024
)
bar := pb.New(times).Start()
runner := threading.NewTaskRunner(runtime.NumCPU())
var wg sync.WaitGroup
wg.Add(times)
for i := 0; i < times; i++ {
runner.Schedule(func() {
start := time.Now()
defer func() {
if time.Since(start) > time.Minute {
t.Fatal("timeout")
}
wg.Done()
}()
t.Run(strconv.Itoa(i), func(t *testing.T) {
n := rand.Int63n(nRange)%nRange + nRange
workers := rand.Int()%50 + runtime.NumCPU()/2
genPanic := rand.Intn(100) == 0
mapperPanic := rand.Intn(100) == 0
reducerPanic := rand.Intn(100) == 0
genIdx := rand.Int63n(n)
mapperIdx := rand.Int63n(n)
reducerIdx := rand.Int63n(n)
squareSum := (n - 1) * n * (2*n - 1) / 6
fn := func() (interface{}, error) {
return MapReduce(func(source chan<- interface{}) {
for i := int64(0); i < n; i++ {
source <- i
if genPanic && i == genIdx {
panic("foo")
}
}
}, func(item interface{}, writer Writer, cancel func(error)) {
v := item.(int64)
if mapperPanic && v == mapperIdx {
panic("bar")
}
writer.Write(v * v)
}, func(pipe <-chan interface{}, writer Writer, cancel func(error)) {
var idx int64
var total int64
for v := range pipe {
if reducerPanic && idx == reducerIdx {
panic("baz")
}
total += v.(int64)
idx++
}
writer.Write(total)
}, WithWorkers(int(workers)%50+runtime.NumCPU()/2))
}
if genPanic || mapperPanic || reducerPanic {
var buf strings.Builder
buf.WriteString(fmt.Sprintf("n: %d", n))
buf.WriteString(fmt.Sprintf(", genPanic: %t", genPanic))
buf.WriteString(fmt.Sprintf(", mapperPanic: %t", mapperPanic))
buf.WriteString(fmt.Sprintf(", reducerPanic: %t", reducerPanic))
buf.WriteString(fmt.Sprintf(", genIdx: %d", genIdx))
buf.WriteString(fmt.Sprintf(", mapperIdx: %d", mapperIdx))
buf.WriteString(fmt.Sprintf(", reducerIdx: %d", reducerIdx))
assert.Panicsf(t, func() { fn() }, buf.String())
} else {
val, err := fn()
assert.Nil(t, err)
assert.Equal(t, squareSum, val.(int64))
}
bar.Increment()
})
})
}
wg.Wait()
bar.Finish()
}
Loading…
Cancel
Save