emissary/internal/agent/controller/app/server.go

190 lines
3.8 KiB
Go
Raw Permalink Normal View History

package app
import (
"context"
"net"
"net/http"
"strings"
2023-03-13 10:44:58 +01:00
"sync"
2023-03-28 20:43:45 +02:00
"time"
2023-03-28 20:43:45 +02:00
"forge.cadoles.com/Cadoles/emissary/internal/agent/controller/app/spec"
appSpec "forge.cadoles.com/Cadoles/emissary/internal/agent/controller/app/spec"
"forge.cadoles.com/Cadoles/emissary/internal/proxy/wildcard"
edgeHTTP "forge.cadoles.com/arcad/edge/pkg/http"
2023-03-13 10:44:58 +01:00
"gitlab.com/wpetit/goweb/logger"
"forge.cadoles.com/arcad/edge/pkg/bundle"
"github.com/go-chi/chi/middleware"
"github.com/go-chi/chi/v5"
"github.com/pkg/errors"
2023-03-21 15:21:19 +01:00
_ "forge.cadoles.com/Cadoles/emissary/internal/imports/passwd"
)
2023-03-28 20:43:45 +02:00
const defaultCookieDuration time.Duration = 24 * time.Hour
type Server struct {
2023-03-28 20:43:45 +02:00
bundle bundle.Bundle
handlerOptions []edgeHTTP.HandlerOptionFunc
server *http.Server
serverMutex sync.RWMutex
config *appSpec.Config
}
2023-03-13 10:44:58 +01:00
func (s *Server) Start(ctx context.Context, addr string) (err error) {
if s.Running() {
if err := s.Stop(); err != nil {
return errors.WithStack(err)
}
}
s.serverMutex.Lock()
defer s.serverMutex.Unlock()
router := chi.NewRouter()
2023-04-21 20:04:37 +02:00
router.Use(middleware.RealIP)
router.Use(middleware.Logger)
router.Use(middleware.Compress(5))
2023-03-28 20:43:45 +02:00
handler := edgeHTTP.NewHandler(s.handlerOptions...)
if err := handler.Load(s.bundle); err != nil {
return errors.Wrap(err, "could not load app bundle")
}
if s.config != nil {
if s.config.UnexpectedHostRedirect != nil {
router.Use(unexpectedHostRedirect(
s.config.UnexpectedHostRedirect.HostTarget,
s.config.UnexpectedHostRedirect.AcceptedHostPatterns...,
))
}
2023-03-21 15:21:19 +01:00
}
router.Handle("/*", handler)
server := &http.Server{
Addr: addr,
Handler: router,
}
go func() {
2023-03-13 10:44:58 +01:00
defer func() {
if recovered := recover(); recovered != nil {
if err, ok := recovered.(error); ok {
logger.Error(ctx, err.Error(), logger.E(errors.WithStack(err)))
return
}
panic(recovered)
}
}()
defer func() {
if err := s.Stop(); err != nil {
panic(errors.WithStack(err))
}
}()
if err := server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
panic(errors.WithStack(err))
}
}()
s.server = server
return nil
}
2023-03-13 10:44:58 +01:00
func (s *Server) Running() bool {
s.serverMutex.RLock()
defer s.serverMutex.RUnlock()
return s.server != nil
}
func (s *Server) Stop() error {
if !s.Running() {
return nil
}
s.serverMutex.Lock()
defer s.serverMutex.Unlock()
if s.server == nil {
return nil
}
if err := s.server.Close(); err != nil {
s.server = nil
return errors.WithStack(err)
}
s.server = nil
return nil
}
func NewServer(bundle bundle.Bundle, config *spec.Config, handlerOptions ...edgeHTTP.HandlerOptionFunc) *Server {
return &Server{
2023-03-28 20:43:45 +02:00
bundle: bundle,
config: config,
2023-03-28 20:43:45 +02:00
handlerOptions: handlerOptions,
}
}
func getCookieDomain(r *http.Request) (string, error) {
host, _, err := net.SplitHostPort(r.Host)
if err != nil {
host = r.Host
}
// If host is an IP address
if ip := net.ParseIP(host); ip != nil {
return "", nil
}
// If host is an domain, return top level domain
domainParts := strings.Split(host, ".")
if len(domainParts) >= 2 {
topLevelDomain := strings.Join(domainParts[len(domainParts)-2:], ".")
return topLevelDomain, nil
}
// By default, return host
return host, nil
}
func unexpectedHostRedirect(hostTarget string, acceptedHostPatterns ...string) func(http.Handler) http.Handler {
return func(h http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
host, port, err := net.SplitHostPort(r.Host)
if err != nil {
host = r.Host
}
matched := wildcard.MatchAny(host, acceptedHostPatterns...)
if !matched {
url := r.URL
url.Host = hostTarget
if port != "" {
url.Host += ":" + port
}
http.Redirect(w, r, url.String(), http.StatusTemporaryRedirect)
return
}
h.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
}