From 006f13bc7b880bd95a04d8e171647ad55c86e48d Mon Sep 17 00:00:00 2001 From: William Petit Date: Wed, 5 Apr 2023 15:19:22 +0200 Subject: [PATCH] feat(module,auth): dynamically define authentication cookie domain --- pkg/module/auth/http/local_handler.go | 42 ++++++++++++++++++--------- pkg/module/auth/http/options.go | 31 +++++++++++++------- 2 files changed, 49 insertions(+), 24 deletions(-) diff --git a/pkg/module/auth/http/local_handler.go b/pkg/module/auth/http/local_handler.go index 3b86c66..dd482f2 100644 --- a/pkg/module/auth/http/local_handler.go +++ b/pkg/module/auth/http/local_handler.go @@ -30,12 +30,12 @@ func init() { } type LocalHandler struct { - router chi.Router - algo jwa.KeyAlgorithm - key jwk.Key - cookieDomain string - cookieDuration time.Duration - accounts map[string]LocalAccount + router chi.Router + algo jwa.KeyAlgorithm + key jwk.Key + getCookieDomain GetCookieDomainFunc + cookieDuration time.Duration + accounts map[string]LocalAccount } func (h *LocalHandler) initRouter(prefix string) { @@ -118,10 +118,18 @@ func (h *LocalHandler) handleForm(w http.ResponseWriter, r *http.Request) { return } + cookieDomain, err := h.getCookieDomain(r) + if err != nil { + logger.Error(ctx, "could not retrieve cookie domain", logger.E(errors.WithStack(err))) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + + return + } + cookie := http.Cookie{ Name: auth.CookieName, Value: string(token), - Domain: h.cookieDomain, + Domain: cookieDomain, HttpOnly: false, Expires: time.Now().Add(h.cookieDuration), Path: "/", @@ -133,12 +141,20 @@ func (h *LocalHandler) handleForm(w http.ResponseWriter, r *http.Request) { } func (h *LocalHandler) handleLogout(w http.ResponseWriter, r *http.Request) { + cookieDomain, err := h.getCookieDomain(r) + if err != nil { + logger.Error(r.Context(), "could not retrieve cookie domain", logger.E(errors.WithStack(err))) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + + return + } + http.SetCookie(w, &http.Cookie{ Name: auth.CookieName, Value: "", HttpOnly: false, Expires: time.Unix(0, 0), - Domain: h.cookieDomain, + Domain: cookieDomain, Path: "/", }) @@ -170,11 +186,11 @@ func NewLocalHandler(algo jwa.KeyAlgorithm, key jwk.Key, funcs ...LocalHandlerOp } handler := &LocalHandler{ - algo: algo, - key: key, - accounts: toAccountsMap(opts.Accounts), - cookieDomain: opts.CookieDomain, - cookieDuration: opts.CookieDuration, + algo: algo, + key: key, + accounts: toAccountsMap(opts.Accounts), + getCookieDomain: opts.GetCookieDomain, + cookieDuration: opts.CookieDuration, } handler.initRouter(opts.RoutePrefix) diff --git a/pkg/module/auth/http/options.go b/pkg/module/auth/http/options.go index 84c6cf7..4f3fb4f 100644 --- a/pkg/module/auth/http/options.go +++ b/pkg/module/auth/http/options.go @@ -1,22 +1,31 @@ package http -import "time" +import ( + "net/http" + "time" +) + +type GetCookieDomainFunc func(r *http.Request) (string, error) + +func defaultGetCookieDomain(r *http.Request) (string, error) { + return "", nil +} type LocalHandlerOptions struct { - RoutePrefix string - Accounts []LocalAccount - CookieDomain string - CookieDuration time.Duration + RoutePrefix string + Accounts []LocalAccount + GetCookieDomain GetCookieDomainFunc + CookieDuration time.Duration } type LocalHandlerOptionFunc func(*LocalHandlerOptions) func defaultLocalHandlerOptions() *LocalHandlerOptions { return &LocalHandlerOptions{ - RoutePrefix: "", - Accounts: make([]LocalAccount, 0), - CookieDomain: "", - CookieDuration: 24 * time.Hour, + RoutePrefix: "", + Accounts: make([]LocalAccount, 0), + GetCookieDomain: defaultGetCookieDomain, + CookieDuration: 24 * time.Hour, } } @@ -32,9 +41,9 @@ func WithRoutePrefix(prefix string) LocalHandlerOptionFunc { } } -func WithCookieOptions(domain string, duration time.Duration) LocalHandlerOptionFunc { +func WithCookieOptions(getCookieDomain GetCookieDomainFunc, duration time.Duration) LocalHandlerOptionFunc { return func(opts *LocalHandlerOptions) { - opts.CookieDomain = domain + opts.GetCookieDomain = getCookieDomain opts.CookieDuration = duration } }