75 lines
1.5 KiB
Go
75 lines
1.5 KiB
Go
package store
|
|
|
|
import (
|
|
"context"
|
|
"database/sql"
|
|
"sync"
|
|
|
|
"github.com/pkg/errors"
|
|
"gorm.io/gorm"
|
|
)
|
|
|
|
type Store struct {
|
|
getDatabase func(ctx context.Context) (*gorm.DB, error)
|
|
}
|
|
|
|
func New(db *gorm.DB) *Store {
|
|
return &Store{
|
|
getDatabase: createGetDatabase(db),
|
|
}
|
|
}
|
|
|
|
func (s *Store) Do(ctx context.Context, fn func(db *gorm.DB) error) error {
|
|
db, err := s.getDatabase(ctx)
|
|
if err != nil {
|
|
return errors.WithStack(err)
|
|
}
|
|
|
|
if err := fn(db); err != nil {
|
|
return errors.WithStack(err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *Store) Tx(ctx context.Context, fn func(db *gorm.DB) error, opts ...*sql.TxOptions) error {
|
|
return errors.WithStack(s.Do(ctx, func(db *gorm.DB) error {
|
|
err := db.Transaction(func(tx *gorm.DB) error {
|
|
if err := fn(tx); err != nil {
|
|
return errors.WithStack(err)
|
|
}
|
|
return nil
|
|
}, opts...)
|
|
if err != nil {
|
|
return errors.WithStack(err)
|
|
}
|
|
return nil
|
|
}))
|
|
}
|
|
|
|
func createGetDatabase(db *gorm.DB) func(ctx context.Context) (*gorm.DB, error) {
|
|
var (
|
|
migrateOnce sync.Once
|
|
migrateErr error
|
|
)
|
|
|
|
return func(ctx context.Context) (*gorm.DB, error) {
|
|
migrateOnce.Do(func() {
|
|
if err := db.AutoMigrate(models...); err != nil {
|
|
migrateErr = errors.WithStack(err)
|
|
return
|
|
}
|
|
|
|
if err := db.Exec("PRAGMA journal_mode=WAL; PRAGMA busy_timeout=500; PRAGMA txlock=deferred; PRAGMA synchronous=normal; PRAGMA encoding='UTF-8';").Error; err != nil {
|
|
migrateErr = errors.WithStack(err)
|
|
return
|
|
}
|
|
})
|
|
if migrateErr != nil {
|
|
return nil, errors.WithStack(migrateErr)
|
|
}
|
|
|
|
return db, nil
|
|
}
|
|
}
|