Add support for `websearch_to_tsquery` in PG 11
This commit is contained in:
parent
5593c66996
commit
5da79d91bf
|
@ -169,3 +169,17 @@ query {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
variables {
|
||||||
|
"beer": "smoke"
|
||||||
|
}
|
||||||
|
|
||||||
|
query beerSearch {
|
||||||
|
products(search: $beer) {
|
||||||
|
id
|
||||||
|
name
|
||||||
|
search_rank
|
||||||
|
search_headline_description
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -171,7 +171,7 @@ roles:
|
||||||
query:
|
query:
|
||||||
limit: 50
|
limit: 50
|
||||||
filters: ["{ user_id: { eq: $user_id } }"]
|
filters: ["{ user_id: { eq: $user_id } }"]
|
||||||
columns: ["id", "name", "description" ]
|
columns: ["id", "name", "description", "search_rank", "search_headline_description" ]
|
||||||
disable_functions: false
|
disable_functions: false
|
||||||
|
|
||||||
insert:
|
insert:
|
||||||
|
|
|
@ -353,16 +353,14 @@ func (m *Migrator) MigrateTo(targetVersion int32) (err error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Migrator) GetCurrentVersion() (v int32, err error) {
|
func (m *Migrator) GetCurrentVersion() (v int32, err error) {
|
||||||
ctx := context.Background()
|
err = m.conn.QueryRow(context.Background(),
|
||||||
|
"select version from "+m.versionTable).Scan(&v)
|
||||||
|
|
||||||
err = m.conn.QueryRow(ctx, "select version from "+m.versionTable).Scan(&v)
|
|
||||||
return v, err
|
return v, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Migrator) ensureSchemaVersionTableExists() (err error) {
|
func (m *Migrator) ensureSchemaVersionTableExists() (err error) {
|
||||||
ctx := context.Background()
|
_, err = m.conn.Exec(context.Background(), fmt.Sprintf(`
|
||||||
|
|
||||||
_, err = m.conn.Exec(ctx, fmt.Sprintf(`
|
|
||||||
create table if not exists %s(version int4 not null);
|
create table if not exists %s(version int4 not null);
|
||||||
|
|
||||||
insert into %s(version)
|
insert into %s(version)
|
||||||
|
|
|
@ -178,6 +178,7 @@ func TestMain(m *testing.M) {
|
||||||
}
|
}
|
||||||
|
|
||||||
schema := &DBSchema{
|
schema := &DBSchema{
|
||||||
|
ver: 110000,
|
||||||
t: make(map[string]*DBTableInfo),
|
t: make(map[string]*DBTableInfo),
|
||||||
rm: make(map[string]map[string]*DBRel),
|
rm: make(map[string]map[string]*DBRel),
|
||||||
al: make(map[string]struct{}),
|
al: make(map[string]struct{}),
|
||||||
|
|
|
@ -189,6 +189,11 @@ func (co *Compiler) compileQuery(qc *qcode.QCode, w io.Writer) (uint32, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(sel.Args) != 0 {
|
||||||
|
for _, v := range sel.Args {
|
||||||
|
qcode.FreeNode(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -515,36 +520,54 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
|
||||||
if isSearch {
|
if isSearch {
|
||||||
switch {
|
switch {
|
||||||
case cn == "search_rank":
|
case cn == "search_rank":
|
||||||
|
if len(sel.Allowed) != 0 {
|
||||||
|
if _, ok := sel.Allowed[cn]; !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
cn = ti.TSVCol
|
cn = ti.TSVCol
|
||||||
arg := sel.Args["search"]
|
arg := sel.Args["search"]
|
||||||
|
|
||||||
if i != 0 {
|
if i != 0 {
|
||||||
io.WriteString(c.w, `, `)
|
io.WriteString(c.w, `, `)
|
||||||
}
|
}
|
||||||
//fmt.Fprintf(w, `ts_rank("%s"."%s", to_tsquery('%s')) AS %s`,
|
//fmt.Fprintf(w, `ts_rank("%s"."%s", websearch_to_tsquery('%s')) AS %s`,
|
||||||
//c.sel.Name, cn, arg.Val, col.Name)
|
//c.sel.Name, cn, arg.Val, col.Name)
|
||||||
io.WriteString(c.w, `ts_rank(`)
|
io.WriteString(c.w, `ts_rank(`)
|
||||||
colWithTable(c.w, ti.Name, cn)
|
colWithTable(c.w, ti.Name, cn)
|
||||||
|
if c.schema.ver >= 110000 {
|
||||||
|
io.WriteString(c.w, `, websearch_to_tsquery('`)
|
||||||
|
} else {
|
||||||
io.WriteString(c.w, `, to_tsquery('`)
|
io.WriteString(c.w, `, to_tsquery('`)
|
||||||
|
}
|
||||||
io.WriteString(c.w, arg.Val)
|
io.WriteString(c.w, arg.Val)
|
||||||
io.WriteString(c.w, `')`)
|
io.WriteString(c.w, `'))`)
|
||||||
alias(c.w, col.Name)
|
alias(c.w, col.Name)
|
||||||
i++
|
i++
|
||||||
|
|
||||||
case strings.HasPrefix(cn, "search_headline_"):
|
case strings.HasPrefix(cn, "search_headline_"):
|
||||||
cn = cn[16:]
|
cn1 := cn[16:]
|
||||||
|
if len(sel.Allowed) != 0 {
|
||||||
|
if _, ok := sel.Allowed[cn1]; !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
arg := sel.Args["search"]
|
arg := sel.Args["search"]
|
||||||
|
|
||||||
if i != 0 {
|
if i != 0 {
|
||||||
io.WriteString(c.w, `, `)
|
io.WriteString(c.w, `, `)
|
||||||
}
|
}
|
||||||
//fmt.Fprintf(w, `ts_headline("%s"."%s", to_tsquery('%s')) AS %s`,
|
//fmt.Fprintf(w, `ts_headline("%s"."%s", websearch_to_tsquery('%s')) AS %s`,
|
||||||
//c.sel.Name, cn, arg.Val, col.Name)
|
//c.sel.Name, cn, arg.Val, col.Name)
|
||||||
io.WriteString(c.w, `ts_headlinek(`)
|
io.WriteString(c.w, `ts_headline(`)
|
||||||
colWithTable(c.w, ti.Name, cn)
|
colWithTable(c.w, ti.Name, cn1)
|
||||||
|
if c.schema.ver >= 110000 {
|
||||||
|
io.WriteString(c.w, `, websearch_to_tsquery('`)
|
||||||
|
} else {
|
||||||
io.WriteString(c.w, `, to_tsquery('`)
|
io.WriteString(c.w, `, to_tsquery('`)
|
||||||
|
}
|
||||||
io.WriteString(c.w, arg.Val)
|
io.WriteString(c.w, arg.Val)
|
||||||
io.WriteString(c.w, `')`)
|
io.WriteString(c.w, `'))`)
|
||||||
alias(c.w, col.Name)
|
alias(c.w, col.Name)
|
||||||
i++
|
i++
|
||||||
|
|
||||||
|
@ -693,6 +716,7 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
|
||||||
//fmt.Fprintf(w, `) AS "%s_%d"`, c.sel.Name, c.sel.ID)
|
//fmt.Fprintf(w, `) AS "%s_%d"`, c.sel.Name, c.sel.ID)
|
||||||
io.WriteString(c.w, `)`)
|
io.WriteString(c.w, `)`)
|
||||||
aliasWithID(c.w, ti.Name, sel.ID)
|
aliasWithID(c.w, ti.Name, sel.ID)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -939,6 +963,7 @@ func (c *compilerContext) renderOp(ex *qcode.Exp, sel *qcode.Select, ti *DBTable
|
||||||
io.WriteString(c.w, `IS NOT NULL)`)
|
io.WriteString(c.w, `IS NOT NULL)`)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
|
|
||||||
case qcode.OpEqID:
|
case qcode.OpEqID:
|
||||||
if len(ti.PrimaryCol) == 0 {
|
if len(ti.PrimaryCol) == 0 {
|
||||||
return fmt.Errorf("no primary key column defined for %s", ti.Name)
|
return fmt.Errorf("no primary key column defined for %s", ti.Name)
|
||||||
|
@ -951,6 +976,7 @@ func (c *compilerContext) renderOp(ex *qcode.Exp, sel *qcode.Select, ti *DBTable
|
||||||
colWithTable(c.w, ti.Name, ti.PrimaryCol)
|
colWithTable(c.w, ti.Name, ti.PrimaryCol)
|
||||||
//io.WriteString(c.w, ti.PrimaryCol)
|
//io.WriteString(c.w, ti.PrimaryCol)
|
||||||
io.WriteString(c.w, `) =`)
|
io.WriteString(c.w, `) =`)
|
||||||
|
|
||||||
case qcode.OpTsQuery:
|
case qcode.OpTsQuery:
|
||||||
if len(ti.TSVCol) == 0 {
|
if len(ti.TSVCol) == 0 {
|
||||||
return fmt.Errorf("no tsv column defined for %s", ti.Name)
|
return fmt.Errorf("no tsv column defined for %s", ti.Name)
|
||||||
|
@ -958,10 +984,14 @@ func (c *compilerContext) renderOp(ex *qcode.Exp, sel *qcode.Select, ti *DBTable
|
||||||
if _, ok = ti.Columns[ti.TSVCol]; !ok {
|
if _, ok = ti.Columns[ti.TSVCol]; !ok {
|
||||||
return fmt.Errorf("no tsv column '%s' found ", ti.TSVCol)
|
return fmt.Errorf("no tsv column '%s' found ", ti.TSVCol)
|
||||||
}
|
}
|
||||||
//fmt.Fprintf(w, `(("%s") @@ to_tsquery('%s'))`, c.ti.TSVCol, val.Val)
|
//fmt.Fprintf(w, `(("%s") @@ websearch_to_tsquery('%s'))`, c.ti.TSVCol, val.Val)
|
||||||
io.WriteString(c.w, `(("`)
|
io.WriteString(c.w, `((`)
|
||||||
io.WriteString(c.w, ti.TSVCol)
|
colWithTable(c.w, ti.Name, ti.TSVCol)
|
||||||
io.WriteString(c.w, `") @@ to_tsquery('`)
|
if c.schema.ver >= 110000 {
|
||||||
|
io.WriteString(c.w, `) @@ websearch_to_tsquery('`)
|
||||||
|
} else {
|
||||||
|
io.WriteString(c.w, `) @@ to_tsquery('`)
|
||||||
|
}
|
||||||
io.WriteString(c.w, ex.Val)
|
io.WriteString(c.w, ex.Val)
|
||||||
io.WriteString(c.w, `'))`)
|
io.WriteString(c.w, `'))`)
|
||||||
return nil
|
return nil
|
||||||
|
|
|
@ -142,15 +142,17 @@ func fetchByID(t *testing.T) {
|
||||||
|
|
||||||
func searchQuery(t *testing.T) {
|
func searchQuery(t *testing.T) {
|
||||||
gql := `query {
|
gql := `query {
|
||||||
products(search: "Imperial") {
|
products(search: "ale") {
|
||||||
id
|
id
|
||||||
name
|
name
|
||||||
|
search_rank
|
||||||
|
search_headline_description
|
||||||
}
|
}
|
||||||
}`
|
}`
|
||||||
|
|
||||||
sql := `SELECT json_object_agg('products', json_0) FROM (SELECT coalesce(json_agg("json_0"), '[]') AS "json_0" FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("tsv") @@ to_tsquery('Imperial'))) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "json_agg_0") AS "sel_0"`
|
sql := `SELECT json_object_agg('products', json_0) FROM (SELECT coalesce(json_agg("json_0"), '[]') AS "json_0" FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."search_rank" AS "search_rank", "products_0"."search_headline_description" AS "search_headline_description") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name", ts_rank("products"."tsv", websearch_to_tsquery('ale')) AS "search_rank", ts_headline("products"."description", websearch_to_tsquery('ale')) AS "search_headline_description" FROM "products" WHERE ((("products"."tsv") @@ websearch_to_tsquery('ale'))) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "json_agg_0") AS "sel_0"`
|
||||||
|
|
||||||
resSQL, err := compileGQLToPSQL(gql, nil, "user")
|
resSQL, err := compileGQLToPSQL(gql, nil, "admin")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ package psql
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gobuffalo/flect"
|
"github.com/gobuffalo/flect"
|
||||||
|
@ -144,6 +145,7 @@ ORDER BY id;`
|
||||||
}
|
}
|
||||||
|
|
||||||
type DBSchema struct {
|
type DBSchema struct {
|
||||||
|
ver int
|
||||||
t map[string]*DBTableInfo
|
t map[string]*DBTableInfo
|
||||||
rm map[string]map[string]*DBRel
|
rm map[string]map[string]*DBRel
|
||||||
al map[string]struct{}
|
al map[string]struct{}
|
||||||
|
@ -184,10 +186,22 @@ func NewDBSchema(db *pgxpool.Pool, aliases map[string][]string) (*DBSchema, erro
|
||||||
|
|
||||||
dbc, err := db.Acquire(context.Background())
|
dbc, err := db.Acquire(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("error acquiring connection from pool")
|
return nil, fmt.Errorf("error acquiring connection from pool: %w", err)
|
||||||
}
|
}
|
||||||
defer dbc.Release()
|
defer dbc.Release()
|
||||||
|
|
||||||
|
var version string
|
||||||
|
|
||||||
|
err = dbc.QueryRow(context.Background(), `SHOW server_version_num`).Scan(&version)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("error fetching version: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
schema.ver, err = strconv.Atoi(version)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
tables, err := GetTables(dbc)
|
tables, err := GetTables(dbc)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
|
|
@ -8,7 +8,6 @@ import (
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Blocklist []string
|
Blocklist []string
|
||||||
KeepArgs bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type QueryConfig struct {
|
type QueryConfig struct {
|
||||||
|
|
|
@ -26,13 +26,13 @@ const (
|
||||||
opQuery
|
opQuery
|
||||||
opMutate
|
opMutate
|
||||||
opSub
|
opSub
|
||||||
nodeStr
|
NodeStr
|
||||||
nodeInt
|
NodeInt
|
||||||
nodeFloat
|
NodeFloat
|
||||||
nodeBool
|
NodeBool
|
||||||
nodeObj
|
NodeObj
|
||||||
nodeList
|
NodeList
|
||||||
nodeVar
|
NodeVar
|
||||||
)
|
)
|
||||||
|
|
||||||
type Operation struct {
|
type Operation struct {
|
||||||
|
@ -413,7 +413,7 @@ func (p *Parser) parseList() (*Node, error) {
|
||||||
return nil, errors.New("List cannot be empty")
|
return nil, errors.New("List cannot be empty")
|
||||||
}
|
}
|
||||||
|
|
||||||
parent.Type = nodeList
|
parent.Type = NodeList
|
||||||
parent.Children = nodes
|
parent.Children = nodes
|
||||||
|
|
||||||
return parent, nil
|
return parent, nil
|
||||||
|
@ -450,7 +450,7 @@ func (p *Parser) parseObj() (*Node, error) {
|
||||||
nodes = append(nodes, node)
|
nodes = append(nodes, node)
|
||||||
}
|
}
|
||||||
|
|
||||||
parent.Type = nodeObj
|
parent.Type = NodeObj
|
||||||
parent.Children = nodes
|
parent.Children = nodes
|
||||||
|
|
||||||
return parent, nil
|
return parent, nil
|
||||||
|
@ -473,17 +473,17 @@ func (p *Parser) parseValue() (*Node, error) {
|
||||||
|
|
||||||
switch item.typ {
|
switch item.typ {
|
||||||
case itemIntVal:
|
case itemIntVal:
|
||||||
node.Type = nodeInt
|
node.Type = NodeInt
|
||||||
case itemFloatVal:
|
case itemFloatVal:
|
||||||
node.Type = nodeFloat
|
node.Type = NodeFloat
|
||||||
case itemStringVal:
|
case itemStringVal:
|
||||||
node.Type = nodeStr
|
node.Type = NodeStr
|
||||||
case itemBoolVal:
|
case itemBoolVal:
|
||||||
node.Type = nodeBool
|
node.Type = NodeBool
|
||||||
case itemName:
|
case itemName:
|
||||||
node.Type = nodeStr
|
node.Type = NodeStr
|
||||||
case itemVariable:
|
case itemVariable:
|
||||||
node.Type = nodeVar
|
node.Type = NodeVar
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("expecting a number, string, object, list or variable as an argument value (not %s)", p.val(p.next()))
|
return nil, fmt.Errorf("expecting a number, string, object, list or variable as an argument value (not %s)", p.val(p.next()))
|
||||||
}
|
}
|
||||||
|
@ -514,19 +514,19 @@ func (t parserType) String() string {
|
||||||
v = "mutation"
|
v = "mutation"
|
||||||
case opSub:
|
case opSub:
|
||||||
v = "subscription"
|
v = "subscription"
|
||||||
case nodeStr:
|
case NodeStr:
|
||||||
v = "node-string"
|
v = "node-string"
|
||||||
case nodeInt:
|
case NodeInt:
|
||||||
v = "node-int"
|
v = "node-int"
|
||||||
case nodeFloat:
|
case NodeFloat:
|
||||||
v = "node-float"
|
v = "node-float"
|
||||||
case nodeBool:
|
case NodeBool:
|
||||||
v = "node-bool"
|
v = "node-bool"
|
||||||
case nodeVar:
|
case NodeVar:
|
||||||
v = "node-var"
|
v = "node-var"
|
||||||
case nodeObj:
|
case NodeObj:
|
||||||
v = "node-obj"
|
v = "node-obj"
|
||||||
case nodeList:
|
case NodeList:
|
||||||
v = "node-list"
|
v = "node-list"
|
||||||
}
|
}
|
||||||
return fmt.Sprintf("<%s>", v)
|
return fmt.Sprintf("<%s>", v)
|
||||||
|
|
191
qcode/qcode.go
191
qcode/qcode.go
|
@ -157,7 +157,6 @@ const (
|
||||||
type Compiler struct {
|
type Compiler struct {
|
||||||
tr map[string]map[string]*trval
|
tr map[string]map[string]*trval
|
||||||
bl map[string]struct{}
|
bl map[string]struct{}
|
||||||
ka bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
var expPool = sync.Pool{
|
var expPool = sync.Pool{
|
||||||
|
@ -165,7 +164,7 @@ var expPool = sync.Pool{
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewCompiler(c Config) (*Compiler, error) {
|
func NewCompiler(c Config) (*Compiler, error) {
|
||||||
co := &Compiler{ka: c.KeepArgs}
|
co := &Compiler{}
|
||||||
co.tr = make(map[string]map[string]*trval)
|
co.tr = make(map[string]map[string]*trval)
|
||||||
co.bl = make(map[string]struct{}, len(c.Blocklist))
|
co.bl = make(map[string]struct{}, len(c.Blocklist))
|
||||||
|
|
||||||
|
@ -380,11 +379,13 @@ func (com *Compiler) compileQuery(qc *QCode, op *Operation, role string) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (com *Compiler) addFilters(qc *QCode, root *Select, role string) {
|
func (com *Compiler) addFilters(qc *QCode, sel *Select, role string) {
|
||||||
var fil *Exp
|
var fil *Exp
|
||||||
|
|
||||||
if trv, ok := com.tr[role][root.Name]; ok {
|
if trv, ok := com.tr[role][sel.Name]; ok {
|
||||||
fil = trv.filter(qc.Type)
|
fil = trv.filter(qc.Type)
|
||||||
|
} else {
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if fil == nil {
|
if fil == nil {
|
||||||
|
@ -394,60 +395,61 @@ func (com *Compiler) addFilters(qc *QCode, root *Select, role string) {
|
||||||
switch fil.Op {
|
switch fil.Op {
|
||||||
case OpNop:
|
case OpNop:
|
||||||
case OpFalse:
|
case OpFalse:
|
||||||
root.Where = fil
|
sel.Where = fil
|
||||||
|
|
||||||
default:
|
default:
|
||||||
if root.Where != nil {
|
if sel.Where != nil {
|
||||||
ow := root.Where
|
ow := sel.Where
|
||||||
|
|
||||||
root.Where = expPool.Get().(*Exp)
|
sel.Where = expPool.Get().(*Exp)
|
||||||
root.Where.Reset()
|
sel.Where.Reset()
|
||||||
root.Where.Op = OpAnd
|
sel.Where.Op = OpAnd
|
||||||
root.Where.Children = root.Where.childrenA[:2]
|
sel.Where.Children = sel.Where.childrenA[:2]
|
||||||
root.Where.Children[0] = fil
|
sel.Where.Children[0] = fil
|
||||||
root.Where.Children[1] = ow
|
sel.Where.Children[1] = ow
|
||||||
} else {
|
} else {
|
||||||
root.Where = fil
|
sel.Where = fil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (com *Compiler) compileArgs(qc *QCode, sel *Select, args []Arg) error {
|
func (com *Compiler) compileArgs(qc *QCode, sel *Select, args []Arg) error {
|
||||||
var err error
|
var err error
|
||||||
|
var ka bool
|
||||||
if com.ka {
|
|
||||||
sel.Args = make(map[string]*Node, len(args))
|
|
||||||
}
|
|
||||||
|
|
||||||
for i := range args {
|
for i := range args {
|
||||||
arg := &args[i]
|
arg := &args[i]
|
||||||
|
|
||||||
switch arg.Name {
|
switch arg.Name {
|
||||||
case "id":
|
case "id":
|
||||||
err = com.compileArgID(sel, arg)
|
err, ka = com.compileArgID(sel, arg)
|
||||||
|
|
||||||
case "search":
|
case "search":
|
||||||
err = com.compileArgSearch(sel, arg)
|
err, ka = com.compileArgSearch(sel, arg)
|
||||||
|
|
||||||
case "where":
|
case "where":
|
||||||
err = com.compileArgWhere(sel, arg)
|
err, ka = com.compileArgWhere(sel, arg)
|
||||||
|
|
||||||
case "orderby", "order_by", "order":
|
case "orderby", "order_by", "order":
|
||||||
err = com.compileArgOrderBy(sel, arg)
|
err, ka = com.compileArgOrderBy(sel, arg)
|
||||||
|
|
||||||
case "distinct_on", "distinct":
|
case "distinct_on", "distinct":
|
||||||
err = com.compileArgDistinctOn(sel, arg)
|
err, ka = com.compileArgDistinctOn(sel, arg)
|
||||||
|
|
||||||
case "limit":
|
case "limit":
|
||||||
err = com.compileArgLimit(sel, arg)
|
err, ka = com.compileArgLimit(sel, arg)
|
||||||
|
|
||||||
case "offset":
|
case "offset":
|
||||||
err = com.compileArgOffset(sel, arg)
|
err, ka = com.compileArgOffset(sel, arg)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !ka {
|
||||||
|
nodePool.Put(arg.Val)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if sel.Args != nil {
|
|
||||||
sel.Args[arg.Name] = arg.Val
|
|
||||||
} else {
|
|
||||||
nodePool.Put(arg.Val)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -455,7 +457,7 @@ func (com *Compiler) compileArgs(qc *QCode, sel *Select, args []Arg) error {
|
||||||
|
|
||||||
func (com *Compiler) setMutationType(qc *QCode, args []Arg) error {
|
func (com *Compiler) setMutationType(qc *QCode, args []Arg) error {
|
||||||
setActionVar := func(arg *Arg) error {
|
setActionVar := func(arg *Arg) error {
|
||||||
if arg.Val.Type != nodeVar {
|
if arg.Val.Type != NodeVar {
|
||||||
return fmt.Errorf("value for argument '%s' must be a variable", arg.Name)
|
return fmt.Errorf("value for argument '%s' must be a variable", arg.Name)
|
||||||
}
|
}
|
||||||
qc.ActionVar = arg.Val.Val
|
qc.ActionVar = arg.Val.Val
|
||||||
|
@ -478,7 +480,7 @@ func (com *Compiler) setMutationType(qc *QCode, args []Arg) error {
|
||||||
case "delete":
|
case "delete":
|
||||||
qc.Type = QTDelete
|
qc.Type = QTDelete
|
||||||
|
|
||||||
if arg.Val.Type != nodeBool {
|
if arg.Val.Type != NodeBool {
|
||||||
return fmt.Errorf("value for argument '%s' must be a boolean", arg.Name)
|
return fmt.Errorf("value for argument '%s' must be a boolean", arg.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -493,7 +495,7 @@ func (com *Compiler) setMutationType(qc *QCode, args []Arg) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (com *Compiler) compileArgObj(st *util.Stack, arg *Arg) (*Exp, error) {
|
func (com *Compiler) compileArgObj(st *util.Stack, arg *Arg) (*Exp, error) {
|
||||||
if arg.Val.Type != nodeObj {
|
if arg.Val.Type != NodeObj {
|
||||||
return nil, fmt.Errorf("expecting an object")
|
return nil, fmt.Errorf("expecting an object")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -545,11 +547,6 @@ func (com *Compiler) compileArgNode(st *util.Stack, node *Node, usePool bool) (*
|
||||||
} else {
|
} else {
|
||||||
node.exp.Children = append(node.exp.Children, ex)
|
node.exp.Children = append(node.exp.Children, ex)
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
if com.ka {
|
|
||||||
return root, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
pushChild(st, nil, node)
|
pushChild(st, nil, node)
|
||||||
|
@ -570,13 +567,13 @@ func (com *Compiler) compileArgNode(st *util.Stack, node *Node, usePool bool) (*
|
||||||
return root, nil
|
return root, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (com *Compiler) compileArgID(sel *Select, arg *Arg) error {
|
func (com *Compiler) compileArgID(sel *Select, arg *Arg) (error, bool) {
|
||||||
if sel.ID != 0 {
|
if sel.ID != 0 {
|
||||||
return nil
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
if sel.Where != nil && sel.Where.Op == OpEqID {
|
if sel.Where != nil && sel.Where.Op == OpEqID {
|
||||||
return nil
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
ex := expPool.Get().(*Exp)
|
ex := expPool.Get().(*Exp)
|
||||||
|
@ -586,30 +583,41 @@ func (com *Compiler) compileArgID(sel *Select, arg *Arg) error {
|
||||||
ex.Val = arg.Val.Val
|
ex.Val = arg.Val.Val
|
||||||
|
|
||||||
switch arg.Val.Type {
|
switch arg.Val.Type {
|
||||||
case nodeStr:
|
case NodeStr:
|
||||||
ex.Type = ValStr
|
ex.Type = ValStr
|
||||||
case nodeInt:
|
case NodeInt:
|
||||||
ex.Type = ValInt
|
ex.Type = ValInt
|
||||||
case nodeFloat:
|
case NodeFloat:
|
||||||
ex.Type = ValFloat
|
ex.Type = ValFloat
|
||||||
case nodeVar:
|
case NodeVar:
|
||||||
ex.Type = ValVar
|
ex.Type = ValVar
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("expecting a string, int, float or variable")
|
return fmt.Errorf("expecting a string, int, float or variable"), false
|
||||||
}
|
}
|
||||||
|
|
||||||
sel.Where = ex
|
sel.Where = ex
|
||||||
return nil
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (com *Compiler) compileArgSearch(sel *Select, arg *Arg) error {
|
func (com *Compiler) compileArgSearch(sel *Select, arg *Arg) (error, bool) {
|
||||||
ex := expPool.Get().(*Exp)
|
ex := expPool.Get().(*Exp)
|
||||||
ex.Reset()
|
ex.Reset()
|
||||||
|
|
||||||
ex.Op = OpTsQuery
|
ex.Op = OpTsQuery
|
||||||
ex.Type = ValStr
|
|
||||||
ex.Val = arg.Val.Val
|
ex.Val = arg.Val.Val
|
||||||
|
|
||||||
|
if arg.Val.Type == NodeVar {
|
||||||
|
ex.Type = ValVar
|
||||||
|
} else {
|
||||||
|
ex.Type = ValStr
|
||||||
|
}
|
||||||
|
|
||||||
|
if sel.Args == nil {
|
||||||
|
sel.Args = make(map[string]*Node)
|
||||||
|
}
|
||||||
|
|
||||||
|
sel.Args[arg.Name] = arg.Val
|
||||||
|
|
||||||
if sel.Where != nil {
|
if sel.Where != nil {
|
||||||
ow := sel.Where
|
ow := sel.Where
|
||||||
|
|
||||||
|
@ -622,16 +630,16 @@ func (com *Compiler) compileArgSearch(sel *Select, arg *Arg) error {
|
||||||
} else {
|
} else {
|
||||||
sel.Where = ex
|
sel.Where = ex
|
||||||
}
|
}
|
||||||
return nil
|
return nil, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (com *Compiler) compileArgWhere(sel *Select, arg *Arg) error {
|
func (com *Compiler) compileArgWhere(sel *Select, arg *Arg) (error, bool) {
|
||||||
st := util.NewStack()
|
st := util.NewStack()
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
ex, err := com.compileArgObj(st, arg)
|
ex, err := com.compileArgObj(st, arg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err, false
|
||||||
}
|
}
|
||||||
|
|
||||||
if sel.Where != nil {
|
if sel.Where != nil {
|
||||||
|
@ -647,12 +655,12 @@ func (com *Compiler) compileArgWhere(sel *Select, arg *Arg) error {
|
||||||
sel.Where = ex
|
sel.Where = ex
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (com *Compiler) compileArgOrderBy(sel *Select, arg *Arg) error {
|
func (com *Compiler) compileArgOrderBy(sel *Select, arg *Arg) (error, bool) {
|
||||||
if arg.Val.Type != nodeObj {
|
if arg.Val.Type != NodeObj {
|
||||||
return fmt.Errorf("expecting an object")
|
return fmt.Errorf("expecting an object"), false
|
||||||
}
|
}
|
||||||
|
|
||||||
st := util.NewStack()
|
st := util.NewStack()
|
||||||
|
@ -670,23 +678,19 @@ func (com *Compiler) compileArgOrderBy(sel *Select, arg *Arg) error {
|
||||||
node, ok := intf.(*Node)
|
node, ok := intf.(*Node)
|
||||||
|
|
||||||
if !ok || node == nil {
|
if !ok || node == nil {
|
||||||
return fmt.Errorf("17: unexpected value %v (%t)", intf, intf)
|
return fmt.Errorf("17: unexpected value %v (%t)", intf, intf), false
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := com.bl[node.Name]; ok {
|
if _, ok := com.bl[node.Name]; ok {
|
||||||
if !com.ka {
|
|
||||||
nodePool.Put(node)
|
nodePool.Put(node)
|
||||||
}
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if node.Type == nodeObj {
|
if node.Type == NodeObj {
|
||||||
for i := range node.Children {
|
for i := range node.Children {
|
||||||
st.Push(node.Children[i])
|
st.Push(node.Children[i])
|
||||||
}
|
}
|
||||||
if !com.ka {
|
|
||||||
nodePool.Put(node)
|
nodePool.Put(node)
|
||||||
}
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -706,65 +710,60 @@ func (com *Compiler) compileArgOrderBy(sel *Select, arg *Arg) error {
|
||||||
case "desc_nulls_last":
|
case "desc_nulls_last":
|
||||||
ob.Order = OrderDescNullsLast
|
ob.Order = OrderDescNullsLast
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("valid values include asc, desc, asc_nulls_first and desc_nulls_first")
|
return fmt.Errorf("valid values include asc, desc, asc_nulls_first and desc_nulls_first"), false
|
||||||
}
|
}
|
||||||
|
|
||||||
setOrderByColName(ob, node)
|
setOrderByColName(ob, node)
|
||||||
sel.OrderBy = append(sel.OrderBy, ob)
|
sel.OrderBy = append(sel.OrderBy, ob)
|
||||||
|
|
||||||
if !com.ka {
|
|
||||||
nodePool.Put(node)
|
nodePool.Put(node)
|
||||||
}
|
}
|
||||||
}
|
return nil, false
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (com *Compiler) compileArgDistinctOn(sel *Select, arg *Arg) error {
|
func (com *Compiler) compileArgDistinctOn(sel *Select, arg *Arg) (error, bool) {
|
||||||
node := arg.Val
|
node := arg.Val
|
||||||
|
|
||||||
if _, ok := com.bl[node.Name]; ok {
|
if _, ok := com.bl[node.Name]; ok {
|
||||||
return nil
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
if node.Type != nodeList && node.Type != nodeStr {
|
if node.Type != NodeList && node.Type != NodeStr {
|
||||||
return fmt.Errorf("expecting a list of strings or just a string")
|
return fmt.Errorf("expecting a list of strings or just a string"), false
|
||||||
}
|
}
|
||||||
|
|
||||||
if node.Type == nodeStr {
|
if node.Type == NodeStr {
|
||||||
sel.DistinctOn = append(sel.DistinctOn, node.Val)
|
sel.DistinctOn = append(sel.DistinctOn, node.Val)
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := range node.Children {
|
for i := range node.Children {
|
||||||
sel.DistinctOn = append(sel.DistinctOn, node.Children[i].Val)
|
sel.DistinctOn = append(sel.DistinctOn, node.Children[i].Val)
|
||||||
if !com.ka {
|
|
||||||
nodePool.Put(node.Children[i])
|
nodePool.Put(node.Children[i])
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
func (com *Compiler) compileArgLimit(sel *Select, arg *Arg) (error, bool) {
|
||||||
}
|
|
||||||
|
|
||||||
func (com *Compiler) compileArgLimit(sel *Select, arg *Arg) error {
|
|
||||||
node := arg.Val
|
node := arg.Val
|
||||||
|
|
||||||
if node.Type != nodeInt {
|
if node.Type != NodeInt {
|
||||||
return fmt.Errorf("expecting an integer")
|
return fmt.Errorf("expecting an integer"), false
|
||||||
}
|
}
|
||||||
|
|
||||||
sel.Paging.Limit = node.Val
|
sel.Paging.Limit = node.Val
|
||||||
|
|
||||||
return nil
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (com *Compiler) compileArgOffset(sel *Select, arg *Arg) error {
|
func (com *Compiler) compileArgOffset(sel *Select, arg *Arg) (error, bool) {
|
||||||
node := arg.Val
|
node := arg.Val
|
||||||
|
|
||||||
if node.Type != nodeInt {
|
if node.Type != NodeInt {
|
||||||
return fmt.Errorf("expecting an integer")
|
return fmt.Errorf("expecting an integer"), false
|
||||||
}
|
}
|
||||||
|
|
||||||
sel.Paging.Offset = node.Val
|
sel.Paging.Offset = node.Val
|
||||||
return nil
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
var zeroTrv = &trval{}
|
var zeroTrv = &trval{}
|
||||||
|
@ -879,17 +878,17 @@ func newExp(st *util.Stack, node *Node, usePool bool) (*Exp, error) {
|
||||||
|
|
||||||
if ex.Op != OpAnd && ex.Op != OpOr && ex.Op != OpNot {
|
if ex.Op != OpAnd && ex.Op != OpOr && ex.Op != OpNot {
|
||||||
switch node.Type {
|
switch node.Type {
|
||||||
case nodeStr:
|
case NodeStr:
|
||||||
ex.Type = ValStr
|
ex.Type = ValStr
|
||||||
case nodeInt:
|
case NodeInt:
|
||||||
ex.Type = ValInt
|
ex.Type = ValInt
|
||||||
case nodeBool:
|
case NodeBool:
|
||||||
ex.Type = ValBool
|
ex.Type = ValBool
|
||||||
case nodeFloat:
|
case NodeFloat:
|
||||||
ex.Type = ValFloat
|
ex.Type = ValFloat
|
||||||
case nodeList:
|
case NodeList:
|
||||||
ex.Type = ValList
|
ex.Type = ValList
|
||||||
case nodeVar:
|
case NodeVar:
|
||||||
ex.Type = ValVar
|
ex.Type = ValVar
|
||||||
default:
|
default:
|
||||||
return nil, fmt.Errorf("[Where] valid values include string, int, float, boolean and list: %s", node.Type)
|
return nil, fmt.Errorf("[Where] valid values include string, int, float, boolean and list: %s", node.Type)
|
||||||
|
@ -903,13 +902,13 @@ func newExp(st *util.Stack, node *Node, usePool bool) (*Exp, error) {
|
||||||
func setListVal(ex *Exp, node *Node) {
|
func setListVal(ex *Exp, node *Node) {
|
||||||
if len(node.Children) != 0 {
|
if len(node.Children) != 0 {
|
||||||
switch node.Children[0].Type {
|
switch node.Children[0].Type {
|
||||||
case nodeStr:
|
case NodeStr:
|
||||||
ex.ListType = ValStr
|
ex.ListType = ValStr
|
||||||
case nodeInt:
|
case NodeInt:
|
||||||
ex.ListType = ValInt
|
ex.ListType = ValInt
|
||||||
case nodeBool:
|
case NodeBool:
|
||||||
ex.ListType = ValBool
|
ex.ListType = ValBool
|
||||||
case nodeFloat:
|
case NodeFloat:
|
||||||
ex.ListType = ValFloat
|
ex.ListType = ValFloat
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -922,7 +921,7 @@ func setWhereColName(ex *Exp, node *Node) {
|
||||||
var list []string
|
var list []string
|
||||||
|
|
||||||
for n := node.Parent; n != nil; n = n.Parent {
|
for n := node.Parent; n != nil; n = n.Parent {
|
||||||
if n.Type != nodeObj {
|
if n.Type != NodeObj {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if len(n.Name) != 0 {
|
if len(n.Name) != 0 {
|
||||||
|
|
|
@ -110,7 +110,7 @@ func graphQLFunc(query string, data interface{}, opt map[string]string) map[stri
|
||||||
|
|
||||||
var root []byte
|
var root []byte
|
||||||
|
|
||||||
if err = tx.QueryRow(c, finalSQL).Scan(&root); err != nil {
|
if err = tx.QueryRow(context.Background(), finalSQL).Scan(&root); err != nil {
|
||||||
errlog.Fatal().Err(err).Msg("sql query failed")
|
errlog.Fatal().Err(err).Msg("sql query failed")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
38
serv/core.go
38
serv/core.go
|
@ -81,7 +81,7 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *stmt, error) {
|
||||||
useTx := useRoleQuery || conf.DB.SetUserID
|
useTx := useRoleQuery || conf.DB.SetUserID
|
||||||
|
|
||||||
if useTx {
|
if useTx {
|
||||||
if tx, err = db.Begin(c); err != nil {
|
if tx, err = db.Begin(context.Background()); err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
defer tx.Rollback(c) //nolint: errcheck
|
defer tx.Rollback(c) //nolint: errcheck
|
||||||
|
@ -122,9 +122,9 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *stmt, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if useTx {
|
if useTx {
|
||||||
row = tx.QueryRow(c, ps.sd.SQL, vars...)
|
row = tx.QueryRow(context.Background(), ps.sd.SQL, vars...)
|
||||||
} else {
|
} else {
|
||||||
row = db.QueryRow(c, ps.sd.SQL, vars...)
|
row = db.QueryRow(context.Background(), ps.sd.SQL, vars...)
|
||||||
}
|
}
|
||||||
|
|
||||||
if mutation || anonQuery {
|
if mutation || anonQuery {
|
||||||
|
@ -146,7 +146,7 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *stmt, error) {
|
||||||
c.req.role = role
|
c.req.role = role
|
||||||
|
|
||||||
if useTx {
|
if useTx {
|
||||||
if err := tx.Commit(c); err != nil {
|
if err := tx.Commit(context.Background()); err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -166,10 +166,10 @@ func (c *coreContext) resolveSQL() ([]byte, *stmt, error) {
|
||||||
useTx := useRoleQuery || conf.DB.SetUserID
|
useTx := useRoleQuery || conf.DB.SetUserID
|
||||||
|
|
||||||
if useTx {
|
if useTx {
|
||||||
if tx, err = db.Begin(c); err != nil {
|
if tx, err = db.Begin(context.Background()); err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
defer tx.Rollback(c) //nolint: errcheck
|
defer tx.Rollback(context.Background()) //nolint: errcheck
|
||||||
}
|
}
|
||||||
|
|
||||||
if conf.DB.SetUserID {
|
if conf.DB.SetUserID {
|
||||||
|
@ -215,9 +215,9 @@ func (c *coreContext) resolveSQL() ([]byte, *stmt, error) {
|
||||||
defaultRole := c.req.role
|
defaultRole := c.req.role
|
||||||
|
|
||||||
if useTx {
|
if useTx {
|
||||||
row = tx.QueryRow(c, finalSQL)
|
row = tx.QueryRow(context.Background(), finalSQL)
|
||||||
} else {
|
} else {
|
||||||
row = db.QueryRow(c, finalSQL)
|
row = db.QueryRow(context.Background(), finalSQL)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(stmts) == 1 {
|
if len(stmts) == 1 {
|
||||||
|
@ -237,7 +237,7 @@ func (c *coreContext) resolveSQL() ([]byte, *stmt, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if useTx {
|
if useTx {
|
||||||
if err := tx.Commit(c); err != nil {
|
if err := tx.Commit(context.Background()); err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -263,7 +263,7 @@ func (c *coreContext) resolveSQL() ([]byte, *stmt, error) {
|
||||||
|
|
||||||
func (c *coreContext) executeRoleQuery(tx pgx.Tx) (string, error) {
|
func (c *coreContext) executeRoleQuery(tx pgx.Tx) (string, error) {
|
||||||
var role string
|
var role string
|
||||||
row := tx.QueryRow(c, "_sg_get_role", c.req.role, 1)
|
row := tx.QueryRow(context.Background(), "_sg_get_role", c.req.role, 1)
|
||||||
|
|
||||||
if err := row.Scan(&role); err != nil {
|
if err := row.Scan(&role); err != nil {
|
||||||
return "", err
|
return "", err
|
||||||
|
@ -320,6 +320,15 @@ func (c *coreContext) addTrace(sel []qcode.Select, id int32, st time.Time) {
|
||||||
append(c.res.Extensions.Tracing.Execution.Resolvers, tr)
|
append(c.res.Extensions.Tracing.Execution.Resolvers, tr)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func setLocalUserID(c context.Context, tx pgx.Tx) error {
|
||||||
|
var err error
|
||||||
|
if v := c.Value(userIDKey); v != nil {
|
||||||
|
_, err = tx.Exec(context.Background(), fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func parentFieldIds(h *xxhash.Digest, sel []qcode.Select, skipped uint32) (
|
func parentFieldIds(h *xxhash.Digest, sel []qcode.Select, skipped uint32) (
|
||||||
[][]byte,
|
[][]byte,
|
||||||
map[uint64]*qcode.Select) {
|
map[uint64]*qcode.Select) {
|
||||||
|
@ -363,15 +372,6 @@ func parentFieldIds(h *xxhash.Digest, sel []qcode.Select, skipped uint32) (
|
||||||
return fm, sm
|
return fm, sm
|
||||||
}
|
}
|
||||||
|
|
||||||
func setLocalUserID(c context.Context, tx pgx.Tx) error {
|
|
||||||
var err error
|
|
||||||
if v := c.Value(userIDKey); v != nil {
|
|
||||||
_, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
|
|
||||||
}
|
|
||||||
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
func isSkipped(n uint32, pos uint32) bool {
|
func isSkipped(n uint32, pos uint32) bool {
|
||||||
return ((n & (1 << pos)) != 0)
|
return ((n & (1 << pos)) != 0)
|
||||||
}
|
}
|
||||||
|
|
|
@ -97,6 +97,10 @@ func buildMultiStmt(gql, vars []byte) ([]stmt, error) {
|
||||||
for i := 0; i < len(conf.Roles); i++ {
|
for i := 0; i < len(conf.Roles); i++ {
|
||||||
role := &conf.Roles[i]
|
role := &conf.Roles[i]
|
||||||
|
|
||||||
|
if role.Name == "anon" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
qc, err := qcompile.Compile(gql, role.Name)
|
qc, err := qcompile.Compile(gql, role.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
|
@ -127,8 +131,6 @@ func buildMultiStmt(gql, vars []byte) ([]stmt, error) {
|
||||||
//nolint: errcheck
|
//nolint: errcheck
|
||||||
func renderUserQuery(
|
func renderUserQuery(
|
||||||
stmts []stmt, vars map[string]json.RawMessage) (string, error) {
|
stmts []stmt, vars map[string]json.RawMessage) (string, error) {
|
||||||
|
|
||||||
var err error
|
|
||||||
w := &bytes.Buffer{}
|
w := &bytes.Buffer{}
|
||||||
|
|
||||||
io.WriteString(w, `SELECT "_sg_auth_info"."role", (CASE "_sg_auth_info"."role" `)
|
io.WriteString(w, `SELECT "_sg_auth_info"."role", (CASE "_sg_auth_info"."role" `)
|
||||||
|
@ -141,11 +143,7 @@ func renderUserQuery(
|
||||||
io.WriteString(w, `WHEN '`)
|
io.WriteString(w, `WHEN '`)
|
||||||
io.WriteString(w, s.role.Name)
|
io.WriteString(w, s.role.Name)
|
||||||
io.WriteString(w, `' THEN (`)
|
io.WriteString(w, `' THEN (`)
|
||||||
|
io.WriteString(w, s.sql)
|
||||||
s.skipped, err = pcompile.Compile(s.qc, w, psql.Variables(vars))
|
|
||||||
if err != nil {
|
|
||||||
return "", err
|
|
||||||
}
|
|
||||||
io.WriteString(w, `) `)
|
io.WriteString(w, `) `)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -23,21 +23,20 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
func initPreparedList() {
|
func initPreparedList() {
|
||||||
c := context.Background()
|
|
||||||
_preparedList = make(map[string]*preparedItem)
|
_preparedList = make(map[string]*preparedItem)
|
||||||
|
|
||||||
tx, err := db.Begin(c)
|
tx, err := db.Begin(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errlog.Fatal().Err(err).Send()
|
errlog.Fatal().Err(err).Send()
|
||||||
}
|
}
|
||||||
defer tx.Rollback(c) //nolint: errcheck
|
defer tx.Rollback(context.Background()) //nolint: errcheck
|
||||||
|
|
||||||
err = prepareRoleStmt(c, tx)
|
err = prepareRoleStmt(tx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errlog.Fatal().Err(err).Msg("failed to prepare get role statement")
|
errlog.Fatal().Err(err).Msg("failed to prepare get role statement")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := tx.Commit(c); err != nil {
|
if err := tx.Commit(context.Background()); err != nil {
|
||||||
errlog.Fatal().Err(err).Send()
|
errlog.Fatal().Err(err).Send()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -48,7 +47,7 @@ func initPreparedList() {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
err := prepareStmt(c, v.gql, v.vars)
|
err := prepareStmt(v.gql, v.vars)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
success++
|
success++
|
||||||
continue
|
continue
|
||||||
|
@ -66,15 +65,15 @@ func initPreparedList() {
|
||||||
success, len(_allowList.list))
|
success, len(_allowList.list))
|
||||||
}
|
}
|
||||||
|
|
||||||
func prepareStmt(c context.Context, gql string, vars []byte) error {
|
func prepareStmt(gql string, vars []byte) error {
|
||||||
qt := qcode.GetQType(gql)
|
qt := qcode.GetQType(gql)
|
||||||
q := []byte(gql)
|
q := []byte(gql)
|
||||||
|
|
||||||
tx, err := db.Begin(c)
|
tx, err := db.Begin(context.Background())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer tx.Rollback(c) //nolint: errcheck
|
defer tx.Rollback(context.Background()) //nolint: errcheck
|
||||||
|
|
||||||
switch qt {
|
switch qt {
|
||||||
case qcode.QTQuery:
|
case qcode.QTQuery:
|
||||||
|
@ -83,7 +82,7 @@ func prepareStmt(c context.Context, gql string, vars []byte) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = prepare(c, tx, &stmts1[0], gqlHash(gql, vars, "user"))
|
err = prepare(tx, &stmts1[0], gqlHash(gql, vars, "user"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -93,7 +92,7 @@ func prepareStmt(c context.Context, gql string, vars []byte) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = prepare(c, tx, &stmts2[0], gqlHash(gql, vars, "anon"))
|
err = prepare(tx, &stmts2[0], gqlHash(gql, vars, "anon"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -105,7 +104,7 @@ func prepareStmt(c context.Context, gql string, vars []byte) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = prepare(c, tx, &stmts[0], gqlHash(gql, vars, role.Name))
|
err = prepare(tx, &stmts[0], gqlHash(gql, vars, role.Name))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -118,17 +117,17 @@ func prepareStmt(c context.Context, gql string, vars []byte) error {
|
||||||
logger.Debug().Msgf("Building prepared statement:\n %s\n%s", vars, gql)
|
logger.Debug().Msgf("Building prepared statement:\n %s\n%s", vars, gql)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := tx.Commit(c); err != nil {
|
if err := tx.Commit(context.Background()); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func prepare(c context.Context, tx pgx.Tx, st *stmt, key string) error {
|
func prepare(tx pgx.Tx, st *stmt, key string) error {
|
||||||
finalSQL, am := processTemplate(st.sql)
|
finalSQL, am := processTemplate(st.sql)
|
||||||
|
|
||||||
sd, err := tx.Prepare(c, "", finalSQL)
|
sd, err := tx.Prepare(context.Background(), "", finalSQL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -142,7 +141,7 @@ func prepare(c context.Context, tx pgx.Tx, st *stmt, key string) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
// nolint: errcheck
|
// nolint: errcheck
|
||||||
func prepareRoleStmt(c context.Context, tx pgx.Tx) error {
|
func prepareRoleStmt(tx pgx.Tx) error {
|
||||||
if len(conf.RolesQuery) == 0 {
|
if len(conf.RolesQuery) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -167,7 +166,7 @@ func prepareRoleStmt(c context.Context, tx pgx.Tx) error {
|
||||||
|
|
||||||
roleSQL, _ := processTemplate(w.String())
|
roleSQL, _ := processTemplate(w.String())
|
||||||
|
|
||||||
_, err := tx.Prepare(c, "_sg_get_role", roleSQL)
|
_, err := tx.Prepare(context.Background(), "_sg_get_role", roleSQL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,7 +24,6 @@ func initCompilers(c *config) (*qcode.Compiler, *psql.Compiler, error) {
|
||||||
|
|
||||||
conf := qcode.Config{
|
conf := qcode.Config{
|
||||||
Blocklist: c.DB.Blocklist,
|
Blocklist: c.DB.Blocklist,
|
||||||
KeepArgs: false,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
qc, err := qcode.NewCompiler(conf)
|
qc, err := qcode.NewCompiler(conf)
|
||||||
|
|
Loading…
Reference in New Issue