package sqlite import ( "context" "database/sql" "encoding/json" "fmt" "strings" "time" "forge.cadoles.com/wpetit/hydra-webauthn/internal/storage" "github.com/go-webauthn/webauthn/webauthn" "github.com/pkg/errors" "gitlab.com/wpetit/goweb/logger" _ "embed" _ "modernc.org/sqlite" ) type UserRepository struct { db *sql.DB sqliteBusyRetryMaxAttempts int } // Create implements storage.UserRepository. func (r *UserRepository) CreateUser(ctx context.Context, username string, attributes map[string]any) (*storage.User, error) { user := storage.NewUser(username, attributes) err := r.withTxRetry(ctx, func(tx *sql.Tx) error { query := `SELECT COUNT(id) FROM users WHERE username = $1` args := []any{ user.Username, } row := tx.QueryRowContext(ctx, query, args...) var count int64 if err := row.Scan(&count); err != nil { return errors.WithStack(err) } if err := row.Err(); err != nil { return errors.WithStack(err) } if count > 0 { return errors.WithStack(storage.ErrAlreadyExist) } query = ` INSERT INTO users (id, username, attributes, created_at, updated_at) VALUES ($1, $2, $3, $4, $4) ` rawAttributes, err := json.Marshal(user.Attributes) if err != nil { return errors.WithStack(err) } args = []any{ user.ID, user.Username, rawAttributes, user.CreatedAt, user.UpdatedAt, } if _, err := tx.ExecContext(ctx, query, args...); err != nil { return errors.WithStack(err) } return nil }) if err != nil { return nil, errors.WithStack(err) } return user, nil } // DeleteUserByID implements storage.UserRepository. func (*UserRepository) DeleteUserByID(ctx context.Context, username string) error { panic("unimplemented") } // FindUserByUsername implements storage.UserRepository. func (r *UserRepository) FindUserByUsername(ctx context.Context, username string) (*storage.User, error) { user, err := r.findUserBy(ctx, "username", username) if err != nil { return nil, errors.WithStack(err) } return user, nil } // FindUserByID implements storage.UserRepository. func (r *UserRepository) FindUserByID(ctx context.Context, userID string) (*storage.User, error) { user, err := r.findUserBy(ctx, "id", userID) if err != nil { return nil, errors.WithStack(err) } return user, nil } func (r *UserRepository) findUserBy(ctx context.Context, column string, value any) (*storage.User, error) { user := &storage.User{} err := r.withTxRetry(ctx, func(tx *sql.Tx) error { query := fmt.Sprintf(`SELECT id, username, attributes, created_at, updated_at FROM users WHERE %s = $1`, column) args := []any{ value, } var rawAttributes []byte row := tx.QueryRowContext(ctx, query, args...) if err := row.Scan(&user.ID, &user.Username, &rawAttributes, &user.CreatedAt, &user.UpdatedAt); err != nil { if errors.Is(err, sql.ErrNoRows) { return errors.WithStack(storage.ErrNotFound) } return errors.WithStack(err) } if err := row.Err(); err != nil { return errors.WithStack(err) } if err := json.Unmarshal(rawAttributes, &user.Attributes); err != nil { return errors.WithStack(err) } if user.Attributes == nil { user.Attributes = make(map[string]any) } query = `SELECT credential FROM user_credentials WHERE user_id = $1` args = []any{ user.ID, } rows, err := tx.QueryContext(ctx, query, args...) if err != nil { return errors.WithStack(err) } defer func() { if err := rows.Close(); err != nil { logger.Error(ctx, "could not close rows", logger.E(errors.WithStack(err))) } }() user.Credentials = make([]webauthn.Credential, 0) for rows.Next() { var ( rawCredential []byte credential webauthn.Credential ) if err := rows.Scan(&rawCredential); err != nil { return errors.WithStack(err) } if err := json.Unmarshal(rawCredential, &credential); err != nil { return errors.WithStack(err) } user.Credentials = append(user.Credentials, credential) } if err := rows.Err(); err != nil { return errors.WithStack(err) } return nil }) if err != nil { return nil, errors.WithStack(err) } return user, nil } // List implements storage.UserRepository. func (r *UserRepository) ListUsers(ctx context.Context) ([]storage.UserHeader, error) { users := make([]storage.UserHeader, 0) err := r.withTxRetry(ctx, func(tx *sql.Tx) error { query := `SELECT id, username, created_at, updated_at FROM users` rows, err := tx.QueryContext(ctx, query) if err != nil { return errors.WithStack(err) } defer func() { if err := rows.Close(); err != nil { logger.Error(ctx, "could not close rows", logger.E(errors.WithStack(err))) } }() for rows.Next() { user := storage.UserHeader{} if err := rows.Scan(&user.ID, &user.Username, &user.CreatedAt, &user.UpdatedAt); err != nil { return errors.WithStack(err) } users = append(users, user) } if err := rows.Err(); err != nil { return errors.WithStack(err) } return nil }) if err != nil { return nil, errors.WithStack(err) } return users, nil } // UpdateUsername implements storage.UserRepository. func (r *UserRepository) UpdateUserUsername(ctx context.Context, userID string, username string) (*storage.User, error) { var user *storage.User err := r.withTxRetry(ctx, func(tx *sql.Tx) error { query := `SELECT COUNT(id) FROM users WHERE id = $1` args := []any{ userID, } row := tx.QueryRowContext(ctx, query, args...) var count int64 if err := row.Scan(&count); err != nil { return errors.WithStack(err) } if err := row.Err(); err != nil { return errors.WithStack(err) } if count == 0 { return errors.WithStack(storage.ErrNotFound) } query = ` UPDATE users SET username = $1, updated_at = $2 WHERE id = $3 RETURNING id, username, attributes, created_at, updated_at ` args = []any{ username, time.Now(), userID, } var rawAttributes []byte user = &storage.User{} row = tx.QueryRowContext(ctx, query, args...) if err := row.Scan(&user.ID, &user.Username, &rawAttributes, &user.CreatedAt, &user.UpdatedAt); err != nil { return errors.WithStack(err) } if err := row.Err(); err != nil { return errors.WithStack(err) } if err := json.Unmarshal(rawAttributes, &user.Attributes); err != nil { return errors.WithStack(err) } if user.Attributes == nil { user.Attributes = make(map[string]any) } return nil }) if err != nil { return nil, errors.WithStack(err) } return user, nil } // Update implements storage.UserRepository. func (r *UserRepository) UpdateUserAttributes(ctx context.Context, userID string, attributes map[string]any) (*storage.User, error) { var user *storage.User err := r.withTxRetry(ctx, func(tx *sql.Tx) error { query := `SELECT COUNT(id) FROM users WHERE id = $1` args := []any{ userID, } row := tx.QueryRowContext(ctx, query, args...) var count int64 if err := row.Scan(&count); err != nil { return errors.WithStack(err) } if err := row.Err(); err != nil { return errors.WithStack(err) } if count == 0 { return errors.WithStack(storage.ErrNotFound) } query = ` UPDATE users SET attributes = $1, updated_at = $2 WHERE id = $3 RETURNING id, username, attributes, created_at, updated_at ` rawAttributes, err := json.Marshal(attributes) if err != nil { return errors.WithStack(err) } args = []any{ rawAttributes, time.Now(), userID, } user = &storage.User{} row = tx.QueryRowContext(ctx, query, args...) if err := row.Scan(&user.ID, &user.Username, &rawAttributes, &user.CreatedAt, &user.UpdatedAt); err != nil { return errors.WithStack(err) } if err := row.Err(); err != nil { return errors.WithStack(err) } if err := json.Unmarshal(rawAttributes, &user.Attributes); err != nil { return errors.WithStack(err) } if user.Attributes == nil { user.Attributes = make(map[string]any) } return nil }) if err != nil { return nil, errors.WithStack(err) } return user, nil } // AddCredential implements storage.UserRepository. func (r *UserRepository) AddUserCredential(ctx context.Context, userID string, credential *webauthn.Credential) (string, error) { credentialID := storage.NewID() err := r.withTxRetry(ctx, func(tx *sql.Tx) error { exists, err := r.userExists(ctx, tx, userID) if err != nil { return errors.WithStack(err) } if !exists { return errors.WithStack(storage.ErrNotFound) } query := ` INSERT INTO user_credentials (id, user_id, credential, created_at) VALUES ($1, $2, $3, $4) ` rawCredential, err := json.Marshal(credential) if err != nil { return errors.WithStack(err) } args := []any{ credentialID, userID, rawCredential, time.Now(), } if _, err := tx.ExecContext(ctx, query, args...); err != nil { return errors.WithStack(err) } return nil }) if err != nil { return "", errors.WithStack(err) } return credentialID, nil } // RemoveCredential implements storage.UserRepository. func (r *UserRepository) RemoveUserCredential(ctx context.Context, userID string, credentialID string) error { err := r.withTxRetry(ctx, func(tx *sql.Tx) error { exists, err := r.userExists(ctx, tx, userID) if err != nil { return errors.WithStack(err) } if !exists { return errors.WithStack(storage.ErrNotFound) } query := ` DELETE FROM user_credentials WHERE id = $1 AND user_id = $2 ` args := []any{ credentialID, userID, } if _, err := tx.ExecContext(ctx, query, args...); err != nil { return errors.WithStack(err) } return nil }) if err != nil { return errors.WithStack(err) } return nil } // GenerateRegistrationLink implements storage.UserRepository. func (r *UserRepository) GenerateRegistrationLink(ctx context.Context, userID string) (*storage.RegistrationLink, error) { registrationLink := storage.NewRegistrationLink(userID) err := r.withTxRetry(ctx, func(tx *sql.Tx) error { exists, err := r.userExists(ctx, tx, userID) if err != nil { return errors.WithStack(err) } if !exists { return errors.WithStack(storage.ErrNotFound) } query := `DELETE FROM registration_links WHERE user_id = $1` args := []any{ registrationLink.UserID, } if _, err := tx.ExecContext(ctx, query, args...); err != nil { return errors.WithStack(err) } query = ` INSERT INTO registration_links (token, user_id, created_at) VALUES ($1, $2, $3) ` args = []any{ registrationLink.Token, registrationLink.UserID, registrationLink.CreatedAt, } if _, err := tx.ExecContext(ctx, query, args...); err != nil { return errors.WithStack(err) } return nil }) if err != nil { return nil, errors.WithStack(err) } return registrationLink, nil } // ClearRegistrationLink implements storage.UserRepository. func (*UserRepository) ClearRegistrationLink(ctx context.Context, userID string) error { panic("unimplemented") } // GetRegistrationLink implements storage.UserRepository. func (r *UserRepository) GetRegistrationLink(ctx context.Context, userID string) (*storage.RegistrationLink, error) { registrationLink := &storage.RegistrationLink{} err := r.withTxRetry(ctx, func(tx *sql.Tx) error { query := `SELECT token, user_id, created_at FROM registration_links WHERE user_id = $1` args := []any{ userID, } row := tx.QueryRowContext(ctx, query, args...) if err := row.Scan(®istrationLink.Token, ®istrationLink.UserID, ®istrationLink.CreatedAt); err != nil { if errors.Is(err, sql.ErrNoRows) { return errors.WithStack(storage.ErrNotFound) } return errors.WithStack(err) } if err := row.Err(); err != nil { return errors.WithStack(err) } return nil }) if err != nil { return nil, errors.WithStack(err) } return registrationLink, nil } // GetRegistrationLinkByToken implements storage.UserRepository. func (r *UserRepository) GetRegistrationLinkByToken(ctx context.Context, token string) (*storage.RegistrationLink, error) { registrationLink := &storage.RegistrationLink{} err := r.withTxRetry(ctx, func(tx *sql.Tx) error { query := `SELECT token, user_id, created_at FROM registration_links WHERE token = $1` args := []any{ token, } row := tx.QueryRowContext(ctx, query, args...) if err := row.Scan(®istrationLink.Token, ®istrationLink.UserID, ®istrationLink.CreatedAt); err != nil { if errors.Is(err, sql.ErrNoRows) { return errors.WithStack(storage.ErrNotFound) } return errors.WithStack(err) } if err := row.Err(); err != nil { return errors.WithStack(err) } return nil }) if err != nil { return nil, errors.WithStack(err) } return registrationLink, nil } func (r *UserRepository) userExists(ctx context.Context, tx *sql.Tx, userID string) (bool, error) { query := `SELECT COUNT(id) FROM users WHERE id = $1` args := []any{ userID, } row := tx.QueryRowContext(ctx, query, args...) var count int64 if err := row.Scan(&count); err != nil { return false, errors.WithStack(err) } if err := row.Err(); err != nil { return false, errors.WithStack(err) } return count >= 1, nil } func (r *UserRepository) withTxRetry(ctx context.Context, fn func(*sql.Tx) error) error { attempts := 0 max := r.sqliteBusyRetryMaxAttempts ctx = logger.With(ctx, logger.F("max", max)) var err error for { ctx = logger.With(ctx) if attempts >= max { logger.Debug(ctx, "transaction retrying failed", logger.F("attempts", attempts)) return errors.Wrapf(err, "transaction failed after %d attempts", max) } err = r.withTx(ctx, fn) if err != nil { if !strings.Contains(err.Error(), "(5) (SQLITE_BUSY)") { return errors.WithStack(err) } err = errors.WithStack(err) logger.Warn(ctx, "database is busy", logger.E(err)) wait := time.Duration(8<<(attempts+1)) * time.Millisecond logger.Debug( ctx, "database is busy, waiting before retrying transaction", logger.F("wait", wait.String()), logger.F("attempts", attempts), ) timer := time.NewTimer(wait) select { case <-timer.C: attempts++ continue case <-ctx.Done(): if err := ctx.Err(); err != nil { return errors.WithStack(err) } return nil } } return nil } } func (r *UserRepository) withTx(ctx context.Context, fn func(*sql.Tx) error) error { tx, err := r.db.BeginTx(ctx, nil) if err != nil { return errors.WithStack(err) } defer func() { if err := tx.Rollback(); err != nil { if errors.Is(err, sql.ErrTxDone) { return } err = errors.WithStack(err) logger.Error(ctx, "could not rollback transaction", logger.CapturedE(err)) } }() if err := fn(tx); err != nil { return errors.WithStack(err) } if err := tx.Commit(); err != nil { return errors.WithStack(err) } return nil } func NewUserRepository(dsn string) (*UserRepository, error) { db, err := sql.Open("sqlite", dsn) if err != nil { return nil, errors.WithStack(err) } if err := applyUserRepositoryMigration(db); err != nil { return nil, errors.Wrap(err, "could not migrate schema") } return &UserRepository{db, 5}, nil } var _ storage.UserRepository = &UserRepository{} //go:embed user_repository.sql var userRepositoryMigrationScript string func applyUserRepositoryMigration(db *sql.DB) error { if err := db.Ping(); err != nil { return errors.WithStack(err) } if _, err := db.Exec(userRepositoryMigrationScript); err != nil { return errors.WithStack(err) } return nil }