Add role based access control

This commit is contained in:
Vikram Rangnekar 2019-10-14 02:51:36 -04:00
parent 85a74ed30c
commit deb5b93c81
13 changed files with 645 additions and 350 deletions

View File

@ -115,9 +115,6 @@ tables:
- name: users - name: users
# This filter will overwrite defaults.filter # This filter will overwrite defaults.filter
# filter: ["{ id: { eq: $user_id } }"] # filter: ["{ id: { eq: $user_id } }"]
# filter_query: ["{ id: { eq: $user_id } }"]
filter_update: ["{ id: { eq: $user_id } }"]
filter_delete: ["{ id: { eq: $user_id } }"]
# - name: products # - name: products
# # Multiple filters are AND'd together # # Multiple filters are AND'd together
@ -127,10 +124,6 @@ tables:
# ] # ]
- name: customers - name: customers
# No filter is used for this field not
# even defaults.filter
filter: none
remotes: remotes:
- name: payments - name: payments
id: stripe_id id: stripe_id
@ -149,7 +142,56 @@ tables:
# real db table backing them # real db table backing them
name: me name: me
table: users table: users
filter: ["{ id: { eq: $user_id } }"]
# - name: posts roles:
# filter: ["{ account_id: { _eq: $account_id } }"] - name: anon
tables:
- name: products
limit: 10
query:
columns: ["id", "name", "description" ]
aggregation: false
insert:
allow: false
update:
allow: false
delete:
allow: false
- name: user
tables:
- name: products
query:
limit: 50
filter: ["{ user_id: { eq: $user_id } }"]
columns: ["id", "name", "description" ]
disable_aggregation: false
insert:
filter: ["{ user_id: { eq: $user_id } }"]
columns: ["id", "name", "description" ]
set:
- created_at: "now"
update:
filter: ["{ user_id: { eq: $user_id } }"]
columns:
- id
- name
set:
- updated_at: "now"
delete:
deny: true
- name: manager
tables:
- name: users
select:
filter: ["{ account_id: { _eq: $account_id } }"]

View File

@ -10,7 +10,7 @@ import (
"github.com/dosco/super-graph/qcode" "github.com/dosco/super-graph/qcode"
) )
var zeroPaging = qcode.Paging{} var noLimit = qcode.Paging{NoLimit: true}
func (co *Compiler) compileMutation(qc *qcode.QCode, w *bytes.Buffer, vars Variables) (uint32, error) { func (co *Compiler) compileMutation(qc *qcode.QCode, w *bytes.Buffer, vars Variables) (uint32, error) {
if len(qc.Selects) == 0 { if len(qc.Selects) == 0 {
@ -29,23 +29,23 @@ func (co *Compiler) compileMutation(qc *qcode.QCode, w *bytes.Buffer, vars Varia
quoted(c.w, ti.Name) quoted(c.w, ti.Name)
c.w.WriteString(` AS `) c.w.WriteString(` AS `)
switch root.Action { switch qc.Type {
case qcode.ActionInsert: case qcode.QTInsert:
if _, err := c.renderInsert(qc, w, vars, ti); err != nil { if _, err := c.renderInsert(qc, w, vars, ti); err != nil {
return 0, err return 0, err
} }
case qcode.ActionUpdate: case qcode.QTUpdate:
if _, err := c.renderUpdate(qc, w, vars, ti); err != nil { if _, err := c.renderUpdate(qc, w, vars, ti); err != nil {
return 0, err return 0, err
} }
case qcode.ActionUpsert: case qcode.QTUpsert:
if _, err := c.renderUpsert(qc, w, vars, ti); err != nil { if _, err := c.renderUpsert(qc, w, vars, ti); err != nil {
return 0, err return 0, err
} }
case qcode.ActionDelete: case qcode.QTDelete:
if _, err := c.renderDelete(qc, w, vars, ti); err != nil { if _, err := c.renderDelete(qc, w, vars, ti); err != nil {
return 0, err return 0, err
} }
@ -56,22 +56,23 @@ func (co *Compiler) compileMutation(qc *qcode.QCode, w *bytes.Buffer, vars Varia
io.WriteString(c.w, ` RETURNING *) `) io.WriteString(c.w, ` RETURNING *) `)
root.Paging = zeroPaging root.Paging = noLimit
root.DistinctOn = root.DistinctOn[:] root.DistinctOn = root.DistinctOn[:]
root.OrderBy = root.OrderBy[:] root.OrderBy = root.OrderBy[:]
root.Where = nil root.Where = nil
root.Args = nil root.Args = nil
qc.Type = qcode.QTQuery
return c.compileQuery(qc, w) return c.compileQuery(qc, w)
} }
func (c *compilerContext) renderInsert(qc *qcode.QCode, w *bytes.Buffer, func (c *compilerContext) renderInsert(qc *qcode.QCode, w *bytes.Buffer,
vars Variables, ti *DBTableInfo) (uint32, error) { vars Variables, ti *DBTableInfo) (uint32, error) {
root := &qc.Selects[0]
insert, ok := vars[root.ActionVar] insert, ok := vars[qc.ActionVar]
if !ok { if !ok {
return 0, fmt.Errorf("Variable '%s' not defined", root.ActionVar) return 0, fmt.Errorf("Variable '%s' not defined", qc.ActionVar)
} }
jt, array, err := jsn.Tree(insert) jt, array, err := jsn.Tree(insert)
@ -80,7 +81,7 @@ func (c *compilerContext) renderInsert(qc *qcode.QCode, w *bytes.Buffer,
} }
c.w.WriteString(`(WITH "input" AS (SELECT {{`) c.w.WriteString(`(WITH "input" AS (SELECT {{`)
c.w.WriteString(root.ActionVar) c.w.WriteString(qc.ActionVar)
c.w.WriteString(`}}::json AS j) INSERT INTO `) c.w.WriteString(`}}::json AS j) INSERT INTO `)
quoted(c.w, ti.Name) quoted(c.w, ti.Name)
io.WriteString(c.w, ` (`) io.WriteString(c.w, ` (`)
@ -106,12 +107,18 @@ func (c *compilerContext) renderInsert(qc *qcode.QCode, w *bytes.Buffer,
func (c *compilerContext) renderInsertUpdateColumns(qc *qcode.QCode, w *bytes.Buffer, func (c *compilerContext) renderInsertUpdateColumns(qc *qcode.QCode, w *bytes.Buffer,
jt map[string]interface{}, ti *DBTableInfo) (uint32, error) { jt map[string]interface{}, ti *DBTableInfo) (uint32, error) {
root := &qc.Selects[0]
i := 0 i := 0
for _, cn := range ti.ColumnNames { for _, cn := range ti.ColumnNames {
if _, ok := jt[cn]; !ok { if _, ok := jt[cn]; !ok {
continue continue
} }
if len(root.Allowed) != 0 {
if _, ok := root.Allowed[cn]; !ok {
continue
}
}
if i != 0 { if i != 0 {
io.WriteString(c.w, `, `) io.WriteString(c.w, `, `)
} }
@ -126,9 +133,9 @@ func (c *compilerContext) renderUpdate(qc *qcode.QCode, w *bytes.Buffer,
vars Variables, ti *DBTableInfo) (uint32, error) { vars Variables, ti *DBTableInfo) (uint32, error) {
root := &qc.Selects[0] root := &qc.Selects[0]
update, ok := vars[root.ActionVar] update, ok := vars[qc.ActionVar]
if !ok { if !ok {
return 0, fmt.Errorf("Variable '%s' not defined", root.ActionVar) return 0, fmt.Errorf("Variable '%s' not defined", qc.ActionVar)
} }
jt, array, err := jsn.Tree(update) jt, array, err := jsn.Tree(update)
@ -137,7 +144,7 @@ func (c *compilerContext) renderUpdate(qc *qcode.QCode, w *bytes.Buffer,
} }
c.w.WriteString(`(WITH "input" AS (SELECT {{`) c.w.WriteString(`(WITH "input" AS (SELECT {{`)
c.w.WriteString(root.ActionVar) c.w.WriteString(qc.ActionVar)
c.w.WriteString(`}}::json AS j) UPDATE `) c.w.WriteString(`}}::json AS j) UPDATE `)
quoted(c.w, ti.Name) quoted(c.w, ti.Name)
io.WriteString(c.w, ` SET (`) io.WriteString(c.w, ` SET (`)
@ -183,11 +190,10 @@ func (c *compilerContext) renderDelete(qc *qcode.QCode, w *bytes.Buffer,
func (c *compilerContext) renderUpsert(qc *qcode.QCode, w *bytes.Buffer, func (c *compilerContext) renderUpsert(qc *qcode.QCode, w *bytes.Buffer,
vars Variables, ti *DBTableInfo) (uint32, error) { vars Variables, ti *DBTableInfo) (uint32, error) {
root := &qc.Selects[0]
upsert, ok := vars[root.ActionVar] upsert, ok := vars[qc.ActionVar]
if !ok { if !ok {
return 0, fmt.Errorf("Variable '%s' not defined", root.ActionVar) return 0, fmt.Errorf("Variable '%s' not defined", qc.ActionVar)
} }
jt, _, err := jsn.Tree(upsert) jt, _, err := jsn.Tree(upsert)

View File

@ -18,7 +18,7 @@ func simpleInsert(t *testing.T) {
"data": json.RawMessage(`{"email": "reannagreenholt@orn.com", "full_name": "Flo Barton"}`), "data": json.RawMessage(`{"email": "reannagreenholt@orn.com", "full_name": "Flo Barton"}`),
} }
resSQL, err := compileGQLToPSQL(gql, vars) resSQL, err := compileGQLToPSQL(gql, vars, "user")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -42,7 +42,7 @@ func singleInsert(t *testing.T) {
"insert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc", "user_id": 5 }`), "insert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc", "user_id": 5 }`),
} }
resSQL, err := compileGQLToPSQL(gql, vars) resSQL, err := compileGQLToPSQL(gql, vars, "user")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -54,7 +54,7 @@ func singleInsert(t *testing.T) {
func bulkInsert(t *testing.T) { func bulkInsert(t *testing.T) {
gql := `mutation { gql := `mutation {
product(id: 15, insert: $insert) { product(name: "test", id: 15, insert: $insert) {
id id
name name
} }
@ -66,7 +66,7 @@ func bulkInsert(t *testing.T) {
"insert": json.RawMessage(` [{ "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }]`), "insert": json.RawMessage(` [{ "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }]`),
} }
resSQL, err := compileGQLToPSQL(gql, vars) resSQL, err := compileGQLToPSQL(gql, vars, "user")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -90,7 +90,7 @@ func singleUpsert(t *testing.T) {
"upsert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`), "upsert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`),
} }
resSQL, err := compileGQLToPSQL(gql, vars) resSQL, err := compileGQLToPSQL(gql, vars, "user")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -114,7 +114,7 @@ func bulkUpsert(t *testing.T) {
"upsert": json.RawMessage(` [{ "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }]`), "upsert": json.RawMessage(` [{ "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }]`),
} }
resSQL, err := compileGQLToPSQL(gql, vars) resSQL, err := compileGQLToPSQL(gql, vars, "user")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -138,7 +138,7 @@ func singleUpdate(t *testing.T) {
"update": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`), "update": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`),
} }
resSQL, err := compileGQLToPSQL(gql, vars) resSQL, err := compileGQLToPSQL(gql, vars, "user")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -162,7 +162,7 @@ func delete(t *testing.T) {
"update": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`), "update": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`),
} }
resSQL, err := compileGQLToPSQL(gql, vars) resSQL, err := compileGQLToPSQL(gql, vars, "user")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -172,7 +172,7 @@ func delete(t *testing.T) {
} }
} }
func TestCompileInsert(t *testing.T) { func TestCompileMutate(t *testing.T) {
t.Run("simpleInsert", simpleInsert) t.Run("simpleInsert", simpleInsert)
t.Run("singleInsert", singleInsert) t.Run("singleInsert", singleInsert)
t.Run("bulkInsert", bulkInsert) t.Run("bulkInsert", bulkInsert)

View File

@ -64,11 +64,11 @@ func (co *Compiler) Compile(qc *qcode.QCode, w *bytes.Buffer, vars Variables) (u
switch qc.Type { switch qc.Type {
case qcode.QTQuery: case qcode.QTQuery:
return co.compileQuery(qc, w) return co.compileQuery(qc, w)
case qcode.QTMutation: case qcode.QTInsert, qcode.QTUpdate, qcode.QTDelete, qcode.QTUpsert:
return co.compileMutation(qc, w, vars) return co.compileMutation(qc, w, vars)
} }
return 0, errors.New("unknown operation") return 0, fmt.Errorf("Unknown operation type %d", qc.Type)
} }
func (co *Compiler) compileQuery(qc *qcode.QCode, w *bytes.Buffer) (uint32, error) { func (co *Compiler) compileQuery(qc *qcode.QCode, w *bytes.Buffer) (uint32, error) {
@ -295,19 +295,21 @@ func (c *compilerContext) renderSelectClose(sel *qcode.Select, ti *DBTableInfo)
} }
} }
if sel.Action == 0 { switch {
if len(sel.Paging.Limit) != 0 { case sel.Paging.NoLimit:
//fmt.Fprintf(w, ` LIMIT ('%s') :: integer`, c.sel.Paging.Limit) break
c.w.WriteString(` LIMIT ('`)
c.w.WriteString(sel.Paging.Limit)
c.w.WriteString(`') :: integer`)
} else if ti.Singular { case len(sel.Paging.Limit) != 0:
c.w.WriteString(` LIMIT ('1') :: integer`) //fmt.Fprintf(w, ` LIMIT ('%s') :: integer`, c.sel.Paging.Limit)
c.w.WriteString(` LIMIT ('`)
c.w.WriteString(sel.Paging.Limit)
c.w.WriteString(`') :: integer`)
} else { case ti.Singular:
c.w.WriteString(` LIMIT ('20') :: integer`) c.w.WriteString(` LIMIT ('1') :: integer`)
}
default:
c.w.WriteString(` LIMIT ('20') :: integer`)
} }
if len(sel.Paging.Offset) != 0 { if len(sel.Paging.Offset) != 0 {
@ -370,13 +372,31 @@ func (c *compilerContext) renderJoinTable(sel *qcode.Select) error {
} }
func (c *compilerContext) renderColumns(sel *qcode.Select, ti *DBTableInfo) { func (c *compilerContext) renderColumns(sel *qcode.Select, ti *DBTableInfo) {
for i, col := range sel.Cols { i := 0
for _, col := range sel.Cols {
if len(sel.Allowed) != 0 {
n := funcPrefixLen(col.Name)
if n != 0 {
if sel.Functions == false {
continue
}
if _, ok := sel.Allowed[col.Name[n:]]; !ok {
continue
}
} else {
if _, ok := sel.Allowed[col.Name]; !ok {
continue
}
}
}
if i != 0 { if i != 0 {
io.WriteString(c.w, ", ") io.WriteString(c.w, ", ")
} }
//fmt.Fprintf(w, `"%s_%d"."%s" AS "%s"`, //fmt.Fprintf(w, `"%s_%d"."%s" AS "%s"`,
//c.sel.Table, c.sel.ID, col.Name, col.FieldName) //c.sel.Table, c.sel.ID, col.Name, col.FieldName)
colWithTableIDAlias(c.w, ti.Name, sel.ID, col.Name, col.FieldName) colWithTableIDAlias(c.w, ti.Name, sel.ID, col.Name, col.FieldName)
i++
} }
} }
@ -435,7 +455,8 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
c.w.WriteString(` FROM (SELECT `) c.w.WriteString(` FROM (SELECT `)
for i, col := range sel.Cols { i := 0
for n, col := range sel.Cols {
cn := col.Name cn := col.Name
_, isRealCol := ti.Columns[cn] _, isRealCol := ti.Columns[cn]
@ -447,6 +468,9 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
cn = ti.TSVCol cn = ti.TSVCol
arg := sel.Args["search"] arg := sel.Args["search"]
if i != 0 {
c.w.WriteString(`, `)
}
//fmt.Fprintf(w, `ts_rank("%s"."%s", to_tsquery('%s')) AS %s`, //fmt.Fprintf(w, `ts_rank("%s"."%s", to_tsquery('%s')) AS %s`,
//c.sel.Table, cn, arg.Val, col.Name) //c.sel.Table, cn, arg.Val, col.Name)
c.w.WriteString(`ts_rank(`) c.w.WriteString(`ts_rank(`)
@ -455,11 +479,15 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
c.w.WriteString(arg.Val) c.w.WriteString(arg.Val)
c.w.WriteString(`')`) c.w.WriteString(`')`)
alias(c.w, col.Name) alias(c.w, col.Name)
i++
case strings.HasPrefix(cn, "search_headline_"): case strings.HasPrefix(cn, "search_headline_"):
cn = cn[16:] cn = cn[16:]
arg := sel.Args["search"] arg := sel.Args["search"]
if i != 0 {
c.w.WriteString(`, `)
}
//fmt.Fprintf(w, `ts_headline("%s"."%s", to_tsquery('%s')) AS %s`, //fmt.Fprintf(w, `ts_headline("%s"."%s", to_tsquery('%s')) AS %s`,
//c.sel.Table, cn, arg.Val, col.Name) //c.sel.Table, cn, arg.Val, col.Name)
c.w.WriteString(`ts_headlinek(`) c.w.WriteString(`ts_headlinek(`)
@ -468,47 +496,63 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
c.w.WriteString(arg.Val) c.w.WriteString(arg.Val)
c.w.WriteString(`')`) c.w.WriteString(`')`)
alias(c.w, col.Name) alias(c.w, col.Name)
i++
} }
} else { } else {
pl := funcPrefixLen(cn) pl := funcPrefixLen(cn)
if pl == 0 { if pl == 0 {
if i != 0 {
c.w.WriteString(`, `)
}
//fmt.Fprintf(w, `'%s not defined' AS %s`, cn, col.Name) //fmt.Fprintf(w, `'%s not defined' AS %s`, cn, col.Name)
c.w.WriteString(`'`) c.w.WriteString(`'`)
c.w.WriteString(cn) c.w.WriteString(cn)
c.w.WriteString(` not defined'`) c.w.WriteString(` not defined'`)
alias(c.w, col.Name) alias(c.w, col.Name)
} else { i++
isAgg = true
} else if sel.Functions {
cn1 := cn[pl:]
if _, ok := sel.Allowed[cn1]; !ok {
continue
}
if i != 0 {
c.w.WriteString(`, `)
}
fn := cn[0 : pl-1] fn := cn[0 : pl-1]
cn := cn[pl:] isAgg = true
//fmt.Fprintf(w, `%s("%s"."%s") AS %s`, fn, c.sel.Table, cn, col.Name) //fmt.Fprintf(w, `%s("%s"."%s") AS %s`, fn, c.sel.Table, cn, col.Name)
c.w.WriteString(fn) c.w.WriteString(fn)
c.w.WriteString(`(`) c.w.WriteString(`(`)
colWithTable(c.w, ti.Name, cn) colWithTable(c.w, ti.Name, cn1)
c.w.WriteString(`)`) c.w.WriteString(`)`)
alias(c.w, col.Name) alias(c.w, col.Name)
i++
} }
} }
} else { } else {
groupBy = append(groupBy, i) groupBy = append(groupBy, n)
//fmt.Fprintf(w, `"%s"."%s"`, c.sel.Table, cn) //fmt.Fprintf(w, `"%s"."%s"`, c.sel.Table, cn)
if i != 0 {
c.w.WriteString(`, `)
}
colWithTable(c.w, ti.Name, cn) colWithTable(c.w, ti.Name, cn)
} i++
if i < len(sel.Cols)-1 || len(childCols) != 0 {
//io.WriteString(w, ", ")
c.w.WriteString(`, `)
} }
} }
for i, col := range childCols { for _, col := range childCols {
if i != 0 { if i != 0 {
//io.WriteString(w, ", ")
c.w.WriteString(`, `) c.w.WriteString(`, `)
} }
//fmt.Fprintf(w, `"%s"."%s"`, col.Table, col.Name) //fmt.Fprintf(w, `"%s"."%s"`, col.Table, col.Name)
colWithTable(c.w, col.Table, col.Name) colWithTable(c.w, col.Table, col.Name)
i++
} }
c.w.WriteString(` FROM `) c.w.WriteString(` FROM `)
@ -570,19 +614,21 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
} }
} }
if sel.Action == 0 { switch {
if len(sel.Paging.Limit) != 0 { case sel.Paging.NoLimit:
//fmt.Fprintf(w, ` LIMIT ('%s') :: integer`, c.sel.Paging.Limit) break
c.w.WriteString(` LIMIT ('`)
c.w.WriteString(sel.Paging.Limit)
c.w.WriteString(`') :: integer`)
} else if ti.Singular { case len(sel.Paging.Limit) != 0:
c.w.WriteString(` LIMIT ('1') :: integer`) //fmt.Fprintf(w, ` LIMIT ('%s') :: integer`, c.sel.Paging.Limit)
c.w.WriteString(` LIMIT ('`)
c.w.WriteString(sel.Paging.Limit)
c.w.WriteString(`') :: integer`)
} else { case ti.Singular:
c.w.WriteString(` LIMIT ('20') :: integer`) c.w.WriteString(` LIMIT ('1') :: integer`)
}
default:
c.w.WriteString(` LIMIT ('20') :: integer`)
} }
if len(sel.Paging.Offset) != 0 { if len(sel.Paging.Offset) != 0 {

View File

@ -22,32 +22,6 @@ func TestMain(m *testing.M) {
var err error var err error
qcompile, err = qcode.NewCompiler(qcode.Config{ qcompile, err = qcode.NewCompiler(qcode.Config{
DefaultFilter: []string{
`{ user_id: { _eq: $user_id } }`,
},
FilterMap: qcode.Filters{
All: map[string][]string{
"users": []string{
"{ id: { eq: $user_id } }",
},
"products": []string{
"{ price: { gt: 0 } }",
"{ price: { lt: 8 } }",
},
"customers": []string{},
"mes": []string{
"{ id: { eq: $user_id } }",
},
},
Query: map[string][]string{
"users": []string{},
},
Update: map[string][]string{
"products": []string{
"{ user_id: { eq: $user_id } }",
},
},
},
Blocklist: []string{ Blocklist: []string{
"secret", "secret",
"password", "password",
@ -55,6 +29,59 @@ func TestMain(m *testing.M) {
}, },
}) })
qcompile.AddRole("user", "product", qcode.TRConfig{
Query: qcode.QueryConfig{
Columns: []string{"id", "name", "price", "users", "customers"},
Filter: []string{
"{ price: { gt: 0 } }",
"{ price: { lt: 8 } }",
},
},
Update: qcode.UpdateConfig{
Filter: []string{"{ user_id: { eq: $user_id } }"},
},
Delete: qcode.DeleteConfig{
Filter: []string{
"{ price: { gt: 0 } }",
"{ price: { lt: 8 } }",
},
},
})
qcompile.AddRole("anon", "product", qcode.TRConfig{
Query: qcode.QueryConfig{
Columns: []string{"id", "name"},
},
})
qcompile.AddRole("anon1", "product", qcode.TRConfig{
Query: qcode.QueryConfig{
Columns: []string{"id", "name", "price"},
DisableFunctions: true,
},
})
qcompile.AddRole("user", "users", qcode.TRConfig{
Query: qcode.QueryConfig{
Columns: []string{"id", "full_name", "avatar", "email", "products"},
},
})
qcompile.AddRole("user", "mes", qcode.TRConfig{
Query: qcode.QueryConfig{
Columns: []string{"id", "full_name", "avatar"},
Filter: []string{
"{ id: { eq: $user_id } }",
},
},
})
qcompile.AddRole("user", "customers", qcode.TRConfig{
Query: qcode.QueryConfig{
Columns: []string{"id", "email", "full_name", "products"},
},
})
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
@ -135,9 +162,8 @@ func TestMain(m *testing.M) {
os.Exit(m.Run()) os.Exit(m.Run())
} }
func compileGQLToPSQL(gql string, vars Variables) ([]byte, error) { func compileGQLToPSQL(gql string, vars Variables, role string) ([]byte, error) {
qc, err := qcompile.Compile([]byte(gql), role)
qc, err := qcompile.Compile([]byte(gql))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -147,6 +173,8 @@ func compileGQLToPSQL(gql string, vars Variables) ([]byte, error) {
return nil, err return nil, err
} }
//fmt.Println(string(sqlStmt))
return sqlStmt, nil return sqlStmt, nil
} }
@ -175,7 +203,7 @@ func withComplexArgs(t *testing.T) {
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0" ORDER BY "products_0_price_ob" DESC), '[]') AS "products" FROM (SELECT DISTINCT ON ("products_0_price_ob") row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0", "products_0"."price" AS "products_0_price_ob" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") < 28) AND (("products"."id") >= 20)) LIMIT ('30') :: integer) AS "products_0" ORDER BY "products_0_price_ob" DESC LIMIT ('30') :: integer) AS "sel_json_agg_0") AS "done_1337";` sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0" ORDER BY "products_0_price_ob" DESC), '[]') AS "products" FROM (SELECT DISTINCT ON ("products_0_price_ob") row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0", "products_0"."price" AS "products_0_price_ob" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") < 28) AND (("products"."id") >= 20)) LIMIT ('30') :: integer) AS "products_0" ORDER BY "products_0_price_ob" DESC LIMIT ('30') :: integer) AS "sel_json_agg_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql, nil) resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -203,7 +231,7 @@ func withWhereMultiOr(t *testing.T) {
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") < 20) OR (("products"."price") > 10) OR NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") < 20) OR (("products"."price") > 10) OR NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql, nil) resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -229,7 +257,7 @@ func withWhereIsNull(t *testing.T) {
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") > 10) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") > 10) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql, nil) resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -255,7 +283,7 @@ func withWhereAndList(t *testing.T) {
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") > 10) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") > 10) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql, nil) resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -275,7 +303,7 @@ func fetchByID(t *testing.T) {
sql := `SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") = 15)) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337";` sql := `SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") = 15)) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql, nil) resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -295,7 +323,7 @@ func searchQuery(t *testing.T) {
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("tsv") @@ to_tsquery('Imperial'))) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("tsv") @@ to_tsquery('Imperial'))) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql, nil) resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -318,7 +346,7 @@ func oneToMany(t *testing.T) {
sql := `SELECT json_object_agg('users', users) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."email" AS "email", "products_1_join"."products" AS "products") AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."email", "users"."id" FROM "users" LIMIT ('20') :: integer) AS "users_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "products_1"."name" AS "name", "products_1"."price" AS "price") AS "sel_1")) AS "sel_json_1" FROM (SELECT "products"."name", "products"."price" FROM "products" WHERE ((("products"."user_id") = ("users_0"."id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` sql := `SELECT json_object_agg('users', users) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."email" AS "email", "products_1_join"."products" AS "products") AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."email", "users"."id" FROM "users" LIMIT ('20') :: integer) AS "users_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "products_1"."name" AS "name", "products_1"."price" AS "price") AS "sel_1")) AS "sel_json_1" FROM (SELECT "products"."name", "products"."price" FROM "products" WHERE ((("products"."user_id") = ("users_0"."id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql, nil) resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -341,7 +369,7 @@ func belongsTo(t *testing.T) {
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."price" AS "price", "users_1_join"."users" AS "users") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", "products"."price", "products"."user_id" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "users_1"."email" AS "email") AS "sel_1")) AS "sel_json_1" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = ("products_0"."user_id"))) LIMIT ('20') :: integer) AS "users_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "users_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."price" AS "price", "users_1_join"."users" AS "users") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", "products"."price", "products"."user_id" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "users_1"."email" AS "email") AS "sel_1")) AS "sel_json_1" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = ("products_0"."user_id"))) LIMIT ('20') :: integer) AS "users_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "users_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql, nil) resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -364,7 +392,7 @@ func manyToMany(t *testing.T) {
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "customers_1_join"."customers" AS "customers") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", "products"."id" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "customers_1"."email" AS "email", "customers_1"."full_name" AS "full_name") AS "sel_1")) AS "sel_json_1" FROM (SELECT "customers"."email", "customers"."full_name" FROM "customers" LEFT OUTER JOIN "purchases" ON (("purchases"."product_id") = ("products_0"."id")) WHERE ((("customers"."id") = ("purchases"."customer_id"))) LIMIT ('20') :: integer) AS "customers_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "customers_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "customers_1_join"."customers" AS "customers") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", "products"."id" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "customers_1"."email" AS "email", "customers_1"."full_name" AS "full_name") AS "sel_1")) AS "sel_json_1" FROM (SELECT "customers"."email", "customers"."full_name" FROM "customers" LEFT OUTER JOIN "purchases" ON (("purchases"."product_id") = ("products_0"."id")) WHERE ((("customers"."id") = ("purchases"."customer_id"))) LIMIT ('20') :: integer) AS "customers_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "customers_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql, nil) resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -387,7 +415,7 @@ func manyToManyReverse(t *testing.T) {
sql := `SELECT json_object_agg('customers', customers) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "customers_0"."email" AS "email", "customers_0"."full_name" AS "full_name", "products_1_join"."products" AS "products") AS "sel_0")) AS "sel_json_0" FROM (SELECT "customers"."email", "customers"."full_name", "customers"."id" FROM "customers" LIMIT ('20') :: integer) AS "customers_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "products_1"."name" AS "name") AS "sel_1")) AS "sel_json_1" FROM (SELECT "products"."name" FROM "products" LEFT OUTER JOIN "purchases" ON (("purchases"."customer_id") = ("customers_0"."id")) WHERE ((("products"."id") = ("purchases"."product_id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` sql := `SELECT json_object_agg('customers', customers) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "customers_0"."email" AS "email", "customers_0"."full_name" AS "full_name", "products_1_join"."products" AS "products") AS "sel_0")) AS "sel_json_0" FROM (SELECT "customers"."email", "customers"."full_name", "customers"."id" FROM "customers" LIMIT ('20') :: integer) AS "customers_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "products_1"."name" AS "name") AS "sel_1")) AS "sel_json_1" FROM (SELECT "products"."name" FROM "products" LEFT OUTER JOIN "purchases" ON (("purchases"."customer_id") = ("customers_0"."id")) WHERE ((("products"."id") = ("purchases"."product_id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql, nil) resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -407,7 +435,47 @@ func aggFunction(t *testing.T) {
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."count_price" AS "count_price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", count("products"."price") AS "count_price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) GROUP BY "products"."name" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."count_price" AS "count_price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", count("products"."price") AS "count_price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) GROUP BY "products"."name" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql, nil) resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil {
t.Fatal(err)
}
if string(resSQL) != sql {
t.Fatal(errNotExpected)
}
}
func aggFunctionBlockedByCol(t *testing.T) {
gql := `query {
products {
name
count_price
}
}`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name" FROM "products" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql, nil, "anon")
if err != nil {
t.Fatal(err)
}
if string(resSQL) != sql {
t.Fatal(errNotExpected)
}
}
func aggFunctionDisabled(t *testing.T) {
gql := `query {
products {
name
count_price
}
}`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name" FROM "products" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql, nil, "anon1")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -427,7 +495,7 @@ func aggFunctionWithFilter(t *testing.T) {
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."max_price" AS "max_price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", max("products"."price") AS "max_price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") > 10)) GROUP BY "products"."id" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."max_price" AS "max_price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", max("products"."price") AS "max_price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") > 10)) GROUP BY "products"."id" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql, nil) resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -447,7 +515,7 @@ func queryWithVariables(t *testing.T) {
sql := `SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") = {{product_price}}) AND (("products"."id") = {{product_id}})) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337";` sql := `SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") = {{product_price}}) AND (("products"."id") = {{product_id}})) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql, nil) resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -464,9 +532,9 @@ func syntheticTables(t *testing.T) {
} }
}` }`
sql := `SELECT json_object_agg('me', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."email" AS "email") AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = {{user_id}})) LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "done_1337";` sql := `SELECT json_object_agg('me', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT ) AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = {{user_id}})) LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql, nil) resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -476,7 +544,7 @@ func syntheticTables(t *testing.T) {
} }
} }
func TestCompileSelect(t *testing.T) { func TestCompileQuery(t *testing.T) {
t.Run("withComplexArgs", withComplexArgs) t.Run("withComplexArgs", withComplexArgs)
t.Run("withWhereAndList", withWhereAndList) t.Run("withWhereAndList", withWhereAndList)
t.Run("withWhereIsNull", withWhereIsNull) t.Run("withWhereIsNull", withWhereIsNull)
@ -488,10 +556,11 @@ func TestCompileSelect(t *testing.T) {
t.Run("manyToMany", manyToMany) t.Run("manyToMany", manyToMany)
t.Run("manyToManyReverse", manyToManyReverse) t.Run("manyToManyReverse", manyToManyReverse)
t.Run("aggFunction", aggFunction) t.Run("aggFunction", aggFunction)
t.Run("aggFunctionBlockedByCol", aggFunctionBlockedByCol)
t.Run("aggFunctionDisabled", aggFunctionDisabled)
t.Run("aggFunctionWithFilter", aggFunctionWithFilter) t.Run("aggFunctionWithFilter", aggFunctionWithFilter)
t.Run("syntheticTables", syntheticTables) t.Run("syntheticTables", syntheticTables)
t.Run("queryWithVariables", queryWithVariables) t.Run("queryWithVariables", queryWithVariables)
} }
var benchGQL = []byte(`query { var benchGQL = []byte(`query {
@ -526,7 +595,7 @@ func BenchmarkCompile(b *testing.B) {
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
w.Reset() w.Reset()
qc, err := qcompile.Compile(benchGQL) qc, err := qcompile.Compile(benchGQL, "user")
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
@ -547,7 +616,7 @@ func BenchmarkCompileParallel(b *testing.B) {
for pb.Next() { for pb.Next() {
w.Reset() w.Reset()
qc, err := qcompile.Compile(benchGQL) qc, err := qcompile.Compile(benchGQL, "user")
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }

99
qcode/config.go Normal file
View File

@ -0,0 +1,99 @@
package qcode
type Config struct {
Blocklist []string
KeepArgs bool
}
type QueryConfig struct {
Limit int
Filter []string
Columns []string
DisableFunctions bool
}
type InsertConfig struct {
Filter []string
Columns []string
Set map[string]string
}
type UpdateConfig struct {
Filter []string
Columns []string
Set map[string]string
}
type DeleteConfig struct {
Filter []string
Columns []string
}
type TRConfig struct {
Query QueryConfig
Insert InsertConfig
Update UpdateConfig
Delete DeleteConfig
}
type trval struct {
query struct {
limit string
fil *Exp
cols map[string]struct{}
disable struct {
funcs bool
}
}
insert struct {
fil *Exp
cols map[string]struct{}
set map[string]string
}
update struct {
fil *Exp
cols map[string]struct{}
set map[string]string
}
delete struct {
fil *Exp
cols map[string]struct{}
}
}
func (trv *trval) allowedColumns(qt QType) map[string]struct{} {
switch qt {
case QTQuery:
return trv.query.cols
case QTInsert:
return trv.insert.cols
case QTUpdate:
return trv.update.cols
case QTDelete:
return trv.insert.cols
case QTUpsert:
return trv.insert.cols
}
return nil
}
func (trv *trval) filter(qt QType) *Exp {
switch qt {
case QTQuery:
return trv.query.fil
case QTInsert:
return trv.insert.fil
case QTUpdate:
return trv.update.fil
case QTDelete:
return trv.delete.fil
case QTUpsert:
return trv.insert.fil
}
return nil
}

View File

@ -5,7 +5,7 @@ func FuzzerEntrypoint(data []byte) int {
//testData := string(data) //testData := string(data)
qcompile, _ := NewCompiler(Config{}) qcompile, _ := NewCompiler(Config{})
_, err := qcompile.Compile(data) _, err := qcompile.Compile(data, "user")
if err != nil { if err != nil {
return -1 return -1
} }

View File

@ -46,13 +46,18 @@ func compareOp(op1, op2 Operation) error {
*/ */
func TestCompile1(t *testing.T) { func TestCompile1(t *testing.T) {
qcompile, _ := NewCompiler(Config{}) qc, _ := NewCompiler(Config{})
qc.AddRole("user", "product", TRConfig{
Query: QueryConfig{
Columns: []string{"id", "Name"},
},
})
_, err := qcompile.Compile([]byte(` _, err := qc.Compile([]byte(`
product(id: 15) { product(id: 15) {
id id
name name
}`)) }`), "user")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -60,13 +65,18 @@ func TestCompile1(t *testing.T) {
} }
func TestCompile2(t *testing.T) { func TestCompile2(t *testing.T) {
qcompile, _ := NewCompiler(Config{}) qc, _ := NewCompiler(Config{})
qc.AddRole("user", "product", TRConfig{
Query: QueryConfig{
Columns: []string{"ID"},
},
})
_, err := qcompile.Compile([]byte(` _, err := qc.Compile([]byte(`
query { product(id: 15) { query { product(id: 15) {
id id
name name
} }`)) } }`), "user")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -74,15 +84,20 @@ func TestCompile2(t *testing.T) {
} }
func TestCompile3(t *testing.T) { func TestCompile3(t *testing.T) {
qcompile, _ := NewCompiler(Config{}) qc, _ := NewCompiler(Config{})
qc.AddRole("user", "product", TRConfig{
Query: QueryConfig{
Columns: []string{"ID"},
},
})
_, err := qcompile.Compile([]byte(` _, err := qc.Compile([]byte(`
mutation { mutation {
product(id: 15, name: "Test") { product(id: 15, name: "Test") {
id id
name name
} }
}`)) }`), "user")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -91,7 +106,7 @@ func TestCompile3(t *testing.T) {
func TestInvalidCompile1(t *testing.T) { func TestInvalidCompile1(t *testing.T) {
qcompile, _ := NewCompiler(Config{}) qcompile, _ := NewCompiler(Config{})
_, err := qcompile.Compile([]byte(`#`)) _, err := qcompile.Compile([]byte(`#`), "user")
if err == nil { if err == nil {
t.Fatal(errors.New("expecting an error")) t.Fatal(errors.New("expecting an error"))
@ -100,7 +115,7 @@ func TestInvalidCompile1(t *testing.T) {
func TestInvalidCompile2(t *testing.T) { func TestInvalidCompile2(t *testing.T) {
qcompile, _ := NewCompiler(Config{}) qcompile, _ := NewCompiler(Config{})
_, err := qcompile.Compile([]byte(`{u(where:{not:0})}`)) _, err := qcompile.Compile([]byte(`{u(where:{not:0})}`), "user")
if err == nil { if err == nil {
t.Fatal(errors.New("expecting an error")) t.Fatal(errors.New("expecting an error"))
@ -109,7 +124,7 @@ func TestInvalidCompile2(t *testing.T) {
func TestEmptyCompile(t *testing.T) { func TestEmptyCompile(t *testing.T) {
qcompile, _ := NewCompiler(Config{}) qcompile, _ := NewCompiler(Config{})
_, err := qcompile.Compile([]byte(``)) _, err := qcompile.Compile([]byte(``), "user")
if err == nil { if err == nil {
t.Fatal(errors.New("expecting an error")) t.Fatal(errors.New("expecting an error"))
@ -144,7 +159,7 @@ func BenchmarkQCompile(b *testing.B) {
b.ReportAllocs() b.ReportAllocs()
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
_, err := qcompile.Compile(gql) _, err := qcompile.Compile(gql, "user")
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
@ -160,7 +175,7 @@ func BenchmarkQCompileP(b *testing.B) {
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
for pb.Next() { for pb.Next() {
_, err := qcompile.Compile(gql) _, err := qcompile.Compile(gql, "user")
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)

View File

@ -3,6 +3,7 @@ package qcode
import ( import (
"errors" "errors"
"fmt" "fmt"
"strconv"
"strings" "strings"
"sync" "sync"
@ -17,23 +18,16 @@ const (
maxSelectors = 30 maxSelectors = 30
QTQuery QType = iota + 1 QTQuery QType = iota + 1
QTMutation QTInsert
QTUpdate
ActionInsert Action = iota + 1 QTDelete
ActionUpdate QTUpsert
ActionDelete
ActionUpsert
) )
type QCode struct { type QCode struct {
Type QType Type QType
Selects []Select ActionVar string
} Selects []Select
type Column struct {
Table string
Name string
FieldName string
} }
type Select struct { type Select struct {
@ -47,9 +41,15 @@ type Select struct {
OrderBy []*OrderBy OrderBy []*OrderBy
DistinctOn []string DistinctOn []string
Paging Paging Paging Paging
Action Action
ActionVar string
Children []int32 Children []int32
Functions bool
Allowed map[string]struct{}
}
type Column struct {
Table string
Name string
FieldName string
} }
type Exp struct { type Exp struct {
@ -77,8 +77,9 @@ type OrderBy struct {
} }
type Paging struct { type Paging struct {
Limit string Limit string
Offset string Offset string
NoLimit bool
} }
type ExpOp int type ExpOp int
@ -145,81 +146,23 @@ const (
OrderDescNullsLast OrderDescNullsLast
) )
type Filters struct {
All map[string][]string
Query map[string][]string
Insert map[string][]string
Update map[string][]string
Delete map[string][]string
}
type Config struct {
DefaultFilter []string
FilterMap Filters
Blocklist []string
KeepArgs bool
}
type Compiler struct { type Compiler struct {
df *Exp tr map[string]map[string]*trval
fm struct {
all map[string]*Exp
query map[string]*Exp
insert map[string]*Exp
update map[string]*Exp
delete map[string]*Exp
}
bl map[string]struct{} bl map[string]struct{}
ka bool ka bool
} }
var opMap = map[parserType]QType{
opQuery: QTQuery,
opMutate: QTMutation,
}
var expPool = sync.Pool{ var expPool = sync.Pool{
New: func() interface{} { return &Exp{doFree: true} }, New: func() interface{} { return &Exp{doFree: true} },
} }
func NewCompiler(c Config) (*Compiler, error) { func NewCompiler(c Config) (*Compiler, error) {
var err error
co := &Compiler{ka: c.KeepArgs} co := &Compiler{ka: c.KeepArgs}
co.tr = make(map[string]map[string]*trval)
co.bl = make(map[string]struct{}, len(c.Blocklist)) co.bl = make(map[string]struct{}, len(c.Blocklist))
for i := range c.Blocklist { for i := range c.Blocklist {
co.bl[c.Blocklist[i]] = struct{}{} co.bl[strings.ToLower(c.Blocklist[i])] = struct{}{}
}
co.df, err = compileFilter(c.DefaultFilter)
if err != nil {
return nil, err
}
co.fm.all, err = buildFilters(c.FilterMap.All)
if err != nil {
return nil, err
}
co.fm.query, err = buildFilters(c.FilterMap.Query)
if err != nil {
return nil, err
}
co.fm.insert, err = buildFilters(c.FilterMap.Insert)
if err != nil {
return nil, err
}
co.fm.update, err = buildFilters(c.FilterMap.Update)
if err != nil {
return nil, err
}
co.fm.delete, err = buildFilters(c.FilterMap.Delete)
if err != nil {
return nil, err
} }
seedExp := [100]Exp{} seedExp := [100]Exp{}
@ -232,58 +175,99 @@ func NewCompiler(c Config) (*Compiler, error) {
return co, nil return co, nil
} }
func buildFilters(filMap map[string][]string) (map[string]*Exp, error) { func (com *Compiler) AddRole(role, table string, trc TRConfig) error {
fm := make(map[string]*Exp, len(filMap)) var err error
trv := &trval{}
for k, v := range filMap { toMap := func(cols []string) map[string]struct{} {
fil, err := compileFilter(v) m := make(map[string]struct{}, len(cols))
if err != nil { for i := range cols {
return nil, err m[strings.ToLower(cols[i])] = struct{}{}
} }
singular := flect.Singularize(k) return m
plural := flect.Pluralize(k)
fm[singular] = fil
fm[plural] = fil
} }
return fm, nil // query config
trv.query.fil, err = compileFilter(trc.Query.Filter)
if err != nil {
return err
}
if trc.Query.Limit > 0 {
trv.query.limit = strconv.Itoa(trc.Query.Limit)
}
trv.query.cols = toMap(trc.Query.Columns)
trv.query.disable.funcs = trc.Query.DisableFunctions
// insert config
if trv.insert.fil, err = compileFilter(trc.Insert.Filter); err != nil {
return err
}
trv.insert.cols = toMap(trc.Insert.Columns)
// update config
if trv.update.fil, err = compileFilter(trc.Update.Filter); err != nil {
return err
}
trv.insert.cols = toMap(trc.Insert.Columns)
trv.insert.set = trc.Insert.Set
// delete config
if trv.delete.fil, err = compileFilter(trc.Delete.Filter); err != nil {
return err
}
trv.delete.cols = toMap(trc.Delete.Columns)
singular := flect.Singularize(table)
plural := flect.Pluralize(table)
if _, ok := com.tr[role]; !ok {
com.tr[role] = make(map[string]*trval)
}
com.tr[role][singular] = trv
com.tr[role][plural] = trv
return nil
} }
func (com *Compiler) Compile(query []byte) (*QCode, error) { func (com *Compiler) Compile(query []byte, role string) (*QCode, error) {
var qc QCode
var err error var err error
qc := QCode{Type: QTQuery}
op, err := Parse(query) op, err := Parse(query)
if err != nil { if err != nil {
return nil, err return nil, err
} }
qc.Selects, err = com.compileQuery(op) if err = com.compileQuery(&qc, op, role); err != nil {
if err != nil {
return nil, err return nil, err
} }
if t, ok := opMap[op.Type]; ok {
qc.Type = t
} else {
return nil, fmt.Errorf("Unknown operation type %d", op.Type)
}
opPool.Put(op) opPool.Put(op)
return &qc, nil return &qc, nil
} }
func (com *Compiler) compileQuery(op *Operation) ([]Select, error) { func (com *Compiler) compileQuery(qc *QCode, op *Operation, role string) error {
id := int32(0) id := int32(0)
parentID := int32(0) parentID := int32(0)
if len(op.Fields) == 0 {
return errors.New("invalid graphql no query found")
}
if op.Type == opMutate {
if err := com.setMutationType(qc, op.Fields[0].Args); err != nil {
return err
}
}
selects := make([]Select, 0, 5) selects := make([]Select, 0, 5)
st := NewStack() st := NewStack()
action := qc.Type
if len(op.Fields) == 0 { if len(op.Fields) == 0 {
return nil, errors.New("empty query") return errors.New("empty query")
} }
st.Push(op.Fields[0].ID) st.Push(op.Fields[0].ID)
@ -293,7 +277,7 @@ func (com *Compiler) compileQuery(op *Operation) ([]Select, error) {
} }
if id >= maxSelectors { if id >= maxSelectors {
return nil, fmt.Errorf("selector limit reached (%d)", maxSelectors) return fmt.Errorf("selector limit reached (%d)", maxSelectors)
} }
fid := st.Pop() fid := st.Pop()
@ -303,14 +287,28 @@ func (com *Compiler) compileQuery(op *Operation) ([]Select, error) {
continue continue
} }
trv, ok := com.tr[role][field.Name]
if !ok {
continue
}
selects = append(selects, Select{ selects = append(selects, Select{
ID: id, ID: id,
ParentID: parentID, ParentID: parentID,
Table: field.Name, Table: field.Name,
Children: make([]int32, 0, 5), Children: make([]int32, 0, 5),
Allowed: trv.allowedColumns(action),
}) })
s := &selects[(len(selects) - 1)] s := &selects[(len(selects) - 1)]
if action == QTQuery {
s.Functions = !trv.query.disable.funcs
if len(trv.query.limit) != 0 {
s.Paging.Limit = trv.query.limit
}
}
if s.ID != 0 { if s.ID != 0 {
p := &selects[s.ParentID] p := &selects[s.ParentID]
p.Children = append(p.Children, s.ID) p.Children = append(p.Children, s.ID)
@ -322,12 +320,13 @@ func (com *Compiler) compileQuery(op *Operation) ([]Select, error) {
s.FieldName = s.Table s.FieldName = s.Table
} }
err := com.compileArgs(s, field.Args) err := com.compileArgs(qc, s, field.Args)
if err != nil { if err != nil {
return nil, err return err
} }
s.Cols = make([]Column, 0, len(field.Children)) s.Cols = make([]Column, 0, len(field.Children))
action = QTQuery
for _, cid := range field.Children { for _, cid := range field.Children {
f := op.Fields[cid] f := op.Fields[cid]
@ -356,36 +355,14 @@ func (com *Compiler) compileQuery(op *Operation) ([]Select, error) {
} }
if id == 0 { if id == 0 {
return nil, errors.New("invalid query") return errors.New("invalid query")
} }
var fil *Exp var fil *Exp
root := &selects[0] root := &selects[0]
switch op.Type { if trv, ok := com.tr[role][op.Fields[0].Name]; ok {
case opQuery: fil = trv.filter(qc.Type)
fil, _ = com.fm.query[root.Table]
case opMutate:
switch root.Action {
case ActionInsert:
fil, _ = com.fm.insert[root.Table]
case ActionUpdate:
fil, _ = com.fm.update[root.Table]
case ActionDelete:
fil, _ = com.fm.delete[root.Table]
case ActionUpsert:
fil, _ = com.fm.insert[root.Table]
}
}
if fil == nil {
fil, _ = com.fm.all[root.Table]
}
if fil == nil {
fil = com.df
} }
if fil != nil && fil.Op != OpNop { if fil != nil && fil.Op != OpNop {
@ -403,10 +380,11 @@ func (com *Compiler) compileQuery(op *Operation) ([]Select, error) {
} }
} }
return selects[:id], nil qc.Selects = selects[:id]
return nil
} }
func (com *Compiler) compileArgs(sel *Select, args []Arg) error { func (com *Compiler) compileArgs(qc *QCode, sel *Select, args []Arg) error {
var err error var err error
if com.ka { if com.ka {
@ -418,9 +396,7 @@ func (com *Compiler) compileArgs(sel *Select, args []Arg) error {
switch arg.Name { switch arg.Name {
case "id": case "id":
if sel.ID == 0 { err = com.compileArgID(sel, arg)
err = com.compileArgID(sel, arg)
}
case "search": case "search":
err = com.compileArgSearch(sel, arg) err = com.compileArgSearch(sel, arg)
case "where": case "where":
@ -433,18 +409,6 @@ func (com *Compiler) compileArgs(sel *Select, args []Arg) error {
err = com.compileArgLimit(sel, arg) err = com.compileArgLimit(sel, arg)
case "offset": case "offset":
err = com.compileArgOffset(sel, arg) err = com.compileArgOffset(sel, arg)
case "insert":
sel.Action = ActionInsert
err = com.compileArgAction(sel, arg)
case "update":
sel.Action = ActionUpdate
err = com.compileArgAction(sel, arg)
case "upsert":
sel.Action = ActionUpsert
err = com.compileArgAction(sel, arg)
case "delete":
sel.Action = ActionDelete
err = com.compileArgAction(sel, arg)
} }
if err != nil { if err != nil {
@ -461,6 +425,45 @@ func (com *Compiler) compileArgs(sel *Select, args []Arg) error {
return nil return nil
} }
func (com *Compiler) setMutationType(qc *QCode, args []Arg) error {
setActionVar := func(arg *Arg) error {
if arg.Val.Type != nodeVar {
return fmt.Errorf("value for argument '%s' must be a variable", arg.Name)
}
qc.ActionVar = arg.Val.Val
return nil
}
for i := range args {
arg := &args[i]
switch arg.Name {
case "insert":
qc.Type = QTInsert
return setActionVar(arg)
case "update":
qc.Type = QTUpdate
return setActionVar(arg)
case "upsert":
qc.Type = QTUpsert
return setActionVar(arg)
case "delete":
qc.Type = QTDelete
if arg.Val.Type != nodeBool {
return fmt.Errorf("value for argument '%s' must be a boolean", arg.Name)
}
if arg.Val.Val == "false" {
qc.Type = QTQuery
}
return nil
}
}
return nil
}
func (com *Compiler) compileArgObj(st *util.Stack, arg *Arg) (*Exp, error) { func (com *Compiler) compileArgObj(st *util.Stack, arg *Arg) (*Exp, error) {
if arg.Val.Type != nodeObj { if arg.Val.Type != nodeObj {
return nil, fmt.Errorf("expecting an object") return nil, fmt.Errorf("expecting an object")
@ -540,6 +543,10 @@ func (com *Compiler) compileArgNode(st *util.Stack, node *Node, usePool bool) (*
} }
func (com *Compiler) compileArgID(sel *Select, arg *Arg) error { func (com *Compiler) compileArgID(sel *Select, arg *Arg) error {
if sel.ID != 0 {
return nil
}
if sel.Where != nil && sel.Where.Op == OpEqID { if sel.Where != nil && sel.Where.Op == OpEqID {
return nil return nil
} }
@ -732,26 +739,6 @@ func (com *Compiler) compileArgOffset(sel *Select, arg *Arg) error {
return nil return nil
} }
func (com *Compiler) compileArgAction(sel *Select, arg *Arg) error {
switch sel.Action {
case ActionDelete:
if arg.Val.Type != nodeBool {
return fmt.Errorf("value for argument '%s' must be a boolean", arg.Name)
}
if arg.Val.Val == "false" {
sel.Action = 0
}
default:
if arg.Val.Type != nodeVar {
return fmt.Errorf("value for argument '%s' must be a variable", arg.Name)
}
sel.ActionVar = arg.Val.Val
}
return nil
}
func newExp(st *util.Stack, node *Node, usePool bool) (*Exp, error) { func newExp(st *util.Stack, node *Node, usePool bool) (*Exp, error) {
name := node.Name name := node.Name
if name[0] == '_' { if name[0] == '_' {

View File

@ -71,18 +71,14 @@ type config struct {
} `mapstructure:"database"` } `mapstructure:"database"`
Tables []configTable Tables []configTable
Roles []configRoles
} }
type configTable struct { type configTable struct {
Name string Name string
Filter []string Table string
FilterQuery []string `mapstructure:"filter_query"` Blocklist []string
FilterInsert []string `mapstructure:"filter_insert"` Remotes []configRemote
FilterUpdate []string `mapstructure:"filter_update"`
FilterDelete []string `mapstructure:"filter_delete"`
Table string
Blocklist []string
Remotes []configRemote
} }
type configRemote struct { type configRemote struct {
@ -98,6 +94,41 @@ type configRemote struct {
} `mapstructure:"set_headers"` } `mapstructure:"set_headers"`
} }
type configRoles struct {
Name string
Tables []struct {
Name string
Query struct {
Limit int
Filter []string
Columns []string
DisableAggregation bool `mapstructure:"disable_aggregation"`
Deny bool
}
Insert struct {
Filter []string
Columns []string
Set map[string]string
Deny bool
}
Update struct {
Filter []string
Columns []string
Set map[string]string
Deny bool
}
Delete struct {
Filter []string
Columns []string
Deny bool
}
}
}
func newConfig() *viper.Viper { func newConfig() *viper.Viper {
vi := viper.New() vi := viper.New()

View File

@ -59,7 +59,7 @@ func (c *coreContext) execQuery() ([]byte, error) {
} else { } else {
qc, err = qcompile.Compile([]byte(c.req.Query)) qc, err = qcompile.Compile([]byte(c.req.Query), "user")
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@ -40,7 +40,7 @@ func prepareStmt(key, gql string, varBytes json.RawMessage) error {
return nil return nil
} }
qc, err := qcompile.Compile([]byte(gql)) qc, err := qcompile.Compile([]byte(gql), "user")
if err != nil { if err != nil {
return err return err
} }

View File

@ -12,7 +12,6 @@ import (
rice "github.com/GeertJohan/go.rice" rice "github.com/GeertJohan/go.rice"
"github.com/dosco/super-graph/psql" "github.com/dosco/super-graph/psql"
"github.com/dosco/super-graph/qcode" "github.com/dosco/super-graph/qcode"
"github.com/gobuffalo/flect"
) )
func initCompilers(c *config) (*qcode.Compiler, *psql.Compiler, error) { func initCompilers(c *config) (*qcode.Compiler, *psql.Compiler, error) {
@ -22,49 +21,50 @@ func initCompilers(c *config) (*qcode.Compiler, *psql.Compiler, error) {
} }
conf := qcode.Config{ conf := qcode.Config{
DefaultFilter: c.DB.Defaults.Filter,
FilterMap: qcode.Filters{
All: make(map[string][]string, len(c.Tables)),
Query: make(map[string][]string, len(c.Tables)),
Insert: make(map[string][]string, len(c.Tables)),
Update: make(map[string][]string, len(c.Tables)),
Delete: make(map[string][]string, len(c.Tables)),
},
Blocklist: c.DB.Defaults.Blocklist, Blocklist: c.DB.Defaults.Blocklist,
KeepArgs: false, KeepArgs: false,
} }
for i := range c.Tables {
t := c.Tables[i]
singular := flect.Singularize(t.Name)
plural := flect.Pluralize(t.Name)
setFilter := func(fm map[string][]string, fil []string) {
switch {
case len(fil) == 0:
return
case fil[0] == "none" || len(fil[0]) == 0:
fm[singular] = []string{}
fm[plural] = []string{}
default:
fm[singular] = t.Filter
fm[plural] = t.Filter
}
}
setFilter(conf.FilterMap.All, t.Filter)
setFilter(conf.FilterMap.Query, t.FilterQuery)
setFilter(conf.FilterMap.Insert, t.FilterInsert)
setFilter(conf.FilterMap.Update, t.FilterUpdate)
setFilter(conf.FilterMap.Delete, t.FilterDelete)
}
qc, err := qcode.NewCompiler(conf) qc, err := qcode.NewCompiler(conf)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
for _, r := range c.Roles {
for _, t := range r.Tables {
query := qcode.QueryConfig{
Limit: t.Query.Limit,
Filter: t.Query.Filter,
Columns: t.Query.Columns,
DisableFunctions: t.Query.DisableAggregation,
}
insert := qcode.InsertConfig{
Filter: t.Insert.Filter,
Columns: t.Insert.Columns,
Set: t.Insert.Set,
}
update := qcode.UpdateConfig{
Filter: t.Insert.Filter,
Columns: t.Insert.Columns,
Set: t.Insert.Set,
}
delete := qcode.DeleteConfig{
Filter: t.Insert.Filter,
Columns: t.Insert.Columns,
}
qc.AddRole(r.Name, t.Name, qcode.TRConfig{
Query: query,
Insert: insert,
Update: update,
Delete: delete,
})
}
}
pc := psql.NewCompiler(psql.Config{ pc := psql.NewCompiler(psql.Config{
Schema: schema, Schema: schema,
Vars: c.getVariables(), Vars: c.getVariables(),