package redis import ( "context" "encoding/json" "net/url" "time" "forge.cadoles.com/cadoles/bouncer/internal/store" "github.com/pkg/errors" "github.com/redis/go-redis/v9" ) const ( keyName = "name" keyFrom = "from" keyTo = "to" keyUpdatedAt = "updated_at" keyCreatedAt = "created_at" keyWeight = "weight" keyPrefixProxy = "proxy:" ) type ProxyRepository struct { client redis.UniversalClient } // GetProxy implements store.ProxyRepository func (r *ProxyRepository) GetProxy(ctx context.Context, name store.ProxyName) (*store.Proxy, error) { var proxy store.Proxy key := proxyKey(name) cmd := r.client.HMGet(ctx, key, keyFrom, keyTo, keyWeight, keyCreatedAt, keyUpdatedAt) values, err := cmd.Result() if err != nil { return nil, errors.WithStack(err) } if allNilValues(values) { return nil, errors.WithStack(store.ErrNotFound) } proxy.Name = name from, err := unwrap[[]string](values[0]) if err != nil { return nil, errors.WithStack(err) } proxy.From = from rawTo, ok := values[1].(string) if !ok { return nil, errors.Errorf("unexpected 'to' value of type '%T'", values[1]) } to, err := url.Parse(rawTo) if err != nil { return nil, errors.WithStack(err) } proxy.To = to weight, err := unwrap[int](values[2]) if err != nil { return nil, errors.WithStack(err) } proxy.Weight = weight createdAt, err := unwrap[time.Time](values[3]) if err != nil { return nil, errors.WithStack(err) } proxy.CreatedAt = createdAt updatedAt, err := unwrap[time.Time](values[4]) if err != nil { return nil, errors.WithStack(err) } proxy.UpdatedAt = updatedAt return &proxy, nil } // CreateProxy implements store.ProxyRepository func (r *ProxyRepository) CreateProxy(ctx context.Context, name store.ProxyName, to *url.URL, from ...string) (*store.Proxy, error) { now := time.Now().UTC() key := proxyKey(name) txf := func(tx *redis.Tx) error { exists, err := tx.Exists(ctx, key).Uint64() if err != nil { return errors.WithStack(err) } if exists > 0 { return errors.WithStack(store.ErrAlreadyExist) } _, err = tx.TxPipelined(ctx, func(p redis.Pipeliner) error { p.HMSet(ctx, key, keyName, string(name)) p.HMSet(ctx, key, keyFrom, wrap(from)) p.HMSet(ctx, key, keyTo, to.String()) p.HMSet(ctx, key, keyWeight, wrap(0)) p.HMSet(ctx, key, keyCreatedAt, wrap(now)) p.HMSet(ctx, key, keyUpdatedAt, wrap(now)) return nil }) return err } err := r.client.Watch(ctx, txf, key) if err != nil { return nil, errors.WithStack(err) } return &store.Proxy{ ProxyHeader: store.ProxyHeader{ Name: name, CreatedAt: now, UpdatedAt: now, }, To: to, From: from, }, nil } // DeleteProxy implements store.ProxyRepository func (r *ProxyRepository) DeleteProxy(ctx context.Context, name store.ProxyName) error { key := proxyKey(name) if cmd := r.client.Del(ctx, key); cmd.Err() != nil { return errors.WithStack(cmd.Err()) } return nil } // QueryProxy implements store.ProxyRepository func (r *ProxyRepository) QueryProxy(ctx context.Context, funcs ...store.QueryProxyOptionFunc) ([]*store.ProxyHeader, error) { iter := r.client.Scan(ctx, 0, keyPrefixProxy+"*", 0).Iterator() headers := make([]*store.ProxyHeader, 0) for iter.Next(ctx) { key := iter.Val() cmd := r.client.HMGet(ctx, key, keyName, keyCreatedAt, keyUpdatedAt) values, err := cmd.Result() if err != nil { return nil, errors.WithStack(err) } if allNilValues(values) { continue } name, ok := values[0].(string) if !ok { return nil, errors.Errorf("unexpected 'name' field value for key '%s': '%s'", key, values[0]) } createdAt, err := unwrap[time.Time](values[1]) if err != nil { return nil, errors.WithStack(err) } updatedAt, err := unwrap[time.Time](values[2]) if err != nil { return nil, errors.WithStack(err) } h := &store.ProxyHeader{ Name: store.ProxyName(name), CreatedAt: createdAt, UpdatedAt: updatedAt, } headers = append(headers, h) } if err := iter.Err(); err != nil { return nil, errors.WithStack(err) } return headers, nil } // UpdateProxy implements store.ProxyRepository func (*ProxyRepository) UpdateProxy(ctx context.Context, name store.ProxyName, funcs ...store.UpdateProxyOptionFunc) (*store.Proxy, error) { panic("unimplemented") } func NewProxyRepository(client redis.UniversalClient) *ProxyRepository { return &ProxyRepository{ client: client, } } var _ store.ProxyRepository = &ProxyRepository{} func proxyKey(name store.ProxyName) string { return keyPrefixProxy + string(name) } type jsonWrapper[T any] struct { value T } func (w *jsonWrapper[T]) MarshalBinary() ([]byte, error) { data, err := json.Marshal(w.value) if err != nil { return nil, errors.WithStack(err) } return data, nil } func (w *jsonWrapper[T]) UnmarshalBinary(data []byte) error { if err := json.Unmarshal(data, &w.value); err != nil { return errors.WithStack(err) } return nil } func (w *jsonWrapper[T]) Value() T { return w.value } func wrap[T any](v T) *jsonWrapper[T] { return &jsonWrapper[T]{v} } func unwrap[T any](v any) (T, error) { str, ok := v.(string) if !ok { return *new(T), errors.Errorf("could not unwrap value of type '%T'", v) } u := new(T) if err := json.Unmarshal([]byte(str), u); err != nil { return *new(T), errors.WithStack(err) } return *u, nil } func allNilValues(values []any) bool { for _, v := range values { if v != nil { return false } } return true }