Make go get to install work.
This commit is contained in:
127
internal/serv/internal/auth/auth.go
Normal file
127
internal/serv/internal/auth/auth.go
Normal file
@ -0,0 +1,127 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/dosco/super-graph/core"
|
||||
)
|
||||
|
||||
// Auth struct contains authentication related config values used by the Super Graph service
|
||||
type Auth struct {
|
||||
Name string
|
||||
Type string
|
||||
Cookie string
|
||||
CredsInHeader bool `mapstructure:"creds_in_header"`
|
||||
|
||||
Rails struct {
|
||||
Version string
|
||||
SecretKeyBase string `mapstructure:"secret_key_base"`
|
||||
URL string
|
||||
Password string
|
||||
MaxIdle int `mapstructure:"max_idle"`
|
||||
MaxActive int `mapstructure:"max_active"`
|
||||
Salt string
|
||||
SignSalt string `mapstructure:"sign_salt"`
|
||||
AuthSalt string `mapstructure:"auth_salt"`
|
||||
}
|
||||
|
||||
JWT struct {
|
||||
Provider string
|
||||
Secret string
|
||||
PubKeyFile string `mapstructure:"public_key_file"`
|
||||
PubKeyType string `mapstructure:"public_key_type"`
|
||||
}
|
||||
|
||||
Header struct {
|
||||
Name string
|
||||
Value string
|
||||
Exists bool
|
||||
}
|
||||
}
|
||||
|
||||
func SimpleHandler(ac *Auth, next http.Handler) (http.HandlerFunc, error) {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
userIDProvider := r.Header.Get("X-User-ID-Provider")
|
||||
if len(userIDProvider) != 0 {
|
||||
ctx = context.WithValue(ctx, core.UserIDProviderKey, userIDProvider)
|
||||
}
|
||||
|
||||
userID := r.Header.Get("X-User-ID")
|
||||
if len(userID) != 0 {
|
||||
ctx = context.WithValue(ctx, core.UserIDKey, userID)
|
||||
}
|
||||
|
||||
userRole := r.Header.Get("X-User-Role")
|
||||
if len(userRole) != 0 {
|
||||
ctx = context.WithValue(ctx, core.UserRoleKey, userRole)
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
}, nil
|
||||
}
|
||||
|
||||
func HeaderHandler(ac *Auth, next http.Handler) (http.HandlerFunc, error) {
|
||||
hdr := ac.Header
|
||||
|
||||
if len(hdr.Name) == 0 {
|
||||
return nil, fmt.Errorf("auth '%s': no header.name defined", ac.Name)
|
||||
}
|
||||
|
||||
if !hdr.Exists && len(hdr.Value) == 0 {
|
||||
return nil, fmt.Errorf("auth '%s': no header.value defined", ac.Name)
|
||||
}
|
||||
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
var fo1 bool
|
||||
value := r.Header.Get(hdr.Name)
|
||||
|
||||
switch {
|
||||
case hdr.Exists:
|
||||
fo1 = (len(value) == 0)
|
||||
|
||||
default:
|
||||
fo1 = (value != hdr.Value)
|
||||
}
|
||||
|
||||
if fo1 {
|
||||
http.Error(w, "401 unauthorized", http.StatusUnauthorized)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
}, nil
|
||||
}
|
||||
|
||||
func WithAuth(next http.Handler, ac *Auth) (http.Handler, error) {
|
||||
var err error
|
||||
|
||||
if ac.CredsInHeader {
|
||||
next, err = SimpleHandler(ac, next)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch ac.Type {
|
||||
case "rails":
|
||||
return RailsHandler(ac, next)
|
||||
|
||||
case "jwt":
|
||||
return JwtHandler(ac, next)
|
||||
|
||||
case "header":
|
||||
return HeaderHandler(ac, next)
|
||||
|
||||
}
|
||||
|
||||
return next, nil
|
||||
}
|
||||
|
||||
func IsAuth(ct context.Context) bool {
|
||||
return ct.Value(core.UserIDKey) != nil
|
||||
}
|
105
internal/serv/internal/auth/jwt.go
Normal file
105
internal/serv/internal/auth/jwt.go
Normal file
@ -0,0 +1,105 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
jwt "github.com/dgrijalva/jwt-go"
|
||||
"github.com/dosco/super-graph/core"
|
||||
)
|
||||
|
||||
const (
|
||||
authHeader = "Authorization"
|
||||
jwtAuth0 int = iota + 1
|
||||
)
|
||||
|
||||
func JwtHandler(ac *Auth, next http.Handler) (http.HandlerFunc, error) {
|
||||
var key interface{}
|
||||
var jwtProvider int
|
||||
|
||||
cookie := ac.Cookie
|
||||
|
||||
if ac.JWT.Provider == "auth0" {
|
||||
jwtProvider = jwtAuth0
|
||||
}
|
||||
|
||||
secret := ac.JWT.Secret
|
||||
publicKeyFile := ac.JWT.PubKeyFile
|
||||
|
||||
switch {
|
||||
case len(secret) != 0:
|
||||
key = []byte(secret)
|
||||
|
||||
case len(publicKeyFile) != 0:
|
||||
kd, err := ioutil.ReadFile(publicKeyFile)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
switch ac.JWT.PubKeyType {
|
||||
case "ecdsa":
|
||||
key, err = jwt.ParseECPublicKeyFromPEM(kd)
|
||||
|
||||
case "rsa":
|
||||
key, err = jwt.ParseRSAPublicKeyFromPEM(kd)
|
||||
|
||||
default:
|
||||
key, err = jwt.ParseECPublicKeyFromPEM(kd)
|
||||
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
var tok string
|
||||
|
||||
if len(cookie) != 0 {
|
||||
ck, err := r.Cookie(cookie)
|
||||
if err != nil {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
tok = ck.Value
|
||||
} else {
|
||||
ah := r.Header.Get(authHeader)
|
||||
if len(ah) < 10 {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
tok = ah[7:]
|
||||
}
|
||||
|
||||
token, err := jwt.ParseWithClaims(tok, &jwt.StandardClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
return key, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if claims, ok := token.Claims.(*jwt.StandardClaims); ok {
|
||||
ctx := r.Context()
|
||||
|
||||
if jwtProvider == jwtAuth0 {
|
||||
sub := strings.Split(claims.Subject, "|")
|
||||
if len(sub) != 2 {
|
||||
ctx = context.WithValue(ctx, core.UserIDProviderKey, sub[0])
|
||||
ctx = context.WithValue(ctx, core.UserIDKey, sub[1])
|
||||
}
|
||||
} else {
|
||||
ctx = context.WithValue(ctx, core.UserIDKey, claims.Subject)
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
}, nil
|
||||
}
|
190
internal/serv/internal/auth/rails.go
Normal file
190
internal/serv/internal/auth/rails.go
Normal file
@ -0,0 +1,190 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/bradfitz/gomemcache/memcache"
|
||||
"github.com/dosco/super-graph/core"
|
||||
"github.com/dosco/super-graph/internal/serv/internal/rails"
|
||||
"github.com/garyburd/redigo/redis"
|
||||
)
|
||||
|
||||
func RailsHandler(ac *Auth, next http.Handler) (http.HandlerFunc, error) {
|
||||
ru := ac.Rails.URL
|
||||
|
||||
if strings.HasPrefix(ru, "memcache:") {
|
||||
return RailsMemcacheHandler(ac, next)
|
||||
}
|
||||
|
||||
if strings.HasPrefix(ru, "redis:") {
|
||||
return RailsRedisHandler(ac, next)
|
||||
}
|
||||
|
||||
return RailsCookieHandler(ac, next)
|
||||
}
|
||||
|
||||
func RailsRedisHandler(ac *Auth, next http.Handler) (http.HandlerFunc, error) {
|
||||
cookie := ac.Cookie
|
||||
|
||||
if len(cookie) == 0 {
|
||||
return nil, fmt.Errorf("no auth.cookie defined")
|
||||
}
|
||||
|
||||
if len(ac.Rails.URL) == 0 {
|
||||
return nil, fmt.Errorf("no auth.rails.url defined")
|
||||
}
|
||||
|
||||
rp := &redis.Pool{
|
||||
MaxIdle: ac.Rails.MaxIdle,
|
||||
MaxActive: ac.Rails.MaxActive,
|
||||
Dial: func() (redis.Conn, error) {
|
||||
c, err := redis.DialURL(ac.Rails.URL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
pwd := ac.Rails.Password
|
||||
if len(pwd) != 0 {
|
||||
if _, err := c.Do("AUTH", pwd); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return c, nil
|
||||
},
|
||||
}
|
||||
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ck, err := r.Cookie(cookie)
|
||||
if err != nil {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("session:%s", ck.Value)
|
||||
sessionData, err := redis.Bytes(rp.Get().Do("GET", key))
|
||||
if err != nil {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
userID, err := rails.ParseCookie(string(sessionData))
|
||||
if err != nil {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), core.UserIDKey, userID)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
}, nil
|
||||
}
|
||||
|
||||
func RailsMemcacheHandler(ac *Auth, next http.Handler) (http.HandlerFunc, error) {
|
||||
cookie := ac.Cookie
|
||||
|
||||
if len(cookie) == 0 {
|
||||
return nil, fmt.Errorf("no auth.cookie defined")
|
||||
}
|
||||
|
||||
if len(ac.Rails.URL) == 0 {
|
||||
return nil, fmt.Errorf("no auth.rails.url defined")
|
||||
}
|
||||
|
||||
rURL, err := url.Parse(ac.Rails.URL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mc := memcache.New(rURL.Host)
|
||||
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ck, err := r.Cookie(cookie)
|
||||
if err != nil {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("session:%s", ck.Value)
|
||||
item, err := mc.Get(key)
|
||||
if err != nil {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
userID, err := rails.ParseCookie(string(item.Value))
|
||||
if err != nil {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), core.UserIDKey, userID)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
}, nil
|
||||
}
|
||||
|
||||
func RailsCookieHandler(ac *Auth, next http.Handler) (http.HandlerFunc, error) {
|
||||
cookie := ac.Cookie
|
||||
if len(cookie) == 0 {
|
||||
return nil, fmt.Errorf("no auth.cookie defined")
|
||||
}
|
||||
|
||||
ra, err := railsAuth(ac)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ck, err := r.Cookie(cookie)
|
||||
if err != nil || len(ck.Value) == 0 {
|
||||
// logger.Warn().Err(err).Msg("rails cookie missing")
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
userID, err := ra.ParseCookie(ck.Value)
|
||||
if err != nil {
|
||||
// logger.Warn().Err(err).Msg("failed to parse rails cookie")
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), core.UserIDKey, userID)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
}, nil
|
||||
}
|
||||
|
||||
func railsAuth(ac *Auth) (*rails.Auth, error) {
|
||||
secret := ac.Rails.SecretKeyBase
|
||||
if len(secret) == 0 {
|
||||
return nil, errors.New("no auth.rails.secret_key_base defined")
|
||||
}
|
||||
|
||||
version := ac.Rails.Version
|
||||
if len(version) == 0 {
|
||||
return nil, errors.New("no auth.rails.version defined")
|
||||
}
|
||||
|
||||
ra, err := rails.NewAuth(version, secret)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(ac.Rails.Salt) != 0 {
|
||||
ra.Salt = ac.Rails.Salt
|
||||
}
|
||||
|
||||
if len(ac.Rails.SignSalt) != 0 {
|
||||
ra.SignSalt = ac.Rails.SignSalt
|
||||
}
|
||||
|
||||
if len(ac.Rails.AuthSalt) != 0 {
|
||||
ra.AuthSalt = ac.Rails.AuthSalt
|
||||
}
|
||||
|
||||
return ra, nil
|
||||
}
|
370
internal/serv/internal/migrate/migrate.go
Normal file
370
internal/serv/internal/migrate/migrate.go
Normal file
@ -0,0 +1,370 @@
|
||||
package migrate
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var migrationPattern = regexp.MustCompile(`\A(\d+)_[^\.]+\.sql\z`)
|
||||
|
||||
var ErrNoFwMigration = errors.Errorf("no sql in forward migration step")
|
||||
|
||||
type BadVersionError string
|
||||
|
||||
func (e BadVersionError) Error() string {
|
||||
return string(e)
|
||||
}
|
||||
|
||||
type IrreversibleMigrationError struct {
|
||||
m *Migration
|
||||
}
|
||||
|
||||
func (e IrreversibleMigrationError) Error() string {
|
||||
return fmt.Sprintf("Irreversible migration: %d - %s", e.m.Sequence, e.m.Name)
|
||||
}
|
||||
|
||||
type NoMigrationsFoundError struct {
|
||||
Path string
|
||||
}
|
||||
|
||||
func (e NoMigrationsFoundError) Error() string {
|
||||
return fmt.Sprintf("No migrations found at %s", e.Path)
|
||||
}
|
||||
|
||||
type MigrationPgError struct {
|
||||
Sql string
|
||||
Error error
|
||||
}
|
||||
|
||||
type Migration struct {
|
||||
Sequence int32
|
||||
Name string
|
||||
UpSQL string
|
||||
DownSQL string
|
||||
}
|
||||
|
||||
type MigratorOptions struct {
|
||||
// DisableTx causes the Migrator not to run migrations in a transaction.
|
||||
DisableTx bool
|
||||
// MigratorFS is the interface used for collecting the migrations.
|
||||
MigratorFS MigratorFS
|
||||
}
|
||||
|
||||
type Migrator struct {
|
||||
db *sql.DB
|
||||
versionTable string
|
||||
options *MigratorOptions
|
||||
Migrations []*Migration
|
||||
OnStart func(int32, string, string, string) // OnStart is called when a migration is run with the sequence, name, direction, and SQL
|
||||
Data map[string]interface{} // Data available to use in migrations
|
||||
}
|
||||
|
||||
func NewMigrator(db *sql.DB, versionTable string) (m *Migrator, err error) {
|
||||
return NewMigratorEx(db, versionTable, &MigratorOptions{MigratorFS: defaultMigratorFS{}})
|
||||
}
|
||||
|
||||
func NewMigratorEx(db *sql.DB, versionTable string, opts *MigratorOptions) (m *Migrator, err error) {
|
||||
m = &Migrator{db: db, versionTable: versionTable, options: opts}
|
||||
err = m.ensureSchemaVersionTableExists()
|
||||
m.Migrations = make([]*Migration, 0)
|
||||
m.Data = make(map[string]interface{})
|
||||
return
|
||||
}
|
||||
|
||||
type MigratorFS interface {
|
||||
ReadDir(dirname string) ([]os.FileInfo, error)
|
||||
ReadFile(filename string) ([]byte, error)
|
||||
Glob(pattern string) (matches []string, err error)
|
||||
}
|
||||
|
||||
type defaultMigratorFS struct{}
|
||||
|
||||
func (defaultMigratorFS) ReadDir(dirname string) ([]os.FileInfo, error) {
|
||||
return ioutil.ReadDir(dirname)
|
||||
}
|
||||
|
||||
func (defaultMigratorFS) ReadFile(filename string) ([]byte, error) {
|
||||
return ioutil.ReadFile(filename)
|
||||
}
|
||||
|
||||
func (defaultMigratorFS) Glob(pattern string) ([]string, error) {
|
||||
return filepath.Glob(pattern)
|
||||
}
|
||||
|
||||
func FindMigrationsEx(path string, fs MigratorFS) ([]string, error) {
|
||||
path = strings.TrimRight(path, string(filepath.Separator))
|
||||
|
||||
fileInfos, err := fs.ReadDir(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
paths := make([]string, 0, len(fileInfos))
|
||||
for _, fi := range fileInfos {
|
||||
if fi.IsDir() {
|
||||
continue
|
||||
}
|
||||
|
||||
matches := migrationPattern.FindStringSubmatch(fi.Name())
|
||||
if len(matches) != 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
n, err := strconv.ParseInt(matches[1], 10, 32)
|
||||
if err != nil {
|
||||
// The regexp already validated that the prefix is all digits so this *should* never fail
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mcount := len(paths)
|
||||
|
||||
if n < int64(mcount) {
|
||||
return nil, fmt.Errorf("Duplicate migration %d", n)
|
||||
}
|
||||
|
||||
if int64(mcount) < n {
|
||||
return nil, fmt.Errorf("Missing migration %d", mcount)
|
||||
}
|
||||
|
||||
paths = append(paths, filepath.Join(path, fi.Name()))
|
||||
}
|
||||
|
||||
return paths, nil
|
||||
}
|
||||
|
||||
func FindMigrations(path string) ([]string, error) {
|
||||
return FindMigrationsEx(path, defaultMigratorFS{})
|
||||
}
|
||||
|
||||
func (m *Migrator) LoadMigrations(path string) error {
|
||||
path = strings.TrimRight(path, string(filepath.Separator))
|
||||
|
||||
mainTmpl := template.New("main")
|
||||
sharedPaths, err := m.options.MigratorFS.Glob(filepath.Join(path, "*", "*.sql"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
for _, p := range sharedPaths {
|
||||
body, err := m.options.MigratorFS.ReadFile(p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
name := strings.Replace(p, path+string(filepath.Separator), "", 1)
|
||||
_, err = mainTmpl.New(name).Parse(string(body))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
paths, err := FindMigrationsEx(path, m.options.MigratorFS)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(paths) == 0 {
|
||||
return NoMigrationsFoundError{Path: path}
|
||||
}
|
||||
|
||||
for _, p := range paths {
|
||||
body, err := m.options.MigratorFS.ReadFile(p)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
pieces := strings.SplitN(string(body), "---- create above / drop below ----", 2)
|
||||
var upSQL, downSQL string
|
||||
upSQL = strings.TrimSpace(pieces[0])
|
||||
upSQL, err = m.evalMigration(mainTmpl.New(filepath.Base(p)+" up"), upSQL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Make sure there is SQL in the forward migration step.
|
||||
containsSQL := false
|
||||
for _, v := range strings.Split(upSQL, "\n") {
|
||||
// Only account for regular single line comment, empty line and space/comment combination
|
||||
cleanString := strings.TrimSpace(v)
|
||||
if len(cleanString) != 0 &&
|
||||
!strings.HasPrefix(cleanString, "--") {
|
||||
containsSQL = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !containsSQL {
|
||||
return ErrNoFwMigration
|
||||
}
|
||||
|
||||
if len(pieces) == 2 {
|
||||
downSQL = strings.TrimSpace(pieces[1])
|
||||
downSQL, err = m.evalMigration(mainTmpl.New(filepath.Base(p)+" down"), downSQL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
m.AppendMigration(filepath.Base(p), upSQL, downSQL)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Migrator) evalMigration(tmpl *template.Template, sql string) (string, error) {
|
||||
tmpl, err := tmpl.Parse(sql)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var buf bytes.Buffer
|
||||
err = tmpl.Execute(&buf, m.Data)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
return buf.String(), nil
|
||||
}
|
||||
|
||||
func (m *Migrator) AppendMigration(name, upSQL, downSQL string) {
|
||||
m.Migrations = append(
|
||||
m.Migrations,
|
||||
&Migration{
|
||||
Sequence: int32(len(m.Migrations)) + 1,
|
||||
Name: name,
|
||||
UpSQL: upSQL,
|
||||
DownSQL: downSQL,
|
||||
})
|
||||
}
|
||||
|
||||
// Migrate runs pending migrations
|
||||
// It calls m.OnStart when it begins a migration
|
||||
func (m *Migrator) Migrate() error {
|
||||
return m.MigrateTo(int32(len(m.Migrations)))
|
||||
}
|
||||
|
||||
// MigrateTo migrates to targetVersion
|
||||
func (m *Migrator) MigrateTo(targetVersion int32) (err error) {
|
||||
// Lock to ensure multiple migrations cannot occur simultaneously
|
||||
lockNum := int64(9628173550095224) // arbitrary random number
|
||||
if _, lockErr := m.db.Exec("select pg_try_advisory_lock($1)", lockNum); lockErr != nil {
|
||||
return lockErr
|
||||
}
|
||||
defer func() {
|
||||
_, unlockErr := m.db.Exec("select pg_advisory_unlock($1)", lockNum)
|
||||
if err == nil && unlockErr != nil {
|
||||
err = unlockErr
|
||||
}
|
||||
}()
|
||||
|
||||
currentVersion, err := m.GetCurrentVersion()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if targetVersion < 0 || int32(len(m.Migrations)) < targetVersion {
|
||||
errMsg := fmt.Sprintf("destination version %d is outside the valid versions of 0 to %d", targetVersion, len(m.Migrations))
|
||||
return BadVersionError(errMsg)
|
||||
}
|
||||
|
||||
if currentVersion < 0 || int32(len(m.Migrations)) < currentVersion {
|
||||
errMsg := fmt.Sprintf("current version %d is outside the valid versions of 0 to %d", currentVersion, len(m.Migrations))
|
||||
return BadVersionError(errMsg)
|
||||
}
|
||||
|
||||
var direction int32
|
||||
if currentVersion < targetVersion {
|
||||
direction = 1
|
||||
} else {
|
||||
direction = -1
|
||||
}
|
||||
|
||||
for currentVersion != targetVersion {
|
||||
var current *Migration
|
||||
var sql, directionName string
|
||||
var sequence int32
|
||||
if direction == 1 {
|
||||
current = m.Migrations[currentVersion]
|
||||
sequence = current.Sequence
|
||||
sql = current.UpSQL
|
||||
directionName = "up"
|
||||
} else {
|
||||
current = m.Migrations[currentVersion-1]
|
||||
sequence = current.Sequence - 1
|
||||
sql = current.DownSQL
|
||||
directionName = "down"
|
||||
if current.DownSQL == "" {
|
||||
return IrreversibleMigrationError{m: current}
|
||||
}
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
tx, err := m.db.BeginTx(ctx, nil)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback() //nolint: errcheck
|
||||
|
||||
// Fire on start callback
|
||||
if m.OnStart != nil {
|
||||
m.OnStart(current.Sequence, current.Name, directionName, sql)
|
||||
}
|
||||
|
||||
// Execute the migration
|
||||
_, err = tx.Exec(sql)
|
||||
if err != nil {
|
||||
// if err, ok := err.(pgx.PgError); ok {
|
||||
// return MigrationPgError{Sql: sql, PgError: err}
|
||||
// }
|
||||
return err
|
||||
}
|
||||
|
||||
// Reset all database connection settings. Important to do before updating version as search_path may have been changed.
|
||||
// if _, err := tx.Exec(ctx, "reset all"); err != nil {
|
||||
// return err
|
||||
// }
|
||||
|
||||
// Add one to the version
|
||||
_, err = tx.Exec("update "+m.versionTable+" set version=$1", sequence)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
currentVersion = currentVersion + direction
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Migrator) GetCurrentVersion() (v int32, err error) {
|
||||
err = m.db.QueryRow("select version from " + m.versionTable).Scan(&v)
|
||||
|
||||
return v, err
|
||||
}
|
||||
|
||||
func (m *Migrator) ensureSchemaVersionTableExists() (err error) {
|
||||
_, err = m.db.Exec(fmt.Sprintf(`
|
||||
create table if not exists %s(version int4 not null);
|
||||
|
||||
insert into %s(version)
|
||||
select 0
|
||||
where 0=(select count(*) from %s);
|
||||
`, m.versionTable, m.versionTable, m.versionTable))
|
||||
|
||||
return err
|
||||
}
|
352
internal/serv/internal/migrate/migrate_test.go
Normal file
352
internal/serv/internal/migrate/migrate_test.go
Normal file
@ -0,0 +1,352 @@
|
||||
package migrate_test
|
||||
|
||||
/*
|
||||
import (
|
||||
. "gopkg.in/check.v1"
|
||||
)
|
||||
|
||||
|
||||
type MigrateSuite struct {
|
||||
conn *pgx.Conn
|
||||
}
|
||||
|
||||
func Test(t *testing.T) { TestingT(t) }
|
||||
|
||||
var _ = Suite(&MigrateSuite{})
|
||||
|
||||
var versionTable string = "schema_version_non_default"
|
||||
|
||||
func (s *MigrateSuite) SetUpTest(c *C) {
|
||||
var err error
|
||||
s.conn, err = pgx.Connect(*defaultConnectionParameters)
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
s.cleanupSampleMigrator(c)
|
||||
}
|
||||
|
||||
func (s *MigrateSuite) currentVersion(c *C) int32 {
|
||||
var n int32
|
||||
err := s.conn.QueryRow("select version from " + versionTable).Scan(&n)
|
||||
c.Assert(err, IsNil)
|
||||
return n
|
||||
}
|
||||
|
||||
func (s *MigrateSuite) Exec(c *C, sql string, arguments ...interface{}) pgx.CommandTag {
|
||||
commandTag, err := s.conn.Exec(sql, arguments...)
|
||||
c.Assert(err, IsNil)
|
||||
return commandTag
|
||||
}
|
||||
|
||||
func (s *MigrateSuite) tableExists(c *C, tableName string) bool {
|
||||
var exists bool
|
||||
err := s.conn.QueryRow(
|
||||
"select exists(select 1 from information_schema.tables where table_catalog=$1 and table_name=$2)",
|
||||
defaultConnectionParameters.Database,
|
||||
tableName,
|
||||
).Scan(&exists)
|
||||
c.Assert(err, IsNil)
|
||||
return exists
|
||||
}
|
||||
|
||||
func (s *MigrateSuite) createEmptyMigrator(c *C) *migrate.Migrator {
|
||||
var err error
|
||||
m, err := migrate.NewMigrator(s.conn, versionTable)
|
||||
c.Assert(err, IsNil)
|
||||
return m
|
||||
}
|
||||
|
||||
func (s *MigrateSuite) createSampleMigrator(c *C) *migrate.Migrator {
|
||||
m := s.createEmptyMigrator(c)
|
||||
m.AppendMigration("Create t1", "create table t1(id serial);", "drop table t1;")
|
||||
m.AppendMigration("Create t2", "create table t2(id serial);", "drop table t2;")
|
||||
m.AppendMigration("Create t3", "create table t3(id serial);", "drop table t3;")
|
||||
return m
|
||||
}
|
||||
|
||||
func (s *MigrateSuite) cleanupSampleMigrator(c *C) {
|
||||
tables := []string{versionTable, "t1", "t2", "t3"}
|
||||
for _, table := range tables {
|
||||
s.Exec(c, "drop table if exists "+table)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *MigrateSuite) TestNewMigrator(c *C) {
|
||||
var m *migrate.Migrator
|
||||
var err error
|
||||
|
||||
// Initial run
|
||||
m, err = migrate.NewMigrator(s.conn, versionTable)
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
// Creates version table
|
||||
schemaVersionExists := s.tableExists(c, versionTable)
|
||||
c.Assert(schemaVersionExists, Equals, true)
|
||||
|
||||
// Succeeds when version table is already created
|
||||
m, err = migrate.NewMigrator(s.conn, versionTable)
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
initialVersion, err := m.GetCurrentVersion()
|
||||
c.Assert(err, IsNil)
|
||||
c.Assert(initialVersion, Equals, int32(0))
|
||||
}
|
||||
|
||||
func (s *MigrateSuite) TestAppendMigration(c *C) {
|
||||
m := s.createEmptyMigrator(c)
|
||||
|
||||
name := "Create t"
|
||||
upSQL := "create t..."
|
||||
downSQL := "drop t..."
|
||||
m.AppendMigration(name, upSQL, downSQL)
|
||||
|
||||
c.Assert(len(m.Migrations), Equals, 1)
|
||||
c.Assert(m.Migrations[0].Name, Equals, name)
|
||||
c.Assert(m.Migrations[0].UpSQL, Equals, upSQL)
|
||||
c.Assert(m.Migrations[0].DownSQL, Equals, downSQL)
|
||||
}
|
||||
|
||||
func (s *MigrateSuite) TestLoadMigrationsMissingDirectory(c *C) {
|
||||
m := s.createEmptyMigrator(c)
|
||||
err := m.LoadMigrations("testdata/missing")
|
||||
c.Assert(err, ErrorMatches, "open testdata/missing: no such file or directory")
|
||||
}
|
||||
|
||||
func (s *MigrateSuite) TestLoadMigrationsEmptyDirectory(c *C) {
|
||||
m := s.createEmptyMigrator(c)
|
||||
err := m.LoadMigrations("testdata/empty")
|
||||
c.Assert(err, ErrorMatches, "No migrations found at testdata/empty")
|
||||
}
|
||||
|
||||
func (s *MigrateSuite) TestFindMigrationsWithGaps(c *C) {
|
||||
_, err := migrate.FindMigrations("testdata/gap")
|
||||
c.Assert(err, ErrorMatches, "Missing migration 2")
|
||||
}
|
||||
|
||||
func (s *MigrateSuite) TestFindMigrationsWithDuplicate(c *C) {
|
||||
_, err := migrate.FindMigrations("testdata/duplicate")
|
||||
c.Assert(err, ErrorMatches, "Duplicate migration 2")
|
||||
}
|
||||
|
||||
func (s *MigrateSuite) TestLoadMigrations(c *C) {
|
||||
m := s.createEmptyMigrator(c)
|
||||
m.Data = map[string]interface{}{"prefix": "foo"}
|
||||
err := m.LoadMigrations("testdata/sample")
|
||||
c.Assert(err, IsNil)
|
||||
c.Assert(m.Migrations, HasLen, 5)
|
||||
|
||||
c.Check(m.Migrations[0].Name, Equals, "001_create_t1.sql")
|
||||
c.Check(m.Migrations[0].UpSQL, Equals, `create table t1(
|
||||
id serial primary key
|
||||
);`)
|
||||
c.Check(m.Migrations[0].DownSQL, Equals, "drop table t1;")
|
||||
|
||||
c.Check(m.Migrations[1].Name, Equals, "002_create_t2.sql")
|
||||
c.Check(m.Migrations[1].UpSQL, Equals, `create table t2(
|
||||
id serial primary key
|
||||
);`)
|
||||
c.Check(m.Migrations[1].DownSQL, Equals, "drop table t2;")
|
||||
|
||||
c.Check(m.Migrations[2].Name, Equals, "003_irreversible.sql")
|
||||
c.Check(m.Migrations[2].UpSQL, Equals, "drop table t2;")
|
||||
c.Check(m.Migrations[2].DownSQL, Equals, "")
|
||||
|
||||
c.Check(m.Migrations[3].Name, Equals, "004_data_interpolation.sql")
|
||||
c.Check(m.Migrations[3].UpSQL, Equals, "create table foo_bar(id serial primary key);")
|
||||
c.Check(m.Migrations[3].DownSQL, Equals, "drop table foo_bar;")
|
||||
}
|
||||
|
||||
func (s *MigrateSuite) TestLoadMigrationsNoForward(c *C) {
|
||||
var err error
|
||||
m, err := migrate.NewMigrator(s.conn, versionTable)
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
m.Data = map[string]interface{}{"prefix": "foo"}
|
||||
err = m.LoadMigrations("testdata/noforward")
|
||||
c.Assert(err, Equals, migrate.ErrNoFwMigration)
|
||||
}
|
||||
|
||||
func (s *MigrateSuite) TestMigrate(c *C) {
|
||||
m := s.createSampleMigrator(c)
|
||||
|
||||
err := m.Migrate()
|
||||
c.Assert(err, IsNil)
|
||||
currentVersion := s.currentVersion(c)
|
||||
c.Assert(currentVersion, Equals, int32(3))
|
||||
}
|
||||
|
||||
func (s *MigrateSuite) TestMigrateToLifeCycle(c *C) {
|
||||
m := s.createSampleMigrator(c)
|
||||
|
||||
var onStartCallUpCount int
|
||||
var onStartCallDownCount int
|
||||
m.OnStart = func(_ int32, _, direction, _ string) {
|
||||
switch direction {
|
||||
case "up":
|
||||
onStartCallUpCount++
|
||||
case "down":
|
||||
onStartCallDownCount++
|
||||
default:
|
||||
c.Fatalf("Unexpected direction: %s", direction)
|
||||
}
|
||||
}
|
||||
|
||||
// Migrate from 0 up to 1
|
||||
err := m.MigrateTo(1)
|
||||
c.Assert(err, IsNil)
|
||||
currentVersion := s.currentVersion(c)
|
||||
c.Assert(currentVersion, Equals, int32(1))
|
||||
c.Assert(s.tableExists(c, "t1"), Equals, true)
|
||||
c.Assert(s.tableExists(c, "t2"), Equals, false)
|
||||
c.Assert(s.tableExists(c, "t3"), Equals, false)
|
||||
c.Assert(onStartCallUpCount, Equals, 1)
|
||||
c.Assert(onStartCallDownCount, Equals, 0)
|
||||
|
||||
// Migrate from 1 up to 3
|
||||
err = m.MigrateTo(3)
|
||||
c.Assert(err, IsNil)
|
||||
currentVersion = s.currentVersion(c)
|
||||
c.Assert(currentVersion, Equals, int32(3))
|
||||
c.Assert(s.tableExists(c, "t1"), Equals, true)
|
||||
c.Assert(s.tableExists(c, "t2"), Equals, true)
|
||||
c.Assert(s.tableExists(c, "t3"), Equals, true)
|
||||
c.Assert(onStartCallUpCount, Equals, 3)
|
||||
c.Assert(onStartCallDownCount, Equals, 0)
|
||||
|
||||
// Migrate from 3 to 3 is no-op
|
||||
err = m.MigrateTo(3)
|
||||
c.Assert(err, IsNil)
|
||||
currentVersion = s.currentVersion(c)
|
||||
c.Assert(s.tableExists(c, "t1"), Equals, true)
|
||||
c.Assert(s.tableExists(c, "t2"), Equals, true)
|
||||
c.Assert(s.tableExists(c, "t3"), Equals, true)
|
||||
c.Assert(onStartCallUpCount, Equals, 3)
|
||||
c.Assert(onStartCallDownCount, Equals, 0)
|
||||
|
||||
// Migrate from 3 down to 1
|
||||
err = m.MigrateTo(1)
|
||||
c.Assert(err, IsNil)
|
||||
currentVersion = s.currentVersion(c)
|
||||
c.Assert(currentVersion, Equals, int32(1))
|
||||
c.Assert(s.tableExists(c, "t1"), Equals, true)
|
||||
c.Assert(s.tableExists(c, "t2"), Equals, false)
|
||||
c.Assert(s.tableExists(c, "t3"), Equals, false)
|
||||
c.Assert(onStartCallUpCount, Equals, 3)
|
||||
c.Assert(onStartCallDownCount, Equals, 2)
|
||||
|
||||
// Migrate from 1 down to 0
|
||||
err = m.MigrateTo(0)
|
||||
c.Assert(err, IsNil)
|
||||
currentVersion = s.currentVersion(c)
|
||||
c.Assert(currentVersion, Equals, int32(0))
|
||||
c.Assert(s.tableExists(c, "t1"), Equals, false)
|
||||
c.Assert(s.tableExists(c, "t2"), Equals, false)
|
||||
c.Assert(s.tableExists(c, "t3"), Equals, false)
|
||||
c.Assert(onStartCallUpCount, Equals, 3)
|
||||
c.Assert(onStartCallDownCount, Equals, 3)
|
||||
|
||||
// Migrate back up to 3
|
||||
err = m.MigrateTo(3)
|
||||
c.Assert(err, IsNil)
|
||||
currentVersion = s.currentVersion(c)
|
||||
c.Assert(currentVersion, Equals, int32(3))
|
||||
c.Assert(s.tableExists(c, "t1"), Equals, true)
|
||||
c.Assert(s.tableExists(c, "t2"), Equals, true)
|
||||
c.Assert(s.tableExists(c, "t3"), Equals, true)
|
||||
c.Assert(onStartCallUpCount, Equals, 6)
|
||||
c.Assert(onStartCallDownCount, Equals, 3)
|
||||
}
|
||||
|
||||
func (s *MigrateSuite) TestMigrateToBoundaries(c *C) {
|
||||
m := s.createSampleMigrator(c)
|
||||
|
||||
// Migrate to -1 is error
|
||||
err := m.MigrateTo(-1)
|
||||
c.Assert(err, ErrorMatches, "destination version -1 is outside the valid versions of 0 to 3")
|
||||
|
||||
// Migrate past end is error
|
||||
err = m.MigrateTo(int32(len(m.Migrations)) + 1)
|
||||
c.Assert(err, ErrorMatches, "destination version 4 is outside the valid versions of 0 to 3")
|
||||
|
||||
// When schema version says it is negative
|
||||
s.Exec(c, "update "+versionTable+" set version=-1")
|
||||
err = m.MigrateTo(int32(1))
|
||||
c.Assert(err, ErrorMatches, "current version -1 is outside the valid versions of 0 to 3")
|
||||
|
||||
// When schema version says it is negative
|
||||
s.Exec(c, "update "+versionTable+" set version=4")
|
||||
err = m.MigrateTo(int32(1))
|
||||
c.Assert(err, ErrorMatches, "current version 4 is outside the valid versions of 0 to 3")
|
||||
}
|
||||
|
||||
func (s *MigrateSuite) TestMigrateToIrreversible(c *C) {
|
||||
m := s.createEmptyMigrator(c)
|
||||
m.AppendMigration("Foo", "drop table if exists t3", "")
|
||||
|
||||
err := m.MigrateTo(1)
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
err = m.MigrateTo(0)
|
||||
c.Assert(err, ErrorMatches, "Irreversible migration: 1 - Foo")
|
||||
}
|
||||
|
||||
func (s *MigrateSuite) TestMigrateToDisableTx(c *C) {
|
||||
m, err := migrate.NewMigratorEx(s.conn, versionTable, &migrate.MigratorOptions{DisableTx: true})
|
||||
c.Assert(err, IsNil)
|
||||
m.AppendMigration("Create t1", "create table t1(id serial);", "drop table t1;")
|
||||
m.AppendMigration("Create t2", "create table t2(id serial);", "drop table t2;")
|
||||
m.AppendMigration("Create t3", "create table t3(id serial);", "drop table t3;")
|
||||
|
||||
tx, err := s.conn.Begin()
|
||||
c.Assert(err, IsNil)
|
||||
|
||||
err = m.MigrateTo(3)
|
||||
c.Assert(err, IsNil)
|
||||
currentVersion := s.currentVersion(c)
|
||||
c.Assert(currentVersion, Equals, int32(3))
|
||||
c.Assert(s.tableExists(c, "t1"), Equals, true)
|
||||
c.Assert(s.tableExists(c, "t2"), Equals, true)
|
||||
c.Assert(s.tableExists(c, "t3"), Equals, true)
|
||||
|
||||
err = tx.Rollback()
|
||||
c.Assert(err, IsNil)
|
||||
currentVersion = s.currentVersion(c)
|
||||
c.Assert(currentVersion, Equals, int32(0))
|
||||
c.Assert(s.tableExists(c, "t1"), Equals, false)
|
||||
c.Assert(s.tableExists(c, "t2"), Equals, false)
|
||||
c.Assert(s.tableExists(c, "t3"), Equals, false)
|
||||
}
|
||||
|
||||
func Example_OnStartMigrationProgressLogging() {
|
||||
conn, err := pgx.Connect(*defaultConnectionParameters)
|
||||
if err != nil {
|
||||
fmt.Printf("Unable to establish connection: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
// Clear any previous runs
|
||||
if _, err = conn.Exec("drop table if exists schema_version"); err != nil {
|
||||
fmt.Printf("Unable to drop schema_version table: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
var m *migrate.Migrator
|
||||
m, err = migrate.NewMigrator(conn, "schema_version")
|
||||
if err != nil {
|
||||
fmt.Printf("Unable to create migrator: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
m.OnStart = func(_ int32, name, direction, _ string) {
|
||||
fmt.Printf("Migrating %s: %s", direction, name)
|
||||
}
|
||||
|
||||
m.AppendMigration("create a table", "create temporary table foo(id serial primary key)", "")
|
||||
|
||||
if err = m.Migrate(); err != nil {
|
||||
fmt.Printf("Unexpected failure migrating: %v", err)
|
||||
return
|
||||
}
|
||||
// Output:
|
||||
// Migrating up: create a table
|
||||
}
|
||||
*/
|
7
internal/serv/internal/migrate/testdata/duplicate/001_create_t1.sql
vendored
Normal file
7
internal/serv/internal/migrate/testdata/duplicate/001_create_t1.sql
vendored
Normal file
@ -0,0 +1,7 @@
|
||||
create table t1(
|
||||
id serial primary key
|
||||
);
|
||||
|
||||
---- create above / drop below ----
|
||||
|
||||
drop table t1;
|
7
internal/serv/internal/migrate/testdata/duplicate/002_create_t2.sql
vendored
Normal file
7
internal/serv/internal/migrate/testdata/duplicate/002_create_t2.sql
vendored
Normal file
@ -0,0 +1,7 @@
|
||||
create table t2(
|
||||
id serial primary key
|
||||
);
|
||||
|
||||
---- create above / drop below ----
|
||||
|
||||
drop table t2;
|
7
internal/serv/internal/migrate/testdata/duplicate/002_duplicate.sql
vendored
Normal file
7
internal/serv/internal/migrate/testdata/duplicate/002_duplicate.sql
vendored
Normal file
@ -0,0 +1,7 @@
|
||||
create table duplicate(
|
||||
id serial primary key
|
||||
);
|
||||
|
||||
---- create above / drop below ----
|
||||
|
||||
drop table duplicate;
|
0
internal/serv/internal/migrate/testdata/empty/.gitignore
vendored
Normal file
0
internal/serv/internal/migrate/testdata/empty/.gitignore
vendored
Normal file
0
internal/serv/internal/migrate/testdata/gap/001_create_people.sql.example
vendored
Normal file
0
internal/serv/internal/migrate/testdata/gap/001_create_people.sql.example
vendored
Normal file
7
internal/serv/internal/migrate/testdata/gap/001_create_t1.sql
vendored
Normal file
7
internal/serv/internal/migrate/testdata/gap/001_create_t1.sql
vendored
Normal file
@ -0,0 +1,7 @@
|
||||
create table t1(
|
||||
id serial primary key
|
||||
);
|
||||
|
||||
---- create above / drop below ----
|
||||
|
||||
drop table t1;
|
1
internal/serv/internal/migrate/testdata/gap/003_irreversible.sql
vendored
Normal file
1
internal/serv/internal/migrate/testdata/gap/003_irreversible.sql
vendored
Normal file
@ -0,0 +1 @@
|
||||
drop table t2;
|
7
internal/serv/internal/migrate/testdata/noforward/001_create_no_forward.sql
vendored
Normal file
7
internal/serv/internal/migrate/testdata/noforward/001_create_no_forward.sql
vendored
Normal file
@ -0,0 +1,7 @@
|
||||
-- no SQL here
|
||||
-- nor here, just all comments.
|
||||
-- comment with space before
|
||||
|
||||
---- create above / drop below ----
|
||||
|
||||
drop table t1;
|
7
internal/serv/internal/migrate/testdata/sample/001_create_t1.sql
vendored
Normal file
7
internal/serv/internal/migrate/testdata/sample/001_create_t1.sql
vendored
Normal file
@ -0,0 +1,7 @@
|
||||
create table t1(
|
||||
id serial primary key
|
||||
);
|
||||
|
||||
---- create above / drop below ----
|
||||
|
||||
drop table t1;
|
7
internal/serv/internal/migrate/testdata/sample/002_create_t2.sql
vendored
Normal file
7
internal/serv/internal/migrate/testdata/sample/002_create_t2.sql
vendored
Normal file
@ -0,0 +1,7 @@
|
||||
create table t2(
|
||||
id serial primary key
|
||||
);
|
||||
|
||||
---- create above / drop below ----
|
||||
|
||||
drop table t2;
|
1
internal/serv/internal/migrate/testdata/sample/003_irreversible.sql
vendored
Normal file
1
internal/serv/internal/migrate/testdata/sample/003_irreversible.sql
vendored
Normal file
@ -0,0 +1 @@
|
||||
drop table t2;
|
5
internal/serv/internal/migrate/testdata/sample/004_data_interpolation.sql
vendored
Normal file
5
internal/serv/internal/migrate/testdata/sample/004_data_interpolation.sql
vendored
Normal file
@ -0,0 +1,5 @@
|
||||
create table {{.prefix}}_bar(id serial primary key);
|
||||
|
||||
---- create above / drop below ----
|
||||
|
||||
drop table {{.prefix}}_bar;
|
5
internal/serv/internal/migrate/testdata/sample/005_template_inclusion.sql
vendored
Normal file
5
internal/serv/internal/migrate/testdata/sample/005_template_inclusion.sql
vendored
Normal file
@ -0,0 +1,5 @@
|
||||
{{ template "shared/v1_001.sql" . }}
|
||||
|
||||
---- create above / drop below ----
|
||||
|
||||
drop view {{.prefix}}v1;
|
1
internal/serv/internal/migrate/testdata/sample/shared/v1_001.sql
vendored
Normal file
1
internal/serv/internal/migrate/testdata/sample/shared/v1_001.sql
vendored
Normal file
@ -0,0 +1 @@
|
||||
create view {{.prefix}}v1 as select * from t1;
|
1
internal/serv/internal/migrate/testdata/sample/should_be_ignored.sql
vendored
Normal file
1
internal/serv/internal/migrate/testdata/sample/should_be_ignored.sql
vendored
Normal file
@ -0,0 +1 @@
|
||||
-- This file should be ignored because it does not start with a number.
|
175
internal/serv/internal/rails/auth.go
Normal file
175
internal/serv/internal/rails/auth.go
Normal file
@ -0,0 +1,175 @@
|
||||
package rails
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/adjust/gorails/marshal"
|
||||
)
|
||||
|
||||
const (
|
||||
salt = "encrypted cookie"
|
||||
signSalt = "signed encrypted cookie"
|
||||
authSalt = "authenticated encrypted cookie"
|
||||
railsCipher = "aes-256-cbc"
|
||||
railsCipher52 = "aes-256-gcm"
|
||||
)
|
||||
|
||||
var (
|
||||
errSessionData = errors.New("error decoding session data")
|
||||
)
|
||||
|
||||
type Auth struct {
|
||||
Cipher string
|
||||
Secret string
|
||||
Salt string
|
||||
SignSalt string
|
||||
AuthSalt string
|
||||
}
|
||||
|
||||
func NewAuth(version, secret string) (*Auth, error) {
|
||||
ra := &Auth{
|
||||
Secret: secret,
|
||||
Salt: salt,
|
||||
SignSalt: signSalt,
|
||||
AuthSalt: authSalt,
|
||||
}
|
||||
|
||||
var v1, v2 int
|
||||
var err error
|
||||
|
||||
sv := strings.Split(version, ".")
|
||||
if len(sv) >= 2 {
|
||||
if v1, err = strconv.Atoi(sv[0]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if v2, err = strconv.Atoi(sv[1]); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if v1 >= 5 && v2 >= 2 {
|
||||
ra.Cipher = railsCipher52
|
||||
} else {
|
||||
ra.Cipher = railsCipher
|
||||
}
|
||||
|
||||
return ra, nil
|
||||
}
|
||||
|
||||
func (ra *Auth) ParseCookie(cookie string) (userID string, err error) {
|
||||
var dcookie []byte
|
||||
|
||||
switch ra.Cipher {
|
||||
case railsCipher:
|
||||
dcookie, err = parseCookie(cookie, ra.Secret, ra.Salt, ra.SignSalt)
|
||||
|
||||
case railsCipher52:
|
||||
dcookie, err = parseCookie52(cookie, ra.Secret, ra.AuthSalt)
|
||||
|
||||
default:
|
||||
err = fmt.Errorf("unknown rails cookie cipher '%s'", ra.Cipher)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
if dcookie[0] != '{' {
|
||||
userID, err = getUserId4(dcookie)
|
||||
} else {
|
||||
userID, err = getUserId(dcookie)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func ParseCookie(cookie string) (string, error) {
|
||||
if cookie[0] != '{' {
|
||||
return getUserId4([]byte(cookie))
|
||||
}
|
||||
|
||||
return getUserId([]byte(cookie))
|
||||
}
|
||||
|
||||
func getUserId(data []byte) (userID string, err error) {
|
||||
var sessionData map[string]interface{}
|
||||
|
||||
err = json.Unmarshal(data, &sessionData)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
userKey, ok := sessionData["warden.user.user.key"]
|
||||
if !ok {
|
||||
err = errors.New("key 'warden.user.user.key' not found in session data")
|
||||
}
|
||||
|
||||
items, ok := userKey.([]interface{})
|
||||
if !ok {
|
||||
err = errSessionData
|
||||
return
|
||||
}
|
||||
|
||||
if len(items) != 2 {
|
||||
err = errSessionData
|
||||
return
|
||||
}
|
||||
|
||||
uids, ok := items[0].([]interface{})
|
||||
if !ok {
|
||||
err = errSessionData
|
||||
return
|
||||
}
|
||||
|
||||
uid, ok := uids[0].(float64)
|
||||
if !ok {
|
||||
err = errSessionData
|
||||
return
|
||||
}
|
||||
userID = fmt.Sprintf("%d", int64(uid))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func getUserId4(data []byte) (userID string, err error) {
|
||||
sessionData, err := marshal.CreateMarshalledObject(data).GetAsMap()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
wardenData, ok := sessionData["warden.user.user.key"]
|
||||
if !ok {
|
||||
err = errSessionData
|
||||
return
|
||||
}
|
||||
|
||||
wardenUserKey, err := wardenData.GetAsArray()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if len(wardenUserKey) < 1 {
|
||||
err = errSessionData
|
||||
return
|
||||
}
|
||||
|
||||
userData, err := wardenUserKey[0].GetAsArray()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if len(userData) < 1 {
|
||||
err = errSessionData
|
||||
return
|
||||
}
|
||||
|
||||
uid, err := userData[0].GetAsInteger()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
userID = fmt.Sprintf("%d", uid)
|
||||
|
||||
return
|
||||
}
|
79
internal/serv/internal/rails/auth_test.go
Normal file
79
internal/serv/internal/rails/auth_test.go
Normal file
@ -0,0 +1,79 @@
|
||||
package rails
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRailsEncryptedSession1(t *testing.T) {
|
||||
cookie := "dDdjMW5jYUNYaFpBT1BSdFgwQkk4ZWNlT214L1FnM0pyZzZ1d21nSnVTTm9zS0ljN000S1JmT3cxcTNtRld2Ny0tQUFBQUFBQUFBQUFBQUFBQUFBQUFBQT09--75d8323b0f0e41cf4d5aabee1b229b1be76b83b6"
|
||||
|
||||
secret := "development_secret"
|
||||
|
||||
ra := Auth{
|
||||
Cipher: railsCipher,
|
||||
Secret: secret,
|
||||
Salt: salt,
|
||||
SignSalt: signSalt,
|
||||
}
|
||||
|
||||
userID, err := ra.ParseCookie(cookie)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
if userID != "1" {
|
||||
t.Errorf("Expecting userID 1 got %s", userID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRailsEncryptedSession52(t *testing.T) {
|
||||
cookie :=
|
||||
"fZy1lt%2FIuXh2cpQgy3wWjbvabh1AqJX%2Bt6qO4D95DOZIpDhMyK2HqPFeNoaBtrXCUa9%2BDQuvbs1GX6tuccEAp14QPLNhm0PPJS5U1pRHqPLWaqT%2BBPYP%2BY9bo677komm9CPuOCOqBKf7rv3%2F4ptLmVO7iefB%2FP2ZlkV1848Johv5q%2B5PGyMxII2BEQnBdS3Petw6lRu741Bquc8z9VofC3t4%2F%2BLxVz%2BvBbTg--VL0MorYITXB8Dj3W--0yr0sr6pRU%2FwlYMQ%2BpEifA%3D%3D"
|
||||
|
||||
secret := "0a248500a64c01184edb4d7ad3a805488f8097ac761b76aaa6c17c01dcb7af03a2f18ba61b2868134b9c7b79a122bc0dadff4367414a2d173297bfea92be5566"
|
||||
|
||||
ra := Auth{
|
||||
Cipher: railsCipher52,
|
||||
Secret: secret,
|
||||
AuthSalt: authSalt,
|
||||
}
|
||||
|
||||
userID, err := ra.ParseCookie(cookie)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
if userID != "2" {
|
||||
t.Errorf("Expecting userID 2 got %s", userID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRailsJsonSession(t *testing.T) {
|
||||
sessionData := `{"warden.user.user.key":[[1],"secret"]}`
|
||||
|
||||
userID, err := getUserId([]byte(sessionData))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
if userID != "1" {
|
||||
t.Errorf("Expecting userID 1 got %s", userID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRailsMarshaledSession(t *testing.T) {
|
||||
sessionData := "\x04\b{\bI\"\x15member_return_to\x06:\x06ETI\"\x06/\x06;\x00TI\"\x19warden.user.user.key\x06;\x00T[\a[\x06i\aI\"\"$2a$11$6SgXdvO9hld82kQAvpEY3e\x06;\x00TI\"\x10_csrf_token\x06;\x00FI\"17lqwj1UsTTgbXBQKH4ipCNW32uLusvfSPds1txppMec=\x06;\x00F"
|
||||
|
||||
userID, err := getUserId4([]byte(sessionData))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
if userID != "2" {
|
||||
t.Errorf("Expecting userID 2 got %s", userID)
|
||||
}
|
||||
}
|
62
internal/serv/internal/rails/cookie.go
Normal file
62
internal/serv/internal/rails/cookie.go
Normal file
@ -0,0 +1,62 @@
|
||||
package rails
|
||||
|
||||
import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/sha1"
|
||||
"encoding/base64"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/adjust/gorails/session"
|
||||
"golang.org/x/crypto/pbkdf2"
|
||||
)
|
||||
|
||||
func parseCookie(cookie, secretKeyBase, salt, signSalt string) ([]byte, error) {
|
||||
return session.DecryptSignedCookie(
|
||||
cookie,
|
||||
secretKeyBase,
|
||||
salt,
|
||||
signSalt)
|
||||
}
|
||||
|
||||
// {"session_id":"a71d6ffcd4ed5572ea2097f569eb95ef","warden.user.user.key":[[2],"$2a$11$q9Br7m4wJxQvF11hAHvTZO"],"_csrf_token":"HsYgrD2YBaWAabOYceN0hluNRnGuz49XiplmMPt43aY="}
|
||||
|
||||
func parseCookie52(cookie, secretKeyBase, authSalt string) ([]byte, error) {
|
||||
ecookie, err := url.QueryUnescape(cookie)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
vectors := strings.Split(ecookie, "--")
|
||||
|
||||
body, err := base64.RawStdEncoding.DecodeString(vectors[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
iv, err := base64.RawStdEncoding.DecodeString(vectors[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
tag, err := base64.StdEncoding.DecodeString(vectors[2])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
key := pbkdf2.Key([]byte(secretKeyBase), []byte(authSalt),
|
||||
1000, 32, sha1.New)
|
||||
|
||||
c, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
gcm, err := cipher.NewGCM(c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return gcm.Open(nil, iv, append(body, tag...), nil)
|
||||
}
|
Reference in New Issue
Block a user