Add nested where clause to filter based on related tables
This commit is contained in:
parent
77a51924a7
commit
89bc93e159
|
@ -1,14 +1,14 @@
|
||||||
<a href="https://supergraph.dev"><img src="https://supergraph.dev/hologram.svg" width="100" height="100" align="right" /></a>
|
<a href="https://supergraph.dev"><img src="https://supergraph.dev/hologram.svg" width="100" height="100" align="right" /></a>
|
||||||
|
|
||||||
# Super Graph - Build web products faster. Instant GraphQL APIs for your apps
|
# Super Graph - Instant GraphQL APIs for your apps.
|
||||||
|
|
||||||
|
## Build web products faster. No code needed. GraphQL auto. transformed into efficient database queries.
|
||||||
|
|
||||||
![MIT license](https://img.shields.io/github/license/dosco/super-graph.svg)
|
![MIT license](https://img.shields.io/github/license/dosco/super-graph.svg)
|
||||||
![Docker build](https://img.shields.io/docker/cloud/build/dosco/super-graph.svg)
|
![Docker build](https://img.shields.io/docker/cloud/build/dosco/super-graph.svg)
|
||||||
![Cloud native](https://img.shields.io/badge/cloud--native-enabled-blue.svg)
|
![Cloud native](https://img.shields.io/badge/cloud--native-enabled-blue.svg)
|
||||||
[![Discord Chat](https://img.shields.io/discord/628796009539043348.svg)](https://discord.gg/6pSWCTZ)
|
[![Discord Chat](https://img.shields.io/discord/628796009539043348.svg)](https://discord.gg/6pSWCTZ)
|
||||||
|
|
||||||
Get an instant high performance GraphQL API for Postgres. No code needed. GraphQL is automatically transformed into efficient database queries.
|
|
||||||
|
|
||||||
![GraphQL](docs/.vuepress/public/graphql.png?raw=true "")
|
![GraphQL](docs/.vuepress/public/graphql.png?raw=true "")
|
||||||
|
|
||||||
## The story of Super Graph?
|
## The story of Super Graph?
|
||||||
|
@ -25,6 +25,7 @@ This compiler is what sits at the heart of Super Graph with layers of useful fun
|
||||||
|
|
||||||
## Features
|
## Features
|
||||||
|
|
||||||
|
- Role based access control
|
||||||
- Works with Rails database schemas
|
- Works with Rails database schemas
|
||||||
- Automatically learns schemas and relationships
|
- Automatically learns schemas and relationships
|
||||||
- Belongs-To, One-To-Many and Many-To-Many table relationships
|
- Belongs-To, One-To-Many and Many-To-Many table relationships
|
||||||
|
|
|
@ -51,7 +51,8 @@ auth:
|
||||||
cookie: _app_session
|
cookie: _app_session
|
||||||
|
|
||||||
# Comment this out if you want to disable setting
|
# Comment this out if you want to disable setting
|
||||||
# the user_id via a header. Good for testing
|
# the user_id via a header for testing.
|
||||||
|
# Disable in production
|
||||||
creds_in_header: true
|
creds_in_header: true
|
||||||
|
|
||||||
rails:
|
rails:
|
||||||
|
@ -92,6 +93,10 @@ database:
|
||||||
#max_retries: 0
|
#max_retries: 0
|
||||||
#log_level: "debug"
|
#log_level: "debug"
|
||||||
|
|
||||||
|
# Set session variable "user.id" to the user id
|
||||||
|
# Enable this if you need the user id in triggers, etc
|
||||||
|
set_user_id: false
|
||||||
|
|
||||||
# Define variables here that you want to use in filters
|
# Define variables here that you want to use in filters
|
||||||
# sub-queries must be wrapped in ()
|
# sub-queries must be wrapped in ()
|
||||||
variables:
|
variables:
|
||||||
|
@ -131,7 +136,7 @@ tables:
|
||||||
name: me
|
name: me
|
||||||
table: users
|
table: users
|
||||||
|
|
||||||
roles_query: "SELECT * FROM users as usr WHERE id = $user_id"
|
roles_query: "SELECT * FROM users WHERE id = $user_id"
|
||||||
|
|
||||||
roles:
|
roles:
|
||||||
- name: anon
|
- name: anon
|
||||||
|
|
|
@ -85,3 +85,7 @@ database:
|
||||||
#pool_size: 10
|
#pool_size: 10
|
||||||
#max_retries: 0
|
#max_retries: 0
|
||||||
#log_level: "debug"
|
#log_level: "debug"
|
||||||
|
|
||||||
|
# Set session variable "user.id" to the user id
|
||||||
|
# Enable this if you need the user id in triggers, etc
|
||||||
|
set_user_id: false
|
|
@ -214,6 +214,7 @@ func (c *compilerContext) renderDelete(qc *qcode.QCode, w io.Writer,
|
||||||
|
|
||||||
func (c *compilerContext) renderUpsert(qc *qcode.QCode, w io.Writer,
|
func (c *compilerContext) renderUpsert(qc *qcode.QCode, w io.Writer,
|
||||||
vars Variables, ti *DBTableInfo) (uint32, error) {
|
vars Variables, ti *DBTableInfo) (uint32, error) {
|
||||||
|
root := &qc.Selects[0]
|
||||||
|
|
||||||
upsert, ok := vars[qc.ActionVar]
|
upsert, ok := vars[qc.ActionVar]
|
||||||
if !ok {
|
if !ok {
|
||||||
|
@ -229,7 +230,7 @@ func (c *compilerContext) renderUpsert(qc *qcode.QCode, w io.Writer,
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
io.WriteString(c.w, ` ON CONFLICT DO (`)
|
io.WriteString(c.w, ` ON CONFLICT (`)
|
||||||
i := 0
|
i := 0
|
||||||
|
|
||||||
for _, cn := range ti.ColumnNames {
|
for _, cn := range ti.ColumnNames {
|
||||||
|
@ -250,10 +251,17 @@ func (c *compilerContext) renderUpsert(qc *qcode.QCode, w io.Writer,
|
||||||
if i == 0 {
|
if i == 0 {
|
||||||
io.WriteString(c.w, ti.PrimaryCol)
|
io.WriteString(c.w, ti.PrimaryCol)
|
||||||
}
|
}
|
||||||
io.WriteString(c.w, `) DO `)
|
io.WriteString(c.w, `)`)
|
||||||
|
|
||||||
io.WriteString(c.w, `UPDATE `)
|
if root.Where != nil {
|
||||||
io.WriteString(c.w, ` SET `)
|
io.WriteString(c.w, ` WHERE `)
|
||||||
|
|
||||||
|
if err := c.renderWhere(root, ti); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
io.WriteString(c.w, ` DO UPDATE SET `)
|
||||||
|
|
||||||
i = 0
|
i = 0
|
||||||
for _, cn := range ti.ColumnNames {
|
for _, cn := range ti.ColumnNames {
|
||||||
|
|
|
@ -78,13 +78,37 @@ func bulkInsert(t *testing.T) {
|
||||||
|
|
||||||
func singleUpsert(t *testing.T) {
|
func singleUpsert(t *testing.T) {
|
||||||
gql := `mutation {
|
gql := `mutation {
|
||||||
product(id: 15, upsert: $upsert) {
|
product(upsert: $upsert) {
|
||||||
id
|
id
|
||||||
name
|
name
|
||||||
}
|
}
|
||||||
}`
|
}`
|
||||||
|
|
||||||
sql := `WITH "products" AS (WITH "input" AS (SELECT {{upsert}}::json AS j) INSERT INTO "products" ("name", "description") SELECT "name", "description" FROM input i, json_populate_record(NULL::products, i.j) t ON CONFLICT DO (id) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"`
|
sql := `WITH "products" AS (WITH "input" AS (SELECT {{upsert}}::json AS j) INSERT INTO "products" ("name", "description") SELECT "name", "description" FROM input i, json_populate_record(NULL::products, i.j) t ON CONFLICT (id) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"`
|
||||||
|
|
||||||
|
vars := map[string]json.RawMessage{
|
||||||
|
"upsert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`),
|
||||||
|
}
|
||||||
|
|
||||||
|
resSQL, err := compileGQLToPSQL(gql, vars, "user")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(resSQL) != sql {
|
||||||
|
t.Fatal(errNotExpected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func singleUpsertWhere(t *testing.T) {
|
||||||
|
gql := `mutation {
|
||||||
|
product(upsert: $upsert, where: { price : { gt: 3 } }) {
|
||||||
|
id
|
||||||
|
name
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
sql := `WITH "products" AS (WITH "input" AS (SELECT {{upsert}}::json AS j) INSERT INTO "products" ("name", "description") SELECT "name", "description" FROM input i, json_populate_record(NULL::products, i.j) t ON CONFLICT (id) WHERE (("products"."price") > 3) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"`
|
||||||
|
|
||||||
vars := map[string]json.RawMessage{
|
vars := map[string]json.RawMessage{
|
||||||
"upsert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`),
|
"upsert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`),
|
||||||
|
@ -102,13 +126,13 @@ func singleUpsert(t *testing.T) {
|
||||||
|
|
||||||
func bulkUpsert(t *testing.T) {
|
func bulkUpsert(t *testing.T) {
|
||||||
gql := `mutation {
|
gql := `mutation {
|
||||||
product(id: 15, upsert: $upsert) {
|
product(upsert: $upsert) {
|
||||||
id
|
id
|
||||||
name
|
name
|
||||||
}
|
}
|
||||||
}`
|
}`
|
||||||
|
|
||||||
sql := `WITH "products" AS (WITH "input" AS (SELECT {{upsert}}::json AS j) INSERT INTO "products" ("name", "description") SELECT "name", "description" FROM input i, json_populate_recordset(NULL::products, i.j) t ON CONFLICT DO (id) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"`
|
sql := `WITH "products" AS (WITH "input" AS (SELECT {{upsert}}::json AS j) INSERT INTO "products" ("name", "description") SELECT "name", "description" FROM input i, json_populate_recordset(NULL::products, i.j) t ON CONFLICT (id) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"`
|
||||||
|
|
||||||
vars := map[string]json.RawMessage{
|
vars := map[string]json.RawMessage{
|
||||||
"upsert": json.RawMessage(` [{ "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }]`),
|
"upsert": json.RawMessage(` [{ "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }]`),
|
||||||
|
@ -271,6 +295,7 @@ func TestCompileMutate(t *testing.T) {
|
||||||
t.Run("bulkInsert", bulkInsert)
|
t.Run("bulkInsert", bulkInsert)
|
||||||
t.Run("singleUpdate", singleUpdate)
|
t.Run("singleUpdate", singleUpdate)
|
||||||
t.Run("singleUpsert", singleUpsert)
|
t.Run("singleUpsert", singleUpsert)
|
||||||
|
t.Run("singleUpsertWhere", singleUpsertWhere)
|
||||||
t.Run("bulkUpsert", bulkUpsert)
|
t.Run("bulkUpsert", bulkUpsert)
|
||||||
t.Run("delete", delete)
|
t.Run("delete", delete)
|
||||||
t.Run("blockedInsert", blockedInsert)
|
t.Run("blockedInsert", blockedInsert)
|
||||||
|
|
284
psql/query.go
284
psql/query.go
|
@ -120,7 +120,7 @@ func (co *Compiler) compileQuery(qc *qcode.QCode, w io.Writer) (uint32, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if sel.ID != 0 {
|
if sel.ID != 0 {
|
||||||
if err = c.renderJoin(sel); err != nil {
|
if err = c.renderLateralJoin(sel); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -154,7 +154,7 @@ func (co *Compiler) compileQuery(qc *qcode.QCode, w io.Writer) (uint32, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if sel.ID != 0 {
|
if sel.ID != 0 {
|
||||||
if err = c.renderJoinClose(sel); err != nil {
|
if err = c.renderLateralJoinClose(sel); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -327,12 +327,12 @@ func (c *compilerContext) renderSelectClose(sel *qcode.Select, ti *DBTableInfo)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *compilerContext) renderJoin(sel *qcode.Select) error {
|
func (c *compilerContext) renderLateralJoin(sel *qcode.Select) error {
|
||||||
io.WriteString(c.w, ` LEFT OUTER JOIN LATERAL (`)
|
io.WriteString(c.w, ` LEFT OUTER JOIN LATERAL (`)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *compilerContext) renderJoinClose(sel *qcode.Select) error {
|
func (c *compilerContext) renderLateralJoinClose(sel *qcode.Select) error {
|
||||||
//fmt.Fprintf(w, `) AS "%s_%d_join" ON ('true')`, c.sel.Table, c.sel.ID)
|
//fmt.Fprintf(w, `) AS "%s_%d_join" ON ('true')`, c.sel.Table, c.sel.ID)
|
||||||
io.WriteString(c.w, `)`)
|
io.WriteString(c.w, `)`)
|
||||||
aliasWithIDSuffix(c.w, sel.Table, sel.ID, "_join")
|
aliasWithIDSuffix(c.w, sel.Table, sel.ID, "_join")
|
||||||
|
@ -340,19 +340,24 @@ func (c *compilerContext) renderJoinClose(sel *qcode.Select) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *compilerContext) renderJoinTable(sel *qcode.Select) error {
|
func (c *compilerContext) renderJoin(sel *qcode.Select) error {
|
||||||
parent := &c.s[sel.ParentID]
|
parent := &c.s[sel.ParentID]
|
||||||
|
return c.renderJoinByName(sel.Table, parent.Table, parent.ID)
|
||||||
|
}
|
||||||
|
|
||||||
rel, err := c.schema.GetRel(sel.Table, parent.Table)
|
func (c *compilerContext) renderJoinByName(table, parent string, id int32) error {
|
||||||
|
rel, err := c.schema.GetRel(table, parent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// This join is only required for one-to-many relations since
|
||||||
|
// these make use of join tables that need to be pulled in.
|
||||||
if rel.Type != RelOneToManyThrough {
|
if rel.Type != RelOneToManyThrough {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
pt, err := c.schema.GetTable(parent.Table)
|
pt, err := c.schema.GetTable(parent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -364,7 +369,11 @@ func (c *compilerContext) renderJoinTable(sel *qcode.Select) error {
|
||||||
io.WriteString(c.w, `" ON ((`)
|
io.WriteString(c.w, `" ON ((`)
|
||||||
colWithTable(c.w, rel.Through, rel.ColT)
|
colWithTable(c.w, rel.Through, rel.ColT)
|
||||||
io.WriteString(c.w, `) = (`)
|
io.WriteString(c.w, `) = (`)
|
||||||
colWithTableID(c.w, pt.Name, parent.ID, rel.Col1)
|
if id != -1 {
|
||||||
|
colWithTableID(c.w, pt.Name, id, rel.Col1)
|
||||||
|
} else {
|
||||||
|
colWithTable(c.w, pt.Name, rel.Col1)
|
||||||
|
}
|
||||||
io.WriteString(c.w, `))`)
|
io.WriteString(c.w, `))`)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
@ -438,7 +447,7 @@ func (c *compilerContext) renderJoinedColumns(sel *qcode.Select, ti *DBTableInfo
|
||||||
|
|
||||||
cti, err := c.schema.GetTable(childSel.Table)
|
cti, err := c.schema.GetTable(childSel.Table)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
//fmt.Fprintf(w, `"%s_%d_join"."%s" AS "%s"`,
|
//fmt.Fprintf(w, `"%s_%d_join"."%s" AS "%s"`,
|
||||||
|
@ -529,8 +538,10 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
|
||||||
|
|
||||||
} else if sel.Functions {
|
} else if sel.Functions {
|
||||||
cn1 := cn[pl:]
|
cn1 := cn[pl:]
|
||||||
if _, ok := sel.Allowed[cn1]; !ok {
|
if len(sel.Allowed) != 0 {
|
||||||
continue
|
if _, ok := sel.Allowed[cn1]; !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
}
|
}
|
||||||
if i != 0 {
|
if i != 0 {
|
||||||
io.WriteString(c.w, `, `)
|
io.WriteString(c.w, `, `)
|
||||||
|
@ -596,7 +607,7 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
|
||||||
}
|
}
|
||||||
|
|
||||||
if !isRoot {
|
if !isRoot {
|
||||||
if err := c.renderJoinTable(sel); err != nil {
|
if err := c.renderJoin(sel); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -680,8 +691,11 @@ func (c *compilerContext) renderOrderByColumns(sel *qcode.Select, ti *DBTableInf
|
||||||
|
|
||||||
func (c *compilerContext) renderRelationship(sel *qcode.Select, ti *DBTableInfo) error {
|
func (c *compilerContext) renderRelationship(sel *qcode.Select, ti *DBTableInfo) error {
|
||||||
parent := c.s[sel.ParentID]
|
parent := c.s[sel.ParentID]
|
||||||
|
return c.renderRelationshipByName(sel.Table, parent.Table, parent.ID)
|
||||||
|
}
|
||||||
|
|
||||||
rel, err := c.schema.GetRel(sel.Table, parent.Table)
|
func (c *compilerContext) renderRelationshipByName(table, parent string, id int32) error {
|
||||||
|
rel, err := c.schema.GetRel(table, parent)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -691,25 +705,34 @@ func (c *compilerContext) renderRelationship(sel *qcode.Select, ti *DBTableInfo)
|
||||||
//fmt.Fprintf(w, `(("%s"."%s") = ("%s_%d"."%s"))`,
|
//fmt.Fprintf(w, `(("%s"."%s") = ("%s_%d"."%s"))`,
|
||||||
//c.sel.Table, rel.Col1, c.parent.Table, c.parent.ID, rel.Col2)
|
//c.sel.Table, rel.Col1, c.parent.Table, c.parent.ID, rel.Col2)
|
||||||
io.WriteString(c.w, `((`)
|
io.WriteString(c.w, `((`)
|
||||||
colWithTable(c.w, ti.Name, rel.Col1)
|
colWithTable(c.w, table, rel.Col1)
|
||||||
io.WriteString(c.w, `) = (`)
|
io.WriteString(c.w, `) = (`)
|
||||||
colWithTableID(c.w, parent.Table, parent.ID, rel.Col2)
|
if id != -1 {
|
||||||
|
colWithTableID(c.w, parent, id, rel.Col2)
|
||||||
|
} else {
|
||||||
|
colWithTable(c.w, parent, rel.Col2)
|
||||||
|
}
|
||||||
io.WriteString(c.w, `))`)
|
io.WriteString(c.w, `))`)
|
||||||
|
|
||||||
case RelOneToMany:
|
case RelOneToMany:
|
||||||
//fmt.Fprintf(w, `(("%s"."%s") = ("%s_%d"."%s"))`,
|
//fmt.Fprintf(w, `(("%s"."%s") = ("%s_%d"."%s"))`,
|
||||||
//c.sel.Table, rel.Col1, c.parent.Table, c.parent.ID, rel.Col2)
|
//c.sel.Table, rel.Col1, c.parent.Table, c.parent.ID, rel.Col2)
|
||||||
io.WriteString(c.w, `((`)
|
io.WriteString(c.w, `((`)
|
||||||
colWithTable(c.w, ti.Name, rel.Col1)
|
colWithTable(c.w, table, rel.Col1)
|
||||||
io.WriteString(c.w, `) = (`)
|
io.WriteString(c.w, `) = (`)
|
||||||
colWithTableID(c.w, parent.Table, parent.ID, rel.Col2)
|
if id != -1 {
|
||||||
|
colWithTableID(c.w, parent, id, rel.Col2)
|
||||||
|
} else {
|
||||||
|
colWithTable(c.w, parent, rel.Col2)
|
||||||
|
}
|
||||||
io.WriteString(c.w, `))`)
|
io.WriteString(c.w, `))`)
|
||||||
|
|
||||||
case RelOneToManyThrough:
|
case RelOneToManyThrough:
|
||||||
|
// This requires the through table to be joined onto this select
|
||||||
//fmt.Fprintf(w, `(("%s"."%s") = ("%s"."%s"))`,
|
//fmt.Fprintf(w, `(("%s"."%s") = ("%s"."%s"))`,
|
||||||
//c.sel.Table, rel.Col1, rel.Through, rel.Col2)
|
//c.sel.Table, rel.Col1, rel.Through, rel.Col2)
|
||||||
io.WriteString(c.w, `((`)
|
io.WriteString(c.w, `((`)
|
||||||
colWithTable(c.w, ti.Name, rel.Col1)
|
colWithTable(c.w, table, rel.Col1)
|
||||||
io.WriteString(c.w, `) = (`)
|
io.WriteString(c.w, `) = (`)
|
||||||
colWithTable(c.w, rel.Through, rel.Col2)
|
colWithTable(c.w, rel.Through, rel.Col2)
|
||||||
io.WriteString(c.w, `))`)
|
io.WriteString(c.w, `))`)
|
||||||
|
@ -768,101 +791,20 @@ func (c *compilerContext) renderWhere(sel *qcode.Select, ti *DBTableInfo) error
|
||||||
qcode.FreeExp(val)
|
qcode.FreeExp(val)
|
||||||
|
|
||||||
default:
|
default:
|
||||||
if val.NestedCol {
|
if len(val.NestedCols) != 0 {
|
||||||
//fmt.Fprintf(w, `(("%s") `, val.Col)
|
io.WriteString(c.w, `EXISTS `)
|
||||||
io.WriteString(c.w, `(("`)
|
|
||||||
io.WriteString(c.w, val.Col)
|
|
||||||
io.WriteString(c.w, `") `)
|
|
||||||
|
|
||||||
} else if len(val.Col) != 0 {
|
if err := c.renderNestedWhere(val, sel, ti); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
//fmt.Fprintf(w, `(("%s"."%s") `, c.sel.Table, val.Col)
|
//fmt.Fprintf(w, `(("%s"."%s") `, c.sel.Table, val.Col)
|
||||||
io.WriteString(c.w, `((`)
|
if err := c.renderOp(val, sel, ti); err != nil {
|
||||||
colWithTable(c.w, ti.Name, val.Col)
|
return err
|
||||||
io.WriteString(c.w, `) `)
|
}
|
||||||
|
qcode.FreeExp(val)
|
||||||
}
|
}
|
||||||
valExists := true
|
|
||||||
|
|
||||||
switch val.Op {
|
|
||||||
case qcode.OpEquals:
|
|
||||||
io.WriteString(c.w, `=`)
|
|
||||||
case qcode.OpNotEquals:
|
|
||||||
io.WriteString(c.w, `!=`)
|
|
||||||
case qcode.OpGreaterOrEquals:
|
|
||||||
io.WriteString(c.w, `>=`)
|
|
||||||
case qcode.OpLesserOrEquals:
|
|
||||||
io.WriteString(c.w, `<=`)
|
|
||||||
case qcode.OpGreaterThan:
|
|
||||||
io.WriteString(c.w, `>`)
|
|
||||||
case qcode.OpLesserThan:
|
|
||||||
io.WriteString(c.w, `<`)
|
|
||||||
case qcode.OpIn:
|
|
||||||
io.WriteString(c.w, `IN`)
|
|
||||||
case qcode.OpNotIn:
|
|
||||||
io.WriteString(c.w, `NOT IN`)
|
|
||||||
case qcode.OpLike:
|
|
||||||
io.WriteString(c.w, `LIKE`)
|
|
||||||
case qcode.OpNotLike:
|
|
||||||
io.WriteString(c.w, `NOT LIKE`)
|
|
||||||
case qcode.OpILike:
|
|
||||||
io.WriteString(c.w, `ILIKE`)
|
|
||||||
case qcode.OpNotILike:
|
|
||||||
io.WriteString(c.w, `NOT ILIKE`)
|
|
||||||
case qcode.OpSimilar:
|
|
||||||
io.WriteString(c.w, `SIMILAR TO`)
|
|
||||||
case qcode.OpNotSimilar:
|
|
||||||
io.WriteString(c.w, `NOT SIMILAR TO`)
|
|
||||||
case qcode.OpContains:
|
|
||||||
io.WriteString(c.w, `@>`)
|
|
||||||
case qcode.OpContainedIn:
|
|
||||||
io.WriteString(c.w, `<@`)
|
|
||||||
case qcode.OpHasKey:
|
|
||||||
io.WriteString(c.w, `?`)
|
|
||||||
case qcode.OpHasKeyAny:
|
|
||||||
io.WriteString(c.w, `?|`)
|
|
||||||
case qcode.OpHasKeyAll:
|
|
||||||
io.WriteString(c.w, `?&`)
|
|
||||||
case qcode.OpIsNull:
|
|
||||||
if strings.EqualFold(val.Val, "true") {
|
|
||||||
io.WriteString(c.w, `IS NULL)`)
|
|
||||||
} else {
|
|
||||||
io.WriteString(c.w, `IS NOT NULL)`)
|
|
||||||
}
|
|
||||||
valExists = false
|
|
||||||
case qcode.OpEqID:
|
|
||||||
if len(ti.PrimaryCol) == 0 {
|
|
||||||
return fmt.Errorf("no primary key column defined for %s", ti.Name)
|
|
||||||
}
|
|
||||||
//fmt.Fprintf(w, `(("%s") =`, c.ti.PrimaryCol)
|
|
||||||
io.WriteString(c.w, `((`)
|
|
||||||
colWithTable(c.w, ti.Name, ti.PrimaryCol)
|
|
||||||
//io.WriteString(c.w, ti.PrimaryCol)
|
|
||||||
io.WriteString(c.w, `) =`)
|
|
||||||
case qcode.OpTsQuery:
|
|
||||||
if len(ti.TSVCol) == 0 {
|
|
||||||
return fmt.Errorf("no tsv column defined for %s", ti.Name)
|
|
||||||
}
|
|
||||||
//fmt.Fprintf(w, `(("%s") @@ to_tsquery('%s'))`, c.ti.TSVCol, val.Val)
|
|
||||||
io.WriteString(c.w, `(("`)
|
|
||||||
io.WriteString(c.w, ti.TSVCol)
|
|
||||||
io.WriteString(c.w, `") @@ to_tsquery('`)
|
|
||||||
io.WriteString(c.w, val.Val)
|
|
||||||
io.WriteString(c.w, `'))`)
|
|
||||||
valExists = false
|
|
||||||
|
|
||||||
default:
|
|
||||||
return fmt.Errorf("[Where] unexpected op code %d", val.Op)
|
|
||||||
}
|
|
||||||
|
|
||||||
if valExists {
|
|
||||||
if val.Type == qcode.ValList {
|
|
||||||
c.renderList(val)
|
|
||||||
} else {
|
|
||||||
c.renderVal(val, c.vars)
|
|
||||||
}
|
|
||||||
io.WriteString(c.w, `)`)
|
|
||||||
}
|
|
||||||
|
|
||||||
qcode.FreeExp(val)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
|
@ -874,6 +816,128 @@ func (c *compilerContext) renderWhere(sel *qcode.Select, ti *DBTableInfo) error
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *compilerContext) renderNestedWhere(ex *qcode.Exp, sel *qcode.Select, ti *DBTableInfo) error {
|
||||||
|
|
||||||
|
for i := 0; i < len(ex.NestedCols)-1; i++ {
|
||||||
|
cti, err := c.schema.GetTable(ex.NestedCols[i])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if i != 0 {
|
||||||
|
io.WriteString(c.w, ` AND `)
|
||||||
|
}
|
||||||
|
|
||||||
|
io.WriteString(c.w, `(SELECT 1 FROM `)
|
||||||
|
io.WriteString(c.w, cti.Name)
|
||||||
|
|
||||||
|
if err := c.renderJoinByName(cti.Name, ti.Name, -1); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
io.WriteString(c.w, ` WHERE `)
|
||||||
|
|
||||||
|
if err := c.renderRelationshipByName(cti.Name, ti.Name, -1); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < len(ex.NestedCols)-1; i++ {
|
||||||
|
io.WriteString(c.w, `)`)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *compilerContext) renderOp(ex *qcode.Exp, sel *qcode.Select, ti *DBTableInfo) error {
|
||||||
|
if len(ex.Col) != 0 {
|
||||||
|
io.WriteString(c.w, `((`)
|
||||||
|
colWithTable(c.w, ti.Name, ex.Col)
|
||||||
|
io.WriteString(c.w, `) `)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch ex.Op {
|
||||||
|
case qcode.OpEquals:
|
||||||
|
io.WriteString(c.w, `=`)
|
||||||
|
case qcode.OpNotEquals:
|
||||||
|
io.WriteString(c.w, `!=`)
|
||||||
|
case qcode.OpGreaterOrEquals:
|
||||||
|
io.WriteString(c.w, `>=`)
|
||||||
|
case qcode.OpLesserOrEquals:
|
||||||
|
io.WriteString(c.w, `<=`)
|
||||||
|
case qcode.OpGreaterThan:
|
||||||
|
io.WriteString(c.w, `>`)
|
||||||
|
case qcode.OpLesserThan:
|
||||||
|
io.WriteString(c.w, `<`)
|
||||||
|
case qcode.OpIn:
|
||||||
|
io.WriteString(c.w, `IN`)
|
||||||
|
case qcode.OpNotIn:
|
||||||
|
io.WriteString(c.w, `NOT IN`)
|
||||||
|
case qcode.OpLike:
|
||||||
|
io.WriteString(c.w, `LIKE`)
|
||||||
|
case qcode.OpNotLike:
|
||||||
|
io.WriteString(c.w, `NOT LIKE`)
|
||||||
|
case qcode.OpILike:
|
||||||
|
io.WriteString(c.w, `ILIKE`)
|
||||||
|
case qcode.OpNotILike:
|
||||||
|
io.WriteString(c.w, `NOT ILIKE`)
|
||||||
|
case qcode.OpSimilar:
|
||||||
|
io.WriteString(c.w, `SIMILAR TO`)
|
||||||
|
case qcode.OpNotSimilar:
|
||||||
|
io.WriteString(c.w, `NOT SIMILAR TO`)
|
||||||
|
case qcode.OpContains:
|
||||||
|
io.WriteString(c.w, `@>`)
|
||||||
|
case qcode.OpContainedIn:
|
||||||
|
io.WriteString(c.w, `<@`)
|
||||||
|
case qcode.OpHasKey:
|
||||||
|
io.WriteString(c.w, `?`)
|
||||||
|
case qcode.OpHasKeyAny:
|
||||||
|
io.WriteString(c.w, `?|`)
|
||||||
|
case qcode.OpHasKeyAll:
|
||||||
|
io.WriteString(c.w, `?&`)
|
||||||
|
case qcode.OpIsNull:
|
||||||
|
if strings.EqualFold(ex.Val, "true") {
|
||||||
|
io.WriteString(c.w, `IS NULL)`)
|
||||||
|
} else {
|
||||||
|
io.WriteString(c.w, `IS NOT NULL)`)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
case qcode.OpEqID:
|
||||||
|
if len(ti.PrimaryCol) == 0 {
|
||||||
|
return fmt.Errorf("no primary key column defined for %s", ti.Name)
|
||||||
|
}
|
||||||
|
//fmt.Fprintf(w, `(("%s") =`, c.ti.PrimaryCol)
|
||||||
|
io.WriteString(c.w, `((`)
|
||||||
|
colWithTable(c.w, ti.Name, ti.PrimaryCol)
|
||||||
|
//io.WriteString(c.w, ti.PrimaryCol)
|
||||||
|
io.WriteString(c.w, `) =`)
|
||||||
|
case qcode.OpTsQuery:
|
||||||
|
if len(ti.TSVCol) == 0 {
|
||||||
|
return fmt.Errorf("no tsv column defined for %s", ti.Name)
|
||||||
|
}
|
||||||
|
//fmt.Fprintf(w, `(("%s") @@ to_tsquery('%s'))`, c.ti.TSVCol, val.Val)
|
||||||
|
io.WriteString(c.w, `(("`)
|
||||||
|
io.WriteString(c.w, ti.TSVCol)
|
||||||
|
io.WriteString(c.w, `") @@ to_tsquery('`)
|
||||||
|
io.WriteString(c.w, ex.Val)
|
||||||
|
io.WriteString(c.w, `'))`)
|
||||||
|
return nil
|
||||||
|
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("[Where] unexpected op code %d", ex.Op)
|
||||||
|
}
|
||||||
|
|
||||||
|
if ex.Type == qcode.ValList {
|
||||||
|
c.renderList(ex)
|
||||||
|
} else {
|
||||||
|
c.renderVal(ex, c.vars)
|
||||||
|
}
|
||||||
|
|
||||||
|
io.WriteString(c.w, `)`)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *compilerContext) renderOrderBy(sel *qcode.Select, ti *DBTableInfo) error {
|
func (c *compilerContext) renderOrderBy(sel *qcode.Select, ti *DBTableInfo) error {
|
||||||
io.WriteString(c.w, ` ORDER BY `)
|
io.WriteString(c.w, ` ORDER BY `)
|
||||||
for i := range sel.OrderBy {
|
for i := range sel.OrderBy {
|
||||||
|
|
|
@ -332,6 +332,25 @@ func aggFunctionWithFilter(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func syntheticTables(t *testing.T) {
|
||||||
|
gql := `query {
|
||||||
|
me {
|
||||||
|
email
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
sql := `SELECT json_object_agg('me', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT ) AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = {{user_id}})) LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "done_1337"`
|
||||||
|
|
||||||
|
resSQL, err := compileGQLToPSQL(gql, nil, "user")
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(resSQL) != sql {
|
||||||
|
t.Fatal(errNotExpected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func queryWithVariables(t *testing.T) {
|
func queryWithVariables(t *testing.T) {
|
||||||
gql := `query {
|
gql := `query {
|
||||||
product(id: $PRODUCT_ID, where: { price: { eq: $PRODUCT_PRICE } }) {
|
product(id: $PRODUCT_ID, where: { price: { eq: $PRODUCT_PRICE } }) {
|
||||||
|
@ -352,14 +371,21 @@ func queryWithVariables(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func syntheticTables(t *testing.T) {
|
func withWhereOnRelations(t *testing.T) {
|
||||||
gql := `query {
|
gql := `query {
|
||||||
me {
|
users(where: {
|
||||||
|
not: {
|
||||||
|
products: {
|
||||||
|
price: { gt: 3 }
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}) {
|
||||||
|
id
|
||||||
email
|
email
|
||||||
}
|
}
|
||||||
}`
|
}`
|
||||||
|
|
||||||
sql := `SELECT json_object_agg('me', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT ) AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = {{user_id}})) LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "done_1337"`
|
sql := `SELECT json_object_agg('users', users) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."id" AS "id", "users_0"."email" AS "email") AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."id", "users"."email" FROM "users" WHERE (NOT EXISTS (SELECT 1 FROM products WHERE (("products"."user_id") = ("users"."id")))) LIMIT ('20') :: integer) AS "users_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"`
|
||||||
|
|
||||||
resSQL, err := compileGQLToPSQL(gql, nil, "user")
|
resSQL, err := compileGQLToPSQL(gql, nil, "user")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -371,50 +397,6 @@ func syntheticTables(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCompileQuery(t *testing.T) {
|
|
||||||
t.Run("withComplexArgs", withComplexArgs)
|
|
||||||
t.Run("withWhereAndList", withWhereAndList)
|
|
||||||
t.Run("withWhereIsNull", withWhereIsNull)
|
|
||||||
t.Run("withWhereMultiOr", withWhereMultiOr)
|
|
||||||
t.Run("fetchByID", fetchByID)
|
|
||||||
t.Run("searchQuery", searchQuery)
|
|
||||||
t.Run("belongsTo", belongsTo)
|
|
||||||
t.Run("oneToMany", oneToMany)
|
|
||||||
t.Run("manyToMany", manyToMany)
|
|
||||||
t.Run("manyToManyReverse", manyToManyReverse)
|
|
||||||
t.Run("aggFunction", aggFunction)
|
|
||||||
t.Run("aggFunctionBlockedByCol", aggFunctionBlockedByCol)
|
|
||||||
t.Run("aggFunctionDisabled", aggFunctionDisabled)
|
|
||||||
t.Run("aggFunctionWithFilter", aggFunctionWithFilter)
|
|
||||||
t.Run("syntheticTables", syntheticTables)
|
|
||||||
t.Run("queryWithVariables", queryWithVariables)
|
|
||||||
t.Run("blockedQuery", blockedQuery)
|
|
||||||
t.Run("blockedFunctions", blockedFunctions)
|
|
||||||
}
|
|
||||||
|
|
||||||
var benchGQL = []byte(`query {
|
|
||||||
proDUcts(
|
|
||||||
# returns only 30 items
|
|
||||||
limit: 30,
|
|
||||||
|
|
||||||
# starts from item 10, commented out for now
|
|
||||||
# offset: 10,
|
|
||||||
|
|
||||||
# orders the response items by highest price
|
|
||||||
order_by: { price: desc },
|
|
||||||
|
|
||||||
# only items with an id >= 30 and < 30 are returned
|
|
||||||
where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) {
|
|
||||||
id
|
|
||||||
NAME
|
|
||||||
price
|
|
||||||
user {
|
|
||||||
full_name
|
|
||||||
picture : avatar
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}`)
|
|
||||||
|
|
||||||
func blockedQuery(t *testing.T) {
|
func blockedQuery(t *testing.T) {
|
||||||
gql := `query {
|
gql := `query {
|
||||||
user(id: 5, where: { id: { gt: 3 } }) {
|
user(id: 5, where: { id: { gt: 3 } }) {
|
||||||
|
@ -456,6 +438,51 @@ func blockedFunctions(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCompileQuery(t *testing.T) {
|
||||||
|
t.Run("withComplexArgs", withComplexArgs)
|
||||||
|
t.Run("withWhereAndList", withWhereAndList)
|
||||||
|
t.Run("withWhereIsNull", withWhereIsNull)
|
||||||
|
t.Run("withWhereMultiOr", withWhereMultiOr)
|
||||||
|
t.Run("fetchByID", fetchByID)
|
||||||
|
t.Run("searchQuery", searchQuery)
|
||||||
|
t.Run("belongsTo", belongsTo)
|
||||||
|
t.Run("oneToMany", oneToMany)
|
||||||
|
t.Run("manyToMany", manyToMany)
|
||||||
|
t.Run("manyToManyReverse", manyToManyReverse)
|
||||||
|
t.Run("aggFunction", aggFunction)
|
||||||
|
t.Run("aggFunctionBlockedByCol", aggFunctionBlockedByCol)
|
||||||
|
t.Run("aggFunctionDisabled", aggFunctionDisabled)
|
||||||
|
t.Run("aggFunctionWithFilter", aggFunctionWithFilter)
|
||||||
|
t.Run("syntheticTables", syntheticTables)
|
||||||
|
t.Run("queryWithVariables", queryWithVariables)
|
||||||
|
t.Run("withWhereOnRelations", withWhereOnRelations)
|
||||||
|
t.Run("blockedQuery", blockedQuery)
|
||||||
|
t.Run("blockedFunctions", blockedFunctions)
|
||||||
|
}
|
||||||
|
|
||||||
|
var benchGQL = []byte(`query {
|
||||||
|
proDUcts(
|
||||||
|
# returns only 30 items
|
||||||
|
limit: 30,
|
||||||
|
|
||||||
|
# starts from item 10, commented out for now
|
||||||
|
# offset: 10,
|
||||||
|
|
||||||
|
# orders the response items by highest price
|
||||||
|
order_by: { price: desc },
|
||||||
|
|
||||||
|
# only items with an id >= 30 and < 30 are returned
|
||||||
|
where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) {
|
||||||
|
id
|
||||||
|
NAME
|
||||||
|
price
|
||||||
|
user {
|
||||||
|
full_name
|
||||||
|
picture : avatar
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}`)
|
||||||
|
|
||||||
func BenchmarkCompile(b *testing.B) {
|
func BenchmarkCompile(b *testing.B) {
|
||||||
w := &bytes.Buffer{}
|
w := &bytes.Buffer{}
|
||||||
|
|
||||||
|
|
|
@ -57,16 +57,16 @@ type Column struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type Exp struct {
|
type Exp struct {
|
||||||
Op ExpOp
|
Op ExpOp
|
||||||
Col string
|
Col string
|
||||||
NestedCol bool
|
NestedCols []string
|
||||||
Type ValType
|
Type ValType
|
||||||
Val string
|
Val string
|
||||||
ListType ValType
|
ListType ValType
|
||||||
ListVal []string
|
ListVal []string
|
||||||
Children []*Exp
|
Children []*Exp
|
||||||
childrenA [5]*Exp
|
childrenA [5]*Exp
|
||||||
doFree bool
|
doFree bool
|
||||||
}
|
}
|
||||||
|
|
||||||
var zeroExp = Exp{doFree: true}
|
var zeroExp = Exp{doFree: true}
|
||||||
|
@ -918,10 +918,8 @@ func setWhereColName(ex *Exp, node *Node) {
|
||||||
}
|
}
|
||||||
if len(list) == 1 {
|
if len(list) == 1 {
|
||||||
ex.Col = list[0]
|
ex.Col = list[0]
|
||||||
|
} else if len(list) > 1 {
|
||||||
} else if len(list) > 2 {
|
ex.NestedCols = list
|
||||||
ex.Col = buildPath(list)
|
|
||||||
ex.NestedCol = true
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -8,19 +8,19 @@ import (
|
||||||
"github.com/dosco/super-graph/jsn"
|
"github.com/dosco/super-graph/jsn"
|
||||||
)
|
)
|
||||||
|
|
||||||
func varMap(ctx *coreContext) func(w io.Writer, tag string) (int, error) {
|
func argMap(ctx *coreContext) func(w io.Writer, tag string) (int, error) {
|
||||||
return func(w io.Writer, tag string) (int, error) {
|
return func(w io.Writer, tag string) (int, error) {
|
||||||
switch tag {
|
switch tag {
|
||||||
case "user_id_provider":
|
case "user_id_provider":
|
||||||
if v := ctx.Value(userIDProviderKey); v != nil {
|
if v := ctx.Value(userIDProviderKey); v != nil {
|
||||||
return stringVar(w, v.(string))
|
return stringArg(w, v.(string))
|
||||||
}
|
}
|
||||||
io.WriteString(w, "null")
|
io.WriteString(w, "null")
|
||||||
return 0, nil
|
return 0, nil
|
||||||
|
|
||||||
case "user_id":
|
case "user_id":
|
||||||
if v := ctx.Value(userIDKey); v != nil {
|
if v := ctx.Value(userIDKey); v != nil {
|
||||||
return stringVar(w, v.(string))
|
return stringArg(w, v.(string))
|
||||||
}
|
}
|
||||||
|
|
||||||
io.WriteString(w, "null")
|
io.WriteString(w, "null")
|
||||||
|
@ -28,7 +28,7 @@ func varMap(ctx *coreContext) func(w io.Writer, tag string) (int, error) {
|
||||||
|
|
||||||
case "user_role":
|
case "user_role":
|
||||||
if v := ctx.Value(userRoleKey); v != nil {
|
if v := ctx.Value(userRoleKey); v != nil {
|
||||||
return stringVar(w, v.(string))
|
return stringArg(w, v.(string))
|
||||||
}
|
}
|
||||||
io.WriteString(w, "null")
|
io.WriteString(w, "null")
|
||||||
return 0, nil
|
return 0, nil
|
||||||
|
@ -50,7 +50,7 @@ func varMap(ctx *coreContext) func(w io.Writer, tag string) (int, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
if is {
|
if is {
|
||||||
return stringVarB(w, fields[0].Value)
|
return stringArgB(w, fields[0].Value)
|
||||||
}
|
}
|
||||||
|
|
||||||
w.Write(fields[0].Value)
|
w.Write(fields[0].Value)
|
||||||
|
@ -58,7 +58,7 @@ func varMap(ctx *coreContext) func(w io.Writer, tag string) (int, error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func varList(ctx *coreContext, args [][]byte) []interface{} {
|
func argList(ctx *coreContext, args [][]byte) []interface{} {
|
||||||
vars := make([]interface{}, len(args))
|
vars := make([]interface{}, len(args))
|
||||||
|
|
||||||
var fields map[string]interface{}
|
var fields map[string]interface{}
|
||||||
|
@ -86,6 +86,11 @@ func varList(ctx *coreContext, args [][]byte) []interface{} {
|
||||||
vars[i] = v.(string)
|
vars[i] = v.(string)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
case bytes.Equal(av, []byte("user_role")):
|
||||||
|
if v := ctx.Value(userRoleKey); v != nil {
|
||||||
|
vars[i] = v.(string)
|
||||||
|
}
|
||||||
|
|
||||||
default:
|
default:
|
||||||
if v, ok := fields[string(av)]; ok {
|
if v, ok := fields[string(av)]; ok {
|
||||||
vars[i] = v
|
vars[i] = v
|
||||||
|
@ -96,7 +101,7 @@ func varList(ctx *coreContext, args [][]byte) []interface{} {
|
||||||
return vars
|
return vars
|
||||||
}
|
}
|
||||||
|
|
||||||
func stringVar(w io.Writer, v string) (int, error) {
|
func stringArg(w io.Writer, v string) (int, error) {
|
||||||
if n, err := w.Write([]byte(`'`)); err != nil {
|
if n, err := w.Write([]byte(`'`)); err != nil {
|
||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
||||||
|
@ -106,7 +111,7 @@ func stringVar(w io.Writer, v string) (int, error) {
|
||||||
return w.Write([]byte(`'`))
|
return w.Write([]byte(`'`))
|
||||||
}
|
}
|
||||||
|
|
||||||
func stringVarB(w io.Writer, v []byte) (int, error) {
|
func stringArgB(w io.Writer, v []byte) (int, error) {
|
||||||
if n, err := w.Write([]byte(`'`)); err != nil {
|
if n, err := w.Write([]byte(`'`)); err != nil {
|
||||||
return n, err
|
return n, err
|
||||||
}
|
}
|
|
@ -66,6 +66,7 @@ type config struct {
|
||||||
PoolSize int32 `mapstructure:"pool_size"`
|
PoolSize int32 `mapstructure:"pool_size"`
|
||||||
MaxRetries int `mapstructure:"max_retries"`
|
MaxRetries int `mapstructure:"max_retries"`
|
||||||
LogLevel string `mapstructure:"log_level"`
|
LogLevel string `mapstructure:"log_level"`
|
||||||
|
SetUserID bool `mapstructure:"set_user_id"`
|
||||||
|
|
||||||
Vars map[string]string `mapstructure:"variables"`
|
Vars map[string]string `mapstructure:"variables"`
|
||||||
|
|
||||||
|
|
25
serv/core.go
25
serv/core.go
|
@ -122,10 +122,8 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) {
|
||||||
}
|
}
|
||||||
defer tx.Rollback(c)
|
defer tx.Rollback(c)
|
||||||
|
|
||||||
if v := c.Value(userIDKey); v != nil {
|
if conf.DB.SetUserID {
|
||||||
_, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
|
if err := c.setLocalUserID(tx); err != nil {
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -153,7 +151,7 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
var root []byte
|
var root []byte
|
||||||
vars := varList(c, ps.args)
|
vars := argList(c, ps.args)
|
||||||
|
|
||||||
if mutation {
|
if mutation {
|
||||||
err = tx.QueryRow(c, ps.stmt.SQL, vars...).Scan(&root)
|
err = tx.QueryRow(c, ps.stmt.SQL, vars...).Scan(&root)
|
||||||
|
@ -206,7 +204,7 @@ func (c *coreContext) resolveSQL() ([]byte, uint32, error) {
|
||||||
t := fasttemplate.New(st.sql, openVar, closeVar)
|
t := fasttemplate.New(st.sql, openVar, closeVar)
|
||||||
|
|
||||||
buf := &bytes.Buffer{}
|
buf := &bytes.Buffer{}
|
||||||
_, err = t.ExecuteFunc(buf, varMap(c))
|
_, err = t.ExecuteFunc(buf, argMap(c))
|
||||||
|
|
||||||
if err == errNoUserID {
|
if err == errNoUserID {
|
||||||
logger.Warn().Msg("no user id found. query requires an authenicated request")
|
logger.Warn().Msg("no user id found. query requires an authenicated request")
|
||||||
|
@ -224,10 +222,8 @@ func (c *coreContext) resolveSQL() ([]byte, uint32, error) {
|
||||||
stime = time.Now()
|
stime = time.Now()
|
||||||
}
|
}
|
||||||
|
|
||||||
if v := c.Value(userIDKey); v != nil {
|
if conf.DB.SetUserID {
|
||||||
_, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
|
if err := c.setLocalUserID(tx); err != nil {
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -425,6 +421,15 @@ func (c *coreContext) executeRoleQuery(tx pgx.Tx) (string, error) {
|
||||||
return role, nil
|
return role, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *coreContext) setLocalUserID(tx pgx.Tx) error {
|
||||||
|
var err error
|
||||||
|
if v := c.Value(userIDKey); v != nil {
|
||||||
|
_, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
|
||||||
|
}
|
||||||
|
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
func (c *coreContext) render(w io.Writer, data []byte) error {
|
func (c *coreContext) render(w io.Writer, data []byte) error {
|
||||||
c.res.Data = json.RawMessage(data)
|
c.res.Data = json.RawMessage(data)
|
||||||
return json.NewEncoder(w).Encode(c.res)
|
return json.NewEncoder(w).Encode(c.res)
|
||||||
|
|
|
@ -51,7 +51,8 @@ auth:
|
||||||
cookie: _{{app_name_slug}}_session
|
cookie: _{{app_name_slug}}_session
|
||||||
|
|
||||||
# Comment this out if you want to disable setting
|
# Comment this out if you want to disable setting
|
||||||
# the user_id via a header. Good for testing
|
# the user_id via a header for testing.
|
||||||
|
# Disable in production
|
||||||
creds_in_header: true
|
creds_in_header: true
|
||||||
|
|
||||||
rails:
|
rails:
|
||||||
|
@ -92,6 +93,10 @@ database:
|
||||||
#max_retries: 0
|
#max_retries: 0
|
||||||
#log_level: "debug"
|
#log_level: "debug"
|
||||||
|
|
||||||
|
# Set session variable "user.id" to the user id
|
||||||
|
# Enable this if you need the user id in triggers, etc
|
||||||
|
set_user_id: false
|
||||||
|
|
||||||
# Define variables here that you want to use in filters
|
# Define variables here that you want to use in filters
|
||||||
# sub-queries must be wrapped in ()
|
# sub-queries must be wrapped in ()
|
||||||
variables:
|
variables:
|
||||||
|
@ -131,7 +136,7 @@ tables:
|
||||||
name: me
|
name: me
|
||||||
table: users
|
table: users
|
||||||
|
|
||||||
roles_query: "SELECT * FROM users as usr WHERE id = $user_id"
|
roles_query: "SELECT * FROM users WHERE id = $user_id"
|
||||||
|
|
||||||
roles:
|
roles:
|
||||||
- name: anon
|
- name: anon
|
||||||
|
|
|
@ -85,3 +85,7 @@ database:
|
||||||
#pool_size: 10
|
#pool_size: 10
|
||||||
#max_retries: 0
|
#max_retries: 0
|
||||||
#log_level: "debug"
|
#log_level: "debug"
|
||||||
|
|
||||||
|
# Set session variable "user.id" to the user id
|
||||||
|
# Enable this if you need the user id in triggers, etc
|
||||||
|
set_user_id: false
|
Loading…
Reference in New Issue