100 lines
1.7 KiB
Go
100 lines
1.7 KiB
Go
|
package sqlite
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"database/sql"
|
||
|
"sync"
|
||
|
|
||
|
"github.com/pkg/errors"
|
||
|
"gitlab.com/wpetit/goweb/logger"
|
||
|
)
|
||
|
|
||
|
func withTx(ctx context.Context, db *sql.DB, fn func(tx *sql.Tx) error) error {
|
||
|
var tx *sql.Tx
|
||
|
|
||
|
tx, err := db.Begin()
|
||
|
if err != nil {
|
||
|
return errors.WithStack(err)
|
||
|
}
|
||
|
|
||
|
defer func() {
|
||
|
if err := tx.Rollback(); err != nil {
|
||
|
if errors.Is(err, sql.ErrTxDone) {
|
||
|
return
|
||
|
}
|
||
|
|
||
|
panic(errors.WithStack(err))
|
||
|
}
|
||
|
}()
|
||
|
|
||
|
if err = fn(tx); err != nil {
|
||
|
return errors.WithStack(err)
|
||
|
}
|
||
|
|
||
|
if err = tx.Commit(); err != nil {
|
||
|
return errors.WithStack(err)
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
type getDBFunc func(ctx context.Context) (*sql.DB, error)
|
||
|
|
||
|
func newGetDBFunc(dsn string, initFunc func(ctx context.Context, db *sql.DB) error) getDBFunc {
|
||
|
var (
|
||
|
db *sql.DB
|
||
|
mutex sync.RWMutex
|
||
|
)
|
||
|
|
||
|
return func(ctx context.Context) (*sql.DB, error) {
|
||
|
mutex.RLock()
|
||
|
if db != nil {
|
||
|
defer mutex.RUnlock()
|
||
|
|
||
|
return db, nil
|
||
|
}
|
||
|
|
||
|
mutex.RUnlock()
|
||
|
|
||
|
mutex.Lock()
|
||
|
defer mutex.Unlock()
|
||
|
|
||
|
logger.Debug(ctx, "opening database", logger.F("dsn", dsn))
|
||
|
|
||
|
newDB, err := sql.Open("sqlite", dsn)
|
||
|
if err != nil {
|
||
|
return nil, errors.WithStack(err)
|
||
|
}
|
||
|
|
||
|
logger.Debug(ctx, "initializing database")
|
||
|
|
||
|
if err = initFunc(ctx, newDB); err != nil {
|
||
|
return nil, errors.WithStack(err)
|
||
|
}
|
||
|
|
||
|
db = newDB
|
||
|
|
||
|
return db, nil
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func newGetDBFuncFromDB(db *sql.DB, initFunc func(ctx context.Context, db *sql.DB) error) getDBFunc {
|
||
|
var err error
|
||
|
|
||
|
initOnce := &sync.Once{}
|
||
|
|
||
|
return func(ctx context.Context) (*sql.DB, error) {
|
||
|
initOnce.Do(func() {
|
||
|
logger.Debug(ctx, "initializing database")
|
||
|
|
||
|
err = initFunc(ctx, db)
|
||
|
})
|
||
|
|
||
|
if err != nil {
|
||
|
return nil, errors.WithStack(err)
|
||
|
}
|
||
|
|
||
|
return db, nil
|
||
|
}
|
||
|
}
|