You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
278 lines
6.1 KiB
Go
278 lines
6.1 KiB
Go
package mr
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"sync"
|
|
|
|
"github.com/tal-tech/go-zero/core/errorx"
|
|
"github.com/tal-tech/go-zero/core/lang"
|
|
"github.com/tal-tech/go-zero/core/syncx"
|
|
"github.com/tal-tech/go-zero/core/threading"
|
|
)
|
|
|
|
const (
|
|
defaultWorkers = 16
|
|
minWorkers = 1
|
|
)
|
|
|
|
var (
|
|
ErrCancelWithNil = errors.New("mapreduce cancelled with nil")
|
|
ErrReduceNoOutput = errors.New("reduce not writing value")
|
|
)
|
|
|
|
type (
|
|
GenerateFunc func(source chan<- interface{})
|
|
MapFunc func(item interface{}, writer Writer)
|
|
VoidMapFunc func(item interface{})
|
|
MapperFunc func(item interface{}, writer Writer, cancel func(error))
|
|
ReducerFunc func(pipe <-chan interface{}, writer Writer, cancel func(error))
|
|
VoidReducerFunc func(pipe <-chan interface{}, cancel func(error))
|
|
Option func(opts *mapReduceOptions)
|
|
|
|
mapReduceOptions struct {
|
|
workers int
|
|
}
|
|
|
|
Writer interface {
|
|
Write(v interface{})
|
|
}
|
|
)
|
|
|
|
func Finish(fns ...func() error) error {
|
|
if len(fns) == 0 {
|
|
return nil
|
|
}
|
|
|
|
return MapReduceVoid(func(source chan<- interface{}) {
|
|
for _, fn := range fns {
|
|
source <- fn
|
|
}
|
|
}, func(item interface{}, writer Writer, cancel func(error)) {
|
|
fn := item.(func() error)
|
|
if err := fn(); err != nil {
|
|
cancel(err)
|
|
}
|
|
}, func(pipe <-chan interface{}, cancel func(error)) {
|
|
drain(pipe)
|
|
}, WithWorkers(len(fns)))
|
|
}
|
|
|
|
func FinishVoid(fns ...func()) {
|
|
if len(fns) == 0 {
|
|
return
|
|
}
|
|
|
|
MapVoid(func(source chan<- interface{}) {
|
|
for _, fn := range fns {
|
|
source <- fn
|
|
}
|
|
}, func(item interface{}) {
|
|
fn := item.(func())
|
|
fn()
|
|
}, WithWorkers(len(fns)))
|
|
}
|
|
|
|
func Map(generate GenerateFunc, mapper MapFunc, opts ...Option) chan interface{} {
|
|
options := buildOptions(opts...)
|
|
source := buildSource(generate)
|
|
collector := make(chan interface{}, options.workers)
|
|
done := syncx.NewDoneChan()
|
|
|
|
go mapDispatcher(mapper, source, collector, done.Done(), options.workers)
|
|
|
|
return collector
|
|
}
|
|
|
|
func MapReduce(generate GenerateFunc, mapper MapperFunc, reducer ReducerFunc, opts ...Option) (interface{}, error) {
|
|
source := buildSource(generate)
|
|
return MapReduceWithSource(source, mapper, reducer, opts...)
|
|
}
|
|
|
|
func MapReduceWithSource(source <-chan interface{}, mapper MapperFunc, reducer ReducerFunc,
|
|
opts ...Option) (interface{}, error) {
|
|
options := buildOptions(opts...)
|
|
output := make(chan interface{})
|
|
collector := make(chan interface{}, options.workers)
|
|
done := syncx.NewDoneChan()
|
|
writer := newGuardedWriter(output, done.Done())
|
|
var closeOnce sync.Once
|
|
var retErr errorx.AtomicError
|
|
finish := func() {
|
|
closeOnce.Do(func() {
|
|
done.Close()
|
|
close(output)
|
|
})
|
|
}
|
|
cancel := once(func(err error) {
|
|
if err != nil {
|
|
retErr.Set(err)
|
|
} else {
|
|
retErr.Set(ErrCancelWithNil)
|
|
}
|
|
|
|
drain(source)
|
|
finish()
|
|
})
|
|
|
|
go func() {
|
|
defer func() {
|
|
if r := recover(); r != nil {
|
|
cancel(fmt.Errorf("%v", r))
|
|
} else {
|
|
finish()
|
|
}
|
|
}()
|
|
reducer(collector, writer, cancel)
|
|
drain(collector)
|
|
}()
|
|
go mapperDispatcher(mapper, source, collector, done.Done(), cancel, options.workers)
|
|
|
|
value, ok := <-output
|
|
if err := retErr.Load(); err != nil {
|
|
return nil, err
|
|
} else if ok {
|
|
return value, nil
|
|
} else {
|
|
return nil, ErrReduceNoOutput
|
|
}
|
|
}
|
|
|
|
func MapReduceVoid(generator GenerateFunc, mapper MapperFunc, reducer VoidReducerFunc, opts ...Option) error {
|
|
_, err := MapReduce(generator, mapper, func(input <-chan interface{}, writer Writer, cancel func(error)) {
|
|
reducer(input, cancel)
|
|
drain(input)
|
|
// We need to write a placeholder to let MapReduce to continue on reducer done,
|
|
// otherwise, all goroutines are waiting. The placeholder will be discarded by MapReduce.
|
|
writer.Write(lang.Placeholder)
|
|
}, opts...)
|
|
return err
|
|
}
|
|
|
|
func MapVoid(generate GenerateFunc, mapper VoidMapFunc, opts ...Option) {
|
|
drain(Map(generate, func(item interface{}, writer Writer) {
|
|
mapper(item)
|
|
}, opts...))
|
|
}
|
|
|
|
func WithWorkers(workers int) Option {
|
|
return func(opts *mapReduceOptions) {
|
|
if workers < minWorkers {
|
|
opts.workers = minWorkers
|
|
} else {
|
|
opts.workers = workers
|
|
}
|
|
}
|
|
}
|
|
|
|
func buildOptions(opts ...Option) *mapReduceOptions {
|
|
options := newOptions()
|
|
for _, opt := range opts {
|
|
opt(options)
|
|
}
|
|
|
|
return options
|
|
}
|
|
|
|
func buildSource(generate GenerateFunc) chan interface{} {
|
|
source := make(chan interface{})
|
|
threading.GoSafe(func() {
|
|
defer close(source)
|
|
generate(source)
|
|
})
|
|
|
|
return source
|
|
}
|
|
|
|
// drain drains the channel.
|
|
func drain(channel <-chan interface{}) {
|
|
// drain the channel
|
|
for range channel {
|
|
}
|
|
}
|
|
|
|
func executeMappers(mapper MapFunc, input <-chan interface{}, collector chan<- interface{},
|
|
done <-chan lang.PlaceholderType, workers int) {
|
|
var wg sync.WaitGroup
|
|
defer func() {
|
|
wg.Wait()
|
|
close(collector)
|
|
}()
|
|
|
|
pool := make(chan lang.PlaceholderType, workers)
|
|
writer := newGuardedWriter(collector, done)
|
|
for {
|
|
select {
|
|
case <-done:
|
|
return
|
|
case pool <- lang.Placeholder:
|
|
item, ok := <-input
|
|
if !ok {
|
|
<-pool
|
|
return
|
|
}
|
|
|
|
wg.Add(1)
|
|
// better to safely run caller defined method
|
|
threading.GoSafe(func() {
|
|
defer func() {
|
|
wg.Done()
|
|
<-pool
|
|
}()
|
|
|
|
mapper(item, writer)
|
|
})
|
|
}
|
|
}
|
|
}
|
|
|
|
func mapDispatcher(mapper MapFunc, input <-chan interface{}, collector chan<- interface{},
|
|
done <-chan lang.PlaceholderType, workers int) {
|
|
executeMappers(func(item interface{}, writer Writer) {
|
|
mapper(item, writer)
|
|
}, input, collector, done, workers)
|
|
}
|
|
|
|
func mapperDispatcher(mapper MapperFunc, input <-chan interface{}, collector chan<- interface{},
|
|
done <-chan lang.PlaceholderType, cancel func(error), workers int) {
|
|
executeMappers(func(item interface{}, writer Writer) {
|
|
mapper(item, writer, cancel)
|
|
}, input, collector, done, workers)
|
|
}
|
|
|
|
func newOptions() *mapReduceOptions {
|
|
return &mapReduceOptions{
|
|
workers: defaultWorkers,
|
|
}
|
|
}
|
|
|
|
func once(fn func(error)) func(error) {
|
|
once := new(sync.Once)
|
|
return func(err error) {
|
|
once.Do(func() {
|
|
fn(err)
|
|
})
|
|
}
|
|
}
|
|
|
|
type guardedWriter struct {
|
|
channel chan<- interface{}
|
|
done <-chan lang.PlaceholderType
|
|
}
|
|
|
|
func newGuardedWriter(channel chan<- interface{}, done <-chan lang.PlaceholderType) guardedWriter {
|
|
return guardedWriter{
|
|
channel: channel,
|
|
done: done,
|
|
}
|
|
}
|
|
|
|
func (gw guardedWriter) Write(v interface{}) {
|
|
select {
|
|
case <-gw.done:
|
|
return
|
|
default:
|
|
gw.channel <- v
|
|
}
|
|
}
|