package director

import (
	"context"
	"net/http"
	"sort"

	"forge.cadoles.com/Cadoles/go-proxy"
	"forge.cadoles.com/Cadoles/go-proxy/wildcard"
	"forge.cadoles.com/cadoles/bouncer/internal/cache"
	"forge.cadoles.com/cadoles/bouncer/internal/store"
	"github.com/pkg/errors"
	"github.com/prometheus/client_golang/prometheus"
	"gitlab.com/wpetit/goweb/logger"
)

type Director struct {
	proxyRepository store.ProxyRepository
	layerRepository store.LayerRepository
	layerRegistry   *LayerRegistry

	proxyCache cache.Cache[string, []*store.Proxy]
	layerCache cache.Cache[string, []*store.Layer]

	handleError HandleErrorFunc
}

func (d *Director) rewriteRequest(r *http.Request) (*http.Request, error) {
	ctx := r.Context()

	proxies, err := d.getProxies(ctx)
	if err != nil {
		return r, errors.WithStack(err)
	}

	url := getRequestURL(r)
	ctx = withOriginalURL(ctx, url)
	ctx = logger.With(ctx, logger.F("url", url.String()))

	layers := make([]*store.Layer, 0)

	for _, p := range proxies {
		for _, from := range p.From {
			if matches := wildcard.Match(url.String(), from); !matches {
				continue
			}

			proxyCtx := logger.With(ctx,
				logger.F("proxy", p.Name),
				logger.F("host", r.Host),
				logger.F("remoteAddr", r.RemoteAddr),
			)

			metricProxyRequestsTotal.With(prometheus.Labels{metricLabelProxy: string(p.Name)}).Add(1)

			proxyLayers, err := d.getLayers(proxyCtx, p.Name)
			if err != nil {
				return r, errors.WithStack(err)
			}

			layers = append(layers, proxyLayers...)

			if p.To == "" {
				continue
			}

			toURL, err := url.Parse(p.To)
			if err != nil {
				return r, errors.WithStack(err)
			}

			r.URL.Host = toURL.Host
			r.URL.Scheme = toURL.Scheme
			r.URL.Path = toURL.JoinPath(r.URL.Path).Path

			proxyCtx = withLayers(proxyCtx, layers)
			r = r.WithContext(proxyCtx)

			return r, nil
		}
	}

	ctx = withLayers(ctx, layers)
	r = r.WithContext(ctx)

	return r, nil
}

const proxiesCacheKey = "proxies"

func (d *Director) getProxies(ctx context.Context) ([]*store.Proxy, error) {
	proxies, exists := d.proxyCache.Get(proxiesCacheKey)
	if exists {
		return proxies, nil
	}

	headers, err := d.proxyRepository.QueryProxy(ctx, store.WithProxyQueryEnabled(true))
	if err != nil {
		return nil, errors.WithStack(err)
	}

	sort.Sort(store.ByProxyWeight(headers))

	proxies = make([]*store.Proxy, 0, len(headers))

	for _, h := range headers {
		if !h.Enabled {
			continue
		}

		proxy, err := d.proxyRepository.GetProxy(ctx, h.Name)
		if err != nil {
			return nil, errors.WithStack(err)
		}

		proxies = append(proxies, proxy)
	}

	d.proxyCache.Set(proxiesCacheKey, proxies)

	return proxies, nil
}

func (d *Director) getLayers(ctx context.Context, proxyName store.ProxyName) ([]*store.Layer, error) {
	cacheKey := "layers-" + string(proxyName)

	layers, exists := d.layerCache.Get(cacheKey)
	if exists {
		return layers, nil
	}

	headers, err := d.layerRepository.QueryLayers(ctx, proxyName, store.WithLayerQueryEnabled(true))
	if err != nil {
		return nil, errors.WithStack(err)
	}

	sort.Sort(store.ByLayerWeight(headers))

	layers = make([]*store.Layer, 0, len(headers))

	for _, h := range headers {
		if !h.Enabled {
			continue
		}

		layer, err := d.layerRepository.GetLayer(ctx, proxyName, h.Name)
		if err != nil {
			return nil, errors.WithStack(err)
		}

		layers = append(layers, layer)
	}

	d.layerCache.Set(cacheKey, layers)

	return layers, nil
}

func (d *Director) RequestTransformer() proxy.RequestTransformer {
	return func(r *http.Request) {
		ctx := r.Context()

		layers, err := ctxLayers(ctx)
		if err != nil {
			if errors.Is(err, errContextKeyNotFound) {
				return
			}

			logger.Error(ctx, "could not retrieve layers from context", logger.CapturedE(errors.WithStack(err)))

			return
		}

		for _, layer := range layers {
			transformerLayer, ok := d.layerRegistry.GetRequestTransformer(layer.Type)
			if !ok {
				continue
			}

			transformer := transformerLayer.RequestTransformer(layer)

			transformer(r)
		}
	}
}

func (d *Director) ResponseTransformer() proxy.ResponseTransformer {
	return func(r *http.Response) error {
		ctx := r.Request.Context()
		layers, err := ctxLayers(ctx)
		if err != nil {
			if errors.Is(err, errContextKeyNotFound) {
				return nil
			}

			return errors.WithStack(err)
		}

		for _, layer := range layers {
			transformerLayer, ok := d.layerRegistry.GetResponseTransformer(layer.Type)
			if !ok {
				continue
			}

			transformer := transformerLayer.ResponseTransformer(layer)

			if err := transformer(r); err != nil {
				return errors.WithStack(err)
			}
		}

		return nil
	}
}

func (d *Director) Middleware() proxy.Middleware {
	return func(next http.Handler) http.Handler {
		fn := func(w http.ResponseWriter, r *http.Request) {
			ctx := withHandleError(r.Context(), d.handleError)
			r = r.WithContext(ctx)

			r, err := d.rewriteRequest(r)
			if err != nil {
				HandleError(ctx, w, r, http.StatusInternalServerError, errors.Wrap(err, "could not rewrite request"))
				return
			}

			ctx = r.Context()

			layers, err := ctxLayers(ctx)
			if err != nil {
				if errors.Is(err, errContextKeyNotFound) {
					return
				}

				HandleError(ctx, w, r, http.StatusInternalServerError, errors.Wrap(err, "could not retrieve proxy and layers from context"))
				return
			}

			httpMiddlewares := make([]proxy.Middleware, 0)
			for _, layer := range layers {
				middleware, ok := d.layerRegistry.GetMiddleware(layer.Type)
				if !ok {
					continue
				}

				httpMiddlewares = append(httpMiddlewares, middleware.Middleware(layer))
			}

			handler := createMiddlewareChain(next, httpMiddlewares)

			handler.ServeHTTP(w, r)
		}

		return http.HandlerFunc(fn)
	}
}

func New(proxyRepository store.ProxyRepository, layerRepository store.LayerRepository, funcs ...OptionFunc) *Director {
	opts := NewOptions(funcs...)

	registry := NewLayerRegistry(opts.Layers...)

	return &Director{
		proxyRepository: proxyRepository,
		layerRepository: layerRepository,
		layerRegistry:   registry,
		proxyCache:      opts.ProxyCache,
		layerCache:      opts.LayerCache,
		handleError:     opts.HandleError,
	}
}