diff --git a/config/dev.yml b/config/dev.yml index 2315ce4..3bf2481 100644 --- a/config/dev.yml +++ b/config/dev.yml @@ -5,11 +5,11 @@ web_ui: true # debug, info, warn, error, fatal, panic log_level: "debug" -# Disable this in development to get a list of -# queries used. When enabled super graph -# will only allow queries from this list -# List saved to ./config/allow.list -use_allow_list: false +# When production mode is 'true' only queries +# from the allow list are permitted. +# When it's 'false' all queries are saved to the +# the allow list in ./config/allow.list +production: true # Throw a 401 on auth failure for queries that need auth auth_fail_block: false diff --git a/config/prod.yml b/config/prod.yml index 70d0cd2..c733472 100644 --- a/config/prod.yml +++ b/config/prod.yml @@ -9,11 +9,11 @@ web_ui: false # debug, info, warn, error, fatal, panic, disable log_level: "info" -# Disable this in development to get a list of -# queries used. When enabled super graph -# will only allow queries from this list -# List saved to ./config/allow.list -use_allow_list: true +# When production mode is 'true' only queries +# from the allow list are permitted. +# When it's 'false' all queries are saved to the +# the allow list in ./config/allow.list +production: true # Throw a 401 on auth failure for queries that need auth auth_fail_block: true diff --git a/psql/mutate.go b/psql/mutate.go index 8c3e939..49c8c15 100644 --- a/psql/mutate.go +++ b/psql/mutate.go @@ -137,16 +137,23 @@ func (c *compilerContext) renderInsertUpdateColumns(qc *qcode.QCode, w io.Writer } for i := range root.PresetList { + cn := root.PresetList[i] + col, ok := ti.Columns[cn] + if !ok { + continue + } if i != 0 { io.WriteString(c.w, `, `) } if values { io.WriteString(c.w, `'`) - io.WriteString(c.w, root.PresetMap[root.PresetList[i]]) - io.WriteString(c.w, `'`) + io.WriteString(c.w, root.PresetMap[cn]) + io.WriteString(c.w, `' :: `) + io.WriteString(c.w, col.Type) + } else { io.WriteString(c.w, `"`) - io.WriteString(c.w, root.PresetList[i]) + io.WriteString(c.w, cn) io.WriteString(c.w, `"`) } } diff --git a/psql/mutate_test.go b/psql/mutate_test.go index c9538bb..b719255 100644 --- a/psql/mutate_test.go +++ b/psql/mutate_test.go @@ -250,7 +250,7 @@ func simpleInsertWithPresets(t *testing.T) { } }` - sql := `WITH "products" AS (WITH "input" AS (SELECT {{data}}::json AS j) INSERT INTO "products" ("name", "price", "created_at", "updated_at", "user_id") SELECT "name", "price", 'now', 'now', '$user_id' FROM input i, json_populate_record(NULL::products, i.j) t RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id" FROM "products") AS "products_0") AS "done_1337"` + sql := `WITH "products" AS (WITH "input" AS (SELECT {{data}}::json AS j) INSERT INTO "products" ("name", "price", "created_at", "updated_at", "user_id") SELECT "name", "price", 'now' :: timestamp without time zone, 'now' :: timestamp without time zone, '$user_id' :: bigint FROM input i, json_populate_record(NULL::products, i.j) t RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id" FROM "products") AS "products_0") AS "done_1337"` vars := map[string]json.RawMessage{ "data": json.RawMessage(`{"name": "Tomato", "price": 5.76}`), @@ -273,7 +273,7 @@ func simpleUpdateWithPresets(t *testing.T) { } }` - sql := `WITH "products" AS (WITH "input" AS (SELECT {{data}}::json AS j) UPDATE "products" SET ("name", "price", "updated_at") = (SELECT "name", "price", 'now' FROM input i, json_populate_record(NULL::products, i.j) t) WHERE (("products"."user_id") = {{user_id}}) RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id" FROM "products") AS "products_0") AS "done_1337"` + sql := `WITH "products" AS (WITH "input" AS (SELECT {{data}}::json AS j) UPDATE "products" SET ("name", "price", "updated_at") = (SELECT "name", "price", 'now' :: timestamp without time zone FROM input i, json_populate_record(NULL::products, i.j) t) WHERE (("products"."user_id") = {{user_id}}) RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id" FROM "products") AS "products_0") AS "done_1337"` vars := map[string]json.RawMessage{ "data": json.RawMessage(`{"name": "Apple", "price": 1.25}`), diff --git a/psql/query.go b/psql/query.go index 4d5b966..f2c8c1b 100644 --- a/psql/query.go +++ b/psql/query.go @@ -340,9 +340,9 @@ func (c *compilerContext) renderLateralJoinClose(sel *qcode.Select) error { return nil } -func (c *compilerContext) renderJoin(sel *qcode.Select) error { +func (c *compilerContext) renderJoin(sel *qcode.Select, ti *DBTableInfo) error { parent := &c.s[sel.ParentID] - return c.renderJoinByName(sel.Table, parent.Table, parent.ID) + return c.renderJoinByName(ti.Name, parent.Table, parent.ID) } func (c *compilerContext) renderJoinByName(table, parent string, id int32) error { @@ -607,7 +607,7 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo, } if !isRoot { - if err := c.renderJoin(sel); err != nil { + if err := c.renderJoin(sel, ti); err != nil { return err } @@ -691,7 +691,7 @@ func (c *compilerContext) renderOrderByColumns(sel *qcode.Select, ti *DBTableInf func (c *compilerContext) renderRelationship(sel *qcode.Select, ti *DBTableInfo) error { parent := c.s[sel.ParentID] - return c.renderRelationshipByName(sel.Table, parent.Table, parent.ID) + return c.renderRelationshipByName(ti.Name, parent.Table, parent.ID) } func (c *compilerContext) renderRelationshipByName(table, parent string, id int32) error { diff --git a/serv/allow.go b/serv/allow.go index b86d51d..b238580 100644 --- a/serv/allow.go +++ b/serv/allow.go @@ -71,7 +71,7 @@ func initAllowList(cpath string) { } if len(_allowList.filepath) == 0 { - if conf.UseAllowList { + if conf.Production { logger.Fatal().Msg("allow.list not found") } diff --git a/serv/cmd_seed.go b/serv/cmd_seed.go index 514c543..0c1f061 100644 --- a/serv/cmd_seed.go +++ b/serv/cmd_seed.go @@ -21,7 +21,7 @@ func cmdDBSeed(cmd *cobra.Command, args []string) { logger.Fatal().Err(err).Msg("failed to read config") } - conf.UseAllowList = false + conf.Production = false db, err = initDBPool(conf) if err != nil { diff --git a/serv/config.go b/serv/config.go index fed61bd..c3f840c 100644 --- a/serv/config.go +++ b/serv/config.go @@ -23,6 +23,7 @@ type config struct { LogLevel string `mapstructure:"log_level"` EnableTracing bool `mapstructure:"enable_tracing"` UseAllowList bool `mapstructure:"use_allow_list"` + Production bool WatchAndReload bool `mapstructure:"reload_on_config_change"` AuthFailBlock bool `mapstructure:"auth_fail_block"` SeedFile string `mapstructure:"seed_file"` @@ -142,9 +143,10 @@ type configRoleTable struct { } type configRole struct { - Name string - Match string - Tables []configRoleTable + Name string + Match string + Tables []configRoleTable + tablesMap map[string]*configRoleTable } func newConfig(name string) *viper.Viper { @@ -195,6 +197,10 @@ func (c *config) Init(vi *viper.Viper) error { c.Tables = c.DB.Tables } + if c.UseAllowList { + c.Production = true + } + for k, v := range c.Inflections { flect.AddPlural(k, v) } @@ -219,13 +225,19 @@ func (c *config) Init(vi *viper.Viper) error { rolesMap := make(map[string]struct{}) for i := range c.Roles { - role := c.Roles[i] + role := &c.Roles[i] if _, ok := rolesMap[role.Name]; ok { logger.Fatal().Msgf("duplicate role '%s' found", role.Name) } role.Name = sanitize(role.Name) role.Match = sanitize(role.Match) + role.tablesMap = make(map[string]*configRoleTable) + + for n, table := range role.Tables { + role.tablesMap[table.Name] = &role.Tables[n] + } + rolesMap[role.Name] = struct{}{} } diff --git a/serv/core.go b/serv/core.go index 2ac9b28..06b88d8 100644 --- a/serv/core.go +++ b/serv/core.go @@ -54,7 +54,7 @@ func (c *coreContext) execQuery() ([]byte, error) { logger.Debug().Str("role", c.req.role).Msg(c.req.Query) - if conf.UseAllowList { + if conf.Production { var ps *preparedItem data, ps, err = c.resolvePreparedSQL() @@ -256,7 +256,7 @@ func (c *coreContext) resolveSQL() ([]byte, uint32, error) { stime) } - if conf.UseAllowList == false { + if conf.Production == false { _allowList.add(&c.req) } diff --git a/serv/core_build.go b/serv/core_build.go index fd86003..c08263d 100644 --- a/serv/core_build.go +++ b/serv/core_build.go @@ -41,17 +41,22 @@ func (c *coreContext) buildStmt() ([]stmt, error) { mutation := (qc.Type != qcode.QTQuery) w := &bytes.Buffer{} - for i := range conf.Roles { + for i := 1; i < len(conf.Roles); i++ { role := &conf.Roles[i] + // For mutations only render sql for a single role from the request if mutation && len(c.req.role) != 0 && role.Name != c.req.role { continue } - if i > 0 { - qc, err = qcompile.Compile(gql, role.Name) - if err != nil { - return nil, err + qc, err = qcompile.Compile(gql, role.Name) + if err != nil { + return nil, err + } + + if conf.Production && role.Name == "anon" { + if _, ok := role.tablesMap[qc.Selects[0].Table]; !ok { + continue } } diff --git a/serv/reload.go b/serv/reload.go index 4364c79..fbce73c 100644 --- a/serv/reload.go +++ b/serv/reload.go @@ -108,7 +108,7 @@ func Do(log func(string, ...interface{}), additional ...dir) error { // Ensure that we use the correct events, as they are not uniform across // platforms. See https://github.com/fsnotify/fsnotify/issues/74 - if conf.UseAllowList == false && strings.HasSuffix(event.Name, "/allow.list") { + if conf.Production == false && strings.HasSuffix(event.Name, "/allow.list") { continue } diff --git a/tmpl/dev.yml b/tmpl/dev.yml index 165206b..1ac418f 100644 --- a/tmpl/dev.yml +++ b/tmpl/dev.yml @@ -5,11 +5,11 @@ web_ui: true # debug, info, warn, error, fatal, panic log_level: "debug" -# Disable this in development to get a list of -# queries used. When enabled super graph -# will only allow queries from this list -# List saved to ./config/allow.list -use_allow_list: false +# When production mode is 'true' only queries +# from the allow list are permitted. +# When it's 'false' all queries are saved to the +# the allow list in ./config/allow.list +production: false # Throw a 401 on auth failure for queries that need auth auth_fail_block: false diff --git a/tmpl/prod.yml b/tmpl/prod.yml index ebfc3d3..f0ada67 100644 --- a/tmpl/prod.yml +++ b/tmpl/prod.yml @@ -8,12 +8,11 @@ web_ui: false # debug, info, warn, error, fatal, panic, disable log_level: "info" - -# Disable this in development to get a list of -# queries used. When enabled super graph -# will only allow queries from this list -# List saved to ./config/allow.list -use_allow_list: true +# When production mode is 'true' only queries +# from the allow list are permitted. +# When it's 'false' all queries are saved to the +# the allow list in ./config/allow.list +production: true # Throw a 401 on auth failure for queries that need auth auth_fail_block: true