arcast/pkg/server/http.go

225 lines
5.4 KiB
Go

package server
import (
"context"
"crypto/tls"
"fmt"
"html/template"
"net"
"net/http"
"strconv"
"forge.cadoles.com/arcad/arcast/pkg/network"
"github.com/go-chi/chi/v5"
"github.com/go-chi/cors"
"github.com/pkg/errors"
"gitlab.com/wpetit/goweb/logger"
_ "embed"
)
var (
//go:embed templates/idle.html.gotmpl
rawIdleTemplate []byte
idleTemplate *template.Template
)
func init() {
tmpl, err := template.New("").Parse(string(rawIdleTemplate))
if err != nil {
panic(errors.Wrap(err, "could not parse idle template"))
}
idleTemplate = tmpl
}
func (s *Server) startWebServers(ctx context.Context) error {
router := chi.NewRouter()
allowedOrigins, err := s.getAllowedOrigins()
if err != nil {
return errors.WithStack(err)
}
if len(allowedOrigins) > 0 {
router.Use(cors.Handler(cors.Options{
AllowedOrigins: allowedOrigins,
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE", "OPTIONS"},
AllowedHeaders: []string{"Accept", "Authorization", "Content-Type"},
AllowCredentials: false,
}))
}
router.Get("/", s.handleHome)
router.Get("/api/v1/info", s.handleInfo)
router.Post("/api/v1/cast", s.handleCast)
router.Delete("/api/v1/cast", s.handleReset)
router.Get("/api/v1/status", s.handleStatus)
if s.appsEnabled {
router.Get("/apps", s.handleDefaultApp)
router.Get("/api/v1/apps", s.handleApps)
router.Handle("/apps/{appID}/*", http.HandlerFunc(s.handleAppFilesystem))
router.Handle("/api/v1/broadcast/{channelID}", http.HandlerFunc(s.handleBroadcast))
}
if err := s.startHTTPServer(ctx, router); err != nil {
return errors.WithStack(err)
}
if s.tlsCert != nil {
if err := s.startHTTPSServer(ctx, router); err != nil {
return errors.WithStack(err)
}
} else {
logger.Info(ctx, "no tls certificate configured, not starting https server")
}
if err := s.resetBrowser(); err != nil {
return errors.WithStack(err)
}
return nil
}
func (s *Server) startHTTPServer(ctx context.Context, router chi.Router) error {
server := http.Server{
Addr: s.address,
Handler: router,
}
listener, err := net.Listen("tcp", s.address)
if err != nil {
return errors.WithStack(err)
}
host, rawPort, err := net.SplitHostPort(listener.Addr().String())
if err != nil {
return errors.WithStack(err)
}
port, err := strconv.ParseInt(rawPort, 10, 32)
if err != nil {
return errors.Wrapf(err, "could not parse listening port '%v'", rawPort)
}
logger.Debug(ctx, "listening for tcp connections", logger.F("port", port), logger.F("host", host))
s.port = int(port)
go func() {
logger.Debug(ctx, "starting http server")
if err := server.Serve(listener); err != nil && !errors.Is(err, http.ErrServerClosed) {
logger.Error(ctx, "could not listen", logger.CapturedE(errors.WithStack(err)))
}
}()
go func() {
<-ctx.Done()
logger.Debug(ctx, "closing http server")
if err := server.Close(); err != nil {
logger.Error(ctx, "could not close http server", logger.CapturedE(errors.WithStack(err)))
}
}()
return nil
}
func (s *Server) startHTTPSServer(ctx context.Context, router chi.Router) error {
server := http.Server{
Addr: s.address,
Handler: router,
TLSConfig: &tls.Config{
GetCertificate: func(*tls.ClientHelloInfo) (*tls.Certificate, error) {
return s.tlsCert, nil
},
},
}
listener, err := net.Listen("tcp", s.tlsAddress)
if err != nil {
return errors.WithStack(err)
}
host, rawPort, err := net.SplitHostPort(listener.Addr().String())
if err != nil {
return errors.WithStack(err)
}
port, err := strconv.ParseInt(rawPort, 10, 32)
if err != nil {
return errors.Wrapf(err, "could not parse listening port '%v'", rawPort)
}
logger.Debug(ctx, "listening for tls tcp connections", logger.F("port", port), logger.F("host", host))
s.tlsPort = int(port)
go func() {
logger.Debug(ctx, "starting https server")
if err := server.ServeTLS(listener, "", ""); err != nil && !errors.Is(err, http.ErrServerClosed) {
logger.Error(ctx, "could not listen", logger.CapturedE(errors.WithStack(err)))
}
}()
go func() {
<-ctx.Done()
logger.Debug(ctx, "closing https server")
if err := server.Close(); err != nil {
logger.Error(ctx, "could not close https server", logger.CapturedE(errors.WithStack(err)))
}
}()
return nil
}
func (s *Server) getAllowedOrigins() ([]string, error) {
allowedOrigins := make([]string, 0)
if s.appsEnabled {
ips, err := network.GetLANIPv4Addrs()
if err != nil {
return nil, errors.WithStack(err)
}
for _, ip := range ips {
allowedOrigins = append(allowedOrigins, fmt.Sprintf("http://%s:%d", ip, s.port))
}
}
if len(s.allowedOrigins) > 0 {
allowedOrigins = append(allowedOrigins, s.allowedOrigins...)
}
return allowedOrigins, nil
}
func (s *Server) handleHome(w http.ResponseWriter, r *http.Request) {
type templateData struct {
IPs []string
Port int
TLSPort int
ID string
Apps bool
}
ips, err := network.GetLANIPv4Addrs()
if err != nil {
logger.Error(r.Context(), "could not retrieve lan ip addresses", logger.CapturedE(errors.WithStack(err)))
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
d := templateData{
ID: s.instanceID,
IPs: ips,
Port: s.port,
TLSPort: s.tlsPort,
Apps: s.appsEnabled,
}
if err := idleTemplate.Execute(w, d); err != nil {
logger.Error(r.Context(), "could not render idle page", logger.CapturedE(errors.WithStack(err)))
}
}