Add nested where clause to filter based on related tables

This commit is contained in:
Vikram Rangnekar 2019-11-04 23:44:42 -05:00
parent 77a51924a7
commit 89bc93e159
13 changed files with 358 additions and 206 deletions

View File

@ -1,14 +1,14 @@
<a href="https://supergraph.dev"><img src="https://supergraph.dev/hologram.svg" width="100" height="100" align="right" /></a>
# Super Graph - Build web products faster. Instant GraphQL APIs for your apps
# Super Graph - Instant GraphQL APIs for your apps.
## Build web products faster. No code needed. GraphQL auto. transformed into efficient database queries.
![MIT license](https://img.shields.io/github/license/dosco/super-graph.svg)
![Docker build](https://img.shields.io/docker/cloud/build/dosco/super-graph.svg)
![Cloud native](https://img.shields.io/badge/cloud--native-enabled-blue.svg)
[![Discord Chat](https://img.shields.io/discord/628796009539043348.svg)](https://discord.gg/6pSWCTZ)
Get an instant high performance GraphQL API for Postgres. No code needed. GraphQL is automatically transformed into efficient database queries.
![GraphQL](docs/.vuepress/public/graphql.png?raw=true "")
## The story of Super Graph?
@ -25,6 +25,7 @@ This compiler is what sits at the heart of Super Graph with layers of useful fun
## Features
- Role based access control
- Works with Rails database schemas
- Automatically learns schemas and relationships
- Belongs-To, One-To-Many and Many-To-Many table relationships

View File

@ -51,7 +51,8 @@ auth:
cookie: _app_session
# Comment this out if you want to disable setting
# the user_id via a header. Good for testing
# the user_id via a header for testing.
# Disable in production
creds_in_header: true
rails:
@ -92,6 +93,10 @@ database:
#max_retries: 0
#log_level: "debug"
# Set session variable "user.id" to the user id
# Enable this if you need the user id in triggers, etc
set_user_id: false
# Define variables here that you want to use in filters
# sub-queries must be wrapped in ()
variables:
@ -131,7 +136,7 @@ tables:
name: me
table: users
roles_query: "SELECT * FROM users as usr WHERE id = $user_id"
roles_query: "SELECT * FROM users WHERE id = $user_id"
roles:
- name: anon

View File

@ -85,3 +85,7 @@ database:
#pool_size: 10
#max_retries: 0
#log_level: "debug"
# Set session variable "user.id" to the user id
# Enable this if you need the user id in triggers, etc
set_user_id: false

View File

@ -214,6 +214,7 @@ func (c *compilerContext) renderDelete(qc *qcode.QCode, w io.Writer,
func (c *compilerContext) renderUpsert(qc *qcode.QCode, w io.Writer,
vars Variables, ti *DBTableInfo) (uint32, error) {
root := &qc.Selects[0]
upsert, ok := vars[qc.ActionVar]
if !ok {
@ -229,7 +230,7 @@ func (c *compilerContext) renderUpsert(qc *qcode.QCode, w io.Writer,
return 0, err
}
io.WriteString(c.w, ` ON CONFLICT DO (`)
io.WriteString(c.w, ` ON CONFLICT (`)
i := 0
for _, cn := range ti.ColumnNames {
@ -250,10 +251,17 @@ func (c *compilerContext) renderUpsert(qc *qcode.QCode, w io.Writer,
if i == 0 {
io.WriteString(c.w, ti.PrimaryCol)
}
io.WriteString(c.w, `) DO `)
io.WriteString(c.w, `)`)
io.WriteString(c.w, `UPDATE `)
io.WriteString(c.w, ` SET `)
if root.Where != nil {
io.WriteString(c.w, ` WHERE `)
if err := c.renderWhere(root, ti); err != nil {
return 0, err
}
}
io.WriteString(c.w, ` DO UPDATE SET `)
i = 0
for _, cn := range ti.ColumnNames {

View File

@ -78,13 +78,37 @@ func bulkInsert(t *testing.T) {
func singleUpsert(t *testing.T) {
gql := `mutation {
product(id: 15, upsert: $upsert) {
product(upsert: $upsert) {
id
name
}
}`
sql := `WITH "products" AS (WITH "input" AS (SELECT {{upsert}}::json AS j) INSERT INTO "products" ("name", "description") SELECT "name", "description" FROM input i, json_populate_record(NULL::products, i.j) t ON CONFLICT DO (id) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"`
sql := `WITH "products" AS (WITH "input" AS (SELECT {{upsert}}::json AS j) INSERT INTO "products" ("name", "description") SELECT "name", "description" FROM input i, json_populate_record(NULL::products, i.j) t ON CONFLICT (id) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"`
vars := map[string]json.RawMessage{
"upsert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`),
}
resSQL, err := compileGQLToPSQL(gql, vars, "user")
if err != nil {
t.Fatal(err)
}
if string(resSQL) != sql {
t.Fatal(errNotExpected)
}
}
func singleUpsertWhere(t *testing.T) {
gql := `mutation {
product(upsert: $upsert, where: { price : { gt: 3 } }) {
id
name
}
}`
sql := `WITH "products" AS (WITH "input" AS (SELECT {{upsert}}::json AS j) INSERT INTO "products" ("name", "description") SELECT "name", "description" FROM input i, json_populate_record(NULL::products, i.j) t ON CONFLICT (id) WHERE (("products"."price") > 3) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"`
vars := map[string]json.RawMessage{
"upsert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`),
@ -102,13 +126,13 @@ func singleUpsert(t *testing.T) {
func bulkUpsert(t *testing.T) {
gql := `mutation {
product(id: 15, upsert: $upsert) {
product(upsert: $upsert) {
id
name
}
}`
sql := `WITH "products" AS (WITH "input" AS (SELECT {{upsert}}::json AS j) INSERT INTO "products" ("name", "description") SELECT "name", "description" FROM input i, json_populate_recordset(NULL::products, i.j) t ON CONFLICT DO (id) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"`
sql := `WITH "products" AS (WITH "input" AS (SELECT {{upsert}}::json AS j) INSERT INTO "products" ("name", "description") SELECT "name", "description" FROM input i, json_populate_recordset(NULL::products, i.j) t ON CONFLICT (id) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"`
vars := map[string]json.RawMessage{
"upsert": json.RawMessage(` [{ "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }]`),
@ -271,6 +295,7 @@ func TestCompileMutate(t *testing.T) {
t.Run("bulkInsert", bulkInsert)
t.Run("singleUpdate", singleUpdate)
t.Run("singleUpsert", singleUpsert)
t.Run("singleUpsertWhere", singleUpsertWhere)
t.Run("bulkUpsert", bulkUpsert)
t.Run("delete", delete)
t.Run("blockedInsert", blockedInsert)

View File

@ -120,7 +120,7 @@ func (co *Compiler) compileQuery(qc *qcode.QCode, w io.Writer) (uint32, error) {
}
if sel.ID != 0 {
if err = c.renderJoin(sel); err != nil {
if err = c.renderLateralJoin(sel); err != nil {
return 0, err
}
}
@ -154,7 +154,7 @@ func (co *Compiler) compileQuery(qc *qcode.QCode, w io.Writer) (uint32, error) {
}
if sel.ID != 0 {
if err = c.renderJoinClose(sel); err != nil {
if err = c.renderLateralJoinClose(sel); err != nil {
return 0, err
}
}
@ -327,12 +327,12 @@ func (c *compilerContext) renderSelectClose(sel *qcode.Select, ti *DBTableInfo)
return nil
}
func (c *compilerContext) renderJoin(sel *qcode.Select) error {
func (c *compilerContext) renderLateralJoin(sel *qcode.Select) error {
io.WriteString(c.w, ` LEFT OUTER JOIN LATERAL (`)
return nil
}
func (c *compilerContext) renderJoinClose(sel *qcode.Select) error {
func (c *compilerContext) renderLateralJoinClose(sel *qcode.Select) error {
//fmt.Fprintf(w, `) AS "%s_%d_join" ON ('true')`, c.sel.Table, c.sel.ID)
io.WriteString(c.w, `)`)
aliasWithIDSuffix(c.w, sel.Table, sel.ID, "_join")
@ -340,19 +340,24 @@ func (c *compilerContext) renderJoinClose(sel *qcode.Select) error {
return nil
}
func (c *compilerContext) renderJoinTable(sel *qcode.Select) error {
func (c *compilerContext) renderJoin(sel *qcode.Select) error {
parent := &c.s[sel.ParentID]
return c.renderJoinByName(sel.Table, parent.Table, parent.ID)
}
rel, err := c.schema.GetRel(sel.Table, parent.Table)
func (c *compilerContext) renderJoinByName(table, parent string, id int32) error {
rel, err := c.schema.GetRel(table, parent)
if err != nil {
return err
}
// This join is only required for one-to-many relations since
// these make use of join tables that need to be pulled in.
if rel.Type != RelOneToManyThrough {
return err
}
pt, err := c.schema.GetTable(parent.Table)
pt, err := c.schema.GetTable(parent)
if err != nil {
return err
}
@ -364,7 +369,11 @@ func (c *compilerContext) renderJoinTable(sel *qcode.Select) error {
io.WriteString(c.w, `" ON ((`)
colWithTable(c.w, rel.Through, rel.ColT)
io.WriteString(c.w, `) = (`)
colWithTableID(c.w, pt.Name, parent.ID, rel.Col1)
if id != -1 {
colWithTableID(c.w, pt.Name, id, rel.Col1)
} else {
colWithTable(c.w, pt.Name, rel.Col1)
}
io.WriteString(c.w, `))`)
return nil
@ -438,7 +447,7 @@ func (c *compilerContext) renderJoinedColumns(sel *qcode.Select, ti *DBTableInfo
cti, err := c.schema.GetTable(childSel.Table)
if err != nil {
continue
return err
}
//fmt.Fprintf(w, `"%s_%d_join"."%s" AS "%s"`,
@ -529,8 +538,10 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
} else if sel.Functions {
cn1 := cn[pl:]
if _, ok := sel.Allowed[cn1]; !ok {
continue
if len(sel.Allowed) != 0 {
if _, ok := sel.Allowed[cn1]; !ok {
continue
}
}
if i != 0 {
io.WriteString(c.w, `, `)
@ -596,7 +607,7 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
}
if !isRoot {
if err := c.renderJoinTable(sel); err != nil {
if err := c.renderJoin(sel); err != nil {
return err
}
@ -680,8 +691,11 @@ func (c *compilerContext) renderOrderByColumns(sel *qcode.Select, ti *DBTableInf
func (c *compilerContext) renderRelationship(sel *qcode.Select, ti *DBTableInfo) error {
parent := c.s[sel.ParentID]
return c.renderRelationshipByName(sel.Table, parent.Table, parent.ID)
}
rel, err := c.schema.GetRel(sel.Table, parent.Table)
func (c *compilerContext) renderRelationshipByName(table, parent string, id int32) error {
rel, err := c.schema.GetRel(table, parent)
if err != nil {
return err
}
@ -691,25 +705,34 @@ func (c *compilerContext) renderRelationship(sel *qcode.Select, ti *DBTableInfo)
//fmt.Fprintf(w, `(("%s"."%s") = ("%s_%d"."%s"))`,
//c.sel.Table, rel.Col1, c.parent.Table, c.parent.ID, rel.Col2)
io.WriteString(c.w, `((`)
colWithTable(c.w, ti.Name, rel.Col1)
colWithTable(c.w, table, rel.Col1)
io.WriteString(c.w, `) = (`)
colWithTableID(c.w, parent.Table, parent.ID, rel.Col2)
if id != -1 {
colWithTableID(c.w, parent, id, rel.Col2)
} else {
colWithTable(c.w, parent, rel.Col2)
}
io.WriteString(c.w, `))`)
case RelOneToMany:
//fmt.Fprintf(w, `(("%s"."%s") = ("%s_%d"."%s"))`,
//c.sel.Table, rel.Col1, c.parent.Table, c.parent.ID, rel.Col2)
io.WriteString(c.w, `((`)
colWithTable(c.w, ti.Name, rel.Col1)
colWithTable(c.w, table, rel.Col1)
io.WriteString(c.w, `) = (`)
colWithTableID(c.w, parent.Table, parent.ID, rel.Col2)
if id != -1 {
colWithTableID(c.w, parent, id, rel.Col2)
} else {
colWithTable(c.w, parent, rel.Col2)
}
io.WriteString(c.w, `))`)
case RelOneToManyThrough:
// This requires the through table to be joined onto this select
//fmt.Fprintf(w, `(("%s"."%s") = ("%s"."%s"))`,
//c.sel.Table, rel.Col1, rel.Through, rel.Col2)
io.WriteString(c.w, `((`)
colWithTable(c.w, ti.Name, rel.Col1)
colWithTable(c.w, table, rel.Col1)
io.WriteString(c.w, `) = (`)
colWithTable(c.w, rel.Through, rel.Col2)
io.WriteString(c.w, `))`)
@ -768,101 +791,20 @@ func (c *compilerContext) renderWhere(sel *qcode.Select, ti *DBTableInfo) error
qcode.FreeExp(val)
default:
if val.NestedCol {
//fmt.Fprintf(w, `(("%s") `, val.Col)
io.WriteString(c.w, `(("`)
io.WriteString(c.w, val.Col)
io.WriteString(c.w, `") `)
if len(val.NestedCols) != 0 {
io.WriteString(c.w, `EXISTS `)
} else if len(val.Col) != 0 {
if err := c.renderNestedWhere(val, sel, ti); err != nil {
return err
}
} else {
//fmt.Fprintf(w, `(("%s"."%s") `, c.sel.Table, val.Col)
io.WriteString(c.w, `((`)
colWithTable(c.w, ti.Name, val.Col)
io.WriteString(c.w, `) `)
if err := c.renderOp(val, sel, ti); err != nil {
return err
}
qcode.FreeExp(val)
}
valExists := true
switch val.Op {
case qcode.OpEquals:
io.WriteString(c.w, `=`)
case qcode.OpNotEquals:
io.WriteString(c.w, `!=`)
case qcode.OpGreaterOrEquals:
io.WriteString(c.w, `>=`)
case qcode.OpLesserOrEquals:
io.WriteString(c.w, `<=`)
case qcode.OpGreaterThan:
io.WriteString(c.w, `>`)
case qcode.OpLesserThan:
io.WriteString(c.w, `<`)
case qcode.OpIn:
io.WriteString(c.w, `IN`)
case qcode.OpNotIn:
io.WriteString(c.w, `NOT IN`)
case qcode.OpLike:
io.WriteString(c.w, `LIKE`)
case qcode.OpNotLike:
io.WriteString(c.w, `NOT LIKE`)
case qcode.OpILike:
io.WriteString(c.w, `ILIKE`)
case qcode.OpNotILike:
io.WriteString(c.w, `NOT ILIKE`)
case qcode.OpSimilar:
io.WriteString(c.w, `SIMILAR TO`)
case qcode.OpNotSimilar:
io.WriteString(c.w, `NOT SIMILAR TO`)
case qcode.OpContains:
io.WriteString(c.w, `@>`)
case qcode.OpContainedIn:
io.WriteString(c.w, `<@`)
case qcode.OpHasKey:
io.WriteString(c.w, `?`)
case qcode.OpHasKeyAny:
io.WriteString(c.w, `?|`)
case qcode.OpHasKeyAll:
io.WriteString(c.w, `?&`)
case qcode.OpIsNull:
if strings.EqualFold(val.Val, "true") {
io.WriteString(c.w, `IS NULL)`)
} else {
io.WriteString(c.w, `IS NOT NULL)`)
}
valExists = false
case qcode.OpEqID:
if len(ti.PrimaryCol) == 0 {
return fmt.Errorf("no primary key column defined for %s", ti.Name)
}
//fmt.Fprintf(w, `(("%s") =`, c.ti.PrimaryCol)
io.WriteString(c.w, `((`)
colWithTable(c.w, ti.Name, ti.PrimaryCol)
//io.WriteString(c.w, ti.PrimaryCol)
io.WriteString(c.w, `) =`)
case qcode.OpTsQuery:
if len(ti.TSVCol) == 0 {
return fmt.Errorf("no tsv column defined for %s", ti.Name)
}
//fmt.Fprintf(w, `(("%s") @@ to_tsquery('%s'))`, c.ti.TSVCol, val.Val)
io.WriteString(c.w, `(("`)
io.WriteString(c.w, ti.TSVCol)
io.WriteString(c.w, `") @@ to_tsquery('`)
io.WriteString(c.w, val.Val)
io.WriteString(c.w, `'))`)
valExists = false
default:
return fmt.Errorf("[Where] unexpected op code %d", val.Op)
}
if valExists {
if val.Type == qcode.ValList {
c.renderList(val)
} else {
c.renderVal(val, c.vars)
}
io.WriteString(c.w, `)`)
}
qcode.FreeExp(val)
}
default:
@ -874,6 +816,128 @@ func (c *compilerContext) renderWhere(sel *qcode.Select, ti *DBTableInfo) error
return nil
}
func (c *compilerContext) renderNestedWhere(ex *qcode.Exp, sel *qcode.Select, ti *DBTableInfo) error {
for i := 0; i < len(ex.NestedCols)-1; i++ {
cti, err := c.schema.GetTable(ex.NestedCols[i])
if err != nil {
return err
}
if i != 0 {
io.WriteString(c.w, ` AND `)
}
io.WriteString(c.w, `(SELECT 1 FROM `)
io.WriteString(c.w, cti.Name)
if err := c.renderJoinByName(cti.Name, ti.Name, -1); err != nil {
return err
}
io.WriteString(c.w, ` WHERE `)
if err := c.renderRelationshipByName(cti.Name, ti.Name, -1); err != nil {
return err
}
}
for i := 0; i < len(ex.NestedCols)-1; i++ {
io.WriteString(c.w, `)`)
}
return nil
}
func (c *compilerContext) renderOp(ex *qcode.Exp, sel *qcode.Select, ti *DBTableInfo) error {
if len(ex.Col) != 0 {
io.WriteString(c.w, `((`)
colWithTable(c.w, ti.Name, ex.Col)
io.WriteString(c.w, `) `)
}
switch ex.Op {
case qcode.OpEquals:
io.WriteString(c.w, `=`)
case qcode.OpNotEquals:
io.WriteString(c.w, `!=`)
case qcode.OpGreaterOrEquals:
io.WriteString(c.w, `>=`)
case qcode.OpLesserOrEquals:
io.WriteString(c.w, `<=`)
case qcode.OpGreaterThan:
io.WriteString(c.w, `>`)
case qcode.OpLesserThan:
io.WriteString(c.w, `<`)
case qcode.OpIn:
io.WriteString(c.w, `IN`)
case qcode.OpNotIn:
io.WriteString(c.w, `NOT IN`)
case qcode.OpLike:
io.WriteString(c.w, `LIKE`)
case qcode.OpNotLike:
io.WriteString(c.w, `NOT LIKE`)
case qcode.OpILike:
io.WriteString(c.w, `ILIKE`)
case qcode.OpNotILike:
io.WriteString(c.w, `NOT ILIKE`)
case qcode.OpSimilar:
io.WriteString(c.w, `SIMILAR TO`)
case qcode.OpNotSimilar:
io.WriteString(c.w, `NOT SIMILAR TO`)
case qcode.OpContains:
io.WriteString(c.w, `@>`)
case qcode.OpContainedIn:
io.WriteString(c.w, `<@`)
case qcode.OpHasKey:
io.WriteString(c.w, `?`)
case qcode.OpHasKeyAny:
io.WriteString(c.w, `?|`)
case qcode.OpHasKeyAll:
io.WriteString(c.w, `?&`)
case qcode.OpIsNull:
if strings.EqualFold(ex.Val, "true") {
io.WriteString(c.w, `IS NULL)`)
} else {
io.WriteString(c.w, `IS NOT NULL)`)
}
return nil
case qcode.OpEqID:
if len(ti.PrimaryCol) == 0 {
return fmt.Errorf("no primary key column defined for %s", ti.Name)
}
//fmt.Fprintf(w, `(("%s") =`, c.ti.PrimaryCol)
io.WriteString(c.w, `((`)
colWithTable(c.w, ti.Name, ti.PrimaryCol)
//io.WriteString(c.w, ti.PrimaryCol)
io.WriteString(c.w, `) =`)
case qcode.OpTsQuery:
if len(ti.TSVCol) == 0 {
return fmt.Errorf("no tsv column defined for %s", ti.Name)
}
//fmt.Fprintf(w, `(("%s") @@ to_tsquery('%s'))`, c.ti.TSVCol, val.Val)
io.WriteString(c.w, `(("`)
io.WriteString(c.w, ti.TSVCol)
io.WriteString(c.w, `") @@ to_tsquery('`)
io.WriteString(c.w, ex.Val)
io.WriteString(c.w, `'))`)
return nil
default:
return fmt.Errorf("[Where] unexpected op code %d", ex.Op)
}
if ex.Type == qcode.ValList {
c.renderList(ex)
} else {
c.renderVal(ex, c.vars)
}
io.WriteString(c.w, `)`)
return nil
}
func (c *compilerContext) renderOrderBy(sel *qcode.Select, ti *DBTableInfo) error {
io.WriteString(c.w, ` ORDER BY `)
for i := range sel.OrderBy {

View File

@ -332,6 +332,25 @@ func aggFunctionWithFilter(t *testing.T) {
}
}
func syntheticTables(t *testing.T) {
gql := `query {
me {
email
}
}`
sql := `SELECT json_object_agg('me', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT ) AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = {{user_id}})) LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "done_1337"`
resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil {
t.Fatal(err)
}
if string(resSQL) != sql {
t.Fatal(errNotExpected)
}
}
func queryWithVariables(t *testing.T) {
gql := `query {
product(id: $PRODUCT_ID, where: { price: { eq: $PRODUCT_PRICE } }) {
@ -352,14 +371,21 @@ func queryWithVariables(t *testing.T) {
}
}
func syntheticTables(t *testing.T) {
func withWhereOnRelations(t *testing.T) {
gql := `query {
me {
users(where: {
not: {
products: {
price: { gt: 3 }
}
}
}) {
id
email
}
}`
sql := `SELECT json_object_agg('me', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT ) AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = {{user_id}})) LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "done_1337"`
sql := `SELECT json_object_agg('users', users) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."id" AS "id", "users_0"."email" AS "email") AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."id", "users"."email" FROM "users" WHERE (NOT EXISTS (SELECT 1 FROM products WHERE (("products"."user_id") = ("users"."id")))) LIMIT ('20') :: integer) AS "users_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"`
resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil {
@ -371,50 +397,6 @@ func syntheticTables(t *testing.T) {
}
}
func TestCompileQuery(t *testing.T) {
t.Run("withComplexArgs", withComplexArgs)
t.Run("withWhereAndList", withWhereAndList)
t.Run("withWhereIsNull", withWhereIsNull)
t.Run("withWhereMultiOr", withWhereMultiOr)
t.Run("fetchByID", fetchByID)
t.Run("searchQuery", searchQuery)
t.Run("belongsTo", belongsTo)
t.Run("oneToMany", oneToMany)
t.Run("manyToMany", manyToMany)
t.Run("manyToManyReverse", manyToManyReverse)
t.Run("aggFunction", aggFunction)
t.Run("aggFunctionBlockedByCol", aggFunctionBlockedByCol)
t.Run("aggFunctionDisabled", aggFunctionDisabled)
t.Run("aggFunctionWithFilter", aggFunctionWithFilter)
t.Run("syntheticTables", syntheticTables)
t.Run("queryWithVariables", queryWithVariables)
t.Run("blockedQuery", blockedQuery)
t.Run("blockedFunctions", blockedFunctions)
}
var benchGQL = []byte(`query {
proDUcts(
# returns only 30 items
limit: 30,
# starts from item 10, commented out for now
# offset: 10,
# orders the response items by highest price
order_by: { price: desc },
# only items with an id >= 30 and < 30 are returned
where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) {
id
NAME
price
user {
full_name
picture : avatar
}
}
}`)
func blockedQuery(t *testing.T) {
gql := `query {
user(id: 5, where: { id: { gt: 3 } }) {
@ -456,6 +438,51 @@ func blockedFunctions(t *testing.T) {
}
}
func TestCompileQuery(t *testing.T) {
t.Run("withComplexArgs", withComplexArgs)
t.Run("withWhereAndList", withWhereAndList)
t.Run("withWhereIsNull", withWhereIsNull)
t.Run("withWhereMultiOr", withWhereMultiOr)
t.Run("fetchByID", fetchByID)
t.Run("searchQuery", searchQuery)
t.Run("belongsTo", belongsTo)
t.Run("oneToMany", oneToMany)
t.Run("manyToMany", manyToMany)
t.Run("manyToManyReverse", manyToManyReverse)
t.Run("aggFunction", aggFunction)
t.Run("aggFunctionBlockedByCol", aggFunctionBlockedByCol)
t.Run("aggFunctionDisabled", aggFunctionDisabled)
t.Run("aggFunctionWithFilter", aggFunctionWithFilter)
t.Run("syntheticTables", syntheticTables)
t.Run("queryWithVariables", queryWithVariables)
t.Run("withWhereOnRelations", withWhereOnRelations)
t.Run("blockedQuery", blockedQuery)
t.Run("blockedFunctions", blockedFunctions)
}
var benchGQL = []byte(`query {
proDUcts(
# returns only 30 items
limit: 30,
# starts from item 10, commented out for now
# offset: 10,
# orders the response items by highest price
order_by: { price: desc },
# only items with an id >= 30 and < 30 are returned
where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) {
id
NAME
price
user {
full_name
picture : avatar
}
}
}`)
func BenchmarkCompile(b *testing.B) {
w := &bytes.Buffer{}

View File

@ -57,16 +57,16 @@ type Column struct {
}
type Exp struct {
Op ExpOp
Col string
NestedCol bool
Type ValType
Val string
ListType ValType
ListVal []string
Children []*Exp
childrenA [5]*Exp
doFree bool
Op ExpOp
Col string
NestedCols []string
Type ValType
Val string
ListType ValType
ListVal []string
Children []*Exp
childrenA [5]*Exp
doFree bool
}
var zeroExp = Exp{doFree: true}
@ -918,10 +918,8 @@ func setWhereColName(ex *Exp, node *Node) {
}
if len(list) == 1 {
ex.Col = list[0]
} else if len(list) > 2 {
ex.Col = buildPath(list)
ex.NestedCol = true
} else if len(list) > 1 {
ex.NestedCols = list
}
}

View File

@ -8,19 +8,19 @@ import (
"github.com/dosco/super-graph/jsn"
)
func varMap(ctx *coreContext) func(w io.Writer, tag string) (int, error) {
func argMap(ctx *coreContext) func(w io.Writer, tag string) (int, error) {
return func(w io.Writer, tag string) (int, error) {
switch tag {
case "user_id_provider":
if v := ctx.Value(userIDProviderKey); v != nil {
return stringVar(w, v.(string))
return stringArg(w, v.(string))
}
io.WriteString(w, "null")
return 0, nil
case "user_id":
if v := ctx.Value(userIDKey); v != nil {
return stringVar(w, v.(string))
return stringArg(w, v.(string))
}
io.WriteString(w, "null")
@ -28,7 +28,7 @@ func varMap(ctx *coreContext) func(w io.Writer, tag string) (int, error) {
case "user_role":
if v := ctx.Value(userRoleKey); v != nil {
return stringVar(w, v.(string))
return stringArg(w, v.(string))
}
io.WriteString(w, "null")
return 0, nil
@ -50,7 +50,7 @@ func varMap(ctx *coreContext) func(w io.Writer, tag string) (int, error) {
}
if is {
return stringVarB(w, fields[0].Value)
return stringArgB(w, fields[0].Value)
}
w.Write(fields[0].Value)
@ -58,7 +58,7 @@ func varMap(ctx *coreContext) func(w io.Writer, tag string) (int, error) {
}
}
func varList(ctx *coreContext, args [][]byte) []interface{} {
func argList(ctx *coreContext, args [][]byte) []interface{} {
vars := make([]interface{}, len(args))
var fields map[string]interface{}
@ -86,6 +86,11 @@ func varList(ctx *coreContext, args [][]byte) []interface{} {
vars[i] = v.(string)
}
case bytes.Equal(av, []byte("user_role")):
if v := ctx.Value(userRoleKey); v != nil {
vars[i] = v.(string)
}
default:
if v, ok := fields[string(av)]; ok {
vars[i] = v
@ -96,7 +101,7 @@ func varList(ctx *coreContext, args [][]byte) []interface{} {
return vars
}
func stringVar(w io.Writer, v string) (int, error) {
func stringArg(w io.Writer, v string) (int, error) {
if n, err := w.Write([]byte(`'`)); err != nil {
return n, err
}
@ -106,7 +111,7 @@ func stringVar(w io.Writer, v string) (int, error) {
return w.Write([]byte(`'`))
}
func stringVarB(w io.Writer, v []byte) (int, error) {
func stringArgB(w io.Writer, v []byte) (int, error) {
if n, err := w.Write([]byte(`'`)); err != nil {
return n, err
}

View File

@ -66,6 +66,7 @@ type config struct {
PoolSize int32 `mapstructure:"pool_size"`
MaxRetries int `mapstructure:"max_retries"`
LogLevel string `mapstructure:"log_level"`
SetUserID bool `mapstructure:"set_user_id"`
Vars map[string]string `mapstructure:"variables"`

View File

@ -122,10 +122,8 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) {
}
defer tx.Rollback(c)
if v := c.Value(userIDKey); v != nil {
_, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
if err != nil {
if conf.DB.SetUserID {
if err := c.setLocalUserID(tx); err != nil {
return nil, nil, err
}
}
@ -153,7 +151,7 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) {
}
var root []byte
vars := varList(c, ps.args)
vars := argList(c, ps.args)
if mutation {
err = tx.QueryRow(c, ps.stmt.SQL, vars...).Scan(&root)
@ -206,7 +204,7 @@ func (c *coreContext) resolveSQL() ([]byte, uint32, error) {
t := fasttemplate.New(st.sql, openVar, closeVar)
buf := &bytes.Buffer{}
_, err = t.ExecuteFunc(buf, varMap(c))
_, err = t.ExecuteFunc(buf, argMap(c))
if err == errNoUserID {
logger.Warn().Msg("no user id found. query requires an authenicated request")
@ -224,10 +222,8 @@ func (c *coreContext) resolveSQL() ([]byte, uint32, error) {
stime = time.Now()
}
if v := c.Value(userIDKey); v != nil {
_, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
if err != nil {
if conf.DB.SetUserID {
if err := c.setLocalUserID(tx); err != nil {
return nil, 0, err
}
}
@ -425,6 +421,15 @@ func (c *coreContext) executeRoleQuery(tx pgx.Tx) (string, error) {
return role, nil
}
func (c *coreContext) setLocalUserID(tx pgx.Tx) error {
var err error
if v := c.Value(userIDKey); v != nil {
_, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
}
return err
}
func (c *coreContext) render(w io.Writer, data []byte) error {
c.res.Data = json.RawMessage(data)
return json.NewEncoder(w).Encode(c.res)

View File

@ -51,7 +51,8 @@ auth:
cookie: _{{app_name_slug}}_session
# Comment this out if you want to disable setting
# the user_id via a header. Good for testing
# the user_id via a header for testing.
# Disable in production
creds_in_header: true
rails:
@ -92,6 +93,10 @@ database:
#max_retries: 0
#log_level: "debug"
# Set session variable "user.id" to the user id
# Enable this if you need the user id in triggers, etc
set_user_id: false
# Define variables here that you want to use in filters
# sub-queries must be wrapped in ()
variables:
@ -131,7 +136,7 @@ tables:
name: me
table: users
roles_query: "SELECT * FROM users as usr WHERE id = $user_id"
roles_query: "SELECT * FROM users WHERE id = $user_id"
roles:
- name: anon

View File

@ -85,3 +85,7 @@ database:
#pool_size: 10
#max_retries: 0
#log_level: "debug"
# Set session variable "user.id" to the user id
# Enable this if you need the user id in triggers, etc
set_user_id: false