Add nested where clause to filter based on related tables
This commit is contained in:
parent
77a51924a7
commit
89bc93e159
|
@ -1,14 +1,14 @@
|
|||
<a href="https://supergraph.dev"><img src="https://supergraph.dev/hologram.svg" width="100" height="100" align="right" /></a>
|
||||
|
||||
# Super Graph - Build web products faster. Instant GraphQL APIs for your apps
|
||||
# Super Graph - Instant GraphQL APIs for your apps.
|
||||
|
||||
## Build web products faster. No code needed. GraphQL auto. transformed into efficient database queries.
|
||||
|
||||
![MIT license](https://img.shields.io/github/license/dosco/super-graph.svg)
|
||||
![Docker build](https://img.shields.io/docker/cloud/build/dosco/super-graph.svg)
|
||||
![Cloud native](https://img.shields.io/badge/cloud--native-enabled-blue.svg)
|
||||
[![Discord Chat](https://img.shields.io/discord/628796009539043348.svg)](https://discord.gg/6pSWCTZ)
|
||||
|
||||
Get an instant high performance GraphQL API for Postgres. No code needed. GraphQL is automatically transformed into efficient database queries.
|
||||
|
||||
![GraphQL](docs/.vuepress/public/graphql.png?raw=true "")
|
||||
|
||||
## The story of Super Graph?
|
||||
|
@ -25,6 +25,7 @@ This compiler is what sits at the heart of Super Graph with layers of useful fun
|
|||
|
||||
## Features
|
||||
|
||||
- Role based access control
|
||||
- Works with Rails database schemas
|
||||
- Automatically learns schemas and relationships
|
||||
- Belongs-To, One-To-Many and Many-To-Many table relationships
|
||||
|
|
|
@ -51,7 +51,8 @@ auth:
|
|||
cookie: _app_session
|
||||
|
||||
# Comment this out if you want to disable setting
|
||||
# the user_id via a header. Good for testing
|
||||
# the user_id via a header for testing.
|
||||
# Disable in production
|
||||
creds_in_header: true
|
||||
|
||||
rails:
|
||||
|
@ -92,6 +93,10 @@ database:
|
|||
#max_retries: 0
|
||||
#log_level: "debug"
|
||||
|
||||
# Set session variable "user.id" to the user id
|
||||
# Enable this if you need the user id in triggers, etc
|
||||
set_user_id: false
|
||||
|
||||
# Define variables here that you want to use in filters
|
||||
# sub-queries must be wrapped in ()
|
||||
variables:
|
||||
|
@ -131,7 +136,7 @@ tables:
|
|||
name: me
|
||||
table: users
|
||||
|
||||
roles_query: "SELECT * FROM users as usr WHERE id = $user_id"
|
||||
roles_query: "SELECT * FROM users WHERE id = $user_id"
|
||||
|
||||
roles:
|
||||
- name: anon
|
||||
|
|
|
@ -85,3 +85,7 @@ database:
|
|||
#pool_size: 10
|
||||
#max_retries: 0
|
||||
#log_level: "debug"
|
||||
|
||||
# Set session variable "user.id" to the user id
|
||||
# Enable this if you need the user id in triggers, etc
|
||||
set_user_id: false
|
|
@ -214,6 +214,7 @@ func (c *compilerContext) renderDelete(qc *qcode.QCode, w io.Writer,
|
|||
|
||||
func (c *compilerContext) renderUpsert(qc *qcode.QCode, w io.Writer,
|
||||
vars Variables, ti *DBTableInfo) (uint32, error) {
|
||||
root := &qc.Selects[0]
|
||||
|
||||
upsert, ok := vars[qc.ActionVar]
|
||||
if !ok {
|
||||
|
@ -229,7 +230,7 @@ func (c *compilerContext) renderUpsert(qc *qcode.QCode, w io.Writer,
|
|||
return 0, err
|
||||
}
|
||||
|
||||
io.WriteString(c.w, ` ON CONFLICT DO (`)
|
||||
io.WriteString(c.w, ` ON CONFLICT (`)
|
||||
i := 0
|
||||
|
||||
for _, cn := range ti.ColumnNames {
|
||||
|
@ -250,10 +251,17 @@ func (c *compilerContext) renderUpsert(qc *qcode.QCode, w io.Writer,
|
|||
if i == 0 {
|
||||
io.WriteString(c.w, ti.PrimaryCol)
|
||||
}
|
||||
io.WriteString(c.w, `) DO `)
|
||||
io.WriteString(c.w, `)`)
|
||||
|
||||
io.WriteString(c.w, `UPDATE `)
|
||||
io.WriteString(c.w, ` SET `)
|
||||
if root.Where != nil {
|
||||
io.WriteString(c.w, ` WHERE `)
|
||||
|
||||
if err := c.renderWhere(root, ti); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
io.WriteString(c.w, ` DO UPDATE SET `)
|
||||
|
||||
i = 0
|
||||
for _, cn := range ti.ColumnNames {
|
||||
|
|
|
@ -78,13 +78,37 @@ func bulkInsert(t *testing.T) {
|
|||
|
||||
func singleUpsert(t *testing.T) {
|
||||
gql := `mutation {
|
||||
product(id: 15, upsert: $upsert) {
|
||||
product(upsert: $upsert) {
|
||||
id
|
||||
name
|
||||
}
|
||||
}`
|
||||
|
||||
sql := `WITH "products" AS (WITH "input" AS (SELECT {{upsert}}::json AS j) INSERT INTO "products" ("name", "description") SELECT "name", "description" FROM input i, json_populate_record(NULL::products, i.j) t ON CONFLICT DO (id) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"`
|
||||
sql := `WITH "products" AS (WITH "input" AS (SELECT {{upsert}}::json AS j) INSERT INTO "products" ("name", "description") SELECT "name", "description" FROM input i, json_populate_record(NULL::products, i.j) t ON CONFLICT (id) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"`
|
||||
|
||||
vars := map[string]json.RawMessage{
|
||||
"upsert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`),
|
||||
}
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, vars, "user")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if string(resSQL) != sql {
|
||||
t.Fatal(errNotExpected)
|
||||
}
|
||||
}
|
||||
|
||||
func singleUpsertWhere(t *testing.T) {
|
||||
gql := `mutation {
|
||||
product(upsert: $upsert, where: { price : { gt: 3 } }) {
|
||||
id
|
||||
name
|
||||
}
|
||||
}`
|
||||
|
||||
sql := `WITH "products" AS (WITH "input" AS (SELECT {{upsert}}::json AS j) INSERT INTO "products" ("name", "description") SELECT "name", "description" FROM input i, json_populate_record(NULL::products, i.j) t ON CONFLICT (id) WHERE (("products"."price") > 3) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"`
|
||||
|
||||
vars := map[string]json.RawMessage{
|
||||
"upsert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`),
|
||||
|
@ -102,13 +126,13 @@ func singleUpsert(t *testing.T) {
|
|||
|
||||
func bulkUpsert(t *testing.T) {
|
||||
gql := `mutation {
|
||||
product(id: 15, upsert: $upsert) {
|
||||
product(upsert: $upsert) {
|
||||
id
|
||||
name
|
||||
}
|
||||
}`
|
||||
|
||||
sql := `WITH "products" AS (WITH "input" AS (SELECT {{upsert}}::json AS j) INSERT INTO "products" ("name", "description") SELECT "name", "description" FROM input i, json_populate_recordset(NULL::products, i.j) t ON CONFLICT DO (id) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"`
|
||||
sql := `WITH "products" AS (WITH "input" AS (SELECT {{upsert}}::json AS j) INSERT INTO "products" ("name", "description") SELECT "name", "description" FROM input i, json_populate_recordset(NULL::products, i.j) t ON CONFLICT (id) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"`
|
||||
|
||||
vars := map[string]json.RawMessage{
|
||||
"upsert": json.RawMessage(` [{ "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }]`),
|
||||
|
@ -271,6 +295,7 @@ func TestCompileMutate(t *testing.T) {
|
|||
t.Run("bulkInsert", bulkInsert)
|
||||
t.Run("singleUpdate", singleUpdate)
|
||||
t.Run("singleUpsert", singleUpsert)
|
||||
t.Run("singleUpsertWhere", singleUpsertWhere)
|
||||
t.Run("bulkUpsert", bulkUpsert)
|
||||
t.Run("delete", delete)
|
||||
t.Run("blockedInsert", blockedInsert)
|
||||
|
|
154
psql/query.go
154
psql/query.go
|
@ -120,7 +120,7 @@ func (co *Compiler) compileQuery(qc *qcode.QCode, w io.Writer) (uint32, error) {
|
|||
}
|
||||
|
||||
if sel.ID != 0 {
|
||||
if err = c.renderJoin(sel); err != nil {
|
||||
if err = c.renderLateralJoin(sel); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
@ -154,7 +154,7 @@ func (co *Compiler) compileQuery(qc *qcode.QCode, w io.Writer) (uint32, error) {
|
|||
}
|
||||
|
||||
if sel.ID != 0 {
|
||||
if err = c.renderJoinClose(sel); err != nil {
|
||||
if err = c.renderLateralJoinClose(sel); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
@ -327,12 +327,12 @@ func (c *compilerContext) renderSelectClose(sel *qcode.Select, ti *DBTableInfo)
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *compilerContext) renderJoin(sel *qcode.Select) error {
|
||||
func (c *compilerContext) renderLateralJoin(sel *qcode.Select) error {
|
||||
io.WriteString(c.w, ` LEFT OUTER JOIN LATERAL (`)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *compilerContext) renderJoinClose(sel *qcode.Select) error {
|
||||
func (c *compilerContext) renderLateralJoinClose(sel *qcode.Select) error {
|
||||
//fmt.Fprintf(w, `) AS "%s_%d_join" ON ('true')`, c.sel.Table, c.sel.ID)
|
||||
io.WriteString(c.w, `)`)
|
||||
aliasWithIDSuffix(c.w, sel.Table, sel.ID, "_join")
|
||||
|
@ -340,19 +340,24 @@ func (c *compilerContext) renderJoinClose(sel *qcode.Select) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func (c *compilerContext) renderJoinTable(sel *qcode.Select) error {
|
||||
func (c *compilerContext) renderJoin(sel *qcode.Select) error {
|
||||
parent := &c.s[sel.ParentID]
|
||||
return c.renderJoinByName(sel.Table, parent.Table, parent.ID)
|
||||
}
|
||||
|
||||
rel, err := c.schema.GetRel(sel.Table, parent.Table)
|
||||
func (c *compilerContext) renderJoinByName(table, parent string, id int32) error {
|
||||
rel, err := c.schema.GetRel(table, parent)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// This join is only required for one-to-many relations since
|
||||
// these make use of join tables that need to be pulled in.
|
||||
if rel.Type != RelOneToManyThrough {
|
||||
return err
|
||||
}
|
||||
|
||||
pt, err := c.schema.GetTable(parent.Table)
|
||||
pt, err := c.schema.GetTable(parent)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -364,7 +369,11 @@ func (c *compilerContext) renderJoinTable(sel *qcode.Select) error {
|
|||
io.WriteString(c.w, `" ON ((`)
|
||||
colWithTable(c.w, rel.Through, rel.ColT)
|
||||
io.WriteString(c.w, `) = (`)
|
||||
colWithTableID(c.w, pt.Name, parent.ID, rel.Col1)
|
||||
if id != -1 {
|
||||
colWithTableID(c.w, pt.Name, id, rel.Col1)
|
||||
} else {
|
||||
colWithTable(c.w, pt.Name, rel.Col1)
|
||||
}
|
||||
io.WriteString(c.w, `))`)
|
||||
|
||||
return nil
|
||||
|
@ -438,7 +447,7 @@ func (c *compilerContext) renderJoinedColumns(sel *qcode.Select, ti *DBTableInfo
|
|||
|
||||
cti, err := c.schema.GetTable(childSel.Table)
|
||||
if err != nil {
|
||||
continue
|
||||
return err
|
||||
}
|
||||
|
||||
//fmt.Fprintf(w, `"%s_%d_join"."%s" AS "%s"`,
|
||||
|
@ -529,9 +538,11 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
|
|||
|
||||
} else if sel.Functions {
|
||||
cn1 := cn[pl:]
|
||||
if len(sel.Allowed) != 0 {
|
||||
if _, ok := sel.Allowed[cn1]; !ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
if i != 0 {
|
||||
io.WriteString(c.w, `, `)
|
||||
}
|
||||
|
@ -596,7 +607,7 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
|
|||
}
|
||||
|
||||
if !isRoot {
|
||||
if err := c.renderJoinTable(sel); err != nil {
|
||||
if err := c.renderJoin(sel); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
|
@ -680,8 +691,11 @@ func (c *compilerContext) renderOrderByColumns(sel *qcode.Select, ti *DBTableInf
|
|||
|
||||
func (c *compilerContext) renderRelationship(sel *qcode.Select, ti *DBTableInfo) error {
|
||||
parent := c.s[sel.ParentID]
|
||||
return c.renderRelationshipByName(sel.Table, parent.Table, parent.ID)
|
||||
}
|
||||
|
||||
rel, err := c.schema.GetRel(sel.Table, parent.Table)
|
||||
func (c *compilerContext) renderRelationshipByName(table, parent string, id int32) error {
|
||||
rel, err := c.schema.GetRel(table, parent)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -691,25 +705,34 @@ func (c *compilerContext) renderRelationship(sel *qcode.Select, ti *DBTableInfo)
|
|||
//fmt.Fprintf(w, `(("%s"."%s") = ("%s_%d"."%s"))`,
|
||||
//c.sel.Table, rel.Col1, c.parent.Table, c.parent.ID, rel.Col2)
|
||||
io.WriteString(c.w, `((`)
|
||||
colWithTable(c.w, ti.Name, rel.Col1)
|
||||
colWithTable(c.w, table, rel.Col1)
|
||||
io.WriteString(c.w, `) = (`)
|
||||
colWithTableID(c.w, parent.Table, parent.ID, rel.Col2)
|
||||
if id != -1 {
|
||||
colWithTableID(c.w, parent, id, rel.Col2)
|
||||
} else {
|
||||
colWithTable(c.w, parent, rel.Col2)
|
||||
}
|
||||
io.WriteString(c.w, `))`)
|
||||
|
||||
case RelOneToMany:
|
||||
//fmt.Fprintf(w, `(("%s"."%s") = ("%s_%d"."%s"))`,
|
||||
//c.sel.Table, rel.Col1, c.parent.Table, c.parent.ID, rel.Col2)
|
||||
io.WriteString(c.w, `((`)
|
||||
colWithTable(c.w, ti.Name, rel.Col1)
|
||||
colWithTable(c.w, table, rel.Col1)
|
||||
io.WriteString(c.w, `) = (`)
|
||||
colWithTableID(c.w, parent.Table, parent.ID, rel.Col2)
|
||||
if id != -1 {
|
||||
colWithTableID(c.w, parent, id, rel.Col2)
|
||||
} else {
|
||||
colWithTable(c.w, parent, rel.Col2)
|
||||
}
|
||||
io.WriteString(c.w, `))`)
|
||||
|
||||
case RelOneToManyThrough:
|
||||
// This requires the through table to be joined onto this select
|
||||
//fmt.Fprintf(w, `(("%s"."%s") = ("%s"."%s"))`,
|
||||
//c.sel.Table, rel.Col1, rel.Through, rel.Col2)
|
||||
io.WriteString(c.w, `((`)
|
||||
colWithTable(c.w, ti.Name, rel.Col1)
|
||||
colWithTable(c.w, table, rel.Col1)
|
||||
io.WriteString(c.w, `) = (`)
|
||||
colWithTable(c.w, rel.Through, rel.Col2)
|
||||
io.WriteString(c.w, `))`)
|
||||
|
@ -768,21 +791,73 @@ func (c *compilerContext) renderWhere(sel *qcode.Select, ti *DBTableInfo) error
|
|||
qcode.FreeExp(val)
|
||||
|
||||
default:
|
||||
if val.NestedCol {
|
||||
//fmt.Fprintf(w, `(("%s") `, val.Col)
|
||||
io.WriteString(c.w, `(("`)
|
||||
io.WriteString(c.w, val.Col)
|
||||
io.WriteString(c.w, `") `)
|
||||
if len(val.NestedCols) != 0 {
|
||||
io.WriteString(c.w, `EXISTS `)
|
||||
|
||||
} else if len(val.Col) != 0 {
|
||||
if err := c.renderNestedWhere(val, sel, ti); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
} else {
|
||||
//fmt.Fprintf(w, `(("%s"."%s") `, c.sel.Table, val.Col)
|
||||
if err := c.renderOp(val, sel, ti); err != nil {
|
||||
return err
|
||||
}
|
||||
qcode.FreeExp(val)
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
return fmt.Errorf("12: unexpected value %v (%t)", intf, intf)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *compilerContext) renderNestedWhere(ex *qcode.Exp, sel *qcode.Select, ti *DBTableInfo) error {
|
||||
|
||||
for i := 0; i < len(ex.NestedCols)-1; i++ {
|
||||
cti, err := c.schema.GetTable(ex.NestedCols[i])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if i != 0 {
|
||||
io.WriteString(c.w, ` AND `)
|
||||
}
|
||||
|
||||
io.WriteString(c.w, `(SELECT 1 FROM `)
|
||||
io.WriteString(c.w, cti.Name)
|
||||
|
||||
if err := c.renderJoinByName(cti.Name, ti.Name, -1); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
io.WriteString(c.w, ` WHERE `)
|
||||
|
||||
if err := c.renderRelationshipByName(cti.Name, ti.Name, -1); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
for i := 0; i < len(ex.NestedCols)-1; i++ {
|
||||
io.WriteString(c.w, `)`)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *compilerContext) renderOp(ex *qcode.Exp, sel *qcode.Select, ti *DBTableInfo) error {
|
||||
if len(ex.Col) != 0 {
|
||||
io.WriteString(c.w, `((`)
|
||||
colWithTable(c.w, ti.Name, val.Col)
|
||||
colWithTable(c.w, ti.Name, ex.Col)
|
||||
io.WriteString(c.w, `) `)
|
||||
}
|
||||
valExists := true
|
||||
|
||||
switch val.Op {
|
||||
switch ex.Op {
|
||||
case qcode.OpEquals:
|
||||
io.WriteString(c.w, `=`)
|
||||
case qcode.OpNotEquals:
|
||||
|
@ -822,12 +897,12 @@ func (c *compilerContext) renderWhere(sel *qcode.Select, ti *DBTableInfo) error
|
|||
case qcode.OpHasKeyAll:
|
||||
io.WriteString(c.w, `?&`)
|
||||
case qcode.OpIsNull:
|
||||
if strings.EqualFold(val.Val, "true") {
|
||||
if strings.EqualFold(ex.Val, "true") {
|
||||
io.WriteString(c.w, `IS NULL)`)
|
||||
} else {
|
||||
io.WriteString(c.w, `IS NOT NULL)`)
|
||||
}
|
||||
valExists = false
|
||||
return nil
|
||||
case qcode.OpEqID:
|
||||
if len(ti.PrimaryCol) == 0 {
|
||||
return fmt.Errorf("no primary key column defined for %s", ti.Name)
|
||||
|
@ -845,32 +920,21 @@ func (c *compilerContext) renderWhere(sel *qcode.Select, ti *DBTableInfo) error
|
|||
io.WriteString(c.w, `(("`)
|
||||
io.WriteString(c.w, ti.TSVCol)
|
||||
io.WriteString(c.w, `") @@ to_tsquery('`)
|
||||
io.WriteString(c.w, val.Val)
|
||||
io.WriteString(c.w, ex.Val)
|
||||
io.WriteString(c.w, `'))`)
|
||||
valExists = false
|
||||
return nil
|
||||
|
||||
default:
|
||||
return fmt.Errorf("[Where] unexpected op code %d", val.Op)
|
||||
return fmt.Errorf("[Where] unexpected op code %d", ex.Op)
|
||||
}
|
||||
|
||||
if valExists {
|
||||
if val.Type == qcode.ValList {
|
||||
c.renderList(val)
|
||||
if ex.Type == qcode.ValList {
|
||||
c.renderList(ex)
|
||||
} else {
|
||||
c.renderVal(val, c.vars)
|
||||
c.renderVal(ex, c.vars)
|
||||
}
|
||||
|
||||
io.WriteString(c.w, `)`)
|
||||
}
|
||||
|
||||
qcode.FreeExp(val)
|
||||
}
|
||||
|
||||
default:
|
||||
return fmt.Errorf("12: unexpected value %v (%t)", intf, intf)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
|
|
@ -332,6 +332,25 @@ func aggFunctionWithFilter(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func syntheticTables(t *testing.T) {
|
||||
gql := `query {
|
||||
me {
|
||||
email
|
||||
}
|
||||
}`
|
||||
|
||||
sql := `SELECT json_object_agg('me', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT ) AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = {{user_id}})) LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "done_1337"`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "user")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if string(resSQL) != sql {
|
||||
t.Fatal(errNotExpected)
|
||||
}
|
||||
}
|
||||
|
||||
func queryWithVariables(t *testing.T) {
|
||||
gql := `query {
|
||||
product(id: $PRODUCT_ID, where: { price: { eq: $PRODUCT_PRICE } }) {
|
||||
|
@ -352,14 +371,21 @@ func queryWithVariables(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func syntheticTables(t *testing.T) {
|
||||
func withWhereOnRelations(t *testing.T) {
|
||||
gql := `query {
|
||||
me {
|
||||
users(where: {
|
||||
not: {
|
||||
products: {
|
||||
price: { gt: 3 }
|
||||
}
|
||||
}
|
||||
}) {
|
||||
id
|
||||
email
|
||||
}
|
||||
}`
|
||||
|
||||
sql := `SELECT json_object_agg('me', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT ) AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = {{user_id}})) LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "done_1337"`
|
||||
sql := `SELECT json_object_agg('users', users) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."id" AS "id", "users_0"."email" AS "email") AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."id", "users"."email" FROM "users" WHERE (NOT EXISTS (SELECT 1 FROM products WHERE (("products"."user_id") = ("users"."id")))) LIMIT ('20') :: integer) AS "users_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "user")
|
||||
if err != nil {
|
||||
|
@ -371,50 +397,6 @@ func syntheticTables(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestCompileQuery(t *testing.T) {
|
||||
t.Run("withComplexArgs", withComplexArgs)
|
||||
t.Run("withWhereAndList", withWhereAndList)
|
||||
t.Run("withWhereIsNull", withWhereIsNull)
|
||||
t.Run("withWhereMultiOr", withWhereMultiOr)
|
||||
t.Run("fetchByID", fetchByID)
|
||||
t.Run("searchQuery", searchQuery)
|
||||
t.Run("belongsTo", belongsTo)
|
||||
t.Run("oneToMany", oneToMany)
|
||||
t.Run("manyToMany", manyToMany)
|
||||
t.Run("manyToManyReverse", manyToManyReverse)
|
||||
t.Run("aggFunction", aggFunction)
|
||||
t.Run("aggFunctionBlockedByCol", aggFunctionBlockedByCol)
|
||||
t.Run("aggFunctionDisabled", aggFunctionDisabled)
|
||||
t.Run("aggFunctionWithFilter", aggFunctionWithFilter)
|
||||
t.Run("syntheticTables", syntheticTables)
|
||||
t.Run("queryWithVariables", queryWithVariables)
|
||||
t.Run("blockedQuery", blockedQuery)
|
||||
t.Run("blockedFunctions", blockedFunctions)
|
||||
}
|
||||
|
||||
var benchGQL = []byte(`query {
|
||||
proDUcts(
|
||||
# returns only 30 items
|
||||
limit: 30,
|
||||
|
||||
# starts from item 10, commented out for now
|
||||
# offset: 10,
|
||||
|
||||
# orders the response items by highest price
|
||||
order_by: { price: desc },
|
||||
|
||||
# only items with an id >= 30 and < 30 are returned
|
||||
where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) {
|
||||
id
|
||||
NAME
|
||||
price
|
||||
user {
|
||||
full_name
|
||||
picture : avatar
|
||||
}
|
||||
}
|
||||
}`)
|
||||
|
||||
func blockedQuery(t *testing.T) {
|
||||
gql := `query {
|
||||
user(id: 5, where: { id: { gt: 3 } }) {
|
||||
|
@ -456,6 +438,51 @@ func blockedFunctions(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestCompileQuery(t *testing.T) {
|
||||
t.Run("withComplexArgs", withComplexArgs)
|
||||
t.Run("withWhereAndList", withWhereAndList)
|
||||
t.Run("withWhereIsNull", withWhereIsNull)
|
||||
t.Run("withWhereMultiOr", withWhereMultiOr)
|
||||
t.Run("fetchByID", fetchByID)
|
||||
t.Run("searchQuery", searchQuery)
|
||||
t.Run("belongsTo", belongsTo)
|
||||
t.Run("oneToMany", oneToMany)
|
||||
t.Run("manyToMany", manyToMany)
|
||||
t.Run("manyToManyReverse", manyToManyReverse)
|
||||
t.Run("aggFunction", aggFunction)
|
||||
t.Run("aggFunctionBlockedByCol", aggFunctionBlockedByCol)
|
||||
t.Run("aggFunctionDisabled", aggFunctionDisabled)
|
||||
t.Run("aggFunctionWithFilter", aggFunctionWithFilter)
|
||||
t.Run("syntheticTables", syntheticTables)
|
||||
t.Run("queryWithVariables", queryWithVariables)
|
||||
t.Run("withWhereOnRelations", withWhereOnRelations)
|
||||
t.Run("blockedQuery", blockedQuery)
|
||||
t.Run("blockedFunctions", blockedFunctions)
|
||||
}
|
||||
|
||||
var benchGQL = []byte(`query {
|
||||
proDUcts(
|
||||
# returns only 30 items
|
||||
limit: 30,
|
||||
|
||||
# starts from item 10, commented out for now
|
||||
# offset: 10,
|
||||
|
||||
# orders the response items by highest price
|
||||
order_by: { price: desc },
|
||||
|
||||
# only items with an id >= 30 and < 30 are returned
|
||||
where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) {
|
||||
id
|
||||
NAME
|
||||
price
|
||||
user {
|
||||
full_name
|
||||
picture : avatar
|
||||
}
|
||||
}
|
||||
}`)
|
||||
|
||||
func BenchmarkCompile(b *testing.B) {
|
||||
w := &bytes.Buffer{}
|
||||
|
||||
|
|
|
@ -59,7 +59,7 @@ type Column struct {
|
|||
type Exp struct {
|
||||
Op ExpOp
|
||||
Col string
|
||||
NestedCol bool
|
||||
NestedCols []string
|
||||
Type ValType
|
||||
Val string
|
||||
ListType ValType
|
||||
|
@ -918,10 +918,8 @@ func setWhereColName(ex *Exp, node *Node) {
|
|||
}
|
||||
if len(list) == 1 {
|
||||
ex.Col = list[0]
|
||||
|
||||
} else if len(list) > 2 {
|
||||
ex.Col = buildPath(list)
|
||||
ex.NestedCol = true
|
||||
} else if len(list) > 1 {
|
||||
ex.NestedCols = list
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
@ -8,19 +8,19 @@ import (
|
|||
"github.com/dosco/super-graph/jsn"
|
||||
)
|
||||
|
||||
func varMap(ctx *coreContext) func(w io.Writer, tag string) (int, error) {
|
||||
func argMap(ctx *coreContext) func(w io.Writer, tag string) (int, error) {
|
||||
return func(w io.Writer, tag string) (int, error) {
|
||||
switch tag {
|
||||
case "user_id_provider":
|
||||
if v := ctx.Value(userIDProviderKey); v != nil {
|
||||
return stringVar(w, v.(string))
|
||||
return stringArg(w, v.(string))
|
||||
}
|
||||
io.WriteString(w, "null")
|
||||
return 0, nil
|
||||
|
||||
case "user_id":
|
||||
if v := ctx.Value(userIDKey); v != nil {
|
||||
return stringVar(w, v.(string))
|
||||
return stringArg(w, v.(string))
|
||||
}
|
||||
|
||||
io.WriteString(w, "null")
|
||||
|
@ -28,7 +28,7 @@ func varMap(ctx *coreContext) func(w io.Writer, tag string) (int, error) {
|
|||
|
||||
case "user_role":
|
||||
if v := ctx.Value(userRoleKey); v != nil {
|
||||
return stringVar(w, v.(string))
|
||||
return stringArg(w, v.(string))
|
||||
}
|
||||
io.WriteString(w, "null")
|
||||
return 0, nil
|
||||
|
@ -50,7 +50,7 @@ func varMap(ctx *coreContext) func(w io.Writer, tag string) (int, error) {
|
|||
}
|
||||
|
||||
if is {
|
||||
return stringVarB(w, fields[0].Value)
|
||||
return stringArgB(w, fields[0].Value)
|
||||
}
|
||||
|
||||
w.Write(fields[0].Value)
|
||||
|
@ -58,7 +58,7 @@ func varMap(ctx *coreContext) func(w io.Writer, tag string) (int, error) {
|
|||
}
|
||||
}
|
||||
|
||||
func varList(ctx *coreContext, args [][]byte) []interface{} {
|
||||
func argList(ctx *coreContext, args [][]byte) []interface{} {
|
||||
vars := make([]interface{}, len(args))
|
||||
|
||||
var fields map[string]interface{}
|
||||
|
@ -86,6 +86,11 @@ func varList(ctx *coreContext, args [][]byte) []interface{} {
|
|||
vars[i] = v.(string)
|
||||
}
|
||||
|
||||
case bytes.Equal(av, []byte("user_role")):
|
||||
if v := ctx.Value(userRoleKey); v != nil {
|
||||
vars[i] = v.(string)
|
||||
}
|
||||
|
||||
default:
|
||||
if v, ok := fields[string(av)]; ok {
|
||||
vars[i] = v
|
||||
|
@ -96,7 +101,7 @@ func varList(ctx *coreContext, args [][]byte) []interface{} {
|
|||
return vars
|
||||
}
|
||||
|
||||
func stringVar(w io.Writer, v string) (int, error) {
|
||||
func stringArg(w io.Writer, v string) (int, error) {
|
||||
if n, err := w.Write([]byte(`'`)); err != nil {
|
||||
return n, err
|
||||
}
|
||||
|
@ -106,7 +111,7 @@ func stringVar(w io.Writer, v string) (int, error) {
|
|||
return w.Write([]byte(`'`))
|
||||
}
|
||||
|
||||
func stringVarB(w io.Writer, v []byte) (int, error) {
|
||||
func stringArgB(w io.Writer, v []byte) (int, error) {
|
||||
if n, err := w.Write([]byte(`'`)); err != nil {
|
||||
return n, err
|
||||
}
|
|
@ -66,6 +66,7 @@ type config struct {
|
|||
PoolSize int32 `mapstructure:"pool_size"`
|
||||
MaxRetries int `mapstructure:"max_retries"`
|
||||
LogLevel string `mapstructure:"log_level"`
|
||||
SetUserID bool `mapstructure:"set_user_id"`
|
||||
|
||||
Vars map[string]string `mapstructure:"variables"`
|
||||
|
||||
|
|
25
serv/core.go
25
serv/core.go
|
@ -122,10 +122,8 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) {
|
|||
}
|
||||
defer tx.Rollback(c)
|
||||
|
||||
if v := c.Value(userIDKey); v != nil {
|
||||
_, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
|
||||
|
||||
if err != nil {
|
||||
if conf.DB.SetUserID {
|
||||
if err := c.setLocalUserID(tx); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
@ -153,7 +151,7 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) {
|
|||
}
|
||||
|
||||
var root []byte
|
||||
vars := varList(c, ps.args)
|
||||
vars := argList(c, ps.args)
|
||||
|
||||
if mutation {
|
||||
err = tx.QueryRow(c, ps.stmt.SQL, vars...).Scan(&root)
|
||||
|
@ -206,7 +204,7 @@ func (c *coreContext) resolveSQL() ([]byte, uint32, error) {
|
|||
t := fasttemplate.New(st.sql, openVar, closeVar)
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
_, err = t.ExecuteFunc(buf, varMap(c))
|
||||
_, err = t.ExecuteFunc(buf, argMap(c))
|
||||
|
||||
if err == errNoUserID {
|
||||
logger.Warn().Msg("no user id found. query requires an authenicated request")
|
||||
|
@ -224,10 +222,8 @@ func (c *coreContext) resolveSQL() ([]byte, uint32, error) {
|
|||
stime = time.Now()
|
||||
}
|
||||
|
||||
if v := c.Value(userIDKey); v != nil {
|
||||
_, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
|
||||
|
||||
if err != nil {
|
||||
if conf.DB.SetUserID {
|
||||
if err := c.setLocalUserID(tx); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
}
|
||||
|
@ -425,6 +421,15 @@ func (c *coreContext) executeRoleQuery(tx pgx.Tx) (string, error) {
|
|||
return role, nil
|
||||
}
|
||||
|
||||
func (c *coreContext) setLocalUserID(tx pgx.Tx) error {
|
||||
var err error
|
||||
if v := c.Value(userIDKey); v != nil {
|
||||
_, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *coreContext) render(w io.Writer, data []byte) error {
|
||||
c.res.Data = json.RawMessage(data)
|
||||
return json.NewEncoder(w).Encode(c.res)
|
||||
|
|
|
@ -51,7 +51,8 @@ auth:
|
|||
cookie: _{{app_name_slug}}_session
|
||||
|
||||
# Comment this out if you want to disable setting
|
||||
# the user_id via a header. Good for testing
|
||||
# the user_id via a header for testing.
|
||||
# Disable in production
|
||||
creds_in_header: true
|
||||
|
||||
rails:
|
||||
|
@ -92,6 +93,10 @@ database:
|
|||
#max_retries: 0
|
||||
#log_level: "debug"
|
||||
|
||||
# Set session variable "user.id" to the user id
|
||||
# Enable this if you need the user id in triggers, etc
|
||||
set_user_id: false
|
||||
|
||||
# Define variables here that you want to use in filters
|
||||
# sub-queries must be wrapped in ()
|
||||
variables:
|
||||
|
@ -131,7 +136,7 @@ tables:
|
|||
name: me
|
||||
table: users
|
||||
|
||||
roles_query: "SELECT * FROM users as usr WHERE id = $user_id"
|
||||
roles_query: "SELECT * FROM users WHERE id = $user_id"
|
||||
|
||||
roles:
|
||||
- name: anon
|
||||
|
|
|
@ -85,3 +85,7 @@ database:
|
|||
#pool_size: 10
|
||||
#max_retries: 0
|
||||
#log_level: "debug"
|
||||
|
||||
# Set session variable "user.id" to the user id
|
||||
# Enable this if you need the user id in triggers, etc
|
||||
set_user_id: false
|
Loading…
Reference in New Issue