diff --git a/config/dev.yml b/config/dev.yml index 7c3ba46..a314dcd 100644 --- a/config/dev.yml +++ b/config/dev.yml @@ -115,9 +115,6 @@ tables: - name: users # This filter will overwrite defaults.filter # filter: ["{ id: { eq: $user_id } }"] - # filter_query: ["{ id: { eq: $user_id } }"] - filter_update: ["{ id: { eq: $user_id } }"] - filter_delete: ["{ id: { eq: $user_id } }"] # - name: products # # Multiple filters are AND'd together @@ -127,10 +124,6 @@ tables: # ] - name: customers - # No filter is used for this field not - # even defaults.filter - filter: none - remotes: - name: payments id: stripe_id @@ -149,7 +142,56 @@ tables: # real db table backing them name: me table: users - filter: ["{ id: { eq: $user_id } }"] - # - name: posts - # filter: ["{ account_id: { _eq: $account_id } }"] \ No newline at end of file +roles: + - name: anon + tables: + - name: products + limit: 10 + + query: + columns: ["id", "name", "description" ] + aggregation: false + + insert: + allow: false + + update: + allow: false + + delete: + allow: false + + - name: user + tables: + - name: products + + query: + limit: 50 + filter: ["{ user_id: { eq: $user_id } }"] + columns: ["id", "name", "description" ] + disable_aggregation: false + + insert: + filter: ["{ user_id: { eq: $user_id } }"] + columns: ["id", "name", "description" ] + set: + - created_at: "now" + + update: + filter: ["{ user_id: { eq: $user_id } }"] + columns: + - id + - name + set: + - updated_at: "now" + + delete: + deny: true + + - name: manager + tables: + - name: users + + select: + filter: ["{ account_id: { _eq: $account_id } }"] diff --git a/psql/insert.go b/psql/mutate.go similarity index 87% rename from psql/insert.go rename to psql/mutate.go index 027588e..63b3a5a 100644 --- a/psql/insert.go +++ b/psql/mutate.go @@ -10,7 +10,7 @@ import ( "github.com/dosco/super-graph/qcode" ) -var zeroPaging = qcode.Paging{} +var noLimit = qcode.Paging{NoLimit: true} func (co *Compiler) compileMutation(qc *qcode.QCode, w *bytes.Buffer, vars Variables) (uint32, error) { if len(qc.Selects) == 0 { @@ -29,23 +29,23 @@ func (co *Compiler) compileMutation(qc *qcode.QCode, w *bytes.Buffer, vars Varia quoted(c.w, ti.Name) c.w.WriteString(` AS `) - switch root.Action { - case qcode.ActionInsert: + switch qc.Type { + case qcode.QTInsert: if _, err := c.renderInsert(qc, w, vars, ti); err != nil { return 0, err } - case qcode.ActionUpdate: + case qcode.QTUpdate: if _, err := c.renderUpdate(qc, w, vars, ti); err != nil { return 0, err } - case qcode.ActionUpsert: + case qcode.QTUpsert: if _, err := c.renderUpsert(qc, w, vars, ti); err != nil { return 0, err } - case qcode.ActionDelete: + case qcode.QTDelete: if _, err := c.renderDelete(qc, w, vars, ti); err != nil { return 0, err } @@ -56,22 +56,23 @@ func (co *Compiler) compileMutation(qc *qcode.QCode, w *bytes.Buffer, vars Varia io.WriteString(c.w, ` RETURNING *) `) - root.Paging = zeroPaging + root.Paging = noLimit root.DistinctOn = root.DistinctOn[:] root.OrderBy = root.OrderBy[:] root.Where = nil root.Args = nil + qc.Type = qcode.QTQuery + return c.compileQuery(qc, w) } func (c *compilerContext) renderInsert(qc *qcode.QCode, w *bytes.Buffer, vars Variables, ti *DBTableInfo) (uint32, error) { - root := &qc.Selects[0] - insert, ok := vars[root.ActionVar] + insert, ok := vars[qc.ActionVar] if !ok { - return 0, fmt.Errorf("Variable '%s' not defined", root.ActionVar) + return 0, fmt.Errorf("Variable '%s' not defined", qc.ActionVar) } jt, array, err := jsn.Tree(insert) @@ -80,7 +81,7 @@ func (c *compilerContext) renderInsert(qc *qcode.QCode, w *bytes.Buffer, } c.w.WriteString(`(WITH "input" AS (SELECT {{`) - c.w.WriteString(root.ActionVar) + c.w.WriteString(qc.ActionVar) c.w.WriteString(`}}::json AS j) INSERT INTO `) quoted(c.w, ti.Name) io.WriteString(c.w, ` (`) @@ -106,12 +107,18 @@ func (c *compilerContext) renderInsert(qc *qcode.QCode, w *bytes.Buffer, func (c *compilerContext) renderInsertUpdateColumns(qc *qcode.QCode, w *bytes.Buffer, jt map[string]interface{}, ti *DBTableInfo) (uint32, error) { + root := &qc.Selects[0] i := 0 for _, cn := range ti.ColumnNames { if _, ok := jt[cn]; !ok { continue } + if len(root.Allowed) != 0 { + if _, ok := root.Allowed[cn]; !ok { + continue + } + } if i != 0 { io.WriteString(c.w, `, `) } @@ -126,9 +133,9 @@ func (c *compilerContext) renderUpdate(qc *qcode.QCode, w *bytes.Buffer, vars Variables, ti *DBTableInfo) (uint32, error) { root := &qc.Selects[0] - update, ok := vars[root.ActionVar] + update, ok := vars[qc.ActionVar] if !ok { - return 0, fmt.Errorf("Variable '%s' not defined", root.ActionVar) + return 0, fmt.Errorf("Variable '%s' not defined", qc.ActionVar) } jt, array, err := jsn.Tree(update) @@ -137,7 +144,7 @@ func (c *compilerContext) renderUpdate(qc *qcode.QCode, w *bytes.Buffer, } c.w.WriteString(`(WITH "input" AS (SELECT {{`) - c.w.WriteString(root.ActionVar) + c.w.WriteString(qc.ActionVar) c.w.WriteString(`}}::json AS j) UPDATE `) quoted(c.w, ti.Name) io.WriteString(c.w, ` SET (`) @@ -183,11 +190,10 @@ func (c *compilerContext) renderDelete(qc *qcode.QCode, w *bytes.Buffer, func (c *compilerContext) renderUpsert(qc *qcode.QCode, w *bytes.Buffer, vars Variables, ti *DBTableInfo) (uint32, error) { - root := &qc.Selects[0] - upsert, ok := vars[root.ActionVar] + upsert, ok := vars[qc.ActionVar] if !ok { - return 0, fmt.Errorf("Variable '%s' not defined", root.ActionVar) + return 0, fmt.Errorf("Variable '%s' not defined", qc.ActionVar) } jt, _, err := jsn.Tree(upsert) diff --git a/psql/insert_test.go b/psql/mutate_test.go similarity index 93% rename from psql/insert_test.go rename to psql/mutate_test.go index bd03d7f..8555f74 100644 --- a/psql/insert_test.go +++ b/psql/mutate_test.go @@ -18,7 +18,7 @@ func simpleInsert(t *testing.T) { "data": json.RawMessage(`{"email": "reannagreenholt@orn.com", "full_name": "Flo Barton"}`), } - resSQL, err := compileGQLToPSQL(gql, vars) + resSQL, err := compileGQLToPSQL(gql, vars, "user") if err != nil { t.Fatal(err) } @@ -42,7 +42,7 @@ func singleInsert(t *testing.T) { "insert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc", "user_id": 5 }`), } - resSQL, err := compileGQLToPSQL(gql, vars) + resSQL, err := compileGQLToPSQL(gql, vars, "user") if err != nil { t.Fatal(err) } @@ -54,7 +54,7 @@ func singleInsert(t *testing.T) { func bulkInsert(t *testing.T) { gql := `mutation { - product(id: 15, insert: $insert) { + product(name: "test", id: 15, insert: $insert) { id name } @@ -66,7 +66,7 @@ func bulkInsert(t *testing.T) { "insert": json.RawMessage(` [{ "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }]`), } - resSQL, err := compileGQLToPSQL(gql, vars) + resSQL, err := compileGQLToPSQL(gql, vars, "user") if err != nil { t.Fatal(err) } @@ -90,7 +90,7 @@ func singleUpsert(t *testing.T) { "upsert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`), } - resSQL, err := compileGQLToPSQL(gql, vars) + resSQL, err := compileGQLToPSQL(gql, vars, "user") if err != nil { t.Fatal(err) } @@ -114,7 +114,7 @@ func bulkUpsert(t *testing.T) { "upsert": json.RawMessage(` [{ "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }]`), } - resSQL, err := compileGQLToPSQL(gql, vars) + resSQL, err := compileGQLToPSQL(gql, vars, "user") if err != nil { t.Fatal(err) } @@ -138,7 +138,7 @@ func singleUpdate(t *testing.T) { "update": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`), } - resSQL, err := compileGQLToPSQL(gql, vars) + resSQL, err := compileGQLToPSQL(gql, vars, "user") if err != nil { t.Fatal(err) } @@ -162,7 +162,7 @@ func delete(t *testing.T) { "update": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`), } - resSQL, err := compileGQLToPSQL(gql, vars) + resSQL, err := compileGQLToPSQL(gql, vars, "user") if err != nil { t.Fatal(err) } @@ -172,7 +172,7 @@ func delete(t *testing.T) { } } -func TestCompileInsert(t *testing.T) { +func TestCompileMutate(t *testing.T) { t.Run("simpleInsert", simpleInsert) t.Run("singleInsert", singleInsert) t.Run("bulkInsert", bulkInsert) diff --git a/psql/select.go b/psql/query.go similarity index 93% rename from psql/select.go rename to psql/query.go index 1058a62..8d70e43 100644 --- a/psql/select.go +++ b/psql/query.go @@ -64,11 +64,11 @@ func (co *Compiler) Compile(qc *qcode.QCode, w *bytes.Buffer, vars Variables) (u switch qc.Type { case qcode.QTQuery: return co.compileQuery(qc, w) - case qcode.QTMutation: + case qcode.QTInsert, qcode.QTUpdate, qcode.QTDelete, qcode.QTUpsert: return co.compileMutation(qc, w, vars) } - return 0, errors.New("unknown operation") + return 0, fmt.Errorf("Unknown operation type %d", qc.Type) } func (co *Compiler) compileQuery(qc *qcode.QCode, w *bytes.Buffer) (uint32, error) { @@ -295,19 +295,21 @@ func (c *compilerContext) renderSelectClose(sel *qcode.Select, ti *DBTableInfo) } } - if sel.Action == 0 { - if len(sel.Paging.Limit) != 0 { - //fmt.Fprintf(w, ` LIMIT ('%s') :: integer`, c.sel.Paging.Limit) - c.w.WriteString(` LIMIT ('`) - c.w.WriteString(sel.Paging.Limit) - c.w.WriteString(`') :: integer`) + switch { + case sel.Paging.NoLimit: + break - } else if ti.Singular { - c.w.WriteString(` LIMIT ('1') :: integer`) + case len(sel.Paging.Limit) != 0: + //fmt.Fprintf(w, ` LIMIT ('%s') :: integer`, c.sel.Paging.Limit) + c.w.WriteString(` LIMIT ('`) + c.w.WriteString(sel.Paging.Limit) + c.w.WriteString(`') :: integer`) - } else { - c.w.WriteString(` LIMIT ('20') :: integer`) - } + case ti.Singular: + c.w.WriteString(` LIMIT ('1') :: integer`) + + default: + c.w.WriteString(` LIMIT ('20') :: integer`) } if len(sel.Paging.Offset) != 0 { @@ -370,13 +372,31 @@ func (c *compilerContext) renderJoinTable(sel *qcode.Select) error { } func (c *compilerContext) renderColumns(sel *qcode.Select, ti *DBTableInfo) { - for i, col := range sel.Cols { + i := 0 + for _, col := range sel.Cols { + if len(sel.Allowed) != 0 { + n := funcPrefixLen(col.Name) + if n != 0 { + if sel.Functions == false { + continue + } + if _, ok := sel.Allowed[col.Name[n:]]; !ok { + continue + } + } else { + if _, ok := sel.Allowed[col.Name]; !ok { + continue + } + } + } + if i != 0 { io.WriteString(c.w, ", ") } //fmt.Fprintf(w, `"%s_%d"."%s" AS "%s"`, //c.sel.Table, c.sel.ID, col.Name, col.FieldName) colWithTableIDAlias(c.w, ti.Name, sel.ID, col.Name, col.FieldName) + i++ } } @@ -435,7 +455,8 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo, c.w.WriteString(` FROM (SELECT `) - for i, col := range sel.Cols { + i := 0 + for n, col := range sel.Cols { cn := col.Name _, isRealCol := ti.Columns[cn] @@ -447,6 +468,9 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo, cn = ti.TSVCol arg := sel.Args["search"] + if i != 0 { + c.w.WriteString(`, `) + } //fmt.Fprintf(w, `ts_rank("%s"."%s", to_tsquery('%s')) AS %s`, //c.sel.Table, cn, arg.Val, col.Name) c.w.WriteString(`ts_rank(`) @@ -455,11 +479,15 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo, c.w.WriteString(arg.Val) c.w.WriteString(`')`) alias(c.w, col.Name) + i++ case strings.HasPrefix(cn, "search_headline_"): cn = cn[16:] arg := sel.Args["search"] + if i != 0 { + c.w.WriteString(`, `) + } //fmt.Fprintf(w, `ts_headline("%s"."%s", to_tsquery('%s')) AS %s`, //c.sel.Table, cn, arg.Val, col.Name) c.w.WriteString(`ts_headlinek(`) @@ -468,47 +496,63 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo, c.w.WriteString(arg.Val) c.w.WriteString(`')`) alias(c.w, col.Name) + i++ + } } else { pl := funcPrefixLen(cn) if pl == 0 { + if i != 0 { + c.w.WriteString(`, `) + } //fmt.Fprintf(w, `'%s not defined' AS %s`, cn, col.Name) c.w.WriteString(`'`) c.w.WriteString(cn) c.w.WriteString(` not defined'`) alias(c.w, col.Name) - } else { - isAgg = true + i++ + + } else if sel.Functions { + cn1 := cn[pl:] + if _, ok := sel.Allowed[cn1]; !ok { + continue + } + if i != 0 { + c.w.WriteString(`, `) + } fn := cn[0 : pl-1] - cn := cn[pl:] + isAgg = true + //fmt.Fprintf(w, `%s("%s"."%s") AS %s`, fn, c.sel.Table, cn, col.Name) c.w.WriteString(fn) c.w.WriteString(`(`) - colWithTable(c.w, ti.Name, cn) + colWithTable(c.w, ti.Name, cn1) c.w.WriteString(`)`) alias(c.w, col.Name) + i++ + } } } else { - groupBy = append(groupBy, i) + groupBy = append(groupBy, n) //fmt.Fprintf(w, `"%s"."%s"`, c.sel.Table, cn) + if i != 0 { + c.w.WriteString(`, `) + } colWithTable(c.w, ti.Name, cn) - } + i++ - if i < len(sel.Cols)-1 || len(childCols) != 0 { - //io.WriteString(w, ", ") - c.w.WriteString(`, `) } } - for i, col := range childCols { + for _, col := range childCols { if i != 0 { - //io.WriteString(w, ", ") c.w.WriteString(`, `) } //fmt.Fprintf(w, `"%s"."%s"`, col.Table, col.Name) colWithTable(c.w, col.Table, col.Name) + i++ } c.w.WriteString(` FROM `) @@ -570,19 +614,21 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo, } } - if sel.Action == 0 { - if len(sel.Paging.Limit) != 0 { - //fmt.Fprintf(w, ` LIMIT ('%s') :: integer`, c.sel.Paging.Limit) - c.w.WriteString(` LIMIT ('`) - c.w.WriteString(sel.Paging.Limit) - c.w.WriteString(`') :: integer`) + switch { + case sel.Paging.NoLimit: + break - } else if ti.Singular { - c.w.WriteString(` LIMIT ('1') :: integer`) + case len(sel.Paging.Limit) != 0: + //fmt.Fprintf(w, ` LIMIT ('%s') :: integer`, c.sel.Paging.Limit) + c.w.WriteString(` LIMIT ('`) + c.w.WriteString(sel.Paging.Limit) + c.w.WriteString(`') :: integer`) - } else { - c.w.WriteString(` LIMIT ('20') :: integer`) - } + case ti.Singular: + c.w.WriteString(` LIMIT ('1') :: integer`) + + default: + c.w.WriteString(` LIMIT ('20') :: integer`) } if len(sel.Paging.Offset) != 0 { diff --git a/psql/select_test.go b/psql/query_test.go similarity index 84% rename from psql/select_test.go rename to psql/query_test.go index 553f576..ea65645 100644 --- a/psql/select_test.go +++ b/psql/query_test.go @@ -22,32 +22,6 @@ func TestMain(m *testing.M) { var err error qcompile, err = qcode.NewCompiler(qcode.Config{ - DefaultFilter: []string{ - `{ user_id: { _eq: $user_id } }`, - }, - FilterMap: qcode.Filters{ - All: map[string][]string{ - "users": []string{ - "{ id: { eq: $user_id } }", - }, - "products": []string{ - "{ price: { gt: 0 } }", - "{ price: { lt: 8 } }", - }, - "customers": []string{}, - "mes": []string{ - "{ id: { eq: $user_id } }", - }, - }, - Query: map[string][]string{ - "users": []string{}, - }, - Update: map[string][]string{ - "products": []string{ - "{ user_id: { eq: $user_id } }", - }, - }, - }, Blocklist: []string{ "secret", "password", @@ -55,6 +29,59 @@ func TestMain(m *testing.M) { }, }) + qcompile.AddRole("user", "product", qcode.TRConfig{ + Query: qcode.QueryConfig{ + Columns: []string{"id", "name", "price", "users", "customers"}, + Filter: []string{ + "{ price: { gt: 0 } }", + "{ price: { lt: 8 } }", + }, + }, + Update: qcode.UpdateConfig{ + Filter: []string{"{ user_id: { eq: $user_id } }"}, + }, + Delete: qcode.DeleteConfig{ + Filter: []string{ + "{ price: { gt: 0 } }", + "{ price: { lt: 8 } }", + }, + }, + }) + + qcompile.AddRole("anon", "product", qcode.TRConfig{ + Query: qcode.QueryConfig{ + Columns: []string{"id", "name"}, + }, + }) + + qcompile.AddRole("anon1", "product", qcode.TRConfig{ + Query: qcode.QueryConfig{ + Columns: []string{"id", "name", "price"}, + DisableFunctions: true, + }, + }) + + qcompile.AddRole("user", "users", qcode.TRConfig{ + Query: qcode.QueryConfig{ + Columns: []string{"id", "full_name", "avatar", "email", "products"}, + }, + }) + + qcompile.AddRole("user", "mes", qcode.TRConfig{ + Query: qcode.QueryConfig{ + Columns: []string{"id", "full_name", "avatar"}, + Filter: []string{ + "{ id: { eq: $user_id } }", + }, + }, + }) + + qcompile.AddRole("user", "customers", qcode.TRConfig{ + Query: qcode.QueryConfig{ + Columns: []string{"id", "email", "full_name", "products"}, + }, + }) + if err != nil { log.Fatal(err) } @@ -135,9 +162,8 @@ func TestMain(m *testing.M) { os.Exit(m.Run()) } -func compileGQLToPSQL(gql string, vars Variables) ([]byte, error) { - - qc, err := qcompile.Compile([]byte(gql)) +func compileGQLToPSQL(gql string, vars Variables, role string) ([]byte, error) { + qc, err := qcompile.Compile([]byte(gql), role) if err != nil { return nil, err } @@ -147,6 +173,8 @@ func compileGQLToPSQL(gql string, vars Variables) ([]byte, error) { return nil, err } + //fmt.Println(string(sqlStmt)) + return sqlStmt, nil } @@ -175,7 +203,7 @@ func withComplexArgs(t *testing.T) { sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0" ORDER BY "products_0_price_ob" DESC), '[]') AS "products" FROM (SELECT DISTINCT ON ("products_0_price_ob") row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0", "products_0"."price" AS "products_0_price_ob" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") < 28) AND (("products"."id") >= 20)) LIMIT ('30') :: integer) AS "products_0" ORDER BY "products_0_price_ob" DESC LIMIT ('30') :: integer) AS "sel_json_agg_0") AS "done_1337";` - resSQL, err := compileGQLToPSQL(gql, nil) + resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { t.Fatal(err) } @@ -203,7 +231,7 @@ func withWhereMultiOr(t *testing.T) { sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") < 20) OR (("products"."price") > 10) OR NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` - resSQL, err := compileGQLToPSQL(gql, nil) + resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { t.Fatal(err) } @@ -229,7 +257,7 @@ func withWhereIsNull(t *testing.T) { sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") > 10) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` - resSQL, err := compileGQLToPSQL(gql, nil) + resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { t.Fatal(err) } @@ -255,7 +283,7 @@ func withWhereAndList(t *testing.T) { sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") > 10) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` - resSQL, err := compileGQLToPSQL(gql, nil) + resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { t.Fatal(err) } @@ -275,7 +303,7 @@ func fetchByID(t *testing.T) { sql := `SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") = 15)) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337";` - resSQL, err := compileGQLToPSQL(gql, nil) + resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { t.Fatal(err) } @@ -295,7 +323,7 @@ func searchQuery(t *testing.T) { sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("tsv") @@ to_tsquery('Imperial'))) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` - resSQL, err := compileGQLToPSQL(gql, nil) + resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { t.Fatal(err) } @@ -318,7 +346,7 @@ func oneToMany(t *testing.T) { sql := `SELECT json_object_agg('users', users) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."email" AS "email", "products_1_join"."products" AS "products") AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."email", "users"."id" FROM "users" LIMIT ('20') :: integer) AS "users_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "products_1"."name" AS "name", "products_1"."price" AS "price") AS "sel_1")) AS "sel_json_1" FROM (SELECT "products"."name", "products"."price" FROM "products" WHERE ((("products"."user_id") = ("users_0"."id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` - resSQL, err := compileGQLToPSQL(gql, nil) + resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { t.Fatal(err) } @@ -341,7 +369,7 @@ func belongsTo(t *testing.T) { sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."price" AS "price", "users_1_join"."users" AS "users") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", "products"."price", "products"."user_id" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "users_1"."email" AS "email") AS "sel_1")) AS "sel_json_1" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = ("products_0"."user_id"))) LIMIT ('20') :: integer) AS "users_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "users_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` - resSQL, err := compileGQLToPSQL(gql, nil) + resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { t.Fatal(err) } @@ -364,7 +392,7 @@ func manyToMany(t *testing.T) { sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "customers_1_join"."customers" AS "customers") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", "products"."id" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "customers_1"."email" AS "email", "customers_1"."full_name" AS "full_name") AS "sel_1")) AS "sel_json_1" FROM (SELECT "customers"."email", "customers"."full_name" FROM "customers" LEFT OUTER JOIN "purchases" ON (("purchases"."product_id") = ("products_0"."id")) WHERE ((("customers"."id") = ("purchases"."customer_id"))) LIMIT ('20') :: integer) AS "customers_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "customers_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` - resSQL, err := compileGQLToPSQL(gql, nil) + resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { t.Fatal(err) } @@ -387,7 +415,7 @@ func manyToManyReverse(t *testing.T) { sql := `SELECT json_object_agg('customers', customers) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "customers_0"."email" AS "email", "customers_0"."full_name" AS "full_name", "products_1_join"."products" AS "products") AS "sel_0")) AS "sel_json_0" FROM (SELECT "customers"."email", "customers"."full_name", "customers"."id" FROM "customers" LIMIT ('20') :: integer) AS "customers_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "products_1"."name" AS "name") AS "sel_1")) AS "sel_json_1" FROM (SELECT "products"."name" FROM "products" LEFT OUTER JOIN "purchases" ON (("purchases"."customer_id") = ("customers_0"."id")) WHERE ((("products"."id") = ("purchases"."product_id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` - resSQL, err := compileGQLToPSQL(gql, nil) + resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { t.Fatal(err) } @@ -407,7 +435,47 @@ func aggFunction(t *testing.T) { sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."count_price" AS "count_price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", count("products"."price") AS "count_price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) GROUP BY "products"."name" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` - resSQL, err := compileGQLToPSQL(gql, nil) + resSQL, err := compileGQLToPSQL(gql, nil, "user") + if err != nil { + t.Fatal(err) + } + + if string(resSQL) != sql { + t.Fatal(errNotExpected) + } +} + +func aggFunctionBlockedByCol(t *testing.T) { + gql := `query { + products { + name + count_price + } + }` + + sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name" FROM "products" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` + + resSQL, err := compileGQLToPSQL(gql, nil, "anon") + if err != nil { + t.Fatal(err) + } + + if string(resSQL) != sql { + t.Fatal(errNotExpected) + } +} + +func aggFunctionDisabled(t *testing.T) { + gql := `query { + products { + name + count_price + } + }` + + sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name" FROM "products" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` + + resSQL, err := compileGQLToPSQL(gql, nil, "anon1") if err != nil { t.Fatal(err) } @@ -427,7 +495,7 @@ func aggFunctionWithFilter(t *testing.T) { sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."max_price" AS "max_price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", max("products"."price") AS "max_price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") > 10)) GROUP BY "products"."id" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` - resSQL, err := compileGQLToPSQL(gql, nil) + resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { t.Fatal(err) } @@ -447,7 +515,7 @@ func queryWithVariables(t *testing.T) { sql := `SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") = {{product_price}}) AND (("products"."id") = {{product_id}})) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337";` - resSQL, err := compileGQLToPSQL(gql, nil) + resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { t.Fatal(err) } @@ -464,9 +532,9 @@ func syntheticTables(t *testing.T) { } }` - sql := `SELECT json_object_agg('me', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."email" AS "email") AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = {{user_id}})) LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "done_1337";` + sql := `SELECT json_object_agg('me', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT ) AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = {{user_id}})) LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "done_1337";` - resSQL, err := compileGQLToPSQL(gql, nil) + resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { t.Fatal(err) } @@ -476,7 +544,7 @@ func syntheticTables(t *testing.T) { } } -func TestCompileSelect(t *testing.T) { +func TestCompileQuery(t *testing.T) { t.Run("withComplexArgs", withComplexArgs) t.Run("withWhereAndList", withWhereAndList) t.Run("withWhereIsNull", withWhereIsNull) @@ -488,10 +556,11 @@ func TestCompileSelect(t *testing.T) { t.Run("manyToMany", manyToMany) t.Run("manyToManyReverse", manyToManyReverse) t.Run("aggFunction", aggFunction) + t.Run("aggFunctionBlockedByCol", aggFunctionBlockedByCol) + t.Run("aggFunctionDisabled", aggFunctionDisabled) t.Run("aggFunctionWithFilter", aggFunctionWithFilter) t.Run("syntheticTables", syntheticTables) t.Run("queryWithVariables", queryWithVariables) - } var benchGQL = []byte(`query { @@ -526,7 +595,7 @@ func BenchmarkCompile(b *testing.B) { for n := 0; n < b.N; n++ { w.Reset() - qc, err := qcompile.Compile(benchGQL) + qc, err := qcompile.Compile(benchGQL, "user") if err != nil { b.Fatal(err) } @@ -547,7 +616,7 @@ func BenchmarkCompileParallel(b *testing.B) { for pb.Next() { w.Reset() - qc, err := qcompile.Compile(benchGQL) + qc, err := qcompile.Compile(benchGQL, "user") if err != nil { b.Fatal(err) } diff --git a/qcode/config.go b/qcode/config.go new file mode 100644 index 0000000..b2ab9c4 --- /dev/null +++ b/qcode/config.go @@ -0,0 +1,99 @@ +package qcode + +type Config struct { + Blocklist []string + KeepArgs bool +} + +type QueryConfig struct { + Limit int + Filter []string + Columns []string + DisableFunctions bool +} + +type InsertConfig struct { + Filter []string + Columns []string + Set map[string]string +} + +type UpdateConfig struct { + Filter []string + Columns []string + Set map[string]string +} + +type DeleteConfig struct { + Filter []string + Columns []string +} + +type TRConfig struct { + Query QueryConfig + Insert InsertConfig + Update UpdateConfig + Delete DeleteConfig +} + +type trval struct { + query struct { + limit string + fil *Exp + cols map[string]struct{} + disable struct { + funcs bool + } + } + + insert struct { + fil *Exp + cols map[string]struct{} + set map[string]string + } + + update struct { + fil *Exp + cols map[string]struct{} + set map[string]string + } + + delete struct { + fil *Exp + cols map[string]struct{} + } +} + +func (trv *trval) allowedColumns(qt QType) map[string]struct{} { + switch qt { + case QTQuery: + return trv.query.cols + case QTInsert: + return trv.insert.cols + case QTUpdate: + return trv.update.cols + case QTDelete: + return trv.insert.cols + case QTUpsert: + return trv.insert.cols + } + + return nil +} + +func (trv *trval) filter(qt QType) *Exp { + switch qt { + case QTQuery: + return trv.query.fil + case QTInsert: + return trv.insert.fil + case QTUpdate: + return trv.update.fil + case QTDelete: + return trv.delete.fil + case QTUpsert: + return trv.insert.fil + } + + return nil +} diff --git a/qcode/fuzz.go b/qcode/fuzz.go index db8f3c8..89a4a3c 100644 --- a/qcode/fuzz.go +++ b/qcode/fuzz.go @@ -5,7 +5,7 @@ func FuzzerEntrypoint(data []byte) int { //testData := string(data) qcompile, _ := NewCompiler(Config{}) - _, err := qcompile.Compile(data) + _, err := qcompile.Compile(data, "user") if err != nil { return -1 } diff --git a/qcode/parse_test.go b/qcode/parse_test.go index dba397f..0e04ed2 100644 --- a/qcode/parse_test.go +++ b/qcode/parse_test.go @@ -46,13 +46,18 @@ func compareOp(op1, op2 Operation) error { */ func TestCompile1(t *testing.T) { - qcompile, _ := NewCompiler(Config{}) + qc, _ := NewCompiler(Config{}) + qc.AddRole("user", "product", TRConfig{ + Query: QueryConfig{ + Columns: []string{"id", "Name"}, + }, + }) - _, err := qcompile.Compile([]byte(` + _, err := qc.Compile([]byte(` product(id: 15) { id name - }`)) + }`), "user") if err != nil { t.Fatal(err) @@ -60,13 +65,18 @@ func TestCompile1(t *testing.T) { } func TestCompile2(t *testing.T) { - qcompile, _ := NewCompiler(Config{}) + qc, _ := NewCompiler(Config{}) + qc.AddRole("user", "product", TRConfig{ + Query: QueryConfig{ + Columns: []string{"ID"}, + }, + }) - _, err := qcompile.Compile([]byte(` + _, err := qc.Compile([]byte(` query { product(id: 15) { id name - } }`)) + } }`), "user") if err != nil { t.Fatal(err) @@ -74,15 +84,20 @@ func TestCompile2(t *testing.T) { } func TestCompile3(t *testing.T) { - qcompile, _ := NewCompiler(Config{}) + qc, _ := NewCompiler(Config{}) + qc.AddRole("user", "product", TRConfig{ + Query: QueryConfig{ + Columns: []string{"ID"}, + }, + }) - _, err := qcompile.Compile([]byte(` + _, err := qc.Compile([]byte(` mutation { product(id: 15, name: "Test") { id name } - }`)) + }`), "user") if err != nil { t.Fatal(err) @@ -91,7 +106,7 @@ func TestCompile3(t *testing.T) { func TestInvalidCompile1(t *testing.T) { qcompile, _ := NewCompiler(Config{}) - _, err := qcompile.Compile([]byte(`#`)) + _, err := qcompile.Compile([]byte(`#`), "user") if err == nil { t.Fatal(errors.New("expecting an error")) @@ -100,7 +115,7 @@ func TestInvalidCompile1(t *testing.T) { func TestInvalidCompile2(t *testing.T) { qcompile, _ := NewCompiler(Config{}) - _, err := qcompile.Compile([]byte(`{u(where:{not:0})}`)) + _, err := qcompile.Compile([]byte(`{u(where:{not:0})}`), "user") if err == nil { t.Fatal(errors.New("expecting an error")) @@ -109,7 +124,7 @@ func TestInvalidCompile2(t *testing.T) { func TestEmptyCompile(t *testing.T) { qcompile, _ := NewCompiler(Config{}) - _, err := qcompile.Compile([]byte(``)) + _, err := qcompile.Compile([]byte(``), "user") if err == nil { t.Fatal(errors.New("expecting an error")) @@ -144,7 +159,7 @@ func BenchmarkQCompile(b *testing.B) { b.ReportAllocs() for n := 0; n < b.N; n++ { - _, err := qcompile.Compile(gql) + _, err := qcompile.Compile(gql, "user") if err != nil { b.Fatal(err) @@ -160,7 +175,7 @@ func BenchmarkQCompileP(b *testing.B) { b.RunParallel(func(pb *testing.PB) { for pb.Next() { - _, err := qcompile.Compile(gql) + _, err := qcompile.Compile(gql, "user") if err != nil { b.Fatal(err) diff --git a/qcode/qcode.go b/qcode/qcode.go index bafe3e4..05e8e98 100644 --- a/qcode/qcode.go +++ b/qcode/qcode.go @@ -3,6 +3,7 @@ package qcode import ( "errors" "fmt" + "strconv" "strings" "sync" @@ -17,23 +18,16 @@ const ( maxSelectors = 30 QTQuery QType = iota + 1 - QTMutation - - ActionInsert Action = iota + 1 - ActionUpdate - ActionDelete - ActionUpsert + QTInsert + QTUpdate + QTDelete + QTUpsert ) type QCode struct { - Type QType - Selects []Select -} - -type Column struct { - Table string - Name string - FieldName string + Type QType + ActionVar string + Selects []Select } type Select struct { @@ -47,9 +41,15 @@ type Select struct { OrderBy []*OrderBy DistinctOn []string Paging Paging - Action Action - ActionVar string Children []int32 + Functions bool + Allowed map[string]struct{} +} + +type Column struct { + Table string + Name string + FieldName string } type Exp struct { @@ -77,8 +77,9 @@ type OrderBy struct { } type Paging struct { - Limit string - Offset string + Limit string + Offset string + NoLimit bool } type ExpOp int @@ -145,81 +146,23 @@ const ( OrderDescNullsLast ) -type Filters struct { - All map[string][]string - Query map[string][]string - Insert map[string][]string - Update map[string][]string - Delete map[string][]string -} - -type Config struct { - DefaultFilter []string - FilterMap Filters - Blocklist []string - KeepArgs bool -} - type Compiler struct { - df *Exp - fm struct { - all map[string]*Exp - query map[string]*Exp - insert map[string]*Exp - update map[string]*Exp - delete map[string]*Exp - } + tr map[string]map[string]*trval bl map[string]struct{} ka bool } -var opMap = map[parserType]QType{ - opQuery: QTQuery, - opMutate: QTMutation, -} - var expPool = sync.Pool{ New: func() interface{} { return &Exp{doFree: true} }, } func NewCompiler(c Config) (*Compiler, error) { - var err error co := &Compiler{ka: c.KeepArgs} - + co.tr = make(map[string]map[string]*trval) co.bl = make(map[string]struct{}, len(c.Blocklist)) for i := range c.Blocklist { - co.bl[c.Blocklist[i]] = struct{}{} - } - - co.df, err = compileFilter(c.DefaultFilter) - if err != nil { - return nil, err - } - - co.fm.all, err = buildFilters(c.FilterMap.All) - if err != nil { - return nil, err - } - - co.fm.query, err = buildFilters(c.FilterMap.Query) - if err != nil { - return nil, err - } - - co.fm.insert, err = buildFilters(c.FilterMap.Insert) - if err != nil { - return nil, err - } - - co.fm.update, err = buildFilters(c.FilterMap.Update) - if err != nil { - return nil, err - } - - co.fm.delete, err = buildFilters(c.FilterMap.Delete) - if err != nil { - return nil, err + co.bl[strings.ToLower(c.Blocklist[i])] = struct{}{} } seedExp := [100]Exp{} @@ -232,58 +175,99 @@ func NewCompiler(c Config) (*Compiler, error) { return co, nil } -func buildFilters(filMap map[string][]string) (map[string]*Exp, error) { - fm := make(map[string]*Exp, len(filMap)) +func (com *Compiler) AddRole(role, table string, trc TRConfig) error { + var err error + trv := &trval{} - for k, v := range filMap { - fil, err := compileFilter(v) - if err != nil { - return nil, err + toMap := func(cols []string) map[string]struct{} { + m := make(map[string]struct{}, len(cols)) + for i := range cols { + m[strings.ToLower(cols[i])] = struct{}{} } - singular := flect.Singularize(k) - plural := flect.Pluralize(k) - - fm[singular] = fil - fm[plural] = fil + return m } - return fm, nil + // query config + trv.query.fil, err = compileFilter(trc.Query.Filter) + if err != nil { + return err + } + if trc.Query.Limit > 0 { + trv.query.limit = strconv.Itoa(trc.Query.Limit) + } + trv.query.cols = toMap(trc.Query.Columns) + trv.query.disable.funcs = trc.Query.DisableFunctions + + // insert config + if trv.insert.fil, err = compileFilter(trc.Insert.Filter); err != nil { + return err + } + trv.insert.cols = toMap(trc.Insert.Columns) + + // update config + if trv.update.fil, err = compileFilter(trc.Update.Filter); err != nil { + return err + } + trv.insert.cols = toMap(trc.Insert.Columns) + trv.insert.set = trc.Insert.Set + + // delete config + if trv.delete.fil, err = compileFilter(trc.Delete.Filter); err != nil { + return err + } + trv.delete.cols = toMap(trc.Delete.Columns) + + singular := flect.Singularize(table) + plural := flect.Pluralize(table) + + if _, ok := com.tr[role]; !ok { + com.tr[role] = make(map[string]*trval) + } + + com.tr[role][singular] = trv + com.tr[role][plural] = trv + return nil } -func (com *Compiler) Compile(query []byte) (*QCode, error) { - var qc QCode +func (com *Compiler) Compile(query []byte, role string) (*QCode, error) { var err error + qc := QCode{Type: QTQuery} + op, err := Parse(query) if err != nil { return nil, err } - qc.Selects, err = com.compileQuery(op) - if err != nil { + if err = com.compileQuery(&qc, op, role); err != nil { return nil, err } - if t, ok := opMap[op.Type]; ok { - qc.Type = t - } else { - return nil, fmt.Errorf("Unknown operation type %d", op.Type) - } - opPool.Put(op) return &qc, nil } -func (com *Compiler) compileQuery(op *Operation) ([]Select, error) { +func (com *Compiler) compileQuery(qc *QCode, op *Operation, role string) error { id := int32(0) parentID := int32(0) + if len(op.Fields) == 0 { + return errors.New("invalid graphql no query found") + } + + if op.Type == opMutate { + if err := com.setMutationType(qc, op.Fields[0].Args); err != nil { + return err + } + } + selects := make([]Select, 0, 5) st := NewStack() + action := qc.Type if len(op.Fields) == 0 { - return nil, errors.New("empty query") + return errors.New("empty query") } st.Push(op.Fields[0].ID) @@ -293,7 +277,7 @@ func (com *Compiler) compileQuery(op *Operation) ([]Select, error) { } if id >= maxSelectors { - return nil, fmt.Errorf("selector limit reached (%d)", maxSelectors) + return fmt.Errorf("selector limit reached (%d)", maxSelectors) } fid := st.Pop() @@ -303,14 +287,28 @@ func (com *Compiler) compileQuery(op *Operation) ([]Select, error) { continue } + trv, ok := com.tr[role][field.Name] + if !ok { + continue + } + selects = append(selects, Select{ ID: id, ParentID: parentID, Table: field.Name, Children: make([]int32, 0, 5), + Allowed: trv.allowedColumns(action), }) s := &selects[(len(selects) - 1)] + if action == QTQuery { + s.Functions = !trv.query.disable.funcs + + if len(trv.query.limit) != 0 { + s.Paging.Limit = trv.query.limit + } + } + if s.ID != 0 { p := &selects[s.ParentID] p.Children = append(p.Children, s.ID) @@ -322,12 +320,13 @@ func (com *Compiler) compileQuery(op *Operation) ([]Select, error) { s.FieldName = s.Table } - err := com.compileArgs(s, field.Args) + err := com.compileArgs(qc, s, field.Args) if err != nil { - return nil, err + return err } s.Cols = make([]Column, 0, len(field.Children)) + action = QTQuery for _, cid := range field.Children { f := op.Fields[cid] @@ -356,36 +355,14 @@ func (com *Compiler) compileQuery(op *Operation) ([]Select, error) { } if id == 0 { - return nil, errors.New("invalid query") + return errors.New("invalid query") } var fil *Exp - root := &selects[0] - switch op.Type { - case opQuery: - fil, _ = com.fm.query[root.Table] - - case opMutate: - switch root.Action { - case ActionInsert: - fil, _ = com.fm.insert[root.Table] - case ActionUpdate: - fil, _ = com.fm.update[root.Table] - case ActionDelete: - fil, _ = com.fm.delete[root.Table] - case ActionUpsert: - fil, _ = com.fm.insert[root.Table] - } - } - - if fil == nil { - fil, _ = com.fm.all[root.Table] - } - - if fil == nil { - fil = com.df + if trv, ok := com.tr[role][op.Fields[0].Name]; ok { + fil = trv.filter(qc.Type) } if fil != nil && fil.Op != OpNop { @@ -403,10 +380,11 @@ func (com *Compiler) compileQuery(op *Operation) ([]Select, error) { } } - return selects[:id], nil + qc.Selects = selects[:id] + return nil } -func (com *Compiler) compileArgs(sel *Select, args []Arg) error { +func (com *Compiler) compileArgs(qc *QCode, sel *Select, args []Arg) error { var err error if com.ka { @@ -418,9 +396,7 @@ func (com *Compiler) compileArgs(sel *Select, args []Arg) error { switch arg.Name { case "id": - if sel.ID == 0 { - err = com.compileArgID(sel, arg) - } + err = com.compileArgID(sel, arg) case "search": err = com.compileArgSearch(sel, arg) case "where": @@ -433,18 +409,6 @@ func (com *Compiler) compileArgs(sel *Select, args []Arg) error { err = com.compileArgLimit(sel, arg) case "offset": err = com.compileArgOffset(sel, arg) - case "insert": - sel.Action = ActionInsert - err = com.compileArgAction(sel, arg) - case "update": - sel.Action = ActionUpdate - err = com.compileArgAction(sel, arg) - case "upsert": - sel.Action = ActionUpsert - err = com.compileArgAction(sel, arg) - case "delete": - sel.Action = ActionDelete - err = com.compileArgAction(sel, arg) } if err != nil { @@ -461,6 +425,45 @@ func (com *Compiler) compileArgs(sel *Select, args []Arg) error { return nil } +func (com *Compiler) setMutationType(qc *QCode, args []Arg) error { + setActionVar := func(arg *Arg) error { + if arg.Val.Type != nodeVar { + return fmt.Errorf("value for argument '%s' must be a variable", arg.Name) + } + qc.ActionVar = arg.Val.Val + return nil + } + + for i := range args { + arg := &args[i] + + switch arg.Name { + case "insert": + qc.Type = QTInsert + return setActionVar(arg) + case "update": + qc.Type = QTUpdate + return setActionVar(arg) + case "upsert": + qc.Type = QTUpsert + return setActionVar(arg) + case "delete": + qc.Type = QTDelete + + if arg.Val.Type != nodeBool { + return fmt.Errorf("value for argument '%s' must be a boolean", arg.Name) + } + + if arg.Val.Val == "false" { + qc.Type = QTQuery + } + return nil + } + } + + return nil +} + func (com *Compiler) compileArgObj(st *util.Stack, arg *Arg) (*Exp, error) { if arg.Val.Type != nodeObj { return nil, fmt.Errorf("expecting an object") @@ -540,6 +543,10 @@ func (com *Compiler) compileArgNode(st *util.Stack, node *Node, usePool bool) (* } func (com *Compiler) compileArgID(sel *Select, arg *Arg) error { + if sel.ID != 0 { + return nil + } + if sel.Where != nil && sel.Where.Op == OpEqID { return nil } @@ -732,26 +739,6 @@ func (com *Compiler) compileArgOffset(sel *Select, arg *Arg) error { return nil } -func (com *Compiler) compileArgAction(sel *Select, arg *Arg) error { - switch sel.Action { - case ActionDelete: - if arg.Val.Type != nodeBool { - return fmt.Errorf("value for argument '%s' must be a boolean", arg.Name) - } - if arg.Val.Val == "false" { - sel.Action = 0 - } - - default: - if arg.Val.Type != nodeVar { - return fmt.Errorf("value for argument '%s' must be a variable", arg.Name) - } - sel.ActionVar = arg.Val.Val - } - - return nil -} - func newExp(st *util.Stack, node *Node, usePool bool) (*Exp, error) { name := node.Name if name[0] == '_' { diff --git a/serv/config.go b/serv/config.go index 1fb7f63..b2ebf21 100644 --- a/serv/config.go +++ b/serv/config.go @@ -71,18 +71,14 @@ type config struct { } `mapstructure:"database"` Tables []configTable + Roles []configRoles } type configTable struct { - Name string - Filter []string - FilterQuery []string `mapstructure:"filter_query"` - FilterInsert []string `mapstructure:"filter_insert"` - FilterUpdate []string `mapstructure:"filter_update"` - FilterDelete []string `mapstructure:"filter_delete"` - Table string - Blocklist []string - Remotes []configRemote + Name string + Table string + Blocklist []string + Remotes []configRemote } type configRemote struct { @@ -98,6 +94,41 @@ type configRemote struct { } `mapstructure:"set_headers"` } +type configRoles struct { + Name string + Tables []struct { + Name string + + Query struct { + Limit int + Filter []string + Columns []string + DisableAggregation bool `mapstructure:"disable_aggregation"` + Deny bool + } + + Insert struct { + Filter []string + Columns []string + Set map[string]string + Deny bool + } + + Update struct { + Filter []string + Columns []string + Set map[string]string + Deny bool + } + + Delete struct { + Filter []string + Columns []string + Deny bool + } + } +} + func newConfig() *viper.Viper { vi := viper.New() diff --git a/serv/core.go b/serv/core.go index 9c7f3ee..8da4eab 100644 --- a/serv/core.go +++ b/serv/core.go @@ -59,7 +59,7 @@ func (c *coreContext) execQuery() ([]byte, error) { } else { - qc, err = qcompile.Compile([]byte(c.req.Query)) + qc, err = qcompile.Compile([]byte(c.req.Query), "user") if err != nil { return nil, err } diff --git a/serv/prepare.go b/serv/prepare.go index bf9a475..ac0eb5f 100644 --- a/serv/prepare.go +++ b/serv/prepare.go @@ -40,7 +40,7 @@ func prepareStmt(key, gql string, varBytes json.RawMessage) error { return nil } - qc, err := qcompile.Compile([]byte(gql)) + qc, err := qcompile.Compile([]byte(gql), "user") if err != nil { return err } diff --git a/serv/serv.go b/serv/serv.go index 1006520..980eb23 100644 --- a/serv/serv.go +++ b/serv/serv.go @@ -12,7 +12,6 @@ import ( rice "github.com/GeertJohan/go.rice" "github.com/dosco/super-graph/psql" "github.com/dosco/super-graph/qcode" - "github.com/gobuffalo/flect" ) func initCompilers(c *config) (*qcode.Compiler, *psql.Compiler, error) { @@ -22,49 +21,50 @@ func initCompilers(c *config) (*qcode.Compiler, *psql.Compiler, error) { } conf := qcode.Config{ - DefaultFilter: c.DB.Defaults.Filter, - FilterMap: qcode.Filters{ - All: make(map[string][]string, len(c.Tables)), - Query: make(map[string][]string, len(c.Tables)), - Insert: make(map[string][]string, len(c.Tables)), - Update: make(map[string][]string, len(c.Tables)), - Delete: make(map[string][]string, len(c.Tables)), - }, Blocklist: c.DB.Defaults.Blocklist, KeepArgs: false, } - for i := range c.Tables { - t := c.Tables[i] - - singular := flect.Singularize(t.Name) - plural := flect.Pluralize(t.Name) - - setFilter := func(fm map[string][]string, fil []string) { - switch { - case len(fil) == 0: - return - case fil[0] == "none" || len(fil[0]) == 0: - fm[singular] = []string{} - fm[plural] = []string{} - default: - fm[singular] = t.Filter - fm[plural] = t.Filter - } - } - - setFilter(conf.FilterMap.All, t.Filter) - setFilter(conf.FilterMap.Query, t.FilterQuery) - setFilter(conf.FilterMap.Insert, t.FilterInsert) - setFilter(conf.FilterMap.Update, t.FilterUpdate) - setFilter(conf.FilterMap.Delete, t.FilterDelete) - } - qc, err := qcode.NewCompiler(conf) if err != nil { return nil, nil, err } + for _, r := range c.Roles { + for _, t := range r.Tables { + query := qcode.QueryConfig{ + Limit: t.Query.Limit, + Filter: t.Query.Filter, + Columns: t.Query.Columns, + DisableFunctions: t.Query.DisableAggregation, + } + + insert := qcode.InsertConfig{ + Filter: t.Insert.Filter, + Columns: t.Insert.Columns, + Set: t.Insert.Set, + } + + update := qcode.UpdateConfig{ + Filter: t.Insert.Filter, + Columns: t.Insert.Columns, + Set: t.Insert.Set, + } + + delete := qcode.DeleteConfig{ + Filter: t.Insert.Filter, + Columns: t.Insert.Columns, + } + + qc.AddRole(r.Name, t.Name, qcode.TRConfig{ + Query: query, + Insert: insert, + Update: update, + Delete: delete, + }) + } + } + pc := psql.NewCompiler(psql.Config{ Schema: schema, Vars: c.getVariables(),