diff --git a/client.go b/client.go
index e1aee9b..33670a5 100644
--- a/client.go
+++ b/client.go
@@ -1,7 +1,10 @@
package oidc
import (
+ "bytes"
"net/http"
+ "net/url"
+ "strings"
"github.com/coreos/go-oidc"
"github.com/dchest/uniuri"
@@ -41,10 +44,30 @@ func (c *Client) Login(w http.ResponseWriter, r *http.Request) {
panic(errors.Wrap(err, "could not save session"))
}
- http.Redirect(w, r, c.oauth2.AuthCodeURL(state), http.StatusFound)
+ authCodeOptions := []oauth2.AuthCodeOption{}
+
+ rawIDToken, _ := RawIDToken(w, r)
+ if rawIDToken != "" {
+ authCodeOptions = append(
+ authCodeOptions,
+ oauth2.SetAuthURLParam("id_token_hint", rawIDToken),
+ )
+ }
+
+ authCodeURL := c.oauth2.AuthCodeURL(
+ state,
+ authCodeOptions...,
+ )
+
+ http.Redirect(w, r, authCodeURL, http.StatusFound)
}
-func (c *Client) Logout(w http.ResponseWriter, r *http.Request) {
+func (c *Client) Logout(w http.ResponseWriter, r *http.Request, postLogoutRedirectURL string) {
+ rawIDToken, err := RawIDToken(w, r)
+ if err != nil {
+ panic(errors.Wrap(err, "could not retrieve raw id token"))
+ }
+
ctn := container.Must(r.Context())
sess, err := session.Must(ctn).Get(w, r)
@@ -55,50 +78,102 @@ func (c *Client) Logout(w http.ResponseWriter, r *http.Request) {
state := uniuri.New()
sess.Set(SessionOIDCStateKey, state)
+ sess.Unset(SessionOIDCRawTokenKey)
+ sess.Unset(SessionOIDCTokenKey)
if err := sess.Save(w, r); err != nil {
panic(errors.Wrap(err, "could not save session"))
}
- http.Redirect(w, r, c.oauth2.AuthCodeURL(state), http.StatusFound)
+ sessionEndURL, err := c.sessionEndURL(rawIDToken, state, postLogoutRedirectURL)
+ if err != nil {
+ panic(errors.Wrap(err, "could not retrieve session end url"))
+ }
+
+ if sessionEndURL != "" {
+ http.Redirect(w, r, sessionEndURL, http.StatusFound)
+ } else {
+ http.Redirect(w, r, postLogoutRedirectURL, http.StatusFound)
+ }
}
-func (c *Client) Validate(w http.ResponseWriter, r *http.Request) (*oidc.IDToken, error) {
+func (c *Client) sessionEndURL(idTokenHint, state, postLogoutRedirectURL string) (string, error) {
+ sessionEndEndpoint := &struct {
+ URL string `json:"end_session_endpoint"`
+ }{}
+
+ if err := c.provider.Claims(&sessionEndEndpoint); err != nil {
+ return "", errors.Wrap(err, "could not unmarshal claims")
+ }
+
+ if sessionEndEndpoint.URL == "" {
+ return "", nil
+ }
+
+ var buf bytes.Buffer
+ buf.WriteString(sessionEndEndpoint.URL)
+
+ v := url.Values{}
+
+ if idTokenHint != "" {
+ v.Set("id_token_hint", idTokenHint)
+ }
+
+ if postLogoutRedirectURL != "" {
+ v.Set("post_logout_redirect_uri", postLogoutRedirectURL)
+ }
+
+ if state != "" {
+ v.Set("state", state)
+ }
+
+ if strings.Contains(sessionEndEndpoint.URL, "?") {
+ buf.WriteByte('&')
+ } else {
+ buf.WriteByte('?')
+ }
+
+ buf.WriteString(v.Encode())
+
+ return buf.String(), nil
+}
+
+func (c *Client) Validate(w http.ResponseWriter, r *http.Request) (*oidc.IDToken, string, error) {
ctx := r.Context()
ctn := container.Must(ctx)
sess, err := session.Must(ctn).Get(w, r)
if err != nil {
- return nil, errors.Wrap(err, "could not retrieve session")
+ return nil, "", errors.Wrap(err, "could not retrieve session")
}
state, ok := sess.Get(SessionOIDCStateKey).(string)
if !ok {
- return nil, errors.New("invalid state")
+ return nil, "", errors.New("invalid state")
}
if r.URL.Query().Get("state") != state {
- return nil, errors.New("state mismatch")
+ return nil, "", errors.New("state mismatch")
}
code := r.URL.Query().Get("code")
token, err := c.oauth2.Exchange(ctx, code)
if err != nil {
- return nil, errors.Wrap(err, "could not exchange token")
+ return nil, "", errors.Wrap(err, "could not exchange token")
}
rawIDToken, ok := token.Extra("id_token").(string)
if !ok {
- return nil, errors.New("could not find id token")
+ return nil, "", errors.New("could not find id token")
}
idToken, err := c.verifier.Verify(ctx, rawIDToken)
if err != nil {
- return nil, errors.Wrap(err, "could not verify id token")
+ return nil, "", errors.Wrap(err, "could not verify id token")
}
- return idToken, nil
+ return idToken, rawIDToken, nil
}
func NewClient(opts ...OptionFunc) *Client {
diff --git a/cmd/server/template/blocks/header.html.tmpl b/cmd/server/template/blocks/header.html.tmpl
index 83e0e81..6af7913 100644
--- a/cmd/server/template/blocks/header.html.tmpl
+++ b/cmd/server/template/blocks/header.html.tmpl
@@ -9,7 +9,11 @@
diff --git a/cmd/server/template/layouts/home.html.tmpl b/cmd/server/template/layouts/home.html.tmpl
index 8ccc566..78e2508 100644
--- a/cmd/server/template/layouts/home.html.tmpl
+++ b/cmd/server/template/layouts/home.html.tmpl
@@ -3,8 +3,6 @@
{{template "header" .}}
-
Jeton OpenID Connect
-
{{ .JSONIDToken }}
{{template "footer" .}}
diff --git a/cmd/server/template/layouts/profile.html.tmpl b/cmd/server/template/layouts/profile.html.tmpl
new file mode 100644
index 0000000..8ccc566
--- /dev/null
+++ b/cmd/server/template/layouts/profile.html.tmpl
@@ -0,0 +1,12 @@
+{{define "title"}}Accueil{{end}}
+{{define "body"}}
+
+
+ {{template "header" .}}
+
Jeton OpenID Connect
+
{{ .JSONIDToken }}
+ {{template "footer" .}}
+
+
+{{end}}
+{{template "base" .}}
diff --git a/go.sum b/go.sum
index ca9d749..36a6b8d 100644
--- a/go.sum
+++ b/go.sum
@@ -34,6 +34,7 @@ github.com/daaku/go.zipexe v1.0.0/go.mod h1:z8IiR6TsVLEYKwXAoE/I+8ys/sDkgTzSL0CL
github.com/danwakefield/fnmatch v0.0.0-20160403171240-cbb64ac3d964 h1:y5HC9v93H5EPKqaS1UYVg1uYah5Xf51mBfIoWehClUQ=
github.com/danwakefield/fnmatch v0.0.0-20160403171240-cbb64ac3d964/go.mod h1:Xd9hchkHSWYkEqJwUGisez3G1QY8Ryz0sdWrLPMGjLk=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/dchest/uniuri v0.0.0-20200228104902-7aecb25e1fe5 h1:RAV05c0xOkJ3dZGS0JFybxFKZ2WMLabgx3uXnd7rpGs=
github.com/dchest/uniuri v0.0.0-20200228104902-7aecb25e1fe5/go.mod h1:GgB8SF9nRG+GqaDtLcwJZsQFhcogVCJ79j4EdT0c2V4=
diff --git a/internal/route/home.go b/internal/route/home.go
new file mode 100644
index 0000000..2b0e17a
--- /dev/null
+++ b/internal/route/home.go
@@ -0,0 +1,28 @@
+package route
+
+import (
+ "net/http"
+
+ oidc "forge.cadoles.com/wpetit/goweb-oidc"
+ "github.com/pkg/errors"
+ "gitlab.com/wpetit/goweb/middleware/container"
+ "gitlab.com/wpetit/goweb/service/template"
+)
+
+func serveHomePage(w http.ResponseWriter, r *http.Request) {
+ ctn := container.Must(r.Context())
+ tmpl := template.Must(ctn)
+
+ idToken, _ := oidc.IDToken(w, r)
+
+ if idToken != nil {
+ http.Redirect(w, r, "/profile", http.StatusSeeOther)
+ return
+ }
+
+ data := extendTemplateData(w, r, template.Data{})
+
+ if err := tmpl.RenderPage(w, "home.html.tmpl", data); err != nil {
+ panic(errors.Wrapf(err, "could not render '%s' page", r.URL.Path))
+ }
+}
diff --git a/internal/route/login.go b/internal/route/login.go
index e61d436..917efae 100644
--- a/internal/route/login.go
+++ b/internal/route/login.go
@@ -1,41 +1,20 @@
package route
import (
- "encoding/json"
"net/http"
oidc "forge.cadoles.com/wpetit/goweb-oidc"
- "github.com/pkg/errors"
"gitlab.com/wpetit/goweb/logger"
"gitlab.com/wpetit/goweb/middleware/container"
- "gitlab.com/wpetit/goweb/service/template"
)
-func serveHomePage(w http.ResponseWriter, r *http.Request) {
+func handleLogin(w http.ResponseWriter, r *http.Request) {
ctn := container.Must(r.Context())
- tmpl := template.Must(ctn)
-
- idToken, err := oidc.IDToken(w, r)
- if err != nil {
- panic(errors.Wrap(err, "could not retrieve idToken"))
- }
-
- jsonIDToken, err := json.MarshalIndent(idToken, "", " ")
- if err != nil {
- panic(errors.Wrap(err, "could not encode idToken"))
- }
-
- data := extendTemplateData(w, r, template.Data{
- "IDToken": idToken,
- "JSONIDToken": string(jsonIDToken),
- })
-
- if err := tmpl.RenderPage(w, "home.html.tmpl", data); err != nil {
- panic(errors.Wrapf(err, "could not render '%s' page", r.URL.Path))
- }
+ client := oidc.Must(ctn)
+ client.Login(w, r)
}
-func handleLogin(w http.ResponseWriter, r *http.Request) {
+func handleLoginCallback(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
idToken, err := oidc.IDToken(w, r)
diff --git a/internal/route/logout.go b/internal/route/logout.go
index 07eefe5..a239175 100644
--- a/internal/route/logout.go
+++ b/internal/route/logout.go
@@ -23,5 +23,5 @@ func handleLogout(w http.ResponseWriter, r *http.Request) {
client := oidc.Must(ctn)
- client.Logout(w, r)
+ client.Logout(w, r, "")
}
diff --git a/internal/route/mount.go b/internal/route/mount.go
index 9f7f89b..4b073f2 100644
--- a/internal/route/mount.go
+++ b/internal/route/mount.go
@@ -9,14 +9,16 @@ import (
)
func Mount(r *chi.Mux, config *config.Config) error {
- r.Group(func(r chi.Router) {
- r.Use(oidc.Middleware)
- r.Get("/", serveHomePage)
- })
-
- r.With(oidc.HandleCallback).Get("/oauth2/callback", handleLogin)
+ r.With(oidc.HandleCallback).Get("/oauth2/callback", handleLoginCallback)
+ r.Get("/", serveHomePage)
r.Get("/logout", handleLogout)
+ r.Get("/login", handleLogin)
+
+ r.Route("/profile", func(r chi.Router) {
+ r.Use(oidc.Middleware)
+ r.Get("/", serveProfilePage)
+ })
notFoundHandler := r.NotFoundHandler()
r.Get("/*", static.Dir(config.HTTP.PublicDir, "", notFoundHandler))
diff --git a/internal/route/profile.go b/internal/route/profile.go
new file mode 100644
index 0000000..36ee144
--- /dev/null
+++ b/internal/route/profile.go
@@ -0,0 +1,35 @@
+package route
+
+import (
+ "encoding/json"
+ "net/http"
+
+ oidc "forge.cadoles.com/wpetit/goweb-oidc"
+ "github.com/pkg/errors"
+ "gitlab.com/wpetit/goweb/middleware/container"
+ "gitlab.com/wpetit/goweb/service/template"
+)
+
+func serveProfilePage(w http.ResponseWriter, r *http.Request) {
+ ctn := container.Must(r.Context())
+ tmpl := template.Must(ctn)
+
+ idToken, err := oidc.IDToken(w, r)
+ if err != nil {
+ panic(errors.Wrap(err, "could not retrieve idToken"))
+ }
+
+ jsonIDToken, err := json.MarshalIndent(idToken, "", " ")
+ if err != nil {
+ panic(errors.Wrap(err, "could not encode idToken"))
+ }
+
+ data := extendTemplateData(w, r, template.Data{
+ "IDToken": idToken,
+ "JSONIDToken": string(jsonIDToken),
+ })
+
+ if err := tmpl.RenderPage(w, "profile.html.tmpl", data); err != nil {
+ panic(errors.Wrapf(err, "could not render '%s' page", r.URL.Path))
+ }
+}
diff --git a/middleware.go b/middleware.go
index 0701d45..efafcf0 100644
--- a/middleware.go
+++ b/middleware.go
@@ -13,8 +13,9 @@ import (
)
const (
- SessionOIDCTokenKey = "oidc-token"
- SessionOIDCStateKey = "oidc-state"
+ SessionOIDCTokenKey = "oidc-token"
+ SessionOIDCRawTokenKey = "oidc-raw-token"
+ SessionOIDCStateKey = "oidc-state"
)
func init() {
@@ -47,7 +48,7 @@ func HandleCallback(next http.Handler) http.Handler {
ctn := container.Must(ctx)
client := Must(ctn)
- idToken, err := client.Validate(w, r)
+ idToken, rawIDToken, err := client.Validate(w, r)
if err != nil {
logger.Error(ctx, "could not validate oidc token", logger.E(err))
@@ -62,6 +63,7 @@ func HandleCallback(next http.Handler) http.Handler {
}
sess.Set(SessionOIDCTokenKey, idToken)
+ sess.Set(SessionOIDCRawTokenKey, rawIDToken)
if err := sess.Save(w, r); err != nil {
panic(errors.Wrap(err, "could not save session"))
@@ -73,14 +75,6 @@ func HandleCallback(next http.Handler) http.Handler {
return http.HandlerFunc(fn)
}
-func Logout(w http.ResponseWriter, r *http.Request) {
- // ctx := r.Context()
- // ctn := container.Must(ctx)
- // client := Must(ctn)
-
- // client
-}
-
func IDToken(w http.ResponseWriter, r *http.Request) (*oidc.IDToken, error) {
ctn := container.Must(r.Context())
@@ -96,3 +90,19 @@ func IDToken(w http.ResponseWriter, r *http.Request) (*oidc.IDToken, error) {
return idToken, nil
}
+
+func RawIDToken(w http.ResponseWriter, r *http.Request) (string, error) {
+ ctn := container.Must(r.Context())
+
+ sess, err := session.Must(ctn).Get(w, r)
+ if err != nil {
+ return "", errors.Wrap(err, "could not retrieve session")
+ }
+
+ rawIDToken, ok := sess.Get(SessionOIDCRawTokenKey).(string)
+ if !ok || rawIDToken == "" {
+ return "", errors.New("invalid raw id token")
+ }
+
+ return rawIDToken, nil
+}
diff --git a/option.go b/option.go
index b7492fd..837805b 100644
--- a/option.go
+++ b/option.go
@@ -16,6 +16,12 @@ type Option struct {
Scopes []string
}
+func WithRedirectURL(url string) OptionFunc {
+ return func(opt *Option) {
+ opt.RedirectURL = url
+ }
+}
+
func WithCredentials(clientID, clientSecret string) OptionFunc {
return func(opt *Option) {
opt.ClientID = clientID