package redis

import (
	"context"
	"strconv"
	"strings"
	"time"

	"forge.cadoles.com/cadoles/bouncer/internal/proxy/director/layer/queue"
	"github.com/pkg/errors"
	"github.com/redis/go-redis/v9"
)

const (
	keyPrefixQueue = "queue"
)

type Adapter struct {
	client     redis.UniversalClient
	txMaxRetry int
}

// Refresh implements queue.Adapter
func (a *Adapter) Refresh(ctx context.Context, queueName string, keepAlive time.Duration) error {
	lastSeenKey := lastSeenKey(queueName)
	rankKey := rankKey(queueName)

	err := withTx(ctx, a.client, func(ctx context.Context, tx *redis.Tx) error {
		expires := time.Now().UTC().Add(-keepAlive)

		cmd := tx.ZRangeByScore(ctx, lastSeenKey, &redis.ZRangeBy{
			Min: "0",
			Max: strconv.FormatInt(expires.UnixNano(), 10),
		})

		members, err := cmd.Result()
		if err != nil {
			return errors.WithStack(err)
		}

		if len(members) == 0 {
			return nil
		}

		anyMembers := make([]any, len(members))
		for i, m := range members {
			anyMembers[i] = m
		}

		if err := tx.ZRem(ctx, rankKey, anyMembers...).Err(); err != nil {
			return errors.WithStack(err)
		}

		if err := tx.ZRem(ctx, lastSeenKey, anyMembers...).Err(); err != nil {
			return errors.WithStack(err)
		}

		return nil
	}, rankKey, lastSeenKey)
	if err != nil {
		return errors.WithStack(err)
	}

	return nil
}

// Touch implements queue.Adapter
func (a *Adapter) Touch(ctx context.Context, queueName string, sessionId string) (int64, error) {
	lastSeenKey := lastSeenKey(queueName)
	rankKey := rankKey(queueName)

	var rank int64

	retry := a.txMaxRetry

	for retry > 0 {
		err := withTx(ctx, a.client, func(ctx context.Context, tx *redis.Tx) error {
			now := time.Now().UTC().UnixNano()

			err := tx.ZAddNX(ctx, rankKey, redis.Z{Score: float64(now), Member: sessionId}).Err()
			if err != nil {
				return errors.WithStack(err)
			}

			err = tx.ZAdd(ctx, lastSeenKey, redis.Z{Score: float64(now), Member: sessionId}).Err()
			if err != nil {
				return errors.WithStack(err)
			}

			val, err := tx.ZRank(ctx, rankKey, sessionId).Result()
			if err != nil {
				return errors.WithStack(err)
			}

			rank = val

			return nil
		}, rankKey, lastSeenKey)
		if err != nil {
			if errors.Is(err, redis.Nil) && retry > 0 {
				retry--

				continue
			}

			return 0, errors.WithStack(err)
		}

		break
	}

	return rank, nil
}

// Status implements queue.Adapter
func (a *Adapter) Status(ctx context.Context, queueName string) (*queue.Status, error) {
	rankKey := rankKey(queueName)

	status := &queue.Status{}

	cmd := a.client.ZCard(ctx, rankKey)
	if err := cmd.Err(); err != nil {
		return nil, errors.WithStack(err)
	}

	status.Sessions = cmd.Val()

	return status, nil
}

func NewAdapter(client redis.UniversalClient, txMaxRetry int) *Adapter {
	return &Adapter{
		client:     client,
		txMaxRetry: txMaxRetry,
	}
}

var _ queue.Adapter = &Adapter{}

func key(parts ...string) string {
	return strings.Join(parts, ":")
}

func rankKey(queueName string) string {
	return key(keyPrefixQueue, queueName, "rank")
}

func lastSeenKey(queueName string) string {
	return key(keyPrefixQueue, queueName, "last_seen")
}

func withTx(ctx context.Context, client redis.UniversalClient, fn func(ctx context.Context, tx *redis.Tx) error, keys ...string) error {
	txf := func(tx *redis.Tx) error {
		if err := fn(ctx, tx); err != nil {
			return errors.WithStack(err)
		}

		return nil
	}

	err := client.Watch(ctx, txf, keys...)
	if err != nil {
		return errors.WithStack(err)
	}

	return nil
}