package proxy

import (
	"bytes"
	"context"
	"expvar"
	"fmt"
	"html/template"
	"io"
	"log"
	"net"
	"net/http"
	"net/http/httputil"
	"net/http/pprof"
	"net/url"
	"os"
	"os/signal"
	"path/filepath"
	"strconv"
	"syscall"
	"time"

	"forge.cadoles.com/Cadoles/go-proxy"
	"forge.cadoles.com/cadoles/bouncer/internal/cache/memory"
	"forge.cadoles.com/cadoles/bouncer/internal/cache/ttl"
	bouncerChi "forge.cadoles.com/cadoles/bouncer/internal/chi"
	"forge.cadoles.com/cadoles/bouncer/internal/config"
	"forge.cadoles.com/cadoles/bouncer/internal/proxy/director"
	"forge.cadoles.com/cadoles/bouncer/internal/store"

	"github.com/Masterminds/sprig/v3"
	sentryhttp "github.com/getsentry/sentry-go/http"
	"github.com/go-chi/chi/v5"
	"github.com/go-chi/chi/v5/middleware"
	"github.com/pkg/errors"
	"github.com/prometheus/client_golang/prometheus/promhttp"
	"gitlab.com/wpetit/goweb/logger"
)

type Server struct {
	serverConfig     config.ProxyServerConfig
	redisConfig      config.RedisConfig
	directorLayers   []director.Layer
	directorCacheTTL time.Duration
	proxyRepository  store.ProxyRepository
	layerRepository  store.LayerRepository
}

func (s *Server) Start(ctx context.Context) (<-chan net.Addr, <-chan error) {
	errs := make(chan error)
	addrs := make(chan net.Addr)

	go s.run(ctx, addrs, errs)

	return addrs, errs
}

func (s *Server) run(parentCtx context.Context, addrs chan net.Addr, errs chan error) {
	defer func() {
		close(errs)
		close(addrs)
	}()

	ctx, cancel := context.WithCancel(parentCtx)
	defer cancel()

	if err := s.initRepositories(ctx); err != nil {
		errs <- errors.WithStack(err)

		return
	}

	listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", s.serverConfig.HTTP.Host, s.serverConfig.HTTP.Port))
	if err != nil {
		errs <- errors.WithStack(err)

		return
	}

	addrs <- listener.Addr()

	defer func() {
		if err := listener.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
			errs <- errors.WithStack(err)
		}
	}()

	go func() {
		<-ctx.Done()

		if err := listener.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
			log.Printf("%+v", errors.WithStack(err))
		}
	}()

	router := chi.NewRouter()

	logger.Info(ctx, "http server listening")

	layerCache, proxyCache, cancel := s.createDirectorCaches(ctx)
	defer cancel()

	director := director.New(
		s.proxyRepository,
		s.layerRepository,
		director.WithLayers(s.directorLayers...),
		director.WithLayerCache(layerCache),
		director.WithProxyCache(proxyCache),
		director.WithHandleErrorFunc(s.handleError),
	)

	if s.serverConfig.HTTP.UseRealIP {
		router.Use(middleware.RealIP)
	}

	router.Use(middleware.RequestID)
	router.Use(middleware.RequestLogger(bouncerChi.NewLogFormatter()))
	router.Use(middleware.Recoverer)

	if s.serverConfig.Sentry.DSN != "" {
		logger.Info(ctx, "enabling sentry http middleware")

		sentryMiddleware := sentryhttp.New(sentryhttp.Options{
			Repanic: true,
		})

		router.Use(sentryMiddleware.Handle)
	}

	if s.serverConfig.Metrics.Enabled {
		metrics := s.serverConfig.Metrics

		logger.Info(ctx, "enabling metrics", logger.F("endpoint", metrics.Endpoint))

		router.Group(func(r chi.Router) {
			if metrics.BasicAuth != nil {
				logger.Info(ctx, "enabling authentication on metrics endpoint")

				r.Use(middleware.BasicAuth(
					"metrics",
					metrics.BasicAuth.CredentialsMap(),
				))
			}

			r.Handle(string(metrics.Endpoint), promhttp.Handler())
		})
	}

	if s.serverConfig.Profiling.Enabled {
		profiling := s.serverConfig.Profiling
		logger.Info(ctx, "enabling profiling", logger.F("endpoint", profiling.Endpoint))

		router.Group(func(r chi.Router) {
			if profiling.BasicAuth != nil {
				logger.Info(ctx, "enabling authentication on profiling endpoint")

				r.Use(middleware.BasicAuth(
					"profiling",
					profiling.BasicAuth.CredentialsMap(),
				))
			}

			r.Route(string(profiling.Endpoint), func(r chi.Router) {
				r.HandleFunc("/", pprof.Index)
				r.HandleFunc("/cmdline", pprof.Cmdline)
				r.HandleFunc("/profile", pprof.Profile)
				r.HandleFunc("/symbol", pprof.Symbol)
				r.HandleFunc("/trace", pprof.Trace)
				r.Handle("/vars", expvar.Handler())
				r.HandleFunc("/{name}", func(w http.ResponseWriter, r *http.Request) {
					name := chi.URLParam(r, "name")
					pprof.Handler(name).ServeHTTP(w, r)
				})
			})
		})
	}

	router.Group(func(r chi.Router) {
		r.Use(director.Middleware())

		handler := proxy.New(
			proxy.WithRequestTransformers(
				director.RequestTransformer(),
			),
			proxy.WithResponseTransformers(
				director.ResponseTransformer(),
			),
			proxy.WithReverseProxyFactory(s.createReverseProxy),
			proxy.WithDefaultHandler(http.HandlerFunc(s.handleDefault)),
		)

		r.Handle("/*", handler)
	})

	if err := http.Serve(listener, router); err != nil && !errors.Is(err, net.ErrClosed) {
		errs <- errors.WithStack(err)
	}

	logger.Info(ctx, "http server exiting")
}

func (s *Server) createDirectorCaches(ctx context.Context) (*ttl.Cache[string, []*store.Layer], *ttl.Cache[string, []*store.Proxy], func()) {
	layerCache := ttl.NewCache(
		memory.NewCache[string, []*store.Layer](),
		memory.NewCache[string, time.Time](),
		s.directorCacheTTL,
	)

	proxyCache := ttl.NewCache(
		memory.NewCache[string, []*store.Proxy](),
		memory.NewCache[string, time.Time](),
		s.directorCacheTTL,
	)

	sig := make(chan os.Signal, 1)

	signal.Notify(sig, syscall.SIGUSR2)

	go func() {
		for {
			_, ok := <-sig
			if !ok {
				return
			}

			logger.Info(ctx, "received sigusr2 signal, clearing proxies and layers cache")

			layerCache.Clear()
			proxyCache.Clear()
		}
	}()

	cancel := func() {
		close(sig)
	}

	return layerCache, proxyCache, cancel
}

func (s *Server) createReverseProxy(ctx context.Context, target *url.URL) *httputil.ReverseProxy {
	reverseProxy := httputil.NewSingleHostReverseProxy(target)

	dialConfig := s.serverConfig.Dial

	dialer := &net.Dialer{
		Timeout:       time.Duration(*dialConfig.Timeout),
		KeepAlive:     time.Duration(*dialConfig.KeepAlive),
		FallbackDelay: time.Duration(*dialConfig.FallbackDelay),
		DualStack:     bool(dialConfig.DualStack),
	}

	httpTransport := s.serverConfig.Transport.AsTransport()
	httpTransport.DialContext = dialer.DialContext

	reverseProxy.Transport = httpTransport
	reverseProxy.ErrorHandler = s.handleProxyError

	return reverseProxy
}

func (s *Server) handleDefault(w http.ResponseWriter, r *http.Request) {
	s.handleError(w, r, http.StatusBadGateway, errors.Errorf("no proxy target found"))
}

func (s *Server) handleError(w http.ResponseWriter, r *http.Request, status int, err error) {
	err = errors.WithStack(err)

	if !errors.Is(err, context.Canceled) {
		logger.Error(r.Context(), err.Error(), logger.CapturedE(err))
	}

	s.renderErrorPage(w, r, err, status, http.StatusText(status))
}

func (s *Server) handleProxyError(w http.ResponseWriter, r *http.Request, err error) {
	s.handleError(w, r, http.StatusBadGateway, err)
}

func (s *Server) renderErrorPage(w http.ResponseWriter, r *http.Request, err error, statusCode int, status string) {
	templateData := struct {
		StatusCode int
		Status     string
		Err        error
		Debug      bool
	}{
		Debug:      bool(s.serverConfig.Debug),
		StatusCode: statusCode,
		Status:     status,
		Err:        err,
	}

	w.WriteHeader(statusCode)
	s.renderPage(w, r, "error", strconv.FormatInt(int64(statusCode), 10), templateData)
}

func (s *Server) renderPage(w http.ResponseWriter, r *http.Request, page string, block string, templateData any) {
	ctx := r.Context()

	templatesConf := s.serverConfig.Templates

	pattern := filepath.Join(string(templatesConf.Dir), page+".gohtml")

	logger.Info(ctx, "loading proxy templates", logger.F("pattern", pattern))

	tmpl, err := template.New("").Funcs(sprig.FuncMap()).ParseGlob(pattern)
	if err != nil {
		logger.Error(ctx, "could not load proxy templates", logger.CapturedE(errors.WithStack(err)))
		http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)

		return
	}

	w.Header().Add("Cache-Control", "no-cache")

	blockTmpl := tmpl.Lookup(block)
	if blockTmpl == nil {
		blockTmpl = tmpl.Lookup("default")
	}

	if blockTmpl == nil {
		logger.Error(ctx, "could not find template block nor default one", logger.F("block", block))
		http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)

		return
	}

	var buf bytes.Buffer

	if err := blockTmpl.Execute(&buf, templateData); err != nil {
		logger.Error(ctx, "could not render proxy page", logger.CapturedE(errors.WithStack(err)))
		http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)

		return
	}

	if _, err := io.Copy(w, &buf); err != nil {
		logger.Error(ctx, "could not write page", logger.CapturedE(errors.WithStack(err)))
	}
}

func NewServer(funcs ...OptionFunc) *Server {
	opt := defaultOption()
	for _, fn := range funcs {
		fn(opt)
	}

	return &Server{
		serverConfig:     opt.ServerConfig,
		redisConfig:      opt.RedisConfig,
		directorLayers:   opt.DirectorLayers,
		directorCacheTTL: opt.DirectorCacheTTL,
	}
}