Redesign config file architecture

This commit is contained in:
Vikram Rangnekar 2019-04-08 21:24:29 -04:00
parent e3660473cc
commit 2d02f2afda
6 changed files with 198 additions and 98 deletions

24
dev.yml
View File

@ -64,7 +64,7 @@ database:
# Define defaults to for the field key and values below # Define defaults to for the field key and values below
defaults: defaults:
filter: ["{ id: { _eq: $user_id } }"] filter: ["{ user_id: { eq: $user_id } }"]
# Fields and table names that you wish to block # Fields and table names that you wish to block
blacklist: blacklist:
@ -77,12 +77,20 @@ database:
fields: fields:
- name: users - name: users
filter: ["{ id: { _eq: $user_id } }"] filter: ["{ id: { eq: $user_id } }"]
- name: products
filter: [
"{ price: { gt: 0 } }",
"{ price: { lt: 8 } }"
]
- name: customers
filter: none
- name: me
table: users
filter: ["{ id: { eq: $user_id } }"]
# - name: posts # - name: posts
# filter: ["{ account_id: { _eq: $account_id } }"] # filter: ["{ account_id: { _eq: $account_id } }"]
- name: my_products
table: products
filter: ["{ id: { _eq: $user_id } }"]

View File

@ -11,22 +11,27 @@ import (
) )
type Config struct { type Config struct {
Schema *DBSchema Schema *DBSchema
Vars map[string]string Vars map[string]string
TableMap map[string]string
} }
type Compiler struct { type Compiler struct {
schema *DBSchema schema *DBSchema
vars map[string]string vars map[string]string
tmap map[string]string
} }
func NewCompiler(conf Config) *Compiler { func NewCompiler(conf Config) *Compiler {
return &Compiler{conf.Schema, conf.Vars} return &Compiler{conf.Schema, conf.Vars, conf.TableMap}
} }
func (c *Compiler) Compile(w io.Writer, qc *qcode.QCode) error { func (c *Compiler) Compile(w io.Writer, qc *qcode.QCode) error {
st := util.NewStack() st := util.NewStack()
ti, _ := c.schema.GetTable(qc.Query.Select.Table) ti, err := c.getTable(qc.Query.Select)
if err != nil {
return err
}
st.Push(&selectBlockClose{nil, qc.Query.Select}) st.Push(&selectBlockClose{nil, qc.Query.Select})
st.Push(&selectBlock{nil, qc.Query.Select, ti, c}) st.Push(&selectBlock{nil, qc.Query.Select, ti, c})
@ -47,12 +52,13 @@ func (c *Compiler) Compile(w io.Writer, qc *qcode.QCode) error {
v.render(w, c.schema, childCols, childIDs) v.render(w, c.schema, childCols, childIDs)
for i := range childIDs { for i := range childIDs {
ti, err := c.schema.GetTable(v.sel.Table)
if err != nil {
continue
}
sub := v.sel.Joins[childIDs[i]] sub := v.sel.Joins[childIDs[i]]
ti, err := c.getTable(sub)
if err != nil {
return err
}
st.Push(&joinClose{sub}) st.Push(&joinClose{sub})
st.Push(&selectBlockClose{v.sel, sub}) st.Push(&selectBlockClose{v.sel, sub})
st.Push(&selectBlock{v.sel, sub, ti, c}) st.Push(&selectBlock{v.sel, sub, ti, c})
@ -75,6 +81,13 @@ func (c *Compiler) Compile(w io.Writer, qc *qcode.QCode) error {
return nil return nil
} }
func (c *Compiler) getTable(sel *qcode.Select) (*DBTableInfo, error) {
if tn, ok := c.tmap[sel.Table]; ok {
return c.schema.GetTable(tn)
}
return c.schema.GetTable(sel.Table)
}
func (c *Compiler) relationshipColumns(parent *qcode.Select) ( func (c *Compiler) relationshipColumns(parent *qcode.Select) (
cols []*qcode.Column, childIDs []int) { cols []*qcode.Column, childIDs []int) {
@ -275,33 +288,34 @@ func (v *selectBlock) renderBaseSelect(w io.Writer, schema *DBSchema, childCols
isRoot := v.parent == nil isRoot := v.parent == nil
isFil := v.sel.Where != nil isFil := v.sel.Where != nil
isSearch := v.sel.Args["search"] != nil
isAgg := false isAgg := false
_, isSearch := v.sel.Args["search"]
io.WriteString(w, " FROM (SELECT ") io.WriteString(w, " FROM (SELECT ")
for i, col := range v.sel.Cols { for i, col := range v.sel.Cols {
cn := col.Name cn := col.Name
_, isRealCol := v.schema.ColMap[TCKey{v.sel.Table, cn}]
_, isRealCol := v.ti.Columns[cn]
if !isRealCol { if !isRealCol {
switch { if isSearch {
case isSearch && cn == "search_rank": switch {
cn = v.ti.TSVCol case cn == "search_rank":
arg := v.sel.Args["search"] cn = v.ti.TSVCol
arg := v.sel.Args["search"]
fmt.Fprintf(w, `ts_rank("%s"."%s", to_tsquery('%s')) AS %s`, fmt.Fprintf(w, `ts_rank("%s"."%s", to_tsquery('%s')) AS %s`,
v.sel.Table, cn, arg.Val, col.Name) v.sel.Table, cn, arg.Val, col.Name)
case isSearch && strings.HasPrefix(cn, "search_headline_"): case strings.HasPrefix(cn, "search_headline_"):
cn = cn[16:] cn = cn[16:]
arg := v.sel.Args["search"] arg := v.sel.Args["search"]
fmt.Fprintf(w, `ts_headline("%s"."%s", to_tsquery('%s')) AS %s`, fmt.Fprintf(w, `ts_headline("%s"."%s", to_tsquery('%s')) AS %s`,
v.sel.Table, cn, arg.Val, col.Name) v.sel.Table, cn, arg.Val, col.Name)
}
default: } else {
pl := funcPrefixLen(cn) pl := funcPrefixLen(cn)
if pl == 0 { if pl == 0 {
fmt.Fprintf(w, `'%s not defined' AS %s`, cn, col.Name) fmt.Fprintf(w, `'%s not defined' AS %s`, cn, col.Name)
@ -330,7 +344,11 @@ func (v *selectBlock) renderBaseSelect(w io.Writer, schema *DBSchema, childCols
} }
} }
fmt.Fprintf(w, ` FROM "%s"`, v.sel.Table) if tn, ok := v.tmap[v.sel.Table]; ok {
fmt.Fprintf(w, ` FROM "%s" AS "%s"`, tn, v.sel.Table)
} else {
fmt.Fprintf(w, ` FROM "%s"`, v.sel.Table)
}
if isRoot && isFil { if isRoot && isFil {
io.WriteString(w, ` WHERE (`) io.WriteString(w, ` WHERE (`)

View File

@ -1,6 +1,7 @@
package psql package psql
import ( import (
"log"
"os" "os"
"strings" "strings"
"testing" "testing"
@ -18,18 +19,35 @@ var (
) )
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
fm := qcode.NewFilterMap(map[string]string{ var err error
"users": "{ id: { _eq: $user_id } }",
"posts": "{ account_id: { _eq: $account_id } }", qcompile, err = qcode.NewCompiler(qcode.Config{
Filter: []string{
`{ user_id: { _eq: $user_id } }`,
},
FilterMap: map[string][]string{
"users": []string{
"{ id: { eq: $user_id } }",
},
"products": []string{
"{ price: { gt: 0 } }",
"{ price: { lt: 8 } }",
},
"customers": []string{},
"mes": []string{
"{ id: { eq: $user_id } }",
},
},
Blacklist: []string{
"secret",
"password",
"token",
},
}) })
bl := qcode.NewBlacklist([]string{ if err != nil {
"secret", log.Fatal(err)
"password", }
"token",
})
qcompile = qcode.NewCompiler(fm, bl)
tables := []*DBTable{ tables := []*DBTable{
&DBTable{Name: "customers", Type: "table"}, &DBTable{Name: "customers", Type: "table"},
@ -81,17 +99,26 @@ func TestMain(m *testing.M) {
&DBColumn{ID: 7, Name: "returned", Type: "timestamp without time zone", NotNull: false, PrimaryKey: false, Uniquekey: false, FKeyTable: "", FKeyColID: []int(nil)}}, &DBColumn{ID: 7, Name: "returned", Type: "timestamp without time zone", NotNull: false, PrimaryKey: false, Uniquekey: false, FKeyTable: "", FKeyColID: []int(nil)}},
} }
schema := initSchema() schema := &DBSchema{
Tables: make(map[string]*DBTableInfo),
RelMap: make(map[TTKey]*DBRel),
}
for i, t := range tables { for i, t := range tables {
updateSchema(schema, t, columns[i]) schema.updateSchema(t, columns[i])
} }
vars := NewVariables(map[string]string{ vars := NewVariables(map[string]string{
"account_id": "select account_id from users where id = $user_id", "account_id": "select account_id from users where id = $user_id",
}) })
pcompile = NewCompiler(schema, vars) pcompile = NewCompiler(Config{
Schema: schema,
Vars: vars,
TableMap: map[string]string{
"mes": "users",
},
})
os.Exit(m.Run()) os.Exit(m.Run())
} }
@ -134,7 +161,7 @@ func withComplexArgs(t *testing.T) {
} }
}` }`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products" ORDER BY "products_0.ob.price" DESC), '[]') AS "products" FROM (SELECT DISTINCT ON ("products_0.ob.price") row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products", "products_0"."price" AS "products_0.ob.price" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."id") < (28)) AND (("products"."id") >= (20))) LIMIT ('30') :: integer) AS "products_0" ORDER BY "products_0.ob.price" DESC LIMIT ('30') :: integer) AS "products_0") AS "done_1337";` sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products" ORDER BY "products_0.ob.price" DESC), '[]') AS "products" FROM (SELECT DISTINCT ON ("products_0.ob.price") row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products", "products_0"."price" AS "products_0.ob.price" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("products"."id") < (28)) AND (("products"."id") >= (20))) LIMIT ('30') :: integer) AS "products_0" ORDER BY "products_0.ob.price" DESC LIMIT ('30') :: integer) AS "products_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql) resSQL, err := compileGQLToPSQL(gql)
if err != nil { if err != nil {
@ -162,7 +189,7 @@ func withWhereMultiOr(t *testing.T) {
} }
}` }`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") < (20)) OR (("products"."price") > (10)) OR NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";` sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("products"."price") < (20)) OR (("products"."price") > (10)) OR NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql) resSQL, err := compileGQLToPSQL(gql)
if err != nil { if err != nil {
@ -188,7 +215,7 @@ func withWhereIsNull(t *testing.T) {
} }
}` }`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > (10)) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";` sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("products"."price") > (10)) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql) resSQL, err := compileGQLToPSQL(gql)
if err != nil { if err != nil {
@ -214,7 +241,7 @@ func withWhereAndList(t *testing.T) {
} }
}` }`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > (10)) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";` sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("products"."price") > (10)) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql) resSQL, err := compileGQLToPSQL(gql)
if err != nil { if err != nil {
@ -260,7 +287,7 @@ func belongsTo(t *testing.T) {
} }
}` }`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."price" AS "price", "users_1.join"."users" AS "users") AS "sel_0")) AS "products" FROM (SELECT "products"."name", "products"."price", "products"."user_id" FROM "products" LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("users"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "users_1"."email" AS "email") AS "sel_1")) AS "users" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = ("products_0"."user_id"))) LIMIT ('20') :: integer) AS "users_1" LIMIT ('20') :: integer) AS "users_1") AS "users_1.join" ON ('true') LIMIT ('20') :: integer) AS "products_0") AS "done_1337";` sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."price" AS "price", "users_1.join"."users" AS "users") AS "sel_0")) AS "products" FROM (SELECT "products"."name", "products"."price", "products"."user_id" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8))) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("users"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "users_1"."email" AS "email") AS "sel_1")) AS "users" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = ("products_0"."user_id"))) LIMIT ('20') :: integer) AS "users_1" LIMIT ('20') :: integer) AS "users_1") AS "users_1.join" ON ('true') LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql) resSQL, err := compileGQLToPSQL(gql)
if err != nil { if err != nil {
@ -283,7 +310,7 @@ func manyToMany(t *testing.T) {
} }
}` }`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "customers_1.join"."customers" AS "customers") AS "sel_0")) AS "products" FROM (SELECT "products"."name", "products"."id" FROM "products" LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("customers"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "customers_1"."email" AS "email", "customers_1"."full_name" AS "full_name") AS "sel_1")) AS "customers" FROM (SELECT "customers"."email", "customers"."full_name" FROM "customers" LEFT OUTER JOIN "purchases" ON (("purchases"."product_id") = ("products_0"."id")) WHERE ((("customers"."id") = ("purchases"."customer_id"))) LIMIT ('20') :: integer) AS "customers_1" LIMIT ('20') :: integer) AS "customers_1") AS "customers_1.join" ON ('true') LIMIT ('20') :: integer) AS "products_0") AS "done_1337";` sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "customers_1.join"."customers" AS "customers") AS "sel_0")) AS "products" FROM (SELECT "products"."name", "products"."id" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8))) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("customers"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "customers_1"."email" AS "email", "customers_1"."full_name" AS "full_name") AS "sel_1")) AS "customers" FROM (SELECT "customers"."email", "customers"."full_name" FROM "customers" LEFT OUTER JOIN "purchases" ON (("purchases"."product_id") = ("products_0"."id")) WHERE ((("customers"."id") = ("purchases"."customer_id"))) LIMIT ('20') :: integer) AS "customers_1" LIMIT ('20') :: integer) AS "customers_1") AS "customers_1.join" ON ('true') LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql) resSQL, err := compileGQLToPSQL(gql)
if err != nil { if err != nil {
@ -320,13 +347,13 @@ func manyToManyReverse(t *testing.T) {
func fetchByID(t *testing.T) { func fetchByID(t *testing.T) {
gql := `query { gql := `query {
product(id: 4) { product(id: 15) {
id id
name name
} }
}` }`
sql := `SELECT json_object_agg('product', products) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("id") = ('4'))) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337";` sql := `SELECT json_object_agg('product', products) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("id") = ('15'))) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql) resSQL, err := compileGQLToPSQL(gql)
if err != nil { if err != nil {
@ -340,13 +367,13 @@ func fetchByID(t *testing.T) {
func searchQuery(t *testing.T) { func searchQuery(t *testing.T) {
gql := `query { gql := `query {
products(search: "Amazing") { products(search: "Imperial") {
id id
name name
} }
}` }`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("tsv") @@ to_tsquery('Amazing'))) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";` sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("tsv") @@ to_tsquery('Imperial'))) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql) resSQL, err := compileGQLToPSQL(gql)
if err != nil { if err != nil {
@ -366,7 +393,7 @@ func aggFunction(t *testing.T) {
} }
}` }`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."count_price" AS "count_price") AS "sel_0")) AS "products" FROM (SELECT "products"."name", count("products"."price") AS count_price FROM "products" GROUP BY "products"."name" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";` sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."count_price" AS "count_price") AS "sel_0")) AS "products" FROM (SELECT "products"."name", count("products"."price") AS count_price FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8))) GROUP BY "products"."name" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql) resSQL, err := compileGQLToPSQL(gql)
if err != nil { if err != nil {
@ -386,7 +413,26 @@ func aggFunctionWithFilter(t *testing.T) {
} }
}` }`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."max_price" AS "max_price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", max("products"."price") AS max_price FROM "products" WHERE ((("products"."id") > (10))) GROUP BY "products"."id" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";` sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."max_price" AS "max_price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", max("products"."price") AS max_price FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("products"."id") > (10))) GROUP BY "products"."id" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql)
if err != nil {
t.Fatal(err)
}
if resSQL != sql {
t.Fatal(errNotExpected)
}
}
func syntheticTables(t *testing.T) {
gql := `query {
me {
email
}
}`
sql := `SELECT json_object_agg('me', mes) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "mes_0"."email" AS "email") AS "sel_0")) AS "mes" FROM (SELECT "mes"."email" FROM "users" AS "mes" WHERE ((("mes"."id") = ('{{user_id}}'))) LIMIT ('1') :: integer) AS "mes_0" LIMIT ('1') :: integer) AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql) resSQL, err := compileGQLToPSQL(gql)
if err != nil { if err != nil {
@ -411,6 +457,7 @@ func TestCompileGQL(t *testing.T) {
t.Run("manyToManyReverse", manyToManyReverse) t.Run("manyToManyReverse", manyToManyReverse)
t.Run("aggFunction", aggFunction) t.Run("aggFunction", aggFunction)
t.Run("aggFunctionWithFilter", aggFunctionWithFilter) t.Run("aggFunctionWithFilter", aggFunctionWithFilter)
t.Run("syntheticTables", syntheticTables)
} }
func BenchmarkCompileGQLToSQL(b *testing.B) { func BenchmarkCompileGQLToSQL(b *testing.B) {

View File

@ -16,16 +16,15 @@ type TTKey struct {
} }
type DBSchema struct { type DBSchema struct {
ColMap map[TCKey]*DBColumn Tables map[string]*DBTableInfo
ColIDMap map[int]*DBColumn
Tables map[string]*DBTableInfo
RelMap map[TTKey]*DBRel RelMap map[TTKey]*DBRel
} }
type DBTableInfo struct { type DBTableInfo struct {
Name string
PrimaryCol string PrimaryCol string
TSVCol string TSVCol string
Columns map[string]*DBColumn
} }
type RelType int type RelType int
@ -45,7 +44,10 @@ type DBRel struct {
} }
func NewDBSchema(db *pg.DB) (*DBSchema, error) { func NewDBSchema(db *pg.DB) (*DBSchema, error) {
schema := initSchema() schema := &DBSchema{
Tables: make(map[string]*DBTableInfo),
RelMap: make(map[TTKey]*DBRel),
}
tables, err := GetTables(db) tables, err := GetTables(db)
if err != nil { if err != nil {
@ -58,41 +60,39 @@ func NewDBSchema(db *pg.DB) (*DBSchema, error) {
return nil, err return nil, err
} }
updateSchema(schema, t, cols) schema.updateSchema(t, cols)
} }
return schema, nil return schema, nil
} }
func initSchema() *DBSchema { func (s *DBSchema) updateSchema(t *DBTable, cols []*DBColumn) {
return &DBSchema{
ColMap: make(map[TCKey]*DBColumn),
ColIDMap: make(map[int]*DBColumn),
Tables: make(map[string]*DBTableInfo),
RelMap: make(map[TTKey]*DBRel),
}
}
func updateSchema(schema *DBSchema, t *DBTable, cols []*DBColumn) {
// Current table // Current table
ct := strings.ToLower(t.Name) ti := &DBTableInfo{
schema.Tables[ct] = &DBTableInfo{} Name: t.Name,
Columns: make(map[string]*DBColumn, len(cols)),
}
// Foreign key columns in current table // Foreign key columns in current table
var jcols []*DBColumn var jcols []*DBColumn
colByID := make(map[int]*DBColumn)
for _, c := range cols { for i := range cols {
schema.ColMap[TCKey{ct, strings.ToLower(c.Name)}] = c c := cols[i]
schema.ColIDMap[c.ID] = c ti.Columns[strings.ToLower(c.Name)] = cols[i]
colByID[c.ID] = cols[i]
} }
ct := strings.ToLower(t.Name)
s.Tables[ct] = ti
for _, c := range cols { for _, c := range cols {
switch { switch {
case c.Type == "tsvector": case c.Type == "tsvector":
schema.Tables[ct].TSVCol = c.Name s.Tables[ct].TSVCol = c.Name
case c.PrimaryKey: case c.PrimaryKey:
schema.Tables[ct].PrimaryCol = c.Name s.Tables[ct].PrimaryCol = c.Name
case len(c.FKeyTable) != 0: case len(c.FKeyTable) != 0:
if len(c.FKeyColID) == 0 { if len(c.FKeyColID) == 0 {
@ -101,7 +101,7 @@ func updateSchema(schema *DBSchema, t *DBTable, cols []*DBColumn) {
// Foreign key column name // Foreign key column name
ft := strings.ToLower(c.FKeyTable) ft := strings.ToLower(c.FKeyTable)
fc, ok := schema.ColIDMap[c.FKeyColID[0]] fc, ok := colByID[c.FKeyColID[0]]
if !ok { if !ok {
continue continue
} }
@ -109,12 +109,12 @@ func updateSchema(schema *DBSchema, t *DBTable, cols []*DBColumn) {
// Belongs-to relation between current table and the // Belongs-to relation between current table and the
// table in the foreign key // table in the foreign key
rel1 := &DBRel{RelBelongTo, "", "", c.Name, fc.Name} rel1 := &DBRel{RelBelongTo, "", "", c.Name, fc.Name}
schema.RelMap[TTKey{ct, ft}] = rel1 s.RelMap[TTKey{ct, ft}] = rel1
// One-to-many relation between the foreign key table and the // One-to-many relation between the foreign key table and the
// the current table // the current table
rel2 := &DBRel{RelOneToMany, "", "", fc.Name, c.Name} rel2 := &DBRel{RelOneToMany, "", "", fc.Name, c.Name}
schema.RelMap[TTKey{ft, ct}] = rel2 s.RelMap[TTKey{ft, ct}] = rel2
jcols = append(jcols, c) jcols = append(jcols, c)
} }
@ -130,22 +130,24 @@ func updateSchema(schema *DBSchema, t *DBTable, cols []*DBColumn) {
for i := range jcols { for i := range jcols {
for n := range jcols { for n := range jcols {
if n != i { if n != i {
updateSchemaOTMT(schema, ct, jcols[i], jcols[n]) s.updateSchemaOTMT(ct, jcols[i], jcols[n], colByID)
} }
} }
} }
} }
} }
func updateSchemaOTMT(schema *DBSchema, ct string, col1, col2 *DBColumn) { func (s *DBSchema) updateSchemaOTMT(
ct string, col1, col2 *DBColumn, colByID map[int]*DBColumn) {
t1 := strings.ToLower(col1.FKeyTable) t1 := strings.ToLower(col1.FKeyTable)
t2 := strings.ToLower(col2.FKeyTable) t2 := strings.ToLower(col2.FKeyTable)
fc1, ok := schema.ColIDMap[col1.FKeyColID[0]] fc1, ok := colByID[col1.FKeyColID[0]]
if !ok { if !ok {
return return
} }
fc2, ok := schema.ColIDMap[col2.FKeyColID[0]] fc2, ok := colByID[col2.FKeyColID[0]]
if !ok { if !ok {
return return
} }
@ -154,13 +156,13 @@ func updateSchemaOTMT(schema *DBSchema, ct string, col1, col2 *DBColumn) {
// 2nd foreign key table // 2nd foreign key table
//rel1 := &DBRel{RelOneToManyThrough, ct, fc1.Name, col1.Name} //rel1 := &DBRel{RelOneToManyThrough, ct, fc1.Name, col1.Name}
rel1 := &DBRel{RelOneToManyThrough, ct, col2.Name, fc2.Name, col1.Name} rel1 := &DBRel{RelOneToManyThrough, ct, col2.Name, fc2.Name, col1.Name}
schema.RelMap[TTKey{t1, t2}] = rel1 s.RelMap[TTKey{t1, t2}] = rel1
// One-to-many-through relation between 2nd foreign key table and the // One-to-many-through relation between 2nd foreign key table and the
// 1nd foreign key table // 1nd foreign key table
//rel2 := &DBRel{RelOneToManyThrough, ct, col2.Name, fc2.Name} //rel2 := &DBRel{RelOneToManyThrough, ct, col2.Name, fc2.Name}
rel2 := &DBRel{RelOneToManyThrough, ct, col1.Name, fc1.Name, col2.Name} rel2 := &DBRel{RelOneToManyThrough, ct, col1.Name, fc1.Name, col2.Name}
schema.RelMap[TTKey{t2, t1}] = rel2 s.RelMap[TTKey{t2, t1}] = rel2
} }
type DBTable struct { type DBTable struct {
@ -242,13 +244,13 @@ WHERE c.relkind = 'r'::char
stmt, err := db.Prepare(sqlStmt) stmt, err := db.Prepare(sqlStmt)
if err != nil { if err != nil {
return nil, fmt.Errorf("Error fetching columns: %s", err) return nil, fmt.Errorf("error fetching columns: %s", err)
} }
var t []*DBColumn var t []*DBColumn
_, err = stmt.Query(&t, schema, table) _, err = stmt.Query(&t, schema, table)
if err != nil { if err != nil {
return nil, fmt.Errorf("Error fetching columns: %s", err) return nil, fmt.Errorf("error fetching columns: %s", err)
} }
return t, nil return t, nil
@ -257,7 +259,7 @@ WHERE c.relkind = 'r'::char
func (s *DBSchema) GetTable(table string) (*DBTableInfo, error) { func (s *DBSchema) GetTable(table string) (*DBTableInfo, error) {
t, ok := s.Tables[table] t, ok := s.Tables[table]
if !ok { if !ok {
return nil, fmt.Errorf("table info not found '%s'", table) return nil, fmt.Errorf("unknown table '%s'", table)
} }
return t, nil return t, nil
} }

View File

@ -61,7 +61,8 @@ type Paging struct {
type ExpOp int type ExpOp int
const ( const (
OpAnd ExpOp = iota + 1 OpNop ExpOp = iota
OpAnd
OpOr OpOr
OpNot OpNot
OpEquals OpEquals
@ -92,6 +93,8 @@ func (t ExpOp) String() string {
var v string var v string
switch t { switch t {
case OpNop:
v = "op-nop"
case OpAnd: case OpAnd:
v = "op-and" v = "op-and"
case OpOr: case OpOr:
@ -333,7 +336,12 @@ func (com *Compiler) compileQuery(op *Operation) (*Query, error) {
} }
} }
if fil, ok := com.fm[selRoot.Table]; ok { fil, ok := com.fm[selRoot.Table]
if !ok {
fil = com.fl
}
if fil != nil && fil.Op != OpNop {
if selRoot.Where != nil { if selRoot.Where != nil {
selRoot.Where = &Exp{Op: OpAnd, Children: []*Exp{fil, selRoot.Where}} selRoot.Where = &Exp{Op: OpAnd, Children: []*Exp{fil, selRoot.Where}}
} else { } else {
@ -788,6 +796,10 @@ func compileFilter(filter []string) (*Exp, error) {
var fl *Exp var fl *Exp
com := &Compiler{} com := &Compiler{}
if len(filter) == 0 {
return &Exp{Op: OpNop}, nil
}
for i := range filter { for i := range filter {
node, err := ParseArgValue(filter[i]) node, err := ParseArgValue(filter[i])
if err != nil { if err != nil {

View File

@ -189,9 +189,21 @@ func initCompilers(c *config) (*qcode.Compiler, *psql.Compiler, error) {
cdb := c.DB cdb := c.DB
fm := make(map[string][]string, len(cdb.Fields)) fm := make(map[string][]string, len(cdb.Fields))
tmap := make(map[string]string, len(cdb.Fields))
for i := range cdb.Fields { for i := range cdb.Fields {
f := cdb.Fields[i] f := cdb.Fields[i]
fm[strings.ToLower(f.Name)] = f.Filter name := flect.Pluralize(strings.ToLower(f.Name))
if len(f.Filter) != 0 {
if f.Filter[0] == "none" {
fm[name] = []string{}
} else {
fm[name] = f.Filter
}
}
if len(f.Table) != 0 {
tmap[name] = f.Table
}
} }
qc, err := qcode.NewCompiler(qcode.Config{ qc, err := qcode.NewCompiler(qcode.Config{
@ -209,8 +221,9 @@ func initCompilers(c *config) (*qcode.Compiler, *psql.Compiler, error) {
} }
pc := psql.NewCompiler(psql.Config{ pc := psql.NewCompiler(psql.Config{
Schema: schema, Schema: schema,
Vars: cdb.Variables, Vars: cdb.Variables,
TableMap: tmap,
}) })
return qc, pc, nil return qc, pc, nil