195 lines
4.2 KiB
Go
195 lines
4.2 KiB
Go
package oidc
|
|
|
|
import (
|
|
"bytes"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
|
|
"github.com/coreos/go-oidc"
|
|
"github.com/dchest/uniuri"
|
|
"github.com/pkg/errors"
|
|
"gitlab.com/wpetit/goweb/middleware/container"
|
|
"gitlab.com/wpetit/goweb/service/session"
|
|
"golang.org/x/oauth2"
|
|
)
|
|
|
|
type Client struct {
|
|
oauth2 *oauth2.Config
|
|
provider *oidc.Provider
|
|
verifier *oidc.IDTokenVerifier
|
|
}
|
|
|
|
func (c *Client) Verifier() *oidc.IDTokenVerifier {
|
|
return c.verifier
|
|
}
|
|
|
|
func (c *Client) Provider() *oidc.Provider {
|
|
return c.provider
|
|
}
|
|
|
|
func (c *Client) Login(w http.ResponseWriter, r *http.Request) {
|
|
ctn := container.Must(r.Context())
|
|
|
|
sess, err := session.Must(ctn).Get(w, r)
|
|
if err != nil {
|
|
panic(errors.Wrap(err, "could not retrieve session"))
|
|
}
|
|
|
|
state := uniuri.New()
|
|
|
|
sess.Set(SessionOIDCStateKey, state)
|
|
|
|
if err := sess.Save(w, r); err != nil {
|
|
panic(errors.Wrap(err, "could not save session"))
|
|
}
|
|
|
|
authCodeOptions := []oauth2.AuthCodeOption{}
|
|
|
|
rawIDToken, _ := RawIDToken(w, r)
|
|
if rawIDToken != "" {
|
|
authCodeOptions = append(
|
|
authCodeOptions,
|
|
oauth2.SetAuthURLParam("id_token_hint", rawIDToken),
|
|
)
|
|
}
|
|
|
|
authCodeURL := c.oauth2.AuthCodeURL(
|
|
state,
|
|
authCodeOptions...,
|
|
)
|
|
|
|
http.Redirect(w, r, authCodeURL, http.StatusFound)
|
|
}
|
|
|
|
func (c *Client) Logout(w http.ResponseWriter, r *http.Request, postLogoutRedirectURL string) {
|
|
rawIDToken, err := RawIDToken(w, r)
|
|
if err != nil {
|
|
panic(errors.Wrap(err, "could not retrieve raw id token"))
|
|
}
|
|
|
|
ctn := container.Must(r.Context())
|
|
|
|
sess, err := session.Must(ctn).Get(w, r)
|
|
if err != nil {
|
|
panic(errors.Wrap(err, "could not retrieve session"))
|
|
}
|
|
|
|
state := uniuri.New()
|
|
|
|
sess.Set(SessionOIDCStateKey, state)
|
|
sess.Unset(SessionIDTokenKey)
|
|
|
|
if err := sess.Save(w, r); err != nil {
|
|
panic(errors.Wrap(err, "could not save session"))
|
|
}
|
|
|
|
sessionEndURL, err := c.sessionEndURL(rawIDToken, state, postLogoutRedirectURL)
|
|
if err != nil {
|
|
panic(errors.Wrap(err, "could not retrieve session end url"))
|
|
}
|
|
|
|
if sessionEndURL != "" {
|
|
http.Redirect(w, r, sessionEndURL, http.StatusFound)
|
|
} else {
|
|
http.Redirect(w, r, postLogoutRedirectURL, http.StatusFound)
|
|
}
|
|
}
|
|
|
|
func (c *Client) sessionEndURL(idTokenHint, state, postLogoutRedirectURL string) (string, error) {
|
|
sessionEndEndpoint := &struct {
|
|
URL string `json:"end_session_endpoint"`
|
|
}{}
|
|
|
|
if err := c.provider.Claims(&sessionEndEndpoint); err != nil {
|
|
return "", errors.Wrap(err, "could not unmarshal claims")
|
|
}
|
|
|
|
if sessionEndEndpoint.URL == "" {
|
|
return "", nil
|
|
}
|
|
|
|
var buf bytes.Buffer
|
|
buf.WriteString(sessionEndEndpoint.URL)
|
|
|
|
v := url.Values{}
|
|
|
|
if idTokenHint != "" {
|
|
v.Set("id_token_hint", idTokenHint)
|
|
}
|
|
|
|
if postLogoutRedirectURL != "" {
|
|
v.Set("post_logout_redirect_uri", postLogoutRedirectURL)
|
|
}
|
|
|
|
if state != "" {
|
|
v.Set("state", state)
|
|
}
|
|
|
|
if strings.Contains(sessionEndEndpoint.URL, "?") {
|
|
buf.WriteByte('&')
|
|
} else {
|
|
buf.WriteByte('?')
|
|
}
|
|
|
|
buf.WriteString(v.Encode())
|
|
|
|
return buf.String(), nil
|
|
}
|
|
|
|
func (c *Client) Validate(w http.ResponseWriter, r *http.Request) (*oidc.IDToken, string, error) {
|
|
ctx := r.Context()
|
|
ctn := container.Must(ctx)
|
|
|
|
sess, err := session.Must(ctn).Get(w, r)
|
|
if err != nil {
|
|
return nil, "", errors.Wrap(err, "could not retrieve session")
|
|
}
|
|
|
|
state, ok := sess.Get(SessionOIDCStateKey).(string)
|
|
if !ok {
|
|
return nil, "", errors.New("invalid state")
|
|
}
|
|
|
|
if r.URL.Query().Get("state") != state {
|
|
return nil, "", errors.New("state mismatch")
|
|
}
|
|
|
|
code := r.URL.Query().Get("code")
|
|
|
|
token, err := c.oauth2.Exchange(ctx, code)
|
|
if err != nil {
|
|
return nil, "", errors.Wrap(err, "could not exchange token")
|
|
}
|
|
|
|
rawIDToken, ok := token.Extra("id_token").(string)
|
|
if !ok {
|
|
return nil, "", errors.New("could not find id token")
|
|
}
|
|
|
|
idToken, err := c.verifier.Verify(ctx, rawIDToken)
|
|
if err != nil {
|
|
return nil, "", errors.Wrap(err, "could not verify id token")
|
|
}
|
|
|
|
return idToken, rawIDToken, nil
|
|
}
|
|
|
|
func NewClient(opts ...OptionFunc) *Client {
|
|
opt := fromDefault(opts...)
|
|
|
|
oauth2 := &oauth2.Config{
|
|
ClientID: opt.ClientID,
|
|
ClientSecret: opt.ClientSecret,
|
|
Endpoint: opt.Provider.Endpoint(),
|
|
RedirectURL: opt.RedirectURL,
|
|
Scopes: opt.Scopes,
|
|
}
|
|
|
|
verifier := opt.Provider.Verifier(&oidc.Config{
|
|
ClientID: opt.ClientID,
|
|
})
|
|
|
|
return &Client{oauth2, opt.Provider, verifier}
|
|
}
|