goweb-oidc/client.go

198 lines
4.5 KiB
Go
Raw Normal View History

2020-05-20 10:43:12 +02:00
package oidc
import (
2020-05-20 13:06:04 +02:00
"bytes"
2020-05-20 10:43:12 +02:00
"net/http"
2020-05-20 13:06:04 +02:00
"net/url"
"strings"
2020-05-20 10:43:12 +02:00
2023-11-02 18:21:54 +01:00
"github.com/coreos/go-oidc/v3/oidc"
2020-05-20 10:43:12 +02:00
"github.com/dchest/uniuri"
"github.com/pkg/errors"
"gitlab.com/wpetit/goweb/middleware/container"
"gitlab.com/wpetit/goweb/service/session"
"golang.org/x/oauth2"
)
type Client struct {
2023-11-02 18:21:54 +01:00
oauth2 *oauth2.Config
provider *oidc.Provider
verifier *oidc.IDTokenVerifier
acrValues string
2020-05-20 10:43:12 +02:00
}
func (c *Client) Verifier() *oidc.IDTokenVerifier {
return c.verifier
}
func (c *Client) Provider() *oidc.Provider {
return c.provider
}
func (c *Client) Login(w http.ResponseWriter, r *http.Request) {
ctn := container.Must(r.Context())
sess, err := session.Must(ctn).Get(w, r)
if err != nil {
panic(errors.Wrap(err, "could not retrieve session"))
}
state := uniuri.New()
2023-03-02 15:07:02 +01:00
nonce := uniuri.New()
2020-05-20 10:43:12 +02:00
sess.Set(SessionOIDCStateKey, state)
2023-03-02 15:07:02 +01:00
sess.Set(SessionOIDCNonceKey, nonce)
2020-05-20 10:43:12 +02:00
if err := sess.Save(w, r); err != nil {
panic(errors.Wrap(err, "could not save session"))
}
2020-05-20 13:06:04 +02:00
authCodeOptions := []oauth2.AuthCodeOption{}
2023-03-02 15:07:02 +01:00
authCodeOptions = append(authCodeOptions, oidc.Nonce(nonce))
2023-11-02 18:21:54 +01:00
if c.acrValues != "" {
authCodeOptions = append(authCodeOptions, oauth2.SetAuthURLParam("acr_values", c.acrValues))
}
2020-05-20 13:06:04 +02:00
authCodeURL := c.oauth2.AuthCodeURL(
state,
authCodeOptions...,
)
http.Redirect(w, r, authCodeURL, http.StatusFound)
2020-05-20 10:43:12 +02:00
}
2020-05-20 13:06:04 +02:00
func (c *Client) Logout(w http.ResponseWriter, r *http.Request, postLogoutRedirectURL string) {
rawIDToken, err := RawIDToken(w, r)
if err != nil {
panic(errors.Wrap(err, "could not retrieve raw id token"))
}
2020-05-20 10:43:12 +02:00
ctn := container.Must(r.Context())
sess, err := session.Must(ctn).Get(w, r)
if err != nil {
panic(errors.Wrap(err, "could not retrieve session"))
}
state := uniuri.New()
sess.Set(SessionOIDCStateKey, state)
2020-05-26 11:17:16 +02:00
sess.Unset(SessionIDTokenKey)
2020-05-20 10:43:12 +02:00
if err := sess.Save(w, r); err != nil {
panic(errors.Wrap(err, "could not save session"))
}
2020-05-20 13:06:04 +02:00
sessionEndURL, err := c.sessionEndURL(rawIDToken, state, postLogoutRedirectURL)
if err != nil {
panic(errors.Wrap(err, "could not retrieve session end url"))
}
if sessionEndURL != "" {
http.Redirect(w, r, sessionEndURL, http.StatusFound)
} else {
http.Redirect(w, r, postLogoutRedirectURL, http.StatusFound)
}
}
func (c *Client) sessionEndURL(idTokenHint, state, postLogoutRedirectURL string) (string, error) {
sessionEndEndpoint := &struct {
URL string `json:"end_session_endpoint"`
}{}
if err := c.provider.Claims(&sessionEndEndpoint); err != nil {
return "", errors.Wrap(err, "could not unmarshal claims")
}
if sessionEndEndpoint.URL == "" {
return "", nil
}
var buf bytes.Buffer
buf.WriteString(sessionEndEndpoint.URL)
v := url.Values{}
if idTokenHint != "" {
v.Set("id_token_hint", idTokenHint)
}
if postLogoutRedirectURL != "" {
v.Set("post_logout_redirect_uri", postLogoutRedirectURL)
}
if state != "" {
v.Set("state", state)
}
if strings.Contains(sessionEndEndpoint.URL, "?") {
buf.WriteByte('&')
} else {
buf.WriteByte('?')
}
buf.WriteString(v.Encode())
return buf.String(), nil
2020-05-20 10:43:12 +02:00
}
2020-05-20 13:06:04 +02:00
func (c *Client) Validate(w http.ResponseWriter, r *http.Request) (*oidc.IDToken, string, error) {
2020-05-20 10:43:12 +02:00
ctx := r.Context()
ctn := container.Must(ctx)
sess, err := session.Must(ctn).Get(w, r)
if err != nil {
2020-05-20 13:06:04 +02:00
return nil, "", errors.Wrap(err, "could not retrieve session")
2020-05-20 10:43:12 +02:00
}
rawStoredState := sess.Get(SessionOIDCStateKey)
receivedState := r.URL.Query().Get("state")
storedState, ok := rawStoredState.(string)
2020-05-20 10:43:12 +02:00
if !ok {
return nil, "", errors.New("could not find state in session")
2020-05-20 10:43:12 +02:00
}
if receivedState != storedState {
2020-05-20 13:06:04 +02:00
return nil, "", errors.New("state mismatch")
2020-05-20 10:43:12 +02:00
}
code := r.URL.Query().Get("code")
token, err := c.oauth2.Exchange(ctx, code)
if err != nil {
2020-05-20 13:06:04 +02:00
return nil, "", errors.Wrap(err, "could not exchange token")
2020-05-20 10:43:12 +02:00
}
rawIDToken, ok := token.Extra("id_token").(string)
if !ok {
2020-05-20 13:06:04 +02:00
return nil, "", errors.New("could not find id token")
2020-05-20 10:43:12 +02:00
}
idToken, err := c.verifier.Verify(ctx, rawIDToken)
if err != nil {
2020-05-20 13:06:04 +02:00
return nil, "", errors.Wrap(err, "could not verify id token")
2020-05-20 10:43:12 +02:00
}
2020-05-20 13:06:04 +02:00
return idToken, rawIDToken, nil
2020-05-20 10:43:12 +02:00
}
func NewClient(opts ...OptionFunc) *Client {
opt := fromDefault(opts...)
oauth2 := &oauth2.Config{
ClientID: opt.ClientID,
ClientSecret: opt.ClientSecret,
Endpoint: opt.Provider.Endpoint(),
RedirectURL: opt.RedirectURL,
Scopes: opt.Scopes,
}
verifier := opt.Provider.Verifier(&oidc.Config{
ClientID: opt.ClientID,
SkipIssuerCheck: opt.SkipIssuerCheck,
2020-05-20 10:43:12 +02:00
})
2023-11-02 18:21:54 +01:00
return &Client{oauth2, opt.Provider, verifier, opt.AcrValues}
2020-05-20 10:43:12 +02:00
}