Merge pull request #11 from dosco/rbac

Role based access control and other fixes
This commit is contained in:
Vikram Rangnekar 2019-10-25 01:49:37 -04:00 committed by GitHub
commit ff13f651d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
39 changed files with 2439 additions and 1148 deletions

1
.gitignore vendored
View File

@ -28,3 +28,4 @@ main
.DS_Store
.swp
main
super-graph

13
.wtc.yaml Normal file
View File

@ -0,0 +1,13 @@
no_trace: false
debounce: 300 # if rule has no debounce, this will be used instead
ignore: \.git/
trig: [start, run] # will run on start
rules:
- name: start
- name: run
match: \.go$
ignore: web|examples|docs|_test\.go$
command: go run main.go serv
- name: test
match: _test\.go$
command: go test -cover {PKG}

View File

@ -11,9 +11,8 @@ RUN apk update && \
apk add --no-cache git && \
apk add --no-cache upx=3.95-r2
RUN go get -u github.com/shanzi/wu && \
go install github.com/shanzi/wu && \
go get github.com/GeertJohan/go.rice/rice
RUN go get -u github.com/rafaelsq/wtc && \
go get -u github.com/GeertJohan/go.rice/rice
WORKDIR /app
COPY . /app

View File

@ -46,6 +46,9 @@ This compiler is what sits at the heart of Super Graph with layers of useful fun
## Contact me
I'm happy to help you deploy Super Graph so feel free to reach out over
Twitter or Discord.
[twitter/dosco](https://twitter.com/dosco)
[chat/super-graph](https://discord.gg/6pSWCTZ)

View File

@ -1,5 +1,27 @@
# http://localhost:8080/
variables {
"data": [
{
"name": "Protect Ya Neck",
"created_at": "now",
"updated_at": "now"
},
{
"name": "Enter the Wu-Tang",
"created_at": "now",
"updated_at": "now"
}
]
}
mutation {
products(insert: $data) {
id
name
}
}
variables {
"update": {
"name": "Wu-Tang",
@ -16,16 +38,16 @@ mutation {
}
}
variables {
"data": {
"product_id": 5
}
}
mutation {
products(id: $product_id, delete: true) {
query {
users {
id
name
email
picture: avatar
products(limit: 2, where: {price: {gt: 10}}) {
id
name
description
}
}
}
@ -73,6 +95,118 @@ query {
}
}
variables {
"data": [
{
"name": "Gumbo1",
"created_at": "now",
"updated_at": "now"
},
{
"name": "Gumbo2",
"created_at": "now",
"updated_at": "now"
}
]
}
query {
products {
id
name
user {
email
}
}
}
variables {
"data": {
"product_id": 5
}
}
mutation {
products(id: $product_id, delete: true) {
id
name
}
}
variables {
"data": [
{
"name": "Gumbo1",
"created_at": "now",
"updated_at": "now"
},
{
"name": "Gumbo2",
"created_at": "now",
"updated_at": "now"
}
]
}
query {
products {
id
name
price
users {
email
}
}
}
variables {
"data": {
"email": "gfk@myspace.com",
"full_name": "Ghostface Killah",
"created_at": "now",
"updated_at": "now"
}
}
mutation {
user(insert: $data) {
id
}
}
variables {
"data": [
{
"name": "Gumbo1",
"created_at": "now",
"updated_at": "now"
},
{
"name": "Gumbo2",
"created_at": "now",
"updated_at": "now"
}
]
}
query {
products {
id
name
users {
email
}
}
}
query {
me {
id
email
full_name
}
}
variables {
"update": {
@ -112,62 +246,30 @@ query {
}
}
query {
me {
id
email
full_name
}
}
variables {
"data": {
"email": "gfk@myspace.com",
"full_name": "Ghostface Killah",
"created_at": "now",
"updated_at": "now"
}
}
mutation {
user(insert: $data) {
id
}
}
query {
users {
id
email
picture: avatar
products(limit: 2, where: {price: {gt: 10}}) {
id
name
description
}
}
}
variables {
"data": [
{
"name": "Protect Ya Neck",
"name": "Gumbo1",
"created_at": "now",
"updated_at": "now"
},
{
"name": "Enter the Wu-Tang",
"name": "Gumbo2",
"created_at": "now",
"updated_at": "now"
}
]
}
mutation {
products(insert: $data) {
query {
products {
id
name
description
users {
email
}
}
}

View File

@ -22,7 +22,7 @@ enable_tracing: true
# Watch the config folder and reload Super Graph
# with the new configs when a change is detected
reload_on_config_change: false
reload_on_config_change: true
# File that points to the database seeding script
# seed_file: seed.js
@ -53,7 +53,7 @@ auth:
# Comment this out if you want to disable setting
# the user_id via a header. Good for testing
header: X-User-ID
creds_in_header: true
rails:
# Rails version this is used for reading the
@ -100,7 +100,7 @@ database:
# Define defaults to for the field key and values below
defaults:
# filter: ["{ user_id: { eq: $user_id } }"]
# filters: ["{ user_id: { eq: $user_id } }"]
# Field and table names that you wish to block
blocklist:
@ -112,25 +112,7 @@ database:
- token
tables:
- name: users
# This filter will overwrite defaults.filter
# filter: ["{ id: { eq: $user_id } }"]
# filter_query: ["{ id: { eq: $user_id } }"]
filter_update: ["{ id: { eq: $user_id } }"]
filter_delete: ["{ id: { eq: $user_id } }"]
# - name: products
# # Multiple filters are AND'd together
# filter: [
# "{ price: { gt: 0 } }",
# "{ price: { lt: 8 } }"
# ]
- name: customers
# No filter is used for this field not
# even defaults.filter
filter: none
remotes:
- name: payments
id: stripe_id
@ -149,7 +131,61 @@ tables:
# real db table backing them
name: me
table: users
filter: ["{ id: { eq: $user_id } }"]
# - name: posts
# filter: ["{ account_id: { _eq: $account_id } }"]
roles_query: "SELECT * FROM users as usr WHERE id = $user_id"
roles:
- name: anon
tables:
- name: products
limit: 10
query:
columns: ["id", "name", "description" ]
aggregation: false
insert:
allow: false
update:
allow: false
delete:
allow: false
- name: user
tables:
- name: users
query:
filters: ["{ id: { _eq: $user_id } }"]
- name: products
query:
limit: 50
filters: ["{ user_id: { eq: $user_id } }"]
columns: ["id", "name", "description" ]
disable_aggregation: false
insert:
filters: ["{ user_id: { eq: $user_id } }"]
columns: ["id", "name", "description" ]
set:
- created_at: "now"
update:
filters: ["{ user_id: { eq: $user_id } }"]
columns:
- id
- name
set:
- updated_at: "now"
delete:
deny: true
- name: admin
match: id = 1
tables:
- name: users
# query:
# filters: ["{ account_id: { _eq: $account_id } }"]

View File

@ -47,10 +47,6 @@ auth:
type: rails
cookie: _app_session
# Comment this out if you want to disable setting
# the user_id via a header. Good for testing
header: X-User-ID
rails:
# Rails version this is used for reading the
# various cookies formats.
@ -94,7 +90,7 @@ database:
# Define defaults to for the field key and values below
defaults:
filter: ["{ user_id: { eq: $user_id } }"]
filters: ["{ user_id: { eq: $user_id } }"]
# Field and table names that you wish to block
blocklist:
@ -106,25 +102,7 @@ database:
- token
tables:
- name: users
# This filter will overwrite defaults.filter
# filter: ["{ id: { eq: $user_id } }"]
# filter_query: ["{ id: { eq: $user_id } }"]
filter_update: ["{ id: { eq: $user_id } }"]
filter_delete: ["{ id: { eq: $user_id } }"]
- name: products
# Multiple filters are AND'd together
filter: [
"{ price: { gt: 0 } }",
"{ price: { lt: 8 } }"
]
- name: customers
# No filter is used for this field not
# even defaults.filter
filter: none
# remotes:
# - name: payments
# id: stripe_id
@ -141,7 +119,61 @@ tables:
# real db table backing them
name: me
table: users
filter: ["{ id: { eq: $user_id } }"]
# - name: posts
# filter: ["{ account_id: { _eq: $account_id } }"]
roles_query: "SELECT * FROM users as usr WHERE id = $user_id"
roles:
- name: anon
tables:
- name: products
limit: 10
query:
columns: ["id", "name", "description" ]
aggregation: false
insert:
allow: false
update:
allow: false
delete:
allow: false
- name: user
tables:
- name: users
query:
filters: ["{ id: { _eq: $user_id } }"]
- name: products
query:
limit: 50
filters: ["{ user_id: { eq: $user_id } }"]
columns: ["id", "name", "description" ]
disable_aggregation: false
insert:
filters: ["{ user_id: { eq: $user_id } }"]
columns: ["id", "name", "description" ]
set:
- created_at: "now"
update:
filters: ["{ user_id: { eq: $user_id } }"]
columns:
- id
- name
set:
- updated_at: "now"
delete:
deny: true
- name: admin
match: id = 1
tables:
- name: users
# query:
# filters: ["{ account_id: { _eq: $account_id } }"]

View File

@ -34,7 +34,7 @@ services:
volumes:
- .:/app
working_dir: /app
command: wu -pattern="*.go" go run main.go serv
command: wtc
depends_on:
- db
- rails_app

View File

@ -1043,26 +1043,35 @@ We're tried to ensure that the config file is self documenting and easy to work
app_name: "Super Graph Development"
host_port: 0.0.0.0:8080
web_ui: true
debug_level: 1
# debug, info, warn, error, fatal, panic, disable
log_level: "info"
# debug, info, warn, error, fatal, panic
log_level: "debug"
# Disable this in development to get a list of
# queries used. When enabled super graph
# will only allow queries from this list
# List saved to ./config/allow.list
use_allow_list: true
use_allow_list: false
# Throw a 401 on auth failure for queries that need auth
# valid values: always, per_query, never
auth_fail_block: always
auth_fail_block: never
# Latency tracing for database queries and remote joins
# the resulting latency information is returned with the
# response
enable_tracing: true
# Watch the config folder and reload Super Graph
# with the new configs when a change is detected
reload_on_config_change: true
# File that points to the database seeding script
# seed_file: seed.js
# Path pointing to where the migrations can be found
migrations_path: ./config/migrations
# Postgres related environment Variables
# SG_DATABASE_HOST
# SG_DATABASE_PORT
@ -1086,7 +1095,7 @@ auth:
# Comment this out if you want to disable setting
# the user_id via a header. Good for testing
header: X-User-ID
creds_in_header: true
rails:
# Rails version this is used for reading the
@ -1097,10 +1106,10 @@ auth:
secret_key_base: 0a248500a64c01184edb4d7ad3a805488f8097ac761b76aaa6c17c01dcb7af03a2f18ba61b2868134b9c7b79a122bc0dadff4367414a2d173297bfea92be5566
# Remote cookie store. (memcache or redis)
# url: redis://127.0.0.1:6379
# password: test
# max_idle: 80,
# max_active: 12000,
# url: redis://redis:6379
# password: ""
# max_idle: 80
# max_active: 12000
# In most cases you don't need these
# salt: "encrypted cookie"
@ -1120,20 +1129,23 @@ database:
dbname: app_development
user: postgres
password: ''
# pool_size: 10
# max_retries: 0
# log_level: "debug"
#schema: "public"
#pool_size: 10
#max_retries: 0
#log_level: "debug"
# Define variables here that you want to use in filters
# sub-queries must be wrapped in ()
variables:
account_id: "select account_id from users where id = $user_id"
account_id: "(select account_id from users where id = $user_id)"
# Define defaults to for the field key and values below
defaults:
filter: ["{ user_id: { eq: $user_id } }"]
# filters: ["{ user_id: { eq: $user_id } }"]
# Field and table names that you wish to block
blacklist:
blocklist:
- ar_internal_metadata
- schema_migrations
- secret
@ -1141,43 +1153,85 @@ database:
- encrypted
- token
tables:
- name: users
# This filter will overwrite defaults.filter
filter: ["{ id: { eq: $user_id } }"]
tables:
- name: customers
remotes:
- name: payments
id: stripe_id
url: http://rails_app:3000/stripe/$id
path: data
# debug: true
pass_headers:
- cookie
set_headers:
- name: Host
value: 0.0.0.0
# - name: Authorization
# value: Bearer <stripe_api_key>
- name: products
# Multiple filters are AND'd together
filter: [
"{ price: { gt: 0 } }",
"{ price: { lt: 8 } }"
]
- # You can create new fields that have a
# real db table backing them
name: me
table: users
- name: customers
# No filter is used for this field not
# even defaults.filter
filter: none
roles_query: "SELECT * FROM users as usr WHERE id = $user_id"
remotes:
- name: payments
id: stripe_id
url: http://rails_app:3000/stripe/$id
path: data
# pass_headers:
# - cookie
# - host
set_headers:
- name: Authorization
value: Bearer <stripe_api_key>
roles:
- name: anon
tables:
- name: products
limit: 10
- # You can create new fields that have a
# real db table backing them
name: me
table: users
filter: ["{ id: { eq: $user_id } }"]
query:
columns: ["id", "name", "description" ]
aggregation: false
insert:
allow: false
update:
allow: false
delete:
allow: false
- name: user
tables:
- name: users
query:
filters: ["{ id: { _eq: $user_id } }"]
- name: products
query:
limit: 50
filters: ["{ user_id: { eq: $user_id } }"]
columns: ["id", "name", "description" ]
disable_aggregation: false
insert:
filters: ["{ user_id: { eq: $user_id } }"]
columns: ["id", "name", "description" ]
set:
- created_at: "now"
update:
filters: ["{ user_id: { eq: $user_id } }"]
columns:
- id
- name
set:
- updated_at: "now"
delete:
deny: true
- name: admin
match: id = 1
tables:
- name: users
# query:
# filters: ["{ account_id: { _eq: $account_id } }"]
# - name: posts
# filter: ["{ account_id: { _eq: $account_id } }"]
```
If deploying into environments like Kubernetes it's useful to be able to configure things like secrets and hosts though environment variables therfore we expose the below environment variables. This is escpecially useful for secrets since they are usually injected in via a secrets management framework ie. Kubernetes Secrets

View File

@ -0,0 +1,273 @@
GIT
remote: https://github.com/stympy/faker.git
revision: 4e9144825fcc9ba5c83cc0fd037779ab82f3120b
branch: master
specs:
faker (2.6.0)
i18n (>= 1.6, < 1.8)
GEM
remote: https://rubygems.org/
specs:
actioncable (6.0.0)
actionpack (= 6.0.0)
nio4r (~> 2.0)
websocket-driver (>= 0.6.1)
actionmailbox (6.0.0)
actionpack (= 6.0.0)
activejob (= 6.0.0)
activerecord (= 6.0.0)
activestorage (= 6.0.0)
activesupport (= 6.0.0)
mail (>= 2.7.1)
actionmailer (6.0.0)
actionpack (= 6.0.0)
actionview (= 6.0.0)
activejob (= 6.0.0)
mail (~> 2.5, >= 2.5.4)
rails-dom-testing (~> 2.0)
actionpack (6.0.0)
actionview (= 6.0.0)
activesupport (= 6.0.0)
rack (~> 2.0)
rack-test (>= 0.6.3)
rails-dom-testing (~> 2.0)
rails-html-sanitizer (~> 1.0, >= 1.2.0)
actiontext (6.0.0)
actionpack (= 6.0.0)
activerecord (= 6.0.0)
activestorage (= 6.0.0)
activesupport (= 6.0.0)
nokogiri (>= 1.8.5)
actionview (6.0.0)
activesupport (= 6.0.0)
builder (~> 3.1)
erubi (~> 1.4)
rails-dom-testing (~> 2.0)
rails-html-sanitizer (~> 1.1, >= 1.2.0)
activejob (6.0.0)
activesupport (= 6.0.0)
globalid (>= 0.3.6)
activemodel (6.0.0)
activesupport (= 6.0.0)
activerecord (6.0.0)
activemodel (= 6.0.0)
activesupport (= 6.0.0)
activestorage (6.0.0)
actionpack (= 6.0.0)
activejob (= 6.0.0)
activerecord (= 6.0.0)
marcel (~> 0.3.1)
activesupport (6.0.0)
concurrent-ruby (~> 1.0, >= 1.0.2)
i18n (>= 0.7, < 2)
minitest (~> 5.1)
tzinfo (~> 1.1)
zeitwerk (~> 2.1, >= 2.1.8)
addressable (2.7.0)
public_suffix (>= 2.0.2, < 5.0)
archive-zip (0.12.0)
io-like (~> 0.3.0)
bcrypt (3.1.13)
bindex (0.8.1)
bootsnap (1.4.5)
msgpack (~> 1.0)
builder (3.2.3)
byebug (11.0.1)
capybara (3.29.0)
addressable
mini_mime (>= 0.1.3)
nokogiri (~> 1.8)
rack (>= 1.6.0)
rack-test (>= 0.6.3)
regexp_parser (~> 1.5)
xpath (~> 3.2)
childprocess (3.0.0)
chromedriver-helper (2.1.1)
archive-zip (~> 0.10)
nokogiri (~> 1.8)
coffee-rails (4.2.2)
coffee-script (>= 2.2.0)
railties (>= 4.0.0)
coffee-script (2.4.1)
coffee-script-source
execjs
coffee-script-source (1.12.2)
concurrent-ruby (1.1.5)
crass (1.0.4)
devise (4.7.1)
bcrypt (~> 3.0)
orm_adapter (~> 0.1)
railties (>= 4.1.0)
responders
warden (~> 1.2.3)
erubi (1.9.0)
execjs (2.7.0)
ffi (1.11.1)
globalid (0.4.2)
activesupport (>= 4.2.0)
i18n (1.7.0)
concurrent-ruby (~> 1.0)
io-like (0.3.0)
jbuilder (2.9.1)
activesupport (>= 4.2.0)
listen (3.1.5)
rb-fsevent (~> 0.9, >= 0.9.4)
rb-inotify (~> 0.9, >= 0.9.7)
ruby_dep (~> 1.2)
loofah (2.3.0)
crass (~> 1.0.2)
nokogiri (>= 1.5.9)
mail (2.7.1)
mini_mime (>= 0.1.1)
marcel (0.3.3)
mimemagic (~> 0.3.2)
method_source (0.9.2)
mimemagic (0.3.3)
mini_mime (1.0.2)
mini_portile2 (2.4.0)
minitest (5.12.2)
msgpack (1.3.1)
nio4r (2.5.2)
nokogiri (1.10.4)
mini_portile2 (~> 2.4.0)
orm_adapter (0.5.0)
pg (1.1.4)
public_suffix (4.0.1)
puma (3.12.1)
rack (2.0.7)
rack-test (1.1.0)
rack (>= 1.0, < 3)
rails (6.0.0)
actioncable (= 6.0.0)
actionmailbox (= 6.0.0)
actionmailer (= 6.0.0)
actionpack (= 6.0.0)
actiontext (= 6.0.0)
actionview (= 6.0.0)
activejob (= 6.0.0)
activemodel (= 6.0.0)
activerecord (= 6.0.0)
activestorage (= 6.0.0)
activesupport (= 6.0.0)
bundler (>= 1.3.0)
railties (= 6.0.0)
sprockets-rails (>= 2.0.0)
rails-dom-testing (2.0.3)
activesupport (>= 4.2.0)
nokogiri (>= 1.6)
rails-html-sanitizer (1.3.0)
loofah (~> 2.3)
railties (6.0.0)
actionpack (= 6.0.0)
activesupport (= 6.0.0)
method_source
rake (>= 0.8.7)
thor (>= 0.20.3, < 2.0)
rake (13.0.0)
rb-fsevent (0.10.3)
rb-inotify (0.10.0)
ffi (~> 1.0)
redis (4.1.3)
redis-actionpack (5.1.0)
actionpack (>= 4.0, < 7)
redis-rack (>= 1, < 3)
redis-store (>= 1.1.0, < 2)
redis-activesupport (5.2.0)
activesupport (>= 3, < 7)
redis-store (>= 1.3, < 2)
redis-rack (2.0.6)
rack (>= 1.5, < 3)
redis-store (>= 1.2, < 2)
redis-rails (5.0.2)
redis-actionpack (>= 5.0, < 6)
redis-activesupport (>= 5.0, < 6)
redis-store (>= 1.2, < 2)
redis-store (1.8.0)
redis (>= 4, < 5)
regexp_parser (1.6.0)
responders (3.0.0)
actionpack (>= 5.0)
railties (>= 5.0)
ruby_dep (1.5.0)
rubyzip (2.0.0)
sass (3.7.4)
sass-listen (~> 4.0.0)
sass-listen (4.0.0)
rb-fsevent (~> 0.9, >= 0.9.4)
rb-inotify (~> 0.9, >= 0.9.7)
sass-rails (5.1.0)
railties (>= 5.2.0)
sass (~> 3.1)
sprockets (>= 2.8, < 4.0)
sprockets-rails (>= 2.0, < 4.0)
tilt (>= 1.1, < 3)
selenium-webdriver (3.142.6)
childprocess (>= 0.5, < 4.0)
rubyzip (>= 1.2.2)
spring (2.1.0)
spring-watcher-listen (2.0.1)
listen (>= 2.7, < 4.0)
spring (>= 1.2, < 3.0)
sprockets (3.7.2)
concurrent-ruby (~> 1.0)
rack (> 1, < 3)
sprockets-rails (3.2.1)
actionpack (>= 4.0)
activesupport (>= 4.0)
sprockets (>= 3.0.0)
thor (0.20.3)
thread_safe (0.3.6)
tilt (2.0.10)
turbolinks (5.2.1)
turbolinks-source (~> 5.2)
turbolinks-source (5.2.0)
tzinfo (1.2.5)
thread_safe (~> 0.1)
uglifier (4.2.0)
execjs (>= 0.3.0, < 3)
warden (1.2.8)
rack (>= 2.0.6)
web-console (4.0.1)
actionview (>= 6.0.0)
activemodel (>= 6.0.0)
bindex (>= 0.4.0)
railties (>= 6.0.0)
websocket-driver (0.7.1)
websocket-extensions (>= 0.1.0)
websocket-extensions (0.1.4)
xpath (3.2.0)
nokogiri (~> 1.8)
zeitwerk (2.2.0)
PLATFORMS
ruby
DEPENDENCIES
bootsnap (>= 1.1.0)
byebug
capybara (>= 2.15)
chromedriver-helper
coffee-rails (~> 4.2)
devise
faker!
jbuilder (~> 2.5)
listen (>= 3.0.5, < 3.2)
pg (>= 0.18, < 2.0)
puma (~> 3.11)
rails (~> 6.0.0.rc1)
redis-rails
sass-rails (~> 5.0)
selenium-webdriver
spring
spring-watcher-listen (~> 2.0.0)
turbolinks (~> 5)
tzinfo-data
uglifier (>= 1.3.0)
web-console (>= 3.3.0)
RUBY VERSION
ruby 2.5.7p206
BUNDLED WITH
1.17.3

View File

@ -1,7 +1,6 @@
package psql
import (
"bytes"
"errors"
"fmt"
"io"
@ -10,9 +9,9 @@ import (
"github.com/dosco/super-graph/qcode"
)
var zeroPaging = qcode.Paging{}
var noLimit = qcode.Paging{NoLimit: true}
func (co *Compiler) compileMutation(qc *qcode.QCode, w *bytes.Buffer, vars Variables) (uint32, error) {
func (co *Compiler) compileMutation(qc *qcode.QCode, w io.Writer, vars Variables) (uint32, error) {
if len(qc.Selects) == 0 {
return 0, errors.New("empty query")
}
@ -25,27 +24,27 @@ func (co *Compiler) compileMutation(qc *qcode.QCode, w *bytes.Buffer, vars Varia
return 0, err
}
c.w.WriteString(`WITH `)
io.WriteString(c.w, `WITH `)
quoted(c.w, ti.Name)
c.w.WriteString(` AS `)
io.WriteString(c.w, ` AS `)
switch root.Action {
case qcode.ActionInsert:
switch qc.Type {
case qcode.QTInsert:
if _, err := c.renderInsert(qc, w, vars, ti); err != nil {
return 0, err
}
case qcode.ActionUpdate:
case qcode.QTUpdate:
if _, err := c.renderUpdate(qc, w, vars, ti); err != nil {
return 0, err
}
case qcode.ActionUpsert:
case qcode.QTUpsert:
if _, err := c.renderUpsert(qc, w, vars, ti); err != nil {
return 0, err
}
case qcode.ActionDelete:
case qcode.QTDelete:
if _, err := c.renderDelete(qc, w, vars, ti); err != nil {
return 0, err
}
@ -56,7 +55,7 @@ func (co *Compiler) compileMutation(qc *qcode.QCode, w *bytes.Buffer, vars Varia
io.WriteString(c.w, ` RETURNING *) `)
root.Paging = zeroPaging
root.Paging = noLimit
root.DistinctOn = root.DistinctOn[:]
root.OrderBy = root.OrderBy[:]
root.Where = nil
@ -65,13 +64,12 @@ func (co *Compiler) compileMutation(qc *qcode.QCode, w *bytes.Buffer, vars Varia
return c.compileQuery(qc, w)
}
func (c *compilerContext) renderInsert(qc *qcode.QCode, w *bytes.Buffer,
func (c *compilerContext) renderInsert(qc *qcode.QCode, w io.Writer,
vars Variables, ti *DBTableInfo) (uint32, error) {
root := &qc.Selects[0]
insert, ok := vars[root.ActionVar]
insert, ok := vars[qc.ActionVar]
if !ok {
return 0, fmt.Errorf("Variable '%s' not defined", root.ActionVar)
return 0, fmt.Errorf("Variable '%s' not defined", qc.ActionVar)
}
jt, array, err := jsn.Tree(insert)
@ -79,56 +77,62 @@ func (c *compilerContext) renderInsert(qc *qcode.QCode, w *bytes.Buffer,
return 0, err
}
c.w.WriteString(`(WITH "input" AS (SELECT {{`)
c.w.WriteString(root.ActionVar)
c.w.WriteString(`}}::json AS j) INSERT INTO `)
io.WriteString(c.w, `(WITH "input" AS (SELECT {{`)
io.WriteString(c.w, qc.ActionVar)
io.WriteString(c.w, `}}::json AS j) INSERT INTO `)
quoted(c.w, ti.Name)
io.WriteString(c.w, ` (`)
c.renderInsertUpdateColumns(qc, w, jt, ti)
io.WriteString(c.w, `)`)
c.w.WriteString(` SELECT `)
io.WriteString(c.w, ` SELECT `)
c.renderInsertUpdateColumns(qc, w, jt, ti)
c.w.WriteString(` FROM input i, `)
io.WriteString(c.w, ` FROM input i, `)
if array {
c.w.WriteString(`json_populate_recordset`)
io.WriteString(c.w, `json_populate_recordset`)
} else {
c.w.WriteString(`json_populate_record`)
io.WriteString(c.w, `json_populate_record`)
}
c.w.WriteString(`(NULL::`)
c.w.WriteString(ti.Name)
c.w.WriteString(`, i.j) t`)
io.WriteString(c.w, `(NULL::`)
io.WriteString(c.w, ti.Name)
io.WriteString(c.w, `, i.j) t`)
return 0, nil
}
func (c *compilerContext) renderInsertUpdateColumns(qc *qcode.QCode, w *bytes.Buffer,
func (c *compilerContext) renderInsertUpdateColumns(qc *qcode.QCode, w io.Writer,
jt map[string]interface{}, ti *DBTableInfo) (uint32, error) {
root := &qc.Selects[0]
i := 0
for _, cn := range ti.ColumnNames {
if _, ok := jt[cn]; !ok {
continue
}
if len(root.Allowed) != 0 {
if _, ok := root.Allowed[cn]; !ok {
continue
}
}
if i != 0 {
io.WriteString(c.w, `, `)
}
c.w.WriteString(cn)
io.WriteString(c.w, cn)
i++
}
return 0, nil
}
func (c *compilerContext) renderUpdate(qc *qcode.QCode, w *bytes.Buffer,
func (c *compilerContext) renderUpdate(qc *qcode.QCode, w io.Writer,
vars Variables, ti *DBTableInfo) (uint32, error) {
root := &qc.Selects[0]
update, ok := vars[root.ActionVar]
update, ok := vars[qc.ActionVar]
if !ok {
return 0, fmt.Errorf("Variable '%s' not defined", root.ActionVar)
return 0, fmt.Errorf("Variable '%s' not defined", qc.ActionVar)
}
jt, array, err := jsn.Tree(update)
@ -136,26 +140,26 @@ func (c *compilerContext) renderUpdate(qc *qcode.QCode, w *bytes.Buffer,
return 0, err
}
c.w.WriteString(`(WITH "input" AS (SELECT {{`)
c.w.WriteString(root.ActionVar)
c.w.WriteString(`}}::json AS j) UPDATE `)
io.WriteString(c.w, `(WITH "input" AS (SELECT {{`)
io.WriteString(c.w, qc.ActionVar)
io.WriteString(c.w, `}}::json AS j) UPDATE `)
quoted(c.w, ti.Name)
io.WriteString(c.w, ` SET (`)
c.renderInsertUpdateColumns(qc, w, jt, ti)
c.w.WriteString(`) = (SELECT `)
io.WriteString(c.w, `) = (SELECT `)
c.renderInsertUpdateColumns(qc, w, jt, ti)
c.w.WriteString(` FROM input i, `)
io.WriteString(c.w, ` FROM input i, `)
if array {
c.w.WriteString(`json_populate_recordset`)
io.WriteString(c.w, `json_populate_recordset`)
} else {
c.w.WriteString(`json_populate_record`)
io.WriteString(c.w, `json_populate_record`)
}
c.w.WriteString(`(NULL::`)
c.w.WriteString(ti.Name)
c.w.WriteString(`, i.j) t)`)
io.WriteString(c.w, `(NULL::`)
io.WriteString(c.w, ti.Name)
io.WriteString(c.w, `, i.j) t)`)
io.WriteString(c.w, ` WHERE `)
@ -166,11 +170,11 @@ func (c *compilerContext) renderUpdate(qc *qcode.QCode, w *bytes.Buffer,
return 0, nil
}
func (c *compilerContext) renderDelete(qc *qcode.QCode, w *bytes.Buffer,
func (c *compilerContext) renderDelete(qc *qcode.QCode, w io.Writer,
vars Variables, ti *DBTableInfo) (uint32, error) {
root := &qc.Selects[0]
c.w.WriteString(`(DELETE FROM `)
io.WriteString(c.w, `(DELETE FROM `)
quoted(c.w, ti.Name)
io.WriteString(c.w, ` WHERE `)
@ -181,13 +185,12 @@ func (c *compilerContext) renderDelete(qc *qcode.QCode, w *bytes.Buffer,
return 0, nil
}
func (c *compilerContext) renderUpsert(qc *qcode.QCode, w *bytes.Buffer,
func (c *compilerContext) renderUpsert(qc *qcode.QCode, w io.Writer,
vars Variables, ti *DBTableInfo) (uint32, error) {
root := &qc.Selects[0]
upsert, ok := vars[root.ActionVar]
upsert, ok := vars[qc.ActionVar]
if !ok {
return 0, fmt.Errorf("Variable '%s' not defined", root.ActionVar)
return 0, fmt.Errorf("Variable '%s' not defined", qc.ActionVar)
}
jt, _, err := jsn.Tree(upsert)
@ -199,7 +202,7 @@ func (c *compilerContext) renderUpsert(qc *qcode.QCode, w *bytes.Buffer,
return 0, err
}
c.w.WriteString(` ON CONFLICT DO (`)
io.WriteString(c.w, ` ON CONFLICT DO (`)
i := 0
for _, cn := range ti.ColumnNames {
@ -214,15 +217,15 @@ func (c *compilerContext) renderUpsert(qc *qcode.QCode, w *bytes.Buffer,
if i != 0 {
io.WriteString(c.w, `, `)
}
c.w.WriteString(cn)
io.WriteString(c.w, cn)
i++
}
if i == 0 {
c.w.WriteString(ti.PrimaryCol)
io.WriteString(c.w, ti.PrimaryCol)
}
c.w.WriteString(`) DO `)
io.WriteString(c.w, `) DO `)
c.w.WriteString(`UPDATE `)
io.WriteString(c.w, `UPDATE `)
io.WriteString(c.w, ` SET `)
i = 0
@ -233,17 +236,17 @@ func (c *compilerContext) renderUpsert(qc *qcode.QCode, w *bytes.Buffer,
if i != 0 {
io.WriteString(c.w, `, `)
}
c.w.WriteString(cn)
io.WriteString(c.w, cn)
io.WriteString(c.w, ` = EXCLUDED.`)
c.w.WriteString(cn)
io.WriteString(c.w, cn)
i++
}
return 0, nil
}
func quoted(w *bytes.Buffer, identifier string) {
w.WriteString(`"`)
w.WriteString(identifier)
w.WriteString(`"`)
func quoted(w io.Writer, identifier string) {
io.WriteString(w, `"`)
io.WriteString(w, identifier)
io.WriteString(w, `"`)
}

View File

@ -12,13 +12,13 @@ func simpleInsert(t *testing.T) {
}
}`
sql := `WITH "users" AS (WITH "input" AS (SELECT {{data}}::json AS j) INSERT INTO "users" (full_name, email) SELECT full_name, email FROM input i, json_populate_record(NULL::users, i.j) t RETURNING *) SELECT json_object_agg('user', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."id" AS "id") AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."id" FROM "users") AS "users_0") AS "done_1337";`
sql := `WITH "users" AS (WITH "input" AS (SELECT {{data}}::json AS j) INSERT INTO "users" (full_name, email) SELECT full_name, email FROM input i, json_populate_record(NULL::users, i.j) t RETURNING *) SELECT json_object_agg('user', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."id" AS "id") AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."id" FROM "users") AS "users_0") AS "done_1337"`
vars := map[string]json.RawMessage{
"data": json.RawMessage(`{"email": "reannagreenholt@orn.com", "full_name": "Flo Barton"}`),
}
resSQL, err := compileGQLToPSQL(gql, vars)
resSQL, err := compileGQLToPSQL(gql, vars, "user")
if err != nil {
t.Fatal(err)
}
@ -36,13 +36,13 @@ func singleInsert(t *testing.T) {
}
}`
sql := `WITH "products" AS (WITH "input" AS (SELECT {{insert}}::json AS j) INSERT INTO "products" (name, description, user_id) SELECT name, description, user_id FROM input i, json_populate_record(NULL::products, i.j) t RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337";`
sql := `WITH "products" AS (WITH "input" AS (SELECT {{insert}}::json AS j) INSERT INTO "products" (name, description, user_id) SELECT name, description, user_id FROM input i, json_populate_record(NULL::products, i.j) t RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"`
vars := map[string]json.RawMessage{
"insert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc", "user_id": 5 }`),
}
resSQL, err := compileGQLToPSQL(gql, vars)
resSQL, err := compileGQLToPSQL(gql, vars, "user")
if err != nil {
t.Fatal(err)
}
@ -54,19 +54,19 @@ func singleInsert(t *testing.T) {
func bulkInsert(t *testing.T) {
gql := `mutation {
product(id: 15, insert: $insert) {
product(name: "test", id: 15, insert: $insert) {
id
name
}
}`
sql := `WITH "products" AS (WITH "input" AS (SELECT {{insert}}::json AS j) INSERT INTO "products" (name, description) SELECT name, description FROM input i, json_populate_recordset(NULL::products, i.j) t RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337";`
sql := `WITH "products" AS (WITH "input" AS (SELECT {{insert}}::json AS j) INSERT INTO "products" (name, description) SELECT name, description FROM input i, json_populate_recordset(NULL::products, i.j) t RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"`
vars := map[string]json.RawMessage{
"insert": json.RawMessage(` [{ "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }]`),
}
resSQL, err := compileGQLToPSQL(gql, vars)
resSQL, err := compileGQLToPSQL(gql, vars, "user")
if err != nil {
t.Fatal(err)
}
@ -84,13 +84,13 @@ func singleUpsert(t *testing.T) {
}
}`
sql := `WITH "products" AS (WITH "input" AS (SELECT {{upsert}}::json AS j) INSERT INTO "products" (name, description) SELECT name, description FROM input i, json_populate_record(NULL::products, i.j) t ON CONFLICT DO (id) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337";`
sql := `WITH "products" AS (WITH "input" AS (SELECT {{upsert}}::json AS j) INSERT INTO "products" (name, description) SELECT name, description FROM input i, json_populate_record(NULL::products, i.j) t ON CONFLICT DO (id) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"`
vars := map[string]json.RawMessage{
"upsert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`),
}
resSQL, err := compileGQLToPSQL(gql, vars)
resSQL, err := compileGQLToPSQL(gql, vars, "user")
if err != nil {
t.Fatal(err)
}
@ -108,13 +108,13 @@ func bulkUpsert(t *testing.T) {
}
}`
sql := `WITH "products" AS (WITH "input" AS (SELECT {{upsert}}::json AS j) INSERT INTO "products" (name, description) SELECT name, description FROM input i, json_populate_recordset(NULL::products, i.j) t ON CONFLICT DO (id) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337";`
sql := `WITH "products" AS (WITH "input" AS (SELECT {{upsert}}::json AS j) INSERT INTO "products" (name, description) SELECT name, description FROM input i, json_populate_recordset(NULL::products, i.j) t ON CONFLICT DO (id) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"`
vars := map[string]json.RawMessage{
"upsert": json.RawMessage(` [{ "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }]`),
}
resSQL, err := compileGQLToPSQL(gql, vars)
resSQL, err := compileGQLToPSQL(gql, vars, "user")
if err != nil {
t.Fatal(err)
}
@ -132,13 +132,13 @@ func singleUpdate(t *testing.T) {
}
}`
sql := `WITH "products" AS (WITH "input" AS (SELECT {{update}}::json AS j) UPDATE "products" SET (name, description) = (SELECT name, description FROM input i, json_populate_record(NULL::products, i.j) t) WHERE (("products"."user_id") = {{user_id}}) AND (("products"."id") = 1) AND (("products"."id") = 15) RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337";`
sql := `WITH "products" AS (WITH "input" AS (SELECT {{update}}::json AS j) UPDATE "products" SET (name, description) = (SELECT name, description FROM input i, json_populate_record(NULL::products, i.j) t) WHERE (("products"."user_id") = {{user_id}}) AND (("products"."id") = 1) AND (("products"."id") = 15) RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"`
vars := map[string]json.RawMessage{
"update": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`),
}
resSQL, err := compileGQLToPSQL(gql, vars)
resSQL, err := compileGQLToPSQL(gql, vars, "user")
if err != nil {
t.Fatal(err)
}
@ -156,13 +156,13 @@ func delete(t *testing.T) {
}
}`
sql := `WITH "products" AS (DELETE FROM "products" WHERE (("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") = 1) RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337";`
sql := `WITH "products" AS (DELETE FROM "products" WHERE (("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") = 1) RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"`
vars := map[string]json.RawMessage{
"update": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`),
}
resSQL, err := compileGQLToPSQL(gql, vars)
resSQL, err := compileGQLToPSQL(gql, vars, "user")
if err != nil {
t.Fatal(err)
}
@ -172,7 +172,7 @@ func delete(t *testing.T) {
}
}
func TestCompileInsert(t *testing.T) {
func TestCompileMutate(t *testing.T) {
t.Run("simpleInsert", simpleInsert)
t.Run("singleInsert", singleInsert)
t.Run("bulkInsert", bulkInsert)

View File

@ -49,7 +49,7 @@ func (c *Compiler) IDColumn(table string) (string, error) {
}
type compilerContext struct {
w *bytes.Buffer
w io.Writer
s []qcode.Select
*Compiler
}
@ -60,18 +60,18 @@ func (co *Compiler) CompileEx(qc *qcode.QCode, vars Variables) (uint32, []byte,
return skipped, w.Bytes(), err
}
func (co *Compiler) Compile(qc *qcode.QCode, w *bytes.Buffer, vars Variables) (uint32, error) {
func (co *Compiler) Compile(qc *qcode.QCode, w io.Writer, vars Variables) (uint32, error) {
switch qc.Type {
case qcode.QTQuery:
return co.compileQuery(qc, w)
case qcode.QTMutation:
case qcode.QTInsert, qcode.QTUpdate, qcode.QTDelete, qcode.QTUpsert:
return co.compileMutation(qc, w, vars)
}
return 0, errors.New("unknown operation")
return 0, fmt.Errorf("Unknown operation type %d", qc.Type)
}
func (co *Compiler) compileQuery(qc *qcode.QCode, w *bytes.Buffer) (uint32, error) {
func (co *Compiler) compileQuery(qc *qcode.QCode, w io.Writer) (uint32, error) {
if len(qc.Selects) == 0 {
return 0, errors.New("empty query")
}
@ -90,17 +90,17 @@ func (co *Compiler) compileQuery(qc *qcode.QCode, w *bytes.Buffer) (uint32, erro
//fmt.Fprintf(w, `SELECT json_object_agg('%s', %s) FROM (`,
//root.FieldName, root.Table)
c.w.WriteString(`SELECT json_object_agg('`)
c.w.WriteString(root.FieldName)
c.w.WriteString(`', `)
io.WriteString(c.w, `SELECT json_object_agg('`)
io.WriteString(c.w, root.FieldName)
io.WriteString(c.w, `', `)
if ti.Singular == false {
c.w.WriteString(root.Table)
io.WriteString(c.w, root.Table)
} else {
c.w.WriteString("sel_json_")
io.WriteString(c.w, "sel_json_")
int2string(c.w, root.ID)
}
c.w.WriteString(`) FROM (`)
io.WriteString(c.w, `) FROM (`)
var ignored uint32
@ -161,9 +161,8 @@ func (co *Compiler) compileQuery(qc *qcode.QCode, w *bytes.Buffer) (uint32, erro
}
}
c.w.WriteString(`)`)
io.WriteString(c.w, `)`)
alias(c.w, `done_1337`)
c.w.WriteString(`;`)
return ignored, nil
}
@ -219,10 +218,10 @@ func (c *compilerContext) renderSelect(sel *qcode.Select, ti *DBTableInfo) (uint
// SELECT
if ti.Singular == false {
//fmt.Fprintf(w, `SELECT coalesce(json_agg("%s"`, c.sel.Table)
c.w.WriteString(`SELECT coalesce(json_agg("`)
c.w.WriteString("sel_json_")
io.WriteString(c.w, `SELECT coalesce(json_agg("`)
io.WriteString(c.w, "sel_json_")
int2string(c.w, sel.ID)
c.w.WriteString(`"`)
io.WriteString(c.w, `"`)
if hasOrder {
err := c.renderOrderBy(sel, ti)
@ -232,24 +231,24 @@ func (c *compilerContext) renderSelect(sel *qcode.Select, ti *DBTableInfo) (uint
}
//fmt.Fprintf(w, `), '[]') AS "%s" FROM (`, c.sel.Table)
c.w.WriteString(`), '[]')`)
io.WriteString(c.w, `), '[]')`)
alias(c.w, sel.Table)
c.w.WriteString(` FROM (`)
io.WriteString(c.w, ` FROM (`)
}
// ROW-TO-JSON
c.w.WriteString(`SELECT `)
io.WriteString(c.w, `SELECT `)
if len(sel.DistinctOn) != 0 {
c.renderDistinctOn(sel, ti)
}
c.w.WriteString(`row_to_json((`)
io.WriteString(c.w, `row_to_json((`)
//fmt.Fprintf(w, `SELECT "sel_%d" FROM (SELECT `, c.sel.ID)
c.w.WriteString(`SELECT "sel_`)
io.WriteString(c.w, `SELECT "sel_`)
int2string(c.w, sel.ID)
c.w.WriteString(`" FROM (SELECT `)
io.WriteString(c.w, `" FROM (SELECT `)
// Combined column names
c.renderColumns(sel, ti)
@ -262,11 +261,11 @@ func (c *compilerContext) renderSelect(sel *qcode.Select, ti *DBTableInfo) (uint
}
//fmt.Fprintf(w, `) AS "sel_%d"`, c.sel.ID)
c.w.WriteString(`)`)
io.WriteString(c.w, `)`)
aliasWithID(c.w, "sel", sel.ID)
//fmt.Fprintf(w, `)) AS "%s"`, c.sel.Table)
c.w.WriteString(`))`)
io.WriteString(c.w, `))`)
aliasWithID(c.w, "sel_json", sel.ID)
// END-ROW-TO-JSON
@ -295,31 +294,33 @@ func (c *compilerContext) renderSelectClose(sel *qcode.Select, ti *DBTableInfo)
}
}
if sel.Action == 0 {
if len(sel.Paging.Limit) != 0 {
//fmt.Fprintf(w, ` LIMIT ('%s') :: integer`, c.sel.Paging.Limit)
c.w.WriteString(` LIMIT ('`)
c.w.WriteString(sel.Paging.Limit)
c.w.WriteString(`') :: integer`)
switch {
case sel.Paging.NoLimit:
break
} else if ti.Singular {
c.w.WriteString(` LIMIT ('1') :: integer`)
case len(sel.Paging.Limit) != 0:
//fmt.Fprintf(w, ` LIMIT ('%s') :: integer`, c.sel.Paging.Limit)
io.WriteString(c.w, ` LIMIT ('`)
io.WriteString(c.w, sel.Paging.Limit)
io.WriteString(c.w, `') :: integer`)
} else {
c.w.WriteString(` LIMIT ('20') :: integer`)
}
case ti.Singular:
io.WriteString(c.w, ` LIMIT ('1') :: integer`)
default:
io.WriteString(c.w, ` LIMIT ('20') :: integer`)
}
if len(sel.Paging.Offset) != 0 {
//fmt.Fprintf(w, ` OFFSET ('%s') :: integer`, c.sel.Paging.Offset)
c.w.WriteString(`OFFSET ('`)
c.w.WriteString(sel.Paging.Offset)
c.w.WriteString(`') :: integer`)
io.WriteString(c.w, `OFFSET ('`)
io.WriteString(c.w, sel.Paging.Offset)
io.WriteString(c.w, `') :: integer`)
}
if ti.Singular == false {
//fmt.Fprintf(w, `) AS "sel_json_agg_%d"`, c.sel.ID)
c.w.WriteString(`)`)
io.WriteString(c.w, `)`)
aliasWithID(c.w, "sel_json_agg", sel.ID)
}
@ -327,15 +328,15 @@ func (c *compilerContext) renderSelectClose(sel *qcode.Select, ti *DBTableInfo)
}
func (c *compilerContext) renderJoin(sel *qcode.Select) error {
c.w.WriteString(` LEFT OUTER JOIN LATERAL (`)
io.WriteString(c.w, ` LEFT OUTER JOIN LATERAL (`)
return nil
}
func (c *compilerContext) renderJoinClose(sel *qcode.Select) error {
//fmt.Fprintf(w, `) AS "%s_%d_join" ON ('true')`, c.sel.Table, c.sel.ID)
c.w.WriteString(`)`)
io.WriteString(c.w, `)`)
aliasWithIDSuffix(c.w, sel.Table, sel.ID, "_join")
c.w.WriteString(` ON ('true')`)
io.WriteString(c.w, ` ON ('true')`)
return nil
}
@ -358,25 +359,43 @@ func (c *compilerContext) renderJoinTable(sel *qcode.Select) error {
//fmt.Fprintf(w, ` LEFT OUTER JOIN "%s" ON (("%s"."%s") = ("%s_%d"."%s"))`,
//rel.Through, rel.Through, rel.ColT, c.parent.Table, c.parent.ID, rel.Col1)
c.w.WriteString(` LEFT OUTER JOIN "`)
c.w.WriteString(rel.Through)
c.w.WriteString(`" ON ((`)
io.WriteString(c.w, ` LEFT OUTER JOIN "`)
io.WriteString(c.w, rel.Through)
io.WriteString(c.w, `" ON ((`)
colWithTable(c.w, rel.Through, rel.ColT)
c.w.WriteString(`) = (`)
io.WriteString(c.w, `) = (`)
colWithTableID(c.w, pt.Name, parent.ID, rel.Col1)
c.w.WriteString(`))`)
io.WriteString(c.w, `))`)
return nil
}
func (c *compilerContext) renderColumns(sel *qcode.Select, ti *DBTableInfo) {
for i, col := range sel.Cols {
i := 0
for _, col := range sel.Cols {
if len(sel.Allowed) != 0 {
n := funcPrefixLen(col.Name)
if n != 0 {
if sel.Functions == false {
continue
}
if _, ok := sel.Allowed[col.Name[n:]]; !ok {
continue
}
} else {
if _, ok := sel.Allowed[col.Name]; !ok {
continue
}
}
}
if i != 0 {
io.WriteString(c.w, ", ")
}
//fmt.Fprintf(w, `"%s_%d"."%s" AS "%s"`,
//c.sel.Table, c.sel.ID, col.Name, col.FieldName)
colWithTableIDAlias(c.w, ti.Name, sel.ID, col.Name, col.FieldName)
i++
}
}
@ -415,10 +434,24 @@ func (c *compilerContext) renderJoinedColumns(sel *qcode.Select, ti *DBTableInfo
}
childSel := &c.s[id]
cti, err := c.schema.GetTable(childSel.Table)
if err != nil {
continue
}
//fmt.Fprintf(w, `"%s_%d_join"."%s" AS "%s"`,
//s.Table, s.ID, s.Table, s.FieldName)
colWithTableIDSuffixAlias(c.w, childSel.Table, childSel.ID,
"_join", childSel.Table, childSel.FieldName)
if cti.Singular {
io.WriteString(c.w, `"sel_json_`)
int2string(c.w, childSel.ID)
io.WriteString(c.w, `" AS "`)
io.WriteString(c.w, childSel.FieldName)
io.WriteString(c.w, `"`)
} else {
colWithTableIDSuffixAlias(c.w, childSel.Table, childSel.ID,
"_join", childSel.Table, childSel.FieldName)
}
}
return nil
@ -433,9 +466,10 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
isSearch := sel.Args["search"] != nil
isAgg := false
c.w.WriteString(` FROM (SELECT `)
io.WriteString(c.w, ` FROM (SELECT `)
for i, col := range sel.Cols {
i := 0
for n, col := range sel.Cols {
cn := col.Name
_, isRealCol := ti.Columns[cn]
@ -447,93 +481,116 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
cn = ti.TSVCol
arg := sel.Args["search"]
if i != 0 {
io.WriteString(c.w, `, `)
}
//fmt.Fprintf(w, `ts_rank("%s"."%s", to_tsquery('%s')) AS %s`,
//c.sel.Table, cn, arg.Val, col.Name)
c.w.WriteString(`ts_rank(`)
io.WriteString(c.w, `ts_rank(`)
colWithTable(c.w, ti.Name, cn)
c.w.WriteString(`, to_tsquery('`)
c.w.WriteString(arg.Val)
c.w.WriteString(`')`)
io.WriteString(c.w, `, to_tsquery('`)
io.WriteString(c.w, arg.Val)
io.WriteString(c.w, `')`)
alias(c.w, col.Name)
i++
case strings.HasPrefix(cn, "search_headline_"):
cn = cn[16:]
arg := sel.Args["search"]
if i != 0 {
io.WriteString(c.w, `, `)
}
//fmt.Fprintf(w, `ts_headline("%s"."%s", to_tsquery('%s')) AS %s`,
//c.sel.Table, cn, arg.Val, col.Name)
c.w.WriteString(`ts_headlinek(`)
io.WriteString(c.w, `ts_headlinek(`)
colWithTable(c.w, ti.Name, cn)
c.w.WriteString(`, to_tsquery('`)
c.w.WriteString(arg.Val)
c.w.WriteString(`')`)
io.WriteString(c.w, `, to_tsquery('`)
io.WriteString(c.w, arg.Val)
io.WriteString(c.w, `')`)
alias(c.w, col.Name)
i++
}
} else {
pl := funcPrefixLen(cn)
if pl == 0 {
if i != 0 {
io.WriteString(c.w, `, `)
}
//fmt.Fprintf(w, `'%s not defined' AS %s`, cn, col.Name)
c.w.WriteString(`'`)
c.w.WriteString(cn)
c.w.WriteString(` not defined'`)
io.WriteString(c.w, `'`)
io.WriteString(c.w, cn)
io.WriteString(c.w, ` not defined'`)
alias(c.w, col.Name)
} else {
isAgg = true
i++
} else if sel.Functions {
cn1 := cn[pl:]
if _, ok := sel.Allowed[cn1]; !ok {
continue
}
if i != 0 {
io.WriteString(c.w, `, `)
}
fn := cn[0 : pl-1]
cn := cn[pl:]
isAgg = true
//fmt.Fprintf(w, `%s("%s"."%s") AS %s`, fn, c.sel.Table, cn, col.Name)
c.w.WriteString(fn)
c.w.WriteString(`(`)
colWithTable(c.w, ti.Name, cn)
c.w.WriteString(`)`)
io.WriteString(c.w, fn)
io.WriteString(c.w, `(`)
colWithTable(c.w, ti.Name, cn1)
io.WriteString(c.w, `)`)
alias(c.w, col.Name)
i++
}
}
} else {
groupBy = append(groupBy, i)
groupBy = append(groupBy, n)
//fmt.Fprintf(w, `"%s"."%s"`, c.sel.Table, cn)
if i != 0 {
io.WriteString(c.w, `, `)
}
colWithTable(c.w, ti.Name, cn)
}
i++
if i < len(sel.Cols)-1 || len(childCols) != 0 {
//io.WriteString(w, ", ")
c.w.WriteString(`, `)
}
}
for i, col := range childCols {
for _, col := range childCols {
if i != 0 {
//io.WriteString(w, ", ")
c.w.WriteString(`, `)
io.WriteString(c.w, `, `)
}
//fmt.Fprintf(w, `"%s"."%s"`, col.Table, col.Name)
colWithTable(c.w, col.Table, col.Name)
i++
}
c.w.WriteString(` FROM `)
io.WriteString(c.w, ` FROM `)
//fmt.Fprintf(w, ` FROM "%s"`, c.sel.Table)
c.w.WriteString(`"`)
c.w.WriteString(ti.Name)
c.w.WriteString(`"`)
io.WriteString(c.w, `"`)
io.WriteString(c.w, ti.Name)
io.WriteString(c.w, `"`)
// if tn, ok := c.tmap[sel.Table]; ok {
// //fmt.Fprintf(w, ` FROM "%s" AS "%s"`, tn, c.sel.Table)
// tableWithAlias(c.w, ti.Name, sel.Table)
// } else {
// //fmt.Fprintf(w, ` FROM "%s"`, c.sel.Table)
// c.w.WriteString(`"`)
// c.w.WriteString(sel.Table)
// c.w.WriteString(`"`)
// io.WriteString(c.w, `"`)
// io.WriteString(c.w, sel.Table)
// io.WriteString(c.w, `"`)
// }
if isRoot && isFil {
c.w.WriteString(` WHERE (`)
io.WriteString(c.w, ` WHERE (`)
if err := c.renderWhere(sel, ti); err != nil {
return err
}
c.w.WriteString(`)`)
io.WriteString(c.w, `)`)
}
if !isRoot {
@ -541,28 +598,28 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
return err
}
c.w.WriteString(` WHERE (`)
io.WriteString(c.w, ` WHERE (`)
if err := c.renderRelationship(sel, ti); err != nil {
return err
}
if isFil {
c.w.WriteString(` AND `)
io.WriteString(c.w, ` AND `)
if err := c.renderWhere(sel, ti); err != nil {
return err
}
}
c.w.WriteString(`)`)
io.WriteString(c.w, `)`)
}
if isAgg {
if len(groupBy) != 0 {
c.w.WriteString(` GROUP BY `)
io.WriteString(c.w, ` GROUP BY `)
for i, id := range groupBy {
if i != 0 {
c.w.WriteString(`, `)
io.WriteString(c.w, `, `)
}
//fmt.Fprintf(w, `"%s"."%s"`, c.sel.Table, c.sel.Cols[id].Name)
colWithTable(c.w, ti.Name, sel.Cols[id].Name)
@ -570,30 +627,32 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
}
}
if sel.Action == 0 {
if len(sel.Paging.Limit) != 0 {
//fmt.Fprintf(w, ` LIMIT ('%s') :: integer`, c.sel.Paging.Limit)
c.w.WriteString(` LIMIT ('`)
c.w.WriteString(sel.Paging.Limit)
c.w.WriteString(`') :: integer`)
switch {
case sel.Paging.NoLimit:
break
} else if ti.Singular {
c.w.WriteString(` LIMIT ('1') :: integer`)
case len(sel.Paging.Limit) != 0:
//fmt.Fprintf(w, ` LIMIT ('%s') :: integer`, c.sel.Paging.Limit)
io.WriteString(c.w, ` LIMIT ('`)
io.WriteString(c.w, sel.Paging.Limit)
io.WriteString(c.w, `') :: integer`)
} else {
c.w.WriteString(` LIMIT ('20') :: integer`)
}
case ti.Singular:
io.WriteString(c.w, ` LIMIT ('1') :: integer`)
default:
io.WriteString(c.w, ` LIMIT ('20') :: integer`)
}
if len(sel.Paging.Offset) != 0 {
//fmt.Fprintf(w, ` OFFSET ('%s') :: integer`, c.sel.Paging.Offset)
c.w.WriteString(` OFFSET ('`)
c.w.WriteString(sel.Paging.Offset)
c.w.WriteString(`') :: integer`)
io.WriteString(c.w, ` OFFSET ('`)
io.WriteString(c.w, sel.Paging.Offset)
io.WriteString(c.w, `') :: integer`)
}
//fmt.Fprintf(w, `) AS "%s_%d"`, c.sel.Table, c.sel.ID)
c.w.WriteString(`)`)
io.WriteString(c.w, `)`)
aliasWithID(c.w, ti.Name, sel.ID)
return nil
}
@ -604,7 +663,7 @@ func (c *compilerContext) renderOrderByColumns(sel *qcode.Select, ti *DBTableInf
for i := range sel.OrderBy {
if colsRendered {
//io.WriteString(w, ", ")
c.w.WriteString(`, `)
io.WriteString(c.w, `, `)
}
col := sel.OrderBy[i].Col
@ -612,7 +671,7 @@ func (c *compilerContext) renderOrderByColumns(sel *qcode.Select, ti *DBTableInf
//c.sel.Table, c.sel.ID, c,
//c.sel.Table, c.sel.ID, c)
colWithTableID(c.w, ti.Name, sel.ID, col)
c.w.WriteString(` AS `)
io.WriteString(c.w, ` AS `)
tableIDColSuffix(c.w, sel.Table, sel.ID, col, "_ob")
}
}
@ -629,29 +688,29 @@ func (c *compilerContext) renderRelationship(sel *qcode.Select, ti *DBTableInfo)
case RelBelongTo:
//fmt.Fprintf(w, `(("%s"."%s") = ("%s_%d"."%s"))`,
//c.sel.Table, rel.Col1, c.parent.Table, c.parent.ID, rel.Col2)
c.w.WriteString(`((`)
io.WriteString(c.w, `((`)
colWithTable(c.w, ti.Name, rel.Col1)
c.w.WriteString(`) = (`)
io.WriteString(c.w, `) = (`)
colWithTableID(c.w, parent.Table, parent.ID, rel.Col2)
c.w.WriteString(`))`)
io.WriteString(c.w, `))`)
case RelOneToMany:
//fmt.Fprintf(w, `(("%s"."%s") = ("%s_%d"."%s"))`,
//c.sel.Table, rel.Col1, c.parent.Table, c.parent.ID, rel.Col2)
c.w.WriteString(`((`)
io.WriteString(c.w, `((`)
colWithTable(c.w, ti.Name, rel.Col1)
c.w.WriteString(`) = (`)
io.WriteString(c.w, `) = (`)
colWithTableID(c.w, parent.Table, parent.ID, rel.Col2)
c.w.WriteString(`))`)
io.WriteString(c.w, `))`)
case RelOneToManyThrough:
//fmt.Fprintf(w, `(("%s"."%s") = ("%s"."%s"))`,
//c.sel.Table, rel.Col1, rel.Through, rel.Col2)
c.w.WriteString(`((`)
io.WriteString(c.w, `((`)
colWithTable(c.w, ti.Name, rel.Col1)
c.w.WriteString(`) = (`)
io.WriteString(c.w, `) = (`)
colWithTable(c.w, rel.Through, rel.Col2)
c.w.WriteString(`))`)
io.WriteString(c.w, `))`)
}
return nil
@ -675,11 +734,11 @@ func (c *compilerContext) renderWhere(sel *qcode.Select, ti *DBTableInfo) error
case qcode.ExpOp:
switch val {
case qcode.OpAnd:
c.w.WriteString(` AND `)
io.WriteString(c.w, ` AND `)
case qcode.OpOr:
c.w.WriteString(` OR `)
io.WriteString(c.w, ` OR `)
case qcode.OpNot:
c.w.WriteString(`NOT `)
io.WriteString(c.w, `NOT `)
default:
return fmt.Errorf("11: unexpected value %v (%t)", intf, intf)
}
@ -703,62 +762,62 @@ func (c *compilerContext) renderWhere(sel *qcode.Select, ti *DBTableInfo) error
default:
if val.NestedCol {
//fmt.Fprintf(w, `(("%s") `, val.Col)
c.w.WriteString(`(("`)
c.w.WriteString(val.Col)
c.w.WriteString(`") `)
io.WriteString(c.w, `(("`)
io.WriteString(c.w, val.Col)
io.WriteString(c.w, `") `)
} else if len(val.Col) != 0 {
//fmt.Fprintf(w, `(("%s"."%s") `, c.sel.Table, val.Col)
c.w.WriteString(`((`)
io.WriteString(c.w, `((`)
colWithTable(c.w, ti.Name, val.Col)
c.w.WriteString(`) `)
io.WriteString(c.w, `) `)
}
valExists := true
switch val.Op {
case qcode.OpEquals:
c.w.WriteString(`=`)
io.WriteString(c.w, `=`)
case qcode.OpNotEquals:
c.w.WriteString(`!=`)
io.WriteString(c.w, `!=`)
case qcode.OpGreaterOrEquals:
c.w.WriteString(`>=`)
io.WriteString(c.w, `>=`)
case qcode.OpLesserOrEquals:
c.w.WriteString(`<=`)
io.WriteString(c.w, `<=`)
case qcode.OpGreaterThan:
c.w.WriteString(`>`)
io.WriteString(c.w, `>`)
case qcode.OpLesserThan:
c.w.WriteString(`<`)
io.WriteString(c.w, `<`)
case qcode.OpIn:
c.w.WriteString(`IN`)
io.WriteString(c.w, `IN`)
case qcode.OpNotIn:
c.w.WriteString(`NOT IN`)
io.WriteString(c.w, `NOT IN`)
case qcode.OpLike:
c.w.WriteString(`LIKE`)
io.WriteString(c.w, `LIKE`)
case qcode.OpNotLike:
c.w.WriteString(`NOT LIKE`)
io.WriteString(c.w, `NOT LIKE`)
case qcode.OpILike:
c.w.WriteString(`ILIKE`)
io.WriteString(c.w, `ILIKE`)
case qcode.OpNotILike:
c.w.WriteString(`NOT ILIKE`)
io.WriteString(c.w, `NOT ILIKE`)
case qcode.OpSimilar:
c.w.WriteString(`SIMILAR TO`)
io.WriteString(c.w, `SIMILAR TO`)
case qcode.OpNotSimilar:
c.w.WriteString(`NOT SIMILAR TO`)
io.WriteString(c.w, `NOT SIMILAR TO`)
case qcode.OpContains:
c.w.WriteString(`@>`)
io.WriteString(c.w, `@>`)
case qcode.OpContainedIn:
c.w.WriteString(`<@`)
io.WriteString(c.w, `<@`)
case qcode.OpHasKey:
c.w.WriteString(`?`)
io.WriteString(c.w, `?`)
case qcode.OpHasKeyAny:
c.w.WriteString(`?|`)
io.WriteString(c.w, `?|`)
case qcode.OpHasKeyAll:
c.w.WriteString(`?&`)
io.WriteString(c.w, `?&`)
case qcode.OpIsNull:
if strings.EqualFold(val.Val, "true") {
c.w.WriteString(`IS NULL)`)
io.WriteString(c.w, `IS NULL)`)
} else {
c.w.WriteString(`IS NOT NULL)`)
io.WriteString(c.w, `IS NOT NULL)`)
}
valExists = false
case qcode.OpEqID:
@ -766,20 +825,20 @@ func (c *compilerContext) renderWhere(sel *qcode.Select, ti *DBTableInfo) error
return fmt.Errorf("no primary key column defined for %s", ti.Name)
}
//fmt.Fprintf(w, `(("%s") =`, c.ti.PrimaryCol)
c.w.WriteString(`((`)
io.WriteString(c.w, `((`)
colWithTable(c.w, ti.Name, ti.PrimaryCol)
//c.w.WriteString(ti.PrimaryCol)
c.w.WriteString(`) =`)
//io.WriteString(c.w, ti.PrimaryCol)
io.WriteString(c.w, `) =`)
case qcode.OpTsQuery:
if len(ti.TSVCol) == 0 {
return fmt.Errorf("no tsv column defined for %s", ti.Name)
}
//fmt.Fprintf(w, `(("%s") @@ to_tsquery('%s'))`, c.ti.TSVCol, val.Val)
c.w.WriteString(`(("`)
c.w.WriteString(ti.TSVCol)
c.w.WriteString(`") @@ to_tsquery('`)
c.w.WriteString(val.Val)
c.w.WriteString(`'))`)
io.WriteString(c.w, `(("`)
io.WriteString(c.w, ti.TSVCol)
io.WriteString(c.w, `") @@ to_tsquery('`)
io.WriteString(c.w, val.Val)
io.WriteString(c.w, `'))`)
valExists = false
default:
@ -792,7 +851,7 @@ func (c *compilerContext) renderWhere(sel *qcode.Select, ti *DBTableInfo) error
} else {
c.renderVal(val, c.vars)
}
c.w.WriteString(`)`)
io.WriteString(c.w, `)`)
}
qcode.FreeExp(val)
@ -808,10 +867,10 @@ func (c *compilerContext) renderWhere(sel *qcode.Select, ti *DBTableInfo) error
}
func (c *compilerContext) renderOrderBy(sel *qcode.Select, ti *DBTableInfo) error {
c.w.WriteString(` ORDER BY `)
io.WriteString(c.w, ` ORDER BY `)
for i := range sel.OrderBy {
if i != 0 {
c.w.WriteString(`, `)
io.WriteString(c.w, `, `)
}
ob := sel.OrderBy[i]
@ -819,27 +878,27 @@ func (c *compilerContext) renderOrderBy(sel *qcode.Select, ti *DBTableInfo) erro
case qcode.OrderAsc:
//fmt.Fprintf(w, `"%s_%d.ob.%s" ASC`, sel.Table, sel.ID, ob.Col)
tableIDColSuffix(c.w, ti.Name, sel.ID, ob.Col, "_ob")
c.w.WriteString(` ASC`)
io.WriteString(c.w, ` ASC`)
case qcode.OrderDesc:
//fmt.Fprintf(w, `"%s_%d.ob.%s" DESC`, sel.Table, sel.ID, ob.Col)
tableIDColSuffix(c.w, ti.Name, sel.ID, ob.Col, "_ob")
c.w.WriteString(` DESC`)
io.WriteString(c.w, ` DESC`)
case qcode.OrderAscNullsFirst:
//fmt.Fprintf(w, `"%s_%d.ob.%s" ASC NULLS FIRST`, sel.Table, sel.ID, ob.Col)
tableIDColSuffix(c.w, ti.Name, sel.ID, ob.Col, "_ob")
c.w.WriteString(` ASC NULLS FIRST`)
io.WriteString(c.w, ` ASC NULLS FIRST`)
case qcode.OrderDescNullsFirst:
//fmt.Fprintf(w, `%s_%d.ob.%s DESC NULLS FIRST`, sel.Table, sel.ID, ob.Col)
tableIDColSuffix(c.w, ti.Name, sel.ID, ob.Col, "_ob")
c.w.WriteString(` DESC NULLLS FIRST`)
io.WriteString(c.w, ` DESC NULLLS FIRST`)
case qcode.OrderAscNullsLast:
//fmt.Fprintf(w, `"%s_%d.ob.%s ASC NULLS LAST`, sel.Table, sel.ID, ob.Col)
tableIDColSuffix(c.w, ti.Name, sel.ID, ob.Col, "_ob")
c.w.WriteString(` ASC NULLS LAST`)
io.WriteString(c.w, ` ASC NULLS LAST`)
case qcode.OrderDescNullsLast:
//fmt.Fprintf(w, `%s_%d.ob.%s DESC NULLS LAST`, sel.Table, sel.ID, ob.Col)
tableIDColSuffix(c.w, ti.Name, sel.ID, ob.Col, "_ob")
c.w.WriteString(` DESC NULLS LAST`)
io.WriteString(c.w, ` DESC NULLS LAST`)
default:
return fmt.Errorf("13: unexpected value %v", ob.Order)
}
@ -851,30 +910,30 @@ func (c *compilerContext) renderDistinctOn(sel *qcode.Select, ti *DBTableInfo) {
io.WriteString(c.w, `DISTINCT ON (`)
for i := range sel.DistinctOn {
if i != 0 {
c.w.WriteString(`, `)
io.WriteString(c.w, `, `)
}
//fmt.Fprintf(w, `"%s_%d.ob.%s"`, c.sel.Table, c.sel.ID, c.sel.DistinctOn[i])
tableIDColSuffix(c.w, ti.Name, sel.ID, sel.DistinctOn[i], "_ob")
}
c.w.WriteString(`) `)
io.WriteString(c.w, `) `)
}
func (c *compilerContext) renderList(ex *qcode.Exp) {
io.WriteString(c.w, ` (`)
for i := range ex.ListVal {
if i != 0 {
c.w.WriteString(`, `)
io.WriteString(c.w, `, `)
}
switch ex.ListType {
case qcode.ValBool, qcode.ValInt, qcode.ValFloat:
c.w.WriteString(ex.ListVal[i])
io.WriteString(c.w, ex.ListVal[i])
case qcode.ValStr:
c.w.WriteString(`'`)
c.w.WriteString(ex.ListVal[i])
c.w.WriteString(`'`)
io.WriteString(c.w, `'`)
io.WriteString(c.w, ex.ListVal[i])
io.WriteString(c.w, `'`)
}
}
c.w.WriteString(`)`)
io.WriteString(c.w, `)`)
}
func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string) {
@ -883,27 +942,27 @@ func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string) {
switch ex.Type {
case qcode.ValBool, qcode.ValInt, qcode.ValFloat:
if len(ex.Val) != 0 {
c.w.WriteString(ex.Val)
io.WriteString(c.w, ex.Val)
} else {
c.w.WriteString(`''`)
io.WriteString(c.w, `''`)
}
case qcode.ValStr:
c.w.WriteString(`'`)
c.w.WriteString(ex.Val)
c.w.WriteString(`'`)
io.WriteString(c.w, `'`)
io.WriteString(c.w, ex.Val)
io.WriteString(c.w, `'`)
case qcode.ValVar:
if val, ok := vars[ex.Val]; ok {
c.w.WriteString(val)
io.WriteString(c.w, val)
} else {
//fmt.Fprintf(w, `'{{%s}}'`, ex.Val)
c.w.WriteString(`{{`)
c.w.WriteString(ex.Val)
c.w.WriteString(`}}`)
io.WriteString(c.w, `{{`)
io.WriteString(c.w, ex.Val)
io.WriteString(c.w, `}}`)
}
}
//c.w.WriteString(`)`)
//io.WriteString(c.w, `)`)
}
func funcPrefixLen(fn string) int {
@ -939,105 +998,105 @@ func hasBit(n uint32, pos uint32) bool {
return (val > 0)
}
func alias(w *bytes.Buffer, alias string) {
w.WriteString(` AS "`)
w.WriteString(alias)
w.WriteString(`"`)
func alias(w io.Writer, alias string) {
io.WriteString(w, ` AS "`)
io.WriteString(w, alias)
io.WriteString(w, `"`)
}
func aliasWithID(w *bytes.Buffer, alias string, id int32) {
w.WriteString(` AS "`)
w.WriteString(alias)
w.WriteString(`_`)
func aliasWithID(w io.Writer, alias string, id int32) {
io.WriteString(w, ` AS "`)
io.WriteString(w, alias)
io.WriteString(w, `_`)
int2string(w, id)
w.WriteString(`"`)
io.WriteString(w, `"`)
}
func aliasWithIDSuffix(w *bytes.Buffer, alias string, id int32, suffix string) {
w.WriteString(` AS "`)
w.WriteString(alias)
w.WriteString(`_`)
func aliasWithIDSuffix(w io.Writer, alias string, id int32, suffix string) {
io.WriteString(w, ` AS "`)
io.WriteString(w, alias)
io.WriteString(w, `_`)
int2string(w, id)
w.WriteString(suffix)
w.WriteString(`"`)
io.WriteString(w, suffix)
io.WriteString(w, `"`)
}
func colWithAlias(w *bytes.Buffer, col, alias string) {
w.WriteString(`"`)
w.WriteString(col)
w.WriteString(`" AS "`)
w.WriteString(alias)
w.WriteString(`"`)
func colWithAlias(w io.Writer, col, alias string) {
io.WriteString(w, `"`)
io.WriteString(w, col)
io.WriteString(w, `" AS "`)
io.WriteString(w, alias)
io.WriteString(w, `"`)
}
func tableWithAlias(w *bytes.Buffer, table, alias string) {
w.WriteString(`"`)
w.WriteString(table)
w.WriteString(`" AS "`)
w.WriteString(alias)
w.WriteString(`"`)
func tableWithAlias(w io.Writer, table, alias string) {
io.WriteString(w, `"`)
io.WriteString(w, table)
io.WriteString(w, `" AS "`)
io.WriteString(w, alias)
io.WriteString(w, `"`)
}
func colWithTable(w *bytes.Buffer, table, col string) {
w.WriteString(`"`)
w.WriteString(table)
w.WriteString(`"."`)
w.WriteString(col)
w.WriteString(`"`)
func colWithTable(w io.Writer, table, col string) {
io.WriteString(w, `"`)
io.WriteString(w, table)
io.WriteString(w, `"."`)
io.WriteString(w, col)
io.WriteString(w, `"`)
}
func colWithTableID(w *bytes.Buffer, table string, id int32, col string) {
w.WriteString(`"`)
w.WriteString(table)
w.WriteString(`_`)
func colWithTableID(w io.Writer, table string, id int32, col string) {
io.WriteString(w, `"`)
io.WriteString(w, table)
io.WriteString(w, `_`)
int2string(w, id)
w.WriteString(`"."`)
w.WriteString(col)
w.WriteString(`"`)
io.WriteString(w, `"."`)
io.WriteString(w, col)
io.WriteString(w, `"`)
}
func colWithTableIDAlias(w *bytes.Buffer, table string, id int32, col, alias string) {
w.WriteString(`"`)
w.WriteString(table)
w.WriteString(`_`)
func colWithTableIDAlias(w io.Writer, table string, id int32, col, alias string) {
io.WriteString(w, `"`)
io.WriteString(w, table)
io.WriteString(w, `_`)
int2string(w, id)
w.WriteString(`"."`)
w.WriteString(col)
w.WriteString(`" AS "`)
w.WriteString(alias)
w.WriteString(`"`)
io.WriteString(w, `"."`)
io.WriteString(w, col)
io.WriteString(w, `" AS "`)
io.WriteString(w, alias)
io.WriteString(w, `"`)
}
func colWithTableIDSuffixAlias(w *bytes.Buffer, table string, id int32,
func colWithTableIDSuffixAlias(w io.Writer, table string, id int32,
suffix, col, alias string) {
w.WriteString(`"`)
w.WriteString(table)
w.WriteString(`_`)
io.WriteString(w, `"`)
io.WriteString(w, table)
io.WriteString(w, `_`)
int2string(w, id)
w.WriteString(suffix)
w.WriteString(`"."`)
w.WriteString(col)
w.WriteString(`" AS "`)
w.WriteString(alias)
w.WriteString(`"`)
io.WriteString(w, suffix)
io.WriteString(w, `"."`)
io.WriteString(w, col)
io.WriteString(w, `" AS "`)
io.WriteString(w, alias)
io.WriteString(w, `"`)
}
func tableIDColSuffix(w *bytes.Buffer, table string, id int32, col, suffix string) {
w.WriteString(`"`)
w.WriteString(table)
w.WriteString(`_`)
func tableIDColSuffix(w io.Writer, table string, id int32, col, suffix string) {
io.WriteString(w, `"`)
io.WriteString(w, table)
io.WriteString(w, `_`)
int2string(w, id)
w.WriteString(`_`)
w.WriteString(col)
w.WriteString(suffix)
w.WriteString(`"`)
io.WriteString(w, `_`)
io.WriteString(w, col)
io.WriteString(w, suffix)
io.WriteString(w, `"`)
}
const charset = "0123456789"
func int2string(w *bytes.Buffer, val int32) {
func int2string(w io.Writer, val int32) {
if val < 10 {
w.WriteByte(charset[val])
w.Write([]byte{charset[val]})
return
}
@ -1053,7 +1112,7 @@ func int2string(w *bytes.Buffer, val int32) {
for val3 > 0 {
d := val3 % 10
val3 /= 10
w.WriteByte(charset[d])
w.Write([]byte{charset[d]})
}
}

View File

@ -22,32 +22,6 @@ func TestMain(m *testing.M) {
var err error
qcompile, err = qcode.NewCompiler(qcode.Config{
DefaultFilter: []string{
`{ user_id: { _eq: $user_id } }`,
},
FilterMap: qcode.Filters{
All: map[string][]string{
"users": []string{
"{ id: { eq: $user_id } }",
},
"products": []string{
"{ price: { gt: 0 } }",
"{ price: { lt: 8 } }",
},
"customers": []string{},
"mes": []string{
"{ id: { eq: $user_id } }",
},
},
Query: map[string][]string{
"users": []string{},
},
Update: map[string][]string{
"products": []string{
"{ user_id: { eq: $user_id } }",
},
},
},
Blocklist: []string{
"secret",
"password",
@ -55,6 +29,59 @@ func TestMain(m *testing.M) {
},
})
qcompile.AddRole("user", "product", qcode.TRConfig{
Query: qcode.QueryConfig{
Columns: []string{"id", "name", "price", "users", "customers"},
Filters: []string{
"{ price: { gt: 0 } }",
"{ price: { lt: 8 } }",
},
},
Update: qcode.UpdateConfig{
Filters: []string{"{ user_id: { eq: $user_id } }"},
},
Delete: qcode.DeleteConfig{
Filters: []string{
"{ price: { gt: 0 } }",
"{ price: { lt: 8 } }",
},
},
})
qcompile.AddRole("anon", "product", qcode.TRConfig{
Query: qcode.QueryConfig{
Columns: []string{"id", "name"},
},
})
qcompile.AddRole("anon1", "product", qcode.TRConfig{
Query: qcode.QueryConfig{
Columns: []string{"id", "name", "price"},
DisableFunctions: true,
},
})
qcompile.AddRole("user", "users", qcode.TRConfig{
Query: qcode.QueryConfig{
Columns: []string{"id", "full_name", "avatar", "email", "products"},
},
})
qcompile.AddRole("user", "mes", qcode.TRConfig{
Query: qcode.QueryConfig{
Columns: []string{"id", "full_name", "avatar"},
Filters: []string{
"{ id: { eq: $user_id } }",
},
},
})
qcompile.AddRole("user", "customers", qcode.TRConfig{
Query: qcode.QueryConfig{
Columns: []string{"id", "email", "full_name", "products"},
},
})
if err != nil {
log.Fatal(err)
}
@ -135,9 +162,8 @@ func TestMain(m *testing.M) {
os.Exit(m.Run())
}
func compileGQLToPSQL(gql string, vars Variables) ([]byte, error) {
qc, err := qcompile.Compile([]byte(gql))
func compileGQLToPSQL(gql string, vars Variables, role string) ([]byte, error) {
qc, err := qcompile.Compile([]byte(gql), role)
if err != nil {
return nil, err
}
@ -147,6 +173,8 @@ func compileGQLToPSQL(gql string, vars Variables) ([]byte, error) {
return nil, err
}
//fmt.Println(string(sqlStmt))
return sqlStmt, nil
}
@ -173,9 +201,9 @@ func withComplexArgs(t *testing.T) {
}
}`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0" ORDER BY "products_0_price_ob" DESC), '[]') AS "products" FROM (SELECT DISTINCT ON ("products_0_price_ob") row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0", "products_0"."price" AS "products_0_price_ob" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") < 28) AND (("products"."id") >= 20)) LIMIT ('30') :: integer) AS "products_0" ORDER BY "products_0_price_ob" DESC LIMIT ('30') :: integer) AS "sel_json_agg_0") AS "done_1337";`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0" ORDER BY "products_0_price_ob" DESC), '[]') AS "products" FROM (SELECT DISTINCT ON ("products_0_price_ob") row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0", "products_0"."price" AS "products_0_price_ob" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") < 28) AND (("products"."id") >= 20)) LIMIT ('30') :: integer) AS "products_0" ORDER BY "products_0_price_ob" DESC LIMIT ('30') :: integer) AS "sel_json_agg_0") AS "done_1337"`
resSQL, err := compileGQLToPSQL(gql, nil)
resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil {
t.Fatal(err)
}
@ -201,9 +229,9 @@ func withWhereMultiOr(t *testing.T) {
}
}`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") < 20) OR (("products"."price") > 10) OR NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") < 20) OR (("products"."price") > 10) OR NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"`
resSQL, err := compileGQLToPSQL(gql, nil)
resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil {
t.Fatal(err)
}
@ -227,9 +255,9 @@ func withWhereIsNull(t *testing.T) {
}
}`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") > 10) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") > 10) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"`
resSQL, err := compileGQLToPSQL(gql, nil)
resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil {
t.Fatal(err)
}
@ -253,9 +281,9 @@ func withWhereAndList(t *testing.T) {
}
}`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") > 10) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") > 10) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"`
resSQL, err := compileGQLToPSQL(gql, nil)
resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil {
t.Fatal(err)
}
@ -273,9 +301,9 @@ func fetchByID(t *testing.T) {
}
}`
sql := `SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") = 15)) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337";`
sql := `SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") = 15)) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337"`
resSQL, err := compileGQLToPSQL(gql, nil)
resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil {
t.Fatal(err)
}
@ -293,9 +321,9 @@ func searchQuery(t *testing.T) {
}
}`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("tsv") @@ to_tsquery('Imperial'))) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("tsv") @@ to_tsquery('Imperial'))) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"`
resSQL, err := compileGQLToPSQL(gql, nil)
resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil {
t.Fatal(err)
}
@ -316,9 +344,9 @@ func oneToMany(t *testing.T) {
}
}`
sql := `SELECT json_object_agg('users', users) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."email" AS "email", "products_1_join"."products" AS "products") AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."email", "users"."id" FROM "users" LIMIT ('20') :: integer) AS "users_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "products_1"."name" AS "name", "products_1"."price" AS "price") AS "sel_1")) AS "sel_json_1" FROM (SELECT "products"."name", "products"."price" FROM "products" WHERE ((("products"."user_id") = ("users_0"."id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
sql := `SELECT json_object_agg('users', users) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."email" AS "email", "products_1_join"."products" AS "products") AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."email", "users"."id" FROM "users" LIMIT ('20') :: integer) AS "users_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "products_1"."name" AS "name", "products_1"."price" AS "price") AS "sel_1")) AS "sel_json_1" FROM (SELECT "products"."name", "products"."price" FROM "products" WHERE ((("products"."user_id") = ("users_0"."id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"`
resSQL, err := compileGQLToPSQL(gql, nil)
resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil {
t.Fatal(err)
}
@ -339,9 +367,9 @@ func belongsTo(t *testing.T) {
}
}`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."price" AS "price", "users_1_join"."users" AS "users") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", "products"."price", "products"."user_id" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "users_1"."email" AS "email") AS "sel_1")) AS "sel_json_1" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = ("products_0"."user_id"))) LIMIT ('20') :: integer) AS "users_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "users_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."price" AS "price", "users_1_join"."users" AS "users") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", "products"."price", "products"."user_id" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "users_1"."email" AS "email") AS "sel_1")) AS "sel_json_1" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = ("products_0"."user_id"))) LIMIT ('20') :: integer) AS "users_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "users_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"`
resSQL, err := compileGQLToPSQL(gql, nil)
resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil {
t.Fatal(err)
}
@ -362,9 +390,9 @@ func manyToMany(t *testing.T) {
}
}`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "customers_1_join"."customers" AS "customers") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", "products"."id" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "customers_1"."email" AS "email", "customers_1"."full_name" AS "full_name") AS "sel_1")) AS "sel_json_1" FROM (SELECT "customers"."email", "customers"."full_name" FROM "customers" LEFT OUTER JOIN "purchases" ON (("purchases"."product_id") = ("products_0"."id")) WHERE ((("customers"."id") = ("purchases"."customer_id"))) LIMIT ('20') :: integer) AS "customers_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "customers_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "customers_1_join"."customers" AS "customers") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", "products"."id" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "customers_1"."email" AS "email", "customers_1"."full_name" AS "full_name") AS "sel_1")) AS "sel_json_1" FROM (SELECT "customers"."email", "customers"."full_name" FROM "customers" LEFT OUTER JOIN "purchases" ON (("purchases"."product_id") = ("products_0"."id")) WHERE ((("customers"."id") = ("purchases"."customer_id"))) LIMIT ('20') :: integer) AS "customers_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "customers_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"`
resSQL, err := compileGQLToPSQL(gql, nil)
resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil {
t.Fatal(err)
}
@ -385,9 +413,9 @@ func manyToManyReverse(t *testing.T) {
}
}`
sql := `SELECT json_object_agg('customers', customers) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "customers_0"."email" AS "email", "customers_0"."full_name" AS "full_name", "products_1_join"."products" AS "products") AS "sel_0")) AS "sel_json_0" FROM (SELECT "customers"."email", "customers"."full_name", "customers"."id" FROM "customers" LIMIT ('20') :: integer) AS "customers_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "products_1"."name" AS "name") AS "sel_1")) AS "sel_json_1" FROM (SELECT "products"."name" FROM "products" LEFT OUTER JOIN "purchases" ON (("purchases"."customer_id") = ("customers_0"."id")) WHERE ((("products"."id") = ("purchases"."product_id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
sql := `SELECT json_object_agg('customers', customers) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "customers_0"."email" AS "email", "customers_0"."full_name" AS "full_name", "products_1_join"."products" AS "products") AS "sel_0")) AS "sel_json_0" FROM (SELECT "customers"."email", "customers"."full_name", "customers"."id" FROM "customers" LIMIT ('20') :: integer) AS "customers_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "products_1"."name" AS "name") AS "sel_1")) AS "sel_json_1" FROM (SELECT "products"."name" FROM "products" LEFT OUTER JOIN "purchases" ON (("purchases"."customer_id") = ("customers_0"."id")) WHERE ((("products"."id") = ("purchases"."product_id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"`
resSQL, err := compileGQLToPSQL(gql, nil)
resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil {
t.Fatal(err)
}
@ -405,9 +433,49 @@ func aggFunction(t *testing.T) {
}
}`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."count_price" AS "count_price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", count("products"."price") AS "count_price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) GROUP BY "products"."name" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."count_price" AS "count_price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", count("products"."price") AS "count_price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) GROUP BY "products"."name" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"`
resSQL, err := compileGQLToPSQL(gql, nil)
resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil {
t.Fatal(err)
}
if string(resSQL) != sql {
t.Fatal(errNotExpected)
}
}
func aggFunctionBlockedByCol(t *testing.T) {
gql := `query {
products {
name
count_price
}
}`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name" FROM "products" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"`
resSQL, err := compileGQLToPSQL(gql, nil, "anon")
if err != nil {
t.Fatal(err)
}
if string(resSQL) != sql {
t.Fatal(errNotExpected)
}
}
func aggFunctionDisabled(t *testing.T) {
gql := `query {
products {
name
count_price
}
}`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name" FROM "products" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"`
resSQL, err := compileGQLToPSQL(gql, nil, "anon1")
if err != nil {
t.Fatal(err)
}
@ -425,9 +493,9 @@ func aggFunctionWithFilter(t *testing.T) {
}
}`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."max_price" AS "max_price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", max("products"."price") AS "max_price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") > 10)) GROUP BY "products"."id" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."max_price" AS "max_price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", max("products"."price") AS "max_price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") > 10)) GROUP BY "products"."id" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"`
resSQL, err := compileGQLToPSQL(gql, nil)
resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil {
t.Fatal(err)
}
@ -445,9 +513,9 @@ func queryWithVariables(t *testing.T) {
}
}`
sql := `SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") = {{product_price}}) AND (("products"."id") = {{product_id}})) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337";`
sql := `SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") = {{product_price}}) AND (("products"."id") = {{product_id}})) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337"`
resSQL, err := compileGQLToPSQL(gql, nil)
resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil {
t.Fatal(err)
}
@ -464,9 +532,9 @@ func syntheticTables(t *testing.T) {
}
}`
sql := `SELECT json_object_agg('me', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."email" AS "email") AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = {{user_id}})) LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "done_1337";`
sql := `SELECT json_object_agg('me', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT ) AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = {{user_id}})) LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "done_1337"`
resSQL, err := compileGQLToPSQL(gql, nil)
resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil {
t.Fatal(err)
}
@ -476,7 +544,7 @@ func syntheticTables(t *testing.T) {
}
}
func TestCompileSelect(t *testing.T) {
func TestCompileQuery(t *testing.T) {
t.Run("withComplexArgs", withComplexArgs)
t.Run("withWhereAndList", withWhereAndList)
t.Run("withWhereIsNull", withWhereIsNull)
@ -488,10 +556,11 @@ func TestCompileSelect(t *testing.T) {
t.Run("manyToMany", manyToMany)
t.Run("manyToManyReverse", manyToManyReverse)
t.Run("aggFunction", aggFunction)
t.Run("aggFunctionBlockedByCol", aggFunctionBlockedByCol)
t.Run("aggFunctionDisabled", aggFunctionDisabled)
t.Run("aggFunctionWithFilter", aggFunctionWithFilter)
t.Run("syntheticTables", syntheticTables)
t.Run("queryWithVariables", queryWithVariables)
}
var benchGQL = []byte(`query {
@ -526,7 +595,7 @@ func BenchmarkCompile(b *testing.B) {
for n := 0; n < b.N; n++ {
w.Reset()
qc, err := qcompile.Compile(benchGQL)
qc, err := qcompile.Compile(benchGQL, "user")
if err != nil {
b.Fatal(err)
}
@ -547,7 +616,7 @@ func BenchmarkCompileParallel(b *testing.B) {
for pb.Next() {
w.Reset()
qc, err := qcompile.Compile(benchGQL)
qc, err := qcompile.Compile(benchGQL, "user")
if err != nil {
b.Fatal(err)
}

99
qcode/config.go Normal file
View File

@ -0,0 +1,99 @@
package qcode
type Config struct {
Blocklist []string
KeepArgs bool
}
type QueryConfig struct {
Limit int
Filters []string
Columns []string
DisableFunctions bool
}
type InsertConfig struct {
Filters []string
Columns []string
Set map[string]string
}
type UpdateConfig struct {
Filters []string
Columns []string
Set map[string]string
}
type DeleteConfig struct {
Filters []string
Columns []string
}
type TRConfig struct {
Query QueryConfig
Insert InsertConfig
Update UpdateConfig
Delete DeleteConfig
}
type trval struct {
query struct {
limit string
fil *Exp
cols map[string]struct{}
disable struct {
funcs bool
}
}
insert struct {
fil *Exp
cols map[string]struct{}
set map[string]string
}
update struct {
fil *Exp
cols map[string]struct{}
set map[string]string
}
delete struct {
fil *Exp
cols map[string]struct{}
}
}
func (trv *trval) allowedColumns(qt QType) map[string]struct{} {
switch qt {
case QTQuery:
return trv.query.cols
case QTInsert:
return trv.insert.cols
case QTUpdate:
return trv.update.cols
case QTDelete:
return trv.insert.cols
case QTUpsert:
return trv.insert.cols
}
return nil
}
func (trv *trval) filter(qt QType) *Exp {
switch qt {
case QTQuery:
return trv.query.fil
case QTInsert:
return trv.insert.fil
case QTUpdate:
return trv.update.fil
case QTDelete:
return trv.delete.fil
case QTUpsert:
return trv.insert.fil
}
return nil
}

View File

@ -5,7 +5,7 @@ func FuzzerEntrypoint(data []byte) int {
//testData := string(data)
qcompile, _ := NewCompiler(Config{})
_, err := qcompile.Compile(data)
_, err := qcompile.Compile(data, "user")
if err != nil {
return -1
}

View File

@ -18,7 +18,9 @@ type parserType int32
const (
maxFields = 100
maxArgs = 10
)
const (
parserError parserType = iota
parserEOF
opQuery

View File

@ -46,13 +46,18 @@ func compareOp(op1, op2 Operation) error {
*/
func TestCompile1(t *testing.T) {
qcompile, _ := NewCompiler(Config{})
qc, _ := NewCompiler(Config{})
qc.AddRole("user", "product", TRConfig{
Query: QueryConfig{
Columns: []string{"id", "Name"},
},
})
_, err := qcompile.Compile([]byte(`
_, err := qc.Compile([]byte(`
product(id: 15) {
id
name
}`))
}`), "user")
if err != nil {
t.Fatal(err)
@ -60,13 +65,18 @@ func TestCompile1(t *testing.T) {
}
func TestCompile2(t *testing.T) {
qcompile, _ := NewCompiler(Config{})
qc, _ := NewCompiler(Config{})
qc.AddRole("user", "product", TRConfig{
Query: QueryConfig{
Columns: []string{"ID"},
},
})
_, err := qcompile.Compile([]byte(`
_, err := qc.Compile([]byte(`
query { product(id: 15) {
id
name
} }`))
} }`), "user")
if err != nil {
t.Fatal(err)
@ -74,15 +84,20 @@ func TestCompile2(t *testing.T) {
}
func TestCompile3(t *testing.T) {
qcompile, _ := NewCompiler(Config{})
qc, _ := NewCompiler(Config{})
qc.AddRole("user", "product", TRConfig{
Query: QueryConfig{
Columns: []string{"ID"},
},
})
_, err := qcompile.Compile([]byte(`
_, err := qc.Compile([]byte(`
mutation {
product(id: 15, name: "Test") {
id
name
}
}`))
}`), "user")
if err != nil {
t.Fatal(err)
@ -91,7 +106,7 @@ func TestCompile3(t *testing.T) {
func TestInvalidCompile1(t *testing.T) {
qcompile, _ := NewCompiler(Config{})
_, err := qcompile.Compile([]byte(`#`))
_, err := qcompile.Compile([]byte(`#`), "user")
if err == nil {
t.Fatal(errors.New("expecting an error"))
@ -100,7 +115,7 @@ func TestInvalidCompile1(t *testing.T) {
func TestInvalidCompile2(t *testing.T) {
qcompile, _ := NewCompiler(Config{})
_, err := qcompile.Compile([]byte(`{u(where:{not:0})}`))
_, err := qcompile.Compile([]byte(`{u(where:{not:0})}`), "user")
if err == nil {
t.Fatal(errors.New("expecting an error"))
@ -109,7 +124,7 @@ func TestInvalidCompile2(t *testing.T) {
func TestEmptyCompile(t *testing.T) {
qcompile, _ := NewCompiler(Config{})
_, err := qcompile.Compile([]byte(``))
_, err := qcompile.Compile([]byte(``), "user")
if err == nil {
t.Fatal(errors.New("expecting an error"))
@ -144,7 +159,7 @@ func BenchmarkQCompile(b *testing.B) {
b.ReportAllocs()
for n := 0; n < b.N; n++ {
_, err := qcompile.Compile(gql)
_, err := qcompile.Compile(gql, "user")
if err != nil {
b.Fatal(err)
@ -160,7 +175,7 @@ func BenchmarkQCompileP(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_, err := qcompile.Compile(gql)
_, err := qcompile.Compile(gql, "user")
if err != nil {
b.Fatal(err)

View File

@ -3,6 +3,7 @@ package qcode
import (
"errors"
"fmt"
"strconv"
"strings"
"sync"
@ -15,25 +16,20 @@ type Action int
const (
maxSelectors = 30
)
const (
QTQuery QType = iota + 1
QTMutation
ActionInsert Action = iota + 1
ActionUpdate
ActionDelete
ActionUpsert
QTInsert
QTUpdate
QTDelete
QTUpsert
)
type QCode struct {
Type QType
Selects []Select
}
type Column struct {
Table string
Name string
FieldName string
Type QType
ActionVar string
Selects []Select
}
type Select struct {
@ -47,9 +43,15 @@ type Select struct {
OrderBy []*OrderBy
DistinctOn []string
Paging Paging
Action Action
ActionVar string
Children []int32
Functions bool
Allowed map[string]struct{}
}
type Column struct {
Table string
Name string
FieldName string
}
type Exp struct {
@ -77,8 +79,9 @@ type OrderBy struct {
}
type Paging struct {
Limit string
Offset string
Limit string
Offset string
NoLimit bool
}
type ExpOp int
@ -145,81 +148,23 @@ const (
OrderDescNullsLast
)
type Filters struct {
All map[string][]string
Query map[string][]string
Insert map[string][]string
Update map[string][]string
Delete map[string][]string
}
type Config struct {
DefaultFilter []string
FilterMap Filters
Blocklist []string
KeepArgs bool
}
type Compiler struct {
df *Exp
fm struct {
all map[string]*Exp
query map[string]*Exp
insert map[string]*Exp
update map[string]*Exp
delete map[string]*Exp
}
tr map[string]map[string]*trval
bl map[string]struct{}
ka bool
}
var opMap = map[parserType]QType{
opQuery: QTQuery,
opMutate: QTMutation,
}
var expPool = sync.Pool{
New: func() interface{} { return &Exp{doFree: true} },
}
func NewCompiler(c Config) (*Compiler, error) {
var err error
co := &Compiler{ka: c.KeepArgs}
co.tr = make(map[string]map[string]*trval)
co.bl = make(map[string]struct{}, len(c.Blocklist))
for i := range c.Blocklist {
co.bl[c.Blocklist[i]] = struct{}{}
}
co.df, err = compileFilter(c.DefaultFilter)
if err != nil {
return nil, err
}
co.fm.all, err = buildFilters(c.FilterMap.All)
if err != nil {
return nil, err
}
co.fm.query, err = buildFilters(c.FilterMap.Query)
if err != nil {
return nil, err
}
co.fm.insert, err = buildFilters(c.FilterMap.Insert)
if err != nil {
return nil, err
}
co.fm.update, err = buildFilters(c.FilterMap.Update)
if err != nil {
return nil, err
}
co.fm.delete, err = buildFilters(c.FilterMap.Delete)
if err != nil {
return nil, err
co.bl[strings.ToLower(c.Blocklist[i])] = struct{}{}
}
seedExp := [100]Exp{}
@ -232,58 +177,99 @@ func NewCompiler(c Config) (*Compiler, error) {
return co, nil
}
func buildFilters(filMap map[string][]string) (map[string]*Exp, error) {
fm := make(map[string]*Exp, len(filMap))
func (com *Compiler) AddRole(role, table string, trc TRConfig) error {
var err error
trv := &trval{}
for k, v := range filMap {
fil, err := compileFilter(v)
if err != nil {
return nil, err
toMap := func(cols []string) map[string]struct{} {
m := make(map[string]struct{}, len(cols))
for i := range cols {
m[strings.ToLower(cols[i])] = struct{}{}
}
singular := flect.Singularize(k)
plural := flect.Pluralize(k)
fm[singular] = fil
fm[plural] = fil
return m
}
return fm, nil
// query config
trv.query.fil, err = compileFilter(trc.Query.Filters)
if err != nil {
return err
}
if trc.Query.Limit > 0 {
trv.query.limit = strconv.Itoa(trc.Query.Limit)
}
trv.query.cols = toMap(trc.Query.Columns)
trv.query.disable.funcs = trc.Query.DisableFunctions
// insert config
if trv.insert.fil, err = compileFilter(trc.Insert.Filters); err != nil {
return err
}
trv.insert.cols = toMap(trc.Insert.Columns)
// update config
if trv.update.fil, err = compileFilter(trc.Update.Filters); err != nil {
return err
}
trv.insert.cols = toMap(trc.Insert.Columns)
trv.insert.set = trc.Insert.Set
// delete config
if trv.delete.fil, err = compileFilter(trc.Delete.Filters); err != nil {
return err
}
trv.delete.cols = toMap(trc.Delete.Columns)
singular := flect.Singularize(table)
plural := flect.Pluralize(table)
if _, ok := com.tr[role]; !ok {
com.tr[role] = make(map[string]*trval)
}
com.tr[role][singular] = trv
com.tr[role][plural] = trv
return nil
}
func (com *Compiler) Compile(query []byte) (*QCode, error) {
var qc QCode
func (com *Compiler) Compile(query []byte, role string) (*QCode, error) {
var err error
qc := QCode{Type: QTQuery}
op, err := Parse(query)
if err != nil {
return nil, err
}
qc.Selects, err = com.compileQuery(op)
if err != nil {
if err = com.compileQuery(&qc, op, role); err != nil {
return nil, err
}
if t, ok := opMap[op.Type]; ok {
qc.Type = t
} else {
return nil, fmt.Errorf("Unknown operation type %d", op.Type)
}
opPool.Put(op)
return &qc, nil
}
func (com *Compiler) compileQuery(op *Operation) ([]Select, error) {
func (com *Compiler) compileQuery(qc *QCode, op *Operation, role string) error {
id := int32(0)
parentID := int32(0)
if len(op.Fields) == 0 {
return errors.New("invalid graphql no query found")
}
if op.Type == opMutate {
if err := com.setMutationType(qc, op.Fields[0].Args); err != nil {
return err
}
}
selects := make([]Select, 0, 5)
st := NewStack()
action := qc.Type
if len(op.Fields) == 0 {
return nil, errors.New("empty query")
return errors.New("empty query")
}
st.Push(op.Fields[0].ID)
@ -293,7 +279,7 @@ func (com *Compiler) compileQuery(op *Operation) ([]Select, error) {
}
if id >= maxSelectors {
return nil, fmt.Errorf("selector limit reached (%d)", maxSelectors)
return fmt.Errorf("selector limit reached (%d)", maxSelectors)
}
fid := st.Pop()
@ -303,14 +289,25 @@ func (com *Compiler) compileQuery(op *Operation) ([]Select, error) {
continue
}
trv := com.getRole(role, field.Name)
selects = append(selects, Select{
ID: id,
ParentID: parentID,
Table: field.Name,
Children: make([]int32, 0, 5),
Allowed: trv.allowedColumns(action),
})
s := &selects[(len(selects) - 1)]
if action == QTQuery {
s.Functions = !trv.query.disable.funcs
if len(trv.query.limit) != 0 {
s.Paging.Limit = trv.query.limit
}
}
if s.ID != 0 {
p := &selects[s.ParentID]
p.Children = append(p.Children, s.ID)
@ -322,12 +319,13 @@ func (com *Compiler) compileQuery(op *Operation) ([]Select, error) {
s.FieldName = s.Table
}
err := com.compileArgs(s, field.Args)
err := com.compileArgs(qc, s, field.Args)
if err != nil {
return nil, err
return err
}
s.Cols = make([]Column, 0, len(field.Children))
action = QTQuery
for _, cid := range field.Children {
f := op.Fields[cid]
@ -356,36 +354,14 @@ func (com *Compiler) compileQuery(op *Operation) ([]Select, error) {
}
if id == 0 {
return nil, errors.New("invalid query")
return errors.New("invalid query")
}
var fil *Exp
root := &selects[0]
switch op.Type {
case opQuery:
fil, _ = com.fm.query[root.Table]
case opMutate:
switch root.Action {
case ActionInsert:
fil, _ = com.fm.insert[root.Table]
case ActionUpdate:
fil, _ = com.fm.update[root.Table]
case ActionDelete:
fil, _ = com.fm.delete[root.Table]
case ActionUpsert:
fil, _ = com.fm.insert[root.Table]
}
}
if fil == nil {
fil, _ = com.fm.all[root.Table]
}
if fil == nil {
fil = com.df
if trv, ok := com.tr[role][op.Fields[0].Name]; ok {
fil = trv.filter(qc.Type)
}
if fil != nil && fil.Op != OpNop {
@ -403,10 +379,11 @@ func (com *Compiler) compileQuery(op *Operation) ([]Select, error) {
}
}
return selects[:id], nil
qc.Selects = selects[:id]
return nil
}
func (com *Compiler) compileArgs(sel *Select, args []Arg) error {
func (com *Compiler) compileArgs(qc *QCode, sel *Select, args []Arg) error {
var err error
if com.ka {
@ -418,9 +395,7 @@ func (com *Compiler) compileArgs(sel *Select, args []Arg) error {
switch arg.Name {
case "id":
if sel.ID == 0 {
err = com.compileArgID(sel, arg)
}
err = com.compileArgID(sel, arg)
case "search":
err = com.compileArgSearch(sel, arg)
case "where":
@ -433,18 +408,6 @@ func (com *Compiler) compileArgs(sel *Select, args []Arg) error {
err = com.compileArgLimit(sel, arg)
case "offset":
err = com.compileArgOffset(sel, arg)
case "insert":
sel.Action = ActionInsert
err = com.compileArgAction(sel, arg)
case "update":
sel.Action = ActionUpdate
err = com.compileArgAction(sel, arg)
case "upsert":
sel.Action = ActionUpsert
err = com.compileArgAction(sel, arg)
case "delete":
sel.Action = ActionDelete
err = com.compileArgAction(sel, arg)
}
if err != nil {
@ -461,6 +424,45 @@ func (com *Compiler) compileArgs(sel *Select, args []Arg) error {
return nil
}
func (com *Compiler) setMutationType(qc *QCode, args []Arg) error {
setActionVar := func(arg *Arg) error {
if arg.Val.Type != nodeVar {
return fmt.Errorf("value for argument '%s' must be a variable", arg.Name)
}
qc.ActionVar = arg.Val.Val
return nil
}
for i := range args {
arg := &args[i]
switch arg.Name {
case "insert":
qc.Type = QTInsert
return setActionVar(arg)
case "update":
qc.Type = QTUpdate
return setActionVar(arg)
case "upsert":
qc.Type = QTUpsert
return setActionVar(arg)
case "delete":
qc.Type = QTDelete
if arg.Val.Type != nodeBool {
return fmt.Errorf("value for argument '%s' must be a boolean", arg.Name)
}
if arg.Val.Val == "false" {
qc.Type = QTQuery
}
return nil
}
}
return nil
}
func (com *Compiler) compileArgObj(st *util.Stack, arg *Arg) (*Exp, error) {
if arg.Val.Type != nodeObj {
return nil, fmt.Errorf("expecting an object")
@ -540,6 +542,10 @@ func (com *Compiler) compileArgNode(st *util.Stack, node *Node, usePool bool) (*
}
func (com *Compiler) compileArgID(sel *Select, arg *Arg) error {
if sel.ID != 0 {
return nil
}
if sel.Where != nil && sel.Where.Op == OpEqID {
return nil
}
@ -732,24 +738,14 @@ func (com *Compiler) compileArgOffset(sel *Select, arg *Arg) error {
return nil
}
func (com *Compiler) compileArgAction(sel *Select, arg *Arg) error {
switch sel.Action {
case ActionDelete:
if arg.Val.Type != nodeBool {
return fmt.Errorf("value for argument '%s' must be a boolean", arg.Name)
}
if arg.Val.Val == "false" {
sel.Action = 0
}
var zeroTrv = &trval{}
default:
if arg.Val.Type != nodeVar {
return fmt.Errorf("value for argument '%s' must be a variable", arg.Name)
}
sel.ActionVar = arg.Val.Val
func (com *Compiler) getRole(role, field string) *trval {
if trv, ok := com.tr[role][field]; ok {
return trv
} else {
return zeroTrv
}
return nil
}
func newExp(st *util.Stack, node *Node, usePool bool) (*Exp, error) {

View File

@ -26,7 +26,8 @@ type allowItem struct {
var _allowList allowList
type allowList struct {
list map[string]*allowItem
list []*allowItem
index map[string]int
filepath string
saveChan chan *allowItem
active bool
@ -34,7 +35,7 @@ type allowList struct {
func initAllowList(cpath string) {
_allowList = allowList{
list: make(map[string]*allowItem),
index: make(map[string]int),
saveChan: make(chan *allowItem),
active: true,
}
@ -172,17 +173,21 @@ func (al *allowList) load() {
if c == 0 {
if ty == AL_QUERY {
q := string(b[s:(e + 1)])
key := gqlHash(q, varBytes, "")
item := &allowItem{
uri: uri,
gql: q,
}
if len(varBytes) != 0 {
if idx, ok := al.index[key]; !ok {
al.list = append(al.list, &allowItem{
uri: uri,
gql: q,
vars: varBytes,
})
al.index[key] = len(al.list) - 1
} else {
item := al.list[idx]
item.gql = q
item.vars = varBytes
}
al.list[gqlHash(q, varBytes)] = item
varBytes = nil
} else if ty == AL_VARS {
@ -203,7 +208,15 @@ func (al *allowList) save(item *allowItem) {
if al.active == false {
return
}
al.list[gqlHash(item.gql, item.vars)] = item
key := gqlHash(item.gql, item.vars, "")
if idx, ok := al.index[key]; ok {
al.list[idx] = item
} else {
al.list = append(al.list, item)
al.index[key] = len(al.list) - 1
}
f, err := os.Create(al.filepath)
if err != nil {

View File

@ -7,28 +7,42 @@ import (
)
var (
userIDProviderKey = struct{}{}
userIDKey = struct{}{}
userIDProviderKey = "user_id_provider"
userIDKey = "user_id"
userRoleKey = "user_role"
)
func headerAuth(r *http.Request, c *config) *http.Request {
if len(c.Auth.Header) == 0 {
return nil
}
func headerAuth(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
userID := r.Header.Get(c.Auth.Header)
if len(userID) != 0 {
ctx := context.WithValue(r.Context(), userIDKey, userID)
return r.WithContext(ctx)
}
userIDProvider := r.Header.Get("X-User-ID-Provider")
if len(userIDProvider) != 0 {
ctx = context.WithValue(ctx, userIDProviderKey, userIDProvider)
}
return nil
userID := r.Header.Get("X-User-ID")
if len(userID) != 0 {
ctx = context.WithValue(ctx, userIDKey, userID)
}
userRole := r.Header.Get("X-User-Role")
if len(userRole) != 0 {
ctx = context.WithValue(ctx, userRoleKey, userRole)
}
next.ServeHTTP(w, r.WithContext(ctx))
}
}
func withAuth(next http.HandlerFunc) http.HandlerFunc {
at := conf.Auth.Type
ru := conf.Auth.Rails.URL
if conf.Auth.CredsInHeader {
next = headerAuth(next)
}
switch at {
case "rails":
if strings.HasPrefix(ru, "memcache:") {

View File

@ -58,11 +58,6 @@ func jwtHandler(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
var tok string
if rn := headerAuth(r, conf); rn != nil {
next.ServeHTTP(w, rn)
return
}
if len(cookie) != 0 {
ck, err := r.Cookie(cookie)
if err != nil {
@ -102,7 +97,6 @@ func jwtHandler(next http.HandlerFunc) http.HandlerFunc {
}
next.ServeHTTP(w, r.WithContext(ctx))
}
next.ServeHTTP(w, r)
}
}

View File

@ -42,11 +42,6 @@ func railsRedisHandler(next http.HandlerFunc) http.HandlerFunc {
}
return func(w http.ResponseWriter, r *http.Request) {
if rn := headerAuth(r, conf); rn != nil {
next.ServeHTTP(w, rn)
return
}
ck, err := r.Cookie(cookie)
if err != nil {
next.ServeHTTP(w, r)
@ -83,17 +78,12 @@ func railsMemcacheHandler(next http.HandlerFunc) http.HandlerFunc {
rURL, err := url.Parse(conf.Auth.Rails.URL)
if err != nil {
logger.Fatal().Err(err)
logger.Fatal().Err(err).Send()
}
mc := memcache.New(rURL.Host)
return func(w http.ResponseWriter, r *http.Request) {
if rn := headerAuth(r, conf); rn != nil {
next.ServeHTTP(w, rn)
return
}
ck, err := r.Cookie(cookie)
if err != nil {
next.ServeHTTP(w, r)
@ -126,25 +116,20 @@ func railsCookieHandler(next http.HandlerFunc) http.HandlerFunc {
ra, err := railsAuth(conf)
if err != nil {
logger.Fatal().Err(err)
logger.Fatal().Err(err).Send()
}
return func(w http.ResponseWriter, r *http.Request) {
if rn := headerAuth(r, conf); rn != nil {
next.ServeHTTP(w, rn)
return
}
ck, err := r.Cookie(cookie)
if err != nil {
logger.Error().Err(err)
logger.Warn().Err(err).Msg("rails cookie missing")
next.ServeHTTP(w, r)
return
}
userID, err := ra.ParseCookie(ck.Value)
if err != nil {
logger.Error().Err(err)
logger.Warn().Err(err).Msg("failed to parse rails cookie")
next.ServeHTTP(w, r)
return
}

View File

@ -10,7 +10,6 @@ import (
"github.com/dosco/super-graph/qcode"
"github.com/gobuffalo/flect"
"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/log/zerologadapter"
"github.com/jackc/pgx/v4/pgxpool"
"github.com/rs/zerolog"
"github.com/spf13/cobra"
@ -184,7 +183,34 @@ func initConf() (*config, error) {
}
zerolog.SetGlobalLevel(logLevel)
//fmt.Printf("%#v", c)
for k, v := range c.DB.Vars {
c.DB.Vars[k] = sanitize(v)
}
c.RolesQuery = sanitize(c.RolesQuery)
rolesMap := make(map[string]struct{})
for i := range c.Roles {
role := &c.Roles[i]
if _, ok := rolesMap[role.Name]; ok {
logger.Fatal().Msgf("duplicate role '%s' found", role.Name)
}
role.Name = sanitize(role.Name)
role.Match = sanitize(role.Match)
rolesMap[role.Name] = struct{}{}
}
if _, ok := rolesMap["user"]; !ok {
c.Roles = append(c.Roles, configRole{Name: "user"})
}
if _, ok := rolesMap["anon"]; !ok {
c.Roles = append(c.Roles, configRole{Name: "anon"})
}
c.Validate()
return c, nil
}
@ -217,7 +243,7 @@ func initDB(c *config, useDB bool) (*pgx.Conn, error) {
config.LogLevel = pgx.LogLevelNone
}
config.Logger = zerologadapter.NewLogger(*logger)
config.Logger = NewSQLLogger(*logger)
db, err := pgx.ConnectConfig(context.Background(), config)
if err != nil {
@ -252,7 +278,7 @@ func initDBPool(c *config) (*pgxpool.Pool, error) {
config.ConnConfig.LogLevel = pgx.LogLevelNone
}
config.ConnConfig.Logger = zerologadapter.NewLogger(*logger)
config.ConnConfig.Logger = NewSQLLogger(*logger)
// if c.DB.MaxRetries != 0 {
// opt.MaxRetries = c.DB.MaxRetries

View File

@ -66,6 +66,7 @@ func graphQLFunc(query string, data interface{}) map[string]interface{} {
c := &coreContext{Context: context.Background()}
c.req.Query = query
c.req.Vars = b
c.req.role = "user"
res, err := c.execQuery()
if err != nil {

View File

@ -1,7 +1,9 @@
package serv
import (
"regexp"
"strings"
"unicode"
"github.com/spf13/viper"
)
@ -24,9 +26,9 @@ type config struct {
Inflections map[string]string
Auth struct {
Type string
Cookie string
Header string
Type string
Cookie string
CredsInHeader bool `mapstructure:"creds_in_header"`
Rails struct {
Version string
@ -60,10 +62,10 @@ type config struct {
MaxRetries int `mapstructure:"max_retries"`
LogLevel string `mapstructure:"log_level"`
vars map[string][]byte `mapstructure:"variables"`
Vars map[string]string `mapstructure:"variables"`
Defaults struct {
Filter []string
Filters []string
Blocklist []string
}
@ -71,18 +73,16 @@ type config struct {
} `mapstructure:"database"`
Tables []configTable
RolesQuery string `mapstructure:"roles_query"`
Roles []configRole
}
type configTable struct {
Name string
Filter []string
FilterQuery []string `mapstructure:"filter_query"`
FilterInsert []string `mapstructure:"filter_insert"`
FilterUpdate []string `mapstructure:"filter_update"`
FilterDelete []string `mapstructure:"filter_delete"`
Table string
Blocklist []string
Remotes []configRemote
Name string
Table string
Blocklist []string
Remotes []configRemote
}
type configRemote struct {
@ -98,6 +98,42 @@ type configRemote struct {
} `mapstructure:"set_headers"`
}
type configRole struct {
Name string
Match string
Tables []struct {
Name string
Query struct {
Limit int
Filters []string
Columns []string
DisableAggregation bool `mapstructure:"disable_aggregation"`
Deny bool
}
Insert struct {
Filters []string
Columns []string
Set map[string]string
Deny bool
}
Update struct {
Filters []string
Columns []string
Set map[string]string
Deny bool
}
Delete struct {
Filters []string
Columns []string
Deny bool
}
}
}
func newConfig() *viper.Viper {
vi := viper.New()
@ -132,24 +168,30 @@ func newConfig() *viper.Viper {
return vi
}
func (c *config) getVariables() map[string]string {
vars := make(map[string]string, len(c.DB.vars))
func (c *config) Validate() {
rm := make(map[string]struct{})
for k, v := range c.DB.vars {
isVar := false
for i := range v {
if v[i] == '$' {
isVar = true
} else if v[i] == ' ' {
isVar = false
} else if isVar && v[i] >= 'a' && v[i] <= 'z' {
v[i] = 'A' + (v[i] - 'a')
}
for i := range c.Roles {
name := strings.ToLower(c.Roles[i].Name)
if _, ok := rm[name]; ok {
logger.Fatal().Msgf("duplicate config for role '%s'", c.Roles[i].Name)
}
vars[k] = string(v)
rm[name] = struct{}{}
}
tm := make(map[string]struct{})
for i := range c.Tables {
name := strings.ToLower(c.Tables[i].Name)
if _, ok := tm[name]; ok {
logger.Fatal().Msgf("duplicate config for table '%s'", c.Tables[i].Name)
}
tm[name] = struct{}{}
}
if len(c.RolesQuery) == 0 {
logger.Warn().Msgf("no 'roles_query' defined.")
}
return vars
}
func (c *config) getAliasMap() map[string][]string {
@ -167,3 +209,21 @@ func (c *config) getAliasMap() map[string][]string {
}
return m
}
var varRe1 = regexp.MustCompile(`(?mi)\$([a-zA-Z0-9_.]+)`)
var varRe2 = regexp.MustCompile(`\{\{([a-zA-Z0-9_.]+)\}\}`)
func sanitize(s string) string {
s0 := varRe1.ReplaceAllString(s, `{{$1}}`)
s1 := strings.Map(func(r rune) rune {
if unicode.IsSpace(r) {
return ' '
}
return r
}, s0)
return varRe2.ReplaceAllStringFunc(s1, func(m string) string {
return strings.ToLower(m)
})
}

View File

@ -13,8 +13,8 @@ import (
"github.com/cespare/xxhash/v2"
"github.com/dosco/super-graph/jsn"
"github.com/dosco/super-graph/psql"
"github.com/dosco/super-graph/qcode"
"github.com/jackc/pgx/v4"
"github.com/valyala/fasttemplate"
)
@ -32,6 +32,12 @@ func (c *coreContext) handleReq(w io.Writer, req *http.Request) error {
c.req.ref = req.Referer()
c.req.hdr = req.Header
if authCheck(c) {
c.req.role = "user"
} else {
c.req.role = "anon"
}
b, err := c.execQuery()
if err != nil {
return err
@ -46,10 +52,12 @@ func (c *coreContext) execQuery() ([]byte, error) {
var qc *qcode.QCode
var data []byte
logger.Debug().Str("role", c.req.role).Msg(c.req.Query)
if conf.UseAllowList {
var ps *preparedItem
data, ps, err = c.resolvePreparedSQL(c.req.Query)
data, ps, err = c.resolvePreparedSQL()
if err != nil {
return nil, err
}
@ -59,12 +67,7 @@ func (c *coreContext) execQuery() ([]byte, error) {
} else {
qc, err = qcompile.Compile([]byte(c.req.Query))
if err != nil {
return nil, err
}
data, skipped, err = c.resolveSQL(qc)
data, skipped, err = c.resolveSQL()
if err != nil {
return nil, err
}
@ -112,6 +115,160 @@ func (c *coreContext) execQuery() ([]byte, error) {
return ob.Bytes(), nil
}
func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) {
tx, err := db.Begin(c)
if err != nil {
return nil, nil, err
}
defer tx.Rollback(c)
if v := c.Value(userIDKey); v != nil {
_, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
if err != nil {
return nil, nil, err
}
}
var role string
mutation := isMutation(c.req.Query)
useRoleQuery := len(conf.RolesQuery) != 0 && mutation
if useRoleQuery {
if role, err = c.executeRoleQuery(tx); err != nil {
return nil, nil, err
}
} else if v := c.Value(userRoleKey); v != nil {
role = v.(string)
} else if mutation {
role = c.req.role
}
ps, ok := _preparedList[gqlHash(c.req.Query, c.req.Vars, role)]
if !ok {
return nil, nil, errUnauthorized
}
var root []byte
vars := varList(c, ps.args)
if mutation {
err = tx.QueryRow(c, ps.stmt.SQL, vars...).Scan(&root)
} else {
err = tx.QueryRow(c, ps.stmt.SQL, vars...).Scan(&c.req.role, &root)
}
if err != nil {
return nil, nil, err
}
if err := tx.Commit(c); err != nil {
return nil, nil, err
}
return root, ps, nil
}
func (c *coreContext) resolveSQL() ([]byte, uint32, error) {
tx, err := db.Begin(c)
if err != nil {
return nil, 0, err
}
defer tx.Rollback(c)
mutation := isMutation(c.req.Query)
useRoleQuery := len(conf.RolesQuery) != 0 && mutation
if useRoleQuery {
if c.req.role, err = c.executeRoleQuery(tx); err != nil {
return nil, 0, err
}
} else if v := c.Value(userRoleKey); v != nil {
c.req.role = v.(string)
}
stmts, err := c.buildStmt()
if err != nil {
return nil, 0, err
}
var st *stmt
if mutation {
st = findStmt(c.req.role, stmts)
} else {
st = &stmts[0]
}
t := fasttemplate.New(st.sql, openVar, closeVar)
buf := &bytes.Buffer{}
_, err = t.ExecuteFunc(buf, varMap(c))
if err == errNoUserID &&
authFailBlock == authFailBlockPerQuery &&
authCheck(c) == false {
return nil, 0, errUnauthorized
}
if err != nil {
return nil, 0, err
}
finalSQL := buf.String()
var stime time.Time
if conf.EnableTracing {
stime = time.Now()
}
if v := c.Value(userIDKey); v != nil {
_, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
if err != nil {
return nil, 0, err
}
}
var root []byte
if mutation {
err = tx.QueryRow(c, finalSQL).Scan(&root)
} else {
err = tx.QueryRow(c, finalSQL).Scan(&c.req.role, &root)
}
if err != nil {
return nil, 0, err
}
if err := tx.Commit(c); err != nil {
return nil, 0, err
}
if mutation {
st = findStmt(c.req.role, stmts)
} else {
st = &stmts[0]
}
if conf.EnableTracing && len(st.qc.Selects) != 0 {
c.addTrace(
st.qc.Selects,
st.qc.Selects[0].ID,
stime)
}
if conf.UseAllowList == false {
_allowList.add(&c.req)
}
return root, st.skipped, nil
}
func (c *coreContext) resolveRemote(
hdr http.Header,
h *xxhash.Digest,
@ -259,125 +416,15 @@ func (c *coreContext) resolveRemotes(
return to, cerr
}
func (c *coreContext) resolvePreparedSQL(gql string) ([]byte, *preparedItem, error) {
ps, ok := _preparedList[gqlHash(gql, c.req.Vars)]
if !ok {
return nil, nil, errUnauthorized
func (c *coreContext) executeRoleQuery(tx pgx.Tx) (string, error) {
var role string
row := tx.QueryRow(c, "_sg_get_role", c.req.role, 1)
if err := row.Scan(&role); err != nil {
return "", err
}
var root []byte
vars := varList(c, ps.args)
tx, err := db.Begin(c)
if err != nil {
return nil, nil, err
}
defer tx.Rollback(c)
if v := c.Value(userIDKey); v != nil {
_, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
if err != nil {
return nil, nil, err
}
}
err = tx.QueryRow(c, ps.stmt.SQL, vars...).Scan(&root)
if err != nil {
return nil, nil, err
}
if err := tx.Commit(c); err != nil {
return nil, nil, err
}
fmt.Printf("PRE: %v\n", ps.stmt)
return root, ps, nil
}
func (c *coreContext) resolveSQL(qc *qcode.QCode) ([]byte, uint32, error) {
var vars map[string]json.RawMessage
stmt := &bytes.Buffer{}
if len(c.req.Vars) != 0 {
if err := json.Unmarshal(c.req.Vars, &vars); err != nil {
return nil, 0, err
}
}
skipped, err := pcompile.Compile(qc, stmt, psql.Variables(vars))
if err != nil {
return nil, 0, err
}
t := fasttemplate.New(stmt.String(), openVar, closeVar)
stmt.Reset()
_, err = t.ExecuteFunc(stmt, varMap(c))
if err == errNoUserID &&
authFailBlock == authFailBlockPerQuery &&
authCheck(c) == false {
return nil, 0, errUnauthorized
}
if err != nil {
return nil, 0, err
}
finalSQL := stmt.String()
// if conf.LogLevel == "debug" {
// os.Stdout.WriteString(finalSQL)
// os.Stdout.WriteString("\n\n")
// }
var st time.Time
if conf.EnableTracing {
st = time.Now()
}
tx, err := db.Begin(c)
if err != nil {
return nil, 0, err
}
defer tx.Rollback(c)
if v := c.Value(userIDKey); v != nil {
_, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
if err != nil {
return nil, 0, err
}
}
//fmt.Printf("\nRAW: %#v\n", finalSQL)
var root []byte
err = tx.QueryRow(c, finalSQL).Scan(&root)
if err != nil {
return nil, 0, err
}
if err := tx.Commit(c); err != nil {
return nil, 0, err
}
if conf.EnableTracing && len(qc.Selects) != 0 {
c.addTrace(
qc.Selects,
qc.Selects[0].ID,
st)
}
if conf.UseAllowList == false {
_allowList.add(&c.req)
}
return root, skipped, nil
return role, nil
}
func (c *coreContext) render(w io.Writer, data []byte) error {

144
serv/core_build.go Normal file
View File

@ -0,0 +1,144 @@
package serv
import (
"bytes"
"encoding/json"
"errors"
"io"
"github.com/dosco/super-graph/psql"
"github.com/dosco/super-graph/qcode"
)
type stmt struct {
role *configRole
qc *qcode.QCode
skipped uint32
sql string
}
func (c *coreContext) buildStmt() ([]stmt, error) {
var vars map[string]json.RawMessage
if len(c.req.Vars) != 0 {
if err := json.Unmarshal(c.req.Vars, &vars); err != nil {
return nil, err
}
}
gql := []byte(c.req.Query)
if len(conf.Roles) == 0 {
return nil, errors.New(`no roles found ('user' and 'anon' required)`)
}
qc, err := qcompile.Compile(gql, conf.Roles[0].Name)
if err != nil {
return nil, err
}
stmts := make([]stmt, 0, len(conf.Roles))
mutation := (qc.Type != qcode.QTQuery)
w := &bytes.Buffer{}
for i := range conf.Roles {
role := &conf.Roles[i]
if mutation && len(c.req.role) != 0 && role.Name != c.req.role {
continue
}
if i > 0 {
qc, err = qcompile.Compile(gql, role.Name)
if err != nil {
return nil, err
}
}
stmts = append(stmts, stmt{role: role, qc: qc})
if mutation {
skipped, err := pcompile.Compile(qc, w, psql.Variables(vars))
if err != nil {
return nil, err
}
s := &stmts[len(stmts)-1]
s.skipped = skipped
s.sql = w.String()
w.Reset()
}
}
if mutation {
return stmts, nil
}
io.WriteString(w, `SELECT "_sg_auth_info"."role", (CASE "_sg_auth_info"."role" `)
for _, s := range stmts {
io.WriteString(w, `WHEN '`)
io.WriteString(w, s.role.Name)
io.WriteString(w, `' THEN (`)
s.skipped, err = pcompile.Compile(s.qc, w, psql.Variables(vars))
if err != nil {
return nil, err
}
io.WriteString(w, `) `)
}
io.WriteString(w, `END) FROM (`)
if len(conf.RolesQuery) == 0 {
v := c.Value(userRoleKey)
io.WriteString(w, `VALUES ("`)
if v != nil {
io.WriteString(w, v.(string))
} else {
io.WriteString(w, c.req.role)
}
io.WriteString(w, `")) AS "_sg_auth_info"(role) LIMIT 1;`)
} else {
io.WriteString(w, `SELECT (CASE WHEN EXISTS (`)
io.WriteString(w, conf.RolesQuery)
io.WriteString(w, `) THEN `)
io.WriteString(w, `(SELECT (CASE`)
for _, s := range stmts {
if len(s.role.Match) == 0 {
continue
}
io.WriteString(w, ` WHEN `)
io.WriteString(w, s.role.Match)
io.WriteString(w, ` THEN '`)
io.WriteString(w, s.role.Name)
io.WriteString(w, `'`)
}
if len(c.req.role) == 0 {
io.WriteString(w, ` ELSE 'anon' END) FROM (`)
} else {
io.WriteString(w, ` ELSE '`)
io.WriteString(w, c.req.role)
io.WriteString(w, `' END) FROM (`)
}
io.WriteString(w, conf.RolesQuery)
io.WriteString(w, `) AS "_sg_auth_roles_query" LIMIT 1) ELSE '`)
if len(c.req.role) == 0 {
io.WriteString(w, `anon`)
} else {
io.WriteString(w, c.req.role)
}
io.WriteString(w, `' END) FROM (VALUES (1)) AS "_sg_auth_filler") AS "_sg_auth_info"(role) LIMIT 1; `)
}
stmts[0].sql = w.String()
stmts[0].role = nil
return stmts, nil
}

View File

@ -30,14 +30,15 @@ type gqlReq struct {
Query string `json:"query"`
Vars json.RawMessage `json:"variables"`
ref string
role string
hdr http.Header
}
type variables map[string]json.RawMessage
type gqlResp struct {
Error string `json:"error,omitempty"`
Data json.RawMessage `json:"data"`
Error string `json:"message,omitempty"`
Data json.RawMessage `json:"data,omitempty"`
Extensions *extensions `json:"extensions,omitempty"`
}
@ -94,55 +95,20 @@ func apiv1Http(w http.ResponseWriter, r *http.Request) {
}
if strings.EqualFold(ctx.req.OpName, introspectionQuery) {
// dat, err := ioutil.ReadFile("test.schema")
// if err != nil {
// http.Error(w, err.Error(), http.StatusInternalServerError)
// return
// }
//w.Write(dat)
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{
"data": {
"__schema": {
"queryType": {
"name": "Query"
},
"mutationType": null,
"subscriptionType": null
}
},
"extensions":{
"tracing":{
"version":1,
"startTime":"2019-06-04T19:53:31.093Z",
"endTime":"2019-06-04T19:53:31.108Z",
"duration":15219720,
"execution": {
"resolvers": [{
"path": ["__schema"],
"parentType": "Query",
"fieldName": "__schema",
"returnType": "__Schema!",
"startOffset": 50950,
"duration": 17187
}]
}
}
}
}`))
introspect(w)
return
}
err = ctx.handleReq(w, r)
if err == errUnauthorized {
err := "Not authorized"
logger.Debug().Msg(err)
http.Error(w, err, 401)
w.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(w).Encode(gqlResp{Error: err.Error()})
return
}
if err != nil {
logger.Err(err).Msg("Failed to handle request")
logger.Err(err).Msg("failed to handle request")
errorResp(w, err)
}
}

36
serv/introsp.go Normal file
View File

@ -0,0 +1,36 @@
package serv
import "net/http"
func introspect(w http.ResponseWriter) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{
"data": {
"__schema": {
"queryType": {
"name": "Query"
},
"mutationType": null,
"subscriptionType": null
}
},
"extensions":{
"tracing":{
"version":1,
"startTime":"2019-06-04T19:53:31.093Z",
"endTime":"2019-06-04T19:53:31.108Z",
"duration":15219720,
"execution": {
"resolvers": [{
"path": ["__schema"],
"parentType": "Query",
"fieldName": "__schema",
"returnType": "__Schema!",
"startOffset": 50950,
"duration": 17187
}]
}
}
}
}`))
}

View File

@ -7,7 +7,6 @@ import (
"fmt"
"io"
"github.com/dosco/super-graph/psql"
"github.com/dosco/super-graph/qcode"
"github.com/jackc/pgconn"
"github.com/valyala/fasttemplate"
@ -27,55 +26,105 @@ var (
func initPreparedList() {
_preparedList = make(map[string]*preparedItem)
for k, v := range _allowList.list {
err := prepareStmt(k, v.gql, v.vars)
if err := prepareRoleStmt(); err != nil {
logger.Fatal().Err(err).Msg("failed to prepare get role statement")
}
for _, v := range _allowList.list {
err := prepareStmt(v.gql, v.vars)
if err != nil {
logger.Warn().Err(err).Send()
logger.Warn().Str("gql", v.gql).Err(err).Send()
}
}
}
func prepareStmt(key, gql string, varBytes json.RawMessage) error {
if len(gql) == 0 || len(key) == 0 {
func prepareStmt(gql string, varBytes json.RawMessage) error {
if len(gql) == 0 {
return nil
}
qc, err := qcompile.Compile([]byte(gql))
c := &coreContext{Context: context.Background()}
c.req.Query = gql
c.req.Vars = varBytes
stmts, err := c.buildStmt()
if err != nil {
return err
}
var vars map[string]json.RawMessage
if len(stmts) != 0 && stmts[0].qc.Type == qcode.QTQuery {
c.req.Vars = nil
}
if len(varBytes) != 0 {
vars = make(map[string]json.RawMessage)
for _, s := range stmts {
if len(s.sql) == 0 {
continue
}
if err := json.Unmarshal(varBytes, &vars); err != nil {
finalSQL, am := processTemplate(s.sql)
ctx := context.Background()
tx, err := db.Begin(ctx)
if err != nil {
return err
}
defer tx.Rollback(ctx)
pstmt, err := tx.Prepare(ctx, "", finalSQL)
if err != nil {
return err
}
var key string
if s.role == nil {
key = gqlHash(gql, c.req.Vars, "")
} else {
key = gqlHash(gql, c.req.Vars, s.role.Name)
}
_preparedList[key] = &preparedItem{
stmt: pstmt,
args: am,
skipped: s.skipped,
qc: s.qc,
}
if err := tx.Commit(ctx); err != nil {
return err
}
}
buf := &bytes.Buffer{}
return nil
}
skipped, err := pcompile.Compile(qc, buf, psql.Variables(vars))
if err != nil {
return err
func prepareRoleStmt() error {
if len(conf.RolesQuery) == 0 {
return nil
}
t := fasttemplate.New(buf.String(), `{{`, `}}`)
am := make([][]byte, 0, 5)
i := 0
w := &bytes.Buffer{}
finalSQL := t.ExecuteFuncString(func(w io.Writer, tag string) (int, error) {
am = append(am, []byte(tag))
i++
return w.Write([]byte(fmt.Sprintf("$%d", i)))
})
if err != nil {
return err
io.WriteString(w, `SELECT (CASE`)
for _, role := range conf.Roles {
if len(role.Match) == 0 {
continue
}
io.WriteString(w, ` WHEN `)
io.WriteString(w, role.Match)
io.WriteString(w, ` THEN '`)
io.WriteString(w, role.Name)
io.WriteString(w, `'`)
}
io.WriteString(w, ` ELSE {{role}} END) FROM (`)
io.WriteString(w, conf.RolesQuery)
io.WriteString(w, `) AS "_sg_auth_roles_query"`)
roleSQL, _ := processTemplate(w.String())
ctx := context.Background()
tx, err := db.Begin(ctx)
@ -84,21 +133,28 @@ func prepareStmt(key, gql string, varBytes json.RawMessage) error {
}
defer tx.Rollback(ctx)
pstmt, err := tx.Prepare(ctx, "", finalSQL)
_, err = tx.Prepare(ctx, "_sg_get_role", roleSQL)
if err != nil {
return err
}
_preparedList[key] = &preparedItem{
stmt: pstmt,
args: am,
skipped: skipped,
qc: qc,
}
if err := tx.Commit(ctx); err != nil {
return err
}
return nil
}
func processTemplate(tmpl string) (string, [][]byte) {
t := fasttemplate.New(tmpl, `{{`, `}}`)
am := make([][]byte, 0, 5)
i := 0
vmap := make(map[string]int)
return t.ExecuteFuncString(func(w io.Writer, tag string) (int, error) {
if n, ok := vmap[tag]; ok {
return w.Write([]byte(fmt.Sprintf("$%d", n)))
}
am = append(am, []byte(tag))
i++
vmap[tag] = i
return w.Write([]byte(fmt.Sprintf("$%d", i)))
}), am
}

View File

@ -12,7 +12,6 @@ import (
rice "github.com/GeertJohan/go.rice"
"github.com/dosco/super-graph/psql"
"github.com/dosco/super-graph/qcode"
"github.com/gobuffalo/flect"
)
func initCompilers(c *config) (*qcode.Compiler, *psql.Compiler, error) {
@ -22,52 +21,53 @@ func initCompilers(c *config) (*qcode.Compiler, *psql.Compiler, error) {
}
conf := qcode.Config{
DefaultFilter: c.DB.Defaults.Filter,
FilterMap: qcode.Filters{
All: make(map[string][]string, len(c.Tables)),
Query: make(map[string][]string, len(c.Tables)),
Insert: make(map[string][]string, len(c.Tables)),
Update: make(map[string][]string, len(c.Tables)),
Delete: make(map[string][]string, len(c.Tables)),
},
Blocklist: c.DB.Defaults.Blocklist,
KeepArgs: false,
}
for i := range c.Tables {
t := c.Tables[i]
singular := flect.Singularize(t.Name)
plural := flect.Pluralize(t.Name)
setFilter := func(fm map[string][]string, fil []string) {
switch {
case len(fil) == 0:
return
case fil[0] == "none" || len(fil[0]) == 0:
fm[singular] = []string{}
fm[plural] = []string{}
default:
fm[singular] = t.Filter
fm[plural] = t.Filter
}
}
setFilter(conf.FilterMap.All, t.Filter)
setFilter(conf.FilterMap.Query, t.FilterQuery)
setFilter(conf.FilterMap.Insert, t.FilterInsert)
setFilter(conf.FilterMap.Update, t.FilterUpdate)
setFilter(conf.FilterMap.Delete, t.FilterDelete)
}
qc, err := qcode.NewCompiler(conf)
if err != nil {
return nil, nil, err
}
for _, r := range c.Roles {
for _, t := range r.Tables {
query := qcode.QueryConfig{
Limit: t.Query.Limit,
Filters: t.Query.Filters,
Columns: t.Query.Columns,
DisableFunctions: t.Query.DisableAggregation,
}
insert := qcode.InsertConfig{
Filters: t.Insert.Filters,
Columns: t.Insert.Columns,
Set: t.Insert.Set,
}
update := qcode.UpdateConfig{
Filters: t.Insert.Filters,
Columns: t.Insert.Columns,
Set: t.Insert.Set,
}
delete := qcode.DeleteConfig{
Filters: t.Insert.Filters,
Columns: t.Insert.Columns,
}
qc.AddRole(r.Name, t.Name, qcode.TRConfig{
Query: query,
Insert: insert,
Update: update,
Delete: delete,
})
}
}
pc := psql.NewCompiler(psql.Config{
Schema: schema,
Vars: c.getVariables(),
Vars: c.DB.Vars,
})
return qc, pc, nil

45
serv/sqllog.go Normal file
View File

@ -0,0 +1,45 @@
package serv
import (
"context"
"github.com/jackc/pgx/v4"
"github.com/rs/zerolog"
)
type Logger struct {
logger zerolog.Logger
}
// NewLogger accepts a zerolog.Logger as input and returns a new custom pgx
// logging fascade as output.
func NewSQLLogger(logger zerolog.Logger) *Logger {
return &Logger{
logger: logger.With().Logger(),
}
}
func (pl *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) {
var zlevel zerolog.Level
switch level {
case pgx.LogLevelNone:
zlevel = zerolog.NoLevel
case pgx.LogLevelError:
zlevel = zerolog.ErrorLevel
case pgx.LogLevelWarn:
zlevel = zerolog.WarnLevel
case pgx.LogLevelInfo:
zlevel = zerolog.InfoLevel
case pgx.LogLevelDebug:
zlevel = zerolog.DebugLevel
default:
zlevel = zerolog.DebugLevel
}
if sql, ok := data["sql"]; ok {
delete(data, "sql")
pl.logger.WithLevel(zlevel).Fields(data).Msg(sql.(string))
} else {
pl.logger.WithLevel(zlevel).Fields(data).Msg(msg)
}
}

View File

@ -21,16 +21,29 @@ func mkkey(h *xxhash.Digest, k1 string, k2 string) uint64 {
return v
}
func gqlHash(b string, vars []byte) string {
func gqlHash(b string, vars []byte, role string) string {
b = strings.TrimSpace(b)
h := sha1.New()
query := "query"
s, e := 0, 0
space := []byte{' '}
starting := true
var b0, b1 byte
for {
if starting && b[e] == 'q' {
n := 0
se := e
for e < len(b) && n < len(query) && b[e] == query[n] {
n++
e++
}
if n != len(query) {
io.WriteString(h, strings.ToLower(b[se:e]))
}
}
if ws(b[e]) {
for e < len(b) && ws(b[e]) {
e++
@ -42,6 +55,7 @@ func gqlHash(b string, vars []byte) string {
h.Write(space)
}
} else {
starting = false
s = e
for e < len(b) && ws(b[e]) == false {
e++
@ -56,6 +70,10 @@ func gqlHash(b string, vars []byte) string {
}
}
if len(role) != 0 {
io.WriteString(h, role)
}
if vars == nil || len(vars) == 0 {
return hex.EncodeToString(h.Sum(nil))
}
@ -80,3 +98,26 @@ func ws(b byte) bool {
func al(b byte) bool {
return (b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z') || (b >= '0' && b <= '9')
}
func isMutation(sql string) bool {
for i := range sql {
b := sql[i]
if b == '{' {
return false
}
if al(b) {
return (b == 'm' || b == 'M')
}
}
return false
}
func findStmt(role string, stmts []stmt) *stmt {
for i := range stmts {
if stmts[i].role.Name != role {
continue
}
return &stmts[i]
}
return nil
}

View File

@ -5,7 +5,7 @@ import (
"testing"
)
func TestRelaxHash1(t *testing.T) {
func TestGQLHash1(t *testing.T) {
var v1 = `
products(
limit: 30,
@ -24,15 +24,15 @@ func TestRelaxHash1(t *testing.T) {
price
} `
h1 := gqlHash(v1, nil)
h2 := gqlHash(v2, nil)
h1 := gqlHash(v1, nil, "")
h2 := gqlHash(v2, nil, "")
if strings.Compare(h1, h2) != 0 {
t.Fatal("Hashes don't match they should")
}
}
func TestRelaxHash2(t *testing.T) {
func TestGQLHash2(t *testing.T) {
var v1 = `
{
products(
@ -53,15 +53,15 @@ func TestRelaxHash2(t *testing.T) {
var v2 = ` { products( limit: 30, order_by: { price: desc }, distinct: [ price ] where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) { id name price user { id email } } } `
h1 := gqlHash(v1, nil)
h2 := gqlHash(v2, nil)
h1 := gqlHash(v1, nil, "")
h2 := gqlHash(v2, nil, "")
if strings.Compare(h1, h2) != 0 {
t.Fatal("Hashes don't match they should")
}
}
func TestRelaxHash3(t *testing.T) {
func TestGQLHash3(t *testing.T) {
var v1 = `users {
id
email
@ -86,15 +86,44 @@ func TestRelaxHash3(t *testing.T) {
}
`
h1 := gqlHash(v1, nil)
h2 := gqlHash(v2, nil)
h1 := gqlHash(v1, nil, "")
h2 := gqlHash(v2, nil, "")
if strings.Compare(h1, h2) != 0 {
t.Fatal("Hashes don't match they should")
}
}
func TestRelaxHashWithVars1(t *testing.T) {
func TestGQLHash4(t *testing.T) {
var v1 = `
query {
products(
limit: 30
order_by: { price: desc }
distinct: [price]
where: { id: { and: { greater_or_equals: 20, lt: 28 } } }
) {
id
name
price
user {
id
email
}
}
}`
var v2 = ` { products( limit: 30, order_by: { price: desc }, distinct: [ price ] where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) { id name price user { id email } } } `
h1 := gqlHash(v1, nil, "")
h2 := gqlHash(v2, nil, "")
if strings.Compare(h1, h2) != 0 {
t.Fatal("Hashes don't match they should")
}
}
func TestGQLHashWithVars1(t *testing.T) {
var q1 = `
products(
limit: 30,
@ -136,15 +165,15 @@ func TestRelaxHashWithVars1(t *testing.T) {
"user": 123
}`
h1 := gqlHash(q1, []byte(v1))
h2 := gqlHash(q2, []byte(v2))
h1 := gqlHash(q1, []byte(v1), "user")
h2 := gqlHash(q2, []byte(v2), "user")
if strings.Compare(h1, h2) != 0 {
t.Fatal("Hashes don't match they should")
}
}
func TestRelaxHashWithVars2(t *testing.T) {
func TestGQLHashWithVars2(t *testing.T) {
var q1 = `
products(
limit: 30,
@ -193,8 +222,8 @@ func TestRelaxHashWithVars2(t *testing.T) {
"user": 123
}`
h1 := gqlHash(q1, []byte(v1))
h2 := gqlHash(q2, []byte(v2))
h1 := gqlHash(q1, []byte(v1), "user")
h2 := gqlHash(q2, []byte(v2), "user")
if strings.Compare(h1, h2) != 0 {
t.Fatal("Hashes don't match they should")

View File

@ -11,17 +11,27 @@ import (
func varMap(ctx *coreContext) func(w io.Writer, tag string) (int, error) {
return func(w io.Writer, tag string) (int, error) {
switch tag {
case "user_id":
if v := ctx.Value(userIDKey); v != nil {
return stringVar(w, v.(string))
}
return 0, errNoUserID
case "user_id_provider":
if v := ctx.Value(userIDProviderKey); v != nil {
return stringVar(w, v.(string))
}
return 0, errNoUserID
io.WriteString(w, "null")
return 0, nil
case "user_id":
if v := ctx.Value(userIDKey); v != nil {
return stringVar(w, v.(string))
}
io.WriteString(w, "null")
return 0, nil
case "user_role":
if v := ctx.Value(userRoleKey); v != nil {
return stringVar(w, v.(string))
}
io.WriteString(w, "null")
return 0, nil
}
fields := jsn.Get(ctx.req.Vars, [][]byte{[]byte(tag)})

View File

@ -80,7 +80,7 @@ SQL Output
account_id: "select account_id from users where id = $user_id"
defaults:
filter: ["{ user_id: { eq: $user_id } }"]
Filters: ["{ user_id: { eq: $user_id } }"]
blacklist:
- password
@ -88,14 +88,14 @@ SQL Output
fields:
- name: users
filter: ["{ id: { eq: $user_id } }"]
Filters: ["{ id: { eq: $user_id } }"]
- name: products
filter: [
Filters: [
"{ price: { gt: 0 } }",
"{ price: { lt: 8 } }"
]
- name: me
table: users
filter: ["{ id: { eq: $user_id } }"]
Filters: ["{ id: { eq: $user_id } }"]

View File

@ -1,4 +1,4 @@
app_name: "{% app_name %} Development"
app_name: "Super Graph Development"
host_port: 0.0.0.0:8080
web_ui: true
@ -53,7 +53,7 @@ auth:
# Comment this out if you want to disable setting
# the user_id via a header. Good for testing
header: X-User-ID
creds_in_header: true
rails:
# Rails version this is used for reading the
@ -84,7 +84,7 @@ database:
type: postgres
host: db
port: 5432
dbname: {% app_name_slug %}_development
dbname: app_development
user: postgres
password: ''
@ -100,7 +100,7 @@ database:
# Define defaults to for the field key and values below
defaults:
# filter: ["{ user_id: { eq: $user_id } }"]
# filters: ["{ user_id: { eq: $user_id } }"]
# Field and table names that you wish to block
blocklist:
@ -111,45 +111,81 @@ database:
- encrypted
- token
tables:
- name: users
# This filter will overwrite defaults.filter
# filter: ["{ id: { eq: $user_id } }"]
# filter_query: ["{ id: { eq: $user_id } }"]
filter_update: ["{ id: { eq: $user_id } }"]
filter_delete: ["{ id: { eq: $user_id } }"]
tables:
- name: customers
remotes:
- name: payments
id: stripe_id
url: http://rails_app:3000/stripe/$id
path: data
# debug: true
pass_headers:
- cookie
set_headers:
- name: Host
value: 0.0.0.0
# - name: Authorization
# value: Bearer <stripe_api_key>
# - name: products
# # Multiple filters are AND'd together
# filter: [
# "{ price: { gt: 0 } }",
# "{ price: { lt: 8 } }"
# ]
- # You can create new fields that have a
# real db table backing them
name: me
table: users
- name: customers
# No filter is used for this field not
# even defaults.filter
filter: none
roles_query: "SELECT * FROM users as usr WHERE id = $user_id"
remotes:
- name: payments
id: stripe_id
url: http://rails_app:3000/stripe/$id
path: data
# debug: true
pass_headers:
- cookie
set_headers:
- name: Host
value: 0.0.0.0
# - name: Authorization
# value: Bearer <stripe_api_key>
roles:
- name: anon
tables:
- name: products
limit: 10
- # You can create new fields that have a
# real db table backing them
name: me
table: users
filter: ["{ id: { eq: $user_id } }"]
query:
columns: ["id", "name", "description" ]
aggregation: false
# - name: posts
# filter: ["{ account_id: { _eq: $account_id } }"]
insert:
allow: false
update:
allow: false
delete:
allow: false
- name: user
tables:
- name: users
query:
filters: ["{ id: { _eq: $user_id } }"]
- name: products
query:
limit: 50
filters: ["{ user_id: { eq: $user_id } }"]
columns: ["id", "name", "description" ]
disable_aggregation: false
insert:
filters: ["{ user_id: { eq: $user_id } }"]
columns: ["id", "name", "description" ]
set:
- created_at: "now"
update:
filters: ["{ user_id: { eq: $user_id } }"]
columns:
- id
- name
set:
- updated_at: "now"
delete:
deny: true
- name: admin
match: id = 1
tables:
- name: users
# query:
# filters: ["{ account_id: { _eq: $account_id } }"]

View File

@ -1,4 +1,4 @@
app_name: "{% app_name %} Production"
app_name: "Super Graph Production"
host_port: 0.0.0.0:8080
web_ui: false
@ -47,10 +47,6 @@ auth:
type: rails
cookie: _app_session
# Comment this out if you want to disable setting
# the user_id via a header. Good for testing
header: X-User-ID
rails:
# Rails version this is used for reading the
# various cookies formats.
@ -80,7 +76,7 @@ database:
type: postgres
host: db
port: 5432
dbname: {% app_name_slug %}_production
dbname: {{app_name_slug}}_development
user: postgres
password: ''
#pool_size: 10
@ -94,7 +90,7 @@ database:
# Define defaults to for the field key and values below
defaults:
filter: ["{ user_id: { eq: $user_id } }"]
filters: ["{ user_id: { eq: $user_id } }"]
# Field and table names that you wish to block
blocklist:
@ -105,43 +101,79 @@ database:
- encrypted
- token
tables:
- name: users
# This filter will overwrite defaults.filter
# filter: ["{ id: { eq: $user_id } }"]
# filter_query: ["{ id: { eq: $user_id } }"]
filter_update: ["{ id: { eq: $user_id } }"]
filter_delete: ["{ id: { eq: $user_id } }"]
tables:
- name: customers
# remotes:
# - name: payments
# id: stripe_id
# url: http://rails_app:3000/stripe/$id
# path: data
# # pass_headers:
# # - cookie
# # - host
# set_headers:
# - name: Authorization
# value: Bearer <stripe_api_key>
- name: products
# Multiple filters are AND'd together
filter: [
"{ price: { gt: 0 } }",
"{ price: { lt: 8 } }"
]
- # You can create new fields that have a
# real db table backing them
name: me
table: users
- name: customers
# No filter is used for this field not
# even defaults.filter
filter: none
roles_query: "SELECT * FROM users as usr WHERE id = $user_id"
# remotes:
# - name: payments
# id: stripe_id
# url: http://rails_app:3000/stripe/$id
# path: data
# # pass_headers:
# # - cookie
# # - host
# set_headers:
# - name: Authorization
# value: Bearer <stripe_api_key>
roles:
- name: anon
tables:
- name: products
limit: 10
- # You can create new fields that have a
# real db table backing them
name: me
table: users
filter: ["{ id: { eq: $user_id } }"]
query:
columns: ["id", "name", "description" ]
aggregation: false
# - name: posts
# filter: ["{ account_id: { _eq: $account_id } }"]
insert:
allow: false
update:
allow: false
delete:
allow: false
- name: user
tables:
- name: users
query:
filters: ["{ id: { _eq: $user_id } }"]
- name: products
query:
limit: 50
filters: ["{ user_id: { eq: $user_id } }"]
columns: ["id", "name", "description" ]
disable_aggregation: false
insert:
filters: ["{ user_id: { eq: $user_id } }"]
columns: ["id", "name", "description" ]
set:
- created_at: "now"
update:
filters: ["{ user_id: { eq: $user_id } }"]
columns:
- id
- name
set:
- updated_at: "now"
delete:
deny: true
- name: admin
match: id = 1
tables:
- name: users
# query:
# filters: ["{ account_id: { _eq: $account_id } }"]