Skip to content

Commit a33c0cd

Browse files
authored
Merge pull request #110 from gburt/master
add Flush() to bypass waiting for the timer
2 parents f9e94c2 + 4fafcfa commit a33c0cd

File tree

2 files changed

+78
-13
lines changed

2 files changed

+78
-13
lines changed

dataloader.go

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@ type Interface[K comparable, V any] interface {
2626
Clear(context.Context, K) Interface[K, V]
2727
ClearAll() Interface[K, V]
2828
Prime(ctx context.Context, key K, value V) Interface[K, V]
29+
Flush()
2930
}
3031

3132
// BatchFunc is a function, which when given a slice of keys (string), returns a slice of `results`.
@@ -298,20 +299,37 @@ func (l *Loader[K, V]) Load(originalContext context.Context, key K) Thunk[V] {
298299
l.count++
299300
// if we hit our limit, force the batch to start
300301
if l.count == l.batchCap {
301-
// end the batcher synchronously here because another call to Load
302+
// end/flush the batcher synchronously here because another call to Load
302303
// may concurrently happen and needs to go to a new batcher.
303-
l.curBatcher.end()
304-
// end the sleeper for the current batcher.
305-
// this is to stop the goroutine without waiting for the
306-
// sleeper timeout.
307-
close(l.endSleeper)
308-
l.reset()
304+
l.flush()
309305
}
310306
}
311307

312308
return thunk
313309
}
314310

311+
// flush() is a helper that runs whatever batched items there are immediately.
312+
// it must be called by code protected by a l.batchLock.Lock()
313+
func (l *Loader[K, V]) flush() {
314+
l.curBatcher.end()
315+
316+
// end the sleeper for the current batcher.
317+
// this is to stop the goroutine without waiting for the
318+
// sleeper timeout.
319+
close(l.endSleeper)
320+
l.reset()
321+
}
322+
323+
// Flush will load the items in the current batch immediately without waiting for the timer.
324+
func (l *Loader[K, V]) Flush() {
325+
l.batchLock.Lock()
326+
defer l.batchLock.Unlock()
327+
if l.curBatcher == nil {
328+
return
329+
}
330+
l.flush()
331+
}
332+
315333
// LoadMany loads multiple keys, returning a thunk (type: ThunkMany) that will resolve the keys passed in.
316334
func (l *Loader[K, V]) LoadMany(originalContext context.Context, keys []K) ThunkMany[V] {
317335
ctx, finish := l.tracer.TraceLoadMany(originalContext, keys)

dataloader_test.go

Lines changed: 53 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,12 @@ import (
99
"strconv"
1010
"sync"
1111
"testing"
12+
"time"
1213
)
1314

14-
///////////////////////////////////////////////////
15-
// Tests
16-
///////////////////////////////////////////////////
15+
/*
16+
Tests
17+
*/
1718
func TestLoader(t *testing.T) {
1819
t.Run("test Load method", func(t *testing.T) {
1920
t.Parallel()
@@ -328,6 +329,7 @@ func TestLoader(t *testing.T) {
328329
t.Parallel()
329330
identityLoader, loadCalls := IDLoader[string](0)
330331
ctx := context.Background()
332+
start := time.Now()
331333
future1 := identityLoader.Load(ctx, "1")
332334
future2 := identityLoader.Load(ctx, "1")
333335

@@ -340,6 +342,12 @@ func TestLoader(t *testing.T) {
340342
t.Error(err.Error())
341343
}
342344

345+
// also check that it took the full timeout to return
346+
var duration = time.Since(start)
347+
if duration < 16*time.Millisecond {
348+
t.Errorf("took %v when expected it to take more than 16 ms because of wait", duration)
349+
}
350+
343351
calls := *loadCalls
344352
inner := []string{"1"}
345353
expected := [][]string{inner}
@@ -348,6 +356,45 @@ func TestLoader(t *testing.T) {
348356
}
349357
})
350358

359+
t.Run("doesn't wait for timeout if Flush() is called", func(t *testing.T) {
360+
t.Parallel()
361+
identityLoader, loadCalls := IDLoader[string](0)
362+
ctx := context.Background()
363+
start := time.Now()
364+
future1 := identityLoader.Load(ctx, "1")
365+
future2 := identityLoader.Load(ctx, "2")
366+
367+
// trigger them to be fetched immediately vs waiting for the 16 ms timer
368+
identityLoader.Flush()
369+
370+
_, err := future1()
371+
if err != nil {
372+
t.Error(err.Error())
373+
}
374+
_, err = future2()
375+
if err != nil {
376+
t.Error(err.Error())
377+
}
378+
379+
var duration = time.Since(start)
380+
if duration > 2*time.Millisecond {
381+
t.Errorf("took %v when expected it to take less than 2 ms b/c we called Flush()", duration)
382+
}
383+
384+
calls := *loadCalls
385+
inner := []string{"1", "2"}
386+
expected := [][]string{inner}
387+
if !reflect.DeepEqual(calls, expected) {
388+
t.Errorf("did not respect max batch size. Expected %#v, got %#v", expected, calls)
389+
}
390+
})
391+
392+
t.Run("Nothing for Flush() to do on empty loader with current batch", func(t *testing.T) {
393+
t.Parallel()
394+
identityLoader, _ := IDLoader[string](0)
395+
identityLoader.Flush()
396+
})
397+
351398
t.Run("allows primed cache", func(t *testing.T) {
352399
t.Parallel()
353400
identityLoader, loadCalls := IDLoader[string](0)
@@ -741,9 +788,9 @@ func FaultyLoader[K comparable]() (*Loader[K, K], *[][]K) {
741788
return loader, &loadCalls
742789
}
743790

744-
///////////////////////////////////////////////////
745-
// Benchmarks
746-
///////////////////////////////////////////////////
791+
/*
792+
Benchmarks
793+
*/
747794
var a = &Avg{}
748795

749796
func batchIdentity[K comparable](_ context.Context, keys []K) (results []*Result[K]) {

0 commit comments

Comments
 (0)