367 lines
9.4 KiB
Go
367 lines
9.4 KiB
Go
package oidc
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/tls"
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
"text/template"
|
|
"time"
|
|
|
|
"forge.cadoles.com/Cadoles/go-proxy/wildcard"
|
|
"forge.cadoles.com/cadoles/bouncer/internal/proxy/director"
|
|
"forge.cadoles.com/cadoles/bouncer/internal/proxy/director/layer/authn"
|
|
"forge.cadoles.com/cadoles/bouncer/internal/store"
|
|
"github.com/coreos/go-oidc/v3/oidc"
|
|
"github.com/gorilla/sessions"
|
|
"github.com/pkg/errors"
|
|
"github.com/prometheus/client_golang/prometheus"
|
|
"gitlab.com/wpetit/goweb/logger"
|
|
)
|
|
|
|
type Authenticator struct {
|
|
store sessions.Store
|
|
httpTransport *http.Transport
|
|
httpClientTimeout time.Duration
|
|
}
|
|
|
|
func (a *Authenticator) PreAuthentication(w http.ResponseWriter, r *http.Request, layer *store.Layer) error {
|
|
ctx := r.Context()
|
|
|
|
originalURL, err := director.OriginalURL(ctx)
|
|
if err != nil {
|
|
return errors.WithStack(err)
|
|
}
|
|
|
|
baseURL := originalURL.Scheme + "://" + originalURL.Host
|
|
|
|
options, err := fromStoreOptions(layer.Options, baseURL)
|
|
if err != nil {
|
|
return errors.WithStack(err)
|
|
}
|
|
|
|
sess, err := a.store.Get(r, a.getCookieName(options.Cookie.Name, layer.Name))
|
|
if err != nil {
|
|
logger.Error(ctx, "could not retrieve session", logger.E(errors.WithStack(err)))
|
|
}
|
|
|
|
loginCallbackURL, err := a.getLoginCallbackURL(layer.Proxy, layer.Name, options)
|
|
if err != nil {
|
|
return errors.WithStack(err)
|
|
}
|
|
|
|
client, err := a.getClient(options, loginCallbackURL.String())
|
|
if err != nil {
|
|
return errors.WithStack(err)
|
|
}
|
|
|
|
loginCallbackURLPattern, err := a.templatize(options.OIDC.MatchLoginCallbackURL, layer.Proxy, layer.Name)
|
|
if err != nil {
|
|
return errors.WithStack(err)
|
|
}
|
|
|
|
logoutURLPattern, err := a.templatize(options.OIDC.MatchLogoutURL, layer.Proxy, layer.Name)
|
|
if err != nil {
|
|
return errors.WithStack(err)
|
|
}
|
|
|
|
logger.Debug(ctx, "checking url", logger.F("loginCallbackURLPattern", loginCallbackURLPattern), logger.F("logoutURLPattern", logoutURLPattern))
|
|
|
|
switch {
|
|
case wildcard.Match(originalURL.String(), loginCallbackURLPattern):
|
|
if err := client.HandleCallback(w, r, sess); err != nil {
|
|
return errors.WithStack(err)
|
|
}
|
|
|
|
metricLoginSuccessesTotal.With(prometheus.Labels{
|
|
metricLabelLayer: string(layer.Name),
|
|
metricLabelProxy: string(layer.Proxy),
|
|
}).Add(1)
|
|
|
|
case wildcard.Match(originalURL.String(), logoutURLPattern):
|
|
postLogoutRedirectURL := options.OIDC.PostLogoutRedirectURL
|
|
if options.OIDC.PostLogoutRedirectURL == "" {
|
|
postLogoutRedirectURL = originalURL.Scheme + "://" + originalURL.Host
|
|
}
|
|
|
|
if err := client.HandleLogout(w, r, sess, postLogoutRedirectURL); err != nil {
|
|
return errors.WithStack(err)
|
|
}
|
|
|
|
metricLogoutsTotal.With(prometheus.Labels{
|
|
metricLabelLayer: string(layer.Name),
|
|
metricLabelProxy: string(layer.Proxy),
|
|
}).Add(1)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// Authenticate implements authn.Authenticator.
|
|
func (a *Authenticator) Authenticate(w http.ResponseWriter, r *http.Request, layer *store.Layer) (*authn.User, error) {
|
|
ctx := r.Context()
|
|
|
|
originalURL, err := director.OriginalURL(ctx)
|
|
if err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
|
|
baseURL := originalURL.Scheme + "://" + originalURL.Host
|
|
|
|
options, err := fromStoreOptions(layer.Options, baseURL)
|
|
if err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
|
|
sess, err := a.store.Get(r, a.getCookieName(options.Cookie.Name, layer.Name))
|
|
if err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
|
|
defer func() {
|
|
if err := sess.Save(r, w); err != nil {
|
|
logger.Error(ctx, "could not save session", logger.E(errors.WithStack(err)))
|
|
}
|
|
}()
|
|
|
|
sess.Options.Domain = options.Cookie.Domain
|
|
sess.Options.HttpOnly = options.Cookie.HTTPOnly
|
|
sess.Options.MaxAge = int(options.Cookie.MaxAge.Seconds())
|
|
sess.Options.Path = options.Cookie.Path
|
|
|
|
switch options.Cookie.SameSite {
|
|
case "lax":
|
|
sess.Options.SameSite = http.SameSiteLaxMode
|
|
case "strict":
|
|
sess.Options.SameSite = http.SameSiteStrictMode
|
|
case "none":
|
|
sess.Options.SameSite = http.SameSiteNoneMode
|
|
default:
|
|
sess.Options.SameSite = http.SameSiteDefaultMode
|
|
}
|
|
|
|
loginCallbackURL, err := a.getLoginCallbackURL(layer.Proxy, layer.Name, options)
|
|
if err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
|
|
client, err := a.getClient(options, loginCallbackURL.String())
|
|
if err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
|
|
idToken, err := client.Authenticate(w, r, sess)
|
|
if err != nil {
|
|
if errors.Is(err, ErrLoginRequired) {
|
|
metricLoginRequestsTotal.With(prometheus.Labels{
|
|
metricLabelLayer: string(layer.Name),
|
|
metricLabelProxy: string(layer.Proxy),
|
|
}).Add(1)
|
|
|
|
return nil, errors.WithStack(authn.ErrSkipRequest)
|
|
}
|
|
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
|
|
user, err := a.toUser(idToken, layer.Proxy, layer.Name, options, sess)
|
|
if err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
|
|
return user, nil
|
|
}
|
|
|
|
type claims struct {
|
|
Issuer string `json:"iss"`
|
|
Subject string `json:"sub"`
|
|
Expiration int64 `json:"exp"`
|
|
IssuedAt int64 `json:"iat"`
|
|
AuthTime int64 `json:"auth_time"`
|
|
Nonce string `json:"nonce"`
|
|
ACR string `json:"acr"`
|
|
AMR string `json:"amr"`
|
|
AZP string `json:"amp"`
|
|
Others map[string]any `json:"-"`
|
|
}
|
|
|
|
func (c claims) AsAttrs() map[string]any {
|
|
attrs := make(map[string]any)
|
|
|
|
for key, val := range c.Others {
|
|
if val != nil {
|
|
attrs["claim_"+key] = val
|
|
}
|
|
}
|
|
|
|
attrs["claim_iss"] = c.Issuer
|
|
attrs["claim_sub"] = c.Subject
|
|
attrs["claim_exp"] = c.Expiration
|
|
attrs["claim_iat"] = c.IssuedAt
|
|
|
|
if c.AuthTime != 0 {
|
|
attrs["claim_auth_time"] = c.AuthTime
|
|
}
|
|
|
|
if c.Nonce != "" {
|
|
attrs["claim_nonce"] = c.Nonce
|
|
}
|
|
|
|
if c.ACR != "" {
|
|
attrs["claim_arc"] = c.ACR
|
|
}
|
|
|
|
if c.AMR != "" {
|
|
attrs["claim_amr"] = c.AMR
|
|
}
|
|
|
|
if c.AZP != "" {
|
|
attrs["claim_azp"] = c.AZP
|
|
}
|
|
|
|
return attrs
|
|
}
|
|
|
|
func (a *Authenticator) toUser(idToken *oidc.IDToken, proxyName store.ProxyName, layerName store.LayerName, options *LayerOptions, sess *sessions.Session) (*authn.User, error) {
|
|
var claims claims
|
|
|
|
if err := idToken.Claims(&claims); err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
|
|
if err := idToken.Claims(&claims.Others); err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
|
|
attrs := claims.AsAttrs()
|
|
|
|
logoutURL, err := a.getLogoutURL(proxyName, layerName, options)
|
|
if err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
|
|
attrs["logout_url"] = logoutURL.String()
|
|
|
|
if accessToken, exists := sess.Values[sessionKeyAccessToken]; exists && accessToken != nil {
|
|
attrs["access_token"] = accessToken
|
|
}
|
|
|
|
if refreshToken, exists := sess.Values[sessionKeyRefreshToken]; exists && refreshToken != nil {
|
|
attrs["refresh_token"] = refreshToken
|
|
}
|
|
|
|
if tokenExpiry, exists := sess.Values[sessionKeyTokenExpiry]; exists && tokenExpiry != nil {
|
|
attrs["token_expiry"] = tokenExpiry
|
|
}
|
|
|
|
user := authn.NewUser(idToken.Subject, attrs)
|
|
|
|
return user, nil
|
|
}
|
|
|
|
func (a *Authenticator) getLoginCallbackURL(proxyName store.ProxyName, layerName store.LayerName, options *LayerOptions) (*url.URL, error) {
|
|
url, err := a.generateURL(options.OIDC.LoginCallbackURL, proxyName, layerName)
|
|
if err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
|
|
return url, nil
|
|
}
|
|
|
|
func (a *Authenticator) getLogoutURL(proxyName store.ProxyName, layerName store.LayerName, options *LayerOptions) (*url.URL, error) {
|
|
url, err := a.generateURL(options.OIDC.LogoutURL, proxyName, layerName)
|
|
if err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
|
|
return url, nil
|
|
}
|
|
|
|
func (a *Authenticator) generateURL(rawURLTemplate string, proxyName store.ProxyName, layerName store.LayerName) (*url.URL, error) {
|
|
rawURL, err := a.templatize(rawURLTemplate, proxyName, layerName)
|
|
if err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
|
|
url, err := url.Parse(rawURL)
|
|
if err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
|
|
return url, nil
|
|
}
|
|
|
|
func (a *Authenticator) templatize(rawTemplate string, proxyName store.ProxyName, layerName store.LayerName) (string, error) {
|
|
tmpl, err := template.New("").Parse(rawTemplate)
|
|
if err != nil {
|
|
return "", errors.WithStack(err)
|
|
}
|
|
|
|
var raw bytes.Buffer
|
|
|
|
err = tmpl.Execute(&raw, struct {
|
|
ProxyName store.ProxyName
|
|
LayerName store.LayerName
|
|
}{
|
|
ProxyName: proxyName,
|
|
LayerName: layerName,
|
|
})
|
|
if err != nil {
|
|
return "", errors.WithStack(err)
|
|
}
|
|
|
|
return raw.String(), nil
|
|
}
|
|
|
|
func (a *Authenticator) getClient(options *LayerOptions, redirectURL string) (*Client, error) {
|
|
ctx := context.Background()
|
|
|
|
transport := a.httpTransport.Clone()
|
|
|
|
if options.OIDC.TLSInsecureSkipVerify {
|
|
if transport.TLSClientConfig == nil {
|
|
transport.TLSClientConfig = &tls.Config{}
|
|
}
|
|
|
|
transport.TLSClientConfig.InsecureSkipVerify = true
|
|
}
|
|
|
|
httpClient := &http.Client{
|
|
Timeout: a.httpClientTimeout,
|
|
Transport: transport,
|
|
}
|
|
|
|
ctx = oidc.ClientContext(ctx, httpClient)
|
|
|
|
if options.OIDC.SkipIssuerVerification {
|
|
ctx = oidc.InsecureIssuerURLContext(ctx, options.OIDC.IssuerURL)
|
|
}
|
|
|
|
provider, err := oidc.NewProvider(ctx, options.OIDC.IssuerURL)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "could not create oidc provider")
|
|
}
|
|
|
|
client := NewClient(
|
|
WithCredentials(options.OIDC.ClientID, options.OIDC.ClientSecret),
|
|
WithProvider(provider),
|
|
WithRedirectURL(redirectURL),
|
|
WithScopes(options.OIDC.Scopes...),
|
|
WithAuthParams(options.OIDC.AuthParams),
|
|
WithHTTPClient(httpClient),
|
|
)
|
|
|
|
return client, nil
|
|
}
|
|
|
|
func (a *Authenticator) getCookieName(cookieName string, layerName store.LayerName) string {
|
|
return fmt.Sprintf("%s_%s", cookieName, layerName)
|
|
}
|
|
|
|
var (
|
|
_ authn.PreAuthentication = &Authenticator{}
|
|
_ authn.Authenticator = &Authenticator{}
|
|
)
|