From 85a74ed30c33eb8253e1f884cb2abd6616fe9d1a Mon Sep 17 00:00:00 2001 From: Vikram Rangnekar Date: Thu, 10 Oct 2019 01:35:35 -0400 Subject: [PATCH 1/6] Update filters section in guide --- docs/guide.md | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/docs/guide.md b/docs/guide.md index e0aacef..13d6338 100644 --- a/docs/guide.md +++ b/docs/guide.md @@ -1144,7 +1144,10 @@ database: tables: - name: users # This filter will overwrite defaults.filter - filter: ["{ id: { eq: $user_id } }"] + # filter: ["{ id: { eq: $user_id } }"] + # filter_query: ["{ id: { eq: $user_id } }"] + filter_update: ["{ id: { eq: $user_id } }"] + filter_delete: ["{ id: { eq: $user_id } }"] - name: products # Multiple filters are AND'd together From deb5b93c81b2e13f9049524babb113efddfd25c7 Mon Sep 17 00:00:00 2001 From: Vikram Rangnekar Date: Mon, 14 Oct 2019 02:51:36 -0400 Subject: [PATCH 2/6] Add role based access control --- config/dev.yml | 62 ++++- psql/{insert.go => mutate.go} | 40 +-- psql/{insert_test.go => mutate_test.go} | 18 +- psql/{select.go => query.go} | 120 ++++++--- psql/{select_test.go => query_test.go} | 165 ++++++++---- qcode/config.go | 99 ++++++++ qcode/fuzz.go | 2 +- qcode/parse_test.go | 43 +++- qcode/qcode.go | 323 ++++++++++++------------ serv/config.go | 49 +++- serv/core.go | 2 +- serv/prepare.go | 2 +- serv/serv.go | 70 ++--- 13 files changed, 645 insertions(+), 350 deletions(-) rename psql/{insert.go => mutate.go} (87%) rename psql/{insert_test.go => mutate_test.go} (93%) rename psql/{select.go => query.go} (93%) rename psql/{select_test.go => query_test.go} (84%) create mode 100644 qcode/config.go diff --git a/config/dev.yml b/config/dev.yml index 7c3ba46..a314dcd 100644 --- a/config/dev.yml +++ b/config/dev.yml @@ -115,9 +115,6 @@ tables: - name: users # This filter will overwrite defaults.filter # filter: ["{ id: { eq: $user_id } }"] - # filter_query: ["{ id: { eq: $user_id } }"] - filter_update: ["{ id: { eq: $user_id } }"] - filter_delete: ["{ id: { eq: $user_id } }"] # - name: products # # Multiple filters are AND'd together @@ -127,10 +124,6 @@ tables: # ] - name: customers - # No filter is used for this field not - # even defaults.filter - filter: none - remotes: - name: payments id: stripe_id @@ -149,7 +142,56 @@ tables: # real db table backing them name: me table: users - filter: ["{ id: { eq: $user_id } }"] - # - name: posts - # filter: ["{ account_id: { _eq: $account_id } }"] \ No newline at end of file +roles: + - name: anon + tables: + - name: products + limit: 10 + + query: + columns: ["id", "name", "description" ] + aggregation: false + + insert: + allow: false + + update: + allow: false + + delete: + allow: false + + - name: user + tables: + - name: products + + query: + limit: 50 + filter: ["{ user_id: { eq: $user_id } }"] + columns: ["id", "name", "description" ] + disable_aggregation: false + + insert: + filter: ["{ user_id: { eq: $user_id } }"] + columns: ["id", "name", "description" ] + set: + - created_at: "now" + + update: + filter: ["{ user_id: { eq: $user_id } }"] + columns: + - id + - name + set: + - updated_at: "now" + + delete: + deny: true + + - name: manager + tables: + - name: users + + select: + filter: ["{ account_id: { _eq: $account_id } }"] diff --git a/psql/insert.go b/psql/mutate.go similarity index 87% rename from psql/insert.go rename to psql/mutate.go index 027588e..63b3a5a 100644 --- a/psql/insert.go +++ b/psql/mutate.go @@ -10,7 +10,7 @@ import ( "github.com/dosco/super-graph/qcode" ) -var zeroPaging = qcode.Paging{} +var noLimit = qcode.Paging{NoLimit: true} func (co *Compiler) compileMutation(qc *qcode.QCode, w *bytes.Buffer, vars Variables) (uint32, error) { if len(qc.Selects) == 0 { @@ -29,23 +29,23 @@ func (co *Compiler) compileMutation(qc *qcode.QCode, w *bytes.Buffer, vars Varia quoted(c.w, ti.Name) c.w.WriteString(` AS `) - switch root.Action { - case qcode.ActionInsert: + switch qc.Type { + case qcode.QTInsert: if _, err := c.renderInsert(qc, w, vars, ti); err != nil { return 0, err } - case qcode.ActionUpdate: + case qcode.QTUpdate: if _, err := c.renderUpdate(qc, w, vars, ti); err != nil { return 0, err } - case qcode.ActionUpsert: + case qcode.QTUpsert: if _, err := c.renderUpsert(qc, w, vars, ti); err != nil { return 0, err } - case qcode.ActionDelete: + case qcode.QTDelete: if _, err := c.renderDelete(qc, w, vars, ti); err != nil { return 0, err } @@ -56,22 +56,23 @@ func (co *Compiler) compileMutation(qc *qcode.QCode, w *bytes.Buffer, vars Varia io.WriteString(c.w, ` RETURNING *) `) - root.Paging = zeroPaging + root.Paging = noLimit root.DistinctOn = root.DistinctOn[:] root.OrderBy = root.OrderBy[:] root.Where = nil root.Args = nil + qc.Type = qcode.QTQuery + return c.compileQuery(qc, w) } func (c *compilerContext) renderInsert(qc *qcode.QCode, w *bytes.Buffer, vars Variables, ti *DBTableInfo) (uint32, error) { - root := &qc.Selects[0] - insert, ok := vars[root.ActionVar] + insert, ok := vars[qc.ActionVar] if !ok { - return 0, fmt.Errorf("Variable '%s' not defined", root.ActionVar) + return 0, fmt.Errorf("Variable '%s' not defined", qc.ActionVar) } jt, array, err := jsn.Tree(insert) @@ -80,7 +81,7 @@ func (c *compilerContext) renderInsert(qc *qcode.QCode, w *bytes.Buffer, } c.w.WriteString(`(WITH "input" AS (SELECT {{`) - c.w.WriteString(root.ActionVar) + c.w.WriteString(qc.ActionVar) c.w.WriteString(`}}::json AS j) INSERT INTO `) quoted(c.w, ti.Name) io.WriteString(c.w, ` (`) @@ -106,12 +107,18 @@ func (c *compilerContext) renderInsert(qc *qcode.QCode, w *bytes.Buffer, func (c *compilerContext) renderInsertUpdateColumns(qc *qcode.QCode, w *bytes.Buffer, jt map[string]interface{}, ti *DBTableInfo) (uint32, error) { + root := &qc.Selects[0] i := 0 for _, cn := range ti.ColumnNames { if _, ok := jt[cn]; !ok { continue } + if len(root.Allowed) != 0 { + if _, ok := root.Allowed[cn]; !ok { + continue + } + } if i != 0 { io.WriteString(c.w, `, `) } @@ -126,9 +133,9 @@ func (c *compilerContext) renderUpdate(qc *qcode.QCode, w *bytes.Buffer, vars Variables, ti *DBTableInfo) (uint32, error) { root := &qc.Selects[0] - update, ok := vars[root.ActionVar] + update, ok := vars[qc.ActionVar] if !ok { - return 0, fmt.Errorf("Variable '%s' not defined", root.ActionVar) + return 0, fmt.Errorf("Variable '%s' not defined", qc.ActionVar) } jt, array, err := jsn.Tree(update) @@ -137,7 +144,7 @@ func (c *compilerContext) renderUpdate(qc *qcode.QCode, w *bytes.Buffer, } c.w.WriteString(`(WITH "input" AS (SELECT {{`) - c.w.WriteString(root.ActionVar) + c.w.WriteString(qc.ActionVar) c.w.WriteString(`}}::json AS j) UPDATE `) quoted(c.w, ti.Name) io.WriteString(c.w, ` SET (`) @@ -183,11 +190,10 @@ func (c *compilerContext) renderDelete(qc *qcode.QCode, w *bytes.Buffer, func (c *compilerContext) renderUpsert(qc *qcode.QCode, w *bytes.Buffer, vars Variables, ti *DBTableInfo) (uint32, error) { - root := &qc.Selects[0] - upsert, ok := vars[root.ActionVar] + upsert, ok := vars[qc.ActionVar] if !ok { - return 0, fmt.Errorf("Variable '%s' not defined", root.ActionVar) + return 0, fmt.Errorf("Variable '%s' not defined", qc.ActionVar) } jt, _, err := jsn.Tree(upsert) diff --git a/psql/insert_test.go b/psql/mutate_test.go similarity index 93% rename from psql/insert_test.go rename to psql/mutate_test.go index bd03d7f..8555f74 100644 --- a/psql/insert_test.go +++ b/psql/mutate_test.go @@ -18,7 +18,7 @@ func simpleInsert(t *testing.T) { "data": json.RawMessage(`{"email": "reannagreenholt@orn.com", "full_name": "Flo Barton"}`), } - resSQL, err := compileGQLToPSQL(gql, vars) + resSQL, err := compileGQLToPSQL(gql, vars, "user") if err != nil { t.Fatal(err) } @@ -42,7 +42,7 @@ func singleInsert(t *testing.T) { "insert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc", "user_id": 5 }`), } - resSQL, err := compileGQLToPSQL(gql, vars) + resSQL, err := compileGQLToPSQL(gql, vars, "user") if err != nil { t.Fatal(err) } @@ -54,7 +54,7 @@ func singleInsert(t *testing.T) { func bulkInsert(t *testing.T) { gql := `mutation { - product(id: 15, insert: $insert) { + product(name: "test", id: 15, insert: $insert) { id name } @@ -66,7 +66,7 @@ func bulkInsert(t *testing.T) { "insert": json.RawMessage(` [{ "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }]`), } - resSQL, err := compileGQLToPSQL(gql, vars) + resSQL, err := compileGQLToPSQL(gql, vars, "user") if err != nil { t.Fatal(err) } @@ -90,7 +90,7 @@ func singleUpsert(t *testing.T) { "upsert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`), } - resSQL, err := compileGQLToPSQL(gql, vars) + resSQL, err := compileGQLToPSQL(gql, vars, "user") if err != nil { t.Fatal(err) } @@ -114,7 +114,7 @@ func bulkUpsert(t *testing.T) { "upsert": json.RawMessage(` [{ "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }]`), } - resSQL, err := compileGQLToPSQL(gql, vars) + resSQL, err := compileGQLToPSQL(gql, vars, "user") if err != nil { t.Fatal(err) } @@ -138,7 +138,7 @@ func singleUpdate(t *testing.T) { "update": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`), } - resSQL, err := compileGQLToPSQL(gql, vars) + resSQL, err := compileGQLToPSQL(gql, vars, "user") if err != nil { t.Fatal(err) } @@ -162,7 +162,7 @@ func delete(t *testing.T) { "update": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`), } - resSQL, err := compileGQLToPSQL(gql, vars) + resSQL, err := compileGQLToPSQL(gql, vars, "user") if err != nil { t.Fatal(err) } @@ -172,7 +172,7 @@ func delete(t *testing.T) { } } -func TestCompileInsert(t *testing.T) { +func TestCompileMutate(t *testing.T) { t.Run("simpleInsert", simpleInsert) t.Run("singleInsert", singleInsert) t.Run("bulkInsert", bulkInsert) diff --git a/psql/select.go b/psql/query.go similarity index 93% rename from psql/select.go rename to psql/query.go index 1058a62..8d70e43 100644 --- a/psql/select.go +++ b/psql/query.go @@ -64,11 +64,11 @@ func (co *Compiler) Compile(qc *qcode.QCode, w *bytes.Buffer, vars Variables) (u switch qc.Type { case qcode.QTQuery: return co.compileQuery(qc, w) - case qcode.QTMutation: + case qcode.QTInsert, qcode.QTUpdate, qcode.QTDelete, qcode.QTUpsert: return co.compileMutation(qc, w, vars) } - return 0, errors.New("unknown operation") + return 0, fmt.Errorf("Unknown operation type %d", qc.Type) } func (co *Compiler) compileQuery(qc *qcode.QCode, w *bytes.Buffer) (uint32, error) { @@ -295,19 +295,21 @@ func (c *compilerContext) renderSelectClose(sel *qcode.Select, ti *DBTableInfo) } } - if sel.Action == 0 { - if len(sel.Paging.Limit) != 0 { - //fmt.Fprintf(w, ` LIMIT ('%s') :: integer`, c.sel.Paging.Limit) - c.w.WriteString(` LIMIT ('`) - c.w.WriteString(sel.Paging.Limit) - c.w.WriteString(`') :: integer`) + switch { + case sel.Paging.NoLimit: + break - } else if ti.Singular { - c.w.WriteString(` LIMIT ('1') :: integer`) + case len(sel.Paging.Limit) != 0: + //fmt.Fprintf(w, ` LIMIT ('%s') :: integer`, c.sel.Paging.Limit) + c.w.WriteString(` LIMIT ('`) + c.w.WriteString(sel.Paging.Limit) + c.w.WriteString(`') :: integer`) - } else { - c.w.WriteString(` LIMIT ('20') :: integer`) - } + case ti.Singular: + c.w.WriteString(` LIMIT ('1') :: integer`) + + default: + c.w.WriteString(` LIMIT ('20') :: integer`) } if len(sel.Paging.Offset) != 0 { @@ -370,13 +372,31 @@ func (c *compilerContext) renderJoinTable(sel *qcode.Select) error { } func (c *compilerContext) renderColumns(sel *qcode.Select, ti *DBTableInfo) { - for i, col := range sel.Cols { + i := 0 + for _, col := range sel.Cols { + if len(sel.Allowed) != 0 { + n := funcPrefixLen(col.Name) + if n != 0 { + if sel.Functions == false { + continue + } + if _, ok := sel.Allowed[col.Name[n:]]; !ok { + continue + } + } else { + if _, ok := sel.Allowed[col.Name]; !ok { + continue + } + } + } + if i != 0 { io.WriteString(c.w, ", ") } //fmt.Fprintf(w, `"%s_%d"."%s" AS "%s"`, //c.sel.Table, c.sel.ID, col.Name, col.FieldName) colWithTableIDAlias(c.w, ti.Name, sel.ID, col.Name, col.FieldName) + i++ } } @@ -435,7 +455,8 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo, c.w.WriteString(` FROM (SELECT `) - for i, col := range sel.Cols { + i := 0 + for n, col := range sel.Cols { cn := col.Name _, isRealCol := ti.Columns[cn] @@ -447,6 +468,9 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo, cn = ti.TSVCol arg := sel.Args["search"] + if i != 0 { + c.w.WriteString(`, `) + } //fmt.Fprintf(w, `ts_rank("%s"."%s", to_tsquery('%s')) AS %s`, //c.sel.Table, cn, arg.Val, col.Name) c.w.WriteString(`ts_rank(`) @@ -455,11 +479,15 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo, c.w.WriteString(arg.Val) c.w.WriteString(`')`) alias(c.w, col.Name) + i++ case strings.HasPrefix(cn, "search_headline_"): cn = cn[16:] arg := sel.Args["search"] + if i != 0 { + c.w.WriteString(`, `) + } //fmt.Fprintf(w, `ts_headline("%s"."%s", to_tsquery('%s')) AS %s`, //c.sel.Table, cn, arg.Val, col.Name) c.w.WriteString(`ts_headlinek(`) @@ -468,47 +496,63 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo, c.w.WriteString(arg.Val) c.w.WriteString(`')`) alias(c.w, col.Name) + i++ + } } else { pl := funcPrefixLen(cn) if pl == 0 { + if i != 0 { + c.w.WriteString(`, `) + } //fmt.Fprintf(w, `'%s not defined' AS %s`, cn, col.Name) c.w.WriteString(`'`) c.w.WriteString(cn) c.w.WriteString(` not defined'`) alias(c.w, col.Name) - } else { - isAgg = true + i++ + + } else if sel.Functions { + cn1 := cn[pl:] + if _, ok := sel.Allowed[cn1]; !ok { + continue + } + if i != 0 { + c.w.WriteString(`, `) + } fn := cn[0 : pl-1] - cn := cn[pl:] + isAgg = true + //fmt.Fprintf(w, `%s("%s"."%s") AS %s`, fn, c.sel.Table, cn, col.Name) c.w.WriteString(fn) c.w.WriteString(`(`) - colWithTable(c.w, ti.Name, cn) + colWithTable(c.w, ti.Name, cn1) c.w.WriteString(`)`) alias(c.w, col.Name) + i++ + } } } else { - groupBy = append(groupBy, i) + groupBy = append(groupBy, n) //fmt.Fprintf(w, `"%s"."%s"`, c.sel.Table, cn) + if i != 0 { + c.w.WriteString(`, `) + } colWithTable(c.w, ti.Name, cn) - } + i++ - if i < len(sel.Cols)-1 || len(childCols) != 0 { - //io.WriteString(w, ", ") - c.w.WriteString(`, `) } } - for i, col := range childCols { + for _, col := range childCols { if i != 0 { - //io.WriteString(w, ", ") c.w.WriteString(`, `) } //fmt.Fprintf(w, `"%s"."%s"`, col.Table, col.Name) colWithTable(c.w, col.Table, col.Name) + i++ } c.w.WriteString(` FROM `) @@ -570,19 +614,21 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo, } } - if sel.Action == 0 { - if len(sel.Paging.Limit) != 0 { - //fmt.Fprintf(w, ` LIMIT ('%s') :: integer`, c.sel.Paging.Limit) - c.w.WriteString(` LIMIT ('`) - c.w.WriteString(sel.Paging.Limit) - c.w.WriteString(`') :: integer`) + switch { + case sel.Paging.NoLimit: + break - } else if ti.Singular { - c.w.WriteString(` LIMIT ('1') :: integer`) + case len(sel.Paging.Limit) != 0: + //fmt.Fprintf(w, ` LIMIT ('%s') :: integer`, c.sel.Paging.Limit) + c.w.WriteString(` LIMIT ('`) + c.w.WriteString(sel.Paging.Limit) + c.w.WriteString(`') :: integer`) - } else { - c.w.WriteString(` LIMIT ('20') :: integer`) - } + case ti.Singular: + c.w.WriteString(` LIMIT ('1') :: integer`) + + default: + c.w.WriteString(` LIMIT ('20') :: integer`) } if len(sel.Paging.Offset) != 0 { diff --git a/psql/select_test.go b/psql/query_test.go similarity index 84% rename from psql/select_test.go rename to psql/query_test.go index 553f576..ea65645 100644 --- a/psql/select_test.go +++ b/psql/query_test.go @@ -22,32 +22,6 @@ func TestMain(m *testing.M) { var err error qcompile, err = qcode.NewCompiler(qcode.Config{ - DefaultFilter: []string{ - `{ user_id: { _eq: $user_id } }`, - }, - FilterMap: qcode.Filters{ - All: map[string][]string{ - "users": []string{ - "{ id: { eq: $user_id } }", - }, - "products": []string{ - "{ price: { gt: 0 } }", - "{ price: { lt: 8 } }", - }, - "customers": []string{}, - "mes": []string{ - "{ id: { eq: $user_id } }", - }, - }, - Query: map[string][]string{ - "users": []string{}, - }, - Update: map[string][]string{ - "products": []string{ - "{ user_id: { eq: $user_id } }", - }, - }, - }, Blocklist: []string{ "secret", "password", @@ -55,6 +29,59 @@ func TestMain(m *testing.M) { }, }) + qcompile.AddRole("user", "product", qcode.TRConfig{ + Query: qcode.QueryConfig{ + Columns: []string{"id", "name", "price", "users", "customers"}, + Filter: []string{ + "{ price: { gt: 0 } }", + "{ price: { lt: 8 } }", + }, + }, + Update: qcode.UpdateConfig{ + Filter: []string{"{ user_id: { eq: $user_id } }"}, + }, + Delete: qcode.DeleteConfig{ + Filter: []string{ + "{ price: { gt: 0 } }", + "{ price: { lt: 8 } }", + }, + }, + }) + + qcompile.AddRole("anon", "product", qcode.TRConfig{ + Query: qcode.QueryConfig{ + Columns: []string{"id", "name"}, + }, + }) + + qcompile.AddRole("anon1", "product", qcode.TRConfig{ + Query: qcode.QueryConfig{ + Columns: []string{"id", "name", "price"}, + DisableFunctions: true, + }, + }) + + qcompile.AddRole("user", "users", qcode.TRConfig{ + Query: qcode.QueryConfig{ + Columns: []string{"id", "full_name", "avatar", "email", "products"}, + }, + }) + + qcompile.AddRole("user", "mes", qcode.TRConfig{ + Query: qcode.QueryConfig{ + Columns: []string{"id", "full_name", "avatar"}, + Filter: []string{ + "{ id: { eq: $user_id } }", + }, + }, + }) + + qcompile.AddRole("user", "customers", qcode.TRConfig{ + Query: qcode.QueryConfig{ + Columns: []string{"id", "email", "full_name", "products"}, + }, + }) + if err != nil { log.Fatal(err) } @@ -135,9 +162,8 @@ func TestMain(m *testing.M) { os.Exit(m.Run()) } -func compileGQLToPSQL(gql string, vars Variables) ([]byte, error) { - - qc, err := qcompile.Compile([]byte(gql)) +func compileGQLToPSQL(gql string, vars Variables, role string) ([]byte, error) { + qc, err := qcompile.Compile([]byte(gql), role) if err != nil { return nil, err } @@ -147,6 +173,8 @@ func compileGQLToPSQL(gql string, vars Variables) ([]byte, error) { return nil, err } + //fmt.Println(string(sqlStmt)) + return sqlStmt, nil } @@ -175,7 +203,7 @@ func withComplexArgs(t *testing.T) { sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0" ORDER BY "products_0_price_ob" DESC), '[]') AS "products" FROM (SELECT DISTINCT ON ("products_0_price_ob") row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0", "products_0"."price" AS "products_0_price_ob" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") < 28) AND (("products"."id") >= 20)) LIMIT ('30') :: integer) AS "products_0" ORDER BY "products_0_price_ob" DESC LIMIT ('30') :: integer) AS "sel_json_agg_0") AS "done_1337";` - resSQL, err := compileGQLToPSQL(gql, nil) + resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { t.Fatal(err) } @@ -203,7 +231,7 @@ func withWhereMultiOr(t *testing.T) { sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") < 20) OR (("products"."price") > 10) OR NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` - resSQL, err := compileGQLToPSQL(gql, nil) + resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { t.Fatal(err) } @@ -229,7 +257,7 @@ func withWhereIsNull(t *testing.T) { sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") > 10) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` - resSQL, err := compileGQLToPSQL(gql, nil) + resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { t.Fatal(err) } @@ -255,7 +283,7 @@ func withWhereAndList(t *testing.T) { sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") > 10) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` - resSQL, err := compileGQLToPSQL(gql, nil) + resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { t.Fatal(err) } @@ -275,7 +303,7 @@ func fetchByID(t *testing.T) { sql := `SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") = 15)) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337";` - resSQL, err := compileGQLToPSQL(gql, nil) + resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { t.Fatal(err) } @@ -295,7 +323,7 @@ func searchQuery(t *testing.T) { sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("tsv") @@ to_tsquery('Imperial'))) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` - resSQL, err := compileGQLToPSQL(gql, nil) + resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { t.Fatal(err) } @@ -318,7 +346,7 @@ func oneToMany(t *testing.T) { sql := `SELECT json_object_agg('users', users) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."email" AS "email", "products_1_join"."products" AS "products") AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."email", "users"."id" FROM "users" LIMIT ('20') :: integer) AS "users_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "products_1"."name" AS "name", "products_1"."price" AS "price") AS "sel_1")) AS "sel_json_1" FROM (SELECT "products"."name", "products"."price" FROM "products" WHERE ((("products"."user_id") = ("users_0"."id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` - resSQL, err := compileGQLToPSQL(gql, nil) + resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { t.Fatal(err) } @@ -341,7 +369,7 @@ func belongsTo(t *testing.T) { sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."price" AS "price", "users_1_join"."users" AS "users") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", "products"."price", "products"."user_id" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "users_1"."email" AS "email") AS "sel_1")) AS "sel_json_1" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = ("products_0"."user_id"))) LIMIT ('20') :: integer) AS "users_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "users_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` - resSQL, err := compileGQLToPSQL(gql, nil) + resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { t.Fatal(err) } @@ -364,7 +392,7 @@ func manyToMany(t *testing.T) { sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "customers_1_join"."customers" AS "customers") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", "products"."id" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "customers_1"."email" AS "email", "customers_1"."full_name" AS "full_name") AS "sel_1")) AS "sel_json_1" FROM (SELECT "customers"."email", "customers"."full_name" FROM "customers" LEFT OUTER JOIN "purchases" ON (("purchases"."product_id") = ("products_0"."id")) WHERE ((("customers"."id") = ("purchases"."customer_id"))) LIMIT ('20') :: integer) AS "customers_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "customers_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` - resSQL, err := compileGQLToPSQL(gql, nil) + resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { t.Fatal(err) } @@ -387,7 +415,7 @@ func manyToManyReverse(t *testing.T) { sql := `SELECT json_object_agg('customers', customers) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "customers_0"."email" AS "email", "customers_0"."full_name" AS "full_name", "products_1_join"."products" AS "products") AS "sel_0")) AS "sel_json_0" FROM (SELECT "customers"."email", "customers"."full_name", "customers"."id" FROM "customers" LIMIT ('20') :: integer) AS "customers_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "products_1"."name" AS "name") AS "sel_1")) AS "sel_json_1" FROM (SELECT "products"."name" FROM "products" LEFT OUTER JOIN "purchases" ON (("purchases"."customer_id") = ("customers_0"."id")) WHERE ((("products"."id") = ("purchases"."product_id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` - resSQL, err := compileGQLToPSQL(gql, nil) + resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { t.Fatal(err) } @@ -407,7 +435,47 @@ func aggFunction(t *testing.T) { sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."count_price" AS "count_price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", count("products"."price") AS "count_price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) GROUP BY "products"."name" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` - resSQL, err := compileGQLToPSQL(gql, nil) + resSQL, err := compileGQLToPSQL(gql, nil, "user") + if err != nil { + t.Fatal(err) + } + + if string(resSQL) != sql { + t.Fatal(errNotExpected) + } +} + +func aggFunctionBlockedByCol(t *testing.T) { + gql := `query { + products { + name + count_price + } + }` + + sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name" FROM "products" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` + + resSQL, err := compileGQLToPSQL(gql, nil, "anon") + if err != nil { + t.Fatal(err) + } + + if string(resSQL) != sql { + t.Fatal(errNotExpected) + } +} + +func aggFunctionDisabled(t *testing.T) { + gql := `query { + products { + name + count_price + } + }` + + sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name" FROM "products" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` + + resSQL, err := compileGQLToPSQL(gql, nil, "anon1") if err != nil { t.Fatal(err) } @@ -427,7 +495,7 @@ func aggFunctionWithFilter(t *testing.T) { sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."max_price" AS "max_price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", max("products"."price") AS "max_price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") > 10)) GROUP BY "products"."id" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` - resSQL, err := compileGQLToPSQL(gql, nil) + resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { t.Fatal(err) } @@ -447,7 +515,7 @@ func queryWithVariables(t *testing.T) { sql := `SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") = {{product_price}}) AND (("products"."id") = {{product_id}})) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337";` - resSQL, err := compileGQLToPSQL(gql, nil) + resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { t.Fatal(err) } @@ -464,9 +532,9 @@ func syntheticTables(t *testing.T) { } }` - sql := `SELECT json_object_agg('me', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."email" AS "email") AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = {{user_id}})) LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "done_1337";` + sql := `SELECT json_object_agg('me', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT ) AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = {{user_id}})) LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "done_1337";` - resSQL, err := compileGQLToPSQL(gql, nil) + resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { t.Fatal(err) } @@ -476,7 +544,7 @@ func syntheticTables(t *testing.T) { } } -func TestCompileSelect(t *testing.T) { +func TestCompileQuery(t *testing.T) { t.Run("withComplexArgs", withComplexArgs) t.Run("withWhereAndList", withWhereAndList) t.Run("withWhereIsNull", withWhereIsNull) @@ -488,10 +556,11 @@ func TestCompileSelect(t *testing.T) { t.Run("manyToMany", manyToMany) t.Run("manyToManyReverse", manyToManyReverse) t.Run("aggFunction", aggFunction) + t.Run("aggFunctionBlockedByCol", aggFunctionBlockedByCol) + t.Run("aggFunctionDisabled", aggFunctionDisabled) t.Run("aggFunctionWithFilter", aggFunctionWithFilter) t.Run("syntheticTables", syntheticTables) t.Run("queryWithVariables", queryWithVariables) - } var benchGQL = []byte(`query { @@ -526,7 +595,7 @@ func BenchmarkCompile(b *testing.B) { for n := 0; n < b.N; n++ { w.Reset() - qc, err := qcompile.Compile(benchGQL) + qc, err := qcompile.Compile(benchGQL, "user") if err != nil { b.Fatal(err) } @@ -547,7 +616,7 @@ func BenchmarkCompileParallel(b *testing.B) { for pb.Next() { w.Reset() - qc, err := qcompile.Compile(benchGQL) + qc, err := qcompile.Compile(benchGQL, "user") if err != nil { b.Fatal(err) } diff --git a/qcode/config.go b/qcode/config.go new file mode 100644 index 0000000..b2ab9c4 --- /dev/null +++ b/qcode/config.go @@ -0,0 +1,99 @@ +package qcode + +type Config struct { + Blocklist []string + KeepArgs bool +} + +type QueryConfig struct { + Limit int + Filter []string + Columns []string + DisableFunctions bool +} + +type InsertConfig struct { + Filter []string + Columns []string + Set map[string]string +} + +type UpdateConfig struct { + Filter []string + Columns []string + Set map[string]string +} + +type DeleteConfig struct { + Filter []string + Columns []string +} + +type TRConfig struct { + Query QueryConfig + Insert InsertConfig + Update UpdateConfig + Delete DeleteConfig +} + +type trval struct { + query struct { + limit string + fil *Exp + cols map[string]struct{} + disable struct { + funcs bool + } + } + + insert struct { + fil *Exp + cols map[string]struct{} + set map[string]string + } + + update struct { + fil *Exp + cols map[string]struct{} + set map[string]string + } + + delete struct { + fil *Exp + cols map[string]struct{} + } +} + +func (trv *trval) allowedColumns(qt QType) map[string]struct{} { + switch qt { + case QTQuery: + return trv.query.cols + case QTInsert: + return trv.insert.cols + case QTUpdate: + return trv.update.cols + case QTDelete: + return trv.insert.cols + case QTUpsert: + return trv.insert.cols + } + + return nil +} + +func (trv *trval) filter(qt QType) *Exp { + switch qt { + case QTQuery: + return trv.query.fil + case QTInsert: + return trv.insert.fil + case QTUpdate: + return trv.update.fil + case QTDelete: + return trv.delete.fil + case QTUpsert: + return trv.insert.fil + } + + return nil +} diff --git a/qcode/fuzz.go b/qcode/fuzz.go index db8f3c8..89a4a3c 100644 --- a/qcode/fuzz.go +++ b/qcode/fuzz.go @@ -5,7 +5,7 @@ func FuzzerEntrypoint(data []byte) int { //testData := string(data) qcompile, _ := NewCompiler(Config{}) - _, err := qcompile.Compile(data) + _, err := qcompile.Compile(data, "user") if err != nil { return -1 } diff --git a/qcode/parse_test.go b/qcode/parse_test.go index dba397f..0e04ed2 100644 --- a/qcode/parse_test.go +++ b/qcode/parse_test.go @@ -46,13 +46,18 @@ func compareOp(op1, op2 Operation) error { */ func TestCompile1(t *testing.T) { - qcompile, _ := NewCompiler(Config{}) + qc, _ := NewCompiler(Config{}) + qc.AddRole("user", "product", TRConfig{ + Query: QueryConfig{ + Columns: []string{"id", "Name"}, + }, + }) - _, err := qcompile.Compile([]byte(` + _, err := qc.Compile([]byte(` product(id: 15) { id name - }`)) + }`), "user") if err != nil { t.Fatal(err) @@ -60,13 +65,18 @@ func TestCompile1(t *testing.T) { } func TestCompile2(t *testing.T) { - qcompile, _ := NewCompiler(Config{}) + qc, _ := NewCompiler(Config{}) + qc.AddRole("user", "product", TRConfig{ + Query: QueryConfig{ + Columns: []string{"ID"}, + }, + }) - _, err := qcompile.Compile([]byte(` + _, err := qc.Compile([]byte(` query { product(id: 15) { id name - } }`)) + } }`), "user") if err != nil { t.Fatal(err) @@ -74,15 +84,20 @@ func TestCompile2(t *testing.T) { } func TestCompile3(t *testing.T) { - qcompile, _ := NewCompiler(Config{}) + qc, _ := NewCompiler(Config{}) + qc.AddRole("user", "product", TRConfig{ + Query: QueryConfig{ + Columns: []string{"ID"}, + }, + }) - _, err := qcompile.Compile([]byte(` + _, err := qc.Compile([]byte(` mutation { product(id: 15, name: "Test") { id name } - }`)) + }`), "user") if err != nil { t.Fatal(err) @@ -91,7 +106,7 @@ func TestCompile3(t *testing.T) { func TestInvalidCompile1(t *testing.T) { qcompile, _ := NewCompiler(Config{}) - _, err := qcompile.Compile([]byte(`#`)) + _, err := qcompile.Compile([]byte(`#`), "user") if err == nil { t.Fatal(errors.New("expecting an error")) @@ -100,7 +115,7 @@ func TestInvalidCompile1(t *testing.T) { func TestInvalidCompile2(t *testing.T) { qcompile, _ := NewCompiler(Config{}) - _, err := qcompile.Compile([]byte(`{u(where:{not:0})}`)) + _, err := qcompile.Compile([]byte(`{u(where:{not:0})}`), "user") if err == nil { t.Fatal(errors.New("expecting an error")) @@ -109,7 +124,7 @@ func TestInvalidCompile2(t *testing.T) { func TestEmptyCompile(t *testing.T) { qcompile, _ := NewCompiler(Config{}) - _, err := qcompile.Compile([]byte(``)) + _, err := qcompile.Compile([]byte(``), "user") if err == nil { t.Fatal(errors.New("expecting an error")) @@ -144,7 +159,7 @@ func BenchmarkQCompile(b *testing.B) { b.ReportAllocs() for n := 0; n < b.N; n++ { - _, err := qcompile.Compile(gql) + _, err := qcompile.Compile(gql, "user") if err != nil { b.Fatal(err) @@ -160,7 +175,7 @@ func BenchmarkQCompileP(b *testing.B) { b.RunParallel(func(pb *testing.PB) { for pb.Next() { - _, err := qcompile.Compile(gql) + _, err := qcompile.Compile(gql, "user") if err != nil { b.Fatal(err) diff --git a/qcode/qcode.go b/qcode/qcode.go index bafe3e4..05e8e98 100644 --- a/qcode/qcode.go +++ b/qcode/qcode.go @@ -3,6 +3,7 @@ package qcode import ( "errors" "fmt" + "strconv" "strings" "sync" @@ -17,23 +18,16 @@ const ( maxSelectors = 30 QTQuery QType = iota + 1 - QTMutation - - ActionInsert Action = iota + 1 - ActionUpdate - ActionDelete - ActionUpsert + QTInsert + QTUpdate + QTDelete + QTUpsert ) type QCode struct { - Type QType - Selects []Select -} - -type Column struct { - Table string - Name string - FieldName string + Type QType + ActionVar string + Selects []Select } type Select struct { @@ -47,9 +41,15 @@ type Select struct { OrderBy []*OrderBy DistinctOn []string Paging Paging - Action Action - ActionVar string Children []int32 + Functions bool + Allowed map[string]struct{} +} + +type Column struct { + Table string + Name string + FieldName string } type Exp struct { @@ -77,8 +77,9 @@ type OrderBy struct { } type Paging struct { - Limit string - Offset string + Limit string + Offset string + NoLimit bool } type ExpOp int @@ -145,81 +146,23 @@ const ( OrderDescNullsLast ) -type Filters struct { - All map[string][]string - Query map[string][]string - Insert map[string][]string - Update map[string][]string - Delete map[string][]string -} - -type Config struct { - DefaultFilter []string - FilterMap Filters - Blocklist []string - KeepArgs bool -} - type Compiler struct { - df *Exp - fm struct { - all map[string]*Exp - query map[string]*Exp - insert map[string]*Exp - update map[string]*Exp - delete map[string]*Exp - } + tr map[string]map[string]*trval bl map[string]struct{} ka bool } -var opMap = map[parserType]QType{ - opQuery: QTQuery, - opMutate: QTMutation, -} - var expPool = sync.Pool{ New: func() interface{} { return &Exp{doFree: true} }, } func NewCompiler(c Config) (*Compiler, error) { - var err error co := &Compiler{ka: c.KeepArgs} - + co.tr = make(map[string]map[string]*trval) co.bl = make(map[string]struct{}, len(c.Blocklist)) for i := range c.Blocklist { - co.bl[c.Blocklist[i]] = struct{}{} - } - - co.df, err = compileFilter(c.DefaultFilter) - if err != nil { - return nil, err - } - - co.fm.all, err = buildFilters(c.FilterMap.All) - if err != nil { - return nil, err - } - - co.fm.query, err = buildFilters(c.FilterMap.Query) - if err != nil { - return nil, err - } - - co.fm.insert, err = buildFilters(c.FilterMap.Insert) - if err != nil { - return nil, err - } - - co.fm.update, err = buildFilters(c.FilterMap.Update) - if err != nil { - return nil, err - } - - co.fm.delete, err = buildFilters(c.FilterMap.Delete) - if err != nil { - return nil, err + co.bl[strings.ToLower(c.Blocklist[i])] = struct{}{} } seedExp := [100]Exp{} @@ -232,58 +175,99 @@ func NewCompiler(c Config) (*Compiler, error) { return co, nil } -func buildFilters(filMap map[string][]string) (map[string]*Exp, error) { - fm := make(map[string]*Exp, len(filMap)) +func (com *Compiler) AddRole(role, table string, trc TRConfig) error { + var err error + trv := &trval{} - for k, v := range filMap { - fil, err := compileFilter(v) - if err != nil { - return nil, err + toMap := func(cols []string) map[string]struct{} { + m := make(map[string]struct{}, len(cols)) + for i := range cols { + m[strings.ToLower(cols[i])] = struct{}{} } - singular := flect.Singularize(k) - plural := flect.Pluralize(k) - - fm[singular] = fil - fm[plural] = fil + return m } - return fm, nil + // query config + trv.query.fil, err = compileFilter(trc.Query.Filter) + if err != nil { + return err + } + if trc.Query.Limit > 0 { + trv.query.limit = strconv.Itoa(trc.Query.Limit) + } + trv.query.cols = toMap(trc.Query.Columns) + trv.query.disable.funcs = trc.Query.DisableFunctions + + // insert config + if trv.insert.fil, err = compileFilter(trc.Insert.Filter); err != nil { + return err + } + trv.insert.cols = toMap(trc.Insert.Columns) + + // update config + if trv.update.fil, err = compileFilter(trc.Update.Filter); err != nil { + return err + } + trv.insert.cols = toMap(trc.Insert.Columns) + trv.insert.set = trc.Insert.Set + + // delete config + if trv.delete.fil, err = compileFilter(trc.Delete.Filter); err != nil { + return err + } + trv.delete.cols = toMap(trc.Delete.Columns) + + singular := flect.Singularize(table) + plural := flect.Pluralize(table) + + if _, ok := com.tr[role]; !ok { + com.tr[role] = make(map[string]*trval) + } + + com.tr[role][singular] = trv + com.tr[role][plural] = trv + return nil } -func (com *Compiler) Compile(query []byte) (*QCode, error) { - var qc QCode +func (com *Compiler) Compile(query []byte, role string) (*QCode, error) { var err error + qc := QCode{Type: QTQuery} + op, err := Parse(query) if err != nil { return nil, err } - qc.Selects, err = com.compileQuery(op) - if err != nil { + if err = com.compileQuery(&qc, op, role); err != nil { return nil, err } - if t, ok := opMap[op.Type]; ok { - qc.Type = t - } else { - return nil, fmt.Errorf("Unknown operation type %d", op.Type) - } - opPool.Put(op) return &qc, nil } -func (com *Compiler) compileQuery(op *Operation) ([]Select, error) { +func (com *Compiler) compileQuery(qc *QCode, op *Operation, role string) error { id := int32(0) parentID := int32(0) + if len(op.Fields) == 0 { + return errors.New("invalid graphql no query found") + } + + if op.Type == opMutate { + if err := com.setMutationType(qc, op.Fields[0].Args); err != nil { + return err + } + } + selects := make([]Select, 0, 5) st := NewStack() + action := qc.Type if len(op.Fields) == 0 { - return nil, errors.New("empty query") + return errors.New("empty query") } st.Push(op.Fields[0].ID) @@ -293,7 +277,7 @@ func (com *Compiler) compileQuery(op *Operation) ([]Select, error) { } if id >= maxSelectors { - return nil, fmt.Errorf("selector limit reached (%d)", maxSelectors) + return fmt.Errorf("selector limit reached (%d)", maxSelectors) } fid := st.Pop() @@ -303,14 +287,28 @@ func (com *Compiler) compileQuery(op *Operation) ([]Select, error) { continue } + trv, ok := com.tr[role][field.Name] + if !ok { + continue + } + selects = append(selects, Select{ ID: id, ParentID: parentID, Table: field.Name, Children: make([]int32, 0, 5), + Allowed: trv.allowedColumns(action), }) s := &selects[(len(selects) - 1)] + if action == QTQuery { + s.Functions = !trv.query.disable.funcs + + if len(trv.query.limit) != 0 { + s.Paging.Limit = trv.query.limit + } + } + if s.ID != 0 { p := &selects[s.ParentID] p.Children = append(p.Children, s.ID) @@ -322,12 +320,13 @@ func (com *Compiler) compileQuery(op *Operation) ([]Select, error) { s.FieldName = s.Table } - err := com.compileArgs(s, field.Args) + err := com.compileArgs(qc, s, field.Args) if err != nil { - return nil, err + return err } s.Cols = make([]Column, 0, len(field.Children)) + action = QTQuery for _, cid := range field.Children { f := op.Fields[cid] @@ -356,36 +355,14 @@ func (com *Compiler) compileQuery(op *Operation) ([]Select, error) { } if id == 0 { - return nil, errors.New("invalid query") + return errors.New("invalid query") } var fil *Exp - root := &selects[0] - switch op.Type { - case opQuery: - fil, _ = com.fm.query[root.Table] - - case opMutate: - switch root.Action { - case ActionInsert: - fil, _ = com.fm.insert[root.Table] - case ActionUpdate: - fil, _ = com.fm.update[root.Table] - case ActionDelete: - fil, _ = com.fm.delete[root.Table] - case ActionUpsert: - fil, _ = com.fm.insert[root.Table] - } - } - - if fil == nil { - fil, _ = com.fm.all[root.Table] - } - - if fil == nil { - fil = com.df + if trv, ok := com.tr[role][op.Fields[0].Name]; ok { + fil = trv.filter(qc.Type) } if fil != nil && fil.Op != OpNop { @@ -403,10 +380,11 @@ func (com *Compiler) compileQuery(op *Operation) ([]Select, error) { } } - return selects[:id], nil + qc.Selects = selects[:id] + return nil } -func (com *Compiler) compileArgs(sel *Select, args []Arg) error { +func (com *Compiler) compileArgs(qc *QCode, sel *Select, args []Arg) error { var err error if com.ka { @@ -418,9 +396,7 @@ func (com *Compiler) compileArgs(sel *Select, args []Arg) error { switch arg.Name { case "id": - if sel.ID == 0 { - err = com.compileArgID(sel, arg) - } + err = com.compileArgID(sel, arg) case "search": err = com.compileArgSearch(sel, arg) case "where": @@ -433,18 +409,6 @@ func (com *Compiler) compileArgs(sel *Select, args []Arg) error { err = com.compileArgLimit(sel, arg) case "offset": err = com.compileArgOffset(sel, arg) - case "insert": - sel.Action = ActionInsert - err = com.compileArgAction(sel, arg) - case "update": - sel.Action = ActionUpdate - err = com.compileArgAction(sel, arg) - case "upsert": - sel.Action = ActionUpsert - err = com.compileArgAction(sel, arg) - case "delete": - sel.Action = ActionDelete - err = com.compileArgAction(sel, arg) } if err != nil { @@ -461,6 +425,45 @@ func (com *Compiler) compileArgs(sel *Select, args []Arg) error { return nil } +func (com *Compiler) setMutationType(qc *QCode, args []Arg) error { + setActionVar := func(arg *Arg) error { + if arg.Val.Type != nodeVar { + return fmt.Errorf("value for argument '%s' must be a variable", arg.Name) + } + qc.ActionVar = arg.Val.Val + return nil + } + + for i := range args { + arg := &args[i] + + switch arg.Name { + case "insert": + qc.Type = QTInsert + return setActionVar(arg) + case "update": + qc.Type = QTUpdate + return setActionVar(arg) + case "upsert": + qc.Type = QTUpsert + return setActionVar(arg) + case "delete": + qc.Type = QTDelete + + if arg.Val.Type != nodeBool { + return fmt.Errorf("value for argument '%s' must be a boolean", arg.Name) + } + + if arg.Val.Val == "false" { + qc.Type = QTQuery + } + return nil + } + } + + return nil +} + func (com *Compiler) compileArgObj(st *util.Stack, arg *Arg) (*Exp, error) { if arg.Val.Type != nodeObj { return nil, fmt.Errorf("expecting an object") @@ -540,6 +543,10 @@ func (com *Compiler) compileArgNode(st *util.Stack, node *Node, usePool bool) (* } func (com *Compiler) compileArgID(sel *Select, arg *Arg) error { + if sel.ID != 0 { + return nil + } + if sel.Where != nil && sel.Where.Op == OpEqID { return nil } @@ -732,26 +739,6 @@ func (com *Compiler) compileArgOffset(sel *Select, arg *Arg) error { return nil } -func (com *Compiler) compileArgAction(sel *Select, arg *Arg) error { - switch sel.Action { - case ActionDelete: - if arg.Val.Type != nodeBool { - return fmt.Errorf("value for argument '%s' must be a boolean", arg.Name) - } - if arg.Val.Val == "false" { - sel.Action = 0 - } - - default: - if arg.Val.Type != nodeVar { - return fmt.Errorf("value for argument '%s' must be a variable", arg.Name) - } - sel.ActionVar = arg.Val.Val - } - - return nil -} - func newExp(st *util.Stack, node *Node, usePool bool) (*Exp, error) { name := node.Name if name[0] == '_' { diff --git a/serv/config.go b/serv/config.go index 1fb7f63..b2ebf21 100644 --- a/serv/config.go +++ b/serv/config.go @@ -71,18 +71,14 @@ type config struct { } `mapstructure:"database"` Tables []configTable + Roles []configRoles } type configTable struct { - Name string - Filter []string - FilterQuery []string `mapstructure:"filter_query"` - FilterInsert []string `mapstructure:"filter_insert"` - FilterUpdate []string `mapstructure:"filter_update"` - FilterDelete []string `mapstructure:"filter_delete"` - Table string - Blocklist []string - Remotes []configRemote + Name string + Table string + Blocklist []string + Remotes []configRemote } type configRemote struct { @@ -98,6 +94,41 @@ type configRemote struct { } `mapstructure:"set_headers"` } +type configRoles struct { + Name string + Tables []struct { + Name string + + Query struct { + Limit int + Filter []string + Columns []string + DisableAggregation bool `mapstructure:"disable_aggregation"` + Deny bool + } + + Insert struct { + Filter []string + Columns []string + Set map[string]string + Deny bool + } + + Update struct { + Filter []string + Columns []string + Set map[string]string + Deny bool + } + + Delete struct { + Filter []string + Columns []string + Deny bool + } + } +} + func newConfig() *viper.Viper { vi := viper.New() diff --git a/serv/core.go b/serv/core.go index 9c7f3ee..8da4eab 100644 --- a/serv/core.go +++ b/serv/core.go @@ -59,7 +59,7 @@ func (c *coreContext) execQuery() ([]byte, error) { } else { - qc, err = qcompile.Compile([]byte(c.req.Query)) + qc, err = qcompile.Compile([]byte(c.req.Query), "user") if err != nil { return nil, err } diff --git a/serv/prepare.go b/serv/prepare.go index bf9a475..ac0eb5f 100644 --- a/serv/prepare.go +++ b/serv/prepare.go @@ -40,7 +40,7 @@ func prepareStmt(key, gql string, varBytes json.RawMessage) error { return nil } - qc, err := qcompile.Compile([]byte(gql)) + qc, err := qcompile.Compile([]byte(gql), "user") if err != nil { return err } diff --git a/serv/serv.go b/serv/serv.go index 1006520..980eb23 100644 --- a/serv/serv.go +++ b/serv/serv.go @@ -12,7 +12,6 @@ import ( rice "github.com/GeertJohan/go.rice" "github.com/dosco/super-graph/psql" "github.com/dosco/super-graph/qcode" - "github.com/gobuffalo/flect" ) func initCompilers(c *config) (*qcode.Compiler, *psql.Compiler, error) { @@ -22,49 +21,50 @@ func initCompilers(c *config) (*qcode.Compiler, *psql.Compiler, error) { } conf := qcode.Config{ - DefaultFilter: c.DB.Defaults.Filter, - FilterMap: qcode.Filters{ - All: make(map[string][]string, len(c.Tables)), - Query: make(map[string][]string, len(c.Tables)), - Insert: make(map[string][]string, len(c.Tables)), - Update: make(map[string][]string, len(c.Tables)), - Delete: make(map[string][]string, len(c.Tables)), - }, Blocklist: c.DB.Defaults.Blocklist, KeepArgs: false, } - for i := range c.Tables { - t := c.Tables[i] - - singular := flect.Singularize(t.Name) - plural := flect.Pluralize(t.Name) - - setFilter := func(fm map[string][]string, fil []string) { - switch { - case len(fil) == 0: - return - case fil[0] == "none" || len(fil[0]) == 0: - fm[singular] = []string{} - fm[plural] = []string{} - default: - fm[singular] = t.Filter - fm[plural] = t.Filter - } - } - - setFilter(conf.FilterMap.All, t.Filter) - setFilter(conf.FilterMap.Query, t.FilterQuery) - setFilter(conf.FilterMap.Insert, t.FilterInsert) - setFilter(conf.FilterMap.Update, t.FilterUpdate) - setFilter(conf.FilterMap.Delete, t.FilterDelete) - } - qc, err := qcode.NewCompiler(conf) if err != nil { return nil, nil, err } + for _, r := range c.Roles { + for _, t := range r.Tables { + query := qcode.QueryConfig{ + Limit: t.Query.Limit, + Filter: t.Query.Filter, + Columns: t.Query.Columns, + DisableFunctions: t.Query.DisableAggregation, + } + + insert := qcode.InsertConfig{ + Filter: t.Insert.Filter, + Columns: t.Insert.Columns, + Set: t.Insert.Set, + } + + update := qcode.UpdateConfig{ + Filter: t.Insert.Filter, + Columns: t.Insert.Columns, + Set: t.Insert.Set, + } + + delete := qcode.DeleteConfig{ + Filter: t.Insert.Filter, + Columns: t.Insert.Columns, + } + + qc.AddRole(r.Name, t.Name, qcode.TRConfig{ + Query: query, + Insert: insert, + Update: update, + Delete: delete, + }) + } + } + pc := psql.NewCompiler(psql.Config{ Schema: schema, Vars: c.getVariables(), From c797deb4d069142914dd29daeb96421f32e2891c Mon Sep 17 00:00:00 2001 From: Vikram Rangnekar Date: Tue, 15 Oct 2019 02:30:19 -0400 Subject: [PATCH 3/6] Add built in 'anon' and 'user' roles --- config/allow.list | 202 +++++++++++++++++------ config/dev.yml | 4 +- examples/rails-app/Gemfile.lock | 273 ++++++++++++++++++++++++++++++++ psql/query.go | 18 ++- qcode/qcode.go | 15 +- serv/cmd.go | 5 +- serv/cmd_seed.go | 2 +- serv/core.go | 16 +- serv/http.go | 37 +---- serv/introsp.go | 36 +++++ serv/prepare.go | 2 +- serv/sqllog.go | 45 ++++++ 12 files changed, 553 insertions(+), 102 deletions(-) create mode 100644 serv/introsp.go create mode 100644 serv/sqllog.go diff --git a/config/allow.list b/config/allow.list index a17a562..196e46b 100644 --- a/config/allow.list +++ b/config/allow.list @@ -1,5 +1,27 @@ # http://localhost:8080/ +variables { + "data": [ + { + "name": "Protect Ya Neck", + "created_at": "now", + "updated_at": "now" + }, + { + "name": "Enter the Wu-Tang", + "created_at": "now", + "updated_at": "now" + } + ] +} + +mutation { + products(insert: $data) { + id + name + } +} + variables { "update": { "name": "Wu-Tang", @@ -16,16 +38,16 @@ mutation { } } -variables { - "data": { - "product_id": 5 - } -} - -mutation { - products(id: $product_id, delete: true) { +query { + users { id - name + email + picture: avatar + products(limit: 2, where: {price: {gt: 10}}) { + id + name + description + } } } @@ -73,6 +95,118 @@ query { } } +variables { + "data": [ + { + "name": "Gumbo1", + "created_at": "now", + "updated_at": "now" + }, + { + "name": "Gumbo2", + "created_at": "now", + "updated_at": "now" + } + ] +} + +query { + products { + id + name + user { + email + } + } +} + +variables { + "data": { + "product_id": 5 + } +} + +mutation { + products(id: $product_id, delete: true) { + id + name + } +} + +variables { + "data": [ + { + "name": "Gumbo1", + "created_at": "now", + "updated_at": "now" + }, + { + "name": "Gumbo2", + "created_at": "now", + "updated_at": "now" + } + ] +} + +query { + products { + id + name + price + users { + email + } + } +} + + +variables { + "data": { + "email": "gfk@myspace.com", + "full_name": "Ghostface Killah", + "created_at": "now", + "updated_at": "now" + } +} + +mutation { + user(insert: $data) { + id + } +} + +variables { + "data": [ + { + "name": "Gumbo1", + "created_at": "now", + "updated_at": "now" + }, + { + "name": "Gumbo2", + "created_at": "now", + "updated_at": "now" + } + ] +} + +query { + products { + id + name + users { + email + } + } +} + +query { + me { + id + email + full_name + } +} variables { "update": { @@ -112,62 +246,30 @@ query { } } - -query { - me { - id - email - full_name - } -} - -variables { - "data": { - "email": "gfk@myspace.com", - "full_name": "Ghostface Killah", - "created_at": "now", - "updated_at": "now" - } -} - -mutation { - user(insert: $data) { - id - } -} - -query { - users { - id - email - picture: avatar - products(limit: 2, where: {price: {gt: 10}}) { - id - name - description - } - } -} - variables { "data": [ { - "name": "Protect Ya Neck", + "name": "Gumbo1", "created_at": "now", "updated_at": "now" }, { - "name": "Enter the Wu-Tang", + "name": "Gumbo2", "created_at": "now", "updated_at": "now" } ] } -mutation { - products(insert: $data) { +query { + products { id name + description + users { + email + } } } + diff --git a/config/dev.yml b/config/dev.yml index a314dcd..c94e4bb 100644 --- a/config/dev.yml +++ b/config/dev.yml @@ -193,5 +193,5 @@ roles: tables: - name: users - select: - filter: ["{ account_id: { _eq: $account_id } }"] + select: + filter: ["{ account_id: { _eq: $account_id } }"] diff --git a/examples/rails-app/Gemfile.lock b/examples/rails-app/Gemfile.lock index e69de29..7a6e370 100644 --- a/examples/rails-app/Gemfile.lock +++ b/examples/rails-app/Gemfile.lock @@ -0,0 +1,273 @@ +GIT + remote: https://github.com/stympy/faker.git + revision: 4e9144825fcc9ba5c83cc0fd037779ab82f3120b + branch: master + specs: + faker (2.6.0) + i18n (>= 1.6, < 1.8) + +GEM + remote: https://rubygems.org/ + specs: + actioncable (6.0.0) + actionpack (= 6.0.0) + nio4r (~> 2.0) + websocket-driver (>= 0.6.1) + actionmailbox (6.0.0) + actionpack (= 6.0.0) + activejob (= 6.0.0) + activerecord (= 6.0.0) + activestorage (= 6.0.0) + activesupport (= 6.0.0) + mail (>= 2.7.1) + actionmailer (6.0.0) + actionpack (= 6.0.0) + actionview (= 6.0.0) + activejob (= 6.0.0) + mail (~> 2.5, >= 2.5.4) + rails-dom-testing (~> 2.0) + actionpack (6.0.0) + actionview (= 6.0.0) + activesupport (= 6.0.0) + rack (~> 2.0) + rack-test (>= 0.6.3) + rails-dom-testing (~> 2.0) + rails-html-sanitizer (~> 1.0, >= 1.2.0) + actiontext (6.0.0) + actionpack (= 6.0.0) + activerecord (= 6.0.0) + activestorage (= 6.0.0) + activesupport (= 6.0.0) + nokogiri (>= 1.8.5) + actionview (6.0.0) + activesupport (= 6.0.0) + builder (~> 3.1) + erubi (~> 1.4) + rails-dom-testing (~> 2.0) + rails-html-sanitizer (~> 1.1, >= 1.2.0) + activejob (6.0.0) + activesupport (= 6.0.0) + globalid (>= 0.3.6) + activemodel (6.0.0) + activesupport (= 6.0.0) + activerecord (6.0.0) + activemodel (= 6.0.0) + activesupport (= 6.0.0) + activestorage (6.0.0) + actionpack (= 6.0.0) + activejob (= 6.0.0) + activerecord (= 6.0.0) + marcel (~> 0.3.1) + activesupport (6.0.0) + concurrent-ruby (~> 1.0, >= 1.0.2) + i18n (>= 0.7, < 2) + minitest (~> 5.1) + tzinfo (~> 1.1) + zeitwerk (~> 2.1, >= 2.1.8) + addressable (2.7.0) + public_suffix (>= 2.0.2, < 5.0) + archive-zip (0.12.0) + io-like (~> 0.3.0) + bcrypt (3.1.13) + bindex (0.8.1) + bootsnap (1.4.5) + msgpack (~> 1.0) + builder (3.2.3) + byebug (11.0.1) + capybara (3.29.0) + addressable + mini_mime (>= 0.1.3) + nokogiri (~> 1.8) + rack (>= 1.6.0) + rack-test (>= 0.6.3) + regexp_parser (~> 1.5) + xpath (~> 3.2) + childprocess (3.0.0) + chromedriver-helper (2.1.1) + archive-zip (~> 0.10) + nokogiri (~> 1.8) + coffee-rails (4.2.2) + coffee-script (>= 2.2.0) + railties (>= 4.0.0) + coffee-script (2.4.1) + coffee-script-source + execjs + coffee-script-source (1.12.2) + concurrent-ruby (1.1.5) + crass (1.0.4) + devise (4.7.1) + bcrypt (~> 3.0) + orm_adapter (~> 0.1) + railties (>= 4.1.0) + responders + warden (~> 1.2.3) + erubi (1.9.0) + execjs (2.7.0) + ffi (1.11.1) + globalid (0.4.2) + activesupport (>= 4.2.0) + i18n (1.7.0) + concurrent-ruby (~> 1.0) + io-like (0.3.0) + jbuilder (2.9.1) + activesupport (>= 4.2.0) + listen (3.1.5) + rb-fsevent (~> 0.9, >= 0.9.4) + rb-inotify (~> 0.9, >= 0.9.7) + ruby_dep (~> 1.2) + loofah (2.3.0) + crass (~> 1.0.2) + nokogiri (>= 1.5.9) + mail (2.7.1) + mini_mime (>= 0.1.1) + marcel (0.3.3) + mimemagic (~> 0.3.2) + method_source (0.9.2) + mimemagic (0.3.3) + mini_mime (1.0.2) + mini_portile2 (2.4.0) + minitest (5.12.2) + msgpack (1.3.1) + nio4r (2.5.2) + nokogiri (1.10.4) + mini_portile2 (~> 2.4.0) + orm_adapter (0.5.0) + pg (1.1.4) + public_suffix (4.0.1) + puma (3.12.1) + rack (2.0.7) + rack-test (1.1.0) + rack (>= 1.0, < 3) + rails (6.0.0) + actioncable (= 6.0.0) + actionmailbox (= 6.0.0) + actionmailer (= 6.0.0) + actionpack (= 6.0.0) + actiontext (= 6.0.0) + actionview (= 6.0.0) + activejob (= 6.0.0) + activemodel (= 6.0.0) + activerecord (= 6.0.0) + activestorage (= 6.0.0) + activesupport (= 6.0.0) + bundler (>= 1.3.0) + railties (= 6.0.0) + sprockets-rails (>= 2.0.0) + rails-dom-testing (2.0.3) + activesupport (>= 4.2.0) + nokogiri (>= 1.6) + rails-html-sanitizer (1.3.0) + loofah (~> 2.3) + railties (6.0.0) + actionpack (= 6.0.0) + activesupport (= 6.0.0) + method_source + rake (>= 0.8.7) + thor (>= 0.20.3, < 2.0) + rake (13.0.0) + rb-fsevent (0.10.3) + rb-inotify (0.10.0) + ffi (~> 1.0) + redis (4.1.3) + redis-actionpack (5.1.0) + actionpack (>= 4.0, < 7) + redis-rack (>= 1, < 3) + redis-store (>= 1.1.0, < 2) + redis-activesupport (5.2.0) + activesupport (>= 3, < 7) + redis-store (>= 1.3, < 2) + redis-rack (2.0.6) + rack (>= 1.5, < 3) + redis-store (>= 1.2, < 2) + redis-rails (5.0.2) + redis-actionpack (>= 5.0, < 6) + redis-activesupport (>= 5.0, < 6) + redis-store (>= 1.2, < 2) + redis-store (1.8.0) + redis (>= 4, < 5) + regexp_parser (1.6.0) + responders (3.0.0) + actionpack (>= 5.0) + railties (>= 5.0) + ruby_dep (1.5.0) + rubyzip (2.0.0) + sass (3.7.4) + sass-listen (~> 4.0.0) + sass-listen (4.0.0) + rb-fsevent (~> 0.9, >= 0.9.4) + rb-inotify (~> 0.9, >= 0.9.7) + sass-rails (5.1.0) + railties (>= 5.2.0) + sass (~> 3.1) + sprockets (>= 2.8, < 4.0) + sprockets-rails (>= 2.0, < 4.0) + tilt (>= 1.1, < 3) + selenium-webdriver (3.142.6) + childprocess (>= 0.5, < 4.0) + rubyzip (>= 1.2.2) + spring (2.1.0) + spring-watcher-listen (2.0.1) + listen (>= 2.7, < 4.0) + spring (>= 1.2, < 3.0) + sprockets (3.7.2) + concurrent-ruby (~> 1.0) + rack (> 1, < 3) + sprockets-rails (3.2.1) + actionpack (>= 4.0) + activesupport (>= 4.0) + sprockets (>= 3.0.0) + thor (0.20.3) + thread_safe (0.3.6) + tilt (2.0.10) + turbolinks (5.2.1) + turbolinks-source (~> 5.2) + turbolinks-source (5.2.0) + tzinfo (1.2.5) + thread_safe (~> 0.1) + uglifier (4.2.0) + execjs (>= 0.3.0, < 3) + warden (1.2.8) + rack (>= 2.0.6) + web-console (4.0.1) + actionview (>= 6.0.0) + activemodel (>= 6.0.0) + bindex (>= 0.4.0) + railties (>= 6.0.0) + websocket-driver (0.7.1) + websocket-extensions (>= 0.1.0) + websocket-extensions (0.1.4) + xpath (3.2.0) + nokogiri (~> 1.8) + zeitwerk (2.2.0) + +PLATFORMS + ruby + +DEPENDENCIES + bootsnap (>= 1.1.0) + byebug + capybara (>= 2.15) + chromedriver-helper + coffee-rails (~> 4.2) + devise + faker! + jbuilder (~> 2.5) + listen (>= 3.0.5, < 3.2) + pg (>= 0.18, < 2.0) + puma (~> 3.11) + rails (~> 6.0.0.rc1) + redis-rails + sass-rails (~> 5.0) + selenium-webdriver + spring + spring-watcher-listen (~> 2.0.0) + turbolinks (~> 5) + tzinfo-data + uglifier (>= 1.3.0) + web-console (>= 3.3.0) + +RUBY VERSION + ruby 2.5.7p206 + +BUNDLED WITH + 1.17.3 diff --git a/psql/query.go b/psql/query.go index 8d70e43..70ea0a2 100644 --- a/psql/query.go +++ b/psql/query.go @@ -435,10 +435,24 @@ func (c *compilerContext) renderJoinedColumns(sel *qcode.Select, ti *DBTableInfo } childSel := &c.s[id] + cti, err := c.schema.GetTable(childSel.Table) + if err != nil { + continue + } + //fmt.Fprintf(w, `"%s_%d_join"."%s" AS "%s"`, //s.Table, s.ID, s.Table, s.FieldName) - colWithTableIDSuffixAlias(c.w, childSel.Table, childSel.ID, - "_join", childSel.Table, childSel.FieldName) + if cti.Singular { + c.w.WriteString(`"sel_json_`) + int2string(c.w, childSel.ID) + c.w.WriteString(`" AS "`) + c.w.WriteString(childSel.FieldName) + c.w.WriteString(`"`) + + } else { + colWithTableIDSuffixAlias(c.w, childSel.Table, childSel.ID, + "_join", childSel.Table, childSel.FieldName) + } } return nil diff --git a/qcode/qcode.go b/qcode/qcode.go index 05e8e98..30bc724 100644 --- a/qcode/qcode.go +++ b/qcode/qcode.go @@ -287,10 +287,7 @@ func (com *Compiler) compileQuery(qc *QCode, op *Operation, role string) error { continue } - trv, ok := com.tr[role][field.Name] - if !ok { - continue - } + trv := com.getRole(role, field.Name) selects = append(selects, Select{ ID: id, @@ -739,6 +736,16 @@ func (com *Compiler) compileArgOffset(sel *Select, arg *Arg) error { return nil } +var zeroTrv = &trval{} + +func (com *Compiler) getRole(role, field string) *trval { + if trv, ok := com.tr[role][field]; ok { + return trv + } else { + return zeroTrv + } +} + func newExp(st *util.Stack, node *Node, usePool bool) (*Exp, error) { name := node.Name if name[0] == '_' { diff --git a/serv/cmd.go b/serv/cmd.go index 87c1660..e73b4fb 100644 --- a/serv/cmd.go +++ b/serv/cmd.go @@ -10,7 +10,6 @@ import ( "github.com/dosco/super-graph/qcode" "github.com/gobuffalo/flect" "github.com/jackc/pgx/v4" - "github.com/jackc/pgx/v4/log/zerologadapter" "github.com/jackc/pgx/v4/pgxpool" "github.com/rs/zerolog" "github.com/spf13/cobra" @@ -217,7 +216,7 @@ func initDB(c *config, useDB bool) (*pgx.Conn, error) { config.LogLevel = pgx.LogLevelNone } - config.Logger = zerologadapter.NewLogger(*logger) + config.Logger = NewSQLLogger(*logger) db, err := pgx.ConnectConfig(context.Background(), config) if err != nil { @@ -252,7 +251,7 @@ func initDBPool(c *config) (*pgxpool.Pool, error) { config.ConnConfig.LogLevel = pgx.LogLevelNone } - config.ConnConfig.Logger = zerologadapter.NewLogger(*logger) + config.ConnConfig.Logger = NewSQLLogger(*logger) // if c.DB.MaxRetries != 0 { // opt.MaxRetries = c.DB.MaxRetries diff --git a/serv/cmd_seed.go b/serv/cmd_seed.go index f0cc2d4..f9b152e 100644 --- a/serv/cmd_seed.go +++ b/serv/cmd_seed.go @@ -67,7 +67,7 @@ func graphQLFunc(query string, data interface{}) map[string]interface{} { c.req.Query = query c.req.Vars = b - res, err := c.execQuery() + res, err := c.execQuery("user") if err != nil { logger.Fatal().Err(err).Msg("graphql query failed") } diff --git a/serv/core.go b/serv/core.go index 8da4eab..edc4779 100644 --- a/serv/core.go +++ b/serv/core.go @@ -32,7 +32,15 @@ func (c *coreContext) handleReq(w io.Writer, req *http.Request) error { c.req.ref = req.Referer() c.req.hdr = req.Header - b, err := c.execQuery() + var role string + + if authCheck(c) { + role = "user" + } else { + role = "anon" + } + + b, err := c.execQuery(role) if err != nil { return err } @@ -40,12 +48,14 @@ func (c *coreContext) handleReq(w io.Writer, req *http.Request) error { return c.render(w, b) } -func (c *coreContext) execQuery() ([]byte, error) { +func (c *coreContext) execQuery(role string) ([]byte, error) { var err error var skipped uint32 var qc *qcode.QCode var data []byte + logger.Debug().Str("role", role).Msg(c.req.Query) + if conf.UseAllowList { var ps *preparedItem @@ -59,7 +69,7 @@ func (c *coreContext) execQuery() ([]byte, error) { } else { - qc, err = qcompile.Compile([]byte(c.req.Query), "user") + qc, err = qcompile.Compile([]byte(c.req.Query), role) if err != nil { return nil, err } diff --git a/serv/http.go b/serv/http.go index c943110..d419d8b 100644 --- a/serv/http.go +++ b/serv/http.go @@ -94,42 +94,7 @@ func apiv1Http(w http.ResponseWriter, r *http.Request) { } if strings.EqualFold(ctx.req.OpName, introspectionQuery) { - // dat, err := ioutil.ReadFile("test.schema") - // if err != nil { - // http.Error(w, err.Error(), http.StatusInternalServerError) - // return - // } - //w.Write(dat) - w.Header().Set("Content-Type", "application/json") - w.Write([]byte(`{ - "data": { - "__schema": { - "queryType": { - "name": "Query" - }, - "mutationType": null, - "subscriptionType": null - } - }, - "extensions":{ - "tracing":{ - "version":1, - "startTime":"2019-06-04T19:53:31.093Z", - "endTime":"2019-06-04T19:53:31.108Z", - "duration":15219720, - "execution": { - "resolvers": [{ - "path": ["__schema"], - "parentType": "Query", - "fieldName": "__schema", - "returnType": "__Schema!", - "startOffset": 50950, - "duration": 17187 - }] - } - } - } - }`)) + introspect(w) return } diff --git a/serv/introsp.go b/serv/introsp.go new file mode 100644 index 0000000..2fbf26f --- /dev/null +++ b/serv/introsp.go @@ -0,0 +1,36 @@ +package serv + +import "net/http" + +func introspect(w http.ResponseWriter) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{ + "data": { + "__schema": { + "queryType": { + "name": "Query" + }, + "mutationType": null, + "subscriptionType": null + } + }, + "extensions":{ + "tracing":{ + "version":1, + "startTime":"2019-06-04T19:53:31.093Z", + "endTime":"2019-06-04T19:53:31.108Z", + "duration":15219720, + "execution": { + "resolvers": [{ + "path": ["__schema"], + "parentType": "Query", + "fieldName": "__schema", + "returnType": "__Schema!", + "startOffset": 50950, + "duration": 17187 + }] + } + } + } + }`)) +} diff --git a/serv/prepare.go b/serv/prepare.go index ac0eb5f..697d80a 100644 --- a/serv/prepare.go +++ b/serv/prepare.go @@ -30,7 +30,7 @@ func initPreparedList() { for k, v := range _allowList.list { err := prepareStmt(k, v.gql, v.vars) if err != nil { - logger.Warn().Err(err).Send() + logger.Warn().Str("gql", v.gql).Err(err).Send() } } } diff --git a/serv/sqllog.go b/serv/sqllog.go new file mode 100644 index 0000000..3fccbea --- /dev/null +++ b/serv/sqllog.go @@ -0,0 +1,45 @@ +package serv + +import ( + "context" + + "github.com/jackc/pgx/v4" + "github.com/rs/zerolog" +) + +type Logger struct { + logger zerolog.Logger +} + +// NewLogger accepts a zerolog.Logger as input and returns a new custom pgx +// logging fascade as output. +func NewSQLLogger(logger zerolog.Logger) *Logger { + return &Logger{ + logger: logger.With().Logger(), + } +} + +func (pl *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) { + var zlevel zerolog.Level + switch level { + case pgx.LogLevelNone: + zlevel = zerolog.NoLevel + case pgx.LogLevelError: + zlevel = zerolog.ErrorLevel + case pgx.LogLevelWarn: + zlevel = zerolog.WarnLevel + case pgx.LogLevelInfo: + zlevel = zerolog.InfoLevel + case pgx.LogLevelDebug: + zlevel = zerolog.DebugLevel + default: + zlevel = zerolog.DebugLevel + } + + if sql, ok := data["sql"]; ok { + delete(data, "sql") + pl.logger.WithLevel(zlevel).Fields(data).Msg(sql.(string)) + } else { + pl.logger.WithLevel(zlevel).Fields(data).Msg(msg) + } +} From 6bc66d28bc83a82d99cc521318d1acfd3260c7ca Mon Sep 17 00:00:00 2001 From: Vikram Rangnekar Date: Thu, 24 Oct 2019 02:07:42 -0400 Subject: [PATCH 4/6] Get RBAC working for queries and mutations --- config/dev.yml | 17 +- config/prod.yml | 4 - psql/mutate.go | 83 ++++---- psql/query.go | 463 ++++++++++++++++++++++----------------------- serv/allow.go | 8 +- serv/auth.go | 34 +++- serv/auth_jwt.go | 6 - serv/auth_rails.go | 23 +-- serv/cmd.go | 27 ++- serv/cmd_seed.go | 3 +- serv/config.go | 55 +++--- serv/core.go | 293 +++++++++++++++------------- serv/core_build.go | 144 ++++++++++++++ serv/http.go | 7 +- serv/prepare.go | 125 ++++++++---- serv/serv.go | 2 +- serv/utils.go | 29 ++- serv/vars.go | 24 ++- tmpl/prod.yml | 123 ++++++++---- 19 files changed, 902 insertions(+), 568 deletions(-) create mode 100644 serv/core_build.go diff --git a/config/dev.yml b/config/dev.yml index c94e4bb..ade488b 100644 --- a/config/dev.yml +++ b/config/dev.yml @@ -22,7 +22,7 @@ enable_tracing: true # Watch the config folder and reload Super Graph # with the new configs when a change is detected -reload_on_config_change: false +reload_on_config_change: true # File that points to the database seeding script # seed_file: seed.js @@ -53,7 +53,7 @@ auth: # Comment this out if you want to disable setting # the user_id via a header. Good for testing - header: X-User-ID + creds_in_header: true rails: # Rails version this is used for reading the @@ -143,6 +143,8 @@ tables: name: me table: users +roles_query: "SELECT * FROM users as usr WHERE id = $user_id" + roles: - name: anon tables: @@ -164,6 +166,10 @@ roles: - name: user tables: + - name: users + query: + filter: ["{ id: { _eq: $user_id } }"] + - name: products query: @@ -189,9 +195,10 @@ roles: delete: deny: true - - name: manager + - name: admin + match: id = 1 tables: - name: users - select: - filter: ["{ account_id: { _eq: $account_id } }"] + # select: + # filter: ["{ account_id: { _eq: $account_id } }"] diff --git a/config/prod.yml b/config/prod.yml index fa5f932..a52af3d 100644 --- a/config/prod.yml +++ b/config/prod.yml @@ -47,10 +47,6 @@ auth: type: rails cookie: _app_session - # Comment this out if you want to disable setting - # the user_id via a header. Good for testing - header: X-User-ID - rails: # Rails version this is used for reading the # various cookies formats. diff --git a/psql/mutate.go b/psql/mutate.go index 63b3a5a..067270e 100644 --- a/psql/mutate.go +++ b/psql/mutate.go @@ -1,7 +1,6 @@ package psql import ( - "bytes" "errors" "fmt" "io" @@ -12,7 +11,7 @@ import ( var noLimit = qcode.Paging{NoLimit: true} -func (co *Compiler) compileMutation(qc *qcode.QCode, w *bytes.Buffer, vars Variables) (uint32, error) { +func (co *Compiler) compileMutation(qc *qcode.QCode, w io.Writer, vars Variables) (uint32, error) { if len(qc.Selects) == 0 { return 0, errors.New("empty query") } @@ -25,9 +24,9 @@ func (co *Compiler) compileMutation(qc *qcode.QCode, w *bytes.Buffer, vars Varia return 0, err } - c.w.WriteString(`WITH `) + io.WriteString(c.w, `WITH `) quoted(c.w, ti.Name) - c.w.WriteString(` AS `) + io.WriteString(c.w, ` AS `) switch qc.Type { case qcode.QTInsert: @@ -67,7 +66,7 @@ func (co *Compiler) compileMutation(qc *qcode.QCode, w *bytes.Buffer, vars Varia return c.compileQuery(qc, w) } -func (c *compilerContext) renderInsert(qc *qcode.QCode, w *bytes.Buffer, +func (c *compilerContext) renderInsert(qc *qcode.QCode, w io.Writer, vars Variables, ti *DBTableInfo) (uint32, error) { insert, ok := vars[qc.ActionVar] @@ -80,32 +79,32 @@ func (c *compilerContext) renderInsert(qc *qcode.QCode, w *bytes.Buffer, return 0, err } - c.w.WriteString(`(WITH "input" AS (SELECT {{`) - c.w.WriteString(qc.ActionVar) - c.w.WriteString(`}}::json AS j) INSERT INTO `) + io.WriteString(c.w, `(WITH "input" AS (SELECT {{`) + io.WriteString(c.w, qc.ActionVar) + io.WriteString(c.w, `}}::json AS j) INSERT INTO `) quoted(c.w, ti.Name) io.WriteString(c.w, ` (`) c.renderInsertUpdateColumns(qc, w, jt, ti) io.WriteString(c.w, `)`) - c.w.WriteString(` SELECT `) + io.WriteString(c.w, ` SELECT `) c.renderInsertUpdateColumns(qc, w, jt, ti) - c.w.WriteString(` FROM input i, `) + io.WriteString(c.w, ` FROM input i, `) if array { - c.w.WriteString(`json_populate_recordset`) + io.WriteString(c.w, `json_populate_recordset`) } else { - c.w.WriteString(`json_populate_record`) + io.WriteString(c.w, `json_populate_record`) } - c.w.WriteString(`(NULL::`) - c.w.WriteString(ti.Name) - c.w.WriteString(`, i.j) t`) + io.WriteString(c.w, `(NULL::`) + io.WriteString(c.w, ti.Name) + io.WriteString(c.w, `, i.j) t`) return 0, nil } -func (c *compilerContext) renderInsertUpdateColumns(qc *qcode.QCode, w *bytes.Buffer, +func (c *compilerContext) renderInsertUpdateColumns(qc *qcode.QCode, w io.Writer, jt map[string]interface{}, ti *DBTableInfo) (uint32, error) { root := &qc.Selects[0] @@ -122,14 +121,14 @@ func (c *compilerContext) renderInsertUpdateColumns(qc *qcode.QCode, w *bytes.Bu if i != 0 { io.WriteString(c.w, `, `) } - c.w.WriteString(cn) + io.WriteString(c.w, cn) i++ } return 0, nil } -func (c *compilerContext) renderUpdate(qc *qcode.QCode, w *bytes.Buffer, +func (c *compilerContext) renderUpdate(qc *qcode.QCode, w io.Writer, vars Variables, ti *DBTableInfo) (uint32, error) { root := &qc.Selects[0] @@ -143,26 +142,26 @@ func (c *compilerContext) renderUpdate(qc *qcode.QCode, w *bytes.Buffer, return 0, err } - c.w.WriteString(`(WITH "input" AS (SELECT {{`) - c.w.WriteString(qc.ActionVar) - c.w.WriteString(`}}::json AS j) UPDATE `) + io.WriteString(c.w, `(WITH "input" AS (SELECT {{`) + io.WriteString(c.w, qc.ActionVar) + io.WriteString(c.w, `}}::json AS j) UPDATE `) quoted(c.w, ti.Name) io.WriteString(c.w, ` SET (`) c.renderInsertUpdateColumns(qc, w, jt, ti) - c.w.WriteString(`) = (SELECT `) + io.WriteString(c.w, `) = (SELECT `) c.renderInsertUpdateColumns(qc, w, jt, ti) - c.w.WriteString(` FROM input i, `) + io.WriteString(c.w, ` FROM input i, `) if array { - c.w.WriteString(`json_populate_recordset`) + io.WriteString(c.w, `json_populate_recordset`) } else { - c.w.WriteString(`json_populate_record`) + io.WriteString(c.w, `json_populate_record`) } - c.w.WriteString(`(NULL::`) - c.w.WriteString(ti.Name) - c.w.WriteString(`, i.j) t)`) + io.WriteString(c.w, `(NULL::`) + io.WriteString(c.w, ti.Name) + io.WriteString(c.w, `, i.j) t)`) io.WriteString(c.w, ` WHERE `) @@ -173,11 +172,11 @@ func (c *compilerContext) renderUpdate(qc *qcode.QCode, w *bytes.Buffer, return 0, nil } -func (c *compilerContext) renderDelete(qc *qcode.QCode, w *bytes.Buffer, +func (c *compilerContext) renderDelete(qc *qcode.QCode, w io.Writer, vars Variables, ti *DBTableInfo) (uint32, error) { root := &qc.Selects[0] - c.w.WriteString(`(DELETE FROM `) + io.WriteString(c.w, `(DELETE FROM `) quoted(c.w, ti.Name) io.WriteString(c.w, ` WHERE `) @@ -188,7 +187,7 @@ func (c *compilerContext) renderDelete(qc *qcode.QCode, w *bytes.Buffer, return 0, nil } -func (c *compilerContext) renderUpsert(qc *qcode.QCode, w *bytes.Buffer, +func (c *compilerContext) renderUpsert(qc *qcode.QCode, w io.Writer, vars Variables, ti *DBTableInfo) (uint32, error) { upsert, ok := vars[qc.ActionVar] @@ -205,7 +204,7 @@ func (c *compilerContext) renderUpsert(qc *qcode.QCode, w *bytes.Buffer, return 0, err } - c.w.WriteString(` ON CONFLICT DO (`) + io.WriteString(c.w, ` ON CONFLICT DO (`) i := 0 for _, cn := range ti.ColumnNames { @@ -220,15 +219,15 @@ func (c *compilerContext) renderUpsert(qc *qcode.QCode, w *bytes.Buffer, if i != 0 { io.WriteString(c.w, `, `) } - c.w.WriteString(cn) + io.WriteString(c.w, cn) i++ } if i == 0 { - c.w.WriteString(ti.PrimaryCol) + io.WriteString(c.w, ti.PrimaryCol) } - c.w.WriteString(`) DO `) + io.WriteString(c.w, `) DO `) - c.w.WriteString(`UPDATE `) + io.WriteString(c.w, `UPDATE `) io.WriteString(c.w, ` SET `) i = 0 @@ -239,17 +238,17 @@ func (c *compilerContext) renderUpsert(qc *qcode.QCode, w *bytes.Buffer, if i != 0 { io.WriteString(c.w, `, `) } - c.w.WriteString(cn) + io.WriteString(c.w, cn) io.WriteString(c.w, ` = EXCLUDED.`) - c.w.WriteString(cn) + io.WriteString(c.w, cn) i++ } return 0, nil } -func quoted(w *bytes.Buffer, identifier string) { - w.WriteString(`"`) - w.WriteString(identifier) - w.WriteString(`"`) +func quoted(w io.Writer, identifier string) { + io.WriteString(w, `"`) + io.WriteString(w, identifier) + io.WriteString(w, `"`) } diff --git a/psql/query.go b/psql/query.go index 70ea0a2..c06ab95 100644 --- a/psql/query.go +++ b/psql/query.go @@ -49,7 +49,7 @@ func (c *Compiler) IDColumn(table string) (string, error) { } type compilerContext struct { - w *bytes.Buffer + w io.Writer s []qcode.Select *Compiler } @@ -60,7 +60,7 @@ func (co *Compiler) CompileEx(qc *qcode.QCode, vars Variables) (uint32, []byte, return skipped, w.Bytes(), err } -func (co *Compiler) Compile(qc *qcode.QCode, w *bytes.Buffer, vars Variables) (uint32, error) { +func (co *Compiler) Compile(qc *qcode.QCode, w io.Writer, vars Variables) (uint32, error) { switch qc.Type { case qcode.QTQuery: return co.compileQuery(qc, w) @@ -71,7 +71,7 @@ func (co *Compiler) Compile(qc *qcode.QCode, w *bytes.Buffer, vars Variables) (u return 0, fmt.Errorf("Unknown operation type %d", qc.Type) } -func (co *Compiler) compileQuery(qc *qcode.QCode, w *bytes.Buffer) (uint32, error) { +func (co *Compiler) compileQuery(qc *qcode.QCode, w io.Writer) (uint32, error) { if len(qc.Selects) == 0 { return 0, errors.New("empty query") } @@ -90,17 +90,17 @@ func (co *Compiler) compileQuery(qc *qcode.QCode, w *bytes.Buffer) (uint32, erro //fmt.Fprintf(w, `SELECT json_object_agg('%s', %s) FROM (`, //root.FieldName, root.Table) - c.w.WriteString(`SELECT json_object_agg('`) - c.w.WriteString(root.FieldName) - c.w.WriteString(`', `) + io.WriteString(c.w, `SELECT json_object_agg('`) + io.WriteString(c.w, root.FieldName) + io.WriteString(c.w, `', `) if ti.Singular == false { - c.w.WriteString(root.Table) + io.WriteString(c.w, root.Table) } else { - c.w.WriteString("sel_json_") + io.WriteString(c.w, "sel_json_") int2string(c.w, root.ID) } - c.w.WriteString(`) FROM (`) + io.WriteString(c.w, `) FROM (`) var ignored uint32 @@ -161,9 +161,8 @@ func (co *Compiler) compileQuery(qc *qcode.QCode, w *bytes.Buffer) (uint32, erro } } - c.w.WriteString(`)`) + io.WriteString(c.w, `)`) alias(c.w, `done_1337`) - c.w.WriteString(`;`) return ignored, nil } @@ -219,10 +218,10 @@ func (c *compilerContext) renderSelect(sel *qcode.Select, ti *DBTableInfo) (uint // SELECT if ti.Singular == false { //fmt.Fprintf(w, `SELECT coalesce(json_agg("%s"`, c.sel.Table) - c.w.WriteString(`SELECT coalesce(json_agg("`) - c.w.WriteString("sel_json_") + io.WriteString(c.w, `SELECT coalesce(json_agg("`) + io.WriteString(c.w, "sel_json_") int2string(c.w, sel.ID) - c.w.WriteString(`"`) + io.WriteString(c.w, `"`) if hasOrder { err := c.renderOrderBy(sel, ti) @@ -232,24 +231,24 @@ func (c *compilerContext) renderSelect(sel *qcode.Select, ti *DBTableInfo) (uint } //fmt.Fprintf(w, `), '[]') AS "%s" FROM (`, c.sel.Table) - c.w.WriteString(`), '[]')`) + io.WriteString(c.w, `), '[]')`) alias(c.w, sel.Table) - c.w.WriteString(` FROM (`) + io.WriteString(c.w, ` FROM (`) } // ROW-TO-JSON - c.w.WriteString(`SELECT `) + io.WriteString(c.w, `SELECT `) if len(sel.DistinctOn) != 0 { c.renderDistinctOn(sel, ti) } - c.w.WriteString(`row_to_json((`) + io.WriteString(c.w, `row_to_json((`) //fmt.Fprintf(w, `SELECT "sel_%d" FROM (SELECT `, c.sel.ID) - c.w.WriteString(`SELECT "sel_`) + io.WriteString(c.w, `SELECT "sel_`) int2string(c.w, sel.ID) - c.w.WriteString(`" FROM (SELECT `) + io.WriteString(c.w, `" FROM (SELECT `) // Combined column names c.renderColumns(sel, ti) @@ -262,11 +261,11 @@ func (c *compilerContext) renderSelect(sel *qcode.Select, ti *DBTableInfo) (uint } //fmt.Fprintf(w, `) AS "sel_%d"`, c.sel.ID) - c.w.WriteString(`)`) + io.WriteString(c.w, `)`) aliasWithID(c.w, "sel", sel.ID) //fmt.Fprintf(w, `)) AS "%s"`, c.sel.Table) - c.w.WriteString(`))`) + io.WriteString(c.w, `))`) aliasWithID(c.w, "sel_json", sel.ID) // END-ROW-TO-JSON @@ -301,27 +300,27 @@ func (c *compilerContext) renderSelectClose(sel *qcode.Select, ti *DBTableInfo) case len(sel.Paging.Limit) != 0: //fmt.Fprintf(w, ` LIMIT ('%s') :: integer`, c.sel.Paging.Limit) - c.w.WriteString(` LIMIT ('`) - c.w.WriteString(sel.Paging.Limit) - c.w.WriteString(`') :: integer`) + io.WriteString(c.w, ` LIMIT ('`) + io.WriteString(c.w, sel.Paging.Limit) + io.WriteString(c.w, `') :: integer`) case ti.Singular: - c.w.WriteString(` LIMIT ('1') :: integer`) + io.WriteString(c.w, ` LIMIT ('1') :: integer`) default: - c.w.WriteString(` LIMIT ('20') :: integer`) + io.WriteString(c.w, ` LIMIT ('20') :: integer`) } if len(sel.Paging.Offset) != 0 { //fmt.Fprintf(w, ` OFFSET ('%s') :: integer`, c.sel.Paging.Offset) - c.w.WriteString(`OFFSET ('`) - c.w.WriteString(sel.Paging.Offset) - c.w.WriteString(`') :: integer`) + io.WriteString(c.w, `OFFSET ('`) + io.WriteString(c.w, sel.Paging.Offset) + io.WriteString(c.w, `') :: integer`) } if ti.Singular == false { //fmt.Fprintf(w, `) AS "sel_json_agg_%d"`, c.sel.ID) - c.w.WriteString(`)`) + io.WriteString(c.w, `)`) aliasWithID(c.w, "sel_json_agg", sel.ID) } @@ -329,15 +328,15 @@ func (c *compilerContext) renderSelectClose(sel *qcode.Select, ti *DBTableInfo) } func (c *compilerContext) renderJoin(sel *qcode.Select) error { - c.w.WriteString(` LEFT OUTER JOIN LATERAL (`) + io.WriteString(c.w, ` LEFT OUTER JOIN LATERAL (`) return nil } func (c *compilerContext) renderJoinClose(sel *qcode.Select) error { //fmt.Fprintf(w, `) AS "%s_%d_join" ON ('true')`, c.sel.Table, c.sel.ID) - c.w.WriteString(`)`) + io.WriteString(c.w, `)`) aliasWithIDSuffix(c.w, sel.Table, sel.ID, "_join") - c.w.WriteString(` ON ('true')`) + io.WriteString(c.w, ` ON ('true')`) return nil } @@ -360,13 +359,13 @@ func (c *compilerContext) renderJoinTable(sel *qcode.Select) error { //fmt.Fprintf(w, ` LEFT OUTER JOIN "%s" ON (("%s"."%s") = ("%s_%d"."%s"))`, //rel.Through, rel.Through, rel.ColT, c.parent.Table, c.parent.ID, rel.Col1) - c.w.WriteString(` LEFT OUTER JOIN "`) - c.w.WriteString(rel.Through) - c.w.WriteString(`" ON ((`) + io.WriteString(c.w, ` LEFT OUTER JOIN "`) + io.WriteString(c.w, rel.Through) + io.WriteString(c.w, `" ON ((`) colWithTable(c.w, rel.Through, rel.ColT) - c.w.WriteString(`) = (`) + io.WriteString(c.w, `) = (`) colWithTableID(c.w, pt.Name, parent.ID, rel.Col1) - c.w.WriteString(`))`) + io.WriteString(c.w, `))`) return nil } @@ -443,11 +442,11 @@ func (c *compilerContext) renderJoinedColumns(sel *qcode.Select, ti *DBTableInfo //fmt.Fprintf(w, `"%s_%d_join"."%s" AS "%s"`, //s.Table, s.ID, s.Table, s.FieldName) if cti.Singular { - c.w.WriteString(`"sel_json_`) + io.WriteString(c.w, `"sel_json_`) int2string(c.w, childSel.ID) - c.w.WriteString(`" AS "`) - c.w.WriteString(childSel.FieldName) - c.w.WriteString(`"`) + io.WriteString(c.w, `" AS "`) + io.WriteString(c.w, childSel.FieldName) + io.WriteString(c.w, `"`) } else { colWithTableIDSuffixAlias(c.w, childSel.Table, childSel.ID, @@ -467,7 +466,7 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo, isSearch := sel.Args["search"] != nil isAgg := false - c.w.WriteString(` FROM (SELECT `) + io.WriteString(c.w, ` FROM (SELECT `) i := 0 for n, col := range sel.Cols { @@ -483,15 +482,15 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo, arg := sel.Args["search"] if i != 0 { - c.w.WriteString(`, `) + io.WriteString(c.w, `, `) } //fmt.Fprintf(w, `ts_rank("%s"."%s", to_tsquery('%s')) AS %s`, //c.sel.Table, cn, arg.Val, col.Name) - c.w.WriteString(`ts_rank(`) + io.WriteString(c.w, `ts_rank(`) colWithTable(c.w, ti.Name, cn) - c.w.WriteString(`, to_tsquery('`) - c.w.WriteString(arg.Val) - c.w.WriteString(`')`) + io.WriteString(c.w, `, to_tsquery('`) + io.WriteString(c.w, arg.Val) + io.WriteString(c.w, `')`) alias(c.w, col.Name) i++ @@ -500,15 +499,15 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo, arg := sel.Args["search"] if i != 0 { - c.w.WriteString(`, `) + io.WriteString(c.w, `, `) } //fmt.Fprintf(w, `ts_headline("%s"."%s", to_tsquery('%s')) AS %s`, //c.sel.Table, cn, arg.Val, col.Name) - c.w.WriteString(`ts_headlinek(`) + io.WriteString(c.w, `ts_headlinek(`) colWithTable(c.w, ti.Name, cn) - c.w.WriteString(`, to_tsquery('`) - c.w.WriteString(arg.Val) - c.w.WriteString(`')`) + io.WriteString(c.w, `, to_tsquery('`) + io.WriteString(c.w, arg.Val) + io.WriteString(c.w, `')`) alias(c.w, col.Name) i++ @@ -517,12 +516,12 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo, pl := funcPrefixLen(cn) if pl == 0 { if i != 0 { - c.w.WriteString(`, `) + io.WriteString(c.w, `, `) } //fmt.Fprintf(w, `'%s not defined' AS %s`, cn, col.Name) - c.w.WriteString(`'`) - c.w.WriteString(cn) - c.w.WriteString(` not defined'`) + io.WriteString(c.w, `'`) + io.WriteString(c.w, cn) + io.WriteString(c.w, ` not defined'`) alias(c.w, col.Name) i++ @@ -532,16 +531,16 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo, continue } if i != 0 { - c.w.WriteString(`, `) + io.WriteString(c.w, `, `) } fn := cn[0 : pl-1] isAgg = true //fmt.Fprintf(w, `%s("%s"."%s") AS %s`, fn, c.sel.Table, cn, col.Name) - c.w.WriteString(fn) - c.w.WriteString(`(`) + io.WriteString(c.w, fn) + io.WriteString(c.w, `(`) colWithTable(c.w, ti.Name, cn1) - c.w.WriteString(`)`) + io.WriteString(c.w, `)`) alias(c.w, col.Name) i++ @@ -551,7 +550,7 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo, groupBy = append(groupBy, n) //fmt.Fprintf(w, `"%s"."%s"`, c.sel.Table, cn) if i != 0 { - c.w.WriteString(`, `) + io.WriteString(c.w, `, `) } colWithTable(c.w, ti.Name, cn) i++ @@ -561,7 +560,7 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo, for _, col := range childCols { if i != 0 { - c.w.WriteString(`, `) + io.WriteString(c.w, `, `) } //fmt.Fprintf(w, `"%s"."%s"`, col.Table, col.Name) @@ -569,29 +568,29 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo, i++ } - c.w.WriteString(` FROM `) + io.WriteString(c.w, ` FROM `) //fmt.Fprintf(w, ` FROM "%s"`, c.sel.Table) - c.w.WriteString(`"`) - c.w.WriteString(ti.Name) - c.w.WriteString(`"`) + io.WriteString(c.w, `"`) + io.WriteString(c.w, ti.Name) + io.WriteString(c.w, `"`) // if tn, ok := c.tmap[sel.Table]; ok { // //fmt.Fprintf(w, ` FROM "%s" AS "%s"`, tn, c.sel.Table) // tableWithAlias(c.w, ti.Name, sel.Table) // } else { // //fmt.Fprintf(w, ` FROM "%s"`, c.sel.Table) - // c.w.WriteString(`"`) - // c.w.WriteString(sel.Table) - // c.w.WriteString(`"`) + // io.WriteString(c.w, `"`) + // io.WriteString(c.w, sel.Table) + // io.WriteString(c.w, `"`) // } if isRoot && isFil { - c.w.WriteString(` WHERE (`) + io.WriteString(c.w, ` WHERE (`) if err := c.renderWhere(sel, ti); err != nil { return err } - c.w.WriteString(`)`) + io.WriteString(c.w, `)`) } if !isRoot { @@ -599,28 +598,28 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo, return err } - c.w.WriteString(` WHERE (`) + io.WriteString(c.w, ` WHERE (`) if err := c.renderRelationship(sel, ti); err != nil { return err } if isFil { - c.w.WriteString(` AND `) + io.WriteString(c.w, ` AND `) if err := c.renderWhere(sel, ti); err != nil { return err } } - c.w.WriteString(`)`) + io.WriteString(c.w, `)`) } if isAgg { if len(groupBy) != 0 { - c.w.WriteString(` GROUP BY `) + io.WriteString(c.w, ` GROUP BY `) for i, id := range groupBy { if i != 0 { - c.w.WriteString(`, `) + io.WriteString(c.w, `, `) } //fmt.Fprintf(w, `"%s"."%s"`, c.sel.Table, c.sel.Cols[id].Name) colWithTable(c.w, ti.Name, sel.Cols[id].Name) @@ -634,26 +633,26 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo, case len(sel.Paging.Limit) != 0: //fmt.Fprintf(w, ` LIMIT ('%s') :: integer`, c.sel.Paging.Limit) - c.w.WriteString(` LIMIT ('`) - c.w.WriteString(sel.Paging.Limit) - c.w.WriteString(`') :: integer`) + io.WriteString(c.w, ` LIMIT ('`) + io.WriteString(c.w, sel.Paging.Limit) + io.WriteString(c.w, `') :: integer`) case ti.Singular: - c.w.WriteString(` LIMIT ('1') :: integer`) + io.WriteString(c.w, ` LIMIT ('1') :: integer`) default: - c.w.WriteString(` LIMIT ('20') :: integer`) + io.WriteString(c.w, ` LIMIT ('20') :: integer`) } if len(sel.Paging.Offset) != 0 { //fmt.Fprintf(w, ` OFFSET ('%s') :: integer`, c.sel.Paging.Offset) - c.w.WriteString(` OFFSET ('`) - c.w.WriteString(sel.Paging.Offset) - c.w.WriteString(`') :: integer`) + io.WriteString(c.w, ` OFFSET ('`) + io.WriteString(c.w, sel.Paging.Offset) + io.WriteString(c.w, `') :: integer`) } //fmt.Fprintf(w, `) AS "%s_%d"`, c.sel.Table, c.sel.ID) - c.w.WriteString(`)`) + io.WriteString(c.w, `)`) aliasWithID(c.w, ti.Name, sel.ID) return nil } @@ -664,7 +663,7 @@ func (c *compilerContext) renderOrderByColumns(sel *qcode.Select, ti *DBTableInf for i := range sel.OrderBy { if colsRendered { //io.WriteString(w, ", ") - c.w.WriteString(`, `) + io.WriteString(c.w, `, `) } col := sel.OrderBy[i].Col @@ -672,7 +671,7 @@ func (c *compilerContext) renderOrderByColumns(sel *qcode.Select, ti *DBTableInf //c.sel.Table, c.sel.ID, c, //c.sel.Table, c.sel.ID, c) colWithTableID(c.w, ti.Name, sel.ID, col) - c.w.WriteString(` AS `) + io.WriteString(c.w, ` AS `) tableIDColSuffix(c.w, sel.Table, sel.ID, col, "_ob") } } @@ -689,29 +688,29 @@ func (c *compilerContext) renderRelationship(sel *qcode.Select, ti *DBTableInfo) case RelBelongTo: //fmt.Fprintf(w, `(("%s"."%s") = ("%s_%d"."%s"))`, //c.sel.Table, rel.Col1, c.parent.Table, c.parent.ID, rel.Col2) - c.w.WriteString(`((`) + io.WriteString(c.w, `((`) colWithTable(c.w, ti.Name, rel.Col1) - c.w.WriteString(`) = (`) + io.WriteString(c.w, `) = (`) colWithTableID(c.w, parent.Table, parent.ID, rel.Col2) - c.w.WriteString(`))`) + io.WriteString(c.w, `))`) case RelOneToMany: //fmt.Fprintf(w, `(("%s"."%s") = ("%s_%d"."%s"))`, //c.sel.Table, rel.Col1, c.parent.Table, c.parent.ID, rel.Col2) - c.w.WriteString(`((`) + io.WriteString(c.w, `((`) colWithTable(c.w, ti.Name, rel.Col1) - c.w.WriteString(`) = (`) + io.WriteString(c.w, `) = (`) colWithTableID(c.w, parent.Table, parent.ID, rel.Col2) - c.w.WriteString(`))`) + io.WriteString(c.w, `))`) case RelOneToManyThrough: //fmt.Fprintf(w, `(("%s"."%s") = ("%s"."%s"))`, //c.sel.Table, rel.Col1, rel.Through, rel.Col2) - c.w.WriteString(`((`) + io.WriteString(c.w, `((`) colWithTable(c.w, ti.Name, rel.Col1) - c.w.WriteString(`) = (`) + io.WriteString(c.w, `) = (`) colWithTable(c.w, rel.Through, rel.Col2) - c.w.WriteString(`))`) + io.WriteString(c.w, `))`) } return nil @@ -735,11 +734,11 @@ func (c *compilerContext) renderWhere(sel *qcode.Select, ti *DBTableInfo) error case qcode.ExpOp: switch val { case qcode.OpAnd: - c.w.WriteString(` AND `) + io.WriteString(c.w, ` AND `) case qcode.OpOr: - c.w.WriteString(` OR `) + io.WriteString(c.w, ` OR `) case qcode.OpNot: - c.w.WriteString(`NOT `) + io.WriteString(c.w, `NOT `) default: return fmt.Errorf("11: unexpected value %v (%t)", intf, intf) } @@ -763,62 +762,62 @@ func (c *compilerContext) renderWhere(sel *qcode.Select, ti *DBTableInfo) error default: if val.NestedCol { //fmt.Fprintf(w, `(("%s") `, val.Col) - c.w.WriteString(`(("`) - c.w.WriteString(val.Col) - c.w.WriteString(`") `) + io.WriteString(c.w, `(("`) + io.WriteString(c.w, val.Col) + io.WriteString(c.w, `") `) } else if len(val.Col) != 0 { //fmt.Fprintf(w, `(("%s"."%s") `, c.sel.Table, val.Col) - c.w.WriteString(`((`) + io.WriteString(c.w, `((`) colWithTable(c.w, ti.Name, val.Col) - c.w.WriteString(`) `) + io.WriteString(c.w, `) `) } valExists := true switch val.Op { case qcode.OpEquals: - c.w.WriteString(`=`) + io.WriteString(c.w, `=`) case qcode.OpNotEquals: - c.w.WriteString(`!=`) + io.WriteString(c.w, `!=`) case qcode.OpGreaterOrEquals: - c.w.WriteString(`>=`) + io.WriteString(c.w, `>=`) case qcode.OpLesserOrEquals: - c.w.WriteString(`<=`) + io.WriteString(c.w, `<=`) case qcode.OpGreaterThan: - c.w.WriteString(`>`) + io.WriteString(c.w, `>`) case qcode.OpLesserThan: - c.w.WriteString(`<`) + io.WriteString(c.w, `<`) case qcode.OpIn: - c.w.WriteString(`IN`) + io.WriteString(c.w, `IN`) case qcode.OpNotIn: - c.w.WriteString(`NOT IN`) + io.WriteString(c.w, `NOT IN`) case qcode.OpLike: - c.w.WriteString(`LIKE`) + io.WriteString(c.w, `LIKE`) case qcode.OpNotLike: - c.w.WriteString(`NOT LIKE`) + io.WriteString(c.w, `NOT LIKE`) case qcode.OpILike: - c.w.WriteString(`ILIKE`) + io.WriteString(c.w, `ILIKE`) case qcode.OpNotILike: - c.w.WriteString(`NOT ILIKE`) + io.WriteString(c.w, `NOT ILIKE`) case qcode.OpSimilar: - c.w.WriteString(`SIMILAR TO`) + io.WriteString(c.w, `SIMILAR TO`) case qcode.OpNotSimilar: - c.w.WriteString(`NOT SIMILAR TO`) + io.WriteString(c.w, `NOT SIMILAR TO`) case qcode.OpContains: - c.w.WriteString(`@>`) + io.WriteString(c.w, `@>`) case qcode.OpContainedIn: - c.w.WriteString(`<@`) + io.WriteString(c.w, `<@`) case qcode.OpHasKey: - c.w.WriteString(`?`) + io.WriteString(c.w, `?`) case qcode.OpHasKeyAny: - c.w.WriteString(`?|`) + io.WriteString(c.w, `?|`) case qcode.OpHasKeyAll: - c.w.WriteString(`?&`) + io.WriteString(c.w, `?&`) case qcode.OpIsNull: if strings.EqualFold(val.Val, "true") { - c.w.WriteString(`IS NULL)`) + io.WriteString(c.w, `IS NULL)`) } else { - c.w.WriteString(`IS NOT NULL)`) + io.WriteString(c.w, `IS NOT NULL)`) } valExists = false case qcode.OpEqID: @@ -826,20 +825,20 @@ func (c *compilerContext) renderWhere(sel *qcode.Select, ti *DBTableInfo) error return fmt.Errorf("no primary key column defined for %s", ti.Name) } //fmt.Fprintf(w, `(("%s") =`, c.ti.PrimaryCol) - c.w.WriteString(`((`) + io.WriteString(c.w, `((`) colWithTable(c.w, ti.Name, ti.PrimaryCol) - //c.w.WriteString(ti.PrimaryCol) - c.w.WriteString(`) =`) + //io.WriteString(c.w, ti.PrimaryCol) + io.WriteString(c.w, `) =`) case qcode.OpTsQuery: if len(ti.TSVCol) == 0 { return fmt.Errorf("no tsv column defined for %s", ti.Name) } //fmt.Fprintf(w, `(("%s") @@ to_tsquery('%s'))`, c.ti.TSVCol, val.Val) - c.w.WriteString(`(("`) - c.w.WriteString(ti.TSVCol) - c.w.WriteString(`") @@ to_tsquery('`) - c.w.WriteString(val.Val) - c.w.WriteString(`'))`) + io.WriteString(c.w, `(("`) + io.WriteString(c.w, ti.TSVCol) + io.WriteString(c.w, `") @@ to_tsquery('`) + io.WriteString(c.w, val.Val) + io.WriteString(c.w, `'))`) valExists = false default: @@ -852,7 +851,7 @@ func (c *compilerContext) renderWhere(sel *qcode.Select, ti *DBTableInfo) error } else { c.renderVal(val, c.vars) } - c.w.WriteString(`)`) + io.WriteString(c.w, `)`) } qcode.FreeExp(val) @@ -868,10 +867,10 @@ func (c *compilerContext) renderWhere(sel *qcode.Select, ti *DBTableInfo) error } func (c *compilerContext) renderOrderBy(sel *qcode.Select, ti *DBTableInfo) error { - c.w.WriteString(` ORDER BY `) + io.WriteString(c.w, ` ORDER BY `) for i := range sel.OrderBy { if i != 0 { - c.w.WriteString(`, `) + io.WriteString(c.w, `, `) } ob := sel.OrderBy[i] @@ -879,27 +878,27 @@ func (c *compilerContext) renderOrderBy(sel *qcode.Select, ti *DBTableInfo) erro case qcode.OrderAsc: //fmt.Fprintf(w, `"%s_%d.ob.%s" ASC`, sel.Table, sel.ID, ob.Col) tableIDColSuffix(c.w, ti.Name, sel.ID, ob.Col, "_ob") - c.w.WriteString(` ASC`) + io.WriteString(c.w, ` ASC`) case qcode.OrderDesc: //fmt.Fprintf(w, `"%s_%d.ob.%s" DESC`, sel.Table, sel.ID, ob.Col) tableIDColSuffix(c.w, ti.Name, sel.ID, ob.Col, "_ob") - c.w.WriteString(` DESC`) + io.WriteString(c.w, ` DESC`) case qcode.OrderAscNullsFirst: //fmt.Fprintf(w, `"%s_%d.ob.%s" ASC NULLS FIRST`, sel.Table, sel.ID, ob.Col) tableIDColSuffix(c.w, ti.Name, sel.ID, ob.Col, "_ob") - c.w.WriteString(` ASC NULLS FIRST`) + io.WriteString(c.w, ` ASC NULLS FIRST`) case qcode.OrderDescNullsFirst: //fmt.Fprintf(w, `%s_%d.ob.%s DESC NULLS FIRST`, sel.Table, sel.ID, ob.Col) tableIDColSuffix(c.w, ti.Name, sel.ID, ob.Col, "_ob") - c.w.WriteString(` DESC NULLLS FIRST`) + io.WriteString(c.w, ` DESC NULLLS FIRST`) case qcode.OrderAscNullsLast: //fmt.Fprintf(w, `"%s_%d.ob.%s ASC NULLS LAST`, sel.Table, sel.ID, ob.Col) tableIDColSuffix(c.w, ti.Name, sel.ID, ob.Col, "_ob") - c.w.WriteString(` ASC NULLS LAST`) + io.WriteString(c.w, ` ASC NULLS LAST`) case qcode.OrderDescNullsLast: //fmt.Fprintf(w, `%s_%d.ob.%s DESC NULLS LAST`, sel.Table, sel.ID, ob.Col) tableIDColSuffix(c.w, ti.Name, sel.ID, ob.Col, "_ob") - c.w.WriteString(` DESC NULLS LAST`) + io.WriteString(c.w, ` DESC NULLS LAST`) default: return fmt.Errorf("13: unexpected value %v", ob.Order) } @@ -911,30 +910,30 @@ func (c *compilerContext) renderDistinctOn(sel *qcode.Select, ti *DBTableInfo) { io.WriteString(c.w, `DISTINCT ON (`) for i := range sel.DistinctOn { if i != 0 { - c.w.WriteString(`, `) + io.WriteString(c.w, `, `) } //fmt.Fprintf(w, `"%s_%d.ob.%s"`, c.sel.Table, c.sel.ID, c.sel.DistinctOn[i]) tableIDColSuffix(c.w, ti.Name, sel.ID, sel.DistinctOn[i], "_ob") } - c.w.WriteString(`) `) + io.WriteString(c.w, `) `) } func (c *compilerContext) renderList(ex *qcode.Exp) { io.WriteString(c.w, ` (`) for i := range ex.ListVal { if i != 0 { - c.w.WriteString(`, `) + io.WriteString(c.w, `, `) } switch ex.ListType { case qcode.ValBool, qcode.ValInt, qcode.ValFloat: - c.w.WriteString(ex.ListVal[i]) + io.WriteString(c.w, ex.ListVal[i]) case qcode.ValStr: - c.w.WriteString(`'`) - c.w.WriteString(ex.ListVal[i]) - c.w.WriteString(`'`) + io.WriteString(c.w, `'`) + io.WriteString(c.w, ex.ListVal[i]) + io.WriteString(c.w, `'`) } } - c.w.WriteString(`)`) + io.WriteString(c.w, `)`) } func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string) { @@ -943,27 +942,27 @@ func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string) { switch ex.Type { case qcode.ValBool, qcode.ValInt, qcode.ValFloat: if len(ex.Val) != 0 { - c.w.WriteString(ex.Val) + io.WriteString(c.w, ex.Val) } else { - c.w.WriteString(`''`) + io.WriteString(c.w, `''`) } case qcode.ValStr: - c.w.WriteString(`'`) - c.w.WriteString(ex.Val) - c.w.WriteString(`'`) + io.WriteString(c.w, `'`) + io.WriteString(c.w, ex.Val) + io.WriteString(c.w, `'`) case qcode.ValVar: if val, ok := vars[ex.Val]; ok { - c.w.WriteString(val) + io.WriteString(c.w, val) } else { //fmt.Fprintf(w, `'{{%s}}'`, ex.Val) - c.w.WriteString(`{{`) - c.w.WriteString(ex.Val) - c.w.WriteString(`}}`) + io.WriteString(c.w, `{{`) + io.WriteString(c.w, ex.Val) + io.WriteString(c.w, `}}`) } } - //c.w.WriteString(`)`) + //io.WriteString(c.w, `)`) } func funcPrefixLen(fn string) int { @@ -999,105 +998,105 @@ func hasBit(n uint32, pos uint32) bool { return (val > 0) } -func alias(w *bytes.Buffer, alias string) { - w.WriteString(` AS "`) - w.WriteString(alias) - w.WriteString(`"`) +func alias(w io.Writer, alias string) { + io.WriteString(w, ` AS "`) + io.WriteString(w, alias) + io.WriteString(w, `"`) } -func aliasWithID(w *bytes.Buffer, alias string, id int32) { - w.WriteString(` AS "`) - w.WriteString(alias) - w.WriteString(`_`) +func aliasWithID(w io.Writer, alias string, id int32) { + io.WriteString(w, ` AS "`) + io.WriteString(w, alias) + io.WriteString(w, `_`) int2string(w, id) - w.WriteString(`"`) + io.WriteString(w, `"`) } -func aliasWithIDSuffix(w *bytes.Buffer, alias string, id int32, suffix string) { - w.WriteString(` AS "`) - w.WriteString(alias) - w.WriteString(`_`) +func aliasWithIDSuffix(w io.Writer, alias string, id int32, suffix string) { + io.WriteString(w, ` AS "`) + io.WriteString(w, alias) + io.WriteString(w, `_`) int2string(w, id) - w.WriteString(suffix) - w.WriteString(`"`) + io.WriteString(w, suffix) + io.WriteString(w, `"`) } -func colWithAlias(w *bytes.Buffer, col, alias string) { - w.WriteString(`"`) - w.WriteString(col) - w.WriteString(`" AS "`) - w.WriteString(alias) - w.WriteString(`"`) +func colWithAlias(w io.Writer, col, alias string) { + io.WriteString(w, `"`) + io.WriteString(w, col) + io.WriteString(w, `" AS "`) + io.WriteString(w, alias) + io.WriteString(w, `"`) } -func tableWithAlias(w *bytes.Buffer, table, alias string) { - w.WriteString(`"`) - w.WriteString(table) - w.WriteString(`" AS "`) - w.WriteString(alias) - w.WriteString(`"`) +func tableWithAlias(w io.Writer, table, alias string) { + io.WriteString(w, `"`) + io.WriteString(w, table) + io.WriteString(w, `" AS "`) + io.WriteString(w, alias) + io.WriteString(w, `"`) } -func colWithTable(w *bytes.Buffer, table, col string) { - w.WriteString(`"`) - w.WriteString(table) - w.WriteString(`"."`) - w.WriteString(col) - w.WriteString(`"`) +func colWithTable(w io.Writer, table, col string) { + io.WriteString(w, `"`) + io.WriteString(w, table) + io.WriteString(w, `"."`) + io.WriteString(w, col) + io.WriteString(w, `"`) } -func colWithTableID(w *bytes.Buffer, table string, id int32, col string) { - w.WriteString(`"`) - w.WriteString(table) - w.WriteString(`_`) +func colWithTableID(w io.Writer, table string, id int32, col string) { + io.WriteString(w, `"`) + io.WriteString(w, table) + io.WriteString(w, `_`) int2string(w, id) - w.WriteString(`"."`) - w.WriteString(col) - w.WriteString(`"`) + io.WriteString(w, `"."`) + io.WriteString(w, col) + io.WriteString(w, `"`) } -func colWithTableIDAlias(w *bytes.Buffer, table string, id int32, col, alias string) { - w.WriteString(`"`) - w.WriteString(table) - w.WriteString(`_`) +func colWithTableIDAlias(w io.Writer, table string, id int32, col, alias string) { + io.WriteString(w, `"`) + io.WriteString(w, table) + io.WriteString(w, `_`) int2string(w, id) - w.WriteString(`"."`) - w.WriteString(col) - w.WriteString(`" AS "`) - w.WriteString(alias) - w.WriteString(`"`) + io.WriteString(w, `"."`) + io.WriteString(w, col) + io.WriteString(w, `" AS "`) + io.WriteString(w, alias) + io.WriteString(w, `"`) } -func colWithTableIDSuffixAlias(w *bytes.Buffer, table string, id int32, +func colWithTableIDSuffixAlias(w io.Writer, table string, id int32, suffix, col, alias string) { - w.WriteString(`"`) - w.WriteString(table) - w.WriteString(`_`) + io.WriteString(w, `"`) + io.WriteString(w, table) + io.WriteString(w, `_`) int2string(w, id) - w.WriteString(suffix) - w.WriteString(`"."`) - w.WriteString(col) - w.WriteString(`" AS "`) - w.WriteString(alias) - w.WriteString(`"`) + io.WriteString(w, suffix) + io.WriteString(w, `"."`) + io.WriteString(w, col) + io.WriteString(w, `" AS "`) + io.WriteString(w, alias) + io.WriteString(w, `"`) } -func tableIDColSuffix(w *bytes.Buffer, table string, id int32, col, suffix string) { - w.WriteString(`"`) - w.WriteString(table) - w.WriteString(`_`) +func tableIDColSuffix(w io.Writer, table string, id int32, col, suffix string) { + io.WriteString(w, `"`) + io.WriteString(w, table) + io.WriteString(w, `_`) int2string(w, id) - w.WriteString(`_`) - w.WriteString(col) - w.WriteString(suffix) - w.WriteString(`"`) + io.WriteString(w, `_`) + io.WriteString(w, col) + io.WriteString(w, suffix) + io.WriteString(w, `"`) } const charset = "0123456789" -func int2string(w *bytes.Buffer, val int32) { +func int2string(w io.Writer, val int32) { if val < 10 { - w.WriteByte(charset[val]) + w.Write([]byte{charset[val]}) return } @@ -1113,7 +1112,7 @@ func int2string(w *bytes.Buffer, val int32) { for val3 > 0 { d := val3 % 10 val3 /= 10 - w.WriteByte(charset[d]) + w.Write([]byte{charset[d]}) } } diff --git a/serv/allow.go b/serv/allow.go index 729ef61..1279170 100644 --- a/serv/allow.go +++ b/serv/allow.go @@ -182,7 +182,7 @@ func (al *allowList) load() { item.vars = varBytes } - al.list[gqlHash(q, varBytes)] = item + al.list[gqlHash(q, varBytes, "")] = item varBytes = nil } else if ty == AL_VARS { @@ -203,7 +203,11 @@ func (al *allowList) save(item *allowItem) { if al.active == false { return } - al.list[gqlHash(item.gql, item.vars)] = item + h := gqlHash(item.gql, item.vars, "") + if _, ok := al.list[h]; ok { + return + } + al.list[gqlHash(item.gql, item.vars, "")] = item f, err := os.Create(al.filepath) if err != nil { diff --git a/serv/auth.go b/serv/auth.go index e597a9b..77942eb 100644 --- a/serv/auth.go +++ b/serv/auth.go @@ -9,26 +9,40 @@ import ( var ( userIDProviderKey = struct{}{} userIDKey = struct{}{} + userRoleKey = struct{}{} ) -func headerAuth(r *http.Request, c *config) *http.Request { - if len(c.Auth.Header) == 0 { - return nil - } +func headerAuth(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() - userID := r.Header.Get(c.Auth.Header) - if len(userID) != 0 { - ctx := context.WithValue(r.Context(), userIDKey, userID) - return r.WithContext(ctx) - } + userIDProvider := r.Header.Get("X-User-ID-Provider") + if len(userIDProvider) != 0 { + ctx = context.WithValue(ctx, userIDProviderKey, userIDProvider) + } - return nil + userID := r.Header.Get("X-User-ID") + if len(userID) != 0 { + ctx = context.WithValue(ctx, userIDKey, userID) + } + + userRole := r.Header.Get("X-User-Role") + if len(userRole) != 0 { + ctx = context.WithValue(ctx, userRoleKey, userRole) + } + + next.ServeHTTP(w, r.WithContext(ctx)) + } } func withAuth(next http.HandlerFunc) http.HandlerFunc { at := conf.Auth.Type ru := conf.Auth.Rails.URL + if conf.Auth.CredsInHeader { + next = headerAuth(next) + } + switch at { case "rails": if strings.HasPrefix(ru, "memcache:") { diff --git a/serv/auth_jwt.go b/serv/auth_jwt.go index 25ed785..ef4f834 100644 --- a/serv/auth_jwt.go +++ b/serv/auth_jwt.go @@ -58,11 +58,6 @@ func jwtHandler(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var tok string - if rn := headerAuth(r, conf); rn != nil { - next.ServeHTTP(w, rn) - return - } - if len(cookie) != 0 { ck, err := r.Cookie(cookie) if err != nil { @@ -102,7 +97,6 @@ func jwtHandler(next http.HandlerFunc) http.HandlerFunc { } next.ServeHTTP(w, r.WithContext(ctx)) } - next.ServeHTTP(w, r) } } diff --git a/serv/auth_rails.go b/serv/auth_rails.go index c72c0d7..cd0b327 100644 --- a/serv/auth_rails.go +++ b/serv/auth_rails.go @@ -42,11 +42,6 @@ func railsRedisHandler(next http.HandlerFunc) http.HandlerFunc { } return func(w http.ResponseWriter, r *http.Request) { - if rn := headerAuth(r, conf); rn != nil { - next.ServeHTTP(w, rn) - return - } - ck, err := r.Cookie(cookie) if err != nil { next.ServeHTTP(w, r) @@ -83,17 +78,12 @@ func railsMemcacheHandler(next http.HandlerFunc) http.HandlerFunc { rURL, err := url.Parse(conf.Auth.Rails.URL) if err != nil { - logger.Fatal().Err(err) + logger.Fatal().Err(err).Send() } mc := memcache.New(rURL.Host) return func(w http.ResponseWriter, r *http.Request) { - if rn := headerAuth(r, conf); rn != nil { - next.ServeHTTP(w, rn) - return - } - ck, err := r.Cookie(cookie) if err != nil { next.ServeHTTP(w, r) @@ -126,25 +116,20 @@ func railsCookieHandler(next http.HandlerFunc) http.HandlerFunc { ra, err := railsAuth(conf) if err != nil { - logger.Fatal().Err(err) + logger.Fatal().Err(err).Send() } return func(w http.ResponseWriter, r *http.Request) { - if rn := headerAuth(r, conf); rn != nil { - next.ServeHTTP(w, rn) - return - } - ck, err := r.Cookie(cookie) if err != nil { - logger.Error().Err(err) + logger.Warn().Err(err).Send() next.ServeHTTP(w, r) return } userID, err := ra.ParseCookie(ck.Value) if err != nil { - logger.Error().Err(err) + logger.Warn().Err(err).Send() next.ServeHTTP(w, r) return } diff --git a/serv/cmd.go b/serv/cmd.go index e73b4fb..12b3ce7 100644 --- a/serv/cmd.go +++ b/serv/cmd.go @@ -183,7 +183,32 @@ func initConf() (*config, error) { } zerolog.SetGlobalLevel(logLevel) - //fmt.Printf("%#v", c) + for k, v := range c.DB.Vars { + c.DB.Vars[k] = sanitize(v) + } + + c.RolesQuery = sanitize(c.RolesQuery) + + rolesMap := make(map[string]struct{}) + + for i := range c.Roles { + role := &c.Roles[i] + + if _, ok := rolesMap[role.Name]; ok { + logger.Fatal().Msgf("duplicate role '%s' found", role.Name) + } + role.Name = sanitize(role.Name) + role.Match = sanitize(role.Match) + rolesMap[role.Name] = struct{}{} + } + + if _, ok := rolesMap["user"]; !ok { + c.Roles = append(c.Roles, configRole{Name: "user"}) + } + + if _, ok := rolesMap["anon"]; !ok { + c.Roles = append(c.Roles, configRole{Name: "anon"}) + } return c, nil } diff --git a/serv/cmd_seed.go b/serv/cmd_seed.go index f9b152e..514c543 100644 --- a/serv/cmd_seed.go +++ b/serv/cmd_seed.go @@ -66,8 +66,9 @@ func graphQLFunc(query string, data interface{}) map[string]interface{} { c := &coreContext{Context: context.Background()} c.req.Query = query c.req.Vars = b + c.req.role = "user" - res, err := c.execQuery("user") + res, err := c.execQuery() if err != nil { logger.Fatal().Err(err).Msg("graphql query failed") } diff --git a/serv/config.go b/serv/config.go index b2ebf21..8420c66 100644 --- a/serv/config.go +++ b/serv/config.go @@ -1,7 +1,9 @@ package serv import ( + "regexp" "strings" + "unicode" "github.com/spf13/viper" ) @@ -24,9 +26,9 @@ type config struct { Inflections map[string]string Auth struct { - Type string - Cookie string - Header string + Type string + Cookie string + CredsInHeader bool `mapstructure:"creds_in_header"` Rails struct { Version string @@ -60,7 +62,7 @@ type config struct { MaxRetries int `mapstructure:"max_retries"` LogLevel string `mapstructure:"log_level"` - vars map[string][]byte `mapstructure:"variables"` + Vars map[string]string `mapstructure:"variables"` Defaults struct { Filter []string @@ -71,7 +73,9 @@ type config struct { } `mapstructure:"database"` Tables []configTable - Roles []configRoles + + RolesQuery string `mapstructure:"roles_query"` + Roles []configRole } type configTable struct { @@ -94,8 +98,9 @@ type configRemote struct { } `mapstructure:"set_headers"` } -type configRoles struct { +type configRole struct { Name string + Match string Tables []struct { Name string @@ -163,26 +168,6 @@ func newConfig() *viper.Viper { return vi } -func (c *config) getVariables() map[string]string { - vars := make(map[string]string, len(c.DB.vars)) - - for k, v := range c.DB.vars { - isVar := false - - for i := range v { - if v[i] == '$' { - isVar = true - } else if v[i] == ' ' { - isVar = false - } else if isVar && v[i] >= 'a' && v[i] <= 'z' { - v[i] = 'A' + (v[i] - 'a') - } - } - vars[k] = string(v) - } - return vars -} - func (c *config) getAliasMap() map[string][]string { m := make(map[string][]string, len(c.Tables)) @@ -198,3 +183,21 @@ func (c *config) getAliasMap() map[string][]string { } return m } + +var varRe1 = regexp.MustCompile(`(?mi)\$([a-zA-Z0-9_.]+)`) +var varRe2 = regexp.MustCompile(`\{\{([a-zA-Z0-9_.]+)\}\}`) + +func sanitize(s string) string { + s0 := varRe1.ReplaceAllString(s, `{{$1}}`) + + s1 := strings.Map(func(r rune) rune { + if unicode.IsSpace(r) { + return ' ' + } + return r + }, s0) + + return varRe2.ReplaceAllStringFunc(s1, func(m string) string { + return strings.ToLower(m) + }) +} diff --git a/serv/core.go b/serv/core.go index edc4779..8a007ac 100644 --- a/serv/core.go +++ b/serv/core.go @@ -13,8 +13,8 @@ import ( "github.com/cespare/xxhash/v2" "github.com/dosco/super-graph/jsn" - "github.com/dosco/super-graph/psql" "github.com/dosco/super-graph/qcode" + "github.com/jackc/pgx/v4" "github.com/valyala/fasttemplate" ) @@ -32,15 +32,13 @@ func (c *coreContext) handleReq(w io.Writer, req *http.Request) error { c.req.ref = req.Referer() c.req.hdr = req.Header - var role string - if authCheck(c) { - role = "user" + c.req.role = "user" } else { - role = "anon" + c.req.role = "anon" } - b, err := c.execQuery(role) + b, err := c.execQuery() if err != nil { return err } @@ -48,18 +46,18 @@ func (c *coreContext) handleReq(w io.Writer, req *http.Request) error { return c.render(w, b) } -func (c *coreContext) execQuery(role string) ([]byte, error) { +func (c *coreContext) execQuery() ([]byte, error) { var err error var skipped uint32 var qc *qcode.QCode var data []byte - logger.Debug().Str("role", role).Msg(c.req.Query) + logger.Debug().Str("role", c.req.role).Msg(c.req.Query) if conf.UseAllowList { var ps *preparedItem - data, ps, err = c.resolvePreparedSQL(c.req.Query) + data, ps, err = c.resolvePreparedSQL() if err != nil { return nil, err } @@ -69,12 +67,7 @@ func (c *coreContext) execQuery(role string) ([]byte, error) { } else { - qc, err = qcompile.Compile([]byte(c.req.Query), role) - if err != nil { - return nil, err - } - - data, skipped, err = c.resolveSQL(qc) + data, skipped, err = c.resolveSQL() if err != nil { return nil, err } @@ -122,6 +115,152 @@ func (c *coreContext) execQuery(role string) ([]byte, error) { return ob.Bytes(), nil } +func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) { + tx, err := db.Begin(c) + if err != nil { + return nil, nil, err + } + defer tx.Rollback(c) + + if v := c.Value(userIDKey); v != nil { + _, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v)) + + if err != nil { + return nil, nil, err + } + } + + var role string + useRoleQuery := len(conf.RolesQuery) != 0 && isMutation(c.req.Query) + + if useRoleQuery { + if role, err = c.executeRoleQuery(tx); err != nil { + return nil, nil, err + } + } else if v := c.Value(userRoleKey); v != nil { + role = v.(string) + } else { + role = c.req.role + } + + ps, ok := _preparedList[gqlHash(c.req.Query, c.req.Vars, role)] + if !ok { + return nil, nil, errUnauthorized + } + + var root []byte + vars := varList(c, ps.args) + + err = tx.QueryRow(c, ps.stmt.SQL, vars...).Scan(&root) + if err != nil { + return nil, nil, err + } + + if err := tx.Commit(c); err != nil { + return nil, nil, err + } + + return root, ps, nil +} + +func (c *coreContext) resolveSQL() ([]byte, uint32, error) { + tx, err := db.Begin(c) + if err != nil { + return nil, 0, err + } + defer tx.Rollback(c) + + mutation := isMutation(c.req.Query) + useRoleQuery := len(conf.RolesQuery) != 0 && mutation + + if useRoleQuery { + if c.req.role, err = c.executeRoleQuery(tx); err != nil { + return nil, 0, err + } + + } else if v := c.Value(userRoleKey); v != nil { + c.req.role = v.(string) + } + + stmts, err := c.buildStmt() + if err != nil { + return nil, 0, err + } + + var st *stmt + + if mutation { + st = findStmt(c.req.role, stmts) + } else { + st = &stmts[0] + } + + t := fasttemplate.New(st.sql, openVar, closeVar) + + buf := &bytes.Buffer{} + _, err = t.ExecuteFunc(buf, varMap(c)) + + if err == errNoUserID && + authFailBlock == authFailBlockPerQuery && + authCheck(c) == false { + return nil, 0, errUnauthorized + } + + if err != nil { + return nil, 0, err + } + + finalSQL := buf.String() + + var stime time.Time + + if conf.EnableTracing { + stime = time.Now() + } + + if v := c.Value(userIDKey); v != nil { + _, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v)) + + if err != nil { + return nil, 0, err + } + } + + var root []byte + + if mutation { + err = tx.QueryRow(c, finalSQL).Scan(&root) + } else { + err = tx.QueryRow(c, finalSQL).Scan(&c.req.role, &root) + } + if err != nil { + return nil, 0, err + } + + if err := tx.Commit(c); err != nil { + return nil, 0, err + } + + if mutation { + st = findStmt(c.req.role, stmts) + } else { + st = &stmts[0] + } + + if conf.EnableTracing && len(st.qc.Selects) != 0 { + c.addTrace( + st.qc.Selects, + st.qc.Selects[0].ID, + stime) + } + + if conf.UseAllowList == false { + _allowList.add(&c.req) + } + + return root, st.skipped, nil +} + func (c *coreContext) resolveRemote( hdr http.Header, h *xxhash.Digest, @@ -269,125 +408,15 @@ func (c *coreContext) resolveRemotes( return to, cerr } -func (c *coreContext) resolvePreparedSQL(gql string) ([]byte, *preparedItem, error) { - ps, ok := _preparedList[gqlHash(gql, c.req.Vars)] - if !ok { - return nil, nil, errUnauthorized +func (c *coreContext) executeRoleQuery(tx pgx.Tx) (string, error) { + var role string + row := tx.QueryRow(c, "_sg_get_role", c.req.role, 1) + + if err := row.Scan(&role); err != nil { + return "", err } - var root []byte - vars := varList(c, ps.args) - - tx, err := db.Begin(c) - if err != nil { - return nil, nil, err - } - defer tx.Rollback(c) - - if v := c.Value(userIDKey); v != nil { - _, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v)) - - if err != nil { - return nil, nil, err - } - } - - err = tx.QueryRow(c, ps.stmt.SQL, vars...).Scan(&root) - if err != nil { - return nil, nil, err - } - - if err := tx.Commit(c); err != nil { - return nil, nil, err - } - - fmt.Printf("PRE: %v\n", ps.stmt) - - return root, ps, nil -} - -func (c *coreContext) resolveSQL(qc *qcode.QCode) ([]byte, uint32, error) { - var vars map[string]json.RawMessage - stmt := &bytes.Buffer{} - - if len(c.req.Vars) != 0 { - if err := json.Unmarshal(c.req.Vars, &vars); err != nil { - return nil, 0, err - } - } - - skipped, err := pcompile.Compile(qc, stmt, psql.Variables(vars)) - if err != nil { - return nil, 0, err - } - - t := fasttemplate.New(stmt.String(), openVar, closeVar) - - stmt.Reset() - _, err = t.ExecuteFunc(stmt, varMap(c)) - - if err == errNoUserID && - authFailBlock == authFailBlockPerQuery && - authCheck(c) == false { - return nil, 0, errUnauthorized - } - - if err != nil { - return nil, 0, err - } - - finalSQL := stmt.String() - - // if conf.LogLevel == "debug" { - // os.Stdout.WriteString(finalSQL) - // os.Stdout.WriteString("\n\n") - // } - - var st time.Time - - if conf.EnableTracing { - st = time.Now() - } - - tx, err := db.Begin(c) - if err != nil { - return nil, 0, err - } - defer tx.Rollback(c) - - if v := c.Value(userIDKey); v != nil { - _, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v)) - - if err != nil { - return nil, 0, err - } - } - - //fmt.Printf("\nRAW: %#v\n", finalSQL) - - var root []byte - - err = tx.QueryRow(c, finalSQL).Scan(&root) - if err != nil { - return nil, 0, err - } - - if err := tx.Commit(c); err != nil { - return nil, 0, err - } - - if conf.EnableTracing && len(qc.Selects) != 0 { - c.addTrace( - qc.Selects, - qc.Selects[0].ID, - st) - } - - if conf.UseAllowList == false { - _allowList.add(&c.req) - } - - return root, skipped, nil + return role, nil } func (c *coreContext) render(w io.Writer, data []byte) error { diff --git a/serv/core_build.go b/serv/core_build.go new file mode 100644 index 0000000..fd86003 --- /dev/null +++ b/serv/core_build.go @@ -0,0 +1,144 @@ +package serv + +import ( + "bytes" + "encoding/json" + "errors" + "io" + + "github.com/dosco/super-graph/psql" + "github.com/dosco/super-graph/qcode" +) + +type stmt struct { + role *configRole + qc *qcode.QCode + skipped uint32 + sql string +} + +func (c *coreContext) buildStmt() ([]stmt, error) { + var vars map[string]json.RawMessage + + if len(c.req.Vars) != 0 { + if err := json.Unmarshal(c.req.Vars, &vars); err != nil { + return nil, err + } + } + + gql := []byte(c.req.Query) + + if len(conf.Roles) == 0 { + return nil, errors.New(`no roles found ('user' and 'anon' required)`) + } + + qc, err := qcompile.Compile(gql, conf.Roles[0].Name) + if err != nil { + return nil, err + } + + stmts := make([]stmt, 0, len(conf.Roles)) + mutation := (qc.Type != qcode.QTQuery) + w := &bytes.Buffer{} + + for i := range conf.Roles { + role := &conf.Roles[i] + + if mutation && len(c.req.role) != 0 && role.Name != c.req.role { + continue + } + + if i > 0 { + qc, err = qcompile.Compile(gql, role.Name) + if err != nil { + return nil, err + } + } + + stmts = append(stmts, stmt{role: role, qc: qc}) + + if mutation { + skipped, err := pcompile.Compile(qc, w, psql.Variables(vars)) + if err != nil { + return nil, err + } + + s := &stmts[len(stmts)-1] + s.skipped = skipped + s.sql = w.String() + w.Reset() + } + } + + if mutation { + return stmts, nil + } + + io.WriteString(w, `SELECT "_sg_auth_info"."role", (CASE "_sg_auth_info"."role" `) + + for _, s := range stmts { + io.WriteString(w, `WHEN '`) + io.WriteString(w, s.role.Name) + io.WriteString(w, `' THEN (`) + + s.skipped, err = pcompile.Compile(s.qc, w, psql.Variables(vars)) + if err != nil { + return nil, err + } + + io.WriteString(w, `) `) + } + io.WriteString(w, `END) FROM (`) + + if len(conf.RolesQuery) == 0 { + v := c.Value(userRoleKey) + + io.WriteString(w, `VALUES ("`) + if v != nil { + io.WriteString(w, v.(string)) + } else { + io.WriteString(w, c.req.role) + } + io.WriteString(w, `")) AS "_sg_auth_info"(role) LIMIT 1;`) + + } else { + + io.WriteString(w, `SELECT (CASE WHEN EXISTS (`) + io.WriteString(w, conf.RolesQuery) + io.WriteString(w, `) THEN `) + + io.WriteString(w, `(SELECT (CASE`) + for _, s := range stmts { + if len(s.role.Match) == 0 { + continue + } + io.WriteString(w, ` WHEN `) + io.WriteString(w, s.role.Match) + io.WriteString(w, ` THEN '`) + io.WriteString(w, s.role.Name) + io.WriteString(w, `'`) + } + + if len(c.req.role) == 0 { + io.WriteString(w, ` ELSE 'anon' END) FROM (`) + } else { + io.WriteString(w, ` ELSE '`) + io.WriteString(w, c.req.role) + io.WriteString(w, `' END) FROM (`) + } + + io.WriteString(w, conf.RolesQuery) + io.WriteString(w, `) AS "_sg_auth_roles_query" LIMIT 1) ELSE '`) + if len(c.req.role) == 0 { + io.WriteString(w, `anon`) + } else { + io.WriteString(w, c.req.role) + } + io.WriteString(w, `' END) FROM (VALUES (1)) AS "_sg_auth_filler") AS "_sg_auth_info"(role) LIMIT 1; `) + } + + stmts[0].sql = w.String() + stmts[0].role = nil + + return stmts, nil +} diff --git a/serv/http.go b/serv/http.go index d419d8b..ae8ff84 100644 --- a/serv/http.go +++ b/serv/http.go @@ -30,6 +30,7 @@ type gqlReq struct { Query string `json:"query"` Vars json.RawMessage `json:"variables"` ref string + role string hdr http.Header } @@ -101,13 +102,11 @@ func apiv1Http(w http.ResponseWriter, r *http.Request) { err = ctx.handleReq(w, r) if err == errUnauthorized { - err := "Not authorized" - logger.Debug().Msg(err) - http.Error(w, err, 401) + http.Error(w, "Not authorized", 401) } if err != nil { - logger.Err(err).Msg("Failed to handle request") + logger.Err(err).Msg("failed to handle request") errorResp(w, err) } } diff --git a/serv/prepare.go b/serv/prepare.go index 697d80a..8415397 100644 --- a/serv/prepare.go +++ b/serv/prepare.go @@ -7,7 +7,6 @@ import ( "fmt" "io" - "github.com/dosco/super-graph/psql" "github.com/dosco/super-graph/qcode" "github.com/jackc/pgconn" "github.com/valyala/fasttemplate" @@ -27,55 +26,100 @@ var ( func initPreparedList() { _preparedList = make(map[string]*preparedItem) - for k, v := range _allowList.list { - err := prepareStmt(k, v.gql, v.vars) + if err := prepareRoleStmt(); err != nil { + logger.Fatal().Err(err).Msg("failed to prepare get role statement") + } + + for _, v := range _allowList.list { + err := prepareStmt(v.gql, v.vars) if err != nil { logger.Warn().Str("gql", v.gql).Err(err).Send() } } } -func prepareStmt(key, gql string, varBytes json.RawMessage) error { - if len(gql) == 0 || len(key) == 0 { +func prepareStmt(gql string, varBytes json.RawMessage) error { + if len(gql) == 0 { return nil } - qc, err := qcompile.Compile([]byte(gql), "user") + c := &coreContext{Context: context.Background()} + c.req.Query = gql + c.req.Vars = varBytes + + stmts, err := c.buildStmt() if err != nil { return err } - var vars map[string]json.RawMessage + for _, s := range stmts { + if len(s.sql) == 0 { + continue + } - if len(varBytes) != 0 { - vars = make(map[string]json.RawMessage) + finalSQL, am := processTemplate(s.sql) - if err := json.Unmarshal(varBytes, &vars); err != nil { + ctx := context.Background() + + tx, err := db.Begin(ctx) + if err != nil { + return err + } + defer tx.Rollback(ctx) + + pstmt, err := tx.Prepare(ctx, "", finalSQL) + if err != nil { + return err + } + + var key string + + if s.role == nil { + key = gqlHash(gql, varBytes, "") + } else { + key = gqlHash(gql, varBytes, s.role.Name) + } + + _preparedList[key] = &preparedItem{ + stmt: pstmt, + args: am, + skipped: s.skipped, + qc: s.qc, + } + + if err := tx.Commit(ctx); err != nil { return err } } - buf := &bytes.Buffer{} + return nil +} - skipped, err := pcompile.Compile(qc, buf, psql.Variables(vars)) - if err != nil { - return err +func prepareRoleStmt() error { + if len(conf.RolesQuery) == 0 { + return nil } - t := fasttemplate.New(buf.String(), `{{`, `}}`) - am := make([][]byte, 0, 5) - i := 0 + w := &bytes.Buffer{} - finalSQL := t.ExecuteFuncString(func(w io.Writer, tag string) (int, error) { - am = append(am, []byte(tag)) - i++ - return w.Write([]byte(fmt.Sprintf("$%d", i))) - }) - - if err != nil { - return err + io.WriteString(w, `SELECT (CASE`) + for _, role := range conf.Roles { + if len(role.Match) == 0 { + continue + } + io.WriteString(w, ` WHEN `) + io.WriteString(w, role.Match) + io.WriteString(w, ` THEN '`) + io.WriteString(w, role.Name) + io.WriteString(w, `'`) } + io.WriteString(w, ` ELSE {{role}} END) FROM (`) + io.WriteString(w, conf.RolesQuery) + io.WriteString(w, `) AS "_sg_auth_roles_query"`) + + roleSQL, _ := processTemplate(w.String()) + ctx := context.Background() tx, err := db.Begin(ctx) @@ -84,21 +128,28 @@ func prepareStmt(key, gql string, varBytes json.RawMessage) error { } defer tx.Rollback(ctx) - pstmt, err := tx.Prepare(ctx, "", finalSQL) + _, err = tx.Prepare(ctx, "_sg_get_role", roleSQL) if err != nil { return err } - _preparedList[key] = &preparedItem{ - stmt: pstmt, - args: am, - skipped: skipped, - qc: qc, - } - - if err := tx.Commit(ctx); err != nil { - return err - } - return nil } + +func processTemplate(tmpl string) (string, [][]byte) { + t := fasttemplate.New(tmpl, `{{`, `}}`) + am := make([][]byte, 0, 5) + i := 0 + + vmap := make(map[string]int) + + return t.ExecuteFuncString(func(w io.Writer, tag string) (int, error) { + if n, ok := vmap[tag]; ok { + return w.Write([]byte(fmt.Sprintf("$%d", n))) + } + am = append(am, []byte(tag)) + i++ + vmap[tag] = i + return w.Write([]byte(fmt.Sprintf("$%d", i))) + }), am +} diff --git a/serv/serv.go b/serv/serv.go index 980eb23..e03b9c2 100644 --- a/serv/serv.go +++ b/serv/serv.go @@ -67,7 +67,7 @@ func initCompilers(c *config) (*qcode.Compiler, *psql.Compiler, error) { pc := psql.NewCompiler(psql.Config{ Schema: schema, - Vars: c.getVariables(), + Vars: c.DB.Vars, }) return qc, pc, nil diff --git a/serv/utils.go b/serv/utils.go index baf49c3..9e095e8 100644 --- a/serv/utils.go +++ b/serv/utils.go @@ -21,7 +21,7 @@ func mkkey(h *xxhash.Digest, k1 string, k2 string) uint64 { return v } -func gqlHash(b string, vars []byte) string { +func gqlHash(b string, vars []byte, role string) string { b = strings.TrimSpace(b) h := sha1.New() @@ -56,6 +56,10 @@ func gqlHash(b string, vars []byte) string { } } + if len(role) != 0 { + io.WriteString(h, role) + } + if vars == nil || len(vars) == 0 { return hex.EncodeToString(h.Sum(nil)) } @@ -80,3 +84,26 @@ func ws(b byte) bool { func al(b byte) bool { return (b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z') || (b >= '0' && b <= '9') } + +func isMutation(sql string) bool { + for i := range sql { + b := sql[i] + if b == '{' { + return false + } + if al(b) { + return (b == 'm' || b == 'M') + } + } + return false +} + +func findStmt(role string, stmts []stmt) *stmt { + for i := range stmts { + if stmts[i].role.Name != role { + continue + } + return &stmts[i] + } + return nil +} diff --git a/serv/vars.go b/serv/vars.go index 6ad9da6..f20627c 100644 --- a/serv/vars.go +++ b/serv/vars.go @@ -11,17 +11,27 @@ import ( func varMap(ctx *coreContext) func(w io.Writer, tag string) (int, error) { return func(w io.Writer, tag string) (int, error) { switch tag { - case "user_id": - if v := ctx.Value(userIDKey); v != nil { - return stringVar(w, v.(string)) - } - return 0, errNoUserID - case "user_id_provider": if v := ctx.Value(userIDProviderKey); v != nil { return stringVar(w, v.(string)) } - return 0, errNoUserID + io.WriteString(w, "null") + return 0, nil + + case "user_id": + if v := ctx.Value(userIDKey); v != nil { + return stringVar(w, v.(string)) + } + + io.WriteString(w, "null") + return 0, nil + + case "user_role": + if v := ctx.Value(userRoleKey); v != nil { + return stringVar(w, v.(string)) + } + io.WriteString(w, "null") + return 0, nil } fields := jsn.Get(ctx.req.Vars, [][]byte{[]byte(tag)}) diff --git a/tmpl/prod.yml b/tmpl/prod.yml index 9597d7a..29c6b45 100644 --- a/tmpl/prod.yml +++ b/tmpl/prod.yml @@ -47,10 +47,6 @@ auth: type: rails cookie: _app_session - # Comment this out if you want to disable setting - # the user_id via a header. Good for testing - header: X-User-ID - rails: # Rails version this is used for reading the # various cookies formats. @@ -106,42 +102,93 @@ database: - token tables: - - name: users - # This filter will overwrite defaults.filter - # filter: ["{ id: { eq: $user_id } }"] - # filter_query: ["{ id: { eq: $user_id } }"] - filter_update: ["{ id: { eq: $user_id } }"] - filter_delete: ["{ id: { eq: $user_id } }"] + - name: users + # This filter will overwrite defaults.filter + # filter: ["{ id: { eq: $user_id } }"] - - name: products - # Multiple filters are AND'd together - filter: [ - "{ price: { gt: 0 } }", - "{ price: { lt: 8 } }" - ] + # - name: products + # # Multiple filters are AND'd together + # filter: [ + # "{ price: { gt: 0 } }", + # "{ price: { lt: 8 } }" + # ] - - name: customers - # No filter is used for this field not - # even defaults.filter - filter: none + - name: customers + remotes: + - name: payments + id: stripe_id + url: http://rails_app:3000/stripe/$id + path: data + # debug: true + pass_headers: + - cookie + set_headers: + - name: Host + value: 0.0.0.0 + # - name: Authorization + # value: Bearer - # remotes: - # - name: payments - # id: stripe_id - # url: http://rails_app:3000/stripe/$id - # path: data - # # pass_headers: - # # - cookie - # # - host - # set_headers: - # - name: Authorization - # value: Bearer + - # You can create new fields that have a + # real db table backing them + name: me + table: users - - # You can create new fields that have a - # real db table backing them - name: me - table: users - filter: ["{ id: { eq: $user_id } }"] +roles_query: "SELECT * FROM users as usr WHERE id = $user_id" - # - name: posts - # filter: ["{ account_id: { _eq: $account_id } }"] \ No newline at end of file +roles: + - name: anon + tables: + - name: products + limit: 10 + + query: + columns: ["id", "name", "description" ] + aggregation: false + + insert: + allow: false + + update: + allow: false + + delete: + allow: false + + - name: user + tables: + - name: users + query: + filter: ["{ id: { _eq: $user_id } }"] + + - name: products + + query: + limit: 50 + filter: ["{ user_id: { eq: $user_id } }"] + columns: ["id", "name", "description" ] + disable_aggregation: false + + insert: + filter: ["{ user_id: { eq: $user_id } }"] + columns: ["id", "name", "description" ] + set: + - created_at: "now" + + update: + filter: ["{ user_id: { eq: $user_id } }"] + columns: + - id + - name + set: + - updated_at: "now" + + delete: + deny: true + + - name: admin + match: id = 1 + tables: + - name: users + + # select: + # filter: ["{ account_id: { _eq: $account_id } }"] From 4edc15eb9895a5c84f68dade668c1a7137b8b4ed Mon Sep 17 00:00:00 2001 From: Vikram Rangnekar Date: Fri, 25 Oct 2019 00:01:22 -0400 Subject: [PATCH 5/6] Optimize prepared statement flow for RBAC --- psql/mutate.go | 2 -- psql/query_test.go | 32 ++++++++++++------------- qcode/parse.go | 2 ++ qcode/qcode.go | 2 ++ serv/allow.go | 2 ++ serv/auth.go | 6 ++--- serv/auth_rails.go | 4 ++-- serv/cmd.go | 2 ++ serv/config.go | 26 ++++++++++++++++++++ serv/core.go | 14 ++++++++--- serv/http.go | 8 ++++--- serv/prepare.go | 9 +++++-- serv/utils.go | 14 +++++++++++ serv/utils_test.go | 59 ++++++++++++++++++++++++++++++++++------------ 14 files changed, 136 insertions(+), 46 deletions(-) diff --git a/psql/mutate.go b/psql/mutate.go index 067270e..84bb122 100644 --- a/psql/mutate.go +++ b/psql/mutate.go @@ -61,8 +61,6 @@ func (co *Compiler) compileMutation(qc *qcode.QCode, w io.Writer, vars Variables root.Where = nil root.Args = nil - qc.Type = qcode.QTQuery - return c.compileQuery(qc, w) } diff --git a/psql/query_test.go b/psql/query_test.go index ea65645..c24b439 100644 --- a/psql/query_test.go +++ b/psql/query_test.go @@ -201,7 +201,7 @@ func withComplexArgs(t *testing.T) { } }` - sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0" ORDER BY "products_0_price_ob" DESC), '[]') AS "products" FROM (SELECT DISTINCT ON ("products_0_price_ob") row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0", "products_0"."price" AS "products_0_price_ob" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") < 28) AND (("products"."id") >= 20)) LIMIT ('30') :: integer) AS "products_0" ORDER BY "products_0_price_ob" DESC LIMIT ('30') :: integer) AS "sel_json_agg_0") AS "done_1337";` + sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0" ORDER BY "products_0_price_ob" DESC), '[]') AS "products" FROM (SELECT DISTINCT ON ("products_0_price_ob") row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0", "products_0"."price" AS "products_0_price_ob" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") < 28) AND (("products"."id") >= 20)) LIMIT ('30') :: integer) AS "products_0" ORDER BY "products_0_price_ob" DESC LIMIT ('30') :: integer) AS "sel_json_agg_0") AS "done_1337"` resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { @@ -229,7 +229,7 @@ func withWhereMultiOr(t *testing.T) { } }` - sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") < 20) OR (("products"."price") > 10) OR NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` + sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") < 20) OR (("products"."price") > 10) OR NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"` resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { @@ -255,7 +255,7 @@ func withWhereIsNull(t *testing.T) { } }` - sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") > 10) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` + sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") > 10) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"` resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { @@ -281,7 +281,7 @@ func withWhereAndList(t *testing.T) { } }` - sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") > 10) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` + sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") > 10) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"` resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { @@ -301,7 +301,7 @@ func fetchByID(t *testing.T) { } }` - sql := `SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") = 15)) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337";` + sql := `SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") = 15)) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337"` resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { @@ -321,7 +321,7 @@ func searchQuery(t *testing.T) { } }` - sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("tsv") @@ to_tsquery('Imperial'))) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` + sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("tsv") @@ to_tsquery('Imperial'))) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"` resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { @@ -344,7 +344,7 @@ func oneToMany(t *testing.T) { } }` - sql := `SELECT json_object_agg('users', users) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."email" AS "email", "products_1_join"."products" AS "products") AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."email", "users"."id" FROM "users" LIMIT ('20') :: integer) AS "users_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "products_1"."name" AS "name", "products_1"."price" AS "price") AS "sel_1")) AS "sel_json_1" FROM (SELECT "products"."name", "products"."price" FROM "products" WHERE ((("products"."user_id") = ("users_0"."id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` + sql := `SELECT json_object_agg('users', users) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."email" AS "email", "products_1_join"."products" AS "products") AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."email", "users"."id" FROM "users" LIMIT ('20') :: integer) AS "users_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "products_1"."name" AS "name", "products_1"."price" AS "price") AS "sel_1")) AS "sel_json_1" FROM (SELECT "products"."name", "products"."price" FROM "products" WHERE ((("products"."user_id") = ("users_0"."id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"` resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { @@ -367,7 +367,7 @@ func belongsTo(t *testing.T) { } }` - sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."price" AS "price", "users_1_join"."users" AS "users") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", "products"."price", "products"."user_id" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "users_1"."email" AS "email") AS "sel_1")) AS "sel_json_1" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = ("products_0"."user_id"))) LIMIT ('20') :: integer) AS "users_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "users_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` + sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."price" AS "price", "users_1_join"."users" AS "users") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", "products"."price", "products"."user_id" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "users_1"."email" AS "email") AS "sel_1")) AS "sel_json_1" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = ("products_0"."user_id"))) LIMIT ('20') :: integer) AS "users_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "users_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"` resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { @@ -390,7 +390,7 @@ func manyToMany(t *testing.T) { } }` - sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "customers_1_join"."customers" AS "customers") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", "products"."id" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "customers_1"."email" AS "email", "customers_1"."full_name" AS "full_name") AS "sel_1")) AS "sel_json_1" FROM (SELECT "customers"."email", "customers"."full_name" FROM "customers" LEFT OUTER JOIN "purchases" ON (("purchases"."product_id") = ("products_0"."id")) WHERE ((("customers"."id") = ("purchases"."customer_id"))) LIMIT ('20') :: integer) AS "customers_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "customers_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` + sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "customers_1_join"."customers" AS "customers") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", "products"."id" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "customers_1"."email" AS "email", "customers_1"."full_name" AS "full_name") AS "sel_1")) AS "sel_json_1" FROM (SELECT "customers"."email", "customers"."full_name" FROM "customers" LEFT OUTER JOIN "purchases" ON (("purchases"."product_id") = ("products_0"."id")) WHERE ((("customers"."id") = ("purchases"."customer_id"))) LIMIT ('20') :: integer) AS "customers_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "customers_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"` resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { @@ -413,7 +413,7 @@ func manyToManyReverse(t *testing.T) { } }` - sql := `SELECT json_object_agg('customers', customers) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "customers_0"."email" AS "email", "customers_0"."full_name" AS "full_name", "products_1_join"."products" AS "products") AS "sel_0")) AS "sel_json_0" FROM (SELECT "customers"."email", "customers"."full_name", "customers"."id" FROM "customers" LIMIT ('20') :: integer) AS "customers_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "products_1"."name" AS "name") AS "sel_1")) AS "sel_json_1" FROM (SELECT "products"."name" FROM "products" LEFT OUTER JOIN "purchases" ON (("purchases"."customer_id") = ("customers_0"."id")) WHERE ((("products"."id") = ("purchases"."product_id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` + sql := `SELECT json_object_agg('customers', customers) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "customers_0"."email" AS "email", "customers_0"."full_name" AS "full_name", "products_1_join"."products" AS "products") AS "sel_0")) AS "sel_json_0" FROM (SELECT "customers"."email", "customers"."full_name", "customers"."id" FROM "customers" LIMIT ('20') :: integer) AS "customers_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "products_1"."name" AS "name") AS "sel_1")) AS "sel_json_1" FROM (SELECT "products"."name" FROM "products" LEFT OUTER JOIN "purchases" ON (("purchases"."customer_id") = ("customers_0"."id")) WHERE ((("products"."id") = ("purchases"."product_id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"` resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { @@ -433,7 +433,7 @@ func aggFunction(t *testing.T) { } }` - sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."count_price" AS "count_price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", count("products"."price") AS "count_price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) GROUP BY "products"."name" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` + sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."count_price" AS "count_price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", count("products"."price") AS "count_price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) GROUP BY "products"."name" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"` resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { @@ -453,7 +453,7 @@ func aggFunctionBlockedByCol(t *testing.T) { } }` - sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name" FROM "products" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` + sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name" FROM "products" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"` resSQL, err := compileGQLToPSQL(gql, nil, "anon") if err != nil { @@ -473,7 +473,7 @@ func aggFunctionDisabled(t *testing.T) { } }` - sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name" FROM "products" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` + sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name" FROM "products" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"` resSQL, err := compileGQLToPSQL(gql, nil, "anon1") if err != nil { @@ -493,7 +493,7 @@ func aggFunctionWithFilter(t *testing.T) { } }` - sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."max_price" AS "max_price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", max("products"."price") AS "max_price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") > 10)) GROUP BY "products"."id" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` + sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."max_price" AS "max_price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", max("products"."price") AS "max_price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") > 10)) GROUP BY "products"."id" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"` resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { @@ -513,7 +513,7 @@ func queryWithVariables(t *testing.T) { } }` - sql := `SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") = {{product_price}}) AND (("products"."id") = {{product_id}})) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337";` + sql := `SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") = {{product_price}}) AND (("products"."id") = {{product_id}})) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337"` resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { @@ -532,7 +532,7 @@ func syntheticTables(t *testing.T) { } }` - sql := `SELECT json_object_agg('me', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT ) AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = {{user_id}})) LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "done_1337";` + sql := `SELECT json_object_agg('me', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT ) AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = {{user_id}})) LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "done_1337"` resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { diff --git a/qcode/parse.go b/qcode/parse.go index c07ab45..0fe6c34 100644 --- a/qcode/parse.go +++ b/qcode/parse.go @@ -18,7 +18,9 @@ type parserType int32 const ( maxFields = 100 maxArgs = 10 +) +const ( parserError parserType = iota parserEOF opQuery diff --git a/qcode/qcode.go b/qcode/qcode.go index 30bc724..a1d55e3 100644 --- a/qcode/qcode.go +++ b/qcode/qcode.go @@ -16,7 +16,9 @@ type Action int const ( maxSelectors = 30 +) +const ( QTQuery QType = iota + 1 QTInsert QTUpdate diff --git a/serv/allow.go b/serv/allow.go index 1279170..f960fc0 100644 --- a/serv/allow.go +++ b/serv/allow.go @@ -182,6 +182,8 @@ func (al *allowList) load() { item.vars = varBytes } + //fmt.Println("%%", item.gql, string(item.vars)) + al.list[gqlHash(q, varBytes, "")] = item varBytes = nil diff --git a/serv/auth.go b/serv/auth.go index 77942eb..22ab698 100644 --- a/serv/auth.go +++ b/serv/auth.go @@ -7,9 +7,9 @@ import ( ) var ( - userIDProviderKey = struct{}{} - userIDKey = struct{}{} - userRoleKey = struct{}{} + userIDProviderKey = "user_id_provider" + userIDKey = "user_id" + userRoleKey = "user_role" ) func headerAuth(next http.HandlerFunc) http.HandlerFunc { diff --git a/serv/auth_rails.go b/serv/auth_rails.go index cd0b327..7f78da0 100644 --- a/serv/auth_rails.go +++ b/serv/auth_rails.go @@ -122,14 +122,14 @@ func railsCookieHandler(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { ck, err := r.Cookie(cookie) if err != nil { - logger.Warn().Err(err).Send() + logger.Warn().Err(err).Msg("rails cookie missing") next.ServeHTTP(w, r) return } userID, err := ra.ParseCookie(ck.Value) if err != nil { - logger.Warn().Err(err).Send() + logger.Warn().Err(err).Msg("failed to parse rails cookie") next.ServeHTTP(w, r) return } diff --git a/serv/cmd.go b/serv/cmd.go index 12b3ce7..6fbc09f 100644 --- a/serv/cmd.go +++ b/serv/cmd.go @@ -210,6 +210,8 @@ func initConf() (*config, error) { c.Roles = append(c.Roles, configRole{Name: "anon"}) } + c.Validate() + return c, nil } diff --git a/serv/config.go b/serv/config.go index 8420c66..39e3d53 100644 --- a/serv/config.go +++ b/serv/config.go @@ -168,6 +168,32 @@ func newConfig() *viper.Viper { return vi } +func (c *config) Validate() { + rm := make(map[string]struct{}) + + for i := range c.Roles { + name := strings.ToLower(c.Roles[i].Name) + if _, ok := rm[name]; ok { + logger.Fatal().Msgf("duplicate config for role '%s'", c.Roles[i].Name) + } + rm[name] = struct{}{} + } + + tm := make(map[string]struct{}) + + for i := range c.Tables { + name := strings.ToLower(c.Tables[i].Name) + if _, ok := tm[name]; ok { + logger.Fatal().Msgf("duplicate config for table '%s'", c.Tables[i].Name) + } + tm[name] = struct{}{} + } + + if len(c.RolesQuery) == 0 { + logger.Warn().Msgf("no 'roles_query' defined.") + } +} + func (c *config) getAliasMap() map[string][]string { m := make(map[string][]string, len(c.Tables)) diff --git a/serv/core.go b/serv/core.go index 8a007ac..dc5d42b 100644 --- a/serv/core.go +++ b/serv/core.go @@ -131,16 +131,20 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) { } var role string - useRoleQuery := len(conf.RolesQuery) != 0 && isMutation(c.req.Query) + mutation := isMutation(c.req.Query) + useRoleQuery := len(conf.RolesQuery) != 0 && mutation if useRoleQuery { if role, err = c.executeRoleQuery(tx); err != nil { return nil, nil, err } + } else if v := c.Value(userRoleKey); v != nil { role = v.(string) - } else { + + } else if mutation { role = c.req.role + } ps, ok := _preparedList[gqlHash(c.req.Query, c.req.Vars, role)] @@ -151,7 +155,11 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) { var root []byte vars := varList(c, ps.args) - err = tx.QueryRow(c, ps.stmt.SQL, vars...).Scan(&root) + if mutation { + err = tx.QueryRow(c, ps.stmt.SQL, vars...).Scan(&root) + } else { + err = tx.QueryRow(c, ps.stmt.SQL, vars...).Scan(&c.req.role, &root) + } if err != nil { return nil, nil, err } diff --git a/serv/http.go b/serv/http.go index ae8ff84..737006d 100644 --- a/serv/http.go +++ b/serv/http.go @@ -37,8 +37,8 @@ type gqlReq struct { type variables map[string]json.RawMessage type gqlResp struct { - Error string `json:"error,omitempty"` - Data json.RawMessage `json:"data"` + Error string `json:"message,omitempty"` + Data json.RawMessage `json:"data,omitempty"` Extensions *extensions `json:"extensions,omitempty"` } @@ -102,7 +102,9 @@ func apiv1Http(w http.ResponseWriter, r *http.Request) { err = ctx.handleReq(w, r) if err == errUnauthorized { - http.Error(w, "Not authorized", 401) + w.WriteHeader(http.StatusUnauthorized) + json.NewEncoder(w).Encode(gqlResp{Error: err.Error()}) + return } if err != nil { diff --git a/serv/prepare.go b/serv/prepare.go index 8415397..2329578 100644 --- a/serv/prepare.go +++ b/serv/prepare.go @@ -31,6 +31,7 @@ func initPreparedList() { } for _, v := range _allowList.list { + err := prepareStmt(v.gql, v.vars) if err != nil { logger.Warn().Str("gql", v.gql).Err(err).Send() @@ -52,6 +53,10 @@ func prepareStmt(gql string, varBytes json.RawMessage) error { return err } + if len(stmts) != 0 && stmts[0].qc.Type == qcode.QTQuery { + c.req.Vars = nil + } + for _, s := range stmts { if len(s.sql) == 0 { continue @@ -75,9 +80,9 @@ func prepareStmt(gql string, varBytes json.RawMessage) error { var key string if s.role == nil { - key = gqlHash(gql, varBytes, "") + key = gqlHash(gql, c.req.Vars, "") } else { - key = gqlHash(gql, varBytes, s.role.Name) + key = gqlHash(gql, c.req.Vars, s.role.Name) } _preparedList[key] = &preparedItem{ diff --git a/serv/utils.go b/serv/utils.go index 9e095e8..b59dded 100644 --- a/serv/utils.go +++ b/serv/utils.go @@ -24,13 +24,26 @@ func mkkey(h *xxhash.Digest, k1 string, k2 string) uint64 { func gqlHash(b string, vars []byte, role string) string { b = strings.TrimSpace(b) h := sha1.New() + query := "query" s, e := 0, 0 space := []byte{' '} + starting := true var b0, b1 byte for { + if starting && b[e] == 'q' { + n := 0 + se := e + for e < len(b) && n < len(query) && b[e] == query[n] { + n++ + e++ + } + if n != len(query) { + io.WriteString(h, strings.ToLower(b[se:e])) + } + } if ws(b[e]) { for e < len(b) && ws(b[e]) { e++ @@ -42,6 +55,7 @@ func gqlHash(b string, vars []byte, role string) string { h.Write(space) } } else { + starting = false s = e for e < len(b) && ws(b[e]) == false { e++ diff --git a/serv/utils_test.go b/serv/utils_test.go index 17d91b7..b8babeb 100644 --- a/serv/utils_test.go +++ b/serv/utils_test.go @@ -5,7 +5,7 @@ import ( "testing" ) -func TestRelaxHash1(t *testing.T) { +func TestGQLHash1(t *testing.T) { var v1 = ` products( limit: 30, @@ -24,15 +24,15 @@ func TestRelaxHash1(t *testing.T) { price } ` - h1 := gqlHash(v1, nil) - h2 := gqlHash(v2, nil) + h1 := gqlHash(v1, nil, "") + h2 := gqlHash(v2, nil, "") if strings.Compare(h1, h2) != 0 { t.Fatal("Hashes don't match they should") } } -func TestRelaxHash2(t *testing.T) { +func TestGQLHash2(t *testing.T) { var v1 = ` { products( @@ -53,15 +53,15 @@ func TestRelaxHash2(t *testing.T) { var v2 = ` { products( limit: 30, order_by: { price: desc }, distinct: [ price ] where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) { id name price user { id email } } } ` - h1 := gqlHash(v1, nil) - h2 := gqlHash(v2, nil) + h1 := gqlHash(v1, nil, "") + h2 := gqlHash(v2, nil, "") if strings.Compare(h1, h2) != 0 { t.Fatal("Hashes don't match they should") } } -func TestRelaxHash3(t *testing.T) { +func TestGQLHash3(t *testing.T) { var v1 = `users { id email @@ -86,15 +86,44 @@ func TestRelaxHash3(t *testing.T) { } ` - h1 := gqlHash(v1, nil) - h2 := gqlHash(v2, nil) + h1 := gqlHash(v1, nil, "") + h2 := gqlHash(v2, nil, "") if strings.Compare(h1, h2) != 0 { t.Fatal("Hashes don't match they should") } } -func TestRelaxHashWithVars1(t *testing.T) { +func TestGQLHash4(t *testing.T) { + var v1 = ` + query { + products( + limit: 30 + order_by: { price: desc } + distinct: [price] + where: { id: { and: { greater_or_equals: 20, lt: 28 } } } + ) { + id + name + price + user { + id + email + } + } + }` + + var v2 = ` { products( limit: 30, order_by: { price: desc }, distinct: [ price ] where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) { id name price user { id email } } } ` + + h1 := gqlHash(v1, nil, "") + h2 := gqlHash(v2, nil, "") + + if strings.Compare(h1, h2) != 0 { + t.Fatal("Hashes don't match they should") + } +} + +func TestGQLHashWithVars1(t *testing.T) { var q1 = ` products( limit: 30, @@ -136,15 +165,15 @@ func TestRelaxHashWithVars1(t *testing.T) { "user": 123 }` - h1 := gqlHash(q1, []byte(v1)) - h2 := gqlHash(q2, []byte(v2)) + h1 := gqlHash(q1, []byte(v1), "user") + h2 := gqlHash(q2, []byte(v2), "user") if strings.Compare(h1, h2) != 0 { t.Fatal("Hashes don't match they should") } } -func TestRelaxHashWithVars2(t *testing.T) { +func TestGQLHashWithVars2(t *testing.T) { var q1 = ` products( limit: 30, @@ -193,8 +222,8 @@ func TestRelaxHashWithVars2(t *testing.T) { "user": 123 }` - h1 := gqlHash(q1, []byte(v1)) - h2 := gqlHash(q2, []byte(v2)) + h1 := gqlHash(q1, []byte(v1), "user") + h2 := gqlHash(q2, []byte(v2), "user") if strings.Compare(h1, h2) != 0 { t.Fatal("Hashes don't match they should") From cabd2d81ae2024010d1a87b94d474d95f942d7bf Mon Sep 17 00:00:00 2001 From: Vikram Rangnekar Date: Fri, 25 Oct 2019 01:39:59 -0400 Subject: [PATCH 6/6] Preserve allow.list ordering on save --- .gitignore | 1 + .wtc.yaml | 13 ++++ Dockerfile | 5 +- README.md | 3 + config/dev.yml | 29 +++----- config/prod.yml | 80 ++++++++++++++++------ docker-compose.yml | 2 +- docs/guide.md | 153 ++++++++++++++++++++++++++++-------------- psql/mutate_test.go | 14 ++-- psql/query_test.go | 8 +-- qcode/config.go | 8 +-- qcode/qcode.go | 8 +-- serv/allow.go | 37 +++++----- serv/config.go | 10 +-- serv/serv.go | 8 +-- slides/overview.slide | 8 +-- tmpl/dev.yml | 118 +++++++++++++++++++++----------- tmpl/prod.yml | 59 ++++++---------- 18 files changed, 341 insertions(+), 223 deletions(-) create mode 100644 .wtc.yaml diff --git a/.gitignore b/.gitignore index 5c88189..fdcd5db 100644 --- a/.gitignore +++ b/.gitignore @@ -28,3 +28,4 @@ main .DS_Store .swp main +super-graph diff --git a/.wtc.yaml b/.wtc.yaml new file mode 100644 index 0000000..215573f --- /dev/null +++ b/.wtc.yaml @@ -0,0 +1,13 @@ +no_trace: false +debounce: 300 # if rule has no debounce, this will be used instead +ignore: \.git/ +trig: [start, run] # will run on start +rules: + - name: start + - name: run + match: \.go$ + ignore: web|examples|docs|_test\.go$ + command: go run main.go serv + - name: test + match: _test\.go$ + command: go test -cover {PKG} \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index c2487cc..805cc0d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,9 +11,8 @@ RUN apk update && \ apk add --no-cache git && \ apk add --no-cache upx=3.95-r2 -RUN go get -u github.com/shanzi/wu && \ - go install github.com/shanzi/wu && \ - go get github.com/GeertJohan/go.rice/rice +RUN go get -u github.com/rafaelsq/wtc && \ + go get -u github.com/GeertJohan/go.rice/rice WORKDIR /app COPY . /app diff --git a/README.md b/README.md index ec614f4..e6fbee3 100644 --- a/README.md +++ b/README.md @@ -46,6 +46,9 @@ This compiler is what sits at the heart of Super Graph with layers of useful fun ## Contact me +I'm happy to help you deploy Super Graph so feel free to reach out over +Twitter or Discord. + [twitter/dosco](https://twitter.com/dosco) [chat/super-graph](https://discord.gg/6pSWCTZ) diff --git a/config/dev.yml b/config/dev.yml index ade488b..ffe9f64 100644 --- a/config/dev.yml +++ b/config/dev.yml @@ -100,7 +100,7 @@ database: # Define defaults to for the field key and values below defaults: - # filter: ["{ user_id: { eq: $user_id } }"] + # filters: ["{ user_id: { eq: $user_id } }"] # Field and table names that you wish to block blocklist: @@ -112,17 +112,6 @@ database: - token tables: - - name: users - # This filter will overwrite defaults.filter - # filter: ["{ id: { eq: $user_id } }"] - - # - name: products - # # Multiple filters are AND'd together - # filter: [ - # "{ price: { gt: 0 } }", - # "{ price: { lt: 8 } }" - # ] - - name: customers remotes: - name: payments @@ -168,24 +157,23 @@ roles: tables: - name: users query: - filter: ["{ id: { _eq: $user_id } }"] + filters: ["{ id: { _eq: $user_id } }"] - name: products - query: limit: 50 - filter: ["{ user_id: { eq: $user_id } }"] + filters: ["{ user_id: { eq: $user_id } }"] columns: ["id", "name", "description" ] disable_aggregation: false insert: - filter: ["{ user_id: { eq: $user_id } }"] + filters: ["{ user_id: { eq: $user_id } }"] columns: ["id", "name", "description" ] set: - - created_at: "now" + - created_at: "now" update: - filter: ["{ user_id: { eq: $user_id } }"] + filters: ["{ user_id: { eq: $user_id } }"] columns: - id - name @@ -199,6 +187,5 @@ roles: match: id = 1 tables: - name: users - - # select: - # filter: ["{ account_id: { _eq: $account_id } }"] + # query: + # filters: ["{ account_id: { _eq: $account_id } }"] diff --git a/config/prod.yml b/config/prod.yml index a52af3d..95abfb7 100644 --- a/config/prod.yml +++ b/config/prod.yml @@ -90,7 +90,7 @@ database: # Define defaults to for the field key and values below defaults: - filter: ["{ user_id: { eq: $user_id } }"] + filters: ["{ user_id: { eq: $user_id } }"] # Field and table names that you wish to block blocklist: @@ -102,25 +102,7 @@ database: - token tables: - - name: users - # This filter will overwrite defaults.filter - # filter: ["{ id: { eq: $user_id } }"] - # filter_query: ["{ id: { eq: $user_id } }"] - filter_update: ["{ id: { eq: $user_id } }"] - filter_delete: ["{ id: { eq: $user_id } }"] - - - name: products - # Multiple filters are AND'd together - filter: [ - "{ price: { gt: 0 } }", - "{ price: { lt: 8 } }" - ] - - name: customers - # No filter is used for this field not - # even defaults.filter - filter: none - # remotes: # - name: payments # id: stripe_id @@ -137,7 +119,61 @@ tables: # real db table backing them name: me table: users - filter: ["{ id: { eq: $user_id } }"] - # - name: posts - # filter: ["{ account_id: { _eq: $account_id } }"] \ No newline at end of file +roles_query: "SELECT * FROM users as usr WHERE id = $user_id" + +roles: + - name: anon + tables: + - name: products + limit: 10 + + query: + columns: ["id", "name", "description" ] + aggregation: false + + insert: + allow: false + + update: + allow: false + + delete: + allow: false + + - name: user + tables: + - name: users + query: + filters: ["{ id: { _eq: $user_id } }"] + + - name: products + query: + limit: 50 + filters: ["{ user_id: { eq: $user_id } }"] + columns: ["id", "name", "description" ] + disable_aggregation: false + + insert: + filters: ["{ user_id: { eq: $user_id } }"] + columns: ["id", "name", "description" ] + set: + - created_at: "now" + + update: + filters: ["{ user_id: { eq: $user_id } }"] + columns: + - id + - name + set: + - updated_at: "now" + + delete: + deny: true + + - name: admin + match: id = 1 + tables: + - name: users + # query: + # filters: ["{ account_id: { _eq: $account_id } }"] diff --git a/docker-compose.yml b/docker-compose.yml index d41e9f5..b3beb6e 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -34,7 +34,7 @@ services: volumes: - .:/app working_dir: /app - command: wu -pattern="*.go" go run main.go serv + command: wtc depends_on: - db - rails_app diff --git a/docs/guide.md b/docs/guide.md index 13d6338..55f0b43 100644 --- a/docs/guide.md +++ b/docs/guide.md @@ -1043,26 +1043,35 @@ We're tried to ensure that the config file is self documenting and easy to work app_name: "Super Graph Development" host_port: 0.0.0.0:8080 web_ui: true -debug_level: 1 -# debug, info, warn, error, fatal, panic, disable -log_level: "info" +# debug, info, warn, error, fatal, panic +log_level: "debug" # Disable this in development to get a list of # queries used. When enabled super graph # will only allow queries from this list # List saved to ./config/allow.list -use_allow_list: true +use_allow_list: false # Throw a 401 on auth failure for queries that need auth # valid values: always, per_query, never -auth_fail_block: always +auth_fail_block: never # Latency tracing for database queries and remote joins # the resulting latency information is returned with the # response enable_tracing: true +# Watch the config folder and reload Super Graph +# with the new configs when a change is detected +reload_on_config_change: true + +# File that points to the database seeding script +# seed_file: seed.js + +# Path pointing to where the migrations can be found +migrations_path: ./config/migrations + # Postgres related environment Variables # SG_DATABASE_HOST # SG_DATABASE_PORT @@ -1086,7 +1095,7 @@ auth: # Comment this out if you want to disable setting # the user_id via a header. Good for testing - header: X-User-ID + creds_in_header: true rails: # Rails version this is used for reading the @@ -1097,10 +1106,10 @@ auth: secret_key_base: 0a248500a64c01184edb4d7ad3a805488f8097ac761b76aaa6c17c01dcb7af03a2f18ba61b2868134b9c7b79a122bc0dadff4367414a2d173297bfea92be5566 # Remote cookie store. (memcache or redis) - # url: redis://127.0.0.1:6379 - # password: test - # max_idle: 80, - # max_active: 12000, + # url: redis://redis:6379 + # password: "" + # max_idle: 80 + # max_active: 12000 # In most cases you don't need these # salt: "encrypted cookie" @@ -1120,20 +1129,23 @@ database: dbname: app_development user: postgres password: '' - # pool_size: 10 - # max_retries: 0 - # log_level: "debug" + + #schema: "public" + #pool_size: 10 + #max_retries: 0 + #log_level: "debug" # Define variables here that you want to use in filters + # sub-queries must be wrapped in () variables: - account_id: "select account_id from users where id = $user_id" + account_id: "(select account_id from users where id = $user_id)" # Define defaults to for the field key and values below defaults: - filter: ["{ user_id: { eq: $user_id } }"] + # filters: ["{ user_id: { eq: $user_id } }"] # Field and table names that you wish to block - blacklist: + blocklist: - ar_internal_metadata - schema_migrations - secret @@ -1141,46 +1153,85 @@ database: - encrypted - token - tables: - - name: users - # This filter will overwrite defaults.filter - # filter: ["{ id: { eq: $user_id } }"] - # filter_query: ["{ id: { eq: $user_id } }"] - filter_update: ["{ id: { eq: $user_id } }"] - filter_delete: ["{ id: { eq: $user_id } }"] +tables: + - name: customers + remotes: + - name: payments + id: stripe_id + url: http://rails_app:3000/stripe/$id + path: data + # debug: true + pass_headers: + - cookie + set_headers: + - name: Host + value: 0.0.0.0 + # - name: Authorization + # value: Bearer - - name: products - # Multiple filters are AND'd together - filter: [ - "{ price: { gt: 0 } }", - "{ price: { lt: 8 } }" - ] + - # You can create new fields that have a + # real db table backing them + name: me + table: users - - name: customers - # No filter is used for this field not - # even defaults.filter - filter: none +roles_query: "SELECT * FROM users as usr WHERE id = $user_id" - remotes: - - name: payments - id: stripe_id - url: http://rails_app:3000/stripe/$id - path: data - # pass_headers: - # - cookie - # - host - set_headers: - - name: Authorization - value: Bearer +roles: + - name: anon + tables: + - name: products + limit: 10 - - # You can create new fields that have a - # real db table backing them - name: me - table: users - filter: ["{ id: { eq: $user_id } }"] + query: + columns: ["id", "name", "description" ] + aggregation: false + + insert: + allow: false + + update: + allow: false + + delete: + allow: false + + - name: user + tables: + - name: users + query: + filters: ["{ id: { _eq: $user_id } }"] + + - name: products + query: + limit: 50 + filters: ["{ user_id: { eq: $user_id } }"] + columns: ["id", "name", "description" ] + disable_aggregation: false + + insert: + filters: ["{ user_id: { eq: $user_id } }"] + columns: ["id", "name", "description" ] + set: + - created_at: "now" + + update: + filters: ["{ user_id: { eq: $user_id } }"] + columns: + - id + - name + set: + - updated_at: "now" + + delete: + deny: true + + - name: admin + match: id = 1 + tables: + - name: users + # query: + # filters: ["{ account_id: { _eq: $account_id } }"] - # - name: posts - # filter: ["{ account_id: { _eq: $account_id } }"] ``` If deploying into environments like Kubernetes it's useful to be able to configure things like secrets and hosts though environment variables therfore we expose the below environment variables. This is escpecially useful for secrets since they are usually injected in via a secrets management framework ie. Kubernetes Secrets diff --git a/psql/mutate_test.go b/psql/mutate_test.go index 8555f74..390c301 100644 --- a/psql/mutate_test.go +++ b/psql/mutate_test.go @@ -12,7 +12,7 @@ func simpleInsert(t *testing.T) { } }` - sql := `WITH "users" AS (WITH "input" AS (SELECT {{data}}::json AS j) INSERT INTO "users" (full_name, email) SELECT full_name, email FROM input i, json_populate_record(NULL::users, i.j) t RETURNING *) SELECT json_object_agg('user', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."id" AS "id") AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."id" FROM "users") AS "users_0") AS "done_1337";` + sql := `WITH "users" AS (WITH "input" AS (SELECT {{data}}::json AS j) INSERT INTO "users" (full_name, email) SELECT full_name, email FROM input i, json_populate_record(NULL::users, i.j) t RETURNING *) SELECT json_object_agg('user', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."id" AS "id") AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."id" FROM "users") AS "users_0") AS "done_1337"` vars := map[string]json.RawMessage{ "data": json.RawMessage(`{"email": "reannagreenholt@orn.com", "full_name": "Flo Barton"}`), @@ -36,7 +36,7 @@ func singleInsert(t *testing.T) { } }` - sql := `WITH "products" AS (WITH "input" AS (SELECT {{insert}}::json AS j) INSERT INTO "products" (name, description, user_id) SELECT name, description, user_id FROM input i, json_populate_record(NULL::products, i.j) t RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337";` + sql := `WITH "products" AS (WITH "input" AS (SELECT {{insert}}::json AS j) INSERT INTO "products" (name, description, user_id) SELECT name, description, user_id FROM input i, json_populate_record(NULL::products, i.j) t RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"` vars := map[string]json.RawMessage{ "insert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc", "user_id": 5 }`), @@ -60,7 +60,7 @@ func bulkInsert(t *testing.T) { } }` - sql := `WITH "products" AS (WITH "input" AS (SELECT {{insert}}::json AS j) INSERT INTO "products" (name, description) SELECT name, description FROM input i, json_populate_recordset(NULL::products, i.j) t RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337";` + sql := `WITH "products" AS (WITH "input" AS (SELECT {{insert}}::json AS j) INSERT INTO "products" (name, description) SELECT name, description FROM input i, json_populate_recordset(NULL::products, i.j) t RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"` vars := map[string]json.RawMessage{ "insert": json.RawMessage(` [{ "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }]`), @@ -84,7 +84,7 @@ func singleUpsert(t *testing.T) { } }` - sql := `WITH "products" AS (WITH "input" AS (SELECT {{upsert}}::json AS j) INSERT INTO "products" (name, description) SELECT name, description FROM input i, json_populate_record(NULL::products, i.j) t ON CONFLICT DO (id) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337";` + sql := `WITH "products" AS (WITH "input" AS (SELECT {{upsert}}::json AS j) INSERT INTO "products" (name, description) SELECT name, description FROM input i, json_populate_record(NULL::products, i.j) t ON CONFLICT DO (id) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"` vars := map[string]json.RawMessage{ "upsert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`), @@ -108,7 +108,7 @@ func bulkUpsert(t *testing.T) { } }` - sql := `WITH "products" AS (WITH "input" AS (SELECT {{upsert}}::json AS j) INSERT INTO "products" (name, description) SELECT name, description FROM input i, json_populate_recordset(NULL::products, i.j) t ON CONFLICT DO (id) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337";` + sql := `WITH "products" AS (WITH "input" AS (SELECT {{upsert}}::json AS j) INSERT INTO "products" (name, description) SELECT name, description FROM input i, json_populate_recordset(NULL::products, i.j) t ON CONFLICT DO (id) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"` vars := map[string]json.RawMessage{ "upsert": json.RawMessage(` [{ "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }]`), @@ -132,7 +132,7 @@ func singleUpdate(t *testing.T) { } }` - sql := `WITH "products" AS (WITH "input" AS (SELECT {{update}}::json AS j) UPDATE "products" SET (name, description) = (SELECT name, description FROM input i, json_populate_record(NULL::products, i.j) t) WHERE (("products"."user_id") = {{user_id}}) AND (("products"."id") = 1) AND (("products"."id") = 15) RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337";` + sql := `WITH "products" AS (WITH "input" AS (SELECT {{update}}::json AS j) UPDATE "products" SET (name, description) = (SELECT name, description FROM input i, json_populate_record(NULL::products, i.j) t) WHERE (("products"."user_id") = {{user_id}}) AND (("products"."id") = 1) AND (("products"."id") = 15) RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"` vars := map[string]json.RawMessage{ "update": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`), @@ -156,7 +156,7 @@ func delete(t *testing.T) { } }` - sql := `WITH "products" AS (DELETE FROM "products" WHERE (("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") = 1) RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337";` + sql := `WITH "products" AS (DELETE FROM "products" WHERE (("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") = 1) RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"` vars := map[string]json.RawMessage{ "update": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`), diff --git a/psql/query_test.go b/psql/query_test.go index c24b439..78330b6 100644 --- a/psql/query_test.go +++ b/psql/query_test.go @@ -32,16 +32,16 @@ func TestMain(m *testing.M) { qcompile.AddRole("user", "product", qcode.TRConfig{ Query: qcode.QueryConfig{ Columns: []string{"id", "name", "price", "users", "customers"}, - Filter: []string{ + Filters: []string{ "{ price: { gt: 0 } }", "{ price: { lt: 8 } }", }, }, Update: qcode.UpdateConfig{ - Filter: []string{"{ user_id: { eq: $user_id } }"}, + Filters: []string{"{ user_id: { eq: $user_id } }"}, }, Delete: qcode.DeleteConfig{ - Filter: []string{ + Filters: []string{ "{ price: { gt: 0 } }", "{ price: { lt: 8 } }", }, @@ -70,7 +70,7 @@ func TestMain(m *testing.M) { qcompile.AddRole("user", "mes", qcode.TRConfig{ Query: qcode.QueryConfig{ Columns: []string{"id", "full_name", "avatar"}, - Filter: []string{ + Filters: []string{ "{ id: { eq: $user_id } }", }, }, diff --git a/qcode/config.go b/qcode/config.go index b2ab9c4..c68a3d1 100644 --- a/qcode/config.go +++ b/qcode/config.go @@ -7,25 +7,25 @@ type Config struct { type QueryConfig struct { Limit int - Filter []string + Filters []string Columns []string DisableFunctions bool } type InsertConfig struct { - Filter []string + Filters []string Columns []string Set map[string]string } type UpdateConfig struct { - Filter []string + Filters []string Columns []string Set map[string]string } type DeleteConfig struct { - Filter []string + Filters []string Columns []string } diff --git a/qcode/qcode.go b/qcode/qcode.go index a1d55e3..8c90d93 100644 --- a/qcode/qcode.go +++ b/qcode/qcode.go @@ -190,7 +190,7 @@ func (com *Compiler) AddRole(role, table string, trc TRConfig) error { } // query config - trv.query.fil, err = compileFilter(trc.Query.Filter) + trv.query.fil, err = compileFilter(trc.Query.Filters) if err != nil { return err } @@ -201,20 +201,20 @@ func (com *Compiler) AddRole(role, table string, trc TRConfig) error { trv.query.disable.funcs = trc.Query.DisableFunctions // insert config - if trv.insert.fil, err = compileFilter(trc.Insert.Filter); err != nil { + if trv.insert.fil, err = compileFilter(trc.Insert.Filters); err != nil { return err } trv.insert.cols = toMap(trc.Insert.Columns) // update config - if trv.update.fil, err = compileFilter(trc.Update.Filter); err != nil { + if trv.update.fil, err = compileFilter(trc.Update.Filters); err != nil { return err } trv.insert.cols = toMap(trc.Insert.Columns) trv.insert.set = trc.Insert.Set // delete config - if trv.delete.fil, err = compileFilter(trc.Delete.Filter); err != nil { + if trv.delete.fil, err = compileFilter(trc.Delete.Filters); err != nil { return err } trv.delete.cols = toMap(trc.Delete.Columns) diff --git a/serv/allow.go b/serv/allow.go index f960fc0..f8f02ad 100644 --- a/serv/allow.go +++ b/serv/allow.go @@ -26,7 +26,8 @@ type allowItem struct { var _allowList allowList type allowList struct { - list map[string]*allowItem + list []*allowItem + index map[string]int filepath string saveChan chan *allowItem active bool @@ -34,7 +35,7 @@ type allowList struct { func initAllowList(cpath string) { _allowList = allowList{ - list: make(map[string]*allowItem), + index: make(map[string]int), saveChan: make(chan *allowItem), active: true, } @@ -172,19 +173,21 @@ func (al *allowList) load() { if c == 0 { if ty == AL_QUERY { q := string(b[s:(e + 1)]) + key := gqlHash(q, varBytes, "") - item := &allowItem{ - uri: uri, - gql: q, - } - - if len(varBytes) != 0 { + if idx, ok := al.index[key]; !ok { + al.list = append(al.list, &allowItem{ + uri: uri, + gql: q, + vars: varBytes, + }) + al.index[key] = len(al.list) - 1 + } else { + item := al.list[idx] + item.gql = q item.vars = varBytes } - //fmt.Println("%%", item.gql, string(item.vars)) - - al.list[gqlHash(q, varBytes, "")] = item varBytes = nil } else if ty == AL_VARS { @@ -205,11 +208,15 @@ func (al *allowList) save(item *allowItem) { if al.active == false { return } - h := gqlHash(item.gql, item.vars, "") - if _, ok := al.list[h]; ok { - return + + key := gqlHash(item.gql, item.vars, "") + + if idx, ok := al.index[key]; ok { + al.list[idx] = item + } else { + al.list = append(al.list, item) + al.index[key] = len(al.list) - 1 } - al.list[gqlHash(item.gql, item.vars, "")] = item f, err := os.Create(al.filepath) if err != nil { diff --git a/serv/config.go b/serv/config.go index 39e3d53..160cd4b 100644 --- a/serv/config.go +++ b/serv/config.go @@ -65,7 +65,7 @@ type config struct { Vars map[string]string `mapstructure:"variables"` Defaults struct { - Filter []string + Filters []string Blocklist []string } @@ -106,28 +106,28 @@ type configRole struct { Query struct { Limit int - Filter []string + Filters []string Columns []string DisableAggregation bool `mapstructure:"disable_aggregation"` Deny bool } Insert struct { - Filter []string + Filters []string Columns []string Set map[string]string Deny bool } Update struct { - Filter []string + Filters []string Columns []string Set map[string]string Deny bool } Delete struct { - Filter []string + Filters []string Columns []string Deny bool } diff --git a/serv/serv.go b/serv/serv.go index e03b9c2..a98c16e 100644 --- a/serv/serv.go +++ b/serv/serv.go @@ -34,25 +34,25 @@ func initCompilers(c *config) (*qcode.Compiler, *psql.Compiler, error) { for _, t := range r.Tables { query := qcode.QueryConfig{ Limit: t.Query.Limit, - Filter: t.Query.Filter, + Filters: t.Query.Filters, Columns: t.Query.Columns, DisableFunctions: t.Query.DisableAggregation, } insert := qcode.InsertConfig{ - Filter: t.Insert.Filter, + Filters: t.Insert.Filters, Columns: t.Insert.Columns, Set: t.Insert.Set, } update := qcode.UpdateConfig{ - Filter: t.Insert.Filter, + Filters: t.Insert.Filters, Columns: t.Insert.Columns, Set: t.Insert.Set, } delete := qcode.DeleteConfig{ - Filter: t.Insert.Filter, + Filters: t.Insert.Filters, Columns: t.Insert.Columns, } diff --git a/slides/overview.slide b/slides/overview.slide index 7888781..e52ff40 100644 --- a/slides/overview.slide +++ b/slides/overview.slide @@ -80,7 +80,7 @@ SQL Output account_id: "select account_id from users where id = $user_id" defaults: - filter: ["{ user_id: { eq: $user_id } }"] + Filters: ["{ user_id: { eq: $user_id } }"] blacklist: - password @@ -88,14 +88,14 @@ SQL Output fields: - name: users - filter: ["{ id: { eq: $user_id } }"] + Filters: ["{ id: { eq: $user_id } }"] - name: products - filter: [ + Filters: [ "{ price: { gt: 0 } }", "{ price: { lt: 8 } }" ] - name: me table: users - filter: ["{ id: { eq: $user_id } }"] + Filters: ["{ id: { eq: $user_id } }"] diff --git a/tmpl/dev.yml b/tmpl/dev.yml index b53a4d5..ffe9f64 100644 --- a/tmpl/dev.yml +++ b/tmpl/dev.yml @@ -1,4 +1,4 @@ -app_name: "{% app_name %} Development" +app_name: "Super Graph Development" host_port: 0.0.0.0:8080 web_ui: true @@ -53,7 +53,7 @@ auth: # Comment this out if you want to disable setting # the user_id via a header. Good for testing - header: X-User-ID + creds_in_header: true rails: # Rails version this is used for reading the @@ -84,7 +84,7 @@ database: type: postgres host: db port: 5432 - dbname: {% app_name_slug %}_development + dbname: app_development user: postgres password: '' @@ -100,7 +100,7 @@ database: # Define defaults to for the field key and values below defaults: - # filter: ["{ user_id: { eq: $user_id } }"] + # filters: ["{ user_id: { eq: $user_id } }"] # Field and table names that you wish to block blocklist: @@ -111,45 +111,81 @@ database: - encrypted - token - tables: - - name: users - # This filter will overwrite defaults.filter - # filter: ["{ id: { eq: $user_id } }"] - # filter_query: ["{ id: { eq: $user_id } }"] - filter_update: ["{ id: { eq: $user_id } }"] - filter_delete: ["{ id: { eq: $user_id } }"] +tables: + - name: customers + remotes: + - name: payments + id: stripe_id + url: http://rails_app:3000/stripe/$id + path: data + # debug: true + pass_headers: + - cookie + set_headers: + - name: Host + value: 0.0.0.0 + # - name: Authorization + # value: Bearer - # - name: products - # # Multiple filters are AND'd together - # filter: [ - # "{ price: { gt: 0 } }", - # "{ price: { lt: 8 } }" - # ] + - # You can create new fields that have a + # real db table backing them + name: me + table: users - - name: customers - # No filter is used for this field not - # even defaults.filter - filter: none +roles_query: "SELECT * FROM users as usr WHERE id = $user_id" - remotes: - - name: payments - id: stripe_id - url: http://rails_app:3000/stripe/$id - path: data - # debug: true - pass_headers: - - cookie - set_headers: - - name: Host - value: 0.0.0.0 - # - name: Authorization - # value: Bearer +roles: + - name: anon + tables: + - name: products + limit: 10 - - # You can create new fields that have a - # real db table backing them - name: me - table: users - filter: ["{ id: { eq: $user_id } }"] + query: + columns: ["id", "name", "description" ] + aggregation: false - # - name: posts - # filter: ["{ account_id: { _eq: $account_id } }"] \ No newline at end of file + insert: + allow: false + + update: + allow: false + + delete: + allow: false + + - name: user + tables: + - name: users + query: + filters: ["{ id: { _eq: $user_id } }"] + + - name: products + query: + limit: 50 + filters: ["{ user_id: { eq: $user_id } }"] + columns: ["id", "name", "description" ] + disable_aggregation: false + + insert: + filters: ["{ user_id: { eq: $user_id } }"] + columns: ["id", "name", "description" ] + set: + - created_at: "now" + + update: + filters: ["{ user_id: { eq: $user_id } }"] + columns: + - id + - name + set: + - updated_at: "now" + + delete: + deny: true + + - name: admin + match: id = 1 + tables: + - name: users + # query: + # filters: ["{ account_id: { _eq: $account_id } }"] diff --git a/tmpl/prod.yml b/tmpl/prod.yml index 29c6b45..95abfb7 100644 --- a/tmpl/prod.yml +++ b/tmpl/prod.yml @@ -1,4 +1,4 @@ -app_name: "{% app_name %} Production" +app_name: "Super Graph Production" host_port: 0.0.0.0:8080 web_ui: false @@ -76,7 +76,7 @@ database: type: postgres host: db port: 5432 - dbname: {% app_name_slug %}_production + dbname: {{app_name_slug}}_development user: postgres password: '' #pool_size: 10 @@ -90,7 +90,7 @@ database: # Define defaults to for the field key and values below defaults: - filter: ["{ user_id: { eq: $user_id } }"] + filters: ["{ user_id: { eq: $user_id } }"] # Field and table names that you wish to block blocklist: @@ -101,32 +101,19 @@ database: - encrypted - token - tables: - - name: users - # This filter will overwrite defaults.filter - # filter: ["{ id: { eq: $user_id } }"] - - # - name: products - # # Multiple filters are AND'd together - # filter: [ - # "{ price: { gt: 0 } }", - # "{ price: { lt: 8 } }" - # ] - +tables: - name: customers - remotes: - - name: payments - id: stripe_id - url: http://rails_app:3000/stripe/$id - path: data - # debug: true - pass_headers: - - cookie - set_headers: - - name: Host - value: 0.0.0.0 - # - name: Authorization - # value: Bearer + # remotes: + # - name: payments + # id: stripe_id + # url: http://rails_app:3000/stripe/$id + # path: data + # # pass_headers: + # # - cookie + # # - host + # set_headers: + # - name: Authorization + # value: Bearer - # You can create new fields that have a # real db table backing them @@ -158,24 +145,23 @@ roles: tables: - name: users query: - filter: ["{ id: { _eq: $user_id } }"] + filters: ["{ id: { _eq: $user_id } }"] - name: products - query: limit: 50 - filter: ["{ user_id: { eq: $user_id } }"] + filters: ["{ user_id: { eq: $user_id } }"] columns: ["id", "name", "description" ] disable_aggregation: false insert: - filter: ["{ user_id: { eq: $user_id } }"] + filters: ["{ user_id: { eq: $user_id } }"] columns: ["id", "name", "description" ] set: - - created_at: "now" + - created_at: "now" update: - filter: ["{ user_id: { eq: $user_id } }"] + filters: ["{ user_id: { eq: $user_id } }"] columns: - id - name @@ -189,6 +175,5 @@ roles: match: id = 1 tables: - name: users - - # select: - # filter: ["{ account_id: { _eq: $account_id } }"] + # query: + # filters: ["{ account_id: { _eq: $account_id } }"]