package redis import ( "context" "encoding/json" "math/rand" "strings" "time" "github.com/pkg/errors" "github.com/redis/go-redis/v9" "gitlab.com/wpetit/goweb/logger" ) var ( DefaultTxMaxAttempts = 20 DefaultTxBaseDelay = 100 * time.Millisecond ) type jsonWrapper[T any] struct { value T } func (w *jsonWrapper[T]) MarshalBinary() ([]byte, error) { data, err := json.Marshal(w.value) if err != nil { return nil, errors.WithStack(err) } return data, nil } func (w *jsonWrapper[T]) UnmarshalBinary(data []byte) error { if err := json.Unmarshal(data, &w.value); err != nil { return errors.WithStack(err) } return nil } func (w *jsonWrapper[T]) UnmarshalText(data []byte) error { if err := json.Unmarshal(data, &w.value); err != nil { return errors.WithStack(err) } return nil } func (w *jsonWrapper[T]) Value() T { return w.value } func wrap[T any](v T) *jsonWrapper[T] { return &jsonWrapper[T]{v} } func unwrap[T any](v any) (T, error) { str, ok := v.(string) if !ok { return *new(T), errors.Errorf("could not unwrap value of type '%T'", v) } u := new(T) if err := json.Unmarshal([]byte(str), u); err != nil { return *new(T), errors.WithStack(err) } return *u, nil } func key(parts ...string) string { return strings.Join(parts, ":") } func WithRetry(ctx context.Context, client redis.UniversalClient, key string, fn func(ctx context.Context, tx *redis.Tx) error, maxAttempts int, baseDelay time.Duration) error { var err error delay := baseDelay for attempt := 0; attempt < maxAttempts; attempt++ { if err = WithTx(ctx, client, key, fn); err != nil { err = errors.WithStack(err) logger.Debug(ctx, "redis transaction failed", logger.E(err)) if errors.Is(err, redis.TxFailedErr) { logger.Debug(ctx, "retrying redis transaction", logger.F("attempts", attempt), logger.F("delay", delay)) time.Sleep(delay) delay = delay*2 + time.Duration(rand.Int63n(int64(baseDelay))) continue } return err } return nil } return errors.WithStack(redis.TxFailedErr) } func WithTx(ctx context.Context, client redis.UniversalClient, key string, fn func(ctx context.Context, tx *redis.Tx) error) error { txf := func(tx *redis.Tx) error { if err := fn(ctx, tx); err != nil { return errors.WithStack(err) } return nil } err := client.Watch(ctx, txf, key) if err != nil { return errors.WithStack(err) } return nil } func contains[T ~string](values []T, v T) bool { for _, vv := range values { if vv == v { return true } } return false }