Skip to content

Commit e17a19f

Browse files
committed
Add WithSplitFunc option (WithTimeout now need generic, damn)
1 parent 91bfaf2 commit e17a19f

File tree

2 files changed

+37
-11
lines changed

2 files changed

+37
-11
lines changed

batchan.go

Lines changed: 21 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,29 +2,40 @@ package batchan
22

33
import "time"
44

5-
type option func(*config)
5+
type option[T any] func(*config[T])
66

7-
type config struct {
7+
type config[T any] struct {
88
timeout time.Duration
99
hasTimeout bool
10+
splitFunc func(T, T) bool
1011
}
1112

12-
func WithTimeout(timeout time.Duration) option {
13-
return func(cfg *config) {
13+
func WithTimeout[T any](timeout time.Duration) option[T] {
14+
return func(cfg *config[T]) {
1415
cfg.timeout = timeout
1516
cfg.hasTimeout = true
1617
}
1718
}
1819

20+
func WithSplitFunc[T any](splitFunc func(T, T) bool) option[T] {
21+
return func(cfg *config[T]) {
22+
cfg.splitFunc = splitFunc
23+
}
24+
}
25+
1926
func timerOrNil(t *time.Timer) <-chan time.Time {
2027
if t != nil {
2128
return t.C
2229
}
2330
return nil
2431
}
2532

26-
func New[T any](in <-chan T, size int, opts ...option) <-chan []T {
27-
cfg := &config{}
33+
func noSplitFunc[T any](t1, t2 T) bool { return false }
34+
35+
func New[T any](in <-chan T, size int, opts ...option[T]) <-chan []T {
36+
cfg := &config[T]{
37+
splitFunc: noSplitFunc[T],
38+
}
2839

2940
for _, opt := range opts {
3041
opt(cfg)
@@ -63,11 +74,14 @@ func New[T any](in <-chan T, size int, opts ...option) <-chan []T {
6374
return
6475
}
6576

77+
if len(currentBatch) > 0 && cfg.splitFunc(currentBatch[len(currentBatch)-1], t) {
78+
flush()
79+
}
80+
6681
currentBatch = append(currentBatch, t)
6782
if len(currentBatch) >= size {
6883
flush()
6984
}
70-
7185
case <-timerOrNil(timer):
7286
flush()
7387
}

batchan_test.go

Lines changed: 16 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ func collectBatches[T any](out <-chan []T) [][]T {
3030

3131
func TestBatchingFlushOnTimeout(t *testing.T) {
3232
in := sendIntsToChan([]int{1, 2}, 150*time.Millisecond) // delay > timeout
33-
out := batchan.New(in, 5, batchan.WithTimeout(100*time.Millisecond))
33+
out := batchan.New(in, 5, batchan.WithTimeout[int](100*time.Millisecond))
3434

3535
got := collectBatches(out)
3636

@@ -43,7 +43,7 @@ func TestBatchingFlushOnTimeout(t *testing.T) {
4343

4444
func TestBatchingFlushOnSizeOrTimeout(t *testing.T) {
4545
in := sendIntsToChan([]int{1, 2, 3, 4}, 50*time.Millisecond)
46-
out := batchan.New(in, 2, batchan.WithTimeout(200*time.Millisecond))
46+
out := batchan.New(in, 2, batchan.WithTimeout[int](200*time.Millisecond))
4747

4848
got := collectBatches(out)
4949
expected := [][]int{{1, 2}, {3, 4}}
@@ -55,7 +55,7 @@ func TestBatchingFlushOnSizeOrTimeout(t *testing.T) {
5555

5656
func TestFlushTimeoutMultiple(t *testing.T) {
5757
in := sendIntsToChan([]int{1, 2, 3}, 300*time.Millisecond)
58-
out := batchan.New(in, 10, batchan.WithTimeout(200*time.Millisecond)) // small timeout, large batch size
58+
out := batchan.New(in, 10, batchan.WithTimeout[int](200*time.Millisecond)) // small timeout, large batch size
5959

6060
got := collectBatches(out)
6161
expected := [][]int{{1}, {2}, {3}}
@@ -77,7 +77,7 @@ func TestTimeoutResetsAfterFlush(t *testing.T) {
7777
in <- 3
7878
}()
7979

80-
out := batchan.New(in, 2, batchan.WithTimeout(100*time.Millisecond))
80+
out := batchan.New(in, 2, batchan.WithTimeout[int](100*time.Millisecond))
8181

8282
got := collectBatches(out)
8383
expected := [][]int{{1}, {2}, {3}}
@@ -146,6 +146,18 @@ func TestBatchSizeLargerThanInput(t *testing.T) {
146146
}
147147
}
148148

149+
func TestSplitFunc(t *testing.T) {
150+
in := sendIntsToChan([]int{1, 2, 3, 5, 6}, time.Microsecond)
151+
out := batchan.New(in, 5, batchan.WithSplitFunc(func(i1, i2 int) bool { return i2-i1 > 1 }))
152+
153+
expected := [][]int{{1, 2, 3}, {5, 6}}
154+
got := collectBatches(out)
155+
156+
if !reflect.DeepEqual(got, expected) {
157+
t.Errorf("Expected %v, got %v", expected, got)
158+
}
159+
}
160+
149161
// Optional: Test that the output channel closes properly
150162
func TestOutputChannelClosure(t *testing.T) {
151163
in := sendIntsToChan([]int{1, 2, 3}, time.Microsecond)

0 commit comments

Comments
 (0)