goweb-oidc/client.go

121 lines
2.6 KiB
Go
Raw Normal View History

2020-05-20 10:43:12 +02:00
package oidc
import (
"net/http"
"github.com/coreos/go-oidc"
"github.com/dchest/uniuri"
"github.com/pkg/errors"
"gitlab.com/wpetit/goweb/middleware/container"
"gitlab.com/wpetit/goweb/service/session"
"golang.org/x/oauth2"
)
type Client struct {
oauth2 *oauth2.Config
provider *oidc.Provider
verifier *oidc.IDTokenVerifier
}
func (c *Client) Verifier() *oidc.IDTokenVerifier {
return c.verifier
}
func (c *Client) Provider() *oidc.Provider {
return c.provider
}
func (c *Client) Login(w http.ResponseWriter, r *http.Request) {
ctn := container.Must(r.Context())
sess, err := session.Must(ctn).Get(w, r)
if err != nil {
panic(errors.Wrap(err, "could not retrieve session"))
}
state := uniuri.New()
sess.Set(SessionOIDCStateKey, state)
if err := sess.Save(w, r); err != nil {
panic(errors.Wrap(err, "could not save session"))
}
http.Redirect(w, r, c.oauth2.AuthCodeURL(state), http.StatusFound)
}
func (c *Client) Logout(w http.ResponseWriter, r *http.Request) {
ctn := container.Must(r.Context())
sess, err := session.Must(ctn).Get(w, r)
if err != nil {
panic(errors.Wrap(err, "could not retrieve session"))
}
state := uniuri.New()
sess.Set(SessionOIDCStateKey, state)
if err := sess.Save(w, r); err != nil {
panic(errors.Wrap(err, "could not save session"))
}
http.Redirect(w, r, c.oauth2.AuthCodeURL(state), http.StatusFound)
}
func (c *Client) Validate(w http.ResponseWriter, r *http.Request) (*oidc.IDToken, error) {
ctx := r.Context()
ctn := container.Must(ctx)
sess, err := session.Must(ctn).Get(w, r)
if err != nil {
return nil, errors.Wrap(err, "could not retrieve session")
}
state, ok := sess.Get(SessionOIDCStateKey).(string)
if !ok {
return nil, errors.New("invalid state")
}
if r.URL.Query().Get("state") != state {
return nil, errors.New("state mismatch")
}
code := r.URL.Query().Get("code")
token, err := c.oauth2.Exchange(ctx, code)
if err != nil {
return nil, errors.Wrap(err, "could not exchange token")
}
rawIDToken, ok := token.Extra("id_token").(string)
if !ok {
return nil, errors.New("could not find id token")
}
idToken, err := c.verifier.Verify(ctx, rawIDToken)
if err != nil {
return nil, errors.Wrap(err, "could not verify id token")
}
return idToken, nil
}
func NewClient(opts ...OptionFunc) *Client {
opt := fromDefault(opts...)
oauth2 := &oauth2.Config{
ClientID: opt.ClientID,
ClientSecret: opt.ClientSecret,
Endpoint: opt.Provider.Endpoint(),
RedirectURL: opt.RedirectURL,
Scopes: opt.Scopes,
}
verifier := opt.Provider.Verifier(&oidc.Config{
ClientID: opt.ClientID,
})
return &Client{oauth2, opt.Provider, verifier}
}