Redesign config file architecture

This commit is contained in:
Vikram Rangnekar 2019-04-08 21:24:29 -04:00
parent e3660473cc
commit 2d02f2afda
6 changed files with 198 additions and 98 deletions

24
dev.yml
View File

@ -64,7 +64,7 @@ database:
# Define defaults to for the field key and values below
defaults:
filter: ["{ id: { _eq: $user_id } }"]
filter: ["{ user_id: { eq: $user_id } }"]
# Fields and table names that you wish to block
blacklist:
@ -77,12 +77,20 @@ database:
fields:
- name: users
filter: ["{ id: { _eq: $user_id } }"]
filter: ["{ id: { eq: $user_id } }"]
- name: products
filter: [
"{ price: { gt: 0 } }",
"{ price: { lt: 8 } }"
]
- name: customers
filter: none
- name: me
table: users
filter: ["{ id: { eq: $user_id } }"]
# - name: posts
# filter: ["{ account_id: { _eq: $account_id } }"]
- name: my_products
table: products
filter: ["{ id: { _eq: $user_id } }"]
# filter: ["{ account_id: { _eq: $account_id } }"]

View File

@ -11,22 +11,27 @@ import (
)
type Config struct {
Schema *DBSchema
Vars map[string]string
Schema *DBSchema
Vars map[string]string
TableMap map[string]string
}
type Compiler struct {
schema *DBSchema
vars map[string]string
tmap map[string]string
}
func NewCompiler(conf Config) *Compiler {
return &Compiler{conf.Schema, conf.Vars}
return &Compiler{conf.Schema, conf.Vars, conf.TableMap}
}
func (c *Compiler) Compile(w io.Writer, qc *qcode.QCode) error {
st := util.NewStack()
ti, _ := c.schema.GetTable(qc.Query.Select.Table)
ti, err := c.getTable(qc.Query.Select)
if err != nil {
return err
}
st.Push(&selectBlockClose{nil, qc.Query.Select})
st.Push(&selectBlock{nil, qc.Query.Select, ti, c})
@ -47,12 +52,13 @@ func (c *Compiler) Compile(w io.Writer, qc *qcode.QCode) error {
v.render(w, c.schema, childCols, childIDs)
for i := range childIDs {
ti, err := c.schema.GetTable(v.sel.Table)
if err != nil {
continue
}
sub := v.sel.Joins[childIDs[i]]
ti, err := c.getTable(sub)
if err != nil {
return err
}
st.Push(&joinClose{sub})
st.Push(&selectBlockClose{v.sel, sub})
st.Push(&selectBlock{v.sel, sub, ti, c})
@ -75,6 +81,13 @@ func (c *Compiler) Compile(w io.Writer, qc *qcode.QCode) error {
return nil
}
func (c *Compiler) getTable(sel *qcode.Select) (*DBTableInfo, error) {
if tn, ok := c.tmap[sel.Table]; ok {
return c.schema.GetTable(tn)
}
return c.schema.GetTable(sel.Table)
}
func (c *Compiler) relationshipColumns(parent *qcode.Select) (
cols []*qcode.Column, childIDs []int) {
@ -275,33 +288,34 @@ func (v *selectBlock) renderBaseSelect(w io.Writer, schema *DBSchema, childCols
isRoot := v.parent == nil
isFil := v.sel.Where != nil
isSearch := v.sel.Args["search"] != nil
isAgg := false
_, isSearch := v.sel.Args["search"]
io.WriteString(w, " FROM (SELECT ")
for i, col := range v.sel.Cols {
cn := col.Name
_, isRealCol := v.schema.ColMap[TCKey{v.sel.Table, cn}]
_, isRealCol := v.ti.Columns[cn]
if !isRealCol {
switch {
case isSearch && cn == "search_rank":
cn = v.ti.TSVCol
arg := v.sel.Args["search"]
if isSearch {
switch {
case cn == "search_rank":
cn = v.ti.TSVCol
arg := v.sel.Args["search"]
fmt.Fprintf(w, `ts_rank("%s"."%s", to_tsquery('%s')) AS %s`,
v.sel.Table, cn, arg.Val, col.Name)
fmt.Fprintf(w, `ts_rank("%s"."%s", to_tsquery('%s')) AS %s`,
v.sel.Table, cn, arg.Val, col.Name)
case isSearch && strings.HasPrefix(cn, "search_headline_"):
cn = cn[16:]
arg := v.sel.Args["search"]
case strings.HasPrefix(cn, "search_headline_"):
cn = cn[16:]
arg := v.sel.Args["search"]
fmt.Fprintf(w, `ts_headline("%s"."%s", to_tsquery('%s')) AS %s`,
v.sel.Table, cn, arg.Val, col.Name)
default:
fmt.Fprintf(w, `ts_headline("%s"."%s", to_tsquery('%s')) AS %s`,
v.sel.Table, cn, arg.Val, col.Name)
}
} else {
pl := funcPrefixLen(cn)
if pl == 0 {
fmt.Fprintf(w, `'%s not defined' AS %s`, cn, col.Name)
@ -330,7 +344,11 @@ func (v *selectBlock) renderBaseSelect(w io.Writer, schema *DBSchema, childCols
}
}
fmt.Fprintf(w, ` FROM "%s"`, v.sel.Table)
if tn, ok := v.tmap[v.sel.Table]; ok {
fmt.Fprintf(w, ` FROM "%s" AS "%s"`, tn, v.sel.Table)
} else {
fmt.Fprintf(w, ` FROM "%s"`, v.sel.Table)
}
if isRoot && isFil {
io.WriteString(w, ` WHERE (`)

View File

@ -1,6 +1,7 @@
package psql
import (
"log"
"os"
"strings"
"testing"
@ -18,18 +19,35 @@ var (
)
func TestMain(m *testing.M) {
fm := qcode.NewFilterMap(map[string]string{
"users": "{ id: { _eq: $user_id } }",
"posts": "{ account_id: { _eq: $account_id } }",
var err error
qcompile, err = qcode.NewCompiler(qcode.Config{
Filter: []string{
`{ user_id: { _eq: $user_id } }`,
},
FilterMap: map[string][]string{
"users": []string{
"{ id: { eq: $user_id } }",
},
"products": []string{
"{ price: { gt: 0 } }",
"{ price: { lt: 8 } }",
},
"customers": []string{},
"mes": []string{
"{ id: { eq: $user_id } }",
},
},
Blacklist: []string{
"secret",
"password",
"token",
},
})
bl := qcode.NewBlacklist([]string{
"secret",
"password",
"token",
})
qcompile = qcode.NewCompiler(fm, bl)
if err != nil {
log.Fatal(err)
}
tables := []*DBTable{
&DBTable{Name: "customers", Type: "table"},
@ -81,17 +99,26 @@ func TestMain(m *testing.M) {
&DBColumn{ID: 7, Name: "returned", Type: "timestamp without time zone", NotNull: false, PrimaryKey: false, Uniquekey: false, FKeyTable: "", FKeyColID: []int(nil)}},
}
schema := initSchema()
schema := &DBSchema{
Tables: make(map[string]*DBTableInfo),
RelMap: make(map[TTKey]*DBRel),
}
for i, t := range tables {
updateSchema(schema, t, columns[i])
schema.updateSchema(t, columns[i])
}
vars := NewVariables(map[string]string{
"account_id": "select account_id from users where id = $user_id",
})
pcompile = NewCompiler(schema, vars)
pcompile = NewCompiler(Config{
Schema: schema,
Vars: vars,
TableMap: map[string]string{
"mes": "users",
},
})
os.Exit(m.Run())
}
@ -134,7 +161,7 @@ func withComplexArgs(t *testing.T) {
}
}`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products" ORDER BY "products_0.ob.price" DESC), '[]') AS "products" FROM (SELECT DISTINCT ON ("products_0.ob.price") row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products", "products_0"."price" AS "products_0.ob.price" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."id") < (28)) AND (("products"."id") >= (20))) LIMIT ('30') :: integer) AS "products_0" ORDER BY "products_0.ob.price" DESC LIMIT ('30') :: integer) AS "products_0") AS "done_1337";`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products" ORDER BY "products_0.ob.price" DESC), '[]') AS "products" FROM (SELECT DISTINCT ON ("products_0.ob.price") row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products", "products_0"."price" AS "products_0.ob.price" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("products"."id") < (28)) AND (("products"."id") >= (20))) LIMIT ('30') :: integer) AS "products_0" ORDER BY "products_0.ob.price" DESC LIMIT ('30') :: integer) AS "products_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql)
if err != nil {
@ -162,7 +189,7 @@ func withWhereMultiOr(t *testing.T) {
}
}`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") < (20)) OR (("products"."price") > (10)) OR NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("products"."price") < (20)) OR (("products"."price") > (10)) OR NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql)
if err != nil {
@ -188,7 +215,7 @@ func withWhereIsNull(t *testing.T) {
}
}`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > (10)) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("products"."price") > (10)) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql)
if err != nil {
@ -214,7 +241,7 @@ func withWhereAndList(t *testing.T) {
}
}`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > (10)) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("products"."price") > (10)) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql)
if err != nil {
@ -260,7 +287,7 @@ func belongsTo(t *testing.T) {
}
}`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."price" AS "price", "users_1.join"."users" AS "users") AS "sel_0")) AS "products" FROM (SELECT "products"."name", "products"."price", "products"."user_id" FROM "products" LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("users"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "users_1"."email" AS "email") AS "sel_1")) AS "users" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = ("products_0"."user_id"))) LIMIT ('20') :: integer) AS "users_1" LIMIT ('20') :: integer) AS "users_1") AS "users_1.join" ON ('true') LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."price" AS "price", "users_1.join"."users" AS "users") AS "sel_0")) AS "products" FROM (SELECT "products"."name", "products"."price", "products"."user_id" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8))) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("users"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "users_1"."email" AS "email") AS "sel_1")) AS "users" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = ("products_0"."user_id"))) LIMIT ('20') :: integer) AS "users_1" LIMIT ('20') :: integer) AS "users_1") AS "users_1.join" ON ('true') LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql)
if err != nil {
@ -283,7 +310,7 @@ func manyToMany(t *testing.T) {
}
}`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "customers_1.join"."customers" AS "customers") AS "sel_0")) AS "products" FROM (SELECT "products"."name", "products"."id" FROM "products" LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("customers"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "customers_1"."email" AS "email", "customers_1"."full_name" AS "full_name") AS "sel_1")) AS "customers" FROM (SELECT "customers"."email", "customers"."full_name" FROM "customers" LEFT OUTER JOIN "purchases" ON (("purchases"."product_id") = ("products_0"."id")) WHERE ((("customers"."id") = ("purchases"."customer_id"))) LIMIT ('20') :: integer) AS "customers_1" LIMIT ('20') :: integer) AS "customers_1") AS "customers_1.join" ON ('true') LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "customers_1.join"."customers" AS "customers") AS "sel_0")) AS "products" FROM (SELECT "products"."name", "products"."id" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8))) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("customers"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "customers_1"."email" AS "email", "customers_1"."full_name" AS "full_name") AS "sel_1")) AS "customers" FROM (SELECT "customers"."email", "customers"."full_name" FROM "customers" LEFT OUTER JOIN "purchases" ON (("purchases"."product_id") = ("products_0"."id")) WHERE ((("customers"."id") = ("purchases"."customer_id"))) LIMIT ('20') :: integer) AS "customers_1" LIMIT ('20') :: integer) AS "customers_1") AS "customers_1.join" ON ('true') LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql)
if err != nil {
@ -320,13 +347,13 @@ func manyToManyReverse(t *testing.T) {
func fetchByID(t *testing.T) {
gql := `query {
product(id: 4) {
product(id: 15) {
id
name
}
}`
sql := `SELECT json_object_agg('product', products) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("id") = ('4'))) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337";`
sql := `SELECT json_object_agg('product', products) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("id") = ('15'))) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql)
if err != nil {
@ -340,13 +367,13 @@ func fetchByID(t *testing.T) {
func searchQuery(t *testing.T) {
gql := `query {
products(search: "Amazing") {
products(search: "Imperial") {
id
name
}
}`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("tsv") @@ to_tsquery('Amazing'))) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("tsv") @@ to_tsquery('Imperial'))) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql)
if err != nil {
@ -366,7 +393,7 @@ func aggFunction(t *testing.T) {
}
}`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."count_price" AS "count_price") AS "sel_0")) AS "products" FROM (SELECT "products"."name", count("products"."price") AS count_price FROM "products" GROUP BY "products"."name" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."count_price" AS "count_price") AS "sel_0")) AS "products" FROM (SELECT "products"."name", count("products"."price") AS count_price FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8))) GROUP BY "products"."name" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql)
if err != nil {
@ -386,7 +413,26 @@ func aggFunctionWithFilter(t *testing.T) {
}
}`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."max_price" AS "max_price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", max("products"."price") AS max_price FROM "products" WHERE ((("products"."id") > (10))) GROUP BY "products"."id" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."max_price" AS "max_price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", max("products"."price") AS max_price FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("products"."id") > (10))) GROUP BY "products"."id" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql)
if err != nil {
t.Fatal(err)
}
if resSQL != sql {
t.Fatal(errNotExpected)
}
}
func syntheticTables(t *testing.T) {
gql := `query {
me {
email
}
}`
sql := `SELECT json_object_agg('me', mes) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "mes_0"."email" AS "email") AS "sel_0")) AS "mes" FROM (SELECT "mes"."email" FROM "users" AS "mes" WHERE ((("mes"."id") = ('{{user_id}}'))) LIMIT ('1') :: integer) AS "mes_0" LIMIT ('1') :: integer) AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql)
if err != nil {
@ -411,6 +457,7 @@ func TestCompileGQL(t *testing.T) {
t.Run("manyToManyReverse", manyToManyReverse)
t.Run("aggFunction", aggFunction)
t.Run("aggFunctionWithFilter", aggFunctionWithFilter)
t.Run("syntheticTables", syntheticTables)
}
func BenchmarkCompileGQLToSQL(b *testing.B) {

View File

@ -16,16 +16,15 @@ type TTKey struct {
}
type DBSchema struct {
ColMap map[TCKey]*DBColumn
ColIDMap map[int]*DBColumn
Tables map[string]*DBTableInfo
Tables map[string]*DBTableInfo
RelMap map[TTKey]*DBRel
}
type DBTableInfo struct {
Name string
PrimaryCol string
TSVCol string
Columns map[string]*DBColumn
}
type RelType int
@ -45,7 +44,10 @@ type DBRel struct {
}
func NewDBSchema(db *pg.DB) (*DBSchema, error) {
schema := initSchema()
schema := &DBSchema{
Tables: make(map[string]*DBTableInfo),
RelMap: make(map[TTKey]*DBRel),
}
tables, err := GetTables(db)
if err != nil {
@ -58,41 +60,39 @@ func NewDBSchema(db *pg.DB) (*DBSchema, error) {
return nil, err
}
updateSchema(schema, t, cols)
schema.updateSchema(t, cols)
}
return schema, nil
}
func initSchema() *DBSchema {
return &DBSchema{
ColMap: make(map[TCKey]*DBColumn),
ColIDMap: make(map[int]*DBColumn),
Tables: make(map[string]*DBTableInfo),
RelMap: make(map[TTKey]*DBRel),
}
}
func updateSchema(schema *DBSchema, t *DBTable, cols []*DBColumn) {
func (s *DBSchema) updateSchema(t *DBTable, cols []*DBColumn) {
// Current table
ct := strings.ToLower(t.Name)
schema.Tables[ct] = &DBTableInfo{}
ti := &DBTableInfo{
Name: t.Name,
Columns: make(map[string]*DBColumn, len(cols)),
}
// Foreign key columns in current table
var jcols []*DBColumn
colByID := make(map[int]*DBColumn)
for _, c := range cols {
schema.ColMap[TCKey{ct, strings.ToLower(c.Name)}] = c
schema.ColIDMap[c.ID] = c
for i := range cols {
c := cols[i]
ti.Columns[strings.ToLower(c.Name)] = cols[i]
colByID[c.ID] = cols[i]
}
ct := strings.ToLower(t.Name)
s.Tables[ct] = ti
for _, c := range cols {
switch {
case c.Type == "tsvector":
schema.Tables[ct].TSVCol = c.Name
s.Tables[ct].TSVCol = c.Name
case c.PrimaryKey:
schema.Tables[ct].PrimaryCol = c.Name
s.Tables[ct].PrimaryCol = c.Name
case len(c.FKeyTable) != 0:
if len(c.FKeyColID) == 0 {
@ -101,7 +101,7 @@ func updateSchema(schema *DBSchema, t *DBTable, cols []*DBColumn) {
// Foreign key column name
ft := strings.ToLower(c.FKeyTable)
fc, ok := schema.ColIDMap[c.FKeyColID[0]]
fc, ok := colByID[c.FKeyColID[0]]
if !ok {
continue
}
@ -109,12 +109,12 @@ func updateSchema(schema *DBSchema, t *DBTable, cols []*DBColumn) {
// Belongs-to relation between current table and the
// table in the foreign key
rel1 := &DBRel{RelBelongTo, "", "", c.Name, fc.Name}
schema.RelMap[TTKey{ct, ft}] = rel1
s.RelMap[TTKey{ct, ft}] = rel1
// One-to-many relation between the foreign key table and the
// the current table
rel2 := &DBRel{RelOneToMany, "", "", fc.Name, c.Name}
schema.RelMap[TTKey{ft, ct}] = rel2
s.RelMap[TTKey{ft, ct}] = rel2
jcols = append(jcols, c)
}
@ -130,22 +130,24 @@ func updateSchema(schema *DBSchema, t *DBTable, cols []*DBColumn) {
for i := range jcols {
for n := range jcols {
if n != i {
updateSchemaOTMT(schema, ct, jcols[i], jcols[n])
s.updateSchemaOTMT(ct, jcols[i], jcols[n], colByID)
}
}
}
}
}
func updateSchemaOTMT(schema *DBSchema, ct string, col1, col2 *DBColumn) {
func (s *DBSchema) updateSchemaOTMT(
ct string, col1, col2 *DBColumn, colByID map[int]*DBColumn) {
t1 := strings.ToLower(col1.FKeyTable)
t2 := strings.ToLower(col2.FKeyTable)
fc1, ok := schema.ColIDMap[col1.FKeyColID[0]]
fc1, ok := colByID[col1.FKeyColID[0]]
if !ok {
return
}
fc2, ok := schema.ColIDMap[col2.FKeyColID[0]]
fc2, ok := colByID[col2.FKeyColID[0]]
if !ok {
return
}
@ -154,13 +156,13 @@ func updateSchemaOTMT(schema *DBSchema, ct string, col1, col2 *DBColumn) {
// 2nd foreign key table
//rel1 := &DBRel{RelOneToManyThrough, ct, fc1.Name, col1.Name}
rel1 := &DBRel{RelOneToManyThrough, ct, col2.Name, fc2.Name, col1.Name}
schema.RelMap[TTKey{t1, t2}] = rel1
s.RelMap[TTKey{t1, t2}] = rel1
// One-to-many-through relation between 2nd foreign key table and the
// 1nd foreign key table
//rel2 := &DBRel{RelOneToManyThrough, ct, col2.Name, fc2.Name}
rel2 := &DBRel{RelOneToManyThrough, ct, col1.Name, fc1.Name, col2.Name}
schema.RelMap[TTKey{t2, t1}] = rel2
s.RelMap[TTKey{t2, t1}] = rel2
}
type DBTable struct {
@ -242,13 +244,13 @@ WHERE c.relkind = 'r'::char
stmt, err := db.Prepare(sqlStmt)
if err != nil {
return nil, fmt.Errorf("Error fetching columns: %s", err)
return nil, fmt.Errorf("error fetching columns: %s", err)
}
var t []*DBColumn
_, err = stmt.Query(&t, schema, table)
if err != nil {
return nil, fmt.Errorf("Error fetching columns: %s", err)
return nil, fmt.Errorf("error fetching columns: %s", err)
}
return t, nil
@ -257,7 +259,7 @@ WHERE c.relkind = 'r'::char
func (s *DBSchema) GetTable(table string) (*DBTableInfo, error) {
t, ok := s.Tables[table]
if !ok {
return nil, fmt.Errorf("table info not found '%s'", table)
return nil, fmt.Errorf("unknown table '%s'", table)
}
return t, nil
}

View File

@ -61,7 +61,8 @@ type Paging struct {
type ExpOp int
const (
OpAnd ExpOp = iota + 1
OpNop ExpOp = iota
OpAnd
OpOr
OpNot
OpEquals
@ -92,6 +93,8 @@ func (t ExpOp) String() string {
var v string
switch t {
case OpNop:
v = "op-nop"
case OpAnd:
v = "op-and"
case OpOr:
@ -333,7 +336,12 @@ func (com *Compiler) compileQuery(op *Operation) (*Query, error) {
}
}
if fil, ok := com.fm[selRoot.Table]; ok {
fil, ok := com.fm[selRoot.Table]
if !ok {
fil = com.fl
}
if fil != nil && fil.Op != OpNop {
if selRoot.Where != nil {
selRoot.Where = &Exp{Op: OpAnd, Children: []*Exp{fil, selRoot.Where}}
} else {
@ -788,6 +796,10 @@ func compileFilter(filter []string) (*Exp, error) {
var fl *Exp
com := &Compiler{}
if len(filter) == 0 {
return &Exp{Op: OpNop}, nil
}
for i := range filter {
node, err := ParseArgValue(filter[i])
if err != nil {

View File

@ -189,9 +189,21 @@ func initCompilers(c *config) (*qcode.Compiler, *psql.Compiler, error) {
cdb := c.DB
fm := make(map[string][]string, len(cdb.Fields))
tmap := make(map[string]string, len(cdb.Fields))
for i := range cdb.Fields {
f := cdb.Fields[i]
fm[strings.ToLower(f.Name)] = f.Filter
name := flect.Pluralize(strings.ToLower(f.Name))
if len(f.Filter) != 0 {
if f.Filter[0] == "none" {
fm[name] = []string{}
} else {
fm[name] = f.Filter
}
}
if len(f.Table) != 0 {
tmap[name] = f.Table
}
}
qc, err := qcode.NewCompiler(qcode.Config{
@ -209,8 +221,9 @@ func initCompilers(c *config) (*qcode.Compiler, *psql.Compiler, error) {
}
pc := psql.NewCompiler(psql.Config{
Schema: schema,
Vars: cdb.Variables,
Schema: schema,
Vars: cdb.Variables,
TableMap: tmap,
})
return qc, pc, nil