Add nested where clause to filter based on related tables

This commit is contained in:
Vikram Rangnekar 2019-11-04 23:44:42 -05:00
parent 77a51924a7
commit 89bc93e159
13 changed files with 358 additions and 206 deletions

View File

@ -1,14 +1,14 @@
<a href="https://supergraph.dev"><img src="https://supergraph.dev/hologram.svg" width="100" height="100" align="right" /></a> <a href="https://supergraph.dev"><img src="https://supergraph.dev/hologram.svg" width="100" height="100" align="right" /></a>
# Super Graph - Build web products faster. Instant GraphQL APIs for your apps # Super Graph - Instant GraphQL APIs for your apps.
## Build web products faster. No code needed. GraphQL auto. transformed into efficient database queries.
![MIT license](https://img.shields.io/github/license/dosco/super-graph.svg) ![MIT license](https://img.shields.io/github/license/dosco/super-graph.svg)
![Docker build](https://img.shields.io/docker/cloud/build/dosco/super-graph.svg) ![Docker build](https://img.shields.io/docker/cloud/build/dosco/super-graph.svg)
![Cloud native](https://img.shields.io/badge/cloud--native-enabled-blue.svg) ![Cloud native](https://img.shields.io/badge/cloud--native-enabled-blue.svg)
[![Discord Chat](https://img.shields.io/discord/628796009539043348.svg)](https://discord.gg/6pSWCTZ) [![Discord Chat](https://img.shields.io/discord/628796009539043348.svg)](https://discord.gg/6pSWCTZ)
Get an instant high performance GraphQL API for Postgres. No code needed. GraphQL is automatically transformed into efficient database queries.
![GraphQL](docs/.vuepress/public/graphql.png?raw=true "") ![GraphQL](docs/.vuepress/public/graphql.png?raw=true "")
## The story of Super Graph? ## The story of Super Graph?
@ -25,6 +25,7 @@ This compiler is what sits at the heart of Super Graph with layers of useful fun
## Features ## Features
- Role based access control
- Works with Rails database schemas - Works with Rails database schemas
- Automatically learns schemas and relationships - Automatically learns schemas and relationships
- Belongs-To, One-To-Many and Many-To-Many table relationships - Belongs-To, One-To-Many and Many-To-Many table relationships

View File

@ -51,7 +51,8 @@ auth:
cookie: _app_session cookie: _app_session
# Comment this out if you want to disable setting # Comment this out if you want to disable setting
# the user_id via a header. Good for testing # the user_id via a header for testing.
# Disable in production
creds_in_header: true creds_in_header: true
rails: rails:
@ -92,6 +93,10 @@ database:
#max_retries: 0 #max_retries: 0
#log_level: "debug" #log_level: "debug"
# Set session variable "user.id" to the user id
# Enable this if you need the user id in triggers, etc
set_user_id: false
# Define variables here that you want to use in filters # Define variables here that you want to use in filters
# sub-queries must be wrapped in () # sub-queries must be wrapped in ()
variables: variables:
@ -131,7 +136,7 @@ tables:
name: me name: me
table: users table: users
roles_query: "SELECT * FROM users as usr WHERE id = $user_id" roles_query: "SELECT * FROM users WHERE id = $user_id"
roles: roles:
- name: anon - name: anon

View File

@ -85,3 +85,7 @@ database:
#pool_size: 10 #pool_size: 10
#max_retries: 0 #max_retries: 0
#log_level: "debug" #log_level: "debug"
# Set session variable "user.id" to the user id
# Enable this if you need the user id in triggers, etc
set_user_id: false

View File

@ -214,6 +214,7 @@ func (c *compilerContext) renderDelete(qc *qcode.QCode, w io.Writer,
func (c *compilerContext) renderUpsert(qc *qcode.QCode, w io.Writer, func (c *compilerContext) renderUpsert(qc *qcode.QCode, w io.Writer,
vars Variables, ti *DBTableInfo) (uint32, error) { vars Variables, ti *DBTableInfo) (uint32, error) {
root := &qc.Selects[0]
upsert, ok := vars[qc.ActionVar] upsert, ok := vars[qc.ActionVar]
if !ok { if !ok {
@ -229,7 +230,7 @@ func (c *compilerContext) renderUpsert(qc *qcode.QCode, w io.Writer,
return 0, err return 0, err
} }
io.WriteString(c.w, ` ON CONFLICT DO (`) io.WriteString(c.w, ` ON CONFLICT (`)
i := 0 i := 0
for _, cn := range ti.ColumnNames { for _, cn := range ti.ColumnNames {
@ -250,10 +251,17 @@ func (c *compilerContext) renderUpsert(qc *qcode.QCode, w io.Writer,
if i == 0 { if i == 0 {
io.WriteString(c.w, ti.PrimaryCol) io.WriteString(c.w, ti.PrimaryCol)
} }
io.WriteString(c.w, `) DO `) io.WriteString(c.w, `)`)
io.WriteString(c.w, `UPDATE `) if root.Where != nil {
io.WriteString(c.w, ` SET `) io.WriteString(c.w, ` WHERE `)
if err := c.renderWhere(root, ti); err != nil {
return 0, err
}
}
io.WriteString(c.w, ` DO UPDATE SET `)
i = 0 i = 0
for _, cn := range ti.ColumnNames { for _, cn := range ti.ColumnNames {

View File

@ -78,13 +78,37 @@ func bulkInsert(t *testing.T) {
func singleUpsert(t *testing.T) { func singleUpsert(t *testing.T) {
gql := `mutation { gql := `mutation {
product(id: 15, upsert: $upsert) { product(upsert: $upsert) {
id id
name name
} }
}` }`
sql := `WITH "products" AS (WITH "input" AS (SELECT {{upsert}}::json AS j) INSERT INTO "products" ("name", "description") SELECT "name", "description" FROM input i, json_populate_record(NULL::products, i.j) t ON CONFLICT DO (id) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"` sql := `WITH "products" AS (WITH "input" AS (SELECT {{upsert}}::json AS j) INSERT INTO "products" ("name", "description") SELECT "name", "description" FROM input i, json_populate_record(NULL::products, i.j) t ON CONFLICT (id) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"`
vars := map[string]json.RawMessage{
"upsert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`),
}
resSQL, err := compileGQLToPSQL(gql, vars, "user")
if err != nil {
t.Fatal(err)
}
if string(resSQL) != sql {
t.Fatal(errNotExpected)
}
}
func singleUpsertWhere(t *testing.T) {
gql := `mutation {
product(upsert: $upsert, where: { price : { gt: 3 } }) {
id
name
}
}`
sql := `WITH "products" AS (WITH "input" AS (SELECT {{upsert}}::json AS j) INSERT INTO "products" ("name", "description") SELECT "name", "description" FROM input i, json_populate_record(NULL::products, i.j) t ON CONFLICT (id) WHERE (("products"."price") > 3) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"`
vars := map[string]json.RawMessage{ vars := map[string]json.RawMessage{
"upsert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`), "upsert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`),
@ -102,13 +126,13 @@ func singleUpsert(t *testing.T) {
func bulkUpsert(t *testing.T) { func bulkUpsert(t *testing.T) {
gql := `mutation { gql := `mutation {
product(id: 15, upsert: $upsert) { product(upsert: $upsert) {
id id
name name
} }
}` }`
sql := `WITH "products" AS (WITH "input" AS (SELECT {{upsert}}::json AS j) INSERT INTO "products" ("name", "description") SELECT "name", "description" FROM input i, json_populate_recordset(NULL::products, i.j) t ON CONFLICT DO (id) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"` sql := `WITH "products" AS (WITH "input" AS (SELECT {{upsert}}::json AS j) INSERT INTO "products" ("name", "description") SELECT "name", "description" FROM input i, json_populate_recordset(NULL::products, i.j) t ON CONFLICT (id) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"`
vars := map[string]json.RawMessage{ vars := map[string]json.RawMessage{
"upsert": json.RawMessage(` [{ "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }]`), "upsert": json.RawMessage(` [{ "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }]`),
@ -271,6 +295,7 @@ func TestCompileMutate(t *testing.T) {
t.Run("bulkInsert", bulkInsert) t.Run("bulkInsert", bulkInsert)
t.Run("singleUpdate", singleUpdate) t.Run("singleUpdate", singleUpdate)
t.Run("singleUpsert", singleUpsert) t.Run("singleUpsert", singleUpsert)
t.Run("singleUpsertWhere", singleUpsertWhere)
t.Run("bulkUpsert", bulkUpsert) t.Run("bulkUpsert", bulkUpsert)
t.Run("delete", delete) t.Run("delete", delete)
t.Run("blockedInsert", blockedInsert) t.Run("blockedInsert", blockedInsert)

View File

@ -120,7 +120,7 @@ func (co *Compiler) compileQuery(qc *qcode.QCode, w io.Writer) (uint32, error) {
} }
if sel.ID != 0 { if sel.ID != 0 {
if err = c.renderJoin(sel); err != nil { if err = c.renderLateralJoin(sel); err != nil {
return 0, err return 0, err
} }
} }
@ -154,7 +154,7 @@ func (co *Compiler) compileQuery(qc *qcode.QCode, w io.Writer) (uint32, error) {
} }
if sel.ID != 0 { if sel.ID != 0 {
if err = c.renderJoinClose(sel); err != nil { if err = c.renderLateralJoinClose(sel); err != nil {
return 0, err return 0, err
} }
} }
@ -327,12 +327,12 @@ func (c *compilerContext) renderSelectClose(sel *qcode.Select, ti *DBTableInfo)
return nil return nil
} }
func (c *compilerContext) renderJoin(sel *qcode.Select) error { func (c *compilerContext) renderLateralJoin(sel *qcode.Select) error {
io.WriteString(c.w, ` LEFT OUTER JOIN LATERAL (`) io.WriteString(c.w, ` LEFT OUTER JOIN LATERAL (`)
return nil return nil
} }
func (c *compilerContext) renderJoinClose(sel *qcode.Select) error { func (c *compilerContext) renderLateralJoinClose(sel *qcode.Select) error {
//fmt.Fprintf(w, `) AS "%s_%d_join" ON ('true')`, c.sel.Table, c.sel.ID) //fmt.Fprintf(w, `) AS "%s_%d_join" ON ('true')`, c.sel.Table, c.sel.ID)
io.WriteString(c.w, `)`) io.WriteString(c.w, `)`)
aliasWithIDSuffix(c.w, sel.Table, sel.ID, "_join") aliasWithIDSuffix(c.w, sel.Table, sel.ID, "_join")
@ -340,19 +340,24 @@ func (c *compilerContext) renderJoinClose(sel *qcode.Select) error {
return nil return nil
} }
func (c *compilerContext) renderJoinTable(sel *qcode.Select) error { func (c *compilerContext) renderJoin(sel *qcode.Select) error {
parent := &c.s[sel.ParentID] parent := &c.s[sel.ParentID]
return c.renderJoinByName(sel.Table, parent.Table, parent.ID)
}
rel, err := c.schema.GetRel(sel.Table, parent.Table) func (c *compilerContext) renderJoinByName(table, parent string, id int32) error {
rel, err := c.schema.GetRel(table, parent)
if err != nil { if err != nil {
return err return err
} }
// This join is only required for one-to-many relations since
// these make use of join tables that need to be pulled in.
if rel.Type != RelOneToManyThrough { if rel.Type != RelOneToManyThrough {
return err return err
} }
pt, err := c.schema.GetTable(parent.Table) pt, err := c.schema.GetTable(parent)
if err != nil { if err != nil {
return err return err
} }
@ -364,7 +369,11 @@ func (c *compilerContext) renderJoinTable(sel *qcode.Select) error {
io.WriteString(c.w, `" ON ((`) io.WriteString(c.w, `" ON ((`)
colWithTable(c.w, rel.Through, rel.ColT) colWithTable(c.w, rel.Through, rel.ColT)
io.WriteString(c.w, `) = (`) io.WriteString(c.w, `) = (`)
colWithTableID(c.w, pt.Name, parent.ID, rel.Col1) if id != -1 {
colWithTableID(c.w, pt.Name, id, rel.Col1)
} else {
colWithTable(c.w, pt.Name, rel.Col1)
}
io.WriteString(c.w, `))`) io.WriteString(c.w, `))`)
return nil return nil
@ -438,7 +447,7 @@ func (c *compilerContext) renderJoinedColumns(sel *qcode.Select, ti *DBTableInfo
cti, err := c.schema.GetTable(childSel.Table) cti, err := c.schema.GetTable(childSel.Table)
if err != nil { if err != nil {
continue return err
} }
//fmt.Fprintf(w, `"%s_%d_join"."%s" AS "%s"`, //fmt.Fprintf(w, `"%s_%d_join"."%s" AS "%s"`,
@ -529,9 +538,11 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
} else if sel.Functions { } else if sel.Functions {
cn1 := cn[pl:] cn1 := cn[pl:]
if len(sel.Allowed) != 0 {
if _, ok := sel.Allowed[cn1]; !ok { if _, ok := sel.Allowed[cn1]; !ok {
continue continue
} }
}
if i != 0 { if i != 0 {
io.WriteString(c.w, `, `) io.WriteString(c.w, `, `)
} }
@ -596,7 +607,7 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
} }
if !isRoot { if !isRoot {
if err := c.renderJoinTable(sel); err != nil { if err := c.renderJoin(sel); err != nil {
return err return err
} }
@ -680,8 +691,11 @@ func (c *compilerContext) renderOrderByColumns(sel *qcode.Select, ti *DBTableInf
func (c *compilerContext) renderRelationship(sel *qcode.Select, ti *DBTableInfo) error { func (c *compilerContext) renderRelationship(sel *qcode.Select, ti *DBTableInfo) error {
parent := c.s[sel.ParentID] parent := c.s[sel.ParentID]
return c.renderRelationshipByName(sel.Table, parent.Table, parent.ID)
}
rel, err := c.schema.GetRel(sel.Table, parent.Table) func (c *compilerContext) renderRelationshipByName(table, parent string, id int32) error {
rel, err := c.schema.GetRel(table, parent)
if err != nil { if err != nil {
return err return err
} }
@ -691,25 +705,34 @@ func (c *compilerContext) renderRelationship(sel *qcode.Select, ti *DBTableInfo)
//fmt.Fprintf(w, `(("%s"."%s") = ("%s_%d"."%s"))`, //fmt.Fprintf(w, `(("%s"."%s") = ("%s_%d"."%s"))`,
//c.sel.Table, rel.Col1, c.parent.Table, c.parent.ID, rel.Col2) //c.sel.Table, rel.Col1, c.parent.Table, c.parent.ID, rel.Col2)
io.WriteString(c.w, `((`) io.WriteString(c.w, `((`)
colWithTable(c.w, ti.Name, rel.Col1) colWithTable(c.w, table, rel.Col1)
io.WriteString(c.w, `) = (`) io.WriteString(c.w, `) = (`)
colWithTableID(c.w, parent.Table, parent.ID, rel.Col2) if id != -1 {
colWithTableID(c.w, parent, id, rel.Col2)
} else {
colWithTable(c.w, parent, rel.Col2)
}
io.WriteString(c.w, `))`) io.WriteString(c.w, `))`)
case RelOneToMany: case RelOneToMany:
//fmt.Fprintf(w, `(("%s"."%s") = ("%s_%d"."%s"))`, //fmt.Fprintf(w, `(("%s"."%s") = ("%s_%d"."%s"))`,
//c.sel.Table, rel.Col1, c.parent.Table, c.parent.ID, rel.Col2) //c.sel.Table, rel.Col1, c.parent.Table, c.parent.ID, rel.Col2)
io.WriteString(c.w, `((`) io.WriteString(c.w, `((`)
colWithTable(c.w, ti.Name, rel.Col1) colWithTable(c.w, table, rel.Col1)
io.WriteString(c.w, `) = (`) io.WriteString(c.w, `) = (`)
colWithTableID(c.w, parent.Table, parent.ID, rel.Col2) if id != -1 {
colWithTableID(c.w, parent, id, rel.Col2)
} else {
colWithTable(c.w, parent, rel.Col2)
}
io.WriteString(c.w, `))`) io.WriteString(c.w, `))`)
case RelOneToManyThrough: case RelOneToManyThrough:
// This requires the through table to be joined onto this select
//fmt.Fprintf(w, `(("%s"."%s") = ("%s"."%s"))`, //fmt.Fprintf(w, `(("%s"."%s") = ("%s"."%s"))`,
//c.sel.Table, rel.Col1, rel.Through, rel.Col2) //c.sel.Table, rel.Col1, rel.Through, rel.Col2)
io.WriteString(c.w, `((`) io.WriteString(c.w, `((`)
colWithTable(c.w, ti.Name, rel.Col1) colWithTable(c.w, table, rel.Col1)
io.WriteString(c.w, `) = (`) io.WriteString(c.w, `) = (`)
colWithTable(c.w, rel.Through, rel.Col2) colWithTable(c.w, rel.Through, rel.Col2)
io.WriteString(c.w, `))`) io.WriteString(c.w, `))`)
@ -768,21 +791,73 @@ func (c *compilerContext) renderWhere(sel *qcode.Select, ti *DBTableInfo) error
qcode.FreeExp(val) qcode.FreeExp(val)
default: default:
if val.NestedCol { if len(val.NestedCols) != 0 {
//fmt.Fprintf(w, `(("%s") `, val.Col) io.WriteString(c.w, `EXISTS `)
io.WriteString(c.w, `(("`)
io.WriteString(c.w, val.Col)
io.WriteString(c.w, `") `)
} else if len(val.Col) != 0 { if err := c.renderNestedWhere(val, sel, ti); err != nil {
return err
}
} else {
//fmt.Fprintf(w, `(("%s"."%s") `, c.sel.Table, val.Col) //fmt.Fprintf(w, `(("%s"."%s") `, c.sel.Table, val.Col)
io.WriteString(c.w, `((`) if err := c.renderOp(val, sel, ti); err != nil {
colWithTable(c.w, ti.Name, val.Col) return err
}
qcode.FreeExp(val)
}
}
default:
return fmt.Errorf("12: unexpected value %v (%t)", intf, intf)
}
}
return nil
}
func (c *compilerContext) renderNestedWhere(ex *qcode.Exp, sel *qcode.Select, ti *DBTableInfo) error {
for i := 0; i < len(ex.NestedCols)-1; i++ {
cti, err := c.schema.GetTable(ex.NestedCols[i])
if err != nil {
return err
}
if i != 0 {
io.WriteString(c.w, ` AND `)
}
io.WriteString(c.w, `(SELECT 1 FROM `)
io.WriteString(c.w, cti.Name)
if err := c.renderJoinByName(cti.Name, ti.Name, -1); err != nil {
return err
}
io.WriteString(c.w, ` WHERE `)
if err := c.renderRelationshipByName(cti.Name, ti.Name, -1); err != nil {
return err
}
}
for i := 0; i < len(ex.NestedCols)-1; i++ {
io.WriteString(c.w, `)`) io.WriteString(c.w, `)`)
} }
valExists := true
switch val.Op { return nil
}
func (c *compilerContext) renderOp(ex *qcode.Exp, sel *qcode.Select, ti *DBTableInfo) error {
if len(ex.Col) != 0 {
io.WriteString(c.w, `((`)
colWithTable(c.w, ti.Name, ex.Col)
io.WriteString(c.w, `) `)
}
switch ex.Op {
case qcode.OpEquals: case qcode.OpEquals:
io.WriteString(c.w, `=`) io.WriteString(c.w, `=`)
case qcode.OpNotEquals: case qcode.OpNotEquals:
@ -822,12 +897,12 @@ func (c *compilerContext) renderWhere(sel *qcode.Select, ti *DBTableInfo) error
case qcode.OpHasKeyAll: case qcode.OpHasKeyAll:
io.WriteString(c.w, `?&`) io.WriteString(c.w, `?&`)
case qcode.OpIsNull: case qcode.OpIsNull:
if strings.EqualFold(val.Val, "true") { if strings.EqualFold(ex.Val, "true") {
io.WriteString(c.w, `IS NULL)`) io.WriteString(c.w, `IS NULL)`)
} else { } else {
io.WriteString(c.w, `IS NOT NULL)`) io.WriteString(c.w, `IS NOT NULL)`)
} }
valExists = false return nil
case qcode.OpEqID: case qcode.OpEqID:
if len(ti.PrimaryCol) == 0 { if len(ti.PrimaryCol) == 0 {
return fmt.Errorf("no primary key column defined for %s", ti.Name) return fmt.Errorf("no primary key column defined for %s", ti.Name)
@ -845,32 +920,21 @@ func (c *compilerContext) renderWhere(sel *qcode.Select, ti *DBTableInfo) error
io.WriteString(c.w, `(("`) io.WriteString(c.w, `(("`)
io.WriteString(c.w, ti.TSVCol) io.WriteString(c.w, ti.TSVCol)
io.WriteString(c.w, `") @@ to_tsquery('`) io.WriteString(c.w, `") @@ to_tsquery('`)
io.WriteString(c.w, val.Val) io.WriteString(c.w, ex.Val)
io.WriteString(c.w, `'))`) io.WriteString(c.w, `'))`)
valExists = false return nil
default: default:
return fmt.Errorf("[Where] unexpected op code %d", val.Op) return fmt.Errorf("[Where] unexpected op code %d", ex.Op)
} }
if valExists { if ex.Type == qcode.ValList {
if val.Type == qcode.ValList { c.renderList(ex)
c.renderList(val)
} else { } else {
c.renderVal(val, c.vars) c.renderVal(ex, c.vars)
} }
io.WriteString(c.w, `)`) io.WriteString(c.w, `)`)
}
qcode.FreeExp(val)
}
default:
return fmt.Errorf("12: unexpected value %v (%t)", intf, intf)
}
}
return nil return nil
} }

View File

@ -332,6 +332,25 @@ func aggFunctionWithFilter(t *testing.T) {
} }
} }
func syntheticTables(t *testing.T) {
gql := `query {
me {
email
}
}`
sql := `SELECT json_object_agg('me', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT ) AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = {{user_id}})) LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "done_1337"`
resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil {
t.Fatal(err)
}
if string(resSQL) != sql {
t.Fatal(errNotExpected)
}
}
func queryWithVariables(t *testing.T) { func queryWithVariables(t *testing.T) {
gql := `query { gql := `query {
product(id: $PRODUCT_ID, where: { price: { eq: $PRODUCT_PRICE } }) { product(id: $PRODUCT_ID, where: { price: { eq: $PRODUCT_PRICE } }) {
@ -352,14 +371,21 @@ func queryWithVariables(t *testing.T) {
} }
} }
func syntheticTables(t *testing.T) { func withWhereOnRelations(t *testing.T) {
gql := `query { gql := `query {
me { users(where: {
not: {
products: {
price: { gt: 3 }
}
}
}) {
id
email email
} }
}` }`
sql := `SELECT json_object_agg('me', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT ) AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = {{user_id}})) LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "done_1337"` sql := `SELECT json_object_agg('users', users) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."id" AS "id", "users_0"."email" AS "email") AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."id", "users"."email" FROM "users" WHERE (NOT EXISTS (SELECT 1 FROM products WHERE (("products"."user_id") = ("users"."id")))) LIMIT ('20') :: integer) AS "users_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"`
resSQL, err := compileGQLToPSQL(gql, nil, "user") resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil { if err != nil {
@ -371,50 +397,6 @@ func syntheticTables(t *testing.T) {
} }
} }
func TestCompileQuery(t *testing.T) {
t.Run("withComplexArgs", withComplexArgs)
t.Run("withWhereAndList", withWhereAndList)
t.Run("withWhereIsNull", withWhereIsNull)
t.Run("withWhereMultiOr", withWhereMultiOr)
t.Run("fetchByID", fetchByID)
t.Run("searchQuery", searchQuery)
t.Run("belongsTo", belongsTo)
t.Run("oneToMany", oneToMany)
t.Run("manyToMany", manyToMany)
t.Run("manyToManyReverse", manyToManyReverse)
t.Run("aggFunction", aggFunction)
t.Run("aggFunctionBlockedByCol", aggFunctionBlockedByCol)
t.Run("aggFunctionDisabled", aggFunctionDisabled)
t.Run("aggFunctionWithFilter", aggFunctionWithFilter)
t.Run("syntheticTables", syntheticTables)
t.Run("queryWithVariables", queryWithVariables)
t.Run("blockedQuery", blockedQuery)
t.Run("blockedFunctions", blockedFunctions)
}
var benchGQL = []byte(`query {
proDUcts(
# returns only 30 items
limit: 30,
# starts from item 10, commented out for now
# offset: 10,
# orders the response items by highest price
order_by: { price: desc },
# only items with an id >= 30 and < 30 are returned
where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) {
id
NAME
price
user {
full_name
picture : avatar
}
}
}`)
func blockedQuery(t *testing.T) { func blockedQuery(t *testing.T) {
gql := `query { gql := `query {
user(id: 5, where: { id: { gt: 3 } }) { user(id: 5, where: { id: { gt: 3 } }) {
@ -456,6 +438,51 @@ func blockedFunctions(t *testing.T) {
} }
} }
func TestCompileQuery(t *testing.T) {
t.Run("withComplexArgs", withComplexArgs)
t.Run("withWhereAndList", withWhereAndList)
t.Run("withWhereIsNull", withWhereIsNull)
t.Run("withWhereMultiOr", withWhereMultiOr)
t.Run("fetchByID", fetchByID)
t.Run("searchQuery", searchQuery)
t.Run("belongsTo", belongsTo)
t.Run("oneToMany", oneToMany)
t.Run("manyToMany", manyToMany)
t.Run("manyToManyReverse", manyToManyReverse)
t.Run("aggFunction", aggFunction)
t.Run("aggFunctionBlockedByCol", aggFunctionBlockedByCol)
t.Run("aggFunctionDisabled", aggFunctionDisabled)
t.Run("aggFunctionWithFilter", aggFunctionWithFilter)
t.Run("syntheticTables", syntheticTables)
t.Run("queryWithVariables", queryWithVariables)
t.Run("withWhereOnRelations", withWhereOnRelations)
t.Run("blockedQuery", blockedQuery)
t.Run("blockedFunctions", blockedFunctions)
}
var benchGQL = []byte(`query {
proDUcts(
# returns only 30 items
limit: 30,
# starts from item 10, commented out for now
# offset: 10,
# orders the response items by highest price
order_by: { price: desc },
# only items with an id >= 30 and < 30 are returned
where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) {
id
NAME
price
user {
full_name
picture : avatar
}
}
}`)
func BenchmarkCompile(b *testing.B) { func BenchmarkCompile(b *testing.B) {
w := &bytes.Buffer{} w := &bytes.Buffer{}

View File

@ -59,7 +59,7 @@ type Column struct {
type Exp struct { type Exp struct {
Op ExpOp Op ExpOp
Col string Col string
NestedCol bool NestedCols []string
Type ValType Type ValType
Val string Val string
ListType ValType ListType ValType
@ -918,10 +918,8 @@ func setWhereColName(ex *Exp, node *Node) {
} }
if len(list) == 1 { if len(list) == 1 {
ex.Col = list[0] ex.Col = list[0]
} else if len(list) > 1 {
} else if len(list) > 2 { ex.NestedCols = list
ex.Col = buildPath(list)
ex.NestedCol = true
} }
} }

View File

@ -8,19 +8,19 @@ import (
"github.com/dosco/super-graph/jsn" "github.com/dosco/super-graph/jsn"
) )
func varMap(ctx *coreContext) func(w io.Writer, tag string) (int, error) { func argMap(ctx *coreContext) func(w io.Writer, tag string) (int, error) {
return func(w io.Writer, tag string) (int, error) { return func(w io.Writer, tag string) (int, error) {
switch tag { switch tag {
case "user_id_provider": case "user_id_provider":
if v := ctx.Value(userIDProviderKey); v != nil { if v := ctx.Value(userIDProviderKey); v != nil {
return stringVar(w, v.(string)) return stringArg(w, v.(string))
} }
io.WriteString(w, "null") io.WriteString(w, "null")
return 0, nil return 0, nil
case "user_id": case "user_id":
if v := ctx.Value(userIDKey); v != nil { if v := ctx.Value(userIDKey); v != nil {
return stringVar(w, v.(string)) return stringArg(w, v.(string))
} }
io.WriteString(w, "null") io.WriteString(w, "null")
@ -28,7 +28,7 @@ func varMap(ctx *coreContext) func(w io.Writer, tag string) (int, error) {
case "user_role": case "user_role":
if v := ctx.Value(userRoleKey); v != nil { if v := ctx.Value(userRoleKey); v != nil {
return stringVar(w, v.(string)) return stringArg(w, v.(string))
} }
io.WriteString(w, "null") io.WriteString(w, "null")
return 0, nil return 0, nil
@ -50,7 +50,7 @@ func varMap(ctx *coreContext) func(w io.Writer, tag string) (int, error) {
} }
if is { if is {
return stringVarB(w, fields[0].Value) return stringArgB(w, fields[0].Value)
} }
w.Write(fields[0].Value) w.Write(fields[0].Value)
@ -58,7 +58,7 @@ func varMap(ctx *coreContext) func(w io.Writer, tag string) (int, error) {
} }
} }
func varList(ctx *coreContext, args [][]byte) []interface{} { func argList(ctx *coreContext, args [][]byte) []interface{} {
vars := make([]interface{}, len(args)) vars := make([]interface{}, len(args))
var fields map[string]interface{} var fields map[string]interface{}
@ -86,6 +86,11 @@ func varList(ctx *coreContext, args [][]byte) []interface{} {
vars[i] = v.(string) vars[i] = v.(string)
} }
case bytes.Equal(av, []byte("user_role")):
if v := ctx.Value(userRoleKey); v != nil {
vars[i] = v.(string)
}
default: default:
if v, ok := fields[string(av)]; ok { if v, ok := fields[string(av)]; ok {
vars[i] = v vars[i] = v
@ -96,7 +101,7 @@ func varList(ctx *coreContext, args [][]byte) []interface{} {
return vars return vars
} }
func stringVar(w io.Writer, v string) (int, error) { func stringArg(w io.Writer, v string) (int, error) {
if n, err := w.Write([]byte(`'`)); err != nil { if n, err := w.Write([]byte(`'`)); err != nil {
return n, err return n, err
} }
@ -106,7 +111,7 @@ func stringVar(w io.Writer, v string) (int, error) {
return w.Write([]byte(`'`)) return w.Write([]byte(`'`))
} }
func stringVarB(w io.Writer, v []byte) (int, error) { func stringArgB(w io.Writer, v []byte) (int, error) {
if n, err := w.Write([]byte(`'`)); err != nil { if n, err := w.Write([]byte(`'`)); err != nil {
return n, err return n, err
} }

View File

@ -66,6 +66,7 @@ type config struct {
PoolSize int32 `mapstructure:"pool_size"` PoolSize int32 `mapstructure:"pool_size"`
MaxRetries int `mapstructure:"max_retries"` MaxRetries int `mapstructure:"max_retries"`
LogLevel string `mapstructure:"log_level"` LogLevel string `mapstructure:"log_level"`
SetUserID bool `mapstructure:"set_user_id"`
Vars map[string]string `mapstructure:"variables"` Vars map[string]string `mapstructure:"variables"`

View File

@ -122,10 +122,8 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) {
} }
defer tx.Rollback(c) defer tx.Rollback(c)
if v := c.Value(userIDKey); v != nil { if conf.DB.SetUserID {
_, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v)) if err := c.setLocalUserID(tx); err != nil {
if err != nil {
return nil, nil, err return nil, nil, err
} }
} }
@ -153,7 +151,7 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) {
} }
var root []byte var root []byte
vars := varList(c, ps.args) vars := argList(c, ps.args)
if mutation { if mutation {
err = tx.QueryRow(c, ps.stmt.SQL, vars...).Scan(&root) err = tx.QueryRow(c, ps.stmt.SQL, vars...).Scan(&root)
@ -206,7 +204,7 @@ func (c *coreContext) resolveSQL() ([]byte, uint32, error) {
t := fasttemplate.New(st.sql, openVar, closeVar) t := fasttemplate.New(st.sql, openVar, closeVar)
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
_, err = t.ExecuteFunc(buf, varMap(c)) _, err = t.ExecuteFunc(buf, argMap(c))
if err == errNoUserID { if err == errNoUserID {
logger.Warn().Msg("no user id found. query requires an authenicated request") logger.Warn().Msg("no user id found. query requires an authenicated request")
@ -224,10 +222,8 @@ func (c *coreContext) resolveSQL() ([]byte, uint32, error) {
stime = time.Now() stime = time.Now()
} }
if v := c.Value(userIDKey); v != nil { if conf.DB.SetUserID {
_, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v)) if err := c.setLocalUserID(tx); err != nil {
if err != nil {
return nil, 0, err return nil, 0, err
} }
} }
@ -425,6 +421,15 @@ func (c *coreContext) executeRoleQuery(tx pgx.Tx) (string, error) {
return role, nil return role, nil
} }
func (c *coreContext) setLocalUserID(tx pgx.Tx) error {
var err error
if v := c.Value(userIDKey); v != nil {
_, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
}
return err
}
func (c *coreContext) render(w io.Writer, data []byte) error { func (c *coreContext) render(w io.Writer, data []byte) error {
c.res.Data = json.RawMessage(data) c.res.Data = json.RawMessage(data)
return json.NewEncoder(w).Encode(c.res) return json.NewEncoder(w).Encode(c.res)

View File

@ -51,7 +51,8 @@ auth:
cookie: _{{app_name_slug}}_session cookie: _{{app_name_slug}}_session
# Comment this out if you want to disable setting # Comment this out if you want to disable setting
# the user_id via a header. Good for testing # the user_id via a header for testing.
# Disable in production
creds_in_header: true creds_in_header: true
rails: rails:
@ -92,6 +93,10 @@ database:
#max_retries: 0 #max_retries: 0
#log_level: "debug" #log_level: "debug"
# Set session variable "user.id" to the user id
# Enable this if you need the user id in triggers, etc
set_user_id: false
# Define variables here that you want to use in filters # Define variables here that you want to use in filters
# sub-queries must be wrapped in () # sub-queries must be wrapped in ()
variables: variables:
@ -131,7 +136,7 @@ tables:
name: me name: me
table: users table: users
roles_query: "SELECT * FROM users as usr WHERE id = $user_id" roles_query: "SELECT * FROM users WHERE id = $user_id"
roles: roles:
- name: anon - name: anon

View File

@ -85,3 +85,7 @@ database:
#pool_size: 10 #pool_size: 10
#max_retries: 0 #max_retries: 0
#log_level: "debug" #log_level: "debug"
# Set session variable "user.id" to the user id
# Enable this if you need the user id in triggers, etc
set_user_id: false