59 lines
1.1 KiB
Go
59 lines
1.1 KiB
Go
package store
|
|
|
|
import (
|
|
"context"
|
|
"sync"
|
|
|
|
"github.com/pkg/errors"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
type Store struct {
|
|
getDatabase func(ctx context.Context) (*gorm.DB, error)
|
|
}
|
|
|
|
func New(db *gorm.DB) *Store {
|
|
return &Store{
|
|
getDatabase: createGetDatabase(db),
|
|
}
|
|
}
|
|
|
|
func (s *Store) Do(ctx context.Context, fn func(db *gorm.DB) error) error {
|
|
db, err := s.getDatabase(ctx)
|
|
if err != nil {
|
|
return errors.WithStack(err)
|
|
}
|
|
|
|
if err := fn(db); err != nil {
|
|
return errors.WithStack(err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func createGetDatabase(db *gorm.DB) func(ctx context.Context) (*gorm.DB, error) {
|
|
var (
|
|
migrateOnce sync.Once
|
|
migrateErr error
|
|
)
|
|
|
|
return func(ctx context.Context) (*gorm.DB, error) {
|
|
migrateOnce.Do(func() {
|
|
if err := db.AutoMigrate(models...); err != nil {
|
|
migrateErr = errors.WithStack(err)
|
|
return
|
|
}
|
|
|
|
if err := db.Exec("PRAGMA journal_mode=WAL; PRAGMA encoding='UTF-8';").Error; err != nil {
|
|
migrateErr = errors.WithStack(err)
|
|
return
|
|
}
|
|
})
|
|
if migrateErr != nil {
|
|
return nil, errors.WithStack(migrateErr)
|
|
}
|
|
|
|
return db, nil
|
|
}
|
|
}
|