feat: support context in MapReduce (#1368)

master
Kevin Wan 3 years ago committed by GitHub
parent 8745ed9c61
commit c0647f0719
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,6 +1,7 @@
package mr package mr
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"sync" "sync"
@ -43,6 +44,7 @@ type (
Option func(opts *mapReduceOptions) Option func(opts *mapReduceOptions)
mapReduceOptions struct { mapReduceOptions struct {
ctx context.Context
workers int workers int
} }
@ -95,14 +97,15 @@ func Map(generate GenerateFunc, mapper MapFunc, opts ...Option) chan interface{}
collector := make(chan interface{}, options.workers) collector := make(chan interface{}, options.workers)
done := syncx.NewDoneChan() done := syncx.NewDoneChan()
go executeMappers(mapper, source, collector, done.Done(), options.workers) go executeMappers(options.ctx, mapper, source, collector, done.Done(), options.workers)
return collector return collector
} }
// MapReduce maps all elements generated from given generate func, // MapReduce maps all elements generated from given generate func,
// and reduces the output elements with given reducer. // and reduces the output elements with given reducer.
func MapReduce(generate GenerateFunc, mapper MapperFunc, reducer ReducerFunc, opts ...Option) (interface{}, error) { func MapReduce(generate GenerateFunc, mapper MapperFunc, reducer ReducerFunc,
opts ...Option) (interface{}, error) {
source := buildSource(generate) source := buildSource(generate)
return MapReduceWithSource(source, mapper, reducer, opts...) return MapReduceWithSource(source, mapper, reducer, opts...)
} }
@ -120,7 +123,7 @@ func MapReduceWithSource(source <-chan interface{}, mapper MapperFunc, reducer R
collector := make(chan interface{}, options.workers) collector := make(chan interface{}, options.workers)
done := syncx.NewDoneChan() done := syncx.NewDoneChan()
writer := newGuardedWriter(output, done.Done()) writer := newGuardedWriter(options.ctx, output, done.Done())
var closeOnce sync.Once var closeOnce sync.Once
var retErr errorx.AtomicError var retErr errorx.AtomicError
finish := func() { finish := func() {
@ -154,7 +157,7 @@ func MapReduceWithSource(source <-chan interface{}, mapper MapperFunc, reducer R
reducer(collector, writer, cancel) reducer(collector, writer, cancel)
}() }()
go executeMappers(func(item interface{}, w Writer) { go executeMappers(options.ctx, func(item interface{}, w Writer) {
mapper(item, w, cancel) mapper(item, w, cancel)
}, source, collector, done.Done(), options.workers) }, source, collector, done.Done(), options.workers)
@ -187,6 +190,13 @@ func MapVoid(generate GenerateFunc, mapper VoidMapFunc, opts ...Option) {
}, opts...)) }, opts...))
} }
// WithContext customizes a mapreduce processing accepts a given ctx.
func WithContext(ctx context.Context) Option {
return func(opts *mapReduceOptions) {
opts.ctx = ctx
}
}
// WithWorkers customizes a mapreduce processing with given workers. // WithWorkers customizes a mapreduce processing with given workers.
func WithWorkers(workers int) Option { func WithWorkers(workers int) Option {
return func(opts *mapReduceOptions) { return func(opts *mapReduceOptions) {
@ -224,8 +234,8 @@ func drain(channel <-chan interface{}) {
} }
} }
func executeMappers(mapper MapFunc, input <-chan interface{}, collector chan<- interface{}, func executeMappers(ctx context.Context, mapper MapFunc, input <-chan interface{},
done <-chan lang.PlaceholderType, workers int) { collector chan<- interface{}, done <-chan lang.PlaceholderType, workers int) {
var wg sync.WaitGroup var wg sync.WaitGroup
defer func() { defer func() {
wg.Wait() wg.Wait()
@ -233,9 +243,11 @@ func executeMappers(mapper MapFunc, input <-chan interface{}, collector chan<- i
}() }()
pool := make(chan lang.PlaceholderType, workers) pool := make(chan lang.PlaceholderType, workers)
writer := newGuardedWriter(collector, done) writer := newGuardedWriter(ctx, collector, done)
for { for {
select { select {
case <-ctx.Done():
return
case <-done: case <-done:
return return
case pool <- lang.Placeholder: case pool <- lang.Placeholder:
@ -261,6 +273,7 @@ func executeMappers(mapper MapFunc, input <-chan interface{}, collector chan<- i
func newOptions() *mapReduceOptions { func newOptions() *mapReduceOptions {
return &mapReduceOptions{ return &mapReduceOptions{
ctx: context.Background(),
workers: defaultWorkers, workers: defaultWorkers,
} }
} }
@ -275,12 +288,15 @@ func once(fn func(error)) func(error) {
} }
type guardedWriter struct { type guardedWriter struct {
ctx context.Context
channel chan<- interface{} channel chan<- interface{}
done <-chan lang.PlaceholderType done <-chan lang.PlaceholderType
} }
func newGuardedWriter(channel chan<- interface{}, done <-chan lang.PlaceholderType) guardedWriter { func newGuardedWriter(ctx context.Context, channel chan<- interface{},
done <-chan lang.PlaceholderType) guardedWriter {
return guardedWriter{ return guardedWriter{
ctx: ctx,
channel: channel, channel: channel,
done: done, done: done,
} }
@ -288,6 +304,8 @@ func newGuardedWriter(channel chan<- interface{}, done <-chan lang.PlaceholderTy
func (gw guardedWriter) Write(v interface{}) { func (gw guardedWriter) Write(v interface{}) {
select { select {
case <-gw.ctx.Done():
return
case <-gw.done: case <-gw.done:
return return
default: default:

@ -1,6 +1,7 @@
package mr package mr
import ( import (
"context"
"errors" "errors"
"io/ioutil" "io/ioutil"
"log" "log"
@ -410,6 +411,50 @@ func TestMapReduceWithoutReducerWrite(t *testing.T) {
assert.Nil(t, res) assert.Nil(t, res)
} }
func TestMapReduceVoidPanicInReducer(t *testing.T) {
const message = "foo"
var done syncx.AtomicBool
err := MapReduceVoid(func(source chan<- interface{}) {
for i := 0; i < defaultWorkers*2; i++ {
source <- i
}
done.Set(true)
}, func(item interface{}, writer Writer, cancel func(error)) {
i := item.(int)
writer.Write(i)
}, func(pipe <-chan interface{}, cancel func(error)) {
panic(message)
}, WithWorkers(1))
assert.NotNil(t, err)
assert.Equal(t, message, err.Error())
assert.True(t, done.True())
}
func TestMapReduceWithContext(t *testing.T) {
var done syncx.AtomicBool
var result []int
ctx, cancel := context.WithCancel(context.Background())
err := MapReduceVoid(func(source chan<- interface{}) {
for i := 0; i < defaultWorkers*2; i++ {
source <- i
}
done.Set(true)
}, func(item interface{}, writer Writer, c func(error)) {
i := item.(int)
if i == defaultWorkers/2 {
cancel()
}
writer.Write(i)
}, func(pipe <-chan interface{}, cancel func(error)) {
for item := range pipe {
i := item.(int)
result = append(result, i)
}
}, WithContext(ctx))
assert.NotNil(t, err)
assert.Equal(t, ErrReduceNoOutput, err)
}
func BenchmarkMapReduce(b *testing.B) { func BenchmarkMapReduce(b *testing.B) {
b.ReportAllocs() b.ReportAllocs()

Loading…
Cancel
Save