package sqlite import ( "context" "database/sql" "strings" "sync" "time" "github.com/pkg/errors" "gitlab.com/wpetit/goweb/logger" "modernc.org/sqlite" _ "modernc.org/sqlite" sqlite3 "modernc.org/sqlite/lib" ) func Open(path string) (*sql.DB, error) { db, err := sql.Open("sqlite", path) if err != nil { return nil, errors.Wrapf(err, "could not open database with path '%s'", path) } return db, nil } func WithRetry(ctx context.Context, db *sql.DB, max int, fn func(*sql.Tx) error) error { attempts := 0 ctx = logger.With(ctx, logger.F("max", max)) var err error for { ctx = logger.With(ctx) if attempts >= max { logger.Debug(ctx, "transaction retrying failed", logger.F("attempts", attempts)) return errors.Wrapf(err, "transaction failed after %d attempts", max) } err = WithTx(ctx, db, fn) if err != nil { if !strings.Contains(err.Error(), "(5) (SQLITE_BUSY)") { return errors.WithStack(err) } err = errors.WithStack(err) logger.Warn(ctx, "database is busy", logger.E(err)) wait := time.Duration(8<<(attempts+1)) * time.Millisecond logger.Debug( ctx, "database is busy, waiting before retrying transaction", logger.F("wait", wait.String()), logger.F("attempts", attempts), ) timer := time.NewTimer(wait) select { case <-timer.C: attempts++ continue case <-ctx.Done(): if err := ctx.Err(); err != nil { return errors.WithStack(err) } return nil } } return nil } } func WithTx(ctx context.Context, db *sql.DB, fn func(tx *sql.Tx) error) error { var tx *sql.Tx tx, err := db.BeginTx(ctx, nil) if err != nil { return errors.WithStack(err) } defer func() { if err := tx.Rollback(); err != nil { if errors.Is(err, sql.ErrTxDone) { return } panic(errors.WithStack(err)) } }() for { if err = fn(tx); err != nil { var sqlErr *sqlite.Error if errors.As(err, &sqlErr) { if sqlErr.Code() == sqlite3.SQLITE_BUSY { logger.Warn(ctx, "database busy, retrying transaction") if err := ctx.Err(); err != nil { logger.Error(ctx, "could not execute transaction", logger.CapturedE(errors.WithStack(err))) return errors.WithStack(err) } continue } } return errors.WithStack(err) } break } if err = tx.Commit(); err != nil { return errors.WithStack(err) } return nil } type GetDBFunc func(ctx context.Context) (*sql.DB, error) func NewGetDBFunc(dsn string, initFunc func(ctx context.Context, db *sql.DB) error) GetDBFunc { var ( db *sql.DB mutex sync.RWMutex ) return func(ctx context.Context) (*sql.DB, error) { mutex.RLock() if db != nil { defer mutex.RUnlock() return db, nil } mutex.RUnlock() mutex.Lock() defer mutex.Unlock() logger.Debug(ctx, "opening database", logger.F("dsn", dsn)) newDB, err := sql.Open("sqlite", dsn) if err != nil { return nil, errors.WithStack(err) } logger.Debug(ctx, "initializing database") if err = initFunc(ctx, newDB); err != nil { return nil, errors.WithStack(err) } db = newDB return db, nil } } func NewGetDBFuncFromDB(db *sql.DB, initFunc func(ctx context.Context, db *sql.DB) error) GetDBFunc { var err error initOnce := &sync.Once{} return func(ctx context.Context) (*sql.DB, error) { initOnce.Do(func() { logger.Debug(ctx, "initializing database") err = initFunc(ctx, db) }) if err != nil { return nil, errors.WithStack(err) } return db, nil } }