From 2d466bfb12feaa2d5c7c2055f6a6cc0d1d422868 Mon Sep 17 00:00:00 2001 From: Vikram Rangnekar Date: Mon, 20 Jan 2020 23:38:17 -0500 Subject: [PATCH] Add skip query selectors that require auth in anon role --- psql/query.go | 64 ++++++++++++++++++++++++-------------- psql/query_test.go | 25 +++++++++++++++ qcode/config.go | 22 +++++++------ qcode/qcode.go | 77 +++++++++++++++++++++++++++++++--------------- serv/args.go | 2 -- 5 files changed, 132 insertions(+), 58 deletions(-) diff --git a/psql/query.go b/psql/query.go index c07675a..1cdc881 100644 --- a/psql/query.go +++ b/psql/query.go @@ -82,17 +82,21 @@ func (co *Compiler) compileQuery(qc *qcode.QCode, w io.Writer) (uint32, error) { multiRoot := (len(qc.Roots) > 1) st := NewIntStack() + si := 0 if multiRoot { io.WriteString(c.w, `SELECT row_to_json("json_root") FROM (SELECT `) - for i, id := range qc.Roots { + for _, id := range qc.Roots { root := qc.Selects[id] + if root.SkipRender { + continue + } st.Push(root.ID + closeBlock) st.Push(root.ID) - if i != 0 { + if si != 0 { io.WriteString(c.w, `, `) } @@ -103,24 +107,34 @@ func (co *Compiler) compileQuery(qc *qcode.QCode, w io.Writer) (uint32, error) { io.WriteString(c.w, `"`) alias(c.w, root.FieldName) + si++ } - io.WriteString(c.w, ` FROM `) + if si != 0 { + io.WriteString(c.w, ` FROM `) + + } } else { root := qc.Selects[0] + if !root.SkipRender { + io.WriteString(c.w, `SELECT json_object_agg(`) + io.WriteString(c.w, `'`) + io.WriteString(c.w, root.FieldName) + io.WriteString(c.w, `', `) + io.WriteString(c.w, `json_`) + int2string(c.w, root.ID) - io.WriteString(c.w, `SELECT json_object_agg(`) - io.WriteString(c.w, `'`) - io.WriteString(c.w, root.FieldName) - io.WriteString(c.w, `', `) - io.WriteString(c.w, `json_`) - int2string(c.w, root.ID) + st.Push(root.ID + closeBlock) + st.Push(root.ID) - st.Push(root.ID + closeBlock) - st.Push(root.ID) + io.WriteString(c.w, `) FROM `) + si++ + } + } - io.WriteString(c.w, `) FROM `) + if si == 0 { + return 0, errors.New("all tables skipped. cannot render query") } var ignored uint32 @@ -161,6 +175,9 @@ func (co *Compiler) compileQuery(qc *qcode.QCode, w io.Writer) (uint32, error) { continue } child := &c.s[cid] + if child.SkipRender { + continue + } st.Push(child.ID + closeBlock) st.Push(child.ID) @@ -475,18 +492,22 @@ func (c *compilerContext) renderRemoteRelColumns(sel *qcode.Select, ti *DBTableI } func (c *compilerContext) renderJoinedColumns(sel *qcode.Select, ti *DBTableInfo, skipped uint32) error { - colsRendered := len(sel.Cols) != 0 + + // columns previously rendered + i := len(sel.Cols) for _, id := range sel.Children { - skipThis := hasBit(skipped, uint32(id)) - - if colsRendered && !skipThis { - io.WriteString(c.w, ", ") - } - if skipThis { + if hasBit(skipped, uint32(id)) { continue } childSel := &c.s[id] + if childSel.SkipRender { + continue + } + + if i != 0 { + io.WriteString(c.w, ", ") + } //fmt.Fprintf(w, `"%s_%d_join"."%s" AS "%s"`, //s.Name, s.ID, s.Name, s.FieldName) @@ -500,6 +521,7 @@ func (c *compilerContext) renderJoinedColumns(sel *qcode.Select, ti *DBTableInfo io.WriteString(c.w, `" AS "`) io.WriteString(c.w, childSel.FieldName) io.WriteString(c.w, `"`) + i++ } return nil @@ -632,10 +654,6 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo, } } - // if i != 0 && len(sel.OrderBy) != 0 { - // io.WriteString(c.w, ", ") - // } - for _, ob := range sel.OrderBy { if _, ok := colmap[ob.Col]; ok { continue diff --git a/psql/query_test.go b/psql/query_test.go index 0a07aba..12ee395 100644 --- a/psql/query_test.go +++ b/psql/query_test.go @@ -463,6 +463,30 @@ func multiRoot(t *testing.T) { } } +func skipUserIDForAnonRole(t *testing.T) { + gql := `query { + products { + id + name + user(where: { id: { eq: $user_id } }) { + id + email + } + } + }` + + sql := `SELECT json_object_agg('products', json_0) FROM (SELECT coalesce(json_agg("json_0"), '[]') AS "json_0" FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name", "products"."user_id" FROM "products" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "json_agg_0") AS "sel_0"` + + resSQL, err := compileGQLToPSQL(gql, nil, "anon") + if err != nil { + t.Fatal(err) + } + + if string(resSQL) != sql { + t.Fatal(errNotExpected) + } +} + func blockedQuery(t *testing.T) { gql := `query { user(id: 5, where: { id: { gt: 3 } }) { @@ -524,6 +548,7 @@ func TestCompileQuery(t *testing.T) { t.Run("queryWithVariables", queryWithVariables) t.Run("withWhereOnRelations", withWhereOnRelations) t.Run("multiRoot", multiRoot) + t.Run("skipUserIDForAnonRole", skipUserIDForAnonRole) t.Run("blockedQuery", blockedQuery) t.Run("blockedFunctions", blockedFunctions) } diff --git a/qcode/config.go b/qcode/config.go index 52bc6fb..f692f44 100644 --- a/qcode/config.go +++ b/qcode/config.go @@ -45,6 +45,7 @@ type trval struct { query struct { limit string fil *Exp + filNU bool cols map[string]struct{} disable struct { funcs bool @@ -53,6 +54,7 @@ type trval struct { insert struct { fil *Exp + filNU bool cols map[string]struct{} psmap map[string]string pslist []string @@ -60,14 +62,16 @@ type trval struct { update struct { fil *Exp + filNU bool cols map[string]struct{} psmap map[string]string pslist []string } delete struct { - fil *Exp - cols map[string]struct{} + fil *Exp + filNU bool + cols map[string]struct{} } } @@ -88,21 +92,21 @@ func (trv *trval) allowedColumns(qt QType) map[string]struct{} { return nil } -func (trv *trval) filter(qt QType) *Exp { +func (trv *trval) filter(qt QType) (*Exp, bool) { switch qt { case QTQuery: - return trv.query.fil + return trv.query.fil, trv.query.filNU case QTInsert: - return trv.insert.fil + return trv.insert.fil, trv.insert.filNU case QTUpdate: - return trv.update.fil + return trv.update.fil, trv.update.filNU case QTDelete: - return trv.delete.fil + return trv.delete.fil, trv.delete.filNU case QTUpsert: - return trv.insert.fil + return trv.insert.fil, trv.insert.filNU } - return nil + return nil, false } func listToMap(list []string) map[string]struct{} { diff --git a/qcode/qcode.go b/qcode/qcode.go index b268818..3d68c17 100644 --- a/qcode/qcode.go +++ b/qcode/qcode.go @@ -51,6 +51,7 @@ type Select struct { Allowed map[string]struct{} PresetMap map[string]string PresetList []string + SkipRender bool } type Column struct { @@ -187,7 +188,7 @@ func (com *Compiler) AddRole(role, table string, trc TRConfig) error { trv := &trval{} // query config - trv.query.fil, err = compileFilter(trc.Query.Filters) + trv.query.fil, trv.query.filNU, err = compileFilter(trc.Query.Filters) if err != nil { return err } @@ -198,7 +199,8 @@ func (com *Compiler) AddRole(role, table string, trc TRConfig) error { trv.query.disable.funcs = trc.Query.DisableFunctions // insert config - if trv.insert.fil, err = compileFilter(trc.Insert.Filters); err != nil { + trv.insert.fil, trv.insert.filNU, err = compileFilter(trc.Insert.Filters) + if err != nil { return err } trv.insert.cols = listToMap(trc.Insert.Columns) @@ -206,7 +208,8 @@ func (com *Compiler) AddRole(role, table string, trc TRConfig) error { trv.insert.pslist = mapToList(trv.insert.psmap) // update config - if trv.update.fil, err = compileFilter(trc.Update.Filters); err != nil { + trv.update.fil, trv.update.filNU, err = compileFilter(trc.Update.Filters) + if err != nil { return err } trv.update.cols = listToMap(trc.Update.Columns) @@ -214,7 +217,8 @@ func (com *Compiler) AddRole(role, table string, trc TRConfig) error { trv.update.pslist = mapToList(trv.update.psmap) // delete config - if trv.delete.fil, err = compileFilter(trc.Delete.Filters); err != nil { + trv.delete.fil, trv.delete.filNU, err = compileFilter(trc.Delete.Filters) + if err != nil { return err } trv.delete.cols = listToMap(trc.Delete.Columns) @@ -334,7 +338,7 @@ func (com *Compiler) compileQuery(qc *QCode, op *Operation, role string) error { s.FieldName = s.Name } - err := com.compileArgs(qc, s, field.Args) + err := com.compileArgs(qc, s, field.Args, role) if err != nil { return err } @@ -388,9 +392,16 @@ func (com *Compiler) compileQuery(qc *QCode, op *Operation, role string) error { func (com *Compiler) addFilters(qc *QCode, sel *Select, role string) { var fil *Exp + var nu bool if trv, ok := com.tr[role][sel.Name]; ok { - fil = trv.filter(qc.Type) + fil, nu = trv.filter(qc.Type) + + } else if role == "anon" { + // Tables not defined under the anon role will not be rendered + sel.SkipRender = true + return + } else { return } @@ -399,6 +410,10 @@ func (com *Compiler) addFilters(qc *QCode, sel *Select, role string) { return } + if nu && role == "anon" { + sel.SkipRender = true + } + switch fil.Op { case OpNop: case OpFalse: @@ -420,7 +435,7 @@ func (com *Compiler) addFilters(qc *QCode, sel *Select, role string) { } } -func (com *Compiler) compileArgs(qc *QCode, sel *Select, args []Arg) error { +func (com *Compiler) compileArgs(qc *QCode, sel *Select, args []Arg, role string) error { var err error var ka bool @@ -435,7 +450,7 @@ func (com *Compiler) compileArgs(qc *QCode, sel *Select, args []Arg) error { err, ka = com.compileArgSearch(sel, arg) case "where": - err, ka = com.compileArgWhere(sel, arg) + err, ka = com.compileArgWhere(sel, arg, role) case "orderby", "order_by", "order": err, ka = com.compileArgOrderBy(sel, arg) @@ -501,19 +516,20 @@ func (com *Compiler) setMutationType(qc *QCode, args []Arg) error { return nil } -func (com *Compiler) compileArgObj(st *util.Stack, arg *Arg) (*Exp, error) { +func (com *Compiler) compileArgObj(st *util.Stack, arg *Arg) (*Exp, bool, error) { if arg.Val.Type != NodeObj { - return nil, fmt.Errorf("expecting an object") + return nil, false, fmt.Errorf("expecting an object") } return com.compileArgNode(st, arg.Val, true) } -func (com *Compiler) compileArgNode(st *util.Stack, node *Node, usePool bool) (*Exp, error) { +func (com *Compiler) compileArgNode(st *util.Stack, node *Node, usePool bool) (*Exp, bool, error) { var root *Exp + var needsUser bool if node == nil || len(node.Children) == 0 { - return nil, errors.New("invalid argument value") + return nil, needsUser, errors.New("invalid argument value") } pushChild(st, nil, node) @@ -526,7 +542,7 @@ func (com *Compiler) compileArgNode(st *util.Stack, node *Node, usePool bool) (* intf := st.Pop() node, ok := intf.(*Node) if !ok || node == nil { - return nil, fmt.Errorf("16: unexpected value %v (%t)", intf, intf) + return nil, needsUser, fmt.Errorf("16: unexpected value %v (%t)", intf, intf) } // Objects inside a list @@ -542,13 +558,17 @@ func (com *Compiler) compileArgNode(st *util.Stack, node *Node, usePool bool) (* ex, err := newExp(st, node, usePool) if err != nil { - return nil, err + return nil, needsUser, err } if ex == nil { continue } + if ex.Type == ValVar && ex.Val == "user_id" { + needsUser = true + } + if node.exp == nil { root = ex } else { @@ -571,7 +591,7 @@ func (com *Compiler) compileArgNode(st *util.Stack, node *Node, usePool bool) (* nodePool.Put(node) } - return root, nil + return root, needsUser, nil } func (com *Compiler) compileArgID(sel *Select, arg *Arg) (error, bool) { @@ -640,15 +660,19 @@ func (com *Compiler) compileArgSearch(sel *Select, arg *Arg) (error, bool) { return nil, true } -func (com *Compiler) compileArgWhere(sel *Select, arg *Arg) (error, bool) { +func (com *Compiler) compileArgWhere(sel *Select, arg *Arg, role string) (error, bool) { st := util.NewStack() var err error - ex, err := com.compileArgObj(st, arg) + ex, nu, err := com.compileArgObj(st, arg) if err != nil { return err, false } + if nu && role == "anon" { + sel.SkipRender = true + } + if sel.Where != nil { ow := sel.Where @@ -976,27 +1000,32 @@ func pushChild(st *util.Stack, exp *Exp, node *Node) { } -func compileFilter(filter []string) (*Exp, error) { +func compileFilter(filter []string) (*Exp, bool, error) { var fl *Exp + var needsUser bool + com := &Compiler{} st := util.NewStack() if len(filter) == 0 { - return &Exp{Op: OpNop, doFree: false}, nil + return &Exp{Op: OpNop, doFree: false}, false, nil } for i := range filter { if filter[i] == "false" { - return &Exp{Op: OpFalse, doFree: false}, nil + return &Exp{Op: OpFalse, doFree: false}, false, nil } node, err := ParseArgValue(filter[i]) if err != nil { - return nil, err + return nil, false, err } - f, err := com.compileArgNode(st, node, false) + f, nu, err := com.compileArgNode(st, node, false) if err != nil { - return nil, err + return nil, false, err + } + if nu { + needsUser = true } // TODO: Invalid table names in nested where causes fail silently @@ -1010,7 +1039,7 @@ func compileFilter(filter []string) (*Exp, error) { fl = &Exp{Op: OpAnd, Children: []*Exp{fl, f}, doFree: false} } } - return fl, nil + return fl, needsUser, nil } func buildPath(a []string) string { diff --git a/serv/args.go b/serv/args.go index 5e05ba2..ac233ee 100644 --- a/serv/args.go +++ b/serv/args.go @@ -35,8 +35,6 @@ func argMap(ctx context.Context, vars []byte) func(w io.Writer, tag string) (int fields := jsn.Get(vars, [][]byte{[]byte(tag)}) - fmt.Println(">>", tag, string(vars)) - if len(fields) == 0 { return 0, nil }