package sqlite import ( "context" "database/sql" "sync" "github.com/pkg/errors" "gitlab.com/wpetit/goweb/logger" ) func withTx(ctx context.Context, db *sql.DB, fn func(tx *sql.Tx) error) error { var tx *sql.Tx tx, err := db.Begin() if err != nil { return errors.WithStack(err) } defer func() { if err := tx.Rollback(); err != nil { if errors.Is(err, sql.ErrTxDone) { return } panic(errors.WithStack(err)) } }() if err = fn(tx); err != nil { return errors.WithStack(err) } if err = tx.Commit(); err != nil { return errors.WithStack(err) } return nil } type getDBFunc func(ctx context.Context) (*sql.DB, error) func newGetDBFunc(dsn string, initFunc func(ctx context.Context, db *sql.DB) error) getDBFunc { var ( db *sql.DB mutex sync.RWMutex ) return func(ctx context.Context) (*sql.DB, error) { mutex.RLock() if db != nil { defer mutex.RUnlock() return db, nil } mutex.RUnlock() mutex.Lock() defer mutex.Unlock() logger.Debug(ctx, "opening database", logger.F("dsn", dsn)) newDB, err := sql.Open("sqlite", dsn) if err != nil { return nil, errors.WithStack(err) } logger.Debug(ctx, "initializing database") if err = initFunc(ctx, newDB); err != nil { return nil, errors.WithStack(err) } db = newDB return db, nil } } func newGetDBFuncFromDB(db *sql.DB, initFunc func(ctx context.Context, db *sql.DB) error) getDBFunc { var err error initOnce := &sync.Once{} return func(ctx context.Context) (*sql.DB, error) { initOnce.Do(func() { logger.Debug(ctx, "initializing database") err = initFunc(ctx, db) }) if err != nil { return nil, errors.WithStack(err) } return db, nil } }