diff --git a/internal/datastore/sqlite/agent_repository.go b/internal/datastore/sqlite/agent_repository.go index 62ef48a..7149b76 100644 --- a/internal/datastore/sqlite/agent_repository.go +++ b/internal/datastore/sqlite/agent_repository.go @@ -20,9 +20,24 @@ type AgentRepository struct { // DeleteSpec implements datastore.AgentRepository. func (r *AgentRepository) DeleteSpec(ctx context.Context, agentID datastore.AgentID, name string) error { - query := `DELETE FROM specs WHERE agent_id = $1 AND name = $2` + err := r.withTx(ctx, func(tx *sql.Tx) error { + exists, err := r.agentExists(ctx, tx, agentID) + if err != nil { + return errors.WithStack(err) + } - _, err := r.db.ExecContext(ctx, query, agentID, name) + if !exists { + return errors.WithStack(datastore.ErrNotFound) + } + + query := `DELETE FROM specs WHERE agent_id = $1 AND name = $2` + + if _, err = tx.ExecContext(ctx, query, agentID, name); err != nil { + return errors.WithStack(err) + } + + return nil + }) if err != nil { return errors.WithStack(err) } @@ -34,41 +49,57 @@ func (r *AgentRepository) DeleteSpec(ctx context.Context, agentID datastore.Agen func (r *AgentRepository) GetSpecs(ctx context.Context, agentID datastore.AgentID) ([]*datastore.Spec, error) { specs := make([]*datastore.Spec, 0) - query := ` + err := r.withTx(ctx, func(tx *sql.Tx) error { + exists, err := r.agentExists(ctx, tx, agentID) + if err != nil { + return errors.WithStack(err) + } + + if !exists { + return errors.WithStack(datastore.ErrNotFound) + } + + query := ` SELECT id, name, revision, data, created_at, updated_at FROM specs WHERE agent_id = $1 - ` + ` - rows, err := r.db.QueryContext(ctx, query, agentID) + rows, err := tx.QueryContext(ctx, query, agentID) + if err != nil { + return errors.WithStack(err) + } + + defer func() { + if err := rows.Close(); err != nil { + logger.Error(ctx, "could not close rows", logger.E(errors.WithStack(err))) + } + }() + + for rows.Next() { + spec := &datastore.Spec{} + + data := JSONMap{} + + if err := rows.Scan(&spec.ID, &spec.Name, &spec.Revision, &data, &spec.CreatedAt, &spec.UpdatedAt); err != nil { + return errors.WithStack(err) + } + + spec.Data = data + + specs = append(specs, spec) + } + + if err := rows.Err(); err != nil { + return errors.WithStack(err) + } + + return nil + }) if err != nil { return nil, errors.WithStack(err) } - defer func() { - if err := rows.Close(); err != nil { - logger.Error(ctx, "could not close rows", logger.E(errors.WithStack(err))) - } - }() - - for rows.Next() { - spec := &datastore.Spec{} - - data := JSONMap{} - - if err := rows.Scan(&spec.ID, &spec.Name, &spec.Revision, &data, &spec.CreatedAt, &spec.UpdatedAt); err != nil { - return nil, errors.WithStack(err) - } - - spec.Data = data - - specs = append(specs, spec) - } - - if err := rows.Err(); err != nil { - return nil, errors.WithStack(err) - } - return specs, nil } @@ -77,6 +108,15 @@ func (r *AgentRepository) UpdateSpec(ctx context.Context, agentID datastore.Agen spec := &datastore.Spec{} err := r.withTx(ctx, func(tx *sql.Tx) error { + exists, err := r.agentExists(ctx, tx, agentID) + if err != nil { + return errors.WithStack(err) + } + + if !exists { + return errors.WithStack(datastore.ErrNotFound) + } + now := time.Now().UTC() query := ` @@ -96,7 +136,7 @@ func (r *AgentRepository) UpdateSpec(ctx context.Context, agentID datastore.Agen data := JSONMap{} - err := row.Scan(&spec.ID, &spec.Name, &spec.Revision, &data, &spec.CreatedAt, &spec.UpdatedAt) + err = row.Scan(&spec.ID, &spec.Name, &spec.Revision, &data, &spec.CreatedAt, &spec.UpdatedAt) if err != nil { if errors.Is(err, sql.ErrNoRows) { return errors.WithStack(datastore.ErrUnexpectedRevision) @@ -472,8 +512,28 @@ func (r *AgentRepository) Update(ctx context.Context, id datastore.AgentID, opts return agent, nil } +func (r *AgentRepository) agentExists(ctx context.Context, tx *sql.Tx, agentID datastore.AgentID) (bool, error) { + row := tx.QueryRowContext(ctx, `SELECT count(id) FROM agents WHERE id = $1`, agentID) + + var count int + + if err := row.Scan(&count); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return false, errors.WithStack(datastore.ErrNotFound) + } + + return false, errors.WithStack(err) + } + + if count == 0 { + return false, errors.WithStack(datastore.ErrNotFound) + } + + return true, nil +} + func (r *AgentRepository) withTx(ctx context.Context, fn func(*sql.Tx) error) error { - tx, err := r.db.Begin() + tx, err := r.db.BeginTx(ctx, nil) if err != nil { return errors.WithStack(err) } diff --git a/internal/datastore/sqlite/agent_repository_test.go b/internal/datastore/sqlite/agent_repository_test.go new file mode 100644 index 0000000..47e70c9 --- /dev/null +++ b/internal/datastore/sqlite/agent_repository_test.go @@ -0,0 +1,46 @@ +package sqlite + +import ( + "database/sql" + "fmt" + "os" + "testing" + "time" + + "forge.cadoles.com/Cadoles/emissary/internal/datastore/testsuite" + "forge.cadoles.com/Cadoles/emissary/internal/migrate" + "github.com/pkg/errors" + "gitlab.com/wpetit/goweb/logger" + + _ "modernc.org/sqlite" +) + +func TestSQLiteAgentRepository(t *testing.T) { + logger.SetLevel(logger.LevelDebug) + + file := "testdata/agent_repository_test.sqlite" + + if err := os.Remove(file); err != nil && !errors.Is(err, os.ErrNotExist) { + t.Fatalf("%+v", errors.WithStack(err)) + } + + dsn := fmt.Sprintf("%s?_pragma=foreign_keys(1)&_pragma=busy_timeout=%d", file, (60 * time.Second).Milliseconds()) + + migr, err := migrate.New("../../../migrations", "sqlite", "sqlite://"+dsn) + if err != nil { + t.Fatalf("%+v", errors.WithStack(err)) + } + + if err := migr.Up(); err != nil { + t.Fatalf("%+v", errors.WithStack(err)) + } + + db, err := sql.Open("sqlite", dsn) + if err != nil { + t.Fatalf("%+v", errors.WithStack(err)) + } + + repo := NewAgentRepository(db) + + testsuite.TestAgentRepository(t, repo) +} diff --git a/internal/datastore/sqlite/testdata/.gitignore b/internal/datastore/sqlite/testdata/.gitignore new file mode 100644 index 0000000..885029a --- /dev/null +++ b/internal/datastore/sqlite/testdata/.gitignore @@ -0,0 +1 @@ +*.sqlite* \ No newline at end of file diff --git a/internal/datastore/testsuite/agent_repository.go b/internal/datastore/testsuite/agent_repository.go new file mode 100644 index 0000000..d40aca1 --- /dev/null +++ b/internal/datastore/testsuite/agent_repository.go @@ -0,0 +1,14 @@ +package testsuite + +import ( + "testing" + + "forge.cadoles.com/Cadoles/emissary/internal/datastore" +) + +func TestAgentRepository(t *testing.T, repo datastore.AgentRepository) { + t.Run("Cases", func(t *testing.T) { + t.Parallel() + runAgentRepositoryTests(t, repo) + }) +} diff --git a/internal/datastore/testsuite/agent_repository_cases.go b/internal/datastore/testsuite/agent_repository_cases.go new file mode 100644 index 0000000..d682062 --- /dev/null +++ b/internal/datastore/testsuite/agent_repository_cases.go @@ -0,0 +1,129 @@ +package testsuite + +import ( + "context" + "testing" + + "forge.cadoles.com/Cadoles/emissary/internal/agent/controller/mdns/spec" + "forge.cadoles.com/Cadoles/emissary/internal/datastore" + "forge.cadoles.com/Cadoles/emissary/internal/jwk" + "github.com/pkg/errors" +) + +type agentRepositoryTestCase struct { + Name string + Skip bool + Run func(ctx context.Context, repo datastore.AgentRepository) error +} + +var agentRepositoryTestCases = []agentRepositoryTestCase{ + { + Name: "Create a new agent", + Run: func(ctx context.Context, repo datastore.AgentRepository) error { + thumbprint := "foo" + keySet := jwk.NewSet() + var metadata map[string]any + + agent, err := repo.Create(ctx, thumbprint, keySet, metadata) + if err != nil { + return errors.WithStack(err) + } + + if agent.CreatedAt.IsZero() { + return errors.Errorf("agent.CreatedAt should not be zero time") + } + + if agent.UpdatedAt.IsZero() { + return errors.Errorf("agent.UpdatedAt should not be zero time") + } + + if e, g := datastore.AgentStatusPending, agent.Status; e != g { + return errors.Errorf("agent.Status: expected '%v', got '%v'", e, g) + } + + return nil + }, + }, + { + Name: "Try to update spec for an unexistant agent", + Run: func(ctx context.Context, repo datastore.AgentRepository) error { + var unexistantAgentID datastore.AgentID = 9999 + var specData map[string]any + + agent, err := repo.UpdateSpec(ctx, unexistantAgentID, string(spec.Name), 0, specData) + if err == nil { + return errors.New("error should not be nil") + } + + if !errors.Is(err, datastore.ErrNotFound) { + return errors.Errorf("error should be datastore.ErrNotFound, got '%+v'", err) + } + + if agent != nil { + return errors.New("agent should be nil") + } + + return nil + }, + }, + { + Name: "Try to delete spec of an unexistant agent", + Run: func(ctx context.Context, repo datastore.AgentRepository) error { + var unexistantAgentID datastore.AgentID = 9999 + + err := repo.DeleteSpec(ctx, unexistantAgentID, string(spec.Name)) + if err == nil { + return errors.New("error should not be nil") + } + + if !errors.Is(err, datastore.ErrNotFound) { + return errors.Errorf("error should be datastore.ErrNotFound, got '%+v'", err) + } + + return nil + }, + }, + { + Name: "Try to get specs of an unexistant agent", + Run: func(ctx context.Context, repo datastore.AgentRepository) error { + var unexistantAgentID datastore.AgentID = 9999 + + specs, err := repo.GetSpecs(ctx, unexistantAgentID) + if err == nil { + return errors.New("error should not be nil") + } + + if !errors.Is(err, datastore.ErrNotFound) { + return errors.Errorf("error should be datastore.ErrNotFound, got '%+v'", err) + } + + if specs != nil { + return errors.Errorf("specs should be nil, got '%+v'", err) + } + + return nil + }, + }, +} + +func runAgentRepositoryTests(t *testing.T, repo datastore.AgentRepository) { + for _, tc := range agentRepositoryTestCases { + func(tc agentRepositoryTestCase) { + t.Run(tc.Name, func(t *testing.T) { + t.Parallel() + + if tc.Skip { + t.SkipNow() + + return + } + + ctx := context.Background() + + if err := tc.Run(ctx, repo); err != nil { + t.Errorf("%+v", errors.WithStack(err)) + } + }) + }(tc) + } +} diff --git a/internal/migrate/migrate.go b/internal/migrate/migrate.go index 56e5f70..6cc47a4 100644 --- a/internal/migrate/migrate.go +++ b/internal/migrate/migrate.go @@ -2,7 +2,6 @@ package migrate import ( "fmt" - "log" "github.com/golang-migrate/migrate/v4" _ "github.com/golang-migrate/migrate/v4/database/postgres" @@ -23,8 +22,6 @@ func New(migrationDir, driver, dsn string) (*migrate.Migrate, error) { fmt.Sprintf("file://%s/%s", migrationDir, driver), dsn, ) - - log.Println(migrationDir, driver, dsn) if err != nil { return nil, errors.WithStack(err) }