package orm import ( "context" "github.com/pkg/errors" ) var ( ErrNoAvailableMigration = errors.New("no available migration") ErrMigrationNotFound = errors.New("migration not found") ) type MigrationManager struct { migrations []Migration resolver VersionResolver } func (m *MigrationManager) Up(ctx context.Context) error { currentVersion, err := m.resolver.Current(ctx) if err != nil { return errors.Wrap(err, "could not retrieve current version") } migrate := func(up Migration) error { if err := up.Up(ctx); err != nil { return errors.Wrapf(err, "could not apply '%s' up migration", up.Version()) } if err := m.resolver.Set(ctx, up.Version()); err != nil { return errors.Wrapf(err, "could not update schema version to '%s'", up.Version()) } return nil } if currentVersion == "" { up := m.migrations[0] return migrate(up) } for i, mi := range m.migrations { if mi.Version() != currentVersion && currentVersion != "" { continue } // Already at latest, do nothing if i >= len(m.migrations)-1 { return nil } up := m.migrations[i+1] return migrate(up) } return errors.WithStack(ErrMigrationNotFound) } func (m *MigrationManager) Down(ctx context.Context) error { currentVersion, err := m.resolver.Current(ctx) if err != nil { return errors.Wrap(err, "could not retrieve current version") } for i, mi := range m.migrations { if mi.Version() != currentVersion { continue } if err := mi.Down(ctx); err != nil { return errors.Wrapf(err, "could not apply '%s' down migration", mi.Version()) } var version string // Already at oldest, do nothing if i != 0 { down := m.migrations[i-1] version = down.Version() } if err := m.resolver.Set(ctx, version); err != nil { return errors.Wrapf(err, "could not update schema version to '%s'", version) } return nil } return errors.WithStack(ErrMigrationNotFound) } func (m *MigrationManager) Latest(ctx context.Context) error { for { isLatest, err := m.IsLatest(ctx) if err != nil { return errors.Wrap(err, "could not retrieve schema state") } if isLatest { return nil } if err := m.Up(ctx); err != nil { return errors.WithStack(err) } } } func (m *MigrationManager) Register(migrations ...Migration) { m.migrations = migrations } func (m *MigrationManager) CurrentVersion(ctx context.Context) (string, error) { return m.resolver.Current(ctx) } func (m *MigrationManager) LatestVersion() (string, error) { if len(m.migrations) == 0 { return "", errors.WithStack(ErrNoAvailableMigration) } return m.migrations[len(m.migrations)-1].Version(), nil } func (m *MigrationManager) IsLatest(ctx context.Context) (bool, error) { currentVersion, err := m.resolver.Current(ctx) if err != nil { return false, errors.Wrap(err, "could not retrieve current version") } latestVersion, err := m.LatestVersion() if err != nil { return false, errors.Wrap(err, "could not retrieve latest version") } return currentVersion == latestVersion, nil } func NewMigrationManager(resolver VersionResolver) *MigrationManager { return &MigrationManager{ resolver: resolver, migrations: make([]Migration, 0), } }