Compare commits

..

5 Commits

39 changed files with 924 additions and 923 deletions

View File

@ -73,43 +73,6 @@ mutation {
}
}
variables {
"data": [
{
"name": "Gumbo1",
"created_at": "now",
"updated_at": "now"
},
{
"name": "Gumbo2",
"created_at": "now",
"updated_at": "now"
}
]
}
query {
products {
id
name
}
}
variables {
"data": [
{
"name": "Gumbo1",
"created_at": "now",
"updated_at": "now"
},
{
"name": "Gumbo2",
"created_at": "now",
"updated_at": "now"
}
]
}
query {
products {
id
@ -133,21 +96,6 @@ mutation {
}
}
variables {
"data": [
{
"name": "Gumbo1",
"created_at": "now",
"updated_at": "now"
},
{
"name": "Gumbo2",
"created_at": "now",
"updated_at": "now"
}
]
}
query {
products {
id
@ -174,39 +122,6 @@ mutation {
}
}
variables {
"data": [
{
"name": "Gumbo1",
"created_at": "now",
"updated_at": "now"
},
{
"name": "Gumbo2",
"created_at": "now",
"updated_at": "now"
}
]
}
query {
products {
id
name
users {
email
}
}
}
query {
me {
id
email
full_name
}
}
variables {
"update": {
"name": "Helloo",
@ -224,66 +139,23 @@ mutation {
}
variables {
"data": [
{
"name": "Gumbo1",
"created_at": "now",
"updated_at": "now"
},
{
"name": "Gumbo2",
"created_at": "now",
"updated_at": "now"
}
]
"data": {
"name": "WOOO",
"price": 50.5
}
}
query {
product {
mutation {
products(insert: $data) {
id
name
}
}
variables {
"data": [
{
"name": "Gumbo1",
"created_at": "now",
"updated_at": "now"
},
{
"name": "Gumbo2",
"created_at": "now",
"updated_at": "now"
}
]
}
query {
products {
id
name
description
users {
email
}
}
}
query {
users {
id
email
picture: avatar
password
full_name
products(limit: 2, where: {price: {gt: 10}}) {
id
name
description
price
}
}
}

View File

@ -97,23 +97,18 @@ database:
# Enable this if you need the user id in triggers, etc
set_user_id: false
# Define variables here that you want to use in filters
# sub-queries must be wrapped in ()
# Define additional variables here to be used with filters
variables:
account_id: "(select account_id from users where id = $user_id)"
admin_account_id: "5"
# Define defaults to for the field key and values below
defaults:
# filters: ["{ user_id: { eq: $user_id } }"]
# Field and table names that you wish to block
blocklist:
- ar_internal_metadata
- schema_migrations
- secret
- password
- encrypted
- token
# Field and table names that you wish to block
blocklist:
- ar_internal_metadata
- schema_migrations
- secret
- password
- encrypted
- token
tables:
- name: customers
@ -141,6 +136,7 @@ roles_query: "SELECT * FROM users WHERE id = $user_id"
roles:
- name: anon
tables:
- name: users
- name: products
limit: 10
@ -174,8 +170,10 @@ roles:
filters: ["{ user_id: { eq: $user_id } }"]
columns: ["id", "name", "description" ]
presets:
- user_id: "$user_id"
- created_at: "now"
- updated_at: "now"
update:
filters: ["{ user_id: { eq: $user_id } }"]
columns:
@ -188,8 +186,7 @@ roles:
block: true
- name: admin
match: id = 1
match: id = 1000
tables:
- name: users
# query:
# filters: ["{ account_id: { _eq: $account_id } }"]
filters: []

View File

@ -46,10 +46,10 @@ for (i = 0; i < product_count; i++) {
var data = {
name: fake.beer_name(),
description: desc,
price: fake.price(),
user_id: user.id,
created_at: "now",
updated_at: "now"
price: fake.price()
//user_id: user.id,
//created_at: "now",
//updated_at: "now"
}
var res = graphql(" \
@ -57,7 +57,9 @@ for (i = 0; i < product_count; i++) {
product(insert: $data) { \
id \
} \
}", { data: data })
}", { data: data }, {
user_id: 5
})
products.push(res.product)
}

View File

@ -276,7 +276,7 @@ transmission_gear_type
// Text
word
sentence
paragrph
paragraph
question
quote
@ -1083,25 +1083,6 @@ must be run to help figure out a users role. This query can be as complex as you
The individual roles are defined under the `roles` parameter and this includes each table the role has a custom setting for. The role is dynamically matched using the `match` parameter for example in the above case `users.id = 1` means that when the `roles_query` is executed a user with the id `1` willbe assigned the admin role and those that don't match get the `user` role if authenticated successfully or the `anon` role.
This below example would work for SAAS apps where an account (tenant) is usually the top parent table to everything else.
```yaml
roles_query: "SELECT * FROM users JOIN accounts on accounts.id = users.account_id WHERE users.id = $user_id"
roles:
- name: user
tables:
- name: users
...
- name: admin
match: accounts.admin_id = $user_id
tables:
- name: users
query:
filters: [{ accounts: { id: { eq: $account_id } } }]
```
## Remote Joins
It often happens that after fetching some data from the DB we need to call another API to fetch some more data and all this combined into a single JSON response. For example along with a list of users you need their last 5 payments from Stripe. This requires you to query your DB for the users and Stripe for the payments. Super Graph handles all this for you also only the fields you requested from the Stripe API are returned.
@ -1290,23 +1271,18 @@ database:
# Enable this if you need the user id in triggers, etc
set_user_id: false
# Define variables here that you want to use in filters
# sub-queries must be wrapped in ()
# Define additional variables here to be used with filters
variables:
account_id: "(select account_id from users where id = $user_id)"
admin_account_id: "5"
# Define defaults to for the field key and values below
defaults:
# filters: ["{ user_id: { eq: $user_id } }"]
# Field and table names that you wish to block
blocklist:
- ar_internal_metadata
- schema_migrations
- secret
- password
- encrypted
- token
# Field and table names that you wish to block
blocklist:
- ar_internal_metadata
- schema_migrations
- secret
- password
- encrypted
- token
tables:
- name: customers
@ -1378,14 +1354,13 @@ roles:
- updated_at: "now"
delete:
deny: true
block: true
- name: admin
match: id = 1
match: id = 1000
tables:
- name: users
# query:
# filters: ["{ account_id: { _eq: $account_id } }"]
filters: []
```

View File

@ -77,9 +77,9 @@ func (c *compilerContext) renderInsert(qc *qcode.QCode, w io.Writer,
return 0, err
}
io.WriteString(c.w, `(WITH "input" AS (SELECT {{`)
io.WriteString(c.w, `(WITH "input" AS (SELECT '{{`)
io.WriteString(c.w, qc.ActionVar)
io.WriteString(c.w, `}}::json AS j) INSERT INTO `)
io.WriteString(c.w, `}}' :: json AS j) INSERT INTO `)
quoted(c.w, ti.Name)
io.WriteString(c.w, ` (`)
c.renderInsertUpdateColumns(qc, w, jt, ti, false)
@ -174,9 +174,9 @@ func (c *compilerContext) renderUpdate(qc *qcode.QCode, w io.Writer,
return 0, err
}
io.WriteString(c.w, `(WITH "input" AS (SELECT {{`)
io.WriteString(c.w, `(WITH "input" AS (SELECT '{{`)
io.WriteString(c.w, qc.ActionVar)
io.WriteString(c.w, `}}::json AS j) UPDATE `)
io.WriteString(c.w, `}}' :: json AS j) UPDATE `)
quoted(c.w, ti.Name)
io.WriteString(c.w, ` SET (`)
c.renderInsertUpdateColumns(qc, w, jt, ti, false)

View File

@ -12,7 +12,7 @@ func simpleInsert(t *testing.T) {
}
}`
sql := `WITH "users" AS (WITH "input" AS (SELECT {{data}}::json AS j) INSERT INTO "users" ("full_name", "email") SELECT "full_name", "email" FROM input i, json_populate_record(NULL::users, i.j) t RETURNING *) SELECT json_object_agg('user', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "users_0"."id" AS "id") AS "json_row_0")) AS "json_0" FROM (SELECT "users"."id" FROM "users" LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "sel_0"`
sql := `WITH "users" AS (WITH "input" AS (SELECT '{{data}}' :: json AS j) INSERT INTO "users" ("full_name", "email") SELECT "full_name", "email" FROM input i, json_populate_record(NULL::users, i.j) t RETURNING *) SELECT json_object_agg('user', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "users_0"."id" AS "id") AS "json_row_0")) AS "json_0" FROM (SELECT "users"."id" FROM "users" LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "sel_0"`
vars := map[string]json.RawMessage{
"data": json.RawMessage(`{"email": "reannagreenholt@orn.com", "full_name": "Flo Barton"}`),
@ -36,7 +36,7 @@ func singleInsert(t *testing.T) {
}
}`
sql := `WITH "products" AS (WITH "input" AS (SELECT {{insert}}::json AS j) INSERT INTO "products" ("name", "description", "user_id") SELECT "name", "description", "user_id" FROM input i, json_populate_record(NULL::products, i.j) t RETURNING *) SELECT json_object_agg('product', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "sel_0"`
sql := `WITH "products" AS (WITH "input" AS (SELECT '{{insert}}' :: json AS j) INSERT INTO "products" ("name", "description", "user_id") SELECT "name", "description", "user_id" FROM input i, json_populate_record(NULL::products, i.j) t RETURNING *) SELECT json_object_agg('product', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "sel_0"`
vars := map[string]json.RawMessage{
"insert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc", "user_id": 5 }`),
@ -60,7 +60,7 @@ func bulkInsert(t *testing.T) {
}
}`
sql := `WITH "products" AS (WITH "input" AS (SELECT {{insert}}::json AS j) INSERT INTO "products" ("name", "description") SELECT "name", "description" FROM input i, json_populate_recordset(NULL::products, i.j) t RETURNING *) SELECT json_object_agg('product', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "sel_0"`
sql := `WITH "products" AS (WITH "input" AS (SELECT '{{insert}}' :: json AS j) INSERT INTO "products" ("name", "description") SELECT "name", "description" FROM input i, json_populate_recordset(NULL::products, i.j) t RETURNING *) SELECT json_object_agg('product', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "sel_0"`
vars := map[string]json.RawMessage{
"insert": json.RawMessage(` [{ "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }]`),
@ -84,7 +84,7 @@ func singleUpsert(t *testing.T) {
}
}`
sql := `WITH "products" AS (WITH "input" AS (SELECT {{upsert}}::json AS j) INSERT INTO "products" ("name", "description") SELECT "name", "description" FROM input i, json_populate_record(NULL::products, i.j) t ON CONFLICT (id) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "sel_0"`
sql := `WITH "products" AS (WITH "input" AS (SELECT '{{upsert}}' :: json AS j) INSERT INTO "products" ("name", "description") SELECT "name", "description" FROM input i, json_populate_record(NULL::products, i.j) t ON CONFLICT (id) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "sel_0"`
vars := map[string]json.RawMessage{
"upsert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`),
@ -108,7 +108,7 @@ func singleUpsertWhere(t *testing.T) {
}
}`
sql := `WITH "products" AS (WITH "input" AS (SELECT {{upsert}}::json AS j) INSERT INTO "products" ("name", "description") SELECT "name", "description" FROM input i, json_populate_record(NULL::products, i.j) t ON CONFLICT (id) WHERE (("products"."price") > 3) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "sel_0"`
sql := `WITH "products" AS (WITH "input" AS (SELECT '{{upsert}}' :: json AS j) INSERT INTO "products" ("name", "description") SELECT "name", "description" FROM input i, json_populate_record(NULL::products, i.j) t ON CONFLICT (id) WHERE (("products"."price") > 3) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "sel_0"`
vars := map[string]json.RawMessage{
"upsert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`),
@ -132,7 +132,7 @@ func bulkUpsert(t *testing.T) {
}
}`
sql := `WITH "products" AS (WITH "input" AS (SELECT {{upsert}}::json AS j) INSERT INTO "products" ("name", "description") SELECT "name", "description" FROM input i, json_populate_recordset(NULL::products, i.j) t ON CONFLICT (id) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "sel_0"`
sql := `WITH "products" AS (WITH "input" AS (SELECT '{{upsert}}' :: json AS j) INSERT INTO "products" ("name", "description") SELECT "name", "description" FROM input i, json_populate_recordset(NULL::products, i.j) t ON CONFLICT (id) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "sel_0"`
vars := map[string]json.RawMessage{
"upsert": json.RawMessage(` [{ "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }]`),
@ -156,7 +156,7 @@ func singleUpdate(t *testing.T) {
}
}`
sql := `WITH "products" AS (WITH "input" AS (SELECT {{update}}::json AS j) UPDATE "products" SET ("name", "description") = (SELECT "name", "description" FROM input i, json_populate_record(NULL::products, i.j) t) WHERE (("products"."id") = 1) AND (("products"."id") = 15) RETURNING *) SELECT json_object_agg('product', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "sel_0"`
sql := `WITH "products" AS (WITH "input" AS (SELECT '{{update}}' :: json AS j) UPDATE "products" SET ("name", "description") = (SELECT "name", "description" FROM input i, json_populate_record(NULL::products, i.j) t) WHERE (("products"."id") = 1) AND (("products"."id") = 15) RETURNING *) SELECT json_object_agg('product', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "sel_0"`
vars := map[string]json.RawMessage{
"update": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`),
@ -203,7 +203,7 @@ func blockedInsert(t *testing.T) {
}
}`
sql := `WITH "users" AS (WITH "input" AS (SELECT {{data}}::json AS j) INSERT INTO "users" ("full_name", "email") SELECT "full_name", "email" FROM input i, json_populate_record(NULL::users, i.j) t WHERE false RETURNING *) SELECT json_object_agg('user', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "users_0"."id" AS "id") AS "json_row_0")) AS "json_0" FROM (SELECT "users"."id" FROM "users" LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "sel_0"`
sql := `WITH "users" AS (WITH "input" AS (SELECT '{{data}}' :: json AS j) INSERT INTO "users" ("full_name", "email") SELECT "full_name", "email" FROM input i, json_populate_record(NULL::users, i.j) t WHERE false RETURNING *) SELECT json_object_agg('user', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "users_0"."id" AS "id") AS "json_row_0")) AS "json_0" FROM (SELECT "users"."id" FROM "users" LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "sel_0"`
vars := map[string]json.RawMessage{
"data": json.RawMessage(`{"email": "reannagreenholt@orn.com", "full_name": "Flo Barton"}`),
@ -227,7 +227,7 @@ func blockedUpdate(t *testing.T) {
}
}`
sql := `WITH "users" AS (WITH "input" AS (SELECT {{data}}::json AS j) UPDATE "users" SET ("full_name", "email") = (SELECT "full_name", "email" FROM input i, json_populate_record(NULL::users, i.j) t) WHERE false RETURNING *) SELECT json_object_agg('user', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "users_0"."id" AS "id", "users_0"."email" AS "email") AS "json_row_0")) AS "json_0" FROM (SELECT "users"."id", "users"."email" FROM "users" LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "sel_0"`
sql := `WITH "users" AS (WITH "input" AS (SELECT '{{data}}' :: json AS j) UPDATE "users" SET ("full_name", "email") = (SELECT "full_name", "email" FROM input i, json_populate_record(NULL::users, i.j) t) WHERE false RETURNING *) SELECT json_object_agg('user', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "users_0"."id" AS "id", "users_0"."email" AS "email") AS "json_row_0")) AS "json_0" FROM (SELECT "users"."id", "users"."email" FROM "users" LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "sel_0"`
vars := map[string]json.RawMessage{
"data": json.RawMessage(`{"email": "reannagreenholt@orn.com", "full_name": "Flo Barton"}`),
@ -250,7 +250,7 @@ func simpleInsertWithPresets(t *testing.T) {
}
}`
sql := `WITH "products" AS (WITH "input" AS (SELECT {{data}}::json AS j) INSERT INTO "products" ("name", "price", "created_at", "updated_at", "user_id") SELECT "name", "price", 'now' :: timestamp without time zone, 'now' :: timestamp without time zone, '$user_id' :: bigint FROM input i, json_populate_record(NULL::products, i.j) t RETURNING *) SELECT json_object_agg('product', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id" FROM "products" LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "sel_0"`
sql := `WITH "products" AS (WITH "input" AS (SELECT '{{data}}' :: json AS j) INSERT INTO "products" ("name", "price", "created_at", "updated_at", "user_id") SELECT "name", "price", 'now' :: timestamp without time zone, 'now' :: timestamp without time zone, '{{user_id}}' :: bigint FROM input i, json_populate_record(NULL::products, i.j) t RETURNING *) SELECT json_object_agg('product', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id" FROM "products" LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "sel_0"`
vars := map[string]json.RawMessage{
"data": json.RawMessage(`{"name": "Tomato", "price": 5.76}`),
@ -273,7 +273,7 @@ func simpleUpdateWithPresets(t *testing.T) {
}
}`
sql := `WITH "products" AS (WITH "input" AS (SELECT {{data}}::json AS j) UPDATE "products" SET ("name", "price", "updated_at") = (SELECT "name", "price", 'now' :: timestamp without time zone FROM input i, json_populate_record(NULL::products, i.j) t) WHERE (("products"."user_id") = {{user_id}}) RETURNING *) SELECT json_object_agg('product', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id" FROM "products" LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "sel_0"`
sql := `WITH "products" AS (WITH "input" AS (SELECT '{{data}}' :: json AS j) UPDATE "products" SET ("name", "price", "updated_at") = (SELECT "name", "price", 'now' :: timestamp without time zone FROM input i, json_populate_record(NULL::products, i.j) t) WHERE (("products"."user_id") = '{{user_id}}' :: bigint) RETURNING *) SELECT json_object_agg('product', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id" FROM "products" LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "sel_0"`
vars := map[string]json.RawMessage{
"data": json.RawMessage(`{"name": "Apple", "price": 1.25}`),

View File

@ -171,7 +171,7 @@ func TestMain(m *testing.M) {
}
vars := NewVariables(map[string]string{
"account_id": "select account_id from users where id = $user_id",
"admin_account_id": "5",
})
pcompile = NewCompiler(Config{

View File

@ -500,7 +500,7 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
var groupBy []int
isRoot := sel.ParentID == -1
isFil := sel.Where != nil
isFil := (sel.Where != nil && sel.Where.Op != qcode.OpNop)
isSearch := sel.Args["search"] != nil
isAgg := false
@ -844,7 +844,6 @@ func (c *compilerContext) renderWhere(sel *qcode.Select, ti *DBTableInfo) error
}
func (c *compilerContext) renderNestedWhere(ex *qcode.Exp, sel *qcode.Select, ti *DBTableInfo) error {
for i := 0; i < len(ex.NestedCols)-1; i++ {
cti, err := c.schema.GetTable(ex.NestedCols[i])
if err != nil {
@ -878,7 +877,18 @@ func (c *compilerContext) renderNestedWhere(ex *qcode.Exp, sel *qcode.Select, ti
}
func (c *compilerContext) renderOp(ex *qcode.Exp, sel *qcode.Select, ti *DBTableInfo) error {
var col *DBColumn
var ok bool
if ex.Op == qcode.OpNop {
return nil
}
if len(ex.Col) != 0 {
if col, ok = ti.Columns[ex.Col]; !ok {
return fmt.Errorf("no column '%s' found ", ex.Col)
}
io.WriteString(c.w, `((`)
colWithTable(c.w, ti.Name, ex.Col)
io.WriteString(c.w, `) `)
@ -934,6 +944,9 @@ func (c *compilerContext) renderOp(ex *qcode.Exp, sel *qcode.Select, ti *DBTable
if len(ti.PrimaryCol) == 0 {
return fmt.Errorf("no primary key column defined for %s", ti.Name)
}
if col, ok = ti.Columns[ti.PrimaryCol]; !ok {
return fmt.Errorf("no primary key column '%s' found ", ti.PrimaryCol)
}
//fmt.Fprintf(w, `(("%s") =`, c.ti.PrimaryCol)
io.WriteString(c.w, `((`)
colWithTable(c.w, ti.Name, ti.PrimaryCol)
@ -943,6 +956,9 @@ func (c *compilerContext) renderOp(ex *qcode.Exp, sel *qcode.Select, ti *DBTable
if len(ti.TSVCol) == 0 {
return fmt.Errorf("no tsv column defined for %s", ti.Name)
}
if col, ok = ti.Columns[ti.TSVCol]; !ok {
return fmt.Errorf("no tsv column '%s' found ", ti.TSVCol)
}
//fmt.Fprintf(w, `(("%s") @@ to_tsquery('%s'))`, c.ti.TSVCol, val.Val)
io.WriteString(c.w, `(("`)
io.WriteString(c.w, ti.TSVCol)
@ -958,7 +974,7 @@ func (c *compilerContext) renderOp(ex *qcode.Exp, sel *qcode.Select, ti *DBTable
if ex.Type == qcode.ValList {
c.renderList(ex)
} else {
c.renderVal(ex, c.vars)
c.renderVal(ex, c.vars, col)
}
io.WriteString(c.w, `)`)
@ -1035,7 +1051,7 @@ func (c *compilerContext) renderList(ex *qcode.Exp) {
io.WriteString(c.w, `)`)
}
func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string) {
func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string, col *DBColumn) {
io.WriteString(c.w, ` `)
switch ex.Type {
@ -1052,6 +1068,7 @@ func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string) {
io.WriteString(c.w, `'`)
case qcode.ValVar:
io.WriteString(c.w, `'`)
if val, ok := vars[ex.Val]; ok {
io.WriteString(c.w, val)
} else {
@ -1060,6 +1077,8 @@ func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string) {
io.WriteString(c.w, ex.Val)
io.WriteString(c.w, `}}`)
}
io.WriteString(c.w, `' :: `)
io.WriteString(c.w, col.Type)
}
//io.WriteString(c.w, `)`)
}

View File

@ -339,7 +339,7 @@ func syntheticTables(t *testing.T) {
}
}`
sql := `SELECT json_object_agg('me', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT ) AS "json_row_0")) AS "json_0" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = {{user_id}})) LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "sel_0"`
sql := `SELECT json_object_agg('me', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT ) AS "json_row_0")) AS "json_0" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = '{{user_id}}' :: bigint)) LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "sel_0"`
resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil {
@ -359,7 +359,7 @@ func queryWithVariables(t *testing.T) {
}
}`
sql := `SELECT json_object_agg('product', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") = {{product_price}}) AND (("products"."id") = {{product_id}})) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "sel_0"`
sql := `SELECT json_object_agg('product', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") = '{{product_price}}' :: numeric(7,2)) AND (("products"."id") = '{{product_id}}' :: bigint)) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "sel_0"`
resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil {

View File

@ -1,6 +1,7 @@
package qcode
import (
"regexp"
"sort"
"strings"
)
@ -121,3 +122,12 @@ func mapToList(m map[string]string) []string {
sort.Strings(list)
return list
}
var varRe = regexp.MustCompile(`\$([a-zA-Z0-9_]+)`)
func parsePresets(m map[string]string) map[string]string {
for k, v := range m {
m[k] = varRe.ReplaceAllString(v, `{{$1}}`)
}
return m
}

View File

@ -4,6 +4,8 @@ package qcode
// FuzzerEntrypoint for Fuzzbuzz
func Fuzz(data []byte) int {
GetQType(string(data))
qcompile, _ := NewCompiler(Config{})
_, err := qcompile.Compile(data, "user")
if err != nil {

View File

@ -20,6 +20,7 @@ const (
const (
QTQuery QType = iota + 1
QTMutation
QTInsert
QTUpdate
QTDelete
@ -202,7 +203,7 @@ func (com *Compiler) AddRole(role, table string, trc TRConfig) error {
return err
}
trv.insert.cols = listToMap(trc.Insert.Columns)
trv.insert.psmap = trc.Insert.Presets
trv.insert.psmap = parsePresets(trc.Insert.Presets)
trv.insert.pslist = mapToList(trv.insert.psmap)
// update config
@ -210,7 +211,7 @@ func (com *Compiler) AddRole(role, table string, trc TRConfig) error {
return err
}
trv.update.cols = listToMap(trc.Update.Columns)
trv.update.psmap = trc.Update.Presets
trv.update.psmap = parsePresets(trc.Update.Presets)
trv.update.pslist = mapToList(trv.update.psmap)
// delete config

23
qcode/utils.go Normal file
View File

@ -0,0 +1,23 @@
package qcode
func GetQType(gql string) QType {
for i := range gql {
b := gql[i]
if b == '{' {
return QTQuery
}
if al(b) {
switch b {
case 'm', 'M':
return QTMutation
case 'q', 'Q':
return QTQuery
}
}
}
return -1
}
func al(b byte) bool {
return (b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z') || (b >= '0' && b <= '9')
}

View File

@ -46,7 +46,7 @@ func initAllowList(cpath string) {
if _, err := os.Stat(fp); err == nil {
_allowList.filepath = fp
} else if !os.IsNotExist(err) {
logger.Fatal().Err(err).Send()
errlog.Fatal().Err(err).Send()
}
}
@ -56,7 +56,7 @@ func initAllowList(cpath string) {
if _, err := os.Stat(fp); err == nil {
_allowList.filepath = fp
} else if !os.IsNotExist(err) {
logger.Fatal().Err(err).Send()
errlog.Fatal().Err(err).Send()
}
}
@ -66,13 +66,13 @@ func initAllowList(cpath string) {
if _, err := os.Stat(fp); err == nil {
_allowList.filepath = fp
} else if !os.IsNotExist(err) {
logger.Fatal().Err(err).Send()
errlog.Fatal().Err(err).Send()
}
}
if len(_allowList.filepath) == 0 {
if conf.Production {
logger.Fatal().Msg("allow.list not found")
errlog.Fatal().Msg("allow.list not found")
}
if len(cpath) == 0 {
@ -187,7 +187,6 @@ func (al *allowList) load() {
item.gql = q
item.vars = varBytes
}
varBytes = nil
} else if ty == AL_VARS {

View File

@ -2,63 +2,46 @@ package serv
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"github.com/dosco/super-graph/jsn"
)
func argMap(ctx *coreContext) func(w io.Writer, tag string) (int, error) {
func argMap(ctx context.Context, vars []byte) func(w io.Writer, tag string) (int, error) {
return func(w io.Writer, tag string) (int, error) {
switch tag {
case "user_id_provider":
if v := ctx.Value(userIDProviderKey); v != nil {
return stringArg(w, v.(string))
return io.WriteString(w, v.(string))
}
io.WriteString(w, "null")
return 0, nil
return 0, errors.New("query requires variable $user_id_provider")
case "user_id":
if v := ctx.Value(userIDKey); v != nil {
return stringArg(w, v.(string))
return io.WriteString(w, v.(string))
}
io.WriteString(w, "null")
return 0, nil
return 0, errors.New("query requires variable $user_id")
case "user_role":
if v := ctx.Value(userRoleKey); v != nil {
return stringArg(w, v.(string))
return io.WriteString(w, v.(string))
}
io.WriteString(w, "null")
return 0, errors.New("query requires variable $user_role")
}
fields := jsn.Get(vars, [][]byte{[]byte(tag)})
if len(fields) == 0 {
return 0, nil
}
fields := jsn.Get(ctx.req.Vars, [][]byte{[]byte(tag)})
if len(fields) == 0 {
return 0, fmt.Errorf("variable '%s' not found", tag)
}
is := false
for i := range fields[0].Value {
c := fields[0].Value[i]
if c != ' ' {
is = (c == '"') || (c == '{') || (c == '[')
break
}
}
if is {
return stringArgB(w, fields[0].Value)
}
w.Write(fields[0].Value)
return 0, nil
return w.Write(fields[0].Value)
}
}
func argList(ctx *coreContext, args [][]byte) []interface{} {
func argList(ctx *coreContext, args [][]byte) ([]interface{}, error) {
vars := make([]interface{}, len(args))
var fields map[string]interface{}
@ -68,7 +51,7 @@ func argList(ctx *coreContext, args [][]byte) []interface{} {
fields, _, err = jsn.Tree(ctx.req.Vars)
if err != nil {
logger.Warn().Err(err).Msg("Failed to parse variables")
return nil, err
}
}
@ -79,44 +62,33 @@ func argList(ctx *coreContext, args [][]byte) []interface{} {
case bytes.Equal(av, []byte("user_id")):
if v := ctx.Value(userIDKey); v != nil {
vars[i] = v.(string)
} else {
return nil, errors.New("query requires variable $user_id")
}
case bytes.Equal(av, []byte("user_id_provider")):
if v := ctx.Value(userIDProviderKey); v != nil {
vars[i] = v.(string)
} else {
return nil, errors.New("query requires variable $user_id_provider")
}
case bytes.Equal(av, []byte("user_role")):
if v := ctx.Value(userRoleKey); v != nil {
vars[i] = v.(string)
} else {
return nil, errors.New("query requires variable $user_role")
}
default:
if v, ok := fields[string(av)]; ok {
vars[i] = v
} else {
return nil, fmt.Errorf("query requires variable $%s", string(av))
}
}
}
return vars
}
func stringArg(w io.Writer, v string) (int, error) {
if n, err := w.Write([]byte(`'`)); err != nil {
return n, err
}
if n, err := w.Write([]byte(v)); err != nil {
return n, err
}
return w.Write([]byte(`'`))
}
func stringArgB(w io.Writer, v []byte) (int, error) {
if n, err := w.Write([]byte(`'`)); err != nil {
return n, err
}
if n, err := w.Write(v); err != nil {
return n, err
}
return w.Write([]byte(`'`))
return vars, nil
}

View File

@ -35,7 +35,7 @@ func jwtHandler(next http.HandlerFunc) http.HandlerFunc {
case len(publicKeyFile) != 0:
kd, err := ioutil.ReadFile(publicKeyFile)
if err != nil {
logger.Fatal().Err(err).Send()
errlog.Fatal().Err(err).Send()
}
switch conf.Auth.JWT.PubKeyType {
@ -51,7 +51,7 @@ func jwtHandler(next http.HandlerFunc) http.HandlerFunc {
}
if err != nil {
logger.Fatal().Err(err).Send()
errlog.Fatal().Err(err).Send()
}
}

View File

@ -15,11 +15,11 @@ import (
func railsRedisHandler(next http.HandlerFunc) http.HandlerFunc {
cookie := conf.Auth.Cookie
if len(cookie) == 0 {
logger.Fatal().Msg("no auth.cookie defined")
errlog.Fatal().Msg("no auth.cookie defined")
}
if len(conf.Auth.Rails.URL) == 0 {
logger.Fatal().Msg("no auth.rails.url defined")
errlog.Fatal().Msg("no auth.rails.url defined")
}
rp := &redis.Pool{
@ -28,13 +28,13 @@ func railsRedisHandler(next http.HandlerFunc) http.HandlerFunc {
Dial: func() (redis.Conn, error) {
c, err := redis.DialURL(conf.Auth.Rails.URL)
if err != nil {
logger.Fatal().Err(err).Send()
errlog.Fatal().Err(err).Send()
}
pwd := conf.Auth.Rails.Password
if len(pwd) != 0 {
if _, err := c.Do("AUTH", pwd); err != nil {
logger.Fatal().Err(err).Send()
errlog.Fatal().Err(err).Send()
}
}
return c, err
@ -69,16 +69,16 @@ func railsRedisHandler(next http.HandlerFunc) http.HandlerFunc {
func railsMemcacheHandler(next http.HandlerFunc) http.HandlerFunc {
cookie := conf.Auth.Cookie
if len(cookie) == 0 {
logger.Fatal().Msg("no auth.cookie defined")
errlog.Fatal().Msg("no auth.cookie defined")
}
if len(conf.Auth.Rails.URL) == 0 {
logger.Fatal().Msg("no auth.rails.url defined")
errlog.Fatal().Msg("no auth.rails.url defined")
}
rURL, err := url.Parse(conf.Auth.Rails.URL)
if err != nil {
logger.Fatal().Err(err).Send()
errlog.Fatal().Err(err).Send()
}
mc := memcache.New(rURL.Host)
@ -111,12 +111,12 @@ func railsMemcacheHandler(next http.HandlerFunc) http.HandlerFunc {
func railsCookieHandler(next http.HandlerFunc) http.HandlerFunc {
cookie := conf.Auth.Cookie
if len(cookie) == 0 {
logger.Fatal().Msg("no auth.cookie defined")
errlog.Fatal().Msg("no auth.cookie defined")
}
ra, err := railsAuth(conf)
if err != nil {
logger.Fatal().Err(err).Send()
errlog.Fatal().Err(err).Send()
}
return func(w http.ResponseWriter, r *http.Request) {

View File

@ -22,16 +22,18 @@ const (
)
var (
logger *zerolog.Logger
logger zerolog.Logger
errlog zerolog.Logger
conf *config
confPath string
db *pgxpool.Pool
schema *psql.DBSchema
qcompile *qcode.Compiler
pcompile *psql.Compiler
)
func Init() {
logger = initLog()
initLog()
rootCmd := &cobra.Command{
Use: "super-graph",
@ -110,6 +112,13 @@ e.g. db:migrate -+1
Run: cmdDBSetup,
})
rootCmd.AddCommand(&cobra.Command{
Use: "db:reset",
Short: "Reset database",
Long: "This command will drop, create, migrate and seed the database (won't run in production)",
Run: cmdDBReset,
})
rootCmd.AddCommand(&cobra.Command{
Use: "new APP-NAME",
Short: "Create a new application",
@ -128,19 +137,14 @@ e.g. db:migrate -+1
"path", "./config", "path to config files")
if err := rootCmd.Execute(); err != nil {
logger.Fatal().Err(err).Send()
errlog.Fatal().Err(err).Send()
}
}
func initLog() *zerolog.Logger {
func initLog() {
out := zerolog.ConsoleWriter{Out: os.Stderr}
logger := zerolog.New(out).
With().
Timestamp().
Caller().
Logger()
return &logger
logger = zerolog.New(out).With().Timestamp().Logger()
errlog = logger.With().Caller().Logger()
}
func initConf() (*config, error) {
@ -159,7 +163,7 @@ func initConf() (*config, error) {
}
if vi.IsSet("inherits") {
logger.Fatal().Msgf("inherited config (%s) cannot itself inherit (%s)",
errlog.Fatal().Msgf("inherited config (%s) cannot itself inherit (%s)",
inherits,
vi.GetString("inherits"))
}
@ -176,7 +180,7 @@ func initConf() (*config, error) {
logLevel, err := zerolog.ParseLevel(c.LogLevel)
if err != nil {
logger.Error().Err(err).Msg("error setting log_level")
errlog.Error().Err(err).Msg("error setting log_level")
}
zerolog.SetGlobalLevel(logLevel)
@ -211,7 +215,7 @@ func initDB(c *config, useDB bool) (*pgx.Conn, error) {
config.LogLevel = pgx.LogLevelNone
}
config.Logger = NewSQLLogger(*logger)
config.Logger = NewSQLLogger(logger)
db, err := pgx.ConnectConfig(context.Background(), config)
if err != nil {
@ -246,7 +250,7 @@ func initDBPool(c *config) (*pgxpool.Pool, error) {
config.ConnConfig.LogLevel = pgx.LogLevelNone
}
config.ConnConfig.Logger = NewSQLLogger(*logger)
config.ConnConfig.Logger = NewSQLLogger(logger)
// if c.DB.MaxRetries != 0 {
// opt.MaxRetries = c.DB.MaxRetries
@ -269,10 +273,20 @@ func initCompiler() {
qcompile, pcompile, err = initCompilers(conf)
if err != nil {
logger.Fatal().Err(err).Msg("failed to initialize compilers")
errlog.Fatal().Err(err).Msg("failed to initialize compilers")
}
if err := initResolvers(); err != nil {
logger.Fatal().Err(err).Msg("failed to initialized resolvers")
errlog.Fatal().Err(err).Msg("failed to initialized resolvers")
}
}
func initConfOnce() {
var err error
if conf == nil {
if conf, err = initConf(); err != nil {
errlog.Fatal().Err(err).Msg("failed to read config")
}
}
}

View File

@ -17,11 +17,11 @@ func cmdConfDump(cmd *cobra.Command, args []string) {
conf, err := initConf()
if err != nil {
logger.Fatal().Err(err).Msg("failed to read config")
errlog.Fatal().Err(err).Msg("failed to read config")
}
if err := conf.Viper.WriteConfigAs(fname); err != nil {
logger.Fatal().Err(err).Send()
errlog.Fatal().Err(err).Send()
}
logger.Info().Msgf("config dumped to ./%s", fname)

View File

@ -36,6 +36,7 @@ var newMigrationText = `-- Write your migrate up statements here
`
func cmdDBSetup(cmd *cobra.Command, args []string) {
initConfOnce()
cmdDBCreate(cmd, []string{})
cmdDBMigrate(cmd, []string{"up"})
@ -48,24 +49,30 @@ func cmdDBSetup(cmd *cobra.Command, args []string) {
}
if os.IsNotExist(err) == false {
logger.Fatal().Err(err).Msgf("unable to check if '%s' exists", sfile)
errlog.Fatal().Err(err).Msgf("unable to check if '%s' exists", sfile)
}
logger.Warn().Msgf("failed to read seed file '%s'", sfile)
}
func cmdDBCreate(cmd *cobra.Command, args []string) {
var err error
func cmdDBReset(cmd *cobra.Command, args []string) {
initConfOnce()
if conf, err = initConf(); err != nil {
logger.Fatal().Err(err).Msg("failed to read config")
if conf.Production {
errlog.Fatal().Msg("db:reset does not work in production")
return
}
cmdDBDrop(cmd, []string{})
cmdDBSetup(cmd, []string{})
}
func cmdDBCreate(cmd *cobra.Command, args []string) {
initConfOnce()
ctx := context.Background()
conn, err := initDB(conf, false)
if err != nil {
logger.Fatal().Err(err).Msg("failed to connect to database")
errlog.Fatal().Err(err).Msg("failed to connect to database")
}
defer conn.Close(ctx)
@ -73,24 +80,19 @@ func cmdDBCreate(cmd *cobra.Command, args []string) {
_, err = conn.Exec(ctx, sql)
if err != nil {
logger.Fatal().Err(err).Msg("failed to create database")
errlog.Fatal().Err(err).Msg("failed to create database")
}
logger.Info().Msgf("created database '%s'", conf.DB.DBName)
}
func cmdDBDrop(cmd *cobra.Command, args []string) {
var err error
if conf, err = initConf(); err != nil {
logger.Fatal().Err(err).Msg("failed to read config")
}
initConfOnce()
ctx := context.Background()
conn, err := initDB(conf, false)
if err != nil {
logger.Fatal().Err(err).Msg("failed to connect to database")
errlog.Fatal().Err(err).Msg("failed to connect to database")
}
defer conn.Close(ctx)
@ -98,7 +100,7 @@ func cmdDBDrop(cmd *cobra.Command, args []string) {
_, err = conn.Exec(ctx, sql)
if err != nil {
logger.Fatal().Err(err).Msg("failed to create database")
errlog.Fatal().Err(err).Msg("failed to create database")
}
logger.Info().Msgf("dropped database '%s'", conf.DB.DBName)
@ -110,12 +112,7 @@ func cmdDBNew(cmd *cobra.Command, args []string) {
os.Exit(1)
}
var err error
if conf, err = initConf(); err != nil {
logger.Fatal().Err(err).Msg("failed to read config")
}
initConfOnce()
name := args[0]
m, err := migrate.FindMigrations(conf.MigrationsPath)
@ -144,39 +141,34 @@ func cmdDBNew(cmd *cobra.Command, args []string) {
}
func cmdDBMigrate(cmd *cobra.Command, args []string) {
var err error
if len(args) == 0 {
cmd.Help()
os.Exit(1)
}
initConfOnce()
dest := args[0]
if conf, err = initConf(); err != nil {
logger.Fatal().Err(err).Msg("failed to read config")
}
conn, err := initDB(conf, true)
if err != nil {
logger.Fatal().Err(err).Msg("failed to connect to database")
errlog.Fatal().Err(err).Msg("failed to connect to database")
}
defer conn.Close(context.Background())
m, err := migrate.NewMigrator(conn, "schema_version")
if err != nil {
logger.Fatal().Err(err).Msg("failed to initializing migrator")
errlog.Fatal().Err(err).Msg("failed to initializing migrator")
}
m.Data = getMigrationVars()
err = m.LoadMigrations(conf.MigrationsPath)
if err != nil {
logger.Fatal().Err(err).Msg("failed to load migrations")
errlog.Fatal().Err(err).Msg("failed to load migrations")
}
if len(m.Migrations) == 0 {
logger.Fatal().Msg("No migrations found")
errlog.Fatal().Msg("No migrations found")
}
m.OnStart = func(sequence int32, name, direction, sql string) {
@ -195,7 +187,7 @@ func cmdDBMigrate(cmd *cobra.Command, args []string) {
var n int64
n, err = strconv.ParseInt(d, 10, 32)
if err != nil {
logger.Fatal().Err(err).Msg("invalid destination")
errlog.Fatal().Err(err).Msg("invalid destination")
}
return int32(n)
}
@ -226,17 +218,15 @@ func cmdDBMigrate(cmd *cobra.Command, args []string) {
if err != nil {
logger.Info().Err(err).Send()
// logger.Info().Err(err).Send()
// if err, ok := err.(m.MigrationPgError); ok {
// if err.Detail != "" {
// logger.Info().Err(err).Msg(err.Detail)
// info.Err(err).Msg(err.Detail)
// }
// if err.Position != 0 {
// ele, err := ExtractErrorLine(err.Sql, int(err.Position))
// if err != nil {
// logger.Fatal().Err(err).Send()
// errlog.Fatal().Err(err).Send()
// }
// prefix := fmt.Sprintf()
@ -251,37 +241,33 @@ func cmdDBMigrate(cmd *cobra.Command, args []string) {
}
func cmdDBStatus(cmd *cobra.Command, args []string) {
var err error
if conf, err = initConf(); err != nil {
logger.Fatal().Err(err).Msg("failed to read config")
}
initConfOnce()
conn, err := initDB(conf, true)
if err != nil {
logger.Fatal().Err(err).Msg("failed to connect to database")
errlog.Fatal().Err(err).Msg("failed to connect to database")
}
defer conn.Close(context.Background())
m, err := migrate.NewMigrator(conn, "schema_version")
if err != nil {
logger.Fatal().Err(err).Msg("failed to initialize migrator")
errlog.Fatal().Err(err).Msg("failed to initialize migrator")
}
m.Data = getMigrationVars()
err = m.LoadMigrations(conf.MigrationsPath)
if err != nil {
logger.Fatal().Err(err).Msg("failed to load migrations")
errlog.Fatal().Err(err).Msg("failed to load migrations")
}
if len(m.Migrations) == 0 {
logger.Fatal().Msg("no migrations found")
errlog.Fatal().Msg("no migrations found")
}
mver, err := m.GetCurrentVersion()
if err != nil {
logger.Fatal().Err(err).Msg("failed to retrieve migration")
errlog.Fatal().Err(err).Msg("failed to retrieve migration")
}
var status string

View File

@ -134,12 +134,12 @@ func ifNotExists(filePath string, doFn func(string) error) {
}
if os.IsNotExist(err) == false {
logger.Fatal().Err(err).Msgf("unable to check if '%s' exists", filePath)
errlog.Fatal().Err(err).Msgf("unable to check if '%s' exists", filePath)
}
err = doFn(filePath)
if err != nil {
logger.Fatal().Err(err).Msgf("unable to create '%s'", filePath)
errlog.Fatal().Err(err).Msgf("unable to create '%s'", filePath)
}
logger.Info().Msgf("created '%s'", filePath)

View File

@ -1,6 +1,7 @@
package serv
import (
"bytes"
"context"
"encoding/json"
"fmt"
@ -12,20 +13,21 @@ import (
"github.com/brianvoe/gofakeit"
"github.com/dop251/goja"
"github.com/spf13/cobra"
"github.com/valyala/fasttemplate"
)
func cmdDBSeed(cmd *cobra.Command, args []string) {
var err error
if conf, err = initConf(); err != nil {
logger.Fatal().Err(err).Msg("failed to read config")
errlog.Fatal().Err(err).Msg("failed to read config")
}
conf.Production = false
db, err = initDBPool(conf)
if err != nil {
logger.Fatal().Err(err).Msg("failed to connect to database")
errlog.Fatal().Err(err).Msg("failed to connect to database")
}
initCompiler()
@ -34,7 +36,7 @@ func cmdDBSeed(cmd *cobra.Command, args []string) {
b, err := ioutil.ReadFile(path.Join(confPath, conf.SeedFile))
if err != nil {
logger.Fatal().Err(err).Msgf("failed to read seed file '%s'", sfile)
errlog.Fatal().Err(err).Msgf("failed to read seed file '%s'", sfile)
}
vm := goja.New()
@ -50,34 +52,77 @@ func cmdDBSeed(cmd *cobra.Command, args []string) {
_, err = vm.RunScript("seed.js", string(b))
if err != nil {
logger.Fatal().Err(err).Msg("failed to execute script")
errlog.Fatal().Err(err).Msg("failed to execute script")
}
logger.Info().Msg("seed script done")
}
//func runFunc(call goja.FunctionCall) {
func graphQLFunc(query string, data interface{}) map[string]interface{} {
b, err := json.Marshal(data)
func graphQLFunc(query string, data interface{}, opt map[string]string) map[string]interface{} {
vars, err := json.Marshal(data)
if err != nil {
logger.Fatal().Err(err).Msg("failed to json serialize")
errlog.Fatal().Err(err).Send()
}
c := &coreContext{Context: context.Background()}
c.req.Query = query
c.req.Vars = b
c.req.role = "user"
c := context.Background()
res, err := c.execQuery()
if v, ok := opt["user_id"]; ok && len(v) != 0 {
c = context.WithValue(c, userIDKey, v)
}
var role string
if v, ok := opt["role"]; ok && len(v) != 0 {
role = v
} else {
role = "user"
}
stmts, err := buildRoleStmt([]byte(query), vars, role)
if err != nil {
logger.Fatal().Err(err).Msg("graphql query failed")
errlog.Fatal().Err(err).Msg("graphql query failed")
}
st := stmts[0]
buf := &bytes.Buffer{}
t := fasttemplate.New(st.sql, openVar, closeVar)
_, err = t.ExecuteFunc(buf, argMap(c, vars))
if err != nil {
errlog.Fatal().Err(err).Send()
}
finalSQL := buf.String()
tx, err := db.Begin(c)
if err != nil {
errlog.Fatal().Err(err).Send()
}
defer tx.Rollback(c)
if conf.DB.SetUserID {
if err := setLocalUserID(c, tx); err != nil {
errlog.Fatal().Err(err).Send()
}
}
var root []byte
if err = tx.QueryRow(c, finalSQL).Scan(&root); err != nil {
errlog.Fatal().Err(err).Msg("sql query failed")
}
if err := tx.Commit(c); err != nil {
errlog.Fatal().Err(err).Send()
}
val := make(map[string]interface{})
err = json.Unmarshal(res, &val)
err = json.Unmarshal(root, &val)
if err != nil {
logger.Fatal().Err(err).Msg("failed to deserialize json")
errlog.Fatal().Err(err).Send()
}
return val
@ -156,10 +201,9 @@ func setFakeFuncs(f *goja.Object) {
f.Set("transmission_gear_type", gofakeit.TransmissionGearType)
// Text
f.Set("word", gofakeit.Word)
f.Set("sentence", gofakeit.Sentence)
f.Set("paragrph", gofakeit.Paragraph)
f.Set("paragraph", gofakeit.Paragraph)
f.Set("question", gofakeit.Question)
f.Set("quote", gofakeit.Quote)

View File

@ -8,12 +8,12 @@ func cmdServ(cmd *cobra.Command, args []string) {
var err error
if conf, err = initConf(); err != nil {
logger.Fatal().Err(err).Msg("failed to read config")
errlog.Fatal().Err(err).Msg("failed to read config")
}
db, err = initDBPool(conf)
if err != nil {
logger.Fatal().Err(err).Msg("failed to connect to database")
errlog.Fatal().Err(err).Msg("failed to connect to database")
}
initCompiler()

View File

@ -64,17 +64,12 @@ type config struct {
User string
Password string
Schema string
PoolSize int32 `mapstructure:"pool_size"`
MaxRetries int `mapstructure:"max_retries"`
LogLevel string `mapstructure:"log_level"`
SetUserID bool `mapstructure:"set_user_id"`
PoolSize int32 `mapstructure:"pool_size"`
MaxRetries int `mapstructure:"max_retries"`
SetUserID bool `mapstructure:"set_user_id"`
Vars map[string]string `mapstructure:"variables"`
Defaults struct {
Filters []string
Blocklist []string
}
Vars map[string]string `mapstructure:"variables"`
Blocklist []string
Tables []configTable
} `mapstructure:"database"`
@ -83,6 +78,7 @@ type config struct {
RolesQuery string `mapstructure:"roles_query"`
Roles []configRole
roles map[string]*configRole
}
type configTable struct {
@ -221,16 +217,15 @@ func (c *config) Init(vi *viper.Viper) error {
}
c.RolesQuery = sanitize(c.RolesQuery)
rolesMap := make(map[string]struct{})
c.roles = make(map[string]*configRole)
for i := range c.Roles {
role := &c.Roles[i]
if _, ok := rolesMap[role.Name]; ok {
logger.Fatal().Msgf("duplicate role '%s' found", role.Name)
if _, ok := c.roles[role.Name]; ok {
errlog.Fatal().Msgf("duplicate role '%s' found", role.Name)
}
role.Name = sanitize(role.Name)
role.Name = strings.ToLower(role.Name)
role.Match = sanitize(role.Match)
role.tablesMap = make(map[string]*configRoleTable)
@ -238,14 +233,16 @@ func (c *config) Init(vi *viper.Viper) error {
role.tablesMap[table.Name] = &role.Tables[n]
}
rolesMap[role.Name] = struct{}{}
c.roles[role.Name] = role
}
if _, ok := rolesMap["user"]; !ok {
c.Roles = append(c.Roles, configRole{Name: "user"})
if _, ok := c.roles["user"]; !ok {
u := configRole{Name: "user"}
c.Roles = append(c.Roles, u)
c.roles["user"] = &u
}
if _, ok := rolesMap["anon"]; !ok {
if _, ok := c.roles["anon"]; !ok {
logger.Warn().Msg("unauthenticated requests will be blocked. no role 'anon' defined")
c.AuthFailBlock = true
}
@ -262,7 +259,7 @@ func (c *config) validate() {
name := c.Roles[i].Name
if _, ok := rm[name]; ok {
logger.Fatal().Msgf("duplicate config for role '%s'", c.Roles[i].Name)
errlog.Fatal().Msgf("duplicate config for role '%s'", c.Roles[i].Name)
}
rm[name] = struct{}{}
}
@ -273,7 +270,7 @@ func (c *config) validate() {
name := c.Tables[i].Name
if _, ok := tm[name]; ok {
logger.Fatal().Msgf("duplicate config for table '%s'", c.Tables[i].Name)
errlog.Fatal().Msgf("duplicate config for table '%s'", c.Tables[i].Name)
}
tm[name] = struct{}{}
}

View File

@ -8,11 +8,9 @@ import (
"fmt"
"io"
"net/http"
"sync"
"time"
"github.com/cespare/xxhash/v2"
"github.com/dosco/super-graph/jsn"
"github.com/dosco/super-graph/qcode"
"github.com/jackc/pgx/v4"
"github.com/valyala/fasttemplate"
@ -32,6 +30,10 @@ func (c *coreContext) handleReq(w io.Writer, req *http.Request) error {
c.req.ref = req.Referer()
c.req.hdr = req.Header
if len(c.req.Vars) == 2 {
c.req.Vars = nil
}
if authCheck(c) {
c.req.role = "user"
} else {
@ -47,88 +49,55 @@ func (c *coreContext) handleReq(w io.Writer, req *http.Request) error {
}
func (c *coreContext) execQuery() ([]byte, error) {
var err error
var skipped uint32
var qc *qcode.QCode
var data []byte
var st *stmt
var err error
if conf.Production {
var ps *preparedItem
data, ps, err = c.resolvePreparedSQL()
data, st, err = c.resolvePreparedSQL()
if err != nil {
return nil, err
}
logger.Error().
Err(err).
Str("default_role", c.req.role).
Msg(c.req.Query)
skipped = ps.skipped
qc = ps.qc
return nil, errors.New("query failed. check logs for error")
}
} else {
data, skipped, err = c.resolveSQL()
if err != nil {
if data, st, err = c.resolveSQL(); err != nil {
return nil, err
}
}
if len(data) == 0 || skipped == 0 {
return data, nil
}
sel := qc.Selects
h := xxhash.New()
// fetch the field name used within the db response json
// that are used to mark insertion points and the mapping between
// those field names and their select objects
fids, sfmap := parentFieldIds(h, sel, skipped)
// fetch the field values of the marked insertion points
// these values contain the id to be used with fetching remote data
from := jsn.Get(data, fids)
var to []jsn.Field
switch {
case len(from) == 1:
to, err = c.resolveRemote(c.req.hdr, h, from[0], sel, sfmap)
case len(from) > 1:
to, err = c.resolveRemotes(c.req.hdr, h, from, sel, sfmap)
default:
return nil, errors.New("something wrong no remote ids found in db response")
}
if err != nil {
return nil, err
}
var ob bytes.Buffer
err = jsn.Replace(&ob, data, from, to)
if err != nil {
return nil, err
}
return ob.Bytes(), nil
return execRemoteJoin(st, data, c.req.hdr)
}
func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) {
tx, err := db.Begin(c)
if err != nil {
return nil, nil, err
func (c *coreContext) resolvePreparedSQL() ([]byte, *stmt, error) {
var tx pgx.Tx
var err error
qt := qcode.GetQType(c.req.Query)
mutation := (qt == qcode.QTMutation)
anonQuery := (qt == qcode.QTQuery && c.req.role == "anon")
useRoleQuery := len(conf.RolesQuery) != 0 && mutation
useTx := useRoleQuery || conf.DB.SetUserID
if useTx {
if tx, err = db.Begin(c); err != nil {
return nil, nil, err
}
defer tx.Rollback(c)
}
defer tx.Rollback(c)
if conf.DB.SetUserID {
if err := c.setLocalUserID(tx); err != nil {
if err := setLocalUserID(c, tx); err != nil {
return nil, nil, err
}
}
var role string
mutation := isMutation(c.req.Query)
useRoleQuery := len(conf.RolesQuery) != 0 && mutation
if useRoleQuery {
if role, err = c.executeRoleQuery(tx); err != nil {
@ -138,7 +107,7 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) {
} else if v := c.Value(userRoleKey); v != nil {
role = v.(string)
} else if mutation {
} else {
role = c.req.role
}
@ -149,15 +118,30 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) {
}
var root []byte
vars := argList(c, ps.args)
var row pgx.Row
if mutation {
err = tx.QueryRow(c, ps.stmt.SQL, vars...).Scan(&root)
} else {
err = tx.QueryRow(c, ps.stmt.SQL, vars...).Scan(&role, &root)
vars, err := argList(c, ps.args)
if err != nil {
return nil, nil, err
}
logger.Debug().Str("default_role", c.req.role).Str("role", role).Msg(c.req.Query)
if useTx {
row = tx.QueryRow(c, ps.sd.SQL, vars...)
} else {
row = db.QueryRow(c, ps.sd.SQL, vars...)
}
if mutation || anonQuery {
err = row.Scan(&root)
} else {
err = row.Scan(&role, &root)
}
if len(role) == 0 {
logger.Debug().Str("default_role", c.req.role).Msg(c.req.Query)
} else {
logger.Debug().Str("default_role", c.req.role).Str("role", role).Msg(c.req.Query)
}
if err != nil {
return nil, nil, err
@ -165,58 +149,61 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) {
c.req.role = role
if err := tx.Commit(c); err != nil {
return nil, nil, err
if useTx {
if err := tx.Commit(c); err != nil {
return nil, nil, err
}
}
return root, ps, nil
return root, ps.st, nil
}
func (c *coreContext) resolveSQL() ([]byte, uint32, error) {
tx, err := db.Begin(c)
if err != nil {
return nil, 0, err
}
defer tx.Rollback(c)
func (c *coreContext) resolveSQL() ([]byte, *stmt, error) {
var tx pgx.Tx
var err error
qt := qcode.GetQType(c.req.Query)
mutation := (qt == qcode.QTMutation)
//anonQuery := (qt == qcode.QTQuery && c.req.role == "anon")
mutation := isMutation(c.req.Query)
useRoleQuery := len(conf.RolesQuery) != 0 && mutation
useTx := useRoleQuery || conf.DB.SetUserID
if useTx {
if tx, err = db.Begin(c); err != nil {
return nil, nil, err
}
defer tx.Rollback(c)
}
if conf.DB.SetUserID {
if err := setLocalUserID(c, tx); err != nil {
return nil, nil, err
}
}
if useRoleQuery {
if c.req.role, err = c.executeRoleQuery(tx); err != nil {
return nil, 0, err
return nil, nil, err
}
} else if v := c.Value(userRoleKey); v != nil {
c.req.role = v.(string)
}
stmts, err := c.buildStmt()
stmts, err := buildStmt(qt, []byte(c.req.Query), c.req.Vars, c.req.role)
if err != nil {
return nil, 0, err
}
var st *stmt
if mutation {
st = findStmt(c.req.role, stmts)
} else {
st = &stmts[0]
return nil, nil, err
}
st := &stmts[0]
t := fasttemplate.New(st.sql, openVar, closeVar)
buf := &bytes.Buffer{}
_, err = t.ExecuteFunc(buf, argMap(c))
if err == errNoUserID {
logger.Warn().Msg("no user id found. query requires an authenicated request")
}
_, err = t.ExecuteFunc(buf, argMap(c, c.req.Vars))
if err != nil {
return nil, 0, err
return nil, nil, err
}
finalSQL := buf.String()
var stime time.Time
@ -225,202 +212,57 @@ func (c *coreContext) resolveSQL() ([]byte, uint32, error) {
stime = time.Now()
}
if conf.DB.SetUserID {
if err := c.setLocalUserID(tx); err != nil {
return nil, 0, err
var root []byte
var role string
var row pgx.Row
defaultRole := c.req.role
if useTx {
row = tx.QueryRow(c, finalSQL)
} else {
row = db.QueryRow(c, finalSQL)
}
if len(stmts) == 1 {
err = row.Scan(&root)
} else {
err = row.Scan(&role, &root)
}
if len(role) == 0 {
logger.Debug().Str("default_role", defaultRole).Msg(c.req.Query)
} else {
logger.Debug().Str("default_role", defaultRole).Str("role", role).Msg(c.req.Query)
}
if err != nil {
return nil, nil, err
}
if useTx {
if err := tx.Commit(c); err != nil {
return nil, nil, err
}
}
var root []byte
var role string
log := logger.Debug()
if mutation {
err = tx.QueryRow(c, finalSQL).Scan(&root)
log = log.Str("role", role)
} else {
err = tx.QueryRow(c, finalSQL).Scan(&role, &root)
log = log.Str("default_role", c.req.role).Str("role", role)
c.req.role = role
// if conf.Production == false {
// _allowList.add(&c.req)
// }
if len(stmts) > 1 {
if st = findStmt(role, stmts); st == nil {
return nil, nil, fmt.Errorf("invalid role '%s' returned", role)
}
}
log.Msg(c.req.Query)
if err != nil {
return nil, 0, err
}
if err := tx.Commit(c); err != nil {
return nil, 0, err
}
if mutation {
st = findStmt(c.req.role, stmts)
} else {
st = &stmts[0]
}
if conf.EnableTracing && len(st.qc.Selects) != 0 {
if conf.EnableTracing {
for _, id := range st.qc.Roots {
c.addTrace(st.qc.Selects, id, stime)
}
}
if conf.Production == false {
_allowList.add(&c.req)
}
return root, st.skipped, nil
}
func (c *coreContext) resolveRemote(
hdr http.Header,
h *xxhash.Digest,
field jsn.Field,
sel []qcode.Select,
sfmap map[uint64]*qcode.Select) ([]jsn.Field, error) {
// replacement data for the marked insertion points
// key and value will be replaced by whats below
toA := [1]jsn.Field{}
to := toA[:1]
// use the json key to find the related Select object
k1 := xxhash.Sum64(field.Key)
s, ok := sfmap[k1]
if !ok {
return nil, nil
}
p := sel[s.ParentID]
// then use the Table nme in the Select and it's parent
// to find the resolver to use for this relationship
k2 := mkkey(h, s.Table, p.Table)
r, ok := rmap[k2]
if !ok {
return nil, nil
}
id := jsn.Value(field.Value)
if len(id) == 0 {
return nil, nil
}
st := time.Now()
b, err := r.Fn(hdr, id)
if err != nil {
return nil, err
}
if conf.EnableTracing {
c.addTrace(sel, s.ID, st)
}
if len(r.Path) != 0 {
b = jsn.Strip(b, r.Path)
}
var ob bytes.Buffer
if len(s.Cols) != 0 {
err = jsn.Filter(&ob, b, colsToList(s.Cols))
if err != nil {
return nil, err
}
} else {
ob.WriteString("null")
}
to[0] = jsn.Field{Key: []byte(s.FieldName), Value: ob.Bytes()}
return to, nil
}
func (c *coreContext) resolveRemotes(
hdr http.Header,
h *xxhash.Digest,
from []jsn.Field,
sel []qcode.Select,
sfmap map[uint64]*qcode.Select) ([]jsn.Field, error) {
// replacement data for the marked insertion points
// key and value will be replaced by whats below
to := make([]jsn.Field, len(from))
var wg sync.WaitGroup
wg.Add(len(from))
var cerr error
for i, id := range from {
// use the json key to find the related Select object
k1 := xxhash.Sum64(id.Key)
s, ok := sfmap[k1]
if !ok {
return nil, nil
}
p := sel[s.ParentID]
// then use the Table nme in the Select and it's parent
// to find the resolver to use for this relationship
k2 := mkkey(h, s.Table, p.Table)
r, ok := rmap[k2]
if !ok {
return nil, nil
}
id := jsn.Value(id.Value)
if len(id) == 0 {
return nil, nil
}
go func(n int, id []byte, s *qcode.Select) {
defer wg.Done()
st := time.Now()
b, err := r.Fn(hdr, id)
if err != nil {
cerr = fmt.Errorf("%s: %s", s.Table, err)
return
}
if conf.EnableTracing {
c.addTrace(sel, s.ID, st)
}
if len(r.Path) != 0 {
b = jsn.Strip(b, r.Path)
}
var ob bytes.Buffer
if len(s.Cols) != 0 {
err = jsn.Filter(&ob, b, colsToList(s.Cols))
if err != nil {
cerr = fmt.Errorf("%s: %s", s.Table, err)
return
}
} else {
ob.WriteString("null")
}
to[n] = jsn.Field{Key: []byte(s.FieldName), Value: ob.Bytes()}
}(i, id, s)
}
wg.Wait()
return to, cerr
return root, st, nil
}
func (c *coreContext) executeRoleQuery(tx pgx.Tx) (string, error) {
@ -434,15 +276,6 @@ func (c *coreContext) executeRoleQuery(tx pgx.Tx) (string, error) {
return role, nil
}
func (c *coreContext) setLocalUserID(tx pgx.Tx) error {
var err error
if v := c.Value(userIDKey); v != nil {
_, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
}
return err
}
func (c *coreContext) render(w io.Writer, data []byte) error {
c.res.Data = json.RawMessage(data)
return json.NewEncoder(w).Encode(c.res)
@ -534,6 +367,15 @@ func parentFieldIds(h *xxhash.Digest, sel []qcode.Select, skipped uint32) (
return fm, sm
}
func setLocalUserID(c context.Context, tx pgx.Tx) error {
var err error
if v := c.Value(userIDKey); v != nil {
_, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
}
return err
}
func isSkipped(n uint32, pos uint32) bool {
return ((n & (1 << pos)) != 0)
}

View File

@ -4,6 +4,7 @@ import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
"github.com/dosco/super-graph/psql"
@ -17,136 +18,171 @@ type stmt struct {
sql string
}
func (c *coreContext) buildStmt() ([]stmt, error) {
var vars map[string]json.RawMessage
func buildStmt(qt qcode.QType, gql, vars []byte, role string) ([]stmt, error) {
switch qt {
case qcode.QTMutation:
return buildRoleStmt(gql, vars, role)
if len(c.req.Vars) != 0 {
if err := json.Unmarshal(c.req.Vars, &vars); err != nil {
case qcode.QTQuery:
switch {
case role == "anon":
return buildRoleStmt(gql, vars, role)
default:
return buildMultiStmt(gql, vars)
}
default:
return nil, fmt.Errorf("unknown query type '%d'", qt)
}
}
func buildRoleStmt(gql, vars []byte, role string) ([]stmt, error) {
ro, ok := conf.roles[role]
if !ok {
return nil, fmt.Errorf(`roles '%s' not defined in config`, role)
}
var vm map[string]json.RawMessage
var err error
if len(vars) != 0 {
if err := json.Unmarshal(vars, &vm); err != nil {
return nil, err
}
}
gql := []byte(c.req.Query)
if len(conf.Roles) == 0 {
return nil, errors.New(`no roles found ('user' and 'anon' required)`)
}
qc, err := qcompile.Compile(gql, conf.Roles[0].Name)
qc, err := qcompile.Compile(gql, ro.Name)
if err != nil {
return nil, err
}
stmts := make([]stmt, 0, len(conf.Roles))
mutation := (qc.Type != qcode.QTQuery)
// For the 'anon' role in production only compile
// queries for tables defined in the config file.
if conf.Production &&
ro.Name == "anon" &&
hasTablesWithConfig(qc, ro) == false {
return nil, errors.New("query contains tables with no 'anon' role config")
}
stmts := []stmt{stmt{role: ro, qc: qc}}
w := &bytes.Buffer{}
for i := 1; i < len(conf.Roles); i++ {
skipped, err := pcompile.Compile(qc, w, psql.Variables(vm))
if err != nil {
return nil, err
}
stmts[0].skipped = skipped
stmts[0].sql = w.String()
return stmts, nil
}
func buildMultiStmt(gql, vars []byte) ([]stmt, error) {
var vm map[string]json.RawMessage
var err error
if len(vars) != 0 {
if err := json.Unmarshal(vars, &vm); err != nil {
return nil, err
}
}
if len(conf.RolesQuery) == 0 {
return buildRoleStmt(gql, vars, "user")
}
stmts := make([]stmt, 0, len(conf.Roles))
w := &bytes.Buffer{}
for i := 0; i < len(conf.Roles); i++ {
role := &conf.Roles[i]
// For mutations only render sql for a single role from the request
if mutation && len(c.req.role) != 0 && role.Name != c.req.role {
continue
}
qc, err = qcompile.Compile(gql, role.Name)
qc, err := qcompile.Compile(gql, role.Name)
if err != nil {
return nil, err
}
if conf.Production && role.Name == "anon" {
for _, id := range qc.Roots {
root := qc.Selects[id]
if _, ok := role.tablesMap[root.Table]; !ok {
continue
}
}
}
stmts = append(stmts, stmt{role: role, qc: qc})
if mutation {
skipped, err := pcompile.Compile(qc, w, psql.Variables(vars))
if err != nil {
return nil, err
}
s := &stmts[len(stmts)-1]
s.skipped = skipped
s.sql = w.String()
w.Reset()
skipped, err := pcompile.Compile(qc, w, psql.Variables(vm))
if err != nil {
return nil, err
}
s := &stmts[len(stmts)-1]
s.skipped = skipped
s.sql = w.String()
w.Reset()
}
if mutation {
return stmts, nil
sql, err := renderUserQuery(stmts, vm)
if err != nil {
return nil, err
}
stmts[0].sql = sql
return stmts, nil
}
func renderUserQuery(
stmts []stmt, vars map[string]json.RawMessage) (string, error) {
var err error
w := &bytes.Buffer{}
io.WriteString(w, `SELECT "_sg_auth_info"."role", (CASE "_sg_auth_info"."role" `)
for _, s := range stmts {
if len(s.role.Match) == 0 &&
s.role.Name != "user" && s.role.Name != "anon" {
continue
}
io.WriteString(w, `WHEN '`)
io.WriteString(w, s.role.Name)
io.WriteString(w, `' THEN (`)
s.skipped, err = pcompile.Compile(s.qc, w, psql.Variables(vars))
if err != nil {
return nil, err
return "", err
}
io.WriteString(w, `) `)
}
io.WriteString(w, `END) FROM (`)
if len(conf.RolesQuery) == 0 {
v := c.Value(userRoleKey)
io.WriteString(w, `END) FROM (SELECT (CASE WHEN EXISTS (`)
io.WriteString(w, conf.RolesQuery)
io.WriteString(w, `) THEN `)
io.WriteString(w, `VALUES ("`)
if v != nil {
io.WriteString(w, v.(string))
} else {
io.WriteString(w, c.req.role)
io.WriteString(w, `(SELECT (CASE`)
for _, s := range stmts {
if len(s.role.Match) == 0 {
continue
}
io.WriteString(w, `")) AS "_sg_auth_info"(role) LIMIT 1;`)
} else {
io.WriteString(w, `SELECT (CASE WHEN EXISTS (`)
io.WriteString(w, conf.RolesQuery)
io.WriteString(w, `) THEN `)
io.WriteString(w, `(SELECT (CASE`)
for _, s := range stmts {
if len(s.role.Match) == 0 {
continue
}
io.WriteString(w, ` WHEN `)
io.WriteString(w, s.role.Match)
io.WriteString(w, ` THEN '`)
io.WriteString(w, s.role.Name)
io.WriteString(w, `'`)
}
if len(c.req.role) == 0 {
io.WriteString(w, ` ELSE 'anon' END) FROM (`)
} else {
io.WriteString(w, ` ELSE '`)
io.WriteString(w, c.req.role)
io.WriteString(w, `' END) FROM (`)
}
io.WriteString(w, conf.RolesQuery)
io.WriteString(w, `) AS "_sg_auth_roles_query" LIMIT 1) ELSE '`)
if len(c.req.role) == 0 {
io.WriteString(w, `anon`)
} else {
io.WriteString(w, c.req.role)
}
io.WriteString(w, `' END) FROM (VALUES (1)) AS "_sg_auth_filler") AS "_sg_auth_info"(role) LIMIT 1; `)
io.WriteString(w, ` WHEN `)
io.WriteString(w, s.role.Match)
io.WriteString(w, ` THEN '`)
io.WriteString(w, s.role.Name)
io.WriteString(w, `'`)
}
stmts[0].sql = w.String()
stmts[0].role = nil
io.WriteString(w, ` ELSE 'user' END) FROM (`)
io.WriteString(w, conf.RolesQuery)
io.WriteString(w, `) AS "_sg_auth_roles_query" LIMIT 1) `)
io.WriteString(w, `ELSE 'anon' END) FROM (VALUES (1)) AS "_sg_auth_filler") AS "_sg_auth_info"(role) LIMIT 1; `)
return stmts, nil
return w.String(), nil
}
func hasTablesWithConfig(qc *qcode.QCode, role *configRole) bool {
for _, id := range qc.Roots {
t, err := schema.GetTable(qc.Selects[id].Table)
if err != nil {
return false
}
if _, ok := role.tablesMap[t.Name]; !ok {
return false
}
}
return true
}

197
serv/core_remote.go Normal file
View File

@ -0,0 +1,197 @@
package serv
import (
"bytes"
"errors"
"fmt"
"net/http"
"sync"
"github.com/cespare/xxhash/v2"
"github.com/dosco/super-graph/jsn"
"github.com/dosco/super-graph/qcode"
)
func execRemoteJoin(st *stmt, data []byte, hdr http.Header) ([]byte, error) {
var err error
if len(data) == 0 || st.skipped == 0 {
return data, nil
}
sel := st.qc.Selects
h := xxhash.New()
// fetch the field name used within the db response json
// that are used to mark insertion points and the mapping between
// those field names and their select objects
fids, sfmap := parentFieldIds(h, sel, st.skipped)
// fetch the field values of the marked insertion points
// these values contain the id to be used with fetching remote data
from := jsn.Get(data, fids)
var to []jsn.Field
switch {
case len(from) == 1:
to, err = resolveRemote(hdr, h, from[0], sel, sfmap)
case len(from) > 1:
to, err = resolveRemotes(hdr, h, from, sel, sfmap)
default:
return nil, errors.New("something wrong no remote ids found in db response")
}
if err != nil {
return nil, err
}
var ob bytes.Buffer
err = jsn.Replace(&ob, data, from, to)
if err != nil {
return nil, err
}
return ob.Bytes(), nil
}
func resolveRemote(
hdr http.Header,
h *xxhash.Digest,
field jsn.Field,
sel []qcode.Select,
sfmap map[uint64]*qcode.Select) ([]jsn.Field, error) {
// replacement data for the marked insertion points
// key and value will be replaced by whats below
toA := [1]jsn.Field{}
to := toA[:1]
// use the json key to find the related Select object
k1 := xxhash.Sum64(field.Key)
s, ok := sfmap[k1]
if !ok {
return nil, nil
}
p := sel[s.ParentID]
// then use the Table nme in the Select and it's parent
// to find the resolver to use for this relationship
k2 := mkkey(h, s.Table, p.Table)
r, ok := rmap[k2]
if !ok {
return nil, nil
}
id := jsn.Value(field.Value)
if len(id) == 0 {
return nil, nil
}
//st := time.Now()
b, err := r.Fn(hdr, id)
if err != nil {
return nil, err
}
if len(r.Path) != 0 {
b = jsn.Strip(b, r.Path)
}
var ob bytes.Buffer
if len(s.Cols) != 0 {
err = jsn.Filter(&ob, b, colsToList(s.Cols))
if err != nil {
return nil, err
}
} else {
ob.WriteString("null")
}
to[0] = jsn.Field{Key: []byte(s.FieldName), Value: ob.Bytes()}
return to, nil
}
func resolveRemotes(
hdr http.Header,
h *xxhash.Digest,
from []jsn.Field,
sel []qcode.Select,
sfmap map[uint64]*qcode.Select) ([]jsn.Field, error) {
// replacement data for the marked insertion points
// key and value will be replaced by whats below
to := make([]jsn.Field, len(from))
var wg sync.WaitGroup
wg.Add(len(from))
var cerr error
for i, id := range from {
// use the json key to find the related Select object
k1 := xxhash.Sum64(id.Key)
s, ok := sfmap[k1]
if !ok {
return nil, nil
}
p := sel[s.ParentID]
// then use the Table nme in the Select and it's parent
// to find the resolver to use for this relationship
k2 := mkkey(h, s.Table, p.Table)
r, ok := rmap[k2]
if !ok {
return nil, nil
}
id := jsn.Value(id.Value)
if len(id) == 0 {
return nil, nil
}
go func(n int, id []byte, s *qcode.Select) {
defer wg.Done()
//st := time.Now()
b, err := r.Fn(hdr, id)
if err != nil {
cerr = fmt.Errorf("%s: %s", s.Table, err)
return
}
if len(r.Path) != 0 {
b = jsn.Strip(b, r.Path)
}
var ob bytes.Buffer
if len(s.Cols) != 0 {
err = jsn.Filter(&ob, b, colsToList(s.Cols))
if err != nil {
cerr = fmt.Errorf("%s: %s", s.Table, err)
return
}
} else {
ob.WriteString("null")
}
to[n] = jsn.Field{Key: []byte(s.FieldName), Value: ob.Bytes()}
}(i, id, s)
}
wg.Wait()
return to, cerr
}

View File

@ -4,7 +4,6 @@ package serv
func Fuzz(data []byte) int {
gql := string(data)
isMutation(gql)
gqlHash(gql, nil, "")
return 1

View File

@ -10,7 +10,6 @@ func TestFuzzCrashers(t *testing.T) {
}
for _, f := range crashers {
isMutation(f)
gqlHash(f, nil, "")
}
}

View File

@ -21,7 +21,6 @@ const (
var (
upgrader = websocket.Upgrader{}
errNoUserID = errors.New("no user_id available")
errUnauthorized = errors.New("not authorized")
)
@ -78,7 +77,7 @@ func apiv1Http(w http.ResponseWriter, r *http.Request) {
b, err := ioutil.ReadAll(io.LimitReader(r.Body, maxReadBytes))
if err != nil {
logger.Err(err).Msg("failed to read request body")
errlog.Error().Err(err).Msg("failed to read request body")
errorResp(w, err)
return
}
@ -86,7 +85,7 @@ func apiv1Http(w http.ResponseWriter, r *http.Request) {
err = json.Unmarshal(b, &ctx.req)
if err != nil {
logger.Err(err).Msg("failed to decode json request body")
errlog.Error().Err(err).Msg("failed to decode json request body")
errorResp(w, err)
return
}
@ -105,7 +104,7 @@ func apiv1Http(w http.ResponseWriter, r *http.Request) {
}
if err != nil {
logger.Err(err).Msg("failed to handle request")
errlog.Error().Err(err).Msg("failed to handle request")
errorResp(w, err)
return
}

View File

@ -3,20 +3,19 @@ package serv
import (
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"github.com/dosco/super-graph/qcode"
"github.com/jackc/pgconn"
"github.com/jackc/pgx/v4"
"github.com/valyala/fasttemplate"
)
type preparedItem struct {
stmt *pgconn.StatementDescription
args [][]byte
skipped uint32
qc *qcode.QCode
sd *pgconn.StatementDescription
args [][]byte
st *stmt
}
var (
@ -24,83 +23,119 @@ var (
)
func initPreparedList() {
c := context.Background()
_preparedList = make(map[string]*preparedItem)
if err := prepareRoleStmt(); err != nil {
logger.Fatal().Err(err).Msg("failed to prepare get role statement")
tx, err := db.Begin(c)
if err != nil {
errlog.Fatal().Err(err).Send()
}
defer tx.Rollback(c)
err = prepareRoleStmt(c, tx)
if err != nil {
errlog.Fatal().Err(err).Msg("failed to prepare get role statement")
}
if err := tx.Commit(c); err != nil {
errlog.Fatal().Err(err).Send()
}
success := 0
for _, v := range _allowList.list {
err := prepareStmt(v.gql, v.vars)
if err != nil {
logger.Warn().Str("gql", v.gql).Err(err).Send()
}
}
}
func prepareStmt(gql string, varBytes json.RawMessage) error {
if len(gql) == 0 {
return nil
}
c := &coreContext{Context: context.Background()}
c.req.Query = gql
c.req.Vars = varBytes
stmts, err := c.buildStmt()
if err != nil {
return err
}
if len(stmts) != 0 && stmts[0].qc.Type == qcode.QTQuery {
c.req.Vars = nil
}
for _, s := range stmts {
if len(s.sql) == 0 {
if len(v.gql) == 0 {
continue
}
finalSQL, am := processTemplate(s.sql)
ctx := context.Background()
tx, err := db.Begin(ctx)
if err != nil {
return err
}
defer tx.Rollback(ctx)
pstmt, err := tx.Prepare(ctx, "", finalSQL)
if err != nil {
return err
err := prepareStmt(c, v.gql, v.vars)
if err == nil {
success++
continue
}
var key string
if s.role == nil {
key = gqlHash(gql, c.req.Vars, "")
if len(v.vars) == 0 {
logger.Warn().Err(err).Msg(v.gql)
} else {
key = gqlHash(gql, c.req.Vars, s.role.Name)
logger.Warn().Err(err).Msgf("%s %s", v.vars, v.gql)
}
}
_preparedList[key] = &preparedItem{
stmt: pstmt,
args: am,
skipped: s.skipped,
qc: s.qc,
}
logger.Info().
Msgf("Registered %d of %d queries from allow.list as prepared statements",
success, len(_allowList.list))
}
if err := tx.Commit(ctx); err != nil {
func prepareStmt(c context.Context, gql string, vars []byte) error {
qt := qcode.GetQType(gql)
q := []byte(gql)
tx, err := db.Begin(c)
if err != nil {
return err
}
defer tx.Rollback(c)
switch qt {
case qcode.QTQuery:
stmts1, err := buildMultiStmt(q, vars)
if err != nil {
return err
}
err = prepare(c, tx, &stmts1[0], gqlHash(gql, vars, "user"))
if err != nil {
return err
}
stmts2, err := buildRoleStmt(q, vars, "anon")
if err != nil {
return err
}
err = prepare(c, tx, &stmts2[0], gqlHash(gql, vars, "anon"))
if err != nil {
return err
}
case qcode.QTMutation:
for _, role := range conf.Roles {
stmts, err := buildRoleStmt(q, vars, role.Name)
if err != nil {
return err
}
err = prepare(c, tx, &stmts[0], gqlHash(gql, vars, role.Name))
if err != nil {
return err
}
}
}
if err := tx.Commit(c); err != nil {
return err
}
return nil
}
func prepareRoleStmt() error {
func prepare(c context.Context, tx pgx.Tx, st *stmt, key string) error {
finalSQL, am := processTemplate(st.sql)
sd, err := tx.Prepare(c, "", finalSQL)
if err != nil {
return err
}
_preparedList[key] = &preparedItem{
sd: sd,
args: am,
st: st,
}
return nil
}
func prepareRoleStmt(c context.Context, tx pgx.Tx) error {
if len(conf.RolesQuery) == 0 {
return nil
}
@ -125,15 +160,7 @@ func prepareRoleStmt() error {
roleSQL, _ := processTemplate(w.String())
ctx := context.Background()
tx, err := db.Begin(ctx)
if err != nil {
return err
}
defer tx.Rollback(ctx)
_, err = tx.Prepare(ctx, "_sg_get_role", roleSQL)
_, err := tx.Prepare(c, "_sg_get_role", roleSQL)
if err != nil {
return err
}
@ -142,19 +169,31 @@ func prepareRoleStmt() error {
}
func processTemplate(tmpl string) (string, [][]byte) {
t := fasttemplate.New(tmpl, `{{`, `}}`)
am := make([][]byte, 0, 5)
i := 0
st := struct {
vmap map[string]int
am [][]byte
i int
}{
vmap: make(map[string]int),
am: make([][]byte, 0, 5),
i: 0,
}
vmap := make(map[string]int)
return t.ExecuteFuncString(func(w io.Writer, tag string) (int, error) {
if n, ok := vmap[tag]; ok {
execFunc := func(w io.Writer, tag string) (int, error) {
if n, ok := st.vmap[tag]; ok {
return w.Write([]byte(fmt.Sprintf("$%d", n)))
}
am = append(am, []byte(tag))
i++
vmap[tag] = i
return w.Write([]byte(fmt.Sprintf("$%d", i)))
}), am
st.am = append(st.am, []byte(tag))
st.i++
st.vmap[tag] = st.i
return w.Write([]byte(fmt.Sprintf("$%d", st.i)))
}
t1 := fasttemplate.New(tmpl, `'{{`, `}}'`)
ts1 := t1.ExecuteFuncString(execFunc)
t2 := fasttemplate.New(ts1, `{{`, `}}`)
ts2 := t2.ExecuteFuncString(execFunc)
return ts2, st.am
}

View File

@ -168,7 +168,7 @@ func Do(log func(string, ...interface{}), additional ...dir) error {
func ReExec() {
err := syscall.Exec(binSelf, append([]string{binSelf}, os.Args[1:]...), os.Environ())
if err != nil {
logger.Fatal().Err(err).Msg("cannot restart")
errlog.Fatal().Err(err).Msg("cannot restart")
}
}

View File

@ -117,7 +117,7 @@ func buildFn(r configRemote) func(http.Header, []byte) ([]byte, error) {
res, err := client.Do(req)
if err != nil {
logger.Error().Err(err).Msgf("Failed to connect to: %s", uri)
errlog.Error().Err(err).Msgf("Failed to connect to: %s", uri)
return nil, err
}
defer res.Body.Close()

View File

@ -15,13 +15,15 @@ import (
)
func initCompilers(c *config) (*qcode.Compiler, *psql.Compiler, error) {
schema, err := psql.NewDBSchema(db, c.getAliasMap())
var err error
schema, err = psql.NewDBSchema(db, c.getAliasMap())
if err != nil {
return nil, nil, err
}
conf := qcode.Config{
Blocklist: c.DB.Defaults.Blocklist,
Blocklist: c.DB.Blocklist,
KeepArgs: false,
}
@ -106,7 +108,7 @@ func initWatcher(cpath string) {
go func() {
err := Do(logger.Printf, d)
if err != nil {
logger.Fatal().Err(err).Send()
errlog.Fatal().Err(err).Send()
}
}()
}
@ -139,7 +141,7 @@ func startHTTP() {
<-sigint
if err := srv.Shutdown(context.Background()); err != nil {
logger.Error().Err(err).Msg("shutdown signal received")
errlog.Error().Err(err).Msg("shutdown signal received")
}
close(idleConnsClosed)
}()
@ -148,18 +150,14 @@ func startHTTP() {
db.Close()
})
var ident string
if len(conf.AppName) == 0 {
ident = conf.Env
} else {
ident = conf.AppName
}
fmt.Printf("%s listening on %s (%s)\n", serverName, hostPort, ident)
logger.Info().
Str("host_post", hostPort).
Str("app_name", conf.AppName).
Str("env", conf.Env).
Msgf("%s listening", serverName)
if err := srv.ListenAndServe(); err != http.ErrServerClosed {
logger.Error().Err(err).Msg("server closed")
errlog.Error().Err(err).Msg("server closed")
}
<-idleConnsClosed

View File

@ -28,9 +28,7 @@ func (pl *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data
zlevel = zerolog.ErrorLevel
case pgx.LogLevelWarn:
zlevel = zerolog.WarnLevel
case pgx.LogLevelInfo:
zlevel = zerolog.InfoLevel
case pgx.LogLevelDebug:
case pgx.LogLevelDebug, pgx.LogLevelInfo:
zlevel = zerolog.DebugLevel
default:
zlevel = zerolog.DebugLevel

View File

@ -106,19 +106,6 @@ func al(b byte) bool {
return (b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z') || (b >= '0' && b <= '9')
}
func isMutation(sql string) bool {
for i := range sql {
b := sql[i]
if b == '{' {
return false
}
if al(b) {
return (b == 'm' || b == 'M')
}
}
return false
}
func findStmt(role string, stmts []stmt) *stmt {
for i := range stmts {
if stmts[i].role.Name != role {

View File

@ -77,7 +77,7 @@ SQL Output
database:
variables:
account_id: "select account_id from users where id = $user_id"
admin_account_id: "5"
defaults:
Filters: ["{ user_id: { eq: $user_id } }"]

View File

@ -3,7 +3,7 @@ host_port: 0.0.0.0:8080
web_ui: true
# debug, info, warn, error, fatal, panic
log_level: "debug"
log_level: "info"
# When production mode is 'true' only queries
# from the allow list are permitted.
@ -97,23 +97,18 @@ database:
# Enable this if you need the user id in triggers, etc
set_user_id: false
# Define variables here that you want to use in filters
# sub-queries must be wrapped in ()
# Define additional variables here to be used with filters
variables:
account_id: "(select account_id from users where id = $user_id)"
admin_account_id: "5"
# Define defaults to for the field key and values below
defaults:
# filters: ["{ user_id: { eq: $user_id } }"]
# Field and table names that you wish to block
blocklist:
- ar_internal_metadata
- schema_migrations
- secret
- password
- encrypted
- token
# Field and table names that you wish to block
blocklist:
- ar_internal_metadata
- schema_migrations
- secret
- password
- encrypted
- token
tables:
- name: customers
@ -185,11 +180,10 @@ roles:
- updated_at: "now"
delete:
deny: true
block: true
- name: admin
match: id = 1
match: id = 1000
tables:
- name: users
# query:
# filters: ["{ account_id: { _eq: $account_id } }"]
filters: []

View File

@ -7,7 +7,7 @@ host_port: 0.0.0.0:8080
web_ui: false
# debug, info, warn, error, fatal, panic, disable
log_level: "info"
log_level: "warn"
# When production mode is 'true' only queries
# from the allow list are permitted.
# When it's 'false' all queries are saved to the