From f518d5fc69c2d72534f05bc51ad5a6a36359fa16 Mon Sep 17 00:00:00 2001 From: Vikram Rangnekar Date: Mon, 25 Nov 2019 02:22:33 -0500 Subject: [PATCH] Fix bug with compiling anon queries --- config/allow.list | 156 +------------------- config/dev.yml | 21 ++- docs/guide.md | 20 +-- psql/query.go | 6 +- qcode/fuzz.go | 2 + qcode/qcode.go | 1 + qcode/utils.go | 23 +++ serv/allow.go | 9 +- serv/args.go | 29 ++-- serv/auth_jwt.go | 4 +- serv/auth_rails.go | 18 +-- serv/cmd.go | 33 ++--- serv/cmd_conf.go | 4 +- serv/cmd_migrate.go | 38 +++-- serv/cmd_new.go | 4 +- serv/cmd_seed.go | 52 +++---- serv/cmd_serv.go | 4 +- serv/config.go | 32 ++-- serv/core.go | 352 +++++++++++--------------------------------- serv/core_build.go | 244 +++++++++++++++--------------- serv/core_remote.go | 197 +++++++++++++++++++++++++ serv/fuzz.go | 1 - serv/fuzz_test.go | 1 - serv/http.go | 7 +- serv/prepare.go | 158 ++++++++++++-------- serv/reload.go | 2 +- serv/reso.go | 2 +- serv/serv.go | 26 ++-- serv/utils.go | 13 -- tmpl/dev.yml | 20 +-- 30 files changed, 687 insertions(+), 792 deletions(-) create mode 100644 qcode/utils.go create mode 100644 serv/core_remote.go diff --git a/config/allow.list b/config/allow.list index aa1f6ab..5eec9c2 100644 --- a/config/allow.list +++ b/config/allow.list @@ -73,43 +73,6 @@ mutation { } } -variables { - "data": [ - { - "name": "Gumbo1", - "created_at": "now", - "updated_at": "now" - }, - { - "name": "Gumbo2", - "created_at": "now", - "updated_at": "now" - } - ] -} - -query { - products { - id - name - } -} - -variables { - "data": [ - { - "name": "Gumbo1", - "created_at": "now", - "updated_at": "now" - }, - { - "name": "Gumbo2", - "created_at": "now", - "updated_at": "now" - } - ] -} - query { products { id @@ -133,21 +96,6 @@ mutation { } } -variables { - "data": [ - { - "name": "Gumbo1", - "created_at": "now", - "updated_at": "now" - }, - { - "name": "Gumbo2", - "created_at": "now", - "updated_at": "now" - } - ] -} - query { products { id @@ -174,39 +122,6 @@ mutation { } } -variables { - "data": [ - { - "name": "Gumbo1", - "created_at": "now", - "updated_at": "now" - }, - { - "name": "Gumbo2", - "created_at": "now", - "updated_at": "now" - } - ] -} - -query { - products { - id - name - users { - email - } - } -} - -query { - me { - id - email - full_name - } -} - variables { "update": { "name": "Helloo", @@ -223,70 +138,6 @@ mutation { } } -variables { - "data": [ - { - "name": "Gumbo1", - "created_at": "now", - "updated_at": "now" - }, - { - "name": "Gumbo2", - "created_at": "now", - "updated_at": "now" - } - ] -} - -query { - product { - id - name - } -} - -variables { - "data": [ - { - "name": "Gumbo1", - "created_at": "now", - "updated_at": "now" - }, - { - "name": "Gumbo2", - "created_at": "now", - "updated_at": "now" - } - ] -} - -query { - products { - id - name - description - users { - email - } - } -} - -query { - users { - id - email - picture: avatar - password - full_name - products(limit: 2, where: {price: {gt: 10}}) { - id - name - description - price - } - } -} - variables { "data": { "name": "WOOO", @@ -301,4 +152,11 @@ mutation { } } +query { + products { + id + name + } +} + diff --git a/config/dev.yml b/config/dev.yml index bb468e0..b721b86 100644 --- a/config/dev.yml +++ b/config/dev.yml @@ -101,18 +101,14 @@ database: variables: admin_account_id: "5" - # Define defaults to for the field key and values below - defaults: - # filters: ["{ user_id: { eq: $user_id } }"] - - # Field and table names that you wish to block - blocklist: - - ar_internal_metadata - - schema_migrations - - secret - - password - - encrypted - - token + # Field and table names that you wish to block + blocklist: + - ar_internal_metadata + - schema_migrations + - secret + - password + - encrypted + - token tables: - name: customers @@ -140,6 +136,7 @@ roles_query: "SELECT * FROM users WHERE id = $user_id" roles: - name: anon tables: + - name: users - name: products limit: 10 diff --git a/docs/guide.md b/docs/guide.md index 59039f8..863ac77 100644 --- a/docs/guide.md +++ b/docs/guide.md @@ -1275,18 +1275,14 @@ database: variables: admin_account_id: "5" - # Define defaults to for the field key and values below - defaults: - # filters: ["{ user_id: { eq: $user_id } }"] - - # Field and table names that you wish to block - blocklist: - - ar_internal_metadata - - schema_migrations - - secret - - password - - encrypted - - token + # Field and table names that you wish to block + blocklist: + - ar_internal_metadata + - schema_migrations + - secret + - password + - encrypted + - token tables: - name: customers diff --git a/psql/query.go b/psql/query.go index 3bc8f0f..1d759b8 100644 --- a/psql/query.go +++ b/psql/query.go @@ -500,7 +500,7 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo, var groupBy []int isRoot := sel.ParentID == -1 - isFil := sel.Where != nil + isFil := (sel.Where != nil && sel.Where.Op != qcode.OpNop) isSearch := sel.Args["search"] != nil isAgg := false @@ -880,6 +880,10 @@ func (c *compilerContext) renderOp(ex *qcode.Exp, sel *qcode.Select, ti *DBTable var col *DBColumn var ok bool + if ex.Op == qcode.OpNop { + return nil + } + if len(ex.Col) != 0 { if col, ok = ti.Columns[ex.Col]; !ok { return fmt.Errorf("no column '%s' found ", ex.Col) diff --git a/qcode/fuzz.go b/qcode/fuzz.go index 1ab9de6..630bba6 100644 --- a/qcode/fuzz.go +++ b/qcode/fuzz.go @@ -4,6 +4,8 @@ package qcode // FuzzerEntrypoint for Fuzzbuzz func Fuzz(data []byte) int { + GetQType(string(data)) + qcompile, _ := NewCompiler(Config{}) _, err := qcompile.Compile(data, "user") if err != nil { diff --git a/qcode/qcode.go b/qcode/qcode.go index 1d47560..cfb7869 100644 --- a/qcode/qcode.go +++ b/qcode/qcode.go @@ -20,6 +20,7 @@ const ( const ( QTQuery QType = iota + 1 + QTMutation QTInsert QTUpdate QTDelete diff --git a/qcode/utils.go b/qcode/utils.go new file mode 100644 index 0000000..346043a --- /dev/null +++ b/qcode/utils.go @@ -0,0 +1,23 @@ +package qcode + +func GetQType(gql string) QType { + for i := range gql { + b := gql[i] + if b == '{' { + return QTQuery + } + if al(b) { + switch b { + case 'm', 'M': + return QTMutation + case 'q', 'Q': + return QTQuery + } + } + } + return -1 +} + +func al(b byte) bool { + return (b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z') || (b >= '0' && b <= '9') +} diff --git a/serv/allow.go b/serv/allow.go index b238580..3ce2b83 100644 --- a/serv/allow.go +++ b/serv/allow.go @@ -46,7 +46,7 @@ func initAllowList(cpath string) { if _, err := os.Stat(fp); err == nil { _allowList.filepath = fp } else if !os.IsNotExist(err) { - logger.Fatal().Err(err).Send() + errlog.Fatal().Err(err).Send() } } @@ -56,7 +56,7 @@ func initAllowList(cpath string) { if _, err := os.Stat(fp); err == nil { _allowList.filepath = fp } else if !os.IsNotExist(err) { - logger.Fatal().Err(err).Send() + errlog.Fatal().Err(err).Send() } } @@ -66,13 +66,13 @@ func initAllowList(cpath string) { if _, err := os.Stat(fp); err == nil { _allowList.filepath = fp } else if !os.IsNotExist(err) { - logger.Fatal().Err(err).Send() + errlog.Fatal().Err(err).Send() } } if len(_allowList.filepath) == 0 { if conf.Production { - logger.Fatal().Msg("allow.list not found") + errlog.Fatal().Msg("allow.list not found") } if len(cpath) == 0 { @@ -187,7 +187,6 @@ func (al *allowList) load() { item.gql = q item.vars = varBytes } - varBytes = nil } else if ty == AL_VARS { diff --git a/serv/args.go b/serv/args.go index a0f6434..41b8b0d 100644 --- a/serv/args.go +++ b/serv/args.go @@ -2,44 +2,46 @@ package serv import ( "bytes" + "context" + "errors" "fmt" "io" "github.com/dosco/super-graph/jsn" ) -func argMap(ctx *coreContext) func(w io.Writer, tag string) (int, error) { +func argMap(ctx context.Context, vars []byte) func(w io.Writer, tag string) (int, error) { return func(w io.Writer, tag string) (int, error) { switch tag { case "user_id_provider": if v := ctx.Value(userIDProviderKey); v != nil { return io.WriteString(w, v.(string)) } - return io.WriteString(w, "null") + return 0, errors.New("query requires variable $user_id_provider") case "user_id": if v := ctx.Value(userIDKey); v != nil { return io.WriteString(w, v.(string)) } - return io.WriteString(w, "null") + return 0, errors.New("query requires variable $user_id") case "user_role": if v := ctx.Value(userRoleKey); v != nil { return io.WriteString(w, v.(string)) } - return io.WriteString(w, "null") + return 0, errors.New("query requires variable $user_role") } - fields := jsn.Get(ctx.req.Vars, [][]byte{[]byte(tag)}) + fields := jsn.Get(vars, [][]byte{[]byte(tag)}) if len(fields) == 0 { - return 0, fmt.Errorf("variable '%s' not found", tag) + return 0, nil } return w.Write(fields[0].Value) } } -func argList(ctx *coreContext, args [][]byte) []interface{} { +func argList(ctx *coreContext, args [][]byte) ([]interface{}, error) { vars := make([]interface{}, len(args)) var fields map[string]interface{} @@ -49,7 +51,7 @@ func argList(ctx *coreContext, args [][]byte) []interface{} { fields, _, err = jsn.Tree(ctx.req.Vars) if err != nil { - logger.Warn().Err(err).Msg("Failed to parse variables") + return nil, err } } @@ -60,24 +62,33 @@ func argList(ctx *coreContext, args [][]byte) []interface{} { case bytes.Equal(av, []byte("user_id")): if v := ctx.Value(userIDKey); v != nil { vars[i] = v.(string) + } else { + return nil, errors.New("query requires variable $user_id") } case bytes.Equal(av, []byte("user_id_provider")): if v := ctx.Value(userIDProviderKey); v != nil { vars[i] = v.(string) + } else { + return nil, errors.New("query requires variable $user_id_provider") } case bytes.Equal(av, []byte("user_role")): if v := ctx.Value(userRoleKey); v != nil { vars[i] = v.(string) + } else { + return nil, errors.New("query requires variable $user_role") } default: if v, ok := fields[string(av)]; ok { vars[i] = v + } else { + return nil, fmt.Errorf("query requires variable $%s", string(av)) + } } } - return vars + return vars, nil } diff --git a/serv/auth_jwt.go b/serv/auth_jwt.go index d7041a2..326fd12 100644 --- a/serv/auth_jwt.go +++ b/serv/auth_jwt.go @@ -35,7 +35,7 @@ func jwtHandler(next http.HandlerFunc) http.HandlerFunc { case len(publicKeyFile) != 0: kd, err := ioutil.ReadFile(publicKeyFile) if err != nil { - logger.Fatal().Err(err).Send() + errlog.Fatal().Err(err).Send() } switch conf.Auth.JWT.PubKeyType { @@ -51,7 +51,7 @@ func jwtHandler(next http.HandlerFunc) http.HandlerFunc { } if err != nil { - logger.Fatal().Err(err).Send() + errlog.Fatal().Err(err).Send() } } diff --git a/serv/auth_rails.go b/serv/auth_rails.go index 4b71d75..0673f9a 100644 --- a/serv/auth_rails.go +++ b/serv/auth_rails.go @@ -15,11 +15,11 @@ import ( func railsRedisHandler(next http.HandlerFunc) http.HandlerFunc { cookie := conf.Auth.Cookie if len(cookie) == 0 { - logger.Fatal().Msg("no auth.cookie defined") + errlog.Fatal().Msg("no auth.cookie defined") } if len(conf.Auth.Rails.URL) == 0 { - logger.Fatal().Msg("no auth.rails.url defined") + errlog.Fatal().Msg("no auth.rails.url defined") } rp := &redis.Pool{ @@ -28,13 +28,13 @@ func railsRedisHandler(next http.HandlerFunc) http.HandlerFunc { Dial: func() (redis.Conn, error) { c, err := redis.DialURL(conf.Auth.Rails.URL) if err != nil { - logger.Fatal().Err(err).Send() + errlog.Fatal().Err(err).Send() } pwd := conf.Auth.Rails.Password if len(pwd) != 0 { if _, err := c.Do("AUTH", pwd); err != nil { - logger.Fatal().Err(err).Send() + errlog.Fatal().Err(err).Send() } } return c, err @@ -69,16 +69,16 @@ func railsRedisHandler(next http.HandlerFunc) http.HandlerFunc { func railsMemcacheHandler(next http.HandlerFunc) http.HandlerFunc { cookie := conf.Auth.Cookie if len(cookie) == 0 { - logger.Fatal().Msg("no auth.cookie defined") + errlog.Fatal().Msg("no auth.cookie defined") } if len(conf.Auth.Rails.URL) == 0 { - logger.Fatal().Msg("no auth.rails.url defined") + errlog.Fatal().Msg("no auth.rails.url defined") } rURL, err := url.Parse(conf.Auth.Rails.URL) if err != nil { - logger.Fatal().Err(err).Send() + errlog.Fatal().Err(err).Send() } mc := memcache.New(rURL.Host) @@ -111,12 +111,12 @@ func railsMemcacheHandler(next http.HandlerFunc) http.HandlerFunc { func railsCookieHandler(next http.HandlerFunc) http.HandlerFunc { cookie := conf.Auth.Cookie if len(cookie) == 0 { - logger.Fatal().Msg("no auth.cookie defined") + errlog.Fatal().Msg("no auth.cookie defined") } ra, err := railsAuth(conf) if err != nil { - logger.Fatal().Err(err).Send() + errlog.Fatal().Err(err).Send() } return func(w http.ResponseWriter, r *http.Request) { diff --git a/serv/cmd.go b/serv/cmd.go index f4dee00..c603d9f 100644 --- a/serv/cmd.go +++ b/serv/cmd.go @@ -22,16 +22,18 @@ const ( ) var ( - logger *zerolog.Logger + logger zerolog.Logger + errlog zerolog.Logger conf *config confPath string db *pgxpool.Pool + schema *psql.DBSchema qcompile *qcode.Compiler pcompile *psql.Compiler ) func Init() { - logger = initLog() + initLog() rootCmd := &cobra.Command{ Use: "super-graph", @@ -135,19 +137,14 @@ e.g. db:migrate -+1 "path", "./config", "path to config files") if err := rootCmd.Execute(); err != nil { - logger.Fatal().Err(err).Send() + errlog.Fatal().Err(err).Send() } } -func initLog() *zerolog.Logger { +func initLog() { out := zerolog.ConsoleWriter{Out: os.Stderr} - logger := zerolog.New(out). - With(). - Timestamp(). - Caller(). - Logger() - - return &logger + logger = zerolog.New(out).With().Timestamp().Logger() + errlog = logger.With().Caller().Logger() } func initConf() (*config, error) { @@ -166,7 +163,7 @@ func initConf() (*config, error) { } if vi.IsSet("inherits") { - logger.Fatal().Msgf("inherited config (%s) cannot itself inherit (%s)", + errlog.Fatal().Msgf("inherited config (%s) cannot itself inherit (%s)", inherits, vi.GetString("inherits")) } @@ -183,7 +180,7 @@ func initConf() (*config, error) { logLevel, err := zerolog.ParseLevel(c.LogLevel) if err != nil { - logger.Error().Err(err).Msg("error setting log_level") + errlog.Error().Err(err).Msg("error setting log_level") } zerolog.SetGlobalLevel(logLevel) @@ -218,7 +215,7 @@ func initDB(c *config, useDB bool) (*pgx.Conn, error) { config.LogLevel = pgx.LogLevelNone } - config.Logger = NewSQLLogger(*logger) + config.Logger = NewSQLLogger(logger) db, err := pgx.ConnectConfig(context.Background(), config) if err != nil { @@ -253,7 +250,7 @@ func initDBPool(c *config) (*pgxpool.Pool, error) { config.ConnConfig.LogLevel = pgx.LogLevelNone } - config.ConnConfig.Logger = NewSQLLogger(*logger) + config.ConnConfig.Logger = NewSQLLogger(logger) // if c.DB.MaxRetries != 0 { // opt.MaxRetries = c.DB.MaxRetries @@ -276,11 +273,11 @@ func initCompiler() { qcompile, pcompile, err = initCompilers(conf) if err != nil { - logger.Fatal().Err(err).Msg("failed to initialize compilers") + errlog.Fatal().Err(err).Msg("failed to initialize compilers") } if err := initResolvers(); err != nil { - logger.Fatal().Err(err).Msg("failed to initialized resolvers") + errlog.Fatal().Err(err).Msg("failed to initialized resolvers") } } @@ -289,7 +286,7 @@ func initConfOnce() { if conf == nil { if conf, err = initConf(); err != nil { - logger.Fatal().Err(err).Msg("failed to read config") + errlog.Fatal().Err(err).Msg("failed to read config") } } } diff --git a/serv/cmd_conf.go b/serv/cmd_conf.go index 13809e2..2b93bef 100644 --- a/serv/cmd_conf.go +++ b/serv/cmd_conf.go @@ -17,11 +17,11 @@ func cmdConfDump(cmd *cobra.Command, args []string) { conf, err := initConf() if err != nil { - logger.Fatal().Err(err).Msg("failed to read config") + errlog.Fatal().Err(err).Msg("failed to read config") } if err := conf.Viper.WriteConfigAs(fname); err != nil { - logger.Fatal().Err(err).Send() + errlog.Fatal().Err(err).Send() } logger.Info().Msgf("config dumped to ./%s", fname) diff --git a/serv/cmd_migrate.go b/serv/cmd_migrate.go index a14091d..2a7b030 100644 --- a/serv/cmd_migrate.go +++ b/serv/cmd_migrate.go @@ -49,7 +49,7 @@ func cmdDBSetup(cmd *cobra.Command, args []string) { } if os.IsNotExist(err) == false { - logger.Fatal().Err(err).Msgf("unable to check if '%s' exists", sfile) + errlog.Fatal().Err(err).Msgf("unable to check if '%s' exists", sfile) } logger.Warn().Msgf("failed to read seed file '%s'", sfile) @@ -59,7 +59,7 @@ func cmdDBReset(cmd *cobra.Command, args []string) { initConfOnce() if conf.Production { - logger.Fatal().Msg("db:reset does not work in production") + errlog.Fatal().Msg("db:reset does not work in production") return } cmdDBDrop(cmd, []string{}) @@ -72,7 +72,7 @@ func cmdDBCreate(cmd *cobra.Command, args []string) { conn, err := initDB(conf, false) if err != nil { - logger.Fatal().Err(err).Msg("failed to connect to database") + errlog.Fatal().Err(err).Msg("failed to connect to database") } defer conn.Close(ctx) @@ -80,7 +80,7 @@ func cmdDBCreate(cmd *cobra.Command, args []string) { _, err = conn.Exec(ctx, sql) if err != nil { - logger.Fatal().Err(err).Msg("failed to create database") + errlog.Fatal().Err(err).Msg("failed to create database") } logger.Info().Msgf("created database '%s'", conf.DB.DBName) @@ -92,7 +92,7 @@ func cmdDBDrop(cmd *cobra.Command, args []string) { conn, err := initDB(conf, false) if err != nil { - logger.Fatal().Err(err).Msg("failed to connect to database") + errlog.Fatal().Err(err).Msg("failed to connect to database") } defer conn.Close(ctx) @@ -100,7 +100,7 @@ func cmdDBDrop(cmd *cobra.Command, args []string) { _, err = conn.Exec(ctx, sql) if err != nil { - logger.Fatal().Err(err).Msg("failed to create database") + errlog.Fatal().Err(err).Msg("failed to create database") } logger.Info().Msgf("dropped database '%s'", conf.DB.DBName) @@ -151,24 +151,24 @@ func cmdDBMigrate(cmd *cobra.Command, args []string) { conn, err := initDB(conf, true) if err != nil { - logger.Fatal().Err(err).Msg("failed to connect to database") + errlog.Fatal().Err(err).Msg("failed to connect to database") } defer conn.Close(context.Background()) m, err := migrate.NewMigrator(conn, "schema_version") if err != nil { - logger.Fatal().Err(err).Msg("failed to initializing migrator") + errlog.Fatal().Err(err).Msg("failed to initializing migrator") } m.Data = getMigrationVars() err = m.LoadMigrations(conf.MigrationsPath) if err != nil { - logger.Fatal().Err(err).Msg("failed to load migrations") + errlog.Fatal().Err(err).Msg("failed to load migrations") } if len(m.Migrations) == 0 { - logger.Fatal().Msg("No migrations found") + errlog.Fatal().Msg("No migrations found") } m.OnStart = func(sequence int32, name, direction, sql string) { @@ -187,7 +187,7 @@ func cmdDBMigrate(cmd *cobra.Command, args []string) { var n int64 n, err = strconv.ParseInt(d, 10, 32) if err != nil { - logger.Fatal().Err(err).Msg("invalid destination") + errlog.Fatal().Err(err).Msg("invalid destination") } return int32(n) } @@ -218,17 +218,15 @@ func cmdDBMigrate(cmd *cobra.Command, args []string) { if err != nil { logger.Info().Err(err).Send() - // logger.Info().Err(err).Send() - // if err, ok := err.(m.MigrationPgError); ok { // if err.Detail != "" { - // logger.Info().Err(err).Msg(err.Detail) + // info.Err(err).Msg(err.Detail) // } // if err.Position != 0 { // ele, err := ExtractErrorLine(err.Sql, int(err.Position)) // if err != nil { - // logger.Fatal().Err(err).Send() + // errlog.Fatal().Err(err).Send() // } // prefix := fmt.Sprintf() @@ -247,29 +245,29 @@ func cmdDBStatus(cmd *cobra.Command, args []string) { conn, err := initDB(conf, true) if err != nil { - logger.Fatal().Err(err).Msg("failed to connect to database") + errlog.Fatal().Err(err).Msg("failed to connect to database") } defer conn.Close(context.Background()) m, err := migrate.NewMigrator(conn, "schema_version") if err != nil { - logger.Fatal().Err(err).Msg("failed to initialize migrator") + errlog.Fatal().Err(err).Msg("failed to initialize migrator") } m.Data = getMigrationVars() err = m.LoadMigrations(conf.MigrationsPath) if err != nil { - logger.Fatal().Err(err).Msg("failed to load migrations") + errlog.Fatal().Err(err).Msg("failed to load migrations") } if len(m.Migrations) == 0 { - logger.Fatal().Msg("no migrations found") + errlog.Fatal().Msg("no migrations found") } mver, err := m.GetCurrentVersion() if err != nil { - logger.Fatal().Err(err).Msg("failed to retrieve migration") + errlog.Fatal().Err(err).Msg("failed to retrieve migration") } var status string diff --git a/serv/cmd_new.go b/serv/cmd_new.go index 44a851c..3329c7b 100644 --- a/serv/cmd_new.go +++ b/serv/cmd_new.go @@ -134,12 +134,12 @@ func ifNotExists(filePath string, doFn func(string) error) { } if os.IsNotExist(err) == false { - logger.Fatal().Err(err).Msgf("unable to check if '%s' exists", filePath) + errlog.Fatal().Err(err).Msgf("unable to check if '%s' exists", filePath) } err = doFn(filePath) if err != nil { - logger.Fatal().Err(err).Msgf("unable to create '%s'", filePath) + errlog.Fatal().Err(err).Msgf("unable to create '%s'", filePath) } logger.Info().Msgf("created '%s'", filePath) diff --git a/serv/cmd_seed.go b/serv/cmd_seed.go index 95ba3fb..62682a8 100644 --- a/serv/cmd_seed.go +++ b/serv/cmd_seed.go @@ -20,14 +20,14 @@ func cmdDBSeed(cmd *cobra.Command, args []string) { var err error if conf, err = initConf(); err != nil { - logger.Fatal().Err(err).Msg("failed to read config") + errlog.Fatal().Err(err).Msg("failed to read config") } conf.Production = false db, err = initDBPool(conf) if err != nil { - logger.Fatal().Err(err).Msg("failed to connect to database") + errlog.Fatal().Err(err).Msg("failed to connect to database") } initCompiler() @@ -36,7 +36,7 @@ func cmdDBSeed(cmd *cobra.Command, args []string) { b, err := ioutil.ReadFile(path.Join(confPath, conf.SeedFile)) if err != nil { - logger.Fatal().Err(err).Msgf("failed to read seed file '%s'", sfile) + errlog.Fatal().Err(err).Msgf("failed to read seed file '%s'", sfile) } vm := goja.New() @@ -52,7 +52,7 @@ func cmdDBSeed(cmd *cobra.Command, args []string) { _, err = vm.RunScript("seed.js", string(b)) if err != nil { - logger.Fatal().Err(err).Msg("failed to execute script") + errlog.Fatal().Err(err).Msg("failed to execute script") } logger.Info().Msg("seed script done") @@ -60,15 +60,15 @@ func cmdDBSeed(cmd *cobra.Command, args []string) { //func runFunc(call goja.FunctionCall) { func graphQLFunc(query string, data interface{}, opt map[string]string) map[string]interface{} { - b, err := json.Marshal(data) + vars, err := json.Marshal(data) if err != nil { - logger.Fatal().Err(err).Send() + errlog.Fatal().Err(err).Send() } - ctx := context.Background() + c := context.Background() if v, ok := opt["user_id"]; ok && len(v) != 0 { - ctx = context.WithValue(ctx, userIDKey, v) + c = context.WithValue(c, userIDKey, v) } var role string @@ -79,62 +79,50 @@ func graphQLFunc(query string, data interface{}, opt map[string]string) map[stri role = "user" } - c := &coreContext{Context: ctx} - c.req.Query = query - c.req.Vars = b - - st, err := c.buildStmtByRole(role) + stmts, err := buildRoleStmt([]byte(query), vars, role) if err != nil { - logger.Fatal().Err(err).Msg("graphql query failed") + errlog.Fatal().Err(err).Msg("graphql query failed") } + st := stmts[0] buf := &bytes.Buffer{} t := fasttemplate.New(st.sql, openVar, closeVar) - _, err = t.ExecuteFunc(buf, argMap(c)) - - if err == errNoUserID { - logger.Fatal().Err(err).Msg("query requires a user_id") - } + _, err = t.ExecuteFunc(buf, argMap(c, vars)) if err != nil { - logger.Fatal().Err(err).Send() + errlog.Fatal().Err(err).Send() } finalSQL := buf.String() tx, err := db.Begin(c) if err != nil { - logger.Fatal().Err(err).Send() + errlog.Fatal().Err(err).Send() } defer tx.Rollback(c) if conf.DB.SetUserID { - if err := c.setLocalUserID(tx); err != nil { - logger.Fatal().Err(err).Send() + if err := setLocalUserID(c, tx); err != nil { + errlog.Fatal().Err(err).Send() } } var root []byte if err = tx.QueryRow(c, finalSQL).Scan(&root); err != nil { - logger.Fatal().Err(err).Msg("sql query failed") + errlog.Fatal().Err(err).Msg("sql query failed") } if err := tx.Commit(c); err != nil { - logger.Fatal().Err(err).Send() - } - - res, err := c.execRemoteJoin(st.qc, st.skipped, root) - if err != nil { - logger.Fatal().Err(err).Send() + errlog.Fatal().Err(err).Send() } val := make(map[string]interface{}) - err = json.Unmarshal(res, &val) + err = json.Unmarshal(root, &val) if err != nil { - logger.Fatal().Err(err).Send() + errlog.Fatal().Err(err).Send() } return val diff --git a/serv/cmd_serv.go b/serv/cmd_serv.go index c5718db..b4fba9b 100644 --- a/serv/cmd_serv.go +++ b/serv/cmd_serv.go @@ -8,12 +8,12 @@ func cmdServ(cmd *cobra.Command, args []string) { var err error if conf, err = initConf(); err != nil { - logger.Fatal().Err(err).Msg("failed to read config") + errlog.Fatal().Err(err).Msg("failed to read config") } db, err = initDBPool(conf) if err != nil { - logger.Fatal().Err(err).Msg("failed to connect to database") + errlog.Fatal().Err(err).Msg("failed to connect to database") } initCompiler() diff --git a/serv/config.go b/serv/config.go index 698d25b..df1b6e3 100644 --- a/serv/config.go +++ b/serv/config.go @@ -68,12 +68,8 @@ type config struct { MaxRetries int `mapstructure:"max_retries"` SetUserID bool `mapstructure:"set_user_id"` - Vars map[string]string `mapstructure:"variables"` - - Defaults struct { - Filters []string - Blocklist []string - } + Vars map[string]string `mapstructure:"variables"` + Blocklist []string Tables []configTable } `mapstructure:"database"` @@ -82,6 +78,7 @@ type config struct { RolesQuery string `mapstructure:"roles_query"` Roles []configRole + roles map[string]*configRole } type configTable struct { @@ -220,16 +217,15 @@ func (c *config) Init(vi *viper.Viper) error { } c.RolesQuery = sanitize(c.RolesQuery) - - rolesMap := make(map[string]struct{}) + c.roles = make(map[string]*configRole) for i := range c.Roles { role := &c.Roles[i] - if _, ok := rolesMap[role.Name]; ok { - logger.Fatal().Msgf("duplicate role '%s' found", role.Name) + if _, ok := c.roles[role.Name]; ok { + errlog.Fatal().Msgf("duplicate role '%s' found", role.Name) } - role.Name = sanitize(role.Name) + role.Name = strings.ToLower(role.Name) role.Match = sanitize(role.Match) role.tablesMap = make(map[string]*configRoleTable) @@ -237,14 +233,16 @@ func (c *config) Init(vi *viper.Viper) error { role.tablesMap[table.Name] = &role.Tables[n] } - rolesMap[role.Name] = struct{}{} + c.roles[role.Name] = role } - if _, ok := rolesMap["user"]; !ok { - c.Roles = append(c.Roles, configRole{Name: "user"}) + if _, ok := c.roles["user"]; !ok { + u := configRole{Name: "user"} + c.Roles = append(c.Roles, u) + c.roles["user"] = &u } - if _, ok := rolesMap["anon"]; !ok { + if _, ok := c.roles["anon"]; !ok { logger.Warn().Msg("unauthenticated requests will be blocked. no role 'anon' defined") c.AuthFailBlock = true } @@ -261,7 +259,7 @@ func (c *config) validate() { name := c.Roles[i].Name if _, ok := rm[name]; ok { - logger.Fatal().Msgf("duplicate config for role '%s'", c.Roles[i].Name) + errlog.Fatal().Msgf("duplicate config for role '%s'", c.Roles[i].Name) } rm[name] = struct{}{} } @@ -272,7 +270,7 @@ func (c *config) validate() { name := c.Tables[i].Name if _, ok := tm[name]; ok { - logger.Fatal().Msgf("duplicate config for table '%s'", c.Tables[i].Name) + errlog.Fatal().Msgf("duplicate config for table '%s'", c.Tables[i].Name) } tm[name] = struct{}{} } diff --git a/serv/core.go b/serv/core.go index cbbbfd4..88a4206 100644 --- a/serv/core.go +++ b/serv/core.go @@ -8,11 +8,9 @@ import ( "fmt" "io" "net/http" - "sync" "time" "github.com/cespare/xxhash/v2" - "github.com/dosco/super-graph/jsn" "github.com/dosco/super-graph/qcode" "github.com/jackc/pgx/v4" "github.com/valyala/fasttemplate" @@ -32,6 +30,10 @@ func (c *coreContext) handleReq(w io.Writer, req *http.Request) error { c.req.ref = req.Referer() c.req.hdr = req.Header + if len(c.req.Vars) == 2 { + c.req.Vars = nil + } + if authCheck(c) { c.req.role = "user" } else { @@ -47,83 +49,38 @@ func (c *coreContext) handleReq(w io.Writer, req *http.Request) error { } func (c *coreContext) execQuery() ([]byte, error) { - var err error - var skipped uint32 - var qc *qcode.QCode var data []byte + var st *stmt + var err error if conf.Production { - var ps *preparedItem - - data, ps, err = c.resolvePreparedSQL() + data, st, err = c.resolvePreparedSQL() if err != nil { - return nil, err - } + logger.Error(). + Err(err). + Str("default_role", c.req.role). + Msg(c.req.Query) - skipped = ps.skipped - qc = ps.qc + return nil, errors.New("query failed. check logs for error") + } } else { - - data, skipped, err = c.resolveSQL() - if err != nil { + if data, st, err = c.resolveSQL(); err != nil { return nil, err } } - return c.execRemoteJoin(qc, skipped, data) + return execRemoteJoin(st, data, c.req.hdr) } -func (c *coreContext) execRemoteJoin(qc *qcode.QCode, skipped uint32, data []byte) ([]byte, error) { - var err error - - if len(data) == 0 || skipped == 0 { - return data, nil - } - - sel := qc.Selects - h := xxhash.New() - - // fetch the field name used within the db response json - // that are used to mark insertion points and the mapping between - // those field names and their select objects - fids, sfmap := parentFieldIds(h, sel, skipped) - - // fetch the field values of the marked insertion points - // these values contain the id to be used with fetching remote data - from := jsn.Get(data, fids) - - var to []jsn.Field - switch { - case len(from) == 1: - to, err = c.resolveRemote(c.req.hdr, h, from[0], sel, sfmap) - - case len(from) > 1: - to, err = c.resolveRemotes(c.req.hdr, h, from, sel, sfmap) - - default: - return nil, errors.New("something wrong no remote ids found in db response") - } - - if err != nil { - return nil, err - } - - var ob bytes.Buffer - - err = jsn.Replace(&ob, data, from, to) - if err != nil { - return nil, err - } - - return ob.Bytes(), nil -} - -func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) { +func (c *coreContext) resolvePreparedSQL() ([]byte, *stmt, error) { var tx pgx.Tx var err error - mutation := isMutation(c.req.Query) + qt := qcode.GetQType(c.req.Query) + mutation := (qt == qcode.QTMutation) + anonQuery := (qt == qcode.QTQuery && c.req.role == "anon") + useRoleQuery := len(conf.RolesQuery) != 0 && mutation useTx := useRoleQuery || conf.DB.SetUserID @@ -135,7 +92,7 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) { } if conf.DB.SetUserID { - if err := c.setLocalUserID(tx); err != nil { + if err := setLocalUserID(c, tx); err != nil { return nil, nil, err } } @@ -150,7 +107,7 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) { } else if v := c.Value(userRoleKey); v != nil { role = v.(string) - } else if mutation { + } else { role = c.req.role } @@ -162,21 +119,29 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) { var root []byte var row pgx.Row - vars := argList(c, ps.args) - if useTx { - row = tx.QueryRow(c, ps.stmt.SQL, vars...) - } else { - row = db.QueryRow(c, ps.stmt.SQL, vars...) + vars, err := argList(c, ps.args) + if err != nil { + return nil, nil, err } - if mutation { + if useTx { + row = tx.QueryRow(c, ps.sd.SQL, vars...) + } else { + row = db.QueryRow(c, ps.sd.SQL, vars...) + } + + if mutation || anonQuery { err = row.Scan(&root) } else { err = row.Scan(&role, &root) } - logger.Debug().Str("default_role", c.req.role).Str("role", role).Msg(c.req.Query) + if len(role) == 0 { + logger.Debug().Str("default_role", c.req.role).Msg(c.req.Query) + } else { + logger.Debug().Str("default_role", c.req.role).Str("role", role).Msg(c.req.Query) + } if err != nil { return nil, nil, err @@ -190,65 +155,55 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) { } } - return root, ps, nil + return root, ps.st, nil } -func (c *coreContext) resolveSQL() ([]byte, uint32, error) { +func (c *coreContext) resolveSQL() ([]byte, *stmt, error) { var tx pgx.Tx var err error - mutation := isMutation(c.req.Query) + qt := qcode.GetQType(c.req.Query) + mutation := (qt == qcode.QTMutation) + //anonQuery := (qt == qcode.QTQuery && c.req.role == "anon") + useRoleQuery := len(conf.RolesQuery) != 0 && mutation useTx := useRoleQuery || conf.DB.SetUserID if useTx { if tx, err = db.Begin(c); err != nil { - return nil, 0, err + return nil, nil, err } defer tx.Rollback(c) } if conf.DB.SetUserID { - if err := c.setLocalUserID(tx); err != nil { - return nil, 0, err + if err := setLocalUserID(c, tx); err != nil { + return nil, nil, err } } if useRoleQuery { if c.req.role, err = c.executeRoleQuery(tx); err != nil { - return nil, 0, err + return nil, nil, err } } else if v := c.Value(userRoleKey); v != nil { c.req.role = v.(string) } - stmts, err := c.buildStmt() + stmts, err := buildStmt(qt, []byte(c.req.Query), c.req.Vars, c.req.role) if err != nil { - return nil, 0, err - } - - var st *stmt - - if mutation { - st = findStmt(c.req.role, stmts) - } else { - st = &stmts[0] + return nil, nil, err } + st := &stmts[0] t := fasttemplate.New(st.sql, openVar, closeVar) - buf := &bytes.Buffer{} - _, err = t.ExecuteFunc(buf, argMap(c)) - - if err == errNoUserID { - logger.Warn().Msg("no user id found. query requires an authenicated request") - } + _, err = t.ExecuteFunc(buf, argMap(c, c.req.Vars)) if err != nil { - return nil, 0, err + return nil, nil, err } - finalSQL := buf.String() var stime time.Time @@ -258,195 +213,56 @@ func (c *coreContext) resolveSQL() ([]byte, uint32, error) { } var root []byte - var role, defaultRole string + var role string var row pgx.Row + defaultRole := c.req.role + if useTx { row = tx.QueryRow(c, finalSQL) } else { row = db.QueryRow(c, finalSQL) } - if mutation { + if len(stmts) == 1 { err = row.Scan(&root) - } else { err = row.Scan(&role, &root) - defaultRole = c.req.role - c.req.role = role - } - logger.Debug().Str("default_role", defaultRole).Str("role", role).Msg(c.req.Query) + if len(role) == 0 { + logger.Debug().Str("default_role", defaultRole).Msg(c.req.Query) + } else { + logger.Debug().Str("default_role", defaultRole).Str("role", role).Msg(c.req.Query) + } if err != nil { - return nil, 0, err + return nil, nil, err } if useTx { if err := tx.Commit(c); err != nil { - return nil, 0, err + return nil, nil, err } } - if conf.EnableTracing && len(st.qc.Selects) != 0 { + // if conf.Production == false { + // _allowList.add(&c.req) + // } + + if len(stmts) > 1 { + if st = findStmt(role, stmts); st == nil { + return nil, nil, fmt.Errorf("invalid role '%s' returned", role) + } + } + + if conf.EnableTracing { for _, id := range st.qc.Roots { c.addTrace(st.qc.Selects, id, stime) } } - if conf.Production == false { - _allowList.add(&c.req) - } - - return root, st.skipped, nil -} - -func (c *coreContext) resolveRemote( - hdr http.Header, - h *xxhash.Digest, - field jsn.Field, - sel []qcode.Select, - sfmap map[uint64]*qcode.Select) ([]jsn.Field, error) { - - // replacement data for the marked insertion points - // key and value will be replaced by whats below - toA := [1]jsn.Field{} - to := toA[:1] - - // use the json key to find the related Select object - k1 := xxhash.Sum64(field.Key) - - s, ok := sfmap[k1] - if !ok { - return nil, nil - } - p := sel[s.ParentID] - - // then use the Table nme in the Select and it's parent - // to find the resolver to use for this relationship - k2 := mkkey(h, s.Table, p.Table) - - r, ok := rmap[k2] - if !ok { - return nil, nil - } - - id := jsn.Value(field.Value) - if len(id) == 0 { - return nil, nil - } - - st := time.Now() - - b, err := r.Fn(hdr, id) - if err != nil { - return nil, err - } - - if conf.EnableTracing { - c.addTrace(sel, s.ID, st) - } - - if len(r.Path) != 0 { - b = jsn.Strip(b, r.Path) - } - - var ob bytes.Buffer - - if len(s.Cols) != 0 { - err = jsn.Filter(&ob, b, colsToList(s.Cols)) - if err != nil { - return nil, err - } - - } else { - ob.WriteString("null") - } - - to[0] = jsn.Field{Key: []byte(s.FieldName), Value: ob.Bytes()} - return to, nil -} - -func (c *coreContext) resolveRemotes( - hdr http.Header, - h *xxhash.Digest, - from []jsn.Field, - sel []qcode.Select, - sfmap map[uint64]*qcode.Select) ([]jsn.Field, error) { - - // replacement data for the marked insertion points - // key and value will be replaced by whats below - to := make([]jsn.Field, len(from)) - - var wg sync.WaitGroup - wg.Add(len(from)) - - var cerr error - - for i, id := range from { - - // use the json key to find the related Select object - k1 := xxhash.Sum64(id.Key) - - s, ok := sfmap[k1] - if !ok { - return nil, nil - } - p := sel[s.ParentID] - - // then use the Table nme in the Select and it's parent - // to find the resolver to use for this relationship - k2 := mkkey(h, s.Table, p.Table) - - r, ok := rmap[k2] - if !ok { - return nil, nil - } - - id := jsn.Value(id.Value) - if len(id) == 0 { - return nil, nil - } - - go func(n int, id []byte, s *qcode.Select) { - defer wg.Done() - - st := time.Now() - - b, err := r.Fn(hdr, id) - if err != nil { - cerr = fmt.Errorf("%s: %s", s.Table, err) - return - } - - if conf.EnableTracing { - c.addTrace(sel, s.ID, st) - } - - if len(r.Path) != 0 { - b = jsn.Strip(b, r.Path) - } - - var ob bytes.Buffer - - if len(s.Cols) != 0 { - err = jsn.Filter(&ob, b, colsToList(s.Cols)) - if err != nil { - cerr = fmt.Errorf("%s: %s", s.Table, err) - return - } - - } else { - ob.WriteString("null") - } - - to[n] = jsn.Field{Key: []byte(s.FieldName), Value: ob.Bytes()} - }(i, id, s) - } - wg.Wait() - - return to, cerr + return root, st, nil } func (c *coreContext) executeRoleQuery(tx pgx.Tx) (string, error) { @@ -460,15 +276,6 @@ func (c *coreContext) executeRoleQuery(tx pgx.Tx) (string, error) { return role, nil } -func (c *coreContext) setLocalUserID(tx pgx.Tx) error { - var err error - if v := c.Value(userIDKey); v != nil { - _, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v)) - } - - return err -} - func (c *coreContext) render(w io.Writer, data []byte) error { c.res.Data = json.RawMessage(data) return json.NewEncoder(w).Encode(c.res) @@ -560,6 +367,15 @@ func parentFieldIds(h *xxhash.Digest, sel []qcode.Select, skipped uint32) ( return fm, sm } +func setLocalUserID(c context.Context, tx pgx.Tx) error { + var err error + if v := c.Value(userIDKey); v != nil { + _, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v)) + } + + return err +} + func isSkipped(n uint32, pos uint32) bool { return ((n & (1 << pos)) != 0) } diff --git a/serv/core_build.go b/serv/core_build.go index 89d0d55..8d6d90c 100644 --- a/serv/core_build.go +++ b/serv/core_build.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/json" "errors" + "fmt" "io" "github.com/dosco/super-graph/psql" @@ -17,172 +18,171 @@ type stmt struct { sql string } -func (c *coreContext) buildStmt() ([]stmt, error) { - var vars map[string]json.RawMessage +func buildStmt(qt qcode.QType, gql, vars []byte, role string) ([]stmt, error) { + switch qt { + case qcode.QTMutation: + return buildRoleStmt(gql, vars, role) - if len(c.req.Vars) != 0 { - if err := json.Unmarshal(c.req.Vars, &vars); err != nil { + case qcode.QTQuery: + switch { + case role == "anon": + return buildRoleStmt(gql, vars, role) + + default: + return buildMultiStmt(gql, vars) + } + + default: + return nil, fmt.Errorf("unknown query type '%d'", qt) + } +} + +func buildRoleStmt(gql, vars []byte, role string) ([]stmt, error) { + ro, ok := conf.roles[role] + if !ok { + return nil, fmt.Errorf(`roles '%s' not defined in config`, role) + } + + var vm map[string]json.RawMessage + var err error + + if len(vars) != 0 { + if err := json.Unmarshal(vars, &vm); err != nil { return nil, err } } - gql := []byte(c.req.Query) - - if len(conf.Roles) == 0 { - return nil, errors.New(`no roles found ('user' and 'anon' required)`) - } - - qc, err := qcompile.Compile(gql, conf.Roles[0].Name) + qc, err := qcompile.Compile(gql, ro.Name) if err != nil { return nil, err } - stmts := make([]stmt, 0, len(conf.Roles)) - mutation := (qc.Type != qcode.QTQuery) + // For the 'anon' role in production only compile + // queries for tables defined in the config file. + if conf.Production && + ro.Name == "anon" && + hasTablesWithConfig(qc, ro) == false { + return nil, errors.New("query contains tables with no 'anon' role config") + } + + stmts := []stmt{stmt{role: ro, qc: qc}} w := &bytes.Buffer{} - for i := 1; i < len(conf.Roles); i++ { + skipped, err := pcompile.Compile(qc, w, psql.Variables(vm)) + if err != nil { + return nil, err + } + + stmts[0].skipped = skipped + stmts[0].sql = w.String() + + return stmts, nil +} + +func buildMultiStmt(gql, vars []byte) ([]stmt, error) { + var vm map[string]json.RawMessage + var err error + + if len(vars) != 0 { + if err := json.Unmarshal(vars, &vm); err != nil { + return nil, err + } + } + + if len(conf.RolesQuery) == 0 { + return buildRoleStmt(gql, vars, "user") + } + + stmts := make([]stmt, 0, len(conf.Roles)) + w := &bytes.Buffer{} + + for i := 0; i < len(conf.Roles); i++ { role := &conf.Roles[i] - // For mutations only render sql for a single role from the request - if mutation && len(c.req.role) != 0 && role.Name != c.req.role { - continue - } - - qc, err = qcompile.Compile(gql, role.Name) + qc, err := qcompile.Compile(gql, role.Name) if err != nil { return nil, err } - if conf.Production && role.Name == "anon" { - for _, id := range qc.Roots { - root := qc.Selects[id] - if _, ok := role.tablesMap[root.Table]; !ok { - continue - } - } - } - stmts = append(stmts, stmt{role: role, qc: qc}) - if mutation { - skipped, err := pcompile.Compile(qc, w, psql.Variables(vars)) - if err != nil { - return nil, err - } - - s := &stmts[len(stmts)-1] - s.skipped = skipped - s.sql = w.String() - w.Reset() + skipped, err := pcompile.Compile(qc, w, psql.Variables(vm)) + if err != nil { + return nil, err } + + s := &stmts[len(stmts)-1] + s.skipped = skipped + s.sql = w.String() + w.Reset() } - if mutation { - return stmts, nil + sql, err := renderUserQuery(stmts, vm) + if err != nil { + return nil, err } + stmts[0].sql = sql + return stmts, nil +} + +func renderUserQuery( + stmts []stmt, vars map[string]json.RawMessage) (string, error) { + + var err error + w := &bytes.Buffer{} + io.WriteString(w, `SELECT "_sg_auth_info"."role", (CASE "_sg_auth_info"."role" `) for _, s := range stmts { + if len(s.role.Match) == 0 && + s.role.Name != "user" && s.role.Name != "anon" { + continue + } io.WriteString(w, `WHEN '`) io.WriteString(w, s.role.Name) io.WriteString(w, `' THEN (`) s.skipped, err = pcompile.Compile(s.qc, w, psql.Variables(vars)) if err != nil { - return nil, err + return "", err } - io.WriteString(w, `) `) } - io.WriteString(w, `END) FROM (`) - if len(conf.RolesQuery) == 0 { - v := c.Value(userRoleKey) + io.WriteString(w, `END) FROM (SELECT (CASE WHEN EXISTS (`) + io.WriteString(w, conf.RolesQuery) + io.WriteString(w, `) THEN `) - io.WriteString(w, `VALUES ("`) - if v != nil { - io.WriteString(w, v.(string)) - } else { - io.WriteString(w, c.req.role) + io.WriteString(w, `(SELECT (CASE`) + for _, s := range stmts { + if len(s.role.Match) == 0 { + continue } - io.WriteString(w, `")) AS "_sg_auth_info"(role) LIMIT 1;`) - - } else { - - io.WriteString(w, `SELECT (CASE WHEN EXISTS (`) - io.WriteString(w, conf.RolesQuery) - io.WriteString(w, `) THEN `) - - io.WriteString(w, `(SELECT (CASE`) - for _, s := range stmts { - if len(s.role.Match) == 0 { - continue - } - io.WriteString(w, ` WHEN `) - io.WriteString(w, s.role.Match) - io.WriteString(w, ` THEN '`) - io.WriteString(w, s.role.Name) - io.WriteString(w, `'`) - } - - if len(c.req.role) == 0 { - io.WriteString(w, ` ELSE 'anon' END) FROM (`) - } else { - io.WriteString(w, ` ELSE '`) - io.WriteString(w, c.req.role) - io.WriteString(w, `' END) FROM (`) - } - - io.WriteString(w, conf.RolesQuery) - io.WriteString(w, `) AS "_sg_auth_roles_query" LIMIT 1) ELSE '`) - if len(c.req.role) == 0 { - io.WriteString(w, `anon`) - } else { - io.WriteString(w, c.req.role) - } - io.WriteString(w, `' END) FROM (VALUES (1)) AS "_sg_auth_filler") AS "_sg_auth_info"(role) LIMIT 1; `) + io.WriteString(w, ` WHEN `) + io.WriteString(w, s.role.Match) + io.WriteString(w, ` THEN '`) + io.WriteString(w, s.role.Name) + io.WriteString(w, `'`) } - stmts[0].sql = w.String() - stmts[0].role = nil + io.WriteString(w, ` ELSE 'user' END) FROM (`) + io.WriteString(w, conf.RolesQuery) + io.WriteString(w, `) AS "_sg_auth_roles_query" LIMIT 1) `) + io.WriteString(w, `ELSE 'anon' END) FROM (VALUES (1)) AS "_sg_auth_filler") AS "_sg_auth_info"(role) LIMIT 1; `) - return stmts, nil + return w.String(), nil } -func (c *coreContext) buildStmtByRole(role string) (stmt, error) { - var st stmt - var err error - - if len(role) == 0 { - return st, errors.New(`no role defined`) - } - - var vars map[string]json.RawMessage - - if len(c.req.Vars) != 0 { - if err := json.Unmarshal(c.req.Vars, &vars); err != nil { - return st, err +func hasTablesWithConfig(qc *qcode.QCode, role *configRole) bool { + for _, id := range qc.Roots { + t, err := schema.GetTable(qc.Selects[id].Table) + if err != nil { + return false + } + if _, ok := role.tablesMap[t.Name]; !ok { + return false } } - - gql := []byte(c.req.Query) - - st.qc, err = qcompile.Compile(gql, role) - if err != nil { - return st, err - } - - w := &bytes.Buffer{} - - st.skipped, err = pcompile.Compile(st.qc, w, psql.Variables(vars)) - if err != nil { - return st, err - } - - st.sql = w.String() - - return st, nil - + return true } diff --git a/serv/core_remote.go b/serv/core_remote.go new file mode 100644 index 0000000..3fcb8f4 --- /dev/null +++ b/serv/core_remote.go @@ -0,0 +1,197 @@ +package serv + +import ( + "bytes" + "errors" + "fmt" + "net/http" + "sync" + + "github.com/cespare/xxhash/v2" + "github.com/dosco/super-graph/jsn" + "github.com/dosco/super-graph/qcode" +) + +func execRemoteJoin(st *stmt, data []byte, hdr http.Header) ([]byte, error) { + var err error + + if len(data) == 0 || st.skipped == 0 { + return data, nil + } + + sel := st.qc.Selects + h := xxhash.New() + + // fetch the field name used within the db response json + // that are used to mark insertion points and the mapping between + // those field names and their select objects + fids, sfmap := parentFieldIds(h, sel, st.skipped) + + // fetch the field values of the marked insertion points + // these values contain the id to be used with fetching remote data + from := jsn.Get(data, fids) + var to []jsn.Field + + switch { + case len(from) == 1: + to, err = resolveRemote(hdr, h, from[0], sel, sfmap) + + case len(from) > 1: + to, err = resolveRemotes(hdr, h, from, sel, sfmap) + + default: + return nil, errors.New("something wrong no remote ids found in db response") + } + + if err != nil { + return nil, err + } + + var ob bytes.Buffer + + err = jsn.Replace(&ob, data, from, to) + if err != nil { + return nil, err + } + + return ob.Bytes(), nil +} + +func resolveRemote( + hdr http.Header, + h *xxhash.Digest, + field jsn.Field, + sel []qcode.Select, + sfmap map[uint64]*qcode.Select) ([]jsn.Field, error) { + + // replacement data for the marked insertion points + // key and value will be replaced by whats below + toA := [1]jsn.Field{} + to := toA[:1] + + // use the json key to find the related Select object + k1 := xxhash.Sum64(field.Key) + + s, ok := sfmap[k1] + if !ok { + return nil, nil + } + p := sel[s.ParentID] + + // then use the Table nme in the Select and it's parent + // to find the resolver to use for this relationship + k2 := mkkey(h, s.Table, p.Table) + + r, ok := rmap[k2] + if !ok { + return nil, nil + } + + id := jsn.Value(field.Value) + if len(id) == 0 { + return nil, nil + } + + //st := time.Now() + + b, err := r.Fn(hdr, id) + if err != nil { + return nil, err + } + + if len(r.Path) != 0 { + b = jsn.Strip(b, r.Path) + } + + var ob bytes.Buffer + + if len(s.Cols) != 0 { + err = jsn.Filter(&ob, b, colsToList(s.Cols)) + if err != nil { + return nil, err + } + + } else { + ob.WriteString("null") + } + + to[0] = jsn.Field{Key: []byte(s.FieldName), Value: ob.Bytes()} + return to, nil +} + +func resolveRemotes( + hdr http.Header, + h *xxhash.Digest, + from []jsn.Field, + sel []qcode.Select, + sfmap map[uint64]*qcode.Select) ([]jsn.Field, error) { + + // replacement data for the marked insertion points + // key and value will be replaced by whats below + to := make([]jsn.Field, len(from)) + + var wg sync.WaitGroup + wg.Add(len(from)) + + var cerr error + + for i, id := range from { + + // use the json key to find the related Select object + k1 := xxhash.Sum64(id.Key) + + s, ok := sfmap[k1] + if !ok { + return nil, nil + } + p := sel[s.ParentID] + + // then use the Table nme in the Select and it's parent + // to find the resolver to use for this relationship + k2 := mkkey(h, s.Table, p.Table) + + r, ok := rmap[k2] + if !ok { + return nil, nil + } + + id := jsn.Value(id.Value) + if len(id) == 0 { + return nil, nil + } + + go func(n int, id []byte, s *qcode.Select) { + defer wg.Done() + + //st := time.Now() + + b, err := r.Fn(hdr, id) + if err != nil { + cerr = fmt.Errorf("%s: %s", s.Table, err) + return + } + + if len(r.Path) != 0 { + b = jsn.Strip(b, r.Path) + } + + var ob bytes.Buffer + + if len(s.Cols) != 0 { + err = jsn.Filter(&ob, b, colsToList(s.Cols)) + if err != nil { + cerr = fmt.Errorf("%s: %s", s.Table, err) + return + } + + } else { + ob.WriteString("null") + } + + to[n] = jsn.Field{Key: []byte(s.FieldName), Value: ob.Bytes()} + }(i, id, s) + } + wg.Wait() + + return to, cerr +} diff --git a/serv/fuzz.go b/serv/fuzz.go index 5464cce..f00eed5 100644 --- a/serv/fuzz.go +++ b/serv/fuzz.go @@ -4,7 +4,6 @@ package serv func Fuzz(data []byte) int { gql := string(data) - isMutation(gql) gqlHash(gql, nil, "") return 1 diff --git a/serv/fuzz_test.go b/serv/fuzz_test.go index 6ff030f..68fe2c6 100644 --- a/serv/fuzz_test.go +++ b/serv/fuzz_test.go @@ -10,7 +10,6 @@ func TestFuzzCrashers(t *testing.T) { } for _, f := range crashers { - isMutation(f) gqlHash(f, nil, "") } } diff --git a/serv/http.go b/serv/http.go index 83939c3..4f6fa99 100644 --- a/serv/http.go +++ b/serv/http.go @@ -21,7 +21,6 @@ const ( var ( upgrader = websocket.Upgrader{} - errNoUserID = errors.New("no user_id available") errUnauthorized = errors.New("not authorized") ) @@ -78,7 +77,7 @@ func apiv1Http(w http.ResponseWriter, r *http.Request) { b, err := ioutil.ReadAll(io.LimitReader(r.Body, maxReadBytes)) if err != nil { - logger.Err(err).Msg("failed to read request body") + errlog.Error().Err(err).Msg("failed to read request body") errorResp(w, err) return } @@ -86,7 +85,7 @@ func apiv1Http(w http.ResponseWriter, r *http.Request) { err = json.Unmarshal(b, &ctx.req) if err != nil { - logger.Err(err).Msg("failed to decode json request body") + errlog.Error().Err(err).Msg("failed to decode json request body") errorResp(w, err) return } @@ -105,7 +104,7 @@ func apiv1Http(w http.ResponseWriter, r *http.Request) { } if err != nil { - logger.Err(err).Msg("failed to handle request") + errlog.Error().Err(err).Msg("failed to handle request") errorResp(w, err) return } diff --git a/serv/prepare.go b/serv/prepare.go index 36631dd..75b3c55 100644 --- a/serv/prepare.go +++ b/serv/prepare.go @@ -3,7 +3,6 @@ package serv import ( "bytes" "context" - "encoding/json" "fmt" "io" @@ -14,10 +13,9 @@ import ( ) type preparedItem struct { - stmt *pgconn.StatementDescription - args [][]byte - skipped uint32 - qc *qcode.QCode + sd *pgconn.StatementDescription + args [][]byte + st *stmt } var ( @@ -25,85 +23,119 @@ var ( ) func initPreparedList() { - ctx := context.Background() - - tx, err := db.Begin(ctx) - if err != nil { - logger.Fatal().Err(err).Send() - } - defer tx.Rollback(ctx) - + c := context.Background() _preparedList = make(map[string]*preparedItem) - if err := prepareRoleStmt(ctx, tx); err != nil { - logger.Fatal().Err(err).Msg("failed to prepare get role statement") + tx, err := db.Begin(c) + if err != nil { + errlog.Fatal().Err(err).Send() } + defer tx.Rollback(c) + + err = prepareRoleStmt(c, tx) + if err != nil { + errlog.Fatal().Err(err).Msg("failed to prepare get role statement") + } + + if err := tx.Commit(c); err != nil { + errlog.Fatal().Err(err).Send() + } + + success := 0 for _, v := range _allowList.list { - err := prepareStmt(ctx, tx, v.gql, v.vars) - if err != nil { - logger.Warn().Str("gql", v.gql).Err(err).Send() - } - } - - if err := tx.Commit(ctx); err != nil { - logger.Fatal().Err(err).Send() - } - - logger.Info().Msgf("Registered %d queries from allow.list as prepared statements", len(_allowList.list)) -} - -func prepareStmt(ctx context.Context, tx pgx.Tx, gql string, varBytes json.RawMessage) error { - if len(gql) == 0 { - return nil - } - - c := &coreContext{Context: context.Background()} - c.req.Query = gql - c.req.Vars = varBytes - - stmts, err := c.buildStmt() - if err != nil { - return err - } - - if len(stmts) != 0 && stmts[0].qc.Type == qcode.QTQuery { - c.req.Vars = nil - } - - for _, s := range stmts { - if len(s.sql) == 0 { + if len(v.gql) == 0 { continue } - finalSQL, am := processTemplate(s.sql) + err := prepareStmt(c, v.gql, v.vars) + if err == nil { + success++ + continue + } - pstmt, err := tx.Prepare(c.Context, "", finalSQL) + if len(v.vars) == 0 { + logger.Warn().Err(err).Msg(v.gql) + } else { + logger.Warn().Err(err).Msgf("%s %s", v.vars, v.gql) + } + } + + logger.Info(). + Msgf("Registered %d of %d queries from allow.list as prepared statements", + success, len(_allowList.list)) +} + +func prepareStmt(c context.Context, gql string, vars []byte) error { + qt := qcode.GetQType(gql) + q := []byte(gql) + + tx, err := db.Begin(c) + if err != nil { + return err + } + defer tx.Rollback(c) + + switch qt { + case qcode.QTQuery: + stmts1, err := buildMultiStmt(q, vars) if err != nil { return err } - var key string - - if s.role == nil { - key = gqlHash(gql, c.req.Vars, "") - } else { - key = gqlHash(gql, c.req.Vars, s.role.Name) + err = prepare(c, tx, &stmts1[0], gqlHash(gql, vars, "user")) + if err != nil { + return err } - _preparedList[key] = &preparedItem{ - stmt: pstmt, - args: am, - skipped: s.skipped, - qc: s.qc, + stmts2, err := buildRoleStmt(q, vars, "anon") + if err != nil { + return err } + err = prepare(c, tx, &stmts2[0], gqlHash(gql, vars, "anon")) + if err != nil { + return err + } + + case qcode.QTMutation: + for _, role := range conf.Roles { + stmts, err := buildRoleStmt(q, vars, role.Name) + if err != nil { + return err + } + + err = prepare(c, tx, &stmts[0], gqlHash(gql, vars, role.Name)) + if err != nil { + return err + } + } + } + + if err := tx.Commit(c); err != nil { + return err } return nil } -func prepareRoleStmt(ctx context.Context, tx pgx.Tx) error { +func prepare(c context.Context, tx pgx.Tx, st *stmt, key string) error { + finalSQL, am := processTemplate(st.sql) + + sd, err := tx.Prepare(c, "", finalSQL) + if err != nil { + return err + } + + _preparedList[key] = &preparedItem{ + sd: sd, + args: am, + st: st, + } + return nil +} + +func prepareRoleStmt(c context.Context, tx pgx.Tx) error { if len(conf.RolesQuery) == 0 { return nil } @@ -128,7 +160,7 @@ func prepareRoleStmt(ctx context.Context, tx pgx.Tx) error { roleSQL, _ := processTemplate(w.String()) - _, err := tx.Prepare(ctx, "_sg_get_role", roleSQL) + _, err := tx.Prepare(c, "_sg_get_role", roleSQL) if err != nil { return err } diff --git a/serv/reload.go b/serv/reload.go index fbce73c..0e92da2 100644 --- a/serv/reload.go +++ b/serv/reload.go @@ -168,7 +168,7 @@ func Do(log func(string, ...interface{}), additional ...dir) error { func ReExec() { err := syscall.Exec(binSelf, append([]string{binSelf}, os.Args[1:]...), os.Environ()) if err != nil { - logger.Fatal().Err(err).Msg("cannot restart") + errlog.Fatal().Err(err).Msg("cannot restart") } } diff --git a/serv/reso.go b/serv/reso.go index 6e132df..827a80e 100644 --- a/serv/reso.go +++ b/serv/reso.go @@ -117,7 +117,7 @@ func buildFn(r configRemote) func(http.Header, []byte) ([]byte, error) { res, err := client.Do(req) if err != nil { - logger.Error().Err(err).Msgf("Failed to connect to: %s", uri) + errlog.Error().Err(err).Msgf("Failed to connect to: %s", uri) return nil, err } defer res.Body.Close() diff --git a/serv/serv.go b/serv/serv.go index 8036119..09143ef 100644 --- a/serv/serv.go +++ b/serv/serv.go @@ -15,13 +15,15 @@ import ( ) func initCompilers(c *config) (*qcode.Compiler, *psql.Compiler, error) { - schema, err := psql.NewDBSchema(db, c.getAliasMap()) + var err error + + schema, err = psql.NewDBSchema(db, c.getAliasMap()) if err != nil { return nil, nil, err } conf := qcode.Config{ - Blocklist: c.DB.Defaults.Blocklist, + Blocklist: c.DB.Blocklist, KeepArgs: false, } @@ -106,7 +108,7 @@ func initWatcher(cpath string) { go func() { err := Do(logger.Printf, d) if err != nil { - logger.Fatal().Err(err).Send() + errlog.Fatal().Err(err).Send() } }() } @@ -139,7 +141,7 @@ func startHTTP() { <-sigint if err := srv.Shutdown(context.Background()); err != nil { - logger.Error().Err(err).Msg("shutdown signal received") + errlog.Error().Err(err).Msg("shutdown signal received") } close(idleConnsClosed) }() @@ -148,18 +150,14 @@ func startHTTP() { db.Close() }) - var ident string - - if len(conf.AppName) == 0 { - ident = conf.Env - } else { - ident = conf.AppName - } - - fmt.Printf("%s listening on %s (%s)\n", serverName, hostPort, ident) + logger.Info(). + Str("host_post", hostPort). + Str("app_name", conf.AppName). + Str("env", conf.Env). + Msgf("%s listening", serverName) if err := srv.ListenAndServe(); err != http.ErrServerClosed { - logger.Error().Err(err).Msg("server closed") + errlog.Error().Err(err).Msg("server closed") } <-idleConnsClosed diff --git a/serv/utils.go b/serv/utils.go index d671813..2452ed0 100644 --- a/serv/utils.go +++ b/serv/utils.go @@ -106,19 +106,6 @@ func al(b byte) bool { return (b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z') || (b >= '0' && b <= '9') } -func isMutation(sql string) bool { - for i := range sql { - b := sql[i] - if b == '{' { - return false - } - if al(b) { - return (b == 'm' || b == 'M') - } - } - return false -} - func findStmt(role string, stmts []stmt) *stmt { for i := range stmts { if stmts[i].role.Name != role { diff --git a/tmpl/dev.yml b/tmpl/dev.yml index 0f20b0e..7a5e990 100644 --- a/tmpl/dev.yml +++ b/tmpl/dev.yml @@ -101,18 +101,14 @@ database: variables: admin_account_id: "5" - # Define defaults to for the field key and values below - defaults: - # filters: ["{ user_id: { eq: $user_id } }"] - - # Field and table names that you wish to block - blocklist: - - ar_internal_metadata - - schema_migrations - - secret - - password - - encrypted - - token + # Field and table names that you wish to block + blocklist: + - ar_internal_metadata + - schema_migrations + - secret + - password + - encrypted + - token tables: - name: customers