package session import ( "context" "net/http" "github.com/pkg/errors" "gitlab.com/wpetit/goweb/middleware/container" "gitlab.com/wpetit/goweb/service/session" ) type contextKey string const userEmailKey contextKey = "user_email" var ( ErrUserEmailNotFound = errors.New("user email not found") ) func UserEmailMiddleware(next http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { userEmail, err := GetUserEmail(w, r) if err != nil { http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } ctx := WithUserEmail(r.Context(), userEmail) r = r.WithContext(ctx) next.ServeHTTP(w, r) } return http.HandlerFunc(fn) } func WithUserEmail(ctx context.Context, email string) context.Context { return context.WithValue(ctx, userEmailKey, email) } func UserEmail(ctx context.Context) (string, error) { email, ok := ctx.Value(userEmailKey).(string) if !ok { return "", errors.WithStack(ErrUserEmailNotFound) } return email, nil } func SaveUserEmail(w http.ResponseWriter, r *http.Request, email string) error { sess, err := getSession(w, r) if err != nil { return errors.WithStack(err) } sess.Set(string(userEmailKey), email) if err := sess.Save(w, r); err != nil { return errors.WithStack(err) } return nil } func ClearUserEmail(w http.ResponseWriter, r *http.Request, saveSession bool) error { sess, err := getSession(w, r) if err != nil { return errors.WithStack(err) } sess.Unset(string(userEmailKey)) if saveSession { if err := sess.Save(w, r); err != nil { return errors.WithStack(err) } } return nil } func GetUserEmail(w http.ResponseWriter, r *http.Request) (string, error) { sess, err := getSession(w, r) if err != nil { return "", errors.WithStack(err) } email, ok := sess.Get(string(userEmailKey)).(string) if !ok { return "", errors.WithStack(ErrUserEmailNotFound) } return email, nil } func getSession(w http.ResponseWriter, r *http.Request) (session.Session, error) { ctx := r.Context() ctn, err := container.From(ctx) if err != nil { return nil, errors.WithStack(err) } session, err := session.From(ctn) if err != nil { return nil, errors.WithStack(err) } sess, err := session.Get(w, r) if err != nil { return nil, errors.WithStack(err) } return sess, nil }