Add role based access control
This commit is contained in:
@ -10,7 +10,7 @@ import (
|
||||
"github.com/dosco/super-graph/qcode"
|
||||
)
|
||||
|
||||
var zeroPaging = qcode.Paging{}
|
||||
var noLimit = qcode.Paging{NoLimit: true}
|
||||
|
||||
func (co *Compiler) compileMutation(qc *qcode.QCode, w *bytes.Buffer, vars Variables) (uint32, error) {
|
||||
if len(qc.Selects) == 0 {
|
||||
@ -29,23 +29,23 @@ func (co *Compiler) compileMutation(qc *qcode.QCode, w *bytes.Buffer, vars Varia
|
||||
quoted(c.w, ti.Name)
|
||||
c.w.WriteString(` AS `)
|
||||
|
||||
switch root.Action {
|
||||
case qcode.ActionInsert:
|
||||
switch qc.Type {
|
||||
case qcode.QTInsert:
|
||||
if _, err := c.renderInsert(qc, w, vars, ti); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
case qcode.ActionUpdate:
|
||||
case qcode.QTUpdate:
|
||||
if _, err := c.renderUpdate(qc, w, vars, ti); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
case qcode.ActionUpsert:
|
||||
case qcode.QTUpsert:
|
||||
if _, err := c.renderUpsert(qc, w, vars, ti); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
case qcode.ActionDelete:
|
||||
case qcode.QTDelete:
|
||||
if _, err := c.renderDelete(qc, w, vars, ti); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
@ -56,22 +56,23 @@ func (co *Compiler) compileMutation(qc *qcode.QCode, w *bytes.Buffer, vars Varia
|
||||
|
||||
io.WriteString(c.w, ` RETURNING *) `)
|
||||
|
||||
root.Paging = zeroPaging
|
||||
root.Paging = noLimit
|
||||
root.DistinctOn = root.DistinctOn[:]
|
||||
root.OrderBy = root.OrderBy[:]
|
||||
root.Where = nil
|
||||
root.Args = nil
|
||||
|
||||
qc.Type = qcode.QTQuery
|
||||
|
||||
return c.compileQuery(qc, w)
|
||||
}
|
||||
|
||||
func (c *compilerContext) renderInsert(qc *qcode.QCode, w *bytes.Buffer,
|
||||
vars Variables, ti *DBTableInfo) (uint32, error) {
|
||||
root := &qc.Selects[0]
|
||||
|
||||
insert, ok := vars[root.ActionVar]
|
||||
insert, ok := vars[qc.ActionVar]
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("Variable '%s' not defined", root.ActionVar)
|
||||
return 0, fmt.Errorf("Variable '%s' not defined", qc.ActionVar)
|
||||
}
|
||||
|
||||
jt, array, err := jsn.Tree(insert)
|
||||
@ -80,7 +81,7 @@ func (c *compilerContext) renderInsert(qc *qcode.QCode, w *bytes.Buffer,
|
||||
}
|
||||
|
||||
c.w.WriteString(`(WITH "input" AS (SELECT {{`)
|
||||
c.w.WriteString(root.ActionVar)
|
||||
c.w.WriteString(qc.ActionVar)
|
||||
c.w.WriteString(`}}::json AS j) INSERT INTO `)
|
||||
quoted(c.w, ti.Name)
|
||||
io.WriteString(c.w, ` (`)
|
||||
@ -106,12 +107,18 @@ func (c *compilerContext) renderInsert(qc *qcode.QCode, w *bytes.Buffer,
|
||||
|
||||
func (c *compilerContext) renderInsertUpdateColumns(qc *qcode.QCode, w *bytes.Buffer,
|
||||
jt map[string]interface{}, ti *DBTableInfo) (uint32, error) {
|
||||
root := &qc.Selects[0]
|
||||
|
||||
i := 0
|
||||
for _, cn := range ti.ColumnNames {
|
||||
if _, ok := jt[cn]; !ok {
|
||||
continue
|
||||
}
|
||||
if len(root.Allowed) != 0 {
|
||||
if _, ok := root.Allowed[cn]; !ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if i != 0 {
|
||||
io.WriteString(c.w, `, `)
|
||||
}
|
||||
@ -126,9 +133,9 @@ func (c *compilerContext) renderUpdate(qc *qcode.QCode, w *bytes.Buffer,
|
||||
vars Variables, ti *DBTableInfo) (uint32, error) {
|
||||
root := &qc.Selects[0]
|
||||
|
||||
update, ok := vars[root.ActionVar]
|
||||
update, ok := vars[qc.ActionVar]
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("Variable '%s' not defined", root.ActionVar)
|
||||
return 0, fmt.Errorf("Variable '%s' not defined", qc.ActionVar)
|
||||
}
|
||||
|
||||
jt, array, err := jsn.Tree(update)
|
||||
@ -137,7 +144,7 @@ func (c *compilerContext) renderUpdate(qc *qcode.QCode, w *bytes.Buffer,
|
||||
}
|
||||
|
||||
c.w.WriteString(`(WITH "input" AS (SELECT {{`)
|
||||
c.w.WriteString(root.ActionVar)
|
||||
c.w.WriteString(qc.ActionVar)
|
||||
c.w.WriteString(`}}::json AS j) UPDATE `)
|
||||
quoted(c.w, ti.Name)
|
||||
io.WriteString(c.w, ` SET (`)
|
||||
@ -183,11 +190,10 @@ func (c *compilerContext) renderDelete(qc *qcode.QCode, w *bytes.Buffer,
|
||||
|
||||
func (c *compilerContext) renderUpsert(qc *qcode.QCode, w *bytes.Buffer,
|
||||
vars Variables, ti *DBTableInfo) (uint32, error) {
|
||||
root := &qc.Selects[0]
|
||||
|
||||
upsert, ok := vars[root.ActionVar]
|
||||
upsert, ok := vars[qc.ActionVar]
|
||||
if !ok {
|
||||
return 0, fmt.Errorf("Variable '%s' not defined", root.ActionVar)
|
||||
return 0, fmt.Errorf("Variable '%s' not defined", qc.ActionVar)
|
||||
}
|
||||
|
||||
jt, _, err := jsn.Tree(upsert)
|
@ -18,7 +18,7 @@ func simpleInsert(t *testing.T) {
|
||||
"data": json.RawMessage(`{"email": "reannagreenholt@orn.com", "full_name": "Flo Barton"}`),
|
||||
}
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, vars)
|
||||
resSQL, err := compileGQLToPSQL(gql, vars, "user")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -42,7 +42,7 @@ func singleInsert(t *testing.T) {
|
||||
"insert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc", "user_id": 5 }`),
|
||||
}
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, vars)
|
||||
resSQL, err := compileGQLToPSQL(gql, vars, "user")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -54,7 +54,7 @@ func singleInsert(t *testing.T) {
|
||||
|
||||
func bulkInsert(t *testing.T) {
|
||||
gql := `mutation {
|
||||
product(id: 15, insert: $insert) {
|
||||
product(name: "test", id: 15, insert: $insert) {
|
||||
id
|
||||
name
|
||||
}
|
||||
@ -66,7 +66,7 @@ func bulkInsert(t *testing.T) {
|
||||
"insert": json.RawMessage(` [{ "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }]`),
|
||||
}
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, vars)
|
||||
resSQL, err := compileGQLToPSQL(gql, vars, "user")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -90,7 +90,7 @@ func singleUpsert(t *testing.T) {
|
||||
"upsert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`),
|
||||
}
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, vars)
|
||||
resSQL, err := compileGQLToPSQL(gql, vars, "user")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -114,7 +114,7 @@ func bulkUpsert(t *testing.T) {
|
||||
"upsert": json.RawMessage(` [{ "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }]`),
|
||||
}
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, vars)
|
||||
resSQL, err := compileGQLToPSQL(gql, vars, "user")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -138,7 +138,7 @@ func singleUpdate(t *testing.T) {
|
||||
"update": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`),
|
||||
}
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, vars)
|
||||
resSQL, err := compileGQLToPSQL(gql, vars, "user")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -162,7 +162,7 @@ func delete(t *testing.T) {
|
||||
"update": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`),
|
||||
}
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, vars)
|
||||
resSQL, err := compileGQLToPSQL(gql, vars, "user")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -172,7 +172,7 @@ func delete(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileInsert(t *testing.T) {
|
||||
func TestCompileMutate(t *testing.T) {
|
||||
t.Run("simpleInsert", simpleInsert)
|
||||
t.Run("singleInsert", singleInsert)
|
||||
t.Run("bulkInsert", bulkInsert)
|
@ -64,11 +64,11 @@ func (co *Compiler) Compile(qc *qcode.QCode, w *bytes.Buffer, vars Variables) (u
|
||||
switch qc.Type {
|
||||
case qcode.QTQuery:
|
||||
return co.compileQuery(qc, w)
|
||||
case qcode.QTMutation:
|
||||
case qcode.QTInsert, qcode.QTUpdate, qcode.QTDelete, qcode.QTUpsert:
|
||||
return co.compileMutation(qc, w, vars)
|
||||
}
|
||||
|
||||
return 0, errors.New("unknown operation")
|
||||
return 0, fmt.Errorf("Unknown operation type %d", qc.Type)
|
||||
}
|
||||
|
||||
func (co *Compiler) compileQuery(qc *qcode.QCode, w *bytes.Buffer) (uint32, error) {
|
||||
@ -295,19 +295,21 @@ func (c *compilerContext) renderSelectClose(sel *qcode.Select, ti *DBTableInfo)
|
||||
}
|
||||
}
|
||||
|
||||
if sel.Action == 0 {
|
||||
if len(sel.Paging.Limit) != 0 {
|
||||
//fmt.Fprintf(w, ` LIMIT ('%s') :: integer`, c.sel.Paging.Limit)
|
||||
c.w.WriteString(` LIMIT ('`)
|
||||
c.w.WriteString(sel.Paging.Limit)
|
||||
c.w.WriteString(`') :: integer`)
|
||||
switch {
|
||||
case sel.Paging.NoLimit:
|
||||
break
|
||||
|
||||
} else if ti.Singular {
|
||||
c.w.WriteString(` LIMIT ('1') :: integer`)
|
||||
case len(sel.Paging.Limit) != 0:
|
||||
//fmt.Fprintf(w, ` LIMIT ('%s') :: integer`, c.sel.Paging.Limit)
|
||||
c.w.WriteString(` LIMIT ('`)
|
||||
c.w.WriteString(sel.Paging.Limit)
|
||||
c.w.WriteString(`') :: integer`)
|
||||
|
||||
} else {
|
||||
c.w.WriteString(` LIMIT ('20') :: integer`)
|
||||
}
|
||||
case ti.Singular:
|
||||
c.w.WriteString(` LIMIT ('1') :: integer`)
|
||||
|
||||
default:
|
||||
c.w.WriteString(` LIMIT ('20') :: integer`)
|
||||
}
|
||||
|
||||
if len(sel.Paging.Offset) != 0 {
|
||||
@ -370,13 +372,31 @@ func (c *compilerContext) renderJoinTable(sel *qcode.Select) error {
|
||||
}
|
||||
|
||||
func (c *compilerContext) renderColumns(sel *qcode.Select, ti *DBTableInfo) {
|
||||
for i, col := range sel.Cols {
|
||||
i := 0
|
||||
for _, col := range sel.Cols {
|
||||
if len(sel.Allowed) != 0 {
|
||||
n := funcPrefixLen(col.Name)
|
||||
if n != 0 {
|
||||
if sel.Functions == false {
|
||||
continue
|
||||
}
|
||||
if _, ok := sel.Allowed[col.Name[n:]]; !ok {
|
||||
continue
|
||||
}
|
||||
} else {
|
||||
if _, ok := sel.Allowed[col.Name]; !ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if i != 0 {
|
||||
io.WriteString(c.w, ", ")
|
||||
}
|
||||
//fmt.Fprintf(w, `"%s_%d"."%s" AS "%s"`,
|
||||
//c.sel.Table, c.sel.ID, col.Name, col.FieldName)
|
||||
colWithTableIDAlias(c.w, ti.Name, sel.ID, col.Name, col.FieldName)
|
||||
i++
|
||||
}
|
||||
}
|
||||
|
||||
@ -435,7 +455,8 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
|
||||
|
||||
c.w.WriteString(` FROM (SELECT `)
|
||||
|
||||
for i, col := range sel.Cols {
|
||||
i := 0
|
||||
for n, col := range sel.Cols {
|
||||
cn := col.Name
|
||||
|
||||
_, isRealCol := ti.Columns[cn]
|
||||
@ -447,6 +468,9 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
|
||||
cn = ti.TSVCol
|
||||
arg := sel.Args["search"]
|
||||
|
||||
if i != 0 {
|
||||
c.w.WriteString(`, `)
|
||||
}
|
||||
//fmt.Fprintf(w, `ts_rank("%s"."%s", to_tsquery('%s')) AS %s`,
|
||||
//c.sel.Table, cn, arg.Val, col.Name)
|
||||
c.w.WriteString(`ts_rank(`)
|
||||
@ -455,11 +479,15 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
|
||||
c.w.WriteString(arg.Val)
|
||||
c.w.WriteString(`')`)
|
||||
alias(c.w, col.Name)
|
||||
i++
|
||||
|
||||
case strings.HasPrefix(cn, "search_headline_"):
|
||||
cn = cn[16:]
|
||||
arg := sel.Args["search"]
|
||||
|
||||
if i != 0 {
|
||||
c.w.WriteString(`, `)
|
||||
}
|
||||
//fmt.Fprintf(w, `ts_headline("%s"."%s", to_tsquery('%s')) AS %s`,
|
||||
//c.sel.Table, cn, arg.Val, col.Name)
|
||||
c.w.WriteString(`ts_headlinek(`)
|
||||
@ -468,47 +496,63 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
|
||||
c.w.WriteString(arg.Val)
|
||||
c.w.WriteString(`')`)
|
||||
alias(c.w, col.Name)
|
||||
i++
|
||||
|
||||
}
|
||||
} else {
|
||||
pl := funcPrefixLen(cn)
|
||||
if pl == 0 {
|
||||
if i != 0 {
|
||||
c.w.WriteString(`, `)
|
||||
}
|
||||
//fmt.Fprintf(w, `'%s not defined' AS %s`, cn, col.Name)
|
||||
c.w.WriteString(`'`)
|
||||
c.w.WriteString(cn)
|
||||
c.w.WriteString(` not defined'`)
|
||||
alias(c.w, col.Name)
|
||||
} else {
|
||||
isAgg = true
|
||||
i++
|
||||
|
||||
} else if sel.Functions {
|
||||
cn1 := cn[pl:]
|
||||
if _, ok := sel.Allowed[cn1]; !ok {
|
||||
continue
|
||||
}
|
||||
if i != 0 {
|
||||
c.w.WriteString(`, `)
|
||||
}
|
||||
fn := cn[0 : pl-1]
|
||||
cn := cn[pl:]
|
||||
isAgg = true
|
||||
|
||||
//fmt.Fprintf(w, `%s("%s"."%s") AS %s`, fn, c.sel.Table, cn, col.Name)
|
||||
c.w.WriteString(fn)
|
||||
c.w.WriteString(`(`)
|
||||
colWithTable(c.w, ti.Name, cn)
|
||||
colWithTable(c.w, ti.Name, cn1)
|
||||
c.w.WriteString(`)`)
|
||||
alias(c.w, col.Name)
|
||||
i++
|
||||
|
||||
}
|
||||
}
|
||||
} else {
|
||||
groupBy = append(groupBy, i)
|
||||
groupBy = append(groupBy, n)
|
||||
//fmt.Fprintf(w, `"%s"."%s"`, c.sel.Table, cn)
|
||||
if i != 0 {
|
||||
c.w.WriteString(`, `)
|
||||
}
|
||||
colWithTable(c.w, ti.Name, cn)
|
||||
}
|
||||
i++
|
||||
|
||||
if i < len(sel.Cols)-1 || len(childCols) != 0 {
|
||||
//io.WriteString(w, ", ")
|
||||
c.w.WriteString(`, `)
|
||||
}
|
||||
}
|
||||
|
||||
for i, col := range childCols {
|
||||
for _, col := range childCols {
|
||||
if i != 0 {
|
||||
//io.WriteString(w, ", ")
|
||||
c.w.WriteString(`, `)
|
||||
}
|
||||
|
||||
//fmt.Fprintf(w, `"%s"."%s"`, col.Table, col.Name)
|
||||
colWithTable(c.w, col.Table, col.Name)
|
||||
i++
|
||||
}
|
||||
|
||||
c.w.WriteString(` FROM `)
|
||||
@ -570,19 +614,21 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
|
||||
}
|
||||
}
|
||||
|
||||
if sel.Action == 0 {
|
||||
if len(sel.Paging.Limit) != 0 {
|
||||
//fmt.Fprintf(w, ` LIMIT ('%s') :: integer`, c.sel.Paging.Limit)
|
||||
c.w.WriteString(` LIMIT ('`)
|
||||
c.w.WriteString(sel.Paging.Limit)
|
||||
c.w.WriteString(`') :: integer`)
|
||||
switch {
|
||||
case sel.Paging.NoLimit:
|
||||
break
|
||||
|
||||
} else if ti.Singular {
|
||||
c.w.WriteString(` LIMIT ('1') :: integer`)
|
||||
case len(sel.Paging.Limit) != 0:
|
||||
//fmt.Fprintf(w, ` LIMIT ('%s') :: integer`, c.sel.Paging.Limit)
|
||||
c.w.WriteString(` LIMIT ('`)
|
||||
c.w.WriteString(sel.Paging.Limit)
|
||||
c.w.WriteString(`') :: integer`)
|
||||
|
||||
} else {
|
||||
c.w.WriteString(` LIMIT ('20') :: integer`)
|
||||
}
|
||||
case ti.Singular:
|
||||
c.w.WriteString(` LIMIT ('1') :: integer`)
|
||||
|
||||
default:
|
||||
c.w.WriteString(` LIMIT ('20') :: integer`)
|
||||
}
|
||||
|
||||
if len(sel.Paging.Offset) != 0 {
|
@ -22,32 +22,6 @@ func TestMain(m *testing.M) {
|
||||
var err error
|
||||
|
||||
qcompile, err = qcode.NewCompiler(qcode.Config{
|
||||
DefaultFilter: []string{
|
||||
`{ user_id: { _eq: $user_id } }`,
|
||||
},
|
||||
FilterMap: qcode.Filters{
|
||||
All: map[string][]string{
|
||||
"users": []string{
|
||||
"{ id: { eq: $user_id } }",
|
||||
},
|
||||
"products": []string{
|
||||
"{ price: { gt: 0 } }",
|
||||
"{ price: { lt: 8 } }",
|
||||
},
|
||||
"customers": []string{},
|
||||
"mes": []string{
|
||||
"{ id: { eq: $user_id } }",
|
||||
},
|
||||
},
|
||||
Query: map[string][]string{
|
||||
"users": []string{},
|
||||
},
|
||||
Update: map[string][]string{
|
||||
"products": []string{
|
||||
"{ user_id: { eq: $user_id } }",
|
||||
},
|
||||
},
|
||||
},
|
||||
Blocklist: []string{
|
||||
"secret",
|
||||
"password",
|
||||
@ -55,6 +29,59 @@ func TestMain(m *testing.M) {
|
||||
},
|
||||
})
|
||||
|
||||
qcompile.AddRole("user", "product", qcode.TRConfig{
|
||||
Query: qcode.QueryConfig{
|
||||
Columns: []string{"id", "name", "price", "users", "customers"},
|
||||
Filter: []string{
|
||||
"{ price: { gt: 0 } }",
|
||||
"{ price: { lt: 8 } }",
|
||||
},
|
||||
},
|
||||
Update: qcode.UpdateConfig{
|
||||
Filter: []string{"{ user_id: { eq: $user_id } }"},
|
||||
},
|
||||
Delete: qcode.DeleteConfig{
|
||||
Filter: []string{
|
||||
"{ price: { gt: 0 } }",
|
||||
"{ price: { lt: 8 } }",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
qcompile.AddRole("anon", "product", qcode.TRConfig{
|
||||
Query: qcode.QueryConfig{
|
||||
Columns: []string{"id", "name"},
|
||||
},
|
||||
})
|
||||
|
||||
qcompile.AddRole("anon1", "product", qcode.TRConfig{
|
||||
Query: qcode.QueryConfig{
|
||||
Columns: []string{"id", "name", "price"},
|
||||
DisableFunctions: true,
|
||||
},
|
||||
})
|
||||
|
||||
qcompile.AddRole("user", "users", qcode.TRConfig{
|
||||
Query: qcode.QueryConfig{
|
||||
Columns: []string{"id", "full_name", "avatar", "email", "products"},
|
||||
},
|
||||
})
|
||||
|
||||
qcompile.AddRole("user", "mes", qcode.TRConfig{
|
||||
Query: qcode.QueryConfig{
|
||||
Columns: []string{"id", "full_name", "avatar"},
|
||||
Filter: []string{
|
||||
"{ id: { eq: $user_id } }",
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
qcompile.AddRole("user", "customers", qcode.TRConfig{
|
||||
Query: qcode.QueryConfig{
|
||||
Columns: []string{"id", "email", "full_name", "products"},
|
||||
},
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
@ -135,9 +162,8 @@ func TestMain(m *testing.M) {
|
||||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func compileGQLToPSQL(gql string, vars Variables) ([]byte, error) {
|
||||
|
||||
qc, err := qcompile.Compile([]byte(gql))
|
||||
func compileGQLToPSQL(gql string, vars Variables, role string) ([]byte, error) {
|
||||
qc, err := qcompile.Compile([]byte(gql), role)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -147,6 +173,8 @@ func compileGQLToPSQL(gql string, vars Variables) ([]byte, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
//fmt.Println(string(sqlStmt))
|
||||
|
||||
return sqlStmt, nil
|
||||
}
|
||||
|
||||
@ -175,7 +203,7 @@ func withComplexArgs(t *testing.T) {
|
||||
|
||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0" ORDER BY "products_0_price_ob" DESC), '[]') AS "products" FROM (SELECT DISTINCT ON ("products_0_price_ob") row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0", "products_0"."price" AS "products_0_price_ob" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") < 28) AND (("products"."id") >= 20)) LIMIT ('30') :: integer) AS "products_0" ORDER BY "products_0_price_ob" DESC LIMIT ('30') :: integer) AS "sel_json_agg_0") AS "done_1337";`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "user")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -203,7 +231,7 @@ func withWhereMultiOr(t *testing.T) {
|
||||
|
||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") < 20) OR (("products"."price") > 10) OR NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "user")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -229,7 +257,7 @@ func withWhereIsNull(t *testing.T) {
|
||||
|
||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") > 10) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "user")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -255,7 +283,7 @@ func withWhereAndList(t *testing.T) {
|
||||
|
||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") > 10) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "user")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -275,7 +303,7 @@ func fetchByID(t *testing.T) {
|
||||
|
||||
sql := `SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") = 15)) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337";`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "user")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -295,7 +323,7 @@ func searchQuery(t *testing.T) {
|
||||
|
||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("tsv") @@ to_tsquery('Imperial'))) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "user")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -318,7 +346,7 @@ func oneToMany(t *testing.T) {
|
||||
|
||||
sql := `SELECT json_object_agg('users', users) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."email" AS "email", "products_1_join"."products" AS "products") AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."email", "users"."id" FROM "users" LIMIT ('20') :: integer) AS "users_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "products_1"."name" AS "name", "products_1"."price" AS "price") AS "sel_1")) AS "sel_json_1" FROM (SELECT "products"."name", "products"."price" FROM "products" WHERE ((("products"."user_id") = ("users_0"."id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "user")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -341,7 +369,7 @@ func belongsTo(t *testing.T) {
|
||||
|
||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."price" AS "price", "users_1_join"."users" AS "users") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", "products"."price", "products"."user_id" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "users_1"."email" AS "email") AS "sel_1")) AS "sel_json_1" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = ("products_0"."user_id"))) LIMIT ('20') :: integer) AS "users_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "users_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "user")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -364,7 +392,7 @@ func manyToMany(t *testing.T) {
|
||||
|
||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "customers_1_join"."customers" AS "customers") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", "products"."id" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "customers_1"."email" AS "email", "customers_1"."full_name" AS "full_name") AS "sel_1")) AS "sel_json_1" FROM (SELECT "customers"."email", "customers"."full_name" FROM "customers" LEFT OUTER JOIN "purchases" ON (("purchases"."product_id") = ("products_0"."id")) WHERE ((("customers"."id") = ("purchases"."customer_id"))) LIMIT ('20') :: integer) AS "customers_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "customers_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "user")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -387,7 +415,7 @@ func manyToManyReverse(t *testing.T) {
|
||||
|
||||
sql := `SELECT json_object_agg('customers', customers) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "customers_0"."email" AS "email", "customers_0"."full_name" AS "full_name", "products_1_join"."products" AS "products") AS "sel_0")) AS "sel_json_0" FROM (SELECT "customers"."email", "customers"."full_name", "customers"."id" FROM "customers" LIMIT ('20') :: integer) AS "customers_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "products_1"."name" AS "name") AS "sel_1")) AS "sel_json_1" FROM (SELECT "products"."name" FROM "products" LEFT OUTER JOIN "purchases" ON (("purchases"."customer_id") = ("customers_0"."id")) WHERE ((("products"."id") = ("purchases"."product_id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "user")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -407,7 +435,47 @@ func aggFunction(t *testing.T) {
|
||||
|
||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."count_price" AS "count_price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", count("products"."price") AS "count_price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) GROUP BY "products"."name" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "user")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if string(resSQL) != sql {
|
||||
t.Fatal(errNotExpected)
|
||||
}
|
||||
}
|
||||
|
||||
func aggFunctionBlockedByCol(t *testing.T) {
|
||||
gql := `query {
|
||||
products {
|
||||
name
|
||||
count_price
|
||||
}
|
||||
}`
|
||||
|
||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name" FROM "products" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "anon")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if string(resSQL) != sql {
|
||||
t.Fatal(errNotExpected)
|
||||
}
|
||||
}
|
||||
|
||||
func aggFunctionDisabled(t *testing.T) {
|
||||
gql := `query {
|
||||
products {
|
||||
name
|
||||
count_price
|
||||
}
|
||||
}`
|
||||
|
||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name" FROM "products" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "anon1")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -427,7 +495,7 @@ func aggFunctionWithFilter(t *testing.T) {
|
||||
|
||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."max_price" AS "max_price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", max("products"."price") AS "max_price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") > 10)) GROUP BY "products"."id" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "user")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -447,7 +515,7 @@ func queryWithVariables(t *testing.T) {
|
||||
|
||||
sql := `SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") = {{product_price}}) AND (("products"."id") = {{product_id}})) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337";`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "user")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -464,9 +532,9 @@ func syntheticTables(t *testing.T) {
|
||||
}
|
||||
}`
|
||||
|
||||
sql := `SELECT json_object_agg('me', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."email" AS "email") AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = {{user_id}})) LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "done_1337";`
|
||||
sql := `SELECT json_object_agg('me', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT ) AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = {{user_id}})) LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "done_1337";`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "user")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -476,7 +544,7 @@ func syntheticTables(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileSelect(t *testing.T) {
|
||||
func TestCompileQuery(t *testing.T) {
|
||||
t.Run("withComplexArgs", withComplexArgs)
|
||||
t.Run("withWhereAndList", withWhereAndList)
|
||||
t.Run("withWhereIsNull", withWhereIsNull)
|
||||
@ -488,10 +556,11 @@ func TestCompileSelect(t *testing.T) {
|
||||
t.Run("manyToMany", manyToMany)
|
||||
t.Run("manyToManyReverse", manyToManyReverse)
|
||||
t.Run("aggFunction", aggFunction)
|
||||
t.Run("aggFunctionBlockedByCol", aggFunctionBlockedByCol)
|
||||
t.Run("aggFunctionDisabled", aggFunctionDisabled)
|
||||
t.Run("aggFunctionWithFilter", aggFunctionWithFilter)
|
||||
t.Run("syntheticTables", syntheticTables)
|
||||
t.Run("queryWithVariables", queryWithVariables)
|
||||
|
||||
}
|
||||
|
||||
var benchGQL = []byte(`query {
|
||||
@ -526,7 +595,7 @@ func BenchmarkCompile(b *testing.B) {
|
||||
for n := 0; n < b.N; n++ {
|
||||
w.Reset()
|
||||
|
||||
qc, err := qcompile.Compile(benchGQL)
|
||||
qc, err := qcompile.Compile(benchGQL, "user")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
@ -547,7 +616,7 @@ func BenchmarkCompileParallel(b *testing.B) {
|
||||
for pb.Next() {
|
||||
w.Reset()
|
||||
|
||||
qc, err := qcompile.Compile(benchGQL)
|
||||
qc, err := qcompile.Compile(benchGQL, "user")
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
Reference in New Issue
Block a user