package oidc import ( "encoding/gob" "net/http" "github.com/coreos/go-oidc/v3/oidc" "github.com/pkg/errors" "gitlab.com/wpetit/goweb/logger" "gitlab.com/wpetit/goweb/middleware/container" "gitlab.com/wpetit/goweb/service/session" ) const ( SessionIDTokenKey = "oidc-id-token" SessionOIDCStateKey = "oidc-state" SessionOIDCNonceKey = "oidc-nonce" ) func init() { gob.Register(&oidc.IDToken{}) } func Middleware(next http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { if _, err := IDToken(w, r); err != nil { logger.Error(r.Context(), "could not retrieve idtoken", logger.E(errors.WithStack(err))) ctn := container.Must(r.Context()) client := Must(ctn) client.Login(w, r) return } next.ServeHTTP(w, r) } return http.HandlerFunc(fn) } func HandleCallback(next http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() ctn := container.Must(ctx) client := Must(ctn) _, rawIDToken, err := client.Validate(w, r) if err != nil { logger.Error(ctx, "could not validate oidc token", logger.E(errors.WithStack(err))) http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) return } sess, err := session.Must(ctn).Get(w, r) if err != nil { panic(errors.Wrap(err, "could not retrieve session")) } sess.Set(SessionIDTokenKey, rawIDToken) if err := sess.Save(w, r); err != nil { panic(errors.Wrap(err, "could not save session")) } next.ServeHTTP(w, r) } return http.HandlerFunc(fn) } func RawIDToken(w http.ResponseWriter, r *http.Request) (string, error) { ctx := r.Context() ctn, err := container.From(ctx) if err != nil { return "", errors.Wrap(err, "could not retrieve service container") } sess, err := session.Must(ctn).Get(w, r) if err != nil { return "", errors.Wrap(err, "could not retrieve session") } rawIDToken, ok := sess.Get(SessionIDTokenKey).(string) if !ok || rawIDToken == "" { return "", errors.New("invalid id token") } return rawIDToken, nil } func IDToken(w http.ResponseWriter, r *http.Request) (*oidc.IDToken, error) { rawIDToken, err := RawIDToken(w, r) if err != nil { return nil, errors.Wrap(err, "could not retrieve raw idtoken") } ctx := r.Context() ctn, err := container.From(ctx) if err != nil { return nil, errors.Wrap(err, "could not retrieve service container") } client, err := From(ctn) if err != nil { return nil, errors.Wrap(err, "could not retrieve oidc service") } idToken, err := client.verifier.Verify(ctx, rawIDToken) if err != nil { return nil, errors.Wrap(err, "could not verify id token") } return idToken, nil }