Add migrate command
This commit is contained in:
@ -1,6 +1,7 @@
|
||||
package serv
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
@ -206,7 +207,7 @@ func (al *allowList) save(item *allowItem) {
|
||||
f.WriteString(fmt.Sprintf("# %s\n\n", k))
|
||||
|
||||
for i := range v {
|
||||
if len(v[i].vars) != 0 {
|
||||
if len(v[i].vars) != 0 && bytes.Equal(v[i].vars, []byte("{}")) == false {
|
||||
vj, err := json.MarshalIndent(v[i].vars, "", "\t")
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msg("Failed to write allow list 'vars' to file")
|
||||
@ -215,7 +216,11 @@ func (al *allowList) save(item *allowItem) {
|
||||
f.WriteString(fmt.Sprintf("variables %s\n\n", vj))
|
||||
}
|
||||
|
||||
f.WriteString(fmt.Sprintf("%s\n\n", v[i].gql))
|
||||
if v[i].gql[0] == '{' {
|
||||
f.WriteString(fmt.Sprintf("query %s\n\n", v[i].gql))
|
||||
} else {
|
||||
f.WriteString(fmt.Sprintf("%s\n\n", v[i].gql))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
290
serv/cmd.go
290
serv/cmd.go
@ -1,17 +1,19 @@
|
||||
package serv
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"flag"
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/dosco/super-graph/psql"
|
||||
"github.com/dosco/super-graph/qcode"
|
||||
"github.com/go-pg/pg"
|
||||
"github.com/gobuffalo/flect"
|
||||
"github.com/jackc/pgx/v4"
|
||||
"github.com/jackc/pgx/v4/log/zerologadapter"
|
||||
"github.com/jackc/pgx/v4/pgxpool"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
@ -28,12 +30,109 @@ const (
|
||||
var (
|
||||
logger *zerolog.Logger
|
||||
conf *config
|
||||
db *pg.DB
|
||||
confPath string
|
||||
db *pgxpool.Pool
|
||||
qcompile *qcode.Compiler
|
||||
pcompile *psql.Compiler
|
||||
authFailBlock int
|
||||
|
||||
rootCmd *cobra.Command
|
||||
servCmd *cobra.Command
|
||||
seedCmd *cobra.Command
|
||||
migrateCmd *cobra.Command
|
||||
statusCmd *cobra.Command
|
||||
newMigrationCmd *cobra.Command
|
||||
)
|
||||
|
||||
func Init() {
|
||||
var err error
|
||||
|
||||
rootCmd = &cobra.Command{
|
||||
Use: "super-graph",
|
||||
Short: "An instant high-performance GraphQL API. No code needed. https://supergraph.dev",
|
||||
//Run: cmdServ,
|
||||
}
|
||||
|
||||
seedCmd = &cobra.Command{
|
||||
Use: "seed",
|
||||
Short: "Run the seed script to seed the database",
|
||||
Run: cmdSeed,
|
||||
}
|
||||
|
||||
servCmd = &cobra.Command{
|
||||
Use: "serv",
|
||||
Short: "Run the super-graph service",
|
||||
Run: cmdServ,
|
||||
}
|
||||
|
||||
migrateCmd = &cobra.Command{
|
||||
Use: "migrate",
|
||||
Short: "Migrate the database",
|
||||
Long: `Migrate the database to destination migration version.
|
||||
|
||||
Destination migration version can be one of the following value types:
|
||||
|
||||
An integer:
|
||||
Migrate to a specific migration.
|
||||
e.g. tern migrate -d 42
|
||||
|
||||
"+" and an integer:
|
||||
Migrate forward N steps.
|
||||
e.g. tern migrate -d +3
|
||||
|
||||
"-" and an integer:
|
||||
Migrate backward N steps.
|
||||
e.g. tern migrate -d -2
|
||||
|
||||
"-+" and an integer:
|
||||
Redo previous N steps (migrate backward N steps then forward N steps).
|
||||
e.g. tern migrate -d -+1
|
||||
|
||||
The word "last":
|
||||
Migrate to the most recent migration. This is the default value, so it is
|
||||
never needed to specify directly.
|
||||
e.g. tern migrate
|
||||
e.g. tern migrate -d last
|
||||
`,
|
||||
Run: cmdMigrate,
|
||||
}
|
||||
|
||||
statusCmd = &cobra.Command{
|
||||
Use: "status",
|
||||
Short: "Print current migration status",
|
||||
Run: cmdStatus,
|
||||
}
|
||||
|
||||
newMigrationCmd = &cobra.Command{
|
||||
Use: "new NAME",
|
||||
Short: "Generate a new migration",
|
||||
Long: "Generate a new migration with the next sequence number and provided name",
|
||||
Run: cmdNewMigration,
|
||||
}
|
||||
|
||||
logger = initLog()
|
||||
|
||||
rootCmd.Flags().StringVar(&confPath,
|
||||
"path", "./config", "path to config files")
|
||||
|
||||
//cmdMigrate.Flags().StringVarP(&cliOptions.destinationVersion,
|
||||
// "destination", "d", "last", "destination migration version")
|
||||
|
||||
rootCmd.AddCommand(servCmd)
|
||||
rootCmd.AddCommand(seedCmd)
|
||||
rootCmd.AddCommand(migrateCmd)
|
||||
rootCmd.AddCommand(statusCmd)
|
||||
rootCmd.AddCommand(newMigrationCmd)
|
||||
|
||||
if conf, err = initConf(); err != nil {
|
||||
logger.Fatal().Err(err).Msg("failed to read config")
|
||||
}
|
||||
|
||||
if err := rootCmd.Execute(); err != nil {
|
||||
logger.Fatal().Err(err).Send()
|
||||
}
|
||||
}
|
||||
|
||||
func initLog() *zerolog.Logger {
|
||||
logger := zerolog.New(zerolog.ConsoleWriter{Out: os.Stderr}).
|
||||
With().
|
||||
@ -42,24 +141,16 @@ func initLog() *zerolog.Logger {
|
||||
Logger()
|
||||
|
||||
return &logger
|
||||
/*
|
||||
log := logrus.New()
|
||||
logger.Formatter = new(logrus.TextFormatter)
|
||||
logger.Formatter.(*logrus.TextFormatter).DisableColors = false
|
||||
logger.Formatter.(*logrus.TextFormatter).DisableTimestamp = true
|
||||
logger.Level = logrus.TraceLevel
|
||||
logger.Out = os.Stdout
|
||||
*/
|
||||
}
|
||||
|
||||
func initConf(path string) (*config, error) {
|
||||
func initConf() (*config, error) {
|
||||
vi := viper.New()
|
||||
|
||||
vi.SetEnvPrefix("SG")
|
||||
vi.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
|
||||
vi.AutomaticEnv()
|
||||
|
||||
vi.AddConfigPath(path)
|
||||
vi.AddConfigPath(confPath)
|
||||
vi.AddConfigPath("./config")
|
||||
vi.SetConfigName(getConfigName())
|
||||
|
||||
@ -73,6 +164,7 @@ func initConf(path string) (*config, error) {
|
||||
vi.SetDefault("database.host", "localhost")
|
||||
vi.SetDefault("database.port", 5432)
|
||||
vi.SetDefault("database.user", "postgres")
|
||||
vi.SetDefault("database.schema", "public")
|
||||
|
||||
vi.SetDefault("env", "development")
|
||||
vi.BindEnv("env", "GO_ENV")
|
||||
@ -103,102 +195,104 @@ func initConf(path string) (*config, error) {
|
||||
|
||||
authFailBlock = getAuthFailBlock(c)
|
||||
|
||||
//fmt.Printf("%#v", c)
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func initDB(c *config) (*pg.DB, error) {
|
||||
opt := &pg.Options{
|
||||
Addr: strings.Join([]string{c.DB.Host, c.DB.Port}, ":"),
|
||||
User: c.DB.User,
|
||||
Password: c.DB.Password,
|
||||
Database: c.DB.DBName,
|
||||
ApplicationName: c.AppName,
|
||||
}
|
||||
|
||||
if c.DB.PoolSize != 0 {
|
||||
opt.PoolSize = conf.DB.PoolSize
|
||||
}
|
||||
|
||||
if c.DB.MaxRetries != 0 {
|
||||
opt.MaxRetries = c.DB.MaxRetries
|
||||
}
|
||||
|
||||
if len(c.DB.Schema) != 0 {
|
||||
opt.OnConnect = func(conn *pg.Conn) error {
|
||||
_, err := conn.Exec("set search_path=?", c.DB.Schema)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
db := pg.Connect(opt)
|
||||
if db == nil {
|
||||
return nil, errors.New("failed to connect to postgres db")
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func Init() {
|
||||
var err error
|
||||
|
||||
path := flag.String("path", "./config", "Path to config files")
|
||||
flag.Parse()
|
||||
|
||||
logger = initLog()
|
||||
|
||||
conf, err = initConf(*path)
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("failed to read config")
|
||||
}
|
||||
|
||||
logLevel, err := zerolog.ParseLevel(conf.LogLevel)
|
||||
logLevel, err := zerolog.ParseLevel(c.LogLevel)
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Msg("error setting log_level")
|
||||
}
|
||||
zerolog.SetGlobalLevel(logLevel)
|
||||
|
||||
db, err = initDB(conf)
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("failed to connect to database")
|
||||
//fmt.Printf("%#v", c)
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func initDB(c *config) (*pgx.Conn, error) {
|
||||
config, _ := pgx.ParseConfig("")
|
||||
config.Host = c.DB.Host
|
||||
config.Port = c.DB.Port
|
||||
config.Database = c.DB.DBName
|
||||
config.User = c.DB.User
|
||||
config.Password = c.DB.Password
|
||||
config.RuntimeParams = map[string]string{
|
||||
"application_name": c.AppName,
|
||||
"search_path": c.DB.Schema,
|
||||
}
|
||||
|
||||
switch c.LogLevel {
|
||||
case "debug":
|
||||
config.LogLevel = pgx.LogLevelDebug
|
||||
case "info":
|
||||
config.LogLevel = pgx.LogLevelInfo
|
||||
case "warn":
|
||||
config.LogLevel = pgx.LogLevelWarn
|
||||
case "error":
|
||||
config.LogLevel = pgx.LogLevelError
|
||||
default:
|
||||
config.LogLevel = pgx.LogLevelNone
|
||||
}
|
||||
|
||||
config.Logger = zerologadapter.NewLogger(*logger)
|
||||
|
||||
db, err := pgx.ConnectConfig(context.Background(), config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func initDBPool(c *config) (*pgxpool.Pool, error) {
|
||||
config, _ := pgxpool.ParseConfig("")
|
||||
config.ConnConfig.Host = c.DB.Host
|
||||
config.ConnConfig.Port = c.DB.Port
|
||||
config.ConnConfig.Database = c.DB.DBName
|
||||
config.ConnConfig.User = c.DB.User
|
||||
config.ConnConfig.Password = c.DB.Password
|
||||
config.ConnConfig.RuntimeParams = map[string]string{
|
||||
"application_name": c.AppName,
|
||||
"search_path": c.DB.Schema,
|
||||
}
|
||||
|
||||
switch c.LogLevel {
|
||||
case "debug":
|
||||
config.ConnConfig.LogLevel = pgx.LogLevelDebug
|
||||
case "info":
|
||||
config.ConnConfig.LogLevel = pgx.LogLevelInfo
|
||||
case "warn":
|
||||
config.ConnConfig.LogLevel = pgx.LogLevelWarn
|
||||
case "error":
|
||||
config.ConnConfig.LogLevel = pgx.LogLevelError
|
||||
default:
|
||||
config.ConnConfig.LogLevel = pgx.LogLevelNone
|
||||
}
|
||||
|
||||
config.ConnConfig.Logger = zerologadapter.NewLogger(*logger)
|
||||
|
||||
// if c.DB.MaxRetries != 0 {
|
||||
// opt.MaxRetries = c.DB.MaxRetries
|
||||
// }
|
||||
|
||||
if c.DB.PoolSize != 0 {
|
||||
config.MaxConns = conf.DB.PoolSize
|
||||
}
|
||||
|
||||
db, err := pgxpool.ConnectConfig(context.Background(), config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func initCompiler() {
|
||||
var err error
|
||||
|
||||
qcompile, pcompile, err = initCompilers(conf)
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("failed to connect to database")
|
||||
logger.Fatal().Err(err).Msg("failed to initialize compilers")
|
||||
}
|
||||
|
||||
if err := initResolvers(); err != nil {
|
||||
logger.Fatal().Err(err).Msg("failed to initialized resolvers")
|
||||
}
|
||||
|
||||
args := flag.Args()
|
||||
|
||||
if len(args) == 0 {
|
||||
cmdServ(*path)
|
||||
}
|
||||
|
||||
switch args[0] {
|
||||
case "seed":
|
||||
cmdSeed(*path)
|
||||
|
||||
case "serv":
|
||||
fallthrough
|
||||
|
||||
default:
|
||||
logger.Fatal().Msg("options: [serve|seed]")
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func cmdServ(path string) {
|
||||
initAllowList(path)
|
||||
initPreparedList()
|
||||
initWatcher(path)
|
||||
|
||||
startHTTP()
|
||||
}
|
||||
|
@ -11,12 +11,21 @@ import (
|
||||
|
||||
"github.com/brianvoe/gofakeit"
|
||||
"github.com/dop251/goja"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func cmdSeed(cpath string) {
|
||||
func cmdSeed(cmd *cobra.Command, args []string) {
|
||||
var err error
|
||||
conf.UseAllowList = false
|
||||
|
||||
b, err := ioutil.ReadFile(path.Join(cpath, conf.SeedFile))
|
||||
db, err = initDBPool(conf)
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("failed to connect to database")
|
||||
}
|
||||
|
||||
initCompiler()
|
||||
|
||||
b, err := ioutil.ReadFile(path.Join(confPath, conf.SeedFile))
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("failed to read file")
|
||||
}
|
21
serv/cmd_serv.go
Normal file
21
serv/cmd_serv.go
Normal file
@ -0,0 +1,21 @@
|
||||
package serv
|
||||
|
||||
import (
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
func cmdServ(cmd *cobra.Command, args []string) {
|
||||
var err error
|
||||
|
||||
db, err = initDBPool(conf)
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("failed to connect to database")
|
||||
}
|
||||
|
||||
initCompiler()
|
||||
initAllowList(confPath)
|
||||
initPreparedList()
|
||||
initWatcher(confPath)
|
||||
|
||||
startHTTP()
|
||||
}
|
231
serv/cmd_tern.go
Normal file
231
serv/cmd_tern.go
Normal file
@ -0,0 +1,231 @@
|
||||
package serv
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/dosco/super-graph/migrate"
|
||||
"github.com/spf13/cobra"
|
||||
)
|
||||
|
||||
var sampleMigration = `-- This is a sample migration.
|
||||
|
||||
create table users(
|
||||
id serial primary key,
|
||||
fullname varchar not null,
|
||||
email varchar not null
|
||||
);
|
||||
|
||||
---- create above / drop below ----
|
||||
|
||||
drop table users;
|
||||
`
|
||||
|
||||
var newMigrationText = `-- Write your migrate up statements here
|
||||
|
||||
---- create above / drop below ----
|
||||
|
||||
-- Write your migrate down statements here. If this migration is irreversible
|
||||
-- Then delete the separator line above.
|
||||
`
|
||||
|
||||
func cmdNewMigration(cmd *cobra.Command, args []string) {
|
||||
if len(args) != 1 {
|
||||
cmd.Help()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
name := args[0]
|
||||
|
||||
m, err := migrate.FindMigrations(conf.MigrationsPath)
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Error loading migrations:\n %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
mname := fmt.Sprintf("%03d_%s.sql", len(m)+100, name)
|
||||
|
||||
// Write new migration
|
||||
mpath := filepath.Join(conf.MigrationsPath, mname)
|
||||
mfile, err := os.OpenFile(mpath, os.O_CREATE|os.O_EXCL|os.O_WRONLY, 0666)
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
defer mfile.Close()
|
||||
|
||||
_, err = mfile.WriteString(newMigrationText)
|
||||
if err != nil {
|
||||
fmt.Fprintln(os.Stderr, err)
|
||||
os.Exit(1)
|
||||
}
|
||||
logger.Info().Msgf("created migration '%s'\n", mpath)
|
||||
}
|
||||
|
||||
func cmdMigrate(cmd *cobra.Command, args []string) {
|
||||
conn, err := initDB(conf)
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("failed to connect to database")
|
||||
}
|
||||
defer conn.Close(context.Background())
|
||||
|
||||
m, err := migrate.NewMigrator(conn, "schema_version")
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("failed to initializing migrator")
|
||||
}
|
||||
//m.Data = config.Data
|
||||
|
||||
err = m.LoadMigrations(conf.MigrationsPath)
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("failed to load migrations")
|
||||
}
|
||||
|
||||
if len(m.Migrations) == 0 {
|
||||
logger.Fatal().Msg("No migrations found")
|
||||
}
|
||||
|
||||
m.OnStart = func(sequence int32, name, direction, sql string) {
|
||||
logger.Info().Msgf("%s executing %s %s\n%s\n\n",
|
||||
time.Now().Format("2006-01-02 15:04:05"), name, direction, sql)
|
||||
}
|
||||
|
||||
var currentVersion int32
|
||||
currentVersion, err = m.GetCurrentVersion()
|
||||
if err != nil {
|
||||
fmt.Fprintf(os.Stderr, "Unable to get current version:\n %v\n", err)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
dest := args[0]
|
||||
mustParseDestination := func(d string) int32 {
|
||||
var n int64
|
||||
n, err = strconv.ParseInt(d, 10, 32)
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("invalid destination")
|
||||
}
|
||||
return int32(n)
|
||||
}
|
||||
|
||||
if dest == "last" {
|
||||
err = m.Migrate()
|
||||
|
||||
} else if len(dest) >= 3 && dest[0:2] == "-+" {
|
||||
err = m.MigrateTo(currentVersion - mustParseDestination(dest[2:]))
|
||||
if err == nil {
|
||||
err = m.MigrateTo(currentVersion)
|
||||
}
|
||||
|
||||
} else if len(dest) >= 2 && dest[0] == '-' {
|
||||
err = m.MigrateTo(currentVersion - mustParseDestination(dest[1:]))
|
||||
|
||||
} else if len(dest) >= 2 && dest[0] == '+' {
|
||||
err = m.MigrateTo(currentVersion + mustParseDestination(dest[1:]))
|
||||
|
||||
} else {
|
||||
//err = make(type, 0).MigrateTo(mustParseDestination(dest))
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logger.Info().Err(err).Send()
|
||||
|
||||
// logger.Info().Err(err).Send()
|
||||
|
||||
// if err, ok := err.(m.MigrationPgError); ok {
|
||||
// if err.Detail != "" {
|
||||
// logger.Info().Err(err).Msg(err.Detail)
|
||||
// }
|
||||
|
||||
// if err.Position != 0 {
|
||||
// ele, err := ExtractErrorLine(err.Sql, int(err.Position))
|
||||
// if err != nil {
|
||||
// logger.Fatal().Err(err).Send()
|
||||
// }
|
||||
|
||||
// prefix := fmt.Sprintf()
|
||||
// logger.Info().Msgf("line %d, %s%s", ele.LineNum, prefix, ele.Text)
|
||||
// }
|
||||
// }
|
||||
// os.Exit(1)
|
||||
}
|
||||
logger.Info().Msg("migration done")
|
||||
|
||||
}
|
||||
|
||||
func cmdStatus(cmd *cobra.Command, args []string) {
|
||||
conn, err := initDB(conf)
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("failed to connect to database")
|
||||
}
|
||||
defer conn.Close(context.Background())
|
||||
|
||||
m, err := migrate.NewMigrator(conn, "schema_version")
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("failed to initialize migrator")
|
||||
}
|
||||
//m.Data = config.Data
|
||||
|
||||
err = m.LoadMigrations(conf.MigrationsPath)
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("failed to load migrations")
|
||||
}
|
||||
|
||||
if len(m.Migrations) == 0 {
|
||||
logger.Fatal().Msg("no migrations found")
|
||||
}
|
||||
|
||||
mver, err := m.GetCurrentVersion()
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("failed to retrieve migration")
|
||||
}
|
||||
|
||||
var status string
|
||||
behindCount := len(m.Migrations) - int(mver)
|
||||
if behindCount == 0 {
|
||||
status = "up to date"
|
||||
} else {
|
||||
status = "migration(s) pending"
|
||||
}
|
||||
|
||||
fmt.Println("status: ", status)
|
||||
fmt.Println("version: %d of %d\n", mver, len(m.Migrations))
|
||||
fmt.Println("host: ", conf.DB.Host)
|
||||
fmt.Println("database:", conf.DB.DBName)
|
||||
}
|
||||
|
||||
type ErrorLineExtract struct {
|
||||
LineNum int // Line number starting with 1
|
||||
ColumnNum int // Column number starting with 1
|
||||
Text string // Text of the line without a new line character.
|
||||
}
|
||||
|
||||
// ExtractErrorLine takes source and character position extracts the line
|
||||
// number, column number, and the line of text.
|
||||
//
|
||||
// The first character is position 1.
|
||||
func ExtractErrorLine(source string, position int) (ErrorLineExtract, error) {
|
||||
ele := ErrorLineExtract{LineNum: 1}
|
||||
|
||||
if position > len(source) {
|
||||
return ele, fmt.Errorf("position (%d) is greater than source length (%d)", position, len(source))
|
||||
}
|
||||
|
||||
lines := strings.SplitAfter(source, "\n")
|
||||
for _, ele.Text = range lines {
|
||||
if position-len(ele.Text) < 1 {
|
||||
ele.ColumnNum = position
|
||||
break
|
||||
}
|
||||
|
||||
ele.LineNum += 1
|
||||
position -= len(ele.Text)
|
||||
}
|
||||
|
||||
ele.Text = strings.TrimSuffix(ele.Text, "\n")
|
||||
|
||||
return ele, nil
|
||||
}
|
104
serv/cmd_test.go
Normal file
104
serv/cmd_test.go
Normal file
@ -0,0 +1,104 @@
|
||||
package serv_test
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestErrorLineExtract(t *testing.T) {
|
||||
tests := []struct {
|
||||
source string
|
||||
position int
|
||||
ele ErrorLineExtract
|
||||
errMsg string
|
||||
}{
|
||||
{
|
||||
source: "single line",
|
||||
position: 3,
|
||||
ele: ErrorLineExtract{
|
||||
LineNum: 1,
|
||||
ColumnNum: 3,
|
||||
Text: "single line",
|
||||
},
|
||||
errMsg: "",
|
||||
},
|
||||
{
|
||||
source: "bad position",
|
||||
position: 32,
|
||||
ele: ErrorLineExtract{},
|
||||
errMsg: "position (32) is greater than source length (12)",
|
||||
},
|
||||
{
|
||||
source: `multi
|
||||
line
|
||||
text`,
|
||||
position: 8,
|
||||
ele: ErrorLineExtract{
|
||||
LineNum: 2,
|
||||
ColumnNum: 2,
|
||||
Text: "line",
|
||||
},
|
||||
errMsg: "",
|
||||
},
|
||||
{
|
||||
source: `last
|
||||
line
|
||||
error`,
|
||||
position: 13,
|
||||
ele: ErrorLineExtract{
|
||||
LineNum: 3,
|
||||
ColumnNum: 3,
|
||||
Text: "error",
|
||||
},
|
||||
errMsg: "",
|
||||
},
|
||||
{
|
||||
source: `first
|
||||
character
|
||||
first
|
||||
line
|
||||
error`,
|
||||
position: 1,
|
||||
ele: ErrorLineExtract{
|
||||
LineNum: 1,
|
||||
ColumnNum: 1,
|
||||
Text: "first",
|
||||
},
|
||||
errMsg: "",
|
||||
},
|
||||
{
|
||||
source: `last
|
||||
character
|
||||
last
|
||||
line
|
||||
error`,
|
||||
position: 30,
|
||||
ele: ErrorLineExtract{
|
||||
LineNum: 5,
|
||||
ColumnNum: 5,
|
||||
Text: "error",
|
||||
},
|
||||
errMsg: "",
|
||||
},
|
||||
}
|
||||
|
||||
for i, tt := range tests {
|
||||
ele, err := ExtractErrorLine(tt.source, tt.position)
|
||||
if err != nil {
|
||||
if tt.errMsg == "" {
|
||||
t.Errorf("%d. Expected success but received err %v", i, err)
|
||||
} else if err.Error() != tt.errMsg {
|
||||
t.Errorf("%d. Expected err %v, but received %v", i, tt.errMsg, err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
if tt.errMsg != "" {
|
||||
t.Errorf("%d. Expected err %v, but it succeeded", i, tt.errMsg)
|
||||
continue
|
||||
}
|
||||
|
||||
if ele != tt.ele {
|
||||
t.Errorf("%d. Expected %v, but received %v", i, tt.ele, ele)
|
||||
}
|
||||
}
|
||||
}
|
@ -19,7 +19,9 @@ type config struct {
|
||||
WatchAndReload bool `mapstructure:"reload_on_config_change"`
|
||||
AuthFailBlock string `mapstructure:"auth_fail_block"`
|
||||
SeedFile string `mapstructure:"seed_file"`
|
||||
Inflections map[string]string
|
||||
MigrationsPath string `mapstructure:"migrations_path"`
|
||||
|
||||
Inflections map[string]string
|
||||
|
||||
Auth struct {
|
||||
Type string
|
||||
@ -49,12 +51,12 @@ type config struct {
|
||||
DB struct {
|
||||
Type string
|
||||
Host string
|
||||
Port string
|
||||
Port uint16
|
||||
DBName string
|
||||
User string
|
||||
Password string
|
||||
Schema string
|
||||
PoolSize int `mapstructure:"pool_size"`
|
||||
PoolSize int32 `mapstructure:"pool_size"`
|
||||
MaxRetries int `mapstructure:"max_retries"`
|
||||
LogLevel string `mapstructure:"log_level"`
|
||||
|
||||
|
50
serv/core.go
50
serv/core.go
@ -8,7 +8,6 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@ -16,7 +15,6 @@ import (
|
||||
"github.com/dosco/super-graph/jsn"
|
||||
"github.com/dosco/super-graph/psql"
|
||||
"github.com/dosco/super-graph/qcode"
|
||||
"github.com/go-pg/pg"
|
||||
"github.com/valyala/fasttemplate"
|
||||
)
|
||||
|
||||
@ -267,45 +265,45 @@ func (c *coreContext) resolvePreparedSQL(gql string) ([]byte, *preparedItem, err
|
||||
return nil, nil, errUnauthorized
|
||||
}
|
||||
|
||||
var root json.RawMessage
|
||||
var root []byte
|
||||
vars := varList(c, ps.args)
|
||||
|
||||
tx, err := db.Begin()
|
||||
tx, err := db.Begin(c)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
defer tx.Rollback(c)
|
||||
|
||||
if v := c.Value(userIDKey); v != nil {
|
||||
_, err = tx.Exec(`SET LOCAL "user.id" = ?`, v)
|
||||
_, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
|
||||
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
_, err = tx.Stmt(ps.stmt).QueryOne(pg.Scan(&root), vars...)
|
||||
err = tx.QueryRow(c, ps.stmt.SQL, vars...).Scan(&root)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
if err := tx.Commit(c); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
fmt.Printf("PRE: %v\n", ps.stmt)
|
||||
|
||||
return []byte(root), ps, nil
|
||||
return root, ps, nil
|
||||
}
|
||||
|
||||
func (c *coreContext) resolveSQL(qc *qcode.QCode) (
|
||||
[]byte, uint32, error) {
|
||||
func (c *coreContext) resolveSQL(qc *qcode.QCode) ([]byte, uint32, error) {
|
||||
var vars map[string]json.RawMessage
|
||||
stmt := &bytes.Buffer{}
|
||||
|
||||
vars := make(map[string]json.RawMessage)
|
||||
|
||||
if err := json.Unmarshal(c.req.Vars, &vars); err != nil {
|
||||
return nil, 0, err
|
||||
if len(c.req.Vars) != 0 {
|
||||
if err := json.Unmarshal(c.req.Vars, &vars); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
}
|
||||
|
||||
skipped, err := pcompile.Compile(qc, stmt, psql.Variables(vars))
|
||||
@ -330,10 +328,10 @@ func (c *coreContext) resolveSQL(qc *qcode.QCode) (
|
||||
|
||||
finalSQL := stmt.String()
|
||||
|
||||
if conf.LogLevel == "debug" {
|
||||
os.Stdout.WriteString(finalSQL)
|
||||
os.Stdout.WriteString("\n\n")
|
||||
}
|
||||
// if conf.LogLevel == "debug" {
|
||||
// os.Stdout.WriteString(finalSQL)
|
||||
// os.Stdout.WriteString("\n\n")
|
||||
// }
|
||||
|
||||
var st time.Time
|
||||
|
||||
@ -341,14 +339,14 @@ func (c *coreContext) resolveSQL(qc *qcode.QCode) (
|
||||
st = time.Now()
|
||||
}
|
||||
|
||||
tx, err := db.Begin()
|
||||
tx, err := db.Begin(c)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer tx.Rollback()
|
||||
defer tx.Rollback(c)
|
||||
|
||||
if v := c.Value(userIDKey); v != nil {
|
||||
_, err = tx.Exec(`SET LOCAL "user.id" = ?`, v)
|
||||
_, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
|
||||
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
@ -357,14 +355,14 @@ func (c *coreContext) resolveSQL(qc *qcode.QCode) (
|
||||
|
||||
//fmt.Printf("\nRAW: %#v\n", finalSQL)
|
||||
|
||||
var root json.RawMessage
|
||||
_, err = tx.QueryOne(pg.Scan(&root), finalSQL)
|
||||
var root []byte
|
||||
|
||||
err = tx.QueryRow(c, finalSQL).Scan(&root)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
if err := tx.Commit(); err != nil {
|
||||
if err := tx.Commit(c); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
@ -379,7 +377,7 @@ func (c *coreContext) resolveSQL(qc *qcode.QCode) (
|
||||
_allowList.add(&c.req)
|
||||
}
|
||||
|
||||
return []byte(root), skipped, nil
|
||||
return root, skipped, nil
|
||||
}
|
||||
|
||||
func (c *coreContext) render(w io.Writer, data []byte) error {
|
||||
|
@ -2,18 +2,19 @@ package serv
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/dosco/super-graph/psql"
|
||||
"github.com/dosco/super-graph/qcode"
|
||||
"github.com/go-pg/pg"
|
||||
"github.com/jackc/pgconn"
|
||||
"github.com/valyala/fasttemplate"
|
||||
)
|
||||
|
||||
type preparedItem struct {
|
||||
stmt *pg.Stmt
|
||||
stmt *pgconn.StatementDescription
|
||||
args [][]byte
|
||||
skipped uint32
|
||||
qc *qcode.QCode
|
||||
@ -75,7 +76,15 @@ func prepareStmt(key, gql string, varBytes json.RawMessage) error {
|
||||
return err
|
||||
}
|
||||
|
||||
pstmt, err := db.Prepare(finalSQL)
|
||||
ctx := context.Background()
|
||||
|
||||
tx, err := db.Begin(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
pstmt, err := tx.Prepare(ctx, "", finalSQL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -87,5 +96,9 @@ func prepareStmt(key, gql string, varBytes json.RawMessage) error {
|
||||
qc: qc,
|
||||
}
|
||||
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -92,9 +92,7 @@ func startHTTP() {
|
||||
}()
|
||||
|
||||
srv.RegisterOnShutdown(func() {
|
||||
if err := db.Close(); err != nil {
|
||||
logger.Error().Err(err).Msg("db closed")
|
||||
}
|
||||
db.Close()
|
||||
})
|
||||
|
||||
fmt.Printf("%s listening on %s (%s)\n", serverName, hostPort, conf.Env)
|
||||
|
Reference in New Issue
Block a user