feat: handling panic in mapreduce, panic in calling goroutine, not inside goroutines (#1490)

* feat: handle panic

* chore: update fuzz test

* chore: optimize square sum algorithm
master
Kevin Wan 3 years ago committed by GitHub
parent 5ad6a6d229
commit 14a902c1a7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -3,12 +3,11 @@ package mr
import (
"context"
"errors"
"fmt"
"sync"
"sync/atomic"
"github.com/zeromicro/go-zero/core/errorx"
"github.com/zeromicro/go-zero/core/lang"
"github.com/zeromicro/go-zero/core/threading"
)
const (
@ -42,6 +41,16 @@ type (
// Option defines the method to customize the mapreduce.
Option func(opts *mapReduceOptions)
mapperContext struct {
ctx context.Context
mapper MapFunc
source <-chan interface{}
panicChan *onceChan
collector chan<- interface{}
doneChan <-chan lang.PlaceholderType
workers int
}
mapReduceOptions struct {
ctx context.Context
workers int
@ -90,46 +99,72 @@ func FinishVoid(fns ...func()) {
// ForEach maps all elements from given generate but no output.
func ForEach(generate GenerateFunc, mapper ForEachFunc, opts ...Option) {
drain(Map(generate, func(item interface{}, writer Writer) {
mapper(item)
}, opts...))
}
// Map maps all elements generated from given generate func, and returns an output channel.
func Map(generate GenerateFunc, mapper MapFunc, opts ...Option) chan interface{} {
options := buildOptions(opts...)
source := buildSource(generate)
panicChan := &onceChan{channel: make(chan interface{})}
source := buildSource(generate, panicChan)
collector := make(chan interface{}, options.workers)
done := make(chan lang.PlaceholderType)
go executeMappers(options.ctx, mapper, source, collector, done, options.workers)
go executeMappers(mapperContext{
ctx: options.ctx,
mapper: func(item interface{}, writer Writer) {
mapper(item)
},
source: source,
panicChan: panicChan,
collector: collector,
doneChan: done,
workers: options.workers,
})
return collector
for {
select {
case v := <-panicChan.channel:
panic(v)
case _, ok := <-collector:
if !ok {
return
}
}
}
}
// MapReduce maps all elements generated from given generate func,
// and reduces the output elements with given reducer.
func MapReduce(generate GenerateFunc, mapper MapperFunc, reducer ReducerFunc,
opts ...Option) (interface{}, error) {
source := buildSource(generate)
return MapReduceChan(source, mapper, reducer, opts...)
panicChan := &onceChan{channel: make(chan interface{})}
source := buildSource(generate, panicChan)
return mapReduceWithPanicChan(source, panicChan, mapper, reducer, opts...)
}
// MapReduceChan maps all elements from source, and reduce the output elements with given reducer.
func MapReduceChan(source <-chan interface{}, mapper MapperFunc, reducer ReducerFunc,
opts ...Option) (interface{}, error) {
panicChan := &onceChan{channel: make(chan interface{})}
return mapReduceWithPanicChan(source, panicChan, mapper, reducer, opts...)
}
// MapReduceChan maps all elements from source, and reduce the output elements with given reducer.
func mapReduceWithPanicChan(source <-chan interface{}, panicChan *onceChan, mapper MapperFunc,
reducer ReducerFunc, opts ...Option) (interface{}, error) {
options := buildOptions(opts...)
// output is used to write the final result
output := make(chan interface{})
defer func() {
// reducer can only write once, if more, panic
for range output {
panic("more than one element written in reducer")
}
}()
// collector is used to collect data from mapper, and consume in reducer
collector := make(chan interface{}, options.workers)
// if done is closed, all mappers and reducer should stop processing
done := make(chan lang.PlaceholderType)
writer := newGuardedWriter(options.ctx, output, done)
var closeOnce sync.Once
// use atomic.Value to avoid data race
var retErr errorx.AtomicError
finish := func() {
closeOnce.Do(func() {
@ -151,30 +186,38 @@ func MapReduceChan(source <-chan interface{}, mapper MapperFunc, reducer Reducer
go func() {
defer func() {
drain(collector)
if r := recover(); r != nil {
cancel(fmt.Errorf("%v", r))
} else {
finish()
panicChan.write(r)
}
finish()
}()
reducer(collector, writer, cancel)
}()
go executeMappers(options.ctx, func(item interface{}, w Writer) {
mapper(item, w, cancel)
}, source, collector, done, options.workers)
go executeMappers(mapperContext{
ctx: options.ctx,
mapper: func(item interface{}, w Writer) {
mapper(item, w, cancel)
},
source: source,
panicChan: panicChan,
collector: collector,
doneChan: done,
workers: options.workers,
})
select {
case <-options.ctx.Done():
cancel(context.DeadlineExceeded)
return nil, context.DeadlineExceeded
case value, ok := <-output:
case v := <-panicChan.channel:
panic(v)
case v, ok := <-output:
if err := retErr.Load(); err != nil {
return nil, err
} else if ok {
return value, nil
return v, nil
} else {
return nil, ErrReduceNoOutput
}
@ -221,12 +264,18 @@ func buildOptions(opts ...Option) *mapReduceOptions {
return options
}
func buildSource(generate GenerateFunc) chan interface{} {
func buildSource(generate GenerateFunc, panicChan *onceChan) chan interface{} {
source := make(chan interface{})
threading.GoSafe(func() {
defer close(source)
go func() {
defer func() {
if r := recover(); r != nil {
panicChan.write(r)
}
close(source)
}()
generate(source)
})
}()
return source
}
@ -238,39 +287,54 @@ func drain(channel <-chan interface{}) {
}
}
func executeMappers(ctx context.Context, mapper MapFunc, input <-chan interface{},
collector chan<- interface{}, done <-chan lang.PlaceholderType, workers int) {
func executeMappers(mCtx mapperContext) {
var wg sync.WaitGroup
pc := &onceChan{channel: make(chan interface{})}
defer func() {
// in case panic happens when processing last item, for loop not handling it.
select {
case r := <-pc.channel:
mCtx.panicChan.write(r)
default:
}
wg.Wait()
close(collector)
close(mCtx.collector)
drain(mCtx.source)
}()
pool := make(chan lang.PlaceholderType, workers)
writer := newGuardedWriter(ctx, collector, done)
pool := make(chan lang.PlaceholderType, mCtx.workers)
writer := newGuardedWriter(mCtx.ctx, mCtx.collector, mCtx.doneChan)
for {
select {
case <-ctx.Done():
case <-mCtx.ctx.Done():
return
case <-mCtx.doneChan:
return
case <-done:
case r := <-pc.channel:
// make sure this method quit ASAP,
// without this case branch, all the items from source will be consumed.
mCtx.panicChan.write(r)
return
case pool <- lang.Placeholder:
item, ok := <-input
item, ok := <-mCtx.source
if !ok {
<-pool
return
}
wg.Add(1)
// better to safely run caller defined method
threading.GoSafe(func() {
go func() {
defer func() {
if r := recover(); r != nil {
pc.write(r)
}
wg.Done()
<-pool
}()
mapper(item, writer)
})
mCtx.mapper(item, writer)
}()
}
}
}
@ -316,3 +380,16 @@ func (gw guardedWriter) Write(v interface{}) {
gw.channel <- v
}
}
type onceChan struct {
channel chan interface{}
wrote int32
}
func (oc *onceChan) write(val interface{}) {
if atomic.AddInt32(&oc.wrote, 1) > 1 {
return
}
oc.channel <- val
}

@ -0,0 +1,78 @@
//go:build go1.18
// +build go1.18
package mr
import (
"fmt"
"math/rand"
"runtime"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"go.uber.org/goleak"
)
func FuzzMapReduce(f *testing.F) {
rand.Seed(time.Now().UnixNano())
f.Add(int64(10), runtime.NumCPU())
f.Fuzz(func(t *testing.T, n int64, workers int) {
n = n%5000 + 5000
genPanic := rand.Intn(100) == 0
mapperPanic := rand.Intn(100) == 0
reducerPanic := rand.Intn(100) == 0
genIdx := rand.Int63n(n)
mapperIdx := rand.Int63n(n)
reducerIdx := rand.Int63n(n)
squareSum := (n - 1) * n * (2*n - 1) / 6
fn := func() (interface{}, error) {
defer goleak.VerifyNone(t, goleak.IgnoreCurrent())
return MapReduce(func(source chan<- interface{}) {
for i := int64(0); i < n; i++ {
source <- i
if genPanic && i == genIdx {
panic("foo")
}
}
}, func(item interface{}, writer Writer, cancel func(error)) {
v := item.(int64)
if mapperPanic && v == mapperIdx {
panic("bar")
}
writer.Write(v * v)
}, func(pipe <-chan interface{}, writer Writer, cancel func(error)) {
var idx int64
var total int64
for v := range pipe {
if reducerPanic && idx == reducerIdx {
panic("baz")
}
total += v.(int64)
idx++
}
writer.Write(total)
}, WithWorkers(workers%50+runtime.NumCPU()))
}
if genPanic || mapperPanic || reducerPanic {
var buf strings.Builder
buf.WriteString(fmt.Sprintf("n: %d", n))
buf.WriteString(fmt.Sprintf(", genPanic: %t", genPanic))
buf.WriteString(fmt.Sprintf(", mapperPanic: %t", mapperPanic))
buf.WriteString(fmt.Sprintf(", reducerPanic: %t", reducerPanic))
buf.WriteString(fmt.Sprintf(", genIdx: %d", genIdx))
buf.WriteString(fmt.Sprintf(", mapperIdx: %d", mapperIdx))
buf.WriteString(fmt.Sprintf(", reducerIdx: %d", reducerIdx))
assert.Panicsf(t, func() { fn() }, buf.String())
} else {
val, err := fn()
assert.Nil(t, err)
assert.Equal(t, squareSum, val.(int64))
}
})
}

@ -11,8 +11,6 @@ import (
"time"
"github.com/stretchr/testify/assert"
"github.com/zeromicro/go-zero/core/stringx"
"github.com/zeromicro/go-zero/core/syncx"
"go.uber.org/goleak"
)
@ -124,84 +122,69 @@ func TestForEach(t *testing.T) {
t.Run("all", func(t *testing.T) {
defer goleak.VerifyNone(t)
ForEach(func(source chan<- interface{}) {
for i := 0; i < tasks; i++ {
source <- i
}
}, func(item interface{}) {
panic("foo")
assert.PanicsWithValue(t, "foo", func() {
ForEach(func(source chan<- interface{}) {
for i := 0; i < tasks; i++ {
source <- i
}
}, func(item interface{}) {
panic("foo")
})
})
})
}
func TestMap(t *testing.T) {
func TestGeneratePanic(t *testing.T) {
defer goleak.VerifyNone(t)
tests := []struct {
mapper MapFunc
expect int
}{
{
mapper: func(item interface{}, writer Writer) {
v := item.(int)
writer.Write(v * v)
},
expect: 30,
},
{
mapper: func(item interface{}, writer Writer) {
v := item.(int)
if v%2 == 0 {
return
}
writer.Write(v * v)
},
expect: 10,
},
{
mapper: func(item interface{}, writer Writer) {
v := item.(int)
if v%2 == 0 {
panic(v)
}
writer.Write(v * v)
},
expect: 10,
},
}
t.Run("all", func(t *testing.T) {
assert.PanicsWithValue(t, "foo", func() {
ForEach(func(source chan<- interface{}) {
panic("foo")
}, func(item interface{}) {
})
})
})
}
for _, test := range tests {
t.Run(stringx.Rand(), func(t *testing.T) {
channel := Map(func(source chan<- interface{}) {
for i := 1; i < 5; i++ {
func TestMapperPanic(t *testing.T) {
defer goleak.VerifyNone(t)
const tasks = 1000
var run int32
t.Run("all", func(t *testing.T) {
assert.PanicsWithValue(t, "foo", func() {
_, _ = MapReduce(func(source chan<- interface{}) {
for i := 0; i < tasks; i++ {
source <- i
}
}, test.mapper, WithWorkers(-1))
var result int
for v := range channel {
result += v.(int)
}
assert.Equal(t, test.expect, result)
}, func(item interface{}, writer Writer, cancel func(error)) {
atomic.AddInt32(&run, 1)
panic("foo")
}, func(pipe <-chan interface{}, writer Writer, cancel func(error)) {
})
})
}
assert.True(t, atomic.LoadInt32(&run) < tasks/2)
})
}
func TestMapReduce(t *testing.T) {
defer goleak.VerifyNone(t)
tests := []struct {
name string
mapper MapperFunc
reducer ReducerFunc
expectErr error
expectValue interface{}
}{
{
name: "simple",
expectErr: nil,
expectValue: 30,
},
{
name: "cancel with error",
mapper: func(item interface{}, writer Writer, cancel func(error)) {
v := item.(int)
if v%3 == 0 {
@ -212,6 +195,7 @@ func TestMapReduce(t *testing.T) {
expectErr: errDummy,
},
{
name: "cancel with nil",
mapper: func(item interface{}, writer Writer, cancel func(error)) {
v := item.(int)
if v%3 == 0 {
@ -223,6 +207,7 @@ func TestMapReduce(t *testing.T) {
expectValue: nil,
},
{
name: "cancel with more",
reducer: func(pipe <-chan interface{}, writer Writer, cancel func(error)) {
var result int
for item := range pipe {
@ -237,45 +222,68 @@ func TestMapReduce(t *testing.T) {
},
}
for _, test := range tests {
t.Run(stringx.Rand(), func(t *testing.T) {
if test.mapper == nil {
test.mapper = func(item interface{}, writer Writer, cancel func(error)) {
v := item.(int)
writer.Write(v * v)
}
}
if test.reducer == nil {
test.reducer = func(pipe <-chan interface{}, writer Writer, cancel func(error)) {
var result int
for item := range pipe {
result += item.(int)
t.Run("MapReduce", func(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
if test.mapper == nil {
test.mapper = func(item interface{}, writer Writer, cancel func(error)) {
v := item.(int)
writer.Write(v * v)
}
writer.Write(result)
}
}
value, err := MapReduce(func(source chan<- interface{}) {
for i := 1; i < 5; i++ {
source <- i
if test.reducer == nil {
test.reducer = func(pipe <-chan interface{}, writer Writer, cancel func(error)) {
var result int
for item := range pipe {
result += item.(int)
}
writer.Write(result)
}
}
}, test.mapper, test.reducer, WithWorkers(runtime.NumCPU()))
value, err := MapReduce(func(source chan<- interface{}) {
for i := 1; i < 5; i++ {
source <- i
}
}, test.mapper, test.reducer, WithWorkers(runtime.NumCPU()))
assert.Equal(t, test.expectErr, err)
assert.Equal(t, test.expectValue, value)
})
}
}
assert.Equal(t, test.expectErr, err)
assert.Equal(t, test.expectValue, value)
})
}
})
func TestMapReducePanicBothMapperAndReducer(t *testing.T) {
defer goleak.VerifyNone(t)
t.Run("MapReduce", func(t *testing.T) {
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
if test.mapper == nil {
test.mapper = func(item interface{}, writer Writer, cancel func(error)) {
v := item.(int)
writer.Write(v * v)
}
}
if test.reducer == nil {
test.reducer = func(pipe <-chan interface{}, writer Writer, cancel func(error)) {
var result int
for item := range pipe {
result += item.(int)
}
writer.Write(result)
}
}
_, _ = MapReduce(func(source chan<- interface{}) {
source <- 0
source <- 1
}, func(item interface{}, writer Writer, cancel func(error)) {
panic("foo")
}, func(pipe <-chan interface{}, writer Writer, cancel func(error)) {
panic("bar")
source := make(chan interface{})
go func() {
for i := 1; i < 5; i++ {
source <- i
}
close(source)
}()
value, err := MapReduceChan(source, test.mapper, test.reducer, WithWorkers(-1))
assert.Equal(t, test.expectErr, err)
assert.Equal(t, test.expectValue, value)
})
}
})
}
@ -302,16 +310,19 @@ func TestMapReduceVoid(t *testing.T) {
var value uint32
tests := []struct {
name string
mapper MapperFunc
reducer VoidReducerFunc
expectValue uint32
expectErr error
}{
{
name: "simple",
expectValue: 30,
expectErr: nil,
},
{
name: "cancel with error",
mapper: func(item interface{}, writer Writer, cancel func(error)) {
v := item.(int)
if v%3 == 0 {
@ -322,6 +333,7 @@ func TestMapReduceVoid(t *testing.T) {
expectErr: errDummy,
},
{
name: "cancel with nil",
mapper: func(item interface{}, writer Writer, cancel func(error)) {
v := item.(int)
if v%3 == 0 {
@ -332,6 +344,7 @@ func TestMapReduceVoid(t *testing.T) {
expectErr: ErrCancelWithNil,
},
{
name: "cancel with more",
reducer: func(pipe <-chan interface{}, cancel func(error)) {
for item := range pipe {
result := atomic.AddUint32(&value, uint32(item.(int)))
@ -345,7 +358,7 @@ func TestMapReduceVoid(t *testing.T) {
}
for _, test := range tests {
t.Run(stringx.Rand(), func(t *testing.T) {
t.Run(test.name, func(t *testing.T) {
atomic.StoreUint32(&value, 0)
if test.mapper == nil {
@ -400,39 +413,59 @@ func TestMapReduceVoidWithDelay(t *testing.T) {
assert.Equal(t, 0, result[1])
}
func TestMapVoid(t *testing.T) {
func TestMapReducePanic(t *testing.T) {
defer goleak.VerifyNone(t)
const tasks = 1000
var count uint32
ForEach(func(source chan<- interface{}) {
for i := 0; i < tasks; i++ {
source <- i
}
}, func(item interface{}) {
atomic.AddUint32(&count, 1)
assert.Panics(t, func() {
_, _ = MapReduce(func(source chan<- interface{}) {
source <- 0
source <- 1
}, func(item interface{}, writer Writer, cancel func(error)) {
i := item.(int)
writer.Write(i)
}, func(pipe <-chan interface{}, writer Writer, cancel func(error)) {
for range pipe {
panic("panic")
}
})
})
}
assert.Equal(t, tasks, int(count))
func TestMapReducePanicOnce(t *testing.T) {
defer goleak.VerifyNone(t)
assert.Panics(t, func() {
_, _ = MapReduce(func(source chan<- interface{}) {
for i := 0; i < 100; i++ {
source <- i
}
}, func(item interface{}, writer Writer, cancel func(error)) {
i := item.(int)
if i == 0 {
panic("foo")
}
writer.Write(i)
}, func(pipe <-chan interface{}, writer Writer, cancel func(error)) {
for range pipe {
panic("bar")
}
})
})
}
func TestMapReducePanic(t *testing.T) {
func TestMapReducePanicBothMapperAndReducer(t *testing.T) {
defer goleak.VerifyNone(t)
v, err := MapReduce(func(source chan<- interface{}) {
source <- 0
source <- 1
}, func(item interface{}, writer Writer, cancel func(error)) {
i := item.(int)
writer.Write(i)
}, func(pipe <-chan interface{}, writer Writer, cancel func(error)) {
for range pipe {
panic("panic")
}
assert.Panics(t, func() {
_, _ = MapReduce(func(source chan<- interface{}) {
source <- 0
source <- 1
}, func(item interface{}, writer Writer, cancel func(error)) {
panic("foo")
}, func(pipe <-chan interface{}, writer Writer, cancel func(error)) {
panic("bar")
})
})
assert.Nil(t, v)
assert.NotNil(t, err)
assert.Equal(t, "panic", err.Error())
}
func TestMapReduceVoidCancel(t *testing.T) {
@ -461,13 +494,13 @@ func TestMapReduceVoidCancel(t *testing.T) {
func TestMapReduceVoidCancelWithRemains(t *testing.T) {
defer goleak.VerifyNone(t)
var done syncx.AtomicBool
var done int32
var result []int
err := MapReduceVoid(func(source chan<- interface{}) {
for i := 0; i < defaultWorkers*2; i++ {
source <- i
}
done.Set(true)
atomic.AddInt32(&done, 1)
}, func(item interface{}, writer Writer, cancel func(error)) {
i := item.(int)
if i == defaultWorkers/2 {
@ -482,7 +515,7 @@ func TestMapReduceVoidCancelWithRemains(t *testing.T) {
})
assert.NotNil(t, err)
assert.Equal(t, "anything", err.Error())
assert.True(t, done.True())
assert.Equal(t, int32(1), done)
}
func TestMapReduceWithoutReducerWrite(t *testing.T) {
@ -507,34 +540,51 @@ func TestMapReduceVoidPanicInReducer(t *testing.T) {
defer goleak.VerifyNone(t)
const message = "foo"
var done syncx.AtomicBool
err := MapReduceVoid(func(source chan<- interface{}) {
assert.Panics(t, func() {
var done int32
_ = MapReduceVoid(func(source chan<- interface{}) {
for i := 0; i < defaultWorkers*2; i++ {
source <- i
}
atomic.AddInt32(&done, 1)
}, func(item interface{}, writer Writer, cancel func(error)) {
i := item.(int)
writer.Write(i)
}, func(pipe <-chan interface{}, cancel func(error)) {
panic(message)
}, WithWorkers(1))
})
}
func TestForEachWithContext(t *testing.T) {
defer goleak.VerifyNone(t)
var done int32
ctx, cancel := context.WithCancel(context.Background())
ForEach(func(source chan<- interface{}) {
for i := 0; i < defaultWorkers*2; i++ {
source <- i
}
done.Set(true)
}, func(item interface{}, writer Writer, cancel func(error)) {
atomic.AddInt32(&done, 1)
}, func(item interface{}) {
i := item.(int)
writer.Write(i)
}, func(pipe <-chan interface{}, cancel func(error)) {
panic(message)
}, WithWorkers(1))
assert.NotNil(t, err)
assert.Equal(t, message, err.Error())
assert.True(t, done.True())
if i == defaultWorkers/2 {
cancel()
}
}, WithContext(ctx))
}
func TestMapReduceWithContext(t *testing.T) {
defer goleak.VerifyNone(t)
var done syncx.AtomicBool
var done int32
var result []int
ctx, cancel := context.WithCancel(context.Background())
err := MapReduceVoid(func(source chan<- interface{}) {
for i := 0; i < defaultWorkers*2; i++ {
source <- i
}
done.Set(true)
atomic.AddInt32(&done, 1)
}, func(item interface{}, writer Writer, c func(error)) {
i := item.(int)
if i == defaultWorkers/2 {

Loading…
Cancel
Save