Skip to content

Commit f9e94c2

Browse files
authored
Merge pull request #111 from uphold-forks/feature/expose-panic-error-wrapper
Support cache skipping for `Load()` calls that throw `SkipCacheError`
2 parents fa2606f + bf5efb6 commit f9e94c2

File tree

2 files changed

+83
-1
lines changed

2 files changed

+83
-1
lines changed

dataloader.go

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,24 @@ func (p *PanicErrorWrapper) Error() string {
6060
return p.panicError.Error()
6161
}
6262

63+
// SkipCacheError wraps the error interface.
64+
// The cache should not store SkipCacheErrors.
65+
type SkipCacheError struct {
66+
err error
67+
}
68+
69+
func (s *SkipCacheError) Error() string {
70+
return s.err.Error()
71+
}
72+
73+
func (s *SkipCacheError) Unwrap() error {
74+
return s.err
75+
}
76+
77+
func NewSkipCacheError(err error) *SkipCacheError {
78+
return &SkipCacheError{err: err}
79+
}
80+
6381
// Loader implements the dataloader.Interface.
6482
type Loader[K comparable, V any] struct {
6583
// the batch function to be used by this loader
@@ -249,7 +267,8 @@ func (l *Loader[K, V]) Load(originalContext context.Context, key K) Thunk[V] {
249267
result.mu.RLock()
250268
defer result.mu.RUnlock()
251269
var ev *PanicErrorWrapper
252-
if result.value.Error != nil && errors.As(result.value.Error, &ev) {
270+
var es *SkipCacheError
271+
if result.value.Error != nil && (errors.As(result.value.Error, &ev) || errors.As(result.value.Error, &es)){
253272
l.Clear(ctx, key)
254273
}
255274
return result.value.Data, result.value.Error

dataloader_test.go

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,45 @@ func TestLoader(t *testing.T) {
7979
}
8080
})
8181

82+
t.Run("test Load Method not caching results with errors of type SkipCacheError", func(t *testing.T) {
83+
t.Parallel()
84+
skipCacheLoader, loadCalls := SkipCacheErrorLoader(3, "1")
85+
ctx := context.Background()
86+
futures1 := skipCacheLoader.LoadMany(ctx, []string{"1", "2", "3"})
87+
_, errs1 := futures1()
88+
var errCount int = 0
89+
var nilCount int = 0
90+
for _, err := range errs1 {
91+
if err == nil {
92+
nilCount++
93+
} else {
94+
errCount++
95+
}
96+
}
97+
if errCount != 1 {
98+
t.Error("Expected an error on only key \"1\"")
99+
}
100+
101+
if nilCount != 2 {
102+
t.Error("Expected the other errors to be nil")
103+
}
104+
105+
futures2 := skipCacheLoader.LoadMany(ctx, []string{"2", "3", "1"})
106+
_, errs2 := futures2()
107+
// There should be no errors in the second batch, as the only key that was not cached
108+
// this time around will not throw an error
109+
if errs2 != nil {
110+
t.Error("Expected LoadMany() to return nil error slice when no errors occurred")
111+
}
112+
113+
calls := (*loadCalls)[1]
114+
expected := []string{"1"}
115+
116+
if !reflect.DeepEqual(calls, expected) {
117+
t.Errorf("Expected load calls %#v, got %#v", expected, calls)
118+
}
119+
})
120+
82121
t.Run("test Load Method Panic Safety in multiple keys", func(t *testing.T) {
83122
t.Parallel()
84123
defer func() {
@@ -622,6 +661,30 @@ func ErrorCacheLoader[K comparable](max int) (*Loader[K, K], *[][]K) {
622661
return errorCacheLoader, &loadCalls
623662
}
624663

664+
func SkipCacheErrorLoader[K comparable](max int, onceErrorKey K) (*Loader[K, K], *[][]K) {
665+
var mu sync.Mutex
666+
var loadCalls [][]K
667+
errorThrown := false
668+
skipCacheErrorLoader := NewBatchedLoader(func(_ context.Context, keys []K) []*Result[K] {
669+
var results []*Result[K]
670+
mu.Lock()
671+
loadCalls = append(loadCalls, keys)
672+
mu.Unlock()
673+
// return a non cacheable error for the first occurence of onceErrorKey
674+
for _, k := range keys {
675+
if !errorThrown && k == onceErrorKey {
676+
results = append(results, &Result[K]{k, NewSkipCacheError(fmt.Errorf("non cacheable error"))})
677+
errorThrown = true
678+
} else {
679+
results = append(results, &Result[K]{k, nil})
680+
}
681+
}
682+
683+
return results
684+
}, WithBatchCapacity[K, K](max))
685+
return skipCacheErrorLoader, &loadCalls
686+
}
687+
625688
func BadLoader[K comparable](max int) (*Loader[K, K], *[][]K) {
626689
var mu sync.Mutex
627690
var loadCalls [][]K

0 commit comments

Comments
 (0)