package oidc import ( "encoding/gob" "log" "net/http" "github.com/coreos/go-oidc" "github.com/pkg/errors" "gitlab.com/wpetit/goweb/logger" "gitlab.com/wpetit/goweb/middleware/container" "gitlab.com/wpetit/goweb/service/session" ) const ( SessionOIDCTokenKey = "oidc-token" SessionOIDCRawTokenKey = "oidc-raw-token" SessionOIDCStateKey = "oidc-state" ) func init() { gob.Register(&oidc.IDToken{}) } func Middleware(next http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { if _, err := IDToken(w, r); err != nil { ctn := container.Must(r.Context()) log.Println("retrieving oidc client") client := Must(ctn) client.Login(w, r) return } next.ServeHTTP(w, r) } return http.HandlerFunc(fn) } func HandleCallback(next http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() ctn := container.Must(ctx) client := Must(ctn) idToken, rawIDToken, err := client.Validate(w, r) if err != nil { logger.Error(ctx, "could not validate oidc token", logger.E(err)) http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) return } sess, err := session.Must(ctn).Get(w, r) if err != nil { panic(errors.Wrap(err, "could not retrieve session")) } sess.Set(SessionOIDCTokenKey, idToken) sess.Set(SessionOIDCRawTokenKey, rawIDToken) if err := sess.Save(w, r); err != nil { panic(errors.Wrap(err, "could not save session")) } next.ServeHTTP(w, r) } return http.HandlerFunc(fn) } func IDToken(w http.ResponseWriter, r *http.Request) (*oidc.IDToken, error) { ctn := container.Must(r.Context()) sess, err := session.Must(ctn).Get(w, r) if err != nil { return nil, errors.Wrap(err, "could not retrieve session") } idToken, ok := sess.Get(SessionOIDCTokenKey).(*oidc.IDToken) if !ok || idToken == nil { return nil, errors.New("invalid id token") } return idToken, nil } func RawIDToken(w http.ResponseWriter, r *http.Request) (string, error) { ctn := container.Must(r.Context()) sess, err := session.Must(ctn).Get(w, r) if err != nil { return "", errors.Wrap(err, "could not retrieve session") } rawIDToken, ok := sess.Get(SessionOIDCRawTokenKey).(string) if !ok || rawIDToken == "" { return "", errors.New("invalid raw id token") } return rawIDToken, nil }