bouncer/internal/session/store.go

191 lines
4.2 KiB
Go
Raw Normal View History

package session
import (
"bytes"
"context"
"crypto/rand"
"encoding/base32"
"encoding/gob"
"io"
"net/http"
"strings"
"time"
"github.com/gorilla/sessions"
"github.com/pkg/errors"
)
var (
ErrNotFound = errors.New("not found")
)
type StoreAdapter interface {
Set(ctx context.Context, key string, data []byte, ttl time.Duration) error
Del(ctx context.Context, key string) error
Get(ctx context.Context, key string) ([]byte, error)
}
type Store struct {
adapter StoreAdapter
options sessions.Options
keyPrefix string
keyGen KeyGenFunc
serializer SessionSerializer
ttl time.Duration
}
type KeyGenFunc func() (string, error)
func NewStore(adapter StoreAdapter, funcs ...OptionFunc) *Store {
opts := NewOptions(funcs...)
rs := &Store{
options: opts.Session,
adapter: adapter,
keyPrefix: opts.KeyPrefix,
keyGen: generateRandomKey,
serializer: GobSerializer{},
ttl: opts.TTL,
}
return rs
}
func (s *Store) Get(r *http.Request, name string) (*sessions.Session, error) {
return sessions.GetRegistry(r).Get(s, name)
}
func (s *Store) New(r *http.Request, name string) (*sessions.Session, error) {
session := sessions.NewSession(s, name)
opts := s.options
session.Options = &opts
session.IsNew = true
c, err := r.Cookie(name)
if err != nil {
return session, nil
}
session.ID = c.Value
err = s.load(r.Context(), session)
if err == nil {
session.IsNew = false
} else if !errors.Is(err, ErrNotFound) {
return session, errors.WithStack(err)
}
return session, nil
}
func (s *Store) Save(r *http.Request, w http.ResponseWriter, session *sessions.Session) error {
if session.Options.MaxAge < 0 {
if err := s.delete(r.Context(), session); err != nil {
return errors.WithStack(err)
}
http.SetCookie(w, sessions.NewCookie(session.Name(), "", session.Options))
return nil
}
if session.ID == "" {
id, err := s.keyGen()
if err != nil {
return errors.Wrap(err, "failed to generate session id")
}
session.ID = id
}
if err := s.save(r.Context(), session); err != nil {
return errors.WithStack(err)
}
http.SetCookie(w, sessions.NewCookie(session.Name(), session.ID, session.Options))
return nil
}
func (s *Store) Options(opts sessions.Options) {
s.options = opts
}
func (s *Store) KeyPrefix(keyPrefix string) {
s.keyPrefix = keyPrefix
}
func (s *Store) KeyGen(f KeyGenFunc) {
s.keyGen = f
}
func (s *Store) Serializer(ss SessionSerializer) {
s.serializer = ss
}
func (s *Store) save(ctx context.Context, session *sessions.Session) error {
b, err := s.serializer.Serialize(session)
if err != nil {
return errors.WithStack(err)
}
ttl := time.Duration(session.Options.MaxAge) * time.Second
if s.ttl < ttl || ttl == 0 {
ttl = s.ttl
}
if err := s.adapter.Set(ctx, s.keyPrefix+session.ID, b, ttl); err != nil {
return errors.WithStack(err)
}
return nil
}
// load reads session from Redis
func (s *Store) load(ctx context.Context, session *sessions.Session) error {
data, err := s.adapter.Get(ctx, s.keyPrefix+session.ID)
if err != nil {
return errors.WithStack(err)
}
return s.serializer.Deserialize(data, session)
}
// delete deletes session in Redis
func (s *Store) delete(ctx context.Context, session *sessions.Session) error {
if err := s.adapter.Del(ctx, s.keyPrefix+session.ID); err != nil {
return errors.WithStack(err)
}
return nil
}
// SessionSerializer provides an interface for serialize/deserialize a session
type SessionSerializer interface {
Serialize(s *sessions.Session) ([]byte, error)
Deserialize(b []byte, s *sessions.Session) error
}
// Gob serializer
type GobSerializer struct{}
func (gs GobSerializer) Serialize(s *sessions.Session) ([]byte, error) {
buf := new(bytes.Buffer)
enc := gob.NewEncoder(buf)
if err := enc.Encode(s.Values); err != nil {
return nil, errors.WithStack(err)
}
return buf.Bytes(), nil
}
func (gs GobSerializer) Deserialize(d []byte, s *sessions.Session) error {
dec := gob.NewDecoder(bytes.NewBuffer(d))
return dec.Decode(&s.Values)
}
// generateRandomKey returns a new random key
func generateRandomKey() (string, error) {
k := make([]byte, 64)
if _, err := io.ReadFull(rand.Reader, k); err != nil {
return "", errors.WithStack(err)
}
return strings.TrimRight(base32.StdEncoding.EncodeToString(k), "="), nil
}