diff --git a/core/fx/fn.go b/core/fx/fn.go index 197bcfea..cc7c722e 100644 --- a/core/fx/fn.go +++ b/core/fx/fn.go @@ -68,6 +68,7 @@ func Range(source <-chan interface{}) Stream { } // Buffer buffers the items into a queue with size n. +// It can balance the producer and the consumer if their processing throughput don't match. func (p Stream) Buffer(n int) Stream { if n < 0 { n = 0 @@ -247,6 +248,32 @@ func (p Stream) Sort(less LessFunc) Stream { return Just(items...) } +// Split splits the elements into chunk with size up to n, +// might be less than n on tailing elements. +func (p Stream) Split(n int) Stream { + if n < 1 { + panic("n should be greater than 0") + } + + source := make(chan interface{}) + go func() { + var chunk []interface{} + for item := range p.source { + chunk = append(chunk, item) + if len(chunk) == n { + source <- chunk + chunk = nil + } + } + if chunk != nil { + source <- chunk + } + close(source) + }() + + return Range(source) +} + func (p Stream) Tail(n int64) Stream { if n < 1 { panic("n should be greater than 0") diff --git a/core/fx/fn_test.go b/core/fx/fn_test.go index 7033b2af..ed87193c 100644 --- a/core/fx/fn_test.go +++ b/core/fx/fn_test.go @@ -283,6 +283,22 @@ func TestSort(t *testing.T) { }) } +func TestSplit(t *testing.T) { + assert.Panics(t, func() { + Just(1, 2, 3, 4, 5, 6, 7, 8, 9, 10).Split(0).Done() + }) + var chunks [][]interface{} + Just(1, 2, 3, 4, 5, 6, 7, 8, 9, 10).Split(4).ForEach(func(item interface{}) { + chunk := item.([]interface{}) + chunks = append(chunks, chunk) + }) + assert.EqualValues(t, [][]interface{}{ + {1, 2, 3, 4}, + {5, 6, 7, 8}, + {9, 10}, + }, chunks) +} + func TestTail(t *testing.T) { var result int Just(1, 2, 3, 4).Tail(2).Reduce(func(pipe <-chan interface{}) (interface{}, error) { diff --git a/example/fx/fx_test.go b/example/fx/fx_test.go index 125e1e5e..37df046d 100644 --- a/example/fx/fx_test.go +++ b/example/fx/fx_test.go @@ -1,11 +1,19 @@ package main import ( + "fmt" "testing" "github.com/tal-tech/go-zero/core/fx" ) +func TestFxSplit(t *testing.T) { + fx.Just(1, 2, 3, 4, 5, 6, 7, 8, 9, 10).Split(4).ForEach(func(item interface{}) { + vals := item.([]interface{}) + fmt.Println(len(vals)) + }) +} + func BenchmarkFx(b *testing.B) { type Mixed struct { Name string