package redis import ( "context" "time" "forge.cadoles.com/cadoles/bouncer/internal/store" "github.com/pkg/errors" "github.com/redis/go-redis/v9" ) const ( keyPrefixProxy = "proxy" ) type ProxyRepository struct { client redis.UniversalClient } // GetProxy implements store.ProxyRepository func (r *ProxyRepository) GetProxy(ctx context.Context, name store.ProxyName) (*store.Proxy, error) { key := proxyKey(name) var proxyItem *proxyItem err := WithTx(ctx, r.client, key, func(ctx context.Context, tx *redis.Tx) error { pItem, err := r.txGetProxyItem(ctx, tx, name) if err != nil { return errors.WithStack(err) } proxyItem = pItem return nil }) if err != nil { return nil, errors.WithStack(err) } proxy, err := proxyItem.ToProxy() if err != nil { return nil, errors.WithStack(err) } return proxy, nil } func (r *ProxyRepository) txGetProxyItem(ctx context.Context, tx *redis.Tx, name store.ProxyName) (*proxyItem, error) { proxyItem := proxyItem{} key := proxyKey(name) exists, err := tx.Exists(ctx, key).Uint64() if err != nil { return nil, errors.WithStack(err) } if exists == 0 { return nil, errors.WithStack(store.ErrNotFound) } if err := tx.HGetAll(ctx, key).Scan(&proxyItem.proxyHeaderItem); err != nil { return nil, errors.WithStack(err) } if err := tx.HGetAll(ctx, key).Scan(&proxyItem); err != nil { return nil, errors.WithStack(err) } return &proxyItem, nil } // CreateProxy implements store.ProxyRepository func (r *ProxyRepository) CreateProxy(ctx context.Context, name store.ProxyName, to string, from ...string) (*store.Proxy, error) { now := time.Now().UTC() key := proxyKey(name) txf := func(tx *redis.Tx) error { exists, err := tx.Exists(ctx, key).Uint64() if err != nil { return errors.WithStack(err) } if exists > 0 { return errors.WithStack(store.ErrAlreadyExist) } proxyItem := &proxyItem{ proxyHeaderItem: proxyHeaderItem{ Name: string(name), CreatedAt: wrap(now), UpdatedAt: wrap(now), Weight: 0, Enabled: false, }, To: to, From: wrap(from), } _, err = tx.TxPipelined(ctx, func(p redis.Pipeliner) error { p.HMSet(ctx, key, proxyItem.proxyHeaderItem) p.HMSet(ctx, key, proxyItem) return nil }) if err != nil { return errors.WithStack(err) } return nil } err := r.client.Watch(ctx, txf, key) if err != nil { return nil, errors.WithStack(err) } return &store.Proxy{ ProxyHeader: store.ProxyHeader{ Name: name, Weight: 0, Enabled: false, }, To: to, From: from, CreatedAt: now, UpdatedAt: now, }, nil } // DeleteProxy implements store.ProxyRepository func (r *ProxyRepository) DeleteProxy(ctx context.Context, name store.ProxyName) error { key := proxyKey(name) if cmd := r.client.Del(ctx, key); cmd.Err() != nil { return errors.WithStack(cmd.Err()) } return nil } // QueryProxy implements store.ProxyRepository func (r *ProxyRepository) QueryProxy(ctx context.Context, funcs ...store.QueryProxyOptionFunc) ([]*store.ProxyHeader, error) { opts := store.DefaultQueryProxyOptions() for _, fn := range funcs { fn(opts) } iter := r.client.Scan(ctx, 0, keyPrefixProxy+"*", 0).Iterator() headers := make([]*store.ProxyHeader, 0) for iter.Next(ctx) { key := iter.Val() proxyHeaderItem := &proxyHeaderItem{} if err := r.client.HGetAll(ctx, key).Scan(proxyHeaderItem); err != nil { return nil, errors.WithStack(err) } proxyHeader, err := proxyHeaderItem.ToProxyHeader() if err != nil { return nil, errors.WithStack(err) } if opts.Enabled != nil && proxyHeader.Enabled != *opts.Enabled { continue } if opts.Names != nil && !contains(opts.Names, proxyHeader.Name) { continue } headers = append(headers, proxyHeader) } if err := iter.Err(); err != nil { return nil, errors.WithStack(err) } return headers, nil } // UpdateProxy implements store.ProxyRepository func (r *ProxyRepository) UpdateProxy(ctx context.Context, name store.ProxyName, funcs ...store.UpdateProxyOptionFunc) (*store.Proxy, error) { opts := &store.UpdateProxyOptions{} for _, fn := range funcs { fn(opts) } key := proxyKey(name) var proxyItem proxyItem err := WithTx(ctx, r.client, key, func(ctx context.Context, tx *redis.Tx) error { item, err := r.txGetProxyItem(ctx, tx, name) if err != nil { return errors.WithStack(err) } if opts.Enabled != nil { item.Enabled = *opts.Enabled } if opts.From != nil { item.From = wrap(opts.From) } if opts.Weight != nil { item.Weight = *opts.Weight } if opts.To != nil { item.To = *opts.To } item.UpdatedAt = wrap(time.Now().UTC()) _, err = tx.TxPipelined(ctx, func(p redis.Pipeliner) error { p.HMSet(ctx, key, item.proxyHeaderItem) p.HMSet(ctx, key, item) return nil }) if err != nil { return errors.WithStack(err) } proxyItem = *item return nil }) if err != nil { return nil, errors.WithStack(err) } proxy, err := proxyItem.ToProxy() if err != nil { return nil, errors.WithStack(err) } return proxy, nil } func NewProxyRepository(client redis.UniversalClient) *ProxyRepository { return &ProxyRepository{ client: client, } } var _ store.ProxyRepository = &ProxyRepository{} func proxyKey(name store.ProxyName) string { return key(keyPrefixProxy, string(name)) }