diff --git a/core/mr/mapreduce.go b/core/mr/mapreduce.go index 83c3c4dd..dce1a3bd 100644 --- a/core/mr/mapreduce.go +++ b/core/mr/mapreduce.go @@ -1,6 +1,7 @@ package mr import ( + "context" "errors" "fmt" "sync" @@ -43,6 +44,7 @@ type ( Option func(opts *mapReduceOptions) mapReduceOptions struct { + ctx context.Context workers int } @@ -95,14 +97,15 @@ func Map(generate GenerateFunc, mapper MapFunc, opts ...Option) chan interface{} collector := make(chan interface{}, options.workers) done := syncx.NewDoneChan() - go executeMappers(mapper, source, collector, done.Done(), options.workers) + go executeMappers(options.ctx, mapper, source, collector, done.Done(), options.workers) return collector } // MapReduce maps all elements generated from given generate func, // and reduces the output elements with given reducer. -func MapReduce(generate GenerateFunc, mapper MapperFunc, reducer ReducerFunc, opts ...Option) (interface{}, error) { +func MapReduce(generate GenerateFunc, mapper MapperFunc, reducer ReducerFunc, + opts ...Option) (interface{}, error) { source := buildSource(generate) return MapReduceWithSource(source, mapper, reducer, opts...) } @@ -120,7 +123,7 @@ func MapReduceWithSource(source <-chan interface{}, mapper MapperFunc, reducer R collector := make(chan interface{}, options.workers) done := syncx.NewDoneChan() - writer := newGuardedWriter(output, done.Done()) + writer := newGuardedWriter(options.ctx, output, done.Done()) var closeOnce sync.Once var retErr errorx.AtomicError finish := func() { @@ -154,7 +157,7 @@ func MapReduceWithSource(source <-chan interface{}, mapper MapperFunc, reducer R reducer(collector, writer, cancel) }() - go executeMappers(func(item interface{}, w Writer) { + go executeMappers(options.ctx, func(item interface{}, w Writer) { mapper(item, w, cancel) }, source, collector, done.Done(), options.workers) @@ -187,6 +190,13 @@ func MapVoid(generate GenerateFunc, mapper VoidMapFunc, opts ...Option) { }, opts...)) } +// WithContext customizes a mapreduce processing accepts a given ctx. +func WithContext(ctx context.Context) Option { + return func(opts *mapReduceOptions) { + opts.ctx = ctx + } +} + // WithWorkers customizes a mapreduce processing with given workers. func WithWorkers(workers int) Option { return func(opts *mapReduceOptions) { @@ -224,8 +234,8 @@ func drain(channel <-chan interface{}) { } } -func executeMappers(mapper MapFunc, input <-chan interface{}, collector chan<- interface{}, - done <-chan lang.PlaceholderType, workers int) { +func executeMappers(ctx context.Context, mapper MapFunc, input <-chan interface{}, + collector chan<- interface{}, done <-chan lang.PlaceholderType, workers int) { var wg sync.WaitGroup defer func() { wg.Wait() @@ -233,9 +243,11 @@ func executeMappers(mapper MapFunc, input <-chan interface{}, collector chan<- i }() pool := make(chan lang.PlaceholderType, workers) - writer := newGuardedWriter(collector, done) + writer := newGuardedWriter(ctx, collector, done) for { select { + case <-ctx.Done(): + return case <-done: return case pool <- lang.Placeholder: @@ -261,6 +273,7 @@ func executeMappers(mapper MapFunc, input <-chan interface{}, collector chan<- i func newOptions() *mapReduceOptions { return &mapReduceOptions{ + ctx: context.Background(), workers: defaultWorkers, } } @@ -275,12 +288,15 @@ func once(fn func(error)) func(error) { } type guardedWriter struct { + ctx context.Context channel chan<- interface{} done <-chan lang.PlaceholderType } -func newGuardedWriter(channel chan<- interface{}, done <-chan lang.PlaceholderType) guardedWriter { +func newGuardedWriter(ctx context.Context, channel chan<- interface{}, + done <-chan lang.PlaceholderType) guardedWriter { return guardedWriter{ + ctx: ctx, channel: channel, done: done, } @@ -288,6 +304,8 @@ func newGuardedWriter(channel chan<- interface{}, done <-chan lang.PlaceholderTy func (gw guardedWriter) Write(v interface{}) { select { + case <-gw.ctx.Done(): + return case <-gw.done: return default: diff --git a/core/mr/mapreduce_test.go b/core/mr/mapreduce_test.go index 5c736c9c..da61b4c5 100644 --- a/core/mr/mapreduce_test.go +++ b/core/mr/mapreduce_test.go @@ -1,6 +1,7 @@ package mr import ( + "context" "errors" "io/ioutil" "log" @@ -410,6 +411,50 @@ func TestMapReduceWithoutReducerWrite(t *testing.T) { assert.Nil(t, res) } +func TestMapReduceVoidPanicInReducer(t *testing.T) { + const message = "foo" + var done syncx.AtomicBool + err := MapReduceVoid(func(source chan<- interface{}) { + for i := 0; i < defaultWorkers*2; i++ { + source <- i + } + done.Set(true) + }, func(item interface{}, writer Writer, cancel func(error)) { + i := item.(int) + writer.Write(i) + }, func(pipe <-chan interface{}, cancel func(error)) { + panic(message) + }, WithWorkers(1)) + assert.NotNil(t, err) + assert.Equal(t, message, err.Error()) + assert.True(t, done.True()) +} + +func TestMapReduceWithContext(t *testing.T) { + var done syncx.AtomicBool + var result []int + ctx, cancel := context.WithCancel(context.Background()) + err := MapReduceVoid(func(source chan<- interface{}) { + for i := 0; i < defaultWorkers*2; i++ { + source <- i + } + done.Set(true) + }, func(item interface{}, writer Writer, c func(error)) { + i := item.(int) + if i == defaultWorkers/2 { + cancel() + } + writer.Write(i) + }, func(pipe <-chan interface{}, cancel func(error)) { + for item := range pipe { + i := item.(int) + result = append(result, i) + } + }, WithContext(ctx)) + assert.NotNil(t, err) + assert.Equal(t, ErrReduceNoOutput, err) +} + func BenchmarkMapReduce(b *testing.B) { b.ReportAllocs()