Optimize db queries limit use of transactions

This commit is contained in:
Vikram Rangnekar 2019-11-21 02:14:12 -05:00
parent 176514b5f1
commit a4c09dedd5
18 changed files with 298 additions and 196 deletions

View File

@ -287,4 +287,18 @@ query {
} }
} }
variables {
"data": {
"name": "WOOO",
"price": 50.5
}
}
mutation {
products(insert: $data) {
id
name
}
}

View File

@ -97,10 +97,9 @@ database:
# Enable this if you need the user id in triggers, etc # Enable this if you need the user id in triggers, etc
set_user_id: false set_user_id: false
# Define variables here that you want to use in filters # Define additional variables here to be used with filters
# sub-queries must be wrapped in ()
variables: variables:
account_id: "(select account_id from users where id = $user_id)" admin_account_id: "5"
# Define defaults to for the field key and values below # Define defaults to for the field key and values below
defaults: defaults:
@ -174,8 +173,10 @@ roles:
filters: ["{ user_id: { eq: $user_id } }"] filters: ["{ user_id: { eq: $user_id } }"]
columns: ["id", "name", "description" ] columns: ["id", "name", "description" ]
presets: presets:
- user_id: "$user_id"
- created_at: "now" - created_at: "now"
- updated_at: "now"
update: update:
filters: ["{ user_id: { eq: $user_id } }"] filters: ["{ user_id: { eq: $user_id } }"]
columns: columns:
@ -188,8 +189,7 @@ roles:
block: true block: true
- name: admin - name: admin
match: id = 1 match: id = 1000
tables: tables:
- name: users - name: users
# query: filters: []
# filters: ["{ account_id: { _eq: $account_id } }"]

View File

@ -46,10 +46,10 @@ for (i = 0; i < product_count; i++) {
var data = { var data = {
name: fake.beer_name(), name: fake.beer_name(),
description: desc, description: desc,
price: fake.price(), price: fake.price()
user_id: user.id, //user_id: user.id,
created_at: "now", //created_at: "now",
updated_at: "now" //updated_at: "now"
} }
var res = graphql(" \ var res = graphql(" \
@ -57,7 +57,9 @@ for (i = 0; i < product_count; i++) {
product(insert: $data) { \ product(insert: $data) { \
id \ id \
} \ } \
}", { data: data }) }", { data: data }, {
user_id: 5
})
products.push(res.product) products.push(res.product)
} }

View File

@ -276,7 +276,7 @@ transmission_gear_type
// Text // Text
word word
sentence sentence
paragrph paragraph
question question
quote quote
@ -1083,25 +1083,6 @@ must be run to help figure out a users role. This query can be as complex as you
The individual roles are defined under the `roles` parameter and this includes each table the role has a custom setting for. The role is dynamically matched using the `match` parameter for example in the above case `users.id = 1` means that when the `roles_query` is executed a user with the id `1` willbe assigned the admin role and those that don't match get the `user` role if authenticated successfully or the `anon` role. The individual roles are defined under the `roles` parameter and this includes each table the role has a custom setting for. The role is dynamically matched using the `match` parameter for example in the above case `users.id = 1` means that when the `roles_query` is executed a user with the id `1` willbe assigned the admin role and those that don't match get the `user` role if authenticated successfully or the `anon` role.
This below example would work for SAAS apps where an account (tenant) is usually the top parent table to everything else.
```yaml
roles_query: "SELECT * FROM users JOIN accounts on accounts.id = users.account_id WHERE users.id = $user_id"
roles:
- name: user
tables:
- name: users
...
- name: admin
match: accounts.admin_id = $user_id
tables:
- name: users
query:
filters: [{ accounts: { id: { eq: $account_id } } }]
```
## Remote Joins ## Remote Joins
It often happens that after fetching some data from the DB we need to call another API to fetch some more data and all this combined into a single JSON response. For example along with a list of users you need their last 5 payments from Stripe. This requires you to query your DB for the users and Stripe for the payments. Super Graph handles all this for you also only the fields you requested from the Stripe API are returned. It often happens that after fetching some data from the DB we need to call another API to fetch some more data and all this combined into a single JSON response. For example along with a list of users you need their last 5 payments from Stripe. This requires you to query your DB for the users and Stripe for the payments. Super Graph handles all this for you also only the fields you requested from the Stripe API are returned.
@ -1290,10 +1271,9 @@ database:
# Enable this if you need the user id in triggers, etc # Enable this if you need the user id in triggers, etc
set_user_id: false set_user_id: false
# Define variables here that you want to use in filters # Define additional variables here to be used with filters
# sub-queries must be wrapped in ()
variables: variables:
account_id: "(select account_id from users where id = $user_id)" admin_account_id: "5"
# Define defaults to for the field key and values below # Define defaults to for the field key and values below
defaults: defaults:
@ -1381,11 +1361,10 @@ roles:
deny: true deny: true
- name: admin - name: admin
match: id = 1 match: id = 1000
tables: tables:
- name: users - name: users
# query: filters: []
# filters: ["{ account_id: { _eq: $account_id } }"]
``` ```

View File

@ -77,9 +77,9 @@ func (c *compilerContext) renderInsert(qc *qcode.QCode, w io.Writer,
return 0, err return 0, err
} }
io.WriteString(c.w, `(WITH "input" AS (SELECT {{`) io.WriteString(c.w, `(WITH "input" AS (SELECT '{{`)
io.WriteString(c.w, qc.ActionVar) io.WriteString(c.w, qc.ActionVar)
io.WriteString(c.w, `}}::json AS j) INSERT INTO `) io.WriteString(c.w, `}}' :: json AS j) INSERT INTO `)
quoted(c.w, ti.Name) quoted(c.w, ti.Name)
io.WriteString(c.w, ` (`) io.WriteString(c.w, ` (`)
c.renderInsertUpdateColumns(qc, w, jt, ti, false) c.renderInsertUpdateColumns(qc, w, jt, ti, false)
@ -174,9 +174,9 @@ func (c *compilerContext) renderUpdate(qc *qcode.QCode, w io.Writer,
return 0, err return 0, err
} }
io.WriteString(c.w, `(WITH "input" AS (SELECT {{`) io.WriteString(c.w, `(WITH "input" AS (SELECT '{{`)
io.WriteString(c.w, qc.ActionVar) io.WriteString(c.w, qc.ActionVar)
io.WriteString(c.w, `}}::json AS j) UPDATE `) io.WriteString(c.w, `}}' :: json AS j) UPDATE `)
quoted(c.w, ti.Name) quoted(c.w, ti.Name)
io.WriteString(c.w, ` SET (`) io.WriteString(c.w, ` SET (`)
c.renderInsertUpdateColumns(qc, w, jt, ti, false) c.renderInsertUpdateColumns(qc, w, jt, ti, false)

View File

@ -12,7 +12,7 @@ func simpleInsert(t *testing.T) {
} }
}` }`
sql := `WITH "users" AS (WITH "input" AS (SELECT {{data}}::json AS j) INSERT INTO "users" ("full_name", "email") SELECT "full_name", "email" FROM input i, json_populate_record(NULL::users, i.j) t RETURNING *) SELECT json_object_agg('user', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "users_0"."id" AS "id") AS "json_row_0")) AS "json_0" FROM (SELECT "users"."id" FROM "users" LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "sel_0"` sql := `WITH "users" AS (WITH "input" AS (SELECT '{{data}}' :: json AS j) INSERT INTO "users" ("full_name", "email") SELECT "full_name", "email" FROM input i, json_populate_record(NULL::users, i.j) t RETURNING *) SELECT json_object_agg('user', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "users_0"."id" AS "id") AS "json_row_0")) AS "json_0" FROM (SELECT "users"."id" FROM "users" LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "sel_0"`
vars := map[string]json.RawMessage{ vars := map[string]json.RawMessage{
"data": json.RawMessage(`{"email": "reannagreenholt@orn.com", "full_name": "Flo Barton"}`), "data": json.RawMessage(`{"email": "reannagreenholt@orn.com", "full_name": "Flo Barton"}`),
@ -36,7 +36,7 @@ func singleInsert(t *testing.T) {
} }
}` }`
sql := `WITH "products" AS (WITH "input" AS (SELECT {{insert}}::json AS j) INSERT INTO "products" ("name", "description", "user_id") SELECT "name", "description", "user_id" FROM input i, json_populate_record(NULL::products, i.j) t RETURNING *) SELECT json_object_agg('product', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "sel_0"` sql := `WITH "products" AS (WITH "input" AS (SELECT '{{insert}}' :: json AS j) INSERT INTO "products" ("name", "description", "user_id") SELECT "name", "description", "user_id" FROM input i, json_populate_record(NULL::products, i.j) t RETURNING *) SELECT json_object_agg('product', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "sel_0"`
vars := map[string]json.RawMessage{ vars := map[string]json.RawMessage{
"insert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc", "user_id": 5 }`), "insert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc", "user_id": 5 }`),
@ -60,7 +60,7 @@ func bulkInsert(t *testing.T) {
} }
}` }`
sql := `WITH "products" AS (WITH "input" AS (SELECT {{insert}}::json AS j) INSERT INTO "products" ("name", "description") SELECT "name", "description" FROM input i, json_populate_recordset(NULL::products, i.j) t RETURNING *) SELECT json_object_agg('product', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "sel_0"` sql := `WITH "products" AS (WITH "input" AS (SELECT '{{insert}}' :: json AS j) INSERT INTO "products" ("name", "description") SELECT "name", "description" FROM input i, json_populate_recordset(NULL::products, i.j) t RETURNING *) SELECT json_object_agg('product', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "sel_0"`
vars := map[string]json.RawMessage{ vars := map[string]json.RawMessage{
"insert": json.RawMessage(` [{ "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }]`), "insert": json.RawMessage(` [{ "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }]`),
@ -84,7 +84,7 @@ func singleUpsert(t *testing.T) {
} }
}` }`
sql := `WITH "products" AS (WITH "input" AS (SELECT {{upsert}}::json AS j) INSERT INTO "products" ("name", "description") SELECT "name", "description" FROM input i, json_populate_record(NULL::products, i.j) t ON CONFLICT (id) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "sel_0"` sql := `WITH "products" AS (WITH "input" AS (SELECT '{{upsert}}' :: json AS j) INSERT INTO "products" ("name", "description") SELECT "name", "description" FROM input i, json_populate_record(NULL::products, i.j) t ON CONFLICT (id) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "sel_0"`
vars := map[string]json.RawMessage{ vars := map[string]json.RawMessage{
"upsert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`), "upsert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`),
@ -108,7 +108,7 @@ func singleUpsertWhere(t *testing.T) {
} }
}` }`
sql := `WITH "products" AS (WITH "input" AS (SELECT {{upsert}}::json AS j) INSERT INTO "products" ("name", "description") SELECT "name", "description" FROM input i, json_populate_record(NULL::products, i.j) t ON CONFLICT (id) WHERE (("products"."price") > 3) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "sel_0"` sql := `WITH "products" AS (WITH "input" AS (SELECT '{{upsert}}' :: json AS j) INSERT INTO "products" ("name", "description") SELECT "name", "description" FROM input i, json_populate_record(NULL::products, i.j) t ON CONFLICT (id) WHERE (("products"."price") > 3) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "sel_0"`
vars := map[string]json.RawMessage{ vars := map[string]json.RawMessage{
"upsert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`), "upsert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`),
@ -132,7 +132,7 @@ func bulkUpsert(t *testing.T) {
} }
}` }`
sql := `WITH "products" AS (WITH "input" AS (SELECT {{upsert}}::json AS j) INSERT INTO "products" ("name", "description") SELECT "name", "description" FROM input i, json_populate_recordset(NULL::products, i.j) t ON CONFLICT (id) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "sel_0"` sql := `WITH "products" AS (WITH "input" AS (SELECT '{{upsert}}' :: json AS j) INSERT INTO "products" ("name", "description") SELECT "name", "description" FROM input i, json_populate_recordset(NULL::products, i.j) t ON CONFLICT (id) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "sel_0"`
vars := map[string]json.RawMessage{ vars := map[string]json.RawMessage{
"upsert": json.RawMessage(` [{ "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }]`), "upsert": json.RawMessage(` [{ "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }]`),
@ -156,7 +156,7 @@ func singleUpdate(t *testing.T) {
} }
}` }`
sql := `WITH "products" AS (WITH "input" AS (SELECT {{update}}::json AS j) UPDATE "products" SET ("name", "description") = (SELECT "name", "description" FROM input i, json_populate_record(NULL::products, i.j) t) WHERE (("products"."id") = 1) AND (("products"."id") = 15) RETURNING *) SELECT json_object_agg('product', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "sel_0"` sql := `WITH "products" AS (WITH "input" AS (SELECT '{{update}}' :: json AS j) UPDATE "products" SET ("name", "description") = (SELECT "name", "description" FROM input i, json_populate_record(NULL::products, i.j) t) WHERE (("products"."id") = 1) AND (("products"."id") = 15) RETURNING *) SELECT json_object_agg('product', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "sel_0"`
vars := map[string]json.RawMessage{ vars := map[string]json.RawMessage{
"update": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`), "update": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`),
@ -203,7 +203,7 @@ func blockedInsert(t *testing.T) {
} }
}` }`
sql := `WITH "users" AS (WITH "input" AS (SELECT {{data}}::json AS j) INSERT INTO "users" ("full_name", "email") SELECT "full_name", "email" FROM input i, json_populate_record(NULL::users, i.j) t WHERE false RETURNING *) SELECT json_object_agg('user', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "users_0"."id" AS "id") AS "json_row_0")) AS "json_0" FROM (SELECT "users"."id" FROM "users" LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "sel_0"` sql := `WITH "users" AS (WITH "input" AS (SELECT '{{data}}' :: json AS j) INSERT INTO "users" ("full_name", "email") SELECT "full_name", "email" FROM input i, json_populate_record(NULL::users, i.j) t WHERE false RETURNING *) SELECT json_object_agg('user', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "users_0"."id" AS "id") AS "json_row_0")) AS "json_0" FROM (SELECT "users"."id" FROM "users" LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "sel_0"`
vars := map[string]json.RawMessage{ vars := map[string]json.RawMessage{
"data": json.RawMessage(`{"email": "reannagreenholt@orn.com", "full_name": "Flo Barton"}`), "data": json.RawMessage(`{"email": "reannagreenholt@orn.com", "full_name": "Flo Barton"}`),
@ -227,7 +227,7 @@ func blockedUpdate(t *testing.T) {
} }
}` }`
sql := `WITH "users" AS (WITH "input" AS (SELECT {{data}}::json AS j) UPDATE "users" SET ("full_name", "email") = (SELECT "full_name", "email" FROM input i, json_populate_record(NULL::users, i.j) t) WHERE false RETURNING *) SELECT json_object_agg('user', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "users_0"."id" AS "id", "users_0"."email" AS "email") AS "json_row_0")) AS "json_0" FROM (SELECT "users"."id", "users"."email" FROM "users" LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "sel_0"` sql := `WITH "users" AS (WITH "input" AS (SELECT '{{data}}' :: json AS j) UPDATE "users" SET ("full_name", "email") = (SELECT "full_name", "email" FROM input i, json_populate_record(NULL::users, i.j) t) WHERE false RETURNING *) SELECT json_object_agg('user', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "users_0"."id" AS "id", "users_0"."email" AS "email") AS "json_row_0")) AS "json_0" FROM (SELECT "users"."id", "users"."email" FROM "users" LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "sel_0"`
vars := map[string]json.RawMessage{ vars := map[string]json.RawMessage{
"data": json.RawMessage(`{"email": "reannagreenholt@orn.com", "full_name": "Flo Barton"}`), "data": json.RawMessage(`{"email": "reannagreenholt@orn.com", "full_name": "Flo Barton"}`),
@ -250,7 +250,7 @@ func simpleInsertWithPresets(t *testing.T) {
} }
}` }`
sql := `WITH "products" AS (WITH "input" AS (SELECT {{data}}::json AS j) INSERT INTO "products" ("name", "price", "created_at", "updated_at", "user_id") SELECT "name", "price", 'now' :: timestamp without time zone, 'now' :: timestamp without time zone, '$user_id' :: bigint FROM input i, json_populate_record(NULL::products, i.j) t RETURNING *) SELECT json_object_agg('product', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id" FROM "products" LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "sel_0"` sql := `WITH "products" AS (WITH "input" AS (SELECT '{{data}}' :: json AS j) INSERT INTO "products" ("name", "price", "created_at", "updated_at", "user_id") SELECT "name", "price", 'now' :: timestamp without time zone, 'now' :: timestamp without time zone, '{{user_id}}' :: bigint FROM input i, json_populate_record(NULL::products, i.j) t RETURNING *) SELECT json_object_agg('product', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id" FROM "products" LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "sel_0"`
vars := map[string]json.RawMessage{ vars := map[string]json.RawMessage{
"data": json.RawMessage(`{"name": "Tomato", "price": 5.76}`), "data": json.RawMessage(`{"name": "Tomato", "price": 5.76}`),
@ -273,7 +273,7 @@ func simpleUpdateWithPresets(t *testing.T) {
} }
}` }`
sql := `WITH "products" AS (WITH "input" AS (SELECT {{data}}::json AS j) UPDATE "products" SET ("name", "price", "updated_at") = (SELECT "name", "price", 'now' :: timestamp without time zone FROM input i, json_populate_record(NULL::products, i.j) t) WHERE (("products"."user_id") = {{user_id}}) RETURNING *) SELECT json_object_agg('product', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id" FROM "products" LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "sel_0"` sql := `WITH "products" AS (WITH "input" AS (SELECT '{{data}}' :: json AS j) UPDATE "products" SET ("name", "price", "updated_at") = (SELECT "name", "price", 'now' :: timestamp without time zone FROM input i, json_populate_record(NULL::products, i.j) t) WHERE (("products"."user_id") = '{{user_id}}' :: bigint) RETURNING *) SELECT json_object_agg('product', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id" FROM "products" LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "sel_0"`
vars := map[string]json.RawMessage{ vars := map[string]json.RawMessage{
"data": json.RawMessage(`{"name": "Apple", "price": 1.25}`), "data": json.RawMessage(`{"name": "Apple", "price": 1.25}`),

View File

@ -171,7 +171,7 @@ func TestMain(m *testing.M) {
} }
vars := NewVariables(map[string]string{ vars := NewVariables(map[string]string{
"account_id": "select account_id from users where id = $user_id", "admin_account_id": "5",
}) })
pcompile = NewCompiler(Config{ pcompile = NewCompiler(Config{

View File

@ -844,7 +844,6 @@ func (c *compilerContext) renderWhere(sel *qcode.Select, ti *DBTableInfo) error
} }
func (c *compilerContext) renderNestedWhere(ex *qcode.Exp, sel *qcode.Select, ti *DBTableInfo) error { func (c *compilerContext) renderNestedWhere(ex *qcode.Exp, sel *qcode.Select, ti *DBTableInfo) error {
for i := 0; i < len(ex.NestedCols)-1; i++ { for i := 0; i < len(ex.NestedCols)-1; i++ {
cti, err := c.schema.GetTable(ex.NestedCols[i]) cti, err := c.schema.GetTable(ex.NestedCols[i])
if err != nil { if err != nil {
@ -878,7 +877,14 @@ func (c *compilerContext) renderNestedWhere(ex *qcode.Exp, sel *qcode.Select, ti
} }
func (c *compilerContext) renderOp(ex *qcode.Exp, sel *qcode.Select, ti *DBTableInfo) error { func (c *compilerContext) renderOp(ex *qcode.Exp, sel *qcode.Select, ti *DBTableInfo) error {
var col *DBColumn
var ok bool
if len(ex.Col) != 0 { if len(ex.Col) != 0 {
if col, ok = ti.Columns[ex.Col]; !ok {
return fmt.Errorf("no column '%s' found ", ex.Col)
}
io.WriteString(c.w, `((`) io.WriteString(c.w, `((`)
colWithTable(c.w, ti.Name, ex.Col) colWithTable(c.w, ti.Name, ex.Col)
io.WriteString(c.w, `) `) io.WriteString(c.w, `) `)
@ -934,6 +940,9 @@ func (c *compilerContext) renderOp(ex *qcode.Exp, sel *qcode.Select, ti *DBTable
if len(ti.PrimaryCol) == 0 { if len(ti.PrimaryCol) == 0 {
return fmt.Errorf("no primary key column defined for %s", ti.Name) return fmt.Errorf("no primary key column defined for %s", ti.Name)
} }
if col, ok = ti.Columns[ti.PrimaryCol]; !ok {
return fmt.Errorf("no primary key column '%s' found ", ti.PrimaryCol)
}
//fmt.Fprintf(w, `(("%s") =`, c.ti.PrimaryCol) //fmt.Fprintf(w, `(("%s") =`, c.ti.PrimaryCol)
io.WriteString(c.w, `((`) io.WriteString(c.w, `((`)
colWithTable(c.w, ti.Name, ti.PrimaryCol) colWithTable(c.w, ti.Name, ti.PrimaryCol)
@ -943,6 +952,9 @@ func (c *compilerContext) renderOp(ex *qcode.Exp, sel *qcode.Select, ti *DBTable
if len(ti.TSVCol) == 0 { if len(ti.TSVCol) == 0 {
return fmt.Errorf("no tsv column defined for %s", ti.Name) return fmt.Errorf("no tsv column defined for %s", ti.Name)
} }
if col, ok = ti.Columns[ti.TSVCol]; !ok {
return fmt.Errorf("no tsv column '%s' found ", ti.TSVCol)
}
//fmt.Fprintf(w, `(("%s") @@ to_tsquery('%s'))`, c.ti.TSVCol, val.Val) //fmt.Fprintf(w, `(("%s") @@ to_tsquery('%s'))`, c.ti.TSVCol, val.Val)
io.WriteString(c.w, `(("`) io.WriteString(c.w, `(("`)
io.WriteString(c.w, ti.TSVCol) io.WriteString(c.w, ti.TSVCol)
@ -958,7 +970,7 @@ func (c *compilerContext) renderOp(ex *qcode.Exp, sel *qcode.Select, ti *DBTable
if ex.Type == qcode.ValList { if ex.Type == qcode.ValList {
c.renderList(ex) c.renderList(ex)
} else { } else {
c.renderVal(ex, c.vars) c.renderVal(ex, c.vars, col)
} }
io.WriteString(c.w, `)`) io.WriteString(c.w, `)`)
@ -1035,7 +1047,7 @@ func (c *compilerContext) renderList(ex *qcode.Exp) {
io.WriteString(c.w, `)`) io.WriteString(c.w, `)`)
} }
func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string) { func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string, col *DBColumn) {
io.WriteString(c.w, ` `) io.WriteString(c.w, ` `)
switch ex.Type { switch ex.Type {
@ -1052,6 +1064,7 @@ func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string) {
io.WriteString(c.w, `'`) io.WriteString(c.w, `'`)
case qcode.ValVar: case qcode.ValVar:
io.WriteString(c.w, `'`)
if val, ok := vars[ex.Val]; ok { if val, ok := vars[ex.Val]; ok {
io.WriteString(c.w, val) io.WriteString(c.w, val)
} else { } else {
@ -1060,6 +1073,8 @@ func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string) {
io.WriteString(c.w, ex.Val) io.WriteString(c.w, ex.Val)
io.WriteString(c.w, `}}`) io.WriteString(c.w, `}}`)
} }
io.WriteString(c.w, `' :: `)
io.WriteString(c.w, col.Type)
} }
//io.WriteString(c.w, `)`) //io.WriteString(c.w, `)`)
} }

View File

@ -339,7 +339,7 @@ func syntheticTables(t *testing.T) {
} }
}` }`
sql := `SELECT json_object_agg('me', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT ) AS "json_row_0")) AS "json_0" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = {{user_id}})) LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "sel_0"` sql := `SELECT json_object_agg('me', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT ) AS "json_row_0")) AS "json_0" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = '{{user_id}}' :: bigint)) LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "sel_0"`
resSQL, err := compileGQLToPSQL(gql, nil, "user") resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil { if err != nil {
@ -359,7 +359,7 @@ func queryWithVariables(t *testing.T) {
} }
}` }`
sql := `SELECT json_object_agg('product', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") = {{product_price}}) AND (("products"."id") = {{product_id}})) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "sel_0"` sql := `SELECT json_object_agg('product', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") = '{{product_price}}' :: numeric(7,2)) AND (("products"."id") = '{{product_id}}' :: bigint)) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "sel_0"`
resSQL, err := compileGQLToPSQL(gql, nil, "user") resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil { if err != nil {

View File

@ -1,6 +1,7 @@
package qcode package qcode
import ( import (
"regexp"
"sort" "sort"
"strings" "strings"
) )
@ -121,3 +122,12 @@ func mapToList(m map[string]string) []string {
sort.Strings(list) sort.Strings(list)
return list return list
} }
var varRe = regexp.MustCompile(`\$([a-zA-Z0-9_]+)`)
func parsePresets(m map[string]string) map[string]string {
for k, v := range m {
m[k] = varRe.ReplaceAllString(v, `{{$1}}`)
}
return m
}

View File

@ -202,7 +202,7 @@ func (com *Compiler) AddRole(role, table string, trc TRConfig) error {
return err return err
} }
trv.insert.cols = listToMap(trc.Insert.Columns) trv.insert.cols = listToMap(trc.Insert.Columns)
trv.insert.psmap = trc.Insert.Presets trv.insert.psmap = parsePresets(trc.Insert.Presets)
trv.insert.pslist = mapToList(trv.insert.psmap) trv.insert.pslist = mapToList(trv.insert.psmap)
// update config // update config
@ -210,7 +210,7 @@ func (com *Compiler) AddRole(role, table string, trc TRConfig) error {
return err return err
} }
trv.update.cols = listToMap(trc.Update.Columns) trv.update.cols = listToMap(trc.Update.Columns)
trv.update.psmap = trc.Update.Presets trv.update.psmap = parsePresets(trc.Update.Presets)
trv.update.pslist = mapToList(trv.update.psmap) trv.update.pslist = mapToList(trv.update.psmap)
// delete config // delete config

View File

@ -13,25 +13,21 @@ func argMap(ctx *coreContext) func(w io.Writer, tag string) (int, error) {
switch tag { switch tag {
case "user_id_provider": case "user_id_provider":
if v := ctx.Value(userIDProviderKey); v != nil { if v := ctx.Value(userIDProviderKey); v != nil {
return stringArg(w, v.(string)) return io.WriteString(w, v.(string))
} }
io.WriteString(w, "null") return io.WriteString(w, "null")
return 0, nil
case "user_id": case "user_id":
if v := ctx.Value(userIDKey); v != nil { if v := ctx.Value(userIDKey); v != nil {
return stringArg(w, v.(string)) return io.WriteString(w, v.(string))
} }
return io.WriteString(w, "null")
io.WriteString(w, "null")
return 0, nil
case "user_role": case "user_role":
if v := ctx.Value(userRoleKey); v != nil { if v := ctx.Value(userRoleKey); v != nil {
return stringArg(w, v.(string)) return io.WriteString(w, v.(string))
} }
io.WriteString(w, "null") return io.WriteString(w, "null")
return 0, nil
} }
fields := jsn.Get(ctx.req.Vars, [][]byte{[]byte(tag)}) fields := jsn.Get(ctx.req.Vars, [][]byte{[]byte(tag)})
@ -39,22 +35,7 @@ func argMap(ctx *coreContext) func(w io.Writer, tag string) (int, error) {
return 0, fmt.Errorf("variable '%s' not found", tag) return 0, fmt.Errorf("variable '%s' not found", tag)
} }
is := false return w.Write(fields[0].Value)
for i := range fields[0].Value {
c := fields[0].Value[i]
if c != ' ' {
is = (c == '"') || (c == '{') || (c == '[')
break
}
}
if is {
return stringArgB(w, fields[0].Value)
}
w.Write(fields[0].Value)
return 0, nil
} }
} }
@ -100,23 +81,3 @@ func argList(ctx *coreContext, args [][]byte) []interface{} {
return vars return vars
} }
func stringArg(w io.Writer, v string) (int, error) {
if n, err := w.Write([]byte(`'`)); err != nil {
return n, err
}
if n, err := w.Write([]byte(v)); err != nil {
return n, err
}
return w.Write([]byte(`'`))
}
func stringArgB(w io.Writer, v []byte) (int, error) {
if n, err := w.Write([]byte(`'`)); err != nil {
return n, err
}
if n, err := w.Write(v); err != nil {
return n, err
}
return w.Write([]byte(`'`))
}

View File

@ -1,6 +1,7 @@
package serv package serv
import ( import (
"bytes"
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
@ -12,6 +13,7 @@ import (
"github.com/brianvoe/gofakeit" "github.com/brianvoe/gofakeit"
"github.com/dop251/goja" "github.com/dop251/goja"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"github.com/valyala/fasttemplate"
) )
func cmdDBSeed(cmd *cobra.Command, args []string) { func cmdDBSeed(cmd *cobra.Command, args []string) {
@ -57,22 +59,75 @@ func cmdDBSeed(cmd *cobra.Command, args []string) {
} }
//func runFunc(call goja.FunctionCall) { //func runFunc(call goja.FunctionCall) {
func graphQLFunc(query string, data interface{}) map[string]interface{} { func graphQLFunc(query string, data interface{}, opt map[string]string) map[string]interface{} {
b, err := json.Marshal(data) b, err := json.Marshal(data)
if err != nil { if err != nil {
logger.Fatal().Err(err).Msg("failed to json serialize") logger.Fatal().Err(err).Msg("failed to json serialize")
} }
c := &coreContext{Context: context.Background()} ctx := context.Background()
if v, ok := opt["user_id"]; ok && len(v) != 0 {
ctx = context.WithValue(ctx, userIDKey, v)
}
var role string
if v, ok := opt["role"]; ok && len(v) != 0 {
role = v
} else {
role = "user"
}
c := &coreContext{Context: ctx}
c.req.Query = query c.req.Query = query
c.req.Vars = b c.req.Vars = b
c.req.role = "user"
res, err := c.execQuery() st, err := c.buildStmtByRole(role)
if err != nil { if err != nil {
logger.Fatal().Err(err).Msg("graphql query failed") logger.Fatal().Err(err).Msg("graphql query failed")
} }
buf := &bytes.Buffer{}
t := fasttemplate.New(st.sql, openVar, closeVar)
_, err = t.ExecuteFunc(buf, argMap(c))
if err == errNoUserID {
logger.Fatal().Msg("query requires a user_id")
}
if err != nil {
logger.Fatal().Err(err).Send()
}
finalSQL := buf.String()
tx, err := db.Begin(c)
if err != nil {
logger.Fatal().Err(err).Send()
}
defer tx.Rollback(c)
// if err := c.setLocalUserID(tx); err != nil {
// return nil, 0, err
// }
var root []byte
if err = tx.QueryRow(c, finalSQL).Scan(&root); err != nil {
logger.Fatal().Err(err).Msg("sql query failed")
}
if err := tx.Commit(c); err != nil {
logger.Fatal().Err(err).Send()
}
res, err := c.execRemoteJoin(st.qc, st.skipped, root)
if err != nil {
logger.Fatal().Err(err).Msg("remote join failed")
}
val := make(map[string]interface{}) val := make(map[string]interface{})
err = json.Unmarshal(res, &val) err = json.Unmarshal(res, &val)
@ -156,10 +211,9 @@ func setFakeFuncs(f *goja.Object) {
f.Set("transmission_gear_type", gofakeit.TransmissionGearType) f.Set("transmission_gear_type", gofakeit.TransmissionGearType)
// Text // Text
f.Set("word", gofakeit.Word) f.Set("word", gofakeit.Word)
f.Set("sentence", gofakeit.Sentence) f.Set("sentence", gofakeit.Sentence)
f.Set("paragrph", gofakeit.Paragraph) f.Set("paragraph", gofakeit.Paragraph)
f.Set("question", gofakeit.Question) f.Set("question", gofakeit.Question)
f.Set("quote", gofakeit.Quote) f.Set("quote", gofakeit.Quote)

View File

@ -71,6 +71,12 @@ func (c *coreContext) execQuery() ([]byte, error) {
} }
} }
return c.execRemoteJoin(qc, skipped, data)
}
func (c *coreContext) execRemoteJoin(qc *qcode.QCode, skipped uint32, data []byte) ([]byte, error) {
var err error
if len(data) == 0 || skipped == 0 { if len(data) == 0 || skipped == 0 {
return data, nil return data, nil
} }
@ -114,11 +120,19 @@ func (c *coreContext) execQuery() ([]byte, error) {
} }
func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) { func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) {
tx, err := db.Begin(c) var tx pgx.Tx
if err != nil { var err error
return nil, nil, err
mutation := isMutation(c.req.Query)
useRoleQuery := len(conf.RolesQuery) != 0 && mutation
useTx := useRoleQuery || conf.DB.SetUserID
if useTx {
if tx, err = db.Begin(c); err != nil {
return nil, nil, err
}
defer tx.Rollback(c)
} }
defer tx.Rollback(c)
if conf.DB.SetUserID { if conf.DB.SetUserID {
if err := c.setLocalUserID(tx); err != nil { if err := c.setLocalUserID(tx); err != nil {
@ -127,8 +141,6 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) {
} }
var role string var role string
mutation := isMutation(c.req.Query)
useRoleQuery := len(conf.RolesQuery) != 0 && mutation
if useRoleQuery { if useRoleQuery {
if role, err = c.executeRoleQuery(tx); err != nil { if role, err = c.executeRoleQuery(tx); err != nil {
@ -149,12 +161,19 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) {
} }
var root []byte var root []byte
var row pgx.Row
vars := argList(c, ps.args) vars := argList(c, ps.args)
if mutation { if useTx {
err = tx.QueryRow(c, ps.stmt.SQL, vars...).Scan(&root) row = tx.QueryRow(c, ps.stmt.SQL, vars...)
} else { } else {
err = tx.QueryRow(c, ps.stmt.SQL, vars...).Scan(&role, &root) row = db.QueryRow(c, ps.stmt.SQL, vars...)
}
if mutation {
err = row.Scan(&root)
} else {
err = row.Scan(&role, &root)
} }
logger.Debug().Str("default_role", c.req.role).Str("role", role).Msg(c.req.Query) logger.Debug().Str("default_role", c.req.role).Str("role", role).Msg(c.req.Query)
@ -165,22 +184,35 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) {
c.req.role = role c.req.role = role
if err := tx.Commit(c); err != nil { if useTx {
return nil, nil, err if err := tx.Commit(c); err != nil {
return nil, nil, err
}
} }
return root, ps, nil return root, ps, nil
} }
func (c *coreContext) resolveSQL() ([]byte, uint32, error) { func (c *coreContext) resolveSQL() ([]byte, uint32, error) {
tx, err := db.Begin(c) var tx pgx.Tx
if err != nil { var err error
return nil, 0, err
}
defer tx.Rollback(c)
mutation := isMutation(c.req.Query) mutation := isMutation(c.req.Query)
useRoleQuery := len(conf.RolesQuery) != 0 && mutation useRoleQuery := len(conf.RolesQuery) != 0 && mutation
useTx := useRoleQuery || conf.DB.SetUserID
if useTx {
if tx, err = db.Begin(c); err != nil {
return nil, 0, err
}
defer tx.Rollback(c)
}
if conf.DB.SetUserID {
if err := c.setLocalUserID(tx); err != nil {
return nil, 0, err
}
}
if useRoleQuery { if useRoleQuery {
if c.req.role, err = c.executeRoleQuery(tx); err != nil { if c.req.role, err = c.executeRoleQuery(tx); err != nil {
@ -225,42 +257,36 @@ func (c *coreContext) resolveSQL() ([]byte, uint32, error) {
stime = time.Now() stime = time.Now()
} }
if conf.DB.SetUserID { var root []byte
if err := c.setLocalUserID(tx); err != nil { var role, defaultRole string
return nil, 0, err var row pgx.Row
}
if useTx {
row = tx.QueryRow(c, finalSQL)
} else {
row = db.QueryRow(c, finalSQL)
} }
var root []byte
var role string
log := logger.Debug()
if mutation { if mutation {
err = tx.QueryRow(c, finalSQL).Scan(&root) err = row.Scan(&root)
log = log.Str("role", role)
} else { } else {
err = tx.QueryRow(c, finalSQL).Scan(&role, &root) err = row.Scan(&role, &root)
log = log.Str("default_role", c.req.role).Str("role", role) defaultRole = c.req.role
c.req.role = role c.req.role = role
} }
log.Msg(c.req.Query) logger.Debug().Str("default_role", defaultRole).Str("role", role).Msg(c.req.Query)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
if err := tx.Commit(c); err != nil { if useTx {
return nil, 0, err if err := tx.Commit(c); err != nil {
} return nil, 0, err
}
if mutation {
st = findStmt(c.req.role, stmts)
} else {
st = &stmts[0]
} }
if conf.EnableTracing && len(st.qc.Selects) != 0 { if conf.EnableTracing && len(st.qc.Selects) != 0 {

View File

@ -150,3 +150,39 @@ func (c *coreContext) buildStmt() ([]stmt, error) {
return stmts, nil return stmts, nil
} }
func (c *coreContext) buildStmtByRole(role string) (stmt, error) {
var st stmt
var err error
if len(role) == 0 {
return st, errors.New(`no role defined`)
}
var vars map[string]json.RawMessage
if len(c.req.Vars) != 0 {
if err := json.Unmarshal(c.req.Vars, &vars); err != nil {
return st, err
}
}
gql := []byte(c.req.Query)
st.qc, err = qcompile.Compile(gql, role)
if err != nil {
return st, err
}
w := &bytes.Buffer{}
st.skipped, err = pcompile.Compile(st.qc, w, psql.Variables(vars))
if err != nil {
return st, err
}
st.sql = w.String()
return st, nil
}

View File

@ -9,6 +9,7 @@ import (
"github.com/dosco/super-graph/qcode" "github.com/dosco/super-graph/qcode"
"github.com/jackc/pgconn" "github.com/jackc/pgconn"
"github.com/jackc/pgx/v4"
"github.com/valyala/fasttemplate" "github.com/valyala/fasttemplate"
) )
@ -24,22 +25,35 @@ var (
) )
func initPreparedList() { func initPreparedList() {
ctx := context.Background()
tx, err := db.Begin(ctx)
if err != nil {
logger.Fatal().Err(err).Send()
}
defer tx.Rollback(ctx)
_preparedList = make(map[string]*preparedItem) _preparedList = make(map[string]*preparedItem)
if err := prepareRoleStmt(); err != nil { if err := prepareRoleStmt(ctx, tx); err != nil {
logger.Fatal().Err(err).Msg("failed to prepare get role statement") logger.Fatal().Err(err).Msg("failed to prepare get role statement")
} }
for _, v := range _allowList.list { for _, v := range _allowList.list {
err := prepareStmt(ctx, tx, v.gql, v.vars)
err := prepareStmt(v.gql, v.vars)
if err != nil { if err != nil {
logger.Warn().Str("gql", v.gql).Err(err).Send() logger.Warn().Str("gql", v.gql).Err(err).Send()
} }
} }
if err := tx.Commit(ctx); err != nil {
logger.Fatal().Err(err).Send()
}
logger.Info().Msgf("Registered %d queries from allow.list as prepared statements", len(_allowList.list))
} }
func prepareStmt(gql string, varBytes json.RawMessage) error { func prepareStmt(ctx context.Context, tx pgx.Tx, gql string, varBytes json.RawMessage) error {
if len(gql) == 0 { if len(gql) == 0 {
return nil return nil
} }
@ -64,15 +78,7 @@ func prepareStmt(gql string, varBytes json.RawMessage) error {
finalSQL, am := processTemplate(s.sql) finalSQL, am := processTemplate(s.sql)
ctx := context.Background() pstmt, err := tx.Prepare(c.Context, "", finalSQL)
tx, err := db.Begin(ctx)
if err != nil {
return err
}
defer tx.Rollback(ctx)
pstmt, err := tx.Prepare(ctx, "", finalSQL)
if err != nil { if err != nil {
return err return err
} }
@ -92,15 +98,12 @@ func prepareStmt(gql string, varBytes json.RawMessage) error {
qc: s.qc, qc: s.qc,
} }
if err := tx.Commit(ctx); err != nil {
return err
}
} }
return nil return nil
} }
func prepareRoleStmt() error { func prepareRoleStmt(ctx context.Context, tx pgx.Tx) error {
if len(conf.RolesQuery) == 0 { if len(conf.RolesQuery) == 0 {
return nil return nil
} }
@ -125,15 +128,7 @@ func prepareRoleStmt() error {
roleSQL, _ := processTemplate(w.String()) roleSQL, _ := processTemplate(w.String())
ctx := context.Background() _, err := tx.Prepare(ctx, "_sg_get_role", roleSQL)
tx, err := db.Begin(ctx)
if err != nil {
return err
}
defer tx.Rollback(ctx)
_, err = tx.Prepare(ctx, "_sg_get_role", roleSQL)
if err != nil { if err != nil {
return err return err
} }
@ -142,19 +137,31 @@ func prepareRoleStmt() error {
} }
func processTemplate(tmpl string) (string, [][]byte) { func processTemplate(tmpl string) (string, [][]byte) {
t := fasttemplate.New(tmpl, `{{`, `}}`) st := struct {
am := make([][]byte, 0, 5) vmap map[string]int
i := 0 am [][]byte
i int
}{
vmap: make(map[string]int),
am: make([][]byte, 0, 5),
i: 0,
}
vmap := make(map[string]int) execFunc := func(w io.Writer, tag string) (int, error) {
if n, ok := st.vmap[tag]; ok {
return t.ExecuteFuncString(func(w io.Writer, tag string) (int, error) {
if n, ok := vmap[tag]; ok {
return w.Write([]byte(fmt.Sprintf("$%d", n))) return w.Write([]byte(fmt.Sprintf("$%d", n)))
} }
am = append(am, []byte(tag)) st.am = append(st.am, []byte(tag))
i++ st.i++
vmap[tag] = i st.vmap[tag] = st.i
return w.Write([]byte(fmt.Sprintf("$%d", i))) return w.Write([]byte(fmt.Sprintf("$%d", st.i)))
}), am }
t1 := fasttemplate.New(tmpl, `'{{`, `}}'`)
ts1 := t1.ExecuteFuncString(execFunc)
t2 := fasttemplate.New(ts1, `{{`, `}}`)
ts2 := t2.ExecuteFuncString(execFunc)
return ts2, st.am
} }

View File

@ -77,7 +77,7 @@ SQL Output
database: database:
variables: variables:
account_id: "select account_id from users where id = $user_id" admin_account_id: "5"
defaults: defaults:
Filters: ["{ user_id: { eq: $user_id } }"] Filters: ["{ user_id: { eq: $user_id } }"]

View File

@ -97,10 +97,9 @@ database:
# Enable this if you need the user id in triggers, etc # Enable this if you need the user id in triggers, etc
set_user_id: false set_user_id: false
# Define variables here that you want to use in filters # Define additional variables here to be used with filters
# sub-queries must be wrapped in ()
variables: variables:
account_id: "(select account_id from users where id = $user_id)" admin_account_id: "5"
# Define defaults to for the field key and values below # Define defaults to for the field key and values below
defaults: defaults:
@ -188,8 +187,7 @@ roles:
deny: true deny: true
- name: admin - name: admin
match: id = 1 match: id = 1000
tables: tables:
- name: users - name: users
# query: filters: []
# filters: ["{ account_id: { _eq: $account_id } }"]