package sqlite import ( "context" "database/sql" "strings" "time" "github.com/pkg/errors" "gitlab.com/wpetit/goweb/logger" ) type repository struct { db *sql.DB sqliteBusyRetryMaxAttempts int } func (r *repository) withTxRetry(ctx context.Context, fn func(*sql.Tx) error) error { attempts := 0 max := r.sqliteBusyRetryMaxAttempts ctx = logger.With(ctx, logger.F("max", max)) var err error for { ctx = logger.With(ctx) if attempts >= max { logger.Debug(ctx, "transaction retrying failed", logger.F("attempts", attempts)) return errors.Wrapf(err, "transaction failed after %d attempts", max) } err = r.withTx(ctx, fn) if err != nil { if !strings.Contains(err.Error(), "(5) (SQLITE_BUSY)") { return errors.WithStack(err) } err = errors.WithStack(err) logger.Warn(ctx, "database is busy", logger.E(err)) wait := time.Duration(8<<(attempts+1)) * time.Millisecond logger.Debug( ctx, "database is busy, waiting before retrying transaction", logger.F("wait", wait.String()), logger.F("attempts", attempts), ) timer := time.NewTimer(wait) select { case <-timer.C: attempts++ continue case <-ctx.Done(): if err := ctx.Err(); err != nil { return errors.WithStack(err) } return nil } } return nil } } func (r *repository) withTx(ctx context.Context, fn func(*sql.Tx) error) error { tx, err := r.db.BeginTx(ctx, nil) if err != nil { return errors.WithStack(err) } defer func() { if err := tx.Rollback(); err != nil { if errors.Is(err, sql.ErrTxDone) { return } err = errors.WithStack(err) logger.Error(ctx, "could not rollback transaction", logger.CapturedE(err)) } }() if err := fn(tx); err != nil { return errors.WithStack(err) } if err := tx.Commit(); err != nil { return errors.WithStack(err) } return nil }