From bd157290f6eca176a208ba5ab2c0b19b75671351 Mon Sep 17 00:00:00 2001 From: Vikram Rangnekar Date: Thu, 4 Jun 2020 21:55:52 -0400 Subject: [PATCH] fix: bug with parsing variables in roles_query --- core/args.go | 5 +- core/build.go | 13 +++-- core/core.go | 2 +- core/init.go | 19 +++++-- core/internal/psql/columns.go | 39 +++++++------- core/internal/psql/insert.go | 2 +- core/internal/psql/metadata.go | 61 ++++++++++++++++++++++ core/internal/psql/mutate.go | 4 +- core/internal/psql/query.go | 86 ++++++++----------------------- core/internal/psql/update.go | 2 +- core/internal/qcode/parse_test.go | 19 ++++++- core/prepare.go | 2 +- core/remote.go | 8 +-- core/resolve.go | 2 +- jsn/filter.go | 4 +- jsn/get.go | 4 +- 16 files changed, 159 insertions(+), 113 deletions(-) create mode 100644 core/internal/psql/metadata.go diff --git a/core/args.go b/core/args.go index 9a309f6..ab9b473 100644 --- a/core/args.go +++ b/core/args.go @@ -12,7 +12,8 @@ import ( // to a prepared statement. func (c *scontext) argList(md psql.Metadata) ([]interface{}, error) { - vars := make([]interface{}, len(md.Params)) + params := md.Params() + vars := make([]interface{}, len(params)) var fields map[string]json.RawMessage var err error @@ -25,7 +26,7 @@ func (c *scontext) argList(md psql.Metadata) ([]interface{}, error) { } } - for i, p := range md.Params { + for i, p := range params { switch p.Name { case "user_id": if v := c.Value(UserIDKey); v != nil { diff --git a/core/build.go b/core/build.go index 9089235..a24cda9 100644 --- a/core/build.go +++ b/core/build.go @@ -88,6 +88,7 @@ func (sg *SuperGraph) buildMultiStmt(query, vars []byte) ([]stmt, error) { stmts := make([]stmt, 0, len(sg.conf.Roles)) w := &bytes.Buffer{} + md := psql.Metadata{} for i := 0; i < len(sg.conf.Roles); i++ { role := &sg.conf.Roles[i] @@ -105,16 +106,18 @@ func (sg *SuperGraph) buildMultiStmt(query, vars []byte) ([]stmt, error) { stmts = append(stmts, stmt{role: role, qc: qc}) s := &stmts[len(stmts)-1] - s.md, err = sg.pc.Compile(w, qc, psql.Variables(vm)) + md, err = sg.pc.CompileWithMetadata(w, qc, psql.Variables(vm), md) if err != nil { return nil, err } s.sql = w.String() + s.md = md + w.Reset() } - sql, err := sg.renderUserQuery(stmts) + sql, err := sg.renderUserQuery(md, stmts) if err != nil { return nil, err } @@ -124,7 +127,7 @@ func (sg *SuperGraph) buildMultiStmt(query, vars []byte) ([]stmt, error) { } //nolint: errcheck -func (sg *SuperGraph) renderUserQuery(stmts []stmt) (string, error) { +func (sg *SuperGraph) renderUserQuery(md psql.Metadata, stmts []stmt) (string, error) { w := &bytes.Buffer{} io.WriteString(w, `SELECT "_sg_auth_info"."role", (CASE "_sg_auth_info"."role" `) @@ -142,7 +145,7 @@ func (sg *SuperGraph) renderUserQuery(stmts []stmt) (string, error) { } io.WriteString(w, `END) FROM (SELECT (CASE WHEN EXISTS (`) - io.WriteString(w, sg.conf.RolesQuery) + md.RenderVar(w, sg.conf.RolesQuery) io.WriteString(w, `) THEN `) io.WriteString(w, `(SELECT (CASE`) @@ -158,7 +161,7 @@ func (sg *SuperGraph) renderUserQuery(stmts []stmt) (string, error) { } io.WriteString(w, ` ELSE 'user' END) FROM (`) - io.WriteString(w, sg.conf.RolesQuery) + md.RenderVar(w, sg.conf.RolesQuery) io.WriteString(w, `) AS "_sg_auth_roles_query" LIMIT 1) `) io.WriteString(w, `ELSE 'anon' END) FROM (VALUES (1)) AS "_sg_auth_filler") AS "_sg_auth_info"(role) LIMIT 1; `) diff --git a/core/core.go b/core/core.go index 9c9d7f3..31df8d2 100644 --- a/core/core.go +++ b/core/core.go @@ -125,7 +125,7 @@ func (c *scontext) execQuery() ([]byte, error) { return nil, err } - if len(data) == 0 || st.md.Skipped == 0 { + if len(data) == 0 || st.md.Skipped() == 0 { return data, nil } diff --git a/core/init.go b/core/init.go index fddb1b4..a464ef0 100644 --- a/core/init.go +++ b/core/init.go @@ -75,13 +75,22 @@ func (sg *SuperGraph) initConfig() error { if c.RolesQuery == "" { sg.log.Printf("INF roles_query not defined: attribute based access control disabled") + } else { + n := 0 + for k, v := range sg.roles { + if k == "user" || k == "anon" { + n++ + } else if v.Match != "" { + n++ + } + } + sg.abacEnabled = (n > 2) + + if !sg.abacEnabled { + sg.log.Printf("WRN attribute based access control disabled: no custom roles found (with 'match' defined)") + } } - _, userExists := sg.roles["user"] - _, sg.anonExists = sg.roles["anon"] - - sg.abacEnabled = userExists && c.RolesQuery != "" - return nil } diff --git a/core/internal/psql/columns.go b/core/internal/psql/columns.go index d339dc9..73148fc 100644 --- a/core/internal/psql/columns.go +++ b/core/internal/psql/columns.go @@ -1,4 +1,3 @@ -//nolint:errcheck package psql import ( @@ -112,15 +111,15 @@ func (c *compilerContext) renderColumnSearchRank(sel *qcode.Select, ti *DBTableI c.renderComma(columnsRendered) //fmt.Fprintf(w, `ts_rank("%s"."%s", websearch_to_tsquery('%s')) AS %s`, //c.sel.Name, cn, arg.Val, col.Name) - io.WriteString(c.w, `ts_rank(`) + _, _ = io.WriteString(c.w, `ts_rank(`) colWithTable(c.w, ti.Name, cn) if c.schema.ver >= 110000 { - io.WriteString(c.w, `, websearch_to_tsquery(`) + _, _ = io.WriteString(c.w, `, websearch_to_tsquery(`) } else { - io.WriteString(c.w, `, to_tsquery(`) + _, _ = io.WriteString(c.w, `, to_tsquery(`) } - c.renderValueExp(Param{Name: arg.Val, Type: "string"}) - io.WriteString(c.w, `))`) + c.md.renderValueExp(c.w, Param{Name: arg.Val, Type: "string"}) + _, _ = io.WriteString(c.w, `))`) alias(c.w, col.Name) return nil @@ -137,15 +136,15 @@ func (c *compilerContext) renderColumnSearchHeadline(sel *qcode.Select, ti *DBTa c.renderComma(columnsRendered) //fmt.Fprintf(w, `ts_headline("%s"."%s", websearch_to_tsquery('%s')) AS %s`, //c.sel.Name, cn, arg.Val, col.Name) - io.WriteString(c.w, `ts_headline(`) + _, _ = io.WriteString(c.w, `ts_headline(`) colWithTable(c.w, ti.Name, cn) if c.schema.ver >= 110000 { - io.WriteString(c.w, `, websearch_to_tsquery(`) + _, _ = io.WriteString(c.w, `, websearch_to_tsquery(`) } else { - io.WriteString(c.w, `, to_tsquery(`) + _, _ = io.WriteString(c.w, `, to_tsquery(`) } - c.renderValueExp(Param{Name: arg.Val, Type: "string"}) - io.WriteString(c.w, `))`) + c.md.renderValueExp(c.w, Param{Name: arg.Val, Type: "string"}) + _, _ = io.WriteString(c.w, `))`) alias(c.w, col.Name) return nil @@ -157,9 +156,9 @@ func (c *compilerContext) renderColumnTypename(sel *qcode.Select, ti *DBTableInf } c.renderComma(columnsRendered) - io.WriteString(c.w, `(`) + _, _ = io.WriteString(c.w, `(`) squoted(c.w, ti.Name) - io.WriteString(c.w, ` :: text)`) + _, _ = io.WriteString(c.w, ` :: text)`) alias(c.w, col.Name) return nil @@ -169,9 +168,9 @@ func (c *compilerContext) renderColumnFunction(sel *qcode.Select, ti *DBTableInf pl := funcPrefixLen(c.schema.fm, col.Name) // if pl == 0 { // //fmt.Fprintf(w, `'%s not defined' AS %s`, cn, col.Name) - // io.WriteString(c.w, `'`) - // io.WriteString(c.w, col.Name) - // io.WriteString(c.w, ` not defined'`) + // _, _ = io.WriteString(c.w, `'`) + // _, _ = io.WriteString(c.w, col.Name) + // _, _ = io.WriteString(c.w, ` not defined'`) // alias(c.w, col.Name) // } @@ -190,10 +189,10 @@ func (c *compilerContext) renderColumnFunction(sel *qcode.Select, ti *DBTableInf c.renderComma(columnsRendered) //fmt.Fprintf(w, `%s("%s"."%s") AS %s`, fn, c.sel.Name, cn, col.Name) - io.WriteString(c.w, fn) - io.WriteString(c.w, `(`) + _, _ = io.WriteString(c.w, fn) + _, _ = io.WriteString(c.w, `(`) colWithTable(c.w, ti.Name, cn) - io.WriteString(c.w, `)`) + _, _ = io.WriteString(c.w, `)`) alias(c.w, col.Name) return nil @@ -201,7 +200,7 @@ func (c *compilerContext) renderColumnFunction(sel *qcode.Select, ti *DBTableInf func (c *compilerContext) renderComma(columnsRendered int) { if columnsRendered != 0 { - io.WriteString(c.w, `, `) + _, _ = io.WriteString(c.w, `, `) } } diff --git a/core/internal/psql/insert.go b/core/internal/psql/insert.go index 21dc243..48d1a2c 100644 --- a/core/internal/psql/insert.go +++ b/core/internal/psql/insert.go @@ -25,7 +25,7 @@ func (c *compilerContext) renderInsert( if insert[0] == '[' { io.WriteString(c.w, `json_array_elements(`) } - c.renderValueExp(Param{Name: qc.ActionVar, Type: "json"}) + c.md.renderValueExp(c.w, Param{Name: qc.ActionVar, Type: "json"}) io.WriteString(c.w, ` :: json`) if insert[0] == '[' { io.WriteString(c.w, `)`) diff --git a/core/internal/psql/metadata.go b/core/internal/psql/metadata.go new file mode 100644 index 0000000..e8d8bfe --- /dev/null +++ b/core/internal/psql/metadata.go @@ -0,0 +1,61 @@ +package psql + +import ( + "io" +) + +func (md *Metadata) RenderVar(w io.Writer, vv string) { + f, s := -1, 0 + + for i := range vv { + v := vv[i] + switch { + case (i > 0 && vv[i-1] != '\\' && v == '$') || v == '$': + if (i - s) > 0 { + _, _ = io.WriteString(w, vv[s:i]) + } + f = i + + case (v < 'a' && v > 'z') && + (v < 'A' && v > 'Z') && + (v < '0' && v > '9') && + v != '_' && + f != -1 && + (i-f) > 1: + md.renderValueExp(w, Param{Name: vv[f+1 : i]}) + s = i + f = -1 + } + } + + if f != -1 && (len(vv)-f) > 1 { + md.renderValueExp(w, Param{Name: vv[f+1:]}) + } else { + _, _ = io.WriteString(w, vv[s:]) + } +} + +func (md *Metadata) renderValueExp(w io.Writer, p Param) { + _, _ = io.WriteString(w, `$`) + if v, ok := md.pindex[p.Name]; ok { + int32String(w, int32(v)) + + } else { + md.params = append(md.params, p) + n := len(md.params) + + if md.pindex == nil { + md.pindex = make(map[string]int) + } + md.pindex[p.Name] = n + int32String(w, int32(n)) + } +} + +func (md Metadata) Skipped() uint32 { + return md.skipped +} + +func (md Metadata) Params() []Param { + return md.params +} diff --git a/core/internal/psql/mutate.go b/core/internal/psql/mutate.go index f23dae4..e5023b3 100644 --- a/core/internal/psql/mutate.go +++ b/core/internal/psql/mutate.go @@ -432,11 +432,11 @@ func (c *compilerContext) renderInsertUpdateColumns( val := root.PresetMap[cn] switch { case ok && len(val) > 1 && val[0] == '$': - c.renderValueExp(Param{Name: val[1:], Type: col.Type}) + c.md.renderValueExp(c.w, Param{Name: val[1:], Type: col.Type}) case ok && strings.HasPrefix(val, "sql:"): io.WriteString(c.w, `(`) - c.renderVar(val[4:], c.renderValueExp) + c.md.RenderVar(c.w, val[4:]) io.WriteString(c.w, `)`) case ok: diff --git a/core/internal/psql/query.go b/core/internal/psql/query.go index 5824933..77ee4eb 100644 --- a/core/internal/psql/query.go +++ b/core/internal/psql/query.go @@ -25,8 +25,8 @@ type Param struct { } type Metadata struct { - Skipped uint32 - Params []Param + skipped uint32 + params []Param pindex map[string]int } @@ -80,26 +80,30 @@ func (co *Compiler) CompileEx(qc *qcode.QCode, vars Variables) (Metadata, []byte } func (co *Compiler) Compile(w io.Writer, qc *qcode.QCode, vars Variables) (Metadata, error) { + return co.CompileWithMetadata(w, qc, vars, Metadata{}) +} + +func (co *Compiler) CompileWithMetadata(w io.Writer, qc *qcode.QCode, vars Variables, md Metadata) (Metadata, error) { + md.skipped = 0 + if qc == nil { - return Metadata{}, fmt.Errorf("qcode is nil") + return md, fmt.Errorf("qcode is nil") } switch qc.Type { case qcode.QTQuery: - return co.compileQuery(w, qc, vars) + return co.compileQueryWithMetadata(w, qc, vars, md) case qcode.QTInsert, qcode.QTUpdate, qcode.QTDelete, qcode.QTUpsert: return co.compileMutation(w, qc, vars) + + default: + return Metadata{}, fmt.Errorf("Unknown operation type %d", qc.Type) } - return Metadata{}, fmt.Errorf("Unknown operation type %d", qc.Type) -} - -func (co *Compiler) compileQuery(w io.Writer, qc *qcode.QCode, vars Variables) (Metadata, error) { - return co.compileQueryWithMetadata(w, qc, vars, Metadata{}) } func (co *Compiler) compileQueryWithMetadata( @@ -176,7 +180,7 @@ func (co *Compiler) compileQueryWithMetadata( } for _, cid := range sel.Children { - if hasBit(c.md.Skipped, uint32(cid)) { + if hasBit(c.md.skipped, uint32(cid)) { continue } child := &c.s[cid] @@ -354,7 +358,7 @@ func (c *compilerContext) initSelect(sel *qcode.Select, ti *DBTableInfo, vars Va if _, ok := colmap[rel.Left.Col]; !ok { cols = append(cols, &qcode.Column{Table: ti.Name, Name: rel.Left.Col, FieldName: rel.Right.Col}) colmap[rel.Left.Col] = struct{}{} - c.md.Skipped |= (1 << uint(id)) + c.md.skipped |= (1 << uint(id)) } default: @@ -622,7 +626,7 @@ func (c *compilerContext) renderJoinColumns(sel *qcode.Select, ti *DBTableInfo, i := colsRendered for _, id := range sel.Children { - if hasBit(c.md.Skipped, uint32(id)) { + if hasBit(c.md.skipped, uint32(id)) { continue } childSel := &c.s[id] @@ -804,7 +808,7 @@ func (c *compilerContext) renderCursorCTE(sel *qcode.Select) error { quoted(c.w, ob.Col) } io.WriteString(c.w, ` FROM string_to_array(`) - c.renderValueExp(Param{Name: "cursor", Type: "json"}) + c.md.renderValueExp(c.w, Param{Name: "cursor", Type: "json"}) io.WriteString(c.w, `, ',') as a) `) return nil } @@ -1102,7 +1106,7 @@ func (c *compilerContext) renderOp(ex *qcode.Exp, ti *DBTableInfo) error { } else { io.WriteString(c.w, `) @@ to_tsquery(`) } - c.renderValueExp(Param{Name: ex.Val, Type: "string"}) + c.md.renderValueExp(c.w, Param{Name: ex.Val, Type: "string"}) io.WriteString(c.w, `))`) return nil @@ -1191,7 +1195,7 @@ func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string, col * switch { case ok && strings.HasPrefix(val, "sql:"): io.WriteString(c.w, `(`) - c.renderVar(val[4:], c.renderValueExp) + c.md.RenderVar(c.w, val[4:]) io.WriteString(c.w, `)`) case ok: @@ -1199,7 +1203,7 @@ func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string, col * case ex.Op == qcode.OpIn || ex.Op == qcode.OpNotIn: io.WriteString(c.w, `(ARRAY(SELECT json_array_elements_text(`) - c.renderValueExp(Param{Name: ex.Val, Type: col.Type, IsArray: true}) + c.md.renderValueExp(c.w, Param{Name: ex.Val, Type: col.Type, IsArray: true}) io.WriteString(c.w, `))`) io.WriteString(c.w, ` :: `) @@ -1208,7 +1212,7 @@ func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string, col * return default: - c.renderValueExp(Param{Name: ex.Val, Type: col.Type, IsArray: false}) + c.md.renderValueExp(c.w, Param{Name: ex.Val, Type: col.Type, IsArray: false}) } case qcode.ValRef: @@ -1222,54 +1226,6 @@ func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string, col * io.WriteString(c.w, col.Type) } -func (c *compilerContext) renderValueExp(p Param) { - io.WriteString(c.w, `$`) - if v, ok := c.md.pindex[p.Name]; ok { - int32String(c.w, int32(v)) - - } else { - c.md.Params = append(c.md.Params, p) - n := len(c.md.Params) - - if c.md.pindex == nil { - c.md.pindex = make(map[string]int) - } - c.md.pindex[p.Name] = n - int32String(c.w, int32(n)) - } -} - -func (c *compilerContext) renderVar(vv string, fn func(Param)) { - f, s := -1, 0 - - for i := range vv { - v := vv[i] - switch { - case (i > 0 && vv[i-1] != '\\' && v == '$') || v == '$': - if (i - s) > 0 { - io.WriteString(c.w, vv[s:i]) - } - f = i - - case (v < 'a' && v > 'z') && - (v < 'A' && v > 'Z') && - (v < '0' && v > '9') && - v != '_' && - f != -1 && - (i-f) > 1: - fn(Param{Name: vv[f+1 : i]}) - s = i - f = -1 - } - } - - if f != -1 && (len(vv)-f) > 1 { - fn(Param{Name: vv[f+1:]}) - } else { - io.WriteString(c.w, vv[s:]) - } -} - func funcPrefixLen(fm map[string]*DBFunction, fn string) int { switch { case strings.HasPrefix(fn, "avg_"): diff --git a/core/internal/psql/update.go b/core/internal/psql/update.go index 21395d9..cbdaf98 100644 --- a/core/internal/psql/update.go +++ b/core/internal/psql/update.go @@ -22,7 +22,7 @@ func (c *compilerContext) renderUpdate( } io.WriteString(c.w, `WITH "_sg_input" AS (SELECT `) - c.renderValueExp(Param{Name: qc.ActionVar, Type: "json"}) + c.md.renderValueExp(c.w, Param{Name: qc.ActionVar, Type: "json"}) // io.WriteString(c.w, qc.ActionVar) io.WriteString(c.w, ` :: json AS j)`) diff --git a/core/internal/qcode/parse_test.go b/core/internal/qcode/parse_test.go index c146dc7..c465e95 100644 --- a/core/internal/qcode/parse_test.go +++ b/core/internal/qcode/parse_test.go @@ -2,8 +2,9 @@ package qcode import ( "errors" - "github.com/chirino/graphql/schema" "testing" + + "github.com/chirino/graphql/schema" ) func TestCompile1(t *testing.T) { @@ -130,6 +131,22 @@ updateThread { } +func TestFragmentsCompile(t *testing.T) { + gql := ` +fragment userFields on user { + name + email +} + +query { users { ...userFields } }` + qcompile, _ := NewCompiler(Config{}) + _, err := qcompile.Compile([]byte(gql), "anon") + + if err == nil { + t.Fatal(errors.New("expecting an error")) + } +} + var gql = []byte(` {products( # returns only 30 items diff --git a/core/prepare.go b/core/prepare.go index 2f6f235..a2f57ac 100644 --- a/core/prepare.go +++ b/core/prepare.go @@ -125,7 +125,7 @@ func (sg *SuperGraph) prepareRoleStmt() error { } io.WriteString(w, ` ELSE $2 END) FROM (`) - io.WriteString(w, sg.conf.RolesQuery) + io.WriteString(w, rq) io.WriteString(w, `) AS "_sg_auth_roles_query" LIMIT 1) `) io.WriteString(w, `ELSE 'anon' END) FROM (VALUES (1)) AS "_sg_auth_filler" LIMIT 1; `) diff --git a/core/remote.go b/core/remote.go index e905520..5515772 100644 --- a/core/remote.go +++ b/core/remote.go @@ -22,7 +22,7 @@ func (sg *SuperGraph) execRemoteJoin(st *stmt, data []byte, hdr http.Header) ([] // fetch the field name used within the db response json // that are used to mark insertion points and the mapping between // those field names and their select objects - fids, sfmap := sg.parentFieldIds(&h, sel, st.md.Skipped) + fids, sfmap := sg.parentFieldIds(&h, sel, st.md.Skipped()) // fetch the field values of the marked insertion points // these values contain the id to be used with fetching remote data @@ -67,7 +67,7 @@ func (sg *SuperGraph) resolveRemote( to := toA[:1] // use the json key to find the related Select object - h.Write(field.Key) + _, _ = h.Write(field.Key) k1 := h.Sum64() s, ok := sfmap[k1] @@ -136,7 +136,7 @@ func (sg *SuperGraph) resolveRemotes( for i, id := range from { // use the json key to find the related Select object - h.Write(id.Key) + _, _ = h.Write(id.Key) k1 := h.Sum64() s, ok := sfmap[k1] @@ -230,7 +230,7 @@ func (sg *SuperGraph) parentFieldIds(h *maphash.Hash, sel []qcode.Select, skippe fm[n] = r.IDField n++ - h.Write(r.IDField) + _, _ = h.Write(r.IDField) sm[h.Sum64()] = s } } diff --git a/core/resolve.go b/core/resolve.go index bdb4857..0f4e804 100644 --- a/core/resolve.go +++ b/core/resolve.go @@ -86,7 +86,7 @@ func (sg *SuperGraph) initRemotes(t Table) error { sg.rmap[mkkey(&h, r.Name, t.Name)] = rf // index resolver obj by IDField - h.Write(rf.IDField) + _, _ = h.Write(rf.IDField) sg.rmap[h.Sum64()] = rf } diff --git a/jsn/filter.go b/jsn/filter.go index 6b9c6dd..60ddc12 100644 --- a/jsn/filter.go +++ b/jsn/filter.go @@ -12,7 +12,7 @@ func Filter(w *bytes.Buffer, b []byte, keys []string) error { h := maphash.Hash{} for i := range keys { - h.WriteString(keys[i]) + _, _ = h.WriteString(keys[i]) kmap[h.Sum64()] = struct{}{} h.Reset() } @@ -134,7 +134,7 @@ func Filter(w *bytes.Buffer, b []byte, keys []string) error { cb := b[s:(e + 1)] e = 0 - h.Write(k) + _, _ = h.Write(k) _, ok := kmap[h.Sum64()] h.Reset() diff --git a/jsn/get.go b/jsn/get.go index 52f00ed..a9b3381 100644 --- a/jsn/get.go +++ b/jsn/get.go @@ -44,7 +44,7 @@ func Get(b []byte, keys [][]byte) []Field { h := maphash.Hash{} for i := range keys { - h.Write(keys[i]) + _, _ = h.Write(keys[i]) kmap[h.Sum64()] = struct{}{} h.Reset() } @@ -144,7 +144,7 @@ func Get(b []byte, keys [][]byte) []Field { } if e != 0 { - h.Write(k) + _, _ = h.Write(k) _, ok := kmap[h.Sum64()] h.Reset()