bouncer/internal/proxy/director/layer/authn/oidc/authenticator.go

279 lines
7.3 KiB
Go

package oidc
import (
"context"
"fmt"
"net/http"
"net/url"
"forge.cadoles.com/cadoles/bouncer/internal/proxy/director"
"forge.cadoles.com/cadoles/bouncer/internal/proxy/director/layer/authn"
"forge.cadoles.com/cadoles/bouncer/internal/store"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/gorilla/sessions"
"github.com/pkg/errors"
"github.com/prometheus/client_golang/prometheus"
"gitlab.com/wpetit/goweb/logger"
)
type Authenticator struct {
store sessions.Store
}
func (a *Authenticator) PreAuthentication(w http.ResponseWriter, r *http.Request, layer *store.Layer) error {
ctx := r.Context()
originalURL, err := director.OriginalURL(ctx)
if err != nil {
return errors.WithStack(err)
}
options, err := fromStoreOptions(layer.Options)
if err != nil {
return errors.WithStack(err)
}
sess, err := a.store.Get(r, a.getCookieName(options.Cookie.Name, layer.Name))
if err != nil {
logger.Error(ctx, "could not retrieve session", logger.E(errors.WithStack(err)))
}
redirectURL := a.getRedirectURL(layer.Proxy, layer.Name, originalURL, options)
logoutURL := a.getLogoutURL(layer.Proxy, layer.Name, originalURL, options)
client, err := a.getClient(options, redirectURL.String())
if err != nil {
return errors.WithStack(err)
}
switch r.URL.Path {
case redirectURL.Path:
if err := client.HandleCallback(w, r, sess); err != nil {
return errors.WithStack(err)
}
metricLoginSuccessesTotal.With(prometheus.Labels{
metricLabelLayer: string(layer.Name),
metricLabelProxy: string(layer.Proxy),
}).Add(1)
case logoutURL.Path:
postLogoutRedirectURL := options.OIDC.PostLogoutRedirectURL
if options.OIDC.PostLogoutRedirectURL == "" {
postLogoutRedirectURL = originalURL.Scheme + "://" + originalURL.Host
}
if err := client.HandleLogout(w, r, sess, postLogoutRedirectURL); err != nil {
return errors.WithStack(err)
}
metricLogoutsTotal.With(prometheus.Labels{
metricLabelLayer: string(layer.Name),
metricLabelProxy: string(layer.Proxy),
}).Add(1)
}
return nil
}
// Authenticate implements authn.Authenticator.
func (a *Authenticator) Authenticate(w http.ResponseWriter, r *http.Request, layer *store.Layer) (*authn.User, error) {
ctx := r.Context()
originalURL, err := director.OriginalURL(ctx)
if err != nil {
return nil, errors.WithStack(err)
}
options, err := fromStoreOptions(layer.Options)
if err != nil {
return nil, errors.WithStack(err)
}
sess, err := a.store.Get(r, a.getCookieName(options.Cookie.Name, layer.Name))
if err != nil {
return nil, errors.WithStack(err)
}
defer func() {
if err := sess.Save(r, w); err != nil {
logger.Error(ctx, "could not save session", logger.E(errors.WithStack(err)))
}
}()
sess.Options.Domain = options.Cookie.Domain
sess.Options.HttpOnly = options.Cookie.HTTPOnly
sess.Options.MaxAge = int(options.Cookie.MaxAge.Seconds())
sess.Options.Path = options.Cookie.Path
switch options.Cookie.SameSite {
case "lax":
sess.Options.SameSite = http.SameSiteLaxMode
case "strict":
sess.Options.SameSite = http.SameSiteStrictMode
case "none":
sess.Options.SameSite = http.SameSiteNoneMode
default:
sess.Options.SameSite = http.SameSiteDefaultMode
}
redirectURL := a.getRedirectURL(layer.Proxy, layer.Name, originalURL, options)
client, err := a.getClient(options, redirectURL.String())
if err != nil {
return nil, errors.WithStack(err)
}
idToken, err := client.Authenticate(w, r, sess)
if err != nil {
if errors.Is(err, ErrLoginRequired) {
metricLoginRequestsTotal.With(prometheus.Labels{
metricLabelLayer: string(layer.Name),
metricLabelProxy: string(layer.Proxy),
}).Add(1)
return nil, errors.WithStack(authn.ErrSkipRequest)
}
return nil, errors.WithStack(err)
}
user, err := a.toUser(idToken, layer.Proxy, layer.Name, originalURL, options, sess)
if err != nil {
return nil, errors.WithStack(err)
}
return user, nil
}
type claims struct {
Issuer string `json:"iss"`
Subject string `json:"sub"`
Expiration int64 `json:"exp"`
IssuedAt int64 `json:"iat"`
AuthTime int64 `json:"auth_time"`
Nonce string `json:"nonce"`
ACR string `json:"acr"`
AMR string `json:"amr"`
AZP string `json:"amp"`
Others map[string]any `json:"-"`
}
func (c claims) AsAttrs() map[string]any {
attrs := make(map[string]any)
for key, val := range c.Others {
if val != nil {
attrs["claim_"+key] = val
}
}
attrs["claim_iss"] = c.Issuer
attrs["claim_sub"] = c.Subject
attrs["claim_exp"] = c.Expiration
attrs["claim_iat"] = c.IssuedAt
if c.AuthTime != 0 {
attrs["claim_auth_time"] = c.AuthTime
}
if c.Nonce != "" {
attrs["claim_nonce"] = c.Nonce
}
if c.ACR != "" {
attrs["claim_arc"] = c.ACR
}
if c.AMR != "" {
attrs["claim_amr"] = c.AMR
}
if c.AZP != "" {
attrs["claim_azp"] = c.AZP
}
return attrs
}
func (a *Authenticator) toUser(idToken *oidc.IDToken, proxyName store.ProxyName, layerName store.LayerName, originalURL *url.URL, options *LayerOptions, sess *sessions.Session) (*authn.User, error) {
var claims claims
if err := idToken.Claims(&claims); err != nil {
return nil, errors.WithStack(err)
}
if err := idToken.Claims(&claims.Others); err != nil {
return nil, errors.WithStack(err)
}
attrs := claims.AsAttrs()
logoutURL := a.getLogoutURL(proxyName, layerName, originalURL, options)
attrs["logout_url"] = logoutURL.String()
if accessToken, exists := sess.Values[sessionKeyAccessToken]; exists && accessToken != nil {
attrs["access_token"] = accessToken
}
if refreshToken, exists := sess.Values[sessionKeyRefreshToken]; exists && refreshToken != nil {
attrs["refresh_token"] = refreshToken
}
if tokenExpiry, exists := sess.Values[sessionKeyTokenExpiry]; exists && tokenExpiry != nil {
attrs["token_expiry"] = tokenExpiry
}
user := authn.NewUser(idToken.Subject, attrs)
return user, nil
}
func (a *Authenticator) getRedirectURL(proxyName store.ProxyName, layerName store.LayerName, u *url.URL, options *LayerOptions) *url.URL {
return &url.URL{
Scheme: u.Scheme,
Host: u.Host,
Path: fmt.Sprintf(options.OIDC.LoginCallbackPath, fmt.Sprintf("%s/%s", proxyName, layerName)),
}
}
func (a *Authenticator) getLogoutURL(proxyName store.ProxyName, layerName store.LayerName, u *url.URL, options *LayerOptions) *url.URL {
return &url.URL{
Scheme: u.Scheme,
Host: u.Host,
Path: fmt.Sprintf(options.OIDC.LogoutPath, fmt.Sprintf("%s/%s", proxyName, layerName)),
}
}
func (a *Authenticator) getClient(options *LayerOptions, redirectURL string) (*Client, error) {
ctx := context.Background()
if options.OIDC.SkipIssuerVerification {
ctx = oidc.InsecureIssuerURLContext(ctx, options.OIDC.IssuerURL)
}
provider, err := oidc.NewProvider(ctx, options.OIDC.IssuerURL)
if err != nil {
return nil, errors.Wrap(err, "could not create oidc provider")
}
client := NewClient(
WithCredentials(options.OIDC.ClientID, options.OIDC.ClientSecret),
WithProvider(provider),
WithRedirectURL(redirectURL),
WithScopes(options.OIDC.Scopes...),
WithAuthParams(options.OIDC.AuthParams),
)
return client, nil
}
func (a *Authenticator) getCookieName(cookieName string, layerName store.LayerName) string {
return fmt.Sprintf("%s_%s", cookieName, layerName)
}
var (
_ authn.PreAuthentication = &Authenticator{}
_ authn.Authenticator = &Authenticator{}
)