package oidc import ( "bytes" "context" "crypto/tls" "fmt" "net/http" "net/url" "slices" "strings" "text/template" "time" "forge.cadoles.com/Cadoles/go-proxy/wildcard" "forge.cadoles.com/cadoles/bouncer/internal/proxy/director" "forge.cadoles.com/cadoles/bouncer/internal/proxy/director/layer/authn" "forge.cadoles.com/cadoles/bouncer/internal/store" "github.com/coreos/go-oidc/v3/oidc" "github.com/gorilla/sessions" "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" "gitlab.com/wpetit/goweb/logger" ) type Authenticator struct { store sessions.Store httpTransport *http.Transport httpClientTimeout time.Duration } func (a *Authenticator) PreAuthentication(w http.ResponseWriter, r *http.Request, layer *store.Layer) error { ctx := r.Context() originalURL, err := director.OriginalURL(ctx) if err != nil { return errors.WithStack(err) } options, err := fromStoreOptions(layer.Options) if err != nil { return errors.WithStack(err) } sess, err := a.store.Get(r, a.getCookieName(options.Cookie.Name, layer.Proxy, layer.Name)) if err != nil { logger.Error(ctx, "could not retrieve session", logger.CapturedE(errors.WithStack(err))) } loginCallbackURL, err := a.getLoginCallbackURL(originalURL, layer.Proxy, layer.Name, options) if err != nil { return errors.WithStack(err) } client, err := a.getClient(options, loginCallbackURL.String()) if err != nil { return errors.WithStack(err) } loginCallbackPathPattern, err := a.templatize(options.OIDC.MatchLoginCallbackPath, layer.Proxy, layer.Name) if err != nil { return errors.WithStack(err) } logoutPathPattern, err := a.templatize(options.OIDC.MatchLogoutPath, layer.Proxy, layer.Name) if err != nil { return errors.WithStack(err) } logger.Debug(ctx, "checking url", logger.F("loginCallbackPathPattern", loginCallbackPathPattern), logger.F("logoutPathPattern", logoutPathPattern)) switch { case wildcard.Match(originalURL.Path, loginCallbackPathPattern): if err := client.HandleCallback(w, r, sess); err != nil { return errors.WithStack(err) } metricLoginSuccessesTotal.With(prometheus.Labels{ metricLabelLayer: string(layer.Name), metricLabelProxy: string(layer.Proxy), }).Add(1) case wildcard.Match(originalURL.Path, logoutPathPattern): postLogoutRedirectURL := r.URL.Query().Get("redirect") if postLogoutRedirectURL != "" { isAuthorized := slices.Contains(options.OIDC.PostLogoutRedirectURLs, postLogoutRedirectURL) if !isAuthorized { director.HandleError(ctx, w, r, http.StatusBadRequest, errors.New("unauthorized post-logout redirect")) return errors.WithStack(authn.ErrSkipRequest) } } if postLogoutRedirectURL == "" { if options.OIDC.PublicBaseURL != "" { postLogoutRedirectURL = options.OIDC.PublicBaseURL } else { postLogoutRedirectURL = originalURL.Scheme + "://" + originalURL.Host } } if err := client.HandleLogout(w, r, sess, postLogoutRedirectURL); err != nil { return errors.WithStack(err) } metricLogoutsTotal.With(prometheus.Labels{ metricLabelLayer: string(layer.Name), metricLabelProxy: string(layer.Proxy), }).Add(1) } return nil } // Authenticate implements authn.Authenticator. func (a *Authenticator) Authenticate(w http.ResponseWriter, r *http.Request, layer *store.Layer) (*authn.User, error) { ctx := r.Context() options, err := fromStoreOptions(layer.Options) if err != nil { return nil, errors.WithStack(err) } sess, err := a.store.Get(r, a.getCookieName(options.Cookie.Name, layer.Proxy, layer.Name)) if err != nil { return nil, errors.WithStack(err) } defer func() { if err := sess.Save(r, w); err != nil { logger.Error(ctx, "could not save session", logger.CapturedE(errors.WithStack(err))) } }() sess.Options.Domain = options.Cookie.Domain sess.Options.HttpOnly = options.Cookie.HTTPOnly sess.Options.MaxAge = int(options.Cookie.MaxAge.Seconds()) sess.Options.Path = options.Cookie.Path switch options.Cookie.SameSite { case "lax": sess.Options.SameSite = http.SameSiteLaxMode case "strict": sess.Options.SameSite = http.SameSiteStrictMode case "none": sess.Options.SameSite = http.SameSiteNoneMode default: sess.Options.SameSite = http.SameSiteDefaultMode } originalURL, err := director.OriginalURL(ctx) if err != nil { return nil, errors.WithStack(err) } loginCallbackURL, err := a.getLoginCallbackURL(originalURL, layer.Proxy, layer.Name, options) if err != nil { return nil, errors.WithStack(err) } client, err := a.getClient(options, loginCallbackURL.String()) if err != nil { return nil, errors.WithStack(err) } postLoginRedirectURL, err := a.mergeURL(originalURL, originalURL.Path, options.OIDC.PublicBaseURL, true) if err != nil { return nil, errors.WithStack(err) } idToken, err := client.Authenticate(w, r, sess, postLoginRedirectURL.String()) if err != nil { if errors.Is(err, ErrLoginRequired) { metricLoginRequestsTotal.With(prometheus.Labels{ metricLabelLayer: string(layer.Name), metricLabelProxy: string(layer.Proxy), }).Add(1) return nil, errors.WithStack(authn.ErrSkipRequest) } return nil, errors.WithStack(err) } user, err := a.toUser(originalURL, idToken, layer.Proxy, layer.Name, options, sess) if err != nil { return nil, errors.WithStack(err) } return user, nil } type claims struct { Issuer string `json:"iss"` Subject string `json:"sub"` Expiration int64 `json:"exp"` IssuedAt int64 `json:"iat"` AuthTime int64 `json:"auth_time"` Nonce string `json:"nonce"` ACR string `json:"acr"` AMR string `json:"amr"` AZP string `json:"amp"` Others map[string]any `json:"-"` } func (c claims) AsAttrs() map[string]any { attrs := make(map[string]any) for key, val := range c.Others { if val != nil { attrs["claim_"+key] = val } } attrs["claim_iss"] = c.Issuer attrs["claim_sub"] = c.Subject attrs["claim_exp"] = c.Expiration attrs["claim_iat"] = c.IssuedAt if c.AuthTime != 0 { attrs["claim_auth_time"] = c.AuthTime } if c.Nonce != "" { attrs["claim_nonce"] = c.Nonce } if c.ACR != "" { attrs["claim_arc"] = c.ACR } if c.AMR != "" { attrs["claim_amr"] = c.AMR } if c.AZP != "" { attrs["claim_azp"] = c.AZP } return attrs } func (a *Authenticator) toUser(originalURL *url.URL, idToken *oidc.IDToken, proxyName store.ProxyName, layerName store.LayerName, options *LayerOptions, sess *sessions.Session) (*authn.User, error) { var claims claims if err := idToken.Claims(&claims); err != nil { return nil, errors.WithStack(err) } if err := idToken.Claims(&claims.Others); err != nil { return nil, errors.WithStack(err) } attrs := claims.AsAttrs() logoutURL, err := a.getLogoutURL(originalURL, proxyName, layerName, options) if err != nil { return nil, errors.WithStack(err) } attrs["logout_url"] = logoutURL.String() if accessToken, exists := sess.Values[sessionKeyAccessToken]; exists && accessToken != nil { attrs["access_token"] = accessToken } if refreshToken, exists := sess.Values[sessionKeyRefreshToken]; exists && refreshToken != nil { attrs["refresh_token"] = refreshToken } if tokenExpiry, exists := sess.Values[sessionKeyTokenExpiry]; exists && tokenExpiry != nil { attrs["token_expiry"] = tokenExpiry } user := authn.NewUser(idToken.Subject, attrs) return user, nil } func (a *Authenticator) getLoginCallbackURL(originalURL *url.URL, proxyName store.ProxyName, layerName store.LayerName, options *LayerOptions) (*url.URL, error) { path, err := a.templatize(options.OIDC.LoginCallbackPath, proxyName, layerName) if err != nil { return nil, errors.WithStack(err) } merged, err := a.mergeURL(originalURL, path, options.OIDC.PublicBaseURL, false) if err != nil { return nil, errors.WithStack(err) } return merged, nil } func (a *Authenticator) getLogoutURL(originalURL *url.URL, proxyName store.ProxyName, layerName store.LayerName, options *LayerOptions) (*url.URL, error) { path, err := a.templatize(options.OIDC.LogoutPath, proxyName, layerName) if err != nil { return nil, errors.WithStack(err) } merged, err := a.mergeURL(originalURL, path, options.OIDC.PublicBaseURL, true) if err != nil { return nil, errors.WithStack(err) } return merged, nil } func (a *Authenticator) mergeURL(base *url.URL, path string, overlay string, withQuery bool) (*url.URL, error) { merged := &url.URL{ Scheme: base.Scheme, Host: base.Host, Path: path, } if withQuery { merged.RawQuery = base.RawQuery } if overlay != "" { overlayURL, err := url.Parse(overlay) if err != nil { return nil, errors.WithStack(err) } merged.Scheme = overlayURL.Scheme merged.Host = overlayURL.Host merged.Path = overlayURL.Path + strings.TrimPrefix(path, "/") for key, values := range overlayURL.Query() { query := merged.Query() for _, v := range values { query.Add(key, v) } merged.RawQuery = query.Encode() } } return merged, nil } func (a *Authenticator) templatize(rawTemplate string, proxyName store.ProxyName, layerName store.LayerName) (string, error) { tmpl, err := template.New("").Parse(rawTemplate) if err != nil { return "", errors.WithStack(err) } var raw bytes.Buffer err = tmpl.Execute(&raw, struct { ProxyName store.ProxyName LayerName store.LayerName }{ ProxyName: proxyName, LayerName: layerName, }) if err != nil { return "", errors.WithStack(err) } return raw.String(), nil } func (a *Authenticator) getClient(options *LayerOptions, redirectURL string) (*Client, error) { ctx := context.Background() transport := a.httpTransport.Clone() if options.OIDC.TLSInsecureSkipVerify { if transport.TLSClientConfig == nil { transport.TLSClientConfig = &tls.Config{} } transport.TLSClientConfig.InsecureSkipVerify = true } httpClient := &http.Client{ Timeout: a.httpClientTimeout, Transport: transport, } ctx = oidc.ClientContext(ctx, httpClient) if options.OIDC.SkipIssuerVerification { ctx = oidc.InsecureIssuerURLContext(ctx, options.OIDC.IssuerURL) } provider, err := oidc.NewProvider(ctx, options.OIDC.IssuerURL) if err != nil { return nil, errors.Wrap(err, "could not create oidc provider") } client := NewClient( WithCredentials(options.OIDC.ClientID, options.OIDC.ClientSecret), WithProvider(provider), WithRedirectURL(redirectURL), WithScopes(options.OIDC.Scopes...), WithAuthParams(options.OIDC.AuthParams), WithHTTPClient(httpClient), ) return client, nil } const defaultCookieNamePrefix = "_bouncer_authn_oidc" func (a *Authenticator) getCookieName(cookieName string, proxyName store.ProxyName, layerName store.LayerName) string { if cookieName != "" { return cookieName } return strings.ToLower(fmt.Sprintf("%s_%s_%s", defaultCookieNamePrefix, proxyName, layerName)) } var ( _ authn.PreAuthentication = &Authenticator{} _ authn.Authenticator = &Authenticator{} )