128 lines
2.6 KiB
Go
128 lines
2.6 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
"log"
|
|
"net/http"
|
|
"sync"
|
|
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/gorilla/websocket"
|
|
"github.com/pkg/errors"
|
|
"gitlab.com/wpetit/goweb/logger"
|
|
)
|
|
|
|
var (
|
|
channels = &channelMap{
|
|
index: make(map[string]map[*websocket.Conn]struct{}),
|
|
}
|
|
)
|
|
|
|
func (s *Server) checkOrigin(r *http.Request) bool {
|
|
allowedOrigins, err := s.getAllowedOrigins()
|
|
if err != nil {
|
|
logger.Error(r.Context(), "could not retrieve allowed origins", logger.CapturedE(errors.WithStack(err)))
|
|
return false
|
|
}
|
|
|
|
requestOrigin := r.Header.Get("Origin")
|
|
|
|
for _, origin := range allowedOrigins {
|
|
if requestOrigin == origin || origin == "*" {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return true
|
|
}
|
|
|
|
func (s *Server) handleBroadcast(w http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
|
|
c, err := s.upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
log.Print("upgrade:", err)
|
|
return
|
|
}
|
|
|
|
channelID := chi.URLParam(r, "channelID")
|
|
channels.Add(channelID, c)
|
|
|
|
defer func() {
|
|
channels.Remove(channelID, c)
|
|
|
|
if err := c.Close(); err != nil && !websocket.IsCloseError(err, 1001) {
|
|
logger.Error(ctx, "could not close connection", logger.E(errors.WithStack(err)))
|
|
}
|
|
}()
|
|
|
|
for {
|
|
messageType, message, err := c.ReadMessage()
|
|
if err != nil && !websocket.IsCloseError(err, 1001) {
|
|
logger.Error(ctx, "could not read message", logger.E(errors.WithStack(err)))
|
|
return
|
|
}
|
|
|
|
if messageType == -1 {
|
|
return
|
|
}
|
|
|
|
logger.Debug(ctx, "broadcasting message", logger.F("message", message), logger.F("messageType", messageType))
|
|
|
|
channels.Send(ctx, channelID, messageType, message, c)
|
|
}
|
|
}
|
|
|
|
type channelMap struct {
|
|
mutex sync.RWMutex
|
|
index map[string]map[*websocket.Conn]struct{}
|
|
}
|
|
|
|
func (m *channelMap) Remove(channelID string, conn *websocket.Conn) {
|
|
m.mutex.Lock()
|
|
defer m.mutex.Unlock()
|
|
|
|
conns, exists := m.index[channelID]
|
|
if !exists {
|
|
return
|
|
}
|
|
|
|
delete(conns, conn)
|
|
if len(conns) == 0 {
|
|
delete(m.index, channelID)
|
|
}
|
|
}
|
|
|
|
func (m *channelMap) Add(channelID string, conn *websocket.Conn) {
|
|
m.mutex.Lock()
|
|
defer m.mutex.Unlock()
|
|
|
|
conns, exists := m.index[channelID]
|
|
if !exists {
|
|
conns = make(map[*websocket.Conn]struct{})
|
|
}
|
|
|
|
conns[conn] = struct{}{}
|
|
m.index[channelID] = conns
|
|
}
|
|
|
|
func (m *channelMap) Send(ctx context.Context, channelID string, messageType int, message []byte, except *websocket.Conn) {
|
|
m.mutex.RLock()
|
|
defer m.mutex.RUnlock()
|
|
|
|
conns, exists := m.index[channelID]
|
|
if !exists {
|
|
return
|
|
}
|
|
|
|
for c := range conns {
|
|
if except == c {
|
|
continue
|
|
}
|
|
|
|
if err := c.WriteMessage(messageType, message); err != nil {
|
|
logger.Error(ctx, "could not write message", logger.E(errors.WithStack(err)))
|
|
}
|
|
}
|
|
}
|