Add support for `websearch_to_tsquery` in PG 11

This commit is contained in:
Vikram Rangnekar 2019-12-02 10:52:22 -05:00
parent 5593c66996
commit 5da79d91bf
15 changed files with 249 additions and 196 deletions

View File

@ -169,3 +169,17 @@ query {
} }
} }
variables {
"beer": "smoke"
}
query beerSearch {
products(search: $beer) {
id
name
search_rank
search_headline_description
}
}

View File

@ -171,7 +171,7 @@ roles:
query: query:
limit: 50 limit: 50
filters: ["{ user_id: { eq: $user_id } }"] filters: ["{ user_id: { eq: $user_id } }"]
columns: ["id", "name", "description" ] columns: ["id", "name", "description", "search_rank", "search_headline_description" ]
disable_functions: false disable_functions: false
insert: insert:

View File

@ -353,16 +353,14 @@ func (m *Migrator) MigrateTo(targetVersion int32) (err error) {
} }
func (m *Migrator) GetCurrentVersion() (v int32, err error) { func (m *Migrator) GetCurrentVersion() (v int32, err error) {
ctx := context.Background() err = m.conn.QueryRow(context.Background(),
"select version from "+m.versionTable).Scan(&v)
err = m.conn.QueryRow(ctx, "select version from "+m.versionTable).Scan(&v)
return v, err return v, err
} }
func (m *Migrator) ensureSchemaVersionTableExists() (err error) { func (m *Migrator) ensureSchemaVersionTableExists() (err error) {
ctx := context.Background() _, err = m.conn.Exec(context.Background(), fmt.Sprintf(`
_, err = m.conn.Exec(ctx, fmt.Sprintf(`
create table if not exists %s(version int4 not null); create table if not exists %s(version int4 not null);
insert into %s(version) insert into %s(version)

View File

@ -178,6 +178,7 @@ func TestMain(m *testing.M) {
} }
schema := &DBSchema{ schema := &DBSchema{
ver: 110000,
t: make(map[string]*DBTableInfo), t: make(map[string]*DBTableInfo),
rm: make(map[string]map[string]*DBRel), rm: make(map[string]map[string]*DBRel),
al: make(map[string]struct{}), al: make(map[string]struct{}),

View File

@ -189,6 +189,11 @@ func (co *Compiler) compileQuery(qc *qcode.QCode, w io.Writer) (uint32, error) {
} }
} }
if len(sel.Args) != 0 {
for _, v := range sel.Args {
qcode.FreeNode(v)
}
}
} }
} }
@ -515,36 +520,54 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
if isSearch { if isSearch {
switch { switch {
case cn == "search_rank": case cn == "search_rank":
if len(sel.Allowed) != 0 {
if _, ok := sel.Allowed[cn]; !ok {
continue
}
}
cn = ti.TSVCol cn = ti.TSVCol
arg := sel.Args["search"] arg := sel.Args["search"]
if i != 0 { if i != 0 {
io.WriteString(c.w, `, `) io.WriteString(c.w, `, `)
} }
//fmt.Fprintf(w, `ts_rank("%s"."%s", to_tsquery('%s')) AS %s`, //fmt.Fprintf(w, `ts_rank("%s"."%s", websearch_to_tsquery('%s')) AS %s`,
//c.sel.Name, cn, arg.Val, col.Name) //c.sel.Name, cn, arg.Val, col.Name)
io.WriteString(c.w, `ts_rank(`) io.WriteString(c.w, `ts_rank(`)
colWithTable(c.w, ti.Name, cn) colWithTable(c.w, ti.Name, cn)
if c.schema.ver >= 110000 {
io.WriteString(c.w, `, websearch_to_tsquery('`)
} else {
io.WriteString(c.w, `, to_tsquery('`) io.WriteString(c.w, `, to_tsquery('`)
}
io.WriteString(c.w, arg.Val) io.WriteString(c.w, arg.Val)
io.WriteString(c.w, `')`) io.WriteString(c.w, `'))`)
alias(c.w, col.Name) alias(c.w, col.Name)
i++ i++
case strings.HasPrefix(cn, "search_headline_"): case strings.HasPrefix(cn, "search_headline_"):
cn = cn[16:] cn1 := cn[16:]
if len(sel.Allowed) != 0 {
if _, ok := sel.Allowed[cn1]; !ok {
continue
}
}
arg := sel.Args["search"] arg := sel.Args["search"]
if i != 0 { if i != 0 {
io.WriteString(c.w, `, `) io.WriteString(c.w, `, `)
} }
//fmt.Fprintf(w, `ts_headline("%s"."%s", to_tsquery('%s')) AS %s`, //fmt.Fprintf(w, `ts_headline("%s"."%s", websearch_to_tsquery('%s')) AS %s`,
//c.sel.Name, cn, arg.Val, col.Name) //c.sel.Name, cn, arg.Val, col.Name)
io.WriteString(c.w, `ts_headlinek(`) io.WriteString(c.w, `ts_headline(`)
colWithTable(c.w, ti.Name, cn) colWithTable(c.w, ti.Name, cn1)
if c.schema.ver >= 110000 {
io.WriteString(c.w, `, websearch_to_tsquery('`)
} else {
io.WriteString(c.w, `, to_tsquery('`) io.WriteString(c.w, `, to_tsquery('`)
}
io.WriteString(c.w, arg.Val) io.WriteString(c.w, arg.Val)
io.WriteString(c.w, `')`) io.WriteString(c.w, `'))`)
alias(c.w, col.Name) alias(c.w, col.Name)
i++ i++
@ -693,6 +716,7 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
//fmt.Fprintf(w, `) AS "%s_%d"`, c.sel.Name, c.sel.ID) //fmt.Fprintf(w, `) AS "%s_%d"`, c.sel.Name, c.sel.ID)
io.WriteString(c.w, `)`) io.WriteString(c.w, `)`)
aliasWithID(c.w, ti.Name, sel.ID) aliasWithID(c.w, ti.Name, sel.ID)
return nil return nil
} }
@ -939,6 +963,7 @@ func (c *compilerContext) renderOp(ex *qcode.Exp, sel *qcode.Select, ti *DBTable
io.WriteString(c.w, `IS NOT NULL)`) io.WriteString(c.w, `IS NOT NULL)`)
} }
return nil return nil
case qcode.OpEqID: case qcode.OpEqID:
if len(ti.PrimaryCol) == 0 { if len(ti.PrimaryCol) == 0 {
return fmt.Errorf("no primary key column defined for %s", ti.Name) return fmt.Errorf("no primary key column defined for %s", ti.Name)
@ -951,6 +976,7 @@ func (c *compilerContext) renderOp(ex *qcode.Exp, sel *qcode.Select, ti *DBTable
colWithTable(c.w, ti.Name, ti.PrimaryCol) colWithTable(c.w, ti.Name, ti.PrimaryCol)
//io.WriteString(c.w, ti.PrimaryCol) //io.WriteString(c.w, ti.PrimaryCol)
io.WriteString(c.w, `) =`) io.WriteString(c.w, `) =`)
case qcode.OpTsQuery: case qcode.OpTsQuery:
if len(ti.TSVCol) == 0 { if len(ti.TSVCol) == 0 {
return fmt.Errorf("no tsv column defined for %s", ti.Name) return fmt.Errorf("no tsv column defined for %s", ti.Name)
@ -958,10 +984,14 @@ func (c *compilerContext) renderOp(ex *qcode.Exp, sel *qcode.Select, ti *DBTable
if _, ok = ti.Columns[ti.TSVCol]; !ok { if _, ok = ti.Columns[ti.TSVCol]; !ok {
return fmt.Errorf("no tsv column '%s' found ", ti.TSVCol) return fmt.Errorf("no tsv column '%s' found ", ti.TSVCol)
} }
//fmt.Fprintf(w, `(("%s") @@ to_tsquery('%s'))`, c.ti.TSVCol, val.Val) //fmt.Fprintf(w, `(("%s") @@ websearch_to_tsquery('%s'))`, c.ti.TSVCol, val.Val)
io.WriteString(c.w, `(("`) io.WriteString(c.w, `((`)
io.WriteString(c.w, ti.TSVCol) colWithTable(c.w, ti.Name, ti.TSVCol)
io.WriteString(c.w, `") @@ to_tsquery('`) if c.schema.ver >= 110000 {
io.WriteString(c.w, `) @@ websearch_to_tsquery('`)
} else {
io.WriteString(c.w, `) @@ to_tsquery('`)
}
io.WriteString(c.w, ex.Val) io.WriteString(c.w, ex.Val)
io.WriteString(c.w, `'))`) io.WriteString(c.w, `'))`)
return nil return nil

View File

@ -142,15 +142,17 @@ func fetchByID(t *testing.T) {
func searchQuery(t *testing.T) { func searchQuery(t *testing.T) {
gql := `query { gql := `query {
products(search: "Imperial") { products(search: "ale") {
id id
name name
search_rank
search_headline_description
} }
}` }`
sql := `SELECT json_object_agg('products', json_0) FROM (SELECT coalesce(json_agg("json_0"), '[]') AS "json_0" FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("tsv") @@ to_tsquery('Imperial'))) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "json_agg_0") AS "sel_0"` sql := `SELECT json_object_agg('products', json_0) FROM (SELECT coalesce(json_agg("json_0"), '[]') AS "json_0" FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."search_rank" AS "search_rank", "products_0"."search_headline_description" AS "search_headline_description") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name", ts_rank("products"."tsv", websearch_to_tsquery('ale')) AS "search_rank", ts_headline("products"."description", websearch_to_tsquery('ale')) AS "search_headline_description" FROM "products" WHERE ((("products"."tsv") @@ websearch_to_tsquery('ale'))) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "json_agg_0") AS "sel_0"`
resSQL, err := compileGQLToPSQL(gql, nil, "user") resSQL, err := compileGQLToPSQL(gql, nil, "admin")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@ -3,6 +3,7 @@ package psql
import ( import (
"context" "context"
"fmt" "fmt"
"strconv"
"strings" "strings"
"github.com/gobuffalo/flect" "github.com/gobuffalo/flect"
@ -144,6 +145,7 @@ ORDER BY id;`
} }
type DBSchema struct { type DBSchema struct {
ver int
t map[string]*DBTableInfo t map[string]*DBTableInfo
rm map[string]map[string]*DBRel rm map[string]map[string]*DBRel
al map[string]struct{} al map[string]struct{}
@ -184,10 +186,22 @@ func NewDBSchema(db *pgxpool.Pool, aliases map[string][]string) (*DBSchema, erro
dbc, err := db.Acquire(context.Background()) dbc, err := db.Acquire(context.Background())
if err != nil { if err != nil {
return nil, fmt.Errorf("error acquiring connection from pool") return nil, fmt.Errorf("error acquiring connection from pool: %w", err)
} }
defer dbc.Release() defer dbc.Release()
var version string
err = dbc.QueryRow(context.Background(), `SHOW server_version_num`).Scan(&version)
if err != nil {
return nil, fmt.Errorf("error fetching version: %w", err)
}
schema.ver, err = strconv.Atoi(version)
if err != nil {
return nil, err
}
tables, err := GetTables(dbc) tables, err := GetTables(dbc)
if err != nil { if err != nil {
return nil, err return nil, err

View File

@ -8,7 +8,6 @@ import (
type Config struct { type Config struct {
Blocklist []string Blocklist []string
KeepArgs bool
} }
type QueryConfig struct { type QueryConfig struct {

View File

@ -26,13 +26,13 @@ const (
opQuery opQuery
opMutate opMutate
opSub opSub
nodeStr NodeStr
nodeInt NodeInt
nodeFloat NodeFloat
nodeBool NodeBool
nodeObj NodeObj
nodeList NodeList
nodeVar NodeVar
) )
type Operation struct { type Operation struct {
@ -413,7 +413,7 @@ func (p *Parser) parseList() (*Node, error) {
return nil, errors.New("List cannot be empty") return nil, errors.New("List cannot be empty")
} }
parent.Type = nodeList parent.Type = NodeList
parent.Children = nodes parent.Children = nodes
return parent, nil return parent, nil
@ -450,7 +450,7 @@ func (p *Parser) parseObj() (*Node, error) {
nodes = append(nodes, node) nodes = append(nodes, node)
} }
parent.Type = nodeObj parent.Type = NodeObj
parent.Children = nodes parent.Children = nodes
return parent, nil return parent, nil
@ -473,17 +473,17 @@ func (p *Parser) parseValue() (*Node, error) {
switch item.typ { switch item.typ {
case itemIntVal: case itemIntVal:
node.Type = nodeInt node.Type = NodeInt
case itemFloatVal: case itemFloatVal:
node.Type = nodeFloat node.Type = NodeFloat
case itemStringVal: case itemStringVal:
node.Type = nodeStr node.Type = NodeStr
case itemBoolVal: case itemBoolVal:
node.Type = nodeBool node.Type = NodeBool
case itemName: case itemName:
node.Type = nodeStr node.Type = NodeStr
case itemVariable: case itemVariable:
node.Type = nodeVar node.Type = NodeVar
default: default:
return nil, fmt.Errorf("expecting a number, string, object, list or variable as an argument value (not %s)", p.val(p.next())) return nil, fmt.Errorf("expecting a number, string, object, list or variable as an argument value (not %s)", p.val(p.next()))
} }
@ -514,19 +514,19 @@ func (t parserType) String() string {
v = "mutation" v = "mutation"
case opSub: case opSub:
v = "subscription" v = "subscription"
case nodeStr: case NodeStr:
v = "node-string" v = "node-string"
case nodeInt: case NodeInt:
v = "node-int" v = "node-int"
case nodeFloat: case NodeFloat:
v = "node-float" v = "node-float"
case nodeBool: case NodeBool:
v = "node-bool" v = "node-bool"
case nodeVar: case NodeVar:
v = "node-var" v = "node-var"
case nodeObj: case NodeObj:
v = "node-obj" v = "node-obj"
case nodeList: case NodeList:
v = "node-list" v = "node-list"
} }
return fmt.Sprintf("<%s>", v) return fmt.Sprintf("<%s>", v)

View File

@ -157,7 +157,6 @@ const (
type Compiler struct { type Compiler struct {
tr map[string]map[string]*trval tr map[string]map[string]*trval
bl map[string]struct{} bl map[string]struct{}
ka bool
} }
var expPool = sync.Pool{ var expPool = sync.Pool{
@ -165,7 +164,7 @@ var expPool = sync.Pool{
} }
func NewCompiler(c Config) (*Compiler, error) { func NewCompiler(c Config) (*Compiler, error) {
co := &Compiler{ka: c.KeepArgs} co := &Compiler{}
co.tr = make(map[string]map[string]*trval) co.tr = make(map[string]map[string]*trval)
co.bl = make(map[string]struct{}, len(c.Blocklist)) co.bl = make(map[string]struct{}, len(c.Blocklist))
@ -380,11 +379,13 @@ func (com *Compiler) compileQuery(qc *QCode, op *Operation, role string) error {
return nil return nil
} }
func (com *Compiler) addFilters(qc *QCode, root *Select, role string) { func (com *Compiler) addFilters(qc *QCode, sel *Select, role string) {
var fil *Exp var fil *Exp
if trv, ok := com.tr[role][root.Name]; ok { if trv, ok := com.tr[role][sel.Name]; ok {
fil = trv.filter(qc.Type) fil = trv.filter(qc.Type)
} else {
return
} }
if fil == nil { if fil == nil {
@ -394,60 +395,61 @@ func (com *Compiler) addFilters(qc *QCode, root *Select, role string) {
switch fil.Op { switch fil.Op {
case OpNop: case OpNop:
case OpFalse: case OpFalse:
root.Where = fil sel.Where = fil
default: default:
if root.Where != nil { if sel.Where != nil {
ow := root.Where ow := sel.Where
root.Where = expPool.Get().(*Exp) sel.Where = expPool.Get().(*Exp)
root.Where.Reset() sel.Where.Reset()
root.Where.Op = OpAnd sel.Where.Op = OpAnd
root.Where.Children = root.Where.childrenA[:2] sel.Where.Children = sel.Where.childrenA[:2]
root.Where.Children[0] = fil sel.Where.Children[0] = fil
root.Where.Children[1] = ow sel.Where.Children[1] = ow
} else { } else {
root.Where = fil sel.Where = fil
} }
} }
} }
func (com *Compiler) compileArgs(qc *QCode, sel *Select, args []Arg) error { func (com *Compiler) compileArgs(qc *QCode, sel *Select, args []Arg) error {
var err error var err error
var ka bool
if com.ka {
sel.Args = make(map[string]*Node, len(args))
}
for i := range args { for i := range args {
arg := &args[i] arg := &args[i]
switch arg.Name { switch arg.Name {
case "id": case "id":
err = com.compileArgID(sel, arg) err, ka = com.compileArgID(sel, arg)
case "search": case "search":
err = com.compileArgSearch(sel, arg) err, ka = com.compileArgSearch(sel, arg)
case "where": case "where":
err = com.compileArgWhere(sel, arg) err, ka = com.compileArgWhere(sel, arg)
case "orderby", "order_by", "order": case "orderby", "order_by", "order":
err = com.compileArgOrderBy(sel, arg) err, ka = com.compileArgOrderBy(sel, arg)
case "distinct_on", "distinct": case "distinct_on", "distinct":
err = com.compileArgDistinctOn(sel, arg) err, ka = com.compileArgDistinctOn(sel, arg)
case "limit": case "limit":
err = com.compileArgLimit(sel, arg) err, ka = com.compileArgLimit(sel, arg)
case "offset": case "offset":
err = com.compileArgOffset(sel, arg) err, ka = com.compileArgOffset(sel, arg)
}
if !ka {
nodePool.Put(arg.Val)
} }
if err != nil { if err != nil {
return err return err
} }
if sel.Args != nil {
sel.Args[arg.Name] = arg.Val
} else {
nodePool.Put(arg.Val)
}
} }
return nil return nil
@ -455,7 +457,7 @@ func (com *Compiler) compileArgs(qc *QCode, sel *Select, args []Arg) error {
func (com *Compiler) setMutationType(qc *QCode, args []Arg) error { func (com *Compiler) setMutationType(qc *QCode, args []Arg) error {
setActionVar := func(arg *Arg) error { setActionVar := func(arg *Arg) error {
if arg.Val.Type != nodeVar { if arg.Val.Type != NodeVar {
return fmt.Errorf("value for argument '%s' must be a variable", arg.Name) return fmt.Errorf("value for argument '%s' must be a variable", arg.Name)
} }
qc.ActionVar = arg.Val.Val qc.ActionVar = arg.Val.Val
@ -478,7 +480,7 @@ func (com *Compiler) setMutationType(qc *QCode, args []Arg) error {
case "delete": case "delete":
qc.Type = QTDelete qc.Type = QTDelete
if arg.Val.Type != nodeBool { if arg.Val.Type != NodeBool {
return fmt.Errorf("value for argument '%s' must be a boolean", arg.Name) return fmt.Errorf("value for argument '%s' must be a boolean", arg.Name)
} }
@ -493,7 +495,7 @@ func (com *Compiler) setMutationType(qc *QCode, args []Arg) error {
} }
func (com *Compiler) compileArgObj(st *util.Stack, arg *Arg) (*Exp, error) { func (com *Compiler) compileArgObj(st *util.Stack, arg *Arg) (*Exp, error) {
if arg.Val.Type != nodeObj { if arg.Val.Type != NodeObj {
return nil, fmt.Errorf("expecting an object") return nil, fmt.Errorf("expecting an object")
} }
@ -545,11 +547,6 @@ func (com *Compiler) compileArgNode(st *util.Stack, node *Node, usePool bool) (*
} else { } else {
node.exp.Children = append(node.exp.Children, ex) node.exp.Children = append(node.exp.Children, ex)
} }
}
if com.ka {
return root, nil
} }
pushChild(st, nil, node) pushChild(st, nil, node)
@ -570,13 +567,13 @@ func (com *Compiler) compileArgNode(st *util.Stack, node *Node, usePool bool) (*
return root, nil return root, nil
} }
func (com *Compiler) compileArgID(sel *Select, arg *Arg) error { func (com *Compiler) compileArgID(sel *Select, arg *Arg) (error, bool) {
if sel.ID != 0 { if sel.ID != 0 {
return nil return nil, false
} }
if sel.Where != nil && sel.Where.Op == OpEqID { if sel.Where != nil && sel.Where.Op == OpEqID {
return nil return nil, false
} }
ex := expPool.Get().(*Exp) ex := expPool.Get().(*Exp)
@ -586,30 +583,41 @@ func (com *Compiler) compileArgID(sel *Select, arg *Arg) error {
ex.Val = arg.Val.Val ex.Val = arg.Val.Val
switch arg.Val.Type { switch arg.Val.Type {
case nodeStr: case NodeStr:
ex.Type = ValStr ex.Type = ValStr
case nodeInt: case NodeInt:
ex.Type = ValInt ex.Type = ValInt
case nodeFloat: case NodeFloat:
ex.Type = ValFloat ex.Type = ValFloat
case nodeVar: case NodeVar:
ex.Type = ValVar ex.Type = ValVar
default: default:
return fmt.Errorf("expecting a string, int, float or variable") return fmt.Errorf("expecting a string, int, float or variable"), false
} }
sel.Where = ex sel.Where = ex
return nil return nil, false
} }
func (com *Compiler) compileArgSearch(sel *Select, arg *Arg) error { func (com *Compiler) compileArgSearch(sel *Select, arg *Arg) (error, bool) {
ex := expPool.Get().(*Exp) ex := expPool.Get().(*Exp)
ex.Reset() ex.Reset()
ex.Op = OpTsQuery ex.Op = OpTsQuery
ex.Type = ValStr
ex.Val = arg.Val.Val ex.Val = arg.Val.Val
if arg.Val.Type == NodeVar {
ex.Type = ValVar
} else {
ex.Type = ValStr
}
if sel.Args == nil {
sel.Args = make(map[string]*Node)
}
sel.Args[arg.Name] = arg.Val
if sel.Where != nil { if sel.Where != nil {
ow := sel.Where ow := sel.Where
@ -622,16 +630,16 @@ func (com *Compiler) compileArgSearch(sel *Select, arg *Arg) error {
} else { } else {
sel.Where = ex sel.Where = ex
} }
return nil return nil, true
} }
func (com *Compiler) compileArgWhere(sel *Select, arg *Arg) error { func (com *Compiler) compileArgWhere(sel *Select, arg *Arg) (error, bool) {
st := util.NewStack() st := util.NewStack()
var err error var err error
ex, err := com.compileArgObj(st, arg) ex, err := com.compileArgObj(st, arg)
if err != nil { if err != nil {
return err return err, false
} }
if sel.Where != nil { if sel.Where != nil {
@ -647,12 +655,12 @@ func (com *Compiler) compileArgWhere(sel *Select, arg *Arg) error {
sel.Where = ex sel.Where = ex
} }
return nil return nil, false
} }
func (com *Compiler) compileArgOrderBy(sel *Select, arg *Arg) error { func (com *Compiler) compileArgOrderBy(sel *Select, arg *Arg) (error, bool) {
if arg.Val.Type != nodeObj { if arg.Val.Type != NodeObj {
return fmt.Errorf("expecting an object") return fmt.Errorf("expecting an object"), false
} }
st := util.NewStack() st := util.NewStack()
@ -670,23 +678,19 @@ func (com *Compiler) compileArgOrderBy(sel *Select, arg *Arg) error {
node, ok := intf.(*Node) node, ok := intf.(*Node)
if !ok || node == nil { if !ok || node == nil {
return fmt.Errorf("17: unexpected value %v (%t)", intf, intf) return fmt.Errorf("17: unexpected value %v (%t)", intf, intf), false
} }
if _, ok := com.bl[node.Name]; ok { if _, ok := com.bl[node.Name]; ok {
if !com.ka {
nodePool.Put(node) nodePool.Put(node)
}
continue continue
} }
if node.Type == nodeObj { if node.Type == NodeObj {
for i := range node.Children { for i := range node.Children {
st.Push(node.Children[i]) st.Push(node.Children[i])
} }
if !com.ka {
nodePool.Put(node) nodePool.Put(node)
}
continue continue
} }
@ -706,65 +710,60 @@ func (com *Compiler) compileArgOrderBy(sel *Select, arg *Arg) error {
case "desc_nulls_last": case "desc_nulls_last":
ob.Order = OrderDescNullsLast ob.Order = OrderDescNullsLast
default: default:
return fmt.Errorf("valid values include asc, desc, asc_nulls_first and desc_nulls_first") return fmt.Errorf("valid values include asc, desc, asc_nulls_first and desc_nulls_first"), false
} }
setOrderByColName(ob, node) setOrderByColName(ob, node)
sel.OrderBy = append(sel.OrderBy, ob) sel.OrderBy = append(sel.OrderBy, ob)
if !com.ka {
nodePool.Put(node) nodePool.Put(node)
} }
} return nil, false
return nil
} }
func (com *Compiler) compileArgDistinctOn(sel *Select, arg *Arg) error { func (com *Compiler) compileArgDistinctOn(sel *Select, arg *Arg) (error, bool) {
node := arg.Val node := arg.Val
if _, ok := com.bl[node.Name]; ok { if _, ok := com.bl[node.Name]; ok {
return nil return nil, false
} }
if node.Type != nodeList && node.Type != nodeStr { if node.Type != NodeList && node.Type != NodeStr {
return fmt.Errorf("expecting a list of strings or just a string") return fmt.Errorf("expecting a list of strings or just a string"), false
} }
if node.Type == nodeStr { if node.Type == NodeStr {
sel.DistinctOn = append(sel.DistinctOn, node.Val) sel.DistinctOn = append(sel.DistinctOn, node.Val)
} }
for i := range node.Children { for i := range node.Children {
sel.DistinctOn = append(sel.DistinctOn, node.Children[i].Val) sel.DistinctOn = append(sel.DistinctOn, node.Children[i].Val)
if !com.ka {
nodePool.Put(node.Children[i]) nodePool.Put(node.Children[i])
} }
return nil, false
} }
return nil func (com *Compiler) compileArgLimit(sel *Select, arg *Arg) (error, bool) {
}
func (com *Compiler) compileArgLimit(sel *Select, arg *Arg) error {
node := arg.Val node := arg.Val
if node.Type != nodeInt { if node.Type != NodeInt {
return fmt.Errorf("expecting an integer") return fmt.Errorf("expecting an integer"), false
} }
sel.Paging.Limit = node.Val sel.Paging.Limit = node.Val
return nil return nil, false
} }
func (com *Compiler) compileArgOffset(sel *Select, arg *Arg) error { func (com *Compiler) compileArgOffset(sel *Select, arg *Arg) (error, bool) {
node := arg.Val node := arg.Val
if node.Type != nodeInt { if node.Type != NodeInt {
return fmt.Errorf("expecting an integer") return fmt.Errorf("expecting an integer"), false
} }
sel.Paging.Offset = node.Val sel.Paging.Offset = node.Val
return nil return nil, false
} }
var zeroTrv = &trval{} var zeroTrv = &trval{}
@ -879,17 +878,17 @@ func newExp(st *util.Stack, node *Node, usePool bool) (*Exp, error) {
if ex.Op != OpAnd && ex.Op != OpOr && ex.Op != OpNot { if ex.Op != OpAnd && ex.Op != OpOr && ex.Op != OpNot {
switch node.Type { switch node.Type {
case nodeStr: case NodeStr:
ex.Type = ValStr ex.Type = ValStr
case nodeInt: case NodeInt:
ex.Type = ValInt ex.Type = ValInt
case nodeBool: case NodeBool:
ex.Type = ValBool ex.Type = ValBool
case nodeFloat: case NodeFloat:
ex.Type = ValFloat ex.Type = ValFloat
case nodeList: case NodeList:
ex.Type = ValList ex.Type = ValList
case nodeVar: case NodeVar:
ex.Type = ValVar ex.Type = ValVar
default: default:
return nil, fmt.Errorf("[Where] valid values include string, int, float, boolean and list: %s", node.Type) return nil, fmt.Errorf("[Where] valid values include string, int, float, boolean and list: %s", node.Type)
@ -903,13 +902,13 @@ func newExp(st *util.Stack, node *Node, usePool bool) (*Exp, error) {
func setListVal(ex *Exp, node *Node) { func setListVal(ex *Exp, node *Node) {
if len(node.Children) != 0 { if len(node.Children) != 0 {
switch node.Children[0].Type { switch node.Children[0].Type {
case nodeStr: case NodeStr:
ex.ListType = ValStr ex.ListType = ValStr
case nodeInt: case NodeInt:
ex.ListType = ValInt ex.ListType = ValInt
case nodeBool: case NodeBool:
ex.ListType = ValBool ex.ListType = ValBool
case nodeFloat: case NodeFloat:
ex.ListType = ValFloat ex.ListType = ValFloat
} }
} }
@ -922,7 +921,7 @@ func setWhereColName(ex *Exp, node *Node) {
var list []string var list []string
for n := node.Parent; n != nil; n = n.Parent { for n := node.Parent; n != nil; n = n.Parent {
if n.Type != nodeObj { if n.Type != NodeObj {
continue continue
} }
if len(n.Name) != 0 { if len(n.Name) != 0 {

View File

@ -110,7 +110,7 @@ func graphQLFunc(query string, data interface{}, opt map[string]string) map[stri
var root []byte var root []byte
if err = tx.QueryRow(c, finalSQL).Scan(&root); err != nil { if err = tx.QueryRow(context.Background(), finalSQL).Scan(&root); err != nil {
errlog.Fatal().Err(err).Msg("sql query failed") errlog.Fatal().Err(err).Msg("sql query failed")
} }

View File

@ -81,7 +81,7 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *stmt, error) {
useTx := useRoleQuery || conf.DB.SetUserID useTx := useRoleQuery || conf.DB.SetUserID
if useTx { if useTx {
if tx, err = db.Begin(c); err != nil { if tx, err = db.Begin(context.Background()); err != nil {
return nil, nil, err return nil, nil, err
} }
defer tx.Rollback(c) //nolint: errcheck defer tx.Rollback(c) //nolint: errcheck
@ -122,9 +122,9 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *stmt, error) {
} }
if useTx { if useTx {
row = tx.QueryRow(c, ps.sd.SQL, vars...) row = tx.QueryRow(context.Background(), ps.sd.SQL, vars...)
} else { } else {
row = db.QueryRow(c, ps.sd.SQL, vars...) row = db.QueryRow(context.Background(), ps.sd.SQL, vars...)
} }
if mutation || anonQuery { if mutation || anonQuery {
@ -146,7 +146,7 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *stmt, error) {
c.req.role = role c.req.role = role
if useTx { if useTx {
if err := tx.Commit(c); err != nil { if err := tx.Commit(context.Background()); err != nil {
return nil, nil, err return nil, nil, err
} }
} }
@ -166,10 +166,10 @@ func (c *coreContext) resolveSQL() ([]byte, *stmt, error) {
useTx := useRoleQuery || conf.DB.SetUserID useTx := useRoleQuery || conf.DB.SetUserID
if useTx { if useTx {
if tx, err = db.Begin(c); err != nil { if tx, err = db.Begin(context.Background()); err != nil {
return nil, nil, err return nil, nil, err
} }
defer tx.Rollback(c) //nolint: errcheck defer tx.Rollback(context.Background()) //nolint: errcheck
} }
if conf.DB.SetUserID { if conf.DB.SetUserID {
@ -215,9 +215,9 @@ func (c *coreContext) resolveSQL() ([]byte, *stmt, error) {
defaultRole := c.req.role defaultRole := c.req.role
if useTx { if useTx {
row = tx.QueryRow(c, finalSQL) row = tx.QueryRow(context.Background(), finalSQL)
} else { } else {
row = db.QueryRow(c, finalSQL) row = db.QueryRow(context.Background(), finalSQL)
} }
if len(stmts) == 1 { if len(stmts) == 1 {
@ -237,7 +237,7 @@ func (c *coreContext) resolveSQL() ([]byte, *stmt, error) {
} }
if useTx { if useTx {
if err := tx.Commit(c); err != nil { if err := tx.Commit(context.Background()); err != nil {
return nil, nil, err return nil, nil, err
} }
} }
@ -263,7 +263,7 @@ func (c *coreContext) resolveSQL() ([]byte, *stmt, error) {
func (c *coreContext) executeRoleQuery(tx pgx.Tx) (string, error) { func (c *coreContext) executeRoleQuery(tx pgx.Tx) (string, error) {
var role string var role string
row := tx.QueryRow(c, "_sg_get_role", c.req.role, 1) row := tx.QueryRow(context.Background(), "_sg_get_role", c.req.role, 1)
if err := row.Scan(&role); err != nil { if err := row.Scan(&role); err != nil {
return "", err return "", err
@ -320,6 +320,15 @@ func (c *coreContext) addTrace(sel []qcode.Select, id int32, st time.Time) {
append(c.res.Extensions.Tracing.Execution.Resolvers, tr) append(c.res.Extensions.Tracing.Execution.Resolvers, tr)
} }
func setLocalUserID(c context.Context, tx pgx.Tx) error {
var err error
if v := c.Value(userIDKey); v != nil {
_, err = tx.Exec(context.Background(), fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
}
return err
}
func parentFieldIds(h *xxhash.Digest, sel []qcode.Select, skipped uint32) ( func parentFieldIds(h *xxhash.Digest, sel []qcode.Select, skipped uint32) (
[][]byte, [][]byte,
map[uint64]*qcode.Select) { map[uint64]*qcode.Select) {
@ -363,15 +372,6 @@ func parentFieldIds(h *xxhash.Digest, sel []qcode.Select, skipped uint32) (
return fm, sm return fm, sm
} }
func setLocalUserID(c context.Context, tx pgx.Tx) error {
var err error
if v := c.Value(userIDKey); v != nil {
_, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
}
return err
}
func isSkipped(n uint32, pos uint32) bool { func isSkipped(n uint32, pos uint32) bool {
return ((n & (1 << pos)) != 0) return ((n & (1 << pos)) != 0)
} }

View File

@ -97,6 +97,10 @@ func buildMultiStmt(gql, vars []byte) ([]stmt, error) {
for i := 0; i < len(conf.Roles); i++ { for i := 0; i < len(conf.Roles); i++ {
role := &conf.Roles[i] role := &conf.Roles[i]
if role.Name == "anon" {
continue
}
qc, err := qcompile.Compile(gql, role.Name) qc, err := qcompile.Compile(gql, role.Name)
if err != nil { if err != nil {
return nil, err return nil, err
@ -127,8 +131,6 @@ func buildMultiStmt(gql, vars []byte) ([]stmt, error) {
//nolint: errcheck //nolint: errcheck
func renderUserQuery( func renderUserQuery(
stmts []stmt, vars map[string]json.RawMessage) (string, error) { stmts []stmt, vars map[string]json.RawMessage) (string, error) {
var err error
w := &bytes.Buffer{} w := &bytes.Buffer{}
io.WriteString(w, `SELECT "_sg_auth_info"."role", (CASE "_sg_auth_info"."role" `) io.WriteString(w, `SELECT "_sg_auth_info"."role", (CASE "_sg_auth_info"."role" `)
@ -141,11 +143,7 @@ func renderUserQuery(
io.WriteString(w, `WHEN '`) io.WriteString(w, `WHEN '`)
io.WriteString(w, s.role.Name) io.WriteString(w, s.role.Name)
io.WriteString(w, `' THEN (`) io.WriteString(w, `' THEN (`)
io.WriteString(w, s.sql)
s.skipped, err = pcompile.Compile(s.qc, w, psql.Variables(vars))
if err != nil {
return "", err
}
io.WriteString(w, `) `) io.WriteString(w, `) `)
} }

View File

@ -23,21 +23,20 @@ var (
) )
func initPreparedList() { func initPreparedList() {
c := context.Background()
_preparedList = make(map[string]*preparedItem) _preparedList = make(map[string]*preparedItem)
tx, err := db.Begin(c) tx, err := db.Begin(context.Background())
if err != nil { if err != nil {
errlog.Fatal().Err(err).Send() errlog.Fatal().Err(err).Send()
} }
defer tx.Rollback(c) //nolint: errcheck defer tx.Rollback(context.Background()) //nolint: errcheck
err = prepareRoleStmt(c, tx) err = prepareRoleStmt(tx)
if err != nil { if err != nil {
errlog.Fatal().Err(err).Msg("failed to prepare get role statement") errlog.Fatal().Err(err).Msg("failed to prepare get role statement")
} }
if err := tx.Commit(c); err != nil { if err := tx.Commit(context.Background()); err != nil {
errlog.Fatal().Err(err).Send() errlog.Fatal().Err(err).Send()
} }
@ -48,7 +47,7 @@ func initPreparedList() {
continue continue
} }
err := prepareStmt(c, v.gql, v.vars) err := prepareStmt(v.gql, v.vars)
if err == nil { if err == nil {
success++ success++
continue continue
@ -66,15 +65,15 @@ func initPreparedList() {
success, len(_allowList.list)) success, len(_allowList.list))
} }
func prepareStmt(c context.Context, gql string, vars []byte) error { func prepareStmt(gql string, vars []byte) error {
qt := qcode.GetQType(gql) qt := qcode.GetQType(gql)
q := []byte(gql) q := []byte(gql)
tx, err := db.Begin(c) tx, err := db.Begin(context.Background())
if err != nil { if err != nil {
return err return err
} }
defer tx.Rollback(c) //nolint: errcheck defer tx.Rollback(context.Background()) //nolint: errcheck
switch qt { switch qt {
case qcode.QTQuery: case qcode.QTQuery:
@ -83,7 +82,7 @@ func prepareStmt(c context.Context, gql string, vars []byte) error {
return err return err
} }
err = prepare(c, tx, &stmts1[0], gqlHash(gql, vars, "user")) err = prepare(tx, &stmts1[0], gqlHash(gql, vars, "user"))
if err != nil { if err != nil {
return err return err
} }
@ -93,7 +92,7 @@ func prepareStmt(c context.Context, gql string, vars []byte) error {
return err return err
} }
err = prepare(c, tx, &stmts2[0], gqlHash(gql, vars, "anon")) err = prepare(tx, &stmts2[0], gqlHash(gql, vars, "anon"))
if err != nil { if err != nil {
return err return err
} }
@ -105,7 +104,7 @@ func prepareStmt(c context.Context, gql string, vars []byte) error {
return err return err
} }
err = prepare(c, tx, &stmts[0], gqlHash(gql, vars, role.Name)) err = prepare(tx, &stmts[0], gqlHash(gql, vars, role.Name))
if err != nil { if err != nil {
return err return err
} }
@ -118,17 +117,17 @@ func prepareStmt(c context.Context, gql string, vars []byte) error {
logger.Debug().Msgf("Building prepared statement:\n %s\n%s", vars, gql) logger.Debug().Msgf("Building prepared statement:\n %s\n%s", vars, gql)
} }
if err := tx.Commit(c); err != nil { if err := tx.Commit(context.Background()); err != nil {
return err return err
} }
return nil return nil
} }
func prepare(c context.Context, tx pgx.Tx, st *stmt, key string) error { func prepare(tx pgx.Tx, st *stmt, key string) error {
finalSQL, am := processTemplate(st.sql) finalSQL, am := processTemplate(st.sql)
sd, err := tx.Prepare(c, "", finalSQL) sd, err := tx.Prepare(context.Background(), "", finalSQL)
if err != nil { if err != nil {
return err return err
} }
@ -142,7 +141,7 @@ func prepare(c context.Context, tx pgx.Tx, st *stmt, key string) error {
} }
// nolint: errcheck // nolint: errcheck
func prepareRoleStmt(c context.Context, tx pgx.Tx) error { func prepareRoleStmt(tx pgx.Tx) error {
if len(conf.RolesQuery) == 0 { if len(conf.RolesQuery) == 0 {
return nil return nil
} }
@ -167,7 +166,7 @@ func prepareRoleStmt(c context.Context, tx pgx.Tx) error {
roleSQL, _ := processTemplate(w.String()) roleSQL, _ := processTemplate(w.String())
_, err := tx.Prepare(c, "_sg_get_role", roleSQL) _, err := tx.Prepare(context.Background(), "_sg_get_role", roleSQL)
if err != nil { if err != nil {
return err return err
} }

View File

@ -24,7 +24,6 @@ func initCompilers(c *config) (*qcode.Compiler, *psql.Compiler, error) {
conf := qcode.Config{ conf := qcode.Config{
Blocklist: c.DB.Blocklist, Blocklist: c.DB.Blocklist,
KeepArgs: false,
} }
qc, err := qcode.NewCompiler(conf) qc, err := qcode.NewCompiler(conf)