diff --git a/client.go b/client.go index e1aee9b..33670a5 100644 --- a/client.go +++ b/client.go @@ -1,7 +1,10 @@ package oidc import ( + "bytes" "net/http" + "net/url" + "strings" "github.com/coreos/go-oidc" "github.com/dchest/uniuri" @@ -41,10 +44,30 @@ func (c *Client) Login(w http.ResponseWriter, r *http.Request) { panic(errors.Wrap(err, "could not save session")) } - http.Redirect(w, r, c.oauth2.AuthCodeURL(state), http.StatusFound) + authCodeOptions := []oauth2.AuthCodeOption{} + + rawIDToken, _ := RawIDToken(w, r) + if rawIDToken != "" { + authCodeOptions = append( + authCodeOptions, + oauth2.SetAuthURLParam("id_token_hint", rawIDToken), + ) + } + + authCodeURL := c.oauth2.AuthCodeURL( + state, + authCodeOptions..., + ) + + http.Redirect(w, r, authCodeURL, http.StatusFound) } -func (c *Client) Logout(w http.ResponseWriter, r *http.Request) { +func (c *Client) Logout(w http.ResponseWriter, r *http.Request, postLogoutRedirectURL string) { + rawIDToken, err := RawIDToken(w, r) + if err != nil { + panic(errors.Wrap(err, "could not retrieve raw id token")) + } + ctn := container.Must(r.Context()) sess, err := session.Must(ctn).Get(w, r) @@ -55,50 +78,102 @@ func (c *Client) Logout(w http.ResponseWriter, r *http.Request) { state := uniuri.New() sess.Set(SessionOIDCStateKey, state) + sess.Unset(SessionOIDCRawTokenKey) + sess.Unset(SessionOIDCTokenKey) if err := sess.Save(w, r); err != nil { panic(errors.Wrap(err, "could not save session")) } - http.Redirect(w, r, c.oauth2.AuthCodeURL(state), http.StatusFound) + sessionEndURL, err := c.sessionEndURL(rawIDToken, state, postLogoutRedirectURL) + if err != nil { + panic(errors.Wrap(err, "could not retrieve session end url")) + } + + if sessionEndURL != "" { + http.Redirect(w, r, sessionEndURL, http.StatusFound) + } else { + http.Redirect(w, r, postLogoutRedirectURL, http.StatusFound) + } } -func (c *Client) Validate(w http.ResponseWriter, r *http.Request) (*oidc.IDToken, error) { +func (c *Client) sessionEndURL(idTokenHint, state, postLogoutRedirectURL string) (string, error) { + sessionEndEndpoint := &struct { + URL string `json:"end_session_endpoint"` + }{} + + if err := c.provider.Claims(&sessionEndEndpoint); err != nil { + return "", errors.Wrap(err, "could not unmarshal claims") + } + + if sessionEndEndpoint.URL == "" { + return "", nil + } + + var buf bytes.Buffer + buf.WriteString(sessionEndEndpoint.URL) + + v := url.Values{} + + if idTokenHint != "" { + v.Set("id_token_hint", idTokenHint) + } + + if postLogoutRedirectURL != "" { + v.Set("post_logout_redirect_uri", postLogoutRedirectURL) + } + + if state != "" { + v.Set("state", state) + } + + if strings.Contains(sessionEndEndpoint.URL, "?") { + buf.WriteByte('&') + } else { + buf.WriteByte('?') + } + + buf.WriteString(v.Encode()) + + return buf.String(), nil +} + +func (c *Client) Validate(w http.ResponseWriter, r *http.Request) (*oidc.IDToken, string, error) { ctx := r.Context() ctn := container.Must(ctx) sess, err := session.Must(ctn).Get(w, r) if err != nil { - return nil, errors.Wrap(err, "could not retrieve session") + return nil, "", errors.Wrap(err, "could not retrieve session") } state, ok := sess.Get(SessionOIDCStateKey).(string) if !ok { - return nil, errors.New("invalid state") + return nil, "", errors.New("invalid state") } if r.URL.Query().Get("state") != state { - return nil, errors.New("state mismatch") + return nil, "", errors.New("state mismatch") } code := r.URL.Query().Get("code") token, err := c.oauth2.Exchange(ctx, code) if err != nil { - return nil, errors.Wrap(err, "could not exchange token") + return nil, "", errors.Wrap(err, "could not exchange token") } rawIDToken, ok := token.Extra("id_token").(string) if !ok { - return nil, errors.New("could not find id token") + return nil, "", errors.New("could not find id token") } idToken, err := c.verifier.Verify(ctx, rawIDToken) if err != nil { - return nil, errors.Wrap(err, "could not verify id token") + return nil, "", errors.Wrap(err, "could not verify id token") } - return idToken, nil + return idToken, rawIDToken, nil } func NewClient(opts ...OptionFunc) *Client { diff --git a/cmd/server/template/blocks/header.html.tmpl b/cmd/server/template/blocks/header.html.tmpl index 83e0e81..6af7913 100644 --- a/cmd/server/template/blocks/header.html.tmpl +++ b/cmd/server/template/blocks/header.html.tmpl @@ -9,7 +9,11 @@
+ {{if .IDToken}} Logout + {{else}} + Login + {{end}}
diff --git a/cmd/server/template/layouts/home.html.tmpl b/cmd/server/template/layouts/home.html.tmpl index 8ccc566..78e2508 100644 --- a/cmd/server/template/layouts/home.html.tmpl +++ b/cmd/server/template/layouts/home.html.tmpl @@ -3,8 +3,6 @@
{{template "header" .}} -

Jeton OpenID Connect

-
{{ .JSONIDToken }}
{{template "footer" .}}
diff --git a/cmd/server/template/layouts/profile.html.tmpl b/cmd/server/template/layouts/profile.html.tmpl new file mode 100644 index 0000000..8ccc566 --- /dev/null +++ b/cmd/server/template/layouts/profile.html.tmpl @@ -0,0 +1,12 @@ +{{define "title"}}Accueil{{end}} +{{define "body"}} +
+
+ {{template "header" .}} +

Jeton OpenID Connect

+
{{ .JSONIDToken }}
+ {{template "footer" .}} +
+
+{{end}} +{{template "base" .}} diff --git a/go.sum b/go.sum index ca9d749..36a6b8d 100644 --- a/go.sum +++ b/go.sum @@ -34,6 +34,7 @@ github.com/daaku/go.zipexe v1.0.0/go.mod h1:z8IiR6TsVLEYKwXAoE/I+8ys/sDkgTzSL0CL github.com/danwakefield/fnmatch v0.0.0-20160403171240-cbb64ac3d964 h1:y5HC9v93H5EPKqaS1UYVg1uYah5Xf51mBfIoWehClUQ= github.com/danwakefield/fnmatch v0.0.0-20160403171240-cbb64ac3d964/go.mod h1:Xd9hchkHSWYkEqJwUGisez3G1QY8Ryz0sdWrLPMGjLk= github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= +github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/dchest/uniuri v0.0.0-20200228104902-7aecb25e1fe5 h1:RAV05c0xOkJ3dZGS0JFybxFKZ2WMLabgx3uXnd7rpGs= github.com/dchest/uniuri v0.0.0-20200228104902-7aecb25e1fe5/go.mod h1:GgB8SF9nRG+GqaDtLcwJZsQFhcogVCJ79j4EdT0c2V4= diff --git a/internal/route/home.go b/internal/route/home.go new file mode 100644 index 0000000..2b0e17a --- /dev/null +++ b/internal/route/home.go @@ -0,0 +1,28 @@ +package route + +import ( + "net/http" + + oidc "forge.cadoles.com/wpetit/goweb-oidc" + "github.com/pkg/errors" + "gitlab.com/wpetit/goweb/middleware/container" + "gitlab.com/wpetit/goweb/service/template" +) + +func serveHomePage(w http.ResponseWriter, r *http.Request) { + ctn := container.Must(r.Context()) + tmpl := template.Must(ctn) + + idToken, _ := oidc.IDToken(w, r) + + if idToken != nil { + http.Redirect(w, r, "/profile", http.StatusSeeOther) + return + } + + data := extendTemplateData(w, r, template.Data{}) + + if err := tmpl.RenderPage(w, "home.html.tmpl", data); err != nil { + panic(errors.Wrapf(err, "could not render '%s' page", r.URL.Path)) + } +} diff --git a/internal/route/login.go b/internal/route/login.go index e61d436..917efae 100644 --- a/internal/route/login.go +++ b/internal/route/login.go @@ -1,41 +1,20 @@ package route import ( - "encoding/json" "net/http" oidc "forge.cadoles.com/wpetit/goweb-oidc" - "github.com/pkg/errors" "gitlab.com/wpetit/goweb/logger" "gitlab.com/wpetit/goweb/middleware/container" - "gitlab.com/wpetit/goweb/service/template" ) -func serveHomePage(w http.ResponseWriter, r *http.Request) { +func handleLogin(w http.ResponseWriter, r *http.Request) { ctn := container.Must(r.Context()) - tmpl := template.Must(ctn) - - idToken, err := oidc.IDToken(w, r) - if err != nil { - panic(errors.Wrap(err, "could not retrieve idToken")) - } - - jsonIDToken, err := json.MarshalIndent(idToken, "", " ") - if err != nil { - panic(errors.Wrap(err, "could not encode idToken")) - } - - data := extendTemplateData(w, r, template.Data{ - "IDToken": idToken, - "JSONIDToken": string(jsonIDToken), - }) - - if err := tmpl.RenderPage(w, "home.html.tmpl", data); err != nil { - panic(errors.Wrapf(err, "could not render '%s' page", r.URL.Path)) - } + client := oidc.Must(ctn) + client.Login(w, r) } -func handleLogin(w http.ResponseWriter, r *http.Request) { +func handleLoginCallback(w http.ResponseWriter, r *http.Request) { ctx := r.Context() idToken, err := oidc.IDToken(w, r) diff --git a/internal/route/logout.go b/internal/route/logout.go index 07eefe5..a239175 100644 --- a/internal/route/logout.go +++ b/internal/route/logout.go @@ -23,5 +23,5 @@ func handleLogout(w http.ResponseWriter, r *http.Request) { client := oidc.Must(ctn) - client.Logout(w, r) + client.Logout(w, r, "") } diff --git a/internal/route/mount.go b/internal/route/mount.go index 9f7f89b..4b073f2 100644 --- a/internal/route/mount.go +++ b/internal/route/mount.go @@ -9,14 +9,16 @@ import ( ) func Mount(r *chi.Mux, config *config.Config) error { - r.Group(func(r chi.Router) { - r.Use(oidc.Middleware) - r.Get("/", serveHomePage) - }) - - r.With(oidc.HandleCallback).Get("/oauth2/callback", handleLogin) + r.With(oidc.HandleCallback).Get("/oauth2/callback", handleLoginCallback) + r.Get("/", serveHomePage) r.Get("/logout", handleLogout) + r.Get("/login", handleLogin) + + r.Route("/profile", func(r chi.Router) { + r.Use(oidc.Middleware) + r.Get("/", serveProfilePage) + }) notFoundHandler := r.NotFoundHandler() r.Get("/*", static.Dir(config.HTTP.PublicDir, "", notFoundHandler)) diff --git a/internal/route/profile.go b/internal/route/profile.go new file mode 100644 index 0000000..36ee144 --- /dev/null +++ b/internal/route/profile.go @@ -0,0 +1,35 @@ +package route + +import ( + "encoding/json" + "net/http" + + oidc "forge.cadoles.com/wpetit/goweb-oidc" + "github.com/pkg/errors" + "gitlab.com/wpetit/goweb/middleware/container" + "gitlab.com/wpetit/goweb/service/template" +) + +func serveProfilePage(w http.ResponseWriter, r *http.Request) { + ctn := container.Must(r.Context()) + tmpl := template.Must(ctn) + + idToken, err := oidc.IDToken(w, r) + if err != nil { + panic(errors.Wrap(err, "could not retrieve idToken")) + } + + jsonIDToken, err := json.MarshalIndent(idToken, "", " ") + if err != nil { + panic(errors.Wrap(err, "could not encode idToken")) + } + + data := extendTemplateData(w, r, template.Data{ + "IDToken": idToken, + "JSONIDToken": string(jsonIDToken), + }) + + if err := tmpl.RenderPage(w, "profile.html.tmpl", data); err != nil { + panic(errors.Wrapf(err, "could not render '%s' page", r.URL.Path)) + } +} diff --git a/middleware.go b/middleware.go index 0701d45..efafcf0 100644 --- a/middleware.go +++ b/middleware.go @@ -13,8 +13,9 @@ import ( ) const ( - SessionOIDCTokenKey = "oidc-token" - SessionOIDCStateKey = "oidc-state" + SessionOIDCTokenKey = "oidc-token" + SessionOIDCRawTokenKey = "oidc-raw-token" + SessionOIDCStateKey = "oidc-state" ) func init() { @@ -47,7 +48,7 @@ func HandleCallback(next http.Handler) http.Handler { ctn := container.Must(ctx) client := Must(ctn) - idToken, err := client.Validate(w, r) + idToken, rawIDToken, err := client.Validate(w, r) if err != nil { logger.Error(ctx, "could not validate oidc token", logger.E(err)) @@ -62,6 +63,7 @@ func HandleCallback(next http.Handler) http.Handler { } sess.Set(SessionOIDCTokenKey, idToken) + sess.Set(SessionOIDCRawTokenKey, rawIDToken) if err := sess.Save(w, r); err != nil { panic(errors.Wrap(err, "could not save session")) @@ -73,14 +75,6 @@ func HandleCallback(next http.Handler) http.Handler { return http.HandlerFunc(fn) } -func Logout(w http.ResponseWriter, r *http.Request) { - // ctx := r.Context() - // ctn := container.Must(ctx) - // client := Must(ctn) - - // client -} - func IDToken(w http.ResponseWriter, r *http.Request) (*oidc.IDToken, error) { ctn := container.Must(r.Context()) @@ -96,3 +90,19 @@ func IDToken(w http.ResponseWriter, r *http.Request) (*oidc.IDToken, error) { return idToken, nil } + +func RawIDToken(w http.ResponseWriter, r *http.Request) (string, error) { + ctn := container.Must(r.Context()) + + sess, err := session.Must(ctn).Get(w, r) + if err != nil { + return "", errors.Wrap(err, "could not retrieve session") + } + + rawIDToken, ok := sess.Get(SessionOIDCRawTokenKey).(string) + if !ok || rawIDToken == "" { + return "", errors.New("invalid raw id token") + } + + return rawIDToken, nil +} diff --git a/option.go b/option.go index b7492fd..837805b 100644 --- a/option.go +++ b/option.go @@ -16,6 +16,12 @@ type Option struct { Scopes []string } +func WithRedirectURL(url string) OptionFunc { + return func(opt *Option) { + opt.RedirectURL = url + } +} + func WithCredentials(clientID, clientSecret string) OptionFunc { return func(opt *Option) { opt.ClientID = clientID