Add role based access control
This commit is contained in:
parent
85a74ed30c
commit
deb5b93c81
|
@ -115,9 +115,6 @@ tables:
|
|||
- name: users
|
||||
# This filter will overwrite defaults.filter
|
||||
# filter: ["{ id: { eq: $user_id } }"]
|
||||
# filter_query: ["{ id: { eq: $user_id } }"]
|
||||
filter_update: ["{ id: { eq: $user_id } }"]
|
||||
filter_delete: ["{ id: { eq: $user_id } }"]
|
||||
|
||||
# - name: products
|
||||
# # Multiple filters are AND'd together
|
||||
|
@ -127,10 +124,6 @@ tables:
|
|||
# ]
|
||||
|
||||
- name: customers
|
||||
# No filter is used for this field not
|
||||
# even defaults.filter
|
||||
filter: none
|
||||
|
||||
remotes:
|
||||
- name: payments
|
||||
id: stripe_id
|
||||
|
@ -149,7 +142,56 @@ tables:
|
|||
# real db table backing them
|
||||
name: me
|
||||
table: users
|
||||
filter: ["{ id: { eq: $user_id } }"]
|
||||
|
||||
# - name: posts
|
||||
# filter: ["{ account_id: { _eq: $account_id } }"]
|
||||
roles:
|
||||
- name: anon
|
||||
tables:
|
||||
- name: products
|
||||
limit: 10
|
||||
|
||||
query:
|
||||
columns: ["id", "name", "description" ]
|
||||
aggregation: false
|
||||
|
||||
insert:
|
||||
allow: false
|
||||
|
||||
update:
|
||||
allow: false
|
||||
|
||||
delete:
|
||||
allow: false
|
||||
|
||||
- name: user
|
||||
tables:
|
||||
- name: products
|
||||
|
||||
query:
|
||||
limit: 50
|
||||
filter: ["{ user_id: { eq: $user_id } }"]
|
||||
columns: ["id", "name", "description" ]
|
||||
disable_aggregation: false
|
||||
|
||||
insert:
|
||||
filter: ["{ user_id: { eq: $user_id } }"]
|
||||
columns: ["id", "name", "description" ]
|
||||
set:
|
||||
- created_at: "now"
|
||||
|
||||
update:
|
||||
filter: ["{ user_id: { eq: $user_id } }"]
|
||||
columns:
|
||||
- id
|
||||
- name
|
||||
set:
|
||||
- updated_at: "now"
|
||||
|
||||
delete:
|
||||
deny: true
|
||||
|
||||
- name: manager
|
||||
tables:
|
||||
- name: users
|
||||
|
||||
select:
|
||||
filter: ["{ account_id: { _eq: $account_id } }"]
|
||||
|
|
|
@ -10,7 +10,7 @@ import (
|
|||
"github.com/dosco/super-graph/qcode"
|
||||
)
|
||||
|
||||
var zeroPaging = qcode.Paging{}
|
||||
var noLimit = qcode.Paging{NoLimit: true}
|
||||
|
||||
func (co *Compiler) compileMutation(qc *qcode.QCode, w *bytes.Buffer, vars Variables) (uint32, error) {
|
||||
if len(qc.Selects) == 0 {
|
||||
|
@ -29,23 +29,23 @@ func (co *Compiler) compileMutation(qc *qcode.QCode, w *bytes.Buffer, vars Varia
|
|||
quoted(c.w, ti.Name)
|
||||
c.w.WriteString(` AS `)
|
||||
|
||||
switch root.Action {
|
||||
case qcode.ActionInsert:
|
||||
switch qc.Type {
|
||||
case qcode.QTInsert:
|
||||
if _, err := c.renderInsert(qc, w, vars, ti); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
case qcode.ActionUpdate:
|
||||
case qcode.QTUpdate:
|
||||
if _, err := c.renderUpdate(qc, w, vars, ti); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
case qcode.ActionUpsert:
|
||||
case qcode.QTUpsert:
|
||||
if _, err := c.renderUpsert(qc, w, vars, ti); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
case qcode.ActionDelete:
|
||||
case qcode.QTDelete:
|
||||
if _, err := c.renderDelete(qc, w, vars, ti); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
@ -56,22 +56,23 @@ func (co *Compiler) compileMutation(qc *qcode.QCode, w *bytes.Buffer, vars Varia
|
|||
|
||||
io.WriteString(c.w, ` RETURNING *) `)
|
||||
|
||||
root.Paging = zeroPaging
|
||||
root.Paging = noLimit
|
||||
root.DistinctOn = root.DistinctOn[:]
|
||||
root.OrderBy = root.OrderBy[:]
|
||||
root.Where = nil
|
||||
root.Args = nil
|
||||
|
||||
qc.Type = qcode.QTQuery
|
||||
|
||||
return c.compileQuery(qc, w)
|
||||
}
|
||||
|
||||
func (c *compilerContext) renderInsert(qc *qcode.QCode, w *bytes.Buffer,
|
||||
vars Variables, ti *DBTableInfo) (uint32, error) {
|
||||
root := &qc.Selects[0]
|
||||
|
||||
insert, ok := vars[root.ActionVar]
|
||||
insert, ok := vars[qc.ActionVar]
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("Variable '%s' not defined", root.ActionVar)
|
||||
return 0, fmt.Errorf("Variable '%s' not defined", qc.ActionVar)
|
||||
}
|
||||
|
||||
jt, array, err := jsn.Tree(insert)
|
||||
|
@ -80,7 +81,7 @@ func (c *compilerContext) renderInsert(qc *qcode.QCode, w *bytes.Buffer,
|
|||
}
|
||||
|
||||
c.w.WriteString(`(WITH "input" AS (SELECT {{`)
|
||||
c.w.WriteString(root.ActionVar)
|
||||
c.w.WriteString(qc.ActionVar)
|
||||
c.w.WriteString(`}}::json AS j) INSERT INTO `)
|
||||
quoted(c.w, ti.Name)
|
||||
io.WriteString(c.w, ` (`)
|
||||
|
@ -106,12 +107,18 @@ func (c *compilerContext) renderInsert(qc *qcode.QCode, w *bytes.Buffer,
|
|||
|
||||
func (c *compilerContext) renderInsertUpdateColumns(qc *qcode.QCode, w *bytes.Buffer,
|
||||
jt map[string]interface{}, ti *DBTableInfo) (uint32, error) {
|
||||
root := &qc.Selects[0]
|
||||
|
||||
i := 0
|
||||
for _, cn := range ti.ColumnNames {
|
||||
if _, ok := jt[cn]; !ok {
|
||||
continue
|
||||
}
|
||||
if len(root.Allowed) != 0 {
|
||||
if _, ok := root.Allowed[cn]; !ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if i != 0 {
|
||||
io.WriteString(c.w, `, `)
|
||||
}
|
||||
|
@ -126,9 +133,9 @@ func (c *compilerContext) renderUpdate(qc *qcode.QCode, w *bytes.Buffer,
|
|||
vars Variables, ti *DBTableInfo) (uint32, error) {
|
||||
root := &qc.Selects[0]
|
||||
|
||||
update, ok := vars[root.ActionVar]
|
||||
update, ok := vars[qc.ActionVar]
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("Variable '%s' not defined", root.ActionVar)
|
||||
return 0, fmt.Errorf("Variable '%s' not defined", qc.ActionVar)
|
||||
}
|
||||
|
||||
jt, array, err := jsn.Tree(update)
|
||||
|
@ -137,7 +144,7 @@ func (c *compilerContext) renderUpdate(qc *qcode.QCode, w *bytes.Buffer,
|
|||
}
|
||||
|
||||
c.w.WriteString(`(WITH "input" AS (SELECT {{`)
|
||||
c.w.WriteString(root.ActionVar)
|
||||
c.w.WriteString(qc.ActionVar)
|
||||
c.w.WriteString(`}}::json AS j) UPDATE `)
|
||||
quoted(c.w, ti.Name)
|
||||
io.WriteString(c.w, ` SET (`)
|
||||
|
@ -183,11 +190,10 @@ func (c *compilerContext) renderDelete(qc *qcode.QCode, w *bytes.Buffer,
|
|||
|
||||
func (c *compilerContext) renderUpsert(qc *qcode.QCode, w *bytes.Buffer,
|
||||
vars Variables, ti *DBTableInfo) (uint32, error) {
|
||||
root := &qc.Selects[0]
|
||||
|
||||
upsert, ok := vars[root.ActionVar]
|
||||
upsert, ok := vars[qc.ActionVar]
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("Variable '%s' not defined", root.ActionVar)
|
||||
return 0, fmt.Errorf("Variable '%s' not defined", qc.ActionVar)
|
||||
}
|
||||
|
||||
jt, _, err := jsn.Tree(upsert)
|
|
@ -18,7 +18,7 @@ func simpleInsert(t *testing.T) {
|
|||
"data": json.RawMessage(`{"email": "reannagreenholt@orn.com", "full_name": "Flo Barton"}`),
|
||||
}
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, vars)
|
||||
resSQL, err := compileGQLToPSQL(gql, vars, "user")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -42,7 +42,7 @@ func singleInsert(t *testing.T) {
|
|||
"insert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc", "user_id": 5 }`),
|
||||
}
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, vars)
|
||||
resSQL, err := compileGQLToPSQL(gql, vars, "user")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -54,7 +54,7 @@ func singleInsert(t *testing.T) {
|
|||
|
||||
func bulkInsert(t *testing.T) {
|
||||
gql := `mutation {
|
||||
product(id: 15, insert: $insert) {
|
||||
product(name: "test", id: 15, insert: $insert) {
|
||||
id
|
||||
name
|
||||
}
|
||||
|
@ -66,7 +66,7 @@ func bulkInsert(t *testing.T) {
|
|||
"insert": json.RawMessage(` [{ "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }]`),
|
||||
}
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, vars)
|
||||
resSQL, err := compileGQLToPSQL(gql, vars, "user")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -90,7 +90,7 @@ func singleUpsert(t *testing.T) {
|
|||
"upsert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`),
|
||||
}
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, vars)
|
||||
resSQL, err := compileGQLToPSQL(gql, vars, "user")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -114,7 +114,7 @@ func bulkUpsert(t *testing.T) {
|
|||
"upsert": json.RawMessage(` [{ "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }]`),
|
||||
}
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, vars)
|
||||
resSQL, err := compileGQLToPSQL(gql, vars, "user")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -138,7 +138,7 @@ func singleUpdate(t *testing.T) {
|
|||
"update": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`),
|
||||
}
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, vars)
|
||||
resSQL, err := compileGQLToPSQL(gql, vars, "user")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -162,7 +162,7 @@ func delete(t *testing.T) {
|
|||
"update": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`),
|
||||
}
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, vars)
|
||||
resSQL, err := compileGQLToPSQL(gql, vars, "user")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -172,7 +172,7 @@ func delete(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestCompileInsert(t *testing.T) {
|
||||
func TestCompileMutate(t *testing.T) {
|
||||
t.Run("simpleInsert", simpleInsert)
|
||||
t.Run("singleInsert", singleInsert)
|
||||
t.Run("bulkInsert", bulkInsert)
|
|
@ -64,11 +64,11 @@ func (co *Compiler) Compile(qc *qcode.QCode, w *bytes.Buffer, vars Variables) (u
|
|||
switch qc.Type {
|
||||
case qcode.QTQuery:
|
||||
return co.compileQuery(qc, w)
|
||||
case qcode.QTMutation:
|
||||
case qcode.QTInsert, qcode.QTUpdate, qcode.QTDelete, qcode.QTUpsert:
|
||||
return co.compileMutation(qc, w, vars)
|
||||
}
|
||||
|
||||
return 0, errors.New("unknown operation")
|
||||
return 0, fmt.Errorf("Unknown operation type %d", qc.Type)
|
||||
}
|
||||
|
||||
func (co *Compiler) compileQuery(qc *qcode.QCode, w *bytes.Buffer) (uint32, error) {
|
||||
|
@ -295,20 +295,22 @@ func (c *compilerContext) renderSelectClose(sel *qcode.Select, ti *DBTableInfo)
|
|||
}
|
||||
}
|
||||
|
||||
if sel.Action == 0 {
|
||||
if len(sel.Paging.Limit) != 0 {
|
||||
switch {
|
||||
case sel.Paging.NoLimit:
|
||||
break
|
||||
|
||||
case len(sel.Paging.Limit) != 0:
|
||||
//fmt.Fprintf(w, ` LIMIT ('%s') :: integer`, c.sel.Paging.Limit)
|
||||
c.w.WriteString(` LIMIT ('`)
|
||||
c.w.WriteString(sel.Paging.Limit)
|
||||
c.w.WriteString(`') :: integer`)
|
||||
|
||||
} else if ti.Singular {
|
||||
case ti.Singular:
|
||||
c.w.WriteString(` LIMIT ('1') :: integer`)
|
||||
|
||||
} else {
|
||||
default:
|
||||
c.w.WriteString(` LIMIT ('20') :: integer`)
|
||||
}
|
||||
}
|
||||
|
||||
if len(sel.Paging.Offset) != 0 {
|
||||
//fmt.Fprintf(w, ` OFFSET ('%s') :: integer`, c.sel.Paging.Offset)
|
||||
|
@ -370,13 +372,31 @@ func (c *compilerContext) renderJoinTable(sel *qcode.Select) error {
|
|||
}
|
||||
|
||||
func (c *compilerContext) renderColumns(sel *qcode.Select, ti *DBTableInfo) {
|
||||
for i, col := range sel.Cols {
|
||||
i := 0
|
||||
for _, col := range sel.Cols {
|
||||
if len(sel.Allowed) != 0 {
|
||||
n := funcPrefixLen(col.Name)
|
||||
if n != 0 {
|
||||
if sel.Functions == false {
|
||||
continue
|
||||
}
|
||||
if _, ok := sel.Allowed[col.Name[n:]]; !ok {
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
if _, ok := sel.Allowed[col.Name]; !ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if i != 0 {
|
||||
io.WriteString(c.w, ", ")
|
||||
}
|
||||
//fmt.Fprintf(w, `"%s_%d"."%s" AS "%s"`,
|
||||
//c.sel.Table, c.sel.ID, col.Name, col.FieldName)
|
||||
colWithTableIDAlias(c.w, ti.Name, sel.ID, col.Name, col.FieldName)
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -435,7 +455,8 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
|
|||
|
||||
c.w.WriteString(` FROM (SELECT `)
|
||||
|
||||
for i, col := range sel.Cols {
|
||||
i := 0
|
||||
for n, col := range sel.Cols {
|
||||
cn := col.Name
|
||||
|
||||
_, isRealCol := ti.Columns[cn]
|
||||
|
@ -447,6 +468,9 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
|
|||
cn = ti.TSVCol
|
||||
arg := sel.Args["search"]
|
||||
|
||||
if i != 0 {
|
||||
c.w.WriteString(`, `)
|
||||
}
|
||||
//fmt.Fprintf(w, `ts_rank("%s"."%s", to_tsquery('%s')) AS %s`,
|
||||
//c.sel.Table, cn, arg.Val, col.Name)
|
||||
c.w.WriteString(`ts_rank(`)
|
||||
|
@ -455,11 +479,15 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
|
|||
c.w.WriteString(arg.Val)
|
||||
c.w.WriteString(`')`)
|
||||
alias(c.w, col.Name)
|
||||
i++
|
||||
|
||||
case strings.HasPrefix(cn, "search_headline_"):
|
||||
cn = cn[16:]
|
||||
arg := sel.Args["search"]
|
||||
|
||||
if i != 0 {
|
||||
c.w.WriteString(`, `)
|
||||
}
|
||||
//fmt.Fprintf(w, `ts_headline("%s"."%s", to_tsquery('%s')) AS %s`,
|
||||
//c.sel.Table, cn, arg.Val, col.Name)
|
||||
c.w.WriteString(`ts_headlinek(`)
|
||||
|
@ -468,47 +496,63 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
|
|||
c.w.WriteString(arg.Val)
|
||||
c.w.WriteString(`')`)
|
||||
alias(c.w, col.Name)
|
||||
i++
|
||||
|
||||
}
|
||||
} else {
|
||||
pl := funcPrefixLen(cn)
|
||||
if pl == 0 {
|
||||
if i != 0 {
|
||||
c.w.WriteString(`, `)
|
||||
}
|
||||
//fmt.Fprintf(w, `'%s not defined' AS %s`, cn, col.Name)
|
||||
c.w.WriteString(`'`)
|
||||
c.w.WriteString(cn)
|
||||
c.w.WriteString(` not defined'`)
|
||||
alias(c.w, col.Name)
|
||||
} else {
|
||||
isAgg = true
|
||||
i++
|
||||
|
||||
} else if sel.Functions {
|
||||
cn1 := cn[pl:]
|
||||
if _, ok := sel.Allowed[cn1]; !ok {
|
||||
continue
|
||||
}
|
||||
if i != 0 {
|
||||
c.w.WriteString(`, `)
|
||||
}
|
||||
fn := cn[0 : pl-1]
|
||||
cn := cn[pl:]
|
||||
isAgg = true
|
||||
|
||||
//fmt.Fprintf(w, `%s("%s"."%s") AS %s`, fn, c.sel.Table, cn, col.Name)
|
||||
c.w.WriteString(fn)
|
||||
c.w.WriteString(`(`)
|
||||
colWithTable(c.w, ti.Name, cn)
|
||||
colWithTable(c.w, ti.Name, cn1)
|
||||
c.w.WriteString(`)`)
|
||||
alias(c.w, col.Name)
|
||||
i++
|
||||
|
||||
}
|
||||
}
|
||||
} else {
|
||||
groupBy = append(groupBy, i)
|
||||
groupBy = append(groupBy, n)
|
||||
//fmt.Fprintf(w, `"%s"."%s"`, c.sel.Table, cn)
|
||||
colWithTable(c.w, ti.Name, cn)
|
||||
}
|
||||
|
||||
if i < len(sel.Cols)-1 || len(childCols) != 0 {
|
||||
//io.WriteString(w, ", ")
|
||||
if i != 0 {
|
||||
c.w.WriteString(`, `)
|
||||
}
|
||||
colWithTable(c.w, ti.Name, cn)
|
||||
i++
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
for i, col := range childCols {
|
||||
for _, col := range childCols {
|
||||
if i != 0 {
|
||||
//io.WriteString(w, ", ")
|
||||
c.w.WriteString(`, `)
|
||||
}
|
||||
|
||||
//fmt.Fprintf(w, `"%s"."%s"`, col.Table, col.Name)
|
||||
colWithTable(c.w, col.Table, col.Name)
|
||||
i++
|
||||
}
|
||||
|
||||
c.w.WriteString(` FROM `)
|
||||
|
@ -570,20 +614,22 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
|
|||
}
|
||||
}
|
||||
|
||||
if sel.Action == 0 {
|
||||
if len(sel.Paging.Limit) != 0 {
|
||||
switch {
|
||||
case sel.Paging.NoLimit:
|
||||
break
|
||||
|
||||
case len(sel.Paging.Limit) != 0:
|
||||
//fmt.Fprintf(w, ` LIMIT ('%s') :: integer`, c.sel.Paging.Limit)
|
||||
c.w.WriteString(` LIMIT ('`)
|
||||
c.w.WriteString(sel.Paging.Limit)
|
||||
c.w.WriteString(`') :: integer`)
|
||||
|
||||
} else if ti.Singular {
|
||||
case ti.Singular:
|
||||
c.w.WriteString(` LIMIT ('1') :: integer`)
|
||||
|
||||
} else {
|
||||
default:
|
||||
c.w.WriteString(` LIMIT ('20') :: integer`)
|
||||
}
|
||||
}
|
||||
|
||||
if len(sel.Paging.Offset) != 0 {
|
||||
//fmt.Fprintf(w, ` OFFSET ('%s') :: integer`, c.sel.Paging.Offset)
|
|
@ -22,32 +22,6 @@ func TestMain(m *testing.M) {
|
|||
var err error
|
||||
|
||||
qcompile, err = qcode.NewCompiler(qcode.Config{
|
||||
DefaultFilter: []string{
|
||||
`{ user_id: { _eq: $user_id } }`,
|
||||
},
|
||||
FilterMap: qcode.Filters{
|
||||
All: map[string][]string{
|
||||
"users": []string{
|
||||
"{ id: { eq: $user_id } }",
|
||||
},
|
||||
"products": []string{
|
||||
"{ price: { gt: 0 } }",
|
||||
"{ price: { lt: 8 } }",
|
||||
},
|
||||
"customers": []string{},
|
||||
"mes": []string{
|
||||
"{ id: { eq: $user_id } }",
|
||||
},
|
||||
},
|
||||
Query: map[string][]string{
|
||||
"users": []string{},
|
||||
},
|
||||
Update: map[string][]string{
|
||||
"products": []string{
|
||||
"{ user_id: { eq: $user_id } }",
|
||||
},
|
||||
},
|
||||
},
|
||||
Blocklist: []string{
|
||||
"secret",
|
||||
"password",
|
||||
|
@ -55,6 +29,59 @@ func TestMain(m *testing.M) {
|
|||
},
|
||||
})
|
||||
|
||||
qcompile.AddRole("user", "product", qcode.TRConfig{
|
||||
Query: qcode.QueryConfig{
|
||||
Columns: []string{"id", "name", "price", "users", "customers"},
|
||||
Filter: []string{
|
||||
"{ price: { gt: 0 } }",
|
||||
"{ price: { lt: 8 } }",
|
||||
},
|
||||
},
|
||||
Update: qcode.UpdateConfig{
|
||||
Filter: []string{"{ user_id: { eq: $user_id } }"},
|
||||
},
|
||||
Delete: qcode.DeleteConfig{
|
||||
Filter: []string{
|
||||
"{ price: { gt: 0 } }",
|
||||
"{ price: { lt: 8 } }",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
qcompile.AddRole("anon", "product", qcode.TRConfig{
|
||||
Query: qcode.QueryConfig{
|
||||
Columns: []string{"id", "name"},
|
||||
},
|
||||
})
|
||||
|
||||
qcompile.AddRole("anon1", "product", qcode.TRConfig{
|
||||
Query: qcode.QueryConfig{
|
||||
Columns: []string{"id", "name", "price"},
|
||||
DisableFunctions: true,
|
||||
},
|
||||
})
|
||||
|
||||
qcompile.AddRole("user", "users", qcode.TRConfig{
|
||||
Query: qcode.QueryConfig{
|
||||
Columns: []string{"id", "full_name", "avatar", "email", "products"},
|
||||
},
|
||||
})
|
||||
|
||||
qcompile.AddRole("user", "mes", qcode.TRConfig{
|
||||
Query: qcode.QueryConfig{
|
||||
Columns: []string{"id", "full_name", "avatar"},
|
||||
Filter: []string{
|
||||
"{ id: { eq: $user_id } }",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
qcompile.AddRole("user", "customers", qcode.TRConfig{
|
||||
Query: qcode.QueryConfig{
|
||||
Columns: []string{"id", "email", "full_name", "products"},
|
||||
},
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
@ -135,9 +162,8 @@ func TestMain(m *testing.M) {
|
|||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func compileGQLToPSQL(gql string, vars Variables) ([]byte, error) {
|
||||
|
||||
qc, err := qcompile.Compile([]byte(gql))
|
||||
func compileGQLToPSQL(gql string, vars Variables, role string) ([]byte, error) {
|
||||
qc, err := qcompile.Compile([]byte(gql), role)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -147,6 +173,8 @@ func compileGQLToPSQL(gql string, vars Variables) ([]byte, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
//fmt.Println(string(sqlStmt))
|
||||
|
||||
return sqlStmt, nil
|
||||
}
|
||||
|
||||
|
@ -175,7 +203,7 @@ func withComplexArgs(t *testing.T) {
|
|||
|
||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0" ORDER BY "products_0_price_ob" DESC), '[]') AS "products" FROM (SELECT DISTINCT ON ("products_0_price_ob") row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0", "products_0"."price" AS "products_0_price_ob" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") < 28) AND (("products"."id") >= 20)) LIMIT ('30') :: integer) AS "products_0" ORDER BY "products_0_price_ob" DESC LIMIT ('30') :: integer) AS "sel_json_agg_0") AS "done_1337";`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "user")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -203,7 +231,7 @@ func withWhereMultiOr(t *testing.T) {
|
|||
|
||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") < 20) OR (("products"."price") > 10) OR NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "user")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -229,7 +257,7 @@ func withWhereIsNull(t *testing.T) {
|
|||
|
||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") > 10) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "user")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -255,7 +283,7 @@ func withWhereAndList(t *testing.T) {
|
|||
|
||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") > 10) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "user")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -275,7 +303,7 @@ func fetchByID(t *testing.T) {
|
|||
|
||||
sql := `SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") = 15)) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337";`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "user")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -295,7 +323,7 @@ func searchQuery(t *testing.T) {
|
|||
|
||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("tsv") @@ to_tsquery('Imperial'))) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "user")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -318,7 +346,7 @@ func oneToMany(t *testing.T) {
|
|||
|
||||
sql := `SELECT json_object_agg('users', users) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."email" AS "email", "products_1_join"."products" AS "products") AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."email", "users"."id" FROM "users" LIMIT ('20') :: integer) AS "users_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "products_1"."name" AS "name", "products_1"."price" AS "price") AS "sel_1")) AS "sel_json_1" FROM (SELECT "products"."name", "products"."price" FROM "products" WHERE ((("products"."user_id") = ("users_0"."id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "user")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -341,7 +369,7 @@ func belongsTo(t *testing.T) {
|
|||
|
||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."price" AS "price", "users_1_join"."users" AS "users") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", "products"."price", "products"."user_id" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "users_1"."email" AS "email") AS "sel_1")) AS "sel_json_1" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = ("products_0"."user_id"))) LIMIT ('20') :: integer) AS "users_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "users_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "user")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -364,7 +392,7 @@ func manyToMany(t *testing.T) {
|
|||
|
||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "customers_1_join"."customers" AS "customers") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", "products"."id" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "customers_1"."email" AS "email", "customers_1"."full_name" AS "full_name") AS "sel_1")) AS "sel_json_1" FROM (SELECT "customers"."email", "customers"."full_name" FROM "customers" LEFT OUTER JOIN "purchases" ON (("purchases"."product_id") = ("products_0"."id")) WHERE ((("customers"."id") = ("purchases"."customer_id"))) LIMIT ('20') :: integer) AS "customers_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "customers_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "user")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -387,7 +415,7 @@ func manyToManyReverse(t *testing.T) {
|
|||
|
||||
sql := `SELECT json_object_agg('customers', customers) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "customers_0"."email" AS "email", "customers_0"."full_name" AS "full_name", "products_1_join"."products" AS "products") AS "sel_0")) AS "sel_json_0" FROM (SELECT "customers"."email", "customers"."full_name", "customers"."id" FROM "customers" LIMIT ('20') :: integer) AS "customers_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "products_1"."name" AS "name") AS "sel_1")) AS "sel_json_1" FROM (SELECT "products"."name" FROM "products" LEFT OUTER JOIN "purchases" ON (("purchases"."customer_id") = ("customers_0"."id")) WHERE ((("products"."id") = ("purchases"."product_id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "user")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -407,7 +435,47 @@ func aggFunction(t *testing.T) {
|
|||
|
||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."count_price" AS "count_price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", count("products"."price") AS "count_price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) GROUP BY "products"."name" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "user")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if string(resSQL) != sql {
|
||||
t.Fatal(errNotExpected)
|
||||
}
|
||||
}
|
||||
|
||||
func aggFunctionBlockedByCol(t *testing.T) {
|
||||
gql := `query {
|
||||
products {
|
||||
name
|
||||
count_price
|
||||
}
|
||||
}`
|
||||
|
||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name" FROM "products" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "anon")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if string(resSQL) != sql {
|
||||
t.Fatal(errNotExpected)
|
||||
}
|
||||
}
|
||||
|
||||
func aggFunctionDisabled(t *testing.T) {
|
||||
gql := `query {
|
||||
products {
|
||||
name
|
||||
count_price
|
||||
}
|
||||
}`
|
||||
|
||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name" FROM "products" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "anon1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -427,7 +495,7 @@ func aggFunctionWithFilter(t *testing.T) {
|
|||
|
||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."max_price" AS "max_price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", max("products"."price") AS "max_price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") > 10)) GROUP BY "products"."id" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "user")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -447,7 +515,7 @@ func queryWithVariables(t *testing.T) {
|
|||
|
||||
sql := `SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") = {{product_price}}) AND (("products"."id") = {{product_id}})) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337";`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "user")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -464,9 +532,9 @@ func syntheticTables(t *testing.T) {
|
|||
}
|
||||
}`
|
||||
|
||||
sql := `SELECT json_object_agg('me', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."email" AS "email") AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = {{user_id}})) LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "done_1337";`
|
||||
sql := `SELECT json_object_agg('me', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT ) AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = {{user_id}})) LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "done_1337";`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "user")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -476,7 +544,7 @@ func syntheticTables(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestCompileSelect(t *testing.T) {
|
||||
func TestCompileQuery(t *testing.T) {
|
||||
t.Run("withComplexArgs", withComplexArgs)
|
||||
t.Run("withWhereAndList", withWhereAndList)
|
||||
t.Run("withWhereIsNull", withWhereIsNull)
|
||||
|
@ -488,10 +556,11 @@ func TestCompileSelect(t *testing.T) {
|
|||
t.Run("manyToMany", manyToMany)
|
||||
t.Run("manyToManyReverse", manyToManyReverse)
|
||||
t.Run("aggFunction", aggFunction)
|
||||
t.Run("aggFunctionBlockedByCol", aggFunctionBlockedByCol)
|
||||
t.Run("aggFunctionDisabled", aggFunctionDisabled)
|
||||
t.Run("aggFunctionWithFilter", aggFunctionWithFilter)
|
||||
t.Run("syntheticTables", syntheticTables)
|
||||
t.Run("queryWithVariables", queryWithVariables)
|
||||
|
||||
}
|
||||
|
||||
var benchGQL = []byte(`query {
|
||||
|
@ -526,7 +595,7 @@ func BenchmarkCompile(b *testing.B) {
|
|||
for n := 0; n < b.N; n++ {
|
||||
w.Reset()
|
||||
|
||||
qc, err := qcompile.Compile(benchGQL)
|
||||
qc, err := qcompile.Compile(benchGQL, "user")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
@ -547,7 +616,7 @@ func BenchmarkCompileParallel(b *testing.B) {
|
|||
for pb.Next() {
|
||||
w.Reset()
|
||||
|
||||
qc, err := qcompile.Compile(benchGQL)
|
||||
qc, err := qcompile.Compile(benchGQL, "user")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
|
@ -0,0 +1,99 @@
|
|||
package qcode
|
||||
|
||||
type Config struct {
|
||||
Blocklist []string
|
||||
KeepArgs bool
|
||||
}
|
||||
|
||||
type QueryConfig struct {
|
||||
Limit int
|
||||
Filter []string
|
||||
Columns []string
|
||||
DisableFunctions bool
|
||||
}
|
||||
|
||||
type InsertConfig struct {
|
||||
Filter []string
|
||||
Columns []string
|
||||
Set map[string]string
|
||||
}
|
||||
|
||||
type UpdateConfig struct {
|
||||
Filter []string
|
||||
Columns []string
|
||||
Set map[string]string
|
||||
}
|
||||
|
||||
type DeleteConfig struct {
|
||||
Filter []string
|
||||
Columns []string
|
||||
}
|
||||
|
||||
type TRConfig struct {
|
||||
Query QueryConfig
|
||||
Insert InsertConfig
|
||||
Update UpdateConfig
|
||||
Delete DeleteConfig
|
||||
}
|
||||
|
||||
type trval struct {
|
||||
query struct {
|
||||
limit string
|
||||
fil *Exp
|
||||
cols map[string]struct{}
|
||||
disable struct {
|
||||
funcs bool
|
||||
}
|
||||
}
|
||||
|
||||
insert struct {
|
||||
fil *Exp
|
||||
cols map[string]struct{}
|
||||
set map[string]string
|
||||
}
|
||||
|
||||
update struct {
|
||||
fil *Exp
|
||||
cols map[string]struct{}
|
||||
set map[string]string
|
||||
}
|
||||
|
||||
delete struct {
|
||||
fil *Exp
|
||||
cols map[string]struct{}
|
||||
}
|
||||
}
|
||||
|
||||
func (trv *trval) allowedColumns(qt QType) map[string]struct{} {
|
||||
switch qt {
|
||||
case QTQuery:
|
||||
return trv.query.cols
|
||||
case QTInsert:
|
||||
return trv.insert.cols
|
||||
case QTUpdate:
|
||||
return trv.update.cols
|
||||
case QTDelete:
|
||||
return trv.insert.cols
|
||||
case QTUpsert:
|
||||
return trv.insert.cols
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (trv *trval) filter(qt QType) *Exp {
|
||||
switch qt {
|
||||
case QTQuery:
|
||||
return trv.query.fil
|
||||
case QTInsert:
|
||||
return trv.insert.fil
|
||||
case QTUpdate:
|
||||
return trv.update.fil
|
||||
case QTDelete:
|
||||
return trv.delete.fil
|
||||
case QTUpsert:
|
||||
return trv.insert.fil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
|
@ -5,7 +5,7 @@ func FuzzerEntrypoint(data []byte) int {
|
|||
//testData := string(data)
|
||||
|
||||
qcompile, _ := NewCompiler(Config{})
|
||||
_, err := qcompile.Compile(data)
|
||||
_, err := qcompile.Compile(data, "user")
|
||||
if err != nil {
|
||||
return -1
|
||||
}
|
||||
|
|
|
@ -46,13 +46,18 @@ func compareOp(op1, op2 Operation) error {
|
|||
*/
|
||||
|
||||
func TestCompile1(t *testing.T) {
|
||||
qcompile, _ := NewCompiler(Config{})
|
||||
qc, _ := NewCompiler(Config{})
|
||||
qc.AddRole("user", "product", TRConfig{
|
||||
Query: QueryConfig{
|
||||
Columns: []string{"id", "Name"},
|
||||
},
|
||||
})
|
||||
|
||||
_, err := qcompile.Compile([]byte(`
|
||||
_, err := qc.Compile([]byte(`
|
||||
product(id: 15) {
|
||||
id
|
||||
name
|
||||
}`))
|
||||
}`), "user")
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
@ -60,13 +65,18 @@ func TestCompile1(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestCompile2(t *testing.T) {
|
||||
qcompile, _ := NewCompiler(Config{})
|
||||
qc, _ := NewCompiler(Config{})
|
||||
qc.AddRole("user", "product", TRConfig{
|
||||
Query: QueryConfig{
|
||||
Columns: []string{"ID"},
|
||||
},
|
||||
})
|
||||
|
||||
_, err := qcompile.Compile([]byte(`
|
||||
_, err := qc.Compile([]byte(`
|
||||
query { product(id: 15) {
|
||||
id
|
||||
name
|
||||
} }`))
|
||||
} }`), "user")
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
@ -74,15 +84,20 @@ func TestCompile2(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestCompile3(t *testing.T) {
|
||||
qcompile, _ := NewCompiler(Config{})
|
||||
qc, _ := NewCompiler(Config{})
|
||||
qc.AddRole("user", "product", TRConfig{
|
||||
Query: QueryConfig{
|
||||
Columns: []string{"ID"},
|
||||
},
|
||||
})
|
||||
|
||||
_, err := qcompile.Compile([]byte(`
|
||||
_, err := qc.Compile([]byte(`
|
||||
mutation {
|
||||
product(id: 15, name: "Test") {
|
||||
id
|
||||
name
|
||||
}
|
||||
}`))
|
||||
}`), "user")
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
@ -91,7 +106,7 @@ func TestCompile3(t *testing.T) {
|
|||
|
||||
func TestInvalidCompile1(t *testing.T) {
|
||||
qcompile, _ := NewCompiler(Config{})
|
||||
_, err := qcompile.Compile([]byte(`#`))
|
||||
_, err := qcompile.Compile([]byte(`#`), "user")
|
||||
|
||||
if err == nil {
|
||||
t.Fatal(errors.New("expecting an error"))
|
||||
|
@ -100,7 +115,7 @@ func TestInvalidCompile1(t *testing.T) {
|
|||
|
||||
func TestInvalidCompile2(t *testing.T) {
|
||||
qcompile, _ := NewCompiler(Config{})
|
||||
_, err := qcompile.Compile([]byte(`{u(where:{not:0})}`))
|
||||
_, err := qcompile.Compile([]byte(`{u(where:{not:0})}`), "user")
|
||||
|
||||
if err == nil {
|
||||
t.Fatal(errors.New("expecting an error"))
|
||||
|
@ -109,7 +124,7 @@ func TestInvalidCompile2(t *testing.T) {
|
|||
|
||||
func TestEmptyCompile(t *testing.T) {
|
||||
qcompile, _ := NewCompiler(Config{})
|
||||
_, err := qcompile.Compile([]byte(``))
|
||||
_, err := qcompile.Compile([]byte(``), "user")
|
||||
|
||||
if err == nil {
|
||||
t.Fatal(errors.New("expecting an error"))
|
||||
|
@ -144,7 +159,7 @@ func BenchmarkQCompile(b *testing.B) {
|
|||
b.ReportAllocs()
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
_, err := qcompile.Compile(gql)
|
||||
_, err := qcompile.Compile(gql, "user")
|
||||
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
|
@ -160,7 +175,7 @@ func BenchmarkQCompileP(b *testing.B) {
|
|||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
_, err := qcompile.Compile(gql)
|
||||
_, err := qcompile.Compile(gql, "user")
|
||||
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
|
|
325
qcode/qcode.go
325
qcode/qcode.go
|
@ -3,6 +3,7 @@ package qcode
|
|||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
|
@ -17,25 +18,18 @@ const (
|
|||
maxSelectors = 30
|
||||
|
||||
QTQuery QType = iota + 1
|
||||
QTMutation
|
||||
|
||||
ActionInsert Action = iota + 1
|
||||
ActionUpdate
|
||||
ActionDelete
|
||||
ActionUpsert
|
||||
QTInsert
|
||||
QTUpdate
|
||||
QTDelete
|
||||
QTUpsert
|
||||
)
|
||||
|
||||
type QCode struct {
|
||||
Type QType
|
||||
ActionVar string
|
||||
Selects []Select
|
||||
}
|
||||
|
||||
type Column struct {
|
||||
Table string
|
||||
Name string
|
||||
FieldName string
|
||||
}
|
||||
|
||||
type Select struct {
|
||||
ID int32
|
||||
ParentID int32
|
||||
|
@ -47,9 +41,15 @@ type Select struct {
|
|||
OrderBy []*OrderBy
|
||||
DistinctOn []string
|
||||
Paging Paging
|
||||
Action Action
|
||||
ActionVar string
|
||||
Children []int32
|
||||
Functions bool
|
||||
Allowed map[string]struct{}
|
||||
}
|
||||
|
||||
type Column struct {
|
||||
Table string
|
||||
Name string
|
||||
FieldName string
|
||||
}
|
||||
|
||||
type Exp struct {
|
||||
|
@ -79,6 +79,7 @@ type OrderBy struct {
|
|||
type Paging struct {
|
||||
Limit string
|
||||
Offset string
|
||||
NoLimit bool
|
||||
}
|
||||
|
||||
type ExpOp int
|
||||
|
@ -145,81 +146,23 @@ const (
|
|||
OrderDescNullsLast
|
||||
)
|
||||
|
||||
type Filters struct {
|
||||
All map[string][]string
|
||||
Query map[string][]string
|
||||
Insert map[string][]string
|
||||
Update map[string][]string
|
||||
Delete map[string][]string
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
DefaultFilter []string
|
||||
FilterMap Filters
|
||||
Blocklist []string
|
||||
KeepArgs bool
|
||||
}
|
||||
|
||||
type Compiler struct {
|
||||
df *Exp
|
||||
fm struct {
|
||||
all map[string]*Exp
|
||||
query map[string]*Exp
|
||||
insert map[string]*Exp
|
||||
update map[string]*Exp
|
||||
delete map[string]*Exp
|
||||
}
|
||||
tr map[string]map[string]*trval
|
||||
bl map[string]struct{}
|
||||
ka bool
|
||||
}
|
||||
|
||||
var opMap = map[parserType]QType{
|
||||
opQuery: QTQuery,
|
||||
opMutate: QTMutation,
|
||||
}
|
||||
|
||||
var expPool = sync.Pool{
|
||||
New: func() interface{} { return &Exp{doFree: true} },
|
||||
}
|
||||
|
||||
func NewCompiler(c Config) (*Compiler, error) {
|
||||
var err error
|
||||
co := &Compiler{ka: c.KeepArgs}
|
||||
|
||||
co.tr = make(map[string]map[string]*trval)
|
||||
co.bl = make(map[string]struct{}, len(c.Blocklist))
|
||||
|
||||
for i := range c.Blocklist {
|
||||
co.bl[c.Blocklist[i]] = struct{}{}
|
||||
}
|
||||
|
||||
co.df, err = compileFilter(c.DefaultFilter)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
co.fm.all, err = buildFilters(c.FilterMap.All)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
co.fm.query, err = buildFilters(c.FilterMap.Query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
co.fm.insert, err = buildFilters(c.FilterMap.Insert)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
co.fm.update, err = buildFilters(c.FilterMap.Update)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
co.fm.delete, err = buildFilters(c.FilterMap.Delete)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
co.bl[strings.ToLower(c.Blocklist[i])] = struct{}{}
|
||||
}
|
||||
|
||||
seedExp := [100]Exp{}
|
||||
|
@ -232,58 +175,99 @@ func NewCompiler(c Config) (*Compiler, error) {
|
|||
return co, nil
|
||||
}
|
||||
|
||||
func buildFilters(filMap map[string][]string) (map[string]*Exp, error) {
|
||||
fm := make(map[string]*Exp, len(filMap))
|
||||
|
||||
for k, v := range filMap {
|
||||
fil, err := compileFilter(v)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
singular := flect.Singularize(k)
|
||||
plural := flect.Pluralize(k)
|
||||
|
||||
fm[singular] = fil
|
||||
fm[plural] = fil
|
||||
}
|
||||
|
||||
return fm, nil
|
||||
}
|
||||
|
||||
func (com *Compiler) Compile(query []byte) (*QCode, error) {
|
||||
var qc QCode
|
||||
func (com *Compiler) AddRole(role, table string, trc TRConfig) error {
|
||||
var err error
|
||||
trv := &trval{}
|
||||
|
||||
toMap := func(cols []string) map[string]struct{} {
|
||||
m := make(map[string]struct{}, len(cols))
|
||||
for i := range cols {
|
||||
m[strings.ToLower(cols[i])] = struct{}{}
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
// query config
|
||||
trv.query.fil, err = compileFilter(trc.Query.Filter)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if trc.Query.Limit > 0 {
|
||||
trv.query.limit = strconv.Itoa(trc.Query.Limit)
|
||||
}
|
||||
trv.query.cols = toMap(trc.Query.Columns)
|
||||
trv.query.disable.funcs = trc.Query.DisableFunctions
|
||||
|
||||
// insert config
|
||||
if trv.insert.fil, err = compileFilter(trc.Insert.Filter); err != nil {
|
||||
return err
|
||||
}
|
||||
trv.insert.cols = toMap(trc.Insert.Columns)
|
||||
|
||||
// update config
|
||||
if trv.update.fil, err = compileFilter(trc.Update.Filter); err != nil {
|
||||
return err
|
||||
}
|
||||
trv.insert.cols = toMap(trc.Insert.Columns)
|
||||
trv.insert.set = trc.Insert.Set
|
||||
|
||||
// delete config
|
||||
if trv.delete.fil, err = compileFilter(trc.Delete.Filter); err != nil {
|
||||
return err
|
||||
}
|
||||
trv.delete.cols = toMap(trc.Delete.Columns)
|
||||
|
||||
singular := flect.Singularize(table)
|
||||
plural := flect.Pluralize(table)
|
||||
|
||||
if _, ok := com.tr[role]; !ok {
|
||||
com.tr[role] = make(map[string]*trval)
|
||||
}
|
||||
|
||||
com.tr[role][singular] = trv
|
||||
com.tr[role][plural] = trv
|
||||
return nil
|
||||
}
|
||||
|
||||
func (com *Compiler) Compile(query []byte, role string) (*QCode, error) {
|
||||
var err error
|
||||
|
||||
qc := QCode{Type: QTQuery}
|
||||
|
||||
op, err := Parse(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
qc.Selects, err = com.compileQuery(op)
|
||||
if err != nil {
|
||||
if err = com.compileQuery(&qc, op, role); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if t, ok := opMap[op.Type]; ok {
|
||||
qc.Type = t
|
||||
} else {
|
||||
return nil, fmt.Errorf("Unknown operation type %d", op.Type)
|
||||
}
|
||||
|
||||
opPool.Put(op)
|
||||
|
||||
return &qc, nil
|
||||
}
|
||||
|
||||
func (com *Compiler) compileQuery(op *Operation) ([]Select, error) {
|
||||
func (com *Compiler) compileQuery(qc *QCode, op *Operation, role string) error {
|
||||
id := int32(0)
|
||||
parentID := int32(0)
|
||||
|
||||
if len(op.Fields) == 0 {
|
||||
return errors.New("invalid graphql no query found")
|
||||
}
|
||||
|
||||
if op.Type == opMutate {
|
||||
if err := com.setMutationType(qc, op.Fields[0].Args); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
selects := make([]Select, 0, 5)
|
||||
st := NewStack()
|
||||
action := qc.Type
|
||||
|
||||
if len(op.Fields) == 0 {
|
||||
return nil, errors.New("empty query")
|
||||
return errors.New("empty query")
|
||||
}
|
||||
st.Push(op.Fields[0].ID)
|
||||
|
||||
|
@ -293,7 +277,7 @@ func (com *Compiler) compileQuery(op *Operation) ([]Select, error) {
|
|||
}
|
||||
|
||||
if id >= maxSelectors {
|
||||
return nil, fmt.Errorf("selector limit reached (%d)", maxSelectors)
|
||||
return fmt.Errorf("selector limit reached (%d)", maxSelectors)
|
||||
}
|
||||
|
||||
fid := st.Pop()
|
||||
|
@ -303,14 +287,28 @@ func (com *Compiler) compileQuery(op *Operation) ([]Select, error) {
|
|||
continue
|
||||
}
|
||||
|
||||
trv, ok := com.tr[role][field.Name]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
selects = append(selects, Select{
|
||||
ID: id,
|
||||
ParentID: parentID,
|
||||
Table: field.Name,
|
||||
Children: make([]int32, 0, 5),
|
||||
Allowed: trv.allowedColumns(action),
|
||||
})
|
||||
s := &selects[(len(selects) - 1)]
|
||||
|
||||
if action == QTQuery {
|
||||
s.Functions = !trv.query.disable.funcs
|
||||
|
||||
if len(trv.query.limit) != 0 {
|
||||
s.Paging.Limit = trv.query.limit
|
||||
}
|
||||
}
|
||||
|
||||
if s.ID != 0 {
|
||||
p := &selects[s.ParentID]
|
||||
p.Children = append(p.Children, s.ID)
|
||||
|
@ -322,12 +320,13 @@ func (com *Compiler) compileQuery(op *Operation) ([]Select, error) {
|
|||
s.FieldName = s.Table
|
||||
}
|
||||
|
||||
err := com.compileArgs(s, field.Args)
|
||||
err := com.compileArgs(qc, s, field.Args)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return err
|
||||
}
|
||||
|
||||
s.Cols = make([]Column, 0, len(field.Children))
|
||||
action = QTQuery
|
||||
|
||||
for _, cid := range field.Children {
|
||||
f := op.Fields[cid]
|
||||
|
@ -356,36 +355,14 @@ func (com *Compiler) compileQuery(op *Operation) ([]Select, error) {
|
|||
}
|
||||
|
||||
if id == 0 {
|
||||
return nil, errors.New("invalid query")
|
||||
return errors.New("invalid query")
|
||||
}
|
||||
|
||||
var fil *Exp
|
||||
|
||||
root := &selects[0]
|
||||
|
||||
switch op.Type {
|
||||
case opQuery:
|
||||
fil, _ = com.fm.query[root.Table]
|
||||
|
||||
case opMutate:
|
||||
switch root.Action {
|
||||
case ActionInsert:
|
||||
fil, _ = com.fm.insert[root.Table]
|
||||
case ActionUpdate:
|
||||
fil, _ = com.fm.update[root.Table]
|
||||
case ActionDelete:
|
||||
fil, _ = com.fm.delete[root.Table]
|
||||
case ActionUpsert:
|
||||
fil, _ = com.fm.insert[root.Table]
|
||||
}
|
||||
}
|
||||
|
||||
if fil == nil {
|
||||
fil, _ = com.fm.all[root.Table]
|
||||
}
|
||||
|
||||
if fil == nil {
|
||||
fil = com.df
|
||||
if trv, ok := com.tr[role][op.Fields[0].Name]; ok {
|
||||
fil = trv.filter(qc.Type)
|
||||
}
|
||||
|
||||
if fil != nil && fil.Op != OpNop {
|
||||
|
@ -403,10 +380,11 @@ func (com *Compiler) compileQuery(op *Operation) ([]Select, error) {
|
|||
}
|
||||
}
|
||||
|
||||
return selects[:id], nil
|
||||
qc.Selects = selects[:id]
|
||||
return nil
|
||||
}
|
||||
|
||||
func (com *Compiler) compileArgs(sel *Select, args []Arg) error {
|
||||
func (com *Compiler) compileArgs(qc *QCode, sel *Select, args []Arg) error {
|
||||
var err error
|
||||
|
||||
if com.ka {
|
||||
|
@ -418,9 +396,7 @@ func (com *Compiler) compileArgs(sel *Select, args []Arg) error {
|
|||
|
||||
switch arg.Name {
|
||||
case "id":
|
||||
if sel.ID == 0 {
|
||||
err = com.compileArgID(sel, arg)
|
||||
}
|
||||
case "search":
|
||||
err = com.compileArgSearch(sel, arg)
|
||||
case "where":
|
||||
|
@ -433,18 +409,6 @@ func (com *Compiler) compileArgs(sel *Select, args []Arg) error {
|
|||
err = com.compileArgLimit(sel, arg)
|
||||
case "offset":
|
||||
err = com.compileArgOffset(sel, arg)
|
||||
case "insert":
|
||||
sel.Action = ActionInsert
|
||||
err = com.compileArgAction(sel, arg)
|
||||
case "update":
|
||||
sel.Action = ActionUpdate
|
||||
err = com.compileArgAction(sel, arg)
|
||||
case "upsert":
|
||||
sel.Action = ActionUpsert
|
||||
err = com.compileArgAction(sel, arg)
|
||||
case "delete":
|
||||
sel.Action = ActionDelete
|
||||
err = com.compileArgAction(sel, arg)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
@ -461,6 +425,45 @@ func (com *Compiler) compileArgs(sel *Select, args []Arg) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (com *Compiler) setMutationType(qc *QCode, args []Arg) error {
|
||||
setActionVar := func(arg *Arg) error {
|
||||
if arg.Val.Type != nodeVar {
|
||||
return fmt.Errorf("value for argument '%s' must be a variable", arg.Name)
|
||||
}
|
||||
qc.ActionVar = arg.Val.Val
|
||||
return nil
|
||||
}
|
||||
|
||||
for i := range args {
|
||||
arg := &args[i]
|
||||
|
||||
switch arg.Name {
|
||||
case "insert":
|
||||
qc.Type = QTInsert
|
||||
return setActionVar(arg)
|
||||
case "update":
|
||||
qc.Type = QTUpdate
|
||||
return setActionVar(arg)
|
||||
case "upsert":
|
||||
qc.Type = QTUpsert
|
||||
return setActionVar(arg)
|
||||
case "delete":
|
||||
qc.Type = QTDelete
|
||||
|
||||
if arg.Val.Type != nodeBool {
|
||||
return fmt.Errorf("value for argument '%s' must be a boolean", arg.Name)
|
||||
}
|
||||
|
||||
if arg.Val.Val == "false" {
|
||||
qc.Type = QTQuery
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (com *Compiler) compileArgObj(st *util.Stack, arg *Arg) (*Exp, error) {
|
||||
if arg.Val.Type != nodeObj {
|
||||
return nil, fmt.Errorf("expecting an object")
|
||||
|
@ -540,6 +543,10 @@ func (com *Compiler) compileArgNode(st *util.Stack, node *Node, usePool bool) (*
|
|||
}
|
||||
|
||||
func (com *Compiler) compileArgID(sel *Select, arg *Arg) error {
|
||||
if sel.ID != 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if sel.Where != nil && sel.Where.Op == OpEqID {
|
||||
return nil
|
||||
}
|
||||
|
@ -732,26 +739,6 @@ func (com *Compiler) compileArgOffset(sel *Select, arg *Arg) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (com *Compiler) compileArgAction(sel *Select, arg *Arg) error {
|
||||
switch sel.Action {
|
||||
case ActionDelete:
|
||||
if arg.Val.Type != nodeBool {
|
||||
return fmt.Errorf("value for argument '%s' must be a boolean", arg.Name)
|
||||
}
|
||||
if arg.Val.Val == "false" {
|
||||
sel.Action = 0
|
||||
}
|
||||
|
||||
default:
|
||||
if arg.Val.Type != nodeVar {
|
||||
return fmt.Errorf("value for argument '%s' must be a variable", arg.Name)
|
||||
}
|
||||
sel.ActionVar = arg.Val.Val
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func newExp(st *util.Stack, node *Node, usePool bool) (*Exp, error) {
|
||||
name := node.Name
|
||||
if name[0] == '_' {
|
||||
|
|
|
@ -71,15 +71,11 @@ type config struct {
|
|||
} `mapstructure:"database"`
|
||||
|
||||
Tables []configTable
|
||||
Roles []configRoles
|
||||
}
|
||||
|
||||
type configTable struct {
|
||||
Name string
|
||||
Filter []string
|
||||
FilterQuery []string `mapstructure:"filter_query"`
|
||||
FilterInsert []string `mapstructure:"filter_insert"`
|
||||
FilterUpdate []string `mapstructure:"filter_update"`
|
||||
FilterDelete []string `mapstructure:"filter_delete"`
|
||||
Table string
|
||||
Blocklist []string
|
||||
Remotes []configRemote
|
||||
|
@ -98,6 +94,41 @@ type configRemote struct {
|
|||
} `mapstructure:"set_headers"`
|
||||
}
|
||||
|
||||
type configRoles struct {
|
||||
Name string
|
||||
Tables []struct {
|
||||
Name string
|
||||
|
||||
Query struct {
|
||||
Limit int
|
||||
Filter []string
|
||||
Columns []string
|
||||
DisableAggregation bool `mapstructure:"disable_aggregation"`
|
||||
Deny bool
|
||||
}
|
||||
|
||||
Insert struct {
|
||||
Filter []string
|
||||
Columns []string
|
||||
Set map[string]string
|
||||
Deny bool
|
||||
}
|
||||
|
||||
Update struct {
|
||||
Filter []string
|
||||
Columns []string
|
||||
Set map[string]string
|
||||
Deny bool
|
||||
}
|
||||
|
||||
Delete struct {
|
||||
Filter []string
|
||||
Columns []string
|
||||
Deny bool
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func newConfig() *viper.Viper {
|
||||
vi := viper.New()
|
||||
|
||||
|
|
|
@ -59,7 +59,7 @@ func (c *coreContext) execQuery() ([]byte, error) {
|
|||
|
||||
} else {
|
||||
|
||||
qc, err = qcompile.Compile([]byte(c.req.Query))
|
||||
qc, err = qcompile.Compile([]byte(c.req.Query), "user")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
|
|
@ -40,7 +40,7 @@ func prepareStmt(key, gql string, varBytes json.RawMessage) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
qc, err := qcompile.Compile([]byte(gql))
|
||||
qc, err := qcompile.Compile([]byte(gql), "user")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
70
serv/serv.go
70
serv/serv.go
|
@ -12,7 +12,6 @@ import (
|
|||
rice "github.com/GeertJohan/go.rice"
|
||||
"github.com/dosco/super-graph/psql"
|
||||
"github.com/dosco/super-graph/qcode"
|
||||
"github.com/gobuffalo/flect"
|
||||
)
|
||||
|
||||
func initCompilers(c *config) (*qcode.Compiler, *psql.Compiler, error) {
|
||||
|
@ -22,49 +21,50 @@ func initCompilers(c *config) (*qcode.Compiler, *psql.Compiler, error) {
|
|||
}
|
||||
|
||||
conf := qcode.Config{
|
||||
DefaultFilter: c.DB.Defaults.Filter,
|
||||
FilterMap: qcode.Filters{
|
||||
All: make(map[string][]string, len(c.Tables)),
|
||||
Query: make(map[string][]string, len(c.Tables)),
|
||||
Insert: make(map[string][]string, len(c.Tables)),
|
||||
Update: make(map[string][]string, len(c.Tables)),
|
||||
Delete: make(map[string][]string, len(c.Tables)),
|
||||
},
|
||||
Blocklist: c.DB.Defaults.Blocklist,
|
||||
KeepArgs: false,
|
||||
}
|
||||
|
||||
for i := range c.Tables {
|
||||
t := c.Tables[i]
|
||||
|
||||
singular := flect.Singularize(t.Name)
|
||||
plural := flect.Pluralize(t.Name)
|
||||
|
||||
setFilter := func(fm map[string][]string, fil []string) {
|
||||
switch {
|
||||
case len(fil) == 0:
|
||||
return
|
||||
case fil[0] == "none" || len(fil[0]) == 0:
|
||||
fm[singular] = []string{}
|
||||
fm[plural] = []string{}
|
||||
default:
|
||||
fm[singular] = t.Filter
|
||||
fm[plural] = t.Filter
|
||||
}
|
||||
}
|
||||
|
||||
setFilter(conf.FilterMap.All, t.Filter)
|
||||
setFilter(conf.FilterMap.Query, t.FilterQuery)
|
||||
setFilter(conf.FilterMap.Insert, t.FilterInsert)
|
||||
setFilter(conf.FilterMap.Update, t.FilterUpdate)
|
||||
setFilter(conf.FilterMap.Delete, t.FilterDelete)
|
||||
}
|
||||
|
||||
qc, err := qcode.NewCompiler(conf)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
for _, r := range c.Roles {
|
||||
for _, t := range r.Tables {
|
||||
query := qcode.QueryConfig{
|
||||
Limit: t.Query.Limit,
|
||||
Filter: t.Query.Filter,
|
||||
Columns: t.Query.Columns,
|
||||
DisableFunctions: t.Query.DisableAggregation,
|
||||
}
|
||||
|
||||
insert := qcode.InsertConfig{
|
||||
Filter: t.Insert.Filter,
|
||||
Columns: t.Insert.Columns,
|
||||
Set: t.Insert.Set,
|
||||
}
|
||||
|
||||
update := qcode.UpdateConfig{
|
||||
Filter: t.Insert.Filter,
|
||||
Columns: t.Insert.Columns,
|
||||
Set: t.Insert.Set,
|
||||
}
|
||||
|
||||
delete := qcode.DeleteConfig{
|
||||
Filter: t.Insert.Filter,
|
||||
Columns: t.Insert.Columns,
|
||||
}
|
||||
|
||||
qc.AddRole(r.Name, t.Name, qcode.TRConfig{
|
||||
Query: query,
|
||||
Insert: insert,
|
||||
Update: update,
|
||||
Delete: delete,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
pc := psql.NewCompiler(psql.Config{
|
||||
Schema: schema,
|
||||
Vars: c.getVariables(),
|
||||
|
|
Loading…
Reference in New Issue