diff --git a/dev.yml b/dev.yml index 8a91aa2..7cc2f40 100644 --- a/dev.yml +++ b/dev.yml @@ -64,7 +64,7 @@ database: # Define defaults to for the field key and values below defaults: - filter: ["{ id: { _eq: $user_id } }"] + filter: ["{ user_id: { eq: $user_id } }"] # Fields and table names that you wish to block blacklist: @@ -77,12 +77,20 @@ database: fields: - name: users - filter: ["{ id: { _eq: $user_id } }"] + filter: ["{ id: { eq: $user_id } }"] + + - name: products + filter: [ + "{ price: { gt: 0 } }", + "{ price: { lt: 8 } }" + ] + + - name: customers + filter: none + + - name: me + table: users + filter: ["{ id: { eq: $user_id } }"] # - name: posts - # filter: ["{ account_id: { _eq: $account_id } }"] - - - name: my_products - table: products - filter: ["{ id: { _eq: $user_id } }"] - + # filter: ["{ account_id: { _eq: $account_id } }"] \ No newline at end of file diff --git a/psql/psql.go b/psql/psql.go index e6a8e0f..2d80df4 100644 --- a/psql/psql.go +++ b/psql/psql.go @@ -11,22 +11,27 @@ import ( ) type Config struct { - Schema *DBSchema - Vars map[string]string + Schema *DBSchema + Vars map[string]string + TableMap map[string]string } type Compiler struct { schema *DBSchema vars map[string]string + tmap map[string]string } func NewCompiler(conf Config) *Compiler { - return &Compiler{conf.Schema, conf.Vars} + return &Compiler{conf.Schema, conf.Vars, conf.TableMap} } func (c *Compiler) Compile(w io.Writer, qc *qcode.QCode) error { st := util.NewStack() - ti, _ := c.schema.GetTable(qc.Query.Select.Table) + ti, err := c.getTable(qc.Query.Select) + if err != nil { + return err + } st.Push(&selectBlockClose{nil, qc.Query.Select}) st.Push(&selectBlock{nil, qc.Query.Select, ti, c}) @@ -47,12 +52,13 @@ func (c *Compiler) Compile(w io.Writer, qc *qcode.QCode) error { v.render(w, c.schema, childCols, childIDs) for i := range childIDs { - ti, err := c.schema.GetTable(v.sel.Table) - if err != nil { - continue - } sub := v.sel.Joins[childIDs[i]] + ti, err := c.getTable(sub) + if err != nil { + return err + } + st.Push(&joinClose{sub}) st.Push(&selectBlockClose{v.sel, sub}) st.Push(&selectBlock{v.sel, sub, ti, c}) @@ -75,6 +81,13 @@ func (c *Compiler) Compile(w io.Writer, qc *qcode.QCode) error { return nil } +func (c *Compiler) getTable(sel *qcode.Select) (*DBTableInfo, error) { + if tn, ok := c.tmap[sel.Table]; ok { + return c.schema.GetTable(tn) + } + return c.schema.GetTable(sel.Table) +} + func (c *Compiler) relationshipColumns(parent *qcode.Select) ( cols []*qcode.Column, childIDs []int) { @@ -275,33 +288,34 @@ func (v *selectBlock) renderBaseSelect(w io.Writer, schema *DBSchema, childCols isRoot := v.parent == nil isFil := v.sel.Where != nil + isSearch := v.sel.Args["search"] != nil isAgg := false - _, isSearch := v.sel.Args["search"] - io.WriteString(w, " FROM (SELECT ") for i, col := range v.sel.Cols { cn := col.Name - _, isRealCol := v.schema.ColMap[TCKey{v.sel.Table, cn}] + + _, isRealCol := v.ti.Columns[cn] if !isRealCol { - switch { - case isSearch && cn == "search_rank": - cn = v.ti.TSVCol - arg := v.sel.Args["search"] + if isSearch { + switch { + case cn == "search_rank": + cn = v.ti.TSVCol + arg := v.sel.Args["search"] - fmt.Fprintf(w, `ts_rank("%s"."%s", to_tsquery('%s')) AS %s`, - v.sel.Table, cn, arg.Val, col.Name) + fmt.Fprintf(w, `ts_rank("%s"."%s", to_tsquery('%s')) AS %s`, + v.sel.Table, cn, arg.Val, col.Name) - case isSearch && strings.HasPrefix(cn, "search_headline_"): - cn = cn[16:] - arg := v.sel.Args["search"] + case strings.HasPrefix(cn, "search_headline_"): + cn = cn[16:] + arg := v.sel.Args["search"] - fmt.Fprintf(w, `ts_headline("%s"."%s", to_tsquery('%s')) AS %s`, - v.sel.Table, cn, arg.Val, col.Name) - - default: + fmt.Fprintf(w, `ts_headline("%s"."%s", to_tsquery('%s')) AS %s`, + v.sel.Table, cn, arg.Val, col.Name) + } + } else { pl := funcPrefixLen(cn) if pl == 0 { fmt.Fprintf(w, `'%s not defined' AS %s`, cn, col.Name) @@ -330,7 +344,11 @@ func (v *selectBlock) renderBaseSelect(w io.Writer, schema *DBSchema, childCols } } - fmt.Fprintf(w, ` FROM "%s"`, v.sel.Table) + if tn, ok := v.tmap[v.sel.Table]; ok { + fmt.Fprintf(w, ` FROM "%s" AS "%s"`, tn, v.sel.Table) + } else { + fmt.Fprintf(w, ` FROM "%s"`, v.sel.Table) + } if isRoot && isFil { io.WriteString(w, ` WHERE (`) diff --git a/psql/psql_test.go b/psql/psql_test.go index 945f205..628bd70 100644 --- a/psql/psql_test.go +++ b/psql/psql_test.go @@ -1,6 +1,7 @@ package psql import ( + "log" "os" "strings" "testing" @@ -18,18 +19,35 @@ var ( ) func TestMain(m *testing.M) { - fm := qcode.NewFilterMap(map[string]string{ - "users": "{ id: { _eq: $user_id } }", - "posts": "{ account_id: { _eq: $account_id } }", + var err error + + qcompile, err = qcode.NewCompiler(qcode.Config{ + Filter: []string{ + `{ user_id: { _eq: $user_id } }`, + }, + FilterMap: map[string][]string{ + "users": []string{ + "{ id: { eq: $user_id } }", + }, + "products": []string{ + "{ price: { gt: 0 } }", + "{ price: { lt: 8 } }", + }, + "customers": []string{}, + "mes": []string{ + "{ id: { eq: $user_id } }", + }, + }, + Blacklist: []string{ + "secret", + "password", + "token", + }, }) - bl := qcode.NewBlacklist([]string{ - "secret", - "password", - "token", - }) - - qcompile = qcode.NewCompiler(fm, bl) + if err != nil { + log.Fatal(err) + } tables := []*DBTable{ &DBTable{Name: "customers", Type: "table"}, @@ -81,17 +99,26 @@ func TestMain(m *testing.M) { &DBColumn{ID: 7, Name: "returned", Type: "timestamp without time zone", NotNull: false, PrimaryKey: false, Uniquekey: false, FKeyTable: "", FKeyColID: []int(nil)}}, } - schema := initSchema() + schema := &DBSchema{ + Tables: make(map[string]*DBTableInfo), + RelMap: make(map[TTKey]*DBRel), + } for i, t := range tables { - updateSchema(schema, t, columns[i]) + schema.updateSchema(t, columns[i]) } vars := NewVariables(map[string]string{ "account_id": "select account_id from users where id = $user_id", }) - pcompile = NewCompiler(schema, vars) + pcompile = NewCompiler(Config{ + Schema: schema, + Vars: vars, + TableMap: map[string]string{ + "mes": "users", + }, + }) os.Exit(m.Run()) } @@ -134,7 +161,7 @@ func withComplexArgs(t *testing.T) { } }` - sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products" ORDER BY "products_0.ob.price" DESC), '[]') AS "products" FROM (SELECT DISTINCT ON ("products_0.ob.price") row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products", "products_0"."price" AS "products_0.ob.price" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."id") < (28)) AND (("products"."id") >= (20))) LIMIT ('30') :: integer) AS "products_0" ORDER BY "products_0.ob.price" DESC LIMIT ('30') :: integer) AS "products_0") AS "done_1337";` + sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products" ORDER BY "products_0.ob.price" DESC), '[]') AS "products" FROM (SELECT DISTINCT ON ("products_0.ob.price") row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products", "products_0"."price" AS "products_0.ob.price" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("products"."id") < (28)) AND (("products"."id") >= (20))) LIMIT ('30') :: integer) AS "products_0" ORDER BY "products_0.ob.price" DESC LIMIT ('30') :: integer) AS "products_0") AS "done_1337";` resSQL, err := compileGQLToPSQL(gql) if err != nil { @@ -162,7 +189,7 @@ func withWhereMultiOr(t *testing.T) { } }` - sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") < (20)) OR (("products"."price") > (10)) OR NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";` + sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("products"."price") < (20)) OR (("products"."price") > (10)) OR NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";` resSQL, err := compileGQLToPSQL(gql) if err != nil { @@ -188,7 +215,7 @@ func withWhereIsNull(t *testing.T) { } }` - sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > (10)) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";` + sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("products"."price") > (10)) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";` resSQL, err := compileGQLToPSQL(gql) if err != nil { @@ -214,7 +241,7 @@ func withWhereAndList(t *testing.T) { } }` - sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > (10)) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";` + sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("products"."price") > (10)) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";` resSQL, err := compileGQLToPSQL(gql) if err != nil { @@ -260,7 +287,7 @@ func belongsTo(t *testing.T) { } }` - sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."price" AS "price", "users_1.join"."users" AS "users") AS "sel_0")) AS "products" FROM (SELECT "products"."name", "products"."price", "products"."user_id" FROM "products" LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("users"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "users_1"."email" AS "email") AS "sel_1")) AS "users" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = ("products_0"."user_id"))) LIMIT ('20') :: integer) AS "users_1" LIMIT ('20') :: integer) AS "users_1") AS "users_1.join" ON ('true') LIMIT ('20') :: integer) AS "products_0") AS "done_1337";` + sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."price" AS "price", "users_1.join"."users" AS "users") AS "sel_0")) AS "products" FROM (SELECT "products"."name", "products"."price", "products"."user_id" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8))) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("users"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "users_1"."email" AS "email") AS "sel_1")) AS "users" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = ("products_0"."user_id"))) LIMIT ('20') :: integer) AS "users_1" LIMIT ('20') :: integer) AS "users_1") AS "users_1.join" ON ('true') LIMIT ('20') :: integer) AS "products_0") AS "done_1337";` resSQL, err := compileGQLToPSQL(gql) if err != nil { @@ -283,7 +310,7 @@ func manyToMany(t *testing.T) { } }` - sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "customers_1.join"."customers" AS "customers") AS "sel_0")) AS "products" FROM (SELECT "products"."name", "products"."id" FROM "products" LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("customers"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "customers_1"."email" AS "email", "customers_1"."full_name" AS "full_name") AS "sel_1")) AS "customers" FROM (SELECT "customers"."email", "customers"."full_name" FROM "customers" LEFT OUTER JOIN "purchases" ON (("purchases"."product_id") = ("products_0"."id")) WHERE ((("customers"."id") = ("purchases"."customer_id"))) LIMIT ('20') :: integer) AS "customers_1" LIMIT ('20') :: integer) AS "customers_1") AS "customers_1.join" ON ('true') LIMIT ('20') :: integer) AS "products_0") AS "done_1337";` + sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "customers_1.join"."customers" AS "customers") AS "sel_0")) AS "products" FROM (SELECT "products"."name", "products"."id" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8))) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("customers"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "customers_1"."email" AS "email", "customers_1"."full_name" AS "full_name") AS "sel_1")) AS "customers" FROM (SELECT "customers"."email", "customers"."full_name" FROM "customers" LEFT OUTER JOIN "purchases" ON (("purchases"."product_id") = ("products_0"."id")) WHERE ((("customers"."id") = ("purchases"."customer_id"))) LIMIT ('20') :: integer) AS "customers_1" LIMIT ('20') :: integer) AS "customers_1") AS "customers_1.join" ON ('true') LIMIT ('20') :: integer) AS "products_0") AS "done_1337";` resSQL, err := compileGQLToPSQL(gql) if err != nil { @@ -320,13 +347,13 @@ func manyToManyReverse(t *testing.T) { func fetchByID(t *testing.T) { gql := `query { - product(id: 4) { + product(id: 15) { id name } }` - sql := `SELECT json_object_agg('product', products) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("id") = ('4'))) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337";` + sql := `SELECT json_object_agg('product', products) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("id") = ('15'))) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337";` resSQL, err := compileGQLToPSQL(gql) if err != nil { @@ -340,13 +367,13 @@ func fetchByID(t *testing.T) { func searchQuery(t *testing.T) { gql := `query { - products(search: "Amazing") { + products(search: "Imperial") { id name } }` - sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("tsv") @@ to_tsquery('Amazing'))) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";` + sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("tsv") @@ to_tsquery('Imperial'))) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";` resSQL, err := compileGQLToPSQL(gql) if err != nil { @@ -366,7 +393,7 @@ func aggFunction(t *testing.T) { } }` - sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."count_price" AS "count_price") AS "sel_0")) AS "products" FROM (SELECT "products"."name", count("products"."price") AS count_price FROM "products" GROUP BY "products"."name" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";` + sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."count_price" AS "count_price") AS "sel_0")) AS "products" FROM (SELECT "products"."name", count("products"."price") AS count_price FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8))) GROUP BY "products"."name" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";` resSQL, err := compileGQLToPSQL(gql) if err != nil { @@ -386,7 +413,26 @@ func aggFunctionWithFilter(t *testing.T) { } }` - sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."max_price" AS "max_price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", max("products"."price") AS max_price FROM "products" WHERE ((("products"."id") > (10))) GROUP BY "products"."id" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";` + sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."max_price" AS "max_price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", max("products"."price") AS max_price FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("products"."id") > (10))) GROUP BY "products"."id" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";` + + resSQL, err := compileGQLToPSQL(gql) + if err != nil { + t.Fatal(err) + } + + if resSQL != sql { + t.Fatal(errNotExpected) + } +} + +func syntheticTables(t *testing.T) { + gql := `query { + me { + email + } + }` + + sql := `SELECT json_object_agg('me', mes) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "mes_0"."email" AS "email") AS "sel_0")) AS "mes" FROM (SELECT "mes"."email" FROM "users" AS "mes" WHERE ((("mes"."id") = ('{{user_id}}'))) LIMIT ('1') :: integer) AS "mes_0" LIMIT ('1') :: integer) AS "done_1337";` resSQL, err := compileGQLToPSQL(gql) if err != nil { @@ -411,6 +457,7 @@ func TestCompileGQL(t *testing.T) { t.Run("manyToManyReverse", manyToManyReverse) t.Run("aggFunction", aggFunction) t.Run("aggFunctionWithFilter", aggFunctionWithFilter) + t.Run("syntheticTables", syntheticTables) } func BenchmarkCompileGQLToSQL(b *testing.B) { diff --git a/psql/tables.go b/psql/tables.go index ffe4537..dff183f 100644 --- a/psql/tables.go +++ b/psql/tables.go @@ -16,16 +16,15 @@ type TTKey struct { } type DBSchema struct { - ColMap map[TCKey]*DBColumn - ColIDMap map[int]*DBColumn - Tables map[string]*DBTableInfo - + Tables map[string]*DBTableInfo RelMap map[TTKey]*DBRel } type DBTableInfo struct { + Name string PrimaryCol string TSVCol string + Columns map[string]*DBColumn } type RelType int @@ -45,7 +44,10 @@ type DBRel struct { } func NewDBSchema(db *pg.DB) (*DBSchema, error) { - schema := initSchema() + schema := &DBSchema{ + Tables: make(map[string]*DBTableInfo), + RelMap: make(map[TTKey]*DBRel), + } tables, err := GetTables(db) if err != nil { @@ -58,41 +60,39 @@ func NewDBSchema(db *pg.DB) (*DBSchema, error) { return nil, err } - updateSchema(schema, t, cols) + schema.updateSchema(t, cols) } return schema, nil } -func initSchema() *DBSchema { - return &DBSchema{ - ColMap: make(map[TCKey]*DBColumn), - ColIDMap: make(map[int]*DBColumn), - Tables: make(map[string]*DBTableInfo), - RelMap: make(map[TTKey]*DBRel), - } -} - -func updateSchema(schema *DBSchema, t *DBTable, cols []*DBColumn) { +func (s *DBSchema) updateSchema(t *DBTable, cols []*DBColumn) { // Current table - ct := strings.ToLower(t.Name) - schema.Tables[ct] = &DBTableInfo{} + ti := &DBTableInfo{ + Name: t.Name, + Columns: make(map[string]*DBColumn, len(cols)), + } // Foreign key columns in current table var jcols []*DBColumn + colByID := make(map[int]*DBColumn) - for _, c := range cols { - schema.ColMap[TCKey{ct, strings.ToLower(c.Name)}] = c - schema.ColIDMap[c.ID] = c + for i := range cols { + c := cols[i] + ti.Columns[strings.ToLower(c.Name)] = cols[i] + colByID[c.ID] = cols[i] } + ct := strings.ToLower(t.Name) + s.Tables[ct] = ti + for _, c := range cols { switch { case c.Type == "tsvector": - schema.Tables[ct].TSVCol = c.Name + s.Tables[ct].TSVCol = c.Name case c.PrimaryKey: - schema.Tables[ct].PrimaryCol = c.Name + s.Tables[ct].PrimaryCol = c.Name case len(c.FKeyTable) != 0: if len(c.FKeyColID) == 0 { @@ -101,7 +101,7 @@ func updateSchema(schema *DBSchema, t *DBTable, cols []*DBColumn) { // Foreign key column name ft := strings.ToLower(c.FKeyTable) - fc, ok := schema.ColIDMap[c.FKeyColID[0]] + fc, ok := colByID[c.FKeyColID[0]] if !ok { continue } @@ -109,12 +109,12 @@ func updateSchema(schema *DBSchema, t *DBTable, cols []*DBColumn) { // Belongs-to relation between current table and the // table in the foreign key rel1 := &DBRel{RelBelongTo, "", "", c.Name, fc.Name} - schema.RelMap[TTKey{ct, ft}] = rel1 + s.RelMap[TTKey{ct, ft}] = rel1 // One-to-many relation between the foreign key table and the // the current table rel2 := &DBRel{RelOneToMany, "", "", fc.Name, c.Name} - schema.RelMap[TTKey{ft, ct}] = rel2 + s.RelMap[TTKey{ft, ct}] = rel2 jcols = append(jcols, c) } @@ -130,22 +130,24 @@ func updateSchema(schema *DBSchema, t *DBTable, cols []*DBColumn) { for i := range jcols { for n := range jcols { if n != i { - updateSchemaOTMT(schema, ct, jcols[i], jcols[n]) + s.updateSchemaOTMT(ct, jcols[i], jcols[n], colByID) } } } } } -func updateSchemaOTMT(schema *DBSchema, ct string, col1, col2 *DBColumn) { +func (s *DBSchema) updateSchemaOTMT( + ct string, col1, col2 *DBColumn, colByID map[int]*DBColumn) { + t1 := strings.ToLower(col1.FKeyTable) t2 := strings.ToLower(col2.FKeyTable) - fc1, ok := schema.ColIDMap[col1.FKeyColID[0]] + fc1, ok := colByID[col1.FKeyColID[0]] if !ok { return } - fc2, ok := schema.ColIDMap[col2.FKeyColID[0]] + fc2, ok := colByID[col2.FKeyColID[0]] if !ok { return } @@ -154,13 +156,13 @@ func updateSchemaOTMT(schema *DBSchema, ct string, col1, col2 *DBColumn) { // 2nd foreign key table //rel1 := &DBRel{RelOneToManyThrough, ct, fc1.Name, col1.Name} rel1 := &DBRel{RelOneToManyThrough, ct, col2.Name, fc2.Name, col1.Name} - schema.RelMap[TTKey{t1, t2}] = rel1 + s.RelMap[TTKey{t1, t2}] = rel1 // One-to-many-through relation between 2nd foreign key table and the // 1nd foreign key table //rel2 := &DBRel{RelOneToManyThrough, ct, col2.Name, fc2.Name} rel2 := &DBRel{RelOneToManyThrough, ct, col1.Name, fc1.Name, col2.Name} - schema.RelMap[TTKey{t2, t1}] = rel2 + s.RelMap[TTKey{t2, t1}] = rel2 } type DBTable struct { @@ -242,13 +244,13 @@ WHERE c.relkind = 'r'::char stmt, err := db.Prepare(sqlStmt) if err != nil { - return nil, fmt.Errorf("Error fetching columns: %s", err) + return nil, fmt.Errorf("error fetching columns: %s", err) } var t []*DBColumn _, err = stmt.Query(&t, schema, table) if err != nil { - return nil, fmt.Errorf("Error fetching columns: %s", err) + return nil, fmt.Errorf("error fetching columns: %s", err) } return t, nil @@ -257,7 +259,7 @@ WHERE c.relkind = 'r'::char func (s *DBSchema) GetTable(table string) (*DBTableInfo, error) { t, ok := s.Tables[table] if !ok { - return nil, fmt.Errorf("table info not found '%s'", table) + return nil, fmt.Errorf("unknown table '%s'", table) } return t, nil } diff --git a/qcode/qcode.go b/qcode/qcode.go index a64e383..3ac6f64 100644 --- a/qcode/qcode.go +++ b/qcode/qcode.go @@ -61,7 +61,8 @@ type Paging struct { type ExpOp int const ( - OpAnd ExpOp = iota + 1 + OpNop ExpOp = iota + OpAnd OpOr OpNot OpEquals @@ -92,6 +93,8 @@ func (t ExpOp) String() string { var v string switch t { + case OpNop: + v = "op-nop" case OpAnd: v = "op-and" case OpOr: @@ -333,7 +336,12 @@ func (com *Compiler) compileQuery(op *Operation) (*Query, error) { } } - if fil, ok := com.fm[selRoot.Table]; ok { + fil, ok := com.fm[selRoot.Table] + if !ok { + fil = com.fl + } + + if fil != nil && fil.Op != OpNop { if selRoot.Where != nil { selRoot.Where = &Exp{Op: OpAnd, Children: []*Exp{fil, selRoot.Where}} } else { @@ -788,6 +796,10 @@ func compileFilter(filter []string) (*Exp, error) { var fl *Exp com := &Compiler{} + if len(filter) == 0 { + return &Exp{Op: OpNop}, nil + } + for i := range filter { node, err := ParseArgValue(filter[i]) if err != nil { diff --git a/serv/serv.go b/serv/serv.go index ab1b7c0..16f7341 100644 --- a/serv/serv.go +++ b/serv/serv.go @@ -189,9 +189,21 @@ func initCompilers(c *config) (*qcode.Compiler, *psql.Compiler, error) { cdb := c.DB fm := make(map[string][]string, len(cdb.Fields)) + tmap := make(map[string]string, len(cdb.Fields)) + for i := range cdb.Fields { f := cdb.Fields[i] - fm[strings.ToLower(f.Name)] = f.Filter + name := flect.Pluralize(strings.ToLower(f.Name)) + if len(f.Filter) != 0 { + if f.Filter[0] == "none" { + fm[name] = []string{} + } else { + fm[name] = f.Filter + } + } + if len(f.Table) != 0 { + tmap[name] = f.Table + } } qc, err := qcode.NewCompiler(qcode.Config{ @@ -209,8 +221,9 @@ func initCompilers(c *config) (*qcode.Compiler, *psql.Compiler, error) { } pc := psql.NewCompiler(psql.Config{ - Schema: schema, - Vars: cdb.Variables, + Schema: schema, + Vars: cdb.Variables, + TableMap: tmap, }) return qc, pc, nil