Add support for `websearch_to_tsquery` in PG 11

This commit is contained in:
Vikram Rangnekar 2019-12-02 10:52:22 -05:00
parent 5593c66996
commit 5da79d91bf
15 changed files with 249 additions and 196 deletions

View File

@ -169,3 +169,17 @@ query {
}
}
variables {
"beer": "smoke"
}
query beerSearch {
products(search: $beer) {
id
name
search_rank
search_headline_description
}
}

View File

@ -171,7 +171,7 @@ roles:
query:
limit: 50
filters: ["{ user_id: { eq: $user_id } }"]
columns: ["id", "name", "description" ]
columns: ["id", "name", "description", "search_rank", "search_headline_description" ]
disable_functions: false
insert:

View File

@ -353,16 +353,14 @@ func (m *Migrator) MigrateTo(targetVersion int32) (err error) {
}
func (m *Migrator) GetCurrentVersion() (v int32, err error) {
ctx := context.Background()
err = m.conn.QueryRow(context.Background(),
"select version from "+m.versionTable).Scan(&v)
err = m.conn.QueryRow(ctx, "select version from "+m.versionTable).Scan(&v)
return v, err
}
func (m *Migrator) ensureSchemaVersionTableExists() (err error) {
ctx := context.Background()
_, err = m.conn.Exec(ctx, fmt.Sprintf(`
_, err = m.conn.Exec(context.Background(), fmt.Sprintf(`
create table if not exists %s(version int4 not null);
insert into %s(version)

View File

@ -178,6 +178,7 @@ func TestMain(m *testing.M) {
}
schema := &DBSchema{
ver: 110000,
t: make(map[string]*DBTableInfo),
rm: make(map[string]map[string]*DBRel),
al: make(map[string]struct{}),

View File

@ -189,6 +189,11 @@ func (co *Compiler) compileQuery(qc *qcode.QCode, w io.Writer) (uint32, error) {
}
}
if len(sel.Args) != 0 {
for _, v := range sel.Args {
qcode.FreeNode(v)
}
}
}
}
@ -515,36 +520,54 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
if isSearch {
switch {
case cn == "search_rank":
if len(sel.Allowed) != 0 {
if _, ok := sel.Allowed[cn]; !ok {
continue
}
}
cn = ti.TSVCol
arg := sel.Args["search"]
if i != 0 {
io.WriteString(c.w, `, `)
}
//fmt.Fprintf(w, `ts_rank("%s"."%s", to_tsquery('%s')) AS %s`,
//fmt.Fprintf(w, `ts_rank("%s"."%s", websearch_to_tsquery('%s')) AS %s`,
//c.sel.Name, cn, arg.Val, col.Name)
io.WriteString(c.w, `ts_rank(`)
colWithTable(c.w, ti.Name, cn)
if c.schema.ver >= 110000 {
io.WriteString(c.w, `, websearch_to_tsquery('`)
} else {
io.WriteString(c.w, `, to_tsquery('`)
}
io.WriteString(c.w, arg.Val)
io.WriteString(c.w, `')`)
io.WriteString(c.w, `'))`)
alias(c.w, col.Name)
i++
case strings.HasPrefix(cn, "search_headline_"):
cn = cn[16:]
cn1 := cn[16:]
if len(sel.Allowed) != 0 {
if _, ok := sel.Allowed[cn1]; !ok {
continue
}
}
arg := sel.Args["search"]
if i != 0 {
io.WriteString(c.w, `, `)
}
//fmt.Fprintf(w, `ts_headline("%s"."%s", to_tsquery('%s')) AS %s`,
//fmt.Fprintf(w, `ts_headline("%s"."%s", websearch_to_tsquery('%s')) AS %s`,
//c.sel.Name, cn, arg.Val, col.Name)
io.WriteString(c.w, `ts_headlinek(`)
colWithTable(c.w, ti.Name, cn)
io.WriteString(c.w, `ts_headline(`)
colWithTable(c.w, ti.Name, cn1)
if c.schema.ver >= 110000 {
io.WriteString(c.w, `, websearch_to_tsquery('`)
} else {
io.WriteString(c.w, `, to_tsquery('`)
}
io.WriteString(c.w, arg.Val)
io.WriteString(c.w, `')`)
io.WriteString(c.w, `'))`)
alias(c.w, col.Name)
i++
@ -693,6 +716,7 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
//fmt.Fprintf(w, `) AS "%s_%d"`, c.sel.Name, c.sel.ID)
io.WriteString(c.w, `)`)
aliasWithID(c.w, ti.Name, sel.ID)
return nil
}
@ -939,6 +963,7 @@ func (c *compilerContext) renderOp(ex *qcode.Exp, sel *qcode.Select, ti *DBTable
io.WriteString(c.w, `IS NOT NULL)`)
}
return nil
case qcode.OpEqID:
if len(ti.PrimaryCol) == 0 {
return fmt.Errorf("no primary key column defined for %s", ti.Name)
@ -951,6 +976,7 @@ func (c *compilerContext) renderOp(ex *qcode.Exp, sel *qcode.Select, ti *DBTable
colWithTable(c.w, ti.Name, ti.PrimaryCol)
//io.WriteString(c.w, ti.PrimaryCol)
io.WriteString(c.w, `) =`)
case qcode.OpTsQuery:
if len(ti.TSVCol) == 0 {
return fmt.Errorf("no tsv column defined for %s", ti.Name)
@ -958,10 +984,14 @@ func (c *compilerContext) renderOp(ex *qcode.Exp, sel *qcode.Select, ti *DBTable
if _, ok = ti.Columns[ti.TSVCol]; !ok {
return fmt.Errorf("no tsv column '%s' found ", ti.TSVCol)
}
//fmt.Fprintf(w, `(("%s") @@ to_tsquery('%s'))`, c.ti.TSVCol, val.Val)
io.WriteString(c.w, `(("`)
io.WriteString(c.w, ti.TSVCol)
io.WriteString(c.w, `") @@ to_tsquery('`)
//fmt.Fprintf(w, `(("%s") @@ websearch_to_tsquery('%s'))`, c.ti.TSVCol, val.Val)
io.WriteString(c.w, `((`)
colWithTable(c.w, ti.Name, ti.TSVCol)
if c.schema.ver >= 110000 {
io.WriteString(c.w, `) @@ websearch_to_tsquery('`)
} else {
io.WriteString(c.w, `) @@ to_tsquery('`)
}
io.WriteString(c.w, ex.Val)
io.WriteString(c.w, `'))`)
return nil

View File

@ -142,15 +142,17 @@ func fetchByID(t *testing.T) {
func searchQuery(t *testing.T) {
gql := `query {
products(search: "Imperial") {
products(search: "ale") {
id
name
search_rank
search_headline_description
}
}`
sql := `SELECT json_object_agg('products', json_0) FROM (SELECT coalesce(json_agg("json_0"), '[]') AS "json_0" FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("tsv") @@ to_tsquery('Imperial'))) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "json_agg_0") AS "sel_0"`
sql := `SELECT json_object_agg('products', json_0) FROM (SELECT coalesce(json_agg("json_0"), '[]') AS "json_0" FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."search_rank" AS "search_rank", "products_0"."search_headline_description" AS "search_headline_description") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name", ts_rank("products"."tsv", websearch_to_tsquery('ale')) AS "search_rank", ts_headline("products"."description", websearch_to_tsquery('ale')) AS "search_headline_description" FROM "products" WHERE ((("products"."tsv") @@ websearch_to_tsquery('ale'))) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "json_agg_0") AS "sel_0"`
resSQL, err := compileGQLToPSQL(gql, nil, "user")
resSQL, err := compileGQLToPSQL(gql, nil, "admin")
if err != nil {
t.Fatal(err)
}

View File

@ -3,6 +3,7 @@ package psql
import (
"context"
"fmt"
"strconv"
"strings"
"github.com/gobuffalo/flect"
@ -144,6 +145,7 @@ ORDER BY id;`
}
type DBSchema struct {
ver int
t map[string]*DBTableInfo
rm map[string]map[string]*DBRel
al map[string]struct{}
@ -184,10 +186,22 @@ func NewDBSchema(db *pgxpool.Pool, aliases map[string][]string) (*DBSchema, erro
dbc, err := db.Acquire(context.Background())
if err != nil {
return nil, fmt.Errorf("error acquiring connection from pool")
return nil, fmt.Errorf("error acquiring connection from pool: %w", err)
}
defer dbc.Release()
var version string
err = dbc.QueryRow(context.Background(), `SHOW server_version_num`).Scan(&version)
if err != nil {
return nil, fmt.Errorf("error fetching version: %w", err)
}
schema.ver, err = strconv.Atoi(version)
if err != nil {
return nil, err
}
tables, err := GetTables(dbc)
if err != nil {
return nil, err

View File

@ -8,7 +8,6 @@ import (
type Config struct {
Blocklist []string
KeepArgs bool
}
type QueryConfig struct {

View File

@ -26,13 +26,13 @@ const (
opQuery
opMutate
opSub
nodeStr
nodeInt
nodeFloat
nodeBool
nodeObj
nodeList
nodeVar
NodeStr
NodeInt
NodeFloat
NodeBool
NodeObj
NodeList
NodeVar
)
type Operation struct {
@ -413,7 +413,7 @@ func (p *Parser) parseList() (*Node, error) {
return nil, errors.New("List cannot be empty")
}
parent.Type = nodeList
parent.Type = NodeList
parent.Children = nodes
return parent, nil
@ -450,7 +450,7 @@ func (p *Parser) parseObj() (*Node, error) {
nodes = append(nodes, node)
}
parent.Type = nodeObj
parent.Type = NodeObj
parent.Children = nodes
return parent, nil
@ -473,17 +473,17 @@ func (p *Parser) parseValue() (*Node, error) {
switch item.typ {
case itemIntVal:
node.Type = nodeInt
node.Type = NodeInt
case itemFloatVal:
node.Type = nodeFloat
node.Type = NodeFloat
case itemStringVal:
node.Type = nodeStr
node.Type = NodeStr
case itemBoolVal:
node.Type = nodeBool
node.Type = NodeBool
case itemName:
node.Type = nodeStr
node.Type = NodeStr
case itemVariable:
node.Type = nodeVar
node.Type = NodeVar
default:
return nil, fmt.Errorf("expecting a number, string, object, list or variable as an argument value (not %s)", p.val(p.next()))
}
@ -514,19 +514,19 @@ func (t parserType) String() string {
v = "mutation"
case opSub:
v = "subscription"
case nodeStr:
case NodeStr:
v = "node-string"
case nodeInt:
case NodeInt:
v = "node-int"
case nodeFloat:
case NodeFloat:
v = "node-float"
case nodeBool:
case NodeBool:
v = "node-bool"
case nodeVar:
case NodeVar:
v = "node-var"
case nodeObj:
case NodeObj:
v = "node-obj"
case nodeList:
case NodeList:
v = "node-list"
}
return fmt.Sprintf("<%s>", v)

View File

@ -157,7 +157,6 @@ const (
type Compiler struct {
tr map[string]map[string]*trval
bl map[string]struct{}
ka bool
}
var expPool = sync.Pool{
@ -165,7 +164,7 @@ var expPool = sync.Pool{
}
func NewCompiler(c Config) (*Compiler, error) {
co := &Compiler{ka: c.KeepArgs}
co := &Compiler{}
co.tr = make(map[string]map[string]*trval)
co.bl = make(map[string]struct{}, len(c.Blocklist))
@ -380,11 +379,13 @@ func (com *Compiler) compileQuery(qc *QCode, op *Operation, role string) error {
return nil
}
func (com *Compiler) addFilters(qc *QCode, root *Select, role string) {
func (com *Compiler) addFilters(qc *QCode, sel *Select, role string) {
var fil *Exp
if trv, ok := com.tr[role][root.Name]; ok {
if trv, ok := com.tr[role][sel.Name]; ok {
fil = trv.filter(qc.Type)
} else {
return
}
if fil == nil {
@ -394,60 +395,61 @@ func (com *Compiler) addFilters(qc *QCode, root *Select, role string) {
switch fil.Op {
case OpNop:
case OpFalse:
root.Where = fil
sel.Where = fil
default:
if root.Where != nil {
ow := root.Where
if sel.Where != nil {
ow := sel.Where
root.Where = expPool.Get().(*Exp)
root.Where.Reset()
root.Where.Op = OpAnd
root.Where.Children = root.Where.childrenA[:2]
root.Where.Children[0] = fil
root.Where.Children[1] = ow
sel.Where = expPool.Get().(*Exp)
sel.Where.Reset()
sel.Where.Op = OpAnd
sel.Where.Children = sel.Where.childrenA[:2]
sel.Where.Children[0] = fil
sel.Where.Children[1] = ow
} else {
root.Where = fil
sel.Where = fil
}
}
}
func (com *Compiler) compileArgs(qc *QCode, sel *Select, args []Arg) error {
var err error
if com.ka {
sel.Args = make(map[string]*Node, len(args))
}
var ka bool
for i := range args {
arg := &args[i]
switch arg.Name {
case "id":
err = com.compileArgID(sel, arg)
err, ka = com.compileArgID(sel, arg)
case "search":
err = com.compileArgSearch(sel, arg)
err, ka = com.compileArgSearch(sel, arg)
case "where":
err = com.compileArgWhere(sel, arg)
err, ka = com.compileArgWhere(sel, arg)
case "orderby", "order_by", "order":
err = com.compileArgOrderBy(sel, arg)
err, ka = com.compileArgOrderBy(sel, arg)
case "distinct_on", "distinct":
err = com.compileArgDistinctOn(sel, arg)
err, ka = com.compileArgDistinctOn(sel, arg)
case "limit":
err = com.compileArgLimit(sel, arg)
err, ka = com.compileArgLimit(sel, arg)
case "offset":
err = com.compileArgOffset(sel, arg)
err, ka = com.compileArgOffset(sel, arg)
}
if !ka {
nodePool.Put(arg.Val)
}
if err != nil {
return err
}
if sel.Args != nil {
sel.Args[arg.Name] = arg.Val
} else {
nodePool.Put(arg.Val)
}
}
return nil
@ -455,7 +457,7 @@ func (com *Compiler) compileArgs(qc *QCode, sel *Select, args []Arg) error {
func (com *Compiler) setMutationType(qc *QCode, args []Arg) error {
setActionVar := func(arg *Arg) error {
if arg.Val.Type != nodeVar {
if arg.Val.Type != NodeVar {
return fmt.Errorf("value for argument '%s' must be a variable", arg.Name)
}
qc.ActionVar = arg.Val.Val
@ -478,7 +480,7 @@ func (com *Compiler) setMutationType(qc *QCode, args []Arg) error {
case "delete":
qc.Type = QTDelete
if arg.Val.Type != nodeBool {
if arg.Val.Type != NodeBool {
return fmt.Errorf("value for argument '%s' must be a boolean", arg.Name)
}
@ -493,7 +495,7 @@ func (com *Compiler) setMutationType(qc *QCode, args []Arg) error {
}
func (com *Compiler) compileArgObj(st *util.Stack, arg *Arg) (*Exp, error) {
if arg.Val.Type != nodeObj {
if arg.Val.Type != NodeObj {
return nil, fmt.Errorf("expecting an object")
}
@ -545,11 +547,6 @@ func (com *Compiler) compileArgNode(st *util.Stack, node *Node, usePool bool) (*
} else {
node.exp.Children = append(node.exp.Children, ex)
}
}
if com.ka {
return root, nil
}
pushChild(st, nil, node)
@ -570,13 +567,13 @@ func (com *Compiler) compileArgNode(st *util.Stack, node *Node, usePool bool) (*
return root, nil
}
func (com *Compiler) compileArgID(sel *Select, arg *Arg) error {
func (com *Compiler) compileArgID(sel *Select, arg *Arg) (error, bool) {
if sel.ID != 0 {
return nil
return nil, false
}
if sel.Where != nil && sel.Where.Op == OpEqID {
return nil
return nil, false
}
ex := expPool.Get().(*Exp)
@ -586,30 +583,41 @@ func (com *Compiler) compileArgID(sel *Select, arg *Arg) error {
ex.Val = arg.Val.Val
switch arg.Val.Type {
case nodeStr:
case NodeStr:
ex.Type = ValStr
case nodeInt:
case NodeInt:
ex.Type = ValInt
case nodeFloat:
case NodeFloat:
ex.Type = ValFloat
case nodeVar:
case NodeVar:
ex.Type = ValVar
default:
return fmt.Errorf("expecting a string, int, float or variable")
return fmt.Errorf("expecting a string, int, float or variable"), false
}
sel.Where = ex
return nil
return nil, false
}
func (com *Compiler) compileArgSearch(sel *Select, arg *Arg) error {
func (com *Compiler) compileArgSearch(sel *Select, arg *Arg) (error, bool) {
ex := expPool.Get().(*Exp)
ex.Reset()
ex.Op = OpTsQuery
ex.Type = ValStr
ex.Val = arg.Val.Val
if arg.Val.Type == NodeVar {
ex.Type = ValVar
} else {
ex.Type = ValStr
}
if sel.Args == nil {
sel.Args = make(map[string]*Node)
}
sel.Args[arg.Name] = arg.Val
if sel.Where != nil {
ow := sel.Where
@ -622,16 +630,16 @@ func (com *Compiler) compileArgSearch(sel *Select, arg *Arg) error {
} else {
sel.Where = ex
}
return nil
return nil, true
}
func (com *Compiler) compileArgWhere(sel *Select, arg *Arg) error {
func (com *Compiler) compileArgWhere(sel *Select, arg *Arg) (error, bool) {
st := util.NewStack()
var err error
ex, err := com.compileArgObj(st, arg)
if err != nil {
return err
return err, false
}
if sel.Where != nil {
@ -647,12 +655,12 @@ func (com *Compiler) compileArgWhere(sel *Select, arg *Arg) error {
sel.Where = ex
}
return nil
return nil, false
}
func (com *Compiler) compileArgOrderBy(sel *Select, arg *Arg) error {
if arg.Val.Type != nodeObj {
return fmt.Errorf("expecting an object")
func (com *Compiler) compileArgOrderBy(sel *Select, arg *Arg) (error, bool) {
if arg.Val.Type != NodeObj {
return fmt.Errorf("expecting an object"), false
}
st := util.NewStack()
@ -670,23 +678,19 @@ func (com *Compiler) compileArgOrderBy(sel *Select, arg *Arg) error {
node, ok := intf.(*Node)
if !ok || node == nil {
return fmt.Errorf("17: unexpected value %v (%t)", intf, intf)
return fmt.Errorf("17: unexpected value %v (%t)", intf, intf), false
}
if _, ok := com.bl[node.Name]; ok {
if !com.ka {
nodePool.Put(node)
}
continue
}
if node.Type == nodeObj {
if node.Type == NodeObj {
for i := range node.Children {
st.Push(node.Children[i])
}
if !com.ka {
nodePool.Put(node)
}
continue
}
@ -706,65 +710,60 @@ func (com *Compiler) compileArgOrderBy(sel *Select, arg *Arg) error {
case "desc_nulls_last":
ob.Order = OrderDescNullsLast
default:
return fmt.Errorf("valid values include asc, desc, asc_nulls_first and desc_nulls_first")
return fmt.Errorf("valid values include asc, desc, asc_nulls_first and desc_nulls_first"), false
}
setOrderByColName(ob, node)
sel.OrderBy = append(sel.OrderBy, ob)
if !com.ka {
nodePool.Put(node)
}
}
return nil
return nil, false
}
func (com *Compiler) compileArgDistinctOn(sel *Select, arg *Arg) error {
func (com *Compiler) compileArgDistinctOn(sel *Select, arg *Arg) (error, bool) {
node := arg.Val
if _, ok := com.bl[node.Name]; ok {
return nil
return nil, false
}
if node.Type != nodeList && node.Type != nodeStr {
return fmt.Errorf("expecting a list of strings or just a string")
if node.Type != NodeList && node.Type != NodeStr {
return fmt.Errorf("expecting a list of strings or just a string"), false
}
if node.Type == nodeStr {
if node.Type == NodeStr {
sel.DistinctOn = append(sel.DistinctOn, node.Val)
}
for i := range node.Children {
sel.DistinctOn = append(sel.DistinctOn, node.Children[i].Val)
if !com.ka {
nodePool.Put(node.Children[i])
}
}
return nil
return nil, false
}
func (com *Compiler) compileArgLimit(sel *Select, arg *Arg) error {
func (com *Compiler) compileArgLimit(sel *Select, arg *Arg) (error, bool) {
node := arg.Val
if node.Type != nodeInt {
return fmt.Errorf("expecting an integer")
if node.Type != NodeInt {
return fmt.Errorf("expecting an integer"), false
}
sel.Paging.Limit = node.Val
return nil
return nil, false
}
func (com *Compiler) compileArgOffset(sel *Select, arg *Arg) error {
func (com *Compiler) compileArgOffset(sel *Select, arg *Arg) (error, bool) {
node := arg.Val
if node.Type != nodeInt {
return fmt.Errorf("expecting an integer")
if node.Type != NodeInt {
return fmt.Errorf("expecting an integer"), false
}
sel.Paging.Offset = node.Val
return nil
return nil, false
}
var zeroTrv = &trval{}
@ -879,17 +878,17 @@ func newExp(st *util.Stack, node *Node, usePool bool) (*Exp, error) {
if ex.Op != OpAnd && ex.Op != OpOr && ex.Op != OpNot {
switch node.Type {
case nodeStr:
case NodeStr:
ex.Type = ValStr
case nodeInt:
case NodeInt:
ex.Type = ValInt
case nodeBool:
case NodeBool:
ex.Type = ValBool
case nodeFloat:
case NodeFloat:
ex.Type = ValFloat
case nodeList:
case NodeList:
ex.Type = ValList
case nodeVar:
case NodeVar:
ex.Type = ValVar
default:
return nil, fmt.Errorf("[Where] valid values include string, int, float, boolean and list: %s", node.Type)
@ -903,13 +902,13 @@ func newExp(st *util.Stack, node *Node, usePool bool) (*Exp, error) {
func setListVal(ex *Exp, node *Node) {
if len(node.Children) != 0 {
switch node.Children[0].Type {
case nodeStr:
case NodeStr:
ex.ListType = ValStr
case nodeInt:
case NodeInt:
ex.ListType = ValInt
case nodeBool:
case NodeBool:
ex.ListType = ValBool
case nodeFloat:
case NodeFloat:
ex.ListType = ValFloat
}
}
@ -922,7 +921,7 @@ func setWhereColName(ex *Exp, node *Node) {
var list []string
for n := node.Parent; n != nil; n = n.Parent {
if n.Type != nodeObj {
if n.Type != NodeObj {
continue
}
if len(n.Name) != 0 {

View File

@ -110,7 +110,7 @@ func graphQLFunc(query string, data interface{}, opt map[string]string) map[stri
var root []byte
if err = tx.QueryRow(c, finalSQL).Scan(&root); err != nil {
if err = tx.QueryRow(context.Background(), finalSQL).Scan(&root); err != nil {
errlog.Fatal().Err(err).Msg("sql query failed")
}

View File

@ -81,7 +81,7 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *stmt, error) {
useTx := useRoleQuery || conf.DB.SetUserID
if useTx {
if tx, err = db.Begin(c); err != nil {
if tx, err = db.Begin(context.Background()); err != nil {
return nil, nil, err
}
defer tx.Rollback(c) //nolint: errcheck
@ -122,9 +122,9 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *stmt, error) {
}
if useTx {
row = tx.QueryRow(c, ps.sd.SQL, vars...)
row = tx.QueryRow(context.Background(), ps.sd.SQL, vars...)
} else {
row = db.QueryRow(c, ps.sd.SQL, vars...)
row = db.QueryRow(context.Background(), ps.sd.SQL, vars...)
}
if mutation || anonQuery {
@ -146,7 +146,7 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *stmt, error) {
c.req.role = role
if useTx {
if err := tx.Commit(c); err != nil {
if err := tx.Commit(context.Background()); err != nil {
return nil, nil, err
}
}
@ -166,10 +166,10 @@ func (c *coreContext) resolveSQL() ([]byte, *stmt, error) {
useTx := useRoleQuery || conf.DB.SetUserID
if useTx {
if tx, err = db.Begin(c); err != nil {
if tx, err = db.Begin(context.Background()); err != nil {
return nil, nil, err
}
defer tx.Rollback(c) //nolint: errcheck
defer tx.Rollback(context.Background()) //nolint: errcheck
}
if conf.DB.SetUserID {
@ -215,9 +215,9 @@ func (c *coreContext) resolveSQL() ([]byte, *stmt, error) {
defaultRole := c.req.role
if useTx {
row = tx.QueryRow(c, finalSQL)
row = tx.QueryRow(context.Background(), finalSQL)
} else {
row = db.QueryRow(c, finalSQL)
row = db.QueryRow(context.Background(), finalSQL)
}
if len(stmts) == 1 {
@ -237,7 +237,7 @@ func (c *coreContext) resolveSQL() ([]byte, *stmt, error) {
}
if useTx {
if err := tx.Commit(c); err != nil {
if err := tx.Commit(context.Background()); err != nil {
return nil, nil, err
}
}
@ -263,7 +263,7 @@ func (c *coreContext) resolveSQL() ([]byte, *stmt, error) {
func (c *coreContext) executeRoleQuery(tx pgx.Tx) (string, error) {
var role string
row := tx.QueryRow(c, "_sg_get_role", c.req.role, 1)
row := tx.QueryRow(context.Background(), "_sg_get_role", c.req.role, 1)
if err := row.Scan(&role); err != nil {
return "", err
@ -320,6 +320,15 @@ func (c *coreContext) addTrace(sel []qcode.Select, id int32, st time.Time) {
append(c.res.Extensions.Tracing.Execution.Resolvers, tr)
}
func setLocalUserID(c context.Context, tx pgx.Tx) error {
var err error
if v := c.Value(userIDKey); v != nil {
_, err = tx.Exec(context.Background(), fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
}
return err
}
func parentFieldIds(h *xxhash.Digest, sel []qcode.Select, skipped uint32) (
[][]byte,
map[uint64]*qcode.Select) {
@ -363,15 +372,6 @@ func parentFieldIds(h *xxhash.Digest, sel []qcode.Select, skipped uint32) (
return fm, sm
}
func setLocalUserID(c context.Context, tx pgx.Tx) error {
var err error
if v := c.Value(userIDKey); v != nil {
_, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
}
return err
}
func isSkipped(n uint32, pos uint32) bool {
return ((n & (1 << pos)) != 0)
}

View File

@ -97,6 +97,10 @@ func buildMultiStmt(gql, vars []byte) ([]stmt, error) {
for i := 0; i < len(conf.Roles); i++ {
role := &conf.Roles[i]
if role.Name == "anon" {
continue
}
qc, err := qcompile.Compile(gql, role.Name)
if err != nil {
return nil, err
@ -127,8 +131,6 @@ func buildMultiStmt(gql, vars []byte) ([]stmt, error) {
//nolint: errcheck
func renderUserQuery(
stmts []stmt, vars map[string]json.RawMessage) (string, error) {
var err error
w := &bytes.Buffer{}
io.WriteString(w, `SELECT "_sg_auth_info"."role", (CASE "_sg_auth_info"."role" `)
@ -141,11 +143,7 @@ func renderUserQuery(
io.WriteString(w, `WHEN '`)
io.WriteString(w, s.role.Name)
io.WriteString(w, `' THEN (`)
s.skipped, err = pcompile.Compile(s.qc, w, psql.Variables(vars))
if err != nil {
return "", err
}
io.WriteString(w, s.sql)
io.WriteString(w, `) `)
}

View File

@ -23,21 +23,20 @@ var (
)
func initPreparedList() {
c := context.Background()
_preparedList = make(map[string]*preparedItem)
tx, err := db.Begin(c)
tx, err := db.Begin(context.Background())
if err != nil {
errlog.Fatal().Err(err).Send()
}
defer tx.Rollback(c) //nolint: errcheck
defer tx.Rollback(context.Background()) //nolint: errcheck
err = prepareRoleStmt(c, tx)
err = prepareRoleStmt(tx)
if err != nil {
errlog.Fatal().Err(err).Msg("failed to prepare get role statement")
}
if err := tx.Commit(c); err != nil {
if err := tx.Commit(context.Background()); err != nil {
errlog.Fatal().Err(err).Send()
}
@ -48,7 +47,7 @@ func initPreparedList() {
continue
}
err := prepareStmt(c, v.gql, v.vars)
err := prepareStmt(v.gql, v.vars)
if err == nil {
success++
continue
@ -66,15 +65,15 @@ func initPreparedList() {
success, len(_allowList.list))
}
func prepareStmt(c context.Context, gql string, vars []byte) error {
func prepareStmt(gql string, vars []byte) error {
qt := qcode.GetQType(gql)
q := []byte(gql)
tx, err := db.Begin(c)
tx, err := db.Begin(context.Background())
if err != nil {
return err
}
defer tx.Rollback(c) //nolint: errcheck
defer tx.Rollback(context.Background()) //nolint: errcheck
switch qt {
case qcode.QTQuery:
@ -83,7 +82,7 @@ func prepareStmt(c context.Context, gql string, vars []byte) error {
return err
}
err = prepare(c, tx, &stmts1[0], gqlHash(gql, vars, "user"))
err = prepare(tx, &stmts1[0], gqlHash(gql, vars, "user"))
if err != nil {
return err
}
@ -93,7 +92,7 @@ func prepareStmt(c context.Context, gql string, vars []byte) error {
return err
}
err = prepare(c, tx, &stmts2[0], gqlHash(gql, vars, "anon"))
err = prepare(tx, &stmts2[0], gqlHash(gql, vars, "anon"))
if err != nil {
return err
}
@ -105,7 +104,7 @@ func prepareStmt(c context.Context, gql string, vars []byte) error {
return err
}
err = prepare(c, tx, &stmts[0], gqlHash(gql, vars, role.Name))
err = prepare(tx, &stmts[0], gqlHash(gql, vars, role.Name))
if err != nil {
return err
}
@ -118,17 +117,17 @@ func prepareStmt(c context.Context, gql string, vars []byte) error {
logger.Debug().Msgf("Building prepared statement:\n %s\n%s", vars, gql)
}
if err := tx.Commit(c); err != nil {
if err := tx.Commit(context.Background()); err != nil {
return err
}
return nil
}
func prepare(c context.Context, tx pgx.Tx, st *stmt, key string) error {
func prepare(tx pgx.Tx, st *stmt, key string) error {
finalSQL, am := processTemplate(st.sql)
sd, err := tx.Prepare(c, "", finalSQL)
sd, err := tx.Prepare(context.Background(), "", finalSQL)
if err != nil {
return err
}
@ -142,7 +141,7 @@ func prepare(c context.Context, tx pgx.Tx, st *stmt, key string) error {
}
// nolint: errcheck
func prepareRoleStmt(c context.Context, tx pgx.Tx) error {
func prepareRoleStmt(tx pgx.Tx) error {
if len(conf.RolesQuery) == 0 {
return nil
}
@ -167,7 +166,7 @@ func prepareRoleStmt(c context.Context, tx pgx.Tx) error {
roleSQL, _ := processTemplate(w.String())
_, err := tx.Prepare(c, "_sg_get_role", roleSQL)
_, err := tx.Prepare(context.Background(), "_sg_get_role", roleSQL)
if err != nil {
return err
}

View File

@ -24,7 +24,6 @@ func initCompilers(c *config) (*qcode.Compiler, *psql.Compiler, error) {
conf := qcode.Config{
Blocklist: c.DB.Blocklist,
KeepArgs: false,
}
qc, err := qcode.NewCompiler(conf)