Compare commits
13 Commits
Author | SHA1 | Date | |
---|---|---|---|
aff2a13ba4 | |||
f518d5fc69 | |||
7583326d21 | |||
8024a8420b | |||
6bd27d7c79 | |||
a4c09dedd5 | |||
176514b5f1 | |||
396f4bcfd8 | |||
51c5f037f4 | |||
c576f1ae16 | |||
0a6995eaca | |||
06ed823a51 | |||
3438c3efb9 |
@ -73,43 +73,6 @@ mutation {
|
||||
}
|
||||
}
|
||||
|
||||
variables {
|
||||
"data": [
|
||||
{
|
||||
"name": "Gumbo1",
|
||||
"created_at": "now",
|
||||
"updated_at": "now"
|
||||
},
|
||||
{
|
||||
"name": "Gumbo2",
|
||||
"created_at": "now",
|
||||
"updated_at": "now"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
query {
|
||||
products {
|
||||
id
|
||||
name
|
||||
}
|
||||
}
|
||||
|
||||
variables {
|
||||
"data": [
|
||||
{
|
||||
"name": "Gumbo1",
|
||||
"created_at": "now",
|
||||
"updated_at": "now"
|
||||
},
|
||||
{
|
||||
"name": "Gumbo2",
|
||||
"created_at": "now",
|
||||
"updated_at": "now"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
query {
|
||||
products {
|
||||
id
|
||||
@ -133,21 +96,6 @@ mutation {
|
||||
}
|
||||
}
|
||||
|
||||
variables {
|
||||
"data": [
|
||||
{
|
||||
"name": "Gumbo1",
|
||||
"created_at": "now",
|
||||
"updated_at": "now"
|
||||
},
|
||||
{
|
||||
"name": "Gumbo2",
|
||||
"created_at": "now",
|
||||
"updated_at": "now"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
query {
|
||||
products {
|
||||
id
|
||||
@ -174,39 +122,6 @@ mutation {
|
||||
}
|
||||
}
|
||||
|
||||
variables {
|
||||
"data": [
|
||||
{
|
||||
"name": "Gumbo1",
|
||||
"created_at": "now",
|
||||
"updated_at": "now"
|
||||
},
|
||||
{
|
||||
"name": "Gumbo2",
|
||||
"created_at": "now",
|
||||
"updated_at": "now"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
query {
|
||||
products {
|
||||
id
|
||||
name
|
||||
users {
|
||||
email
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
query {
|
||||
me {
|
||||
id
|
||||
email
|
||||
full_name
|
||||
}
|
||||
}
|
||||
|
||||
variables {
|
||||
"update": {
|
||||
"name": "Helloo",
|
||||
@ -224,66 +139,23 @@ mutation {
|
||||
}
|
||||
|
||||
variables {
|
||||
"data": [
|
||||
{
|
||||
"name": "Gumbo1",
|
||||
"created_at": "now",
|
||||
"updated_at": "now"
|
||||
},
|
||||
{
|
||||
"name": "Gumbo2",
|
||||
"created_at": "now",
|
||||
"updated_at": "now"
|
||||
}
|
||||
]
|
||||
"data": {
|
||||
"name": "WOOO",
|
||||
"price": 50.5
|
||||
}
|
||||
}
|
||||
|
||||
query {
|
||||
product {
|
||||
mutation {
|
||||
products(insert: $data) {
|
||||
id
|
||||
name
|
||||
}
|
||||
}
|
||||
|
||||
variables {
|
||||
"data": [
|
||||
{
|
||||
"name": "Gumbo1",
|
||||
"created_at": "now",
|
||||
"updated_at": "now"
|
||||
},
|
||||
{
|
||||
"name": "Gumbo2",
|
||||
"created_at": "now",
|
||||
"updated_at": "now"
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
query {
|
||||
query getProducts {
|
||||
products {
|
||||
id
|
||||
name
|
||||
description
|
||||
users {
|
||||
email
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
query {
|
||||
users {
|
||||
id
|
||||
email
|
||||
picture: avatar
|
||||
password
|
||||
full_name
|
||||
products(limit: 2, where: {price: {gt: 10}}) {
|
||||
id
|
||||
name
|
||||
description
|
||||
price
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -9,7 +9,7 @@ log_level: "debug"
|
||||
# from the allow list are permitted.
|
||||
# When it's 'false' all queries are saved to the
|
||||
# the allow list in ./config/allow.list
|
||||
production: true
|
||||
production: false
|
||||
|
||||
# Throw a 401 on auth failure for queries that need auth
|
||||
auth_fail_block: false
|
||||
@ -97,23 +97,18 @@ database:
|
||||
# Enable this if you need the user id in triggers, etc
|
||||
set_user_id: false
|
||||
|
||||
# Define variables here that you want to use in filters
|
||||
# sub-queries must be wrapped in ()
|
||||
# Define additional variables here to be used with filters
|
||||
variables:
|
||||
account_id: "(select account_id from users where id = $user_id)"
|
||||
admin_account_id: "5"
|
||||
|
||||
# Define defaults to for the field key and values below
|
||||
defaults:
|
||||
# filters: ["{ user_id: { eq: $user_id } }"]
|
||||
|
||||
# Field and table names that you wish to block
|
||||
blocklist:
|
||||
- ar_internal_metadata
|
||||
- schema_migrations
|
||||
- secret
|
||||
- password
|
||||
- encrypted
|
||||
- token
|
||||
# Field and table names that you wish to block
|
||||
blocklist:
|
||||
- ar_internal_metadata
|
||||
- schema_migrations
|
||||
- secret
|
||||
- password
|
||||
- encrypted
|
||||
- token
|
||||
|
||||
tables:
|
||||
- name: customers
|
||||
@ -141,6 +136,7 @@ roles_query: "SELECT * FROM users WHERE id = $user_id"
|
||||
roles:
|
||||
- name: anon
|
||||
tables:
|
||||
- name: users
|
||||
- name: products
|
||||
limit: 10
|
||||
|
||||
@ -174,8 +170,10 @@ roles:
|
||||
filters: ["{ user_id: { eq: $user_id } }"]
|
||||
columns: ["id", "name", "description" ]
|
||||
presets:
|
||||
- user_id: "$user_id"
|
||||
- created_at: "now"
|
||||
|
||||
- updated_at: "now"
|
||||
|
||||
update:
|
||||
filters: ["{ user_id: { eq: $user_id } }"]
|
||||
columns:
|
||||
@ -188,8 +186,7 @@ roles:
|
||||
block: true
|
||||
|
||||
- name: admin
|
||||
match: id = 1
|
||||
match: id = 1000
|
||||
tables:
|
||||
- name: users
|
||||
# query:
|
||||
# filters: ["{ account_id: { _eq: $account_id } }"]
|
||||
filters: []
|
||||
|
@ -41,40 +41,6 @@ enable_tracing: true
|
||||
# SG_AUTH_RAILS_REDIS_PASSWORD
|
||||
# SG_AUTH_JWT_PUBLIC_KEY_FILE
|
||||
|
||||
# inflections:
|
||||
# person: people
|
||||
# sheep: sheep
|
||||
|
||||
auth:
|
||||
# Can be 'rails' or 'jwt'
|
||||
type: rails
|
||||
cookie: _app_session
|
||||
|
||||
rails:
|
||||
# Rails version this is used for reading the
|
||||
# various cookies formats.
|
||||
version: 5.2
|
||||
|
||||
# Found in 'Rails.application.config.secret_key_base'
|
||||
secret_key_base: 0a248500a64c01184edb4d7ad3a805488f8097ac761b76aaa6c17c01dcb7af03a2f18ba61b2868134b9c7b79a122bc0dadff4367414a2d173297bfea92be5566
|
||||
|
||||
# Remote cookie store. (memcache or redis)
|
||||
# url: redis://127.0.0.1:6379
|
||||
# password: test
|
||||
# max_idle: 80,
|
||||
# max_active: 12000,
|
||||
|
||||
# In most cases you don't need these
|
||||
# salt: "encrypted cookie"
|
||||
# sign_salt: "signed encrypted cookie"
|
||||
# auth_salt: "authenticated encrypted cookie"
|
||||
|
||||
# jwt:
|
||||
# provider: auth0
|
||||
# secret: abc335bfcfdb04e50db5bb0a4d67ab9
|
||||
# public_key_file: /secrets/public_key.pem
|
||||
# public_key_type: ecdsa #rsa
|
||||
|
||||
database:
|
||||
type: postgres
|
||||
host: db
|
||||
|
@ -46,10 +46,10 @@ for (i = 0; i < product_count; i++) {
|
||||
var data = {
|
||||
name: fake.beer_name(),
|
||||
description: desc,
|
||||
price: fake.price(),
|
||||
user_id: user.id,
|
||||
created_at: "now",
|
||||
updated_at: "now"
|
||||
price: fake.price()
|
||||
//user_id: user.id,
|
||||
//created_at: "now",
|
||||
//updated_at: "now"
|
||||
}
|
||||
|
||||
var res = graphql(" \
|
||||
@ -57,7 +57,9 @@ for (i = 0; i < product_count; i++) {
|
||||
product(insert: $data) { \
|
||||
id \
|
||||
} \
|
||||
}", { data: data })
|
||||
}", { data: data }, {
|
||||
user_id: 5
|
||||
})
|
||||
products.push(res.product)
|
||||
}
|
||||
|
||||
|
@ -24,6 +24,12 @@
|
||||
:item="actionLink"
|
||||
/>
|
||||
|
||||
<a
|
||||
class="px-4 py-3 my-8 border-2 border-gray-500 text-gray-600 font-bold rounded"
|
||||
href="https://github.com/dosco/super-graph"
|
||||
target="_blank"
|
||||
>Github</a>
|
||||
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
@ -1,6 +1,11 @@
|
||||
let ogprefix = 'og: http://ogp.me/ns#'
|
||||
let title = 'Super Graph'
|
||||
let description = 'An instant GraphQL API for your app. No code needed.'
|
||||
let color = '#f42525'
|
||||
|
||||
module.exports = {
|
||||
title: 'Super Graph',
|
||||
description: 'Get an instant GraphQL API for your Rails apps.',
|
||||
title: title,
|
||||
description: description,
|
||||
|
||||
themeConfig: {
|
||||
logo: '/hologram.svg',
|
||||
@ -15,6 +20,22 @@ module.exports = {
|
||||
serviceWorker: {
|
||||
updatePopup: true
|
||||
},
|
||||
|
||||
head: [
|
||||
//['link', { rel: 'icon', href: `/assets/favicon.ico` }],
|
||||
['meta', { prefix: ogprefix, property: 'og:title', content: title }],
|
||||
['meta', { prefix: ogprefix, property: 'twitter:title', content: title }],
|
||||
['meta', { prefix: ogprefix, property: 'og:type', content: 'website' }],
|
||||
['meta', { prefix: ogprefix, property: 'og:url', content: 'https://supergraph.dev' }],
|
||||
['meta', { prefix: ogprefix, property: 'og:description', content: description }],
|
||||
//['meta', { prefix: ogprefix, property: 'og:image', content: 'https://wireupyourfrontend.com/assets/logo.png' }],
|
||||
// ['meta', { name: 'apple-mobile-web-app-capable', content: 'yes' }],
|
||||
// ['meta', { name: 'apple-mobile-web-app-status-bar-style', content: 'black' }],
|
||||
// ['link', { rel: 'apple-touch-icon', href: `/assets/apple-touch-icon.png` }],
|
||||
// ['link', { rel: 'mask-icon', href: '/assets/safari-pinned-tab.svg', color: color }],
|
||||
// ['meta', { name: 'msapplication-TileImage', content: '/assets/mstile-150x150.png' }],
|
||||
// ['meta', { name: 'msapplication-TileColor', content: color }],
|
||||
],
|
||||
},
|
||||
|
||||
postcss: {
|
||||
|
112
docs/guide.md
112
docs/guide.md
@ -276,7 +276,7 @@ transmission_gear_type
|
||||
// Text
|
||||
word
|
||||
sentence
|
||||
paragrph
|
||||
paragraph
|
||||
question
|
||||
quote
|
||||
|
||||
@ -476,6 +476,21 @@ query {
|
||||
}
|
||||
```
|
||||
|
||||
Multiple tables can also be fetched using a single GraphQL query. This is very fast since the entire query is converted into a single SQL query which the database can efficiently run.
|
||||
|
||||
```graphql
|
||||
query {
|
||||
user {
|
||||
full_name
|
||||
email
|
||||
}
|
||||
products {
|
||||
name
|
||||
description
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### Fetching data
|
||||
|
||||
To fetch a specific `product` by it's ID you can use the `id` argument. The real name id field will be resolved automatically so this query will work even if your id column is named something like `product_id`.
|
||||
@ -908,6 +923,40 @@ class AddSearchColumn < ActiveRecord::Migration[5.1]
|
||||
end
|
||||
```
|
||||
|
||||
## GraphQL with React
|
||||
|
||||
This is a quick simple example using `graphql.js` [https://github.com/f/graphql.js/](https://github.com/f/graphql.js/)
|
||||
|
||||
```js
|
||||
import React, { useState, useEffect } from 'react'
|
||||
import graphql from 'graphql.js'
|
||||
|
||||
// Create a GraphQL client pointing to Super Graph
|
||||
var graph = graphql("http://localhost:3000/api/v1/graphql", { asJSON: true })
|
||||
|
||||
const App = () => {
|
||||
const [user, setUser] = useState(null)
|
||||
|
||||
useEffect(() => {
|
||||
async function action() {
|
||||
// Use the GraphQL client to execute a graphQL query
|
||||
// The second argument to the client are the variables you need to pass
|
||||
const result = await graph(`{ user { id first_name last_name picture_url } }`)()
|
||||
setUser(result)
|
||||
}
|
||||
action()
|
||||
}, []);
|
||||
|
||||
return (
|
||||
<div className="App">
|
||||
<h1>{ JSON.stringify(user) }</h1>
|
||||
</div>
|
||||
);
|
||||
}
|
||||
```
|
||||
|
||||
export default App;
|
||||
|
||||
## Authentication
|
||||
|
||||
You can only have one type of auth enabled. You can either pick Rails or JWT.
|
||||
@ -1034,25 +1083,6 @@ must be run to help figure out a users role. This query can be as complex as you
|
||||
|
||||
The individual roles are defined under the `roles` parameter and this includes each table the role has a custom setting for. The role is dynamically matched using the `match` parameter for example in the above case `users.id = 1` means that when the `roles_query` is executed a user with the id `1` willbe assigned the admin role and those that don't match get the `user` role if authenticated successfully or the `anon` role.
|
||||
|
||||
This below example would work for SAAS apps where an account (tenant) is usually the top parent table to everything else.
|
||||
|
||||
```yaml
|
||||
roles_query: "SELECT * FROM users JOIN accounts on accounts.id = users.account_id WHERE users.id = $user_id"
|
||||
|
||||
roles:
|
||||
- name: user
|
||||
tables:
|
||||
- name: users
|
||||
...
|
||||
|
||||
- name: admin
|
||||
match: accounts.admin_id = $user_id
|
||||
tables:
|
||||
- name: users
|
||||
query:
|
||||
filters: [{ accounts: { id: { eq: $account_id } } }]
|
||||
```
|
||||
|
||||
## Remote Joins
|
||||
|
||||
It often happens that after fetching some data from the DB we need to call another API to fetch some more data and all this combined into a single JSON response. For example along with a list of users you need their last 5 payments from Stripe. This requires you to query your DB for the users and Stripe for the payments. Super Graph handles all this for you also only the fields you requested from the Stripe API are returned.
|
||||
@ -1149,11 +1179,11 @@ web_ui: true
|
||||
# debug, info, warn, error, fatal, panic
|
||||
log_level: "debug"
|
||||
|
||||
# Disable this in development to get a list of
|
||||
# queries used. When enabled super graph
|
||||
# will only allow queries from this list
|
||||
# List saved to ./config/allow.list
|
||||
use_allow_list: false
|
||||
# When production mode is 'true' only queries
|
||||
# from the allow list are permitted.
|
||||
# When it's 'false' all queries are saved to the
|
||||
# the allow list in ./config/allow.list
|
||||
production: false
|
||||
|
||||
# Throw a 401 on auth failure for queries that need auth
|
||||
auth_fail_block: false
|
||||
@ -1241,23 +1271,18 @@ database:
|
||||
# Enable this if you need the user id in triggers, etc
|
||||
set_user_id: false
|
||||
|
||||
# Define variables here that you want to use in filters
|
||||
# sub-queries must be wrapped in ()
|
||||
# Define additional variables here to be used with filters
|
||||
variables:
|
||||
account_id: "(select account_id from users where id = $user_id)"
|
||||
admin_account_id: "5"
|
||||
|
||||
# Define defaults to for the field key and values below
|
||||
defaults:
|
||||
# filters: ["{ user_id: { eq: $user_id } }"]
|
||||
|
||||
# Field and table names that you wish to block
|
||||
blocklist:
|
||||
- ar_internal_metadata
|
||||
- schema_migrations
|
||||
- secret
|
||||
- password
|
||||
- encrypted
|
||||
- token
|
||||
# Field and table names that you wish to block
|
||||
blocklist:
|
||||
- ar_internal_metadata
|
||||
- schema_migrations
|
||||
- secret
|
||||
- password
|
||||
- encrypted
|
||||
- token
|
||||
|
||||
tables:
|
||||
- name: customers
|
||||
@ -1329,14 +1354,13 @@ roles:
|
||||
- updated_at: "now"
|
||||
|
||||
delete:
|
||||
deny: true
|
||||
block: true
|
||||
|
||||
- name: admin
|
||||
match: id = 1
|
||||
match: id = 1000
|
||||
tables:
|
||||
- name: users
|
||||
# query:
|
||||
# filters: ["{ account_id: { _eq: $account_id } }"]
|
||||
filters: []
|
||||
|
||||
```
|
||||
|
||||
|
@ -374,10 +374,6 @@ func TestKeys2(t *testing.T) {
|
||||
"id", "posts", "title", "description", "full_name", "email", "books", "name", "description",
|
||||
}
|
||||
|
||||
// for i := range fields {
|
||||
// fmt.Println("-->", string(fields[i]))
|
||||
// }
|
||||
|
||||
if len(exp) != len(fields) {
|
||||
t.Errorf("Expected %d fields %d", len(exp), len(fields))
|
||||
}
|
||||
|
@ -16,7 +16,7 @@ import (
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var migrationPattern = regexp.MustCompile(`\A(\d+)_.+\.sql\z`)
|
||||
var migrationPattern = regexp.MustCompile(`\A(\d+)_[^\.]+\.sql\z`)
|
||||
|
||||
var ErrNoFwMigration = errors.Errorf("no sql in forward migration step")
|
||||
|
||||
@ -127,7 +127,7 @@ func FindMigrationsEx(path string, fs MigratorFS) ([]string, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
mcount := len(paths) + 100
|
||||
mcount := len(paths)
|
||||
|
||||
if n < int64(mcount) {
|
||||
return nil, fmt.Errorf("Duplicate migration %d", n)
|
||||
|
@ -77,9 +77,9 @@ func (c *compilerContext) renderInsert(qc *qcode.QCode, w io.Writer,
|
||||
return 0, err
|
||||
}
|
||||
|
||||
io.WriteString(c.w, `(WITH "input" AS (SELECT {{`)
|
||||
io.WriteString(c.w, `(WITH "input" AS (SELECT '{{`)
|
||||
io.WriteString(c.w, qc.ActionVar)
|
||||
io.WriteString(c.w, `}}::json AS j) INSERT INTO `)
|
||||
io.WriteString(c.w, `}}' :: json AS j) INSERT INTO `)
|
||||
quoted(c.w, ti.Name)
|
||||
io.WriteString(c.w, ` (`)
|
||||
c.renderInsertUpdateColumns(qc, w, jt, ti, false)
|
||||
@ -174,9 +174,9 @@ func (c *compilerContext) renderUpdate(qc *qcode.QCode, w io.Writer,
|
||||
return 0, err
|
||||
}
|
||||
|
||||
io.WriteString(c.w, `(WITH "input" AS (SELECT {{`)
|
||||
io.WriteString(c.w, `(WITH "input" AS (SELECT '{{`)
|
||||
io.WriteString(c.w, qc.ActionVar)
|
||||
io.WriteString(c.w, `}}::json AS j) UPDATE `)
|
||||
io.WriteString(c.w, `}}' :: json AS j) UPDATE `)
|
||||
quoted(c.w, ti.Name)
|
||||
io.WriteString(c.w, ` SET (`)
|
||||
c.renderInsertUpdateColumns(qc, w, jt, ti, false)
|
||||
|
@ -12,7 +12,7 @@ func simpleInsert(t *testing.T) {
|
||||
}
|
||||
}`
|
||||
|
||||
sql := `WITH "users" AS (WITH "input" AS (SELECT {{data}}::json AS j) INSERT INTO "users" ("full_name", "email") SELECT "full_name", "email" FROM input i, json_populate_record(NULL::users, i.j) t RETURNING *) SELECT json_object_agg('user', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."id" AS "id") AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."id" FROM "users") AS "users_0") AS "done_1337"`
|
||||
sql := `WITH "users" AS (WITH "input" AS (SELECT '{{data}}' :: json AS j) INSERT INTO "users" ("full_name", "email") SELECT "full_name", "email" FROM input i, json_populate_record(NULL::users, i.j) t RETURNING *) SELECT json_object_agg('user', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "users_0"."id" AS "id") AS "json_row_0")) AS "json_0" FROM (SELECT "users"."id" FROM "users" LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "sel_0"`
|
||||
|
||||
vars := map[string]json.RawMessage{
|
||||
"data": json.RawMessage(`{"email": "reannagreenholt@orn.com", "full_name": "Flo Barton"}`),
|
||||
@ -36,7 +36,7 @@ func singleInsert(t *testing.T) {
|
||||
}
|
||||
}`
|
||||
|
||||
sql := `WITH "products" AS (WITH "input" AS (SELECT {{insert}}::json AS j) INSERT INTO "products" ("name", "description", "user_id") SELECT "name", "description", "user_id" FROM input i, json_populate_record(NULL::products, i.j) t RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"`
|
||||
sql := `WITH "products" AS (WITH "input" AS (SELECT '{{insert}}' :: json AS j) INSERT INTO "products" ("name", "description", "user_id") SELECT "name", "description", "user_id" FROM input i, json_populate_record(NULL::products, i.j) t RETURNING *) SELECT json_object_agg('product', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "sel_0"`
|
||||
|
||||
vars := map[string]json.RawMessage{
|
||||
"insert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc", "user_id": 5 }`),
|
||||
@ -60,7 +60,7 @@ func bulkInsert(t *testing.T) {
|
||||
}
|
||||
}`
|
||||
|
||||
sql := `WITH "products" AS (WITH "input" AS (SELECT {{insert}}::json AS j) INSERT INTO "products" ("name", "description") SELECT "name", "description" FROM input i, json_populate_recordset(NULL::products, i.j) t RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"`
|
||||
sql := `WITH "products" AS (WITH "input" AS (SELECT '{{insert}}' :: json AS j) INSERT INTO "products" ("name", "description") SELECT "name", "description" FROM input i, json_populate_recordset(NULL::products, i.j) t RETURNING *) SELECT json_object_agg('product', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "sel_0"`
|
||||
|
||||
vars := map[string]json.RawMessage{
|
||||
"insert": json.RawMessage(` [{ "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }]`),
|
||||
@ -84,7 +84,7 @@ func singleUpsert(t *testing.T) {
|
||||
}
|
||||
}`
|
||||
|
||||
sql := `WITH "products" AS (WITH "input" AS (SELECT {{upsert}}::json AS j) INSERT INTO "products" ("name", "description") SELECT "name", "description" FROM input i, json_populate_record(NULL::products, i.j) t ON CONFLICT (id) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"`
|
||||
sql := `WITH "products" AS (WITH "input" AS (SELECT '{{upsert}}' :: json AS j) INSERT INTO "products" ("name", "description") SELECT "name", "description" FROM input i, json_populate_record(NULL::products, i.j) t ON CONFLICT (id) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "sel_0"`
|
||||
|
||||
vars := map[string]json.RawMessage{
|
||||
"upsert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`),
|
||||
@ -108,7 +108,7 @@ func singleUpsertWhere(t *testing.T) {
|
||||
}
|
||||
}`
|
||||
|
||||
sql := `WITH "products" AS (WITH "input" AS (SELECT {{upsert}}::json AS j) INSERT INTO "products" ("name", "description") SELECT "name", "description" FROM input i, json_populate_record(NULL::products, i.j) t ON CONFLICT (id) WHERE (("products"."price") > 3) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"`
|
||||
sql := `WITH "products" AS (WITH "input" AS (SELECT '{{upsert}}' :: json AS j) INSERT INTO "products" ("name", "description") SELECT "name", "description" FROM input i, json_populate_record(NULL::products, i.j) t ON CONFLICT (id) WHERE (("products"."price") > 3) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "sel_0"`
|
||||
|
||||
vars := map[string]json.RawMessage{
|
||||
"upsert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`),
|
||||
@ -132,7 +132,7 @@ func bulkUpsert(t *testing.T) {
|
||||
}
|
||||
}`
|
||||
|
||||
sql := `WITH "products" AS (WITH "input" AS (SELECT {{upsert}}::json AS j) INSERT INTO "products" ("name", "description") SELECT "name", "description" FROM input i, json_populate_recordset(NULL::products, i.j) t ON CONFLICT (id) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"`
|
||||
sql := `WITH "products" AS (WITH "input" AS (SELECT '{{upsert}}' :: json AS j) INSERT INTO "products" ("name", "description") SELECT "name", "description" FROM input i, json_populate_recordset(NULL::products, i.j) t ON CONFLICT (id) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "sel_0"`
|
||||
|
||||
vars := map[string]json.RawMessage{
|
||||
"upsert": json.RawMessage(` [{ "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }]`),
|
||||
@ -156,7 +156,7 @@ func singleUpdate(t *testing.T) {
|
||||
}
|
||||
}`
|
||||
|
||||
sql := `WITH "products" AS (WITH "input" AS (SELECT {{update}}::json AS j) UPDATE "products" SET ("name", "description") = (SELECT "name", "description" FROM input i, json_populate_record(NULL::products, i.j) t) WHERE (("products"."id") = 1) AND (("products"."id") = 15) RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"`
|
||||
sql := `WITH "products" AS (WITH "input" AS (SELECT '{{update}}' :: json AS j) UPDATE "products" SET ("name", "description") = (SELECT "name", "description" FROM input i, json_populate_record(NULL::products, i.j) t) WHERE (("products"."id") = 1) AND (("products"."id") = 15) RETURNING *) SELECT json_object_agg('product', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "sel_0"`
|
||||
|
||||
vars := map[string]json.RawMessage{
|
||||
"update": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`),
|
||||
@ -180,7 +180,7 @@ func delete(t *testing.T) {
|
||||
}
|
||||
}`
|
||||
|
||||
sql := `WITH "products" AS (DELETE FROM "products" WHERE (("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") = 1) RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"`
|
||||
sql := `WITH "products" AS (DELETE FROM "products" WHERE (("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") = 1) RETURNING *) SELECT json_object_agg('product', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "sel_0"`
|
||||
|
||||
vars := map[string]json.RawMessage{
|
||||
"update": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`),
|
||||
@ -203,7 +203,7 @@ func blockedInsert(t *testing.T) {
|
||||
}
|
||||
}`
|
||||
|
||||
sql := `WITH "users" AS (WITH "input" AS (SELECT {{data}}::json AS j) INSERT INTO "users" ("full_name", "email") SELECT "full_name", "email" FROM input i, json_populate_record(NULL::users, i.j) t WHERE false RETURNING *) SELECT json_object_agg('user', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."id" AS "id") AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."id" FROM "users") AS "users_0") AS "done_1337"`
|
||||
sql := `WITH "users" AS (WITH "input" AS (SELECT '{{data}}' :: json AS j) INSERT INTO "users" ("full_name", "email") SELECT "full_name", "email" FROM input i, json_populate_record(NULL::users, i.j) t WHERE false RETURNING *) SELECT json_object_agg('user', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "users_0"."id" AS "id") AS "json_row_0")) AS "json_0" FROM (SELECT "users"."id" FROM "users" LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "sel_0"`
|
||||
|
||||
vars := map[string]json.RawMessage{
|
||||
"data": json.RawMessage(`{"email": "reannagreenholt@orn.com", "full_name": "Flo Barton"}`),
|
||||
@ -227,7 +227,7 @@ func blockedUpdate(t *testing.T) {
|
||||
}
|
||||
}`
|
||||
|
||||
sql := `WITH "users" AS (WITH "input" AS (SELECT {{data}}::json AS j) UPDATE "users" SET ("full_name", "email") = (SELECT "full_name", "email" FROM input i, json_populate_record(NULL::users, i.j) t) WHERE false RETURNING *) SELECT json_object_agg('user', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."id" AS "id", "users_0"."email" AS "email") AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."id", "users"."email" FROM "users") AS "users_0") AS "done_1337"`
|
||||
sql := `WITH "users" AS (WITH "input" AS (SELECT '{{data}}' :: json AS j) UPDATE "users" SET ("full_name", "email") = (SELECT "full_name", "email" FROM input i, json_populate_record(NULL::users, i.j) t) WHERE false RETURNING *) SELECT json_object_agg('user', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "users_0"."id" AS "id", "users_0"."email" AS "email") AS "json_row_0")) AS "json_0" FROM (SELECT "users"."id", "users"."email" FROM "users" LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "sel_0"`
|
||||
|
||||
vars := map[string]json.RawMessage{
|
||||
"data": json.RawMessage(`{"email": "reannagreenholt@orn.com", "full_name": "Flo Barton"}`),
|
||||
@ -250,7 +250,7 @@ func simpleInsertWithPresets(t *testing.T) {
|
||||
}
|
||||
}`
|
||||
|
||||
sql := `WITH "products" AS (WITH "input" AS (SELECT {{data}}::json AS j) INSERT INTO "products" ("name", "price", "created_at", "updated_at", "user_id") SELECT "name", "price", 'now' :: timestamp without time zone, 'now' :: timestamp without time zone, '$user_id' :: bigint FROM input i, json_populate_record(NULL::products, i.j) t RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id" FROM "products") AS "products_0") AS "done_1337"`
|
||||
sql := `WITH "products" AS (WITH "input" AS (SELECT '{{data}}' :: json AS j) INSERT INTO "products" ("name", "price", "created_at", "updated_at", "user_id") SELECT "name", "price", 'now' :: timestamp without time zone, 'now' :: timestamp without time zone, '{{user_id}}' :: bigint FROM input i, json_populate_record(NULL::products, i.j) t RETURNING *) SELECT json_object_agg('product', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id" FROM "products" LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "sel_0"`
|
||||
|
||||
vars := map[string]json.RawMessage{
|
||||
"data": json.RawMessage(`{"name": "Tomato", "price": 5.76}`),
|
||||
@ -273,7 +273,7 @@ func simpleUpdateWithPresets(t *testing.T) {
|
||||
}
|
||||
}`
|
||||
|
||||
sql := `WITH "products" AS (WITH "input" AS (SELECT {{data}}::json AS j) UPDATE "products" SET ("name", "price", "updated_at") = (SELECT "name", "price", 'now' :: timestamp without time zone FROM input i, json_populate_record(NULL::products, i.j) t) WHERE (("products"."user_id") = {{user_id}}) RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id" FROM "products") AS "products_0") AS "done_1337"`
|
||||
sql := `WITH "products" AS (WITH "input" AS (SELECT '{{data}}' :: json AS j) UPDATE "products" SET ("name", "price", "updated_at") = (SELECT "name", "price", 'now' :: timestamp without time zone FROM input i, json_populate_record(NULL::products, i.j) t) WHERE (("products"."user_id") = '{{user_id}}' :: bigint) RETURNING *) SELECT json_object_agg('product', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id" FROM "products" LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "sel_0"`
|
||||
|
||||
vars := map[string]json.RawMessage{
|
||||
"data": json.RawMessage(`{"name": "Apple", "price": 1.25}`),
|
||||
|
@ -171,7 +171,7 @@ func TestMain(m *testing.M) {
|
||||
}
|
||||
|
||||
vars := NewVariables(map[string]string{
|
||||
"account_id": "select account_id from users where id = $user_id",
|
||||
"admin_account_id": "5",
|
||||
})
|
||||
|
||||
pcompile = NewCompiler(Config{
|
||||
|
170
psql/query.go
170
psql/query.go
@ -77,30 +77,49 @@ func (co *Compiler) compileQuery(qc *qcode.QCode, w io.Writer) (uint32, error) {
|
||||
}
|
||||
|
||||
c := &compilerContext{w, qc.Selects, co}
|
||||
root := &qc.Selects[0]
|
||||
|
||||
ti, err := c.schema.GetTable(root.Table)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
multiRoot := (len(qc.Roots) > 1)
|
||||
|
||||
st := NewStack()
|
||||
st.Push(root.ID + closeBlock)
|
||||
st.Push(root.ID)
|
||||
|
||||
//fmt.Fprintf(w, `SELECT json_object_agg('%s', %s) FROM (`,
|
||||
//root.FieldName, root.Table)
|
||||
io.WriteString(c.w, `SELECT json_object_agg('`)
|
||||
io.WriteString(c.w, root.FieldName)
|
||||
io.WriteString(c.w, `', `)
|
||||
if multiRoot {
|
||||
io.WriteString(c.w, `SELECT row_to_json("json_root") FROM (SELECT `)
|
||||
|
||||
for i, id := range qc.Roots {
|
||||
root := qc.Selects[id]
|
||||
|
||||
st.Push(root.ID + closeBlock)
|
||||
st.Push(root.ID)
|
||||
|
||||
if i != 0 {
|
||||
io.WriteString(c.w, `, `)
|
||||
}
|
||||
|
||||
io.WriteString(c.w, `"sel_`)
|
||||
int2string(c.w, root.ID)
|
||||
io.WriteString(c.w, `"."json_`)
|
||||
int2string(c.w, root.ID)
|
||||
io.WriteString(c.w, `"`)
|
||||
|
||||
alias(c.w, root.FieldName)
|
||||
}
|
||||
|
||||
io.WriteString(c.w, ` FROM `)
|
||||
|
||||
if ti.Singular == false {
|
||||
io.WriteString(c.w, root.Table)
|
||||
} else {
|
||||
io.WriteString(c.w, "sel_json_")
|
||||
root := qc.Selects[0]
|
||||
|
||||
io.WriteString(c.w, `SELECT json_object_agg(`)
|
||||
io.WriteString(c.w, `'`)
|
||||
io.WriteString(c.w, root.FieldName)
|
||||
io.WriteString(c.w, `', `)
|
||||
io.WriteString(c.w, `json_`)
|
||||
int2string(c.w, root.ID)
|
||||
|
||||
st.Push(root.ID + closeBlock)
|
||||
st.Push(root.ID)
|
||||
|
||||
io.WriteString(c.w, `) FROM `)
|
||||
}
|
||||
io.WriteString(c.w, `) FROM (`)
|
||||
|
||||
var ignored uint32
|
||||
|
||||
@ -114,16 +133,21 @@ func (co *Compiler) compileQuery(qc *qcode.QCode, w io.Writer) (uint32, error) {
|
||||
if id < closeBlock {
|
||||
sel := &c.s[id]
|
||||
|
||||
if sel.ParentID == -1 {
|
||||
io.WriteString(c.w, `(`)
|
||||
}
|
||||
|
||||
ti, err := c.schema.GetTable(sel.Table)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if sel.ID != 0 {
|
||||
if sel.ParentID != -1 {
|
||||
if err = c.renderLateralJoin(sel); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
|
||||
skipped, err := c.renderSelect(sel, ti)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
@ -153,16 +177,25 @@ func (co *Compiler) compileQuery(qc *qcode.QCode, w io.Writer) (uint32, error) {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
if sel.ID != 0 {
|
||||
if sel.ParentID != -1 {
|
||||
if err = c.renderLateralJoinClose(sel); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
} else {
|
||||
io.WriteString(c.w, `)`)
|
||||
aliasWithID(c.w, `sel`, sel.ID)
|
||||
|
||||
if st.Len() != 0 {
|
||||
io.WriteString(c.w, `, `)
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
io.WriteString(c.w, `)`)
|
||||
alias(c.w, `done_1337`)
|
||||
if multiRoot {
|
||||
io.WriteString(c.w, `) AS "json_root"`)
|
||||
}
|
||||
|
||||
return ignored, nil
|
||||
}
|
||||
@ -191,15 +224,15 @@ func (c *compilerContext) processChildren(sel *qcode.Select, ti *DBTableInfo) (u
|
||||
fallthrough
|
||||
case RelBelongTo:
|
||||
if _, ok := colmap[rel.Col2]; !ok {
|
||||
cols = append(cols, &qcode.Column{ti.Name, rel.Col2, rel.Col2})
|
||||
cols = append(cols, &qcode.Column{Table: ti.Name, Name: rel.Col2, FieldName: rel.Col2})
|
||||
}
|
||||
case RelOneToManyThrough:
|
||||
if _, ok := colmap[rel.Col1]; !ok {
|
||||
cols = append(cols, &qcode.Column{ti.Name, rel.Col1, rel.Col1})
|
||||
cols = append(cols, &qcode.Column{Table: ti.Name, Name: rel.Col1, FieldName: rel.Col1})
|
||||
}
|
||||
case RelRemote:
|
||||
if _, ok := colmap[rel.Col1]; !ok {
|
||||
cols = append(cols, &qcode.Column{ti.Name, rel.Col1, rel.Col2})
|
||||
cols = append(cols, &qcode.Column{Table: ti.Name, Name: rel.Col1, FieldName: rel.Col2})
|
||||
}
|
||||
skipped |= (1 << uint(id))
|
||||
|
||||
@ -219,7 +252,7 @@ func (c *compilerContext) renderSelect(sel *qcode.Select, ti *DBTableInfo) (uint
|
||||
if ti.Singular == false {
|
||||
//fmt.Fprintf(w, `SELECT coalesce(json_agg("%s"`, c.sel.Table)
|
||||
io.WriteString(c.w, `SELECT coalesce(json_agg("`)
|
||||
io.WriteString(c.w, "sel_json_")
|
||||
io.WriteString(c.w, "json_")
|
||||
int2string(c.w, sel.ID)
|
||||
io.WriteString(c.w, `"`)
|
||||
|
||||
@ -232,7 +265,7 @@ func (c *compilerContext) renderSelect(sel *qcode.Select, ti *DBTableInfo) (uint
|
||||
|
||||
//fmt.Fprintf(w, `), '[]') AS "%s" FROM (`, c.sel.Table)
|
||||
io.WriteString(c.w, `), '[]')`)
|
||||
alias(c.w, sel.Table)
|
||||
aliasWithID(c.w, "json", sel.ID)
|
||||
io.WriteString(c.w, ` FROM (`)
|
||||
}
|
||||
|
||||
@ -245,8 +278,8 @@ func (c *compilerContext) renderSelect(sel *qcode.Select, ti *DBTableInfo) (uint
|
||||
|
||||
io.WriteString(c.w, `row_to_json((`)
|
||||
|
||||
//fmt.Fprintf(w, `SELECT "sel_%d" FROM (SELECT `, c.sel.ID)
|
||||
io.WriteString(c.w, `SELECT "sel_`)
|
||||
//fmt.Fprintf(w, `SELECT "%d" FROM (SELECT `, c.sel.ID)
|
||||
io.WriteString(c.w, `SELECT "json_row_`)
|
||||
int2string(c.w, sel.ID)
|
||||
io.WriteString(c.w, `" FROM (SELECT `)
|
||||
|
||||
@ -260,13 +293,13 @@ func (c *compilerContext) renderSelect(sel *qcode.Select, ti *DBTableInfo) (uint
|
||||
return skipped, err
|
||||
}
|
||||
|
||||
//fmt.Fprintf(w, `) AS "sel_%d"`, c.sel.ID)
|
||||
//fmt.Fprintf(w, `) AS "%d"`, c.sel.ID)
|
||||
io.WriteString(c.w, `)`)
|
||||
aliasWithID(c.w, "sel", sel.ID)
|
||||
aliasWithID(c.w, "json_row", sel.ID)
|
||||
|
||||
//fmt.Fprintf(w, `)) AS "%s"`, c.sel.Table)
|
||||
io.WriteString(c.w, `))`)
|
||||
aliasWithID(c.w, "sel_json", sel.ID)
|
||||
aliasWithID(c.w, "json", sel.ID)
|
||||
// END-ROW-TO-JSON
|
||||
|
||||
if hasOrder {
|
||||
@ -295,8 +328,8 @@ func (c *compilerContext) renderSelectClose(sel *qcode.Select, ti *DBTableInfo)
|
||||
}
|
||||
|
||||
switch {
|
||||
case sel.Paging.NoLimit:
|
||||
break
|
||||
case ti.Singular:
|
||||
io.WriteString(c.w, ` LIMIT ('1') :: integer`)
|
||||
|
||||
case len(sel.Paging.Limit) != 0:
|
||||
//fmt.Fprintf(w, ` LIMIT ('%s') :: integer`, c.sel.Paging.Limit)
|
||||
@ -304,8 +337,8 @@ func (c *compilerContext) renderSelectClose(sel *qcode.Select, ti *DBTableInfo)
|
||||
io.WriteString(c.w, sel.Paging.Limit)
|
||||
io.WriteString(c.w, `') :: integer`)
|
||||
|
||||
case ti.Singular:
|
||||
io.WriteString(c.w, ` LIMIT ('1') :: integer`)
|
||||
case sel.Paging.NoLimit:
|
||||
break
|
||||
|
||||
default:
|
||||
io.WriteString(c.w, ` LIMIT ('20') :: integer`)
|
||||
@ -319,9 +352,9 @@ func (c *compilerContext) renderSelectClose(sel *qcode.Select, ti *DBTableInfo)
|
||||
}
|
||||
|
||||
if ti.Singular == false {
|
||||
//fmt.Fprintf(w, `) AS "sel_json_agg_%d"`, c.sel.ID)
|
||||
//fmt.Fprintf(w, `) AS "json_agg_%d"`, c.sel.ID)
|
||||
io.WriteString(c.w, `)`)
|
||||
aliasWithID(c.w, "sel_json_agg", sel.ID)
|
||||
aliasWithID(c.w, "json_agg", sel.ID)
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -445,24 +478,18 @@ func (c *compilerContext) renderJoinedColumns(sel *qcode.Select, ti *DBTableInfo
|
||||
}
|
||||
childSel := &c.s[id]
|
||||
|
||||
cti, err := c.schema.GetTable(childSel.Table)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
//fmt.Fprintf(w, `"%s_%d_join"."%s" AS "%s"`,
|
||||
//s.Table, s.ID, s.Table, s.FieldName)
|
||||
if cti.Singular {
|
||||
io.WriteString(c.w, `"sel_json_`)
|
||||
int2string(c.w, childSel.ID)
|
||||
io.WriteString(c.w, `" AS "`)
|
||||
io.WriteString(c.w, childSel.FieldName)
|
||||
io.WriteString(c.w, `"`)
|
||||
|
||||
} else {
|
||||
colWithTableIDSuffixAlias(c.w, childSel.Table, childSel.ID,
|
||||
"_join", childSel.Table, childSel.FieldName)
|
||||
}
|
||||
//if cti.Singular {
|
||||
io.WriteString(c.w, `"`)
|
||||
io.WriteString(c.w, childSel.Table)
|
||||
io.WriteString(c.w, `_`)
|
||||
int2string(c.w, childSel.ID)
|
||||
io.WriteString(c.w, `_join"."json_`)
|
||||
int2string(c.w, childSel.ID)
|
||||
io.WriteString(c.w, `" AS "`)
|
||||
io.WriteString(c.w, childSel.FieldName)
|
||||
io.WriteString(c.w, `"`)
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -472,8 +499,8 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
|
||||
childCols []*qcode.Column, skipped uint32) error {
|
||||
var groupBy []int
|
||||
|
||||
isRoot := sel.ID == 0
|
||||
isFil := sel.Where != nil
|
||||
isRoot := sel.ParentID == -1
|
||||
isFil := (sel.Where != nil && sel.Where.Op != qcode.OpNop)
|
||||
isSearch := sel.Args["search"] != nil
|
||||
isAgg := false
|
||||
|
||||
@ -641,8 +668,8 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
|
||||
}
|
||||
|
||||
switch {
|
||||
case sel.Paging.NoLimit:
|
||||
break
|
||||
case ti.Singular:
|
||||
io.WriteString(c.w, ` LIMIT ('1') :: integer`)
|
||||
|
||||
case len(sel.Paging.Limit) != 0:
|
||||
//fmt.Fprintf(w, ` LIMIT ('%s') :: integer`, c.sel.Paging.Limit)
|
||||
@ -650,8 +677,8 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
|
||||
io.WriteString(c.w, sel.Paging.Limit)
|
||||
io.WriteString(c.w, `') :: integer`)
|
||||
|
||||
case ti.Singular:
|
||||
io.WriteString(c.w, ` LIMIT ('1') :: integer`)
|
||||
case sel.Paging.NoLimit:
|
||||
break
|
||||
|
||||
default:
|
||||
io.WriteString(c.w, ` LIMIT ('20') :: integer`)
|
||||
@ -817,7 +844,6 @@ func (c *compilerContext) renderWhere(sel *qcode.Select, ti *DBTableInfo) error
|
||||
}
|
||||
|
||||
func (c *compilerContext) renderNestedWhere(ex *qcode.Exp, sel *qcode.Select, ti *DBTableInfo) error {
|
||||
|
||||
for i := 0; i < len(ex.NestedCols)-1; i++ {
|
||||
cti, err := c.schema.GetTable(ex.NestedCols[i])
|
||||
if err != nil {
|
||||
@ -851,7 +877,18 @@ func (c *compilerContext) renderNestedWhere(ex *qcode.Exp, sel *qcode.Select, ti
|
||||
}
|
||||
|
||||
func (c *compilerContext) renderOp(ex *qcode.Exp, sel *qcode.Select, ti *DBTableInfo) error {
|
||||
var col *DBColumn
|
||||
var ok bool
|
||||
|
||||
if ex.Op == qcode.OpNop {
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(ex.Col) != 0 {
|
||||
if col, ok = ti.Columns[ex.Col]; !ok {
|
||||
return fmt.Errorf("no column '%s' found ", ex.Col)
|
||||
}
|
||||
|
||||
io.WriteString(c.w, `((`)
|
||||
colWithTable(c.w, ti.Name, ex.Col)
|
||||
io.WriteString(c.w, `) `)
|
||||
@ -907,6 +944,9 @@ func (c *compilerContext) renderOp(ex *qcode.Exp, sel *qcode.Select, ti *DBTable
|
||||
if len(ti.PrimaryCol) == 0 {
|
||||
return fmt.Errorf("no primary key column defined for %s", ti.Name)
|
||||
}
|
||||
if col, ok = ti.Columns[ti.PrimaryCol]; !ok {
|
||||
return fmt.Errorf("no primary key column '%s' found ", ti.PrimaryCol)
|
||||
}
|
||||
//fmt.Fprintf(w, `(("%s") =`, c.ti.PrimaryCol)
|
||||
io.WriteString(c.w, `((`)
|
||||
colWithTable(c.w, ti.Name, ti.PrimaryCol)
|
||||
@ -916,6 +956,9 @@ func (c *compilerContext) renderOp(ex *qcode.Exp, sel *qcode.Select, ti *DBTable
|
||||
if len(ti.TSVCol) == 0 {
|
||||
return fmt.Errorf("no tsv column defined for %s", ti.Name)
|
||||
}
|
||||
if col, ok = ti.Columns[ti.TSVCol]; !ok {
|
||||
return fmt.Errorf("no tsv column '%s' found ", ti.TSVCol)
|
||||
}
|
||||
//fmt.Fprintf(w, `(("%s") @@ to_tsquery('%s'))`, c.ti.TSVCol, val.Val)
|
||||
io.WriteString(c.w, `(("`)
|
||||
io.WriteString(c.w, ti.TSVCol)
|
||||
@ -931,7 +974,7 @@ func (c *compilerContext) renderOp(ex *qcode.Exp, sel *qcode.Select, ti *DBTable
|
||||
if ex.Type == qcode.ValList {
|
||||
c.renderList(ex)
|
||||
} else {
|
||||
c.renderVal(ex, c.vars)
|
||||
c.renderVal(ex, c.vars, col)
|
||||
}
|
||||
|
||||
io.WriteString(c.w, `)`)
|
||||
@ -1008,7 +1051,7 @@ func (c *compilerContext) renderList(ex *qcode.Exp) {
|
||||
io.WriteString(c.w, `)`)
|
||||
}
|
||||
|
||||
func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string) {
|
||||
func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string, col *DBColumn) {
|
||||
io.WriteString(c.w, ` `)
|
||||
|
||||
switch ex.Type {
|
||||
@ -1025,6 +1068,7 @@ func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string) {
|
||||
io.WriteString(c.w, `'`)
|
||||
|
||||
case qcode.ValVar:
|
||||
io.WriteString(c.w, `'`)
|
||||
if val, ok := vars[ex.Val]; ok {
|
||||
io.WriteString(c.w, val)
|
||||
} else {
|
||||
@ -1033,6 +1077,8 @@ func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string) {
|
||||
io.WriteString(c.w, ex.Val)
|
||||
io.WriteString(c.w, `}}`)
|
||||
}
|
||||
io.WriteString(c.w, `' :: `)
|
||||
io.WriteString(c.w, col.Type)
|
||||
}
|
||||
//io.WriteString(c.w, `)`)
|
||||
}
|
||||
|
@ -28,7 +28,7 @@ func withComplexArgs(t *testing.T) {
|
||||
}
|
||||
}`
|
||||
|
||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0" ORDER BY "products_0_price_ob" DESC), '[]') AS "products" FROM (SELECT DISTINCT ON ("products_0_price_ob") row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0", "products_0"."price" AS "products_0_price_ob" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") < 28) AND (("products"."id") >= 20)) LIMIT ('30') :: integer) AS "products_0" ORDER BY "products_0_price_ob" DESC LIMIT ('30') :: integer) AS "sel_json_agg_0") AS "done_1337"`
|
||||
sql := `SELECT json_object_agg('products', json_0) FROM (SELECT coalesce(json_agg("json_0" ORDER BY "products_0_price_ob" DESC), '[]') AS "json_0" FROM (SELECT DISTINCT ON ("products_0_price_ob") row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "json_row_0")) AS "json_0", "products_0"."price" AS "products_0_price_ob" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") < 28) AND (("products"."id") >= 20)) LIMIT ('30') :: integer) AS "products_0" ORDER BY "products_0_price_ob" DESC LIMIT ('30') :: integer) AS "json_agg_0") AS "sel_0"`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "user")
|
||||
if err != nil {
|
||||
@ -56,7 +56,7 @@ func withWhereMultiOr(t *testing.T) {
|
||||
}
|
||||
}`
|
||||
|
||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") < 20) OR (("products"."price") > 10) OR NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"`
|
||||
sql := `SELECT json_object_agg('products', json_0) FROM (SELECT coalesce(json_agg("json_0"), '[]') AS "json_0" FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") < 20) OR (("products"."price") > 10) OR NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "json_agg_0") AS "sel_0"`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "user")
|
||||
if err != nil {
|
||||
@ -82,7 +82,7 @@ func withWhereIsNull(t *testing.T) {
|
||||
}
|
||||
}`
|
||||
|
||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") > 10) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"`
|
||||
sql := `SELECT json_object_agg('products', json_0) FROM (SELECT coalesce(json_agg("json_0"), '[]') AS "json_0" FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") > 10) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "json_agg_0") AS "sel_0"`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "user")
|
||||
if err != nil {
|
||||
@ -108,7 +108,7 @@ func withWhereAndList(t *testing.T) {
|
||||
}
|
||||
}`
|
||||
|
||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") > 10) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"`
|
||||
sql := `SELECT json_object_agg('products', json_0) FROM (SELECT coalesce(json_agg("json_0"), '[]') AS "json_0" FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") > 10) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "json_agg_0") AS "sel_0"`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "user")
|
||||
if err != nil {
|
||||
@ -128,7 +128,7 @@ func fetchByID(t *testing.T) {
|
||||
}
|
||||
}`
|
||||
|
||||
sql := `SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") = 15)) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337"`
|
||||
sql := `SELECT json_object_agg('product', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") = 15)) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "sel_0"`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "user")
|
||||
if err != nil {
|
||||
@ -148,7 +148,7 @@ func searchQuery(t *testing.T) {
|
||||
}
|
||||
}`
|
||||
|
||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("tsv") @@ to_tsquery('Imperial'))) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"`
|
||||
sql := `SELECT json_object_agg('products', json_0) FROM (SELECT coalesce(json_agg("json_0"), '[]') AS "json_0" FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("tsv") @@ to_tsquery('Imperial'))) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "json_agg_0") AS "sel_0"`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "user")
|
||||
if err != nil {
|
||||
@ -171,7 +171,7 @@ func oneToMany(t *testing.T) {
|
||||
}
|
||||
}`
|
||||
|
||||
sql := `SELECT json_object_agg('users', users) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."email" AS "email", "products_1_join"."products" AS "products") AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."email", "users"."id" FROM "users" LIMIT ('20') :: integer) AS "users_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "products_1"."name" AS "name", "products_1"."price" AS "price") AS "sel_1")) AS "sel_json_1" FROM (SELECT "products"."name", "products"."price" FROM "products" WHERE ((("products"."user_id") = ("users_0"."id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"`
|
||||
sql := `SELECT json_object_agg('users', json_0) FROM (SELECT coalesce(json_agg("json_0"), '[]') AS "json_0" FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "users_0"."email" AS "email", "products_1_join"."json_1" AS "products") AS "json_row_0")) AS "json_0" FROM (SELECT "users"."email", "users"."id" FROM "users" LIMIT ('20') :: integer) AS "users_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("json_1"), '[]') AS "json_1" FROM (SELECT row_to_json((SELECT "json_row_1" FROM (SELECT "products_1"."name" AS "name", "products_1"."price" AS "price") AS "json_row_1")) AS "json_1" FROM (SELECT "products"."name", "products"."price" FROM "products" WHERE ((("products"."user_id") = ("users_0"."id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "json_agg_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "json_agg_0") AS "sel_0"`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "user")
|
||||
if err != nil {
|
||||
@ -194,7 +194,7 @@ func belongsTo(t *testing.T) {
|
||||
}
|
||||
}`
|
||||
|
||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."price" AS "price", "users_1_join"."users" AS "users") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", "products"."price", "products"."user_id" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "users_1"."email" AS "email") AS "sel_1")) AS "sel_json_1" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = ("products_0"."user_id"))) LIMIT ('20') :: integer) AS "users_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "users_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"`
|
||||
sql := `SELECT json_object_agg('products', json_0) FROM (SELECT coalesce(json_agg("json_0"), '[]') AS "json_0" FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."price" AS "price", "users_1_join"."json_1" AS "users") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."name", "products"."price", "products"."user_id" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("json_1"), '[]') AS "json_1" FROM (SELECT row_to_json((SELECT "json_row_1" FROM (SELECT "users_1"."email" AS "email") AS "json_row_1")) AS "json_1" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = ("products_0"."user_id"))) LIMIT ('20') :: integer) AS "users_1" LIMIT ('20') :: integer) AS "json_agg_1") AS "users_1_join" ON ('true') LIMIT ('20') :: integer) AS "json_agg_0") AS "sel_0"`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "user")
|
||||
if err != nil {
|
||||
@ -217,7 +217,7 @@ func manyToMany(t *testing.T) {
|
||||
}
|
||||
}`
|
||||
|
||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "customers_1_join"."customers" AS "customers") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", "products"."id" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "customers_1"."email" AS "email", "customers_1"."full_name" AS "full_name") AS "sel_1")) AS "sel_json_1" FROM (SELECT "customers"."email", "customers"."full_name" FROM "customers" LEFT OUTER JOIN "purchases" ON (("purchases"."product_id") = ("products_0"."id")) WHERE ((("customers"."id") = ("purchases"."customer_id"))) LIMIT ('20') :: integer) AS "customers_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "customers_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"`
|
||||
sql := `SELECT json_object_agg('products', json_0) FROM (SELECT coalesce(json_agg("json_0"), '[]') AS "json_0" FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."name" AS "name", "customers_1_join"."json_1" AS "customers") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."name", "products"."id" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("json_1"), '[]') AS "json_1" FROM (SELECT row_to_json((SELECT "json_row_1" FROM (SELECT "customers_1"."email" AS "email", "customers_1"."full_name" AS "full_name") AS "json_row_1")) AS "json_1" FROM (SELECT "customers"."email", "customers"."full_name" FROM "customers" LEFT OUTER JOIN "purchases" ON (("purchases"."product_id") = ("products_0"."id")) WHERE ((("customers"."id") = ("purchases"."customer_id"))) LIMIT ('20') :: integer) AS "customers_1" LIMIT ('20') :: integer) AS "json_agg_1") AS "customers_1_join" ON ('true') LIMIT ('20') :: integer) AS "json_agg_0") AS "sel_0"`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "user")
|
||||
if err != nil {
|
||||
@ -240,7 +240,7 @@ func manyToManyReverse(t *testing.T) {
|
||||
}
|
||||
}`
|
||||
|
||||
sql := `SELECT json_object_agg('customers', customers) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "customers_0"."email" AS "email", "customers_0"."full_name" AS "full_name", "products_1_join"."products" AS "products") AS "sel_0")) AS "sel_json_0" FROM (SELECT "customers"."email", "customers"."full_name", "customers"."id" FROM "customers" LIMIT ('20') :: integer) AS "customers_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "products_1"."name" AS "name") AS "sel_1")) AS "sel_json_1" FROM (SELECT "products"."name" FROM "products" LEFT OUTER JOIN "purchases" ON (("purchases"."customer_id") = ("customers_0"."id")) WHERE ((("products"."id") = ("purchases"."product_id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"`
|
||||
sql := `SELECT json_object_agg('customers', json_0) FROM (SELECT coalesce(json_agg("json_0"), '[]') AS "json_0" FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "customers_0"."email" AS "email", "customers_0"."full_name" AS "full_name", "products_1_join"."json_1" AS "products") AS "json_row_0")) AS "json_0" FROM (SELECT "customers"."email", "customers"."full_name", "customers"."id" FROM "customers" LIMIT ('20') :: integer) AS "customers_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("json_1"), '[]') AS "json_1" FROM (SELECT row_to_json((SELECT "json_row_1" FROM (SELECT "products_1"."name" AS "name") AS "json_row_1")) AS "json_1" FROM (SELECT "products"."name" FROM "products" LEFT OUTER JOIN "purchases" ON (("purchases"."customer_id") = ("customers_0"."id")) WHERE ((("products"."id") = ("purchases"."product_id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "json_agg_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "json_agg_0") AS "sel_0"`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "user")
|
||||
if err != nil {
|
||||
@ -260,7 +260,7 @@ func aggFunction(t *testing.T) {
|
||||
}
|
||||
}`
|
||||
|
||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."count_price" AS "count_price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", count("products"."price") AS "count_price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) GROUP BY "products"."name" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"`
|
||||
sql := `SELECT json_object_agg('products', json_0) FROM (SELECT coalesce(json_agg("json_0"), '[]') AS "json_0" FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."count_price" AS "count_price") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."name", count("products"."price") AS "count_price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) GROUP BY "products"."name" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "json_agg_0") AS "sel_0"`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "user")
|
||||
if err != nil {
|
||||
@ -280,7 +280,7 @@ func aggFunctionBlockedByCol(t *testing.T) {
|
||||
}
|
||||
}`
|
||||
|
||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name" FROM "products" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"`
|
||||
sql := `SELECT json_object_agg('products', json_0) FROM (SELECT coalesce(json_agg("json_0"), '[]') AS "json_0" FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."name" AS "name") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."name" FROM "products" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "json_agg_0") AS "sel_0"`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "anon")
|
||||
if err != nil {
|
||||
@ -300,7 +300,7 @@ func aggFunctionDisabled(t *testing.T) {
|
||||
}
|
||||
}`
|
||||
|
||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name" FROM "products" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"`
|
||||
sql := `SELECT json_object_agg('products', json_0) FROM (SELECT coalesce(json_agg("json_0"), '[]') AS "json_0" FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."name" AS "name") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."name" FROM "products" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "json_agg_0") AS "sel_0"`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "anon1")
|
||||
if err != nil {
|
||||
@ -320,7 +320,7 @@ func aggFunctionWithFilter(t *testing.T) {
|
||||
}
|
||||
}`
|
||||
|
||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."max_price" AS "max_price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", max("products"."price") AS "max_price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") > 10)) GROUP BY "products"."id" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"`
|
||||
sql := `SELECT json_object_agg('products', json_0) FROM (SELECT coalesce(json_agg("json_0"), '[]') AS "json_0" FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."max_price" AS "max_price") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", max("products"."price") AS "max_price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") > 10)) GROUP BY "products"."id" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "json_agg_0") AS "sel_0"`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "user")
|
||||
if err != nil {
|
||||
@ -339,7 +339,7 @@ func syntheticTables(t *testing.T) {
|
||||
}
|
||||
}`
|
||||
|
||||
sql := `SELECT json_object_agg('me', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT ) AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = {{user_id}})) LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "done_1337"`
|
||||
sql := `SELECT json_object_agg('me', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT ) AS "json_row_0")) AS "json_0" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = '{{user_id}}' :: bigint)) LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "sel_0"`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "user")
|
||||
if err != nil {
|
||||
@ -359,7 +359,7 @@ func queryWithVariables(t *testing.T) {
|
||||
}
|
||||
}`
|
||||
|
||||
sql := `SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") = {{product_price}}) AND (("products"."id") = {{product_id}})) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337"`
|
||||
sql := `SELECT json_object_agg('product', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") = '{{product_price}}' :: numeric(7,2)) AND (("products"."id") = '{{product_id}}' :: bigint)) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "sel_0"`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "user")
|
||||
if err != nil {
|
||||
@ -385,7 +385,40 @@ func withWhereOnRelations(t *testing.T) {
|
||||
}
|
||||
}`
|
||||
|
||||
sql := `SELECT json_object_agg('users', users) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."id" AS "id", "users_0"."email" AS "email") AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."id", "users"."email" FROM "users" WHERE (NOT EXISTS (SELECT 1 FROM products WHERE (("products"."user_id") = ("users"."id")))) LIMIT ('20') :: integer) AS "users_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"`
|
||||
sql := `SELECT json_object_agg('users', json_0) FROM (SELECT coalesce(json_agg("json_0"), '[]') AS "json_0" FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "users_0"."id" AS "id", "users_0"."email" AS "email") AS "json_row_0")) AS "json_0" FROM (SELECT "users"."id", "users"."email" FROM "users" WHERE (NOT EXISTS (SELECT 1 FROM products WHERE (("products"."user_id") = ("users"."id")))) LIMIT ('20') :: integer) AS "users_0" LIMIT ('20') :: integer) AS "json_agg_0") AS "sel_0"`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "user")
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if string(resSQL) != sql {
|
||||
t.Fatal(errNotExpected)
|
||||
}
|
||||
}
|
||||
|
||||
func multiRoot(t *testing.T) {
|
||||
gql := `query {
|
||||
product {
|
||||
id
|
||||
name
|
||||
customer {
|
||||
email
|
||||
}
|
||||
customers {
|
||||
email
|
||||
}
|
||||
}
|
||||
user {
|
||||
id
|
||||
email
|
||||
}
|
||||
customer {
|
||||
id
|
||||
}
|
||||
}`
|
||||
|
||||
sql := `SELECT row_to_json("json_root") FROM (SELECT "sel_0"."json_0" AS "customer", "sel_1"."json_1" AS "user", "sel_2"."json_2" AS "product" FROM (SELECT row_to_json((SELECT "json_row_2" FROM (SELECT "products_2"."id" AS "id", "products_2"."name" AS "name", "customers_3_join"."json_3" AS "customers", "customer_4_join"."json_4" AS "customer") AS "json_row_2")) AS "json_2" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) LIMIT ('1') :: integer) AS "products_2" LEFT OUTER JOIN LATERAL (SELECT row_to_json((SELECT "json_row_4" FROM (SELECT "customers_4"."email" AS "email") AS "json_row_4")) AS "json_4" FROM (SELECT "customers"."email" FROM "customers" LEFT OUTER JOIN "purchases" ON (("purchases"."product_id") = ("products_2"."id")) WHERE ((("customers"."id") = ("purchases"."customer_id"))) LIMIT ('1') :: integer) AS "customers_4" LIMIT ('1') :: integer) AS "customer_4_join" ON ('true') LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("json_3"), '[]') AS "json_3" FROM (SELECT row_to_json((SELECT "json_row_3" FROM (SELECT "customers_3"."email" AS "email") AS "json_row_3")) AS "json_3" FROM (SELECT "customers"."email" FROM "customers" LEFT OUTER JOIN "purchases" ON (("purchases"."product_id") = ("products_2"."id")) WHERE ((("customers"."id") = ("purchases"."customer_id"))) LIMIT ('20') :: integer) AS "customers_3" LIMIT ('20') :: integer) AS "json_agg_3") AS "customers_3_join" ON ('true') LIMIT ('1') :: integer) AS "sel_2", (SELECT row_to_json((SELECT "json_row_1" FROM (SELECT "users_1"."id" AS "id", "users_1"."email" AS "email") AS "json_row_1")) AS "json_1" FROM (SELECT "users"."id", "users"."email" FROM "users" LIMIT ('1') :: integer) AS "users_1" LIMIT ('1') :: integer) AS "sel_1", (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "customers_0"."id" AS "id") AS "json_row_0")) AS "json_0" FROM (SELECT "customers"."id" FROM "customers" LIMIT ('1') :: integer) AS "customers_0" LIMIT ('1') :: integer) AS "sel_0") AS "json_root"`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "user")
|
||||
if err != nil {
|
||||
@ -406,7 +439,7 @@ func blockedQuery(t *testing.T) {
|
||||
}
|
||||
}`
|
||||
|
||||
sql := `SELECT json_object_agg('user', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."id" AS "id", "users_0"."full_name" AS "full_name", "users_0"."email" AS "email") AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."id", "users"."full_name", "users"."email" FROM "users" WHERE (false) LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "done_1337"`
|
||||
sql := `SELECT json_object_agg('user', json_0) FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "users_0"."id" AS "id", "users_0"."full_name" AS "full_name", "users_0"."email" AS "email") AS "json_row_0")) AS "json_0" FROM (SELECT "users"."id", "users"."full_name", "users"."email" FROM "users" WHERE (false) LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "sel_0"`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "bad_dude")
|
||||
if err != nil {
|
||||
@ -426,7 +459,7 @@ func blockedFunctions(t *testing.T) {
|
||||
}
|
||||
}`
|
||||
|
||||
sql := `SELECT json_object_agg('users', users) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."email" AS "email") AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."email" FROM "users" WHERE (false) LIMIT ('20') :: integer) AS "users_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"`
|
||||
sql := `SELECT json_object_agg('users', json_0) FROM (SELECT coalesce(json_agg("json_0"), '[]') AS "json_0" FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "users_0"."email" AS "email") AS "json_row_0")) AS "json_0" FROM (SELECT "users"."email" FROM "users" WHERE (false) LIMIT ('20') :: integer) AS "users_0" LIMIT ('20') :: integer) AS "json_agg_0") AS "sel_0"`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, nil, "bad_dude")
|
||||
if err != nil {
|
||||
@ -456,6 +489,7 @@ func TestCompileQuery(t *testing.T) {
|
||||
t.Run("syntheticTables", syntheticTables)
|
||||
t.Run("queryWithVariables", queryWithVariables)
|
||||
t.Run("withWhereOnRelations", withWhereOnRelations)
|
||||
t.Run("multiRoot", multiRoot)
|
||||
t.Run("blockedQuery", blockedQuery)
|
||||
t.Run("blockedFunctions", blockedFunctions)
|
||||
}
|
||||
|
2
qcode/cleanup.sh
Executable file
2
qcode/cleanup.sh
Executable file
@ -0,0 +1,2 @@
|
||||
#!/bin/sh
|
||||
cd corpus && rm -rf $(find . ! -name '00?.gql')
|
@ -1,6 +1,7 @@
|
||||
package qcode
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
@ -121,3 +122,12 @@ func mapToList(m map[string]string) []string {
|
||||
sort.Strings(list)
|
||||
return list
|
||||
}
|
||||
|
||||
var varRe = regexp.MustCompile(`\$([a-zA-Z0-9_]+)`)
|
||||
|
||||
func parsePresets(m map[string]string) map[string]string {
|
||||
for k, v := range m {
|
||||
m[k] = varRe.ReplaceAllString(v, `{{$1}}`)
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
@ -4,6 +4,8 @@ package qcode
|
||||
|
||||
// FuzzerEntrypoint for Fuzzbuzz
|
||||
func Fuzz(data []byte) int {
|
||||
GetQType(string(data))
|
||||
|
||||
qcompile, _ := NewCompiler(Config{})
|
||||
_, err := qcompile.Compile(data, "user")
|
||||
if err != nil {
|
||||
|
@ -148,24 +148,13 @@ func parseSelectionSet(gql []byte) (*Operation, error) {
|
||||
|
||||
if p.peek(itemObjOpen) {
|
||||
p.ignore()
|
||||
}
|
||||
|
||||
if p.peek(itemName) {
|
||||
op = opPool.Get().(*Operation)
|
||||
op.Reset()
|
||||
|
||||
op.Type = opQuery
|
||||
op.Name = ""
|
||||
op.Fields = op.fieldsA[:0]
|
||||
op.Args = op.argsA[:0]
|
||||
op.Fields, err = p.parseFields(op.Fields)
|
||||
|
||||
op, err = p.parseQueryOp()
|
||||
} else {
|
||||
op, err = p.parseOp()
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
lexPool.Put(l)
|
||||
@ -259,6 +248,37 @@ func (p *Parser) parseOp() (*Operation, error) {
|
||||
|
||||
if p.peek(itemObjOpen) {
|
||||
p.ignore()
|
||||
|
||||
for n := 0; n < 10; n++ {
|
||||
if p.peek(itemName) == false {
|
||||
break
|
||||
}
|
||||
|
||||
op.Fields, err = p.parseFields(op.Fields)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return op, nil
|
||||
}
|
||||
|
||||
func (p *Parser) parseQueryOp() (*Operation, error) {
|
||||
op := opPool.Get().(*Operation)
|
||||
op.Reset()
|
||||
|
||||
op.Type = opQuery
|
||||
op.Fields = op.fieldsA[:0]
|
||||
op.Args = op.argsA[:0]
|
||||
|
||||
var err error
|
||||
|
||||
for n := 0; n < 10; n++ {
|
||||
if p.peek(itemName) == false {
|
||||
break
|
||||
}
|
||||
|
||||
op.Fields, err = p.parseFields(op.Fields)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -300,16 +320,12 @@ func (p *Parser) parseFields(fields []Field) ([]Field, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if f.ID != 0 {
|
||||
intf := st.Peek()
|
||||
pid, ok := intf.(int32)
|
||||
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("14: unexpected value %v (%t)", intf, intf)
|
||||
}
|
||||
|
||||
intf := st.Peek()
|
||||
if pid, ok := intf.(int32); ok {
|
||||
f.ParentID = pid
|
||||
fields[pid].Children = append(fields[pid].Children, f.ID)
|
||||
} else {
|
||||
f.ParentID = -1
|
||||
}
|
||||
|
||||
if p.peek(itemObjOpen) {
|
||||
|
@ -14,10 +14,10 @@ func TestCompile1(t *testing.T) {
|
||||
})
|
||||
|
||||
_, err := qc.Compile([]byte(`
|
||||
product(id: 15) {
|
||||
{ product(id: 15) {
|
||||
id
|
||||
name
|
||||
}`), "user")
|
||||
} }`), "user")
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
|
@ -20,6 +20,7 @@ const (
|
||||
|
||||
const (
|
||||
QTQuery QType = iota + 1
|
||||
QTMutation
|
||||
QTInsert
|
||||
QTUpdate
|
||||
QTDelete
|
||||
@ -30,6 +31,8 @@ type QCode struct {
|
||||
Type QType
|
||||
ActionVar string
|
||||
Selects []Select
|
||||
Roots []int32
|
||||
rootsA [5]int32
|
||||
}
|
||||
|
||||
type Select struct {
|
||||
@ -200,7 +203,7 @@ func (com *Compiler) AddRole(role, table string, trc TRConfig) error {
|
||||
return err
|
||||
}
|
||||
trv.insert.cols = listToMap(trc.Insert.Columns)
|
||||
trv.insert.psmap = trc.Insert.Presets
|
||||
trv.insert.psmap = parsePresets(trc.Insert.Presets)
|
||||
trv.insert.pslist = mapToList(trv.insert.psmap)
|
||||
|
||||
// update config
|
||||
@ -208,7 +211,7 @@ func (com *Compiler) AddRole(role, table string, trc TRConfig) error {
|
||||
return err
|
||||
}
|
||||
trv.update.cols = listToMap(trc.Update.Columns)
|
||||
trv.update.psmap = trc.Update.Presets
|
||||
trv.update.psmap = parsePresets(trc.Update.Presets)
|
||||
trv.update.pslist = mapToList(trv.update.psmap)
|
||||
|
||||
// delete config
|
||||
@ -233,6 +236,7 @@ func (com *Compiler) Compile(query []byte, role string) (*QCode, error) {
|
||||
var err error
|
||||
|
||||
qc := QCode{Type: QTQuery}
|
||||
qc.Roots = qc.rootsA[:0]
|
||||
|
||||
op, err := Parse(query)
|
||||
if err != nil {
|
||||
@ -250,7 +254,7 @@ func (com *Compiler) Compile(query []byte, role string) (*QCode, error) {
|
||||
|
||||
func (com *Compiler) compileQuery(qc *QCode, op *Operation, role string) error {
|
||||
id := int32(0)
|
||||
parentID := int32(0)
|
||||
parentID := int32(-1)
|
||||
|
||||
if len(op.Fields) == 0 {
|
||||
return errors.New("invalid graphql no query found")
|
||||
@ -269,7 +273,12 @@ func (com *Compiler) compileQuery(qc *QCode, op *Operation, role string) error {
|
||||
if len(op.Fields) == 0 {
|
||||
return errors.New("empty query")
|
||||
}
|
||||
st.Push(op.Fields[0].ID)
|
||||
|
||||
for i := range op.Fields {
|
||||
if op.Fields[i].ParentID == -1 {
|
||||
st.Push(op.Fields[i].ID)
|
||||
}
|
||||
}
|
||||
|
||||
for {
|
||||
if st.Len() == 0 {
|
||||
@ -313,11 +322,6 @@ func (com *Compiler) compileQuery(qc *QCode, op *Operation, role string) error {
|
||||
s.PresetList = trv.update.pslist
|
||||
}
|
||||
|
||||
if s.ID != 0 {
|
||||
p := &selects[s.ParentID]
|
||||
p.Children = append(p.Children, s.ID)
|
||||
}
|
||||
|
||||
if len(field.Alias) != 0 {
|
||||
s.FieldName = field.Alias
|
||||
} else {
|
||||
@ -329,6 +333,15 @@ func (com *Compiler) compileQuery(qc *QCode, op *Operation, role string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
// Order is important addFilters must come after compileArgs
|
||||
if s.ParentID == -1 {
|
||||
qc.Roots = append(qc.Roots, s.ID)
|
||||
com.addFilters(qc, s, role)
|
||||
} else {
|
||||
p := &selects[s.ParentID]
|
||||
p.Children = append(p.Children, s.ID)
|
||||
}
|
||||
|
||||
s.Cols = make([]Column, 0, len(field.Children))
|
||||
action = QTQuery
|
||||
|
||||
@ -362,36 +375,40 @@ func (com *Compiler) compileQuery(qc *QCode, op *Operation, role string) error {
|
||||
return errors.New("invalid query")
|
||||
}
|
||||
|
||||
var fil *Exp
|
||||
root := &selects[0]
|
||||
qc.Selects = selects[:id]
|
||||
return nil
|
||||
}
|
||||
|
||||
if trv, ok := com.tr[role][op.Fields[0].Name]; ok {
|
||||
func (com *Compiler) addFilters(qc *QCode, root *Select, role string) {
|
||||
var fil *Exp
|
||||
|
||||
if trv, ok := com.tr[role][root.Table]; ok {
|
||||
fil = trv.filter(qc.Type)
|
||||
}
|
||||
|
||||
if fil != nil {
|
||||
switch fil.Op {
|
||||
case OpNop:
|
||||
case OpFalse:
|
||||
root.Where = fil
|
||||
default:
|
||||
if root.Where != nil {
|
||||
ow := root.Where
|
||||
|
||||
root.Where = expPool.Get().(*Exp)
|
||||
root.Where.Reset()
|
||||
root.Where.Op = OpAnd
|
||||
root.Where.Children = root.Where.childrenA[:2]
|
||||
root.Where.Children[0] = fil
|
||||
root.Where.Children[1] = ow
|
||||
} else {
|
||||
root.Where = fil
|
||||
}
|
||||
}
|
||||
if fil == nil {
|
||||
return
|
||||
}
|
||||
|
||||
qc.Selects = selects[:id]
|
||||
return nil
|
||||
switch fil.Op {
|
||||
case OpNop:
|
||||
case OpFalse:
|
||||
root.Where = fil
|
||||
|
||||
default:
|
||||
if root.Where != nil {
|
||||
ow := root.Where
|
||||
|
||||
root.Where = expPool.Get().(*Exp)
|
||||
root.Where.Reset()
|
||||
root.Where.Op = OpAnd
|
||||
root.Where.Children = root.Where.childrenA[:2]
|
||||
root.Where.Children[0] = fil
|
||||
root.Where.Children[1] = ow
|
||||
} else {
|
||||
root.Where = fil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (com *Compiler) compileArgs(qc *QCode, sel *Select, args []Arg) error {
|
||||
@ -577,7 +594,7 @@ func (com *Compiler) compileArgID(sel *Select, arg *Arg) error {
|
||||
case nodeVar:
|
||||
ex.Type = ValVar
|
||||
default:
|
||||
fmt.Errorf("expecting a string, int, float or variable")
|
||||
return fmt.Errorf("expecting a string, int, float or variable")
|
||||
}
|
||||
|
||||
sel.Where = ex
|
||||
|
23
qcode/utils.go
Normal file
23
qcode/utils.go
Normal file
@ -0,0 +1,23 @@
|
||||
package qcode
|
||||
|
||||
func GetQType(gql string) QType {
|
||||
for i := range gql {
|
||||
b := gql[i]
|
||||
if b == '{' {
|
||||
return QTQuery
|
||||
}
|
||||
if al(b) {
|
||||
switch b {
|
||||
case 'm', 'M':
|
||||
return QTMutation
|
||||
case 'q', 'Q':
|
||||
return QTQuery
|
||||
}
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func al(b byte) bool {
|
||||
return (b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z') || (b >= '0' && b <= '9')
|
||||
}
|
@ -18,6 +18,8 @@ const (
|
||||
)
|
||||
|
||||
type allowItem struct {
|
||||
name string
|
||||
hash string
|
||||
uri string
|
||||
gql string
|
||||
vars json.RawMessage
|
||||
@ -46,7 +48,7 @@ func initAllowList(cpath string) {
|
||||
if _, err := os.Stat(fp); err == nil {
|
||||
_allowList.filepath = fp
|
||||
} else if !os.IsNotExist(err) {
|
||||
logger.Fatal().Err(err).Send()
|
||||
errlog.Fatal().Err(err).Send()
|
||||
}
|
||||
}
|
||||
|
||||
@ -56,7 +58,7 @@ func initAllowList(cpath string) {
|
||||
if _, err := os.Stat(fp); err == nil {
|
||||
_allowList.filepath = fp
|
||||
} else if !os.IsNotExist(err) {
|
||||
logger.Fatal().Err(err).Send()
|
||||
errlog.Fatal().Err(err).Send()
|
||||
}
|
||||
}
|
||||
|
||||
@ -66,13 +68,13 @@ func initAllowList(cpath string) {
|
||||
if _, err := os.Stat(fp); err == nil {
|
||||
_allowList.filepath = fp
|
||||
} else if !os.IsNotExist(err) {
|
||||
logger.Fatal().Err(err).Send()
|
||||
errlog.Fatal().Err(err).Send()
|
||||
}
|
||||
}
|
||||
|
||||
if len(_allowList.filepath) == 0 {
|
||||
if conf.Production {
|
||||
logger.Fatal().Msg("allow.list not found")
|
||||
errlog.Fatal().Msg("allow.list not found")
|
||||
}
|
||||
|
||||
if len(cpath) == 0 {
|
||||
@ -94,7 +96,7 @@ func initAllowList(cpath string) {
|
||||
}
|
||||
|
||||
func (al *allowList) add(req *gqlReq) {
|
||||
if al.active == false || len(req.ref) == 0 || len(req.Query) == 0 {
|
||||
if len(req.ref) == 0 || len(req.Query) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
@ -119,11 +121,39 @@ func (al *allowList) add(req *gqlReq) {
|
||||
}
|
||||
}
|
||||
|
||||
func (al *allowList) load() {
|
||||
if al.active == false {
|
||||
return
|
||||
func (al *allowList) upsert(query, vars []byte, uri string) {
|
||||
q := string(query)
|
||||
hash := gqlHash(q, vars, "")
|
||||
name := gqlName(q)
|
||||
|
||||
var key string
|
||||
|
||||
if len(name) == 0 {
|
||||
key = hash
|
||||
} else {
|
||||
key = name
|
||||
}
|
||||
|
||||
if i, ok := al.index[key]; !ok {
|
||||
al.list = append(al.list, &allowItem{
|
||||
name: name,
|
||||
hash: hash,
|
||||
uri: uri,
|
||||
gql: q,
|
||||
vars: vars,
|
||||
})
|
||||
al.index[key] = len(al.list) - 1
|
||||
} else {
|
||||
item := al.list[i]
|
||||
item.name = name
|
||||
item.hash = hash
|
||||
item.gql = q
|
||||
item.vars = vars
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func (al *allowList) load() {
|
||||
b, err := ioutil.ReadFile(al.filepath)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
@ -172,22 +202,7 @@ func (al *allowList) load() {
|
||||
|
||||
if c == 0 {
|
||||
if ty == AL_QUERY {
|
||||
q := string(b[s:(e + 1)])
|
||||
key := gqlHash(q, varBytes, "")
|
||||
|
||||
if idx, ok := al.index[key]; !ok {
|
||||
al.list = append(al.list, &allowItem{
|
||||
uri: uri,
|
||||
gql: q,
|
||||
vars: varBytes,
|
||||
})
|
||||
al.index[key] = len(al.list) - 1
|
||||
} else {
|
||||
item := al.list[idx]
|
||||
item.gql = q
|
||||
item.vars = varBytes
|
||||
}
|
||||
|
||||
al.upsert(b[s:(e+1)], varBytes, uri)
|
||||
varBytes = nil
|
||||
|
||||
} else if ty == AL_VARS {
|
||||
@ -205,19 +220,33 @@ func (al *allowList) load() {
|
||||
}
|
||||
|
||||
func (al *allowList) save(item *allowItem) {
|
||||
if al.active == false {
|
||||
return
|
||||
item.hash = gqlHash(item.gql, item.vars, "")
|
||||
item.name = gqlName(item.gql)
|
||||
|
||||
if len(item.name) == 0 {
|
||||
key := item.hash
|
||||
|
||||
if _, ok := al.index[key]; ok {
|
||||
return
|
||||
}
|
||||
|
||||
al.list = append(al.list, item)
|
||||
al.index[key] = len(al.list) - 1
|
||||
|
||||
} else {
|
||||
key := item.name
|
||||
|
||||
if i, ok := al.index[key]; ok {
|
||||
if al.list[i].hash == item.hash {
|
||||
return
|
||||
}
|
||||
al.list[i] = item
|
||||
} else {
|
||||
al.list = append(al.list, item)
|
||||
al.index[key] = len(al.list) - 1
|
||||
}
|
||||
}
|
||||
|
||||
key := gqlHash(item.gql, item.vars, "")
|
||||
|
||||
if _, ok := al.index[key]; ok {
|
||||
return
|
||||
}
|
||||
|
||||
al.list = append(al.list, item)
|
||||
al.index[key] = len(al.list) - 1
|
||||
|
||||
f, err := os.Create(al.filepath)
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msgf("Failed to write allow list: %s", al.filepath)
|
||||
|
80
serv/args.go
80
serv/args.go
@ -2,63 +2,46 @@ package serv
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/dosco/super-graph/jsn"
|
||||
)
|
||||
|
||||
func argMap(ctx *coreContext) func(w io.Writer, tag string) (int, error) {
|
||||
func argMap(ctx context.Context, vars []byte) func(w io.Writer, tag string) (int, error) {
|
||||
return func(w io.Writer, tag string) (int, error) {
|
||||
switch tag {
|
||||
case "user_id_provider":
|
||||
if v := ctx.Value(userIDProviderKey); v != nil {
|
||||
return stringArg(w, v.(string))
|
||||
return io.WriteString(w, v.(string))
|
||||
}
|
||||
io.WriteString(w, "null")
|
||||
return 0, nil
|
||||
return 0, errors.New("query requires variable $user_id_provider")
|
||||
|
||||
case "user_id":
|
||||
if v := ctx.Value(userIDKey); v != nil {
|
||||
return stringArg(w, v.(string))
|
||||
return io.WriteString(w, v.(string))
|
||||
}
|
||||
|
||||
io.WriteString(w, "null")
|
||||
return 0, nil
|
||||
return 0, errors.New("query requires variable $user_id")
|
||||
|
||||
case "user_role":
|
||||
if v := ctx.Value(userRoleKey); v != nil {
|
||||
return stringArg(w, v.(string))
|
||||
return io.WriteString(w, v.(string))
|
||||
}
|
||||
io.WriteString(w, "null")
|
||||
return 0, errors.New("query requires variable $user_role")
|
||||
}
|
||||
|
||||
fields := jsn.Get(vars, [][]byte{[]byte(tag)})
|
||||
if len(fields) == 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
fields := jsn.Get(ctx.req.Vars, [][]byte{[]byte(tag)})
|
||||
if len(fields) == 0 {
|
||||
return 0, fmt.Errorf("variable '%s' not found", tag)
|
||||
}
|
||||
|
||||
is := false
|
||||
|
||||
for i := range fields[0].Value {
|
||||
c := fields[0].Value[i]
|
||||
if c != ' ' {
|
||||
is = (c == '"') || (c == '{') || (c == '[')
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if is {
|
||||
return stringArgB(w, fields[0].Value)
|
||||
}
|
||||
|
||||
w.Write(fields[0].Value)
|
||||
return 0, nil
|
||||
return w.Write(fields[0].Value)
|
||||
}
|
||||
}
|
||||
|
||||
func argList(ctx *coreContext, args [][]byte) []interface{} {
|
||||
func argList(ctx *coreContext, args [][]byte) ([]interface{}, error) {
|
||||
vars := make([]interface{}, len(args))
|
||||
|
||||
var fields map[string]interface{}
|
||||
@ -68,7 +51,7 @@ func argList(ctx *coreContext, args [][]byte) []interface{} {
|
||||
fields, _, err = jsn.Tree(ctx.req.Vars)
|
||||
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msg("Failed to parse variables")
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
@ -79,44 +62,33 @@ func argList(ctx *coreContext, args [][]byte) []interface{} {
|
||||
case bytes.Equal(av, []byte("user_id")):
|
||||
if v := ctx.Value(userIDKey); v != nil {
|
||||
vars[i] = v.(string)
|
||||
} else {
|
||||
return nil, errors.New("query requires variable $user_id")
|
||||
}
|
||||
|
||||
case bytes.Equal(av, []byte("user_id_provider")):
|
||||
if v := ctx.Value(userIDProviderKey); v != nil {
|
||||
vars[i] = v.(string)
|
||||
} else {
|
||||
return nil, errors.New("query requires variable $user_id_provider")
|
||||
}
|
||||
|
||||
case bytes.Equal(av, []byte("user_role")):
|
||||
if v := ctx.Value(userRoleKey); v != nil {
|
||||
vars[i] = v.(string)
|
||||
} else {
|
||||
return nil, errors.New("query requires variable $user_role")
|
||||
}
|
||||
|
||||
default:
|
||||
if v, ok := fields[string(av)]; ok {
|
||||
vars[i] = v
|
||||
} else {
|
||||
return nil, fmt.Errorf("query requires variable $%s", string(av))
|
||||
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return vars
|
||||
}
|
||||
|
||||
func stringArg(w io.Writer, v string) (int, error) {
|
||||
if n, err := w.Write([]byte(`'`)); err != nil {
|
||||
return n, err
|
||||
}
|
||||
if n, err := w.Write([]byte(v)); err != nil {
|
||||
return n, err
|
||||
}
|
||||
return w.Write([]byte(`'`))
|
||||
}
|
||||
|
||||
func stringArgB(w io.Writer, v []byte) (int, error) {
|
||||
if n, err := w.Write([]byte(`'`)); err != nil {
|
||||
return n, err
|
||||
}
|
||||
if n, err := w.Write(v); err != nil {
|
||||
return n, err
|
||||
}
|
||||
return w.Write([]byte(`'`))
|
||||
return vars, nil
|
||||
}
|
||||
|
@ -35,7 +35,7 @@ func jwtHandler(next http.HandlerFunc) http.HandlerFunc {
|
||||
case len(publicKeyFile) != 0:
|
||||
kd, err := ioutil.ReadFile(publicKeyFile)
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Send()
|
||||
errlog.Fatal().Err(err).Send()
|
||||
}
|
||||
|
||||
switch conf.Auth.JWT.PubKeyType {
|
||||
@ -51,7 +51,7 @@ func jwtHandler(next http.HandlerFunc) http.HandlerFunc {
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Send()
|
||||
errlog.Fatal().Err(err).Send()
|
||||
}
|
||||
}
|
||||
|
||||
@ -95,8 +95,11 @@ func jwtHandler(next http.HandlerFunc) http.HandlerFunc {
|
||||
} else {
|
||||
ctx = context.WithValue(ctx, userIDKey, claims.Subject)
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
}
|
||||
|
@ -15,11 +15,11 @@ import (
|
||||
func railsRedisHandler(next http.HandlerFunc) http.HandlerFunc {
|
||||
cookie := conf.Auth.Cookie
|
||||
if len(cookie) == 0 {
|
||||
logger.Fatal().Msg("no auth.cookie defined")
|
||||
errlog.Fatal().Msg("no auth.cookie defined")
|
||||
}
|
||||
|
||||
if len(conf.Auth.Rails.URL) == 0 {
|
||||
logger.Fatal().Msg("no auth.rails.url defined")
|
||||
errlog.Fatal().Msg("no auth.rails.url defined")
|
||||
}
|
||||
|
||||
rp := &redis.Pool{
|
||||
@ -28,13 +28,13 @@ func railsRedisHandler(next http.HandlerFunc) http.HandlerFunc {
|
||||
Dial: func() (redis.Conn, error) {
|
||||
c, err := redis.DialURL(conf.Auth.Rails.URL)
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Send()
|
||||
errlog.Fatal().Err(err).Send()
|
||||
}
|
||||
|
||||
pwd := conf.Auth.Rails.Password
|
||||
if len(pwd) != 0 {
|
||||
if _, err := c.Do("AUTH", pwd); err != nil {
|
||||
logger.Fatal().Err(err).Send()
|
||||
errlog.Fatal().Err(err).Send()
|
||||
}
|
||||
}
|
||||
return c, err
|
||||
@ -69,16 +69,16 @@ func railsRedisHandler(next http.HandlerFunc) http.HandlerFunc {
|
||||
func railsMemcacheHandler(next http.HandlerFunc) http.HandlerFunc {
|
||||
cookie := conf.Auth.Cookie
|
||||
if len(cookie) == 0 {
|
||||
logger.Fatal().Msg("no auth.cookie defined")
|
||||
errlog.Fatal().Msg("no auth.cookie defined")
|
||||
}
|
||||
|
||||
if len(conf.Auth.Rails.URL) == 0 {
|
||||
logger.Fatal().Msg("no auth.rails.url defined")
|
||||
errlog.Fatal().Msg("no auth.rails.url defined")
|
||||
}
|
||||
|
||||
rURL, err := url.Parse(conf.Auth.Rails.URL)
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Send()
|
||||
errlog.Fatal().Err(err).Send()
|
||||
}
|
||||
|
||||
mc := memcache.New(rURL.Host)
|
||||
@ -111,12 +111,12 @@ func railsMemcacheHandler(next http.HandlerFunc) http.HandlerFunc {
|
||||
func railsCookieHandler(next http.HandlerFunc) http.HandlerFunc {
|
||||
cookie := conf.Auth.Cookie
|
||||
if len(cookie) == 0 {
|
||||
logger.Fatal().Msg("no auth.cookie defined")
|
||||
errlog.Fatal().Msg("no auth.cookie defined")
|
||||
}
|
||||
|
||||
ra, err := railsAuth(conf)
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Send()
|
||||
errlog.Fatal().Err(err).Send()
|
||||
}
|
||||
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
|
48
serv/cmd.go
48
serv/cmd.go
@ -22,16 +22,18 @@ const (
|
||||
)
|
||||
|
||||
var (
|
||||
logger *zerolog.Logger
|
||||
logger zerolog.Logger
|
||||
errlog zerolog.Logger
|
||||
conf *config
|
||||
confPath string
|
||||
db *pgxpool.Pool
|
||||
schema *psql.DBSchema
|
||||
qcompile *qcode.Compiler
|
||||
pcompile *psql.Compiler
|
||||
)
|
||||
|
||||
func Init() {
|
||||
logger = initLog()
|
||||
initLog()
|
||||
|
||||
rootCmd := &cobra.Command{
|
||||
Use: "super-graph",
|
||||
@ -110,6 +112,13 @@ e.g. db:migrate -+1
|
||||
Run: cmdDBSetup,
|
||||
})
|
||||
|
||||
rootCmd.AddCommand(&cobra.Command{
|
||||
Use: "db:reset",
|
||||
Short: "Reset database",
|
||||
Long: "This command will drop, create, migrate and seed the database (won't run in production)",
|
||||
Run: cmdDBReset,
|
||||
})
|
||||
|
||||
rootCmd.AddCommand(&cobra.Command{
|
||||
Use: "new APP-NAME",
|
||||
Short: "Create a new application",
|
||||
@ -128,19 +137,14 @@ e.g. db:migrate -+1
|
||||
"path", "./config", "path to config files")
|
||||
|
||||
if err := rootCmd.Execute(); err != nil {
|
||||
logger.Fatal().Err(err).Send()
|
||||
errlog.Fatal().Err(err).Send()
|
||||
}
|
||||
}
|
||||
|
||||
func initLog() *zerolog.Logger {
|
||||
func initLog() {
|
||||
out := zerolog.ConsoleWriter{Out: os.Stderr}
|
||||
logger := zerolog.New(out).
|
||||
With().
|
||||
Timestamp().
|
||||
Caller().
|
||||
Logger()
|
||||
|
||||
return &logger
|
||||
logger = zerolog.New(out).With().Timestamp().Logger()
|
||||
errlog = logger.With().Caller().Logger()
|
||||
}
|
||||
|
||||
func initConf() (*config, error) {
|
||||
@ -159,7 +163,7 @@ func initConf() (*config, error) {
|
||||
}
|
||||
|
||||
if vi.IsSet("inherits") {
|
||||
logger.Fatal().Msgf("inherited config (%s) cannot itself inherit (%s)",
|
||||
errlog.Fatal().Msgf("inherited config (%s) cannot itself inherit (%s)",
|
||||
inherits,
|
||||
vi.GetString("inherits"))
|
||||
}
|
||||
@ -176,7 +180,7 @@ func initConf() (*config, error) {
|
||||
|
||||
logLevel, err := zerolog.ParseLevel(c.LogLevel)
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Msg("error setting log_level")
|
||||
errlog.Error().Err(err).Msg("error setting log_level")
|
||||
}
|
||||
zerolog.SetGlobalLevel(logLevel)
|
||||
|
||||
@ -211,7 +215,7 @@ func initDB(c *config, useDB bool) (*pgx.Conn, error) {
|
||||
config.LogLevel = pgx.LogLevelNone
|
||||
}
|
||||
|
||||
config.Logger = NewSQLLogger(*logger)
|
||||
config.Logger = NewSQLLogger(logger)
|
||||
|
||||
db, err := pgx.ConnectConfig(context.Background(), config)
|
||||
if err != nil {
|
||||
@ -246,7 +250,7 @@ func initDBPool(c *config) (*pgxpool.Pool, error) {
|
||||
config.ConnConfig.LogLevel = pgx.LogLevelNone
|
||||
}
|
||||
|
||||
config.ConnConfig.Logger = NewSQLLogger(*logger)
|
||||
config.ConnConfig.Logger = NewSQLLogger(logger)
|
||||
|
||||
// if c.DB.MaxRetries != 0 {
|
||||
// opt.MaxRetries = c.DB.MaxRetries
|
||||
@ -269,10 +273,20 @@ func initCompiler() {
|
||||
|
||||
qcompile, pcompile, err = initCompilers(conf)
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("failed to initialize compilers")
|
||||
errlog.Fatal().Err(err).Msg("failed to initialize compilers")
|
||||
}
|
||||
|
||||
if err := initResolvers(); err != nil {
|
||||
logger.Fatal().Err(err).Msg("failed to initialized resolvers")
|
||||
errlog.Fatal().Err(err).Msg("failed to initialized resolvers")
|
||||
}
|
||||
}
|
||||
|
||||
func initConfOnce() {
|
||||
var err error
|
||||
|
||||
if conf == nil {
|
||||
if conf, err = initConf(); err != nil {
|
||||
errlog.Fatal().Err(err).Msg("failed to read config")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -17,11 +17,11 @@ func cmdConfDump(cmd *cobra.Command, args []string) {
|
||||
|
||||
conf, err := initConf()
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("failed to read config")
|
||||
errlog.Fatal().Err(err).Msg("failed to read config")
|
||||
}
|
||||
|
||||
if err := conf.Viper.WriteConfigAs(fname); err != nil {
|
||||
logger.Fatal().Err(err).Send()
|
||||
errlog.Fatal().Err(err).Send()
|
||||
}
|
||||
|
||||
logger.Info().Msgf("config dumped to ./%s", fname)
|
||||
|
@ -36,6 +36,7 @@ var newMigrationText = `-- Write your migrate up statements here
|
||||
`
|
||||
|
||||
func cmdDBSetup(cmd *cobra.Command, args []string) {
|
||||
initConfOnce()
|
||||
cmdDBCreate(cmd, []string{})
|
||||
cmdDBMigrate(cmd, []string{"up"})
|
||||
|
||||
@ -48,24 +49,30 @@ func cmdDBSetup(cmd *cobra.Command, args []string) {
|
||||
}
|
||||
|
||||
if os.IsNotExist(err) == false {
|
||||
logger.Fatal().Err(err).Msgf("unable to check if '%s' exists", sfile)
|
||||
errlog.Fatal().Err(err).Msgf("unable to check if '%s' exists", sfile)
|
||||
}
|
||||
|
||||
logger.Warn().Msgf("failed to read seed file '%s'", sfile)
|
||||
}
|
||||
|
||||
func cmdDBCreate(cmd *cobra.Command, args []string) {
|
||||
var err error
|
||||
func cmdDBReset(cmd *cobra.Command, args []string) {
|
||||
initConfOnce()
|
||||
|
||||
if conf, err = initConf(); err != nil {
|
||||
logger.Fatal().Err(err).Msg("failed to read config")
|
||||
if conf.Production {
|
||||
errlog.Fatal().Msg("db:reset does not work in production")
|
||||
return
|
||||
}
|
||||
cmdDBDrop(cmd, []string{})
|
||||
cmdDBSetup(cmd, []string{})
|
||||
}
|
||||
|
||||
func cmdDBCreate(cmd *cobra.Command, args []string) {
|
||||
initConfOnce()
|
||||
ctx := context.Background()
|
||||
|
||||
conn, err := initDB(conf, false)
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("failed to connect to database")
|
||||
errlog.Fatal().Err(err).Msg("failed to connect to database")
|
||||
}
|
||||
defer conn.Close(ctx)
|
||||
|
||||
@ -73,24 +80,19 @@ func cmdDBCreate(cmd *cobra.Command, args []string) {
|
||||
|
||||
_, err = conn.Exec(ctx, sql)
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("failed to create database")
|
||||
errlog.Fatal().Err(err).Msg("failed to create database")
|
||||
}
|
||||
|
||||
logger.Info().Msgf("created database '%s'", conf.DB.DBName)
|
||||
}
|
||||
|
||||
func cmdDBDrop(cmd *cobra.Command, args []string) {
|
||||
var err error
|
||||
|
||||
if conf, err = initConf(); err != nil {
|
||||
logger.Fatal().Err(err).Msg("failed to read config")
|
||||
}
|
||||
|
||||
initConfOnce()
|
||||
ctx := context.Background()
|
||||
|
||||
conn, err := initDB(conf, false)
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("failed to connect to database")
|
||||
errlog.Fatal().Err(err).Msg("failed to connect to database")
|
||||
}
|
||||
defer conn.Close(ctx)
|
||||
|
||||
@ -98,7 +100,7 @@ func cmdDBDrop(cmd *cobra.Command, args []string) {
|
||||
|
||||
_, err = conn.Exec(ctx, sql)
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("failed to create database")
|
||||
errlog.Fatal().Err(err).Msg("failed to create database")
|
||||
}
|
||||
|
||||
logger.Info().Msgf("dropped database '%s'", conf.DB.DBName)
|
||||
@ -110,12 +112,7 @@ func cmdDBNew(cmd *cobra.Command, args []string) {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
var err error
|
||||
|
||||
if conf, err = initConf(); err != nil {
|
||||
logger.Fatal().Err(err).Msg("failed to read config")
|
||||
}
|
||||
|
||||
initConfOnce()
|
||||
name := args[0]
|
||||
|
||||
m, err := migrate.FindMigrations(conf.MigrationsPath)
|
||||
@ -124,7 +121,7 @@ func cmdDBNew(cmd *cobra.Command, args []string) {
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
mname := fmt.Sprintf("%03d_%s.sql", len(m)+100, name)
|
||||
mname := fmt.Sprintf("%d_%s.sql", len(m), name)
|
||||
|
||||
// Write new migration
|
||||
mpath := filepath.Join(conf.MigrationsPath, mname)
|
||||
@ -144,39 +141,34 @@ func cmdDBNew(cmd *cobra.Command, args []string) {
|
||||
}
|
||||
|
||||
func cmdDBMigrate(cmd *cobra.Command, args []string) {
|
||||
var err error
|
||||
|
||||
if len(args) == 0 {
|
||||
cmd.Help()
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
initConfOnce()
|
||||
dest := args[0]
|
||||
|
||||
if conf, err = initConf(); err != nil {
|
||||
logger.Fatal().Err(err).Msg("failed to read config")
|
||||
}
|
||||
|
||||
conn, err := initDB(conf, true)
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("failed to connect to database")
|
||||
errlog.Fatal().Err(err).Msg("failed to connect to database")
|
||||
}
|
||||
defer conn.Close(context.Background())
|
||||
|
||||
m, err := migrate.NewMigrator(conn, "schema_version")
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("failed to initializing migrator")
|
||||
errlog.Fatal().Err(err).Msg("failed to initializing migrator")
|
||||
}
|
||||
|
||||
m.Data = getMigrationVars()
|
||||
|
||||
err = m.LoadMigrations(conf.MigrationsPath)
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("failed to load migrations")
|
||||
errlog.Fatal().Err(err).Msg("failed to load migrations")
|
||||
}
|
||||
|
||||
if len(m.Migrations) == 0 {
|
||||
logger.Fatal().Msg("No migrations found")
|
||||
errlog.Fatal().Msg("No migrations found")
|
||||
}
|
||||
|
||||
m.OnStart = func(sequence int32, name, direction, sql string) {
|
||||
@ -195,7 +187,7 @@ func cmdDBMigrate(cmd *cobra.Command, args []string) {
|
||||
var n int64
|
||||
n, err = strconv.ParseInt(d, 10, 32)
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("invalid destination")
|
||||
errlog.Fatal().Err(err).Msg("invalid destination")
|
||||
}
|
||||
return int32(n)
|
||||
}
|
||||
@ -226,17 +218,15 @@ func cmdDBMigrate(cmd *cobra.Command, args []string) {
|
||||
if err != nil {
|
||||
logger.Info().Err(err).Send()
|
||||
|
||||
// logger.Info().Err(err).Send()
|
||||
|
||||
// if err, ok := err.(m.MigrationPgError); ok {
|
||||
// if err.Detail != "" {
|
||||
// logger.Info().Err(err).Msg(err.Detail)
|
||||
// info.Err(err).Msg(err.Detail)
|
||||
// }
|
||||
|
||||
// if err.Position != 0 {
|
||||
// ele, err := ExtractErrorLine(err.Sql, int(err.Position))
|
||||
// if err != nil {
|
||||
// logger.Fatal().Err(err).Send()
|
||||
// errlog.Fatal().Err(err).Send()
|
||||
// }
|
||||
|
||||
// prefix := fmt.Sprintf()
|
||||
@ -251,37 +241,33 @@ func cmdDBMigrate(cmd *cobra.Command, args []string) {
|
||||
}
|
||||
|
||||
func cmdDBStatus(cmd *cobra.Command, args []string) {
|
||||
var err error
|
||||
|
||||
if conf, err = initConf(); err != nil {
|
||||
logger.Fatal().Err(err).Msg("failed to read config")
|
||||
}
|
||||
initConfOnce()
|
||||
|
||||
conn, err := initDB(conf, true)
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("failed to connect to database")
|
||||
errlog.Fatal().Err(err).Msg("failed to connect to database")
|
||||
}
|
||||
defer conn.Close(context.Background())
|
||||
|
||||
m, err := migrate.NewMigrator(conn, "schema_version")
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("failed to initialize migrator")
|
||||
errlog.Fatal().Err(err).Msg("failed to initialize migrator")
|
||||
}
|
||||
|
||||
m.Data = getMigrationVars()
|
||||
|
||||
err = m.LoadMigrations(conf.MigrationsPath)
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("failed to load migrations")
|
||||
errlog.Fatal().Err(err).Msg("failed to load migrations")
|
||||
}
|
||||
|
||||
if len(m.Migrations) == 0 {
|
||||
logger.Fatal().Msg("no migrations found")
|
||||
errlog.Fatal().Msg("no migrations found")
|
||||
}
|
||||
|
||||
mver, err := m.GetCurrentVersion()
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("failed to retrieve migration")
|
||||
errlog.Fatal().Err(err).Msg("failed to retrieve migration")
|
||||
}
|
||||
|
||||
var status string
|
||||
|
@ -134,12 +134,12 @@ func ifNotExists(filePath string, doFn func(string) error) {
|
||||
}
|
||||
|
||||
if os.IsNotExist(err) == false {
|
||||
logger.Fatal().Err(err).Msgf("unable to check if '%s' exists", filePath)
|
||||
errlog.Fatal().Err(err).Msgf("unable to check if '%s' exists", filePath)
|
||||
}
|
||||
|
||||
err = doFn(filePath)
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msgf("unable to create '%s'", filePath)
|
||||
errlog.Fatal().Err(err).Msgf("unable to create '%s'", filePath)
|
||||
}
|
||||
|
||||
logger.Info().Msgf("created '%s'", filePath)
|
||||
|
@ -1,6 +1,7 @@
|
||||
package serv
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
@ -12,20 +13,21 @@ import (
|
||||
"github.com/brianvoe/gofakeit"
|
||||
"github.com/dop251/goja"
|
||||
"github.com/spf13/cobra"
|
||||
"github.com/valyala/fasttemplate"
|
||||
)
|
||||
|
||||
func cmdDBSeed(cmd *cobra.Command, args []string) {
|
||||
var err error
|
||||
|
||||
if conf, err = initConf(); err != nil {
|
||||
logger.Fatal().Err(err).Msg("failed to read config")
|
||||
errlog.Fatal().Err(err).Msg("failed to read config")
|
||||
}
|
||||
|
||||
conf.Production = false
|
||||
|
||||
db, err = initDBPool(conf)
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("failed to connect to database")
|
||||
errlog.Fatal().Err(err).Msg("failed to connect to database")
|
||||
}
|
||||
|
||||
initCompiler()
|
||||
@ -34,7 +36,7 @@ func cmdDBSeed(cmd *cobra.Command, args []string) {
|
||||
|
||||
b, err := ioutil.ReadFile(path.Join(confPath, conf.SeedFile))
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msgf("failed to read seed file '%s'", sfile)
|
||||
errlog.Fatal().Err(err).Msgf("failed to read seed file '%s'", sfile)
|
||||
}
|
||||
|
||||
vm := goja.New()
|
||||
@ -50,34 +52,77 @@ func cmdDBSeed(cmd *cobra.Command, args []string) {
|
||||
|
||||
_, err = vm.RunScript("seed.js", string(b))
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("failed to execute script")
|
||||
errlog.Fatal().Err(err).Msg("failed to execute script")
|
||||
}
|
||||
|
||||
logger.Info().Msg("seed script done")
|
||||
}
|
||||
|
||||
//func runFunc(call goja.FunctionCall) {
|
||||
func graphQLFunc(query string, data interface{}) map[string]interface{} {
|
||||
b, err := json.Marshal(data)
|
||||
func graphQLFunc(query string, data interface{}, opt map[string]string) map[string]interface{} {
|
||||
vars, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("failed to json serialize")
|
||||
errlog.Fatal().Err(err).Send()
|
||||
}
|
||||
|
||||
c := &coreContext{Context: context.Background()}
|
||||
c.req.Query = query
|
||||
c.req.Vars = b
|
||||
c.req.role = "user"
|
||||
c := context.Background()
|
||||
|
||||
res, err := c.execQuery()
|
||||
if v, ok := opt["user_id"]; ok && len(v) != 0 {
|
||||
c = context.WithValue(c, userIDKey, v)
|
||||
}
|
||||
|
||||
var role string
|
||||
|
||||
if v, ok := opt["role"]; ok && len(v) != 0 {
|
||||
role = v
|
||||
} else {
|
||||
role = "user"
|
||||
}
|
||||
|
||||
stmts, err := buildRoleStmt([]byte(query), vars, role)
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("graphql query failed")
|
||||
errlog.Fatal().Err(err).Msg("graphql query failed")
|
||||
}
|
||||
st := stmts[0]
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
t := fasttemplate.New(st.sql, openVar, closeVar)
|
||||
_, err = t.ExecuteFunc(buf, argMap(c, vars))
|
||||
|
||||
if err != nil {
|
||||
errlog.Fatal().Err(err).Send()
|
||||
}
|
||||
|
||||
finalSQL := buf.String()
|
||||
|
||||
tx, err := db.Begin(c)
|
||||
if err != nil {
|
||||
errlog.Fatal().Err(err).Send()
|
||||
}
|
||||
defer tx.Rollback(c)
|
||||
|
||||
if conf.DB.SetUserID {
|
||||
if err := setLocalUserID(c, tx); err != nil {
|
||||
errlog.Fatal().Err(err).Send()
|
||||
}
|
||||
}
|
||||
|
||||
var root []byte
|
||||
|
||||
if err = tx.QueryRow(c, finalSQL).Scan(&root); err != nil {
|
||||
errlog.Fatal().Err(err).Msg("sql query failed")
|
||||
}
|
||||
|
||||
if err := tx.Commit(c); err != nil {
|
||||
errlog.Fatal().Err(err).Send()
|
||||
}
|
||||
|
||||
val := make(map[string]interface{})
|
||||
|
||||
err = json.Unmarshal(res, &val)
|
||||
err = json.Unmarshal(root, &val)
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("failed to deserialize json")
|
||||
errlog.Fatal().Err(err).Send()
|
||||
}
|
||||
|
||||
return val
|
||||
@ -156,10 +201,9 @@ func setFakeFuncs(f *goja.Object) {
|
||||
f.Set("transmission_gear_type", gofakeit.TransmissionGearType)
|
||||
|
||||
// Text
|
||||
|
||||
f.Set("word", gofakeit.Word)
|
||||
f.Set("sentence", gofakeit.Sentence)
|
||||
f.Set("paragrph", gofakeit.Paragraph)
|
||||
f.Set("paragraph", gofakeit.Paragraph)
|
||||
f.Set("question", gofakeit.Question)
|
||||
f.Set("quote", gofakeit.Quote)
|
||||
|
||||
|
@ -8,12 +8,12 @@ func cmdServ(cmd *cobra.Command, args []string) {
|
||||
var err error
|
||||
|
||||
if conf, err = initConf(); err != nil {
|
||||
logger.Fatal().Err(err).Msg("failed to read config")
|
||||
errlog.Fatal().Err(err).Msg("failed to read config")
|
||||
}
|
||||
|
||||
db, err = initDBPool(conf)
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("failed to connect to database")
|
||||
errlog.Fatal().Err(err).Msg("failed to connect to database")
|
||||
}
|
||||
|
||||
initCompiler()
|
||||
|
@ -64,17 +64,12 @@ type config struct {
|
||||
User string
|
||||
Password string
|
||||
Schema string
|
||||
PoolSize int32 `mapstructure:"pool_size"`
|
||||
MaxRetries int `mapstructure:"max_retries"`
|
||||
LogLevel string `mapstructure:"log_level"`
|
||||
SetUserID bool `mapstructure:"set_user_id"`
|
||||
PoolSize int32 `mapstructure:"pool_size"`
|
||||
MaxRetries int `mapstructure:"max_retries"`
|
||||
SetUserID bool `mapstructure:"set_user_id"`
|
||||
|
||||
Vars map[string]string `mapstructure:"variables"`
|
||||
|
||||
Defaults struct {
|
||||
Filters []string
|
||||
Blocklist []string
|
||||
}
|
||||
Vars map[string]string `mapstructure:"variables"`
|
||||
Blocklist []string
|
||||
|
||||
Tables []configTable
|
||||
} `mapstructure:"database"`
|
||||
@ -83,6 +78,7 @@ type config struct {
|
||||
|
||||
RolesQuery string `mapstructure:"roles_query"`
|
||||
Roles []configRole
|
||||
roles map[string]*configRole
|
||||
}
|
||||
|
||||
type configTable struct {
|
||||
@ -221,16 +217,15 @@ func (c *config) Init(vi *viper.Viper) error {
|
||||
}
|
||||
|
||||
c.RolesQuery = sanitize(c.RolesQuery)
|
||||
|
||||
rolesMap := make(map[string]struct{})
|
||||
c.roles = make(map[string]*configRole)
|
||||
|
||||
for i := range c.Roles {
|
||||
role := &c.Roles[i]
|
||||
|
||||
if _, ok := rolesMap[role.Name]; ok {
|
||||
logger.Fatal().Msgf("duplicate role '%s' found", role.Name)
|
||||
if _, ok := c.roles[role.Name]; ok {
|
||||
errlog.Fatal().Msgf("duplicate role '%s' found", role.Name)
|
||||
}
|
||||
role.Name = sanitize(role.Name)
|
||||
role.Name = strings.ToLower(role.Name)
|
||||
role.Match = sanitize(role.Match)
|
||||
role.tablesMap = make(map[string]*configRoleTable)
|
||||
|
||||
@ -238,14 +233,16 @@ func (c *config) Init(vi *viper.Viper) error {
|
||||
role.tablesMap[table.Name] = &role.Tables[n]
|
||||
}
|
||||
|
||||
rolesMap[role.Name] = struct{}{}
|
||||
c.roles[role.Name] = role
|
||||
}
|
||||
|
||||
if _, ok := rolesMap["user"]; !ok {
|
||||
c.Roles = append(c.Roles, configRole{Name: "user"})
|
||||
if _, ok := c.roles["user"]; !ok {
|
||||
u := configRole{Name: "user"}
|
||||
c.Roles = append(c.Roles, u)
|
||||
c.roles["user"] = &u
|
||||
}
|
||||
|
||||
if _, ok := rolesMap["anon"]; !ok {
|
||||
if _, ok := c.roles["anon"]; !ok {
|
||||
logger.Warn().Msg("unauthenticated requests will be blocked. no role 'anon' defined")
|
||||
c.AuthFailBlock = true
|
||||
}
|
||||
@ -262,7 +259,7 @@ func (c *config) validate() {
|
||||
name := c.Roles[i].Name
|
||||
|
||||
if _, ok := rm[name]; ok {
|
||||
logger.Fatal().Msgf("duplicate config for role '%s'", c.Roles[i].Name)
|
||||
errlog.Fatal().Msgf("duplicate config for role '%s'", c.Roles[i].Name)
|
||||
}
|
||||
rm[name] = struct{}{}
|
||||
}
|
||||
@ -273,7 +270,7 @@ func (c *config) validate() {
|
||||
name := c.Tables[i].Name
|
||||
|
||||
if _, ok := tm[name]; ok {
|
||||
logger.Fatal().Msgf("duplicate config for table '%s'", c.Tables[i].Name)
|
||||
errlog.Fatal().Msgf("duplicate config for table '%s'", c.Tables[i].Name)
|
||||
}
|
||||
tm[name] = struct{}{}
|
||||
}
|
||||
|
413
serv/core.go
413
serv/core.go
@ -8,11 +8,9 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/cespare/xxhash/v2"
|
||||
"github.com/dosco/super-graph/jsn"
|
||||
"github.com/dosco/super-graph/qcode"
|
||||
"github.com/jackc/pgx/v4"
|
||||
"github.com/valyala/fasttemplate"
|
||||
@ -32,6 +30,10 @@ func (c *coreContext) handleReq(w io.Writer, req *http.Request) error {
|
||||
c.req.ref = req.Referer()
|
||||
c.req.hdr = req.Header
|
||||
|
||||
if len(c.req.Vars) == 2 {
|
||||
c.req.Vars = nil
|
||||
}
|
||||
|
||||
if authCheck(c) {
|
||||
c.req.role = "user"
|
||||
} else {
|
||||
@ -47,90 +49,55 @@ func (c *coreContext) handleReq(w io.Writer, req *http.Request) error {
|
||||
}
|
||||
|
||||
func (c *coreContext) execQuery() ([]byte, error) {
|
||||
var err error
|
||||
var skipped uint32
|
||||
var qc *qcode.QCode
|
||||
var data []byte
|
||||
|
||||
logger.Debug().Str("role", c.req.role).Msg(c.req.Query)
|
||||
var st *stmt
|
||||
var err error
|
||||
|
||||
if conf.Production {
|
||||
var ps *preparedItem
|
||||
|
||||
data, ps, err = c.resolvePreparedSQL()
|
||||
data, st, err = c.resolvePreparedSQL()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
logger.Error().
|
||||
Err(err).
|
||||
Str("default_role", c.req.role).
|
||||
Msg(c.req.Query)
|
||||
|
||||
skipped = ps.skipped
|
||||
qc = ps.qc
|
||||
return nil, errors.New("query failed. check logs for error")
|
||||
}
|
||||
|
||||
} else {
|
||||
|
||||
data, skipped, err = c.resolveSQL()
|
||||
if err != nil {
|
||||
if data, st, err = c.resolveSQL(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if len(data) == 0 || skipped == 0 {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
sel := qc.Selects
|
||||
h := xxhash.New()
|
||||
|
||||
// fetch the field name used within the db response json
|
||||
// that are used to mark insertion points and the mapping between
|
||||
// those field names and their select objects
|
||||
fids, sfmap := parentFieldIds(h, sel, skipped)
|
||||
|
||||
// fetch the field values of the marked insertion points
|
||||
// these values contain the id to be used with fetching remote data
|
||||
from := jsn.Get(data, fids)
|
||||
|
||||
var to []jsn.Field
|
||||
switch {
|
||||
case len(from) == 1:
|
||||
to, err = c.resolveRemote(c.req.hdr, h, from[0], sel, sfmap)
|
||||
|
||||
case len(from) > 1:
|
||||
to, err = c.resolveRemotes(c.req.hdr, h, from, sel, sfmap)
|
||||
|
||||
default:
|
||||
return nil, errors.New("something wrong no remote ids found in db response")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var ob bytes.Buffer
|
||||
|
||||
err = jsn.Replace(&ob, data, from, to)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ob.Bytes(), nil
|
||||
return execRemoteJoin(st, data, c.req.hdr)
|
||||
}
|
||||
|
||||
func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) {
|
||||
tx, err := db.Begin(c)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
func (c *coreContext) resolvePreparedSQL() ([]byte, *stmt, error) {
|
||||
var tx pgx.Tx
|
||||
var err error
|
||||
|
||||
qt := qcode.GetQType(c.req.Query)
|
||||
mutation := (qt == qcode.QTMutation)
|
||||
anonQuery := (qt == qcode.QTQuery && c.req.role == "anon")
|
||||
|
||||
useRoleQuery := len(conf.RolesQuery) != 0 && mutation
|
||||
useTx := useRoleQuery || conf.DB.SetUserID
|
||||
|
||||
if useTx {
|
||||
if tx, err = db.Begin(c); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer tx.Rollback(c)
|
||||
}
|
||||
defer tx.Rollback(c)
|
||||
|
||||
if conf.DB.SetUserID {
|
||||
if err := c.setLocalUserID(tx); err != nil {
|
||||
if err := setLocalUserID(c, tx); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var role string
|
||||
mutation := isMutation(c.req.Query)
|
||||
useRoleQuery := len(conf.RolesQuery) != 0 && mutation
|
||||
|
||||
if useRoleQuery {
|
||||
if role, err = c.executeRoleQuery(tx); err != nil {
|
||||
@ -140,7 +107,7 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) {
|
||||
} else if v := c.Value(userRoleKey); v != nil {
|
||||
role = v.(string)
|
||||
|
||||
} else if mutation {
|
||||
} else {
|
||||
role = c.req.role
|
||||
|
||||
}
|
||||
@ -151,69 +118,92 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) {
|
||||
}
|
||||
|
||||
var root []byte
|
||||
vars := argList(c, ps.args)
|
||||
var row pgx.Row
|
||||
|
||||
if mutation {
|
||||
err = tx.QueryRow(c, ps.stmt.SQL, vars...).Scan(&root)
|
||||
} else {
|
||||
err = tx.QueryRow(c, ps.stmt.SQL, vars...).Scan(&c.req.role, &root)
|
||||
}
|
||||
vars, err := argList(c, ps.args)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if err := tx.Commit(c); err != nil {
|
||||
if useTx {
|
||||
row = tx.QueryRow(c, ps.sd.SQL, vars...)
|
||||
} else {
|
||||
row = db.QueryRow(c, ps.sd.SQL, vars...)
|
||||
}
|
||||
|
||||
if mutation || anonQuery {
|
||||
err = row.Scan(&root)
|
||||
} else {
|
||||
err = row.Scan(&role, &root)
|
||||
}
|
||||
|
||||
if len(role) == 0 {
|
||||
logger.Debug().Str("default_role", c.req.role).Msg(c.req.Query)
|
||||
} else {
|
||||
logger.Debug().Str("default_role", c.req.role).Str("role", role).Msg(c.req.Query)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return root, ps, nil
|
||||
c.req.role = role
|
||||
|
||||
if useTx {
|
||||
if err := tx.Commit(c); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return root, ps.st, nil
|
||||
}
|
||||
|
||||
func (c *coreContext) resolveSQL() ([]byte, uint32, error) {
|
||||
tx, err := db.Begin(c)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer tx.Rollback(c)
|
||||
func (c *coreContext) resolveSQL() ([]byte, *stmt, error) {
|
||||
var tx pgx.Tx
|
||||
var err error
|
||||
|
||||
qt := qcode.GetQType(c.req.Query)
|
||||
mutation := (qt == qcode.QTMutation)
|
||||
//anonQuery := (qt == qcode.QTQuery && c.req.role == "anon")
|
||||
|
||||
mutation := isMutation(c.req.Query)
|
||||
useRoleQuery := len(conf.RolesQuery) != 0 && mutation
|
||||
useTx := useRoleQuery || conf.DB.SetUserID
|
||||
|
||||
if useTx {
|
||||
if tx, err = db.Begin(c); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer tx.Rollback(c)
|
||||
}
|
||||
|
||||
if conf.DB.SetUserID {
|
||||
if err := setLocalUserID(c, tx); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if useRoleQuery {
|
||||
if c.req.role, err = c.executeRoleQuery(tx); err != nil {
|
||||
return nil, 0, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
} else if v := c.Value(userRoleKey); v != nil {
|
||||
c.req.role = v.(string)
|
||||
}
|
||||
|
||||
stmts, err := c.buildStmt()
|
||||
stmts, err := buildStmt(qt, []byte(c.req.Query), c.req.Vars, c.req.role)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
var st *stmt
|
||||
|
||||
if mutation {
|
||||
st = findStmt(c.req.role, stmts)
|
||||
} else {
|
||||
st = &stmts[0]
|
||||
return nil, nil, err
|
||||
}
|
||||
st := &stmts[0]
|
||||
|
||||
t := fasttemplate.New(st.sql, openVar, closeVar)
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
_, err = t.ExecuteFunc(buf, argMap(c))
|
||||
|
||||
if err == errNoUserID {
|
||||
logger.Warn().Msg("no user id found. query requires an authenicated request")
|
||||
}
|
||||
|
||||
_, err = t.ExecuteFunc(buf, argMap(c, c.req.Vars))
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
finalSQL := buf.String()
|
||||
|
||||
var stime time.Time
|
||||
@ -222,192 +212,57 @@ func (c *coreContext) resolveSQL() ([]byte, uint32, error) {
|
||||
stime = time.Now()
|
||||
}
|
||||
|
||||
if conf.DB.SetUserID {
|
||||
if err := c.setLocalUserID(tx); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
}
|
||||
|
||||
var root []byte
|
||||
var role string
|
||||
var row pgx.Row
|
||||
|
||||
if mutation {
|
||||
err = tx.QueryRow(c, finalSQL).Scan(&root)
|
||||
defaultRole := c.req.role
|
||||
|
||||
if useTx {
|
||||
row = tx.QueryRow(c, finalSQL)
|
||||
} else {
|
||||
err = tx.QueryRow(c, finalSQL).Scan(&c.req.role, &root)
|
||||
row = db.QueryRow(c, finalSQL)
|
||||
}
|
||||
|
||||
if len(stmts) == 1 {
|
||||
err = row.Scan(&root)
|
||||
} else {
|
||||
err = row.Scan(&role, &root)
|
||||
}
|
||||
|
||||
if len(role) == 0 {
|
||||
logger.Debug().Str("default_role", defaultRole).Msg(c.req.Query)
|
||||
} else {
|
||||
logger.Debug().Str("default_role", defaultRole).Str("role", role).Msg(c.req.Query)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if err := tx.Commit(c); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
if mutation {
|
||||
st = findStmt(c.req.role, stmts)
|
||||
} else {
|
||||
st = &stmts[0]
|
||||
}
|
||||
|
||||
if conf.EnableTracing && len(st.qc.Selects) != 0 {
|
||||
c.addTrace(
|
||||
st.qc.Selects,
|
||||
st.qc.Selects[0].ID,
|
||||
stime)
|
||||
if useTx {
|
||||
if err := tx.Commit(c); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if conf.Production == false {
|
||||
_allowList.add(&c.req)
|
||||
}
|
||||
|
||||
return root, st.skipped, nil
|
||||
}
|
||||
|
||||
func (c *coreContext) resolveRemote(
|
||||
hdr http.Header,
|
||||
h *xxhash.Digest,
|
||||
field jsn.Field,
|
||||
sel []qcode.Select,
|
||||
sfmap map[uint64]*qcode.Select) ([]jsn.Field, error) {
|
||||
|
||||
// replacement data for the marked insertion points
|
||||
// key and value will be replaced by whats below
|
||||
toA := [1]jsn.Field{}
|
||||
to := toA[:1]
|
||||
|
||||
// use the json key to find the related Select object
|
||||
k1 := xxhash.Sum64(field.Key)
|
||||
|
||||
s, ok := sfmap[k1]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
p := sel[s.ParentID]
|
||||
|
||||
// then use the Table nme in the Select and it's parent
|
||||
// to find the resolver to use for this relationship
|
||||
k2 := mkkey(h, s.Table, p.Table)
|
||||
|
||||
r, ok := rmap[k2]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
id := jsn.Value(field.Value)
|
||||
if len(id) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
st := time.Now()
|
||||
|
||||
b, err := r.Fn(hdr, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
if len(stmts) > 1 {
|
||||
if st = findStmt(role, stmts); st == nil {
|
||||
return nil, nil, fmt.Errorf("invalid role '%s' returned", role)
|
||||
}
|
||||
}
|
||||
|
||||
if conf.EnableTracing {
|
||||
c.addTrace(sel, s.ID, st)
|
||||
for _, id := range st.qc.Roots {
|
||||
c.addTrace(st.qc.Selects, id, stime)
|
||||
}
|
||||
}
|
||||
|
||||
if len(r.Path) != 0 {
|
||||
b = jsn.Strip(b, r.Path)
|
||||
}
|
||||
|
||||
var ob bytes.Buffer
|
||||
|
||||
if len(s.Cols) != 0 {
|
||||
err = jsn.Filter(&ob, b, colsToList(s.Cols))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
} else {
|
||||
ob.WriteString("null")
|
||||
}
|
||||
|
||||
to[0] = jsn.Field{[]byte(s.FieldName), ob.Bytes()}
|
||||
return to, nil
|
||||
}
|
||||
|
||||
func (c *coreContext) resolveRemotes(
|
||||
hdr http.Header,
|
||||
h *xxhash.Digest,
|
||||
from []jsn.Field,
|
||||
sel []qcode.Select,
|
||||
sfmap map[uint64]*qcode.Select) ([]jsn.Field, error) {
|
||||
|
||||
// replacement data for the marked insertion points
|
||||
// key and value will be replaced by whats below
|
||||
to := make([]jsn.Field, len(from))
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(from))
|
||||
|
||||
var cerr error
|
||||
|
||||
for i, id := range from {
|
||||
|
||||
// use the json key to find the related Select object
|
||||
k1 := xxhash.Sum64(id.Key)
|
||||
|
||||
s, ok := sfmap[k1]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
p := sel[s.ParentID]
|
||||
|
||||
// then use the Table nme in the Select and it's parent
|
||||
// to find the resolver to use for this relationship
|
||||
k2 := mkkey(h, s.Table, p.Table)
|
||||
|
||||
r, ok := rmap[k2]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
id := jsn.Value(id.Value)
|
||||
if len(id) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
go func(n int, id []byte, s *qcode.Select) {
|
||||
defer wg.Done()
|
||||
|
||||
st := time.Now()
|
||||
|
||||
b, err := r.Fn(hdr, id)
|
||||
if err != nil {
|
||||
cerr = fmt.Errorf("%s: %s", s.Table, err)
|
||||
return
|
||||
}
|
||||
|
||||
if conf.EnableTracing {
|
||||
c.addTrace(sel, s.ID, st)
|
||||
}
|
||||
|
||||
if len(r.Path) != 0 {
|
||||
b = jsn.Strip(b, r.Path)
|
||||
}
|
||||
|
||||
var ob bytes.Buffer
|
||||
|
||||
if len(s.Cols) != 0 {
|
||||
err = jsn.Filter(&ob, b, colsToList(s.Cols))
|
||||
if err != nil {
|
||||
cerr = fmt.Errorf("%s: %s", s.Table, err)
|
||||
return
|
||||
}
|
||||
|
||||
} else {
|
||||
ob.WriteString("null")
|
||||
}
|
||||
|
||||
to[n] = jsn.Field{[]byte(s.FieldName), ob.Bytes()}
|
||||
}(i, id, s)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
return to, cerr
|
||||
return root, st, nil
|
||||
}
|
||||
|
||||
func (c *coreContext) executeRoleQuery(tx pgx.Tx) (string, error) {
|
||||
@ -421,15 +276,6 @@ func (c *coreContext) executeRoleQuery(tx pgx.Tx) (string, error) {
|
||||
return role, nil
|
||||
}
|
||||
|
||||
func (c *coreContext) setLocalUserID(tx pgx.Tx) error {
|
||||
var err error
|
||||
if v := c.Value(userIDKey); v != nil {
|
||||
_, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *coreContext) render(w io.Writer, data []byte) error {
|
||||
c.res.Data = json.RawMessage(data)
|
||||
return json.NewEncoder(w).Encode(c.res)
|
||||
@ -451,7 +297,7 @@ func (c *coreContext) addTrace(sel []qcode.Select, id int32, st time.Time) {
|
||||
c.res.Extensions.Tracing.Duration = du
|
||||
|
||||
n := 1
|
||||
for i := id; i != 0; i = sel[i].ParentID {
|
||||
for i := id; i != -1; i = sel[i].ParentID {
|
||||
n++
|
||||
}
|
||||
path := make([]string, n)
|
||||
@ -459,7 +305,7 @@ func (c *coreContext) addTrace(sel []qcode.Select, id int32, st time.Time) {
|
||||
n--
|
||||
for i := id; ; i = sel[i].ParentID {
|
||||
path[n] = sel[i].Table
|
||||
if sel[i].ID == 0 {
|
||||
if sel[i].ParentID == -1 {
|
||||
break
|
||||
}
|
||||
n--
|
||||
@ -521,6 +367,15 @@ func parentFieldIds(h *xxhash.Digest, sel []qcode.Select, skipped uint32) (
|
||||
return fm, sm
|
||||
}
|
||||
|
||||
func setLocalUserID(c context.Context, tx pgx.Tx) error {
|
||||
var err error
|
||||
if v := c.Value(userIDKey); v != nil {
|
||||
_, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func isSkipped(n uint32, pos uint32) bool {
|
||||
return ((n & (1 << pos)) != 0)
|
||||
}
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/dosco/super-graph/psql"
|
||||
@ -17,133 +18,171 @@ type stmt struct {
|
||||
sql string
|
||||
}
|
||||
|
||||
func (c *coreContext) buildStmt() ([]stmt, error) {
|
||||
var vars map[string]json.RawMessage
|
||||
func buildStmt(qt qcode.QType, gql, vars []byte, role string) ([]stmt, error) {
|
||||
switch qt {
|
||||
case qcode.QTMutation:
|
||||
return buildRoleStmt(gql, vars, role)
|
||||
|
||||
if len(c.req.Vars) != 0 {
|
||||
if err := json.Unmarshal(c.req.Vars, &vars); err != nil {
|
||||
case qcode.QTQuery:
|
||||
switch {
|
||||
case role == "anon":
|
||||
return buildRoleStmt(gql, vars, role)
|
||||
|
||||
default:
|
||||
return buildMultiStmt(gql, vars)
|
||||
}
|
||||
|
||||
default:
|
||||
return nil, fmt.Errorf("unknown query type '%d'", qt)
|
||||
}
|
||||
}
|
||||
|
||||
func buildRoleStmt(gql, vars []byte, role string) ([]stmt, error) {
|
||||
ro, ok := conf.roles[role]
|
||||
if !ok {
|
||||
return nil, fmt.Errorf(`roles '%s' not defined in config`, role)
|
||||
}
|
||||
|
||||
var vm map[string]json.RawMessage
|
||||
var err error
|
||||
|
||||
if len(vars) != 0 {
|
||||
if err := json.Unmarshal(vars, &vm); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
gql := []byte(c.req.Query)
|
||||
|
||||
if len(conf.Roles) == 0 {
|
||||
return nil, errors.New(`no roles found ('user' and 'anon' required)`)
|
||||
}
|
||||
|
||||
qc, err := qcompile.Compile(gql, conf.Roles[0].Name)
|
||||
qc, err := qcompile.Compile(gql, ro.Name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stmts := make([]stmt, 0, len(conf.Roles))
|
||||
mutation := (qc.Type != qcode.QTQuery)
|
||||
// For the 'anon' role in production only compile
|
||||
// queries for tables defined in the config file.
|
||||
if conf.Production &&
|
||||
ro.Name == "anon" &&
|
||||
hasTablesWithConfig(qc, ro) == false {
|
||||
return nil, errors.New("query contains tables with no 'anon' role config")
|
||||
}
|
||||
|
||||
stmts := []stmt{stmt{role: ro, qc: qc}}
|
||||
w := &bytes.Buffer{}
|
||||
|
||||
for i := 1; i < len(conf.Roles); i++ {
|
||||
skipped, err := pcompile.Compile(qc, w, psql.Variables(vm))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stmts[0].skipped = skipped
|
||||
stmts[0].sql = w.String()
|
||||
|
||||
return stmts, nil
|
||||
}
|
||||
|
||||
func buildMultiStmt(gql, vars []byte) ([]stmt, error) {
|
||||
var vm map[string]json.RawMessage
|
||||
var err error
|
||||
|
||||
if len(vars) != 0 {
|
||||
if err := json.Unmarshal(vars, &vm); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if len(conf.RolesQuery) == 0 {
|
||||
return buildRoleStmt(gql, vars, "user")
|
||||
}
|
||||
|
||||
stmts := make([]stmt, 0, len(conf.Roles))
|
||||
w := &bytes.Buffer{}
|
||||
|
||||
for i := 0; i < len(conf.Roles); i++ {
|
||||
role := &conf.Roles[i]
|
||||
|
||||
// For mutations only render sql for a single role from the request
|
||||
if mutation && len(c.req.role) != 0 && role.Name != c.req.role {
|
||||
continue
|
||||
}
|
||||
|
||||
qc, err = qcompile.Compile(gql, role.Name)
|
||||
qc, err := qcompile.Compile(gql, role.Name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if conf.Production && role.Name == "anon" {
|
||||
if _, ok := role.tablesMap[qc.Selects[0].Table]; !ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
stmts = append(stmts, stmt{role: role, qc: qc})
|
||||
|
||||
if mutation {
|
||||
skipped, err := pcompile.Compile(qc, w, psql.Variables(vars))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s := &stmts[len(stmts)-1]
|
||||
s.skipped = skipped
|
||||
s.sql = w.String()
|
||||
w.Reset()
|
||||
skipped, err := pcompile.Compile(qc, w, psql.Variables(vm))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s := &stmts[len(stmts)-1]
|
||||
s.skipped = skipped
|
||||
s.sql = w.String()
|
||||
w.Reset()
|
||||
}
|
||||
|
||||
if mutation {
|
||||
return stmts, nil
|
||||
sql, err := renderUserQuery(stmts, vm)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stmts[0].sql = sql
|
||||
return stmts, nil
|
||||
}
|
||||
|
||||
func renderUserQuery(
|
||||
stmts []stmt, vars map[string]json.RawMessage) (string, error) {
|
||||
|
||||
var err error
|
||||
w := &bytes.Buffer{}
|
||||
|
||||
io.WriteString(w, `SELECT "_sg_auth_info"."role", (CASE "_sg_auth_info"."role" `)
|
||||
|
||||
for _, s := range stmts {
|
||||
if len(s.role.Match) == 0 &&
|
||||
s.role.Name != "user" && s.role.Name != "anon" {
|
||||
continue
|
||||
}
|
||||
io.WriteString(w, `WHEN '`)
|
||||
io.WriteString(w, s.role.Name)
|
||||
io.WriteString(w, `' THEN (`)
|
||||
|
||||
s.skipped, err = pcompile.Compile(s.qc, w, psql.Variables(vars))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return "", err
|
||||
}
|
||||
|
||||
io.WriteString(w, `) `)
|
||||
}
|
||||
io.WriteString(w, `END) FROM (`)
|
||||
|
||||
if len(conf.RolesQuery) == 0 {
|
||||
v := c.Value(userRoleKey)
|
||||
io.WriteString(w, `END) FROM (SELECT (CASE WHEN EXISTS (`)
|
||||
io.WriteString(w, conf.RolesQuery)
|
||||
io.WriteString(w, `) THEN `)
|
||||
|
||||
io.WriteString(w, `VALUES ("`)
|
||||
if v != nil {
|
||||
io.WriteString(w, v.(string))
|
||||
} else {
|
||||
io.WriteString(w, c.req.role)
|
||||
io.WriteString(w, `(SELECT (CASE`)
|
||||
for _, s := range stmts {
|
||||
if len(s.role.Match) == 0 {
|
||||
continue
|
||||
}
|
||||
io.WriteString(w, `")) AS "_sg_auth_info"(role) LIMIT 1;`)
|
||||
|
||||
} else {
|
||||
|
||||
io.WriteString(w, `SELECT (CASE WHEN EXISTS (`)
|
||||
io.WriteString(w, conf.RolesQuery)
|
||||
io.WriteString(w, `) THEN `)
|
||||
|
||||
io.WriteString(w, `(SELECT (CASE`)
|
||||
for _, s := range stmts {
|
||||
if len(s.role.Match) == 0 {
|
||||
continue
|
||||
}
|
||||
io.WriteString(w, ` WHEN `)
|
||||
io.WriteString(w, s.role.Match)
|
||||
io.WriteString(w, ` THEN '`)
|
||||
io.WriteString(w, s.role.Name)
|
||||
io.WriteString(w, `'`)
|
||||
}
|
||||
|
||||
if len(c.req.role) == 0 {
|
||||
io.WriteString(w, ` ELSE 'anon' END) FROM (`)
|
||||
} else {
|
||||
io.WriteString(w, ` ELSE '`)
|
||||
io.WriteString(w, c.req.role)
|
||||
io.WriteString(w, `' END) FROM (`)
|
||||
}
|
||||
|
||||
io.WriteString(w, conf.RolesQuery)
|
||||
io.WriteString(w, `) AS "_sg_auth_roles_query" LIMIT 1) ELSE '`)
|
||||
if len(c.req.role) == 0 {
|
||||
io.WriteString(w, `anon`)
|
||||
} else {
|
||||
io.WriteString(w, c.req.role)
|
||||
}
|
||||
io.WriteString(w, `' END) FROM (VALUES (1)) AS "_sg_auth_filler") AS "_sg_auth_info"(role) LIMIT 1; `)
|
||||
io.WriteString(w, ` WHEN `)
|
||||
io.WriteString(w, s.role.Match)
|
||||
io.WriteString(w, ` THEN '`)
|
||||
io.WriteString(w, s.role.Name)
|
||||
io.WriteString(w, `'`)
|
||||
}
|
||||
|
||||
stmts[0].sql = w.String()
|
||||
stmts[0].role = nil
|
||||
io.WriteString(w, ` ELSE 'user' END) FROM (`)
|
||||
io.WriteString(w, conf.RolesQuery)
|
||||
io.WriteString(w, `) AS "_sg_auth_roles_query" LIMIT 1) `)
|
||||
io.WriteString(w, `ELSE 'anon' END) FROM (VALUES (1)) AS "_sg_auth_filler") AS "_sg_auth_info"(role) LIMIT 1; `)
|
||||
|
||||
return stmts, nil
|
||||
return w.String(), nil
|
||||
}
|
||||
|
||||
func hasTablesWithConfig(qc *qcode.QCode, role *configRole) bool {
|
||||
for _, id := range qc.Roots {
|
||||
t, err := schema.GetTable(qc.Selects[id].Table)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
if _, ok := role.tablesMap[t.Name]; !ok {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
197
serv/core_remote.go
Normal file
197
serv/core_remote.go
Normal file
@ -0,0 +1,197 @@
|
||||
package serv
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"sync"
|
||||
|
||||
"github.com/cespare/xxhash/v2"
|
||||
"github.com/dosco/super-graph/jsn"
|
||||
"github.com/dosco/super-graph/qcode"
|
||||
)
|
||||
|
||||
func execRemoteJoin(st *stmt, data []byte, hdr http.Header) ([]byte, error) {
|
||||
var err error
|
||||
|
||||
if len(data) == 0 || st.skipped == 0 {
|
||||
return data, nil
|
||||
}
|
||||
|
||||
sel := st.qc.Selects
|
||||
h := xxhash.New()
|
||||
|
||||
// fetch the field name used within the db response json
|
||||
// that are used to mark insertion points and the mapping between
|
||||
// those field names and their select objects
|
||||
fids, sfmap := parentFieldIds(h, sel, st.skipped)
|
||||
|
||||
// fetch the field values of the marked insertion points
|
||||
// these values contain the id to be used with fetching remote data
|
||||
from := jsn.Get(data, fids)
|
||||
var to []jsn.Field
|
||||
|
||||
switch {
|
||||
case len(from) == 1:
|
||||
to, err = resolveRemote(hdr, h, from[0], sel, sfmap)
|
||||
|
||||
case len(from) > 1:
|
||||
to, err = resolveRemotes(hdr, h, from, sel, sfmap)
|
||||
|
||||
default:
|
||||
return nil, errors.New("something wrong no remote ids found in db response")
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var ob bytes.Buffer
|
||||
|
||||
err = jsn.Replace(&ob, data, from, to)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return ob.Bytes(), nil
|
||||
}
|
||||
|
||||
func resolveRemote(
|
||||
hdr http.Header,
|
||||
h *xxhash.Digest,
|
||||
field jsn.Field,
|
||||
sel []qcode.Select,
|
||||
sfmap map[uint64]*qcode.Select) ([]jsn.Field, error) {
|
||||
|
||||
// replacement data for the marked insertion points
|
||||
// key and value will be replaced by whats below
|
||||
toA := [1]jsn.Field{}
|
||||
to := toA[:1]
|
||||
|
||||
// use the json key to find the related Select object
|
||||
k1 := xxhash.Sum64(field.Key)
|
||||
|
||||
s, ok := sfmap[k1]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
p := sel[s.ParentID]
|
||||
|
||||
// then use the Table nme in the Select and it's parent
|
||||
// to find the resolver to use for this relationship
|
||||
k2 := mkkey(h, s.Table, p.Table)
|
||||
|
||||
r, ok := rmap[k2]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
id := jsn.Value(field.Value)
|
||||
if len(id) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
//st := time.Now()
|
||||
|
||||
b, err := r.Fn(hdr, id)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if len(r.Path) != 0 {
|
||||
b = jsn.Strip(b, r.Path)
|
||||
}
|
||||
|
||||
var ob bytes.Buffer
|
||||
|
||||
if len(s.Cols) != 0 {
|
||||
err = jsn.Filter(&ob, b, colsToList(s.Cols))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
} else {
|
||||
ob.WriteString("null")
|
||||
}
|
||||
|
||||
to[0] = jsn.Field{Key: []byte(s.FieldName), Value: ob.Bytes()}
|
||||
return to, nil
|
||||
}
|
||||
|
||||
func resolveRemotes(
|
||||
hdr http.Header,
|
||||
h *xxhash.Digest,
|
||||
from []jsn.Field,
|
||||
sel []qcode.Select,
|
||||
sfmap map[uint64]*qcode.Select) ([]jsn.Field, error) {
|
||||
|
||||
// replacement data for the marked insertion points
|
||||
// key and value will be replaced by whats below
|
||||
to := make([]jsn.Field, len(from))
|
||||
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(len(from))
|
||||
|
||||
var cerr error
|
||||
|
||||
for i, id := range from {
|
||||
|
||||
// use the json key to find the related Select object
|
||||
k1 := xxhash.Sum64(id.Key)
|
||||
|
||||
s, ok := sfmap[k1]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
p := sel[s.ParentID]
|
||||
|
||||
// then use the Table nme in the Select and it's parent
|
||||
// to find the resolver to use for this relationship
|
||||
k2 := mkkey(h, s.Table, p.Table)
|
||||
|
||||
r, ok := rmap[k2]
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
id := jsn.Value(id.Value)
|
||||
if len(id) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
go func(n int, id []byte, s *qcode.Select) {
|
||||
defer wg.Done()
|
||||
|
||||
//st := time.Now()
|
||||
|
||||
b, err := r.Fn(hdr, id)
|
||||
if err != nil {
|
||||
cerr = fmt.Errorf("%s: %s", s.Table, err)
|
||||
return
|
||||
}
|
||||
|
||||
if len(r.Path) != 0 {
|
||||
b = jsn.Strip(b, r.Path)
|
||||
}
|
||||
|
||||
var ob bytes.Buffer
|
||||
|
||||
if len(s.Cols) != 0 {
|
||||
err = jsn.Filter(&ob, b, colsToList(s.Cols))
|
||||
if err != nil {
|
||||
cerr = fmt.Errorf("%s: %s", s.Table, err)
|
||||
return
|
||||
}
|
||||
|
||||
} else {
|
||||
ob.WriteString("null")
|
||||
}
|
||||
|
||||
to[n] = jsn.Field{Key: []byte(s.FieldName), Value: ob.Bytes()}
|
||||
}(i, id, s)
|
||||
}
|
||||
wg.Wait()
|
||||
|
||||
return to, cerr
|
||||
}
|
@ -4,7 +4,7 @@ package serv
|
||||
|
||||
func Fuzz(data []byte) int {
|
||||
gql := string(data)
|
||||
isMutation(gql)
|
||||
gqlName(gql)
|
||||
gqlHash(gql, nil, "")
|
||||
|
||||
return 1
|
||||
|
@ -10,7 +10,7 @@ func TestFuzzCrashers(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, f := range crashers {
|
||||
isMutation(f)
|
||||
gqlName(f)
|
||||
gqlHash(f, nil, "")
|
||||
}
|
||||
}
|
||||
|
13
serv/http.go
13
serv/http.go
@ -21,7 +21,6 @@ const (
|
||||
|
||||
var (
|
||||
upgrader = websocket.Upgrader{}
|
||||
errNoUserID = errors.New("no user_id available")
|
||||
errUnauthorized = errors.New("not authorized")
|
||||
)
|
||||
|
||||
@ -77,18 +76,16 @@ func apiv1Http(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
b, err := ioutil.ReadAll(io.LimitReader(r.Body, maxReadBytes))
|
||||
defer r.Body.Close()
|
||||
|
||||
if err != nil {
|
||||
logger.Err(err).Msg("failed to read request body")
|
||||
errlog.Error().Err(err).Msg("failed to read request body")
|
||||
errorResp(w, err)
|
||||
return
|
||||
}
|
||||
defer r.Body.Close()
|
||||
|
||||
err = json.Unmarshal(b, &ctx.req)
|
||||
|
||||
if err != nil {
|
||||
logger.Err(err).Msg("failed to decode json request body")
|
||||
errlog.Error().Err(err).Msg("failed to decode json request body")
|
||||
errorResp(w, err)
|
||||
return
|
||||
}
|
||||
@ -107,12 +104,12 @@ func apiv1Http(w http.ResponseWriter, r *http.Request) {
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logger.Err(err).Msg("failed to handle request")
|
||||
errlog.Error().Err(err).Msg("failed to handle request")
|
||||
errorResp(w, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func errorResp(w http.ResponseWriter, err error) {
|
||||
w.WriteHeader(http.StatusBadRequest)
|
||||
json.NewEncoder(w).Encode(gqlResp{Error: err.Error()})
|
||||
}
|
||||
|
209
serv/prepare.go
209
serv/prepare.go
@ -3,20 +3,19 @@ package serv
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/dosco/super-graph/qcode"
|
||||
"github.com/jackc/pgconn"
|
||||
"github.com/jackc/pgx/v4"
|
||||
"github.com/valyala/fasttemplate"
|
||||
)
|
||||
|
||||
type preparedItem struct {
|
||||
stmt *pgconn.StatementDescription
|
||||
args [][]byte
|
||||
skipped uint32
|
||||
qc *qcode.QCode
|
||||
sd *pgconn.StatementDescription
|
||||
args [][]byte
|
||||
st *stmt
|
||||
}
|
||||
|
||||
var (
|
||||
@ -24,83 +23,125 @@ var (
|
||||
)
|
||||
|
||||
func initPreparedList() {
|
||||
c := context.Background()
|
||||
_preparedList = make(map[string]*preparedItem)
|
||||
|
||||
if err := prepareRoleStmt(); err != nil {
|
||||
logger.Fatal().Err(err).Msg("failed to prepare get role statement")
|
||||
tx, err := db.Begin(c)
|
||||
if err != nil {
|
||||
errlog.Fatal().Err(err).Send()
|
||||
}
|
||||
defer tx.Rollback(c)
|
||||
|
||||
err = prepareRoleStmt(c, tx)
|
||||
if err != nil {
|
||||
errlog.Fatal().Err(err).Msg("failed to prepare get role statement")
|
||||
}
|
||||
|
||||
if err := tx.Commit(c); err != nil {
|
||||
errlog.Fatal().Err(err).Send()
|
||||
}
|
||||
|
||||
success := 0
|
||||
|
||||
for _, v := range _allowList.list {
|
||||
|
||||
err := prepareStmt(v.gql, v.vars)
|
||||
if err != nil {
|
||||
logger.Warn().Str("gql", v.gql).Err(err).Send()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func prepareStmt(gql string, varBytes json.RawMessage) error {
|
||||
if len(gql) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
c := &coreContext{Context: context.Background()}
|
||||
c.req.Query = gql
|
||||
c.req.Vars = varBytes
|
||||
|
||||
stmts, err := c.buildStmt()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(stmts) != 0 && stmts[0].qc.Type == qcode.QTQuery {
|
||||
c.req.Vars = nil
|
||||
}
|
||||
|
||||
for _, s := range stmts {
|
||||
if len(s.sql) == 0 {
|
||||
if len(v.gql) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
finalSQL, am := processTemplate(s.sql)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
tx, err := db.Begin(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
pstmt, err := tx.Prepare(ctx, "", finalSQL)
|
||||
if err != nil {
|
||||
return err
|
||||
err := prepareStmt(c, v.gql, v.vars)
|
||||
if err == nil {
|
||||
success++
|
||||
continue
|
||||
}
|
||||
|
||||
var key string
|
||||
|
||||
if s.role == nil {
|
||||
key = gqlHash(gql, c.req.Vars, "")
|
||||
if len(v.vars) == 0 {
|
||||
logger.Warn().Err(err).Msg(v.gql)
|
||||
} else {
|
||||
key = gqlHash(gql, c.req.Vars, s.role.Name)
|
||||
logger.Warn().Err(err).Msgf("%s %s", v.vars, v.gql)
|
||||
}
|
||||
}
|
||||
|
||||
_preparedList[key] = &preparedItem{
|
||||
stmt: pstmt,
|
||||
args: am,
|
||||
skipped: s.skipped,
|
||||
qc: s.qc,
|
||||
}
|
||||
logger.Info().
|
||||
Msgf("Registered %d of %d queries from allow.list as prepared statements",
|
||||
success, len(_allowList.list))
|
||||
}
|
||||
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
func prepareStmt(c context.Context, gql string, vars []byte) error {
|
||||
qt := qcode.GetQType(gql)
|
||||
q := []byte(gql)
|
||||
|
||||
tx, err := db.Begin(c)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback(c)
|
||||
|
||||
switch qt {
|
||||
case qcode.QTQuery:
|
||||
stmts1, err := buildMultiStmt(q, vars)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = prepare(c, tx, &stmts1[0], gqlHash(gql, vars, "user"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
stmts2, err := buildRoleStmt(q, vars, "anon")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = prepare(c, tx, &stmts2[0], gqlHash(gql, vars, "anon"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case qcode.QTMutation:
|
||||
for _, role := range conf.Roles {
|
||||
stmts, err := buildRoleStmt(q, vars, role.Name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
err = prepare(c, tx, &stmts[0], gqlHash(gql, vars, role.Name))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(vars) == 0 {
|
||||
logger.Debug().Msgf("Building prepared statement for:\n %s", gql)
|
||||
} else {
|
||||
logger.Debug().Msgf("Building prepared statement:\n %s\n%s", vars, gql)
|
||||
}
|
||||
|
||||
if err := tx.Commit(c); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func prepareRoleStmt() error {
|
||||
func prepare(c context.Context, tx pgx.Tx, st *stmt, key string) error {
|
||||
finalSQL, am := processTemplate(st.sql)
|
||||
|
||||
sd, err := tx.Prepare(c, "", finalSQL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_preparedList[key] = &preparedItem{
|
||||
sd: sd,
|
||||
args: am,
|
||||
st: st,
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func prepareRoleStmt(c context.Context, tx pgx.Tx) error {
|
||||
if len(conf.RolesQuery) == 0 {
|
||||
return nil
|
||||
}
|
||||
@ -125,15 +166,7 @@ func prepareRoleStmt() error {
|
||||
|
||||
roleSQL, _ := processTemplate(w.String())
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
tx, err := db.Begin(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
_, err = tx.Prepare(ctx, "_sg_get_role", roleSQL)
|
||||
_, err := tx.Prepare(c, "_sg_get_role", roleSQL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -142,19 +175,31 @@ func prepareRoleStmt() error {
|
||||
}
|
||||
|
||||
func processTemplate(tmpl string) (string, [][]byte) {
|
||||
t := fasttemplate.New(tmpl, `{{`, `}}`)
|
||||
am := make([][]byte, 0, 5)
|
||||
i := 0
|
||||
st := struct {
|
||||
vmap map[string]int
|
||||
am [][]byte
|
||||
i int
|
||||
}{
|
||||
vmap: make(map[string]int),
|
||||
am: make([][]byte, 0, 5),
|
||||
i: 0,
|
||||
}
|
||||
|
||||
vmap := make(map[string]int)
|
||||
|
||||
return t.ExecuteFuncString(func(w io.Writer, tag string) (int, error) {
|
||||
if n, ok := vmap[tag]; ok {
|
||||
execFunc := func(w io.Writer, tag string) (int, error) {
|
||||
if n, ok := st.vmap[tag]; ok {
|
||||
return w.Write([]byte(fmt.Sprintf("$%d", n)))
|
||||
}
|
||||
am = append(am, []byte(tag))
|
||||
i++
|
||||
vmap[tag] = i
|
||||
return w.Write([]byte(fmt.Sprintf("$%d", i)))
|
||||
}), am
|
||||
st.am = append(st.am, []byte(tag))
|
||||
st.i++
|
||||
st.vmap[tag] = st.i
|
||||
return w.Write([]byte(fmt.Sprintf("$%d", st.i)))
|
||||
}
|
||||
|
||||
t1 := fasttemplate.New(tmpl, `'{{`, `}}'`)
|
||||
ts1 := t1.ExecuteFuncString(execFunc)
|
||||
|
||||
t2 := fasttemplate.New(ts1, `{{`, `}}`)
|
||||
ts2 := t2.ExecuteFuncString(execFunc)
|
||||
|
||||
return ts2, st.am
|
||||
}
|
||||
|
@ -168,7 +168,7 @@ func Do(log func(string, ...interface{}), additional ...dir) error {
|
||||
func ReExec() {
|
||||
err := syscall.Exec(binSelf, append([]string{binSelf}, os.Args[1:]...), os.Environ())
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("cannot restart")
|
||||
errlog.Fatal().Err(err).Msg("cannot restart")
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -117,7 +117,7 @@ func buildFn(r configRemote) func(http.Header, []byte) ([]byte, error) {
|
||||
|
||||
res, err := client.Do(req)
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Msgf("Failed to connect to: %s", uri)
|
||||
errlog.Error().Err(err).Msgf("Failed to connect to: %s", uri)
|
||||
return nil, err
|
||||
}
|
||||
defer res.Body.Close()
|
||||
|
27
serv/serv.go
27
serv/serv.go
@ -15,13 +15,15 @@ import (
|
||||
)
|
||||
|
||||
func initCompilers(c *config) (*qcode.Compiler, *psql.Compiler, error) {
|
||||
schema, err := psql.NewDBSchema(db, c.getAliasMap())
|
||||
var err error
|
||||
|
||||
schema, err = psql.NewDBSchema(db, c.getAliasMap())
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
conf := qcode.Config{
|
||||
Blocklist: c.DB.Defaults.Blocklist,
|
||||
Blocklist: c.DB.Blocklist,
|
||||
KeepArgs: false,
|
||||
}
|
||||
|
||||
@ -106,7 +108,7 @@ func initWatcher(cpath string) {
|
||||
go func() {
|
||||
err := Do(logger.Printf, d)
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Send()
|
||||
errlog.Fatal().Err(err).Send()
|
||||
}
|
||||
}()
|
||||
}
|
||||
@ -139,7 +141,7 @@ func startHTTP() {
|
||||
<-sigint
|
||||
|
||||
if err := srv.Shutdown(context.Background()); err != nil {
|
||||
logger.Error().Err(err).Msg("shutdown signal received")
|
||||
errlog.Error().Err(err).Msg("shutdown signal received")
|
||||
}
|
||||
close(idleConnsClosed)
|
||||
}()
|
||||
@ -148,18 +150,14 @@ func startHTTP() {
|
||||
db.Close()
|
||||
})
|
||||
|
||||
var ident string
|
||||
|
||||
if len(conf.AppName) == 0 {
|
||||
ident = conf.Env
|
||||
} else {
|
||||
ident = conf.AppName
|
||||
}
|
||||
|
||||
fmt.Printf("%s listening on %s (%s)\n", serverName, hostPort, ident)
|
||||
logger.Info().
|
||||
Str("host_post", hostPort).
|
||||
Str("app_name", conf.AppName).
|
||||
Str("env", conf.Env).
|
||||
Msgf("%s listening", serverName)
|
||||
|
||||
if err := srv.ListenAndServe(); err != http.ErrServerClosed {
|
||||
logger.Error().Err(err).Msg("server closed")
|
||||
errlog.Error().Err(err).Msg("server closed")
|
||||
}
|
||||
|
||||
<-idleConnsClosed
|
||||
@ -169,6 +167,7 @@ func routeHandler() http.Handler {
|
||||
mux := http.NewServeMux()
|
||||
|
||||
mux.Handle("/api/v1/graphql", withAuth(apiv1Http))
|
||||
|
||||
if conf.WebUI {
|
||||
mux.Handle("/", http.FileServer(rice.MustFindBox("../web/build").HTTPBox()))
|
||||
}
|
||||
|
@ -28,9 +28,7 @@ func (pl *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data
|
||||
zlevel = zerolog.ErrorLevel
|
||||
case pgx.LogLevelWarn:
|
||||
zlevel = zerolog.WarnLevel
|
||||
case pgx.LogLevelInfo:
|
||||
zlevel = zerolog.InfoLevel
|
||||
case pgx.LogLevelDebug:
|
||||
case pgx.LogLevelDebug, pgx.LogLevelInfo:
|
||||
zlevel = zerolog.DebugLevel
|
||||
default:
|
||||
zlevel = zerolog.DebugLevel
|
||||
|
@ -106,17 +106,28 @@ func al(b byte) bool {
|
||||
return (b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z') || (b >= '0' && b <= '9')
|
||||
}
|
||||
|
||||
func isMutation(sql string) bool {
|
||||
for i := range sql {
|
||||
b := sql[i]
|
||||
if b == '{' {
|
||||
return false
|
||||
}
|
||||
if al(b) {
|
||||
return (b == 'm' || b == 'M')
|
||||
func gqlName(b string) string {
|
||||
state, s := 0, 0
|
||||
|
||||
for i := 0; i < len(b); i++ {
|
||||
switch {
|
||||
case state == 2 && b[i] == '{':
|
||||
return b[s:i]
|
||||
case state == 2 && b[i] == ' ':
|
||||
return b[s:i]
|
||||
case state == 1 && b[i] == '{':
|
||||
return ""
|
||||
case state == 1 && b[i] != ' ':
|
||||
s = i
|
||||
state = 2
|
||||
case state == 1 && b[i] == ' ':
|
||||
continue
|
||||
case i != 0 && b[i] == ' ' && (b[i-1] == 'n' || b[i-1] == 'y'):
|
||||
state = 1
|
||||
}
|
||||
}
|
||||
return false
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func findStmt(role string, stmts []stmt) *stmt {
|
||||
|
@ -229,3 +229,80 @@ func TestGQLHashWithVars2(t *testing.T) {
|
||||
t.Fatal("Hashes don't match they should")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGQLName1(t *testing.T) {
|
||||
var q = `
|
||||
query {
|
||||
products(
|
||||
distinct: [price]
|
||||
where: { id: { and: { greater_or_equals: 20, lt: 28 } } }
|
||||
) { id name } }`
|
||||
|
||||
name := gqlName(q)
|
||||
|
||||
if len(name) != 0 {
|
||||
t.Fatal("Name should be empty, not ", name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGQLName2(t *testing.T) {
|
||||
var q = `
|
||||
query hakuna_matata {
|
||||
products(
|
||||
distinct: [price]
|
||||
where: { id: { and: { greater_or_equals: 20, lt: 28 } } }
|
||||
) {
|
||||
id
|
||||
name
|
||||
}
|
||||
}`
|
||||
|
||||
name := gqlName(q)
|
||||
|
||||
if name != "hakuna_matata" {
|
||||
t.Fatal("Name should be 'hakuna_matata', not ", name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGQLName3(t *testing.T) {
|
||||
var q = `
|
||||
mutation means{ users { id } }`
|
||||
|
||||
// var v2 = ` { products( limit: 30, order_by: { price: desc }, distinct: [ price ] where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) { id name price user { id email } } } `
|
||||
|
||||
name := gqlName(q)
|
||||
|
||||
if name != "means" {
|
||||
t.Fatal("Name should be 'means', not ", name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGQLName4(t *testing.T) {
|
||||
var q = `
|
||||
query no_worries
|
||||
users {
|
||||
id
|
||||
}
|
||||
}`
|
||||
|
||||
name := gqlName(q)
|
||||
|
||||
if name != "no_worries" {
|
||||
t.Fatal("Name should be 'no_worries', not ", name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGQLName5(t *testing.T) {
|
||||
var q = `
|
||||
{
|
||||
users {
|
||||
id
|
||||
}
|
||||
}`
|
||||
|
||||
name := gqlName(q)
|
||||
|
||||
if len(name) != 0 {
|
||||
t.Fatal("Name should be empty, not ", name)
|
||||
}
|
||||
}
|
||||
|
@ -77,7 +77,7 @@ SQL Output
|
||||
|
||||
database:
|
||||
variables:
|
||||
account_id: "select account_id from users where id = $user_id"
|
||||
admin_account_id: "5"
|
||||
|
||||
defaults:
|
||||
Filters: ["{ user_id: { eq: $user_id } }"]
|
||||
|
40
tmpl/dev.yml
40
tmpl/dev.yml
@ -1,9 +1,9 @@
|
||||
app_name: "{{app_name}} Development"
|
||||
app_name: "{% app_name %} Development"
|
||||
host_port: 0.0.0.0:8080
|
||||
web_ui: true
|
||||
|
||||
# debug, info, warn, error, fatal, panic
|
||||
log_level: "debug"
|
||||
log_level: "info"
|
||||
|
||||
# When production mode is 'true' only queries
|
||||
# from the allow list are permitted.
|
||||
@ -48,7 +48,7 @@ migrations_path: ./config/migrations
|
||||
auth:
|
||||
# Can be 'rails' or 'jwt'
|
||||
type: rails
|
||||
cookie: _{{app_name_slug}}_session
|
||||
cookie: _{% app_name_slug %}_session
|
||||
|
||||
# Comment this out if you want to disable setting
|
||||
# the user_id via a header for testing.
|
||||
@ -84,7 +84,7 @@ database:
|
||||
type: postgres
|
||||
host: db
|
||||
port: 5432
|
||||
dbname: {{app_name_slug}}_development
|
||||
dbname: {% app_name_slug %}_development
|
||||
user: postgres
|
||||
password: ''
|
||||
|
||||
@ -97,23 +97,18 @@ database:
|
||||
# Enable this if you need the user id in triggers, etc
|
||||
set_user_id: false
|
||||
|
||||
# Define variables here that you want to use in filters
|
||||
# sub-queries must be wrapped in ()
|
||||
# Define additional variables here to be used with filters
|
||||
variables:
|
||||
account_id: "(select account_id from users where id = $user_id)"
|
||||
admin_account_id: "5"
|
||||
|
||||
# Define defaults to for the field key and values below
|
||||
defaults:
|
||||
# filters: ["{ user_id: { eq: $user_id } }"]
|
||||
|
||||
# Field and table names that you wish to block
|
||||
blocklist:
|
||||
- ar_internal_metadata
|
||||
- schema_migrations
|
||||
- secret
|
||||
- password
|
||||
- encrypted
|
||||
- token
|
||||
# Field and table names that you wish to block
|
||||
blocklist:
|
||||
- ar_internal_metadata
|
||||
- schema_migrations
|
||||
- secret
|
||||
- password
|
||||
- encrypted
|
||||
- token
|
||||
|
||||
tables:
|
||||
- name: customers
|
||||
@ -185,11 +180,10 @@ roles:
|
||||
- updated_at: "now"
|
||||
|
||||
delete:
|
||||
deny: true
|
||||
block: true
|
||||
|
||||
- name: admin
|
||||
match: id = 1
|
||||
match: id = 1000
|
||||
tables:
|
||||
- name: users
|
||||
# query:
|
||||
# filters: ["{ account_id: { _eq: $account_id } }"]
|
||||
filters: []
|
||||
|
@ -1,4 +1,4 @@
|
||||
version: '3'
|
||||
version: '3.4'
|
||||
services:
|
||||
db:
|
||||
image: postgres
|
||||
|
@ -2,12 +2,12 @@
|
||||
# so I only need to overwrite some values
|
||||
inherits: dev
|
||||
|
||||
app_name: "{{app_name}} Production"
|
||||
app_name: "{% app_name %} Production"
|
||||
host_port: 0.0.0.0:8080
|
||||
web_ui: false
|
||||
|
||||
# debug, info, warn, error, fatal, panic, disable
|
||||
log_level: "info"
|
||||
log_level: "warn"
|
||||
# When production mode is 'true' only queries
|
||||
# from the allow list are permitted.
|
||||
# When it's 'false' all queries are saved to the
|
||||
@ -40,45 +40,11 @@ enable_tracing: true
|
||||
# SG_AUTH_RAILS_REDIS_PASSWORD
|
||||
# SG_AUTH_JWT_PUBLIC_KEY_FILE
|
||||
|
||||
# inflections:
|
||||
# person: people
|
||||
# sheep: sheep
|
||||
|
||||
auth:
|
||||
# Can be 'rails' or 'jwt'
|
||||
type: rails
|
||||
cookie: _{{app_name_slug}}_session
|
||||
|
||||
rails:
|
||||
# Rails version this is used for reading the
|
||||
# various cookies formats.
|
||||
version: 5.2
|
||||
|
||||
# Found in 'Rails.application.config.secret_key_base'
|
||||
secret_key_base: 0a248500a64c01184edb4d7ad3a805488f8097ac761b76aaa6c17c01dcb7af03a2f18ba61b2868134b9c7b79a122bc0dadff4367414a2d173297bfea92be5566
|
||||
|
||||
# Remote cookie store. (memcache or redis)
|
||||
# url: redis://127.0.0.1:6379
|
||||
# password: test
|
||||
# max_idle: 80,
|
||||
# max_active: 12000,
|
||||
|
||||
# In most cases you don't need these
|
||||
# salt: "encrypted cookie"
|
||||
# sign_salt: "signed encrypted cookie"
|
||||
# auth_salt: "authenticated encrypted cookie"
|
||||
|
||||
# jwt:
|
||||
# provider: auth0
|
||||
# secret: abc335bfcfdb04e50db5bb0a4d67ab9
|
||||
# public_key_file: /secrets/public_key.pem
|
||||
# public_key_type: ecdsa #rsa
|
||||
|
||||
database:
|
||||
type: postgres
|
||||
host: db
|
||||
port: 5432
|
||||
dbname: {{app_name_slug}}_development
|
||||
dbname: {% app_name_slug %}_development
|
||||
user: postgres
|
||||
password: ''
|
||||
#pool_size: 10
|
||||
|
Reference in New Issue
Block a user