package proxy import ( "bytes" "context" "expvar" "fmt" "html/template" "io" "log" "net" "net/http" "net/http/httputil" "net/http/pprof" "net/url" "path/filepath" "strconv" "time" "forge.cadoles.com/Cadoles/go-proxy" "forge.cadoles.com/cadoles/bouncer/internal/cache/memory" "forge.cadoles.com/cadoles/bouncer/internal/cache/ttl" bouncerChi "forge.cadoles.com/cadoles/bouncer/internal/chi" "forge.cadoles.com/cadoles/bouncer/internal/config" "forge.cadoles.com/cadoles/bouncer/internal/proxy/director" "forge.cadoles.com/cadoles/bouncer/internal/store" "github.com/Masterminds/sprig/v3" sentryhttp "github.com/getsentry/sentry-go/http" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus/promhttp" "gitlab.com/wpetit/goweb/logger" ) type Server struct { serverConfig config.ProxyServerConfig redisConfig config.RedisConfig directorLayers []director.Layer directorCacheTTL time.Duration proxyRepository store.ProxyRepository layerRepository store.LayerRepository } func (s *Server) Start(ctx context.Context) (<-chan net.Addr, <-chan error) { errs := make(chan error) addrs := make(chan net.Addr) go s.run(ctx, addrs, errs) return addrs, errs } func (s *Server) run(parentCtx context.Context, addrs chan net.Addr, errs chan error) { defer func() { close(errs) close(addrs) }() ctx, cancel := context.WithCancel(parentCtx) defer cancel() if err := s.initRepositories(ctx); err != nil { errs <- errors.WithStack(err) return } listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", s.serverConfig.HTTP.Host, s.serverConfig.HTTP.Port)) if err != nil { errs <- errors.WithStack(err) return } addrs <- listener.Addr() defer func() { if err := listener.Close(); err != nil && !errors.Is(err, net.ErrClosed) { errs <- errors.WithStack(err) } }() go func() { <-ctx.Done() if err := listener.Close(); err != nil && !errors.Is(err, net.ErrClosed) { log.Printf("%+v", errors.WithStack(err)) } }() router := chi.NewRouter() logger.Info(ctx, "http server listening") director := director.New( s.proxyRepository, s.layerRepository, director.WithLayers(s.directorLayers...), director.WithLayerCache( ttl.NewCache( memory.NewCache[string, []*store.Layer](), memory.NewCache[string, time.Time](), s.directorCacheTTL, ), ), director.WithProxyCache( ttl.NewCache( memory.NewCache[string, []*store.Proxy](), memory.NewCache[string, time.Time](), s.directorCacheTTL, ), ), director.WithHandleErrorFunc(s.handleError), ) if s.serverConfig.HTTP.UseRealIP { router.Use(middleware.RealIP) } router.Use(middleware.RequestID) router.Use(middleware.RequestLogger(bouncerChi.NewLogFormatter())) router.Use(middleware.Recoverer) if s.serverConfig.Sentry.DSN != "" { logger.Info(ctx, "enabling sentry http middleware") sentryMiddleware := sentryhttp.New(sentryhttp.Options{ Repanic: true, }) router.Use(sentryMiddleware.Handle) } if s.serverConfig.Metrics.Enabled { metrics := s.serverConfig.Metrics logger.Info(ctx, "enabling metrics", logger.F("endpoint", metrics.Endpoint)) router.Group(func(r chi.Router) { if metrics.BasicAuth != nil { logger.Info(ctx, "enabling authentication on metrics endpoint") r.Use(middleware.BasicAuth( "metrics", metrics.BasicAuth.CredentialsMap(), )) } r.Handle(string(metrics.Endpoint), promhttp.Handler()) }) } if s.serverConfig.Profiling.Enabled { profiling := s.serverConfig.Profiling logger.Info(ctx, "enabling profiling", logger.F("endpoint", profiling.Endpoint)) router.Group(func(r chi.Router) { if profiling.BasicAuth != nil { logger.Info(ctx, "enabling authentication on profiling endpoint") r.Use(middleware.BasicAuth( "profiling", profiling.BasicAuth.CredentialsMap(), )) } r.Route(string(profiling.Endpoint), func(r chi.Router) { r.HandleFunc("/", pprof.Index) r.HandleFunc("/cmdline", pprof.Cmdline) r.HandleFunc("/profile", pprof.Profile) r.HandleFunc("/symbol", pprof.Symbol) r.HandleFunc("/trace", pprof.Trace) r.Handle("/vars", expvar.Handler()) r.HandleFunc("/{name}", func(w http.ResponseWriter, r *http.Request) { name := chi.URLParam(r, "name") pprof.Handler(name).ServeHTTP(w, r) }) }) }) } router.Group(func(r chi.Router) { r.Use(director.Middleware()) handler := proxy.New( proxy.WithRequestTransformers( director.RequestTransformer(), ), proxy.WithResponseTransformers( director.ResponseTransformer(), ), proxy.WithReverseProxyFactory(s.createReverseProxy), proxy.WithDefaultHandler(http.HandlerFunc(s.handleDefault)), ) r.Handle("/*", handler) }) if err := http.Serve(listener, router); err != nil && !errors.Is(err, net.ErrClosed) { errs <- errors.WithStack(err) } logger.Info(ctx, "http server exiting") } func (s *Server) createReverseProxy(ctx context.Context, target *url.URL) *httputil.ReverseProxy { reverseProxy := httputil.NewSingleHostReverseProxy(target) dialConfig := s.serverConfig.Dial dialer := &net.Dialer{ Timeout: time.Duration(*dialConfig.Timeout), KeepAlive: time.Duration(*dialConfig.KeepAlive), FallbackDelay: time.Duration(*dialConfig.FallbackDelay), DualStack: bool(dialConfig.DualStack), } httpTransport := s.serverConfig.Transport.AsTransport() httpTransport.DialContext = dialer.DialContext reverseProxy.Transport = httpTransport reverseProxy.ErrorHandler = s.handleProxyError return reverseProxy } func (s *Server) handleDefault(w http.ResponseWriter, r *http.Request) { s.handleError(w, r, http.StatusBadGateway, errors.Errorf("no proxy target found")) } func (s *Server) handleError(w http.ResponseWriter, r *http.Request, status int, err error) { err = errors.WithStack(err) if errors.Is(err, context.Canceled) { logger.Warn(r.Context(), err.Error(), logger.E(err)) } else { logger.Error(r.Context(), err.Error(), logger.CapturedE(err)) } s.renderErrorPage(w, r, err, status, http.StatusText(status)) } func (s *Server) handleProxyError(w http.ResponseWriter, r *http.Request, err error) { s.handleError(w, r, http.StatusBadGateway, err) } func (s *Server) renderErrorPage(w http.ResponseWriter, r *http.Request, err error, statusCode int, status string) { templateData := struct { StatusCode int Status string Err error Debug bool }{ Debug: bool(s.serverConfig.Debug), StatusCode: statusCode, Status: status, Err: err, } w.WriteHeader(statusCode) s.renderPage(w, r, "error", strconv.FormatInt(int64(statusCode), 10), templateData) } func (s *Server) renderPage(w http.ResponseWriter, r *http.Request, page string, block string, templateData any) { ctx := r.Context() templatesConf := s.serverConfig.Templates pattern := filepath.Join(string(templatesConf.Dir), page+".gohtml") logger.Info(ctx, "loading proxy templates", logger.F("pattern", pattern)) tmpl, err := template.New("").Funcs(sprig.FuncMap()).ParseGlob(pattern) if err != nil { logger.Error(ctx, "could not load proxy templates", logger.CapturedE(errors.WithStack(err))) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } w.Header().Add("Cache-Control", "no-cache") blockTmpl := tmpl.Lookup(block) if blockTmpl == nil { blockTmpl = tmpl.Lookup("default") } if blockTmpl == nil { logger.Error(ctx, "could not find template block nor default one", logger.F("block", block)) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } var buf bytes.Buffer if err := blockTmpl.Execute(&buf, templateData); err != nil { logger.Error(ctx, "could not render proxy page", logger.CapturedE(errors.WithStack(err))) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } if _, err := io.Copy(w, &buf); err != nil { logger.Error(ctx, "could not write page", logger.CapturedE(errors.WithStack(err))) } } func NewServer(funcs ...OptionFunc) *Server { opt := defaultOption() for _, fn := range funcs { fn(opt) } return &Server{ serverConfig: opt.ServerConfig, redisConfig: opt.RedisConfig, directorLayers: opt.DirectorLayers, directorCacheTTL: opt.DirectorCacheTTL, } }