From 8d6d37f71e9915553d516424d56fc3191de7d7f4 Mon Sep 17 00:00:00 2001 From: Kevin Wan Date: Tue, 11 Jan 2022 16:17:51 +0800 Subject: [PATCH] remove unnecessary drain, fix data race (#1435) * remove unnecessary drain, fix data race * chore: fix parameter order * refactor: rename MapVoid to ForEach in mr --- core/mr/mapreduce.go | 34 +++++++++++++------------- core/mr/mapreduce_test.go | 50 ++++++++++++++++++++++++++++++++++++++- go.mod | 2 +- 3 files changed, 67 insertions(+), 19 deletions(-) diff --git a/core/mr/mapreduce.go b/core/mr/mapreduce.go index 70ea52d5..80c54aa9 100644 --- a/core/mr/mapreduce.go +++ b/core/mr/mapreduce.go @@ -24,12 +24,12 @@ var ( ) type ( + // ForEachFunc is used to do element processing, but no output. + ForEachFunc func(item interface{}) // GenerateFunc is used to let callers send elements into source. GenerateFunc func(source chan<- interface{}) // MapFunc is used to do element processing and write the output to writer. MapFunc func(item interface{}, writer Writer) - // VoidMapFunc is used to do element processing, but no output. - VoidMapFunc func(item interface{}) // MapperFunc is used to do element processing and write the output to writer, // use cancel func to cancel the processing. MapperFunc func(item interface{}, writer Writer, cancel func(error)) @@ -69,7 +69,6 @@ func Finish(fns ...func() error) error { cancel(err) } }, func(pipe <-chan interface{}, cancel func(error)) { - drain(pipe) }, WithWorkers(len(fns))) } @@ -79,7 +78,7 @@ func FinishVoid(fns ...func()) { return } - MapVoid(func(source chan<- interface{}) { + ForEach(func(source chan<- interface{}) { for _, fn := range fns { source <- fn } @@ -89,6 +88,13 @@ func FinishVoid(fns ...func()) { }, WithWorkers(len(fns))) } +// ForEach maps all elements from given generate but no output. +func ForEach(generate GenerateFunc, mapper ForEachFunc, opts ...Option) { + drain(Map(generate, func(item interface{}, writer Writer) { + mapper(item) + }, opts...)) +} + // Map maps all elements generated from given generate func, and returns an output channel. func Map(generate GenerateFunc, mapper MapFunc, opts ...Option) chan interface{} { options := buildOptions(opts...) @@ -106,11 +112,11 @@ func Map(generate GenerateFunc, mapper MapFunc, opts ...Option) chan interface{} func MapReduce(generate GenerateFunc, mapper MapperFunc, reducer ReducerFunc, opts ...Option) (interface{}, error) { source := buildSource(generate) - return MapReduceWithSource(source, mapper, reducer, opts...) + return MapReduceChan(source, mapper, reducer, opts...) } -// MapReduceWithSource maps all elements from source, and reduce the output elements with given reducer. -func MapReduceWithSource(source <-chan interface{}, mapper MapperFunc, reducer ReducerFunc, +// MapReduceChan maps all elements from source, and reduce the output elements with given reducer. +func MapReduceChan(source <-chan interface{}, mapper MapperFunc, reducer ReducerFunc, opts ...Option) (interface{}, error) { options := buildOptions(opts...) output := make(chan interface{}) @@ -180,18 +186,12 @@ func MapReduceWithSource(source <-chan interface{}, mapper MapperFunc, reducer R func MapReduceVoid(generate GenerateFunc, mapper MapperFunc, reducer VoidReducerFunc, opts ...Option) error { _, err := MapReduce(generate, mapper, func(input <-chan interface{}, writer Writer, cancel func(error)) { reducer(input, cancel) - // We need to write a placeholder to let MapReduce to continue on reducer done, - // otherwise, all goroutines are waiting. The placeholder will be discarded by MapReduce. - writer.Write(lang.Placeholder) }, opts...) - return err -} + if errors.Is(err, ErrReduceNoOutput) { + return nil + } -// MapVoid maps all elements from given generate but no output. -func MapVoid(generate GenerateFunc, mapper VoidMapFunc, opts ...Option) { - drain(Map(generate, func(item interface{}, writer Writer) { - mapper(item) - }, opts...)) + return err } // WithContext customizes a mapreduce processing accepts a given ctx. diff --git a/core/mr/mapreduce_test.go b/core/mr/mapreduce_test.go index 324f4ad8..252aed51 100644 --- a/core/mr/mapreduce_test.go +++ b/core/mr/mapreduce_test.go @@ -86,6 +86,54 @@ func TestFinishVoid(t *testing.T) { assert.Equal(t, uint32(10), atomic.LoadUint32(&total)) } +func TestForEach(t *testing.T) { + const tasks = 1000 + + t.Run("all", func(t *testing.T) { + defer goleak.VerifyNone(t) + + var count uint32 + ForEach(func(source chan<- interface{}) { + for i := 0; i < tasks; i++ { + source <- i + } + }, func(item interface{}) { + atomic.AddUint32(&count, 1) + }, WithWorkers(-1)) + + assert.Equal(t, tasks, int(count)) + }) + + t.Run("odd", func(t *testing.T) { + defer goleak.VerifyNone(t) + + var count uint32 + ForEach(func(source chan<- interface{}) { + for i := 0; i < tasks; i++ { + source <- i + } + }, func(item interface{}) { + if item.(int)%2 == 0 { + atomic.AddUint32(&count, 1) + } + }) + + assert.Equal(t, tasks/2, int(count)) + }) + + t.Run("all", func(t *testing.T) { + defer goleak.VerifyNone(t) + + ForEach(func(source chan<- interface{}) { + for i := 0; i < tasks; i++ { + source <- i + } + }, func(item interface{}) { + panic("foo") + }) + }) +} + func TestMap(t *testing.T) { defer goleak.VerifyNone(t) @@ -344,7 +392,7 @@ func TestMapVoid(t *testing.T) { const tasks = 1000 var count uint32 - MapVoid(func(source chan<- interface{}) { + ForEach(func(source chan<- interface{}) { for i := 0; i < tasks; i++ { source <- i } diff --git a/go.mod b/go.mod index 240efed3..7e4b4478 100644 --- a/go.mod +++ b/go.mod @@ -29,7 +29,7 @@ require ( go.opentelemetry.io/otel/sdk v1.1.0 go.opentelemetry.io/otel/trace v1.1.0 go.uber.org/automaxprocs v1.4.0 - go.uber.org/goleak v1.1.12 // indirect + go.uber.org/goleak v1.1.12 golang.org/x/net v0.0.0-20211015210444-4f30a5c0130f // indirect golang.org/x/sys v0.0.0-20211106132015-ebca88c72f68 // indirect golang.org/x/time v0.0.0-20210723032227-1f47c861a9ac