feat: prevent call bursts on oidc provider refresh
All checks were successful
Cadoles/bouncer/pipeline/head This commit looks good

This commit is contained in:
2025-03-18 15:51:25 +01:00
parent 59ecfa7b4e
commit 692523e54f
5 changed files with 205 additions and 115 deletions

View File

@ -0,0 +1,63 @@
package syncx
import (
"context"
"sync"
"forge.cadoles.com/cadoles/bouncer/internal/cache"
"github.com/pkg/errors"
)
type RefreshFunc[K comparable, V any] func(ctx context.Context, key K) (V, error)
type CachedResource[K comparable, V any] struct {
cache cache.Cache[K, V]
lock sync.RWMutex
refresh RefreshFunc[K, V]
}
func (r *CachedResource[K, V]) Clear() {
r.cache.Clear()
}
func (r *CachedResource[K, V]) Get(ctx context.Context, key K) (V, bool, error) {
value, exists := r.cache.Get(key)
if exists {
return value, false, nil
}
locked := r.lock.TryLock()
if !locked {
r.lock.RLock()
value, exists := r.cache.Get(key)
if exists {
r.lock.RUnlock()
return value, false, nil
}
r.lock.RUnlock()
}
if !locked {
r.lock.Lock()
}
defer r.lock.Unlock()
value, err := r.refresh(ctx, key)
if err != nil {
return *new(V), false, errors.WithStack(err)
}
r.cache.Set(key, value)
return value, true, nil
}
func NewCachedResource[K comparable, V any](cache cache.Cache[K, V], refresh RefreshFunc[K, V]) *CachedResource[K, V] {
return &CachedResource[K, V]{
cache: cache,
refresh: refresh,
}
}

View File

@ -0,0 +1,66 @@
package syncx
import (
"context"
"math"
"sync"
"testing"
"time"
"forge.cadoles.com/cadoles/bouncer/internal/cache/memory"
"forge.cadoles.com/cadoles/bouncer/internal/cache/ttl"
"github.com/pkg/errors"
)
func TestCachedResource(t *testing.T) {
refreshCalls := 0
cacheTTL := 1*time.Second + 500*time.Millisecond
duration := 2 * time.Second
expectedCalls := math.Ceil(float64(duration) / float64(cacheTTL))
resource := NewCachedResource(
ttl.NewCache(
memory.NewCache[string, string](),
memory.NewCache[string, time.Time](),
cacheTTL,
),
func(ctx context.Context, key string) (string, error) {
refreshCalls++
return "bar", nil
},
)
concurrents := 50
key := "foo"
var wg sync.WaitGroup
wg.Add(concurrents)
for i := range concurrents {
go func(i int) {
done := time.After(duration)
defer wg.Done()
for {
select {
case <-done:
return
default:
value, fresh, err := resource.Get(context.Background(), key)
if err != nil {
t.Errorf("%+v", errors.WithStack(err))
}
t.Logf("resource retrieved for goroutine #%d: (%s, %s, %v)", i, key, value, fresh)
}
}
}(i)
}
wg.Wait()
if e, g := int(expectedCalls), refreshCalls; e != g {
t.Errorf("refreshCalls: expected '%d', got '%d'", e, g)
}
}