From 7c02226016dc0d2e9b6b6d12b77564f63467dd46 Mon Sep 17 00:00:00 2001 From: Vikram Rangnekar Date: Mon, 13 Jan 2020 09:34:15 -0500 Subject: [PATCH] Fix role filters and nested where bugs --- psql/query.go | 34 +++++++++++++++++++++++----------- psql/query_test.go | 4 ++-- qcode/config.go | 2 +- qcode/qcode.go | 14 +++++++++++--- serv/config_compile.go | 16 ++++++++-------- 5 files changed, 45 insertions(+), 25 deletions(-) diff --git a/psql/query.go b/psql/query.go index 6c655b5..607eb32 100644 --- a/psql/query.go +++ b/psql/query.go @@ -814,11 +814,15 @@ func (c *compilerContext) renderRelationshipByName(table, parent string, id int3 } func (c *compilerContext) renderWhere(sel *qcode.Select, ti *DBTableInfo) error { - st := util.NewStack() - if sel.Where != nil { - st.Push(sel.Where) + return c.renderExp(sel.Where, ti, false) } + return nil +} + +func (c *compilerContext) renderExp(ex *qcode.Exp, ti *DBTableInfo, skipNested bool) error { + st := util.NewStack() + st.Push(ex) for { if st.Len() == 0 { @@ -873,16 +877,16 @@ func (c *compilerContext) renderWhere(sel *qcode.Select, ti *DBTableInfo) error qcode.FreeExp(val) default: - if len(val.NestedCols) != 0 { + if !skipNested && len(val.NestedCols) != 0 { io.WriteString(c.w, `EXISTS `) - if err := c.renderNestedWhere(val, sel, ti); err != nil { + if err := c.renderNestedWhere(val, ti); err != nil { return err } } else { //fmt.Fprintf(w, `(("%s"."%s") `, c.sel.Name, val.Col) - if err := c.renderOp(val, sel, ti); err != nil { + if err := c.renderOp(val, ti); err != nil { return err } qcode.FreeExp(val) @@ -898,7 +902,7 @@ func (c *compilerContext) renderWhere(sel *qcode.Select, ti *DBTableInfo) error return nil } -func (c *compilerContext) renderNestedWhere(ex *qcode.Exp, sel *qcode.Select, ti *DBTableInfo) error { +func (c *compilerContext) renderNestedWhere(ex *qcode.Exp, ti *DBTableInfo) error { for i := 0; i < len(ex.NestedCols)-1; i++ { cti, err := c.schema.GetTable(ex.NestedCols[i]) if err != nil { @@ -922,6 +926,15 @@ func (c *compilerContext) renderNestedWhere(ex *qcode.Exp, sel *qcode.Select, ti return err } + io.WriteString(c.w, ` AND (`) + + if err := c.renderExp(ex, cti, true); err != nil { + return err + } + + //fmt.Println(">", ex) + io.WriteString(c.w, `)`) + } for i := 0; i < len(ex.NestedCols)-1; i++ { @@ -931,7 +944,7 @@ func (c *compilerContext) renderNestedWhere(ex *qcode.Exp, sel *qcode.Select, ti return nil } -func (c *compilerContext) renderOp(ex *qcode.Exp, sel *qcode.Select, ti *DBTableInfo) error { +func (c *compilerContext) renderOp(ex *qcode.Exp, ti *DBTableInfo) error { var col *DBColumn var ok bool @@ -1029,10 +1042,9 @@ func (c *compilerContext) renderOp(ex *qcode.Exp, sel *qcode.Select, ti *DBTable if ex.Type == qcode.ValList { c.renderList(ex) + } else if col == nil { + return errors.New("no column found for expression value") } else { - if col == nil { - return errors.New("no column found for expression value") - } c.renderVal(ex, c.vars, col) } diff --git a/psql/query_test.go b/psql/query_test.go index 37fbd13..1d649b9 100644 --- a/psql/query_test.go +++ b/psql/query_test.go @@ -409,7 +409,7 @@ func withWhereOnRelations(t *testing.T) { users(where: { not: { products: { - price: { gt: 3 } + price: { gt: 3 } } } }) { @@ -418,7 +418,7 @@ func withWhereOnRelations(t *testing.T) { } }` - sql := `SELECT json_object_agg('users', json_0) FROM (SELECT coalesce(json_agg("json_0"), '[]') AS "json_0" FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "users_0"."id" AS "id", "users_0"."email" AS "email") AS "json_row_0")) AS "json_0" FROM (SELECT "users"."id", "users"."email" FROM "users" WHERE (NOT EXISTS (SELECT 1 FROM products WHERE (("products"."user_id") = ("users"."id")))) LIMIT ('20') :: integer) AS "users_0" LIMIT ('20') :: integer) AS "json_agg_0") AS "sel_0"` + sql := `SELECT json_object_agg('users', json_0) FROM (SELECT coalesce(json_agg("json_0"), '[]') AS "json_0" FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "users_0"."id" AS "id", "users_0"."email" AS "email") AS "json_row_0")) AS "json_0" FROM (SELECT "users"."id", "users"."email" FROM "users" WHERE (NOT EXISTS (SELECT 1 FROM products WHERE (("products"."user_id") = ("users"."id")) AND ((("products"."price") > 3)))) LIMIT ('20') :: integer) AS "users_0" LIMIT ('20') :: integer) AS "json_agg_0") AS "sel_0"` resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { diff --git a/qcode/config.go b/qcode/config.go index ddd118f..52bc6fb 100644 --- a/qcode/config.go +++ b/qcode/config.go @@ -80,7 +80,7 @@ func (trv *trval) allowedColumns(qt QType) map[string]struct{} { case QTUpdate: return trv.update.cols case QTDelete: - return trv.insert.cols + return trv.delete.cols case QTUpsert: return trv.insert.cols } diff --git a/qcode/qcode.go b/qcode/qcode.go index 0a840c0..b268818 100644 --- a/qcode/qcode.go +++ b/qcode/qcode.go @@ -940,10 +940,13 @@ func setWhereColName(ex *Exp, node *Node) { list = append([]string{k}, list...) } } - if len(list) == 1 { + listlen := len(list) + + if listlen == 1 { ex.Col = list[0] - } else if len(list) > 1 { - ex.NestedCols = list + } else if listlen > 1 { + ex.Col = list[listlen-1] + ex.NestedCols = list[:listlen] } } @@ -996,6 +999,11 @@ func compileFilter(filter []string) (*Exp, error) { return nil, err } + // TODO: Invalid table names in nested where causes fail silently + // returning a nil 'f' this needs to be fixed + + // TODO: Invalid where clauses such as missing op (eg. eq) also fail silently + if fl == nil { fl = f } else { diff --git a/serv/config_compile.go b/serv/config_compile.go index 3c55164..36ad073 100644 --- a/serv/config_compile.go +++ b/serv/config_compile.go @@ -80,26 +80,26 @@ func addRole(qc *qcode.Compiler, r configRole, t configRoleTable) error { Presets: t.Insert.Presets, } - if t.Query.Block { + if t.Insert.Block { insert.Filters = blockFilter } update := qcode.UpdateConfig{ - Filters: t.Insert.Filters, - Columns: t.Insert.Columns, - Presets: t.Insert.Presets, + Filters: t.Update.Filters, + Columns: t.Update.Columns, + Presets: t.Update.Presets, } - if t.Query.Block { + if t.Update.Block { update.Filters = blockFilter } delete := qcode.DeleteConfig{ - Filters: t.Insert.Filters, - Columns: t.Insert.Columns, + Filters: t.Delete.Filters, + Columns: t.Delete.Columns, } - if t.Query.Block { + if t.Delete.Block { delete.Filters = blockFilter }