185 lines
3.6 KiB
Go
185 lines
3.6 KiB
Go
|
package rebound
|
||
|
|
||
|
import (
|
||
|
"errors"
|
||
|
"log"
|
||
|
"sync"
|
||
|
"time"
|
||
|
)
|
||
|
|
||
|
type SessionID string
|
||
|
|
||
|
type SessionType int
|
||
|
|
||
|
const (
|
||
|
TypeServiceUnknown SessionType = iota
|
||
|
TypeServiceProvider
|
||
|
TypeServiceConsumer
|
||
|
)
|
||
|
|
||
|
type SessionData struct {
|
||
|
Type SessionType
|
||
|
Token string
|
||
|
}
|
||
|
|
||
|
type SessionManager struct {
|
||
|
sessions map[SessionID]SessionData
|
||
|
sessionsMutex sync.Mutex
|
||
|
|
||
|
tokenIndex map[string]SessionID
|
||
|
|
||
|
updates map[SessionID][]chan SessionData
|
||
|
updatesMutex sync.Mutex
|
||
|
|
||
|
updateReadTimeout time.Duration
|
||
|
}
|
||
|
|
||
|
func (m *SessionManager) Get(sessID SessionID, defaultValue SessionData) SessionData {
|
||
|
log.Println("reading session", sessID)
|
||
|
|
||
|
m.sessionsMutex.Lock()
|
||
|
defer m.sessionsMutex.Unlock()
|
||
|
|
||
|
session, exists := m.sessions[sessID]
|
||
|
if !exists {
|
||
|
session = defaultValue
|
||
|
m.sessions[sessID] = session
|
||
|
|
||
|
m.updatesMutex.Lock()
|
||
|
m.dispatchUpdate(sessID, session)
|
||
|
m.updatesMutex.Unlock()
|
||
|
}
|
||
|
|
||
|
return session
|
||
|
}
|
||
|
|
||
|
func (m *SessionManager) FindByToken(token string) SessionID {
|
||
|
m.sessionsMutex.Lock()
|
||
|
defer m.sessionsMutex.Unlock()
|
||
|
|
||
|
sessID, exists := m.tokenIndex[token]
|
||
|
if !exists {
|
||
|
return ""
|
||
|
}
|
||
|
|
||
|
return sessID
|
||
|
}
|
||
|
|
||
|
func (m *SessionManager) Set(sessID SessionID, sess SessionData) {
|
||
|
log.Println("updating session", sessID, sess)
|
||
|
|
||
|
m.sessionsMutex.Lock()
|
||
|
oldSess, ok := m.sessions[sessID]
|
||
|
if ok {
|
||
|
m.updateTokenIndex(sessID, sess.Token, oldSess.Token)
|
||
|
} else {
|
||
|
m.updateTokenIndex(sessID, sess.Token, "")
|
||
|
}
|
||
|
m.sessions[sessID] = sess
|
||
|
m.sessionsMutex.Unlock()
|
||
|
|
||
|
m.updatesMutex.Lock()
|
||
|
m.dispatchUpdate(sessID, sess)
|
||
|
m.updatesMutex.Unlock()
|
||
|
}
|
||
|
|
||
|
func (m *SessionManager) Remove(sessID SessionID) {
|
||
|
m.sessionsMutex.Lock()
|
||
|
oldSess, ok := m.sessions[sessID]
|
||
|
if ok {
|
||
|
m.updateTokenIndex(sessID, "", oldSess.Token)
|
||
|
}
|
||
|
delete(m.sessions, sessID)
|
||
|
m.sessionsMutex.Unlock()
|
||
|
|
||
|
m.updatesMutex.Lock()
|
||
|
m.closeAllUpdates(sessID)
|
||
|
m.updatesMutex.Unlock()
|
||
|
}
|
||
|
|
||
|
func (m *SessionManager) OnUpdate(sessID SessionID) (<-chan SessionData, func()) {
|
||
|
update := make(chan SessionData)
|
||
|
|
||
|
close := func() {
|
||
|
m.updatesMutex.Lock()
|
||
|
m.closeUpdate(sessID, update)
|
||
|
m.updatesMutex.Unlock()
|
||
|
}
|
||
|
|
||
|
m.updatesMutex.Lock()
|
||
|
defer m.updatesMutex.Unlock()
|
||
|
|
||
|
channels, exists := m.updates[sessID]
|
||
|
if !exists {
|
||
|
channels = make([]chan SessionData, 0, 1)
|
||
|
}
|
||
|
|
||
|
channels = append(channels, update)
|
||
|
m.updates[sessID] = channels
|
||
|
|
||
|
return update, close
|
||
|
}
|
||
|
|
||
|
func (m *SessionManager) closeAllUpdates(sessID SessionID) {
|
||
|
channels, exists := m.updates[sessID]
|
||
|
if !exists {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
for _, ch := range channels {
|
||
|
m.closeUpdate(sessID, ch)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (m *SessionManager) closeUpdate(sessID SessionID, update chan SessionData) {
|
||
|
channels, exists := m.updates[sessID]
|
||
|
if !exists {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
for idx, ch := range channels {
|
||
|
if ch != update {
|
||
|
continue
|
||
|
}
|
||
|
|
||
|
close(ch)
|
||
|
m.updates[sessID] = append(channels[:idx], channels[idx+1:]...)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (m *SessionManager) dispatchUpdate(sessID SessionID, sess SessionData) {
|
||
|
channels, exists := m.updates[sessID]
|
||
|
if !exists {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
for _, ch := range channels {
|
||
|
timeout := time.After(m.updateReadTimeout)
|
||
|
select {
|
||
|
case ch <- sess:
|
||
|
case <-timeout:
|
||
|
err := errors.New("session update read timed out")
|
||
|
log.Printf("[ERROR] %+v", err)
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func (m *SessionManager) updateTokenIndex(sessID SessionID, addedToken, deletedToken string) {
|
||
|
if addedToken != "" {
|
||
|
m.tokenIndex[addedToken] = sessID
|
||
|
}
|
||
|
|
||
|
if deletedToken != "" {
|
||
|
delete(m.tokenIndex, deletedToken)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func NewSessionManager(updateReadTimeout time.Duration) *SessionManager {
|
||
|
return &SessionManager{
|
||
|
sessions: make(map[SessionID]SessionData),
|
||
|
tokenIndex: make(map[string]SessionID),
|
||
|
updates: make(map[SessionID][]chan SessionData),
|
||
|
updateReadTimeout: updateReadTimeout,
|
||
|
}
|
||
|
}
|