fix: bug with parsing variables in roles_query
This commit is contained in:
parent
82cc712a93
commit
bd157290f6
|
@ -12,7 +12,8 @@ import (
|
||||||
// to a prepared statement.
|
// to a prepared statement.
|
||||||
|
|
||||||
func (c *scontext) argList(md psql.Metadata) ([]interface{}, error) {
|
func (c *scontext) argList(md psql.Metadata) ([]interface{}, error) {
|
||||||
vars := make([]interface{}, len(md.Params))
|
params := md.Params()
|
||||||
|
vars := make([]interface{}, len(params))
|
||||||
|
|
||||||
var fields map[string]json.RawMessage
|
var fields map[string]json.RawMessage
|
||||||
var err error
|
var err error
|
||||||
|
@ -25,7 +26,7 @@ func (c *scontext) argList(md psql.Metadata) ([]interface{}, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for i, p := range md.Params {
|
for i, p := range params {
|
||||||
switch p.Name {
|
switch p.Name {
|
||||||
case "user_id":
|
case "user_id":
|
||||||
if v := c.Value(UserIDKey); v != nil {
|
if v := c.Value(UserIDKey); v != nil {
|
||||||
|
|
|
@ -88,6 +88,7 @@ func (sg *SuperGraph) buildMultiStmt(query, vars []byte) ([]stmt, error) {
|
||||||
|
|
||||||
stmts := make([]stmt, 0, len(sg.conf.Roles))
|
stmts := make([]stmt, 0, len(sg.conf.Roles))
|
||||||
w := &bytes.Buffer{}
|
w := &bytes.Buffer{}
|
||||||
|
md := psql.Metadata{}
|
||||||
|
|
||||||
for i := 0; i < len(sg.conf.Roles); i++ {
|
for i := 0; i < len(sg.conf.Roles); i++ {
|
||||||
role := &sg.conf.Roles[i]
|
role := &sg.conf.Roles[i]
|
||||||
|
@ -105,16 +106,18 @@ func (sg *SuperGraph) buildMultiStmt(query, vars []byte) ([]stmt, error) {
|
||||||
stmts = append(stmts, stmt{role: role, qc: qc})
|
stmts = append(stmts, stmt{role: role, qc: qc})
|
||||||
s := &stmts[len(stmts)-1]
|
s := &stmts[len(stmts)-1]
|
||||||
|
|
||||||
s.md, err = sg.pc.Compile(w, qc, psql.Variables(vm))
|
md, err = sg.pc.CompileWithMetadata(w, qc, psql.Variables(vm), md)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
s.sql = w.String()
|
s.sql = w.String()
|
||||||
|
s.md = md
|
||||||
|
|
||||||
w.Reset()
|
w.Reset()
|
||||||
}
|
}
|
||||||
|
|
||||||
sql, err := sg.renderUserQuery(stmts)
|
sql, err := sg.renderUserQuery(md, stmts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -124,7 +127,7 @@ func (sg *SuperGraph) buildMultiStmt(query, vars []byte) ([]stmt, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
//nolint: errcheck
|
//nolint: errcheck
|
||||||
func (sg *SuperGraph) renderUserQuery(stmts []stmt) (string, error) {
|
func (sg *SuperGraph) renderUserQuery(md psql.Metadata, stmts []stmt) (string, error) {
|
||||||
w := &bytes.Buffer{}
|
w := &bytes.Buffer{}
|
||||||
|
|
||||||
io.WriteString(w, `SELECT "_sg_auth_info"."role", (CASE "_sg_auth_info"."role" `)
|
io.WriteString(w, `SELECT "_sg_auth_info"."role", (CASE "_sg_auth_info"."role" `)
|
||||||
|
@ -142,7 +145,7 @@ func (sg *SuperGraph) renderUserQuery(stmts []stmt) (string, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
io.WriteString(w, `END) FROM (SELECT (CASE WHEN EXISTS (`)
|
io.WriteString(w, `END) FROM (SELECT (CASE WHEN EXISTS (`)
|
||||||
io.WriteString(w, sg.conf.RolesQuery)
|
md.RenderVar(w, sg.conf.RolesQuery)
|
||||||
io.WriteString(w, `) THEN `)
|
io.WriteString(w, `) THEN `)
|
||||||
|
|
||||||
io.WriteString(w, `(SELECT (CASE`)
|
io.WriteString(w, `(SELECT (CASE`)
|
||||||
|
@ -158,7 +161,7 @@ func (sg *SuperGraph) renderUserQuery(stmts []stmt) (string, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
io.WriteString(w, ` ELSE 'user' END) FROM (`)
|
io.WriteString(w, ` ELSE 'user' END) FROM (`)
|
||||||
io.WriteString(w, sg.conf.RolesQuery)
|
md.RenderVar(w, sg.conf.RolesQuery)
|
||||||
io.WriteString(w, `) AS "_sg_auth_roles_query" LIMIT 1) `)
|
io.WriteString(w, `) AS "_sg_auth_roles_query" LIMIT 1) `)
|
||||||
io.WriteString(w, `ELSE 'anon' END) FROM (VALUES (1)) AS "_sg_auth_filler") AS "_sg_auth_info"(role) LIMIT 1; `)
|
io.WriteString(w, `ELSE 'anon' END) FROM (VALUES (1)) AS "_sg_auth_filler") AS "_sg_auth_info"(role) LIMIT 1; `)
|
||||||
|
|
||||||
|
|
|
@ -125,7 +125,7 @@ func (c *scontext) execQuery() ([]byte, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(data) == 0 || st.md.Skipped == 0 {
|
if len(data) == 0 || st.md.Skipped() == 0 {
|
||||||
return data, nil
|
return data, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
19
core/init.go
19
core/init.go
|
@ -75,13 +75,22 @@ func (sg *SuperGraph) initConfig() error {
|
||||||
|
|
||||||
if c.RolesQuery == "" {
|
if c.RolesQuery == "" {
|
||||||
sg.log.Printf("INF roles_query not defined: attribute based access control disabled")
|
sg.log.Printf("INF roles_query not defined: attribute based access control disabled")
|
||||||
|
} else {
|
||||||
|
n := 0
|
||||||
|
for k, v := range sg.roles {
|
||||||
|
if k == "user" || k == "anon" {
|
||||||
|
n++
|
||||||
|
} else if v.Match != "" {
|
||||||
|
n++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
sg.abacEnabled = (n > 2)
|
||||||
|
|
||||||
|
if !sg.abacEnabled {
|
||||||
|
sg.log.Printf("WRN attribute based access control disabled: no custom roles found (with 'match' defined)")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
_, userExists := sg.roles["user"]
|
|
||||||
_, sg.anonExists = sg.roles["anon"]
|
|
||||||
|
|
||||||
sg.abacEnabled = userExists && c.RolesQuery != ""
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -1,4 +1,3 @@
|
||||||
//nolint:errcheck
|
|
||||||
package psql
|
package psql
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
@ -112,15 +111,15 @@ func (c *compilerContext) renderColumnSearchRank(sel *qcode.Select, ti *DBTableI
|
||||||
c.renderComma(columnsRendered)
|
c.renderComma(columnsRendered)
|
||||||
//fmt.Fprintf(w, `ts_rank("%s"."%s", websearch_to_tsquery('%s')) AS %s`,
|
//fmt.Fprintf(w, `ts_rank("%s"."%s", websearch_to_tsquery('%s')) AS %s`,
|
||||||
//c.sel.Name, cn, arg.Val, col.Name)
|
//c.sel.Name, cn, arg.Val, col.Name)
|
||||||
io.WriteString(c.w, `ts_rank(`)
|
_, _ = io.WriteString(c.w, `ts_rank(`)
|
||||||
colWithTable(c.w, ti.Name, cn)
|
colWithTable(c.w, ti.Name, cn)
|
||||||
if c.schema.ver >= 110000 {
|
if c.schema.ver >= 110000 {
|
||||||
io.WriteString(c.w, `, websearch_to_tsquery(`)
|
_, _ = io.WriteString(c.w, `, websearch_to_tsquery(`)
|
||||||
} else {
|
} else {
|
||||||
io.WriteString(c.w, `, to_tsquery(`)
|
_, _ = io.WriteString(c.w, `, to_tsquery(`)
|
||||||
}
|
}
|
||||||
c.renderValueExp(Param{Name: arg.Val, Type: "string"})
|
c.md.renderValueExp(c.w, Param{Name: arg.Val, Type: "string"})
|
||||||
io.WriteString(c.w, `))`)
|
_, _ = io.WriteString(c.w, `))`)
|
||||||
alias(c.w, col.Name)
|
alias(c.w, col.Name)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -137,15 +136,15 @@ func (c *compilerContext) renderColumnSearchHeadline(sel *qcode.Select, ti *DBTa
|
||||||
c.renderComma(columnsRendered)
|
c.renderComma(columnsRendered)
|
||||||
//fmt.Fprintf(w, `ts_headline("%s"."%s", websearch_to_tsquery('%s')) AS %s`,
|
//fmt.Fprintf(w, `ts_headline("%s"."%s", websearch_to_tsquery('%s')) AS %s`,
|
||||||
//c.sel.Name, cn, arg.Val, col.Name)
|
//c.sel.Name, cn, arg.Val, col.Name)
|
||||||
io.WriteString(c.w, `ts_headline(`)
|
_, _ = io.WriteString(c.w, `ts_headline(`)
|
||||||
colWithTable(c.w, ti.Name, cn)
|
colWithTable(c.w, ti.Name, cn)
|
||||||
if c.schema.ver >= 110000 {
|
if c.schema.ver >= 110000 {
|
||||||
io.WriteString(c.w, `, websearch_to_tsquery(`)
|
_, _ = io.WriteString(c.w, `, websearch_to_tsquery(`)
|
||||||
} else {
|
} else {
|
||||||
io.WriteString(c.w, `, to_tsquery(`)
|
_, _ = io.WriteString(c.w, `, to_tsquery(`)
|
||||||
}
|
}
|
||||||
c.renderValueExp(Param{Name: arg.Val, Type: "string"})
|
c.md.renderValueExp(c.w, Param{Name: arg.Val, Type: "string"})
|
||||||
io.WriteString(c.w, `))`)
|
_, _ = io.WriteString(c.w, `))`)
|
||||||
alias(c.w, col.Name)
|
alias(c.w, col.Name)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -157,9 +156,9 @@ func (c *compilerContext) renderColumnTypename(sel *qcode.Select, ti *DBTableInf
|
||||||
}
|
}
|
||||||
|
|
||||||
c.renderComma(columnsRendered)
|
c.renderComma(columnsRendered)
|
||||||
io.WriteString(c.w, `(`)
|
_, _ = io.WriteString(c.w, `(`)
|
||||||
squoted(c.w, ti.Name)
|
squoted(c.w, ti.Name)
|
||||||
io.WriteString(c.w, ` :: text)`)
|
_, _ = io.WriteString(c.w, ` :: text)`)
|
||||||
alias(c.w, col.Name)
|
alias(c.w, col.Name)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -169,9 +168,9 @@ func (c *compilerContext) renderColumnFunction(sel *qcode.Select, ti *DBTableInf
|
||||||
pl := funcPrefixLen(c.schema.fm, col.Name)
|
pl := funcPrefixLen(c.schema.fm, col.Name)
|
||||||
// if pl == 0 {
|
// if pl == 0 {
|
||||||
// //fmt.Fprintf(w, `'%s not defined' AS %s`, cn, col.Name)
|
// //fmt.Fprintf(w, `'%s not defined' AS %s`, cn, col.Name)
|
||||||
// io.WriteString(c.w, `'`)
|
// _, _ = io.WriteString(c.w, `'`)
|
||||||
// io.WriteString(c.w, col.Name)
|
// _, _ = io.WriteString(c.w, col.Name)
|
||||||
// io.WriteString(c.w, ` not defined'`)
|
// _, _ = io.WriteString(c.w, ` not defined'`)
|
||||||
// alias(c.w, col.Name)
|
// alias(c.w, col.Name)
|
||||||
// }
|
// }
|
||||||
|
|
||||||
|
@ -190,10 +189,10 @@ func (c *compilerContext) renderColumnFunction(sel *qcode.Select, ti *DBTableInf
|
||||||
c.renderComma(columnsRendered)
|
c.renderComma(columnsRendered)
|
||||||
|
|
||||||
//fmt.Fprintf(w, `%s("%s"."%s") AS %s`, fn, c.sel.Name, cn, col.Name)
|
//fmt.Fprintf(w, `%s("%s"."%s") AS %s`, fn, c.sel.Name, cn, col.Name)
|
||||||
io.WriteString(c.w, fn)
|
_, _ = io.WriteString(c.w, fn)
|
||||||
io.WriteString(c.w, `(`)
|
_, _ = io.WriteString(c.w, `(`)
|
||||||
colWithTable(c.w, ti.Name, cn)
|
colWithTable(c.w, ti.Name, cn)
|
||||||
io.WriteString(c.w, `)`)
|
_, _ = io.WriteString(c.w, `)`)
|
||||||
alias(c.w, col.Name)
|
alias(c.w, col.Name)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -201,7 +200,7 @@ func (c *compilerContext) renderColumnFunction(sel *qcode.Select, ti *DBTableInf
|
||||||
|
|
||||||
func (c *compilerContext) renderComma(columnsRendered int) {
|
func (c *compilerContext) renderComma(columnsRendered int) {
|
||||||
if columnsRendered != 0 {
|
if columnsRendered != 0 {
|
||||||
io.WriteString(c.w, `, `)
|
_, _ = io.WriteString(c.w, `, `)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -25,7 +25,7 @@ func (c *compilerContext) renderInsert(
|
||||||
if insert[0] == '[' {
|
if insert[0] == '[' {
|
||||||
io.WriteString(c.w, `json_array_elements(`)
|
io.WriteString(c.w, `json_array_elements(`)
|
||||||
}
|
}
|
||||||
c.renderValueExp(Param{Name: qc.ActionVar, Type: "json"})
|
c.md.renderValueExp(c.w, Param{Name: qc.ActionVar, Type: "json"})
|
||||||
io.WriteString(c.w, ` :: json`)
|
io.WriteString(c.w, ` :: json`)
|
||||||
if insert[0] == '[' {
|
if insert[0] == '[' {
|
||||||
io.WriteString(c.w, `)`)
|
io.WriteString(c.w, `)`)
|
||||||
|
|
|
@ -0,0 +1,61 @@
|
||||||
|
package psql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (md *Metadata) RenderVar(w io.Writer, vv string) {
|
||||||
|
f, s := -1, 0
|
||||||
|
|
||||||
|
for i := range vv {
|
||||||
|
v := vv[i]
|
||||||
|
switch {
|
||||||
|
case (i > 0 && vv[i-1] != '\\' && v == '$') || v == '$':
|
||||||
|
if (i - s) > 0 {
|
||||||
|
_, _ = io.WriteString(w, vv[s:i])
|
||||||
|
}
|
||||||
|
f = i
|
||||||
|
|
||||||
|
case (v < 'a' && v > 'z') &&
|
||||||
|
(v < 'A' && v > 'Z') &&
|
||||||
|
(v < '0' && v > '9') &&
|
||||||
|
v != '_' &&
|
||||||
|
f != -1 &&
|
||||||
|
(i-f) > 1:
|
||||||
|
md.renderValueExp(w, Param{Name: vv[f+1 : i]})
|
||||||
|
s = i
|
||||||
|
f = -1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if f != -1 && (len(vv)-f) > 1 {
|
||||||
|
md.renderValueExp(w, Param{Name: vv[f+1:]})
|
||||||
|
} else {
|
||||||
|
_, _ = io.WriteString(w, vv[s:])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (md *Metadata) renderValueExp(w io.Writer, p Param) {
|
||||||
|
_, _ = io.WriteString(w, `$`)
|
||||||
|
if v, ok := md.pindex[p.Name]; ok {
|
||||||
|
int32String(w, int32(v))
|
||||||
|
|
||||||
|
} else {
|
||||||
|
md.params = append(md.params, p)
|
||||||
|
n := len(md.params)
|
||||||
|
|
||||||
|
if md.pindex == nil {
|
||||||
|
md.pindex = make(map[string]int)
|
||||||
|
}
|
||||||
|
md.pindex[p.Name] = n
|
||||||
|
int32String(w, int32(n))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (md Metadata) Skipped() uint32 {
|
||||||
|
return md.skipped
|
||||||
|
}
|
||||||
|
|
||||||
|
func (md Metadata) Params() []Param {
|
||||||
|
return md.params
|
||||||
|
}
|
|
@ -432,11 +432,11 @@ func (c *compilerContext) renderInsertUpdateColumns(
|
||||||
val := root.PresetMap[cn]
|
val := root.PresetMap[cn]
|
||||||
switch {
|
switch {
|
||||||
case ok && len(val) > 1 && val[0] == '$':
|
case ok && len(val) > 1 && val[0] == '$':
|
||||||
c.renderValueExp(Param{Name: val[1:], Type: col.Type})
|
c.md.renderValueExp(c.w, Param{Name: val[1:], Type: col.Type})
|
||||||
|
|
||||||
case ok && strings.HasPrefix(val, "sql:"):
|
case ok && strings.HasPrefix(val, "sql:"):
|
||||||
io.WriteString(c.w, `(`)
|
io.WriteString(c.w, `(`)
|
||||||
c.renderVar(val[4:], c.renderValueExp)
|
c.md.RenderVar(c.w, val[4:])
|
||||||
io.WriteString(c.w, `)`)
|
io.WriteString(c.w, `)`)
|
||||||
|
|
||||||
case ok:
|
case ok:
|
||||||
|
|
|
@ -25,8 +25,8 @@ type Param struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Metadata struct {
|
type Metadata struct {
|
||||||
Skipped uint32
|
skipped uint32
|
||||||
Params []Param
|
params []Param
|
||||||
pindex map[string]int
|
pindex map[string]int
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -80,26 +80,30 @@ func (co *Compiler) CompileEx(qc *qcode.QCode, vars Variables) (Metadata, []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
func (co *Compiler) Compile(w io.Writer, qc *qcode.QCode, vars Variables) (Metadata, error) {
|
func (co *Compiler) Compile(w io.Writer, qc *qcode.QCode, vars Variables) (Metadata, error) {
|
||||||
|
return co.CompileWithMetadata(w, qc, vars, Metadata{})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (co *Compiler) CompileWithMetadata(w io.Writer, qc *qcode.QCode, vars Variables, md Metadata) (Metadata, error) {
|
||||||
|
md.skipped = 0
|
||||||
|
|
||||||
if qc == nil {
|
if qc == nil {
|
||||||
return Metadata{}, fmt.Errorf("qcode is nil")
|
return md, fmt.Errorf("qcode is nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
switch qc.Type {
|
switch qc.Type {
|
||||||
case qcode.QTQuery:
|
case qcode.QTQuery:
|
||||||
return co.compileQuery(w, qc, vars)
|
return co.compileQueryWithMetadata(w, qc, vars, md)
|
||||||
|
|
||||||
case qcode.QTInsert,
|
case qcode.QTInsert,
|
||||||
qcode.QTUpdate,
|
qcode.QTUpdate,
|
||||||
qcode.QTDelete,
|
qcode.QTDelete,
|
||||||
qcode.QTUpsert:
|
qcode.QTUpsert:
|
||||||
return co.compileMutation(w, qc, vars)
|
return co.compileMutation(w, qc, vars)
|
||||||
|
|
||||||
|
default:
|
||||||
|
return Metadata{}, fmt.Errorf("Unknown operation type %d", qc.Type)
|
||||||
}
|
}
|
||||||
|
|
||||||
return Metadata{}, fmt.Errorf("Unknown operation type %d", qc.Type)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (co *Compiler) compileQuery(w io.Writer, qc *qcode.QCode, vars Variables) (Metadata, error) {
|
|
||||||
return co.compileQueryWithMetadata(w, qc, vars, Metadata{})
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (co *Compiler) compileQueryWithMetadata(
|
func (co *Compiler) compileQueryWithMetadata(
|
||||||
|
@ -176,7 +180,7 @@ func (co *Compiler) compileQueryWithMetadata(
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, cid := range sel.Children {
|
for _, cid := range sel.Children {
|
||||||
if hasBit(c.md.Skipped, uint32(cid)) {
|
if hasBit(c.md.skipped, uint32(cid)) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
child := &c.s[cid]
|
child := &c.s[cid]
|
||||||
|
@ -354,7 +358,7 @@ func (c *compilerContext) initSelect(sel *qcode.Select, ti *DBTableInfo, vars Va
|
||||||
if _, ok := colmap[rel.Left.Col]; !ok {
|
if _, ok := colmap[rel.Left.Col]; !ok {
|
||||||
cols = append(cols, &qcode.Column{Table: ti.Name, Name: rel.Left.Col, FieldName: rel.Right.Col})
|
cols = append(cols, &qcode.Column{Table: ti.Name, Name: rel.Left.Col, FieldName: rel.Right.Col})
|
||||||
colmap[rel.Left.Col] = struct{}{}
|
colmap[rel.Left.Col] = struct{}{}
|
||||||
c.md.Skipped |= (1 << uint(id))
|
c.md.skipped |= (1 << uint(id))
|
||||||
}
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
|
@ -622,7 +626,7 @@ func (c *compilerContext) renderJoinColumns(sel *qcode.Select, ti *DBTableInfo,
|
||||||
i := colsRendered
|
i := colsRendered
|
||||||
|
|
||||||
for _, id := range sel.Children {
|
for _, id := range sel.Children {
|
||||||
if hasBit(c.md.Skipped, uint32(id)) {
|
if hasBit(c.md.skipped, uint32(id)) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
childSel := &c.s[id]
|
childSel := &c.s[id]
|
||||||
|
@ -804,7 +808,7 @@ func (c *compilerContext) renderCursorCTE(sel *qcode.Select) error {
|
||||||
quoted(c.w, ob.Col)
|
quoted(c.w, ob.Col)
|
||||||
}
|
}
|
||||||
io.WriteString(c.w, ` FROM string_to_array(`)
|
io.WriteString(c.w, ` FROM string_to_array(`)
|
||||||
c.renderValueExp(Param{Name: "cursor", Type: "json"})
|
c.md.renderValueExp(c.w, Param{Name: "cursor", Type: "json"})
|
||||||
io.WriteString(c.w, `, ',') as a) `)
|
io.WriteString(c.w, `, ',') as a) `)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -1102,7 +1106,7 @@ func (c *compilerContext) renderOp(ex *qcode.Exp, ti *DBTableInfo) error {
|
||||||
} else {
|
} else {
|
||||||
io.WriteString(c.w, `) @@ to_tsquery(`)
|
io.WriteString(c.w, `) @@ to_tsquery(`)
|
||||||
}
|
}
|
||||||
c.renderValueExp(Param{Name: ex.Val, Type: "string"})
|
c.md.renderValueExp(c.w, Param{Name: ex.Val, Type: "string"})
|
||||||
io.WriteString(c.w, `))`)
|
io.WriteString(c.w, `))`)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -1191,7 +1195,7 @@ func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string, col *
|
||||||
switch {
|
switch {
|
||||||
case ok && strings.HasPrefix(val, "sql:"):
|
case ok && strings.HasPrefix(val, "sql:"):
|
||||||
io.WriteString(c.w, `(`)
|
io.WriteString(c.w, `(`)
|
||||||
c.renderVar(val[4:], c.renderValueExp)
|
c.md.RenderVar(c.w, val[4:])
|
||||||
io.WriteString(c.w, `)`)
|
io.WriteString(c.w, `)`)
|
||||||
|
|
||||||
case ok:
|
case ok:
|
||||||
|
@ -1199,7 +1203,7 @@ func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string, col *
|
||||||
|
|
||||||
case ex.Op == qcode.OpIn || ex.Op == qcode.OpNotIn:
|
case ex.Op == qcode.OpIn || ex.Op == qcode.OpNotIn:
|
||||||
io.WriteString(c.w, `(ARRAY(SELECT json_array_elements_text(`)
|
io.WriteString(c.w, `(ARRAY(SELECT json_array_elements_text(`)
|
||||||
c.renderValueExp(Param{Name: ex.Val, Type: col.Type, IsArray: true})
|
c.md.renderValueExp(c.w, Param{Name: ex.Val, Type: col.Type, IsArray: true})
|
||||||
io.WriteString(c.w, `))`)
|
io.WriteString(c.w, `))`)
|
||||||
|
|
||||||
io.WriteString(c.w, ` :: `)
|
io.WriteString(c.w, ` :: `)
|
||||||
|
@ -1208,7 +1212,7 @@ func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string, col *
|
||||||
return
|
return
|
||||||
|
|
||||||
default:
|
default:
|
||||||
c.renderValueExp(Param{Name: ex.Val, Type: col.Type, IsArray: false})
|
c.md.renderValueExp(c.w, Param{Name: ex.Val, Type: col.Type, IsArray: false})
|
||||||
}
|
}
|
||||||
|
|
||||||
case qcode.ValRef:
|
case qcode.ValRef:
|
||||||
|
@ -1222,54 +1226,6 @@ func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string, col *
|
||||||
io.WriteString(c.w, col.Type)
|
io.WriteString(c.w, col.Type)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *compilerContext) renderValueExp(p Param) {
|
|
||||||
io.WriteString(c.w, `$`)
|
|
||||||
if v, ok := c.md.pindex[p.Name]; ok {
|
|
||||||
int32String(c.w, int32(v))
|
|
||||||
|
|
||||||
} else {
|
|
||||||
c.md.Params = append(c.md.Params, p)
|
|
||||||
n := len(c.md.Params)
|
|
||||||
|
|
||||||
if c.md.pindex == nil {
|
|
||||||
c.md.pindex = make(map[string]int)
|
|
||||||
}
|
|
||||||
c.md.pindex[p.Name] = n
|
|
||||||
int32String(c.w, int32(n))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *compilerContext) renderVar(vv string, fn func(Param)) {
|
|
||||||
f, s := -1, 0
|
|
||||||
|
|
||||||
for i := range vv {
|
|
||||||
v := vv[i]
|
|
||||||
switch {
|
|
||||||
case (i > 0 && vv[i-1] != '\\' && v == '$') || v == '$':
|
|
||||||
if (i - s) > 0 {
|
|
||||||
io.WriteString(c.w, vv[s:i])
|
|
||||||
}
|
|
||||||
f = i
|
|
||||||
|
|
||||||
case (v < 'a' && v > 'z') &&
|
|
||||||
(v < 'A' && v > 'Z') &&
|
|
||||||
(v < '0' && v > '9') &&
|
|
||||||
v != '_' &&
|
|
||||||
f != -1 &&
|
|
||||||
(i-f) > 1:
|
|
||||||
fn(Param{Name: vv[f+1 : i]})
|
|
||||||
s = i
|
|
||||||
f = -1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if f != -1 && (len(vv)-f) > 1 {
|
|
||||||
fn(Param{Name: vv[f+1:]})
|
|
||||||
} else {
|
|
||||||
io.WriteString(c.w, vv[s:])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func funcPrefixLen(fm map[string]*DBFunction, fn string) int {
|
func funcPrefixLen(fm map[string]*DBFunction, fn string) int {
|
||||||
switch {
|
switch {
|
||||||
case strings.HasPrefix(fn, "avg_"):
|
case strings.HasPrefix(fn, "avg_"):
|
||||||
|
|
|
@ -22,7 +22,7 @@ func (c *compilerContext) renderUpdate(
|
||||||
}
|
}
|
||||||
|
|
||||||
io.WriteString(c.w, `WITH "_sg_input" AS (SELECT `)
|
io.WriteString(c.w, `WITH "_sg_input" AS (SELECT `)
|
||||||
c.renderValueExp(Param{Name: qc.ActionVar, Type: "json"})
|
c.md.renderValueExp(c.w, Param{Name: qc.ActionVar, Type: "json"})
|
||||||
// io.WriteString(c.w, qc.ActionVar)
|
// io.WriteString(c.w, qc.ActionVar)
|
||||||
io.WriteString(c.w, ` :: json AS j)`)
|
io.WriteString(c.w, ` :: json AS j)`)
|
||||||
|
|
||||||
|
|
|
@ -2,8 +2,9 @@ package qcode
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"errors"
|
"errors"
|
||||||
"github.com/chirino/graphql/schema"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
|
"github.com/chirino/graphql/schema"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestCompile1(t *testing.T) {
|
func TestCompile1(t *testing.T) {
|
||||||
|
@ -130,6 +131,22 @@ updateThread {
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFragmentsCompile(t *testing.T) {
|
||||||
|
gql := `
|
||||||
|
fragment userFields on user {
|
||||||
|
name
|
||||||
|
email
|
||||||
|
}
|
||||||
|
|
||||||
|
query { users { ...userFields } }`
|
||||||
|
qcompile, _ := NewCompiler(Config{})
|
||||||
|
_, err := qcompile.Compile([]byte(gql), "anon")
|
||||||
|
|
||||||
|
if err == nil {
|
||||||
|
t.Fatal(errors.New("expecting an error"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
var gql = []byte(`
|
var gql = []byte(`
|
||||||
{products(
|
{products(
|
||||||
# returns only 30 items
|
# returns only 30 items
|
||||||
|
|
|
@ -125,7 +125,7 @@ func (sg *SuperGraph) prepareRoleStmt() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
io.WriteString(w, ` ELSE $2 END) FROM (`)
|
io.WriteString(w, ` ELSE $2 END) FROM (`)
|
||||||
io.WriteString(w, sg.conf.RolesQuery)
|
io.WriteString(w, rq)
|
||||||
io.WriteString(w, `) AS "_sg_auth_roles_query" LIMIT 1) `)
|
io.WriteString(w, `) AS "_sg_auth_roles_query" LIMIT 1) `)
|
||||||
io.WriteString(w, `ELSE 'anon' END) FROM (VALUES (1)) AS "_sg_auth_filler" LIMIT 1; `)
|
io.WriteString(w, `ELSE 'anon' END) FROM (VALUES (1)) AS "_sg_auth_filler" LIMIT 1; `)
|
||||||
|
|
||||||
|
|
|
@ -22,7 +22,7 @@ func (sg *SuperGraph) execRemoteJoin(st *stmt, data []byte, hdr http.Header) ([]
|
||||||
// fetch the field name used within the db response json
|
// fetch the field name used within the db response json
|
||||||
// that are used to mark insertion points and the mapping between
|
// that are used to mark insertion points and the mapping between
|
||||||
// those field names and their select objects
|
// those field names and their select objects
|
||||||
fids, sfmap := sg.parentFieldIds(&h, sel, st.md.Skipped)
|
fids, sfmap := sg.parentFieldIds(&h, sel, st.md.Skipped())
|
||||||
|
|
||||||
// fetch the field values of the marked insertion points
|
// fetch the field values of the marked insertion points
|
||||||
// these values contain the id to be used with fetching remote data
|
// these values contain the id to be used with fetching remote data
|
||||||
|
@ -67,7 +67,7 @@ func (sg *SuperGraph) resolveRemote(
|
||||||
to := toA[:1]
|
to := toA[:1]
|
||||||
|
|
||||||
// use the json key to find the related Select object
|
// use the json key to find the related Select object
|
||||||
h.Write(field.Key)
|
_, _ = h.Write(field.Key)
|
||||||
k1 := h.Sum64()
|
k1 := h.Sum64()
|
||||||
|
|
||||||
s, ok := sfmap[k1]
|
s, ok := sfmap[k1]
|
||||||
|
@ -136,7 +136,7 @@ func (sg *SuperGraph) resolveRemotes(
|
||||||
for i, id := range from {
|
for i, id := range from {
|
||||||
|
|
||||||
// use the json key to find the related Select object
|
// use the json key to find the related Select object
|
||||||
h.Write(id.Key)
|
_, _ = h.Write(id.Key)
|
||||||
k1 := h.Sum64()
|
k1 := h.Sum64()
|
||||||
|
|
||||||
s, ok := sfmap[k1]
|
s, ok := sfmap[k1]
|
||||||
|
@ -230,7 +230,7 @@ func (sg *SuperGraph) parentFieldIds(h *maphash.Hash, sel []qcode.Select, skippe
|
||||||
fm[n] = r.IDField
|
fm[n] = r.IDField
|
||||||
n++
|
n++
|
||||||
|
|
||||||
h.Write(r.IDField)
|
_, _ = h.Write(r.IDField)
|
||||||
sm[h.Sum64()] = s
|
sm[h.Sum64()] = s
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -86,7 +86,7 @@ func (sg *SuperGraph) initRemotes(t Table) error {
|
||||||
sg.rmap[mkkey(&h, r.Name, t.Name)] = rf
|
sg.rmap[mkkey(&h, r.Name, t.Name)] = rf
|
||||||
|
|
||||||
// index resolver obj by IDField
|
// index resolver obj by IDField
|
||||||
h.Write(rf.IDField)
|
_, _ = h.Write(rf.IDField)
|
||||||
sg.rmap[h.Sum64()] = rf
|
sg.rmap[h.Sum64()] = rf
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -12,7 +12,7 @@ func Filter(w *bytes.Buffer, b []byte, keys []string) error {
|
||||||
h := maphash.Hash{}
|
h := maphash.Hash{}
|
||||||
|
|
||||||
for i := range keys {
|
for i := range keys {
|
||||||
h.WriteString(keys[i])
|
_, _ = h.WriteString(keys[i])
|
||||||
kmap[h.Sum64()] = struct{}{}
|
kmap[h.Sum64()] = struct{}{}
|
||||||
h.Reset()
|
h.Reset()
|
||||||
}
|
}
|
||||||
|
@ -134,7 +134,7 @@ func Filter(w *bytes.Buffer, b []byte, keys []string) error {
|
||||||
cb := b[s:(e + 1)]
|
cb := b[s:(e + 1)]
|
||||||
e = 0
|
e = 0
|
||||||
|
|
||||||
h.Write(k)
|
_, _ = h.Write(k)
|
||||||
_, ok := kmap[h.Sum64()]
|
_, ok := kmap[h.Sum64()]
|
||||||
h.Reset()
|
h.Reset()
|
||||||
|
|
||||||
|
|
|
@ -44,7 +44,7 @@ func Get(b []byte, keys [][]byte) []Field {
|
||||||
h := maphash.Hash{}
|
h := maphash.Hash{}
|
||||||
|
|
||||||
for i := range keys {
|
for i := range keys {
|
||||||
h.Write(keys[i])
|
_, _ = h.Write(keys[i])
|
||||||
kmap[h.Sum64()] = struct{}{}
|
kmap[h.Sum64()] = struct{}{}
|
||||||
h.Reset()
|
h.Reset()
|
||||||
}
|
}
|
||||||
|
@ -144,7 +144,7 @@ func Get(b []byte, keys [][]byte) []Field {
|
||||||
}
|
}
|
||||||
|
|
||||||
if e != 0 {
|
if e != 0 {
|
||||||
h.Write(k)
|
_, _ = h.Write(k)
|
||||||
_, ok := kmap[h.Sum64()]
|
_, ok := kmap[h.Sum64()]
|
||||||
h.Reset()
|
h.Reset()
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue