package admin

import (
	"context"
	"fmt"
	"log"
	"net"
	"net/http"
	"net/http/pprof"

	"forge.cadoles.com/cadoles/bouncer/internal/auth"
	"forge.cadoles.com/cadoles/bouncer/internal/auth/jwt"
	bouncerChi "forge.cadoles.com/cadoles/bouncer/internal/chi"
	"forge.cadoles.com/cadoles/bouncer/internal/config"
	"forge.cadoles.com/cadoles/bouncer/internal/integration"
	"forge.cadoles.com/cadoles/bouncer/internal/jwk"
	"forge.cadoles.com/cadoles/bouncer/internal/store"
	sentryhttp "github.com/getsentry/sentry-go/http"
	"github.com/go-chi/chi/v5"
	"github.com/go-chi/chi/v5/middleware"
	"github.com/go-chi/cors"
	"github.com/pkg/errors"
	"github.com/prometheus/client_golang/prometheus/promhttp"
	"github.com/redis/go-redis/v9"
	"gitlab.com/wpetit/goweb/logger"
)

type Server struct {
	serverConfig config.AdminServerConfig
	redisConfig  config.RedisConfig

	redisClient redis.UniversalClient

	integrations []integration.Integration

	bootstrapConfig config.BootstrapConfig
	proxyRepository store.ProxyRepository
	layerRepository store.LayerRepository

	privateKey jwk.Key
	publicKeys jwk.Set
}

func (s *Server) Start(ctx context.Context) (<-chan net.Addr, <-chan error) {
	errs := make(chan error)
	addrs := make(chan net.Addr)

	go s.run(ctx, addrs, errs)

	return addrs, errs
}

func (s *Server) run(parentCtx context.Context, addrs chan net.Addr, errs chan error) {
	defer func() {
		close(errs)
		close(addrs)
	}()

	ctx, cancel := context.WithCancel(parentCtx)
	defer cancel()

	if err := s.initRepositories(ctx); err != nil {
		errs <- errors.WithStack(err)

		return
	}

	if err := s.bootstrapProxies(ctx); err != nil {
		errs <- errors.WithStack(err)

		return
	}

	if err := s.initPrivateKey(ctx); err != nil {
		errs <- errors.WithStack(err)

		return
	}

	ctx = integration.WithPrivateKey(ctx, s.privateKey)
	ctx = integration.WithPublicKeySet(ctx, s.publicKeys)

	if err := integration.RunOnStartup(ctx, s.integrations); err != nil {
		errs <- errors.WithStack(err)

		return
	}

	listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", s.serverConfig.HTTP.Host, s.serverConfig.HTTP.Port))
	if err != nil {
		errs <- errors.WithStack(err)

		return
	}

	addrs <- listener.Addr()

	defer func() {
		if err := listener.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
			errs <- errors.WithStack(err)
		}
	}()

	go func() {
		<-ctx.Done()

		if err := listener.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
			log.Printf("%+v", errors.WithStack(err))
		}
	}()

	router := chi.NewRouter()

	if s.serverConfig.HTTP.UseRealIP {
		router.Use(middleware.RealIP)
	}

	router.Use(middleware.RequestLogger(bouncerChi.NewLogFormatter()))

	if s.serverConfig.Sentry.DSN != "" {
		logger.Info(ctx, "enabling sentry http middleware")

		sentryMiddleware := sentryhttp.New(sentryhttp.Options{
			Repanic: true,
		})

		router.Use(sentryMiddleware.Handle)
	}

	corsMiddleware := cors.New(cors.Options{
		AllowedOrigins:   s.serverConfig.CORS.AllowedOrigins,
		AllowedMethods:   s.serverConfig.CORS.AllowedMethods,
		AllowCredentials: bool(s.serverConfig.CORS.AllowCredentials),
		AllowedHeaders:   s.serverConfig.CORS.AllowedHeaders,
		Debug:            bool(s.serverConfig.CORS.Debug),
	})

	router.Use(corsMiddleware.Handler)

	if s.serverConfig.Metrics.Enabled {
		metrics := s.serverConfig.Metrics

		logger.Info(ctx, "enabling metrics", logger.F("endpoint", metrics.Endpoint))

		router.Group(func(r chi.Router) {
			if metrics.BasicAuth != nil {
				logger.Info(ctx, "enabling authentication on metrics endpoint")

				r.Use(middleware.BasicAuth(
					"metrics",
					metrics.BasicAuth.CredentialsMap(),
				))
			}

			r.Handle(string(metrics.Endpoint), promhttp.Handler())
		})
	}

	if s.serverConfig.Profiling.Enabled {
		profiling := s.serverConfig.Profiling
		logger.Info(ctx, "enabling profiling", logger.F("endpoint", profiling.Endpoint))

		router.Group(func(r chi.Router) {
			if profiling.BasicAuth != nil {
				logger.Info(ctx, "enabling authentication on metrics endpoint")

				r.Use(middleware.BasicAuth(
					"profiling",
					profiling.BasicAuth.CredentialsMap(),
				))
			}

			r.Route(string(profiling.Endpoint), func(r chi.Router) {
				r.HandleFunc("/", pprof.Index)
				r.HandleFunc("/cmdline", pprof.Cmdline)
				r.HandleFunc("/profile", pprof.Profile)
				r.HandleFunc("/symbol", pprof.Symbol)
				r.HandleFunc("/trace", pprof.Trace)
				r.HandleFunc("/{name}", func(w http.ResponseWriter, r *http.Request) {
					name := chi.URLParam(r, "name")
					pprof.Handler(name).ServeHTTP(w, r)
				})
			})
		})
	}

	router.Route("/api/v1", func(r chi.Router) {
		r.Group(func(r chi.Router) {
			r.Use(auth.Middleware(
				jwt.NewAuthenticator(s.publicKeys, string(s.serverConfig.Auth.Issuer), jwt.DefaultAcceptableSkew),
			))

			r.Route("/definitions", func(r chi.Router) {
				r.With(assertReadAccess).Get("/layers", s.queryLayerDefinition)
			})

			r.Route("/proxies", func(r chi.Router) {
				r.With(assertReadAccess).Get("/", s.queryProxy)
				r.With(assertWriteAccess).Post("/", s.createProxy)
				r.With(assertReadAccess).Get("/{proxyName}", s.getProxy)
				r.With(assertWriteAccess).Put("/{proxyName}", s.updateProxy)
				r.With(assertWriteAccess).Delete("/{proxyName}", s.deleteProxy)

				r.With(assertReadAccess).Get("/{proxyName}/layers", s.queryLayer)
				r.With(assertWriteAccess).Post("/{proxyName}/layers", s.createLayer)
				r.With(assertReadAccess).Get("/{proxyName}/layers/{layerName}", s.getLayer)
				r.With(assertWriteAccess).Put("/{proxyName}/layers/{layerName}", s.updateLayer)
				r.With(assertWriteAccess).Delete("/{proxyName}/layers/{layerName}", s.deleteLayer)
			})
		})
	})

	logger.Info(ctx, "http server listening")

	if err := http.Serve(listener, router); err != nil && !errors.Is(err, net.ErrClosed) {
		errs <- errors.WithStack(err)
	}

	logger.Info(ctx, "http server exiting")
}

func NewServer(funcs ...OptionFunc) *Server {
	opt := defaultOption()
	for _, fn := range funcs {
		fn(opt)
	}

	return &Server{
		serverConfig:    opt.ServerConfig,
		redisConfig:     opt.RedisConfig,
		bootstrapConfig: opt.BootstrapConfig,
		integrations:    opt.Integrations,
	}
}