286 lines
7.3 KiB
Go
286 lines
7.3 KiB
Go
package oidc
|
|
|
|
import (
|
|
"bytes"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
|
|
"github.com/coreos/go-oidc/v3/oidc"
|
|
"github.com/dchest/uniuri"
|
|
"github.com/gorilla/sessions"
|
|
"github.com/pkg/errors"
|
|
"gitlab.com/wpetit/goweb/logger"
|
|
"golang.org/x/oauth2"
|
|
)
|
|
|
|
const (
|
|
sessionKeyAccessToken = "access-token"
|
|
sessionKeyRefreshToken = "refresh-token"
|
|
sessionKeyTokenExpiry = "token-expiry"
|
|
sessionKeyIDToken = "id-token"
|
|
sessionKeyPostLoginRedirectURL = "post-login-redirect-url"
|
|
sessionKeyLoginState = "login-state"
|
|
sessionKeyLoginNonce = "login-nonce"
|
|
)
|
|
|
|
var (
|
|
ErrLoginRequired = errors.New("login required")
|
|
)
|
|
|
|
type Client struct {
|
|
httpClient *http.Client
|
|
oauth2 *oauth2.Config
|
|
provider *oidc.Provider
|
|
verifier *oidc.IDTokenVerifier
|
|
authParams map[string]string
|
|
}
|
|
|
|
func (c *Client) Verifier() *oidc.IDTokenVerifier {
|
|
return c.verifier
|
|
}
|
|
|
|
func (c *Client) Provider() *oidc.Provider {
|
|
return c.provider
|
|
}
|
|
|
|
func (c *Client) Authenticate(w http.ResponseWriter, r *http.Request, sess *sessions.Session, postLoginRedirectURL string) (*oidc.IDToken, error) {
|
|
idToken, err := c.getIDToken(r, sess)
|
|
if err != nil {
|
|
logger.Warn(r.Context(), "could not retrieve idtoken", logger.CapturedE(errors.WithStack(err)))
|
|
|
|
c.login(w, r, sess, postLoginRedirectURL)
|
|
|
|
return nil, errors.WithStack(ErrLoginRequired)
|
|
}
|
|
|
|
return idToken, nil
|
|
}
|
|
|
|
func (c *Client) login(w http.ResponseWriter, r *http.Request, sess *sessions.Session, postLoginRedirectURL string) {
|
|
ctx := r.Context()
|
|
|
|
state := uniuri.New()
|
|
nonce := uniuri.New()
|
|
|
|
sess.Values[sessionKeyLoginState] = state
|
|
sess.Values[sessionKeyLoginNonce] = nonce
|
|
sess.Values[sessionKeyPostLoginRedirectURL] = postLoginRedirectURL
|
|
|
|
if err := sess.Save(r, w); err != nil {
|
|
logger.Error(ctx, "could not save session", logger.CapturedE(errors.WithStack(err)))
|
|
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
|
|
|
return
|
|
}
|
|
|
|
authCodeOptions := []oauth2.AuthCodeOption{}
|
|
authCodeOptions = append(authCodeOptions, oidc.Nonce(nonce))
|
|
|
|
for key, val := range c.authParams {
|
|
authCodeOptions = append(authCodeOptions, oauth2.SetAuthURLParam(key, val))
|
|
}
|
|
|
|
authCodeURL := c.oauth2.AuthCodeURL(
|
|
state,
|
|
authCodeOptions...,
|
|
)
|
|
|
|
http.Redirect(w, r, authCodeURL, http.StatusFound)
|
|
}
|
|
|
|
func (c *Client) HandleCallback(w http.ResponseWriter, r *http.Request, sess *sessions.Session) error {
|
|
token, _, rawIDToken, err := c.validate(r, sess)
|
|
if err != nil {
|
|
return errors.Wrap(err, "could not validate oidc token")
|
|
}
|
|
|
|
sess.Values[sessionKeyIDToken] = rawIDToken
|
|
sess.Values[sessionKeyAccessToken] = token.AccessToken
|
|
sess.Values[sessionKeyRefreshToken] = token.RefreshToken
|
|
sess.Values[sessionKeyTokenExpiry] = token.Expiry.UTC().Unix()
|
|
|
|
if err := sess.Save(r, w); err != nil {
|
|
return errors.WithStack(err)
|
|
}
|
|
|
|
rawPostLoginRedirectURL, exists := sess.Values[sessionKeyPostLoginRedirectURL]
|
|
if !exists {
|
|
return errors.Wrap(err, "could not find post login redirect url")
|
|
}
|
|
|
|
postLoginRedirectURL, ok := rawPostLoginRedirectURL.(string)
|
|
if !ok {
|
|
return errors.Wrapf(err, "unexpected value '%v' for post login redirect url", rawPostLoginRedirectURL)
|
|
}
|
|
|
|
http.Redirect(w, r, postLoginRedirectURL, http.StatusTemporaryRedirect)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *Client) HandleLogout(w http.ResponseWriter, r *http.Request, sess *sessions.Session, postLogoutRedirectURL string) error {
|
|
state := uniuri.New()
|
|
sess.Values[sessionKeyLoginState] = state
|
|
|
|
ctx := r.Context()
|
|
|
|
rawIDToken, err := c.getRawIDToken(sess)
|
|
if err != nil {
|
|
logger.Error(ctx, "could not retrieve raw id token", logger.CapturedE(errors.WithStack(err)))
|
|
}
|
|
|
|
sess.Values[sessionKeyIDToken] = nil
|
|
sess.Values[sessionKeyAccessToken] = nil
|
|
sess.Values[sessionKeyRefreshToken] = nil
|
|
sess.Values[sessionKeyTokenExpiry] = nil
|
|
sess.Options.MaxAge = -1
|
|
|
|
if err := sess.Save(r, w); err != nil {
|
|
return errors.Wrap(err, "could not save session")
|
|
}
|
|
|
|
if rawIDToken == "" {
|
|
http.Redirect(w, r, postLogoutRedirectURL, http.StatusFound)
|
|
return nil
|
|
}
|
|
|
|
sessionEndURL, err := c.sessionEndURL(rawIDToken, state, postLogoutRedirectURL)
|
|
if err != nil {
|
|
return errors.Wrap(err, "could not retrieve session end url")
|
|
}
|
|
|
|
if sessionEndURL != "" {
|
|
http.Redirect(w, r, sessionEndURL, http.StatusFound)
|
|
} else {
|
|
http.Redirect(w, r, postLogoutRedirectURL, http.StatusFound)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *Client) sessionEndURL(idTokenHint, state, postLogoutRedirectURL string) (string, error) {
|
|
sessionEndEndpoint := &struct {
|
|
URL string `json:"end_session_endpoint"`
|
|
}{}
|
|
|
|
if err := c.provider.Claims(&sessionEndEndpoint); err != nil {
|
|
return "", errors.Wrap(err, "could not unmarshal claims")
|
|
}
|
|
|
|
if sessionEndEndpoint.URL == "" {
|
|
return "", nil
|
|
}
|
|
|
|
var buf bytes.Buffer
|
|
buf.WriteString(sessionEndEndpoint.URL)
|
|
|
|
v := url.Values{}
|
|
|
|
if idTokenHint != "" {
|
|
v.Set("id_token_hint", idTokenHint)
|
|
}
|
|
|
|
if postLogoutRedirectURL != "" {
|
|
v.Set("post_logout_redirect_uri", postLogoutRedirectURL)
|
|
}
|
|
|
|
if state != "" {
|
|
v.Set("state", state)
|
|
}
|
|
|
|
if strings.Contains(sessionEndEndpoint.URL, "?") {
|
|
buf.WriteByte('&')
|
|
} else {
|
|
buf.WriteByte('?')
|
|
}
|
|
|
|
buf.WriteString(v.Encode())
|
|
|
|
return buf.String(), nil
|
|
}
|
|
|
|
func (c *Client) validate(r *http.Request, sess *sessions.Session) (*oauth2.Token, *oidc.IDToken, string, error) {
|
|
ctx := r.Context()
|
|
ctx = oidc.ClientContext(ctx, c.httpClient)
|
|
|
|
rawStoredState := sess.Values[sessionKeyLoginState]
|
|
receivedState := r.URL.Query().Get("state")
|
|
|
|
storedState, ok := rawStoredState.(string)
|
|
if !ok {
|
|
return nil, nil, "", errors.New("could not find state in session")
|
|
}
|
|
|
|
if receivedState != storedState {
|
|
return nil, nil, "", errors.New("state mismatch")
|
|
}
|
|
|
|
code := r.URL.Query().Get("code")
|
|
|
|
token, err := c.oauth2.Exchange(ctx, code)
|
|
if err != nil {
|
|
return nil, nil, "", errors.Wrap(err, "could not exchange token")
|
|
}
|
|
|
|
rawIDToken, ok := token.Extra("id_token").(string)
|
|
if !ok {
|
|
return nil, nil, "", errors.New("could not find id token")
|
|
}
|
|
|
|
idToken, err := c.verifier.Verify(ctx, rawIDToken)
|
|
if err != nil {
|
|
return nil, nil, "", errors.Wrap(err, "could not verify id token")
|
|
}
|
|
|
|
return token, idToken, rawIDToken, nil
|
|
}
|
|
|
|
func (c *Client) getRawIDToken(sess *sessions.Session) (string, error) {
|
|
rawIDToken, ok := sess.Values[sessionKeyIDToken].(string)
|
|
if !ok || rawIDToken == "" {
|
|
return "", errors.New("id token not found")
|
|
}
|
|
|
|
return rawIDToken, nil
|
|
}
|
|
|
|
func (c *Client) getIDToken(r *http.Request, sess *sessions.Session) (*oidc.IDToken, error) {
|
|
rawIDToken, err := c.getRawIDToken(sess)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "could not retrieve raw idtoken")
|
|
}
|
|
|
|
idToken, err := c.verifier.Verify(r.Context(), rawIDToken)
|
|
if err != nil {
|
|
return nil, errors.Wrap(err, "could not verify id token")
|
|
}
|
|
|
|
return idToken, nil
|
|
}
|
|
|
|
func NewClient(funcs ...ClientOptionFunc) *Client {
|
|
opts := NewClientOptions(funcs...)
|
|
|
|
oauth2 := &oauth2.Config{
|
|
ClientID: opts.ClientID,
|
|
ClientSecret: opts.ClientSecret,
|
|
Endpoint: opts.Provider.Endpoint(),
|
|
RedirectURL: opts.RedirectURL,
|
|
Scopes: opts.Scopes,
|
|
}
|
|
|
|
verifier := opts.Provider.Verifier(&oidc.Config{
|
|
ClientID: opts.ClientID,
|
|
SkipIssuerCheck: opts.SkipIssuerCheck,
|
|
})
|
|
|
|
return &Client{
|
|
oauth2: oauth2,
|
|
provider: opts.Provider,
|
|
verifier: verifier,
|
|
authParams: opts.AuthParams,
|
|
httpClient: opts.HTTPClient,
|
|
}
|
|
}
|