package store import ( "context" "database/sql" "sync" "github.com/pkg/errors" "gorm.io/gorm" ) type Store struct { getDatabase func(ctx context.Context) (*gorm.DB, error) } func New(db *gorm.DB) *Store { return &Store{ getDatabase: createGetDatabase(db), } } func (s *Store) Do(ctx context.Context, fn func(db *gorm.DB) error) error { db, err := s.getDatabase(ctx) if err != nil { return errors.WithStack(err) } if err := fn(db); err != nil { return errors.WithStack(err) } return nil } func (s *Store) Tx(ctx context.Context, fn func(db *gorm.DB) error, opts ...*sql.TxOptions) error { return errors.WithStack(s.Do(ctx, func(db *gorm.DB) error { err := db.Transaction(func(tx *gorm.DB) error { if err := fn(tx); err != nil { return errors.WithStack(err) } return nil }, opts...) if err != nil { return errors.WithStack(err) } return nil })) } func createGetDatabase(db *gorm.DB) func(ctx context.Context) (*gorm.DB, error) { var ( migrateOnce sync.Once migrateErr error ) return func(ctx context.Context) (*gorm.DB, error) { migrateOnce.Do(func() { if err := db.AutoMigrate(models...); err != nil { migrateErr = errors.WithStack(err) return } if err := db.Exec("PRAGMA journal_mode=WAL; PRAGMA busy_timeout=500; PRAGMA txlock=deferred; PRAGMA synchronous=normal; PRAGMA encoding='UTF-8';").Error; err != nil { migrateErr = errors.WithStack(err) return } }) if migrateErr != nil { return nil, errors.WithStack(migrateErr) } return db, nil } }