From 6bc66d28bc83a82d99cc521318d1acfd3260c7ca Mon Sep 17 00:00:00 2001 From: Vikram Rangnekar Date: Thu, 24 Oct 2019 02:07:42 -0400 Subject: [PATCH] Get RBAC working for queries and mutations --- config/dev.yml | 17 +- config/prod.yml | 4 - psql/mutate.go | 83 ++++---- psql/query.go | 463 ++++++++++++++++++++++----------------------- serv/allow.go | 8 +- serv/auth.go | 34 +++- serv/auth_jwt.go | 6 - serv/auth_rails.go | 23 +-- serv/cmd.go | 27 ++- serv/cmd_seed.go | 3 +- serv/config.go | 55 +++--- serv/core.go | 293 +++++++++++++++------------- serv/core_build.go | 144 ++++++++++++++ serv/http.go | 7 +- serv/prepare.go | 125 ++++++++---- serv/serv.go | 2 +- serv/utils.go | 29 ++- serv/vars.go | 24 ++- tmpl/prod.yml | 123 ++++++++---- 19 files changed, 902 insertions(+), 568 deletions(-) create mode 100644 serv/core_build.go diff --git a/config/dev.yml b/config/dev.yml index c94e4bb..ade488b 100644 --- a/config/dev.yml +++ b/config/dev.yml @@ -22,7 +22,7 @@ enable_tracing: true # Watch the config folder and reload Super Graph # with the new configs when a change is detected -reload_on_config_change: false +reload_on_config_change: true # File that points to the database seeding script # seed_file: seed.js @@ -53,7 +53,7 @@ auth: # Comment this out if you want to disable setting # the user_id via a header. Good for testing - header: X-User-ID + creds_in_header: true rails: # Rails version this is used for reading the @@ -143,6 +143,8 @@ tables: name: me table: users +roles_query: "SELECT * FROM users as usr WHERE id = $user_id" + roles: - name: anon tables: @@ -164,6 +166,10 @@ roles: - name: user tables: + - name: users + query: + filter: ["{ id: { _eq: $user_id } }"] + - name: products query: @@ -189,9 +195,10 @@ roles: delete: deny: true - - name: manager + - name: admin + match: id = 1 tables: - name: users - select: - filter: ["{ account_id: { _eq: $account_id } }"] + # select: + # filter: ["{ account_id: { _eq: $account_id } }"] diff --git a/config/prod.yml b/config/prod.yml index fa5f932..a52af3d 100644 --- a/config/prod.yml +++ b/config/prod.yml @@ -47,10 +47,6 @@ auth: type: rails cookie: _app_session - # Comment this out if you want to disable setting - # the user_id via a header. Good for testing - header: X-User-ID - rails: # Rails version this is used for reading the # various cookies formats. diff --git a/psql/mutate.go b/psql/mutate.go index 63b3a5a..067270e 100644 --- a/psql/mutate.go +++ b/psql/mutate.go @@ -1,7 +1,6 @@ package psql import ( - "bytes" "errors" "fmt" "io" @@ -12,7 +11,7 @@ import ( var noLimit = qcode.Paging{NoLimit: true} -func (co *Compiler) compileMutation(qc *qcode.QCode, w *bytes.Buffer, vars Variables) (uint32, error) { +func (co *Compiler) compileMutation(qc *qcode.QCode, w io.Writer, vars Variables) (uint32, error) { if len(qc.Selects) == 0 { return 0, errors.New("empty query") } @@ -25,9 +24,9 @@ func (co *Compiler) compileMutation(qc *qcode.QCode, w *bytes.Buffer, vars Varia return 0, err } - c.w.WriteString(`WITH `) + io.WriteString(c.w, `WITH `) quoted(c.w, ti.Name) - c.w.WriteString(` AS `) + io.WriteString(c.w, ` AS `) switch qc.Type { case qcode.QTInsert: @@ -67,7 +66,7 @@ func (co *Compiler) compileMutation(qc *qcode.QCode, w *bytes.Buffer, vars Varia return c.compileQuery(qc, w) } -func (c *compilerContext) renderInsert(qc *qcode.QCode, w *bytes.Buffer, +func (c *compilerContext) renderInsert(qc *qcode.QCode, w io.Writer, vars Variables, ti *DBTableInfo) (uint32, error) { insert, ok := vars[qc.ActionVar] @@ -80,32 +79,32 @@ func (c *compilerContext) renderInsert(qc *qcode.QCode, w *bytes.Buffer, return 0, err } - c.w.WriteString(`(WITH "input" AS (SELECT {{`) - c.w.WriteString(qc.ActionVar) - c.w.WriteString(`}}::json AS j) INSERT INTO `) + io.WriteString(c.w, `(WITH "input" AS (SELECT {{`) + io.WriteString(c.w, qc.ActionVar) + io.WriteString(c.w, `}}::json AS j) INSERT INTO `) quoted(c.w, ti.Name) io.WriteString(c.w, ` (`) c.renderInsertUpdateColumns(qc, w, jt, ti) io.WriteString(c.w, `)`) - c.w.WriteString(` SELECT `) + io.WriteString(c.w, ` SELECT `) c.renderInsertUpdateColumns(qc, w, jt, ti) - c.w.WriteString(` FROM input i, `) + io.WriteString(c.w, ` FROM input i, `) if array { - c.w.WriteString(`json_populate_recordset`) + io.WriteString(c.w, `json_populate_recordset`) } else { - c.w.WriteString(`json_populate_record`) + io.WriteString(c.w, `json_populate_record`) } - c.w.WriteString(`(NULL::`) - c.w.WriteString(ti.Name) - c.w.WriteString(`, i.j) t`) + io.WriteString(c.w, `(NULL::`) + io.WriteString(c.w, ti.Name) + io.WriteString(c.w, `, i.j) t`) return 0, nil } -func (c *compilerContext) renderInsertUpdateColumns(qc *qcode.QCode, w *bytes.Buffer, +func (c *compilerContext) renderInsertUpdateColumns(qc *qcode.QCode, w io.Writer, jt map[string]interface{}, ti *DBTableInfo) (uint32, error) { root := &qc.Selects[0] @@ -122,14 +121,14 @@ func (c *compilerContext) renderInsertUpdateColumns(qc *qcode.QCode, w *bytes.Bu if i != 0 { io.WriteString(c.w, `, `) } - c.w.WriteString(cn) + io.WriteString(c.w, cn) i++ } return 0, nil } -func (c *compilerContext) renderUpdate(qc *qcode.QCode, w *bytes.Buffer, +func (c *compilerContext) renderUpdate(qc *qcode.QCode, w io.Writer, vars Variables, ti *DBTableInfo) (uint32, error) { root := &qc.Selects[0] @@ -143,26 +142,26 @@ func (c *compilerContext) renderUpdate(qc *qcode.QCode, w *bytes.Buffer, return 0, err } - c.w.WriteString(`(WITH "input" AS (SELECT {{`) - c.w.WriteString(qc.ActionVar) - c.w.WriteString(`}}::json AS j) UPDATE `) + io.WriteString(c.w, `(WITH "input" AS (SELECT {{`) + io.WriteString(c.w, qc.ActionVar) + io.WriteString(c.w, `}}::json AS j) UPDATE `) quoted(c.w, ti.Name) io.WriteString(c.w, ` SET (`) c.renderInsertUpdateColumns(qc, w, jt, ti) - c.w.WriteString(`) = (SELECT `) + io.WriteString(c.w, `) = (SELECT `) c.renderInsertUpdateColumns(qc, w, jt, ti) - c.w.WriteString(` FROM input i, `) + io.WriteString(c.w, ` FROM input i, `) if array { - c.w.WriteString(`json_populate_recordset`) + io.WriteString(c.w, `json_populate_recordset`) } else { - c.w.WriteString(`json_populate_record`) + io.WriteString(c.w, `json_populate_record`) } - c.w.WriteString(`(NULL::`) - c.w.WriteString(ti.Name) - c.w.WriteString(`, i.j) t)`) + io.WriteString(c.w, `(NULL::`) + io.WriteString(c.w, ti.Name) + io.WriteString(c.w, `, i.j) t)`) io.WriteString(c.w, ` WHERE `) @@ -173,11 +172,11 @@ func (c *compilerContext) renderUpdate(qc *qcode.QCode, w *bytes.Buffer, return 0, nil } -func (c *compilerContext) renderDelete(qc *qcode.QCode, w *bytes.Buffer, +func (c *compilerContext) renderDelete(qc *qcode.QCode, w io.Writer, vars Variables, ti *DBTableInfo) (uint32, error) { root := &qc.Selects[0] - c.w.WriteString(`(DELETE FROM `) + io.WriteString(c.w, `(DELETE FROM `) quoted(c.w, ti.Name) io.WriteString(c.w, ` WHERE `) @@ -188,7 +187,7 @@ func (c *compilerContext) renderDelete(qc *qcode.QCode, w *bytes.Buffer, return 0, nil } -func (c *compilerContext) renderUpsert(qc *qcode.QCode, w *bytes.Buffer, +func (c *compilerContext) renderUpsert(qc *qcode.QCode, w io.Writer, vars Variables, ti *DBTableInfo) (uint32, error) { upsert, ok := vars[qc.ActionVar] @@ -205,7 +204,7 @@ func (c *compilerContext) renderUpsert(qc *qcode.QCode, w *bytes.Buffer, return 0, err } - c.w.WriteString(` ON CONFLICT DO (`) + io.WriteString(c.w, ` ON CONFLICT DO (`) i := 0 for _, cn := range ti.ColumnNames { @@ -220,15 +219,15 @@ func (c *compilerContext) renderUpsert(qc *qcode.QCode, w *bytes.Buffer, if i != 0 { io.WriteString(c.w, `, `) } - c.w.WriteString(cn) + io.WriteString(c.w, cn) i++ } if i == 0 { - c.w.WriteString(ti.PrimaryCol) + io.WriteString(c.w, ti.PrimaryCol) } - c.w.WriteString(`) DO `) + io.WriteString(c.w, `) DO `) - c.w.WriteString(`UPDATE `) + io.WriteString(c.w, `UPDATE `) io.WriteString(c.w, ` SET `) i = 0 @@ -239,17 +238,17 @@ func (c *compilerContext) renderUpsert(qc *qcode.QCode, w *bytes.Buffer, if i != 0 { io.WriteString(c.w, `, `) } - c.w.WriteString(cn) + io.WriteString(c.w, cn) io.WriteString(c.w, ` = EXCLUDED.`) - c.w.WriteString(cn) + io.WriteString(c.w, cn) i++ } return 0, nil } -func quoted(w *bytes.Buffer, identifier string) { - w.WriteString(`"`) - w.WriteString(identifier) - w.WriteString(`"`) +func quoted(w io.Writer, identifier string) { + io.WriteString(w, `"`) + io.WriteString(w, identifier) + io.WriteString(w, `"`) } diff --git a/psql/query.go b/psql/query.go index 70ea0a2..c06ab95 100644 --- a/psql/query.go +++ b/psql/query.go @@ -49,7 +49,7 @@ func (c *Compiler) IDColumn(table string) (string, error) { } type compilerContext struct { - w *bytes.Buffer + w io.Writer s []qcode.Select *Compiler } @@ -60,7 +60,7 @@ func (co *Compiler) CompileEx(qc *qcode.QCode, vars Variables) (uint32, []byte, return skipped, w.Bytes(), err } -func (co *Compiler) Compile(qc *qcode.QCode, w *bytes.Buffer, vars Variables) (uint32, error) { +func (co *Compiler) Compile(qc *qcode.QCode, w io.Writer, vars Variables) (uint32, error) { switch qc.Type { case qcode.QTQuery: return co.compileQuery(qc, w) @@ -71,7 +71,7 @@ func (co *Compiler) Compile(qc *qcode.QCode, w *bytes.Buffer, vars Variables) (u return 0, fmt.Errorf("Unknown operation type %d", qc.Type) } -func (co *Compiler) compileQuery(qc *qcode.QCode, w *bytes.Buffer) (uint32, error) { +func (co *Compiler) compileQuery(qc *qcode.QCode, w io.Writer) (uint32, error) { if len(qc.Selects) == 0 { return 0, errors.New("empty query") } @@ -90,17 +90,17 @@ func (co *Compiler) compileQuery(qc *qcode.QCode, w *bytes.Buffer) (uint32, erro //fmt.Fprintf(w, `SELECT json_object_agg('%s', %s) FROM (`, //root.FieldName, root.Table) - c.w.WriteString(`SELECT json_object_agg('`) - c.w.WriteString(root.FieldName) - c.w.WriteString(`', `) + io.WriteString(c.w, `SELECT json_object_agg('`) + io.WriteString(c.w, root.FieldName) + io.WriteString(c.w, `', `) if ti.Singular == false { - c.w.WriteString(root.Table) + io.WriteString(c.w, root.Table) } else { - c.w.WriteString("sel_json_") + io.WriteString(c.w, "sel_json_") int2string(c.w, root.ID) } - c.w.WriteString(`) FROM (`) + io.WriteString(c.w, `) FROM (`) var ignored uint32 @@ -161,9 +161,8 @@ func (co *Compiler) compileQuery(qc *qcode.QCode, w *bytes.Buffer) (uint32, erro } } - c.w.WriteString(`)`) + io.WriteString(c.w, `)`) alias(c.w, `done_1337`) - c.w.WriteString(`;`) return ignored, nil } @@ -219,10 +218,10 @@ func (c *compilerContext) renderSelect(sel *qcode.Select, ti *DBTableInfo) (uint // SELECT if ti.Singular == false { //fmt.Fprintf(w, `SELECT coalesce(json_agg("%s"`, c.sel.Table) - c.w.WriteString(`SELECT coalesce(json_agg("`) - c.w.WriteString("sel_json_") + io.WriteString(c.w, `SELECT coalesce(json_agg("`) + io.WriteString(c.w, "sel_json_") int2string(c.w, sel.ID) - c.w.WriteString(`"`) + io.WriteString(c.w, `"`) if hasOrder { err := c.renderOrderBy(sel, ti) @@ -232,24 +231,24 @@ func (c *compilerContext) renderSelect(sel *qcode.Select, ti *DBTableInfo) (uint } //fmt.Fprintf(w, `), '[]') AS "%s" FROM (`, c.sel.Table) - c.w.WriteString(`), '[]')`) + io.WriteString(c.w, `), '[]')`) alias(c.w, sel.Table) - c.w.WriteString(` FROM (`) + io.WriteString(c.w, ` FROM (`) } // ROW-TO-JSON - c.w.WriteString(`SELECT `) + io.WriteString(c.w, `SELECT `) if len(sel.DistinctOn) != 0 { c.renderDistinctOn(sel, ti) } - c.w.WriteString(`row_to_json((`) + io.WriteString(c.w, `row_to_json((`) //fmt.Fprintf(w, `SELECT "sel_%d" FROM (SELECT `, c.sel.ID) - c.w.WriteString(`SELECT "sel_`) + io.WriteString(c.w, `SELECT "sel_`) int2string(c.w, sel.ID) - c.w.WriteString(`" FROM (SELECT `) + io.WriteString(c.w, `" FROM (SELECT `) // Combined column names c.renderColumns(sel, ti) @@ -262,11 +261,11 @@ func (c *compilerContext) renderSelect(sel *qcode.Select, ti *DBTableInfo) (uint } //fmt.Fprintf(w, `) AS "sel_%d"`, c.sel.ID) - c.w.WriteString(`)`) + io.WriteString(c.w, `)`) aliasWithID(c.w, "sel", sel.ID) //fmt.Fprintf(w, `)) AS "%s"`, c.sel.Table) - c.w.WriteString(`))`) + io.WriteString(c.w, `))`) aliasWithID(c.w, "sel_json", sel.ID) // END-ROW-TO-JSON @@ -301,27 +300,27 @@ func (c *compilerContext) renderSelectClose(sel *qcode.Select, ti *DBTableInfo) case len(sel.Paging.Limit) != 0: //fmt.Fprintf(w, ` LIMIT ('%s') :: integer`, c.sel.Paging.Limit) - c.w.WriteString(` LIMIT ('`) - c.w.WriteString(sel.Paging.Limit) - c.w.WriteString(`') :: integer`) + io.WriteString(c.w, ` LIMIT ('`) + io.WriteString(c.w, sel.Paging.Limit) + io.WriteString(c.w, `') :: integer`) case ti.Singular: - c.w.WriteString(` LIMIT ('1') :: integer`) + io.WriteString(c.w, ` LIMIT ('1') :: integer`) default: - c.w.WriteString(` LIMIT ('20') :: integer`) + io.WriteString(c.w, ` LIMIT ('20') :: integer`) } if len(sel.Paging.Offset) != 0 { //fmt.Fprintf(w, ` OFFSET ('%s') :: integer`, c.sel.Paging.Offset) - c.w.WriteString(`OFFSET ('`) - c.w.WriteString(sel.Paging.Offset) - c.w.WriteString(`') :: integer`) + io.WriteString(c.w, `OFFSET ('`) + io.WriteString(c.w, sel.Paging.Offset) + io.WriteString(c.w, `') :: integer`) } if ti.Singular == false { //fmt.Fprintf(w, `) AS "sel_json_agg_%d"`, c.sel.ID) - c.w.WriteString(`)`) + io.WriteString(c.w, `)`) aliasWithID(c.w, "sel_json_agg", sel.ID) } @@ -329,15 +328,15 @@ func (c *compilerContext) renderSelectClose(sel *qcode.Select, ti *DBTableInfo) } func (c *compilerContext) renderJoin(sel *qcode.Select) error { - c.w.WriteString(` LEFT OUTER JOIN LATERAL (`) + io.WriteString(c.w, ` LEFT OUTER JOIN LATERAL (`) return nil } func (c *compilerContext) renderJoinClose(sel *qcode.Select) error { //fmt.Fprintf(w, `) AS "%s_%d_join" ON ('true')`, c.sel.Table, c.sel.ID) - c.w.WriteString(`)`) + io.WriteString(c.w, `)`) aliasWithIDSuffix(c.w, sel.Table, sel.ID, "_join") - c.w.WriteString(` ON ('true')`) + io.WriteString(c.w, ` ON ('true')`) return nil } @@ -360,13 +359,13 @@ func (c *compilerContext) renderJoinTable(sel *qcode.Select) error { //fmt.Fprintf(w, ` LEFT OUTER JOIN "%s" ON (("%s"."%s") = ("%s_%d"."%s"))`, //rel.Through, rel.Through, rel.ColT, c.parent.Table, c.parent.ID, rel.Col1) - c.w.WriteString(` LEFT OUTER JOIN "`) - c.w.WriteString(rel.Through) - c.w.WriteString(`" ON ((`) + io.WriteString(c.w, ` LEFT OUTER JOIN "`) + io.WriteString(c.w, rel.Through) + io.WriteString(c.w, `" ON ((`) colWithTable(c.w, rel.Through, rel.ColT) - c.w.WriteString(`) = (`) + io.WriteString(c.w, `) = (`) colWithTableID(c.w, pt.Name, parent.ID, rel.Col1) - c.w.WriteString(`))`) + io.WriteString(c.w, `))`) return nil } @@ -443,11 +442,11 @@ func (c *compilerContext) renderJoinedColumns(sel *qcode.Select, ti *DBTableInfo //fmt.Fprintf(w, `"%s_%d_join"."%s" AS "%s"`, //s.Table, s.ID, s.Table, s.FieldName) if cti.Singular { - c.w.WriteString(`"sel_json_`) + io.WriteString(c.w, `"sel_json_`) int2string(c.w, childSel.ID) - c.w.WriteString(`" AS "`) - c.w.WriteString(childSel.FieldName) - c.w.WriteString(`"`) + io.WriteString(c.w, `" AS "`) + io.WriteString(c.w, childSel.FieldName) + io.WriteString(c.w, `"`) } else { colWithTableIDSuffixAlias(c.w, childSel.Table, childSel.ID, @@ -467,7 +466,7 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo, isSearch := sel.Args["search"] != nil isAgg := false - c.w.WriteString(` FROM (SELECT `) + io.WriteString(c.w, ` FROM (SELECT `) i := 0 for n, col := range sel.Cols { @@ -483,15 +482,15 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo, arg := sel.Args["search"] if i != 0 { - c.w.WriteString(`, `) + io.WriteString(c.w, `, `) } //fmt.Fprintf(w, `ts_rank("%s"."%s", to_tsquery('%s')) AS %s`, //c.sel.Table, cn, arg.Val, col.Name) - c.w.WriteString(`ts_rank(`) + io.WriteString(c.w, `ts_rank(`) colWithTable(c.w, ti.Name, cn) - c.w.WriteString(`, to_tsquery('`) - c.w.WriteString(arg.Val) - c.w.WriteString(`')`) + io.WriteString(c.w, `, to_tsquery('`) + io.WriteString(c.w, arg.Val) + io.WriteString(c.w, `')`) alias(c.w, col.Name) i++ @@ -500,15 +499,15 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo, arg := sel.Args["search"] if i != 0 { - c.w.WriteString(`, `) + io.WriteString(c.w, `, `) } //fmt.Fprintf(w, `ts_headline("%s"."%s", to_tsquery('%s')) AS %s`, //c.sel.Table, cn, arg.Val, col.Name) - c.w.WriteString(`ts_headlinek(`) + io.WriteString(c.w, `ts_headlinek(`) colWithTable(c.w, ti.Name, cn) - c.w.WriteString(`, to_tsquery('`) - c.w.WriteString(arg.Val) - c.w.WriteString(`')`) + io.WriteString(c.w, `, to_tsquery('`) + io.WriteString(c.w, arg.Val) + io.WriteString(c.w, `')`) alias(c.w, col.Name) i++ @@ -517,12 +516,12 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo, pl := funcPrefixLen(cn) if pl == 0 { if i != 0 { - c.w.WriteString(`, `) + io.WriteString(c.w, `, `) } //fmt.Fprintf(w, `'%s not defined' AS %s`, cn, col.Name) - c.w.WriteString(`'`) - c.w.WriteString(cn) - c.w.WriteString(` not defined'`) + io.WriteString(c.w, `'`) + io.WriteString(c.w, cn) + io.WriteString(c.w, ` not defined'`) alias(c.w, col.Name) i++ @@ -532,16 +531,16 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo, continue } if i != 0 { - c.w.WriteString(`, `) + io.WriteString(c.w, `, `) } fn := cn[0 : pl-1] isAgg = true //fmt.Fprintf(w, `%s("%s"."%s") AS %s`, fn, c.sel.Table, cn, col.Name) - c.w.WriteString(fn) - c.w.WriteString(`(`) + io.WriteString(c.w, fn) + io.WriteString(c.w, `(`) colWithTable(c.w, ti.Name, cn1) - c.w.WriteString(`)`) + io.WriteString(c.w, `)`) alias(c.w, col.Name) i++ @@ -551,7 +550,7 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo, groupBy = append(groupBy, n) //fmt.Fprintf(w, `"%s"."%s"`, c.sel.Table, cn) if i != 0 { - c.w.WriteString(`, `) + io.WriteString(c.w, `, `) } colWithTable(c.w, ti.Name, cn) i++ @@ -561,7 +560,7 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo, for _, col := range childCols { if i != 0 { - c.w.WriteString(`, `) + io.WriteString(c.w, `, `) } //fmt.Fprintf(w, `"%s"."%s"`, col.Table, col.Name) @@ -569,29 +568,29 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo, i++ } - c.w.WriteString(` FROM `) + io.WriteString(c.w, ` FROM `) //fmt.Fprintf(w, ` FROM "%s"`, c.sel.Table) - c.w.WriteString(`"`) - c.w.WriteString(ti.Name) - c.w.WriteString(`"`) + io.WriteString(c.w, `"`) + io.WriteString(c.w, ti.Name) + io.WriteString(c.w, `"`) // if tn, ok := c.tmap[sel.Table]; ok { // //fmt.Fprintf(w, ` FROM "%s" AS "%s"`, tn, c.sel.Table) // tableWithAlias(c.w, ti.Name, sel.Table) // } else { // //fmt.Fprintf(w, ` FROM "%s"`, c.sel.Table) - // c.w.WriteString(`"`) - // c.w.WriteString(sel.Table) - // c.w.WriteString(`"`) + // io.WriteString(c.w, `"`) + // io.WriteString(c.w, sel.Table) + // io.WriteString(c.w, `"`) // } if isRoot && isFil { - c.w.WriteString(` WHERE (`) + io.WriteString(c.w, ` WHERE (`) if err := c.renderWhere(sel, ti); err != nil { return err } - c.w.WriteString(`)`) + io.WriteString(c.w, `)`) } if !isRoot { @@ -599,28 +598,28 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo, return err } - c.w.WriteString(` WHERE (`) + io.WriteString(c.w, ` WHERE (`) if err := c.renderRelationship(sel, ti); err != nil { return err } if isFil { - c.w.WriteString(` AND `) + io.WriteString(c.w, ` AND `) if err := c.renderWhere(sel, ti); err != nil { return err } } - c.w.WriteString(`)`) + io.WriteString(c.w, `)`) } if isAgg { if len(groupBy) != 0 { - c.w.WriteString(` GROUP BY `) + io.WriteString(c.w, ` GROUP BY `) for i, id := range groupBy { if i != 0 { - c.w.WriteString(`, `) + io.WriteString(c.w, `, `) } //fmt.Fprintf(w, `"%s"."%s"`, c.sel.Table, c.sel.Cols[id].Name) colWithTable(c.w, ti.Name, sel.Cols[id].Name) @@ -634,26 +633,26 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo, case len(sel.Paging.Limit) != 0: //fmt.Fprintf(w, ` LIMIT ('%s') :: integer`, c.sel.Paging.Limit) - c.w.WriteString(` LIMIT ('`) - c.w.WriteString(sel.Paging.Limit) - c.w.WriteString(`') :: integer`) + io.WriteString(c.w, ` LIMIT ('`) + io.WriteString(c.w, sel.Paging.Limit) + io.WriteString(c.w, `') :: integer`) case ti.Singular: - c.w.WriteString(` LIMIT ('1') :: integer`) + io.WriteString(c.w, ` LIMIT ('1') :: integer`) default: - c.w.WriteString(` LIMIT ('20') :: integer`) + io.WriteString(c.w, ` LIMIT ('20') :: integer`) } if len(sel.Paging.Offset) != 0 { //fmt.Fprintf(w, ` OFFSET ('%s') :: integer`, c.sel.Paging.Offset) - c.w.WriteString(` OFFSET ('`) - c.w.WriteString(sel.Paging.Offset) - c.w.WriteString(`') :: integer`) + io.WriteString(c.w, ` OFFSET ('`) + io.WriteString(c.w, sel.Paging.Offset) + io.WriteString(c.w, `') :: integer`) } //fmt.Fprintf(w, `) AS "%s_%d"`, c.sel.Table, c.sel.ID) - c.w.WriteString(`)`) + io.WriteString(c.w, `)`) aliasWithID(c.w, ti.Name, sel.ID) return nil } @@ -664,7 +663,7 @@ func (c *compilerContext) renderOrderByColumns(sel *qcode.Select, ti *DBTableInf for i := range sel.OrderBy { if colsRendered { //io.WriteString(w, ", ") - c.w.WriteString(`, `) + io.WriteString(c.w, `, `) } col := sel.OrderBy[i].Col @@ -672,7 +671,7 @@ func (c *compilerContext) renderOrderByColumns(sel *qcode.Select, ti *DBTableInf //c.sel.Table, c.sel.ID, c, //c.sel.Table, c.sel.ID, c) colWithTableID(c.w, ti.Name, sel.ID, col) - c.w.WriteString(` AS `) + io.WriteString(c.w, ` AS `) tableIDColSuffix(c.w, sel.Table, sel.ID, col, "_ob") } } @@ -689,29 +688,29 @@ func (c *compilerContext) renderRelationship(sel *qcode.Select, ti *DBTableInfo) case RelBelongTo: //fmt.Fprintf(w, `(("%s"."%s") = ("%s_%d"."%s"))`, //c.sel.Table, rel.Col1, c.parent.Table, c.parent.ID, rel.Col2) - c.w.WriteString(`((`) + io.WriteString(c.w, `((`) colWithTable(c.w, ti.Name, rel.Col1) - c.w.WriteString(`) = (`) + io.WriteString(c.w, `) = (`) colWithTableID(c.w, parent.Table, parent.ID, rel.Col2) - c.w.WriteString(`))`) + io.WriteString(c.w, `))`) case RelOneToMany: //fmt.Fprintf(w, `(("%s"."%s") = ("%s_%d"."%s"))`, //c.sel.Table, rel.Col1, c.parent.Table, c.parent.ID, rel.Col2) - c.w.WriteString(`((`) + io.WriteString(c.w, `((`) colWithTable(c.w, ti.Name, rel.Col1) - c.w.WriteString(`) = (`) + io.WriteString(c.w, `) = (`) colWithTableID(c.w, parent.Table, parent.ID, rel.Col2) - c.w.WriteString(`))`) + io.WriteString(c.w, `))`) case RelOneToManyThrough: //fmt.Fprintf(w, `(("%s"."%s") = ("%s"."%s"))`, //c.sel.Table, rel.Col1, rel.Through, rel.Col2) - c.w.WriteString(`((`) + io.WriteString(c.w, `((`) colWithTable(c.w, ti.Name, rel.Col1) - c.w.WriteString(`) = (`) + io.WriteString(c.w, `) = (`) colWithTable(c.w, rel.Through, rel.Col2) - c.w.WriteString(`))`) + io.WriteString(c.w, `))`) } return nil @@ -735,11 +734,11 @@ func (c *compilerContext) renderWhere(sel *qcode.Select, ti *DBTableInfo) error case qcode.ExpOp: switch val { case qcode.OpAnd: - c.w.WriteString(` AND `) + io.WriteString(c.w, ` AND `) case qcode.OpOr: - c.w.WriteString(` OR `) + io.WriteString(c.w, ` OR `) case qcode.OpNot: - c.w.WriteString(`NOT `) + io.WriteString(c.w, `NOT `) default: return fmt.Errorf("11: unexpected value %v (%t)", intf, intf) } @@ -763,62 +762,62 @@ func (c *compilerContext) renderWhere(sel *qcode.Select, ti *DBTableInfo) error default: if val.NestedCol { //fmt.Fprintf(w, `(("%s") `, val.Col) - c.w.WriteString(`(("`) - c.w.WriteString(val.Col) - c.w.WriteString(`") `) + io.WriteString(c.w, `(("`) + io.WriteString(c.w, val.Col) + io.WriteString(c.w, `") `) } else if len(val.Col) != 0 { //fmt.Fprintf(w, `(("%s"."%s") `, c.sel.Table, val.Col) - c.w.WriteString(`((`) + io.WriteString(c.w, `((`) colWithTable(c.w, ti.Name, val.Col) - c.w.WriteString(`) `) + io.WriteString(c.w, `) `) } valExists := true switch val.Op { case qcode.OpEquals: - c.w.WriteString(`=`) + io.WriteString(c.w, `=`) case qcode.OpNotEquals: - c.w.WriteString(`!=`) + io.WriteString(c.w, `!=`) case qcode.OpGreaterOrEquals: - c.w.WriteString(`>=`) + io.WriteString(c.w, `>=`) case qcode.OpLesserOrEquals: - c.w.WriteString(`<=`) + io.WriteString(c.w, `<=`) case qcode.OpGreaterThan: - c.w.WriteString(`>`) + io.WriteString(c.w, `>`) case qcode.OpLesserThan: - c.w.WriteString(`<`) + io.WriteString(c.w, `<`) case qcode.OpIn: - c.w.WriteString(`IN`) + io.WriteString(c.w, `IN`) case qcode.OpNotIn: - c.w.WriteString(`NOT IN`) + io.WriteString(c.w, `NOT IN`) case qcode.OpLike: - c.w.WriteString(`LIKE`) + io.WriteString(c.w, `LIKE`) case qcode.OpNotLike: - c.w.WriteString(`NOT LIKE`) + io.WriteString(c.w, `NOT LIKE`) case qcode.OpILike: - c.w.WriteString(`ILIKE`) + io.WriteString(c.w, `ILIKE`) case qcode.OpNotILike: - c.w.WriteString(`NOT ILIKE`) + io.WriteString(c.w, `NOT ILIKE`) case qcode.OpSimilar: - c.w.WriteString(`SIMILAR TO`) + io.WriteString(c.w, `SIMILAR TO`) case qcode.OpNotSimilar: - c.w.WriteString(`NOT SIMILAR TO`) + io.WriteString(c.w, `NOT SIMILAR TO`) case qcode.OpContains: - c.w.WriteString(`@>`) + io.WriteString(c.w, `@>`) case qcode.OpContainedIn: - c.w.WriteString(`<@`) + io.WriteString(c.w, `<@`) case qcode.OpHasKey: - c.w.WriteString(`?`) + io.WriteString(c.w, `?`) case qcode.OpHasKeyAny: - c.w.WriteString(`?|`) + io.WriteString(c.w, `?|`) case qcode.OpHasKeyAll: - c.w.WriteString(`?&`) + io.WriteString(c.w, `?&`) case qcode.OpIsNull: if strings.EqualFold(val.Val, "true") { - c.w.WriteString(`IS NULL)`) + io.WriteString(c.w, `IS NULL)`) } else { - c.w.WriteString(`IS NOT NULL)`) + io.WriteString(c.w, `IS NOT NULL)`) } valExists = false case qcode.OpEqID: @@ -826,20 +825,20 @@ func (c *compilerContext) renderWhere(sel *qcode.Select, ti *DBTableInfo) error return fmt.Errorf("no primary key column defined for %s", ti.Name) } //fmt.Fprintf(w, `(("%s") =`, c.ti.PrimaryCol) - c.w.WriteString(`((`) + io.WriteString(c.w, `((`) colWithTable(c.w, ti.Name, ti.PrimaryCol) - //c.w.WriteString(ti.PrimaryCol) - c.w.WriteString(`) =`) + //io.WriteString(c.w, ti.PrimaryCol) + io.WriteString(c.w, `) =`) case qcode.OpTsQuery: if len(ti.TSVCol) == 0 { return fmt.Errorf("no tsv column defined for %s", ti.Name) } //fmt.Fprintf(w, `(("%s") @@ to_tsquery('%s'))`, c.ti.TSVCol, val.Val) - c.w.WriteString(`(("`) - c.w.WriteString(ti.TSVCol) - c.w.WriteString(`") @@ to_tsquery('`) - c.w.WriteString(val.Val) - c.w.WriteString(`'))`) + io.WriteString(c.w, `(("`) + io.WriteString(c.w, ti.TSVCol) + io.WriteString(c.w, `") @@ to_tsquery('`) + io.WriteString(c.w, val.Val) + io.WriteString(c.w, `'))`) valExists = false default: @@ -852,7 +851,7 @@ func (c *compilerContext) renderWhere(sel *qcode.Select, ti *DBTableInfo) error } else { c.renderVal(val, c.vars) } - c.w.WriteString(`)`) + io.WriteString(c.w, `)`) } qcode.FreeExp(val) @@ -868,10 +867,10 @@ func (c *compilerContext) renderWhere(sel *qcode.Select, ti *DBTableInfo) error } func (c *compilerContext) renderOrderBy(sel *qcode.Select, ti *DBTableInfo) error { - c.w.WriteString(` ORDER BY `) + io.WriteString(c.w, ` ORDER BY `) for i := range sel.OrderBy { if i != 0 { - c.w.WriteString(`, `) + io.WriteString(c.w, `, `) } ob := sel.OrderBy[i] @@ -879,27 +878,27 @@ func (c *compilerContext) renderOrderBy(sel *qcode.Select, ti *DBTableInfo) erro case qcode.OrderAsc: //fmt.Fprintf(w, `"%s_%d.ob.%s" ASC`, sel.Table, sel.ID, ob.Col) tableIDColSuffix(c.w, ti.Name, sel.ID, ob.Col, "_ob") - c.w.WriteString(` ASC`) + io.WriteString(c.w, ` ASC`) case qcode.OrderDesc: //fmt.Fprintf(w, `"%s_%d.ob.%s" DESC`, sel.Table, sel.ID, ob.Col) tableIDColSuffix(c.w, ti.Name, sel.ID, ob.Col, "_ob") - c.w.WriteString(` DESC`) + io.WriteString(c.w, ` DESC`) case qcode.OrderAscNullsFirst: //fmt.Fprintf(w, `"%s_%d.ob.%s" ASC NULLS FIRST`, sel.Table, sel.ID, ob.Col) tableIDColSuffix(c.w, ti.Name, sel.ID, ob.Col, "_ob") - c.w.WriteString(` ASC NULLS FIRST`) + io.WriteString(c.w, ` ASC NULLS FIRST`) case qcode.OrderDescNullsFirst: //fmt.Fprintf(w, `%s_%d.ob.%s DESC NULLS FIRST`, sel.Table, sel.ID, ob.Col) tableIDColSuffix(c.w, ti.Name, sel.ID, ob.Col, "_ob") - c.w.WriteString(` DESC NULLLS FIRST`) + io.WriteString(c.w, ` DESC NULLLS FIRST`) case qcode.OrderAscNullsLast: //fmt.Fprintf(w, `"%s_%d.ob.%s ASC NULLS LAST`, sel.Table, sel.ID, ob.Col) tableIDColSuffix(c.w, ti.Name, sel.ID, ob.Col, "_ob") - c.w.WriteString(` ASC NULLS LAST`) + io.WriteString(c.w, ` ASC NULLS LAST`) case qcode.OrderDescNullsLast: //fmt.Fprintf(w, `%s_%d.ob.%s DESC NULLS LAST`, sel.Table, sel.ID, ob.Col) tableIDColSuffix(c.w, ti.Name, sel.ID, ob.Col, "_ob") - c.w.WriteString(` DESC NULLS LAST`) + io.WriteString(c.w, ` DESC NULLS LAST`) default: return fmt.Errorf("13: unexpected value %v", ob.Order) } @@ -911,30 +910,30 @@ func (c *compilerContext) renderDistinctOn(sel *qcode.Select, ti *DBTableInfo) { io.WriteString(c.w, `DISTINCT ON (`) for i := range sel.DistinctOn { if i != 0 { - c.w.WriteString(`, `) + io.WriteString(c.w, `, `) } //fmt.Fprintf(w, `"%s_%d.ob.%s"`, c.sel.Table, c.sel.ID, c.sel.DistinctOn[i]) tableIDColSuffix(c.w, ti.Name, sel.ID, sel.DistinctOn[i], "_ob") } - c.w.WriteString(`) `) + io.WriteString(c.w, `) `) } func (c *compilerContext) renderList(ex *qcode.Exp) { io.WriteString(c.w, ` (`) for i := range ex.ListVal { if i != 0 { - c.w.WriteString(`, `) + io.WriteString(c.w, `, `) } switch ex.ListType { case qcode.ValBool, qcode.ValInt, qcode.ValFloat: - c.w.WriteString(ex.ListVal[i]) + io.WriteString(c.w, ex.ListVal[i]) case qcode.ValStr: - c.w.WriteString(`'`) - c.w.WriteString(ex.ListVal[i]) - c.w.WriteString(`'`) + io.WriteString(c.w, `'`) + io.WriteString(c.w, ex.ListVal[i]) + io.WriteString(c.w, `'`) } } - c.w.WriteString(`)`) + io.WriteString(c.w, `)`) } func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string) { @@ -943,27 +942,27 @@ func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string) { switch ex.Type { case qcode.ValBool, qcode.ValInt, qcode.ValFloat: if len(ex.Val) != 0 { - c.w.WriteString(ex.Val) + io.WriteString(c.w, ex.Val) } else { - c.w.WriteString(`''`) + io.WriteString(c.w, `''`) } case qcode.ValStr: - c.w.WriteString(`'`) - c.w.WriteString(ex.Val) - c.w.WriteString(`'`) + io.WriteString(c.w, `'`) + io.WriteString(c.w, ex.Val) + io.WriteString(c.w, `'`) case qcode.ValVar: if val, ok := vars[ex.Val]; ok { - c.w.WriteString(val) + io.WriteString(c.w, val) } else { //fmt.Fprintf(w, `'{{%s}}'`, ex.Val) - c.w.WriteString(`{{`) - c.w.WriteString(ex.Val) - c.w.WriteString(`}}`) + io.WriteString(c.w, `{{`) + io.WriteString(c.w, ex.Val) + io.WriteString(c.w, `}}`) } } - //c.w.WriteString(`)`) + //io.WriteString(c.w, `)`) } func funcPrefixLen(fn string) int { @@ -999,105 +998,105 @@ func hasBit(n uint32, pos uint32) bool { return (val > 0) } -func alias(w *bytes.Buffer, alias string) { - w.WriteString(` AS "`) - w.WriteString(alias) - w.WriteString(`"`) +func alias(w io.Writer, alias string) { + io.WriteString(w, ` AS "`) + io.WriteString(w, alias) + io.WriteString(w, `"`) } -func aliasWithID(w *bytes.Buffer, alias string, id int32) { - w.WriteString(` AS "`) - w.WriteString(alias) - w.WriteString(`_`) +func aliasWithID(w io.Writer, alias string, id int32) { + io.WriteString(w, ` AS "`) + io.WriteString(w, alias) + io.WriteString(w, `_`) int2string(w, id) - w.WriteString(`"`) + io.WriteString(w, `"`) } -func aliasWithIDSuffix(w *bytes.Buffer, alias string, id int32, suffix string) { - w.WriteString(` AS "`) - w.WriteString(alias) - w.WriteString(`_`) +func aliasWithIDSuffix(w io.Writer, alias string, id int32, suffix string) { + io.WriteString(w, ` AS "`) + io.WriteString(w, alias) + io.WriteString(w, `_`) int2string(w, id) - w.WriteString(suffix) - w.WriteString(`"`) + io.WriteString(w, suffix) + io.WriteString(w, `"`) } -func colWithAlias(w *bytes.Buffer, col, alias string) { - w.WriteString(`"`) - w.WriteString(col) - w.WriteString(`" AS "`) - w.WriteString(alias) - w.WriteString(`"`) +func colWithAlias(w io.Writer, col, alias string) { + io.WriteString(w, `"`) + io.WriteString(w, col) + io.WriteString(w, `" AS "`) + io.WriteString(w, alias) + io.WriteString(w, `"`) } -func tableWithAlias(w *bytes.Buffer, table, alias string) { - w.WriteString(`"`) - w.WriteString(table) - w.WriteString(`" AS "`) - w.WriteString(alias) - w.WriteString(`"`) +func tableWithAlias(w io.Writer, table, alias string) { + io.WriteString(w, `"`) + io.WriteString(w, table) + io.WriteString(w, `" AS "`) + io.WriteString(w, alias) + io.WriteString(w, `"`) } -func colWithTable(w *bytes.Buffer, table, col string) { - w.WriteString(`"`) - w.WriteString(table) - w.WriteString(`"."`) - w.WriteString(col) - w.WriteString(`"`) +func colWithTable(w io.Writer, table, col string) { + io.WriteString(w, `"`) + io.WriteString(w, table) + io.WriteString(w, `"."`) + io.WriteString(w, col) + io.WriteString(w, `"`) } -func colWithTableID(w *bytes.Buffer, table string, id int32, col string) { - w.WriteString(`"`) - w.WriteString(table) - w.WriteString(`_`) +func colWithTableID(w io.Writer, table string, id int32, col string) { + io.WriteString(w, `"`) + io.WriteString(w, table) + io.WriteString(w, `_`) int2string(w, id) - w.WriteString(`"."`) - w.WriteString(col) - w.WriteString(`"`) + io.WriteString(w, `"."`) + io.WriteString(w, col) + io.WriteString(w, `"`) } -func colWithTableIDAlias(w *bytes.Buffer, table string, id int32, col, alias string) { - w.WriteString(`"`) - w.WriteString(table) - w.WriteString(`_`) +func colWithTableIDAlias(w io.Writer, table string, id int32, col, alias string) { + io.WriteString(w, `"`) + io.WriteString(w, table) + io.WriteString(w, `_`) int2string(w, id) - w.WriteString(`"."`) - w.WriteString(col) - w.WriteString(`" AS "`) - w.WriteString(alias) - w.WriteString(`"`) + io.WriteString(w, `"."`) + io.WriteString(w, col) + io.WriteString(w, `" AS "`) + io.WriteString(w, alias) + io.WriteString(w, `"`) } -func colWithTableIDSuffixAlias(w *bytes.Buffer, table string, id int32, +func colWithTableIDSuffixAlias(w io.Writer, table string, id int32, suffix, col, alias string) { - w.WriteString(`"`) - w.WriteString(table) - w.WriteString(`_`) + io.WriteString(w, `"`) + io.WriteString(w, table) + io.WriteString(w, `_`) int2string(w, id) - w.WriteString(suffix) - w.WriteString(`"."`) - w.WriteString(col) - w.WriteString(`" AS "`) - w.WriteString(alias) - w.WriteString(`"`) + io.WriteString(w, suffix) + io.WriteString(w, `"."`) + io.WriteString(w, col) + io.WriteString(w, `" AS "`) + io.WriteString(w, alias) + io.WriteString(w, `"`) } -func tableIDColSuffix(w *bytes.Buffer, table string, id int32, col, suffix string) { - w.WriteString(`"`) - w.WriteString(table) - w.WriteString(`_`) +func tableIDColSuffix(w io.Writer, table string, id int32, col, suffix string) { + io.WriteString(w, `"`) + io.WriteString(w, table) + io.WriteString(w, `_`) int2string(w, id) - w.WriteString(`_`) - w.WriteString(col) - w.WriteString(suffix) - w.WriteString(`"`) + io.WriteString(w, `_`) + io.WriteString(w, col) + io.WriteString(w, suffix) + io.WriteString(w, `"`) } const charset = "0123456789" -func int2string(w *bytes.Buffer, val int32) { +func int2string(w io.Writer, val int32) { if val < 10 { - w.WriteByte(charset[val]) + w.Write([]byte{charset[val]}) return } @@ -1113,7 +1112,7 @@ func int2string(w *bytes.Buffer, val int32) { for val3 > 0 { d := val3 % 10 val3 /= 10 - w.WriteByte(charset[d]) + w.Write([]byte{charset[d]}) } } diff --git a/serv/allow.go b/serv/allow.go index 729ef61..1279170 100644 --- a/serv/allow.go +++ b/serv/allow.go @@ -182,7 +182,7 @@ func (al *allowList) load() { item.vars = varBytes } - al.list[gqlHash(q, varBytes)] = item + al.list[gqlHash(q, varBytes, "")] = item varBytes = nil } else if ty == AL_VARS { @@ -203,7 +203,11 @@ func (al *allowList) save(item *allowItem) { if al.active == false { return } - al.list[gqlHash(item.gql, item.vars)] = item + h := gqlHash(item.gql, item.vars, "") + if _, ok := al.list[h]; ok { + return + } + al.list[gqlHash(item.gql, item.vars, "")] = item f, err := os.Create(al.filepath) if err != nil { diff --git a/serv/auth.go b/serv/auth.go index e597a9b..77942eb 100644 --- a/serv/auth.go +++ b/serv/auth.go @@ -9,26 +9,40 @@ import ( var ( userIDProviderKey = struct{}{} userIDKey = struct{}{} + userRoleKey = struct{}{} ) -func headerAuth(r *http.Request, c *config) *http.Request { - if len(c.Auth.Header) == 0 { - return nil - } +func headerAuth(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() - userID := r.Header.Get(c.Auth.Header) - if len(userID) != 0 { - ctx := context.WithValue(r.Context(), userIDKey, userID) - return r.WithContext(ctx) - } + userIDProvider := r.Header.Get("X-User-ID-Provider") + if len(userIDProvider) != 0 { + ctx = context.WithValue(ctx, userIDProviderKey, userIDProvider) + } - return nil + userID := r.Header.Get("X-User-ID") + if len(userID) != 0 { + ctx = context.WithValue(ctx, userIDKey, userID) + } + + userRole := r.Header.Get("X-User-Role") + if len(userRole) != 0 { + ctx = context.WithValue(ctx, userRoleKey, userRole) + } + + next.ServeHTTP(w, r.WithContext(ctx)) + } } func withAuth(next http.HandlerFunc) http.HandlerFunc { at := conf.Auth.Type ru := conf.Auth.Rails.URL + if conf.Auth.CredsInHeader { + next = headerAuth(next) + } + switch at { case "rails": if strings.HasPrefix(ru, "memcache:") { diff --git a/serv/auth_jwt.go b/serv/auth_jwt.go index 25ed785..ef4f834 100644 --- a/serv/auth_jwt.go +++ b/serv/auth_jwt.go @@ -58,11 +58,6 @@ func jwtHandler(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var tok string - if rn := headerAuth(r, conf); rn != nil { - next.ServeHTTP(w, rn) - return - } - if len(cookie) != 0 { ck, err := r.Cookie(cookie) if err != nil { @@ -102,7 +97,6 @@ func jwtHandler(next http.HandlerFunc) http.HandlerFunc { } next.ServeHTTP(w, r.WithContext(ctx)) } - next.ServeHTTP(w, r) } } diff --git a/serv/auth_rails.go b/serv/auth_rails.go index c72c0d7..cd0b327 100644 --- a/serv/auth_rails.go +++ b/serv/auth_rails.go @@ -42,11 +42,6 @@ func railsRedisHandler(next http.HandlerFunc) http.HandlerFunc { } return func(w http.ResponseWriter, r *http.Request) { - if rn := headerAuth(r, conf); rn != nil { - next.ServeHTTP(w, rn) - return - } - ck, err := r.Cookie(cookie) if err != nil { next.ServeHTTP(w, r) @@ -83,17 +78,12 @@ func railsMemcacheHandler(next http.HandlerFunc) http.HandlerFunc { rURL, err := url.Parse(conf.Auth.Rails.URL) if err != nil { - logger.Fatal().Err(err) + logger.Fatal().Err(err).Send() } mc := memcache.New(rURL.Host) return func(w http.ResponseWriter, r *http.Request) { - if rn := headerAuth(r, conf); rn != nil { - next.ServeHTTP(w, rn) - return - } - ck, err := r.Cookie(cookie) if err != nil { next.ServeHTTP(w, r) @@ -126,25 +116,20 @@ func railsCookieHandler(next http.HandlerFunc) http.HandlerFunc { ra, err := railsAuth(conf) if err != nil { - logger.Fatal().Err(err) + logger.Fatal().Err(err).Send() } return func(w http.ResponseWriter, r *http.Request) { - if rn := headerAuth(r, conf); rn != nil { - next.ServeHTTP(w, rn) - return - } - ck, err := r.Cookie(cookie) if err != nil { - logger.Error().Err(err) + logger.Warn().Err(err).Send() next.ServeHTTP(w, r) return } userID, err := ra.ParseCookie(ck.Value) if err != nil { - logger.Error().Err(err) + logger.Warn().Err(err).Send() next.ServeHTTP(w, r) return } diff --git a/serv/cmd.go b/serv/cmd.go index e73b4fb..12b3ce7 100644 --- a/serv/cmd.go +++ b/serv/cmd.go @@ -183,7 +183,32 @@ func initConf() (*config, error) { } zerolog.SetGlobalLevel(logLevel) - //fmt.Printf("%#v", c) + for k, v := range c.DB.Vars { + c.DB.Vars[k] = sanitize(v) + } + + c.RolesQuery = sanitize(c.RolesQuery) + + rolesMap := make(map[string]struct{}) + + for i := range c.Roles { + role := &c.Roles[i] + + if _, ok := rolesMap[role.Name]; ok { + logger.Fatal().Msgf("duplicate role '%s' found", role.Name) + } + role.Name = sanitize(role.Name) + role.Match = sanitize(role.Match) + rolesMap[role.Name] = struct{}{} + } + + if _, ok := rolesMap["user"]; !ok { + c.Roles = append(c.Roles, configRole{Name: "user"}) + } + + if _, ok := rolesMap["anon"]; !ok { + c.Roles = append(c.Roles, configRole{Name: "anon"}) + } return c, nil } diff --git a/serv/cmd_seed.go b/serv/cmd_seed.go index f9b152e..514c543 100644 --- a/serv/cmd_seed.go +++ b/serv/cmd_seed.go @@ -66,8 +66,9 @@ func graphQLFunc(query string, data interface{}) map[string]interface{} { c := &coreContext{Context: context.Background()} c.req.Query = query c.req.Vars = b + c.req.role = "user" - res, err := c.execQuery("user") + res, err := c.execQuery() if err != nil { logger.Fatal().Err(err).Msg("graphql query failed") } diff --git a/serv/config.go b/serv/config.go index b2ebf21..8420c66 100644 --- a/serv/config.go +++ b/serv/config.go @@ -1,7 +1,9 @@ package serv import ( + "regexp" "strings" + "unicode" "github.com/spf13/viper" ) @@ -24,9 +26,9 @@ type config struct { Inflections map[string]string Auth struct { - Type string - Cookie string - Header string + Type string + Cookie string + CredsInHeader bool `mapstructure:"creds_in_header"` Rails struct { Version string @@ -60,7 +62,7 @@ type config struct { MaxRetries int `mapstructure:"max_retries"` LogLevel string `mapstructure:"log_level"` - vars map[string][]byte `mapstructure:"variables"` + Vars map[string]string `mapstructure:"variables"` Defaults struct { Filter []string @@ -71,7 +73,9 @@ type config struct { } `mapstructure:"database"` Tables []configTable - Roles []configRoles + + RolesQuery string `mapstructure:"roles_query"` + Roles []configRole } type configTable struct { @@ -94,8 +98,9 @@ type configRemote struct { } `mapstructure:"set_headers"` } -type configRoles struct { +type configRole struct { Name string + Match string Tables []struct { Name string @@ -163,26 +168,6 @@ func newConfig() *viper.Viper { return vi } -func (c *config) getVariables() map[string]string { - vars := make(map[string]string, len(c.DB.vars)) - - for k, v := range c.DB.vars { - isVar := false - - for i := range v { - if v[i] == '$' { - isVar = true - } else if v[i] == ' ' { - isVar = false - } else if isVar && v[i] >= 'a' && v[i] <= 'z' { - v[i] = 'A' + (v[i] - 'a') - } - } - vars[k] = string(v) - } - return vars -} - func (c *config) getAliasMap() map[string][]string { m := make(map[string][]string, len(c.Tables)) @@ -198,3 +183,21 @@ func (c *config) getAliasMap() map[string][]string { } return m } + +var varRe1 = regexp.MustCompile(`(?mi)\$([a-zA-Z0-9_.]+)`) +var varRe2 = regexp.MustCompile(`\{\{([a-zA-Z0-9_.]+)\}\}`) + +func sanitize(s string) string { + s0 := varRe1.ReplaceAllString(s, `{{$1}}`) + + s1 := strings.Map(func(r rune) rune { + if unicode.IsSpace(r) { + return ' ' + } + return r + }, s0) + + return varRe2.ReplaceAllStringFunc(s1, func(m string) string { + return strings.ToLower(m) + }) +} diff --git a/serv/core.go b/serv/core.go index edc4779..8a007ac 100644 --- a/serv/core.go +++ b/serv/core.go @@ -13,8 +13,8 @@ import ( "github.com/cespare/xxhash/v2" "github.com/dosco/super-graph/jsn" - "github.com/dosco/super-graph/psql" "github.com/dosco/super-graph/qcode" + "github.com/jackc/pgx/v4" "github.com/valyala/fasttemplate" ) @@ -32,15 +32,13 @@ func (c *coreContext) handleReq(w io.Writer, req *http.Request) error { c.req.ref = req.Referer() c.req.hdr = req.Header - var role string - if authCheck(c) { - role = "user" + c.req.role = "user" } else { - role = "anon" + c.req.role = "anon" } - b, err := c.execQuery(role) + b, err := c.execQuery() if err != nil { return err } @@ -48,18 +46,18 @@ func (c *coreContext) handleReq(w io.Writer, req *http.Request) error { return c.render(w, b) } -func (c *coreContext) execQuery(role string) ([]byte, error) { +func (c *coreContext) execQuery() ([]byte, error) { var err error var skipped uint32 var qc *qcode.QCode var data []byte - logger.Debug().Str("role", role).Msg(c.req.Query) + logger.Debug().Str("role", c.req.role).Msg(c.req.Query) if conf.UseAllowList { var ps *preparedItem - data, ps, err = c.resolvePreparedSQL(c.req.Query) + data, ps, err = c.resolvePreparedSQL() if err != nil { return nil, err } @@ -69,12 +67,7 @@ func (c *coreContext) execQuery(role string) ([]byte, error) { } else { - qc, err = qcompile.Compile([]byte(c.req.Query), role) - if err != nil { - return nil, err - } - - data, skipped, err = c.resolveSQL(qc) + data, skipped, err = c.resolveSQL() if err != nil { return nil, err } @@ -122,6 +115,152 @@ func (c *coreContext) execQuery(role string) ([]byte, error) { return ob.Bytes(), nil } +func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) { + tx, err := db.Begin(c) + if err != nil { + return nil, nil, err + } + defer tx.Rollback(c) + + if v := c.Value(userIDKey); v != nil { + _, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v)) + + if err != nil { + return nil, nil, err + } + } + + var role string + useRoleQuery := len(conf.RolesQuery) != 0 && isMutation(c.req.Query) + + if useRoleQuery { + if role, err = c.executeRoleQuery(tx); err != nil { + return nil, nil, err + } + } else if v := c.Value(userRoleKey); v != nil { + role = v.(string) + } else { + role = c.req.role + } + + ps, ok := _preparedList[gqlHash(c.req.Query, c.req.Vars, role)] + if !ok { + return nil, nil, errUnauthorized + } + + var root []byte + vars := varList(c, ps.args) + + err = tx.QueryRow(c, ps.stmt.SQL, vars...).Scan(&root) + if err != nil { + return nil, nil, err + } + + if err := tx.Commit(c); err != nil { + return nil, nil, err + } + + return root, ps, nil +} + +func (c *coreContext) resolveSQL() ([]byte, uint32, error) { + tx, err := db.Begin(c) + if err != nil { + return nil, 0, err + } + defer tx.Rollback(c) + + mutation := isMutation(c.req.Query) + useRoleQuery := len(conf.RolesQuery) != 0 && mutation + + if useRoleQuery { + if c.req.role, err = c.executeRoleQuery(tx); err != nil { + return nil, 0, err + } + + } else if v := c.Value(userRoleKey); v != nil { + c.req.role = v.(string) + } + + stmts, err := c.buildStmt() + if err != nil { + return nil, 0, err + } + + var st *stmt + + if mutation { + st = findStmt(c.req.role, stmts) + } else { + st = &stmts[0] + } + + t := fasttemplate.New(st.sql, openVar, closeVar) + + buf := &bytes.Buffer{} + _, err = t.ExecuteFunc(buf, varMap(c)) + + if err == errNoUserID && + authFailBlock == authFailBlockPerQuery && + authCheck(c) == false { + return nil, 0, errUnauthorized + } + + if err != nil { + return nil, 0, err + } + + finalSQL := buf.String() + + var stime time.Time + + if conf.EnableTracing { + stime = time.Now() + } + + if v := c.Value(userIDKey); v != nil { + _, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v)) + + if err != nil { + return nil, 0, err + } + } + + var root []byte + + if mutation { + err = tx.QueryRow(c, finalSQL).Scan(&root) + } else { + err = tx.QueryRow(c, finalSQL).Scan(&c.req.role, &root) + } + if err != nil { + return nil, 0, err + } + + if err := tx.Commit(c); err != nil { + return nil, 0, err + } + + if mutation { + st = findStmt(c.req.role, stmts) + } else { + st = &stmts[0] + } + + if conf.EnableTracing && len(st.qc.Selects) != 0 { + c.addTrace( + st.qc.Selects, + st.qc.Selects[0].ID, + stime) + } + + if conf.UseAllowList == false { + _allowList.add(&c.req) + } + + return root, st.skipped, nil +} + func (c *coreContext) resolveRemote( hdr http.Header, h *xxhash.Digest, @@ -269,125 +408,15 @@ func (c *coreContext) resolveRemotes( return to, cerr } -func (c *coreContext) resolvePreparedSQL(gql string) ([]byte, *preparedItem, error) { - ps, ok := _preparedList[gqlHash(gql, c.req.Vars)] - if !ok { - return nil, nil, errUnauthorized +func (c *coreContext) executeRoleQuery(tx pgx.Tx) (string, error) { + var role string + row := tx.QueryRow(c, "_sg_get_role", c.req.role, 1) + + if err := row.Scan(&role); err != nil { + return "", err } - var root []byte - vars := varList(c, ps.args) - - tx, err := db.Begin(c) - if err != nil { - return nil, nil, err - } - defer tx.Rollback(c) - - if v := c.Value(userIDKey); v != nil { - _, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v)) - - if err != nil { - return nil, nil, err - } - } - - err = tx.QueryRow(c, ps.stmt.SQL, vars...).Scan(&root) - if err != nil { - return nil, nil, err - } - - if err := tx.Commit(c); err != nil { - return nil, nil, err - } - - fmt.Printf("PRE: %v\n", ps.stmt) - - return root, ps, nil -} - -func (c *coreContext) resolveSQL(qc *qcode.QCode) ([]byte, uint32, error) { - var vars map[string]json.RawMessage - stmt := &bytes.Buffer{} - - if len(c.req.Vars) != 0 { - if err := json.Unmarshal(c.req.Vars, &vars); err != nil { - return nil, 0, err - } - } - - skipped, err := pcompile.Compile(qc, stmt, psql.Variables(vars)) - if err != nil { - return nil, 0, err - } - - t := fasttemplate.New(stmt.String(), openVar, closeVar) - - stmt.Reset() - _, err = t.ExecuteFunc(stmt, varMap(c)) - - if err == errNoUserID && - authFailBlock == authFailBlockPerQuery && - authCheck(c) == false { - return nil, 0, errUnauthorized - } - - if err != nil { - return nil, 0, err - } - - finalSQL := stmt.String() - - // if conf.LogLevel == "debug" { - // os.Stdout.WriteString(finalSQL) - // os.Stdout.WriteString("\n\n") - // } - - var st time.Time - - if conf.EnableTracing { - st = time.Now() - } - - tx, err := db.Begin(c) - if err != nil { - return nil, 0, err - } - defer tx.Rollback(c) - - if v := c.Value(userIDKey); v != nil { - _, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v)) - - if err != nil { - return nil, 0, err - } - } - - //fmt.Printf("\nRAW: %#v\n", finalSQL) - - var root []byte - - err = tx.QueryRow(c, finalSQL).Scan(&root) - if err != nil { - return nil, 0, err - } - - if err := tx.Commit(c); err != nil { - return nil, 0, err - } - - if conf.EnableTracing && len(qc.Selects) != 0 { - c.addTrace( - qc.Selects, - qc.Selects[0].ID, - st) - } - - if conf.UseAllowList == false { - _allowList.add(&c.req) - } - - return root, skipped, nil + return role, nil } func (c *coreContext) render(w io.Writer, data []byte) error { diff --git a/serv/core_build.go b/serv/core_build.go new file mode 100644 index 0000000..fd86003 --- /dev/null +++ b/serv/core_build.go @@ -0,0 +1,144 @@ +package serv + +import ( + "bytes" + "encoding/json" + "errors" + "io" + + "github.com/dosco/super-graph/psql" + "github.com/dosco/super-graph/qcode" +) + +type stmt struct { + role *configRole + qc *qcode.QCode + skipped uint32 + sql string +} + +func (c *coreContext) buildStmt() ([]stmt, error) { + var vars map[string]json.RawMessage + + if len(c.req.Vars) != 0 { + if err := json.Unmarshal(c.req.Vars, &vars); err != nil { + return nil, err + } + } + + gql := []byte(c.req.Query) + + if len(conf.Roles) == 0 { + return nil, errors.New(`no roles found ('user' and 'anon' required)`) + } + + qc, err := qcompile.Compile(gql, conf.Roles[0].Name) + if err != nil { + return nil, err + } + + stmts := make([]stmt, 0, len(conf.Roles)) + mutation := (qc.Type != qcode.QTQuery) + w := &bytes.Buffer{} + + for i := range conf.Roles { + role := &conf.Roles[i] + + if mutation && len(c.req.role) != 0 && role.Name != c.req.role { + continue + } + + if i > 0 { + qc, err = qcompile.Compile(gql, role.Name) + if err != nil { + return nil, err + } + } + + stmts = append(stmts, stmt{role: role, qc: qc}) + + if mutation { + skipped, err := pcompile.Compile(qc, w, psql.Variables(vars)) + if err != nil { + return nil, err + } + + s := &stmts[len(stmts)-1] + s.skipped = skipped + s.sql = w.String() + w.Reset() + } + } + + if mutation { + return stmts, nil + } + + io.WriteString(w, `SELECT "_sg_auth_info"."role", (CASE "_sg_auth_info"."role" `) + + for _, s := range stmts { + io.WriteString(w, `WHEN '`) + io.WriteString(w, s.role.Name) + io.WriteString(w, `' THEN (`) + + s.skipped, err = pcompile.Compile(s.qc, w, psql.Variables(vars)) + if err != nil { + return nil, err + } + + io.WriteString(w, `) `) + } + io.WriteString(w, `END) FROM (`) + + if len(conf.RolesQuery) == 0 { + v := c.Value(userRoleKey) + + io.WriteString(w, `VALUES ("`) + if v != nil { + io.WriteString(w, v.(string)) + } else { + io.WriteString(w, c.req.role) + } + io.WriteString(w, `")) AS "_sg_auth_info"(role) LIMIT 1;`) + + } else { + + io.WriteString(w, `SELECT (CASE WHEN EXISTS (`) + io.WriteString(w, conf.RolesQuery) + io.WriteString(w, `) THEN `) + + io.WriteString(w, `(SELECT (CASE`) + for _, s := range stmts { + if len(s.role.Match) == 0 { + continue + } + io.WriteString(w, ` WHEN `) + io.WriteString(w, s.role.Match) + io.WriteString(w, ` THEN '`) + io.WriteString(w, s.role.Name) + io.WriteString(w, `'`) + } + + if len(c.req.role) == 0 { + io.WriteString(w, ` ELSE 'anon' END) FROM (`) + } else { + io.WriteString(w, ` ELSE '`) + io.WriteString(w, c.req.role) + io.WriteString(w, `' END) FROM (`) + } + + io.WriteString(w, conf.RolesQuery) + io.WriteString(w, `) AS "_sg_auth_roles_query" LIMIT 1) ELSE '`) + if len(c.req.role) == 0 { + io.WriteString(w, `anon`) + } else { + io.WriteString(w, c.req.role) + } + io.WriteString(w, `' END) FROM (VALUES (1)) AS "_sg_auth_filler") AS "_sg_auth_info"(role) LIMIT 1; `) + } + + stmts[0].sql = w.String() + stmts[0].role = nil + + return stmts, nil +} diff --git a/serv/http.go b/serv/http.go index d419d8b..ae8ff84 100644 --- a/serv/http.go +++ b/serv/http.go @@ -30,6 +30,7 @@ type gqlReq struct { Query string `json:"query"` Vars json.RawMessage `json:"variables"` ref string + role string hdr http.Header } @@ -101,13 +102,11 @@ func apiv1Http(w http.ResponseWriter, r *http.Request) { err = ctx.handleReq(w, r) if err == errUnauthorized { - err := "Not authorized" - logger.Debug().Msg(err) - http.Error(w, err, 401) + http.Error(w, "Not authorized", 401) } if err != nil { - logger.Err(err).Msg("Failed to handle request") + logger.Err(err).Msg("failed to handle request") errorResp(w, err) } } diff --git a/serv/prepare.go b/serv/prepare.go index 697d80a..8415397 100644 --- a/serv/prepare.go +++ b/serv/prepare.go @@ -7,7 +7,6 @@ import ( "fmt" "io" - "github.com/dosco/super-graph/psql" "github.com/dosco/super-graph/qcode" "github.com/jackc/pgconn" "github.com/valyala/fasttemplate" @@ -27,55 +26,100 @@ var ( func initPreparedList() { _preparedList = make(map[string]*preparedItem) - for k, v := range _allowList.list { - err := prepareStmt(k, v.gql, v.vars) + if err := prepareRoleStmt(); err != nil { + logger.Fatal().Err(err).Msg("failed to prepare get role statement") + } + + for _, v := range _allowList.list { + err := prepareStmt(v.gql, v.vars) if err != nil { logger.Warn().Str("gql", v.gql).Err(err).Send() } } } -func prepareStmt(key, gql string, varBytes json.RawMessage) error { - if len(gql) == 0 || len(key) == 0 { +func prepareStmt(gql string, varBytes json.RawMessage) error { + if len(gql) == 0 { return nil } - qc, err := qcompile.Compile([]byte(gql), "user") + c := &coreContext{Context: context.Background()} + c.req.Query = gql + c.req.Vars = varBytes + + stmts, err := c.buildStmt() if err != nil { return err } - var vars map[string]json.RawMessage + for _, s := range stmts { + if len(s.sql) == 0 { + continue + } - if len(varBytes) != 0 { - vars = make(map[string]json.RawMessage) + finalSQL, am := processTemplate(s.sql) - if err := json.Unmarshal(varBytes, &vars); err != nil { + ctx := context.Background() + + tx, err := db.Begin(ctx) + if err != nil { + return err + } + defer tx.Rollback(ctx) + + pstmt, err := tx.Prepare(ctx, "", finalSQL) + if err != nil { + return err + } + + var key string + + if s.role == nil { + key = gqlHash(gql, varBytes, "") + } else { + key = gqlHash(gql, varBytes, s.role.Name) + } + + _preparedList[key] = &preparedItem{ + stmt: pstmt, + args: am, + skipped: s.skipped, + qc: s.qc, + } + + if err := tx.Commit(ctx); err != nil { return err } } - buf := &bytes.Buffer{} + return nil +} - skipped, err := pcompile.Compile(qc, buf, psql.Variables(vars)) - if err != nil { - return err +func prepareRoleStmt() error { + if len(conf.RolesQuery) == 0 { + return nil } - t := fasttemplate.New(buf.String(), `{{`, `}}`) - am := make([][]byte, 0, 5) - i := 0 + w := &bytes.Buffer{} - finalSQL := t.ExecuteFuncString(func(w io.Writer, tag string) (int, error) { - am = append(am, []byte(tag)) - i++ - return w.Write([]byte(fmt.Sprintf("$%d", i))) - }) - - if err != nil { - return err + io.WriteString(w, `SELECT (CASE`) + for _, role := range conf.Roles { + if len(role.Match) == 0 { + continue + } + io.WriteString(w, ` WHEN `) + io.WriteString(w, role.Match) + io.WriteString(w, ` THEN '`) + io.WriteString(w, role.Name) + io.WriteString(w, `'`) } + io.WriteString(w, ` ELSE {{role}} END) FROM (`) + io.WriteString(w, conf.RolesQuery) + io.WriteString(w, `) AS "_sg_auth_roles_query"`) + + roleSQL, _ := processTemplate(w.String()) + ctx := context.Background() tx, err := db.Begin(ctx) @@ -84,21 +128,28 @@ func prepareStmt(key, gql string, varBytes json.RawMessage) error { } defer tx.Rollback(ctx) - pstmt, err := tx.Prepare(ctx, "", finalSQL) + _, err = tx.Prepare(ctx, "_sg_get_role", roleSQL) if err != nil { return err } - _preparedList[key] = &preparedItem{ - stmt: pstmt, - args: am, - skipped: skipped, - qc: qc, - } - - if err := tx.Commit(ctx); err != nil { - return err - } - return nil } + +func processTemplate(tmpl string) (string, [][]byte) { + t := fasttemplate.New(tmpl, `{{`, `}}`) + am := make([][]byte, 0, 5) + i := 0 + + vmap := make(map[string]int) + + return t.ExecuteFuncString(func(w io.Writer, tag string) (int, error) { + if n, ok := vmap[tag]; ok { + return w.Write([]byte(fmt.Sprintf("$%d", n))) + } + am = append(am, []byte(tag)) + i++ + vmap[tag] = i + return w.Write([]byte(fmt.Sprintf("$%d", i))) + }), am +} diff --git a/serv/serv.go b/serv/serv.go index 980eb23..e03b9c2 100644 --- a/serv/serv.go +++ b/serv/serv.go @@ -67,7 +67,7 @@ func initCompilers(c *config) (*qcode.Compiler, *psql.Compiler, error) { pc := psql.NewCompiler(psql.Config{ Schema: schema, - Vars: c.getVariables(), + Vars: c.DB.Vars, }) return qc, pc, nil diff --git a/serv/utils.go b/serv/utils.go index baf49c3..9e095e8 100644 --- a/serv/utils.go +++ b/serv/utils.go @@ -21,7 +21,7 @@ func mkkey(h *xxhash.Digest, k1 string, k2 string) uint64 { return v } -func gqlHash(b string, vars []byte) string { +func gqlHash(b string, vars []byte, role string) string { b = strings.TrimSpace(b) h := sha1.New() @@ -56,6 +56,10 @@ func gqlHash(b string, vars []byte) string { } } + if len(role) != 0 { + io.WriteString(h, role) + } + if vars == nil || len(vars) == 0 { return hex.EncodeToString(h.Sum(nil)) } @@ -80,3 +84,26 @@ func ws(b byte) bool { func al(b byte) bool { return (b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z') || (b >= '0' && b <= '9') } + +func isMutation(sql string) bool { + for i := range sql { + b := sql[i] + if b == '{' { + return false + } + if al(b) { + return (b == 'm' || b == 'M') + } + } + return false +} + +func findStmt(role string, stmts []stmt) *stmt { + for i := range stmts { + if stmts[i].role.Name != role { + continue + } + return &stmts[i] + } + return nil +} diff --git a/serv/vars.go b/serv/vars.go index 6ad9da6..f20627c 100644 --- a/serv/vars.go +++ b/serv/vars.go @@ -11,17 +11,27 @@ import ( func varMap(ctx *coreContext) func(w io.Writer, tag string) (int, error) { return func(w io.Writer, tag string) (int, error) { switch tag { - case "user_id": - if v := ctx.Value(userIDKey); v != nil { - return stringVar(w, v.(string)) - } - return 0, errNoUserID - case "user_id_provider": if v := ctx.Value(userIDProviderKey); v != nil { return stringVar(w, v.(string)) } - return 0, errNoUserID + io.WriteString(w, "null") + return 0, nil + + case "user_id": + if v := ctx.Value(userIDKey); v != nil { + return stringVar(w, v.(string)) + } + + io.WriteString(w, "null") + return 0, nil + + case "user_role": + if v := ctx.Value(userRoleKey); v != nil { + return stringVar(w, v.(string)) + } + io.WriteString(w, "null") + return 0, nil } fields := jsn.Get(ctx.req.Vars, [][]byte{[]byte(tag)}) diff --git a/tmpl/prod.yml b/tmpl/prod.yml index 9597d7a..29c6b45 100644 --- a/tmpl/prod.yml +++ b/tmpl/prod.yml @@ -47,10 +47,6 @@ auth: type: rails cookie: _app_session - # Comment this out if you want to disable setting - # the user_id via a header. Good for testing - header: X-User-ID - rails: # Rails version this is used for reading the # various cookies formats. @@ -106,42 +102,93 @@ database: - token tables: - - name: users - # This filter will overwrite defaults.filter - # filter: ["{ id: { eq: $user_id } }"] - # filter_query: ["{ id: { eq: $user_id } }"] - filter_update: ["{ id: { eq: $user_id } }"] - filter_delete: ["{ id: { eq: $user_id } }"] + - name: users + # This filter will overwrite defaults.filter + # filter: ["{ id: { eq: $user_id } }"] - - name: products - # Multiple filters are AND'd together - filter: [ - "{ price: { gt: 0 } }", - "{ price: { lt: 8 } }" - ] + # - name: products + # # Multiple filters are AND'd together + # filter: [ + # "{ price: { gt: 0 } }", + # "{ price: { lt: 8 } }" + # ] - - name: customers - # No filter is used for this field not - # even defaults.filter - filter: none + - name: customers + remotes: + - name: payments + id: stripe_id + url: http://rails_app:3000/stripe/$id + path: data + # debug: true + pass_headers: + - cookie + set_headers: + - name: Host + value: 0.0.0.0 + # - name: Authorization + # value: Bearer - # remotes: - # - name: payments - # id: stripe_id - # url: http://rails_app:3000/stripe/$id - # path: data - # # pass_headers: - # # - cookie - # # - host - # set_headers: - # - name: Authorization - # value: Bearer + - # You can create new fields that have a + # real db table backing them + name: me + table: users - - # You can create new fields that have a - # real db table backing them - name: me - table: users - filter: ["{ id: { eq: $user_id } }"] +roles_query: "SELECT * FROM users as usr WHERE id = $user_id" - # - name: posts - # filter: ["{ account_id: { _eq: $account_id } }"] \ No newline at end of file +roles: + - name: anon + tables: + - name: products + limit: 10 + + query: + columns: ["id", "name", "description" ] + aggregation: false + + insert: + allow: false + + update: + allow: false + + delete: + allow: false + + - name: user + tables: + - name: users + query: + filter: ["{ id: { _eq: $user_id } }"] + + - name: products + + query: + limit: 50 + filter: ["{ user_id: { eq: $user_id } }"] + columns: ["id", "name", "description" ] + disable_aggregation: false + + insert: + filter: ["{ user_id: { eq: $user_id } }"] + columns: ["id", "name", "description" ] + set: + - created_at: "now" + + update: + filter: ["{ user_id: { eq: $user_id } }"] + columns: + - id + - name + set: + - updated_at: "now" + + delete: + deny: true + + - name: admin + match: id = 1 + tables: + - name: users + + # select: + # filter: ["{ account_id: { _eq: $account_id } }"]