package session import ( "bytes" "context" "crypto/rand" "encoding/base32" "encoding/gob" "io" "net/http" "strings" "time" "github.com/gorilla/sessions" "github.com/pkg/errors" ) var ( ErrNotFound = errors.New("not found") ) type StoreAdapter interface { Set(ctx context.Context, key string, data []byte, ttl time.Duration) error Del(ctx context.Context, key string) error Get(ctx context.Context, key string) ([]byte, error) } type Store struct { adapter StoreAdapter options sessions.Options keyPrefix string keyGen KeyGenFunc serializer SessionSerializer ttl time.Duration } type KeyGenFunc func() (string, error) func NewStore(adapter StoreAdapter, funcs ...OptionFunc) *Store { opts := NewOptions(funcs...) rs := &Store{ options: opts.Session, adapter: adapter, keyPrefix: opts.KeyPrefix, keyGen: generateRandomKey, serializer: GobSerializer{}, ttl: opts.TTL, } return rs } func (s *Store) Get(r *http.Request, name string) (*sessions.Session, error) { return sessions.GetRegistry(r).Get(s, name) } func (s *Store) New(r *http.Request, name string) (*sessions.Session, error) { session := sessions.NewSession(s, name) opts := s.options session.Options = &opts session.IsNew = true c, err := r.Cookie(name) if err != nil { return session, nil } session.ID = c.Value err = s.load(r.Context(), session) if err == nil { session.IsNew = false } else if !errors.Is(err, ErrNotFound) { return session, errors.WithStack(err) } return session, nil } func (s *Store) Save(r *http.Request, w http.ResponseWriter, session *sessions.Session) error { if session.Options.MaxAge < 0 { if err := s.delete(r.Context(), session); err != nil { return errors.WithStack(err) } http.SetCookie(w, sessions.NewCookie(session.Name(), "", session.Options)) return nil } if session.ID == "" { id, err := s.keyGen() if err != nil { return errors.Wrap(err, "failed to generate session id") } session.ID = id } if err := s.save(r.Context(), session); err != nil { return errors.WithStack(err) } http.SetCookie(w, sessions.NewCookie(session.Name(), session.ID, session.Options)) return nil } func (s *Store) Options(opts sessions.Options) { s.options = opts } func (s *Store) KeyPrefix(keyPrefix string) { s.keyPrefix = keyPrefix } func (s *Store) KeyGen(f KeyGenFunc) { s.keyGen = f } func (s *Store) Serializer(ss SessionSerializer) { s.serializer = ss } func (s *Store) save(ctx context.Context, session *sessions.Session) error { b, err := s.serializer.Serialize(session) if err != nil { return errors.WithStack(err) } ttl := time.Duration(session.Options.MaxAge) * time.Second if s.ttl < ttl || ttl == 0 { ttl = s.ttl } if err := s.adapter.Set(ctx, s.keyPrefix+session.ID, b, ttl); err != nil { return errors.WithStack(err) } return nil } // load reads session from Redis func (s *Store) load(ctx context.Context, session *sessions.Session) error { data, err := s.adapter.Get(ctx, s.keyPrefix+session.ID) if err != nil { return errors.WithStack(err) } return s.serializer.Deserialize(data, session) } // delete deletes session in Redis func (s *Store) delete(ctx context.Context, session *sessions.Session) error { if err := s.adapter.Del(ctx, s.keyPrefix+session.ID); err != nil { return errors.WithStack(err) } return nil } // SessionSerializer provides an interface for serialize/deserialize a session type SessionSerializer interface { Serialize(s *sessions.Session) ([]byte, error) Deserialize(b []byte, s *sessions.Session) error } // Gob serializer type GobSerializer struct{} func (gs GobSerializer) Serialize(s *sessions.Session) ([]byte, error) { buf := new(bytes.Buffer) enc := gob.NewEncoder(buf) if err := enc.Encode(s.Values); err != nil { return nil, errors.WithStack(err) } return buf.Bytes(), nil } func (gs GobSerializer) Deserialize(d []byte, s *sessions.Session) error { dec := gob.NewDecoder(bytes.NewBuffer(d)) return dec.Decode(&s.Values) } // generateRandomKey returns a new random key func generateRandomKey() (string, error) { k := make([]byte, 64) if _, err := io.ReadFull(rand.Reader, k); err != nil { return "", errors.WithStack(err) } return strings.TrimRight(base32.StdEncoding.EncodeToString(k), "="), nil }