Make go get to install work.

This commit is contained in:
Vikram Rangnekar
2020-04-16 00:26:32 -04:00
parent 40c99e9ef3
commit 09d6460a13
99 changed files with 240 additions and 73 deletions

View File

@ -0,0 +1,370 @@
package migrate
import (
"bytes"
"context"
"database/sql"
"fmt"
"io/ioutil"
"os"
"path/filepath"
"regexp"
"strconv"
"strings"
"text/template"
"github.com/pkg/errors"
)
var migrationPattern = regexp.MustCompile(`\A(\d+)_[^\.]+\.sql\z`)
var ErrNoFwMigration = errors.Errorf("no sql in forward migration step")
type BadVersionError string
func (e BadVersionError) Error() string {
return string(e)
}
type IrreversibleMigrationError struct {
m *Migration
}
func (e IrreversibleMigrationError) Error() string {
return fmt.Sprintf("Irreversible migration: %d - %s", e.m.Sequence, e.m.Name)
}
type NoMigrationsFoundError struct {
Path string
}
func (e NoMigrationsFoundError) Error() string {
return fmt.Sprintf("No migrations found at %s", e.Path)
}
type MigrationPgError struct {
Sql string
Error error
}
type Migration struct {
Sequence int32
Name string
UpSQL string
DownSQL string
}
type MigratorOptions struct {
// DisableTx causes the Migrator not to run migrations in a transaction.
DisableTx bool
// MigratorFS is the interface used for collecting the migrations.
MigratorFS MigratorFS
}
type Migrator struct {
db *sql.DB
versionTable string
options *MigratorOptions
Migrations []*Migration
OnStart func(int32, string, string, string) // OnStart is called when a migration is run with the sequence, name, direction, and SQL
Data map[string]interface{} // Data available to use in migrations
}
func NewMigrator(db *sql.DB, versionTable string) (m *Migrator, err error) {
return NewMigratorEx(db, versionTable, &MigratorOptions{MigratorFS: defaultMigratorFS{}})
}
func NewMigratorEx(db *sql.DB, versionTable string, opts *MigratorOptions) (m *Migrator, err error) {
m = &Migrator{db: db, versionTable: versionTable, options: opts}
err = m.ensureSchemaVersionTableExists()
m.Migrations = make([]*Migration, 0)
m.Data = make(map[string]interface{})
return
}
type MigratorFS interface {
ReadDir(dirname string) ([]os.FileInfo, error)
ReadFile(filename string) ([]byte, error)
Glob(pattern string) (matches []string, err error)
}
type defaultMigratorFS struct{}
func (defaultMigratorFS) ReadDir(dirname string) ([]os.FileInfo, error) {
return ioutil.ReadDir(dirname)
}
func (defaultMigratorFS) ReadFile(filename string) ([]byte, error) {
return ioutil.ReadFile(filename)
}
func (defaultMigratorFS) Glob(pattern string) ([]string, error) {
return filepath.Glob(pattern)
}
func FindMigrationsEx(path string, fs MigratorFS) ([]string, error) {
path = strings.TrimRight(path, string(filepath.Separator))
fileInfos, err := fs.ReadDir(path)
if err != nil {
return nil, err
}
paths := make([]string, 0, len(fileInfos))
for _, fi := range fileInfos {
if fi.IsDir() {
continue
}
matches := migrationPattern.FindStringSubmatch(fi.Name())
if len(matches) != 2 {
continue
}
n, err := strconv.ParseInt(matches[1], 10, 32)
if err != nil {
// The regexp already validated that the prefix is all digits so this *should* never fail
return nil, err
}
mcount := len(paths)
if n < int64(mcount) {
return nil, fmt.Errorf("Duplicate migration %d", n)
}
if int64(mcount) < n {
return nil, fmt.Errorf("Missing migration %d", mcount)
}
paths = append(paths, filepath.Join(path, fi.Name()))
}
return paths, nil
}
func FindMigrations(path string) ([]string, error) {
return FindMigrationsEx(path, defaultMigratorFS{})
}
func (m *Migrator) LoadMigrations(path string) error {
path = strings.TrimRight(path, string(filepath.Separator))
mainTmpl := template.New("main")
sharedPaths, err := m.options.MigratorFS.Glob(filepath.Join(path, "*", "*.sql"))
if err != nil {
return err
}
for _, p := range sharedPaths {
body, err := m.options.MigratorFS.ReadFile(p)
if err != nil {
return err
}
name := strings.Replace(p, path+string(filepath.Separator), "", 1)
_, err = mainTmpl.New(name).Parse(string(body))
if err != nil {
return err
}
}
paths, err := FindMigrationsEx(path, m.options.MigratorFS)
if err != nil {
return err
}
if len(paths) == 0 {
return NoMigrationsFoundError{Path: path}
}
for _, p := range paths {
body, err := m.options.MigratorFS.ReadFile(p)
if err != nil {
return err
}
pieces := strings.SplitN(string(body), "---- create above / drop below ----", 2)
var upSQL, downSQL string
upSQL = strings.TrimSpace(pieces[0])
upSQL, err = m.evalMigration(mainTmpl.New(filepath.Base(p)+" up"), upSQL)
if err != nil {
return err
}
// Make sure there is SQL in the forward migration step.
containsSQL := false
for _, v := range strings.Split(upSQL, "\n") {
// Only account for regular single line comment, empty line and space/comment combination
cleanString := strings.TrimSpace(v)
if len(cleanString) != 0 &&
!strings.HasPrefix(cleanString, "--") {
containsSQL = true
break
}
}
if !containsSQL {
return ErrNoFwMigration
}
if len(pieces) == 2 {
downSQL = strings.TrimSpace(pieces[1])
downSQL, err = m.evalMigration(mainTmpl.New(filepath.Base(p)+" down"), downSQL)
if err != nil {
return err
}
}
m.AppendMigration(filepath.Base(p), upSQL, downSQL)
}
return nil
}
func (m *Migrator) evalMigration(tmpl *template.Template, sql string) (string, error) {
tmpl, err := tmpl.Parse(sql)
if err != nil {
return "", err
}
var buf bytes.Buffer
err = tmpl.Execute(&buf, m.Data)
if err != nil {
return "", err
}
return buf.String(), nil
}
func (m *Migrator) AppendMigration(name, upSQL, downSQL string) {
m.Migrations = append(
m.Migrations,
&Migration{
Sequence: int32(len(m.Migrations)) + 1,
Name: name,
UpSQL: upSQL,
DownSQL: downSQL,
})
}
// Migrate runs pending migrations
// It calls m.OnStart when it begins a migration
func (m *Migrator) Migrate() error {
return m.MigrateTo(int32(len(m.Migrations)))
}
// MigrateTo migrates to targetVersion
func (m *Migrator) MigrateTo(targetVersion int32) (err error) {
// Lock to ensure multiple migrations cannot occur simultaneously
lockNum := int64(9628173550095224) // arbitrary random number
if _, lockErr := m.db.Exec("select pg_try_advisory_lock($1)", lockNum); lockErr != nil {
return lockErr
}
defer func() {
_, unlockErr := m.db.Exec("select pg_advisory_unlock($1)", lockNum)
if err == nil && unlockErr != nil {
err = unlockErr
}
}()
currentVersion, err := m.GetCurrentVersion()
if err != nil {
return err
}
if targetVersion < 0 || int32(len(m.Migrations)) < targetVersion {
errMsg := fmt.Sprintf("destination version %d is outside the valid versions of 0 to %d", targetVersion, len(m.Migrations))
return BadVersionError(errMsg)
}
if currentVersion < 0 || int32(len(m.Migrations)) < currentVersion {
errMsg := fmt.Sprintf("current version %d is outside the valid versions of 0 to %d", currentVersion, len(m.Migrations))
return BadVersionError(errMsg)
}
var direction int32
if currentVersion < targetVersion {
direction = 1
} else {
direction = -1
}
for currentVersion != targetVersion {
var current *Migration
var sql, directionName string
var sequence int32
if direction == 1 {
current = m.Migrations[currentVersion]
sequence = current.Sequence
sql = current.UpSQL
directionName = "up"
} else {
current = m.Migrations[currentVersion-1]
sequence = current.Sequence - 1
sql = current.DownSQL
directionName = "down"
if current.DownSQL == "" {
return IrreversibleMigrationError{m: current}
}
}
ctx := context.Background()
tx, err := m.db.BeginTx(ctx, nil)
if err != nil {
return err
}
defer tx.Rollback() //nolint: errcheck
// Fire on start callback
if m.OnStart != nil {
m.OnStart(current.Sequence, current.Name, directionName, sql)
}
// Execute the migration
_, err = tx.Exec(sql)
if err != nil {
// if err, ok := err.(pgx.PgError); ok {
// return MigrationPgError{Sql: sql, PgError: err}
// }
return err
}
// Reset all database connection settings. Important to do before updating version as search_path may have been changed.
// if _, err := tx.Exec(ctx, "reset all"); err != nil {
// return err
// }
// Add one to the version
_, err = tx.Exec("update "+m.versionTable+" set version=$1", sequence)
if err != nil {
return err
}
err = tx.Commit()
if err != nil {
return err
}
currentVersion = currentVersion + direction
}
return nil
}
func (m *Migrator) GetCurrentVersion() (v int32, err error) {
err = m.db.QueryRow("select version from " + m.versionTable).Scan(&v)
return v, err
}
func (m *Migrator) ensureSchemaVersionTableExists() (err error) {
_, err = m.db.Exec(fmt.Sprintf(`
create table if not exists %s(version int4 not null);
insert into %s(version)
select 0
where 0=(select count(*) from %s);
`, m.versionTable, m.versionTable, m.versionTable))
return err
}

View File

@ -0,0 +1,352 @@
package migrate_test
/*
import (
. "gopkg.in/check.v1"
)
type MigrateSuite struct {
conn *pgx.Conn
}
func Test(t *testing.T) { TestingT(t) }
var _ = Suite(&MigrateSuite{})
var versionTable string = "schema_version_non_default"
func (s *MigrateSuite) SetUpTest(c *C) {
var err error
s.conn, err = pgx.Connect(*defaultConnectionParameters)
c.Assert(err, IsNil)
s.cleanupSampleMigrator(c)
}
func (s *MigrateSuite) currentVersion(c *C) int32 {
var n int32
err := s.conn.QueryRow("select version from " + versionTable).Scan(&n)
c.Assert(err, IsNil)
return n
}
func (s *MigrateSuite) Exec(c *C, sql string, arguments ...interface{}) pgx.CommandTag {
commandTag, err := s.conn.Exec(sql, arguments...)
c.Assert(err, IsNil)
return commandTag
}
func (s *MigrateSuite) tableExists(c *C, tableName string) bool {
var exists bool
err := s.conn.QueryRow(
"select exists(select 1 from information_schema.tables where table_catalog=$1 and table_name=$2)",
defaultConnectionParameters.Database,
tableName,
).Scan(&exists)
c.Assert(err, IsNil)
return exists
}
func (s *MigrateSuite) createEmptyMigrator(c *C) *migrate.Migrator {
var err error
m, err := migrate.NewMigrator(s.conn, versionTable)
c.Assert(err, IsNil)
return m
}
func (s *MigrateSuite) createSampleMigrator(c *C) *migrate.Migrator {
m := s.createEmptyMigrator(c)
m.AppendMigration("Create t1", "create table t1(id serial);", "drop table t1;")
m.AppendMigration("Create t2", "create table t2(id serial);", "drop table t2;")
m.AppendMigration("Create t3", "create table t3(id serial);", "drop table t3;")
return m
}
func (s *MigrateSuite) cleanupSampleMigrator(c *C) {
tables := []string{versionTable, "t1", "t2", "t3"}
for _, table := range tables {
s.Exec(c, "drop table if exists "+table)
}
}
func (s *MigrateSuite) TestNewMigrator(c *C) {
var m *migrate.Migrator
var err error
// Initial run
m, err = migrate.NewMigrator(s.conn, versionTable)
c.Assert(err, IsNil)
// Creates version table
schemaVersionExists := s.tableExists(c, versionTable)
c.Assert(schemaVersionExists, Equals, true)
// Succeeds when version table is already created
m, err = migrate.NewMigrator(s.conn, versionTable)
c.Assert(err, IsNil)
initialVersion, err := m.GetCurrentVersion()
c.Assert(err, IsNil)
c.Assert(initialVersion, Equals, int32(0))
}
func (s *MigrateSuite) TestAppendMigration(c *C) {
m := s.createEmptyMigrator(c)
name := "Create t"
upSQL := "create t..."
downSQL := "drop t..."
m.AppendMigration(name, upSQL, downSQL)
c.Assert(len(m.Migrations), Equals, 1)
c.Assert(m.Migrations[0].Name, Equals, name)
c.Assert(m.Migrations[0].UpSQL, Equals, upSQL)
c.Assert(m.Migrations[0].DownSQL, Equals, downSQL)
}
func (s *MigrateSuite) TestLoadMigrationsMissingDirectory(c *C) {
m := s.createEmptyMigrator(c)
err := m.LoadMigrations("testdata/missing")
c.Assert(err, ErrorMatches, "open testdata/missing: no such file or directory")
}
func (s *MigrateSuite) TestLoadMigrationsEmptyDirectory(c *C) {
m := s.createEmptyMigrator(c)
err := m.LoadMigrations("testdata/empty")
c.Assert(err, ErrorMatches, "No migrations found at testdata/empty")
}
func (s *MigrateSuite) TestFindMigrationsWithGaps(c *C) {
_, err := migrate.FindMigrations("testdata/gap")
c.Assert(err, ErrorMatches, "Missing migration 2")
}
func (s *MigrateSuite) TestFindMigrationsWithDuplicate(c *C) {
_, err := migrate.FindMigrations("testdata/duplicate")
c.Assert(err, ErrorMatches, "Duplicate migration 2")
}
func (s *MigrateSuite) TestLoadMigrations(c *C) {
m := s.createEmptyMigrator(c)
m.Data = map[string]interface{}{"prefix": "foo"}
err := m.LoadMigrations("testdata/sample")
c.Assert(err, IsNil)
c.Assert(m.Migrations, HasLen, 5)
c.Check(m.Migrations[0].Name, Equals, "001_create_t1.sql")
c.Check(m.Migrations[0].UpSQL, Equals, `create table t1(
id serial primary key
);`)
c.Check(m.Migrations[0].DownSQL, Equals, "drop table t1;")
c.Check(m.Migrations[1].Name, Equals, "002_create_t2.sql")
c.Check(m.Migrations[1].UpSQL, Equals, `create table t2(
id serial primary key
);`)
c.Check(m.Migrations[1].DownSQL, Equals, "drop table t2;")
c.Check(m.Migrations[2].Name, Equals, "003_irreversible.sql")
c.Check(m.Migrations[2].UpSQL, Equals, "drop table t2;")
c.Check(m.Migrations[2].DownSQL, Equals, "")
c.Check(m.Migrations[3].Name, Equals, "004_data_interpolation.sql")
c.Check(m.Migrations[3].UpSQL, Equals, "create table foo_bar(id serial primary key);")
c.Check(m.Migrations[3].DownSQL, Equals, "drop table foo_bar;")
}
func (s *MigrateSuite) TestLoadMigrationsNoForward(c *C) {
var err error
m, err := migrate.NewMigrator(s.conn, versionTable)
c.Assert(err, IsNil)
m.Data = map[string]interface{}{"prefix": "foo"}
err = m.LoadMigrations("testdata/noforward")
c.Assert(err, Equals, migrate.ErrNoFwMigration)
}
func (s *MigrateSuite) TestMigrate(c *C) {
m := s.createSampleMigrator(c)
err := m.Migrate()
c.Assert(err, IsNil)
currentVersion := s.currentVersion(c)
c.Assert(currentVersion, Equals, int32(3))
}
func (s *MigrateSuite) TestMigrateToLifeCycle(c *C) {
m := s.createSampleMigrator(c)
var onStartCallUpCount int
var onStartCallDownCount int
m.OnStart = func(_ int32, _, direction, _ string) {
switch direction {
case "up":
onStartCallUpCount++
case "down":
onStartCallDownCount++
default:
c.Fatalf("Unexpected direction: %s", direction)
}
}
// Migrate from 0 up to 1
err := m.MigrateTo(1)
c.Assert(err, IsNil)
currentVersion := s.currentVersion(c)
c.Assert(currentVersion, Equals, int32(1))
c.Assert(s.tableExists(c, "t1"), Equals, true)
c.Assert(s.tableExists(c, "t2"), Equals, false)
c.Assert(s.tableExists(c, "t3"), Equals, false)
c.Assert(onStartCallUpCount, Equals, 1)
c.Assert(onStartCallDownCount, Equals, 0)
// Migrate from 1 up to 3
err = m.MigrateTo(3)
c.Assert(err, IsNil)
currentVersion = s.currentVersion(c)
c.Assert(currentVersion, Equals, int32(3))
c.Assert(s.tableExists(c, "t1"), Equals, true)
c.Assert(s.tableExists(c, "t2"), Equals, true)
c.Assert(s.tableExists(c, "t3"), Equals, true)
c.Assert(onStartCallUpCount, Equals, 3)
c.Assert(onStartCallDownCount, Equals, 0)
// Migrate from 3 to 3 is no-op
err = m.MigrateTo(3)
c.Assert(err, IsNil)
currentVersion = s.currentVersion(c)
c.Assert(s.tableExists(c, "t1"), Equals, true)
c.Assert(s.tableExists(c, "t2"), Equals, true)
c.Assert(s.tableExists(c, "t3"), Equals, true)
c.Assert(onStartCallUpCount, Equals, 3)
c.Assert(onStartCallDownCount, Equals, 0)
// Migrate from 3 down to 1
err = m.MigrateTo(1)
c.Assert(err, IsNil)
currentVersion = s.currentVersion(c)
c.Assert(currentVersion, Equals, int32(1))
c.Assert(s.tableExists(c, "t1"), Equals, true)
c.Assert(s.tableExists(c, "t2"), Equals, false)
c.Assert(s.tableExists(c, "t3"), Equals, false)
c.Assert(onStartCallUpCount, Equals, 3)
c.Assert(onStartCallDownCount, Equals, 2)
// Migrate from 1 down to 0
err = m.MigrateTo(0)
c.Assert(err, IsNil)
currentVersion = s.currentVersion(c)
c.Assert(currentVersion, Equals, int32(0))
c.Assert(s.tableExists(c, "t1"), Equals, false)
c.Assert(s.tableExists(c, "t2"), Equals, false)
c.Assert(s.tableExists(c, "t3"), Equals, false)
c.Assert(onStartCallUpCount, Equals, 3)
c.Assert(onStartCallDownCount, Equals, 3)
// Migrate back up to 3
err = m.MigrateTo(3)
c.Assert(err, IsNil)
currentVersion = s.currentVersion(c)
c.Assert(currentVersion, Equals, int32(3))
c.Assert(s.tableExists(c, "t1"), Equals, true)
c.Assert(s.tableExists(c, "t2"), Equals, true)
c.Assert(s.tableExists(c, "t3"), Equals, true)
c.Assert(onStartCallUpCount, Equals, 6)
c.Assert(onStartCallDownCount, Equals, 3)
}
func (s *MigrateSuite) TestMigrateToBoundaries(c *C) {
m := s.createSampleMigrator(c)
// Migrate to -1 is error
err := m.MigrateTo(-1)
c.Assert(err, ErrorMatches, "destination version -1 is outside the valid versions of 0 to 3")
// Migrate past end is error
err = m.MigrateTo(int32(len(m.Migrations)) + 1)
c.Assert(err, ErrorMatches, "destination version 4 is outside the valid versions of 0 to 3")
// When schema version says it is negative
s.Exec(c, "update "+versionTable+" set version=-1")
err = m.MigrateTo(int32(1))
c.Assert(err, ErrorMatches, "current version -1 is outside the valid versions of 0 to 3")
// When schema version says it is negative
s.Exec(c, "update "+versionTable+" set version=4")
err = m.MigrateTo(int32(1))
c.Assert(err, ErrorMatches, "current version 4 is outside the valid versions of 0 to 3")
}
func (s *MigrateSuite) TestMigrateToIrreversible(c *C) {
m := s.createEmptyMigrator(c)
m.AppendMigration("Foo", "drop table if exists t3", "")
err := m.MigrateTo(1)
c.Assert(err, IsNil)
err = m.MigrateTo(0)
c.Assert(err, ErrorMatches, "Irreversible migration: 1 - Foo")
}
func (s *MigrateSuite) TestMigrateToDisableTx(c *C) {
m, err := migrate.NewMigratorEx(s.conn, versionTable, &migrate.MigratorOptions{DisableTx: true})
c.Assert(err, IsNil)
m.AppendMigration("Create t1", "create table t1(id serial);", "drop table t1;")
m.AppendMigration("Create t2", "create table t2(id serial);", "drop table t2;")
m.AppendMigration("Create t3", "create table t3(id serial);", "drop table t3;")
tx, err := s.conn.Begin()
c.Assert(err, IsNil)
err = m.MigrateTo(3)
c.Assert(err, IsNil)
currentVersion := s.currentVersion(c)
c.Assert(currentVersion, Equals, int32(3))
c.Assert(s.tableExists(c, "t1"), Equals, true)
c.Assert(s.tableExists(c, "t2"), Equals, true)
c.Assert(s.tableExists(c, "t3"), Equals, true)
err = tx.Rollback()
c.Assert(err, IsNil)
currentVersion = s.currentVersion(c)
c.Assert(currentVersion, Equals, int32(0))
c.Assert(s.tableExists(c, "t1"), Equals, false)
c.Assert(s.tableExists(c, "t2"), Equals, false)
c.Assert(s.tableExists(c, "t3"), Equals, false)
}
func Example_OnStartMigrationProgressLogging() {
conn, err := pgx.Connect(*defaultConnectionParameters)
if err != nil {
fmt.Printf("Unable to establish connection: %v", err)
return
}
// Clear any previous runs
if _, err = conn.Exec("drop table if exists schema_version"); err != nil {
fmt.Printf("Unable to drop schema_version table: %v", err)
return
}
var m *migrate.Migrator
m, err = migrate.NewMigrator(conn, "schema_version")
if err != nil {
fmt.Printf("Unable to create migrator: %v", err)
return
}
m.OnStart = func(_ int32, name, direction, _ string) {
fmt.Printf("Migrating %s: %s", direction, name)
}
m.AppendMigration("create a table", "create temporary table foo(id serial primary key)", "")
if err = m.Migrate(); err != nil {
fmt.Printf("Unexpected failure migrating: %v", err)
return
}
// Output:
// Migrating up: create a table
}
*/

View File

@ -0,0 +1,7 @@
create table t1(
id serial primary key
);
---- create above / drop below ----
drop table t1;

View File

@ -0,0 +1,7 @@
create table t2(
id serial primary key
);
---- create above / drop below ----
drop table t2;

View File

@ -0,0 +1,7 @@
create table duplicate(
id serial primary key
);
---- create above / drop below ----
drop table duplicate;

View File

@ -0,0 +1,7 @@
create table t1(
id serial primary key
);
---- create above / drop below ----
drop table t1;

View File

@ -0,0 +1 @@
drop table t2;

View File

@ -0,0 +1,7 @@
-- no SQL here
-- nor here, just all comments.
-- comment with space before
---- create above / drop below ----
drop table t1;

View File

@ -0,0 +1,7 @@
create table t1(
id serial primary key
);
---- create above / drop below ----
drop table t1;

View File

@ -0,0 +1,7 @@
create table t2(
id serial primary key
);
---- create above / drop below ----
drop table t2;

View File

@ -0,0 +1 @@
drop table t2;

View File

@ -0,0 +1,5 @@
create table {{.prefix}}_bar(id serial primary key);
---- create above / drop below ----
drop table {{.prefix}}_bar;

View File

@ -0,0 +1,5 @@
{{ template "shared/v1_001.sql" . }}
---- create above / drop below ----
drop view {{.prefix}}v1;

View File

@ -0,0 +1 @@
create view {{.prefix}}v1 as select * from t1;

View File

@ -0,0 +1 @@
-- This file should be ignored because it does not start with a number.