feat: support ScheduleImmediately in TaskRunner (#3896)

master
Kevin Wan 10 months ago committed by GitHub
parent 10f1d93e2a
commit 2ccef5bb4f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,13 +1,20 @@
package threading package threading
import ( import (
"errors"
"sync"
"github.com/zeromicro/go-zero/core/lang" "github.com/zeromicro/go-zero/core/lang"
"github.com/zeromicro/go-zero/core/rescue" "github.com/zeromicro/go-zero/core/rescue"
) )
// ErrTaskRunnerBusy is the error that indicates the runner is busy.
var ErrTaskRunnerBusy = errors.New("task runner is busy")
// A TaskRunner is used to control the concurrency of goroutines. // A TaskRunner is used to control the concurrency of goroutines.
type TaskRunner struct { type TaskRunner struct {
limitChan chan lang.PlaceholderType limitChan chan lang.PlaceholderType
waitGroup sync.WaitGroup
} }
// NewTaskRunner returns a TaskRunner. // NewTaskRunner returns a TaskRunner.
@ -19,13 +26,47 @@ func NewTaskRunner(concurrency int) *TaskRunner {
// Schedule schedules a task to run under concurrency control. // Schedule schedules a task to run under concurrency control.
func (rp *TaskRunner) Schedule(task func()) { func (rp *TaskRunner) Schedule(task func()) {
// Why we add waitGroup first, in case of race condition on starting a task and wait returns.
// For example, limitChan is full, and the task is scheduled to run, but the waitGroup is not added,
// then the wait returns, and the task is then scheduled to run, but caller thinks all tasks are done.
// the same reason for ScheduleImmediately.
rp.waitGroup.Add(1)
rp.limitChan <- lang.Placeholder rp.limitChan <- lang.Placeholder
go func() { go func() {
defer rescue.Recover(func() { defer rescue.Recover(func() {
<-rp.limitChan <-rp.limitChan
rp.waitGroup.Done()
}) })
task() task()
}() }()
} }
// ScheduleImmediately schedules a task to run immediately under concurrency control.
// It returns ErrTaskRunnerBusy if the runner is busy.
func (rp *TaskRunner) ScheduleImmediately(task func()) error {
// Why we add waitGroup first, check the comment in Schedule.
rp.waitGroup.Add(1)
select {
case rp.limitChan <- lang.Placeholder:
default:
rp.waitGroup.Done()
return ErrTaskRunnerBusy
}
go func() {
defer rescue.Recover(func() {
<-rp.limitChan
rp.waitGroup.Done()
})
task()
}()
return nil
}
// Wait waits all running tasks to be done.
func (rp *TaskRunner) Wait() {
rp.waitGroup.Wait()
}

@ -2,32 +2,52 @@ package threading
import ( import (
"runtime" "runtime"
"sync"
"sync/atomic" "sync/atomic"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
) )
func TestRoutinePool(t *testing.T) { func TestTaskRunner_Schedule(t *testing.T) {
times := 100 times := 100
pool := NewTaskRunner(runtime.NumCPU()) pool := NewTaskRunner(runtime.NumCPU())
var counter int32 var counter int32
var waitGroup sync.WaitGroup
for i := 0; i < times; i++ { for i := 0; i < times; i++ {
waitGroup.Add(1)
pool.Schedule(func() { pool.Schedule(func() {
atomic.AddInt32(&counter, 1) atomic.AddInt32(&counter, 1)
waitGroup.Done()
}) })
} }
waitGroup.Wait() pool.Wait()
assert.Equal(t, times, int(counter)) assert.Equal(t, times, int(counter))
} }
func TestTaskRunner_ScheduleImmediately(t *testing.T) {
cpus := runtime.NumCPU()
times := cpus * 2
pool := NewTaskRunner(cpus)
var counter int32
for i := 0; i < times; i++ {
err := pool.ScheduleImmediately(func() {
atomic.AddInt32(&counter, 1)
time.Sleep(time.Millisecond * 100)
})
if i < cpus {
assert.Nil(t, err)
} else {
assert.ErrorIs(t, err, ErrTaskRunnerBusy)
}
}
pool.Wait()
assert.Equal(t, cpus, int(counter))
}
func BenchmarkRoutinePool(b *testing.B) { func BenchmarkRoutinePool(b *testing.B) {
queue := NewTaskRunner(runtime.NumCPU()) queue := NewTaskRunner(runtime.NumCPU())
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {

Loading…
Cancel
Save