package redis import ( "context" "time" "forge.cadoles.com/cadoles/bouncer/internal/store" "github.com/pkg/errors" "github.com/redis/go-redis/v9" ) const ( keyPrefixLayer = "layer" ) type LayerRepository struct { client redis.UniversalClient txMaxAttempts int txRetryBaseDelay time.Duration } // CreateLayer implements store.LayerRepository func (r *LayerRepository) CreateLayer(ctx context.Context, proxyName store.ProxyName, layerName store.LayerName, layerType store.LayerType, options store.LayerOptions) (*store.Layer, error) { now := time.Now().UTC() key := layerKey(proxyName, layerName) layerItem := &layerItem{ layerHeaderItem: layerHeaderItem{ Proxy: string(proxyName), Name: string(layerName), Type: string(layerType), Weight: 0, Revision: 0, Enabled: false, }, CreatedAt: wrap(now), UpdatedAt: wrap(now), Options: wrap(options), } txf := func(tx *redis.Tx) error { exists, err := tx.Exists(ctx, key).Uint64() if err != nil { return errors.WithStack(err) } if exists > 0 { return errors.WithStack(store.ErrAlreadyExist) } _, err = tx.TxPipelined(ctx, func(p redis.Pipeliner) error { p.HMSet(ctx, key, &layerItem.layerHeaderItem) p.HMSet(ctx, key, layerItem) return nil }) if err != nil { return errors.WithStack(err) } layerItem, err = r.txGetLayerItem(ctx, tx, proxyName, layerName) if err != nil { return errors.WithStack(err) } return nil } err := r.client.Watch(ctx, txf, key) if err != nil { return nil, errors.WithStack(err) } return &store.Layer{ LayerHeader: store.LayerHeader{ Name: store.LayerName(layerItem.Name), Proxy: store.ProxyName(layerItem.Proxy), Type: store.LayerType(layerItem.Type), Weight: layerItem.Weight, Enabled: layerItem.Enabled, }, CreatedAt: layerItem.CreatedAt.Value(), UpdatedAt: layerItem.UpdatedAt.Value(), Options: layerItem.Options.Value(), }, nil } // DeleteLayer implements store.LayerRepository func (r *LayerRepository) DeleteLayer(ctx context.Context, proxyName store.ProxyName, layerName store.LayerName) error { key := layerKey(proxyName, layerName) if cmd := r.client.Del(ctx, key); cmd.Err() != nil { return errors.WithStack(cmd.Err()) } return nil } // GetLayer implements store.LayerRepository func (r *LayerRepository) GetLayer(ctx context.Context, proxyName store.ProxyName, layerName store.LayerName) (*store.Layer, error) { key := layerKey(proxyName, layerName) var layerItem *layerItem err := WithRetry(ctx, r.client, key, func(ctx context.Context, tx *redis.Tx) error { pItem, err := r.txGetLayerItem(ctx, tx, proxyName, layerName) if err != nil { return errors.WithStack(err) } layerItem = pItem return nil }, r.txMaxAttempts, r.txRetryBaseDelay) if err != nil { return nil, errors.WithStack(err) } layer, err := layerItem.ToLayer() if err != nil { return nil, errors.WithStack(err) } return layer, nil } func (r *LayerRepository) txGetLayerItem(ctx context.Context, tx *redis.Tx, proxyName store.ProxyName, layerName store.LayerName) (*layerItem, error) { layerItem := layerItem{} key := layerKey(proxyName, layerName) exists, err := tx.Exists(ctx, key).Uint64() if err != nil { return nil, errors.WithStack(err) } if exists == 0 { return nil, errors.WithStack(store.ErrNotFound) } if err := tx.HGetAll(ctx, key).Scan(&layerItem.layerHeaderItem); err != nil { return nil, errors.WithStack(err) } if err := tx.HGetAll(ctx, key).Scan(&layerItem); err != nil { return nil, errors.WithStack(err) } return &layerItem, nil } // QueryLayers implements store.LayerRepository func (r *LayerRepository) QueryLayers(ctx context.Context, proxyName store.ProxyName, funcs ...store.QueryLayerOptionFunc) ([]*store.LayerHeader, error) { opts := store.DefaultQueryLayerOptions() for _, fn := range funcs { fn(opts) } keyParts := []string{keyPrefixLayer, string(proxyName)} if opts.Name != nil { keyParts = append(keyParts, string(*opts.Name)) } else { keyParts = append(keyParts, "*") } key := key(keyParts...) iter := r.client.Scan(ctx, 0, key, 0).Iterator() headers := make([]*store.LayerHeader, 0) for iter.Next(ctx) { key := iter.Val() layerHeaderItem := &layerHeaderItem{} if err := r.client.HGetAll(ctx, key).Scan(layerHeaderItem); err != nil { return nil, errors.WithStack(err) } layerHeader, err := layerHeaderItem.ToLayerHeader() if err != nil { return nil, errors.WithStack(err) } headers = append(headers, layerHeader) } if err := iter.Err(); err != nil { return nil, errors.WithStack(err) } return headers, nil } // UpdateLayer implements store.LayerRepository func (r *LayerRepository) UpdateLayer(ctx context.Context, proxyName store.ProxyName, layerName store.LayerName, funcs ...store.UpdateLayerOptionFunc) (*store.Layer, error) { opts := &store.UpdateLayerOptions{} for _, fn := range funcs { fn(opts) } key := layerKey(proxyName, layerName) var layerItem layerItem err := WithRetry(ctx, r.client, key, func(ctx context.Context, tx *redis.Tx) error { item, err := r.txGetLayerItem(ctx, tx, proxyName, layerName) if err != nil { return errors.WithStack(err) } if opts.Enabled != nil { item.Enabled = *opts.Enabled } if opts.Weight != nil { item.Weight = *opts.Weight } if opts.Options != nil { item.Options = wrap(*opts.Options) } item.UpdatedAt = wrap(time.Now().UTC()) item.Revision = item.Revision + 1 _, err = tx.TxPipelined(ctx, func(p redis.Pipeliner) error { p.HMSet(ctx, key, item.layerHeaderItem) p.HMSet(ctx, key, item) return nil }) if err != nil { return errors.WithStack(err) } layerItem = *item return nil }, r.txMaxAttempts, r.txRetryBaseDelay) if err != nil { return nil, errors.WithStack(err) } layer, err := layerItem.ToLayer() if err != nil { return nil, errors.WithStack(err) } return layer, nil } func NewLayerRepository(client redis.UniversalClient, txMaxAttempts int, txRetryBaseDelay time.Duration) *LayerRepository { return &LayerRepository{ client: client, txMaxAttempts: txMaxAttempts, txRetryBaseDelay: txRetryBaseDelay, } } var _ store.LayerRepository = &LayerRepository{} func layerKey(proxyName store.ProxyName, layerName store.LayerName) string { return key(keyPrefixLayer, string(proxyName), string(layerName)) }