package oidc

import (
	"bytes"
	"context"
	"crypto/tls"
	"fmt"
	"net/http"
	"net/url"
	"slices"
	"strings"
	"text/template"
	"time"

	"forge.cadoles.com/Cadoles/go-proxy/wildcard"
	"forge.cadoles.com/cadoles/bouncer/internal/cache"
	"forge.cadoles.com/cadoles/bouncer/internal/proxy/director"
	"forge.cadoles.com/cadoles/bouncer/internal/proxy/director/layer/authn"
	"forge.cadoles.com/cadoles/bouncer/internal/store"
	"github.com/coreos/go-oidc/v3/oidc"
	"github.com/gorilla/sessions"
	"github.com/pkg/errors"
	"github.com/prometheus/client_golang/prometheus"
	"gitlab.com/wpetit/goweb/logger"
)

type Authenticator struct {
	store             sessions.Store
	httpTransport     *http.Transport
	httpClientTimeout time.Duration
	oidcProviderCache cache.Cache[string, *oidc.Provider]
}

func (a *Authenticator) PreAuthentication(w http.ResponseWriter, r *http.Request, layer *store.Layer) error {
	ctx := r.Context()

	originalURL, err := director.OriginalURL(ctx)
	if err != nil {
		return errors.WithStack(err)
	}

	options, err := fromStoreOptions(layer.Options)
	if err != nil {
		return errors.WithStack(err)
	}

	sess, err := a.store.Get(r, a.getCookieName(options.Cookie.Name, layer.Proxy, layer.Name))
	if err != nil {
		logger.Error(ctx, "could not retrieve session", logger.CapturedE(errors.WithStack(err)))
	}

	loginCallbackURL, err := a.getLoginCallbackURL(originalURL, layer.Proxy, layer.Name, options)
	if err != nil {
		return errors.WithStack(err)
	}

	client, err := a.getClient(options, loginCallbackURL.String())
	if err != nil {
		return errors.WithStack(err)
	}

	loginCallbackPathPattern, err := a.templatize(options.OIDC.MatchLoginCallbackPath, layer.Proxy, layer.Name)
	if err != nil {
		return errors.WithStack(err)
	}

	logoutPathPattern, err := a.templatize(options.OIDC.MatchLogoutPath, layer.Proxy, layer.Name)
	if err != nil {
		return errors.WithStack(err)
	}

	logger.Debug(ctx, "checking url", logger.F("loginCallbackPathPattern", loginCallbackPathPattern), logger.F("logoutPathPattern", logoutPathPattern))

	switch {
	case wildcard.Match(originalURL.Path, loginCallbackPathPattern):
		if err := client.HandleCallback(w, r, sess); err != nil {
			return errors.WithStack(err)
		}

		metricLoginSuccessesTotal.With(prometheus.Labels{
			metricLabelLayer: string(layer.Name),
			metricLabelProxy: string(layer.Proxy),
		}).Add(1)

	case wildcard.Match(originalURL.Path, logoutPathPattern):
		postLogoutRedirectURL := r.URL.Query().Get("redirect")

		if postLogoutRedirectURL != "" {
			isAuthorized := slices.Contains(options.OIDC.PostLogoutRedirectURLs, postLogoutRedirectURL)
			if !isAuthorized {
				director.HandleError(ctx, w, r, http.StatusBadRequest, errors.New("unauthorized post-logout redirect"))
				return errors.WithStack(authn.ErrSkipRequest)
			}
		}

		if postLogoutRedirectURL == "" {
			if options.OIDC.PublicBaseURL != "" {
				postLogoutRedirectURL = options.OIDC.PublicBaseURL
			} else {
				postLogoutRedirectURL = originalURL.Scheme + "://" + originalURL.Host
			}
		}

		if err := client.HandleLogout(w, r, sess, postLogoutRedirectURL); err != nil {
			return errors.WithStack(err)
		}

		metricLogoutsTotal.With(prometheus.Labels{
			metricLabelLayer: string(layer.Name),
			metricLabelProxy: string(layer.Proxy),
		}).Add(1)
	}

	return nil
}

// Authenticate implements authn.Authenticator.
func (a *Authenticator) Authenticate(w http.ResponseWriter, r *http.Request, layer *store.Layer) (*authn.User, error) {
	ctx := r.Context()

	options, err := fromStoreOptions(layer.Options)
	if err != nil {
		return nil, errors.WithStack(err)
	}

	sess, err := a.store.Get(r, a.getCookieName(options.Cookie.Name, layer.Proxy, layer.Name))
	if err != nil {
		return nil, errors.WithStack(err)
	}

	defer func() {
		if err := sess.Save(r, w); err != nil {
			logger.Error(ctx, "could not save session", logger.CapturedE(errors.WithStack(err)))
		}
	}()

	sess.Options.Domain = options.Cookie.Domain
	sess.Options.HttpOnly = options.Cookie.HTTPOnly
	sess.Options.MaxAge = int(options.Cookie.MaxAge.Seconds())
	sess.Options.Path = options.Cookie.Path

	switch options.Cookie.SameSite {
	case "lax":
		sess.Options.SameSite = http.SameSiteLaxMode
	case "strict":
		sess.Options.SameSite = http.SameSiteStrictMode
	case "none":
		sess.Options.SameSite = http.SameSiteNoneMode
	default:
		sess.Options.SameSite = http.SameSiteDefaultMode
	}

	originalURL, err := director.OriginalURL(ctx)
	if err != nil {
		return nil, errors.WithStack(err)
	}

	loginCallbackURL, err := a.getLoginCallbackURL(originalURL, layer.Proxy, layer.Name, options)
	if err != nil {
		return nil, errors.WithStack(err)
	}

	client, err := a.getClient(options, loginCallbackURL.String())
	if err != nil {
		return nil, errors.WithStack(err)
	}

	postLoginRedirectURL, err := a.mergeURL(originalURL, originalURL.Path, options.OIDC.PublicBaseURL, true)
	if err != nil {
		return nil, errors.WithStack(err)
	}

	idToken, err := client.Authenticate(w, r, sess, postLoginRedirectURL.String())
	if err != nil {
		if errors.Is(err, ErrLoginRequired) {
			metricLoginRequestsTotal.With(prometheus.Labels{
				metricLabelLayer: string(layer.Name),
				metricLabelProxy: string(layer.Proxy),
			}).Add(1)

			return nil, errors.WithStack(authn.ErrSkipRequest)
		}

		return nil, errors.WithStack(err)
	}

	user, err := a.toUser(originalURL, idToken, layer.Proxy, layer.Name, options, sess)
	if err != nil {
		return nil, errors.WithStack(err)
	}

	return user, nil
}

type claims struct {
	Issuer     string         `json:"iss"`
	Subject    string         `json:"sub"`
	Expiration int64          `json:"exp"`
	IssuedAt   int64          `json:"iat"`
	AuthTime   int64          `json:"auth_time"`
	Nonce      string         `json:"nonce"`
	ACR        string         `json:"acr"`
	AMR        string         `json:"amr"`
	AZP        string         `json:"amp"`
	Others     map[string]any `json:"-"`
}

func (c claims) AsAttrs() map[string]any {
	attrs := make(map[string]any)

	for key, val := range c.Others {
		if val != nil {
			attrs["claim_"+key] = val
		}
	}

	attrs["claim_iss"] = c.Issuer
	attrs["claim_sub"] = c.Subject
	attrs["claim_exp"] = c.Expiration
	attrs["claim_iat"] = c.IssuedAt

	if c.AuthTime != 0 {
		attrs["claim_auth_time"] = c.AuthTime
	}

	if c.Nonce != "" {
		attrs["claim_nonce"] = c.Nonce
	}

	if c.ACR != "" {
		attrs["claim_arc"] = c.ACR
	}

	if c.AMR != "" {
		attrs["claim_amr"] = c.AMR
	}

	if c.AZP != "" {
		attrs["claim_azp"] = c.AZP
	}

	return attrs
}

func (a *Authenticator) toUser(originalURL *url.URL, idToken *oidc.IDToken, proxyName store.ProxyName, layerName store.LayerName, options *LayerOptions, sess *sessions.Session) (*authn.User, error) {
	var claims claims

	if err := idToken.Claims(&claims); err != nil {
		return nil, errors.WithStack(err)
	}

	if err := idToken.Claims(&claims.Others); err != nil {
		return nil, errors.WithStack(err)
	}

	attrs := claims.AsAttrs()

	logoutURL, err := a.getLogoutURL(originalURL, proxyName, layerName, options)
	if err != nil {
		return nil, errors.WithStack(err)
	}

	attrs["logout_url"] = logoutURL.String()

	if accessToken, exists := sess.Values[sessionKeyAccessToken]; exists && accessToken != nil {
		attrs["access_token"] = accessToken
	}

	if refreshToken, exists := sess.Values[sessionKeyRefreshToken]; exists && refreshToken != nil {
		attrs["refresh_token"] = refreshToken
	}

	if tokenExpiry, exists := sess.Values[sessionKeyTokenExpiry]; exists && tokenExpiry != nil {
		attrs["token_expiry"] = tokenExpiry
	}

	user := authn.NewUser(idToken.Subject, attrs)

	return user, nil
}

func (a *Authenticator) getLoginCallbackURL(originalURL *url.URL, proxyName store.ProxyName, layerName store.LayerName, options *LayerOptions) (*url.URL, error) {
	path, err := a.templatize(options.OIDC.LoginCallbackPath, proxyName, layerName)
	if err != nil {
		return nil, errors.WithStack(err)
	}

	merged, err := a.mergeURL(originalURL, path, options.OIDC.PublicBaseURL, false)
	if err != nil {
		return nil, errors.WithStack(err)
	}

	return merged, nil
}

func (a *Authenticator) getLogoutURL(originalURL *url.URL, proxyName store.ProxyName, layerName store.LayerName, options *LayerOptions) (*url.URL, error) {
	path, err := a.templatize(options.OIDC.LogoutPath, proxyName, layerName)
	if err != nil {
		return nil, errors.WithStack(err)
	}

	merged, err := a.mergeURL(originalURL, path, options.OIDC.PublicBaseURL, true)
	if err != nil {
		return nil, errors.WithStack(err)
	}

	return merged, nil
}

func (a *Authenticator) mergeURL(base *url.URL, path string, overlay string, withQuery bool) (*url.URL, error) {
	merged := &url.URL{
		Scheme: base.Scheme,
		Host:   base.Host,
		Path:   path,
	}

	if withQuery {
		merged.RawQuery = base.RawQuery
	}

	if overlay != "" {
		overlayURL, err := url.Parse(overlay)
		if err != nil {
			return nil, errors.WithStack(err)
		}

		merged.Scheme = overlayURL.Scheme
		merged.Host = overlayURL.Host
		merged.Path = overlayURL.Path + strings.TrimPrefix(path, "/")

		for key, values := range overlayURL.Query() {
			query := merged.Query()
			for _, v := range values {
				query.Add(key, v)
			}
			merged.RawQuery = query.Encode()
		}
	}

	return merged, nil
}

func (a *Authenticator) templatize(rawTemplate string, proxyName store.ProxyName, layerName store.LayerName) (string, error) {
	tmpl, err := template.New("").Parse(rawTemplate)
	if err != nil {
		return "", errors.WithStack(err)
	}

	var raw bytes.Buffer

	err = tmpl.Execute(&raw, struct {
		ProxyName store.ProxyName
		LayerName store.LayerName
	}{
		ProxyName: proxyName,
		LayerName: layerName,
	})
	if err != nil {
		return "", errors.WithStack(err)
	}

	return raw.String(), nil
}

func (a *Authenticator) getClient(options *LayerOptions, redirectURL string) (*Client, error) {
	ctx := context.Background()

	transport := a.httpTransport.Clone()

	if options.OIDC.TLSInsecureSkipVerify {
		if transport.TLSClientConfig == nil {
			transport.TLSClientConfig = &tls.Config{}
		}

		transport.TLSClientConfig.InsecureSkipVerify = true
	}

	httpClient := &http.Client{
		Timeout:   a.httpClientTimeout,
		Transport: transport,
	}

	provider, exists := a.oidcProviderCache.Get(options.OIDC.IssuerURL)
	if !exists {
		var err error
		ctx = oidc.ClientContext(ctx, httpClient)

		if options.OIDC.SkipIssuerVerification {
			ctx = oidc.InsecureIssuerURLContext(ctx, options.OIDC.IssuerURL)
		}

		logger.Debug(ctx, "refreshing oidc provider", logger.F("issuerURL", options.OIDC.IssuerURL))

		provider, err = oidc.NewProvider(ctx, options.OIDC.IssuerURL)
		if err != nil {
			return nil, errors.Wrap(err, "could not create oidc provider")
		}

		a.oidcProviderCache.Set(options.OIDC.IssuerURL, provider)
	}

	client := NewClient(
		WithCredentials(options.OIDC.ClientID, options.OIDC.ClientSecret),
		WithProvider(provider),
		WithRedirectURL(redirectURL),
		WithScopes(options.OIDC.Scopes...),
		WithAuthParams(options.OIDC.AuthParams),
		WithHTTPClient(httpClient),
	)

	return client, nil
}

const defaultCookieNamePrefix = "_bouncer_authn_oidc"

func (a *Authenticator) getCookieName(cookieName string, proxyName store.ProxyName, layerName store.LayerName) string {
	if cookieName != "" {
		return cookieName
	}

	return strings.ToLower(fmt.Sprintf("%s_%s_%s", defaultCookieNamePrefix, proxyName, layerName))
}

var (
	_ authn.PreAuthentication = &Authenticator{}
	_ authn.Authenticator     = &Authenticator{}
)