From 6029c5e05cc2c7f323e94d7f959da36b329ffb61 Mon Sep 17 00:00:00 2001 From: Vikram Rangnekar Date: Mon, 2 Dec 2019 10:52:22 -0500 Subject: [PATCH] Add support for `websearch_to_tsquery` in PG 11 --- config/allow.list | 14 ++++ config/dev.yml | 2 +- migrate/migrate.go | 8 +- psql/psql_test.go | 7 +- psql/query.go | 56 ++++++++++--- psql/query_test.go | 8 +- psql/tables.go | 22 ++++- qcode/config.go | 1 - qcode/parse.go | 44 +++++----- qcode/qcode.go | 197 ++++++++++++++++++++++----------------------- serv/cmd_seed.go | 2 +- serv/core.go | 38 ++++----- serv/core_build.go | 12 ++- serv/prepare.go | 33 ++++---- serv/serv.go | 1 - 15 files changed, 249 insertions(+), 196 deletions(-) diff --git a/config/allow.list b/config/allow.list index 11477b5..4971ba5 100644 --- a/config/allow.list +++ b/config/allow.list @@ -169,3 +169,17 @@ query { } } +variables { + "beer": "smoke" +} + +query beerSearch { + products(search: $beer) { + id + name + search_rank + search_headline_description + } +} + + diff --git a/config/dev.yml b/config/dev.yml index 8ef4d6d..e04b1fc 100644 --- a/config/dev.yml +++ b/config/dev.yml @@ -171,7 +171,7 @@ roles: query: limit: 50 filters: ["{ user_id: { eq: $user_id } }"] - columns: ["id", "name", "description" ] + columns: ["id", "name", "description", "search_rank", "search_headline_description" ] disable_functions: false insert: diff --git a/migrate/migrate.go b/migrate/migrate.go index badae2f..7730531 100644 --- a/migrate/migrate.go +++ b/migrate/migrate.go @@ -353,16 +353,14 @@ func (m *Migrator) MigrateTo(targetVersion int32) (err error) { } func (m *Migrator) GetCurrentVersion() (v int32, err error) { - ctx := context.Background() + err = m.conn.QueryRow(context.Background(), + "select version from "+m.versionTable).Scan(&v) - err = m.conn.QueryRow(ctx, "select version from "+m.versionTable).Scan(&v) return v, err } func (m *Migrator) ensureSchemaVersionTableExists() (err error) { - ctx := context.Background() - - _, err = m.conn.Exec(ctx, fmt.Sprintf(` + _, err = m.conn.Exec(context.Background(), fmt.Sprintf(` create table if not exists %s(version int4 not null); insert into %s(version) diff --git a/psql/psql_test.go b/psql/psql_test.go index 455b75d..1609611 100644 --- a/psql/psql_test.go +++ b/psql/psql_test.go @@ -178,9 +178,10 @@ func TestMain(m *testing.M) { } schema := &DBSchema{ - t: make(map[string]*DBTableInfo), - rm: make(map[string]map[string]*DBRel), - al: make(map[string]struct{}), + ver: 110000, + t: make(map[string]*DBTableInfo), + rm: make(map[string]map[string]*DBRel), + al: make(map[string]struct{}), } aliases := map[string][]string{ diff --git a/psql/query.go b/psql/query.go index fdd37ed..b0cc310 100644 --- a/psql/query.go +++ b/psql/query.go @@ -189,6 +189,11 @@ func (co *Compiler) compileQuery(qc *qcode.QCode, w io.Writer) (uint32, error) { } } + if len(sel.Args) != 0 { + for _, v := range sel.Args { + qcode.FreeNode(v) + } + } } } @@ -515,36 +520,54 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo, if isSearch { switch { case cn == "search_rank": + if len(sel.Allowed) != 0 { + if _, ok := sel.Allowed[cn]; !ok { + continue + } + } cn = ti.TSVCol arg := sel.Args["search"] if i != 0 { io.WriteString(c.w, `, `) } - //fmt.Fprintf(w, `ts_rank("%s"."%s", to_tsquery('%s')) AS %s`, + //fmt.Fprintf(w, `ts_rank("%s"."%s", websearch_to_tsquery('%s')) AS %s`, //c.sel.Name, cn, arg.Val, col.Name) io.WriteString(c.w, `ts_rank(`) colWithTable(c.w, ti.Name, cn) - io.WriteString(c.w, `, to_tsquery('`) + if c.schema.ver >= 110000 { + io.WriteString(c.w, `, websearch_to_tsquery('`) + } else { + io.WriteString(c.w, `, to_tsquery('`) + } io.WriteString(c.w, arg.Val) - io.WriteString(c.w, `')`) + io.WriteString(c.w, `'))`) alias(c.w, col.Name) i++ case strings.HasPrefix(cn, "search_headline_"): - cn = cn[16:] + cn1 := cn[16:] + if len(sel.Allowed) != 0 { + if _, ok := sel.Allowed[cn1]; !ok { + continue + } + } arg := sel.Args["search"] if i != 0 { io.WriteString(c.w, `, `) } - //fmt.Fprintf(w, `ts_headline("%s"."%s", to_tsquery('%s')) AS %s`, + //fmt.Fprintf(w, `ts_headline("%s"."%s", websearch_to_tsquery('%s')) AS %s`, //c.sel.Name, cn, arg.Val, col.Name) - io.WriteString(c.w, `ts_headlinek(`) - colWithTable(c.w, ti.Name, cn) - io.WriteString(c.w, `, to_tsquery('`) + io.WriteString(c.w, `ts_headline(`) + colWithTable(c.w, ti.Name, cn1) + if c.schema.ver >= 110000 { + io.WriteString(c.w, `, websearch_to_tsquery('`) + } else { + io.WriteString(c.w, `, to_tsquery('`) + } io.WriteString(c.w, arg.Val) - io.WriteString(c.w, `')`) + io.WriteString(c.w, `'))`) alias(c.w, col.Name) i++ @@ -693,6 +716,7 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo, //fmt.Fprintf(w, `) AS "%s_%d"`, c.sel.Name, c.sel.ID) io.WriteString(c.w, `)`) aliasWithID(c.w, ti.Name, sel.ID) + return nil } @@ -939,6 +963,7 @@ func (c *compilerContext) renderOp(ex *qcode.Exp, sel *qcode.Select, ti *DBTable io.WriteString(c.w, `IS NOT NULL)`) } return nil + case qcode.OpEqID: if len(ti.PrimaryCol) == 0 { return fmt.Errorf("no primary key column defined for %s", ti.Name) @@ -951,6 +976,7 @@ func (c *compilerContext) renderOp(ex *qcode.Exp, sel *qcode.Select, ti *DBTable colWithTable(c.w, ti.Name, ti.PrimaryCol) //io.WriteString(c.w, ti.PrimaryCol) io.WriteString(c.w, `) =`) + case qcode.OpTsQuery: if len(ti.TSVCol) == 0 { return fmt.Errorf("no tsv column defined for %s", ti.Name) @@ -958,10 +984,14 @@ func (c *compilerContext) renderOp(ex *qcode.Exp, sel *qcode.Select, ti *DBTable if _, ok = ti.Columns[ti.TSVCol]; !ok { return fmt.Errorf("no tsv column '%s' found ", ti.TSVCol) } - //fmt.Fprintf(w, `(("%s") @@ to_tsquery('%s'))`, c.ti.TSVCol, val.Val) - io.WriteString(c.w, `(("`) - io.WriteString(c.w, ti.TSVCol) - io.WriteString(c.w, `") @@ to_tsquery('`) + //fmt.Fprintf(w, `(("%s") @@ websearch_to_tsquery('%s'))`, c.ti.TSVCol, val.Val) + io.WriteString(c.w, `((`) + colWithTable(c.w, ti.Name, ti.TSVCol) + if c.schema.ver >= 110000 { + io.WriteString(c.w, `) @@ websearch_to_tsquery('`) + } else { + io.WriteString(c.w, `) @@ to_tsquery('`) + } io.WriteString(c.w, ex.Val) io.WriteString(c.w, `'))`) return nil diff --git a/psql/query_test.go b/psql/query_test.go index c4fcfb4..d1e1763 100644 --- a/psql/query_test.go +++ b/psql/query_test.go @@ -142,15 +142,17 @@ func fetchByID(t *testing.T) { func searchQuery(t *testing.T) { gql := `query { - products(search: "Imperial") { + products(search: "ale") { id name + search_rank + search_headline_description } }` - sql := `SELECT json_object_agg('products', json_0) FROM (SELECT coalesce(json_agg("json_0"), '[]') AS "json_0" FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("tsv") @@ to_tsquery('Imperial'))) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "json_agg_0") AS "sel_0"` + sql := `SELECT json_object_agg('products', json_0) FROM (SELECT coalesce(json_agg("json_0"), '[]') AS "json_0" FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."search_rank" AS "search_rank", "products_0"."search_headline_description" AS "search_headline_description") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name", ts_rank("products"."tsv", websearch_to_tsquery('ale')) AS "search_rank", ts_headline("products"."description", websearch_to_tsquery('ale')) AS "search_headline_description" FROM "products" WHERE ((("products"."tsv") @@ websearch_to_tsquery('ale'))) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "json_agg_0") AS "sel_0"` - resSQL, err := compileGQLToPSQL(gql, nil, "user") + resSQL, err := compileGQLToPSQL(gql, nil, "admin") if err != nil { t.Fatal(err) } diff --git a/psql/tables.go b/psql/tables.go index 4c74bc6..e82e31b 100644 --- a/psql/tables.go +++ b/psql/tables.go @@ -3,6 +3,7 @@ package psql import ( "context" "fmt" + "strconv" "strings" "github.com/gobuffalo/flect" @@ -144,9 +145,10 @@ ORDER BY id;` } type DBSchema struct { - t map[string]*DBTableInfo - rm map[string]map[string]*DBRel - al map[string]struct{} + ver int + t map[string]*DBTableInfo + rm map[string]map[string]*DBRel + al map[string]struct{} } type DBTableInfo struct { @@ -184,10 +186,22 @@ func NewDBSchema(db *pgxpool.Pool, aliases map[string][]string) (*DBSchema, erro dbc, err := db.Acquire(context.Background()) if err != nil { - return nil, fmt.Errorf("error acquiring connection from pool") + return nil, fmt.Errorf("error acquiring connection from pool: %w", err) } defer dbc.Release() + var version string + + err = dbc.QueryRow(context.Background(), `SHOW server_version_num`).Scan(&version) + if err != nil { + return nil, fmt.Errorf("error fetching version: %w", err) + } + + schema.ver, err = strconv.Atoi(version) + if err != nil { + return nil, err + } + tables, err := GetTables(dbc) if err != nil { return nil, err diff --git a/qcode/config.go b/qcode/config.go index 34a9d88..ddd118f 100644 --- a/qcode/config.go +++ b/qcode/config.go @@ -8,7 +8,6 @@ import ( type Config struct { Blocklist []string - KeepArgs bool } type QueryConfig struct { diff --git a/qcode/parse.go b/qcode/parse.go index 2e6b574..644ecf5 100644 --- a/qcode/parse.go +++ b/qcode/parse.go @@ -26,13 +26,13 @@ const ( opQuery opMutate opSub - nodeStr - nodeInt - nodeFloat - nodeBool - nodeObj - nodeList - nodeVar + NodeStr + NodeInt + NodeFloat + NodeBool + NodeObj + NodeList + NodeVar ) type Operation struct { @@ -413,7 +413,7 @@ func (p *Parser) parseList() (*Node, error) { return nil, errors.New("List cannot be empty") } - parent.Type = nodeList + parent.Type = NodeList parent.Children = nodes return parent, nil @@ -450,7 +450,7 @@ func (p *Parser) parseObj() (*Node, error) { nodes = append(nodes, node) } - parent.Type = nodeObj + parent.Type = NodeObj parent.Children = nodes return parent, nil @@ -473,17 +473,17 @@ func (p *Parser) parseValue() (*Node, error) { switch item.typ { case itemIntVal: - node.Type = nodeInt + node.Type = NodeInt case itemFloatVal: - node.Type = nodeFloat + node.Type = NodeFloat case itemStringVal: - node.Type = nodeStr + node.Type = NodeStr case itemBoolVal: - node.Type = nodeBool + node.Type = NodeBool case itemName: - node.Type = nodeStr + node.Type = NodeStr case itemVariable: - node.Type = nodeVar + node.Type = NodeVar default: return nil, fmt.Errorf("expecting a number, string, object, list or variable as an argument value (not %s)", p.val(p.next())) } @@ -514,19 +514,19 @@ func (t parserType) String() string { v = "mutation" case opSub: v = "subscription" - case nodeStr: + case NodeStr: v = "node-string" - case nodeInt: + case NodeInt: v = "node-int" - case nodeFloat: + case NodeFloat: v = "node-float" - case nodeBool: + case NodeBool: v = "node-bool" - case nodeVar: + case NodeVar: v = "node-var" - case nodeObj: + case NodeObj: v = "node-obj" - case nodeList: + case NodeList: v = "node-list" } return fmt.Sprintf("<%s>", v) diff --git a/qcode/qcode.go b/qcode/qcode.go index 331bac6..b31c588 100644 --- a/qcode/qcode.go +++ b/qcode/qcode.go @@ -157,7 +157,6 @@ const ( type Compiler struct { tr map[string]map[string]*trval bl map[string]struct{} - ka bool } var expPool = sync.Pool{ @@ -165,7 +164,7 @@ var expPool = sync.Pool{ } func NewCompiler(c Config) (*Compiler, error) { - co := &Compiler{ka: c.KeepArgs} + co := &Compiler{} co.tr = make(map[string]map[string]*trval) co.bl = make(map[string]struct{}, len(c.Blocklist)) @@ -380,11 +379,13 @@ func (com *Compiler) compileQuery(qc *QCode, op *Operation, role string) error { return nil } -func (com *Compiler) addFilters(qc *QCode, root *Select, role string) { +func (com *Compiler) addFilters(qc *QCode, sel *Select, role string) { var fil *Exp - if trv, ok := com.tr[role][root.Name]; ok { + if trv, ok := com.tr[role][sel.Name]; ok { fil = trv.filter(qc.Type) + } else { + return } if fil == nil { @@ -394,60 +395,61 @@ func (com *Compiler) addFilters(qc *QCode, root *Select, role string) { switch fil.Op { case OpNop: case OpFalse: - root.Where = fil + sel.Where = fil default: - if root.Where != nil { - ow := root.Where + if sel.Where != nil { + ow := sel.Where - root.Where = expPool.Get().(*Exp) - root.Where.Reset() - root.Where.Op = OpAnd - root.Where.Children = root.Where.childrenA[:2] - root.Where.Children[0] = fil - root.Where.Children[1] = ow + sel.Where = expPool.Get().(*Exp) + sel.Where.Reset() + sel.Where.Op = OpAnd + sel.Where.Children = sel.Where.childrenA[:2] + sel.Where.Children[0] = fil + sel.Where.Children[1] = ow } else { - root.Where = fil + sel.Where = fil } } } func (com *Compiler) compileArgs(qc *QCode, sel *Select, args []Arg) error { var err error - - if com.ka { - sel.Args = make(map[string]*Node, len(args)) - } + var ka bool for i := range args { arg := &args[i] switch arg.Name { case "id": - err = com.compileArgID(sel, arg) + err, ka = com.compileArgID(sel, arg) + case "search": - err = com.compileArgSearch(sel, arg) + err, ka = com.compileArgSearch(sel, arg) + case "where": - err = com.compileArgWhere(sel, arg) + err, ka = com.compileArgWhere(sel, arg) + case "orderby", "order_by", "order": - err = com.compileArgOrderBy(sel, arg) + err, ka = com.compileArgOrderBy(sel, arg) + case "distinct_on", "distinct": - err = com.compileArgDistinctOn(sel, arg) + err, ka = com.compileArgDistinctOn(sel, arg) + case "limit": - err = com.compileArgLimit(sel, arg) + err, ka = com.compileArgLimit(sel, arg) + case "offset": - err = com.compileArgOffset(sel, arg) + err, ka = com.compileArgOffset(sel, arg) + } + + if !ka { + nodePool.Put(arg.Val) } if err != nil { return err } - - if sel.Args != nil { - sel.Args[arg.Name] = arg.Val - } else { - nodePool.Put(arg.Val) - } } return nil @@ -455,7 +457,7 @@ func (com *Compiler) compileArgs(qc *QCode, sel *Select, args []Arg) error { func (com *Compiler) setMutationType(qc *QCode, args []Arg) error { setActionVar := func(arg *Arg) error { - if arg.Val.Type != nodeVar { + if arg.Val.Type != NodeVar { return fmt.Errorf("value for argument '%s' must be a variable", arg.Name) } qc.ActionVar = arg.Val.Val @@ -478,7 +480,7 @@ func (com *Compiler) setMutationType(qc *QCode, args []Arg) error { case "delete": qc.Type = QTDelete - if arg.Val.Type != nodeBool { + if arg.Val.Type != NodeBool { return fmt.Errorf("value for argument '%s' must be a boolean", arg.Name) } @@ -493,7 +495,7 @@ func (com *Compiler) setMutationType(qc *QCode, args []Arg) error { } func (com *Compiler) compileArgObj(st *util.Stack, arg *Arg) (*Exp, error) { - if arg.Val.Type != nodeObj { + if arg.Val.Type != NodeObj { return nil, fmt.Errorf("expecting an object") } @@ -545,11 +547,6 @@ func (com *Compiler) compileArgNode(st *util.Stack, node *Node, usePool bool) (* } else { node.exp.Children = append(node.exp.Children, ex) } - - } - - if com.ka { - return root, nil } pushChild(st, nil, node) @@ -570,13 +567,13 @@ func (com *Compiler) compileArgNode(st *util.Stack, node *Node, usePool bool) (* return root, nil } -func (com *Compiler) compileArgID(sel *Select, arg *Arg) error { +func (com *Compiler) compileArgID(sel *Select, arg *Arg) (error, bool) { if sel.ID != 0 { - return nil + return nil, false } if sel.Where != nil && sel.Where.Op == OpEqID { - return nil + return nil, false } ex := expPool.Get().(*Exp) @@ -586,30 +583,41 @@ func (com *Compiler) compileArgID(sel *Select, arg *Arg) error { ex.Val = arg.Val.Val switch arg.Val.Type { - case nodeStr: + case NodeStr: ex.Type = ValStr - case nodeInt: + case NodeInt: ex.Type = ValInt - case nodeFloat: + case NodeFloat: ex.Type = ValFloat - case nodeVar: + case NodeVar: ex.Type = ValVar default: - return fmt.Errorf("expecting a string, int, float or variable") + return fmt.Errorf("expecting a string, int, float or variable"), false } sel.Where = ex - return nil + return nil, false } -func (com *Compiler) compileArgSearch(sel *Select, arg *Arg) error { +func (com *Compiler) compileArgSearch(sel *Select, arg *Arg) (error, bool) { ex := expPool.Get().(*Exp) ex.Reset() ex.Op = OpTsQuery - ex.Type = ValStr ex.Val = arg.Val.Val + if arg.Val.Type == NodeVar { + ex.Type = ValVar + } else { + ex.Type = ValStr + } + + if sel.Args == nil { + sel.Args = make(map[string]*Node) + } + + sel.Args[arg.Name] = arg.Val + if sel.Where != nil { ow := sel.Where @@ -622,16 +630,16 @@ func (com *Compiler) compileArgSearch(sel *Select, arg *Arg) error { } else { sel.Where = ex } - return nil + return nil, true } -func (com *Compiler) compileArgWhere(sel *Select, arg *Arg) error { +func (com *Compiler) compileArgWhere(sel *Select, arg *Arg) (error, bool) { st := util.NewStack() var err error ex, err := com.compileArgObj(st, arg) if err != nil { - return err + return err, false } if sel.Where != nil { @@ -647,12 +655,12 @@ func (com *Compiler) compileArgWhere(sel *Select, arg *Arg) error { sel.Where = ex } - return nil + return nil, false } -func (com *Compiler) compileArgOrderBy(sel *Select, arg *Arg) error { - if arg.Val.Type != nodeObj { - return fmt.Errorf("expecting an object") +func (com *Compiler) compileArgOrderBy(sel *Select, arg *Arg) (error, bool) { + if arg.Val.Type != NodeObj { + return fmt.Errorf("expecting an object"), false } st := util.NewStack() @@ -670,23 +678,19 @@ func (com *Compiler) compileArgOrderBy(sel *Select, arg *Arg) error { node, ok := intf.(*Node) if !ok || node == nil { - return fmt.Errorf("17: unexpected value %v (%t)", intf, intf) + return fmt.Errorf("17: unexpected value %v (%t)", intf, intf), false } if _, ok := com.bl[node.Name]; ok { - if !com.ka { - nodePool.Put(node) - } + nodePool.Put(node) continue } - if node.Type == nodeObj { + if node.Type == NodeObj { for i := range node.Children { st.Push(node.Children[i]) } - if !com.ka { - nodePool.Put(node) - } + nodePool.Put(node) continue } @@ -706,65 +710,60 @@ func (com *Compiler) compileArgOrderBy(sel *Select, arg *Arg) error { case "desc_nulls_last": ob.Order = OrderDescNullsLast default: - return fmt.Errorf("valid values include asc, desc, asc_nulls_first and desc_nulls_first") + return fmt.Errorf("valid values include asc, desc, asc_nulls_first and desc_nulls_first"), false } setOrderByColName(ob, node) sel.OrderBy = append(sel.OrderBy, ob) - - if !com.ka { - nodePool.Put(node) - } + nodePool.Put(node) } - return nil + return nil, false } -func (com *Compiler) compileArgDistinctOn(sel *Select, arg *Arg) error { +func (com *Compiler) compileArgDistinctOn(sel *Select, arg *Arg) (error, bool) { node := arg.Val if _, ok := com.bl[node.Name]; ok { - return nil + return nil, false } - if node.Type != nodeList && node.Type != nodeStr { - return fmt.Errorf("expecting a list of strings or just a string") + if node.Type != NodeList && node.Type != NodeStr { + return fmt.Errorf("expecting a list of strings or just a string"), false } - if node.Type == nodeStr { + if node.Type == NodeStr { sel.DistinctOn = append(sel.DistinctOn, node.Val) } for i := range node.Children { sel.DistinctOn = append(sel.DistinctOn, node.Children[i].Val) - if !com.ka { - nodePool.Put(node.Children[i]) - } + nodePool.Put(node.Children[i]) } - return nil + return nil, false } -func (com *Compiler) compileArgLimit(sel *Select, arg *Arg) error { +func (com *Compiler) compileArgLimit(sel *Select, arg *Arg) (error, bool) { node := arg.Val - if node.Type != nodeInt { - return fmt.Errorf("expecting an integer") + if node.Type != NodeInt { + return fmt.Errorf("expecting an integer"), false } sel.Paging.Limit = node.Val - return nil + return nil, false } -func (com *Compiler) compileArgOffset(sel *Select, arg *Arg) error { +func (com *Compiler) compileArgOffset(sel *Select, arg *Arg) (error, bool) { node := arg.Val - if node.Type != nodeInt { - return fmt.Errorf("expecting an integer") + if node.Type != NodeInt { + return fmt.Errorf("expecting an integer"), false } sel.Paging.Offset = node.Val - return nil + return nil, false } var zeroTrv = &trval{} @@ -879,17 +878,17 @@ func newExp(st *util.Stack, node *Node, usePool bool) (*Exp, error) { if ex.Op != OpAnd && ex.Op != OpOr && ex.Op != OpNot { switch node.Type { - case nodeStr: + case NodeStr: ex.Type = ValStr - case nodeInt: + case NodeInt: ex.Type = ValInt - case nodeBool: + case NodeBool: ex.Type = ValBool - case nodeFloat: + case NodeFloat: ex.Type = ValFloat - case nodeList: + case NodeList: ex.Type = ValList - case nodeVar: + case NodeVar: ex.Type = ValVar default: return nil, fmt.Errorf("[Where] valid values include string, int, float, boolean and list: %s", node.Type) @@ -903,13 +902,13 @@ func newExp(st *util.Stack, node *Node, usePool bool) (*Exp, error) { func setListVal(ex *Exp, node *Node) { if len(node.Children) != 0 { switch node.Children[0].Type { - case nodeStr: + case NodeStr: ex.ListType = ValStr - case nodeInt: + case NodeInt: ex.ListType = ValInt - case nodeBool: + case NodeBool: ex.ListType = ValBool - case nodeFloat: + case NodeFloat: ex.ListType = ValFloat } } @@ -922,7 +921,7 @@ func setWhereColName(ex *Exp, node *Node) { var list []string for n := node.Parent; n != nil; n = n.Parent { - if n.Type != nodeObj { + if n.Type != NodeObj { continue } if len(n.Name) != 0 { diff --git a/serv/cmd_seed.go b/serv/cmd_seed.go index 627c3e7..2f19e10 100644 --- a/serv/cmd_seed.go +++ b/serv/cmd_seed.go @@ -110,7 +110,7 @@ func graphQLFunc(query string, data interface{}, opt map[string]string) map[stri var root []byte - if err = tx.QueryRow(c, finalSQL).Scan(&root); err != nil { + if err = tx.QueryRow(context.Background(), finalSQL).Scan(&root); err != nil { errlog.Fatal().Err(err).Msg("sql query failed") } diff --git a/serv/core.go b/serv/core.go index 945ed9c..de482ce 100644 --- a/serv/core.go +++ b/serv/core.go @@ -81,7 +81,7 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *stmt, error) { useTx := useRoleQuery || conf.DB.SetUserID if useTx { - if tx, err = db.Begin(c); err != nil { + if tx, err = db.Begin(context.Background()); err != nil { return nil, nil, err } defer tx.Rollback(c) //nolint: errcheck @@ -122,9 +122,9 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *stmt, error) { } if useTx { - row = tx.QueryRow(c, ps.sd.SQL, vars...) + row = tx.QueryRow(context.Background(), ps.sd.SQL, vars...) } else { - row = db.QueryRow(c, ps.sd.SQL, vars...) + row = db.QueryRow(context.Background(), ps.sd.SQL, vars...) } if mutation || anonQuery { @@ -146,7 +146,7 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *stmt, error) { c.req.role = role if useTx { - if err := tx.Commit(c); err != nil { + if err := tx.Commit(context.Background()); err != nil { return nil, nil, err } } @@ -166,10 +166,10 @@ func (c *coreContext) resolveSQL() ([]byte, *stmt, error) { useTx := useRoleQuery || conf.DB.SetUserID if useTx { - if tx, err = db.Begin(c); err != nil { + if tx, err = db.Begin(context.Background()); err != nil { return nil, nil, err } - defer tx.Rollback(c) //nolint: errcheck + defer tx.Rollback(context.Background()) //nolint: errcheck } if conf.DB.SetUserID { @@ -215,9 +215,9 @@ func (c *coreContext) resolveSQL() ([]byte, *stmt, error) { defaultRole := c.req.role if useTx { - row = tx.QueryRow(c, finalSQL) + row = tx.QueryRow(context.Background(), finalSQL) } else { - row = db.QueryRow(c, finalSQL) + row = db.QueryRow(context.Background(), finalSQL) } if len(stmts) == 1 { @@ -237,7 +237,7 @@ func (c *coreContext) resolveSQL() ([]byte, *stmt, error) { } if useTx { - if err := tx.Commit(c); err != nil { + if err := tx.Commit(context.Background()); err != nil { return nil, nil, err } } @@ -263,7 +263,7 @@ func (c *coreContext) resolveSQL() ([]byte, *stmt, error) { func (c *coreContext) executeRoleQuery(tx pgx.Tx) (string, error) { var role string - row := tx.QueryRow(c, "_sg_get_role", c.req.role, 1) + row := tx.QueryRow(context.Background(), "_sg_get_role", c.req.role, 1) if err := row.Scan(&role); err != nil { return "", err @@ -320,6 +320,15 @@ func (c *coreContext) addTrace(sel []qcode.Select, id int32, st time.Time) { append(c.res.Extensions.Tracing.Execution.Resolvers, tr) } +func setLocalUserID(c context.Context, tx pgx.Tx) error { + var err error + if v := c.Value(userIDKey); v != nil { + _, err = tx.Exec(context.Background(), fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v)) + } + + return err +} + func parentFieldIds(h *xxhash.Digest, sel []qcode.Select, skipped uint32) ( [][]byte, map[uint64]*qcode.Select) { @@ -363,15 +372,6 @@ func parentFieldIds(h *xxhash.Digest, sel []qcode.Select, skipped uint32) ( return fm, sm } -func setLocalUserID(c context.Context, tx pgx.Tx) error { - var err error - if v := c.Value(userIDKey); v != nil { - _, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v)) - } - - return err -} - func isSkipped(n uint32, pos uint32) bool { return ((n & (1 << pos)) != 0) } diff --git a/serv/core_build.go b/serv/core_build.go index cb7f66d..2cc1b3b 100644 --- a/serv/core_build.go +++ b/serv/core_build.go @@ -97,6 +97,10 @@ func buildMultiStmt(gql, vars []byte) ([]stmt, error) { for i := 0; i < len(conf.Roles); i++ { role := &conf.Roles[i] + if role.Name == "anon" { + continue + } + qc, err := qcompile.Compile(gql, role.Name) if err != nil { return nil, err @@ -127,8 +131,6 @@ func buildMultiStmt(gql, vars []byte) ([]stmt, error) { //nolint: errcheck func renderUserQuery( stmts []stmt, vars map[string]json.RawMessage) (string, error) { - - var err error w := &bytes.Buffer{} io.WriteString(w, `SELECT "_sg_auth_info"."role", (CASE "_sg_auth_info"."role" `) @@ -141,11 +143,7 @@ func renderUserQuery( io.WriteString(w, `WHEN '`) io.WriteString(w, s.role.Name) io.WriteString(w, `' THEN (`) - - s.skipped, err = pcompile.Compile(s.qc, w, psql.Variables(vars)) - if err != nil { - return "", err - } + io.WriteString(w, s.sql) io.WriteString(w, `) `) } diff --git a/serv/prepare.go b/serv/prepare.go index 9255638..2d3b3a2 100644 --- a/serv/prepare.go +++ b/serv/prepare.go @@ -23,21 +23,20 @@ var ( ) func initPreparedList() { - c := context.Background() _preparedList = make(map[string]*preparedItem) - tx, err := db.Begin(c) + tx, err := db.Begin(context.Background()) if err != nil { errlog.Fatal().Err(err).Send() } - defer tx.Rollback(c) //nolint: errcheck + defer tx.Rollback(context.Background()) //nolint: errcheck - err = prepareRoleStmt(c, tx) + err = prepareRoleStmt(tx) if err != nil { errlog.Fatal().Err(err).Msg("failed to prepare get role statement") } - if err := tx.Commit(c); err != nil { + if err := tx.Commit(context.Background()); err != nil { errlog.Fatal().Err(err).Send() } @@ -48,7 +47,7 @@ func initPreparedList() { continue } - err := prepareStmt(c, v.gql, v.vars) + err := prepareStmt(v.gql, v.vars) if err == nil { success++ continue @@ -66,15 +65,15 @@ func initPreparedList() { success, len(_allowList.list)) } -func prepareStmt(c context.Context, gql string, vars []byte) error { +func prepareStmt(gql string, vars []byte) error { qt := qcode.GetQType(gql) q := []byte(gql) - tx, err := db.Begin(c) + tx, err := db.Begin(context.Background()) if err != nil { return err } - defer tx.Rollback(c) //nolint: errcheck + defer tx.Rollback(context.Background()) //nolint: errcheck switch qt { case qcode.QTQuery: @@ -83,7 +82,7 @@ func prepareStmt(c context.Context, gql string, vars []byte) error { return err } - err = prepare(c, tx, &stmts1[0], gqlHash(gql, vars, "user")) + err = prepare(tx, &stmts1[0], gqlHash(gql, vars, "user")) if err != nil { return err } @@ -93,7 +92,7 @@ func prepareStmt(c context.Context, gql string, vars []byte) error { return err } - err = prepare(c, tx, &stmts2[0], gqlHash(gql, vars, "anon")) + err = prepare(tx, &stmts2[0], gqlHash(gql, vars, "anon")) if err != nil { return err } @@ -105,7 +104,7 @@ func prepareStmt(c context.Context, gql string, vars []byte) error { return err } - err = prepare(c, tx, &stmts[0], gqlHash(gql, vars, role.Name)) + err = prepare(tx, &stmts[0], gqlHash(gql, vars, role.Name)) if err != nil { return err } @@ -118,17 +117,17 @@ func prepareStmt(c context.Context, gql string, vars []byte) error { logger.Debug().Msgf("Building prepared statement:\n %s\n%s", vars, gql) } - if err := tx.Commit(c); err != nil { + if err := tx.Commit(context.Background()); err != nil { return err } return nil } -func prepare(c context.Context, tx pgx.Tx, st *stmt, key string) error { +func prepare(tx pgx.Tx, st *stmt, key string) error { finalSQL, am := processTemplate(st.sql) - sd, err := tx.Prepare(c, "", finalSQL) + sd, err := tx.Prepare(context.Background(), "", finalSQL) if err != nil { return err } @@ -142,7 +141,7 @@ func prepare(c context.Context, tx pgx.Tx, st *stmt, key string) error { } // nolint: errcheck -func prepareRoleStmt(c context.Context, tx pgx.Tx) error { +func prepareRoleStmt(tx pgx.Tx) error { if len(conf.RolesQuery) == 0 { return nil } @@ -167,7 +166,7 @@ func prepareRoleStmt(c context.Context, tx pgx.Tx) error { roleSQL, _ := processTemplate(w.String()) - _, err := tx.Prepare(c, "_sg_get_role", roleSQL) + _, err := tx.Prepare(context.Background(), "_sg_get_role", roleSQL) if err != nil { return err } diff --git a/serv/serv.go b/serv/serv.go index c4df809..e41bb52 100644 --- a/serv/serv.go +++ b/serv/serv.go @@ -24,7 +24,6 @@ func initCompilers(c *config) (*qcode.Compiler, *psql.Compiler, error) { conf := qcode.Config{ Blocklist: c.DB.Blocklist, - KeepArgs: false, } qc, err := qcode.NewCompiler(conf)