bouncer/internal/proxy/director/layer/authn/oidc/client.go

286 lines
7.3 KiB
Go
Raw Normal View History

package oidc
import (
"bytes"
"net/http"
"net/url"
"strings"
"github.com/coreos/go-oidc/v3/oidc"
"github.com/dchest/uniuri"
"github.com/gorilla/sessions"
"github.com/pkg/errors"
"gitlab.com/wpetit/goweb/logger"
"golang.org/x/oauth2"
)
const (
sessionKeyAccessToken = "access-token"
sessionKeyRefreshToken = "refresh-token"
sessionKeyTokenExpiry = "token-expiry"
sessionKeyIDToken = "id-token"
sessionKeyPostLoginRedirectURL = "post-login-redirect-url"
sessionKeyLoginState = "login-state"
sessionKeyLoginNonce = "login-nonce"
)
var (
ErrLoginRequired = errors.New("login required")
)
type Client struct {
httpClient *http.Client
oauth2 *oauth2.Config
provider *oidc.Provider
verifier *oidc.IDTokenVerifier
authParams map[string]string
}
func (c *Client) Verifier() *oidc.IDTokenVerifier {
return c.verifier
}
func (c *Client) Provider() *oidc.Provider {
return c.provider
}
func (c *Client) Authenticate(w http.ResponseWriter, r *http.Request, sess *sessions.Session, postLoginRedirectURL string) (*oidc.IDToken, error) {
idToken, err := c.getIDToken(r, sess)
if err != nil {
logger.Warn(r.Context(), "could not retrieve idtoken", logger.CapturedE(errors.WithStack(err)))
c.login(w, r, sess, postLoginRedirectURL)
return nil, errors.WithStack(ErrLoginRequired)
}
return idToken, nil
}
func (c *Client) login(w http.ResponseWriter, r *http.Request, sess *sessions.Session, postLoginRedirectURL string) {
ctx := r.Context()
state := uniuri.New()
nonce := uniuri.New()
sess.Values[sessionKeyLoginState] = state
sess.Values[sessionKeyLoginNonce] = nonce
sess.Values[sessionKeyPostLoginRedirectURL] = postLoginRedirectURL
if err := sess.Save(r, w); err != nil {
logger.Error(ctx, "could not save session", logger.CapturedE(errors.WithStack(err)))
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
authCodeOptions := []oauth2.AuthCodeOption{}
authCodeOptions = append(authCodeOptions, oidc.Nonce(nonce))
for key, val := range c.authParams {
authCodeOptions = append(authCodeOptions, oauth2.SetAuthURLParam(key, val))
}
authCodeURL := c.oauth2.AuthCodeURL(
state,
authCodeOptions...,
)
http.Redirect(w, r, authCodeURL, http.StatusFound)
}
func (c *Client) HandleCallback(w http.ResponseWriter, r *http.Request, sess *sessions.Session) error {
token, _, rawIDToken, err := c.validate(r, sess)
if err != nil {
return errors.Wrap(err, "could not validate oidc token")
}
sess.Values[sessionKeyIDToken] = rawIDToken
sess.Values[sessionKeyAccessToken] = token.AccessToken
sess.Values[sessionKeyRefreshToken] = token.RefreshToken
sess.Values[sessionKeyTokenExpiry] = token.Expiry.UTC().Unix()
if err := sess.Save(r, w); err != nil {
return errors.WithStack(err)
}
rawPostLoginRedirectURL, exists := sess.Values[sessionKeyPostLoginRedirectURL]
if !exists {
return errors.Wrap(err, "could not find post login redirect url")
}
postLoginRedirectURL, ok := rawPostLoginRedirectURL.(string)
if !ok {
return errors.Wrapf(err, "unexpected value '%v' for post login redirect url", rawPostLoginRedirectURL)
}
http.Redirect(w, r, postLoginRedirectURL, http.StatusTemporaryRedirect)
return nil
}
func (c *Client) HandleLogout(w http.ResponseWriter, r *http.Request, sess *sessions.Session, postLogoutRedirectURL string) error {
state := uniuri.New()
sess.Values[sessionKeyLoginState] = state
ctx := r.Context()
rawIDToken, err := c.getRawIDToken(sess)
if err != nil {
logger.Error(ctx, "could not retrieve raw id token", logger.CapturedE(errors.WithStack(err)))
}
sess.Values[sessionKeyIDToken] = nil
sess.Values[sessionKeyAccessToken] = nil
sess.Values[sessionKeyRefreshToken] = nil
sess.Values[sessionKeyTokenExpiry] = nil
sess.Options.MaxAge = -1
if err := sess.Save(r, w); err != nil {
return errors.Wrap(err, "could not save session")
}
if rawIDToken == "" {
http.Redirect(w, r, postLogoutRedirectURL, http.StatusFound)
return nil
}
sessionEndURL, err := c.sessionEndURL(rawIDToken, state, postLogoutRedirectURL)
if err != nil {
return errors.Wrap(err, "could not retrieve session end url")
}
if sessionEndURL != "" {
http.Redirect(w, r, sessionEndURL, http.StatusFound)
} else {
http.Redirect(w, r, postLogoutRedirectURL, http.StatusFound)
}
return nil
}
func (c *Client) sessionEndURL(idTokenHint, state, postLogoutRedirectURL string) (string, error) {
sessionEndEndpoint := &struct {
URL string `json:"end_session_endpoint"`
}{}
if err := c.provider.Claims(&sessionEndEndpoint); err != nil {
return "", errors.Wrap(err, "could not unmarshal claims")
}
if sessionEndEndpoint.URL == "" {
return "", nil
}
var buf bytes.Buffer
buf.WriteString(sessionEndEndpoint.URL)
v := url.Values{}
if idTokenHint != "" {
v.Set("id_token_hint", idTokenHint)
}
if postLogoutRedirectURL != "" {
v.Set("post_logout_redirect_uri", postLogoutRedirectURL)
}
if state != "" {
v.Set("state", state)
}
if strings.Contains(sessionEndEndpoint.URL, "?") {
buf.WriteByte('&')
} else {
buf.WriteByte('?')
}
buf.WriteString(v.Encode())
return buf.String(), nil
}
func (c *Client) validate(r *http.Request, sess *sessions.Session) (*oauth2.Token, *oidc.IDToken, string, error) {
ctx := r.Context()
ctx = oidc.ClientContext(ctx, c.httpClient)
rawStoredState := sess.Values[sessionKeyLoginState]
receivedState := r.URL.Query().Get("state")
storedState, ok := rawStoredState.(string)
if !ok {
return nil, nil, "", errors.New("could not find state in session")
}
if receivedState != storedState {
return nil, nil, "", errors.New("state mismatch")
}
code := r.URL.Query().Get("code")
token, err := c.oauth2.Exchange(ctx, code)
if err != nil {
return nil, nil, "", errors.Wrap(err, "could not exchange token")
}
rawIDToken, ok := token.Extra("id_token").(string)
if !ok {
return nil, nil, "", errors.New("could not find id token")
}
idToken, err := c.verifier.Verify(ctx, rawIDToken)
if err != nil {
return nil, nil, "", errors.Wrap(err, "could not verify id token")
}
return token, idToken, rawIDToken, nil
}
func (c *Client) getRawIDToken(sess *sessions.Session) (string, error) {
rawIDToken, ok := sess.Values[sessionKeyIDToken].(string)
if !ok || rawIDToken == "" {
return "", errors.New("id token not found")
}
return rawIDToken, nil
}
func (c *Client) getIDToken(r *http.Request, sess *sessions.Session) (*oidc.IDToken, error) {
rawIDToken, err := c.getRawIDToken(sess)
if err != nil {
return nil, errors.Wrap(err, "could not retrieve raw idtoken")
}
idToken, err := c.verifier.Verify(r.Context(), rawIDToken)
if err != nil {
return nil, errors.Wrap(err, "could not verify id token")
}
return idToken, nil
}
func NewClient(funcs ...ClientOptionFunc) *Client {
opts := NewClientOptions(funcs...)
oauth2 := &oauth2.Config{
ClientID: opts.ClientID,
ClientSecret: opts.ClientSecret,
Endpoint: opts.Provider.Endpoint(),
RedirectURL: opts.RedirectURL,
Scopes: opts.Scopes,
}
verifier := opts.Provider.Verifier(&oidc.Config{
ClientID: opts.ClientID,
SkipIssuerCheck: opts.SkipIssuerCheck,
})
return &Client{
oauth2: oauth2,
provider: opts.Provider,
verifier: verifier,
authParams: opts.AuthParams,
httpClient: opts.HTTPClient,
}
}