From 692523e54f73b40fa7e3eb851f83d949da017b3c Mon Sep 17 00:00:00 2001 From: William Petit Date: Tue, 18 Mar 2025 15:51:25 +0100 Subject: [PATCH] feat: prevent call bursts on oidc provider refresh --- internal/proxy/director/director.go | 93 ++++--------------- .../layer/authn/oidc/authenticator.go | 76 ++++++++++----- .../proxy/director/layer/authn/oidc/layer.go | 22 ++--- internal/syncx/cached_resource.go | 63 +++++++++++++ internal/syncx/cached_resource_test.go | 66 +++++++++++++ 5 files changed, 205 insertions(+), 115 deletions(-) create mode 100644 internal/syncx/cached_resource.go create mode 100644 internal/syncx/cached_resource_test.go diff --git a/internal/proxy/director/director.go b/internal/proxy/director/director.go index fe3cf5d..d94b93d 100644 --- a/internal/proxy/director/director.go +++ b/internal/proxy/director/director.go @@ -4,12 +4,11 @@ import ( "context" "net/http" "sort" - "sync" "forge.cadoles.com/Cadoles/go-proxy" "forge.cadoles.com/Cadoles/go-proxy/wildcard" - "forge.cadoles.com/cadoles/bouncer/internal/cache" "forge.cadoles.com/cadoles/bouncer/internal/store" + "forge.cadoles.com/cadoles/bouncer/internal/syncx" "github.com/getsentry/sentry-go" "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" @@ -21,19 +20,18 @@ type Director struct { layerRepository store.LayerRepository layerRegistry *LayerRegistry - proxyCache cache.Cache[string, []*store.Proxy] - layerCache cache.Cache[string, []*store.Layer] - - proxyCacheLock sync.RWMutex - layerCacheLock sync.RWMutex + cachedProxies *syncx.CachedResource[string, []*store.Proxy] + cachedLayers *syncx.CachedResource[string, []*store.Layer] handleError HandleErrorFunc } +const proxiesCacheKey = "proxies" + func (d *Director) rewriteRequest(r *http.Request) (*http.Request, error) { ctx := r.Context() - proxies, err := d.getProxies(ctx) + proxies, _, err := d.cachedProxies.Get(ctx, proxiesCacheKey) if err != nil { return r, errors.WithStack(err) } @@ -58,7 +56,7 @@ func (d *Director) rewriteRequest(r *http.Request) (*http.Request, error) { metricProxyRequestsTotal.With(prometheus.Labels{metricLabelProxy: string(p.Name)}).Add(1) - proxyLayers, err := d.getLayers(proxyCtx, p.Name) + proxyLayers, _, err := d.cachedLayers.Get(proxyCtx, string(p.Name)) if err != nil { return r, errors.WithStack(err) } @@ -98,35 +96,7 @@ func (d *Director) rewriteRequest(r *http.Request) (*http.Request, error) { return r, nil } -const proxiesCacheKey = "proxies" - -func (d *Director) getProxies(ctx context.Context) ([]*store.Proxy, error) { - proxies, exists := d.proxyCache.Get(proxiesCacheKey) - if exists { - logger.Debug(ctx, "using cached proxies") - return proxies, nil - } - - locked := d.proxyCacheLock.TryLock() - if !locked { - d.proxyCacheLock.RLock() - - proxies, exists := d.proxyCache.Get(proxiesCacheKey) - if exists { - d.proxyCacheLock.RUnlock() - logger.Debug(ctx, "using cached proxies") - return proxies, nil - } - - d.proxyCacheLock.RUnlock() - } - - if !locked { - d.proxyCacheLock.Lock() - } - - defer d.proxyCacheLock.Unlock() - +func (d *Director) getProxies(ctx context.Context, key string) ([]*store.Proxy, error) { logger.Debug(ctx, "querying fresh proxies") headers, err := d.proxyRepository.QueryProxy(ctx, store.WithProxyQueryEnabled(true)) @@ -136,7 +106,7 @@ func (d *Director) getProxies(ctx context.Context) ([]*store.Proxy, error) { sort.Sort(store.ByProxyWeight(headers)) - proxies = make([]*store.Proxy, 0, len(headers)) + proxies := make([]*store.Proxy, 0, len(headers)) for _, h := range headers { if !h.Enabled { @@ -151,39 +121,11 @@ func (d *Director) getProxies(ctx context.Context) ([]*store.Proxy, error) { proxies = append(proxies, proxy) } - d.proxyCache.Set(proxiesCacheKey, proxies) - return proxies, nil } -func (d *Director) getLayers(ctx context.Context, proxyName store.ProxyName) ([]*store.Layer, error) { - cacheKey := "layers-" + string(proxyName) - - layers, exists := d.layerCache.Get(cacheKey) - if exists { - logger.Debug(ctx, "using cached layers") - return layers, nil - } - - locked := d.layerCacheLock.TryLock() - if !locked { - d.layerCacheLock.RLock() - - layers, exists := d.layerCache.Get(cacheKey) - if exists { - d.layerCacheLock.RUnlock() - logger.Debug(ctx, "using cached layers") - return layers, nil - } - - d.layerCacheLock.RUnlock() - } - - if !locked { - d.layerCacheLock.Lock() - } - - defer d.layerCacheLock.Unlock() +func (d *Director) getLayers(ctx context.Context, rawProxyName string) ([]*store.Layer, error) { + proxyName := store.ProxyName(rawProxyName) logger.Debug(ctx, "querying fresh layers") @@ -194,7 +136,7 @@ func (d *Director) getLayers(ctx context.Context, proxyName store.ProxyName) ([] sort.Sort(store.ByLayerWeight(headers)) - layers = make([]*store.Layer, 0, len(headers)) + layers := make([]*store.Layer, 0, len(headers)) for _, h := range headers { if !h.Enabled { @@ -209,8 +151,6 @@ func (d *Director) getLayers(ctx context.Context, proxyName store.ProxyName) ([] layers = append(layers, layer) } - d.layerCache.Set(cacheKey, layers) - return layers, nil } @@ -322,12 +262,15 @@ func New(proxyRepository store.ProxyRepository, layerRepository store.LayerRepos registry := NewLayerRegistry(opts.Layers...) - return &Director{ + director := &Director{ proxyRepository: proxyRepository, layerRepository: layerRepository, layerRegistry: registry, - proxyCache: opts.ProxyCache, - layerCache: opts.LayerCache, handleError: opts.HandleError, } + + director.cachedProxies = syncx.NewCachedResource(opts.ProxyCache, director.getProxies) + director.cachedLayers = syncx.NewCachedResource(opts.LayerCache, director.getLayers) + + return director } diff --git a/internal/proxy/director/layer/authn/oidc/authenticator.go b/internal/proxy/director/layer/authn/oidc/authenticator.go index abfc3a6..4fb306d 100644 --- a/internal/proxy/director/layer/authn/oidc/authenticator.go +++ b/internal/proxy/director/layer/authn/oidc/authenticator.go @@ -13,10 +13,12 @@ import ( "time" "forge.cadoles.com/Cadoles/go-proxy/wildcard" - "forge.cadoles.com/cadoles/bouncer/internal/cache" + "forge.cadoles.com/cadoles/bouncer/internal/cache/memory" + "forge.cadoles.com/cadoles/bouncer/internal/cache/ttl" "forge.cadoles.com/cadoles/bouncer/internal/proxy/director" "forge.cadoles.com/cadoles/bouncer/internal/proxy/director/layer/authn" "forge.cadoles.com/cadoles/bouncer/internal/store" + "forge.cadoles.com/cadoles/bouncer/internal/syncx" "github.com/coreos/go-oidc/v3/oidc" "github.com/gorilla/sessions" "github.com/pkg/errors" @@ -25,10 +27,10 @@ import ( ) type Authenticator struct { - store sessions.Store - httpTransport *http.Transport - httpClientTimeout time.Duration - oidcProviderCache cache.Cache[string, *oidc.Provider] + store sessions.Store + httpTransport *http.Transport + httpClientTimeout time.Duration + cachedOIDCProvider *syncx.CachedResource[string, *oidc.Provider] } func (a *Authenticator) PreAuthentication(w http.ResponseWriter, r *http.Request, layer *store.Layer) error { @@ -54,7 +56,7 @@ func (a *Authenticator) PreAuthentication(w http.ResponseWriter, r *http.Request return errors.WithStack(err) } - client, err := a.getClient(options, loginCallbackURL.String()) + client, err := a.getClient(ctx, options, loginCallbackURL.String()) if err != nil { return errors.WithStack(err) } @@ -160,7 +162,7 @@ func (a *Authenticator) Authenticate(w http.ResponseWriter, r *http.Request, lay return nil, errors.WithStack(err) } - client, err := a.getClient(options, loginCallbackURL.String()) + client, err := a.getClient(ctx, options, loginCallbackURL.String()) if err != nil { return nil, errors.WithStack(err) } @@ -362,9 +364,7 @@ func (a *Authenticator) templatize(rawTemplate string, proxyName store.ProxyName return raw.String(), nil } -func (a *Authenticator) getClient(options *LayerOptions, redirectURL string) (*Client, error) { - ctx := context.Background() - +func (a *Authenticator) getClient(ctx context.Context, options *LayerOptions, redirectURL string) (*Client, error) { transport := a.httpTransport.Clone() if options.OIDC.TLSInsecureSkipVerify { @@ -375,28 +375,24 @@ func (a *Authenticator) getClient(options *LayerOptions, redirectURL string) (*C transport.TLSClientConfig.InsecureSkipVerify = true } + if options.OIDC.SkipIssuerVerification { + ctx = oidc.InsecureIssuerURLContext(ctx, options.OIDC.IssuerURL) + } + httpClient := &http.Client{ Timeout: a.httpClientTimeout, Transport: transport, } - provider, exists := a.oidcProviderCache.Get(options.OIDC.IssuerURL) - if !exists { - var err error - ctx = oidc.ClientContext(ctx, httpClient) + ctx = oidc.ClientContext(ctx, httpClient) - if options.OIDC.SkipIssuerVerification { - ctx = oidc.InsecureIssuerURLContext(ctx, options.OIDC.IssuerURL) - } + if options.OIDC.SkipIssuerVerification { + ctx = oidc.InsecureIssuerURLContext(ctx, options.OIDC.IssuerURL) + } - logger.Debug(ctx, "refreshing oidc provider", logger.F("issuerURL", options.OIDC.IssuerURL)) - - provider, err = oidc.NewProvider(ctx, options.OIDC.IssuerURL) - if err != nil { - return nil, errors.Wrap(err, "could not create oidc provider") - } - - a.oidcProviderCache.Set(options.OIDC.IssuerURL, provider) + provider, _, err := a.cachedOIDCProvider.Get(ctx, options.OIDC.IssuerURL) + if err != nil { + return nil, errors.Wrap(err, "could not retrieve oidc provider") } client := NewClient( @@ -411,6 +407,17 @@ func (a *Authenticator) getClient(options *LayerOptions, redirectURL string) (*C return client, nil } +func (a *Authenticator) getOIDCProvider(ctx context.Context, issuerURL string) (*oidc.Provider, error) { + logger.Debug(ctx, "refreshing oidc provider", logger.F("issuerURL", issuerURL)) + + provider, err := oidc.NewProvider(ctx, issuerURL) + if err != nil { + return nil, errors.Wrap(err, "could not create oidc provider") + } + + return provider, nil +} + const defaultCookieNamePrefix = "_bouncer_authn_oidc" func (a *Authenticator) getCookieName(cookieName string, proxyName store.ProxyName, layerName store.LayerName) string { @@ -421,6 +428,25 @@ func (a *Authenticator) getCookieName(cookieName string, proxyName store.ProxyNa return strings.ToLower(fmt.Sprintf("%s_%s_%s", defaultCookieNamePrefix, proxyName, layerName)) } +func NewAuthenticator(httpTransport *http.Transport, clientTimeout time.Duration, store sessions.Store, oidcProviderCacheTimeout time.Duration) *Authenticator { + authenticator := &Authenticator{ + httpTransport: httpTransport, + httpClientTimeout: clientTimeout, + store: store, + } + + authenticator.cachedOIDCProvider = syncx.NewCachedResource( + ttl.NewCache( + memory.NewCache[string, *oidc.Provider](), + memory.NewCache[string, time.Time](), + oidcProviderCacheTimeout, + ), + authenticator.getOIDCProvider, + ) + + return authenticator +} + var ( _ authn.PreAuthentication = &Authenticator{} _ authn.Authenticator = &Authenticator{} diff --git a/internal/proxy/director/layer/authn/oidc/layer.go b/internal/proxy/director/layer/authn/oidc/layer.go index 617988f..33971b6 100644 --- a/internal/proxy/director/layer/authn/oidc/layer.go +++ b/internal/proxy/director/layer/authn/oidc/layer.go @@ -1,13 +1,8 @@ package oidc import ( - "time" - - "forge.cadoles.com/cadoles/bouncer/internal/cache/memory" - "forge.cadoles.com/cadoles/bouncer/internal/cache/ttl" "forge.cadoles.com/cadoles/bouncer/internal/proxy/director/layer/authn" "forge.cadoles.com/cadoles/bouncer/internal/store" - "github.com/coreos/go-oidc/v3/oidc" "github.com/gorilla/sessions" ) @@ -15,14 +10,11 @@ const LayerType store.LayerType = "authn-oidc" func NewLayer(store sessions.Store, funcs ...OptionFunc) *authn.Layer { opts := NewOptions(funcs...) - return authn.NewLayer(LayerType, &Authenticator{ - httpTransport: opts.HTTPTransport, - httpClientTimeout: opts.HTTPClientTimeout, - store: store, - oidcProviderCache: ttl.NewCache( - memory.NewCache[string, *oidc.Provider](), - memory.NewCache[string, time.Time](), - opts.OIDCProviderCacheTimeout, - ), - }, opts.AuthnOptions...) + authenticator := NewAuthenticator( + opts.HTTPTransport, + opts.HTTPClientTimeout, + store, + opts.OIDCProviderCacheTimeout, + ) + return authn.NewLayer(LayerType, authenticator, opts.AuthnOptions...) } diff --git a/internal/syncx/cached_resource.go b/internal/syncx/cached_resource.go new file mode 100644 index 0000000..1f1dcd4 --- /dev/null +++ b/internal/syncx/cached_resource.go @@ -0,0 +1,63 @@ +package syncx + +import ( + "context" + "sync" + + "forge.cadoles.com/cadoles/bouncer/internal/cache" + "github.com/pkg/errors" +) + +type RefreshFunc[K comparable, V any] func(ctx context.Context, key K) (V, error) + +type CachedResource[K comparable, V any] struct { + cache cache.Cache[K, V] + lock sync.RWMutex + refresh RefreshFunc[K, V] +} + +func (r *CachedResource[K, V]) Clear() { + r.cache.Clear() +} + +func (r *CachedResource[K, V]) Get(ctx context.Context, key K) (V, bool, error) { + value, exists := r.cache.Get(key) + if exists { + return value, false, nil + } + + locked := r.lock.TryLock() + if !locked { + r.lock.RLock() + + value, exists := r.cache.Get(key) + if exists { + r.lock.RUnlock() + return value, false, nil + } + + r.lock.RUnlock() + } + + if !locked { + r.lock.Lock() + } + + defer r.lock.Unlock() + + value, err := r.refresh(ctx, key) + if err != nil { + return *new(V), false, errors.WithStack(err) + } + + r.cache.Set(key, value) + + return value, true, nil +} + +func NewCachedResource[K comparable, V any](cache cache.Cache[K, V], refresh RefreshFunc[K, V]) *CachedResource[K, V] { + return &CachedResource[K, V]{ + cache: cache, + refresh: refresh, + } +} diff --git a/internal/syncx/cached_resource_test.go b/internal/syncx/cached_resource_test.go new file mode 100644 index 0000000..731b3a4 --- /dev/null +++ b/internal/syncx/cached_resource_test.go @@ -0,0 +1,66 @@ +package syncx + +import ( + "context" + "math" + "sync" + "testing" + "time" + + "forge.cadoles.com/cadoles/bouncer/internal/cache/memory" + "forge.cadoles.com/cadoles/bouncer/internal/cache/ttl" + "github.com/pkg/errors" +) + +func TestCachedResource(t *testing.T) { + refreshCalls := 0 + cacheTTL := 1*time.Second + 500*time.Millisecond + duration := 2 * time.Second + + expectedCalls := math.Ceil(float64(duration) / float64(cacheTTL)) + + resource := NewCachedResource( + ttl.NewCache( + memory.NewCache[string, string](), + memory.NewCache[string, time.Time](), + cacheTTL, + ), + func(ctx context.Context, key string) (string, error) { + refreshCalls++ + return "bar", nil + }, + ) + + concurrents := 50 + key := "foo" + + var wg sync.WaitGroup + + wg.Add(concurrents) + + for i := range concurrents { + go func(i int) { + done := time.After(duration) + + defer wg.Done() + for { + select { + case <-done: + return + default: + value, fresh, err := resource.Get(context.Background(), key) + if err != nil { + t.Errorf("%+v", errors.WithStack(err)) + } + t.Logf("resource retrieved for goroutine #%d: (%s, %s, %v)", i, key, value, fresh) + } + } + }(i) + } + + wg.Wait() + + if e, g := int(expectedCalls), refreshCalls; e != g { + t.Errorf("refreshCalls: expected '%d', got '%d'", e, g) + } +}