package oidc import ( "context" "fmt" "net/http" "net/url" "forge.cadoles.com/cadoles/bouncer/internal/proxy/director" "forge.cadoles.com/cadoles/bouncer/internal/proxy/director/layer/authn" "forge.cadoles.com/cadoles/bouncer/internal/store" "github.com/coreos/go-oidc/v3/oidc" "github.com/gorilla/sessions" "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" "gitlab.com/wpetit/goweb/logger" ) type Authenticator struct { store sessions.Store } func (a *Authenticator) PreAuthentication(w http.ResponseWriter, r *http.Request, layer *store.Layer) error { ctx := r.Context() originalURL, err := director.OriginalURL(ctx) if err != nil { return errors.WithStack(err) } options, err := fromStoreOptions(layer.Options) if err != nil { return errors.WithStack(err) } sess, err := a.store.Get(r, a.getCookieName(options.Cookie.Name, layer.Name)) if err != nil { logger.Error(ctx, "could not retrieve session", logger.E(errors.WithStack(err))) } redirectURL := a.getRedirectURL(layer.Proxy, layer.Name, originalURL, options) logoutURL := a.getLogoutURL(layer.Proxy, layer.Name, originalURL, options) client, err := a.getClient(options, redirectURL.String()) if err != nil { return errors.WithStack(err) } switch r.URL.Path { case redirectURL.Path: if err := client.HandleCallback(w, r, sess); err != nil { return errors.WithStack(err) } metricLoginSuccessesTotal.With(prometheus.Labels{ metricLabelLayer: string(layer.Name), metricLabelProxy: string(layer.Proxy), }).Add(1) case logoutURL.Path: postLogoutRedirectURL := options.OIDC.PostLogoutRedirectURL if options.OIDC.PostLogoutRedirectURL == "" { postLogoutRedirectURL = originalURL.Scheme + "://" + originalURL.Host } if err := client.HandleLogout(w, r, sess, postLogoutRedirectURL); err != nil { return errors.WithStack(err) } metricLogoutsTotal.With(prometheus.Labels{ metricLabelLayer: string(layer.Name), metricLabelProxy: string(layer.Proxy), }).Add(1) } return nil } // Authenticate implements authn.Authenticator. func (a *Authenticator) Authenticate(w http.ResponseWriter, r *http.Request, layer *store.Layer) (*authn.User, error) { ctx := r.Context() originalURL, err := director.OriginalURL(ctx) if err != nil { return nil, errors.WithStack(err) } options, err := fromStoreOptions(layer.Options) if err != nil { return nil, errors.WithStack(err) } sess, err := a.store.Get(r, a.getCookieName(options.Cookie.Name, layer.Name)) if err != nil { return nil, errors.WithStack(err) } defer func() { if err := sess.Save(r, w); err != nil { logger.Error(ctx, "could not save session", logger.E(errors.WithStack(err))) } }() sess.Options.Domain = options.Cookie.Domain sess.Options.HttpOnly = options.Cookie.HTTPOnly sess.Options.MaxAge = int(options.Cookie.MaxAge.Seconds()) sess.Options.Path = options.Cookie.Path switch options.Cookie.SameSite { case "lax": sess.Options.SameSite = http.SameSiteLaxMode case "strict": sess.Options.SameSite = http.SameSiteStrictMode case "none": sess.Options.SameSite = http.SameSiteNoneMode default: sess.Options.SameSite = http.SameSiteDefaultMode } redirectURL := a.getRedirectURL(layer.Proxy, layer.Name, originalURL, options) client, err := a.getClient(options, redirectURL.String()) if err != nil { return nil, errors.WithStack(err) } idToken, err := client.Authenticate(w, r, sess) if err != nil { if errors.Is(err, ErrLoginRequired) { metricLoginRequestsTotal.With(prometheus.Labels{ metricLabelLayer: string(layer.Name), metricLabelProxy: string(layer.Proxy), }).Add(1) return nil, errors.WithStack(authn.ErrSkipRequest) } return nil, errors.WithStack(err) } user, err := a.toUser(idToken, layer.Proxy, layer.Name, originalURL, options, sess) if err != nil { return nil, errors.WithStack(err) } return user, nil } type claims struct { Issuer string `json:"iss"` Subject string `json:"sub"` Expiration int64 `json:"exp"` IssuedAt int64 `json:"iat"` AuthTime int64 `json:"auth_time"` Nonce string `json:"nonce"` ACR string `json:"acr"` AMR string `json:"amr"` AZP string `json:"amp"` Others map[string]any `json:"-"` } func (c claims) AsAttrs() map[string]any { attrs := make(map[string]any) for key, val := range c.Others { if val != nil { attrs["claim_"+key] = val } } attrs["claim_iss"] = c.Issuer attrs["claim_sub"] = c.Subject attrs["claim_exp"] = c.Expiration attrs["claim_iat"] = c.IssuedAt if c.AuthTime != 0 { attrs["claim_auth_time"] = c.AuthTime } if c.Nonce != "" { attrs["claim_nonce"] = c.Nonce } if c.ACR != "" { attrs["claim_arc"] = c.ACR } if c.AMR != "" { attrs["claim_amr"] = c.AMR } if c.AZP != "" { attrs["claim_azp"] = c.AZP } return attrs } func (a *Authenticator) toUser(idToken *oidc.IDToken, proxyName store.ProxyName, layerName store.LayerName, originalURL *url.URL, options *LayerOptions, sess *sessions.Session) (*authn.User, error) { var claims claims if err := idToken.Claims(&claims); err != nil { return nil, errors.WithStack(err) } if err := idToken.Claims(&claims.Others); err != nil { return nil, errors.WithStack(err) } attrs := claims.AsAttrs() logoutURL := a.getLogoutURL(proxyName, layerName, originalURL, options) attrs["logout_url"] = logoutURL.String() if accessToken, exists := sess.Values[sessionKeyAccessToken]; exists && accessToken != nil { attrs["access_token"] = accessToken } if refreshToken, exists := sess.Values[sessionKeyRefreshToken]; exists && refreshToken != nil { attrs["refresh_token"] = refreshToken } if tokenExpiry, exists := sess.Values[sessionKeyTokenExpiry]; exists && tokenExpiry != nil { attrs["token_expiry"] = tokenExpiry } user := authn.NewUser(idToken.Subject, attrs) return user, nil } func (a *Authenticator) getRedirectURL(proxyName store.ProxyName, layerName store.LayerName, u *url.URL, options *LayerOptions) *url.URL { return &url.URL{ Scheme: u.Scheme, Host: u.Host, Path: fmt.Sprintf(options.OIDC.LoginCallbackPath, fmt.Sprintf("%s/%s", proxyName, layerName)), } } func (a *Authenticator) getLogoutURL(proxyName store.ProxyName, layerName store.LayerName, u *url.URL, options *LayerOptions) *url.URL { return &url.URL{ Scheme: u.Scheme, Host: u.Host, Path: fmt.Sprintf(options.OIDC.LogoutPath, fmt.Sprintf("%s/%s", proxyName, layerName)), } } func (a *Authenticator) getClient(options *LayerOptions, redirectURL string) (*Client, error) { ctx := context.Background() if options.OIDC.SkipIssuerVerification { ctx = oidc.InsecureIssuerURLContext(ctx, options.OIDC.IssuerURL) } provider, err := oidc.NewProvider(ctx, options.OIDC.IssuerURL) if err != nil { return nil, errors.Wrap(err, "could not create oidc provider") } client := NewClient( WithCredentials(options.OIDC.ClientID, options.OIDC.ClientSecret), WithProvider(provider), WithRedirectURL(redirectURL), WithScopes(options.OIDC.Scopes...), WithAuthParams(options.OIDC.AuthParams), ) return client, nil } func (a *Authenticator) getCookieName(cookieName string, layerName store.LayerName) string { return fmt.Sprintf("%s_%s", cookieName, layerName) } var ( _ authn.PreAuthentication = &Authenticator{} _ authn.Authenticator = &Authenticator{} )