arcast/pkg/server/broadcast.go

128 lines
2.6 KiB
Go

package server
import (
"context"
"log"
"net/http"
"sync"
"github.com/go-chi/chi/v5"
"github.com/gorilla/websocket"
"github.com/pkg/errors"
"gitlab.com/wpetit/goweb/logger"
)
var (
channels = &channelMap{
index: make(map[string]map[*websocket.Conn]struct{}),
}
)
func (s *Server) checkOrigin(r *http.Request) bool {
allowedOrigins, err := s.getAllowedOrigins()
if err != nil {
logger.Error(r.Context(), "could not retrieve allowed origins", logger.CapturedE(errors.WithStack(err)))
return false
}
requestOrigin := r.Header.Get("Origin")
for _, origin := range allowedOrigins {
if requestOrigin == origin || origin == "*" {
return true
}
}
return true
}
func (s *Server) handleBroadcast(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
c, err := s.upgrader.Upgrade(w, r, nil)
if err != nil {
log.Print("upgrade:", err)
return
}
channelID := chi.URLParam(r, "channelID")
channels.Add(channelID, c)
defer func() {
channels.Remove(channelID, c)
if err := c.Close(); err != nil && !websocket.IsCloseError(err, 1001) {
logger.Error(ctx, "could not close connection", logger.E(errors.WithStack(err)))
}
}()
for {
messageType, message, err := c.ReadMessage()
if err != nil && !websocket.IsCloseError(err, 1001) {
logger.Error(ctx, "could not read message", logger.E(errors.WithStack(err)))
return
}
if messageType == -1 {
return
}
logger.Debug(ctx, "broadcasting message", logger.F("message", message), logger.F("messageType", messageType))
channels.Send(ctx, channelID, messageType, message, c)
}
}
type channelMap struct {
mutex sync.RWMutex
index map[string]map[*websocket.Conn]struct{}
}
func (m *channelMap) Remove(channelID string, conn *websocket.Conn) {
m.mutex.Lock()
defer m.mutex.Unlock()
conns, exists := m.index[channelID]
if !exists {
return
}
delete(conns, conn)
if len(conns) == 0 {
delete(m.index, channelID)
}
}
func (m *channelMap) Add(channelID string, conn *websocket.Conn) {
m.mutex.Lock()
defer m.mutex.Unlock()
conns, exists := m.index[channelID]
if !exists {
conns = make(map[*websocket.Conn]struct{})
}
conns[conn] = struct{}{}
m.index[channelID] = conns
}
func (m *channelMap) Send(ctx context.Context, channelID string, messageType int, message []byte, except *websocket.Conn) {
m.mutex.RLock()
defer m.mutex.RUnlock()
conns, exists := m.index[channelID]
if !exists {
return
}
for c := range conns {
if except == c {
continue
}
if err := c.WriteMessage(messageType, message); err != nil {
logger.Error(ctx, "could not write message", logger.E(errors.WithStack(err)))
}
}
}