48 lines
988 B
Go
48 lines
988 B
Go
|
package orm
|
||
|
|
||
|
import (
|
||
|
"context"
|
||
|
"database/sql"
|
||
|
|
||
|
"github.com/jinzhu/gorm"
|
||
|
"github.com/pkg/errors"
|
||
|
)
|
||
|
|
||
|
func WithTx(ctx context.Context, db *gorm.DB, fn func(context.Context, *gorm.DB) error) error {
|
||
|
tx := db.BeginTx(ctx, &sql.TxOptions{})
|
||
|
|
||
|
defer func() {
|
||
|
if err := tx.Rollback().Error; err != nil && !isGormError(err, gorm.ErrInvalidTransaction) {
|
||
|
panic(errors.Wrap(err, "could not rollback transaction"))
|
||
|
}
|
||
|
}()
|
||
|
|
||
|
if err := fn(ctx, tx); err != nil {
|
||
|
err := errors.Wrap(err, "could not apply down migration")
|
||
|
|
||
|
if rollbackErr := tx.Rollback().Error; rollbackErr != nil {
|
||
|
return errors.Wrap(err, rollbackErr.Error())
|
||
|
}
|
||
|
|
||
|
return err
|
||
|
}
|
||
|
|
||
|
if err := tx.Commit().Error; err != nil {
|
||
|
return errors.Wrap(err, "could not commit transaction")
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func isGormError(err error, compErr error) bool {
|
||
|
if errs, ok := err.(gorm.Errors); ok {
|
||
|
for _, err := range errs {
|
||
|
if errors.Is(err, compErr) {
|
||
|
return true
|
||
|
}
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return errors.Is(err, compErr)
|
||
|
}
|