feat: resources segregation by tenant
This commit is contained in:
@ -20,6 +20,75 @@ type AgentRepository struct {
|
||||
sqliteBusyRetryMaxAttempts int
|
||||
}
|
||||
|
||||
// Attach implements datastore.AgentRepository.
|
||||
func (r *AgentRepository) Attach(ctx context.Context, tenantID datastore.TenantID, agentID datastore.AgentID) (*datastore.Agent, error) {
|
||||
var agent datastore.Agent
|
||||
|
||||
err := r.withTxRetry(ctx, func(tx *sql.Tx) error {
|
||||
query := `SELECT count(id), tenant_id FROM agents WHERE id = $1`
|
||||
row := tx.QueryRowContext(ctx, query, agentID)
|
||||
|
||||
var (
|
||||
count int
|
||||
attachedTenantID *datastore.TenantID
|
||||
)
|
||||
|
||||
if err := row.Scan(&count, &attachedTenantID); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
if count == 0 {
|
||||
return errors.WithStack(datastore.ErrNotFound)
|
||||
}
|
||||
|
||||
if attachedTenantID != nil {
|
||||
return errors.WithStack(datastore.ErrAlreadyAttached)
|
||||
}
|
||||
|
||||
now := time.Now().UTC()
|
||||
|
||||
query = `
|
||||
UPDATE agents SET tenant_id = $1, updated_at = $2
|
||||
RETURNING "id", "thumbprint", "keyset", "metadata", "status", "created_at", "updated_at", "tenant_id"
|
||||
`
|
||||
|
||||
row = tx.QueryRowContext(
|
||||
ctx, query,
|
||||
tenantID,
|
||||
now,
|
||||
)
|
||||
|
||||
metadata := JSONMap{}
|
||||
var rawKeySet []byte
|
||||
|
||||
err := row.Scan(&agent.ID, &agent.Thumbprint, &rawKeySet, &metadata, &agent.Status, &agent.CreatedAt, &agent.UpdatedAt, &agent.TenantID)
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
agent.Metadata = metadata
|
||||
|
||||
keySet, err := jwk.Parse(rawKeySet)
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
agent.KeySet = &datastore.SerializableKeySet{keySet}
|
||||
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return &agent, nil
|
||||
}
|
||||
|
||||
// Detach implements datastore.AgentRepository.
|
||||
func (*AgentRepository) Detach(ctx context.Context, agentID datastore.AgentID) (*datastore.Agent, error) {
|
||||
panic("unimplemented")
|
||||
}
|
||||
|
||||
// DeleteSpec implements datastore.AgentRepository.
|
||||
func (r *AgentRepository) DeleteSpec(ctx context.Context, agentID datastore.AgentID, name string) error {
|
||||
err := r.withTxRetry(ctx, func(tx *sql.Tx) error {
|
||||
@ -170,7 +239,7 @@ func (r *AgentRepository) Query(ctx context.Context, opts ...datastore.AgentQuer
|
||||
count := 0
|
||||
|
||||
err := r.withTxRetry(ctx, func(tx *sql.Tx) error {
|
||||
query := `SELECT id, label, thumbprint, status, contacted_at, created_at, updated_at FROM agents`
|
||||
query := `SELECT id, label, thumbprint, status, contacted_at, created_at, updated_at, tenant_id FROM agents`
|
||||
|
||||
limit := 10
|
||||
if options.Limit != nil {
|
||||
@ -193,6 +262,13 @@ func (r *AgentRepository) Query(ctx context.Context, opts ...datastore.AgentQuer
|
||||
args = append(args, newArgs...)
|
||||
}
|
||||
|
||||
if options.TenantIDs != nil && len(options.TenantIDs) > 0 {
|
||||
filter, newArgs, newParamIndex := inFilter("tenant_id", paramIndex, options.TenantIDs)
|
||||
filters += filter
|
||||
paramIndex = newParamIndex
|
||||
args = append(args, newArgs...)
|
||||
}
|
||||
|
||||
if options.Thumbprints != nil && len(options.Thumbprints) > 0 {
|
||||
if filters != "" {
|
||||
filters += " AND "
|
||||
@ -240,7 +316,7 @@ func (r *AgentRepository) Query(ctx context.Context, opts ...datastore.AgentQuer
|
||||
metadata := JSONMap{}
|
||||
contactedAt := sql.NullTime{}
|
||||
|
||||
if err := rows.Scan(&agent.ID, &agent.Label, &agent.Thumbprint, &agent.Status, &contactedAt, &agent.CreatedAt, &agent.UpdatedAt); err != nil {
|
||||
if err := rows.Scan(&agent.ID, &agent.Label, &agent.Thumbprint, &agent.Status, &contactedAt, &agent.CreatedAt, &agent.UpdatedAt, &agent.TenantID); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
@ -293,7 +369,7 @@ func (r *AgentRepository) Create(ctx context.Context, thumbprint string, keySet
|
||||
query = `
|
||||
INSERT INTO agents (thumbprint, keyset, metadata, status, created_at, updated_at)
|
||||
VALUES($1, $2, $3, $4, $5, $5)
|
||||
RETURNING "id", "thumbprint", "keyset", "metadata", "status", "created_at", "updated_at"
|
||||
RETURNING "id", "thumbprint", "keyset", "metadata", "status", "created_at", "updated_at", "tenant_id"
|
||||
`
|
||||
|
||||
rawKeySet, err := json.Marshal(keySet)
|
||||
@ -308,7 +384,7 @@ func (r *AgentRepository) Create(ctx context.Context, thumbprint string, keySet
|
||||
|
||||
metadata := JSONMap{}
|
||||
|
||||
err = row.Scan(&agent.ID, &agent.Thumbprint, &rawKeySet, &metadata, &agent.Status, &agent.CreatedAt, &agent.UpdatedAt)
|
||||
err = row.Scan(&agent.ID, &agent.Thumbprint, &rawKeySet, &metadata, &agent.Status, &agent.CreatedAt, &agent.UpdatedAt, &agent.TenantID)
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
@ -363,7 +439,7 @@ func (r *AgentRepository) Get(ctx context.Context, id datastore.AgentID) (*datas
|
||||
|
||||
err := r.withTxRetry(ctx, func(tx *sql.Tx) error {
|
||||
query := `
|
||||
SELECT "id", "label", "thumbprint", "keyset", "metadata", "status", "contacted_at", "created_at", "updated_at"
|
||||
SELECT "id", "label", "thumbprint", "keyset", "metadata", "status", "contacted_at", "created_at", "updated_at", "tenant_id"
|
||||
FROM agents
|
||||
WHERE id = $1
|
||||
`
|
||||
@ -374,7 +450,7 @@ func (r *AgentRepository) Get(ctx context.Context, id datastore.AgentID) (*datas
|
||||
contactedAt := sql.NullTime{}
|
||||
var rawKeySet []byte
|
||||
|
||||
if err := row.Scan(&agent.ID, &agent.Label, &agent.Thumbprint, &rawKeySet, &metadata, &agent.Status, &contactedAt, &agent.CreatedAt, &agent.UpdatedAt); err != nil {
|
||||
if err := row.Scan(&agent.ID, &agent.Label, &agent.Thumbprint, &rawKeySet, &metadata, &agent.Status, &contactedAt, &agent.CreatedAt, &agent.UpdatedAt, &agent.TenantID); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return datastore.ErrNotFound
|
||||
}
|
||||
@ -476,7 +552,7 @@ func (r *AgentRepository) Update(ctx context.Context, id datastore.AgentID, opts
|
||||
|
||||
query += `
|
||||
WHERE id = $1
|
||||
RETURNING "id", "label", "thumbprint", "keyset", "metadata", "status", "contacted_at", "created_at", "updated_at"
|
||||
RETURNING "id", "label", "thumbprint", "keyset", "metadata", "status", "contacted_at", "created_at", "updated_at", "tenant_id"
|
||||
`
|
||||
|
||||
logger.Debug(ctx, "executing query", logger.F("query", query), logger.F("args", args))
|
||||
@ -487,7 +563,7 @@ func (r *AgentRepository) Update(ctx context.Context, id datastore.AgentID, opts
|
||||
contactedAt := sql.NullTime{}
|
||||
var rawKeySet []byte
|
||||
|
||||
if err := row.Scan(&agent.ID, &agent.Label, &agent.Thumbprint, &rawKeySet, &metadata, &agent.Status, &contactedAt, &agent.CreatedAt, &agent.UpdatedAt); err != nil {
|
||||
if err := row.Scan(&agent.ID, &agent.Label, &agent.Thumbprint, &rawKeySet, &metadata, &agent.Status, &contactedAt, &agent.CreatedAt, &agent.UpdatedAt, &agent.TenantID); err != nil {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return datastore.ErrNotFound
|
||||
}
|
||||
@ -622,23 +698,3 @@ func NewAgentRepository(db *sql.DB, sqliteBusyRetryMaxAttempts int) *AgentReposi
|
||||
}
|
||||
|
||||
var _ datastore.AgentRepository = &AgentRepository{}
|
||||
|
||||
func inFilter[T any](column string, paramIndex int, items []T) (string, []any, int) {
|
||||
args := make([]any, 0, len(items))
|
||||
filter := fmt.Sprintf("%s in (", column)
|
||||
|
||||
for idx, item := range items {
|
||||
if idx != 0 {
|
||||
filter += ","
|
||||
}
|
||||
|
||||
filter += fmt.Sprintf("$%d", paramIndex)
|
||||
paramIndex++
|
||||
|
||||
args = append(args, item)
|
||||
}
|
||||
|
||||
filter += ")"
|
||||
|
||||
return filter, args, paramIndex
|
||||
}
|
||||
|
23
internal/datastore/sqlite/sql.go
Normal file
23
internal/datastore/sqlite/sql.go
Normal file
@ -0,0 +1,23 @@
|
||||
package sqlite
|
||||
|
||||
import "fmt"
|
||||
|
||||
func inFilter[T any](column string, paramIndex int, items []T) (string, []any, int) {
|
||||
args := make([]any, 0, len(items))
|
||||
filter := fmt.Sprintf("%s in (", column)
|
||||
|
||||
for idx, item := range items {
|
||||
if idx != 0 {
|
||||
filter += ","
|
||||
}
|
||||
|
||||
filter += fmt.Sprintf("$%d", paramIndex)
|
||||
paramIndex++
|
||||
|
||||
args = append(args, item)
|
||||
}
|
||||
|
||||
filter += ")"
|
||||
|
||||
return filter, args, paramIndex
|
||||
}
|
Reference in New Issue
Block a user