From 2ccef5bb4f92b8a5b41b2e5a986fcae65e5685c3 Mon Sep 17 00:00:00 2001 From: Kevin Wan Date: Tue, 6 Feb 2024 14:26:22 +0800 Subject: [PATCH] feat: support ScheduleImmediately in TaskRunner (#3896) --- core/threading/taskrunner.go | 41 +++++++++++++++++++++++++++++++ core/threading/taskrunner_test.go | 32 +++++++++++++++++++----- 2 files changed, 67 insertions(+), 6 deletions(-) diff --git a/core/threading/taskrunner.go b/core/threading/taskrunner.go index 955beeab..18ac4f6c 100644 --- a/core/threading/taskrunner.go +++ b/core/threading/taskrunner.go @@ -1,13 +1,20 @@ package threading import ( + "errors" + "sync" + "github.com/zeromicro/go-zero/core/lang" "github.com/zeromicro/go-zero/core/rescue" ) +// ErrTaskRunnerBusy is the error that indicates the runner is busy. +var ErrTaskRunnerBusy = errors.New("task runner is busy") + // A TaskRunner is used to control the concurrency of goroutines. type TaskRunner struct { limitChan chan lang.PlaceholderType + waitGroup sync.WaitGroup } // NewTaskRunner returns a TaskRunner. @@ -19,13 +26,47 @@ func NewTaskRunner(concurrency int) *TaskRunner { // Schedule schedules a task to run under concurrency control. func (rp *TaskRunner) Schedule(task func()) { + // Why we add waitGroup first, in case of race condition on starting a task and wait returns. + // For example, limitChan is full, and the task is scheduled to run, but the waitGroup is not added, + // then the wait returns, and the task is then scheduled to run, but caller thinks all tasks are done. + // the same reason for ScheduleImmediately. + rp.waitGroup.Add(1) rp.limitChan <- lang.Placeholder go func() { defer rescue.Recover(func() { <-rp.limitChan + rp.waitGroup.Done() }) task() }() } + +// ScheduleImmediately schedules a task to run immediately under concurrency control. +// It returns ErrTaskRunnerBusy if the runner is busy. +func (rp *TaskRunner) ScheduleImmediately(task func()) error { + // Why we add waitGroup first, check the comment in Schedule. + rp.waitGroup.Add(1) + select { + case rp.limitChan <- lang.Placeholder: + default: + rp.waitGroup.Done() + return ErrTaskRunnerBusy + } + + go func() { + defer rescue.Recover(func() { + <-rp.limitChan + rp.waitGroup.Done() + }) + task() + }() + + return nil +} + +// Wait waits all running tasks to be done. +func (rp *TaskRunner) Wait() { + rp.waitGroup.Wait() +} diff --git a/core/threading/taskrunner_test.go b/core/threading/taskrunner_test.go index 81cefc82..7771760a 100644 --- a/core/threading/taskrunner_test.go +++ b/core/threading/taskrunner_test.go @@ -2,32 +2,52 @@ package threading import ( "runtime" - "sync" "sync/atomic" "testing" + "time" "github.com/stretchr/testify/assert" ) -func TestRoutinePool(t *testing.T) { +func TestTaskRunner_Schedule(t *testing.T) { times := 100 pool := NewTaskRunner(runtime.NumCPU()) var counter int32 - var waitGroup sync.WaitGroup for i := 0; i < times; i++ { - waitGroup.Add(1) pool.Schedule(func() { atomic.AddInt32(&counter, 1) - waitGroup.Done() }) } - waitGroup.Wait() + pool.Wait() assert.Equal(t, times, int(counter)) } +func TestTaskRunner_ScheduleImmediately(t *testing.T) { + cpus := runtime.NumCPU() + times := cpus * 2 + pool := NewTaskRunner(cpus) + + var counter int32 + for i := 0; i < times; i++ { + err := pool.ScheduleImmediately(func() { + atomic.AddInt32(&counter, 1) + time.Sleep(time.Millisecond * 100) + }) + if i < cpus { + assert.Nil(t, err) + } else { + assert.ErrorIs(t, err, ErrTaskRunnerBusy) + } + } + + pool.Wait() + + assert.Equal(t, cpus, int(counter)) +} + func BenchmarkRoutinePool(b *testing.B) { queue := NewTaskRunner(runtime.NumCPU()) for i := 0; i < b.N; i++ {