package ssh import ( "errors" "log" "sync" "time" ) type SessionID string type SessionType int const ( TypeServiceUnknown SessionType = iota TypeServiceProvider TypeServiceConsumer ) type SessionData struct { Type SessionType Token string } type SessionManager struct { sessions map[SessionID]SessionData sessionsMutex sync.Mutex tokenIndex map[string]SessionID updates map[SessionID][]chan SessionData updatesMutex sync.Mutex updateReadTimeout time.Duration } func (m *SessionManager) Get(sessID SessionID, defaultValue SessionData) SessionData { log.Println("reading session", sessID) m.sessionsMutex.Lock() defer m.sessionsMutex.Unlock() session, exists := m.sessions[sessID] if !exists { session = defaultValue m.sessions[sessID] = session m.updatesMutex.Lock() m.dispatchUpdate(sessID, session) m.updatesMutex.Unlock() } return session } func (m *SessionManager) FindByToken(token string) SessionID { m.sessionsMutex.Lock() defer m.sessionsMutex.Unlock() sessID, exists := m.tokenIndex[token] if !exists { return "" } return sessID } func (m *SessionManager) Set(sessID SessionID, sess SessionData) { log.Println("updating session", sessID, sess) m.sessionsMutex.Lock() oldSess, ok := m.sessions[sessID] if ok { m.updateTokenIndex(sessID, sess.Token, oldSess.Token) } else { m.updateTokenIndex(sessID, sess.Token, "") } m.sessions[sessID] = sess m.sessionsMutex.Unlock() m.updatesMutex.Lock() m.dispatchUpdate(sessID, sess) m.updatesMutex.Unlock() } func (m *SessionManager) Remove(sessID SessionID) { m.sessionsMutex.Lock() oldSess, ok := m.sessions[sessID] if ok { m.updateTokenIndex(sessID, "", oldSess.Token) } delete(m.sessions, sessID) m.sessionsMutex.Unlock() m.updatesMutex.Lock() m.closeAllUpdates(sessID) m.updatesMutex.Unlock() } func (m *SessionManager) OnUpdate(sessID SessionID) (<-chan SessionData, func()) { update := make(chan SessionData) close := func() { m.updatesMutex.Lock() m.closeUpdate(sessID, update) m.updatesMutex.Unlock() } m.updatesMutex.Lock() defer m.updatesMutex.Unlock() channels, exists := m.updates[sessID] if !exists { channels = make([]chan SessionData, 0, 1) } channels = append(channels, update) m.updates[sessID] = channels return update, close } func (m *SessionManager) closeAllUpdates(sessID SessionID) { channels, exists := m.updates[sessID] if !exists { return } for _, ch := range channels { m.closeUpdate(sessID, ch) } } func (m *SessionManager) closeUpdate(sessID SessionID, update chan SessionData) { channels, exists := m.updates[sessID] if !exists { return } for idx, ch := range channels { if ch != update { continue } close(ch) m.updates[sessID] = append(channels[:idx], channels[idx+1:]...) } } func (m *SessionManager) dispatchUpdate(sessID SessionID, sess SessionData) { channels, exists := m.updates[sessID] if !exists { return } for _, ch := range channels { timeout := time.After(m.updateReadTimeout) select { case ch <- sess: case <-timeout: err := errors.New("session update read timed out") log.Printf("[ERROR] %+v", err) } } } func (m *SessionManager) updateTokenIndex(sessID SessionID, addedToken, deletedToken string) { if addedToken != "" { m.tokenIndex[addedToken] = sessID } if deletedToken != "" { delete(m.tokenIndex, deletedToken) } } func NewSessionManager(updateReadTimeout time.Duration) *SessionManager { return &SessionManager{ sessions: make(map[SessionID]SessionData), tokenIndex: make(map[string]SessionID), updates: make(map[SessionID][]chan SessionData), updateReadTimeout: updateReadTimeout, } }