Add skip query selectors that require auth in anon role

This commit is contained in:
Vikram Rangnekar 2020-01-20 23:38:17 -05:00
parent a0b8907c3c
commit 2d466bfb12
5 changed files with 132 additions and 58 deletions

View File

@ -82,17 +82,21 @@ func (co *Compiler) compileQuery(qc *qcode.QCode, w io.Writer) (uint32, error) {
multiRoot := (len(qc.Roots) > 1) multiRoot := (len(qc.Roots) > 1)
st := NewIntStack() st := NewIntStack()
si := 0
if multiRoot { if multiRoot {
io.WriteString(c.w, `SELECT row_to_json("json_root") FROM (SELECT `) io.WriteString(c.w, `SELECT row_to_json("json_root") FROM (SELECT `)
for i, id := range qc.Roots { for _, id := range qc.Roots {
root := qc.Selects[id] root := qc.Selects[id]
if root.SkipRender {
continue
}
st.Push(root.ID + closeBlock) st.Push(root.ID + closeBlock)
st.Push(root.ID) st.Push(root.ID)
if i != 0 { if si != 0 {
io.WriteString(c.w, `, `) io.WriteString(c.w, `, `)
} }
@ -103,24 +107,34 @@ func (co *Compiler) compileQuery(qc *qcode.QCode, w io.Writer) (uint32, error) {
io.WriteString(c.w, `"`) io.WriteString(c.w, `"`)
alias(c.w, root.FieldName) alias(c.w, root.FieldName)
si++
} }
io.WriteString(c.w, ` FROM `) if si != 0 {
io.WriteString(c.w, ` FROM `)
}
} else { } else {
root := qc.Selects[0] root := qc.Selects[0]
if !root.SkipRender {
io.WriteString(c.w, `SELECT json_object_agg(`)
io.WriteString(c.w, `'`)
io.WriteString(c.w, root.FieldName)
io.WriteString(c.w, `', `)
io.WriteString(c.w, `json_`)
int2string(c.w, root.ID)
io.WriteString(c.w, `SELECT json_object_agg(`) st.Push(root.ID + closeBlock)
io.WriteString(c.w, `'`) st.Push(root.ID)
io.WriteString(c.w, root.FieldName)
io.WriteString(c.w, `', `)
io.WriteString(c.w, `json_`)
int2string(c.w, root.ID)
st.Push(root.ID + closeBlock) io.WriteString(c.w, `) FROM `)
st.Push(root.ID) si++
}
}
io.WriteString(c.w, `) FROM `) if si == 0 {
return 0, errors.New("all tables skipped. cannot render query")
} }
var ignored uint32 var ignored uint32
@ -161,6 +175,9 @@ func (co *Compiler) compileQuery(qc *qcode.QCode, w io.Writer) (uint32, error) {
continue continue
} }
child := &c.s[cid] child := &c.s[cid]
if child.SkipRender {
continue
}
st.Push(child.ID + closeBlock) st.Push(child.ID + closeBlock)
st.Push(child.ID) st.Push(child.ID)
@ -475,18 +492,22 @@ func (c *compilerContext) renderRemoteRelColumns(sel *qcode.Select, ti *DBTableI
} }
func (c *compilerContext) renderJoinedColumns(sel *qcode.Select, ti *DBTableInfo, skipped uint32) error { func (c *compilerContext) renderJoinedColumns(sel *qcode.Select, ti *DBTableInfo, skipped uint32) error {
colsRendered := len(sel.Cols) != 0
// columns previously rendered
i := len(sel.Cols)
for _, id := range sel.Children { for _, id := range sel.Children {
skipThis := hasBit(skipped, uint32(id)) if hasBit(skipped, uint32(id)) {
if colsRendered && !skipThis {
io.WriteString(c.w, ", ")
}
if skipThis {
continue continue
} }
childSel := &c.s[id] childSel := &c.s[id]
if childSel.SkipRender {
continue
}
if i != 0 {
io.WriteString(c.w, ", ")
}
//fmt.Fprintf(w, `"%s_%d_join"."%s" AS "%s"`, //fmt.Fprintf(w, `"%s_%d_join"."%s" AS "%s"`,
//s.Name, s.ID, s.Name, s.FieldName) //s.Name, s.ID, s.Name, s.FieldName)
@ -500,6 +521,7 @@ func (c *compilerContext) renderJoinedColumns(sel *qcode.Select, ti *DBTableInfo
io.WriteString(c.w, `" AS "`) io.WriteString(c.w, `" AS "`)
io.WriteString(c.w, childSel.FieldName) io.WriteString(c.w, childSel.FieldName)
io.WriteString(c.w, `"`) io.WriteString(c.w, `"`)
i++
} }
return nil return nil
@ -632,10 +654,6 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
} }
} }
// if i != 0 && len(sel.OrderBy) != 0 {
// io.WriteString(c.w, ", ")
// }
for _, ob := range sel.OrderBy { for _, ob := range sel.OrderBy {
if _, ok := colmap[ob.Col]; ok { if _, ok := colmap[ob.Col]; ok {
continue continue

View File

@ -463,6 +463,30 @@ func multiRoot(t *testing.T) {
} }
} }
func skipUserIDForAnonRole(t *testing.T) {
gql := `query {
products {
id
name
user(where: { id: { eq: $user_id } }) {
id
email
}
}
}`
sql := `SELECT json_object_agg('products', json_0) FROM (SELECT coalesce(json_agg("json_0"), '[]') AS "json_0" FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name", "products"."user_id" FROM "products" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "json_agg_0") AS "sel_0"`
resSQL, err := compileGQLToPSQL(gql, nil, "anon")
if err != nil {
t.Fatal(err)
}
if string(resSQL) != sql {
t.Fatal(errNotExpected)
}
}
func blockedQuery(t *testing.T) { func blockedQuery(t *testing.T) {
gql := `query { gql := `query {
user(id: 5, where: { id: { gt: 3 } }) { user(id: 5, where: { id: { gt: 3 } }) {
@ -524,6 +548,7 @@ func TestCompileQuery(t *testing.T) {
t.Run("queryWithVariables", queryWithVariables) t.Run("queryWithVariables", queryWithVariables)
t.Run("withWhereOnRelations", withWhereOnRelations) t.Run("withWhereOnRelations", withWhereOnRelations)
t.Run("multiRoot", multiRoot) t.Run("multiRoot", multiRoot)
t.Run("skipUserIDForAnonRole", skipUserIDForAnonRole)
t.Run("blockedQuery", blockedQuery) t.Run("blockedQuery", blockedQuery)
t.Run("blockedFunctions", blockedFunctions) t.Run("blockedFunctions", blockedFunctions)
} }

View File

@ -45,6 +45,7 @@ type trval struct {
query struct { query struct {
limit string limit string
fil *Exp fil *Exp
filNU bool
cols map[string]struct{} cols map[string]struct{}
disable struct { disable struct {
funcs bool funcs bool
@ -53,6 +54,7 @@ type trval struct {
insert struct { insert struct {
fil *Exp fil *Exp
filNU bool
cols map[string]struct{} cols map[string]struct{}
psmap map[string]string psmap map[string]string
pslist []string pslist []string
@ -60,14 +62,16 @@ type trval struct {
update struct { update struct {
fil *Exp fil *Exp
filNU bool
cols map[string]struct{} cols map[string]struct{}
psmap map[string]string psmap map[string]string
pslist []string pslist []string
} }
delete struct { delete struct {
fil *Exp fil *Exp
cols map[string]struct{} filNU bool
cols map[string]struct{}
} }
} }
@ -88,21 +92,21 @@ func (trv *trval) allowedColumns(qt QType) map[string]struct{} {
return nil return nil
} }
func (trv *trval) filter(qt QType) *Exp { func (trv *trval) filter(qt QType) (*Exp, bool) {
switch qt { switch qt {
case QTQuery: case QTQuery:
return trv.query.fil return trv.query.fil, trv.query.filNU
case QTInsert: case QTInsert:
return trv.insert.fil return trv.insert.fil, trv.insert.filNU
case QTUpdate: case QTUpdate:
return trv.update.fil return trv.update.fil, trv.update.filNU
case QTDelete: case QTDelete:
return trv.delete.fil return trv.delete.fil, trv.delete.filNU
case QTUpsert: case QTUpsert:
return trv.insert.fil return trv.insert.fil, trv.insert.filNU
} }
return nil return nil, false
} }
func listToMap(list []string) map[string]struct{} { func listToMap(list []string) map[string]struct{} {

View File

@ -51,6 +51,7 @@ type Select struct {
Allowed map[string]struct{} Allowed map[string]struct{}
PresetMap map[string]string PresetMap map[string]string
PresetList []string PresetList []string
SkipRender bool
} }
type Column struct { type Column struct {
@ -187,7 +188,7 @@ func (com *Compiler) AddRole(role, table string, trc TRConfig) error {
trv := &trval{} trv := &trval{}
// query config // query config
trv.query.fil, err = compileFilter(trc.Query.Filters) trv.query.fil, trv.query.filNU, err = compileFilter(trc.Query.Filters)
if err != nil { if err != nil {
return err return err
} }
@ -198,7 +199,8 @@ func (com *Compiler) AddRole(role, table string, trc TRConfig) error {
trv.query.disable.funcs = trc.Query.DisableFunctions trv.query.disable.funcs = trc.Query.DisableFunctions
// insert config // insert config
if trv.insert.fil, err = compileFilter(trc.Insert.Filters); err != nil { trv.insert.fil, trv.insert.filNU, err = compileFilter(trc.Insert.Filters)
if err != nil {
return err return err
} }
trv.insert.cols = listToMap(trc.Insert.Columns) trv.insert.cols = listToMap(trc.Insert.Columns)
@ -206,7 +208,8 @@ func (com *Compiler) AddRole(role, table string, trc TRConfig) error {
trv.insert.pslist = mapToList(trv.insert.psmap) trv.insert.pslist = mapToList(trv.insert.psmap)
// update config // update config
if trv.update.fil, err = compileFilter(trc.Update.Filters); err != nil { trv.update.fil, trv.update.filNU, err = compileFilter(trc.Update.Filters)
if err != nil {
return err return err
} }
trv.update.cols = listToMap(trc.Update.Columns) trv.update.cols = listToMap(trc.Update.Columns)
@ -214,7 +217,8 @@ func (com *Compiler) AddRole(role, table string, trc TRConfig) error {
trv.update.pslist = mapToList(trv.update.psmap) trv.update.pslist = mapToList(trv.update.psmap)
// delete config // delete config
if trv.delete.fil, err = compileFilter(trc.Delete.Filters); err != nil { trv.delete.fil, trv.delete.filNU, err = compileFilter(trc.Delete.Filters)
if err != nil {
return err return err
} }
trv.delete.cols = listToMap(trc.Delete.Columns) trv.delete.cols = listToMap(trc.Delete.Columns)
@ -334,7 +338,7 @@ func (com *Compiler) compileQuery(qc *QCode, op *Operation, role string) error {
s.FieldName = s.Name s.FieldName = s.Name
} }
err := com.compileArgs(qc, s, field.Args) err := com.compileArgs(qc, s, field.Args, role)
if err != nil { if err != nil {
return err return err
} }
@ -388,9 +392,16 @@ func (com *Compiler) compileQuery(qc *QCode, op *Operation, role string) error {
func (com *Compiler) addFilters(qc *QCode, sel *Select, role string) { func (com *Compiler) addFilters(qc *QCode, sel *Select, role string) {
var fil *Exp var fil *Exp
var nu bool
if trv, ok := com.tr[role][sel.Name]; ok { if trv, ok := com.tr[role][sel.Name]; ok {
fil = trv.filter(qc.Type) fil, nu = trv.filter(qc.Type)
} else if role == "anon" {
// Tables not defined under the anon role will not be rendered
sel.SkipRender = true
return
} else { } else {
return return
} }
@ -399,6 +410,10 @@ func (com *Compiler) addFilters(qc *QCode, sel *Select, role string) {
return return
} }
if nu && role == "anon" {
sel.SkipRender = true
}
switch fil.Op { switch fil.Op {
case OpNop: case OpNop:
case OpFalse: case OpFalse:
@ -420,7 +435,7 @@ func (com *Compiler) addFilters(qc *QCode, sel *Select, role string) {
} }
} }
func (com *Compiler) compileArgs(qc *QCode, sel *Select, args []Arg) error { func (com *Compiler) compileArgs(qc *QCode, sel *Select, args []Arg, role string) error {
var err error var err error
var ka bool var ka bool
@ -435,7 +450,7 @@ func (com *Compiler) compileArgs(qc *QCode, sel *Select, args []Arg) error {
err, ka = com.compileArgSearch(sel, arg) err, ka = com.compileArgSearch(sel, arg)
case "where": case "where":
err, ka = com.compileArgWhere(sel, arg) err, ka = com.compileArgWhere(sel, arg, role)
case "orderby", "order_by", "order": case "orderby", "order_by", "order":
err, ka = com.compileArgOrderBy(sel, arg) err, ka = com.compileArgOrderBy(sel, arg)
@ -501,19 +516,20 @@ func (com *Compiler) setMutationType(qc *QCode, args []Arg) error {
return nil return nil
} }
func (com *Compiler) compileArgObj(st *util.Stack, arg *Arg) (*Exp, error) { func (com *Compiler) compileArgObj(st *util.Stack, arg *Arg) (*Exp, bool, error) {
if arg.Val.Type != NodeObj { if arg.Val.Type != NodeObj {
return nil, fmt.Errorf("expecting an object") return nil, false, fmt.Errorf("expecting an object")
} }
return com.compileArgNode(st, arg.Val, true) return com.compileArgNode(st, arg.Val, true)
} }
func (com *Compiler) compileArgNode(st *util.Stack, node *Node, usePool bool) (*Exp, error) { func (com *Compiler) compileArgNode(st *util.Stack, node *Node, usePool bool) (*Exp, bool, error) {
var root *Exp var root *Exp
var needsUser bool
if node == nil || len(node.Children) == 0 { if node == nil || len(node.Children) == 0 {
return nil, errors.New("invalid argument value") return nil, needsUser, errors.New("invalid argument value")
} }
pushChild(st, nil, node) pushChild(st, nil, node)
@ -526,7 +542,7 @@ func (com *Compiler) compileArgNode(st *util.Stack, node *Node, usePool bool) (*
intf := st.Pop() intf := st.Pop()
node, ok := intf.(*Node) node, ok := intf.(*Node)
if !ok || node == nil { if !ok || node == nil {
return nil, fmt.Errorf("16: unexpected value %v (%t)", intf, intf) return nil, needsUser, fmt.Errorf("16: unexpected value %v (%t)", intf, intf)
} }
// Objects inside a list // Objects inside a list
@ -542,13 +558,17 @@ func (com *Compiler) compileArgNode(st *util.Stack, node *Node, usePool bool) (*
ex, err := newExp(st, node, usePool) ex, err := newExp(st, node, usePool)
if err != nil { if err != nil {
return nil, err return nil, needsUser, err
} }
if ex == nil { if ex == nil {
continue continue
} }
if ex.Type == ValVar && ex.Val == "user_id" {
needsUser = true
}
if node.exp == nil { if node.exp == nil {
root = ex root = ex
} else { } else {
@ -571,7 +591,7 @@ func (com *Compiler) compileArgNode(st *util.Stack, node *Node, usePool bool) (*
nodePool.Put(node) nodePool.Put(node)
} }
return root, nil return root, needsUser, nil
} }
func (com *Compiler) compileArgID(sel *Select, arg *Arg) (error, bool) { func (com *Compiler) compileArgID(sel *Select, arg *Arg) (error, bool) {
@ -640,15 +660,19 @@ func (com *Compiler) compileArgSearch(sel *Select, arg *Arg) (error, bool) {
return nil, true return nil, true
} }
func (com *Compiler) compileArgWhere(sel *Select, arg *Arg) (error, bool) { func (com *Compiler) compileArgWhere(sel *Select, arg *Arg, role string) (error, bool) {
st := util.NewStack() st := util.NewStack()
var err error var err error
ex, err := com.compileArgObj(st, arg) ex, nu, err := com.compileArgObj(st, arg)
if err != nil { if err != nil {
return err, false return err, false
} }
if nu && role == "anon" {
sel.SkipRender = true
}
if sel.Where != nil { if sel.Where != nil {
ow := sel.Where ow := sel.Where
@ -976,27 +1000,32 @@ func pushChild(st *util.Stack, exp *Exp, node *Node) {
} }
func compileFilter(filter []string) (*Exp, error) { func compileFilter(filter []string) (*Exp, bool, error) {
var fl *Exp var fl *Exp
var needsUser bool
com := &Compiler{} com := &Compiler{}
st := util.NewStack() st := util.NewStack()
if len(filter) == 0 { if len(filter) == 0 {
return &Exp{Op: OpNop, doFree: false}, nil return &Exp{Op: OpNop, doFree: false}, false, nil
} }
for i := range filter { for i := range filter {
if filter[i] == "false" { if filter[i] == "false" {
return &Exp{Op: OpFalse, doFree: false}, nil return &Exp{Op: OpFalse, doFree: false}, false, nil
} }
node, err := ParseArgValue(filter[i]) node, err := ParseArgValue(filter[i])
if err != nil { if err != nil {
return nil, err return nil, false, err
} }
f, err := com.compileArgNode(st, node, false) f, nu, err := com.compileArgNode(st, node, false)
if err != nil { if err != nil {
return nil, err return nil, false, err
}
if nu {
needsUser = true
} }
// TODO: Invalid table names in nested where causes fail silently // TODO: Invalid table names in nested where causes fail silently
@ -1010,7 +1039,7 @@ func compileFilter(filter []string) (*Exp, error) {
fl = &Exp{Op: OpAnd, Children: []*Exp{fl, f}, doFree: false} fl = &Exp{Op: OpAnd, Children: []*Exp{fl, f}, doFree: false}
} }
} }
return fl, nil return fl, needsUser, nil
} }
func buildPath(a []string) string { func buildPath(a []string) string {

View File

@ -35,8 +35,6 @@ func argMap(ctx context.Context, vars []byte) func(w io.Writer, tag string) (int
fields := jsn.Get(vars, [][]byte{[]byte(tag)}) fields := jsn.Get(vars, [][]byte{[]byte(tag)})
fmt.Println(">>", tag, string(vars))
if len(fields) == 0 { if len(fields) == 0 {
return 0, nil return 0, nil
} }