goweb-oidc/middleware.go

117 lines
2.5 KiB
Go
Raw Normal View History

2020-05-20 10:43:12 +02:00
package oidc
import (
"encoding/gob"
"net/http"
"github.com/coreos/go-oidc"
"github.com/getsentry/sentry-go"
2020-05-20 10:43:12 +02:00
"github.com/pkg/errors"
"gitlab.com/wpetit/goweb/logger"
"gitlab.com/wpetit/goweb/middleware/container"
"gitlab.com/wpetit/goweb/service/session"
)
const (
2020-05-26 11:17:16 +02:00
SessionIDTokenKey = "oidc-id-token"
SessionOIDCStateKey = "oidc-state"
2020-05-20 10:43:12 +02:00
)
func init() {
gob.Register(&oidc.IDToken{})
}
func Middleware(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
if _, err := IDToken(w, r); err != nil {
ctn := container.Must(r.Context())
client := Must(ctn)
client.Login(w, r)
return
}
next.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
func HandleCallback(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
ctn := container.Must(ctx)
client := Must(ctn)
2020-05-26 11:17:16 +02:00
_, rawIDToken, err := client.Validate(w, r)
2020-05-20 10:43:12 +02:00
if err != nil {
sentry.CaptureException(err)
2020-05-20 10:43:12 +02:00
logger.Error(ctx, "could not validate oidc token", logger.E(err))
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
return
}
sess, err := session.Must(ctn).Get(w, r)
if err != nil {
panic(errors.Wrap(err, "could not retrieve session"))
}
2020-05-26 11:17:16 +02:00
sess.Set(SessionIDTokenKey, rawIDToken)
2020-05-20 10:43:12 +02:00
if err := sess.Save(w, r); err != nil {
panic(errors.Wrap(err, "could not save session"))
}
next.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
2020-05-26 11:17:16 +02:00
func RawIDToken(w http.ResponseWriter, r *http.Request) (string, error) {
ctx := r.Context()
ctn, err := container.From(ctx)
if err != nil {
return "", errors.Wrap(err, "could not retrieve service container")
}
2020-05-20 10:43:12 +02:00
sess, err := session.Must(ctn).Get(w, r)
if err != nil {
2020-05-26 11:17:16 +02:00
return "", errors.Wrap(err, "could not retrieve session")
2020-05-20 10:43:12 +02:00
}
2020-05-26 11:17:16 +02:00
rawIDToken, ok := sess.Get(SessionIDTokenKey).(string)
if !ok || rawIDToken == "" {
return "", errors.New("invalid id token")
2020-05-20 10:43:12 +02:00
}
2020-05-26 11:17:16 +02:00
return rawIDToken, nil
2020-05-20 10:43:12 +02:00
}
2020-05-20 13:06:04 +02:00
2020-05-26 11:17:16 +02:00
func IDToken(w http.ResponseWriter, r *http.Request) (*oidc.IDToken, error) {
rawIDToken, err := RawIDToken(w, r)
if err != nil {
return nil, errors.Wrap(err, "could not retrieve raw idtoken")
}
2020-05-20 13:06:04 +02:00
2020-05-26 11:17:16 +02:00
ctx := r.Context()
ctn, err := container.From(ctx)
2020-05-20 13:06:04 +02:00
if err != nil {
2020-05-26 11:17:16 +02:00
return nil, errors.Wrap(err, "could not retrieve service container")
2020-05-20 13:06:04 +02:00
}
2020-05-26 11:17:16 +02:00
client, err := From(ctn)
if err != nil {
return nil, errors.Wrap(err, "could not retrieve oidc service")
2020-05-20 13:06:04 +02:00
}
2020-05-26 11:17:16 +02:00
idToken, err := client.verifier.Verify(ctx, rawIDToken)
if err != nil {
return nil, errors.Wrap(err, "could not verify id token")
}
return idToken, nil
2020-05-20 13:06:04 +02:00
}