daddy/internal/session/user_email.go

115 lines
2.2 KiB
Go

package session
import (
"context"
"net/http"
"github.com/pkg/errors"
"gitlab.com/wpetit/goweb/middleware/container"
"gitlab.com/wpetit/goweb/service/session"
)
type contextKey string
const userEmailKey contextKey = "user_email"
var (
ErrUserEmailNotFound = errors.New("user email not found")
)
func UserEmailMiddleware(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
userEmail, err := GetUserEmail(w, r)
if err != nil {
panic(errors.Wrap(err, "could not find user email"))
}
ctx := WithUserEmail(r.Context(), userEmail)
r = r.WithContext(ctx)
next.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
func WithUserEmail(ctx context.Context, email string) context.Context {
return context.WithValue(ctx, userEmailKey, email)
}
func UserEmail(ctx context.Context) (string, error) {
email, ok := ctx.Value(userEmailKey).(string)
if !ok {
return "", errors.WithStack(ErrUserEmailNotFound)
}
return email, nil
}
func SaveUserEmail(w http.ResponseWriter, r *http.Request, email string) error {
sess, err := getSession(w, r)
if err != nil {
return errors.WithStack(err)
}
sess.Set(string(userEmailKey), email)
if err := sess.Save(w, r); err != nil {
return errors.WithStack(err)
}
return nil
}
func ClearUserEmail(w http.ResponseWriter, r *http.Request, saveSession bool) error {
sess, err := getSession(w, r)
if err != nil {
return errors.WithStack(err)
}
sess.Unset(string(userEmailKey))
if saveSession {
if err := sess.Save(w, r); err != nil {
return errors.WithStack(err)
}
}
return nil
}
func GetUserEmail(w http.ResponseWriter, r *http.Request) (string, error) {
sess, err := getSession(w, r)
if err != nil {
return "", errors.WithStack(err)
}
email, ok := sess.Get(string(userEmailKey)).(string)
if !ok {
return "", errors.WithStack(ErrUserEmailNotFound)
}
return email, nil
}
func getSession(w http.ResponseWriter, r *http.Request) (session.Session, error) {
ctx := r.Context()
ctn, err := container.From(ctx)
if err != nil {
return nil, errors.WithStack(err)
}
session, err := session.From(ctn)
if err != nil {
return nil, errors.WithStack(err)
}
sess, err := session.Get(w, r)
if err != nil {
return nil, errors.WithStack(err)
}
return sess, nil
}