package server import ( "context" "crypto/tls" "fmt" "net" "net/http" "strconv" "forge.cadoles.com/arcad/arcast/pkg/network" "github.com/go-chi/chi/v5" "github.com/go-chi/cors" "github.com/pkg/errors" "gitlab.com/wpetit/goweb/logger" ) func (s *Server) startWebServers(ctx context.Context) error { router := chi.NewRouter() allowedOrigins, err := s.getAllowedOrigins() if err != nil { return errors.WithStack(err) } if len(allowedOrigins) > 0 { router.Use(cors.Handler(cors.Options{ AllowedOrigins: allowedOrigins, AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"}, AllowedHeaders: []string{"Accept", "Authorization", "Content-Type"}, AllowCredentials: false, })) } router.Get("/api/v1/info", s.handleInfo) router.Post("/api/v1/cast", s.handleCast) router.Delete("/api/v1/cast", s.handleReset) router.Get("/api/v1/status", s.handleStatus) if s.appsEnabled { router.Get("/apps", s.handleDefaultApp) router.Get("/api/v1/apps", s.handleApps) router.Handle("/apps/{appID}/*", http.HandlerFunc(s.handleAppFilesystem)) router.Handle("/api/v1/broadcast/{channelID}", http.HandlerFunc(s.handleBroadcast)) } router.Get("/", s.handleIndex) router.Get("/*", s.handleStatic) if err := s.startHTTPServer(ctx, router); err != nil { return errors.WithStack(err) } if s.tlsCert != nil { if err := s.startHTTPSServer(ctx, router); err != nil { return errors.WithStack(err) } } else { logger.Info(ctx, "no tls certificate configured, not starting https server") } if err := s.resetBrowser(); err != nil { return errors.WithStack(err) } return nil } func (s *Server) startHTTPServer(ctx context.Context, router chi.Router) error { server := http.Server{ Addr: s.address, Handler: router, } listener, err := net.Listen("tcp", s.address) if err != nil { return errors.WithStack(err) } host, rawPort, err := net.SplitHostPort(listener.Addr().String()) if err != nil { return errors.WithStack(err) } port, err := strconv.ParseInt(rawPort, 10, 32) if err != nil { return errors.Wrapf(err, "could not parse listening port '%v'", rawPort) } logger.Debug(ctx, "listening for tcp connections", logger.F("port", port), logger.F("host", host)) s.port = int(port) go func() { logger.Debug(ctx, "starting http server") if err := server.Serve(listener); err != nil && !errors.Is(err, http.ErrServerClosed) { logger.Error(ctx, "could not listen", logger.CapturedE(errors.WithStack(err))) } }() go func() { <-ctx.Done() logger.Debug(ctx, "closing http server") if err := server.Close(); err != nil { logger.Error(ctx, "could not close http server", logger.CapturedE(errors.WithStack(err))) } }() return nil } func (s *Server) startHTTPSServer(ctx context.Context, router chi.Router) error { server := http.Server{ Addr: s.address, Handler: router, TLSConfig: &tls.Config{ GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) { return s.tlsCert, nil }, }, } listener, err := net.Listen("tcp", s.tlsAddress) if err != nil { return errors.WithStack(err) } host, rawPort, err := net.SplitHostPort(listener.Addr().String()) if err != nil { return errors.WithStack(err) } port, err := strconv.ParseInt(rawPort, 10, 32) if err != nil { return errors.Wrapf(err, "could not parse listening port '%v'", rawPort) } logger.Debug(ctx, "listening for tls tcp connections", logger.F("port", port), logger.F("host", host)) s.tlsPort = int(port) go func() { logger.Debug(ctx, "starting https server") if err := server.ServeTLS(listener, "", ""); err != nil && !errors.Is(err, http.ErrServerClosed) { logger.Error(ctx, "could not listen", logger.CapturedE(errors.WithStack(err))) } }() go func() { <-ctx.Done() logger.Debug(ctx, "closing https server") if err := server.Close(); err != nil { logger.Error(ctx, "could not close https server", logger.CapturedE(errors.WithStack(err))) } }() return nil } func (s *Server) getAllowedOrigins() ([]string, error) { allowedOrigins := make([]string, 0) if s.appsEnabled { ips, err := network.GetLANIPv4Addrs() if err != nil { return nil, errors.WithStack(err) } for _, ip := range ips { allowedOrigins = append(allowedOrigins, fmt.Sprintf("http://%s:%d", ip, s.port)) } } if len(s.allowedOrigins) > 0 { allowedOrigins = append(allowedOrigins, s.allowedOrigins...) } return allowedOrigins, nil }