arcast/pkg/server/http.go

180 lines
4.4 KiB
Go

package server
import (
"context"
"crypto/tls"
"fmt"
"net"
"net/http"
"strconv"
"forge.cadoles.com/arcad/arcast/pkg/network"
"github.com/go-chi/chi/v5"
"github.com/go-chi/cors"
"github.com/pkg/errors"
"gitlab.com/wpetit/goweb/logger"
)
func (s *Server) startWebServers(ctx context.Context) error {
router := chi.NewRouter()
allowedOrigins, err := s.getAllowedOrigins()
if err != nil {
return errors.WithStack(err)
}
if len(allowedOrigins) > 0 {
router.Use(cors.Handler(cors.Options{
AllowedOrigins: allowedOrigins,
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type"},
AllowCredentials: false,
}))
}
router.Get("/api/v1/info", s.handleInfo)
router.Post("/api/v1/cast", s.handleCast)
router.Delete("/api/v1/cast", s.handleReset)
router.Get("/api/v1/status", s.handleStatus)
if s.appsEnabled {
router.Get("/apps", s.handleDefaultApp)
router.Get("/api/v1/apps", s.handleApps)
router.Handle("/apps/{appID}/*", http.HandlerFunc(s.handleAppFilesystem))
router.Handle("/api/v1/broadcast/{channelID}", http.HandlerFunc(s.handleBroadcast))
}
router.Get("/", s.handleIndex)
router.Get("/*", s.handleStatic)
if err := s.startHTTPServer(ctx, router); err != nil {
return errors.WithStack(err)
}
if s.tlsCert != nil {
if err := s.startHTTPSServer(ctx, router); err != nil {
return errors.WithStack(err)
}
} else {
logger.Info(ctx, "no tls certificate configured, not starting https server")
}
if err := s.resetBrowser(); err != nil {
return errors.WithStack(err)
}
return nil
}
func (s *Server) startHTTPServer(ctx context.Context, router chi.Router) error {
server := http.Server{
Addr: s.address,
Handler: router,
}
listener, err := net.Listen("tcp", s.address)
if err != nil {
return errors.WithStack(err)
}
host, rawPort, err := net.SplitHostPort(listener.Addr().String())
if err != nil {
return errors.WithStack(err)
}
port, err := strconv.ParseInt(rawPort, 10, 32)
if err != nil {
return errors.Wrapf(err, "could not parse listening port '%v'", rawPort)
}
logger.Debug(ctx, "listening for tcp connections", logger.F("port", port), logger.F("host", host))
s.port = int(port)
go func() {
logger.Debug(ctx, "starting http server")
if err := server.Serve(listener); err != nil && !errors.Is(err, http.ErrServerClosed) {
logger.Error(ctx, "could not listen", logger.CapturedE(errors.WithStack(err)))
}
}()
go func() {
<-ctx.Done()
logger.Debug(ctx, "closing http server")
if err := server.Close(); err != nil {
logger.Error(ctx, "could not close http server", logger.CapturedE(errors.WithStack(err)))
}
}()
return nil
}
func (s *Server) startHTTPSServer(ctx context.Context, router chi.Router) error {
server := http.Server{
Addr: s.address,
Handler: router,
TLSConfig: &tls.Config{
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
return s.tlsCert, nil
},
},
}
listener, err := net.Listen("tcp", s.tlsAddress)
if err != nil {
return errors.WithStack(err)
}
host, rawPort, err := net.SplitHostPort(listener.Addr().String())
if err != nil {
return errors.WithStack(err)
}
port, err := strconv.ParseInt(rawPort, 10, 32)
if err != nil {
return errors.Wrapf(err, "could not parse listening port '%v'", rawPort)
}
logger.Debug(ctx, "listening for tls tcp connections", logger.F("port", port), logger.F("host", host))
s.tlsPort = int(port)
go func() {
logger.Debug(ctx, "starting https server")
if err := server.ServeTLS(listener, "", ""); err != nil && !errors.Is(err, http.ErrServerClosed) {
logger.Error(ctx, "could not listen", logger.CapturedE(errors.WithStack(err)))
}
}()
go func() {
<-ctx.Done()
logger.Debug(ctx, "closing https server")
if err := server.Close(); err != nil {
logger.Error(ctx, "could not close https server", logger.CapturedE(errors.WithStack(err)))
}
}()
return nil
}
func (s *Server) getAllowedOrigins() ([]string, error) {
allowedOrigins := make([]string, 0)
if s.appsEnabled {
ips, err := network.GetLANIPv4Addrs()
if err != nil {
return nil, errors.WithStack(err)
}
for _, ip := range ips {
allowedOrigins = append(allowedOrigins, fmt.Sprintf("http://%s:%d", ip, s.port))
}
}
if len(s.allowedOrigins) > 0 {
allowedOrigins = append(allowedOrigins, s.allowedOrigins...)
}
return allowedOrigins, nil
}