feat: new openid connect authentication layer
Some checks are pending
Cadoles/bouncer/pipeline/pr-develop Build started...
Some checks are pending
Cadoles/bouncer/pipeline/pr-develop Build started...
This commit is contained in:
291
internal/proxy/director/layer/authn/oidc/client.go
Normal file
291
internal/proxy/director/layer/authn/oidc/client.go
Normal file
@ -0,0 +1,291 @@
|
||||
package oidc
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"forge.cadoles.com/cadoles/bouncer/internal/proxy/director"
|
||||
"github.com/coreos/go-oidc/v3/oidc"
|
||||
"github.com/dchest/uniuri"
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/pkg/errors"
|
||||
"gitlab.com/wpetit/goweb/logger"
|
||||
"golang.org/x/oauth2"
|
||||
)
|
||||
|
||||
const (
|
||||
sessionKeyAccessToken = "access-token"
|
||||
sessionKeyRefreshToken = "refresh-token"
|
||||
sessionKeyTokenExpiry = "token-expiry"
|
||||
sessionKeyIDToken = "id-token"
|
||||
sessionKeyPostLoginRedirectURL = "post-login-redirect-url"
|
||||
sessionKeyLoginState = "login-state"
|
||||
sessionKeyLoginNonce = "login-nonce"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrLoginRequired = errors.New("login required")
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
oauth2 *oauth2.Config
|
||||
provider *oidc.Provider
|
||||
verifier *oidc.IDTokenVerifier
|
||||
authParams map[string]string
|
||||
}
|
||||
|
||||
func (c *Client) Verifier() *oidc.IDTokenVerifier {
|
||||
return c.verifier
|
||||
}
|
||||
|
||||
func (c *Client) Provider() *oidc.Provider {
|
||||
return c.provider
|
||||
}
|
||||
|
||||
func (c *Client) Authenticate(w http.ResponseWriter, r *http.Request, sess *sessions.Session) (*oidc.IDToken, error) {
|
||||
idToken, err := c.getIDToken(r, sess)
|
||||
if err != nil {
|
||||
logger.Error(r.Context(), "could not retrieve idtoken", logger.E(errors.WithStack(err)))
|
||||
|
||||
c.login(w, r, sess)
|
||||
|
||||
return nil, errors.WithStack(ErrLoginRequired)
|
||||
}
|
||||
|
||||
return idToken, nil
|
||||
}
|
||||
|
||||
func (c *Client) login(w http.ResponseWriter, r *http.Request, sess *sessions.Session) {
|
||||
ctx := r.Context()
|
||||
|
||||
state := uniuri.New()
|
||||
nonce := uniuri.New()
|
||||
|
||||
originalURL, err := director.OriginalURL(ctx)
|
||||
if err != nil {
|
||||
logger.Error(ctx, "could not retrieve original url", logger.E(errors.WithStack(err)))
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
sess.Values[sessionKeyLoginState] = state
|
||||
sess.Values[sessionKeyLoginNonce] = nonce
|
||||
sess.Values[sessionKeyPostLoginRedirectURL] = originalURL.String()
|
||||
|
||||
if err := sess.Save(r, w); err != nil {
|
||||
logger.Error(ctx, "could not save session", logger.E(errors.WithStack(err)))
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
authCodeOptions := []oauth2.AuthCodeOption{}
|
||||
authCodeOptions = append(authCodeOptions, oidc.Nonce(nonce))
|
||||
|
||||
for key, val := range c.authParams {
|
||||
authCodeOptions = append(authCodeOptions, oauth2.SetAuthURLParam(key, val))
|
||||
}
|
||||
|
||||
authCodeURL := c.oauth2.AuthCodeURL(
|
||||
state,
|
||||
authCodeOptions...,
|
||||
)
|
||||
|
||||
http.Redirect(w, r, authCodeURL, http.StatusFound)
|
||||
}
|
||||
|
||||
func (c *Client) HandleCallback(w http.ResponseWriter, r *http.Request, sess *sessions.Session) error {
|
||||
token, _, rawIDToken, err := c.validate(r, sess)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "could not validate oidc token")
|
||||
}
|
||||
|
||||
sess.Values[sessionKeyIDToken] = rawIDToken
|
||||
sess.Values[sessionKeyAccessToken] = token.AccessToken
|
||||
sess.Values[sessionKeyRefreshToken] = token.RefreshToken
|
||||
sess.Values[sessionKeyTokenExpiry] = token.Expiry.UTC().Unix()
|
||||
|
||||
if err := sess.Save(r, w); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
rawPostLoginRedirectURL, exists := sess.Values[sessionKeyPostLoginRedirectURL]
|
||||
if !exists {
|
||||
return errors.Wrap(err, "could not find post login redirect url")
|
||||
}
|
||||
|
||||
postLoginRedirectURL, ok := rawPostLoginRedirectURL.(string)
|
||||
if !ok {
|
||||
return errors.Wrapf(err, "unexpected value '%v' for post login redirect url", rawPostLoginRedirectURL)
|
||||
}
|
||||
|
||||
http.Redirect(w, r, postLoginRedirectURL, http.StatusTemporaryRedirect)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) HandleLogout(w http.ResponseWriter, r *http.Request, sess *sessions.Session, postLogoutRedirectURL string) error {
|
||||
state := uniuri.New()
|
||||
sess.Values[sessionKeyLoginState] = state
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
rawIDToken, err := c.getRawIDToken(sess)
|
||||
if err != nil {
|
||||
logger.Error(ctx, "could not retrieve raw id token", logger.E(errors.WithStack(err)))
|
||||
}
|
||||
|
||||
sess.Values[sessionKeyIDToken] = nil
|
||||
sess.Values[sessionKeyAccessToken] = nil
|
||||
sess.Values[sessionKeyRefreshToken] = nil
|
||||
sess.Values[sessionKeyTokenExpiry] = nil
|
||||
sess.Options.MaxAge = -1
|
||||
|
||||
if err := sess.Save(r, w); err != nil {
|
||||
return errors.Wrap(err, "could not save session")
|
||||
}
|
||||
|
||||
if rawIDToken == "" {
|
||||
http.Redirect(w, r, postLogoutRedirectURL, http.StatusFound)
|
||||
return nil
|
||||
}
|
||||
|
||||
sessionEndURL, err := c.sessionEndURL(rawIDToken, state, postLogoutRedirectURL)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "could not retrieve session end url")
|
||||
}
|
||||
|
||||
if sessionEndURL != "" {
|
||||
http.Redirect(w, r, sessionEndURL, http.StatusFound)
|
||||
} else {
|
||||
http.Redirect(w, r, postLogoutRedirectURL, http.StatusFound)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) sessionEndURL(idTokenHint, state, postLogoutRedirectURL string) (string, error) {
|
||||
sessionEndEndpoint := &struct {
|
||||
URL string `json:"end_session_endpoint"`
|
||||
}{}
|
||||
|
||||
if err := c.provider.Claims(&sessionEndEndpoint); err != nil {
|
||||
return "", errors.Wrap(err, "could not unmarshal claims")
|
||||
}
|
||||
|
||||
if sessionEndEndpoint.URL == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
buf.WriteString(sessionEndEndpoint.URL)
|
||||
|
||||
v := url.Values{}
|
||||
|
||||
if idTokenHint != "" {
|
||||
v.Set("id_token_hint", idTokenHint)
|
||||
}
|
||||
|
||||
if postLogoutRedirectURL != "" {
|
||||
v.Set("post_logout_redirect_uri", postLogoutRedirectURL)
|
||||
}
|
||||
|
||||
if state != "" {
|
||||
v.Set("state", state)
|
||||
}
|
||||
|
||||
if strings.Contains(sessionEndEndpoint.URL, "?") {
|
||||
buf.WriteByte('&')
|
||||
} else {
|
||||
buf.WriteByte('?')
|
||||
}
|
||||
|
||||
buf.WriteString(v.Encode())
|
||||
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
func (c *Client) validate(r *http.Request, sess *sessions.Session) (*oauth2.Token, *oidc.IDToken, string, error) {
|
||||
ctx := r.Context()
|
||||
|
||||
rawStoredState := sess.Values[sessionKeyLoginState]
|
||||
receivedState := r.URL.Query().Get("state")
|
||||
|
||||
storedState, ok := rawStoredState.(string)
|
||||
if !ok {
|
||||
return nil, nil, "", errors.New("could not find state in session")
|
||||
}
|
||||
|
||||
if receivedState != storedState {
|
||||
return nil, nil, "", errors.New("state mismatch")
|
||||
}
|
||||
|
||||
code := r.URL.Query().Get("code")
|
||||
|
||||
token, err := c.oauth2.Exchange(ctx, code)
|
||||
if err != nil {
|
||||
return nil, nil, "", errors.Wrap(err, "could not exchange token")
|
||||
}
|
||||
|
||||
rawIDToken, ok := token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
return nil, nil, "", errors.New("could not find id token")
|
||||
}
|
||||
|
||||
idToken, err := c.verifier.Verify(ctx, rawIDToken)
|
||||
if err != nil {
|
||||
return nil, nil, "", errors.Wrap(err, "could not verify id token")
|
||||
}
|
||||
|
||||
return token, idToken, rawIDToken, nil
|
||||
}
|
||||
|
||||
func (c *Client) getRawIDToken(sess *sessions.Session) (string, error) {
|
||||
rawIDToken, ok := sess.Values[sessionKeyIDToken].(string)
|
||||
if !ok || rawIDToken == "" {
|
||||
return "", errors.New("invalid id token")
|
||||
}
|
||||
|
||||
return rawIDToken, nil
|
||||
}
|
||||
|
||||
func (c *Client) getIDToken(r *http.Request, sess *sessions.Session) (*oidc.IDToken, error) {
|
||||
rawIDToken, err := c.getRawIDToken(sess)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "could not retrieve raw idtoken")
|
||||
}
|
||||
|
||||
idToken, err := c.verifier.Verify(r.Context(), rawIDToken)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "could not verify id token")
|
||||
}
|
||||
|
||||
return idToken, nil
|
||||
}
|
||||
|
||||
func NewClient(funcs ...ClientOptionFunc) *Client {
|
||||
opts := NewClientOptions(funcs...)
|
||||
|
||||
oauth2 := &oauth2.Config{
|
||||
ClientID: opts.ClientID,
|
||||
ClientSecret: opts.ClientSecret,
|
||||
Endpoint: opts.Provider.Endpoint(),
|
||||
RedirectURL: opts.RedirectURL,
|
||||
Scopes: opts.Scopes,
|
||||
}
|
||||
|
||||
verifier := opts.Provider.Verifier(&oidc.Config{
|
||||
ClientID: opts.ClientID,
|
||||
SkipIssuerCheck: opts.SkipIssuerCheck,
|
||||
})
|
||||
|
||||
return &Client{
|
||||
oauth2: oauth2,
|
||||
provider: opts.Provider,
|
||||
verifier: verifier,
|
||||
authParams: opts.AuthParams,
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user