feat: a concurrent runner with messages taken in pushing order (#3941)
parent
c98d5fdaf4
commit
a1bacd3fc8
@ -0,0 +1,105 @@
|
||||
package threading
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"runtime"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
)
|
||||
|
||||
const factor = 10
|
||||
|
||||
var (
|
||||
ErrRunnerClosed = errors.New("runner closed")
|
||||
|
||||
bufSize = runtime.NumCPU() * factor
|
||||
)
|
||||
|
||||
// StableRunner is a runner that guarantees messages are taken out with the pushed order.
|
||||
// This runner is typically useful for Kafka consumers with parallel processing.
|
||||
type StableRunner[I, O any] struct {
|
||||
handle func(I) O
|
||||
consumedIndex uint64
|
||||
writtenIndex uint64
|
||||
ring []*struct {
|
||||
value chan O
|
||||
lock sync.Mutex
|
||||
}
|
||||
runner *TaskRunner
|
||||
done chan struct{}
|
||||
}
|
||||
|
||||
// NewStableRunner returns a new StableRunner with given message processor fn.
|
||||
func NewStableRunner[I, O any](fn func(I) O) *StableRunner[I, O] {
|
||||
ring := make([]*struct {
|
||||
value chan O
|
||||
lock sync.Mutex
|
||||
}, bufSize)
|
||||
for i := 0; i < bufSize; i++ {
|
||||
ring[i] = &struct {
|
||||
value chan O
|
||||
lock sync.Mutex
|
||||
}{
|
||||
value: make(chan O, 1),
|
||||
}
|
||||
}
|
||||
|
||||
return &StableRunner[I, O]{
|
||||
handle: fn,
|
||||
ring: ring,
|
||||
runner: NewTaskRunner(runtime.NumCPU()),
|
||||
done: make(chan struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
// Get returns the next processed message in order.
|
||||
// This method should be called in one goroutine.
|
||||
func (r *StableRunner[I, O]) Get() (O, error) {
|
||||
defer atomic.AddUint64(&r.consumedIndex, 1)
|
||||
|
||||
index := atomic.LoadUint64(&r.consumedIndex)
|
||||
offset := index % uint64(bufSize)
|
||||
holder := r.ring[offset]
|
||||
|
||||
select {
|
||||
case o := <-holder.value:
|
||||
return o, nil
|
||||
case <-r.done:
|
||||
if atomic.LoadUint64(&r.consumedIndex) < atomic.LoadUint64(&r.writtenIndex) {
|
||||
return <-holder.value, nil
|
||||
}
|
||||
|
||||
var o O
|
||||
return o, ErrRunnerClosed
|
||||
}
|
||||
}
|
||||
|
||||
// Push pushes the message v into the runner and to be processed concurrently,
|
||||
// after processed, it will be cached to let caller take it in pushing order.
|
||||
func (r *StableRunner[I, O]) Push(v I) error {
|
||||
select {
|
||||
case <-r.done:
|
||||
return ErrRunnerClosed
|
||||
default:
|
||||
index := atomic.AddUint64(&r.writtenIndex, 1)
|
||||
offset := (index - 1) % uint64(bufSize)
|
||||
holder := r.ring[offset]
|
||||
holder.lock.Lock()
|
||||
r.runner.Schedule(func() {
|
||||
defer holder.lock.Unlock()
|
||||
o := r.handle(v)
|
||||
holder.value <- o
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Wait waits all the messages to be processed and taken from inner buffer.
|
||||
func (r *StableRunner[I, O]) Wait() {
|
||||
close(r.done)
|
||||
r.runner.Wait()
|
||||
for atomic.LoadUint64(&r.consumedIndex) < atomic.LoadUint64(&r.writtenIndex) {
|
||||
runtime.Gosched()
|
||||
}
|
||||
}
|
@ -0,0 +1,97 @@
|
||||
package threading
|
||||
|
||||
import (
|
||||
"math/rand"
|
||||
"sort"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestStableRunner(t *testing.T) {
|
||||
size := bufSize * 2
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
runner := NewStableRunner(func(v int) float64 {
|
||||
if v == 0 {
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
} else {
|
||||
time.Sleep(time.Millisecond * time.Duration(rand.Intn(10)))
|
||||
}
|
||||
return float64(v) + 0.5
|
||||
})
|
||||
|
||||
var waitGroup sync.WaitGroup
|
||||
waitGroup.Add(1)
|
||||
go func() {
|
||||
for i := 0; i < size; i++ {
|
||||
assert.NoError(t, runner.Push(i))
|
||||
}
|
||||
runner.Wait()
|
||||
waitGroup.Done()
|
||||
}()
|
||||
|
||||
values := make([]float64, size)
|
||||
for i := 0; i < size; i++ {
|
||||
var err error
|
||||
values[i], err = runner.Get()
|
||||
assert.NoError(t, err)
|
||||
time.Sleep(time.Millisecond)
|
||||
}
|
||||
|
||||
assert.True(t, sort.Float64sAreSorted(values))
|
||||
waitGroup.Wait()
|
||||
|
||||
assert.Equal(t, ErrRunnerClosed, runner.Push(1))
|
||||
_, err := runner.Get()
|
||||
assert.Equal(t, ErrRunnerClosed, err)
|
||||
}
|
||||
|
||||
func FuzzStableRunner(f *testing.F) {
|
||||
rand.Seed(time.Now().UnixNano())
|
||||
f.Add(uint64(bufSize))
|
||||
f.Fuzz(func(t *testing.T, n uint64) {
|
||||
runner := NewStableRunner(func(v int) float64 {
|
||||
if v == 0 {
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
} else {
|
||||
time.Sleep(time.Millisecond * time.Duration(rand.Intn(10)))
|
||||
}
|
||||
return float64(v) + 0.5
|
||||
})
|
||||
|
||||
go func() {
|
||||
for i := 0; i < int(n); i++ {
|
||||
assert.NoError(t, runner.Push(i))
|
||||
}
|
||||
}()
|
||||
|
||||
values := make([]float64, n)
|
||||
for i := 0; i < int(n); i++ {
|
||||
var err error
|
||||
values[i], err = runner.Get()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
runner.Wait()
|
||||
assert.True(t, sort.Float64sAreSorted(values))
|
||||
|
||||
// make sure returning errors after runner is closed
|
||||
assert.Equal(t, ErrRunnerClosed, runner.Push(1))
|
||||
_, err := runner.Get()
|
||||
assert.Equal(t, ErrRunnerClosed, err)
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkStableRunner(b *testing.B) {
|
||||
runner := NewStableRunner(func(v int) float64 {
|
||||
time.Sleep(time.Millisecond * time.Duration(rand.Intn(10)))
|
||||
return float64(v) + 0.5
|
||||
})
|
||||
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = runner.Push(i)
|
||||
_, _ = runner.Get()
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue