package oidc import ( "bytes" "context" "crypto/tls" "fmt" "net/http" "net/url" "text/template" "time" "forge.cadoles.com/Cadoles/go-proxy/wildcard" "forge.cadoles.com/cadoles/bouncer/internal/proxy/director" "forge.cadoles.com/cadoles/bouncer/internal/proxy/director/layer/authn" "forge.cadoles.com/cadoles/bouncer/internal/store" "github.com/coreos/go-oidc/v3/oidc" "github.com/gorilla/sessions" "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" "gitlab.com/wpetit/goweb/logger" ) type Authenticator struct { store sessions.Store httpTransport *http.Transport httpClientTimeout time.Duration } func (a *Authenticator) PreAuthentication(w http.ResponseWriter, r *http.Request, layer *store.Layer) error { ctx := r.Context() originalURL, err := director.OriginalURL(ctx) if err != nil { return errors.WithStack(err) } baseURL := originalURL.Scheme + "://" + originalURL.Host options, err := fromStoreOptions(layer.Options, baseURL) if err != nil { return errors.WithStack(err) } sess, err := a.store.Get(r, a.getCookieName(options.Cookie.Name, layer.Name)) if err != nil { logger.Error(ctx, "could not retrieve session", logger.E(errors.WithStack(err))) } loginCallbackURL, err := a.getLoginCallbackURL(layer.Proxy, layer.Name, options) if err != nil { return errors.WithStack(err) } client, err := a.getClient(options, loginCallbackURL.String()) if err != nil { return errors.WithStack(err) } loginCallbackURLPattern, err := a.templatize(options.OIDC.MatchLoginCallbackURL, layer.Proxy, layer.Name) if err != nil { return errors.WithStack(err) } logoutURLPattern, err := a.templatize(options.OIDC.MatchLogoutURL, layer.Proxy, layer.Name) if err != nil { return errors.WithStack(err) } logger.Debug(ctx, "checking url", logger.F("loginCallbackURLPattern", loginCallbackURLPattern), logger.F("logoutURLPattern", logoutURLPattern)) switch { case wildcard.Match(originalURL.String(), loginCallbackURLPattern): if err := client.HandleCallback(w, r, sess); err != nil { return errors.WithStack(err) } metricLoginSuccessesTotal.With(prometheus.Labels{ metricLabelLayer: string(layer.Name), metricLabelProxy: string(layer.Proxy), }).Add(1) case wildcard.Match(originalURL.String(), logoutURLPattern): postLogoutRedirectURL := options.OIDC.PostLogoutRedirectURL if options.OIDC.PostLogoutRedirectURL == "" { postLogoutRedirectURL = originalURL.Scheme + "://" + originalURL.Host } if err := client.HandleLogout(w, r, sess, postLogoutRedirectURL); err != nil { return errors.WithStack(err) } metricLogoutsTotal.With(prometheus.Labels{ metricLabelLayer: string(layer.Name), metricLabelProxy: string(layer.Proxy), }).Add(1) } return nil } // Authenticate implements authn.Authenticator. func (a *Authenticator) Authenticate(w http.ResponseWriter, r *http.Request, layer *store.Layer) (*authn.User, error) { ctx := r.Context() originalURL, err := director.OriginalURL(ctx) if err != nil { return nil, errors.WithStack(err) } baseURL := originalURL.Scheme + "://" + originalURL.Host options, err := fromStoreOptions(layer.Options, baseURL) if err != nil { return nil, errors.WithStack(err) } sess, err := a.store.Get(r, a.getCookieName(options.Cookie.Name, layer.Name)) if err != nil { return nil, errors.WithStack(err) } defer func() { if err := sess.Save(r, w); err != nil { logger.Error(ctx, "could not save session", logger.E(errors.WithStack(err))) } }() sess.Options.Domain = options.Cookie.Domain sess.Options.HttpOnly = options.Cookie.HTTPOnly sess.Options.MaxAge = int(options.Cookie.MaxAge.Seconds()) sess.Options.Path = options.Cookie.Path switch options.Cookie.SameSite { case "lax": sess.Options.SameSite = http.SameSiteLaxMode case "strict": sess.Options.SameSite = http.SameSiteStrictMode case "none": sess.Options.SameSite = http.SameSiteNoneMode default: sess.Options.SameSite = http.SameSiteDefaultMode } loginCallbackURL, err := a.getLoginCallbackURL(layer.Proxy, layer.Name, options) if err != nil { return nil, errors.WithStack(err) } client, err := a.getClient(options, loginCallbackURL.String()) if err != nil { return nil, errors.WithStack(err) } idToken, err := client.Authenticate(w, r, sess) if err != nil { if errors.Is(err, ErrLoginRequired) { metricLoginRequestsTotal.With(prometheus.Labels{ metricLabelLayer: string(layer.Name), metricLabelProxy: string(layer.Proxy), }).Add(1) return nil, errors.WithStack(authn.ErrSkipRequest) } return nil, errors.WithStack(err) } user, err := a.toUser(idToken, layer.Proxy, layer.Name, options, sess) if err != nil { return nil, errors.WithStack(err) } return user, nil } type claims struct { Issuer string `json:"iss"` Subject string `json:"sub"` Expiration int64 `json:"exp"` IssuedAt int64 `json:"iat"` AuthTime int64 `json:"auth_time"` Nonce string `json:"nonce"` ACR string `json:"acr"` AMR string `json:"amr"` AZP string `json:"amp"` Others map[string]any `json:"-"` } func (c claims) AsAttrs() map[string]any { attrs := make(map[string]any) for key, val := range c.Others { if val != nil { attrs["claim_"+key] = val } } attrs["claim_iss"] = c.Issuer attrs["claim_sub"] = c.Subject attrs["claim_exp"] = c.Expiration attrs["claim_iat"] = c.IssuedAt if c.AuthTime != 0 { attrs["claim_auth_time"] = c.AuthTime } if c.Nonce != "" { attrs["claim_nonce"] = c.Nonce } if c.ACR != "" { attrs["claim_arc"] = c.ACR } if c.AMR != "" { attrs["claim_amr"] = c.AMR } if c.AZP != "" { attrs["claim_azp"] = c.AZP } return attrs } func (a *Authenticator) toUser(idToken *oidc.IDToken, proxyName store.ProxyName, layerName store.LayerName, options *LayerOptions, sess *sessions.Session) (*authn.User, error) { var claims claims if err := idToken.Claims(&claims); err != nil { return nil, errors.WithStack(err) } if err := idToken.Claims(&claims.Others); err != nil { return nil, errors.WithStack(err) } attrs := claims.AsAttrs() logoutURL, err := a.getLogoutURL(proxyName, layerName, options) if err != nil { return nil, errors.WithStack(err) } attrs["logout_url"] = logoutURL.String() if accessToken, exists := sess.Values[sessionKeyAccessToken]; exists && accessToken != nil { attrs["access_token"] = accessToken } if refreshToken, exists := sess.Values[sessionKeyRefreshToken]; exists && refreshToken != nil { attrs["refresh_token"] = refreshToken } if tokenExpiry, exists := sess.Values[sessionKeyTokenExpiry]; exists && tokenExpiry != nil { attrs["token_expiry"] = tokenExpiry } user := authn.NewUser(idToken.Subject, attrs) return user, nil } func (a *Authenticator) getLoginCallbackURL(proxyName store.ProxyName, layerName store.LayerName, options *LayerOptions) (*url.URL, error) { url, err := a.generateURL(options.OIDC.LoginCallbackURL, proxyName, layerName) if err != nil { return nil, errors.WithStack(err) } return url, nil } func (a *Authenticator) getLogoutURL(proxyName store.ProxyName, layerName store.LayerName, options *LayerOptions) (*url.URL, error) { url, err := a.generateURL(options.OIDC.LogoutURL, proxyName, layerName) if err != nil { return nil, errors.WithStack(err) } return url, nil } func (a *Authenticator) generateURL(rawURLTemplate string, proxyName store.ProxyName, layerName store.LayerName) (*url.URL, error) { rawURL, err := a.templatize(rawURLTemplate, proxyName, layerName) if err != nil { return nil, errors.WithStack(err) } url, err := url.Parse(rawURL) if err != nil { return nil, errors.WithStack(err) } return url, nil } func (a *Authenticator) templatize(rawTemplate string, proxyName store.ProxyName, layerName store.LayerName) (string, error) { tmpl, err := template.New("").Parse(rawTemplate) if err != nil { return "", errors.WithStack(err) } var raw bytes.Buffer err = tmpl.Execute(&raw, struct { ProxyName store.ProxyName LayerName store.LayerName }{ ProxyName: proxyName, LayerName: layerName, }) if err != nil { return "", errors.WithStack(err) } return raw.String(), nil } func (a *Authenticator) getClient(options *LayerOptions, redirectURL string) (*Client, error) { ctx := context.Background() transport := a.httpTransport.Clone() if options.OIDC.TLSInsecureSkipVerify { if transport.TLSClientConfig == nil { transport.TLSClientConfig = &tls.Config{} } transport.TLSClientConfig.InsecureSkipVerify = true } httpClient := &http.Client{ Timeout: a.httpClientTimeout, Transport: transport, } ctx = oidc.ClientContext(ctx, httpClient) if options.OIDC.SkipIssuerVerification { ctx = oidc.InsecureIssuerURLContext(ctx, options.OIDC.IssuerURL) } provider, err := oidc.NewProvider(ctx, options.OIDC.IssuerURL) if err != nil { return nil, errors.Wrap(err, "could not create oidc provider") } client := NewClient( WithCredentials(options.OIDC.ClientID, options.OIDC.ClientSecret), WithProvider(provider), WithRedirectURL(redirectURL), WithScopes(options.OIDC.Scopes...), WithAuthParams(options.OIDC.AuthParams), WithHTTPClient(httpClient), ) return client, nil } func (a *Authenticator) getCookieName(cookieName string, layerName store.LayerName) string { return fmt.Sprintf("%s_%s", cookieName, layerName) } var ( _ authn.PreAuthentication = &Authenticator{} _ authn.Authenticator = &Authenticator{} )