Add role based access control

This commit is contained in:
Vikram Rangnekar
2019-10-14 02:51:36 -04:00
parent 85a74ed30c
commit deb5b93c81
13 changed files with 645 additions and 350 deletions

View File

@ -10,7 +10,7 @@ import (
"github.com/dosco/super-graph/qcode"
)
var zeroPaging = qcode.Paging{}
var noLimit = qcode.Paging{NoLimit: true}
func (co *Compiler) compileMutation(qc *qcode.QCode, w *bytes.Buffer, vars Variables) (uint32, error) {
if len(qc.Selects) == 0 {
@ -29,23 +29,23 @@ func (co *Compiler) compileMutation(qc *qcode.QCode, w *bytes.Buffer, vars Varia
quoted(c.w, ti.Name)
c.w.WriteString(` AS `)
switch root.Action {
case qcode.ActionInsert:
switch qc.Type {
case qcode.QTInsert:
if _, err := c.renderInsert(qc, w, vars, ti); err != nil {
return 0, err
}
case qcode.ActionUpdate:
case qcode.QTUpdate:
if _, err := c.renderUpdate(qc, w, vars, ti); err != nil {
return 0, err
}
case qcode.ActionUpsert:
case qcode.QTUpsert:
if _, err := c.renderUpsert(qc, w, vars, ti); err != nil {
return 0, err
}
case qcode.ActionDelete:
case qcode.QTDelete:
if _, err := c.renderDelete(qc, w, vars, ti); err != nil {
return 0, err
}
@ -56,22 +56,23 @@ func (co *Compiler) compileMutation(qc *qcode.QCode, w *bytes.Buffer, vars Varia
io.WriteString(c.w, ` RETURNING *) `)
root.Paging = zeroPaging
root.Paging = noLimit
root.DistinctOn = root.DistinctOn[:]
root.OrderBy = root.OrderBy[:]
root.Where = nil
root.Args = nil
qc.Type = qcode.QTQuery
return c.compileQuery(qc, w)
}
func (c *compilerContext) renderInsert(qc *qcode.QCode, w *bytes.Buffer,
vars Variables, ti *DBTableInfo) (uint32, error) {
root := &qc.Selects[0]
insert, ok := vars[root.ActionVar]
insert, ok := vars[qc.ActionVar]
if !ok {
return 0, fmt.Errorf("Variable '%s' not defined", root.ActionVar)
return 0, fmt.Errorf("Variable '%s' not defined", qc.ActionVar)
}
jt, array, err := jsn.Tree(insert)
@ -80,7 +81,7 @@ func (c *compilerContext) renderInsert(qc *qcode.QCode, w *bytes.Buffer,
}
c.w.WriteString(`(WITH "input" AS (SELECT {{`)
c.w.WriteString(root.ActionVar)
c.w.WriteString(qc.ActionVar)
c.w.WriteString(`}}::json AS j) INSERT INTO `)
quoted(c.w, ti.Name)
io.WriteString(c.w, ` (`)
@ -106,12 +107,18 @@ func (c *compilerContext) renderInsert(qc *qcode.QCode, w *bytes.Buffer,
func (c *compilerContext) renderInsertUpdateColumns(qc *qcode.QCode, w *bytes.Buffer,
jt map[string]interface{}, ti *DBTableInfo) (uint32, error) {
root := &qc.Selects[0]
i := 0
for _, cn := range ti.ColumnNames {
if _, ok := jt[cn]; !ok {
continue
}
if len(root.Allowed) != 0 {
if _, ok := root.Allowed[cn]; !ok {
continue
}
}
if i != 0 {
io.WriteString(c.w, `, `)
}
@ -126,9 +133,9 @@ func (c *compilerContext) renderUpdate(qc *qcode.QCode, w *bytes.Buffer,
vars Variables, ti *DBTableInfo) (uint32, error) {
root := &qc.Selects[0]
update, ok := vars[root.ActionVar]
update, ok := vars[qc.ActionVar]
if !ok {
return 0, fmt.Errorf("Variable '%s' not defined", root.ActionVar)
return 0, fmt.Errorf("Variable '%s' not defined", qc.ActionVar)
}
jt, array, err := jsn.Tree(update)
@ -137,7 +144,7 @@ func (c *compilerContext) renderUpdate(qc *qcode.QCode, w *bytes.Buffer,
}
c.w.WriteString(`(WITH "input" AS (SELECT {{`)
c.w.WriteString(root.ActionVar)
c.w.WriteString(qc.ActionVar)
c.w.WriteString(`}}::json AS j) UPDATE `)
quoted(c.w, ti.Name)
io.WriteString(c.w, ` SET (`)
@ -183,11 +190,10 @@ func (c *compilerContext) renderDelete(qc *qcode.QCode, w *bytes.Buffer,
func (c *compilerContext) renderUpsert(qc *qcode.QCode, w *bytes.Buffer,
vars Variables, ti *DBTableInfo) (uint32, error) {
root := &qc.Selects[0]
upsert, ok := vars[root.ActionVar]
upsert, ok := vars[qc.ActionVar]
if !ok {
return 0, fmt.Errorf("Variable '%s' not defined", root.ActionVar)
return 0, fmt.Errorf("Variable '%s' not defined", qc.ActionVar)
}
jt, _, err := jsn.Tree(upsert)

View File

@ -18,7 +18,7 @@ func simpleInsert(t *testing.T) {
"data": json.RawMessage(`{"email": "reannagreenholt@orn.com", "full_name": "Flo Barton"}`),
}
resSQL, err := compileGQLToPSQL(gql, vars)
resSQL, err := compileGQLToPSQL(gql, vars, "user")
if err != nil {
t.Fatal(err)
}
@ -42,7 +42,7 @@ func singleInsert(t *testing.T) {
"insert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc", "user_id": 5 }`),
}
resSQL, err := compileGQLToPSQL(gql, vars)
resSQL, err := compileGQLToPSQL(gql, vars, "user")
if err != nil {
t.Fatal(err)
}
@ -54,7 +54,7 @@ func singleInsert(t *testing.T) {
func bulkInsert(t *testing.T) {
gql := `mutation {
product(id: 15, insert: $insert) {
product(name: "test", id: 15, insert: $insert) {
id
name
}
@ -66,7 +66,7 @@ func bulkInsert(t *testing.T) {
"insert": json.RawMessage(` [{ "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }]`),
}
resSQL, err := compileGQLToPSQL(gql, vars)
resSQL, err := compileGQLToPSQL(gql, vars, "user")
if err != nil {
t.Fatal(err)
}
@ -90,7 +90,7 @@ func singleUpsert(t *testing.T) {
"upsert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`),
}
resSQL, err := compileGQLToPSQL(gql, vars)
resSQL, err := compileGQLToPSQL(gql, vars, "user")
if err != nil {
t.Fatal(err)
}
@ -114,7 +114,7 @@ func bulkUpsert(t *testing.T) {
"upsert": json.RawMessage(` [{ "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }]`),
}
resSQL, err := compileGQLToPSQL(gql, vars)
resSQL, err := compileGQLToPSQL(gql, vars, "user")
if err != nil {
t.Fatal(err)
}
@ -138,7 +138,7 @@ func singleUpdate(t *testing.T) {
"update": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`),
}
resSQL, err := compileGQLToPSQL(gql, vars)
resSQL, err := compileGQLToPSQL(gql, vars, "user")
if err != nil {
t.Fatal(err)
}
@ -162,7 +162,7 @@ func delete(t *testing.T) {
"update": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`),
}
resSQL, err := compileGQLToPSQL(gql, vars)
resSQL, err := compileGQLToPSQL(gql, vars, "user")
if err != nil {
t.Fatal(err)
}
@ -172,7 +172,7 @@ func delete(t *testing.T) {
}
}
func TestCompileInsert(t *testing.T) {
func TestCompileMutate(t *testing.T) {
t.Run("simpleInsert", simpleInsert)
t.Run("singleInsert", singleInsert)
t.Run("bulkInsert", bulkInsert)

View File

@ -64,11 +64,11 @@ func (co *Compiler) Compile(qc *qcode.QCode, w *bytes.Buffer, vars Variables) (u
switch qc.Type {
case qcode.QTQuery:
return co.compileQuery(qc, w)
case qcode.QTMutation:
case qcode.QTInsert, qcode.QTUpdate, qcode.QTDelete, qcode.QTUpsert:
return co.compileMutation(qc, w, vars)
}
return 0, errors.New("unknown operation")
return 0, fmt.Errorf("Unknown operation type %d", qc.Type)
}
func (co *Compiler) compileQuery(qc *qcode.QCode, w *bytes.Buffer) (uint32, error) {
@ -295,19 +295,21 @@ func (c *compilerContext) renderSelectClose(sel *qcode.Select, ti *DBTableInfo)
}
}
if sel.Action == 0 {
if len(sel.Paging.Limit) != 0 {
//fmt.Fprintf(w, ` LIMIT ('%s') :: integer`, c.sel.Paging.Limit)
c.w.WriteString(` LIMIT ('`)
c.w.WriteString(sel.Paging.Limit)
c.w.WriteString(`') :: integer`)
switch {
case sel.Paging.NoLimit:
break
} else if ti.Singular {
c.w.WriteString(` LIMIT ('1') :: integer`)
case len(sel.Paging.Limit) != 0:
//fmt.Fprintf(w, ` LIMIT ('%s') :: integer`, c.sel.Paging.Limit)
c.w.WriteString(` LIMIT ('`)
c.w.WriteString(sel.Paging.Limit)
c.w.WriteString(`') :: integer`)
} else {
c.w.WriteString(` LIMIT ('20') :: integer`)
}
case ti.Singular:
c.w.WriteString(` LIMIT ('1') :: integer`)
default:
c.w.WriteString(` LIMIT ('20') :: integer`)
}
if len(sel.Paging.Offset) != 0 {
@ -370,13 +372,31 @@ func (c *compilerContext) renderJoinTable(sel *qcode.Select) error {
}
func (c *compilerContext) renderColumns(sel *qcode.Select, ti *DBTableInfo) {
for i, col := range sel.Cols {
i := 0
for _, col := range sel.Cols {
if len(sel.Allowed) != 0 {
n := funcPrefixLen(col.Name)
if n != 0 {
if sel.Functions == false {
continue
}
if _, ok := sel.Allowed[col.Name[n:]]; !ok {
continue
}
} else {
if _, ok := sel.Allowed[col.Name]; !ok {
continue
}
}
}
if i != 0 {
io.WriteString(c.w, ", ")
}
//fmt.Fprintf(w, `"%s_%d"."%s" AS "%s"`,
//c.sel.Table, c.sel.ID, col.Name, col.FieldName)
colWithTableIDAlias(c.w, ti.Name, sel.ID, col.Name, col.FieldName)
i++
}
}
@ -435,7 +455,8 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
c.w.WriteString(` FROM (SELECT `)
for i, col := range sel.Cols {
i := 0
for n, col := range sel.Cols {
cn := col.Name
_, isRealCol := ti.Columns[cn]
@ -447,6 +468,9 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
cn = ti.TSVCol
arg := sel.Args["search"]
if i != 0 {
c.w.WriteString(`, `)
}
//fmt.Fprintf(w, `ts_rank("%s"."%s", to_tsquery('%s')) AS %s`,
//c.sel.Table, cn, arg.Val, col.Name)
c.w.WriteString(`ts_rank(`)
@ -455,11 +479,15 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
c.w.WriteString(arg.Val)
c.w.WriteString(`')`)
alias(c.w, col.Name)
i++
case strings.HasPrefix(cn, "search_headline_"):
cn = cn[16:]
arg := sel.Args["search"]
if i != 0 {
c.w.WriteString(`, `)
}
//fmt.Fprintf(w, `ts_headline("%s"."%s", to_tsquery('%s')) AS %s`,
//c.sel.Table, cn, arg.Val, col.Name)
c.w.WriteString(`ts_headlinek(`)
@ -468,47 +496,63 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
c.w.WriteString(arg.Val)
c.w.WriteString(`')`)
alias(c.w, col.Name)
i++
}
} else {
pl := funcPrefixLen(cn)
if pl == 0 {
if i != 0 {
c.w.WriteString(`, `)
}
//fmt.Fprintf(w, `'%s not defined' AS %s`, cn, col.Name)
c.w.WriteString(`'`)
c.w.WriteString(cn)
c.w.WriteString(` not defined'`)
alias(c.w, col.Name)
} else {
isAgg = true
i++
} else if sel.Functions {
cn1 := cn[pl:]
if _, ok := sel.Allowed[cn1]; !ok {
continue
}
if i != 0 {
c.w.WriteString(`, `)
}
fn := cn[0 : pl-1]
cn := cn[pl:]
isAgg = true
//fmt.Fprintf(w, `%s("%s"."%s") AS %s`, fn, c.sel.Table, cn, col.Name)
c.w.WriteString(fn)
c.w.WriteString(`(`)
colWithTable(c.w, ti.Name, cn)
colWithTable(c.w, ti.Name, cn1)
c.w.WriteString(`)`)
alias(c.w, col.Name)
i++
}
}
} else {
groupBy = append(groupBy, i)
groupBy = append(groupBy, n)
//fmt.Fprintf(w, `"%s"."%s"`, c.sel.Table, cn)
if i != 0 {
c.w.WriteString(`, `)
}
colWithTable(c.w, ti.Name, cn)
}
i++
if i < len(sel.Cols)-1 || len(childCols) != 0 {
//io.WriteString(w, ", ")
c.w.WriteString(`, `)
}
}
for i, col := range childCols {
for _, col := range childCols {
if i != 0 {
//io.WriteString(w, ", ")
c.w.WriteString(`, `)
}
//fmt.Fprintf(w, `"%s"."%s"`, col.Table, col.Name)
colWithTable(c.w, col.Table, col.Name)
i++
}
c.w.WriteString(` FROM `)
@ -570,19 +614,21 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
}
}
if sel.Action == 0 {
if len(sel.Paging.Limit) != 0 {
//fmt.Fprintf(w, ` LIMIT ('%s') :: integer`, c.sel.Paging.Limit)
c.w.WriteString(` LIMIT ('`)
c.w.WriteString(sel.Paging.Limit)
c.w.WriteString(`') :: integer`)
switch {
case sel.Paging.NoLimit:
break
} else if ti.Singular {
c.w.WriteString(` LIMIT ('1') :: integer`)
case len(sel.Paging.Limit) != 0:
//fmt.Fprintf(w, ` LIMIT ('%s') :: integer`, c.sel.Paging.Limit)
c.w.WriteString(` LIMIT ('`)
c.w.WriteString(sel.Paging.Limit)
c.w.WriteString(`') :: integer`)
} else {
c.w.WriteString(` LIMIT ('20') :: integer`)
}
case ti.Singular:
c.w.WriteString(` LIMIT ('1') :: integer`)
default:
c.w.WriteString(` LIMIT ('20') :: integer`)
}
if len(sel.Paging.Offset) != 0 {

View File

@ -22,32 +22,6 @@ func TestMain(m *testing.M) {
var err error
qcompile, err = qcode.NewCompiler(qcode.Config{
DefaultFilter: []string{
`{ user_id: { _eq: $user_id } }`,
},
FilterMap: qcode.Filters{
All: map[string][]string{
"users": []string{
"{ id: { eq: $user_id } }",
},
"products": []string{
"{ price: { gt: 0 } }",
"{ price: { lt: 8 } }",
},
"customers": []string{},
"mes": []string{
"{ id: { eq: $user_id } }",
},
},
Query: map[string][]string{
"users": []string{},
},
Update: map[string][]string{
"products": []string{
"{ user_id: { eq: $user_id } }",
},
},
},
Blocklist: []string{
"secret",
"password",
@ -55,6 +29,59 @@ func TestMain(m *testing.M) {
},
})
qcompile.AddRole("user", "product", qcode.TRConfig{
Query: qcode.QueryConfig{
Columns: []string{"id", "name", "price", "users", "customers"},
Filter: []string{
"{ price: { gt: 0 } }",
"{ price: { lt: 8 } }",
},
},
Update: qcode.UpdateConfig{
Filter: []string{"{ user_id: { eq: $user_id } }"},
},
Delete: qcode.DeleteConfig{
Filter: []string{
"{ price: { gt: 0 } }",
"{ price: { lt: 8 } }",
},
},
})
qcompile.AddRole("anon", "product", qcode.TRConfig{
Query: qcode.QueryConfig{
Columns: []string{"id", "name"},
},
})
qcompile.AddRole("anon1", "product", qcode.TRConfig{
Query: qcode.QueryConfig{
Columns: []string{"id", "name", "price"},
DisableFunctions: true,
},
})
qcompile.AddRole("user", "users", qcode.TRConfig{
Query: qcode.QueryConfig{
Columns: []string{"id", "full_name", "avatar", "email", "products"},
},
})
qcompile.AddRole("user", "mes", qcode.TRConfig{
Query: qcode.QueryConfig{
Columns: []string{"id", "full_name", "avatar"},
Filter: []string{
"{ id: { eq: $user_id } }",
},
},
})
qcompile.AddRole("user", "customers", qcode.TRConfig{
Query: qcode.QueryConfig{
Columns: []string{"id", "email", "full_name", "products"},
},
})
if err != nil {
log.Fatal(err)
}
@ -135,9 +162,8 @@ func TestMain(m *testing.M) {
os.Exit(m.Run())
}
func compileGQLToPSQL(gql string, vars Variables) ([]byte, error) {
qc, err := qcompile.Compile([]byte(gql))
func compileGQLToPSQL(gql string, vars Variables, role string) ([]byte, error) {
qc, err := qcompile.Compile([]byte(gql), role)
if err != nil {
return nil, err
}
@ -147,6 +173,8 @@ func compileGQLToPSQL(gql string, vars Variables) ([]byte, error) {
return nil, err
}
//fmt.Println(string(sqlStmt))
return sqlStmt, nil
}
@ -175,7 +203,7 @@ func withComplexArgs(t *testing.T) {
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0" ORDER BY "products_0_price_ob" DESC), '[]') AS "products" FROM (SELECT DISTINCT ON ("products_0_price_ob") row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0", "products_0"."price" AS "products_0_price_ob" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") < 28) AND (("products"."id") >= 20)) LIMIT ('30') :: integer) AS "products_0" ORDER BY "products_0_price_ob" DESC LIMIT ('30') :: integer) AS "sel_json_agg_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql, nil)
resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil {
t.Fatal(err)
}
@ -203,7 +231,7 @@ func withWhereMultiOr(t *testing.T) {
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") < 20) OR (("products"."price") > 10) OR NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql, nil)
resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil {
t.Fatal(err)
}
@ -229,7 +257,7 @@ func withWhereIsNull(t *testing.T) {
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") > 10) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql, nil)
resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil {
t.Fatal(err)
}
@ -255,7 +283,7 @@ func withWhereAndList(t *testing.T) {
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") > 10) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql, nil)
resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil {
t.Fatal(err)
}
@ -275,7 +303,7 @@ func fetchByID(t *testing.T) {
sql := `SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") = 15)) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql, nil)
resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil {
t.Fatal(err)
}
@ -295,7 +323,7 @@ func searchQuery(t *testing.T) {
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("tsv") @@ to_tsquery('Imperial'))) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql, nil)
resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil {
t.Fatal(err)
}
@ -318,7 +346,7 @@ func oneToMany(t *testing.T) {
sql := `SELECT json_object_agg('users', users) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."email" AS "email", "products_1_join"."products" AS "products") AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."email", "users"."id" FROM "users" LIMIT ('20') :: integer) AS "users_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "products_1"."name" AS "name", "products_1"."price" AS "price") AS "sel_1")) AS "sel_json_1" FROM (SELECT "products"."name", "products"."price" FROM "products" WHERE ((("products"."user_id") = ("users_0"."id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql, nil)
resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil {
t.Fatal(err)
}
@ -341,7 +369,7 @@ func belongsTo(t *testing.T) {
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."price" AS "price", "users_1_join"."users" AS "users") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", "products"."price", "products"."user_id" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "users_1"."email" AS "email") AS "sel_1")) AS "sel_json_1" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = ("products_0"."user_id"))) LIMIT ('20') :: integer) AS "users_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "users_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql, nil)
resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil {
t.Fatal(err)
}
@ -364,7 +392,7 @@ func manyToMany(t *testing.T) {
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "customers_1_join"."customers" AS "customers") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", "products"."id" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "customers_1"."email" AS "email", "customers_1"."full_name" AS "full_name") AS "sel_1")) AS "sel_json_1" FROM (SELECT "customers"."email", "customers"."full_name" FROM "customers" LEFT OUTER JOIN "purchases" ON (("purchases"."product_id") = ("products_0"."id")) WHERE ((("customers"."id") = ("purchases"."customer_id"))) LIMIT ('20') :: integer) AS "customers_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "customers_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql, nil)
resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil {
t.Fatal(err)
}
@ -387,7 +415,7 @@ func manyToManyReverse(t *testing.T) {
sql := `SELECT json_object_agg('customers', customers) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "customers_0"."email" AS "email", "customers_0"."full_name" AS "full_name", "products_1_join"."products" AS "products") AS "sel_0")) AS "sel_json_0" FROM (SELECT "customers"."email", "customers"."full_name", "customers"."id" FROM "customers" LIMIT ('20') :: integer) AS "customers_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "products_1"."name" AS "name") AS "sel_1")) AS "sel_json_1" FROM (SELECT "products"."name" FROM "products" LEFT OUTER JOIN "purchases" ON (("purchases"."customer_id") = ("customers_0"."id")) WHERE ((("products"."id") = ("purchases"."product_id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql, nil)
resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil {
t.Fatal(err)
}
@ -407,7 +435,47 @@ func aggFunction(t *testing.T) {
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."count_price" AS "count_price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", count("products"."price") AS "count_price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) GROUP BY "products"."name" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql, nil)
resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil {
t.Fatal(err)
}
if string(resSQL) != sql {
t.Fatal(errNotExpected)
}
}
func aggFunctionBlockedByCol(t *testing.T) {
gql := `query {
products {
name
count_price
}
}`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name" FROM "products" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql, nil, "anon")
if err != nil {
t.Fatal(err)
}
if string(resSQL) != sql {
t.Fatal(errNotExpected)
}
}
func aggFunctionDisabled(t *testing.T) {
gql := `query {
products {
name
count_price
}
}`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name" FROM "products" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql, nil, "anon1")
if err != nil {
t.Fatal(err)
}
@ -427,7 +495,7 @@ func aggFunctionWithFilter(t *testing.T) {
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."max_price" AS "max_price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", max("products"."price") AS "max_price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") > 10)) GROUP BY "products"."id" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql, nil)
resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil {
t.Fatal(err)
}
@ -447,7 +515,7 @@ func queryWithVariables(t *testing.T) {
sql := `SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") = {{product_price}}) AND (("products"."id") = {{product_id}})) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql, nil)
resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil {
t.Fatal(err)
}
@ -464,9 +532,9 @@ func syntheticTables(t *testing.T) {
}
}`
sql := `SELECT json_object_agg('me', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."email" AS "email") AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = {{user_id}})) LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "done_1337";`
sql := `SELECT json_object_agg('me', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT ) AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = {{user_id}})) LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql, nil)
resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil {
t.Fatal(err)
}
@ -476,7 +544,7 @@ func syntheticTables(t *testing.T) {
}
}
func TestCompileSelect(t *testing.T) {
func TestCompileQuery(t *testing.T) {
t.Run("withComplexArgs", withComplexArgs)
t.Run("withWhereAndList", withWhereAndList)
t.Run("withWhereIsNull", withWhereIsNull)
@ -488,10 +556,11 @@ func TestCompileSelect(t *testing.T) {
t.Run("manyToMany", manyToMany)
t.Run("manyToManyReverse", manyToManyReverse)
t.Run("aggFunction", aggFunction)
t.Run("aggFunctionBlockedByCol", aggFunctionBlockedByCol)
t.Run("aggFunctionDisabled", aggFunctionDisabled)
t.Run("aggFunctionWithFilter", aggFunctionWithFilter)
t.Run("syntheticTables", syntheticTables)
t.Run("queryWithVariables", queryWithVariables)
}
var benchGQL = []byte(`query {
@ -526,7 +595,7 @@ func BenchmarkCompile(b *testing.B) {
for n := 0; n < b.N; n++ {
w.Reset()
qc, err := qcompile.Compile(benchGQL)
qc, err := qcompile.Compile(benchGQL, "user")
if err != nil {
b.Fatal(err)
}
@ -547,7 +616,7 @@ func BenchmarkCompileParallel(b *testing.B) {
for pb.Next() {
w.Reset()
qc, err := qcompile.Compile(benchGQL)
qc, err := qcompile.Compile(benchGQL, "user")
if err != nil {
b.Fatal(err)
}