You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
go-zero/core/mr/mapreduce.go

278 lines
6.1 KiB
Go

package mr
4 years ago
import (
"errors"
"fmt"
"sync"
"github.com/tal-tech/go-zero/core/errorx"
"github.com/tal-tech/go-zero/core/lang"
"github.com/tal-tech/go-zero/core/syncx"
"github.com/tal-tech/go-zero/core/threading"
4 years ago
)
const (
defaultWorkers = 16
minWorkers = 1
)
var (
ErrCancelWithNil = errors.New("mapreduce cancelled with nil")
ErrReduceNoOutput = errors.New("reduce not writing value")
)
4 years ago
type (
GenerateFunc func(source chan<- interface{})
MapFunc func(item interface{}, writer Writer)
VoidMapFunc func(item interface{})
MapperFunc func(item interface{}, writer Writer, cancel func(error))
ReducerFunc func(pipe <-chan interface{}, writer Writer, cancel func(error))
VoidReducerFunc func(pipe <-chan interface{}, cancel func(error))
Option func(opts *mapReduceOptions)
mapReduceOptions struct {
workers int
}
Writer interface {
Write(v interface{})
}
)
func Finish(fns ...func() error) error {
if len(fns) == 0 {
return nil
}
return MapReduceVoid(func(source chan<- interface{}) {
for _, fn := range fns {
source <- fn
}
}, func(item interface{}, writer Writer, cancel func(error)) {
fn := item.(func() error)
if err := fn(); err != nil {
cancel(err)
}
}, func(pipe <-chan interface{}, cancel func(error)) {
drain(pipe)
}, WithWorkers(len(fns)))
}
func FinishVoid(fns ...func()) {
if len(fns) == 0 {
return
}
MapVoid(func(source chan<- interface{}) {
for _, fn := range fns {
source <- fn
}
}, func(item interface{}) {
fn := item.(func())
fn()
}, WithWorkers(len(fns)))
}
func Map(generate GenerateFunc, mapper MapFunc, opts ...Option) chan interface{} {
options := buildOptions(opts...)
source := buildSource(generate)
collector := make(chan interface{}, options.workers)
done := syncx.NewDoneChan()
go mapDispatcher(mapper, source, collector, done.Done(), options.workers)
return collector
}
func MapReduce(generate GenerateFunc, mapper MapperFunc, reducer ReducerFunc, opts ...Option) (interface{}, error) {
source := buildSource(generate)
return MapReduceWithSource(source, mapper, reducer, opts...)
}
func MapReduceWithSource(source <-chan interface{}, mapper MapperFunc, reducer ReducerFunc,
opts ...Option) (interface{}, error) {
options := buildOptions(opts...)
output := make(chan interface{})
collector := make(chan interface{}, options.workers)
done := syncx.NewDoneChan()
writer := newGuardedWriter(output, done.Done())
var closeOnce sync.Once
4 years ago
var retErr errorx.AtomicError
finish := func() {
closeOnce.Do(func() {
done.Close()
close(output)
})
}
4 years ago
cancel := once(func(err error) {
if err != nil {
retErr.Set(err)
} else {
retErr.Set(ErrCancelWithNil)
}
drain(source)
finish()
4 years ago
})
go func() {
defer func() {
if r := recover(); r != nil {
cancel(fmt.Errorf("%v", r))
} else {
finish()
4 years ago
}
}()
reducer(collector, writer, cancel)
drain(collector)
4 years ago
}()
go mapperDispatcher(mapper, source, collector, done.Done(), cancel, options.workers)
value, ok := <-output
if err := retErr.Load(); err != nil {
return nil, err
} else if ok {
return value, nil
} else {
return nil, ErrReduceNoOutput
4 years ago
}
}
func MapReduceVoid(generator GenerateFunc, mapper MapperFunc, reducer VoidReducerFunc, opts ...Option) error {
_, err := MapReduce(generator, mapper, func(input <-chan interface{}, writer Writer, cancel func(error)) {
reducer(input, cancel)
drain(input)
4 years ago
// We need to write a placeholder to let MapReduce to continue on reducer done,
// otherwise, all goroutines are waiting. The placeholder will be discarded by MapReduce.
writer.Write(lang.Placeholder)
}, opts...)
return err
}
func MapVoid(generate GenerateFunc, mapper VoidMapFunc, opts ...Option) {
drain(Map(generate, func(item interface{}, writer Writer) {
mapper(item)
}, opts...))
}
func WithWorkers(workers int) Option {
return func(opts *mapReduceOptions) {
if workers < minWorkers {
opts.workers = minWorkers
} else {
opts.workers = workers
}
}
}
func buildOptions(opts ...Option) *mapReduceOptions {
options := newOptions()
for _, opt := range opts {
opt(options)
}
return options
}
func buildSource(generate GenerateFunc) chan interface{} {
source := make(chan interface{})
threading.GoSafe(func() {
defer close(source)
generate(source)
})
return source
}
// drain drains the channel.
func drain(channel <-chan interface{}) {
// drain the channel
for range channel {
}
}
func executeMappers(mapper MapFunc, input <-chan interface{}, collector chan<- interface{},
done <-chan lang.PlaceholderType, workers int) {
var wg sync.WaitGroup
defer func() {
wg.Wait()
close(collector)
}()
pool := make(chan lang.PlaceholderType, workers)
writer := newGuardedWriter(collector, done)
for {
select {
case <-done:
return
case pool <- lang.Placeholder:
item, ok := <-input
if !ok {
<-pool
return
}
wg.Add(1)
// better to safely run caller defined method
threading.GoSafe(func() {
defer func() {
wg.Done()
<-pool
}()
mapper(item, writer)
})
}
}
}
func mapDispatcher(mapper MapFunc, input <-chan interface{}, collector chan<- interface{},
done <-chan lang.PlaceholderType, workers int) {
executeMappers(func(item interface{}, writer Writer) {
mapper(item, writer)
}, input, collector, done, workers)
}
func mapperDispatcher(mapper MapperFunc, input <-chan interface{}, collector chan<- interface{},
done <-chan lang.PlaceholderType, cancel func(error), workers int) {
executeMappers(func(item interface{}, writer Writer) {
mapper(item, writer, cancel)
}, input, collector, done, workers)
}
func newOptions() *mapReduceOptions {
return &mapReduceOptions{
workers: defaultWorkers,
}
}
func once(fn func(error)) func(error) {
once := new(sync.Once)
return func(err error) {
once.Do(func() {
fn(err)
})
}
}
type guardedWriter struct {
channel chan<- interface{}
done <-chan lang.PlaceholderType
}
func newGuardedWriter(channel chan<- interface{}, done <-chan lang.PlaceholderType) guardedWriter {
return guardedWriter{
channel: channel,
done: done,
}
}
func (gw guardedWriter) Write(v interface{}) {
select {
case <-gw.done:
return
default:
gw.channel <- v
}
}