edge/pkg/storage/sqlite/sql.go

100 lines
1.7 KiB
Go
Raw Permalink Normal View History

2023-02-09 12:16:36 +01:00
package sqlite
import (
"context"
"database/sql"
"sync"
"github.com/pkg/errors"
"gitlab.com/wpetit/goweb/logger"
)
func withTx(ctx context.Context, db *sql.DB, fn func(tx *sql.Tx) error) error {
var tx *sql.Tx
tx, err := db.Begin()
if err != nil {
return errors.WithStack(err)
}
defer func() {
if err := tx.Rollback(); err != nil {
if errors.Is(err, sql.ErrTxDone) {
return
}
panic(errors.WithStack(err))
}
}()
if err = fn(tx); err != nil {
return errors.WithStack(err)
}
if err = tx.Commit(); err != nil {
return errors.WithStack(err)
}
return nil
}
type getDBFunc func(ctx context.Context) (*sql.DB, error)
func newGetDBFunc(dsn string, initFunc func(ctx context.Context, db *sql.DB) error) getDBFunc {
var (
db *sql.DB
mutex sync.RWMutex
)
return func(ctx context.Context) (*sql.DB, error) {
mutex.RLock()
if db != nil {
defer mutex.RUnlock()
return db, nil
}
mutex.RUnlock()
mutex.Lock()
defer mutex.Unlock()
logger.Debug(ctx, "opening database", logger.F("dsn", dsn))
newDB, err := sql.Open("sqlite", dsn)
if err != nil {
return nil, errors.WithStack(err)
}
logger.Debug(ctx, "initializing database")
if err = initFunc(ctx, newDB); err != nil {
return nil, errors.WithStack(err)
}
db = newDB
return db, nil
}
}
func newGetDBFuncFromDB(db *sql.DB, initFunc func(ctx context.Context, db *sql.DB) error) getDBFunc {
var err error
initOnce := &sync.Once{}
return func(ctx context.Context) (*sql.DB, error) {
initOnce.Do(func() {
logger.Debug(ctx, "initializing database")
err = initFunc(ctx, db)
})
if err != nil {
return nil, errors.WithStack(err)
}
return db, nil
}
}