package sqlite import ( "context" "database/sql" "encoding/json" "fmt" "time" "forge.cadoles.com/Cadoles/emissary/internal/datastore" "github.com/lestrrat-go/jwx/v2/jwk" "github.com/pkg/errors" "gitlab.com/wpetit/goweb/logger" ) type AgentRepository struct { db *sql.DB } // DeleteSpec implements datastore.AgentRepository. func (r *AgentRepository) DeleteSpec(ctx context.Context, agentID datastore.AgentID, name string) error { err := r.withTx(ctx, func(tx *sql.Tx) error { exists, err := r.agentExists(ctx, tx, agentID) if err != nil { return errors.WithStack(err) } if !exists { return errors.WithStack(datastore.ErrNotFound) } query := `DELETE FROM specs WHERE agent_id = $1 AND name = $2` if _, err = tx.ExecContext(ctx, query, agentID, name); err != nil { return errors.WithStack(err) } return nil }) if err != nil { return errors.WithStack(err) } return nil } // GetSpecs implements datastore.AgentRepository. func (r *AgentRepository) GetSpecs(ctx context.Context, agentID datastore.AgentID) ([]*datastore.Spec, error) { specs := make([]*datastore.Spec, 0) err := r.withTx(ctx, func(tx *sql.Tx) error { exists, err := r.agentExists(ctx, tx, agentID) if err != nil { return errors.WithStack(err) } if !exists { return errors.WithStack(datastore.ErrNotFound) } query := ` SELECT id, name, revision, data, created_at, updated_at FROM specs WHERE agent_id = $1 ` rows, err := tx.QueryContext(ctx, query, agentID) if err != nil { return errors.WithStack(err) } defer func() { if err := rows.Close(); err != nil { logger.Error(ctx, "could not close rows", logger.E(errors.WithStack(err))) } }() for rows.Next() { spec := &datastore.Spec{} data := JSONMap{} if err := rows.Scan(&spec.ID, &spec.Name, &spec.Revision, &data, &spec.CreatedAt, &spec.UpdatedAt); err != nil { return errors.WithStack(err) } spec.Data = data specs = append(specs, spec) } if err := rows.Err(); err != nil { return errors.WithStack(err) } return nil }) if err != nil { return nil, errors.WithStack(err) } return specs, nil } // UpdateSpec implements datastore.AgentRepository. func (r *AgentRepository) UpdateSpec(ctx context.Context, agentID datastore.AgentID, name string, revision int, data map[string]any) (*datastore.Spec, error) { spec := &datastore.Spec{} err := r.withTx(ctx, func(tx *sql.Tx) error { exists, err := r.agentExists(ctx, tx, agentID) if err != nil { return errors.WithStack(err) } if !exists { return errors.WithStack(datastore.ErrNotFound) } now := time.Now().UTC() query := ` INSERT INTO specs (agent_id, name, revision, data, created_at, updated_at) VALUES($1, $2, $3, $4, $5, $5) ON CONFLICT (agent_id, name) DO UPDATE SET data = $4, updated_at = $5, revision = specs.revision + 1 WHERE revision = $3 RETURNING "id", "name", "revision", "data", "created_at", "updated_at" ` args := []any{agentID, name, revision, JSONMap(data), now} logger.Debug(ctx, "executing query", logger.F("query", query), logger.F("args", args)) row := tx.QueryRowContext(ctx, query, args...) data := JSONMap{} err = row.Scan(&spec.ID, &spec.Name, &spec.Revision, &data, &spec.CreatedAt, &spec.UpdatedAt) if err != nil { if errors.Is(err, sql.ErrNoRows) { return errors.WithStack(datastore.ErrUnexpectedRevision) } return errors.WithStack(err) } spec.Data = data return nil }) if err != nil { return nil, errors.WithStack(err) } return spec, nil } // Query implements datastore.AgentRepository. func (r *AgentRepository) Query(ctx context.Context, opts ...datastore.AgentQueryOptionFunc) ([]*datastore.Agent, int, error) { options := &datastore.AgentQueryOptions{} for _, fn := range opts { fn(options) } agents := make([]*datastore.Agent, 0) count := 0 err := r.withTx(ctx, func(tx *sql.Tx) error { query := `SELECT id, label, thumbprint, status, contacted_at, created_at, updated_at FROM agents` limit := 10 if options.Limit != nil { limit = *options.Limit } offset := 0 if options.Offset != nil { offset = *options.Offset } filters := "" paramIndex := 3 args := []any{offset, limit} if options.IDs != nil && len(options.IDs) > 0 { filter, newArgs, newParamIndex := inFilter("id", paramIndex, options.IDs) filters += filter paramIndex = newParamIndex args = append(args, newArgs...) } if options.Thumbprints != nil && len(options.Thumbprints) > 0 { if filters != "" { filters += " AND " } filter, newArgs, newParamIndex := inFilter("thumbprint", paramIndex, options.Thumbprints) filters += filter paramIndex = newParamIndex args = append(args, newArgs...) } if options.Statuses != nil && len(options.Statuses) > 0 { if filters != "" { filters += " AND " } filter, newArgs, _ := inFilter("status", paramIndex, options.Statuses) filters += filter args = append(args, newArgs...) } if filters != "" { filters = ` WHERE ` + filters } query += filters + ` LIMIT $2 OFFSET $1` logger.Debug(ctx, "executing query", logger.F("query", query), logger.F("args", args)) rows, err := r.db.QueryContext(ctx, query, args...) if err != nil { return errors.WithStack(err) } defer func() { if err := rows.Close(); err != nil { logger.Error(ctx, "could not close rows", logger.E(errors.WithStack(err))) } }() for rows.Next() { agent := &datastore.Agent{} metadata := JSONMap{} contactedAt := sql.NullTime{} if err := rows.Scan(&agent.ID, &agent.Label, &agent.Thumbprint, &agent.Status, &contactedAt, &agent.CreatedAt, &agent.UpdatedAt); err != nil { return errors.WithStack(err) } agent.Metadata = metadata if contactedAt.Valid { agent.ContactedAt = &contactedAt.Time } agents = append(agents, agent) } if err := rows.Err(); err != nil { return errors.WithStack(err) } row := tx.QueryRowContext(ctx, `SELECT count(id) FROM agents `+filters, args...) if err := row.Scan(&count); err != nil { return errors.WithStack(err) } return nil }) if err != nil { return nil, 0, errors.WithStack(err) } return agents, count, nil } // Create implements datastore.AgentRepository func (r *AgentRepository) Create(ctx context.Context, thumbprint string, keySet jwk.Set, metadata map[string]any) (*datastore.Agent, error) { agent := &datastore.Agent{} err := r.withTx(ctx, func(tx *sql.Tx) error { query := `SELECT count(id) FROM agents WHERE thumbprint = $1` row := tx.QueryRowContext(ctx, query, thumbprint) var count int if err := row.Scan(&count); err != nil { return errors.WithStack(err) } if count > 0 { return errors.WithStack(datastore.ErrAlreadyExist) } now := time.Now().UTC() query = ` INSERT INTO agents (thumbprint, keyset, metadata, status, created_at, updated_at) VALUES($1, $2, $3, $4, $5, $5) RETURNING "id", "thumbprint", "keyset", "metadata", "status", "created_at", "updated_at" ` rawKeySet, err := json.Marshal(keySet) if err != nil { return errors.WithStack(err) } row = tx.QueryRowContext( ctx, query, thumbprint, rawKeySet, JSONMap(metadata), datastore.AgentStatusPending, now, ) metadata := JSONMap{} err = row.Scan(&agent.ID, &agent.Thumbprint, &rawKeySet, &metadata, &agent.Status, &agent.CreatedAt, &agent.UpdatedAt) if err != nil { return errors.WithStack(err) } agent.Metadata = metadata keySet, err = jwk.Parse(rawKeySet) if err != nil { return errors.WithStack(err) } agent.KeySet = &datastore.SerializableKeySet{keySet} return nil }) if err != nil { return nil, errors.WithStack(err) } return agent, nil } // Delete implements datastore.AgentRepository func (r *AgentRepository) Delete(ctx context.Context, id datastore.AgentID) error { err := r.withTx(ctx, func(tx *sql.Tx) error { query := `DELETE FROM agents WHERE id = $1` _, err := r.db.ExecContext(ctx, query, id) if err != nil { return errors.WithStack(err) } query = `DELETE FROM specs WHERE agent_id = $1` _, err = r.db.ExecContext(ctx, query, id) if err != nil { return errors.WithStack(err) } return nil }) if err != nil { return errors.WithStack(err) } return nil } // Get implements datastore.AgentRepository func (r *AgentRepository) Get(ctx context.Context, id datastore.AgentID) (*datastore.Agent, error) { agent := &datastore.Agent{ ID: id, } err := r.withTx(ctx, func(tx *sql.Tx) error { query := ` SELECT "id", "label", "thumbprint", "keyset", "metadata", "status", "contacted_at", "created_at", "updated_at" FROM agents WHERE id = $1 ` row := r.db.QueryRowContext(ctx, query, id) metadata := JSONMap{} contactedAt := sql.NullTime{} var rawKeySet []byte if err := row.Scan(&agent.ID, &agent.Label, &agent.Thumbprint, &rawKeySet, &metadata, &agent.Status, &contactedAt, &agent.CreatedAt, &agent.UpdatedAt); err != nil { if errors.Is(err, sql.ErrNoRows) { return datastore.ErrNotFound } return errors.WithStack(err) } agent.Metadata = metadata if contactedAt.Valid { agent.ContactedAt = &contactedAt.Time } keySet := jwk.NewSet() if err := json.Unmarshal(rawKeySet, &keySet); err != nil { return errors.WithStack(err) } agent.KeySet = &datastore.SerializableKeySet{keySet} return nil }) if err != nil { return nil, errors.WithStack(err) } return agent, nil } // Update implements datastore.AgentRepository func (r *AgentRepository) Update(ctx context.Context, id datastore.AgentID, opts ...datastore.AgentUpdateOptionFunc) (*datastore.Agent, error) { options := &datastore.AgentUpdateOptions{} for _, fn := range opts { fn(options) } agent := &datastore.Agent{} err := r.withTx(ctx, func(tx *sql.Tx) error { query := ` UPDATE agents SET id = $1 ` args := []any{id} index := 2 if options.Status != nil { query += fmt.Sprintf(`, status = $%d`, index) args = append(args, *options.Status) index++ } if options.KeySet != nil { rawKeySet, err := json.Marshal(*options.KeySet) if err != nil { return errors.WithStack(err) } query += fmt.Sprintf(`, keyset = $%d`, index) args = append(args, rawKeySet) index++ } if options.Thumbprint != nil { query += fmt.Sprintf(`, thumbprint = $%d`, index) args = append(args, *options.Thumbprint) index++ } if options.Label != nil { query += fmt.Sprintf(`, label = $%d`, index) args = append(args, *options.Label) index++ } if options.ContactedAt != nil { query += fmt.Sprintf(`, contacted_at = $%d`, index) utc := options.ContactedAt.UTC() args = append(args, utc) index++ } if options.Metadata != nil { query += fmt.Sprintf(`, metadata = $%d`, index) args = append(args, JSONMap(*options.Metadata)) index++ } updated := options.Metadata != nil || options.Status != nil || options.Label != nil || options.KeySet != nil || options.Thumbprint != nil if updated { now := time.Now().UTC() query += fmt.Sprintf(`, updated_at = $%d`, index) args = append(args, now) index++ } query += ` WHERE id = $1 RETURNING "id", "label", "thumbprint", "keyset", "metadata", "status", "contacted_at", "created_at", "updated_at" ` logger.Debug(ctx, "executing query", logger.F("query", query), logger.F("args", args)) row := tx.QueryRowContext(ctx, query, args...) metadata := JSONMap{} contactedAt := sql.NullTime{} var rawKeySet []byte if err := row.Scan(&agent.ID, &agent.Label, &agent.Thumbprint, &rawKeySet, &metadata, &agent.Status, &contactedAt, &agent.CreatedAt, &agent.UpdatedAt); err != nil { if errors.Is(err, sql.ErrNoRows) { return datastore.ErrNotFound } return errors.WithStack(err) } agent.Metadata = metadata if contactedAt.Valid { agent.ContactedAt = &contactedAt.Time } keySet := jwk.NewSet() if err := json.Unmarshal(rawKeySet, &keySet); err != nil { return errors.WithStack(err) } agent.KeySet = &datastore.SerializableKeySet{keySet} return nil }) if err != nil { return nil, errors.WithStack(err) } return agent, nil } func (r *AgentRepository) agentExists(ctx context.Context, tx *sql.Tx, agentID datastore.AgentID) (bool, error) { row := tx.QueryRowContext(ctx, `SELECT count(id) FROM agents WHERE id = $1`, agentID) var count int if err := row.Scan(&count); err != nil { if errors.Is(err, sql.ErrNoRows) { return false, errors.WithStack(datastore.ErrNotFound) } return false, errors.WithStack(err) } if count == 0 { return false, errors.WithStack(datastore.ErrNotFound) } return true, nil } func (r *AgentRepository) withTx(ctx context.Context, fn func(*sql.Tx) error) error { tx, err := r.db.BeginTx(ctx, nil) if err != nil { return errors.WithStack(err) } defer func() { if err := tx.Rollback(); err != nil { if errors.Is(err, sql.ErrTxDone) { return } logger.Error(ctx, "could not rollback transaction", logger.E(errors.WithStack(err))) } }() if err := fn(tx); err != nil { return errors.WithStack(err) } if err := tx.Commit(); err != nil { return errors.WithStack(err) } return nil } func NewAgentRepository(db *sql.DB) *AgentRepository { return &AgentRepository{db} } var _ datastore.AgentRepository = &AgentRepository{} func inFilter[T any](column string, paramIndex int, items []T) (string, []any, int) { args := make([]any, 0, len(items)) filter := fmt.Sprintf("%s in (", column) for idx, item := range items { if idx != 0 { filter += "," } filter += fmt.Sprintf("$%d", paramIndex) paramIndex++ args = append(args, item) } filter += ")" return filter, args, paramIndex }