package server import ( "context" "log" "net/http" "sync" "github.com/go-chi/chi/v5" "github.com/gorilla/websocket" "github.com/pkg/errors" "gitlab.com/wpetit/goweb/logger" ) var ( channels = &channelMap{ index: make(map[string]map[*websocket.Conn]struct{}), } ) func (s *Server) checkOrigin(r *http.Request) bool { allowedOrigins, err := s.getAllowedOrigins() if err != nil { logger.Error(r.Context(), "could not retrieve allowed origins", logger.CapturedE(errors.WithStack(err))) return false } requestOrigin := r.Header.Get("Origin") for _, origin := range allowedOrigins { if requestOrigin == origin || origin == "*" { return true } } return true } func (s *Server) handleBroadcast(w http.ResponseWriter, r *http.Request) { ctx := r.Context() c, err := s.upgrader.Upgrade(w, r, nil) if err != nil { log.Print("upgrade:", err) return } channelID := chi.URLParam(r, "channelID") channels.Add(channelID, c) defer func() { channels.Remove(channelID, c) if err := c.Close(); err != nil && !websocket.IsCloseError(err, 1001) { logger.Error(ctx, "could not close connection", logger.E(errors.WithStack(err))) } }() for { messageType, message, err := c.ReadMessage() if err != nil && !websocket.IsCloseError(err, 1001) { logger.Error(ctx, "could not read message", logger.E(errors.WithStack(err))) return } if messageType == -1 { return } logger.Debug(ctx, "broadcasting message", logger.F("message", message), logger.F("messageType", messageType)) channels.Send(ctx, channelID, messageType, message, c) } } type channelMap struct { mutex sync.RWMutex index map[string]map[*websocket.Conn]struct{} } func (m *channelMap) Remove(channelID string, conn *websocket.Conn) { m.mutex.Lock() defer m.mutex.Unlock() conns, exists := m.index[channelID] if !exists { return } delete(conns, conn) if len(conns) == 0 { delete(m.index, channelID) } } func (m *channelMap) Add(channelID string, conn *websocket.Conn) { m.mutex.Lock() defer m.mutex.Unlock() conns, exists := m.index[channelID] if !exists { conns = make(map[*websocket.Conn]struct{}) } conns[conn] = struct{}{} m.index[channelID] = conns } func (m *channelMap) Send(ctx context.Context, channelID string, messageType int, message []byte, except *websocket.Conn) { m.mutex.RLock() defer m.mutex.RUnlock() conns, exists := m.index[channelID] if !exists { return } for c := range conns { if except == c { continue } if err := c.WriteMessage(messageType, message); err != nil { logger.Error(ctx, "could not write message", logger.E(errors.WithStack(err))) } } }