package oidc import ( "bytes" "net/http" "net/url" "strings" "github.com/coreos/go-oidc/v3/oidc" "github.com/dchest/uniuri" "github.com/gorilla/sessions" "github.com/pkg/errors" "gitlab.com/wpetit/goweb/logger" "golang.org/x/oauth2" ) const ( sessionKeyAccessToken = "access-token" sessionKeyRefreshToken = "refresh-token" sessionKeyTokenExpiry = "token-expiry" sessionKeyIDToken = "id-token" sessionKeyPostLoginRedirectURL = "post-login-redirect-url" sessionKeyLoginState = "login-state" sessionKeyLoginNonce = "login-nonce" ) var ( ErrLoginRequired = errors.New("login required") ) type Client struct { httpClient *http.Client oauth2 *oauth2.Config provider *oidc.Provider verifier *oidc.IDTokenVerifier authParams map[string]string } func (c *Client) Verifier() *oidc.IDTokenVerifier { return c.verifier } func (c *Client) Provider() *oidc.Provider { return c.provider } func (c *Client) Authenticate(w http.ResponseWriter, r *http.Request, sess *sessions.Session, postLoginRedirectURL string) (*oidc.IDToken, error) { idToken, err := c.getIDToken(r, sess) if err != nil { logger.Warn(r.Context(), "could not retrieve idtoken", logger.CapturedE(errors.WithStack(err))) c.login(w, r, sess, postLoginRedirectURL) return nil, errors.WithStack(ErrLoginRequired) } return idToken, nil } func (c *Client) login(w http.ResponseWriter, r *http.Request, sess *sessions.Session, postLoginRedirectURL string) { ctx := r.Context() state := uniuri.New() nonce := uniuri.New() sess.Values[sessionKeyLoginState] = state sess.Values[sessionKeyLoginNonce] = nonce sess.Values[sessionKeyPostLoginRedirectURL] = postLoginRedirectURL if err := sess.Save(r, w); err != nil { logger.Error(ctx, "could not save session", logger.CapturedE(errors.WithStack(err))) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } authCodeOptions := []oauth2.AuthCodeOption{} authCodeOptions = append(authCodeOptions, oidc.Nonce(nonce)) for key, val := range c.authParams { authCodeOptions = append(authCodeOptions, oauth2.SetAuthURLParam(key, val)) } authCodeURL := c.oauth2.AuthCodeURL( state, authCodeOptions..., ) http.Redirect(w, r, authCodeURL, http.StatusFound) } func (c *Client) HandleCallback(w http.ResponseWriter, r *http.Request, sess *sessions.Session) error { token, _, rawIDToken, err := c.validate(r, sess) if err != nil { return errors.Wrap(err, "could not validate oidc token") } sess.Values[sessionKeyIDToken] = rawIDToken sess.Values[sessionKeyAccessToken] = token.AccessToken sess.Values[sessionKeyRefreshToken] = token.RefreshToken sess.Values[sessionKeyTokenExpiry] = token.Expiry.UTC().Unix() if err := sess.Save(r, w); err != nil { return errors.WithStack(err) } rawPostLoginRedirectURL, exists := sess.Values[sessionKeyPostLoginRedirectURL] if !exists { return errors.Wrap(err, "could not find post login redirect url") } postLoginRedirectURL, ok := rawPostLoginRedirectURL.(string) if !ok { return errors.Wrapf(err, "unexpected value '%v' for post login redirect url", rawPostLoginRedirectURL) } http.Redirect(w, r, postLoginRedirectURL, http.StatusTemporaryRedirect) return nil } func (c *Client) HandleLogout(w http.ResponseWriter, r *http.Request, sess *sessions.Session, postLogoutRedirectURL string) error { state := uniuri.New() sess.Values[sessionKeyLoginState] = state ctx := r.Context() rawIDToken, err := c.getRawIDToken(sess) if err != nil { logger.Error(ctx, "could not retrieve raw id token", logger.CapturedE(errors.WithStack(err))) } sess.Values[sessionKeyIDToken] = nil sess.Values[sessionKeyAccessToken] = nil sess.Values[sessionKeyRefreshToken] = nil sess.Values[sessionKeyTokenExpiry] = nil sess.Options.MaxAge = -1 if err := sess.Save(r, w); err != nil { return errors.Wrap(err, "could not save session") } if rawIDToken == "" { http.Redirect(w, r, postLogoutRedirectURL, http.StatusFound) return nil } sessionEndURL, err := c.sessionEndURL(rawIDToken, state, postLogoutRedirectURL) if err != nil { return errors.Wrap(err, "could not retrieve session end url") } if sessionEndURL != "" { http.Redirect(w, r, sessionEndURL, http.StatusFound) } else { http.Redirect(w, r, postLogoutRedirectURL, http.StatusFound) } return nil } func (c *Client) sessionEndURL(idTokenHint, state, postLogoutRedirectURL string) (string, error) { sessionEndEndpoint := &struct { URL string `json:"end_session_endpoint"` }{} if err := c.provider.Claims(&sessionEndEndpoint); err != nil { return "", errors.Wrap(err, "could not unmarshal claims") } if sessionEndEndpoint.URL == "" { return "", nil } var buf bytes.Buffer buf.WriteString(sessionEndEndpoint.URL) v := url.Values{} if idTokenHint != "" { v.Set("id_token_hint", idTokenHint) } if postLogoutRedirectURL != "" { v.Set("post_logout_redirect_uri", postLogoutRedirectURL) } if state != "" { v.Set("state", state) } if strings.Contains(sessionEndEndpoint.URL, "?") { buf.WriteByte('&') } else { buf.WriteByte('?') } buf.WriteString(v.Encode()) return buf.String(), nil } func (c *Client) validate(r *http.Request, sess *sessions.Session) (*oauth2.Token, *oidc.IDToken, string, error) { ctx := r.Context() ctx = oidc.ClientContext(ctx, c.httpClient) rawStoredState := sess.Values[sessionKeyLoginState] receivedState := r.URL.Query().Get("state") storedState, ok := rawStoredState.(string) if !ok { return nil, nil, "", errors.New("could not find state in session") } if receivedState != storedState { return nil, nil, "", errors.New("state mismatch") } code := r.URL.Query().Get("code") token, err := c.oauth2.Exchange(ctx, code) if err != nil { return nil, nil, "", errors.Wrap(err, "could not exchange token") } rawIDToken, ok := token.Extra("id_token").(string) if !ok { return nil, nil, "", errors.New("could not find id token") } idToken, err := c.verifier.Verify(ctx, rawIDToken) if err != nil { return nil, nil, "", errors.Wrap(err, "could not verify id token") } return token, idToken, rawIDToken, nil } func (c *Client) getRawIDToken(sess *sessions.Session) (string, error) { rawIDToken, ok := sess.Values[sessionKeyIDToken].(string) if !ok || rawIDToken == "" { return "", errors.New("id token not found") } return rawIDToken, nil } func (c *Client) getIDToken(r *http.Request, sess *sessions.Session) (*oidc.IDToken, error) { rawIDToken, err := c.getRawIDToken(sess) if err != nil { return nil, errors.Wrap(err, "could not retrieve raw idtoken") } idToken, err := c.verifier.Verify(r.Context(), rawIDToken) if err != nil { return nil, errors.Wrap(err, "could not verify id token") } return idToken, nil } func NewClient(funcs ...ClientOptionFunc) *Client { opts := NewClientOptions(funcs...) oauth2 := &oauth2.Config{ ClientID: opts.ClientID, ClientSecret: opts.ClientSecret, Endpoint: opts.Provider.Endpoint(), RedirectURL: opts.RedirectURL, Scopes: opts.Scopes, } verifier := opts.Provider.Verifier(&oidc.Config{ ClientID: opts.ClientID, SkipIssuerCheck: opts.SkipIssuerCheck, }) return &Client{ oauth2: oauth2, provider: opts.Provider, verifier: verifier, authParams: opts.AuthParams, httpClient: opts.HTTPClient, } }