package sqlite import ( "context" "database/sql" "sync" "github.com/pkg/errors" "gitlab.com/wpetit/goweb/logger" "modernc.org/sqlite" _ "modernc.org/sqlite" sqlite3 "modernc.org/sqlite/lib" ) func Open(path string) (*sql.DB, error) { db, err := sql.Open("sqlite", path) if err != nil { return nil, errors.Wrapf(err, "could not open database with path '%s'", path) } return db, nil } func WithTx(ctx context.Context, db *sql.DB, fn func(tx *sql.Tx) error) error { var tx *sql.Tx tx, err := db.BeginTx(ctx, nil) if err != nil { return errors.WithStack(err) } defer func() { if err := tx.Rollback(); err != nil { if errors.Is(err, sql.ErrTxDone) { return } panic(errors.WithStack(err)) } }() for { if err = fn(tx); err != nil { var sqlErr *sqlite.Error if errors.As(err, &sqlErr) { if sqlErr.Code() == sqlite3.SQLITE_BUSY { logger.Warn(ctx, "database busy, retrying transaction") if err := ctx.Err(); err != nil { logger.Error(ctx, "could not execute transaction", logger.E(errors.WithStack(err))) return errors.WithStack(err) } continue } } return errors.WithStack(err) } break } if err = tx.Commit(); err != nil { return errors.WithStack(err) } return nil } type GetDBFunc func(ctx context.Context) (*sql.DB, error) func NewGetDBFunc(dsn string, initFunc func(ctx context.Context, db *sql.DB) error) GetDBFunc { var ( db *sql.DB mutex sync.RWMutex ) return func(ctx context.Context) (*sql.DB, error) { mutex.RLock() if db != nil { defer mutex.RUnlock() return db, nil } mutex.RUnlock() mutex.Lock() defer mutex.Unlock() logger.Debug(ctx, "opening database", logger.F("dsn", dsn)) newDB, err := sql.Open("sqlite", dsn) if err != nil { return nil, errors.WithStack(err) } logger.Debug(ctx, "initializing database") if err = initFunc(ctx, newDB); err != nil { return nil, errors.WithStack(err) } db = newDB return db, nil } } func NewGetDBFuncFromDB(db *sql.DB, initFunc func(ctx context.Context, db *sql.DB) error) GetDBFunc { var err error initOnce := &sync.Once{} return func(ctx context.Context) (*sql.DB, error) { initOnce.Do(func() { logger.Debug(ctx, "initializing database") err = initFunc(ctx, db) }) if err != nil { return nil, errors.WithStack(err) } return db, nil } }