rebound/ssh/session_manager.go

185 lines
3.6 KiB
Go

package ssh
import (
"errors"
"log"
"sync"
"time"
)
type SessionID string
type SessionType int
const (
TypeServiceUnknown SessionType = iota
TypeServiceProvider
TypeServiceConsumer
)
type SessionData struct {
Type SessionType
Token string
}
type SessionManager struct {
sessions map[SessionID]SessionData
sessionsMutex sync.Mutex
tokenIndex map[string]SessionID
updates map[SessionID][]chan SessionData
updatesMutex sync.Mutex
updateReadTimeout time.Duration
}
func (m *SessionManager) Get(sessID SessionID, defaultValue SessionData) SessionData {
log.Println("reading session", sessID)
m.sessionsMutex.Lock()
defer m.sessionsMutex.Unlock()
session, exists := m.sessions[sessID]
if !exists {
session = defaultValue
m.sessions[sessID] = session
m.updatesMutex.Lock()
m.dispatchUpdate(sessID, session)
m.updatesMutex.Unlock()
}
return session
}
func (m *SessionManager) FindByToken(token string) SessionID {
m.sessionsMutex.Lock()
defer m.sessionsMutex.Unlock()
sessID, exists := m.tokenIndex[token]
if !exists {
return ""
}
return sessID
}
func (m *SessionManager) Set(sessID SessionID, sess SessionData) {
log.Println("updating session", sessID, sess)
m.sessionsMutex.Lock()
oldSess, ok := m.sessions[sessID]
if ok {
m.updateTokenIndex(sessID, sess.Token, oldSess.Token)
} else {
m.updateTokenIndex(sessID, sess.Token, "")
}
m.sessions[sessID] = sess
m.sessionsMutex.Unlock()
m.updatesMutex.Lock()
m.dispatchUpdate(sessID, sess)
m.updatesMutex.Unlock()
}
func (m *SessionManager) Remove(sessID SessionID) {
m.sessionsMutex.Lock()
oldSess, ok := m.sessions[sessID]
if ok {
m.updateTokenIndex(sessID, "", oldSess.Token)
}
delete(m.sessions, sessID)
m.sessionsMutex.Unlock()
m.updatesMutex.Lock()
m.closeAllUpdates(sessID)
m.updatesMutex.Unlock()
}
func (m *SessionManager) OnUpdate(sessID SessionID) (<-chan SessionData, func()) {
update := make(chan SessionData)
close := func() {
m.updatesMutex.Lock()
m.closeUpdate(sessID, update)
m.updatesMutex.Unlock()
}
m.updatesMutex.Lock()
defer m.updatesMutex.Unlock()
channels, exists := m.updates[sessID]
if !exists {
channels = make([]chan SessionData, 0, 1)
}
channels = append(channels, update)
m.updates[sessID] = channels
return update, close
}
func (m *SessionManager) closeAllUpdates(sessID SessionID) {
channels, exists := m.updates[sessID]
if !exists {
return
}
for _, ch := range channels {
m.closeUpdate(sessID, ch)
}
}
func (m *SessionManager) closeUpdate(sessID SessionID, update chan SessionData) {
channels, exists := m.updates[sessID]
if !exists {
return
}
for idx, ch := range channels {
if ch != update {
continue
}
close(ch)
m.updates[sessID] = append(channels[:idx], channels[idx+1:]...)
}
}
func (m *SessionManager) dispatchUpdate(sessID SessionID, sess SessionData) {
channels, exists := m.updates[sessID]
if !exists {
return
}
for _, ch := range channels {
timeout := time.After(m.updateReadTimeout)
select {
case ch <- sess:
case <-timeout:
err := errors.New("session update read timed out")
log.Printf("[ERROR] %+v", err)
}
}
}
func (m *SessionManager) updateTokenIndex(sessID SessionID, addedToken, deletedToken string) {
if addedToken != "" {
m.tokenIndex[addedToken] = sessID
}
if deletedToken != "" {
delete(m.tokenIndex, deletedToken)
}
}
func NewSessionManager(updateReadTimeout time.Duration) *SessionManager {
return &SessionManager{
sessions: make(map[SessionID]SessionData),
tokenIndex: make(map[string]SessionID),
updates: make(map[SessionID][]chan SessionData),
updateReadTimeout: updateReadTimeout,
}
}