emissary/internal/datastore/sqlite/agent_repository.go

394 lines
8.8 KiB
Go

package sqlite
import (
"context"
"database/sql"
"fmt"
"time"
"forge.cadoles.com/Cadoles/emissary/internal/datastore"
"github.com/pkg/errors"
"gitlab.com/wpetit/goweb/logger"
)
type AgentRepository struct {
db *sql.DB
}
// DeleteSpec implements datastore.AgentRepository.
func (r *AgentRepository) DeleteSpec(ctx context.Context, agentID datastore.AgentID, name string) error {
query := `DELETE FROM specs WHERE agent_id = $1 AND name = $2`
_, err := r.db.ExecContext(ctx, query, agentID, name)
if err != nil {
return errors.WithStack(err)
}
return nil
}
// GetSpecs implements datastore.AgentRepository.
func (r *AgentRepository) GetSpecs(ctx context.Context, agentID datastore.AgentID) ([]*datastore.Spec, error) {
specs := make([]*datastore.Spec, 0)
query := `
SELECT id, name, revision, data, created_at, updated_at
FROM specs
WHERE agent_id = $1
`
rows, err := r.db.QueryContext(ctx, query, agentID)
if err != nil {
return nil, errors.WithStack(err)
}
defer rows.Close()
for rows.Next() {
spec := &datastore.Spec{}
data := JSONMap{}
if err := rows.Scan(&spec.ID, &spec.Name, &spec.Revision, &data, &spec.CreatedAt, &spec.UpdatedAt); err != nil {
return nil, errors.WithStack(err)
}
spec.Data = data
specs = append(specs, spec)
}
return specs, nil
}
// UpdateSpec implements datastore.AgentRepository.
func (r *AgentRepository) UpdateSpec(ctx context.Context, agentID datastore.AgentID, name string, revision int, data map[string]any) (*datastore.Spec, error) {
spec := &datastore.Spec{}
err := r.withTx(ctx, func(tx *sql.Tx) error {
now := time.Now().UTC()
query := `
INSERT INTO specs (agent_id, name, revision, data, created_at, updated_at)
VALUES($1, $2, $3, $4, $5, $5)
ON CONFLICT (agent_id, name) DO UPDATE SET
data = $4, updated_at = $5, revision = specs.revision + 1
WHERE revision = $3
RETURNING "id", "name", "revision", "data", "created_at", "updated_at"
`
args := []any{agentID, name, revision, JSONMap(data), now}
logger.Debug(ctx, "executing query", logger.F("query", query), logger.F("args", args))
row := tx.QueryRowContext(ctx, query, args...)
data := JSONMap{}
err := row.Scan(&spec.ID, &spec.Name, &spec.Revision, &data, &spec.CreatedAt, &spec.UpdatedAt)
if err != nil {
if errors.Is(err, sql.ErrNoRows) {
return errors.WithStack(datastore.ErrUnexpectedRevision)
}
return errors.WithStack(err)
}
spec.Data = data
return nil
})
if err != nil {
return nil, errors.WithStack(err)
}
return spec, nil
}
// Query implements datastore.AgentRepository.
func (r *AgentRepository) Query(ctx context.Context, opts ...datastore.AgentQueryOptionFunc) ([]*datastore.Agent, int, error) {
options := &datastore.AgentQueryOptions{}
for _, fn := range opts {
fn(options)
}
agents := make([]*datastore.Agent, 0)
count := 0
err := r.withTx(ctx, func(tx *sql.Tx) error {
query := `SELECT id, remote_id, status, created_at, updated_at FROM agents`
limit := 10
if options.Limit != nil {
limit = *options.Limit
}
offset := 0
if options.Offset != nil {
offset = *options.Offset
}
filters := ""
paramIndex := 3
args := []any{offset, limit}
if options.IDs != nil && len(options.IDs) > 0 {
filters += "id in ("
filter, newArgs, newParamIndex := inFilter("id", paramIndex, options.RemoteIDs)
filters += filter
paramIndex = newParamIndex
args = append(args, newArgs...)
}
if options.RemoteIDs != nil && len(options.RemoteIDs) > 0 {
if filters != "" {
filters += " AND "
}
filter, newArgs, newParamIndex := inFilter("remote_id", paramIndex, options.RemoteIDs)
filters += filter
paramIndex = newParamIndex
args = append(args, newArgs...)
}
if options.Statuses != nil && len(options.Statuses) > 0 {
if filters != "" {
filters += " AND "
}
filter, newArgs, _ := inFilter("status", paramIndex, options.Statuses)
filters += filter
args = append(args, newArgs...)
}
if filters != "" {
filters = ` WHERE ` + filters
}
query += filters + ` LIMIT $2 OFFSET $1`
logger.Debug(ctx, "executing query", logger.F("query", query), logger.F("args", args))
rows, err := r.db.QueryContext(ctx, query, args...)
if err != nil {
return errors.WithStack(err)
}
defer rows.Close()
for rows.Next() {
agent := &datastore.Agent{}
if err := rows.Scan(&agent.ID, &agent.RemoteID, &agent.Status, &agent.CreatedAt, &agent.UpdatedAt); err != nil {
return errors.WithStack(err)
}
agents = append(agents, agent)
}
row := tx.QueryRowContext(ctx, `SELECT count(id) FROM agents `+filters, args...)
if err := row.Scan(&count); err != nil {
return errors.WithStack(err)
}
return nil
})
if err != nil {
return nil, 0, errors.WithStack(err)
}
return agents, count, nil
}
// Create implements datastore.AgentRepository
func (r *AgentRepository) Create(ctx context.Context, remoteID string, status datastore.AgentStatus) (*datastore.Agent, error) {
agent := &datastore.Agent{}
err := r.withTx(ctx, func(tx *sql.Tx) error {
query := `SELECT count(id) FROM agents WHERE remote_id = $1`
row := tx.QueryRowContext(ctx, query, remoteID)
var count int
if err := row.Scan(&count); err != nil {
return errors.WithStack(err)
}
if count > 0 {
return errors.WithStack(datastore.ErrAlreadyExist)
}
now := time.Now().UTC()
query = `
INSERT INTO agents (remote_id, status, created_at, updated_at)
VALUES($1, $2, $3, $3)
RETURNING "id", "remote_id", "status", "created_at", "updated_at"
`
row = tx.QueryRowContext(
ctx, query,
remoteID, status, now,
)
err := row.Scan(&agent.ID, &agent.RemoteID, &agent.Status, &agent.CreatedAt, &agent.UpdatedAt)
if err != nil {
return errors.WithStack(err)
}
return nil
})
if err != nil {
return nil, errors.WithStack(err)
}
return agent, nil
}
// Delete implements datastore.AgentRepository
func (r *AgentRepository) Delete(ctx context.Context, id datastore.AgentID) error {
query := `DELETE FROM agents WHERE id = $1`
_, err := r.db.ExecContext(ctx, query, id)
if err != nil {
return errors.WithStack(err)
}
return nil
}
// Get implements datastore.AgentRepository
func (r *AgentRepository) Get(ctx context.Context, id datastore.AgentID) (*datastore.Agent, error) {
agent := &datastore.Agent{
ID: id,
}
err := r.withTx(ctx, func(tx *sql.Tx) error {
query := `
SELECT "remote_id", "status", "created_at", "updated_at"
FROM agents
WHERE id = $1
`
row := r.db.QueryRowContext(ctx, query, id)
if err := row.Scan(&agent.RemoteID, &agent.Status, &agent.CreatedAt, &agent.UpdatedAt); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return datastore.ErrNotFound
}
return errors.WithStack(err)
}
return nil
})
if err != nil {
return nil, errors.WithStack(err)
}
return agent, nil
}
// Update implements datastore.AgentRepository
func (r *AgentRepository) Update(ctx context.Context, id datastore.AgentID, opts ...datastore.AgentUpdateOptionFunc) (*datastore.Agent, error) {
options := &datastore.AgentUpdateOptions{}
for _, fn := range opts {
fn(options)
}
agent := &datastore.Agent{}
err := r.withTx(ctx, func(tx *sql.Tx) error {
query := `
UPDATE agents SET updated_at = $2
`
now := time.Now().UTC()
args := []any{
id, now,
}
index := 3
if options.Status != nil {
query += fmt.Sprintf(`, status = $%d`, index)
args = append(args, *options.Status)
index++
}
query += `
WHERE id = $1
RETURNING "id","remote_id","status","updated_at","created_at"
`
row := tx.QueryRowContext(ctx, query, args...)
if err := row.Scan(&agent.ID, &agent.RemoteID, &agent.Status, &agent.CreatedAt, &agent.UpdatedAt); err != nil {
return errors.WithStack(err)
}
return nil
})
if err != nil {
return nil, errors.WithStack(err)
}
return agent, nil
}
func (r *AgentRepository) withTx(ctx context.Context, fn func(*sql.Tx) error) error {
tx, err := r.db.Begin()
if err != nil {
return errors.WithStack(err)
}
defer func() {
if err := tx.Rollback(); err != nil {
if errors.Is(err, sql.ErrTxDone) {
return
}
logger.Error(ctx, "could not rollback transaction", logger.E(errors.WithStack(err)))
}
}()
if err := fn(tx); err != nil {
return errors.WithStack(err)
}
if err := tx.Commit(); err != nil {
return errors.WithStack(err)
}
return nil
}
func NewAgentRepository(db *sql.DB) *AgentRepository {
return &AgentRepository{db}
}
var _ datastore.AgentRepository = &AgentRepository{}
func inFilter[T any](column string, paramIndex int, items []T) (string, []any, int) {
args := make([]any, 0, len(items))
filter := fmt.Sprintf("%s in (", column)
for idx, item := range items {
if idx != 0 {
filter += ","
}
filter += fmt.Sprintf("$%d", paramIndex)
paramIndex++
args = append(args, item)
}
filter += ")"
return filter, args, paramIndex
}