package oidc import ( "bytes" "net/http" "net/url" "strings" "github.com/coreos/go-oidc" "github.com/dchest/uniuri" "github.com/pkg/errors" "gitlab.com/wpetit/goweb/middleware/container" "gitlab.com/wpetit/goweb/service/session" "golang.org/x/oauth2" ) type Client struct { oauth2 *oauth2.Config provider *oidc.Provider verifier *oidc.IDTokenVerifier } func (c *Client) Verifier() *oidc.IDTokenVerifier { return c.verifier } func (c *Client) Provider() *oidc.Provider { return c.provider } func (c *Client) Login(w http.ResponseWriter, r *http.Request) { ctn := container.Must(r.Context()) sess, err := session.Must(ctn).Get(w, r) if err != nil { panic(errors.Wrap(err, "could not retrieve session")) } state := uniuri.New() sess.Set(SessionOIDCStateKey, state) if err := sess.Save(w, r); err != nil { panic(errors.Wrap(err, "could not save session")) } authCodeOptions := []oauth2.AuthCodeOption{} rawIDToken, _ := RawIDToken(w, r) if rawIDToken != "" { authCodeOptions = append( authCodeOptions, oauth2.SetAuthURLParam("id_token_hint", rawIDToken), ) } authCodeURL := c.oauth2.AuthCodeURL( state, authCodeOptions..., ) http.Redirect(w, r, authCodeURL, http.StatusFound) } func (c *Client) Logout(w http.ResponseWriter, r *http.Request, postLogoutRedirectURL string) { rawIDToken, err := RawIDToken(w, r) if err != nil { panic(errors.Wrap(err, "could not retrieve raw id token")) } ctn := container.Must(r.Context()) sess, err := session.Must(ctn).Get(w, r) if err != nil { panic(errors.Wrap(err, "could not retrieve session")) } state := uniuri.New() sess.Set(SessionOIDCStateKey, state) sess.Unset(SessionOIDCRawTokenKey) sess.Unset(SessionOIDCTokenKey) if err := sess.Save(w, r); err != nil { panic(errors.Wrap(err, "could not save session")) } sessionEndURL, err := c.sessionEndURL(rawIDToken, state, postLogoutRedirectURL) if err != nil { panic(errors.Wrap(err, "could not retrieve session end url")) } if sessionEndURL != "" { http.Redirect(w, r, sessionEndURL, http.StatusFound) } else { http.Redirect(w, r, postLogoutRedirectURL, http.StatusFound) } } func (c *Client) sessionEndURL(idTokenHint, state, postLogoutRedirectURL string) (string, error) { sessionEndEndpoint := &struct { URL string `json:"end_session_endpoint"` }{} if err := c.provider.Claims(&sessionEndEndpoint); err != nil { return "", errors.Wrap(err, "could not unmarshal claims") } if sessionEndEndpoint.URL == "" { return "", nil } var buf bytes.Buffer buf.WriteString(sessionEndEndpoint.URL) v := url.Values{} if idTokenHint != "" { v.Set("id_token_hint", idTokenHint) } if postLogoutRedirectURL != "" { v.Set("post_logout_redirect_uri", postLogoutRedirectURL) } if state != "" { v.Set("state", state) } if strings.Contains(sessionEndEndpoint.URL, "?") { buf.WriteByte('&') } else { buf.WriteByte('?') } buf.WriteString(v.Encode()) return buf.String(), nil } func (c *Client) Validate(w http.ResponseWriter, r *http.Request) (*oidc.IDToken, string, error) { ctx := r.Context() ctn := container.Must(ctx) sess, err := session.Must(ctn).Get(w, r) if err != nil { return nil, "", errors.Wrap(err, "could not retrieve session") } state, ok := sess.Get(SessionOIDCStateKey).(string) if !ok { return nil, "", errors.New("invalid state") } if r.URL.Query().Get("state") != state { return nil, "", errors.New("state mismatch") } code := r.URL.Query().Get("code") token, err := c.oauth2.Exchange(ctx, code) if err != nil { return nil, "", errors.Wrap(err, "could not exchange token") } rawIDToken, ok := token.Extra("id_token").(string) if !ok { return nil, "", errors.New("could not find id token") } idToken, err := c.verifier.Verify(ctx, rawIDToken) if err != nil { return nil, "", errors.Wrap(err, "could not verify id token") } return idToken, rawIDToken, nil } func NewClient(opts ...OptionFunc) *Client { opt := fromDefault(opts...) oauth2 := &oauth2.Config{ ClientID: opt.ClientID, ClientSecret: opt.ClientSecret, Endpoint: opt.Provider.Endpoint(), RedirectURL: opt.RedirectURL, Scopes: opt.Scopes, } verifier := opt.Provider.Verifier(&oidc.Config{ ClientID: opt.ClientID, }) return &Client{oauth2, opt.Provider, verifier} }