package orm import ( "context" "database/sql" "github.com/jinzhu/gorm" "github.com/pkg/errors" ) func WithTx(ctx context.Context, db *gorm.DB, fn func(context.Context, *gorm.DB) error) error { tx := db.BeginTx(ctx, &sql.TxOptions{}) defer func() { if err := tx.Rollback().Error; err != nil && !isGormError(err, gorm.ErrInvalidTransaction) { panic(errors.Wrap(err, "could not rollback transaction")) } }() if err := fn(ctx, tx); err != nil { err := errors.Wrap(err, "could not apply down migration") if rollbackErr := tx.Rollback().Error; rollbackErr != nil { return errors.Wrap(err, rollbackErr.Error()) } return err } if err := tx.Commit().Error; err != nil { return errors.Wrap(err, "could not commit transaction") } return nil } func isGormError(err error, compErr error) bool { if errs, ok := err.(gorm.Errors); ok { for _, err := range errs { if errors.Is(err, compErr) { return true } } } return errors.Is(err, compErr) }