129 lines
2.5 KiB
Go
129 lines
2.5 KiB
Go
package redis
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"math/rand"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/pkg/errors"
|
|
"github.com/redis/go-redis/v9"
|
|
"gitlab.com/wpetit/goweb/logger"
|
|
)
|
|
|
|
var (
|
|
DefaultTxMaxAttempts = 20
|
|
DefaultTxBaseDelay = 100 * time.Millisecond
|
|
)
|
|
|
|
type jsonWrapper[T any] struct {
|
|
value T
|
|
}
|
|
|
|
func (w *jsonWrapper[T]) MarshalBinary() ([]byte, error) {
|
|
data, err := json.Marshal(w.value)
|
|
if err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
|
|
return data, nil
|
|
}
|
|
|
|
func (w *jsonWrapper[T]) UnmarshalBinary(data []byte) error {
|
|
if err := json.Unmarshal(data, &w.value); err != nil {
|
|
return errors.WithStack(err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (w *jsonWrapper[T]) UnmarshalText(data []byte) error {
|
|
if err := json.Unmarshal(data, &w.value); err != nil {
|
|
return errors.WithStack(err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (w *jsonWrapper[T]) Value() T {
|
|
return w.value
|
|
}
|
|
|
|
func wrap[T any](v T) *jsonWrapper[T] {
|
|
return &jsonWrapper[T]{v}
|
|
}
|
|
|
|
func unwrap[T any](v any) (T, error) {
|
|
str, ok := v.(string)
|
|
if !ok {
|
|
return *new(T), errors.Errorf("could not unwrap value of type '%T'", v)
|
|
}
|
|
|
|
u := new(T)
|
|
|
|
if err := json.Unmarshal([]byte(str), u); err != nil {
|
|
return *new(T), errors.WithStack(err)
|
|
}
|
|
|
|
return *u, nil
|
|
}
|
|
|
|
func key(parts ...string) string {
|
|
return strings.Join(parts, ":")
|
|
}
|
|
|
|
func WithRetry(ctx context.Context, client redis.UniversalClient, key string, fn func(ctx context.Context, tx *redis.Tx) error, maxAttempts int, baseDelay time.Duration) error {
|
|
var err error
|
|
|
|
delay := baseDelay
|
|
|
|
for attempt := 0; attempt < maxAttempts; attempt++ {
|
|
if err = WithTx(ctx, client, key, fn); err != nil {
|
|
err = errors.WithStack(err)
|
|
logger.Debug(ctx, "redis transaction failed", logger.E(err))
|
|
|
|
if errors.Is(err, redis.TxFailedErr) {
|
|
logger.Debug(ctx, "retrying redis transaction", logger.F("attempts", attempt), logger.F("delay", delay))
|
|
time.Sleep(delay)
|
|
delay = delay*2 + time.Duration(rand.Int63n(int64(baseDelay)))
|
|
|
|
continue
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
return errors.WithStack(redis.TxFailedErr)
|
|
}
|
|
|
|
func WithTx(ctx context.Context, client redis.UniversalClient, key string, fn func(ctx context.Context, tx *redis.Tx) error) error {
|
|
txf := func(tx *redis.Tx) error {
|
|
if err := fn(ctx, tx); err != nil {
|
|
return errors.WithStack(err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
err := client.Watch(ctx, txf, key)
|
|
if err != nil {
|
|
return errors.WithStack(err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func contains[T ~string](values []T, v T) bool {
|
|
for _, vv := range values {
|
|
if vv == v {
|
|
return true
|
|
}
|
|
}
|
|
|
|
return false
|
|
}
|