Merge pull request #11 from dosco/rbac

Role based access control and other fixes
This commit is contained in:
Vikram Rangnekar 2019-10-25 01:49:37 -04:00 committed by GitHub
commit ff13f651d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
39 changed files with 2439 additions and 1148 deletions

1
.gitignore vendored
View File

@ -28,3 +28,4 @@ main
.DS_Store .DS_Store
.swp .swp
main main
super-graph

13
.wtc.yaml Normal file
View File

@ -0,0 +1,13 @@
no_trace: false
debounce: 300 # if rule has no debounce, this will be used instead
ignore: \.git/
trig: [start, run] # will run on start
rules:
- name: start
- name: run
match: \.go$
ignore: web|examples|docs|_test\.go$
command: go run main.go serv
- name: test
match: _test\.go$
command: go test -cover {PKG}

View File

@ -11,9 +11,8 @@ RUN apk update && \
apk add --no-cache git && \ apk add --no-cache git && \
apk add --no-cache upx=3.95-r2 apk add --no-cache upx=3.95-r2
RUN go get -u github.com/shanzi/wu && \ RUN go get -u github.com/rafaelsq/wtc && \
go install github.com/shanzi/wu && \ go get -u github.com/GeertJohan/go.rice/rice
go get github.com/GeertJohan/go.rice/rice
WORKDIR /app WORKDIR /app
COPY . /app COPY . /app

View File

@ -46,6 +46,9 @@ This compiler is what sits at the heart of Super Graph with layers of useful fun
## Contact me ## Contact me
I'm happy to help you deploy Super Graph so feel free to reach out over
Twitter or Discord.
[twitter/dosco](https://twitter.com/dosco) [twitter/dosco](https://twitter.com/dosco)
[chat/super-graph](https://discord.gg/6pSWCTZ) [chat/super-graph](https://discord.gg/6pSWCTZ)

View File

@ -1,5 +1,27 @@
# http://localhost:8080/ # http://localhost:8080/
variables {
"data": [
{
"name": "Protect Ya Neck",
"created_at": "now",
"updated_at": "now"
},
{
"name": "Enter the Wu-Tang",
"created_at": "now",
"updated_at": "now"
}
]
}
mutation {
products(insert: $data) {
id
name
}
}
variables { variables {
"update": { "update": {
"name": "Wu-Tang", "name": "Wu-Tang",
@ -16,16 +38,16 @@ mutation {
} }
} }
variables { query {
"data": { users {
"product_id": 5 id
} email
} picture: avatar
products(limit: 2, where: {price: {gt: 10}}) {
mutation {
products(id: $product_id, delete: true) {
id id
name name
description
}
} }
} }
@ -73,6 +95,118 @@ query {
} }
} }
variables {
"data": [
{
"name": "Gumbo1",
"created_at": "now",
"updated_at": "now"
},
{
"name": "Gumbo2",
"created_at": "now",
"updated_at": "now"
}
]
}
query {
products {
id
name
user {
email
}
}
}
variables {
"data": {
"product_id": 5
}
}
mutation {
products(id: $product_id, delete: true) {
id
name
}
}
variables {
"data": [
{
"name": "Gumbo1",
"created_at": "now",
"updated_at": "now"
},
{
"name": "Gumbo2",
"created_at": "now",
"updated_at": "now"
}
]
}
query {
products {
id
name
price
users {
email
}
}
}
variables {
"data": {
"email": "gfk@myspace.com",
"full_name": "Ghostface Killah",
"created_at": "now",
"updated_at": "now"
}
}
mutation {
user(insert: $data) {
id
}
}
variables {
"data": [
{
"name": "Gumbo1",
"created_at": "now",
"updated_at": "now"
},
{
"name": "Gumbo2",
"created_at": "now",
"updated_at": "now"
}
]
}
query {
products {
id
name
users {
email
}
}
}
query {
me {
id
email
full_name
}
}
variables { variables {
"update": { "update": {
@ -112,62 +246,30 @@ query {
} }
} }
query {
me {
id
email
full_name
}
}
variables {
"data": {
"email": "gfk@myspace.com",
"full_name": "Ghostface Killah",
"created_at": "now",
"updated_at": "now"
}
}
mutation {
user(insert: $data) {
id
}
}
query {
users {
id
email
picture: avatar
products(limit: 2, where: {price: {gt: 10}}) {
id
name
description
}
}
}
variables { variables {
"data": [ "data": [
{ {
"name": "Protect Ya Neck", "name": "Gumbo1",
"created_at": "now", "created_at": "now",
"updated_at": "now" "updated_at": "now"
}, },
{ {
"name": "Enter the Wu-Tang", "name": "Gumbo2",
"created_at": "now", "created_at": "now",
"updated_at": "now" "updated_at": "now"
} }
] ]
} }
mutation { query {
products(insert: $data) { products {
id id
name name
description
users {
email
}
} }
} }

View File

@ -22,7 +22,7 @@ enable_tracing: true
# Watch the config folder and reload Super Graph # Watch the config folder and reload Super Graph
# with the new configs when a change is detected # with the new configs when a change is detected
reload_on_config_change: false reload_on_config_change: true
# File that points to the database seeding script # File that points to the database seeding script
# seed_file: seed.js # seed_file: seed.js
@ -53,7 +53,7 @@ auth:
# Comment this out if you want to disable setting # Comment this out if you want to disable setting
# the user_id via a header. Good for testing # the user_id via a header. Good for testing
header: X-User-ID creds_in_header: true
rails: rails:
# Rails version this is used for reading the # Rails version this is used for reading the
@ -100,7 +100,7 @@ database:
# Define defaults to for the field key and values below # Define defaults to for the field key and values below
defaults: defaults:
# filter: ["{ user_id: { eq: $user_id } }"] # filters: ["{ user_id: { eq: $user_id } }"]
# Field and table names that you wish to block # Field and table names that you wish to block
blocklist: blocklist:
@ -112,25 +112,7 @@ database:
- token - token
tables: tables:
- name: users
# This filter will overwrite defaults.filter
# filter: ["{ id: { eq: $user_id } }"]
# filter_query: ["{ id: { eq: $user_id } }"]
filter_update: ["{ id: { eq: $user_id } }"]
filter_delete: ["{ id: { eq: $user_id } }"]
# - name: products
# # Multiple filters are AND'd together
# filter: [
# "{ price: { gt: 0 } }",
# "{ price: { lt: 8 } }"
# ]
- name: customers - name: customers
# No filter is used for this field not
# even defaults.filter
filter: none
remotes: remotes:
- name: payments - name: payments
id: stripe_id id: stripe_id
@ -149,7 +131,61 @@ tables:
# real db table backing them # real db table backing them
name: me name: me
table: users table: users
filter: ["{ id: { eq: $user_id } }"]
# - name: posts roles_query: "SELECT * FROM users as usr WHERE id = $user_id"
# filter: ["{ account_id: { _eq: $account_id } }"]
roles:
- name: anon
tables:
- name: products
limit: 10
query:
columns: ["id", "name", "description" ]
aggregation: false
insert:
allow: false
update:
allow: false
delete:
allow: false
- name: user
tables:
- name: users
query:
filters: ["{ id: { _eq: $user_id } }"]
- name: products
query:
limit: 50
filters: ["{ user_id: { eq: $user_id } }"]
columns: ["id", "name", "description" ]
disable_aggregation: false
insert:
filters: ["{ user_id: { eq: $user_id } }"]
columns: ["id", "name", "description" ]
set:
- created_at: "now"
update:
filters: ["{ user_id: { eq: $user_id } }"]
columns:
- id
- name
set:
- updated_at: "now"
delete:
deny: true
- name: admin
match: id = 1
tables:
- name: users
# query:
# filters: ["{ account_id: { _eq: $account_id } }"]

View File

@ -47,10 +47,6 @@ auth:
type: rails type: rails
cookie: _app_session cookie: _app_session
# Comment this out if you want to disable setting
# the user_id via a header. Good for testing
header: X-User-ID
rails: rails:
# Rails version this is used for reading the # Rails version this is used for reading the
# various cookies formats. # various cookies formats.
@ -94,7 +90,7 @@ database:
# Define defaults to for the field key and values below # Define defaults to for the field key and values below
defaults: defaults:
filter: ["{ user_id: { eq: $user_id } }"] filters: ["{ user_id: { eq: $user_id } }"]
# Field and table names that you wish to block # Field and table names that you wish to block
blocklist: blocklist:
@ -106,25 +102,7 @@ database:
- token - token
tables: tables:
- name: users
# This filter will overwrite defaults.filter
# filter: ["{ id: { eq: $user_id } }"]
# filter_query: ["{ id: { eq: $user_id } }"]
filter_update: ["{ id: { eq: $user_id } }"]
filter_delete: ["{ id: { eq: $user_id } }"]
- name: products
# Multiple filters are AND'd together
filter: [
"{ price: { gt: 0 } }",
"{ price: { lt: 8 } }"
]
- name: customers - name: customers
# No filter is used for this field not
# even defaults.filter
filter: none
# remotes: # remotes:
# - name: payments # - name: payments
# id: stripe_id # id: stripe_id
@ -141,7 +119,61 @@ tables:
# real db table backing them # real db table backing them
name: me name: me
table: users table: users
filter: ["{ id: { eq: $user_id } }"]
# - name: posts roles_query: "SELECT * FROM users as usr WHERE id = $user_id"
# filter: ["{ account_id: { _eq: $account_id } }"]
roles:
- name: anon
tables:
- name: products
limit: 10
query:
columns: ["id", "name", "description" ]
aggregation: false
insert:
allow: false
update:
allow: false
delete:
allow: false
- name: user
tables:
- name: users
query:
filters: ["{ id: { _eq: $user_id } }"]
- name: products
query:
limit: 50
filters: ["{ user_id: { eq: $user_id } }"]
columns: ["id", "name", "description" ]
disable_aggregation: false
insert:
filters: ["{ user_id: { eq: $user_id } }"]
columns: ["id", "name", "description" ]
set:
- created_at: "now"
update:
filters: ["{ user_id: { eq: $user_id } }"]
columns:
- id
- name
set:
- updated_at: "now"
delete:
deny: true
- name: admin
match: id = 1
tables:
- name: users
# query:
# filters: ["{ account_id: { _eq: $account_id } }"]

View File

@ -34,7 +34,7 @@ services:
volumes: volumes:
- .:/app - .:/app
working_dir: /app working_dir: /app
command: wu -pattern="*.go" go run main.go serv command: wtc
depends_on: depends_on:
- db - db
- rails_app - rails_app

View File

@ -1043,26 +1043,35 @@ We're tried to ensure that the config file is self documenting and easy to work
app_name: "Super Graph Development" app_name: "Super Graph Development"
host_port: 0.0.0.0:8080 host_port: 0.0.0.0:8080
web_ui: true web_ui: true
debug_level: 1
# debug, info, warn, error, fatal, panic, disable # debug, info, warn, error, fatal, panic
log_level: "info" log_level: "debug"
# Disable this in development to get a list of # Disable this in development to get a list of
# queries used. When enabled super graph # queries used. When enabled super graph
# will only allow queries from this list # will only allow queries from this list
# List saved to ./config/allow.list # List saved to ./config/allow.list
use_allow_list: true use_allow_list: false
# Throw a 401 on auth failure for queries that need auth # Throw a 401 on auth failure for queries that need auth
# valid values: always, per_query, never # valid values: always, per_query, never
auth_fail_block: always auth_fail_block: never
# Latency tracing for database queries and remote joins # Latency tracing for database queries and remote joins
# the resulting latency information is returned with the # the resulting latency information is returned with the
# response # response
enable_tracing: true enable_tracing: true
# Watch the config folder and reload Super Graph
# with the new configs when a change is detected
reload_on_config_change: true
# File that points to the database seeding script
# seed_file: seed.js
# Path pointing to where the migrations can be found
migrations_path: ./config/migrations
# Postgres related environment Variables # Postgres related environment Variables
# SG_DATABASE_HOST # SG_DATABASE_HOST
# SG_DATABASE_PORT # SG_DATABASE_PORT
@ -1086,7 +1095,7 @@ auth:
# Comment this out if you want to disable setting # Comment this out if you want to disable setting
# the user_id via a header. Good for testing # the user_id via a header. Good for testing
header: X-User-ID creds_in_header: true
rails: rails:
# Rails version this is used for reading the # Rails version this is used for reading the
@ -1097,10 +1106,10 @@ auth:
secret_key_base: 0a248500a64c01184edb4d7ad3a805488f8097ac761b76aaa6c17c01dcb7af03a2f18ba61b2868134b9c7b79a122bc0dadff4367414a2d173297bfea92be5566 secret_key_base: 0a248500a64c01184edb4d7ad3a805488f8097ac761b76aaa6c17c01dcb7af03a2f18ba61b2868134b9c7b79a122bc0dadff4367414a2d173297bfea92be5566
# Remote cookie store. (memcache or redis) # Remote cookie store. (memcache or redis)
# url: redis://127.0.0.1:6379 # url: redis://redis:6379
# password: test # password: ""
# max_idle: 80, # max_idle: 80
# max_active: 12000, # max_active: 12000
# In most cases you don't need these # In most cases you don't need these
# salt: "encrypted cookie" # salt: "encrypted cookie"
@ -1120,20 +1129,23 @@ database:
dbname: app_development dbname: app_development
user: postgres user: postgres
password: '' password: ''
#schema: "public"
#pool_size: 10 #pool_size: 10
#max_retries: 0 #max_retries: 0
#log_level: "debug" #log_level: "debug"
# Define variables here that you want to use in filters # Define variables here that you want to use in filters
# sub-queries must be wrapped in ()
variables: variables:
account_id: "select account_id from users where id = $user_id" account_id: "(select account_id from users where id = $user_id)"
# Define defaults to for the field key and values below # Define defaults to for the field key and values below
defaults: defaults:
filter: ["{ user_id: { eq: $user_id } }"] # filters: ["{ user_id: { eq: $user_id } }"]
# Field and table names that you wish to block # Field and table names that you wish to block
blacklist: blocklist:
- ar_internal_metadata - ar_internal_metadata
- schema_migrations - schema_migrations
- secret - secret
@ -1142,42 +1154,84 @@ database:
- token - token
tables: tables:
- name: users
# This filter will overwrite defaults.filter
filter: ["{ id: { eq: $user_id } }"]
- name: products
# Multiple filters are AND'd together
filter: [
"{ price: { gt: 0 } }",
"{ price: { lt: 8 } }"
]
- name: customers - name: customers
# No filter is used for this field not
# even defaults.filter
filter: none
remotes: remotes:
- name: payments - name: payments
id: stripe_id id: stripe_id
url: http://rails_app:3000/stripe/$id url: http://rails_app:3000/stripe/$id
path: data path: data
# pass_headers: # debug: true
# - cookie pass_headers:
# - host - cookie
set_headers: set_headers:
- name: Authorization - name: Host
value: Bearer <stripe_api_key> value: 0.0.0.0
# - name: Authorization
# value: Bearer <stripe_api_key>
- # You can create new fields that have a - # You can create new fields that have a
# real db table backing them # real db table backing them
name: me name: me
table: users table: users
filter: ["{ id: { eq: $user_id } }"]
# - name: posts roles_query: "SELECT * FROM users as usr WHERE id = $user_id"
# filter: ["{ account_id: { _eq: $account_id } }"]
roles:
- name: anon
tables:
- name: products
limit: 10
query:
columns: ["id", "name", "description" ]
aggregation: false
insert:
allow: false
update:
allow: false
delete:
allow: false
- name: user
tables:
- name: users
query:
filters: ["{ id: { _eq: $user_id } }"]
- name: products
query:
limit: 50
filters: ["{ user_id: { eq: $user_id } }"]
columns: ["id", "name", "description" ]
disable_aggregation: false
insert:
filters: ["{ user_id: { eq: $user_id } }"]
columns: ["id", "name", "description" ]
set:
- created_at: "now"
update:
filters: ["{ user_id: { eq: $user_id } }"]
columns:
- id
- name
set:
- updated_at: "now"
delete:
deny: true
- name: admin
match: id = 1
tables:
- name: users
# query:
# filters: ["{ account_id: { _eq: $account_id } }"]
``` ```
If deploying into environments like Kubernetes it's useful to be able to configure things like secrets and hosts though environment variables therfore we expose the below environment variables. This is escpecially useful for secrets since they are usually injected in via a secrets management framework ie. Kubernetes Secrets If deploying into environments like Kubernetes it's useful to be able to configure things like secrets and hosts though environment variables therfore we expose the below environment variables. This is escpecially useful for secrets since they are usually injected in via a secrets management framework ie. Kubernetes Secrets

View File

@ -0,0 +1,273 @@
GIT
remote: https://github.com/stympy/faker.git
revision: 4e9144825fcc9ba5c83cc0fd037779ab82f3120b
branch: master
specs:
faker (2.6.0)
i18n (>= 1.6, < 1.8)
GEM
remote: https://rubygems.org/
specs:
actioncable (6.0.0)
actionpack (= 6.0.0)
nio4r (~> 2.0)
websocket-driver (>= 0.6.1)
actionmailbox (6.0.0)
actionpack (= 6.0.0)
activejob (= 6.0.0)
activerecord (= 6.0.0)
activestorage (= 6.0.0)
activesupport (= 6.0.0)
mail (>= 2.7.1)
actionmailer (6.0.0)
actionpack (= 6.0.0)
actionview (= 6.0.0)
activejob (= 6.0.0)
mail (~> 2.5, >= 2.5.4)
rails-dom-testing (~> 2.0)
actionpack (6.0.0)
actionview (= 6.0.0)
activesupport (= 6.0.0)
rack (~> 2.0)
rack-test (>= 0.6.3)
rails-dom-testing (~> 2.0)
rails-html-sanitizer (~> 1.0, >= 1.2.0)
actiontext (6.0.0)
actionpack (= 6.0.0)
activerecord (= 6.0.0)
activestorage (= 6.0.0)
activesupport (= 6.0.0)
nokogiri (>= 1.8.5)
actionview (6.0.0)
activesupport (= 6.0.0)
builder (~> 3.1)
erubi (~> 1.4)
rails-dom-testing (~> 2.0)
rails-html-sanitizer (~> 1.1, >= 1.2.0)
activejob (6.0.0)
activesupport (= 6.0.0)
globalid (>= 0.3.6)
activemodel (6.0.0)
activesupport (= 6.0.0)
activerecord (6.0.0)
activemodel (= 6.0.0)
activesupport (= 6.0.0)
activestorage (6.0.0)
actionpack (= 6.0.0)
activejob (= 6.0.0)
activerecord (= 6.0.0)
marcel (~> 0.3.1)
activesupport (6.0.0)
concurrent-ruby (~> 1.0, >= 1.0.2)
i18n (>= 0.7, < 2)
minitest (~> 5.1)
tzinfo (~> 1.1)
zeitwerk (~> 2.1, >= 2.1.8)
addressable (2.7.0)
public_suffix (>= 2.0.2, < 5.0)
archive-zip (0.12.0)
io-like (~> 0.3.0)
bcrypt (3.1.13)
bindex (0.8.1)
bootsnap (1.4.5)
msgpack (~> 1.0)
builder (3.2.3)
byebug (11.0.1)
capybara (3.29.0)
addressable
mini_mime (>= 0.1.3)
nokogiri (~> 1.8)
rack (>= 1.6.0)
rack-test (>= 0.6.3)
regexp_parser (~> 1.5)
xpath (~> 3.2)
childprocess (3.0.0)
chromedriver-helper (2.1.1)
archive-zip (~> 0.10)
nokogiri (~> 1.8)
coffee-rails (4.2.2)
coffee-script (>= 2.2.0)
railties (>= 4.0.0)
coffee-script (2.4.1)
coffee-script-source
execjs
coffee-script-source (1.12.2)
concurrent-ruby (1.1.5)
crass (1.0.4)
devise (4.7.1)
bcrypt (~> 3.0)
orm_adapter (~> 0.1)
railties (>= 4.1.0)
responders
warden (~> 1.2.3)
erubi (1.9.0)
execjs (2.7.0)
ffi (1.11.1)
globalid (0.4.2)
activesupport (>= 4.2.0)
i18n (1.7.0)
concurrent-ruby (~> 1.0)
io-like (0.3.0)
jbuilder (2.9.1)
activesupport (>= 4.2.0)
listen (3.1.5)
rb-fsevent (~> 0.9, >= 0.9.4)
rb-inotify (~> 0.9, >= 0.9.7)
ruby_dep (~> 1.2)
loofah (2.3.0)
crass (~> 1.0.2)
nokogiri (>= 1.5.9)
mail (2.7.1)
mini_mime (>= 0.1.1)
marcel (0.3.3)
mimemagic (~> 0.3.2)
method_source (0.9.2)
mimemagic (0.3.3)
mini_mime (1.0.2)
mini_portile2 (2.4.0)
minitest (5.12.2)
msgpack (1.3.1)
nio4r (2.5.2)
nokogiri (1.10.4)
mini_portile2 (~> 2.4.0)
orm_adapter (0.5.0)
pg (1.1.4)
public_suffix (4.0.1)
puma (3.12.1)
rack (2.0.7)
rack-test (1.1.0)
rack (>= 1.0, < 3)
rails (6.0.0)
actioncable (= 6.0.0)
actionmailbox (= 6.0.0)
actionmailer (= 6.0.0)
actionpack (= 6.0.0)
actiontext (= 6.0.0)
actionview (= 6.0.0)
activejob (= 6.0.0)
activemodel (= 6.0.0)
activerecord (= 6.0.0)
activestorage (= 6.0.0)
activesupport (= 6.0.0)
bundler (>= 1.3.0)
railties (= 6.0.0)
sprockets-rails (>= 2.0.0)
rails-dom-testing (2.0.3)
activesupport (>= 4.2.0)
nokogiri (>= 1.6)
rails-html-sanitizer (1.3.0)
loofah (~> 2.3)
railties (6.0.0)
actionpack (= 6.0.0)
activesupport (= 6.0.0)
method_source
rake (>= 0.8.7)
thor (>= 0.20.3, < 2.0)
rake (13.0.0)
rb-fsevent (0.10.3)
rb-inotify (0.10.0)
ffi (~> 1.0)
redis (4.1.3)
redis-actionpack (5.1.0)
actionpack (>= 4.0, < 7)
redis-rack (>= 1, < 3)
redis-store (>= 1.1.0, < 2)
redis-activesupport (5.2.0)
activesupport (>= 3, < 7)
redis-store (>= 1.3, < 2)
redis-rack (2.0.6)
rack (>= 1.5, < 3)
redis-store (>= 1.2, < 2)
redis-rails (5.0.2)
redis-actionpack (>= 5.0, < 6)
redis-activesupport (>= 5.0, < 6)
redis-store (>= 1.2, < 2)
redis-store (1.8.0)
redis (>= 4, < 5)
regexp_parser (1.6.0)
responders (3.0.0)
actionpack (>= 5.0)
railties (>= 5.0)
ruby_dep (1.5.0)
rubyzip (2.0.0)
sass (3.7.4)
sass-listen (~> 4.0.0)
sass-listen (4.0.0)
rb-fsevent (~> 0.9, >= 0.9.4)
rb-inotify (~> 0.9, >= 0.9.7)
sass-rails (5.1.0)
railties (>= 5.2.0)
sass (~> 3.1)
sprockets (>= 2.8, < 4.0)
sprockets-rails (>= 2.0, < 4.0)
tilt (>= 1.1, < 3)
selenium-webdriver (3.142.6)
childprocess (>= 0.5, < 4.0)
rubyzip (>= 1.2.2)
spring (2.1.0)
spring-watcher-listen (2.0.1)
listen (>= 2.7, < 4.0)
spring (>= 1.2, < 3.0)
sprockets (3.7.2)
concurrent-ruby (~> 1.0)
rack (> 1, < 3)
sprockets-rails (3.2.1)
actionpack (>= 4.0)
activesupport (>= 4.0)
sprockets (>= 3.0.0)
thor (0.20.3)
thread_safe (0.3.6)
tilt (2.0.10)
turbolinks (5.2.1)
turbolinks-source (~> 5.2)
turbolinks-source (5.2.0)
tzinfo (1.2.5)
thread_safe (~> 0.1)
uglifier (4.2.0)
execjs (>= 0.3.0, < 3)
warden (1.2.8)
rack (>= 2.0.6)
web-console (4.0.1)
actionview (>= 6.0.0)
activemodel (>= 6.0.0)
bindex (>= 0.4.0)
railties (>= 6.0.0)
websocket-driver (0.7.1)
websocket-extensions (>= 0.1.0)
websocket-extensions (0.1.4)
xpath (3.2.0)
nokogiri (~> 1.8)
zeitwerk (2.2.0)
PLATFORMS
ruby
DEPENDENCIES
bootsnap (>= 1.1.0)
byebug
capybara (>= 2.15)
chromedriver-helper
coffee-rails (~> 4.2)
devise
faker!
jbuilder (~> 2.5)
listen (>= 3.0.5, < 3.2)
pg (>= 0.18, < 2.0)
puma (~> 3.11)
rails (~> 6.0.0.rc1)
redis-rails
sass-rails (~> 5.0)
selenium-webdriver
spring
spring-watcher-listen (~> 2.0.0)
turbolinks (~> 5)
tzinfo-data
uglifier (>= 1.3.0)
web-console (>= 3.3.0)
RUBY VERSION
ruby 2.5.7p206
BUNDLED WITH
1.17.3

View File

@ -1,7 +1,6 @@
package psql package psql
import ( import (
"bytes"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -10,9 +9,9 @@ import (
"github.com/dosco/super-graph/qcode" "github.com/dosco/super-graph/qcode"
) )
var zeroPaging = qcode.Paging{} var noLimit = qcode.Paging{NoLimit: true}
func (co *Compiler) compileMutation(qc *qcode.QCode, w *bytes.Buffer, vars Variables) (uint32, error) { func (co *Compiler) compileMutation(qc *qcode.QCode, w io.Writer, vars Variables) (uint32, error) {
if len(qc.Selects) == 0 { if len(qc.Selects) == 0 {
return 0, errors.New("empty query") return 0, errors.New("empty query")
} }
@ -25,27 +24,27 @@ func (co *Compiler) compileMutation(qc *qcode.QCode, w *bytes.Buffer, vars Varia
return 0, err return 0, err
} }
c.w.WriteString(`WITH `) io.WriteString(c.w, `WITH `)
quoted(c.w, ti.Name) quoted(c.w, ti.Name)
c.w.WriteString(` AS `) io.WriteString(c.w, ` AS `)
switch root.Action { switch qc.Type {
case qcode.ActionInsert: case qcode.QTInsert:
if _, err := c.renderInsert(qc, w, vars, ti); err != nil { if _, err := c.renderInsert(qc, w, vars, ti); err != nil {
return 0, err return 0, err
} }
case qcode.ActionUpdate: case qcode.QTUpdate:
if _, err := c.renderUpdate(qc, w, vars, ti); err != nil { if _, err := c.renderUpdate(qc, w, vars, ti); err != nil {
return 0, err return 0, err
} }
case qcode.ActionUpsert: case qcode.QTUpsert:
if _, err := c.renderUpsert(qc, w, vars, ti); err != nil { if _, err := c.renderUpsert(qc, w, vars, ti); err != nil {
return 0, err return 0, err
} }
case qcode.ActionDelete: case qcode.QTDelete:
if _, err := c.renderDelete(qc, w, vars, ti); err != nil { if _, err := c.renderDelete(qc, w, vars, ti); err != nil {
return 0, err return 0, err
} }
@ -56,7 +55,7 @@ func (co *Compiler) compileMutation(qc *qcode.QCode, w *bytes.Buffer, vars Varia
io.WriteString(c.w, ` RETURNING *) `) io.WriteString(c.w, ` RETURNING *) `)
root.Paging = zeroPaging root.Paging = noLimit
root.DistinctOn = root.DistinctOn[:] root.DistinctOn = root.DistinctOn[:]
root.OrderBy = root.OrderBy[:] root.OrderBy = root.OrderBy[:]
root.Where = nil root.Where = nil
@ -65,13 +64,12 @@ func (co *Compiler) compileMutation(qc *qcode.QCode, w *bytes.Buffer, vars Varia
return c.compileQuery(qc, w) return c.compileQuery(qc, w)
} }
func (c *compilerContext) renderInsert(qc *qcode.QCode, w *bytes.Buffer, func (c *compilerContext) renderInsert(qc *qcode.QCode, w io.Writer,
vars Variables, ti *DBTableInfo) (uint32, error) { vars Variables, ti *DBTableInfo) (uint32, error) {
root := &qc.Selects[0]
insert, ok := vars[root.ActionVar] insert, ok := vars[qc.ActionVar]
if !ok { if !ok {
return 0, fmt.Errorf("Variable '%s' not defined", root.ActionVar) return 0, fmt.Errorf("Variable '%s' not defined", qc.ActionVar)
} }
jt, array, err := jsn.Tree(insert) jt, array, err := jsn.Tree(insert)
@ -79,56 +77,62 @@ func (c *compilerContext) renderInsert(qc *qcode.QCode, w *bytes.Buffer,
return 0, err return 0, err
} }
c.w.WriteString(`(WITH "input" AS (SELECT {{`) io.WriteString(c.w, `(WITH "input" AS (SELECT {{`)
c.w.WriteString(root.ActionVar) io.WriteString(c.w, qc.ActionVar)
c.w.WriteString(`}}::json AS j) INSERT INTO `) io.WriteString(c.w, `}}::json AS j) INSERT INTO `)
quoted(c.w, ti.Name) quoted(c.w, ti.Name)
io.WriteString(c.w, ` (`) io.WriteString(c.w, ` (`)
c.renderInsertUpdateColumns(qc, w, jt, ti) c.renderInsertUpdateColumns(qc, w, jt, ti)
io.WriteString(c.w, `)`) io.WriteString(c.w, `)`)
c.w.WriteString(` SELECT `) io.WriteString(c.w, ` SELECT `)
c.renderInsertUpdateColumns(qc, w, jt, ti) c.renderInsertUpdateColumns(qc, w, jt, ti)
c.w.WriteString(` FROM input i, `) io.WriteString(c.w, ` FROM input i, `)
if array { if array {
c.w.WriteString(`json_populate_recordset`) io.WriteString(c.w, `json_populate_recordset`)
} else { } else {
c.w.WriteString(`json_populate_record`) io.WriteString(c.w, `json_populate_record`)
} }
c.w.WriteString(`(NULL::`) io.WriteString(c.w, `(NULL::`)
c.w.WriteString(ti.Name) io.WriteString(c.w, ti.Name)
c.w.WriteString(`, i.j) t`) io.WriteString(c.w, `, i.j) t`)
return 0, nil return 0, nil
} }
func (c *compilerContext) renderInsertUpdateColumns(qc *qcode.QCode, w *bytes.Buffer, func (c *compilerContext) renderInsertUpdateColumns(qc *qcode.QCode, w io.Writer,
jt map[string]interface{}, ti *DBTableInfo) (uint32, error) { jt map[string]interface{}, ti *DBTableInfo) (uint32, error) {
root := &qc.Selects[0]
i := 0 i := 0
for _, cn := range ti.ColumnNames { for _, cn := range ti.ColumnNames {
if _, ok := jt[cn]; !ok { if _, ok := jt[cn]; !ok {
continue continue
} }
if len(root.Allowed) != 0 {
if _, ok := root.Allowed[cn]; !ok {
continue
}
}
if i != 0 { if i != 0 {
io.WriteString(c.w, `, `) io.WriteString(c.w, `, `)
} }
c.w.WriteString(cn) io.WriteString(c.w, cn)
i++ i++
} }
return 0, nil return 0, nil
} }
func (c *compilerContext) renderUpdate(qc *qcode.QCode, w *bytes.Buffer, func (c *compilerContext) renderUpdate(qc *qcode.QCode, w io.Writer,
vars Variables, ti *DBTableInfo) (uint32, error) { vars Variables, ti *DBTableInfo) (uint32, error) {
root := &qc.Selects[0] root := &qc.Selects[0]
update, ok := vars[root.ActionVar] update, ok := vars[qc.ActionVar]
if !ok { if !ok {
return 0, fmt.Errorf("Variable '%s' not defined", root.ActionVar) return 0, fmt.Errorf("Variable '%s' not defined", qc.ActionVar)
} }
jt, array, err := jsn.Tree(update) jt, array, err := jsn.Tree(update)
@ -136,26 +140,26 @@ func (c *compilerContext) renderUpdate(qc *qcode.QCode, w *bytes.Buffer,
return 0, err return 0, err
} }
c.w.WriteString(`(WITH "input" AS (SELECT {{`) io.WriteString(c.w, `(WITH "input" AS (SELECT {{`)
c.w.WriteString(root.ActionVar) io.WriteString(c.w, qc.ActionVar)
c.w.WriteString(`}}::json AS j) UPDATE `) io.WriteString(c.w, `}}::json AS j) UPDATE `)
quoted(c.w, ti.Name) quoted(c.w, ti.Name)
io.WriteString(c.w, ` SET (`) io.WriteString(c.w, ` SET (`)
c.renderInsertUpdateColumns(qc, w, jt, ti) c.renderInsertUpdateColumns(qc, w, jt, ti)
c.w.WriteString(`) = (SELECT `) io.WriteString(c.w, `) = (SELECT `)
c.renderInsertUpdateColumns(qc, w, jt, ti) c.renderInsertUpdateColumns(qc, w, jt, ti)
c.w.WriteString(` FROM input i, `) io.WriteString(c.w, ` FROM input i, `)
if array { if array {
c.w.WriteString(`json_populate_recordset`) io.WriteString(c.w, `json_populate_recordset`)
} else { } else {
c.w.WriteString(`json_populate_record`) io.WriteString(c.w, `json_populate_record`)
} }
c.w.WriteString(`(NULL::`) io.WriteString(c.w, `(NULL::`)
c.w.WriteString(ti.Name) io.WriteString(c.w, ti.Name)
c.w.WriteString(`, i.j) t)`) io.WriteString(c.w, `, i.j) t)`)
io.WriteString(c.w, ` WHERE `) io.WriteString(c.w, ` WHERE `)
@ -166,11 +170,11 @@ func (c *compilerContext) renderUpdate(qc *qcode.QCode, w *bytes.Buffer,
return 0, nil return 0, nil
} }
func (c *compilerContext) renderDelete(qc *qcode.QCode, w *bytes.Buffer, func (c *compilerContext) renderDelete(qc *qcode.QCode, w io.Writer,
vars Variables, ti *DBTableInfo) (uint32, error) { vars Variables, ti *DBTableInfo) (uint32, error) {
root := &qc.Selects[0] root := &qc.Selects[0]
c.w.WriteString(`(DELETE FROM `) io.WriteString(c.w, `(DELETE FROM `)
quoted(c.w, ti.Name) quoted(c.w, ti.Name)
io.WriteString(c.w, ` WHERE `) io.WriteString(c.w, ` WHERE `)
@ -181,13 +185,12 @@ func (c *compilerContext) renderDelete(qc *qcode.QCode, w *bytes.Buffer,
return 0, nil return 0, nil
} }
func (c *compilerContext) renderUpsert(qc *qcode.QCode, w *bytes.Buffer, func (c *compilerContext) renderUpsert(qc *qcode.QCode, w io.Writer,
vars Variables, ti *DBTableInfo) (uint32, error) { vars Variables, ti *DBTableInfo) (uint32, error) {
root := &qc.Selects[0]
upsert, ok := vars[root.ActionVar] upsert, ok := vars[qc.ActionVar]
if !ok { if !ok {
return 0, fmt.Errorf("Variable '%s' not defined", root.ActionVar) return 0, fmt.Errorf("Variable '%s' not defined", qc.ActionVar)
} }
jt, _, err := jsn.Tree(upsert) jt, _, err := jsn.Tree(upsert)
@ -199,7 +202,7 @@ func (c *compilerContext) renderUpsert(qc *qcode.QCode, w *bytes.Buffer,
return 0, err return 0, err
} }
c.w.WriteString(` ON CONFLICT DO (`) io.WriteString(c.w, ` ON CONFLICT DO (`)
i := 0 i := 0
for _, cn := range ti.ColumnNames { for _, cn := range ti.ColumnNames {
@ -214,15 +217,15 @@ func (c *compilerContext) renderUpsert(qc *qcode.QCode, w *bytes.Buffer,
if i != 0 { if i != 0 {
io.WriteString(c.w, `, `) io.WriteString(c.w, `, `)
} }
c.w.WriteString(cn) io.WriteString(c.w, cn)
i++ i++
} }
if i == 0 { if i == 0 {
c.w.WriteString(ti.PrimaryCol) io.WriteString(c.w, ti.PrimaryCol)
} }
c.w.WriteString(`) DO `) io.WriteString(c.w, `) DO `)
c.w.WriteString(`UPDATE `) io.WriteString(c.w, `UPDATE `)
io.WriteString(c.w, ` SET `) io.WriteString(c.w, ` SET `)
i = 0 i = 0
@ -233,17 +236,17 @@ func (c *compilerContext) renderUpsert(qc *qcode.QCode, w *bytes.Buffer,
if i != 0 { if i != 0 {
io.WriteString(c.w, `, `) io.WriteString(c.w, `, `)
} }
c.w.WriteString(cn) io.WriteString(c.w, cn)
io.WriteString(c.w, ` = EXCLUDED.`) io.WriteString(c.w, ` = EXCLUDED.`)
c.w.WriteString(cn) io.WriteString(c.w, cn)
i++ i++
} }
return 0, nil return 0, nil
} }
func quoted(w *bytes.Buffer, identifier string) { func quoted(w io.Writer, identifier string) {
w.WriteString(`"`) io.WriteString(w, `"`)
w.WriteString(identifier) io.WriteString(w, identifier)
w.WriteString(`"`) io.WriteString(w, `"`)
} }

View File

@ -12,13 +12,13 @@ func simpleInsert(t *testing.T) {
} }
}` }`
sql := `WITH "users" AS (WITH "input" AS (SELECT {{data}}::json AS j) INSERT INTO "users" (full_name, email) SELECT full_name, email FROM input i, json_populate_record(NULL::users, i.j) t RETURNING *) SELECT json_object_agg('user', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."id" AS "id") AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."id" FROM "users") AS "users_0") AS "done_1337";` sql := `WITH "users" AS (WITH "input" AS (SELECT {{data}}::json AS j) INSERT INTO "users" (full_name, email) SELECT full_name, email FROM input i, json_populate_record(NULL::users, i.j) t RETURNING *) SELECT json_object_agg('user', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."id" AS "id") AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."id" FROM "users") AS "users_0") AS "done_1337"`
vars := map[string]json.RawMessage{ vars := map[string]json.RawMessage{
"data": json.RawMessage(`{"email": "reannagreenholt@orn.com", "full_name": "Flo Barton"}`), "data": json.RawMessage(`{"email": "reannagreenholt@orn.com", "full_name": "Flo Barton"}`),
} }
resSQL, err := compileGQLToPSQL(gql, vars) resSQL, err := compileGQLToPSQL(gql, vars, "user")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -36,13 +36,13 @@ func singleInsert(t *testing.T) {
} }
}` }`
sql := `WITH "products" AS (WITH "input" AS (SELECT {{insert}}::json AS j) INSERT INTO "products" (name, description, user_id) SELECT name, description, user_id FROM input i, json_populate_record(NULL::products, i.j) t RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337";` sql := `WITH "products" AS (WITH "input" AS (SELECT {{insert}}::json AS j) INSERT INTO "products" (name, description, user_id) SELECT name, description, user_id FROM input i, json_populate_record(NULL::products, i.j) t RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"`
vars := map[string]json.RawMessage{ vars := map[string]json.RawMessage{
"insert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc", "user_id": 5 }`), "insert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc", "user_id": 5 }`),
} }
resSQL, err := compileGQLToPSQL(gql, vars) resSQL, err := compileGQLToPSQL(gql, vars, "user")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -54,19 +54,19 @@ func singleInsert(t *testing.T) {
func bulkInsert(t *testing.T) { func bulkInsert(t *testing.T) {
gql := `mutation { gql := `mutation {
product(id: 15, insert: $insert) { product(name: "test", id: 15, insert: $insert) {
id id
name name
} }
}` }`
sql := `WITH "products" AS (WITH "input" AS (SELECT {{insert}}::json AS j) INSERT INTO "products" (name, description) SELECT name, description FROM input i, json_populate_recordset(NULL::products, i.j) t RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337";` sql := `WITH "products" AS (WITH "input" AS (SELECT {{insert}}::json AS j) INSERT INTO "products" (name, description) SELECT name, description FROM input i, json_populate_recordset(NULL::products, i.j) t RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"`
vars := map[string]json.RawMessage{ vars := map[string]json.RawMessage{
"insert": json.RawMessage(` [{ "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }]`), "insert": json.RawMessage(` [{ "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }]`),
} }
resSQL, err := compileGQLToPSQL(gql, vars) resSQL, err := compileGQLToPSQL(gql, vars, "user")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -84,13 +84,13 @@ func singleUpsert(t *testing.T) {
} }
}` }`
sql := `WITH "products" AS (WITH "input" AS (SELECT {{upsert}}::json AS j) INSERT INTO "products" (name, description) SELECT name, description FROM input i, json_populate_record(NULL::products, i.j) t ON CONFLICT DO (id) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337";` sql := `WITH "products" AS (WITH "input" AS (SELECT {{upsert}}::json AS j) INSERT INTO "products" (name, description) SELECT name, description FROM input i, json_populate_record(NULL::products, i.j) t ON CONFLICT DO (id) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"`
vars := map[string]json.RawMessage{ vars := map[string]json.RawMessage{
"upsert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`), "upsert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`),
} }
resSQL, err := compileGQLToPSQL(gql, vars) resSQL, err := compileGQLToPSQL(gql, vars, "user")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -108,13 +108,13 @@ func bulkUpsert(t *testing.T) {
} }
}` }`
sql := `WITH "products" AS (WITH "input" AS (SELECT {{upsert}}::json AS j) INSERT INTO "products" (name, description) SELECT name, description FROM input i, json_populate_recordset(NULL::products, i.j) t ON CONFLICT DO (id) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337";` sql := `WITH "products" AS (WITH "input" AS (SELECT {{upsert}}::json AS j) INSERT INTO "products" (name, description) SELECT name, description FROM input i, json_populate_recordset(NULL::products, i.j) t ON CONFLICT DO (id) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"`
vars := map[string]json.RawMessage{ vars := map[string]json.RawMessage{
"upsert": json.RawMessage(` [{ "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }]`), "upsert": json.RawMessage(` [{ "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }]`),
} }
resSQL, err := compileGQLToPSQL(gql, vars) resSQL, err := compileGQLToPSQL(gql, vars, "user")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -132,13 +132,13 @@ func singleUpdate(t *testing.T) {
} }
}` }`
sql := `WITH "products" AS (WITH "input" AS (SELECT {{update}}::json AS j) UPDATE "products" SET (name, description) = (SELECT name, description FROM input i, json_populate_record(NULL::products, i.j) t) WHERE (("products"."user_id") = {{user_id}}) AND (("products"."id") = 1) AND (("products"."id") = 15) RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337";` sql := `WITH "products" AS (WITH "input" AS (SELECT {{update}}::json AS j) UPDATE "products" SET (name, description) = (SELECT name, description FROM input i, json_populate_record(NULL::products, i.j) t) WHERE (("products"."user_id") = {{user_id}}) AND (("products"."id") = 1) AND (("products"."id") = 15) RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"`
vars := map[string]json.RawMessage{ vars := map[string]json.RawMessage{
"update": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`), "update": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`),
} }
resSQL, err := compileGQLToPSQL(gql, vars) resSQL, err := compileGQLToPSQL(gql, vars, "user")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -156,13 +156,13 @@ func delete(t *testing.T) {
} }
}` }`
sql := `WITH "products" AS (DELETE FROM "products" WHERE (("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") = 1) RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337";` sql := `WITH "products" AS (DELETE FROM "products" WHERE (("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") = 1) RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"`
vars := map[string]json.RawMessage{ vars := map[string]json.RawMessage{
"update": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`), "update": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`),
} }
resSQL, err := compileGQLToPSQL(gql, vars) resSQL, err := compileGQLToPSQL(gql, vars, "user")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -172,7 +172,7 @@ func delete(t *testing.T) {
} }
} }
func TestCompileInsert(t *testing.T) { func TestCompileMutate(t *testing.T) {
t.Run("simpleInsert", simpleInsert) t.Run("simpleInsert", simpleInsert)
t.Run("singleInsert", singleInsert) t.Run("singleInsert", singleInsert)
t.Run("bulkInsert", bulkInsert) t.Run("bulkInsert", bulkInsert)

View File

@ -49,7 +49,7 @@ func (c *Compiler) IDColumn(table string) (string, error) {
} }
type compilerContext struct { type compilerContext struct {
w *bytes.Buffer w io.Writer
s []qcode.Select s []qcode.Select
*Compiler *Compiler
} }
@ -60,18 +60,18 @@ func (co *Compiler) CompileEx(qc *qcode.QCode, vars Variables) (uint32, []byte,
return skipped, w.Bytes(), err return skipped, w.Bytes(), err
} }
func (co *Compiler) Compile(qc *qcode.QCode, w *bytes.Buffer, vars Variables) (uint32, error) { func (co *Compiler) Compile(qc *qcode.QCode, w io.Writer, vars Variables) (uint32, error) {
switch qc.Type { switch qc.Type {
case qcode.QTQuery: case qcode.QTQuery:
return co.compileQuery(qc, w) return co.compileQuery(qc, w)
case qcode.QTMutation: case qcode.QTInsert, qcode.QTUpdate, qcode.QTDelete, qcode.QTUpsert:
return co.compileMutation(qc, w, vars) return co.compileMutation(qc, w, vars)
} }
return 0, errors.New("unknown operation") return 0, fmt.Errorf("Unknown operation type %d", qc.Type)
} }
func (co *Compiler) compileQuery(qc *qcode.QCode, w *bytes.Buffer) (uint32, error) { func (co *Compiler) compileQuery(qc *qcode.QCode, w io.Writer) (uint32, error) {
if len(qc.Selects) == 0 { if len(qc.Selects) == 0 {
return 0, errors.New("empty query") return 0, errors.New("empty query")
} }
@ -90,17 +90,17 @@ func (co *Compiler) compileQuery(qc *qcode.QCode, w *bytes.Buffer) (uint32, erro
//fmt.Fprintf(w, `SELECT json_object_agg('%s', %s) FROM (`, //fmt.Fprintf(w, `SELECT json_object_agg('%s', %s) FROM (`,
//root.FieldName, root.Table) //root.FieldName, root.Table)
c.w.WriteString(`SELECT json_object_agg('`) io.WriteString(c.w, `SELECT json_object_agg('`)
c.w.WriteString(root.FieldName) io.WriteString(c.w, root.FieldName)
c.w.WriteString(`', `) io.WriteString(c.w, `', `)
if ti.Singular == false { if ti.Singular == false {
c.w.WriteString(root.Table) io.WriteString(c.w, root.Table)
} else { } else {
c.w.WriteString("sel_json_") io.WriteString(c.w, "sel_json_")
int2string(c.w, root.ID) int2string(c.w, root.ID)
} }
c.w.WriteString(`) FROM (`) io.WriteString(c.w, `) FROM (`)
var ignored uint32 var ignored uint32
@ -161,9 +161,8 @@ func (co *Compiler) compileQuery(qc *qcode.QCode, w *bytes.Buffer) (uint32, erro
} }
} }
c.w.WriteString(`)`) io.WriteString(c.w, `)`)
alias(c.w, `done_1337`) alias(c.w, `done_1337`)
c.w.WriteString(`;`)
return ignored, nil return ignored, nil
} }
@ -219,10 +218,10 @@ func (c *compilerContext) renderSelect(sel *qcode.Select, ti *DBTableInfo) (uint
// SELECT // SELECT
if ti.Singular == false { if ti.Singular == false {
//fmt.Fprintf(w, `SELECT coalesce(json_agg("%s"`, c.sel.Table) //fmt.Fprintf(w, `SELECT coalesce(json_agg("%s"`, c.sel.Table)
c.w.WriteString(`SELECT coalesce(json_agg("`) io.WriteString(c.w, `SELECT coalesce(json_agg("`)
c.w.WriteString("sel_json_") io.WriteString(c.w, "sel_json_")
int2string(c.w, sel.ID) int2string(c.w, sel.ID)
c.w.WriteString(`"`) io.WriteString(c.w, `"`)
if hasOrder { if hasOrder {
err := c.renderOrderBy(sel, ti) err := c.renderOrderBy(sel, ti)
@ -232,24 +231,24 @@ func (c *compilerContext) renderSelect(sel *qcode.Select, ti *DBTableInfo) (uint
} }
//fmt.Fprintf(w, `), '[]') AS "%s" FROM (`, c.sel.Table) //fmt.Fprintf(w, `), '[]') AS "%s" FROM (`, c.sel.Table)
c.w.WriteString(`), '[]')`) io.WriteString(c.w, `), '[]')`)
alias(c.w, sel.Table) alias(c.w, sel.Table)
c.w.WriteString(` FROM (`) io.WriteString(c.w, ` FROM (`)
} }
// ROW-TO-JSON // ROW-TO-JSON
c.w.WriteString(`SELECT `) io.WriteString(c.w, `SELECT `)
if len(sel.DistinctOn) != 0 { if len(sel.DistinctOn) != 0 {
c.renderDistinctOn(sel, ti) c.renderDistinctOn(sel, ti)
} }
c.w.WriteString(`row_to_json((`) io.WriteString(c.w, `row_to_json((`)
//fmt.Fprintf(w, `SELECT "sel_%d" FROM (SELECT `, c.sel.ID) //fmt.Fprintf(w, `SELECT "sel_%d" FROM (SELECT `, c.sel.ID)
c.w.WriteString(`SELECT "sel_`) io.WriteString(c.w, `SELECT "sel_`)
int2string(c.w, sel.ID) int2string(c.w, sel.ID)
c.w.WriteString(`" FROM (SELECT `) io.WriteString(c.w, `" FROM (SELECT `)
// Combined column names // Combined column names
c.renderColumns(sel, ti) c.renderColumns(sel, ti)
@ -262,11 +261,11 @@ func (c *compilerContext) renderSelect(sel *qcode.Select, ti *DBTableInfo) (uint
} }
//fmt.Fprintf(w, `) AS "sel_%d"`, c.sel.ID) //fmt.Fprintf(w, `) AS "sel_%d"`, c.sel.ID)
c.w.WriteString(`)`) io.WriteString(c.w, `)`)
aliasWithID(c.w, "sel", sel.ID) aliasWithID(c.w, "sel", sel.ID)
//fmt.Fprintf(w, `)) AS "%s"`, c.sel.Table) //fmt.Fprintf(w, `)) AS "%s"`, c.sel.Table)
c.w.WriteString(`))`) io.WriteString(c.w, `))`)
aliasWithID(c.w, "sel_json", sel.ID) aliasWithID(c.w, "sel_json", sel.ID)
// END-ROW-TO-JSON // END-ROW-TO-JSON
@ -295,31 +294,33 @@ func (c *compilerContext) renderSelectClose(sel *qcode.Select, ti *DBTableInfo)
} }
} }
if sel.Action == 0 { switch {
if len(sel.Paging.Limit) != 0 { case sel.Paging.NoLimit:
break
case len(sel.Paging.Limit) != 0:
//fmt.Fprintf(w, ` LIMIT ('%s') :: integer`, c.sel.Paging.Limit) //fmt.Fprintf(w, ` LIMIT ('%s') :: integer`, c.sel.Paging.Limit)
c.w.WriteString(` LIMIT ('`) io.WriteString(c.w, ` LIMIT ('`)
c.w.WriteString(sel.Paging.Limit) io.WriteString(c.w, sel.Paging.Limit)
c.w.WriteString(`') :: integer`) io.WriteString(c.w, `') :: integer`)
} else if ti.Singular { case ti.Singular:
c.w.WriteString(` LIMIT ('1') :: integer`) io.WriteString(c.w, ` LIMIT ('1') :: integer`)
} else { default:
c.w.WriteString(` LIMIT ('20') :: integer`) io.WriteString(c.w, ` LIMIT ('20') :: integer`)
}
} }
if len(sel.Paging.Offset) != 0 { if len(sel.Paging.Offset) != 0 {
//fmt.Fprintf(w, ` OFFSET ('%s') :: integer`, c.sel.Paging.Offset) //fmt.Fprintf(w, ` OFFSET ('%s') :: integer`, c.sel.Paging.Offset)
c.w.WriteString(`OFFSET ('`) io.WriteString(c.w, `OFFSET ('`)
c.w.WriteString(sel.Paging.Offset) io.WriteString(c.w, sel.Paging.Offset)
c.w.WriteString(`') :: integer`) io.WriteString(c.w, `') :: integer`)
} }
if ti.Singular == false { if ti.Singular == false {
//fmt.Fprintf(w, `) AS "sel_json_agg_%d"`, c.sel.ID) //fmt.Fprintf(w, `) AS "sel_json_agg_%d"`, c.sel.ID)
c.w.WriteString(`)`) io.WriteString(c.w, `)`)
aliasWithID(c.w, "sel_json_agg", sel.ID) aliasWithID(c.w, "sel_json_agg", sel.ID)
} }
@ -327,15 +328,15 @@ func (c *compilerContext) renderSelectClose(sel *qcode.Select, ti *DBTableInfo)
} }
func (c *compilerContext) renderJoin(sel *qcode.Select) error { func (c *compilerContext) renderJoin(sel *qcode.Select) error {
c.w.WriteString(` LEFT OUTER JOIN LATERAL (`) io.WriteString(c.w, ` LEFT OUTER JOIN LATERAL (`)
return nil return nil
} }
func (c *compilerContext) renderJoinClose(sel *qcode.Select) error { func (c *compilerContext) renderJoinClose(sel *qcode.Select) error {
//fmt.Fprintf(w, `) AS "%s_%d_join" ON ('true')`, c.sel.Table, c.sel.ID) //fmt.Fprintf(w, `) AS "%s_%d_join" ON ('true')`, c.sel.Table, c.sel.ID)
c.w.WriteString(`)`) io.WriteString(c.w, `)`)
aliasWithIDSuffix(c.w, sel.Table, sel.ID, "_join") aliasWithIDSuffix(c.w, sel.Table, sel.ID, "_join")
c.w.WriteString(` ON ('true')`) io.WriteString(c.w, ` ON ('true')`)
return nil return nil
} }
@ -358,25 +359,43 @@ func (c *compilerContext) renderJoinTable(sel *qcode.Select) error {
//fmt.Fprintf(w, ` LEFT OUTER JOIN "%s" ON (("%s"."%s") = ("%s_%d"."%s"))`, //fmt.Fprintf(w, ` LEFT OUTER JOIN "%s" ON (("%s"."%s") = ("%s_%d"."%s"))`,
//rel.Through, rel.Through, rel.ColT, c.parent.Table, c.parent.ID, rel.Col1) //rel.Through, rel.Through, rel.ColT, c.parent.Table, c.parent.ID, rel.Col1)
c.w.WriteString(` LEFT OUTER JOIN "`) io.WriteString(c.w, ` LEFT OUTER JOIN "`)
c.w.WriteString(rel.Through) io.WriteString(c.w, rel.Through)
c.w.WriteString(`" ON ((`) io.WriteString(c.w, `" ON ((`)
colWithTable(c.w, rel.Through, rel.ColT) colWithTable(c.w, rel.Through, rel.ColT)
c.w.WriteString(`) = (`) io.WriteString(c.w, `) = (`)
colWithTableID(c.w, pt.Name, parent.ID, rel.Col1) colWithTableID(c.w, pt.Name, parent.ID, rel.Col1)
c.w.WriteString(`))`) io.WriteString(c.w, `))`)
return nil return nil
} }
func (c *compilerContext) renderColumns(sel *qcode.Select, ti *DBTableInfo) { func (c *compilerContext) renderColumns(sel *qcode.Select, ti *DBTableInfo) {
for i, col := range sel.Cols { i := 0
for _, col := range sel.Cols {
if len(sel.Allowed) != 0 {
n := funcPrefixLen(col.Name)
if n != 0 {
if sel.Functions == false {
continue
}
if _, ok := sel.Allowed[col.Name[n:]]; !ok {
continue
}
} else {
if _, ok := sel.Allowed[col.Name]; !ok {
continue
}
}
}
if i != 0 { if i != 0 {
io.WriteString(c.w, ", ") io.WriteString(c.w, ", ")
} }
//fmt.Fprintf(w, `"%s_%d"."%s" AS "%s"`, //fmt.Fprintf(w, `"%s_%d"."%s" AS "%s"`,
//c.sel.Table, c.sel.ID, col.Name, col.FieldName) //c.sel.Table, c.sel.ID, col.Name, col.FieldName)
colWithTableIDAlias(c.w, ti.Name, sel.ID, col.Name, col.FieldName) colWithTableIDAlias(c.w, ti.Name, sel.ID, col.Name, col.FieldName)
i++
} }
} }
@ -415,11 +434,25 @@ func (c *compilerContext) renderJoinedColumns(sel *qcode.Select, ti *DBTableInfo
} }
childSel := &c.s[id] childSel := &c.s[id]
cti, err := c.schema.GetTable(childSel.Table)
if err != nil {
continue
}
//fmt.Fprintf(w, `"%s_%d_join"."%s" AS "%s"`, //fmt.Fprintf(w, `"%s_%d_join"."%s" AS "%s"`,
//s.Table, s.ID, s.Table, s.FieldName) //s.Table, s.ID, s.Table, s.FieldName)
if cti.Singular {
io.WriteString(c.w, `"sel_json_`)
int2string(c.w, childSel.ID)
io.WriteString(c.w, `" AS "`)
io.WriteString(c.w, childSel.FieldName)
io.WriteString(c.w, `"`)
} else {
colWithTableIDSuffixAlias(c.w, childSel.Table, childSel.ID, colWithTableIDSuffixAlias(c.w, childSel.Table, childSel.ID,
"_join", childSel.Table, childSel.FieldName) "_join", childSel.Table, childSel.FieldName)
} }
}
return nil return nil
} }
@ -433,9 +466,10 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
isSearch := sel.Args["search"] != nil isSearch := sel.Args["search"] != nil
isAgg := false isAgg := false
c.w.WriteString(` FROM (SELECT `) io.WriteString(c.w, ` FROM (SELECT `)
for i, col := range sel.Cols { i := 0
for n, col := range sel.Cols {
cn := col.Name cn := col.Name
_, isRealCol := ti.Columns[cn] _, isRealCol := ti.Columns[cn]
@ -447,93 +481,116 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
cn = ti.TSVCol cn = ti.TSVCol
arg := sel.Args["search"] arg := sel.Args["search"]
if i != 0 {
io.WriteString(c.w, `, `)
}
//fmt.Fprintf(w, `ts_rank("%s"."%s", to_tsquery('%s')) AS %s`, //fmt.Fprintf(w, `ts_rank("%s"."%s", to_tsquery('%s')) AS %s`,
//c.sel.Table, cn, arg.Val, col.Name) //c.sel.Table, cn, arg.Val, col.Name)
c.w.WriteString(`ts_rank(`) io.WriteString(c.w, `ts_rank(`)
colWithTable(c.w, ti.Name, cn) colWithTable(c.w, ti.Name, cn)
c.w.WriteString(`, to_tsquery('`) io.WriteString(c.w, `, to_tsquery('`)
c.w.WriteString(arg.Val) io.WriteString(c.w, arg.Val)
c.w.WriteString(`')`) io.WriteString(c.w, `')`)
alias(c.w, col.Name) alias(c.w, col.Name)
i++
case strings.HasPrefix(cn, "search_headline_"): case strings.HasPrefix(cn, "search_headline_"):
cn = cn[16:] cn = cn[16:]
arg := sel.Args["search"] arg := sel.Args["search"]
if i != 0 {
io.WriteString(c.w, `, `)
}
//fmt.Fprintf(w, `ts_headline("%s"."%s", to_tsquery('%s')) AS %s`, //fmt.Fprintf(w, `ts_headline("%s"."%s", to_tsquery('%s')) AS %s`,
//c.sel.Table, cn, arg.Val, col.Name) //c.sel.Table, cn, arg.Val, col.Name)
c.w.WriteString(`ts_headlinek(`) io.WriteString(c.w, `ts_headlinek(`)
colWithTable(c.w, ti.Name, cn) colWithTable(c.w, ti.Name, cn)
c.w.WriteString(`, to_tsquery('`) io.WriteString(c.w, `, to_tsquery('`)
c.w.WriteString(arg.Val) io.WriteString(c.w, arg.Val)
c.w.WriteString(`')`) io.WriteString(c.w, `')`)
alias(c.w, col.Name) alias(c.w, col.Name)
i++
} }
} else { } else {
pl := funcPrefixLen(cn) pl := funcPrefixLen(cn)
if pl == 0 { if pl == 0 {
//fmt.Fprintf(w, `'%s not defined' AS %s`, cn, col.Name)
c.w.WriteString(`'`)
c.w.WriteString(cn)
c.w.WriteString(` not defined'`)
alias(c.w, col.Name)
} else {
isAgg = true
fn := cn[0 : pl-1]
cn := cn[pl:]
//fmt.Fprintf(w, `%s("%s"."%s") AS %s`, fn, c.sel.Table, cn, col.Name)
c.w.WriteString(fn)
c.w.WriteString(`(`)
colWithTable(c.w, ti.Name, cn)
c.w.WriteString(`)`)
alias(c.w, col.Name)
}
}
} else {
groupBy = append(groupBy, i)
//fmt.Fprintf(w, `"%s"."%s"`, c.sel.Table, cn)
colWithTable(c.w, ti.Name, cn)
}
if i < len(sel.Cols)-1 || len(childCols) != 0 {
//io.WriteString(w, ", ")
c.w.WriteString(`, `)
}
}
for i, col := range childCols {
if i != 0 { if i != 0 {
//io.WriteString(w, ", ") io.WriteString(c.w, `, `)
c.w.WriteString(`, `) }
//fmt.Fprintf(w, `'%s not defined' AS %s`, cn, col.Name)
io.WriteString(c.w, `'`)
io.WriteString(c.w, cn)
io.WriteString(c.w, ` not defined'`)
alias(c.w, col.Name)
i++
} else if sel.Functions {
cn1 := cn[pl:]
if _, ok := sel.Allowed[cn1]; !ok {
continue
}
if i != 0 {
io.WriteString(c.w, `, `)
}
fn := cn[0 : pl-1]
isAgg = true
//fmt.Fprintf(w, `%s("%s"."%s") AS %s`, fn, c.sel.Table, cn, col.Name)
io.WriteString(c.w, fn)
io.WriteString(c.w, `(`)
colWithTable(c.w, ti.Name, cn1)
io.WriteString(c.w, `)`)
alias(c.w, col.Name)
i++
}
}
} else {
groupBy = append(groupBy, n)
//fmt.Fprintf(w, `"%s"."%s"`, c.sel.Table, cn)
if i != 0 {
io.WriteString(c.w, `, `)
}
colWithTable(c.w, ti.Name, cn)
i++
}
}
for _, col := range childCols {
if i != 0 {
io.WriteString(c.w, `, `)
} }
//fmt.Fprintf(w, `"%s"."%s"`, col.Table, col.Name) //fmt.Fprintf(w, `"%s"."%s"`, col.Table, col.Name)
colWithTable(c.w, col.Table, col.Name) colWithTable(c.w, col.Table, col.Name)
i++
} }
c.w.WriteString(` FROM `) io.WriteString(c.w, ` FROM `)
//fmt.Fprintf(w, ` FROM "%s"`, c.sel.Table) //fmt.Fprintf(w, ` FROM "%s"`, c.sel.Table)
c.w.WriteString(`"`) io.WriteString(c.w, `"`)
c.w.WriteString(ti.Name) io.WriteString(c.w, ti.Name)
c.w.WriteString(`"`) io.WriteString(c.w, `"`)
// if tn, ok := c.tmap[sel.Table]; ok { // if tn, ok := c.tmap[sel.Table]; ok {
// //fmt.Fprintf(w, ` FROM "%s" AS "%s"`, tn, c.sel.Table) // //fmt.Fprintf(w, ` FROM "%s" AS "%s"`, tn, c.sel.Table)
// tableWithAlias(c.w, ti.Name, sel.Table) // tableWithAlias(c.w, ti.Name, sel.Table)
// } else { // } else {
// //fmt.Fprintf(w, ` FROM "%s"`, c.sel.Table) // //fmt.Fprintf(w, ` FROM "%s"`, c.sel.Table)
// c.w.WriteString(`"`) // io.WriteString(c.w, `"`)
// c.w.WriteString(sel.Table) // io.WriteString(c.w, sel.Table)
// c.w.WriteString(`"`) // io.WriteString(c.w, `"`)
// } // }
if isRoot && isFil { if isRoot && isFil {
c.w.WriteString(` WHERE (`) io.WriteString(c.w, ` WHERE (`)
if err := c.renderWhere(sel, ti); err != nil { if err := c.renderWhere(sel, ti); err != nil {
return err return err
} }
c.w.WriteString(`)`) io.WriteString(c.w, `)`)
} }
if !isRoot { if !isRoot {
@ -541,28 +598,28 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
return err return err
} }
c.w.WriteString(` WHERE (`) io.WriteString(c.w, ` WHERE (`)
if err := c.renderRelationship(sel, ti); err != nil { if err := c.renderRelationship(sel, ti); err != nil {
return err return err
} }
if isFil { if isFil {
c.w.WriteString(` AND `) io.WriteString(c.w, ` AND `)
if err := c.renderWhere(sel, ti); err != nil { if err := c.renderWhere(sel, ti); err != nil {
return err return err
} }
} }
c.w.WriteString(`)`) io.WriteString(c.w, `)`)
} }
if isAgg { if isAgg {
if len(groupBy) != 0 { if len(groupBy) != 0 {
c.w.WriteString(` GROUP BY `) io.WriteString(c.w, ` GROUP BY `)
for i, id := range groupBy { for i, id := range groupBy {
if i != 0 { if i != 0 {
c.w.WriteString(`, `) io.WriteString(c.w, `, `)
} }
//fmt.Fprintf(w, `"%s"."%s"`, c.sel.Table, c.sel.Cols[id].Name) //fmt.Fprintf(w, `"%s"."%s"`, c.sel.Table, c.sel.Cols[id].Name)
colWithTable(c.w, ti.Name, sel.Cols[id].Name) colWithTable(c.w, ti.Name, sel.Cols[id].Name)
@ -570,30 +627,32 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
} }
} }
if sel.Action == 0 { switch {
if len(sel.Paging.Limit) != 0 { case sel.Paging.NoLimit:
break
case len(sel.Paging.Limit) != 0:
//fmt.Fprintf(w, ` LIMIT ('%s') :: integer`, c.sel.Paging.Limit) //fmt.Fprintf(w, ` LIMIT ('%s') :: integer`, c.sel.Paging.Limit)
c.w.WriteString(` LIMIT ('`) io.WriteString(c.w, ` LIMIT ('`)
c.w.WriteString(sel.Paging.Limit) io.WriteString(c.w, sel.Paging.Limit)
c.w.WriteString(`') :: integer`) io.WriteString(c.w, `') :: integer`)
} else if ti.Singular { case ti.Singular:
c.w.WriteString(` LIMIT ('1') :: integer`) io.WriteString(c.w, ` LIMIT ('1') :: integer`)
} else { default:
c.w.WriteString(` LIMIT ('20') :: integer`) io.WriteString(c.w, ` LIMIT ('20') :: integer`)
}
} }
if len(sel.Paging.Offset) != 0 { if len(sel.Paging.Offset) != 0 {
//fmt.Fprintf(w, ` OFFSET ('%s') :: integer`, c.sel.Paging.Offset) //fmt.Fprintf(w, ` OFFSET ('%s') :: integer`, c.sel.Paging.Offset)
c.w.WriteString(` OFFSET ('`) io.WriteString(c.w, ` OFFSET ('`)
c.w.WriteString(sel.Paging.Offset) io.WriteString(c.w, sel.Paging.Offset)
c.w.WriteString(`') :: integer`) io.WriteString(c.w, `') :: integer`)
} }
//fmt.Fprintf(w, `) AS "%s_%d"`, c.sel.Table, c.sel.ID) //fmt.Fprintf(w, `) AS "%s_%d"`, c.sel.Table, c.sel.ID)
c.w.WriteString(`)`) io.WriteString(c.w, `)`)
aliasWithID(c.w, ti.Name, sel.ID) aliasWithID(c.w, ti.Name, sel.ID)
return nil return nil
} }
@ -604,7 +663,7 @@ func (c *compilerContext) renderOrderByColumns(sel *qcode.Select, ti *DBTableInf
for i := range sel.OrderBy { for i := range sel.OrderBy {
if colsRendered { if colsRendered {
//io.WriteString(w, ", ") //io.WriteString(w, ", ")
c.w.WriteString(`, `) io.WriteString(c.w, `, `)
} }
col := sel.OrderBy[i].Col col := sel.OrderBy[i].Col
@ -612,7 +671,7 @@ func (c *compilerContext) renderOrderByColumns(sel *qcode.Select, ti *DBTableInf
//c.sel.Table, c.sel.ID, c, //c.sel.Table, c.sel.ID, c,
//c.sel.Table, c.sel.ID, c) //c.sel.Table, c.sel.ID, c)
colWithTableID(c.w, ti.Name, sel.ID, col) colWithTableID(c.w, ti.Name, sel.ID, col)
c.w.WriteString(` AS `) io.WriteString(c.w, ` AS `)
tableIDColSuffix(c.w, sel.Table, sel.ID, col, "_ob") tableIDColSuffix(c.w, sel.Table, sel.ID, col, "_ob")
} }
} }
@ -629,29 +688,29 @@ func (c *compilerContext) renderRelationship(sel *qcode.Select, ti *DBTableInfo)
case RelBelongTo: case RelBelongTo:
//fmt.Fprintf(w, `(("%s"."%s") = ("%s_%d"."%s"))`, //fmt.Fprintf(w, `(("%s"."%s") = ("%s_%d"."%s"))`,
//c.sel.Table, rel.Col1, c.parent.Table, c.parent.ID, rel.Col2) //c.sel.Table, rel.Col1, c.parent.Table, c.parent.ID, rel.Col2)
c.w.WriteString(`((`) io.WriteString(c.w, `((`)
colWithTable(c.w, ti.Name, rel.Col1) colWithTable(c.w, ti.Name, rel.Col1)
c.w.WriteString(`) = (`) io.WriteString(c.w, `) = (`)
colWithTableID(c.w, parent.Table, parent.ID, rel.Col2) colWithTableID(c.w, parent.Table, parent.ID, rel.Col2)
c.w.WriteString(`))`) io.WriteString(c.w, `))`)
case RelOneToMany: case RelOneToMany:
//fmt.Fprintf(w, `(("%s"."%s") = ("%s_%d"."%s"))`, //fmt.Fprintf(w, `(("%s"."%s") = ("%s_%d"."%s"))`,
//c.sel.Table, rel.Col1, c.parent.Table, c.parent.ID, rel.Col2) //c.sel.Table, rel.Col1, c.parent.Table, c.parent.ID, rel.Col2)
c.w.WriteString(`((`) io.WriteString(c.w, `((`)
colWithTable(c.w, ti.Name, rel.Col1) colWithTable(c.w, ti.Name, rel.Col1)
c.w.WriteString(`) = (`) io.WriteString(c.w, `) = (`)
colWithTableID(c.w, parent.Table, parent.ID, rel.Col2) colWithTableID(c.w, parent.Table, parent.ID, rel.Col2)
c.w.WriteString(`))`) io.WriteString(c.w, `))`)
case RelOneToManyThrough: case RelOneToManyThrough:
//fmt.Fprintf(w, `(("%s"."%s") = ("%s"."%s"))`, //fmt.Fprintf(w, `(("%s"."%s") = ("%s"."%s"))`,
//c.sel.Table, rel.Col1, rel.Through, rel.Col2) //c.sel.Table, rel.Col1, rel.Through, rel.Col2)
c.w.WriteString(`((`) io.WriteString(c.w, `((`)
colWithTable(c.w, ti.Name, rel.Col1) colWithTable(c.w, ti.Name, rel.Col1)
c.w.WriteString(`) = (`) io.WriteString(c.w, `) = (`)
colWithTable(c.w, rel.Through, rel.Col2) colWithTable(c.w, rel.Through, rel.Col2)
c.w.WriteString(`))`) io.WriteString(c.w, `))`)
} }
return nil return nil
@ -675,11 +734,11 @@ func (c *compilerContext) renderWhere(sel *qcode.Select, ti *DBTableInfo) error
case qcode.ExpOp: case qcode.ExpOp:
switch val { switch val {
case qcode.OpAnd: case qcode.OpAnd:
c.w.WriteString(` AND `) io.WriteString(c.w, ` AND `)
case qcode.OpOr: case qcode.OpOr:
c.w.WriteString(` OR `) io.WriteString(c.w, ` OR `)
case qcode.OpNot: case qcode.OpNot:
c.w.WriteString(`NOT `) io.WriteString(c.w, `NOT `)
default: default:
return fmt.Errorf("11: unexpected value %v (%t)", intf, intf) return fmt.Errorf("11: unexpected value %v (%t)", intf, intf)
} }
@ -703,62 +762,62 @@ func (c *compilerContext) renderWhere(sel *qcode.Select, ti *DBTableInfo) error
default: default:
if val.NestedCol { if val.NestedCol {
//fmt.Fprintf(w, `(("%s") `, val.Col) //fmt.Fprintf(w, `(("%s") `, val.Col)
c.w.WriteString(`(("`) io.WriteString(c.w, `(("`)
c.w.WriteString(val.Col) io.WriteString(c.w, val.Col)
c.w.WriteString(`") `) io.WriteString(c.w, `") `)
} else if len(val.Col) != 0 { } else if len(val.Col) != 0 {
//fmt.Fprintf(w, `(("%s"."%s") `, c.sel.Table, val.Col) //fmt.Fprintf(w, `(("%s"."%s") `, c.sel.Table, val.Col)
c.w.WriteString(`((`) io.WriteString(c.w, `((`)
colWithTable(c.w, ti.Name, val.Col) colWithTable(c.w, ti.Name, val.Col)
c.w.WriteString(`) `) io.WriteString(c.w, `) `)
} }
valExists := true valExists := true
switch val.Op { switch val.Op {
case qcode.OpEquals: case qcode.OpEquals:
c.w.WriteString(`=`) io.WriteString(c.w, `=`)
case qcode.OpNotEquals: case qcode.OpNotEquals:
c.w.WriteString(`!=`) io.WriteString(c.w, `!=`)
case qcode.OpGreaterOrEquals: case qcode.OpGreaterOrEquals:
c.w.WriteString(`>=`) io.WriteString(c.w, `>=`)
case qcode.OpLesserOrEquals: case qcode.OpLesserOrEquals:
c.w.WriteString(`<=`) io.WriteString(c.w, `<=`)
case qcode.OpGreaterThan: case qcode.OpGreaterThan:
c.w.WriteString(`>`) io.WriteString(c.w, `>`)
case qcode.OpLesserThan: case qcode.OpLesserThan:
c.w.WriteString(`<`) io.WriteString(c.w, `<`)
case qcode.OpIn: case qcode.OpIn:
c.w.WriteString(`IN`) io.WriteString(c.w, `IN`)
case qcode.OpNotIn: case qcode.OpNotIn:
c.w.WriteString(`NOT IN`) io.WriteString(c.w, `NOT IN`)
case qcode.OpLike: case qcode.OpLike:
c.w.WriteString(`LIKE`) io.WriteString(c.w, `LIKE`)
case qcode.OpNotLike: case qcode.OpNotLike:
c.w.WriteString(`NOT LIKE`) io.WriteString(c.w, `NOT LIKE`)
case qcode.OpILike: case qcode.OpILike:
c.w.WriteString(`ILIKE`) io.WriteString(c.w, `ILIKE`)
case qcode.OpNotILike: case qcode.OpNotILike:
c.w.WriteString(`NOT ILIKE`) io.WriteString(c.w, `NOT ILIKE`)
case qcode.OpSimilar: case qcode.OpSimilar:
c.w.WriteString(`SIMILAR TO`) io.WriteString(c.w, `SIMILAR TO`)
case qcode.OpNotSimilar: case qcode.OpNotSimilar:
c.w.WriteString(`NOT SIMILAR TO`) io.WriteString(c.w, `NOT SIMILAR TO`)
case qcode.OpContains: case qcode.OpContains:
c.w.WriteString(`@>`) io.WriteString(c.w, `@>`)
case qcode.OpContainedIn: case qcode.OpContainedIn:
c.w.WriteString(`<@`) io.WriteString(c.w, `<@`)
case qcode.OpHasKey: case qcode.OpHasKey:
c.w.WriteString(`?`) io.WriteString(c.w, `?`)
case qcode.OpHasKeyAny: case qcode.OpHasKeyAny:
c.w.WriteString(`?|`) io.WriteString(c.w, `?|`)
case qcode.OpHasKeyAll: case qcode.OpHasKeyAll:
c.w.WriteString(`?&`) io.WriteString(c.w, `?&`)
case qcode.OpIsNull: case qcode.OpIsNull:
if strings.EqualFold(val.Val, "true") { if strings.EqualFold(val.Val, "true") {
c.w.WriteString(`IS NULL)`) io.WriteString(c.w, `IS NULL)`)
} else { } else {
c.w.WriteString(`IS NOT NULL)`) io.WriteString(c.w, `IS NOT NULL)`)
} }
valExists = false valExists = false
case qcode.OpEqID: case qcode.OpEqID:
@ -766,20 +825,20 @@ func (c *compilerContext) renderWhere(sel *qcode.Select, ti *DBTableInfo) error
return fmt.Errorf("no primary key column defined for %s", ti.Name) return fmt.Errorf("no primary key column defined for %s", ti.Name)
} }
//fmt.Fprintf(w, `(("%s") =`, c.ti.PrimaryCol) //fmt.Fprintf(w, `(("%s") =`, c.ti.PrimaryCol)
c.w.WriteString(`((`) io.WriteString(c.w, `((`)
colWithTable(c.w, ti.Name, ti.PrimaryCol) colWithTable(c.w, ti.Name, ti.PrimaryCol)
//c.w.WriteString(ti.PrimaryCol) //io.WriteString(c.w, ti.PrimaryCol)
c.w.WriteString(`) =`) io.WriteString(c.w, `) =`)
case qcode.OpTsQuery: case qcode.OpTsQuery:
if len(ti.TSVCol) == 0 { if len(ti.TSVCol) == 0 {
return fmt.Errorf("no tsv column defined for %s", ti.Name) return fmt.Errorf("no tsv column defined for %s", ti.Name)
} }
//fmt.Fprintf(w, `(("%s") @@ to_tsquery('%s'))`, c.ti.TSVCol, val.Val) //fmt.Fprintf(w, `(("%s") @@ to_tsquery('%s'))`, c.ti.TSVCol, val.Val)
c.w.WriteString(`(("`) io.WriteString(c.w, `(("`)
c.w.WriteString(ti.TSVCol) io.WriteString(c.w, ti.TSVCol)
c.w.WriteString(`") @@ to_tsquery('`) io.WriteString(c.w, `") @@ to_tsquery('`)
c.w.WriteString(val.Val) io.WriteString(c.w, val.Val)
c.w.WriteString(`'))`) io.WriteString(c.w, `'))`)
valExists = false valExists = false
default: default:
@ -792,7 +851,7 @@ func (c *compilerContext) renderWhere(sel *qcode.Select, ti *DBTableInfo) error
} else { } else {
c.renderVal(val, c.vars) c.renderVal(val, c.vars)
} }
c.w.WriteString(`)`) io.WriteString(c.w, `)`)
} }
qcode.FreeExp(val) qcode.FreeExp(val)
@ -808,10 +867,10 @@ func (c *compilerContext) renderWhere(sel *qcode.Select, ti *DBTableInfo) error
} }
func (c *compilerContext) renderOrderBy(sel *qcode.Select, ti *DBTableInfo) error { func (c *compilerContext) renderOrderBy(sel *qcode.Select, ti *DBTableInfo) error {
c.w.WriteString(` ORDER BY `) io.WriteString(c.w, ` ORDER BY `)
for i := range sel.OrderBy { for i := range sel.OrderBy {
if i != 0 { if i != 0 {
c.w.WriteString(`, `) io.WriteString(c.w, `, `)
} }
ob := sel.OrderBy[i] ob := sel.OrderBy[i]
@ -819,27 +878,27 @@ func (c *compilerContext) renderOrderBy(sel *qcode.Select, ti *DBTableInfo) erro
case qcode.OrderAsc: case qcode.OrderAsc:
//fmt.Fprintf(w, `"%s_%d.ob.%s" ASC`, sel.Table, sel.ID, ob.Col) //fmt.Fprintf(w, `"%s_%d.ob.%s" ASC`, sel.Table, sel.ID, ob.Col)
tableIDColSuffix(c.w, ti.Name, sel.ID, ob.Col, "_ob") tableIDColSuffix(c.w, ti.Name, sel.ID, ob.Col, "_ob")
c.w.WriteString(` ASC`) io.WriteString(c.w, ` ASC`)
case qcode.OrderDesc: case qcode.OrderDesc:
//fmt.Fprintf(w, `"%s_%d.ob.%s" DESC`, sel.Table, sel.ID, ob.Col) //fmt.Fprintf(w, `"%s_%d.ob.%s" DESC`, sel.Table, sel.ID, ob.Col)
tableIDColSuffix(c.w, ti.Name, sel.ID, ob.Col, "_ob") tableIDColSuffix(c.w, ti.Name, sel.ID, ob.Col, "_ob")
c.w.WriteString(` DESC`) io.WriteString(c.w, ` DESC`)
case qcode.OrderAscNullsFirst: case qcode.OrderAscNullsFirst:
//fmt.Fprintf(w, `"%s_%d.ob.%s" ASC NULLS FIRST`, sel.Table, sel.ID, ob.Col) //fmt.Fprintf(w, `"%s_%d.ob.%s" ASC NULLS FIRST`, sel.Table, sel.ID, ob.Col)
tableIDColSuffix(c.w, ti.Name, sel.ID, ob.Col, "_ob") tableIDColSuffix(c.w, ti.Name, sel.ID, ob.Col, "_ob")
c.w.WriteString(` ASC NULLS FIRST`) io.WriteString(c.w, ` ASC NULLS FIRST`)
case qcode.OrderDescNullsFirst: case qcode.OrderDescNullsFirst:
//fmt.Fprintf(w, `%s_%d.ob.%s DESC NULLS FIRST`, sel.Table, sel.ID, ob.Col) //fmt.Fprintf(w, `%s_%d.ob.%s DESC NULLS FIRST`, sel.Table, sel.ID, ob.Col)
tableIDColSuffix(c.w, ti.Name, sel.ID, ob.Col, "_ob") tableIDColSuffix(c.w, ti.Name, sel.ID, ob.Col, "_ob")
c.w.WriteString(` DESC NULLLS FIRST`) io.WriteString(c.w, ` DESC NULLLS FIRST`)
case qcode.OrderAscNullsLast: case qcode.OrderAscNullsLast:
//fmt.Fprintf(w, `"%s_%d.ob.%s ASC NULLS LAST`, sel.Table, sel.ID, ob.Col) //fmt.Fprintf(w, `"%s_%d.ob.%s ASC NULLS LAST`, sel.Table, sel.ID, ob.Col)
tableIDColSuffix(c.w, ti.Name, sel.ID, ob.Col, "_ob") tableIDColSuffix(c.w, ti.Name, sel.ID, ob.Col, "_ob")
c.w.WriteString(` ASC NULLS LAST`) io.WriteString(c.w, ` ASC NULLS LAST`)
case qcode.OrderDescNullsLast: case qcode.OrderDescNullsLast:
//fmt.Fprintf(w, `%s_%d.ob.%s DESC NULLS LAST`, sel.Table, sel.ID, ob.Col) //fmt.Fprintf(w, `%s_%d.ob.%s DESC NULLS LAST`, sel.Table, sel.ID, ob.Col)
tableIDColSuffix(c.w, ti.Name, sel.ID, ob.Col, "_ob") tableIDColSuffix(c.w, ti.Name, sel.ID, ob.Col, "_ob")
c.w.WriteString(` DESC NULLS LAST`) io.WriteString(c.w, ` DESC NULLS LAST`)
default: default:
return fmt.Errorf("13: unexpected value %v", ob.Order) return fmt.Errorf("13: unexpected value %v", ob.Order)
} }
@ -851,30 +910,30 @@ func (c *compilerContext) renderDistinctOn(sel *qcode.Select, ti *DBTableInfo) {
io.WriteString(c.w, `DISTINCT ON (`) io.WriteString(c.w, `DISTINCT ON (`)
for i := range sel.DistinctOn { for i := range sel.DistinctOn {
if i != 0 { if i != 0 {
c.w.WriteString(`, `) io.WriteString(c.w, `, `)
} }
//fmt.Fprintf(w, `"%s_%d.ob.%s"`, c.sel.Table, c.sel.ID, c.sel.DistinctOn[i]) //fmt.Fprintf(w, `"%s_%d.ob.%s"`, c.sel.Table, c.sel.ID, c.sel.DistinctOn[i])
tableIDColSuffix(c.w, ti.Name, sel.ID, sel.DistinctOn[i], "_ob") tableIDColSuffix(c.w, ti.Name, sel.ID, sel.DistinctOn[i], "_ob")
} }
c.w.WriteString(`) `) io.WriteString(c.w, `) `)
} }
func (c *compilerContext) renderList(ex *qcode.Exp) { func (c *compilerContext) renderList(ex *qcode.Exp) {
io.WriteString(c.w, ` (`) io.WriteString(c.w, ` (`)
for i := range ex.ListVal { for i := range ex.ListVal {
if i != 0 { if i != 0 {
c.w.WriteString(`, `) io.WriteString(c.w, `, `)
} }
switch ex.ListType { switch ex.ListType {
case qcode.ValBool, qcode.ValInt, qcode.ValFloat: case qcode.ValBool, qcode.ValInt, qcode.ValFloat:
c.w.WriteString(ex.ListVal[i]) io.WriteString(c.w, ex.ListVal[i])
case qcode.ValStr: case qcode.ValStr:
c.w.WriteString(`'`) io.WriteString(c.w, `'`)
c.w.WriteString(ex.ListVal[i]) io.WriteString(c.w, ex.ListVal[i])
c.w.WriteString(`'`) io.WriteString(c.w, `'`)
} }
} }
c.w.WriteString(`)`) io.WriteString(c.w, `)`)
} }
func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string) { func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string) {
@ -883,27 +942,27 @@ func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string) {
switch ex.Type { switch ex.Type {
case qcode.ValBool, qcode.ValInt, qcode.ValFloat: case qcode.ValBool, qcode.ValInt, qcode.ValFloat:
if len(ex.Val) != 0 { if len(ex.Val) != 0 {
c.w.WriteString(ex.Val) io.WriteString(c.w, ex.Val)
} else { } else {
c.w.WriteString(`''`) io.WriteString(c.w, `''`)
} }
case qcode.ValStr: case qcode.ValStr:
c.w.WriteString(`'`) io.WriteString(c.w, `'`)
c.w.WriteString(ex.Val) io.WriteString(c.w, ex.Val)
c.w.WriteString(`'`) io.WriteString(c.w, `'`)
case qcode.ValVar: case qcode.ValVar:
if val, ok := vars[ex.Val]; ok { if val, ok := vars[ex.Val]; ok {
c.w.WriteString(val) io.WriteString(c.w, val)
} else { } else {
//fmt.Fprintf(w, `'{{%s}}'`, ex.Val) //fmt.Fprintf(w, `'{{%s}}'`, ex.Val)
c.w.WriteString(`{{`) io.WriteString(c.w, `{{`)
c.w.WriteString(ex.Val) io.WriteString(c.w, ex.Val)
c.w.WriteString(`}}`) io.WriteString(c.w, `}}`)
} }
} }
//c.w.WriteString(`)`) //io.WriteString(c.w, `)`)
} }
func funcPrefixLen(fn string) int { func funcPrefixLen(fn string) int {
@ -939,105 +998,105 @@ func hasBit(n uint32, pos uint32) bool {
return (val > 0) return (val > 0)
} }
func alias(w *bytes.Buffer, alias string) { func alias(w io.Writer, alias string) {
w.WriteString(` AS "`) io.WriteString(w, ` AS "`)
w.WriteString(alias) io.WriteString(w, alias)
w.WriteString(`"`) io.WriteString(w, `"`)
} }
func aliasWithID(w *bytes.Buffer, alias string, id int32) { func aliasWithID(w io.Writer, alias string, id int32) {
w.WriteString(` AS "`) io.WriteString(w, ` AS "`)
w.WriteString(alias) io.WriteString(w, alias)
w.WriteString(`_`) io.WriteString(w, `_`)
int2string(w, id) int2string(w, id)
w.WriteString(`"`) io.WriteString(w, `"`)
} }
func aliasWithIDSuffix(w *bytes.Buffer, alias string, id int32, suffix string) { func aliasWithIDSuffix(w io.Writer, alias string, id int32, suffix string) {
w.WriteString(` AS "`) io.WriteString(w, ` AS "`)
w.WriteString(alias) io.WriteString(w, alias)
w.WriteString(`_`) io.WriteString(w, `_`)
int2string(w, id) int2string(w, id)
w.WriteString(suffix) io.WriteString(w, suffix)
w.WriteString(`"`) io.WriteString(w, `"`)
} }
func colWithAlias(w *bytes.Buffer, col, alias string) { func colWithAlias(w io.Writer, col, alias string) {
w.WriteString(`"`) io.WriteString(w, `"`)
w.WriteString(col) io.WriteString(w, col)
w.WriteString(`" AS "`) io.WriteString(w, `" AS "`)
w.WriteString(alias) io.WriteString(w, alias)
w.WriteString(`"`) io.WriteString(w, `"`)
} }
func tableWithAlias(w *bytes.Buffer, table, alias string) { func tableWithAlias(w io.Writer, table, alias string) {
w.WriteString(`"`) io.WriteString(w, `"`)
w.WriteString(table) io.WriteString(w, table)
w.WriteString(`" AS "`) io.WriteString(w, `" AS "`)
w.WriteString(alias) io.WriteString(w, alias)
w.WriteString(`"`) io.WriteString(w, `"`)
} }
func colWithTable(w *bytes.Buffer, table, col string) { func colWithTable(w io.Writer, table, col string) {
w.WriteString(`"`) io.WriteString(w, `"`)
w.WriteString(table) io.WriteString(w, table)
w.WriteString(`"."`) io.WriteString(w, `"."`)
w.WriteString(col) io.WriteString(w, col)
w.WriteString(`"`) io.WriteString(w, `"`)
} }
func colWithTableID(w *bytes.Buffer, table string, id int32, col string) { func colWithTableID(w io.Writer, table string, id int32, col string) {
w.WriteString(`"`) io.WriteString(w, `"`)
w.WriteString(table) io.WriteString(w, table)
w.WriteString(`_`) io.WriteString(w, `_`)
int2string(w, id) int2string(w, id)
w.WriteString(`"."`) io.WriteString(w, `"."`)
w.WriteString(col) io.WriteString(w, col)
w.WriteString(`"`) io.WriteString(w, `"`)
} }
func colWithTableIDAlias(w *bytes.Buffer, table string, id int32, col, alias string) { func colWithTableIDAlias(w io.Writer, table string, id int32, col, alias string) {
w.WriteString(`"`) io.WriteString(w, `"`)
w.WriteString(table) io.WriteString(w, table)
w.WriteString(`_`) io.WriteString(w, `_`)
int2string(w, id) int2string(w, id)
w.WriteString(`"."`) io.WriteString(w, `"."`)
w.WriteString(col) io.WriteString(w, col)
w.WriteString(`" AS "`) io.WriteString(w, `" AS "`)
w.WriteString(alias) io.WriteString(w, alias)
w.WriteString(`"`) io.WriteString(w, `"`)
} }
func colWithTableIDSuffixAlias(w *bytes.Buffer, table string, id int32, func colWithTableIDSuffixAlias(w io.Writer, table string, id int32,
suffix, col, alias string) { suffix, col, alias string) {
w.WriteString(`"`) io.WriteString(w, `"`)
w.WriteString(table) io.WriteString(w, table)
w.WriteString(`_`) io.WriteString(w, `_`)
int2string(w, id) int2string(w, id)
w.WriteString(suffix) io.WriteString(w, suffix)
w.WriteString(`"."`) io.WriteString(w, `"."`)
w.WriteString(col) io.WriteString(w, col)
w.WriteString(`" AS "`) io.WriteString(w, `" AS "`)
w.WriteString(alias) io.WriteString(w, alias)
w.WriteString(`"`) io.WriteString(w, `"`)
} }
func tableIDColSuffix(w *bytes.Buffer, table string, id int32, col, suffix string) { func tableIDColSuffix(w io.Writer, table string, id int32, col, suffix string) {
w.WriteString(`"`) io.WriteString(w, `"`)
w.WriteString(table) io.WriteString(w, table)
w.WriteString(`_`) io.WriteString(w, `_`)
int2string(w, id) int2string(w, id)
w.WriteString(`_`) io.WriteString(w, `_`)
w.WriteString(col) io.WriteString(w, col)
w.WriteString(suffix) io.WriteString(w, suffix)
w.WriteString(`"`) io.WriteString(w, `"`)
} }
const charset = "0123456789" const charset = "0123456789"
func int2string(w *bytes.Buffer, val int32) { func int2string(w io.Writer, val int32) {
if val < 10 { if val < 10 {
w.WriteByte(charset[val]) w.Write([]byte{charset[val]})
return return
} }
@ -1053,7 +1112,7 @@ func int2string(w *bytes.Buffer, val int32) {
for val3 > 0 { for val3 > 0 {
d := val3 % 10 d := val3 % 10
val3 /= 10 val3 /= 10
w.WriteByte(charset[d]) w.Write([]byte{charset[d]})
} }
} }

View File

@ -22,32 +22,6 @@ func TestMain(m *testing.M) {
var err error var err error
qcompile, err = qcode.NewCompiler(qcode.Config{ qcompile, err = qcode.NewCompiler(qcode.Config{
DefaultFilter: []string{
`{ user_id: { _eq: $user_id } }`,
},
FilterMap: qcode.Filters{
All: map[string][]string{
"users": []string{
"{ id: { eq: $user_id } }",
},
"products": []string{
"{ price: { gt: 0 } }",
"{ price: { lt: 8 } }",
},
"customers": []string{},
"mes": []string{
"{ id: { eq: $user_id } }",
},
},
Query: map[string][]string{
"users": []string{},
},
Update: map[string][]string{
"products": []string{
"{ user_id: { eq: $user_id } }",
},
},
},
Blocklist: []string{ Blocklist: []string{
"secret", "secret",
"password", "password",
@ -55,6 +29,59 @@ func TestMain(m *testing.M) {
}, },
}) })
qcompile.AddRole("user", "product", qcode.TRConfig{
Query: qcode.QueryConfig{
Columns: []string{"id", "name", "price", "users", "customers"},
Filters: []string{
"{ price: { gt: 0 } }",
"{ price: { lt: 8 } }",
},
},
Update: qcode.UpdateConfig{
Filters: []string{"{ user_id: { eq: $user_id } }"},
},
Delete: qcode.DeleteConfig{
Filters: []string{
"{ price: { gt: 0 } }",
"{ price: { lt: 8 } }",
},
},
})
qcompile.AddRole("anon", "product", qcode.TRConfig{
Query: qcode.QueryConfig{
Columns: []string{"id", "name"},
},
})
qcompile.AddRole("anon1", "product", qcode.TRConfig{
Query: qcode.QueryConfig{
Columns: []string{"id", "name", "price"},
DisableFunctions: true,
},
})
qcompile.AddRole("user", "users", qcode.TRConfig{
Query: qcode.QueryConfig{
Columns: []string{"id", "full_name", "avatar", "email", "products"},
},
})
qcompile.AddRole("user", "mes", qcode.TRConfig{
Query: qcode.QueryConfig{
Columns: []string{"id", "full_name", "avatar"},
Filters: []string{
"{ id: { eq: $user_id } }",
},
},
})
qcompile.AddRole("user", "customers", qcode.TRConfig{
Query: qcode.QueryConfig{
Columns: []string{"id", "email", "full_name", "products"},
},
})
if err != nil { if err != nil {
log.Fatal(err) log.Fatal(err)
} }
@ -135,9 +162,8 @@ func TestMain(m *testing.M) {
os.Exit(m.Run()) os.Exit(m.Run())
} }
func compileGQLToPSQL(gql string, vars Variables) ([]byte, error) { func compileGQLToPSQL(gql string, vars Variables, role string) ([]byte, error) {
qc, err := qcompile.Compile([]byte(gql), role)
qc, err := qcompile.Compile([]byte(gql))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -147,6 +173,8 @@ func compileGQLToPSQL(gql string, vars Variables) ([]byte, error) {
return nil, err return nil, err
} }
//fmt.Println(string(sqlStmt))
return sqlStmt, nil return sqlStmt, nil
} }
@ -173,9 +201,9 @@ func withComplexArgs(t *testing.T) {
} }
}` }`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0" ORDER BY "products_0_price_ob" DESC), '[]') AS "products" FROM (SELECT DISTINCT ON ("products_0_price_ob") row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0", "products_0"."price" AS "products_0_price_ob" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") < 28) AND (("products"."id") >= 20)) LIMIT ('30') :: integer) AS "products_0" ORDER BY "products_0_price_ob" DESC LIMIT ('30') :: integer) AS "sel_json_agg_0") AS "done_1337";` sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0" ORDER BY "products_0_price_ob" DESC), '[]') AS "products" FROM (SELECT DISTINCT ON ("products_0_price_ob") row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0", "products_0"."price" AS "products_0_price_ob" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") < 28) AND (("products"."id") >= 20)) LIMIT ('30') :: integer) AS "products_0" ORDER BY "products_0_price_ob" DESC LIMIT ('30') :: integer) AS "sel_json_agg_0") AS "done_1337"`
resSQL, err := compileGQLToPSQL(gql, nil) resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -201,9 +229,9 @@ func withWhereMultiOr(t *testing.T) {
} }
}` }`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") < 20) OR (("products"."price") > 10) OR NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") < 20) OR (("products"."price") > 10) OR NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"`
resSQL, err := compileGQLToPSQL(gql, nil) resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -227,9 +255,9 @@ func withWhereIsNull(t *testing.T) {
} }
}` }`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") > 10) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") > 10) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"`
resSQL, err := compileGQLToPSQL(gql, nil) resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -253,9 +281,9 @@ func withWhereAndList(t *testing.T) {
} }
}` }`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") > 10) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") > 10) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"`
resSQL, err := compileGQLToPSQL(gql, nil) resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -273,9 +301,9 @@ func fetchByID(t *testing.T) {
} }
}` }`
sql := `SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") = 15)) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337";` sql := `SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") = 15)) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337"`
resSQL, err := compileGQLToPSQL(gql, nil) resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -293,9 +321,9 @@ func searchQuery(t *testing.T) {
} }
}` }`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("tsv") @@ to_tsquery('Imperial'))) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("tsv") @@ to_tsquery('Imperial'))) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"`
resSQL, err := compileGQLToPSQL(gql, nil) resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -316,9 +344,9 @@ func oneToMany(t *testing.T) {
} }
}` }`
sql := `SELECT json_object_agg('users', users) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."email" AS "email", "products_1_join"."products" AS "products") AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."email", "users"."id" FROM "users" LIMIT ('20') :: integer) AS "users_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "products_1"."name" AS "name", "products_1"."price" AS "price") AS "sel_1")) AS "sel_json_1" FROM (SELECT "products"."name", "products"."price" FROM "products" WHERE ((("products"."user_id") = ("users_0"."id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` sql := `SELECT json_object_agg('users', users) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."email" AS "email", "products_1_join"."products" AS "products") AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."email", "users"."id" FROM "users" LIMIT ('20') :: integer) AS "users_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "products_1"."name" AS "name", "products_1"."price" AS "price") AS "sel_1")) AS "sel_json_1" FROM (SELECT "products"."name", "products"."price" FROM "products" WHERE ((("products"."user_id") = ("users_0"."id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"`
resSQL, err := compileGQLToPSQL(gql, nil) resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -339,9 +367,9 @@ func belongsTo(t *testing.T) {
} }
}` }`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."price" AS "price", "users_1_join"."users" AS "users") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", "products"."price", "products"."user_id" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "users_1"."email" AS "email") AS "sel_1")) AS "sel_json_1" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = ("products_0"."user_id"))) LIMIT ('20') :: integer) AS "users_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "users_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."price" AS "price", "users_1_join"."users" AS "users") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", "products"."price", "products"."user_id" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "users_1"."email" AS "email") AS "sel_1")) AS "sel_json_1" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = ("products_0"."user_id"))) LIMIT ('20') :: integer) AS "users_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "users_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"`
resSQL, err := compileGQLToPSQL(gql, nil) resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -362,9 +390,9 @@ func manyToMany(t *testing.T) {
} }
}` }`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "customers_1_join"."customers" AS "customers") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", "products"."id" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "customers_1"."email" AS "email", "customers_1"."full_name" AS "full_name") AS "sel_1")) AS "sel_json_1" FROM (SELECT "customers"."email", "customers"."full_name" FROM "customers" LEFT OUTER JOIN "purchases" ON (("purchases"."product_id") = ("products_0"."id")) WHERE ((("customers"."id") = ("purchases"."customer_id"))) LIMIT ('20') :: integer) AS "customers_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "customers_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "customers_1_join"."customers" AS "customers") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", "products"."id" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "customers_1"."email" AS "email", "customers_1"."full_name" AS "full_name") AS "sel_1")) AS "sel_json_1" FROM (SELECT "customers"."email", "customers"."full_name" FROM "customers" LEFT OUTER JOIN "purchases" ON (("purchases"."product_id") = ("products_0"."id")) WHERE ((("customers"."id") = ("purchases"."customer_id"))) LIMIT ('20') :: integer) AS "customers_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "customers_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"`
resSQL, err := compileGQLToPSQL(gql, nil) resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -385,9 +413,9 @@ func manyToManyReverse(t *testing.T) {
} }
}` }`
sql := `SELECT json_object_agg('customers', customers) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "customers_0"."email" AS "email", "customers_0"."full_name" AS "full_name", "products_1_join"."products" AS "products") AS "sel_0")) AS "sel_json_0" FROM (SELECT "customers"."email", "customers"."full_name", "customers"."id" FROM "customers" LIMIT ('20') :: integer) AS "customers_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "products_1"."name" AS "name") AS "sel_1")) AS "sel_json_1" FROM (SELECT "products"."name" FROM "products" LEFT OUTER JOIN "purchases" ON (("purchases"."customer_id") = ("customers_0"."id")) WHERE ((("products"."id") = ("purchases"."product_id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` sql := `SELECT json_object_agg('customers', customers) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "customers_0"."email" AS "email", "customers_0"."full_name" AS "full_name", "products_1_join"."products" AS "products") AS "sel_0")) AS "sel_json_0" FROM (SELECT "customers"."email", "customers"."full_name", "customers"."id" FROM "customers" LIMIT ('20') :: integer) AS "customers_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "products_1"."name" AS "name") AS "sel_1")) AS "sel_json_1" FROM (SELECT "products"."name" FROM "products" LEFT OUTER JOIN "purchases" ON (("purchases"."customer_id") = ("customers_0"."id")) WHERE ((("products"."id") = ("purchases"."product_id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"`
resSQL, err := compileGQLToPSQL(gql, nil) resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -405,9 +433,49 @@ func aggFunction(t *testing.T) {
} }
}` }`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."count_price" AS "count_price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", count("products"."price") AS "count_price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) GROUP BY "products"."name" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."count_price" AS "count_price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", count("products"."price") AS "count_price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) GROUP BY "products"."name" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"`
resSQL, err := compileGQLToPSQL(gql, nil) resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil {
t.Fatal(err)
}
if string(resSQL) != sql {
t.Fatal(errNotExpected)
}
}
func aggFunctionBlockedByCol(t *testing.T) {
gql := `query {
products {
name
count_price
}
}`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name" FROM "products" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"`
resSQL, err := compileGQLToPSQL(gql, nil, "anon")
if err != nil {
t.Fatal(err)
}
if string(resSQL) != sql {
t.Fatal(errNotExpected)
}
}
func aggFunctionDisabled(t *testing.T) {
gql := `query {
products {
name
count_price
}
}`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name" FROM "products" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"`
resSQL, err := compileGQLToPSQL(gql, nil, "anon1")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -425,9 +493,9 @@ func aggFunctionWithFilter(t *testing.T) {
} }
}` }`
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."max_price" AS "max_price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", max("products"."price") AS "max_price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") > 10)) GROUP BY "products"."id" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."max_price" AS "max_price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", max("products"."price") AS "max_price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") > 10)) GROUP BY "products"."id" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"`
resSQL, err := compileGQLToPSQL(gql, nil) resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -445,9 +513,9 @@ func queryWithVariables(t *testing.T) {
} }
}` }`
sql := `SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") = {{product_price}}) AND (("products"."id") = {{product_id}})) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337";` sql := `SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") = {{product_price}}) AND (("products"."id") = {{product_id}})) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337"`
resSQL, err := compileGQLToPSQL(gql, nil) resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -464,9 +532,9 @@ func syntheticTables(t *testing.T) {
} }
}` }`
sql := `SELECT json_object_agg('me', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."email" AS "email") AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = {{user_id}})) LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "done_1337";` sql := `SELECT json_object_agg('me', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT ) AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = {{user_id}})) LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "done_1337"`
resSQL, err := compileGQLToPSQL(gql, nil) resSQL, err := compileGQLToPSQL(gql, nil, "user")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -476,7 +544,7 @@ func syntheticTables(t *testing.T) {
} }
} }
func TestCompileSelect(t *testing.T) { func TestCompileQuery(t *testing.T) {
t.Run("withComplexArgs", withComplexArgs) t.Run("withComplexArgs", withComplexArgs)
t.Run("withWhereAndList", withWhereAndList) t.Run("withWhereAndList", withWhereAndList)
t.Run("withWhereIsNull", withWhereIsNull) t.Run("withWhereIsNull", withWhereIsNull)
@ -488,10 +556,11 @@ func TestCompileSelect(t *testing.T) {
t.Run("manyToMany", manyToMany) t.Run("manyToMany", manyToMany)
t.Run("manyToManyReverse", manyToManyReverse) t.Run("manyToManyReverse", manyToManyReverse)
t.Run("aggFunction", aggFunction) t.Run("aggFunction", aggFunction)
t.Run("aggFunctionBlockedByCol", aggFunctionBlockedByCol)
t.Run("aggFunctionDisabled", aggFunctionDisabled)
t.Run("aggFunctionWithFilter", aggFunctionWithFilter) t.Run("aggFunctionWithFilter", aggFunctionWithFilter)
t.Run("syntheticTables", syntheticTables) t.Run("syntheticTables", syntheticTables)
t.Run("queryWithVariables", queryWithVariables) t.Run("queryWithVariables", queryWithVariables)
} }
var benchGQL = []byte(`query { var benchGQL = []byte(`query {
@ -526,7 +595,7 @@ func BenchmarkCompile(b *testing.B) {
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
w.Reset() w.Reset()
qc, err := qcompile.Compile(benchGQL) qc, err := qcompile.Compile(benchGQL, "user")
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
@ -547,7 +616,7 @@ func BenchmarkCompileParallel(b *testing.B) {
for pb.Next() { for pb.Next() {
w.Reset() w.Reset()
qc, err := qcompile.Compile(benchGQL) qc, err := qcompile.Compile(benchGQL, "user")
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }

99
qcode/config.go Normal file
View File

@ -0,0 +1,99 @@
package qcode
type Config struct {
Blocklist []string
KeepArgs bool
}
type QueryConfig struct {
Limit int
Filters []string
Columns []string
DisableFunctions bool
}
type InsertConfig struct {
Filters []string
Columns []string
Set map[string]string
}
type UpdateConfig struct {
Filters []string
Columns []string
Set map[string]string
}
type DeleteConfig struct {
Filters []string
Columns []string
}
type TRConfig struct {
Query QueryConfig
Insert InsertConfig
Update UpdateConfig
Delete DeleteConfig
}
type trval struct {
query struct {
limit string
fil *Exp
cols map[string]struct{}
disable struct {
funcs bool
}
}
insert struct {
fil *Exp
cols map[string]struct{}
set map[string]string
}
update struct {
fil *Exp
cols map[string]struct{}
set map[string]string
}
delete struct {
fil *Exp
cols map[string]struct{}
}
}
func (trv *trval) allowedColumns(qt QType) map[string]struct{} {
switch qt {
case QTQuery:
return trv.query.cols
case QTInsert:
return trv.insert.cols
case QTUpdate:
return trv.update.cols
case QTDelete:
return trv.insert.cols
case QTUpsert:
return trv.insert.cols
}
return nil
}
func (trv *trval) filter(qt QType) *Exp {
switch qt {
case QTQuery:
return trv.query.fil
case QTInsert:
return trv.insert.fil
case QTUpdate:
return trv.update.fil
case QTDelete:
return trv.delete.fil
case QTUpsert:
return trv.insert.fil
}
return nil
}

View File

@ -5,7 +5,7 @@ func FuzzerEntrypoint(data []byte) int {
//testData := string(data) //testData := string(data)
qcompile, _ := NewCompiler(Config{}) qcompile, _ := NewCompiler(Config{})
_, err := qcompile.Compile(data) _, err := qcompile.Compile(data, "user")
if err != nil { if err != nil {
return -1 return -1
} }

View File

@ -18,7 +18,9 @@ type parserType int32
const ( const (
maxFields = 100 maxFields = 100
maxArgs = 10 maxArgs = 10
)
const (
parserError parserType = iota parserError parserType = iota
parserEOF parserEOF
opQuery opQuery

View File

@ -46,13 +46,18 @@ func compareOp(op1, op2 Operation) error {
*/ */
func TestCompile1(t *testing.T) { func TestCompile1(t *testing.T) {
qcompile, _ := NewCompiler(Config{}) qc, _ := NewCompiler(Config{})
qc.AddRole("user", "product", TRConfig{
Query: QueryConfig{
Columns: []string{"id", "Name"},
},
})
_, err := qcompile.Compile([]byte(` _, err := qc.Compile([]byte(`
product(id: 15) { product(id: 15) {
id id
name name
}`)) }`), "user")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -60,13 +65,18 @@ func TestCompile1(t *testing.T) {
} }
func TestCompile2(t *testing.T) { func TestCompile2(t *testing.T) {
qcompile, _ := NewCompiler(Config{}) qc, _ := NewCompiler(Config{})
qc.AddRole("user", "product", TRConfig{
Query: QueryConfig{
Columns: []string{"ID"},
},
})
_, err := qcompile.Compile([]byte(` _, err := qc.Compile([]byte(`
query { product(id: 15) { query { product(id: 15) {
id id
name name
} }`)) } }`), "user")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -74,15 +84,20 @@ func TestCompile2(t *testing.T) {
} }
func TestCompile3(t *testing.T) { func TestCompile3(t *testing.T) {
qcompile, _ := NewCompiler(Config{}) qc, _ := NewCompiler(Config{})
qc.AddRole("user", "product", TRConfig{
Query: QueryConfig{
Columns: []string{"ID"},
},
})
_, err := qcompile.Compile([]byte(` _, err := qc.Compile([]byte(`
mutation { mutation {
product(id: 15, name: "Test") { product(id: 15, name: "Test") {
id id
name name
} }
}`)) }`), "user")
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
@ -91,7 +106,7 @@ func TestCompile3(t *testing.T) {
func TestInvalidCompile1(t *testing.T) { func TestInvalidCompile1(t *testing.T) {
qcompile, _ := NewCompiler(Config{}) qcompile, _ := NewCompiler(Config{})
_, err := qcompile.Compile([]byte(`#`)) _, err := qcompile.Compile([]byte(`#`), "user")
if err == nil { if err == nil {
t.Fatal(errors.New("expecting an error")) t.Fatal(errors.New("expecting an error"))
@ -100,7 +115,7 @@ func TestInvalidCompile1(t *testing.T) {
func TestInvalidCompile2(t *testing.T) { func TestInvalidCompile2(t *testing.T) {
qcompile, _ := NewCompiler(Config{}) qcompile, _ := NewCompiler(Config{})
_, err := qcompile.Compile([]byte(`{u(where:{not:0})}`)) _, err := qcompile.Compile([]byte(`{u(where:{not:0})}`), "user")
if err == nil { if err == nil {
t.Fatal(errors.New("expecting an error")) t.Fatal(errors.New("expecting an error"))
@ -109,7 +124,7 @@ func TestInvalidCompile2(t *testing.T) {
func TestEmptyCompile(t *testing.T) { func TestEmptyCompile(t *testing.T) {
qcompile, _ := NewCompiler(Config{}) qcompile, _ := NewCompiler(Config{})
_, err := qcompile.Compile([]byte(``)) _, err := qcompile.Compile([]byte(``), "user")
if err == nil { if err == nil {
t.Fatal(errors.New("expecting an error")) t.Fatal(errors.New("expecting an error"))
@ -144,7 +159,7 @@ func BenchmarkQCompile(b *testing.B) {
b.ReportAllocs() b.ReportAllocs()
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
_, err := qcompile.Compile(gql) _, err := qcompile.Compile(gql, "user")
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
@ -160,7 +175,7 @@ func BenchmarkQCompileP(b *testing.B) {
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
for pb.Next() { for pb.Next() {
_, err := qcompile.Compile(gql) _, err := qcompile.Compile(gql, "user")
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)

View File

@ -3,6 +3,7 @@ package qcode
import ( import (
"errors" "errors"
"fmt" "fmt"
"strconv"
"strings" "strings"
"sync" "sync"
@ -15,27 +16,22 @@ type Action int
const ( const (
maxSelectors = 30 maxSelectors = 30
)
const (
QTQuery QType = iota + 1 QTQuery QType = iota + 1
QTMutation QTInsert
QTUpdate
ActionInsert Action = iota + 1 QTDelete
ActionUpdate QTUpsert
ActionDelete
ActionUpsert
) )
type QCode struct { type QCode struct {
Type QType Type QType
ActionVar string
Selects []Select Selects []Select
} }
type Column struct {
Table string
Name string
FieldName string
}
type Select struct { type Select struct {
ID int32 ID int32
ParentID int32 ParentID int32
@ -47,9 +43,15 @@ type Select struct {
OrderBy []*OrderBy OrderBy []*OrderBy
DistinctOn []string DistinctOn []string
Paging Paging Paging Paging
Action Action
ActionVar string
Children []int32 Children []int32
Functions bool
Allowed map[string]struct{}
}
type Column struct {
Table string
Name string
FieldName string
} }
type Exp struct { type Exp struct {
@ -79,6 +81,7 @@ type OrderBy struct {
type Paging struct { type Paging struct {
Limit string Limit string
Offset string Offset string
NoLimit bool
} }
type ExpOp int type ExpOp int
@ -145,81 +148,23 @@ const (
OrderDescNullsLast OrderDescNullsLast
) )
type Filters struct {
All map[string][]string
Query map[string][]string
Insert map[string][]string
Update map[string][]string
Delete map[string][]string
}
type Config struct {
DefaultFilter []string
FilterMap Filters
Blocklist []string
KeepArgs bool
}
type Compiler struct { type Compiler struct {
df *Exp tr map[string]map[string]*trval
fm struct {
all map[string]*Exp
query map[string]*Exp
insert map[string]*Exp
update map[string]*Exp
delete map[string]*Exp
}
bl map[string]struct{} bl map[string]struct{}
ka bool ka bool
} }
var opMap = map[parserType]QType{
opQuery: QTQuery,
opMutate: QTMutation,
}
var expPool = sync.Pool{ var expPool = sync.Pool{
New: func() interface{} { return &Exp{doFree: true} }, New: func() interface{} { return &Exp{doFree: true} },
} }
func NewCompiler(c Config) (*Compiler, error) { func NewCompiler(c Config) (*Compiler, error) {
var err error
co := &Compiler{ka: c.KeepArgs} co := &Compiler{ka: c.KeepArgs}
co.tr = make(map[string]map[string]*trval)
co.bl = make(map[string]struct{}, len(c.Blocklist)) co.bl = make(map[string]struct{}, len(c.Blocklist))
for i := range c.Blocklist { for i := range c.Blocklist {
co.bl[c.Blocklist[i]] = struct{}{} co.bl[strings.ToLower(c.Blocklist[i])] = struct{}{}
}
co.df, err = compileFilter(c.DefaultFilter)
if err != nil {
return nil, err
}
co.fm.all, err = buildFilters(c.FilterMap.All)
if err != nil {
return nil, err
}
co.fm.query, err = buildFilters(c.FilterMap.Query)
if err != nil {
return nil, err
}
co.fm.insert, err = buildFilters(c.FilterMap.Insert)
if err != nil {
return nil, err
}
co.fm.update, err = buildFilters(c.FilterMap.Update)
if err != nil {
return nil, err
}
co.fm.delete, err = buildFilters(c.FilterMap.Delete)
if err != nil {
return nil, err
} }
seedExp := [100]Exp{} seedExp := [100]Exp{}
@ -232,58 +177,99 @@ func NewCompiler(c Config) (*Compiler, error) {
return co, nil return co, nil
} }
func buildFilters(filMap map[string][]string) (map[string]*Exp, error) { func (com *Compiler) AddRole(role, table string, trc TRConfig) error {
fm := make(map[string]*Exp, len(filMap))
for k, v := range filMap {
fil, err := compileFilter(v)
if err != nil {
return nil, err
}
singular := flect.Singularize(k)
plural := flect.Pluralize(k)
fm[singular] = fil
fm[plural] = fil
}
return fm, nil
}
func (com *Compiler) Compile(query []byte) (*QCode, error) {
var qc QCode
var err error var err error
trv := &trval{}
toMap := func(cols []string) map[string]struct{} {
m := make(map[string]struct{}, len(cols))
for i := range cols {
m[strings.ToLower(cols[i])] = struct{}{}
}
return m
}
// query config
trv.query.fil, err = compileFilter(trc.Query.Filters)
if err != nil {
return err
}
if trc.Query.Limit > 0 {
trv.query.limit = strconv.Itoa(trc.Query.Limit)
}
trv.query.cols = toMap(trc.Query.Columns)
trv.query.disable.funcs = trc.Query.DisableFunctions
// insert config
if trv.insert.fil, err = compileFilter(trc.Insert.Filters); err != nil {
return err
}
trv.insert.cols = toMap(trc.Insert.Columns)
// update config
if trv.update.fil, err = compileFilter(trc.Update.Filters); err != nil {
return err
}
trv.insert.cols = toMap(trc.Insert.Columns)
trv.insert.set = trc.Insert.Set
// delete config
if trv.delete.fil, err = compileFilter(trc.Delete.Filters); err != nil {
return err
}
trv.delete.cols = toMap(trc.Delete.Columns)
singular := flect.Singularize(table)
plural := flect.Pluralize(table)
if _, ok := com.tr[role]; !ok {
com.tr[role] = make(map[string]*trval)
}
com.tr[role][singular] = trv
com.tr[role][plural] = trv
return nil
}
func (com *Compiler) Compile(query []byte, role string) (*QCode, error) {
var err error
qc := QCode{Type: QTQuery}
op, err := Parse(query) op, err := Parse(query)
if err != nil { if err != nil {
return nil, err return nil, err
} }
qc.Selects, err = com.compileQuery(op) if err = com.compileQuery(&qc, op, role); err != nil {
if err != nil {
return nil, err return nil, err
} }
if t, ok := opMap[op.Type]; ok {
qc.Type = t
} else {
return nil, fmt.Errorf("Unknown operation type %d", op.Type)
}
opPool.Put(op) opPool.Put(op)
return &qc, nil return &qc, nil
} }
func (com *Compiler) compileQuery(op *Operation) ([]Select, error) { func (com *Compiler) compileQuery(qc *QCode, op *Operation, role string) error {
id := int32(0) id := int32(0)
parentID := int32(0) parentID := int32(0)
if len(op.Fields) == 0 {
return errors.New("invalid graphql no query found")
}
if op.Type == opMutate {
if err := com.setMutationType(qc, op.Fields[0].Args); err != nil {
return err
}
}
selects := make([]Select, 0, 5) selects := make([]Select, 0, 5)
st := NewStack() st := NewStack()
action := qc.Type
if len(op.Fields) == 0 { if len(op.Fields) == 0 {
return nil, errors.New("empty query") return errors.New("empty query")
} }
st.Push(op.Fields[0].ID) st.Push(op.Fields[0].ID)
@ -293,7 +279,7 @@ func (com *Compiler) compileQuery(op *Operation) ([]Select, error) {
} }
if id >= maxSelectors { if id >= maxSelectors {
return nil, fmt.Errorf("selector limit reached (%d)", maxSelectors) return fmt.Errorf("selector limit reached (%d)", maxSelectors)
} }
fid := st.Pop() fid := st.Pop()
@ -303,14 +289,25 @@ func (com *Compiler) compileQuery(op *Operation) ([]Select, error) {
continue continue
} }
trv := com.getRole(role, field.Name)
selects = append(selects, Select{ selects = append(selects, Select{
ID: id, ID: id,
ParentID: parentID, ParentID: parentID,
Table: field.Name, Table: field.Name,
Children: make([]int32, 0, 5), Children: make([]int32, 0, 5),
Allowed: trv.allowedColumns(action),
}) })
s := &selects[(len(selects) - 1)] s := &selects[(len(selects) - 1)]
if action == QTQuery {
s.Functions = !trv.query.disable.funcs
if len(trv.query.limit) != 0 {
s.Paging.Limit = trv.query.limit
}
}
if s.ID != 0 { if s.ID != 0 {
p := &selects[s.ParentID] p := &selects[s.ParentID]
p.Children = append(p.Children, s.ID) p.Children = append(p.Children, s.ID)
@ -322,12 +319,13 @@ func (com *Compiler) compileQuery(op *Operation) ([]Select, error) {
s.FieldName = s.Table s.FieldName = s.Table
} }
err := com.compileArgs(s, field.Args) err := com.compileArgs(qc, s, field.Args)
if err != nil { if err != nil {
return nil, err return err
} }
s.Cols = make([]Column, 0, len(field.Children)) s.Cols = make([]Column, 0, len(field.Children))
action = QTQuery
for _, cid := range field.Children { for _, cid := range field.Children {
f := op.Fields[cid] f := op.Fields[cid]
@ -356,36 +354,14 @@ func (com *Compiler) compileQuery(op *Operation) ([]Select, error) {
} }
if id == 0 { if id == 0 {
return nil, errors.New("invalid query") return errors.New("invalid query")
} }
var fil *Exp var fil *Exp
root := &selects[0] root := &selects[0]
switch op.Type { if trv, ok := com.tr[role][op.Fields[0].Name]; ok {
case opQuery: fil = trv.filter(qc.Type)
fil, _ = com.fm.query[root.Table]
case opMutate:
switch root.Action {
case ActionInsert:
fil, _ = com.fm.insert[root.Table]
case ActionUpdate:
fil, _ = com.fm.update[root.Table]
case ActionDelete:
fil, _ = com.fm.delete[root.Table]
case ActionUpsert:
fil, _ = com.fm.insert[root.Table]
}
}
if fil == nil {
fil, _ = com.fm.all[root.Table]
}
if fil == nil {
fil = com.df
} }
if fil != nil && fil.Op != OpNop { if fil != nil && fil.Op != OpNop {
@ -403,10 +379,11 @@ func (com *Compiler) compileQuery(op *Operation) ([]Select, error) {
} }
} }
return selects[:id], nil qc.Selects = selects[:id]
return nil
} }
func (com *Compiler) compileArgs(sel *Select, args []Arg) error { func (com *Compiler) compileArgs(qc *QCode, sel *Select, args []Arg) error {
var err error var err error
if com.ka { if com.ka {
@ -418,9 +395,7 @@ func (com *Compiler) compileArgs(sel *Select, args []Arg) error {
switch arg.Name { switch arg.Name {
case "id": case "id":
if sel.ID == 0 {
err = com.compileArgID(sel, arg) err = com.compileArgID(sel, arg)
}
case "search": case "search":
err = com.compileArgSearch(sel, arg) err = com.compileArgSearch(sel, arg)
case "where": case "where":
@ -433,18 +408,6 @@ func (com *Compiler) compileArgs(sel *Select, args []Arg) error {
err = com.compileArgLimit(sel, arg) err = com.compileArgLimit(sel, arg)
case "offset": case "offset":
err = com.compileArgOffset(sel, arg) err = com.compileArgOffset(sel, arg)
case "insert":
sel.Action = ActionInsert
err = com.compileArgAction(sel, arg)
case "update":
sel.Action = ActionUpdate
err = com.compileArgAction(sel, arg)
case "upsert":
sel.Action = ActionUpsert
err = com.compileArgAction(sel, arg)
case "delete":
sel.Action = ActionDelete
err = com.compileArgAction(sel, arg)
} }
if err != nil { if err != nil {
@ -461,6 +424,45 @@ func (com *Compiler) compileArgs(sel *Select, args []Arg) error {
return nil return nil
} }
func (com *Compiler) setMutationType(qc *QCode, args []Arg) error {
setActionVar := func(arg *Arg) error {
if arg.Val.Type != nodeVar {
return fmt.Errorf("value for argument '%s' must be a variable", arg.Name)
}
qc.ActionVar = arg.Val.Val
return nil
}
for i := range args {
arg := &args[i]
switch arg.Name {
case "insert":
qc.Type = QTInsert
return setActionVar(arg)
case "update":
qc.Type = QTUpdate
return setActionVar(arg)
case "upsert":
qc.Type = QTUpsert
return setActionVar(arg)
case "delete":
qc.Type = QTDelete
if arg.Val.Type != nodeBool {
return fmt.Errorf("value for argument '%s' must be a boolean", arg.Name)
}
if arg.Val.Val == "false" {
qc.Type = QTQuery
}
return nil
}
}
return nil
}
func (com *Compiler) compileArgObj(st *util.Stack, arg *Arg) (*Exp, error) { func (com *Compiler) compileArgObj(st *util.Stack, arg *Arg) (*Exp, error) {
if arg.Val.Type != nodeObj { if arg.Val.Type != nodeObj {
return nil, fmt.Errorf("expecting an object") return nil, fmt.Errorf("expecting an object")
@ -540,6 +542,10 @@ func (com *Compiler) compileArgNode(st *util.Stack, node *Node, usePool bool) (*
} }
func (com *Compiler) compileArgID(sel *Select, arg *Arg) error { func (com *Compiler) compileArgID(sel *Select, arg *Arg) error {
if sel.ID != 0 {
return nil
}
if sel.Where != nil && sel.Where.Op == OpEqID { if sel.Where != nil && sel.Where.Op == OpEqID {
return nil return nil
} }
@ -732,24 +738,14 @@ func (com *Compiler) compileArgOffset(sel *Select, arg *Arg) error {
return nil return nil
} }
func (com *Compiler) compileArgAction(sel *Select, arg *Arg) error { var zeroTrv = &trval{}
switch sel.Action {
case ActionDelete:
if arg.Val.Type != nodeBool {
return fmt.Errorf("value for argument '%s' must be a boolean", arg.Name)
}
if arg.Val.Val == "false" {
sel.Action = 0
}
default: func (com *Compiler) getRole(role, field string) *trval {
if arg.Val.Type != nodeVar { if trv, ok := com.tr[role][field]; ok {
return fmt.Errorf("value for argument '%s' must be a variable", arg.Name) return trv
} else {
return zeroTrv
} }
sel.ActionVar = arg.Val.Val
}
return nil
} }
func newExp(st *util.Stack, node *Node, usePool bool) (*Exp, error) { func newExp(st *util.Stack, node *Node, usePool bool) (*Exp, error) {

View File

@ -26,7 +26,8 @@ type allowItem struct {
var _allowList allowList var _allowList allowList
type allowList struct { type allowList struct {
list map[string]*allowItem list []*allowItem
index map[string]int
filepath string filepath string
saveChan chan *allowItem saveChan chan *allowItem
active bool active bool
@ -34,7 +35,7 @@ type allowList struct {
func initAllowList(cpath string) { func initAllowList(cpath string) {
_allowList = allowList{ _allowList = allowList{
list: make(map[string]*allowItem), index: make(map[string]int),
saveChan: make(chan *allowItem), saveChan: make(chan *allowItem),
active: true, active: true,
} }
@ -172,17 +173,21 @@ func (al *allowList) load() {
if c == 0 { if c == 0 {
if ty == AL_QUERY { if ty == AL_QUERY {
q := string(b[s:(e + 1)]) q := string(b[s:(e + 1)])
key := gqlHash(q, varBytes, "")
item := &allowItem{ if idx, ok := al.index[key]; !ok {
al.list = append(al.list, &allowItem{
uri: uri, uri: uri,
gql: q, gql: q,
} vars: varBytes,
})
if len(varBytes) != 0 { al.index[key] = len(al.list) - 1
} else {
item := al.list[idx]
item.gql = q
item.vars = varBytes item.vars = varBytes
} }
al.list[gqlHash(q, varBytes)] = item
varBytes = nil varBytes = nil
} else if ty == AL_VARS { } else if ty == AL_VARS {
@ -203,7 +208,15 @@ func (al *allowList) save(item *allowItem) {
if al.active == false { if al.active == false {
return return
} }
al.list[gqlHash(item.gql, item.vars)] = item
key := gqlHash(item.gql, item.vars, "")
if idx, ok := al.index[key]; ok {
al.list[idx] = item
} else {
al.list = append(al.list, item)
al.index[key] = len(al.list) - 1
}
f, err := os.Create(al.filepath) f, err := os.Create(al.filepath)
if err != nil { if err != nil {

View File

@ -7,28 +7,42 @@ import (
) )
var ( var (
userIDProviderKey = struct{}{} userIDProviderKey = "user_id_provider"
userIDKey = struct{}{} userIDKey = "user_id"
userRoleKey = "user_role"
) )
func headerAuth(r *http.Request, c *config) *http.Request { func headerAuth(next http.HandlerFunc) http.HandlerFunc {
if len(c.Auth.Header) == 0 { return func(w http.ResponseWriter, r *http.Request) {
return nil ctx := r.Context()
userIDProvider := r.Header.Get("X-User-ID-Provider")
if len(userIDProvider) != 0 {
ctx = context.WithValue(ctx, userIDProviderKey, userIDProvider)
} }
userID := r.Header.Get(c.Auth.Header) userID := r.Header.Get("X-User-ID")
if len(userID) != 0 { if len(userID) != 0 {
ctx := context.WithValue(r.Context(), userIDKey, userID) ctx = context.WithValue(ctx, userIDKey, userID)
return r.WithContext(ctx)
} }
return nil userRole := r.Header.Get("X-User-Role")
if len(userRole) != 0 {
ctx = context.WithValue(ctx, userRoleKey, userRole)
}
next.ServeHTTP(w, r.WithContext(ctx))
}
} }
func withAuth(next http.HandlerFunc) http.HandlerFunc { func withAuth(next http.HandlerFunc) http.HandlerFunc {
at := conf.Auth.Type at := conf.Auth.Type
ru := conf.Auth.Rails.URL ru := conf.Auth.Rails.URL
if conf.Auth.CredsInHeader {
next = headerAuth(next)
}
switch at { switch at {
case "rails": case "rails":
if strings.HasPrefix(ru, "memcache:") { if strings.HasPrefix(ru, "memcache:") {

View File

@ -58,11 +58,6 @@ func jwtHandler(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
var tok string var tok string
if rn := headerAuth(r, conf); rn != nil {
next.ServeHTTP(w, rn)
return
}
if len(cookie) != 0 { if len(cookie) != 0 {
ck, err := r.Cookie(cookie) ck, err := r.Cookie(cookie)
if err != nil { if err != nil {
@ -102,7 +97,6 @@ func jwtHandler(next http.HandlerFunc) http.HandlerFunc {
} }
next.ServeHTTP(w, r.WithContext(ctx)) next.ServeHTTP(w, r.WithContext(ctx))
} }
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
} }
} }

View File

@ -42,11 +42,6 @@ func railsRedisHandler(next http.HandlerFunc) http.HandlerFunc {
} }
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
if rn := headerAuth(r, conf); rn != nil {
next.ServeHTTP(w, rn)
return
}
ck, err := r.Cookie(cookie) ck, err := r.Cookie(cookie)
if err != nil { if err != nil {
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
@ -83,17 +78,12 @@ func railsMemcacheHandler(next http.HandlerFunc) http.HandlerFunc {
rURL, err := url.Parse(conf.Auth.Rails.URL) rURL, err := url.Parse(conf.Auth.Rails.URL)
if err != nil { if err != nil {
logger.Fatal().Err(err) logger.Fatal().Err(err).Send()
} }
mc := memcache.New(rURL.Host) mc := memcache.New(rURL.Host)
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
if rn := headerAuth(r, conf); rn != nil {
next.ServeHTTP(w, rn)
return
}
ck, err := r.Cookie(cookie) ck, err := r.Cookie(cookie)
if err != nil { if err != nil {
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
@ -126,25 +116,20 @@ func railsCookieHandler(next http.HandlerFunc) http.HandlerFunc {
ra, err := railsAuth(conf) ra, err := railsAuth(conf)
if err != nil { if err != nil {
logger.Fatal().Err(err) logger.Fatal().Err(err).Send()
} }
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {
if rn := headerAuth(r, conf); rn != nil {
next.ServeHTTP(w, rn)
return
}
ck, err := r.Cookie(cookie) ck, err := r.Cookie(cookie)
if err != nil { if err != nil {
logger.Error().Err(err) logger.Warn().Err(err).Msg("rails cookie missing")
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
return return
} }
userID, err := ra.ParseCookie(ck.Value) userID, err := ra.ParseCookie(ck.Value)
if err != nil { if err != nil {
logger.Error().Err(err) logger.Warn().Err(err).Msg("failed to parse rails cookie")
next.ServeHTTP(w, r) next.ServeHTTP(w, r)
return return
} }

View File

@ -10,7 +10,6 @@ import (
"github.com/dosco/super-graph/qcode" "github.com/dosco/super-graph/qcode"
"github.com/gobuffalo/flect" "github.com/gobuffalo/flect"
"github.com/jackc/pgx/v4" "github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/log/zerologadapter"
"github.com/jackc/pgx/v4/pgxpool" "github.com/jackc/pgx/v4/pgxpool"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/spf13/cobra" "github.com/spf13/cobra"
@ -184,7 +183,34 @@ func initConf() (*config, error) {
} }
zerolog.SetGlobalLevel(logLevel) zerolog.SetGlobalLevel(logLevel)
//fmt.Printf("%#v", c) for k, v := range c.DB.Vars {
c.DB.Vars[k] = sanitize(v)
}
c.RolesQuery = sanitize(c.RolesQuery)
rolesMap := make(map[string]struct{})
for i := range c.Roles {
role := &c.Roles[i]
if _, ok := rolesMap[role.Name]; ok {
logger.Fatal().Msgf("duplicate role '%s' found", role.Name)
}
role.Name = sanitize(role.Name)
role.Match = sanitize(role.Match)
rolesMap[role.Name] = struct{}{}
}
if _, ok := rolesMap["user"]; !ok {
c.Roles = append(c.Roles, configRole{Name: "user"})
}
if _, ok := rolesMap["anon"]; !ok {
c.Roles = append(c.Roles, configRole{Name: "anon"})
}
c.Validate()
return c, nil return c, nil
} }
@ -217,7 +243,7 @@ func initDB(c *config, useDB bool) (*pgx.Conn, error) {
config.LogLevel = pgx.LogLevelNone config.LogLevel = pgx.LogLevelNone
} }
config.Logger = zerologadapter.NewLogger(*logger) config.Logger = NewSQLLogger(*logger)
db, err := pgx.ConnectConfig(context.Background(), config) db, err := pgx.ConnectConfig(context.Background(), config)
if err != nil { if err != nil {
@ -252,7 +278,7 @@ func initDBPool(c *config) (*pgxpool.Pool, error) {
config.ConnConfig.LogLevel = pgx.LogLevelNone config.ConnConfig.LogLevel = pgx.LogLevelNone
} }
config.ConnConfig.Logger = zerologadapter.NewLogger(*logger) config.ConnConfig.Logger = NewSQLLogger(*logger)
// if c.DB.MaxRetries != 0 { // if c.DB.MaxRetries != 0 {
// opt.MaxRetries = c.DB.MaxRetries // opt.MaxRetries = c.DB.MaxRetries

View File

@ -66,6 +66,7 @@ func graphQLFunc(query string, data interface{}) map[string]interface{} {
c := &coreContext{Context: context.Background()} c := &coreContext{Context: context.Background()}
c.req.Query = query c.req.Query = query
c.req.Vars = b c.req.Vars = b
c.req.role = "user"
res, err := c.execQuery() res, err := c.execQuery()
if err != nil { if err != nil {

View File

@ -1,7 +1,9 @@
package serv package serv
import ( import (
"regexp"
"strings" "strings"
"unicode"
"github.com/spf13/viper" "github.com/spf13/viper"
) )
@ -26,7 +28,7 @@ type config struct {
Auth struct { Auth struct {
Type string Type string
Cookie string Cookie string
Header string CredsInHeader bool `mapstructure:"creds_in_header"`
Rails struct { Rails struct {
Version string Version string
@ -60,10 +62,10 @@ type config struct {
MaxRetries int `mapstructure:"max_retries"` MaxRetries int `mapstructure:"max_retries"`
LogLevel string `mapstructure:"log_level"` LogLevel string `mapstructure:"log_level"`
vars map[string][]byte `mapstructure:"variables"` Vars map[string]string `mapstructure:"variables"`
Defaults struct { Defaults struct {
Filter []string Filters []string
Blocklist []string Blocklist []string
} }
@ -71,15 +73,13 @@ type config struct {
} `mapstructure:"database"` } `mapstructure:"database"`
Tables []configTable Tables []configTable
RolesQuery string `mapstructure:"roles_query"`
Roles []configRole
} }
type configTable struct { type configTable struct {
Name string Name string
Filter []string
FilterQuery []string `mapstructure:"filter_query"`
FilterInsert []string `mapstructure:"filter_insert"`
FilterUpdate []string `mapstructure:"filter_update"`
FilterDelete []string `mapstructure:"filter_delete"`
Table string Table string
Blocklist []string Blocklist []string
Remotes []configRemote Remotes []configRemote
@ -98,6 +98,42 @@ type configRemote struct {
} `mapstructure:"set_headers"` } `mapstructure:"set_headers"`
} }
type configRole struct {
Name string
Match string
Tables []struct {
Name string
Query struct {
Limit int
Filters []string
Columns []string
DisableAggregation bool `mapstructure:"disable_aggregation"`
Deny bool
}
Insert struct {
Filters []string
Columns []string
Set map[string]string
Deny bool
}
Update struct {
Filters []string
Columns []string
Set map[string]string
Deny bool
}
Delete struct {
Filters []string
Columns []string
Deny bool
}
}
}
func newConfig() *viper.Viper { func newConfig() *viper.Viper {
vi := viper.New() vi := viper.New()
@ -132,24 +168,30 @@ func newConfig() *viper.Viper {
return vi return vi
} }
func (c *config) getVariables() map[string]string { func (c *config) Validate() {
vars := make(map[string]string, len(c.DB.vars)) rm := make(map[string]struct{})
for k, v := range c.DB.vars { for i := range c.Roles {
isVar := false name := strings.ToLower(c.Roles[i].Name)
if _, ok := rm[name]; ok {
logger.Fatal().Msgf("duplicate config for role '%s'", c.Roles[i].Name)
}
rm[name] = struct{}{}
}
for i := range v { tm := make(map[string]struct{})
if v[i] == '$' {
isVar = true for i := range c.Tables {
} else if v[i] == ' ' { name := strings.ToLower(c.Tables[i].Name)
isVar = false if _, ok := tm[name]; ok {
} else if isVar && v[i] >= 'a' && v[i] <= 'z' { logger.Fatal().Msgf("duplicate config for table '%s'", c.Tables[i].Name)
v[i] = 'A' + (v[i] - 'a')
} }
tm[name] = struct{}{}
} }
vars[k] = string(v)
if len(c.RolesQuery) == 0 {
logger.Warn().Msgf("no 'roles_query' defined.")
} }
return vars
} }
func (c *config) getAliasMap() map[string][]string { func (c *config) getAliasMap() map[string][]string {
@ -167,3 +209,21 @@ func (c *config) getAliasMap() map[string][]string {
} }
return m return m
} }
var varRe1 = regexp.MustCompile(`(?mi)\$([a-zA-Z0-9_.]+)`)
var varRe2 = regexp.MustCompile(`\{\{([a-zA-Z0-9_.]+)\}\}`)
func sanitize(s string) string {
s0 := varRe1.ReplaceAllString(s, `{{$1}}`)
s1 := strings.Map(func(r rune) rune {
if unicode.IsSpace(r) {
return ' '
}
return r
}, s0)
return varRe2.ReplaceAllStringFunc(s1, func(m string) string {
return strings.ToLower(m)
})
}

View File

@ -13,8 +13,8 @@ import (
"github.com/cespare/xxhash/v2" "github.com/cespare/xxhash/v2"
"github.com/dosco/super-graph/jsn" "github.com/dosco/super-graph/jsn"
"github.com/dosco/super-graph/psql"
"github.com/dosco/super-graph/qcode" "github.com/dosco/super-graph/qcode"
"github.com/jackc/pgx/v4"
"github.com/valyala/fasttemplate" "github.com/valyala/fasttemplate"
) )
@ -32,6 +32,12 @@ func (c *coreContext) handleReq(w io.Writer, req *http.Request) error {
c.req.ref = req.Referer() c.req.ref = req.Referer()
c.req.hdr = req.Header c.req.hdr = req.Header
if authCheck(c) {
c.req.role = "user"
} else {
c.req.role = "anon"
}
b, err := c.execQuery() b, err := c.execQuery()
if err != nil { if err != nil {
return err return err
@ -46,10 +52,12 @@ func (c *coreContext) execQuery() ([]byte, error) {
var qc *qcode.QCode var qc *qcode.QCode
var data []byte var data []byte
logger.Debug().Str("role", c.req.role).Msg(c.req.Query)
if conf.UseAllowList { if conf.UseAllowList {
var ps *preparedItem var ps *preparedItem
data, ps, err = c.resolvePreparedSQL(c.req.Query) data, ps, err = c.resolvePreparedSQL()
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -59,12 +67,7 @@ func (c *coreContext) execQuery() ([]byte, error) {
} else { } else {
qc, err = qcompile.Compile([]byte(c.req.Query)) data, skipped, err = c.resolveSQL()
if err != nil {
return nil, err
}
data, skipped, err = c.resolveSQL(qc)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -112,6 +115,160 @@ func (c *coreContext) execQuery() ([]byte, error) {
return ob.Bytes(), nil return ob.Bytes(), nil
} }
func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) {
tx, err := db.Begin(c)
if err != nil {
return nil, nil, err
}
defer tx.Rollback(c)
if v := c.Value(userIDKey); v != nil {
_, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
if err != nil {
return nil, nil, err
}
}
var role string
mutation := isMutation(c.req.Query)
useRoleQuery := len(conf.RolesQuery) != 0 && mutation
if useRoleQuery {
if role, err = c.executeRoleQuery(tx); err != nil {
return nil, nil, err
}
} else if v := c.Value(userRoleKey); v != nil {
role = v.(string)
} else if mutation {
role = c.req.role
}
ps, ok := _preparedList[gqlHash(c.req.Query, c.req.Vars, role)]
if !ok {
return nil, nil, errUnauthorized
}
var root []byte
vars := varList(c, ps.args)
if mutation {
err = tx.QueryRow(c, ps.stmt.SQL, vars...).Scan(&root)
} else {
err = tx.QueryRow(c, ps.stmt.SQL, vars...).Scan(&c.req.role, &root)
}
if err != nil {
return nil, nil, err
}
if err := tx.Commit(c); err != nil {
return nil, nil, err
}
return root, ps, nil
}
func (c *coreContext) resolveSQL() ([]byte, uint32, error) {
tx, err := db.Begin(c)
if err != nil {
return nil, 0, err
}
defer tx.Rollback(c)
mutation := isMutation(c.req.Query)
useRoleQuery := len(conf.RolesQuery) != 0 && mutation
if useRoleQuery {
if c.req.role, err = c.executeRoleQuery(tx); err != nil {
return nil, 0, err
}
} else if v := c.Value(userRoleKey); v != nil {
c.req.role = v.(string)
}
stmts, err := c.buildStmt()
if err != nil {
return nil, 0, err
}
var st *stmt
if mutation {
st = findStmt(c.req.role, stmts)
} else {
st = &stmts[0]
}
t := fasttemplate.New(st.sql, openVar, closeVar)
buf := &bytes.Buffer{}
_, err = t.ExecuteFunc(buf, varMap(c))
if err == errNoUserID &&
authFailBlock == authFailBlockPerQuery &&
authCheck(c) == false {
return nil, 0, errUnauthorized
}
if err != nil {
return nil, 0, err
}
finalSQL := buf.String()
var stime time.Time
if conf.EnableTracing {
stime = time.Now()
}
if v := c.Value(userIDKey); v != nil {
_, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
if err != nil {
return nil, 0, err
}
}
var root []byte
if mutation {
err = tx.QueryRow(c, finalSQL).Scan(&root)
} else {
err = tx.QueryRow(c, finalSQL).Scan(&c.req.role, &root)
}
if err != nil {
return nil, 0, err
}
if err := tx.Commit(c); err != nil {
return nil, 0, err
}
if mutation {
st = findStmt(c.req.role, stmts)
} else {
st = &stmts[0]
}
if conf.EnableTracing && len(st.qc.Selects) != 0 {
c.addTrace(
st.qc.Selects,
st.qc.Selects[0].ID,
stime)
}
if conf.UseAllowList == false {
_allowList.add(&c.req)
}
return root, st.skipped, nil
}
func (c *coreContext) resolveRemote( func (c *coreContext) resolveRemote(
hdr http.Header, hdr http.Header,
h *xxhash.Digest, h *xxhash.Digest,
@ -259,125 +416,15 @@ func (c *coreContext) resolveRemotes(
return to, cerr return to, cerr
} }
func (c *coreContext) resolvePreparedSQL(gql string) ([]byte, *preparedItem, error) { func (c *coreContext) executeRoleQuery(tx pgx.Tx) (string, error) {
ps, ok := _preparedList[gqlHash(gql, c.req.Vars)] var role string
if !ok { row := tx.QueryRow(c, "_sg_get_role", c.req.role, 1)
return nil, nil, errUnauthorized
if err := row.Scan(&role); err != nil {
return "", err
} }
var root []byte return role, nil
vars := varList(c, ps.args)
tx, err := db.Begin(c)
if err != nil {
return nil, nil, err
}
defer tx.Rollback(c)
if v := c.Value(userIDKey); v != nil {
_, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
if err != nil {
return nil, nil, err
}
}
err = tx.QueryRow(c, ps.stmt.SQL, vars...).Scan(&root)
if err != nil {
return nil, nil, err
}
if err := tx.Commit(c); err != nil {
return nil, nil, err
}
fmt.Printf("PRE: %v\n", ps.stmt)
return root, ps, nil
}
func (c *coreContext) resolveSQL(qc *qcode.QCode) ([]byte, uint32, error) {
var vars map[string]json.RawMessage
stmt := &bytes.Buffer{}
if len(c.req.Vars) != 0 {
if err := json.Unmarshal(c.req.Vars, &vars); err != nil {
return nil, 0, err
}
}
skipped, err := pcompile.Compile(qc, stmt, psql.Variables(vars))
if err != nil {
return nil, 0, err
}
t := fasttemplate.New(stmt.String(), openVar, closeVar)
stmt.Reset()
_, err = t.ExecuteFunc(stmt, varMap(c))
if err == errNoUserID &&
authFailBlock == authFailBlockPerQuery &&
authCheck(c) == false {
return nil, 0, errUnauthorized
}
if err != nil {
return nil, 0, err
}
finalSQL := stmt.String()
// if conf.LogLevel == "debug" {
// os.Stdout.WriteString(finalSQL)
// os.Stdout.WriteString("\n\n")
// }
var st time.Time
if conf.EnableTracing {
st = time.Now()
}
tx, err := db.Begin(c)
if err != nil {
return nil, 0, err
}
defer tx.Rollback(c)
if v := c.Value(userIDKey); v != nil {
_, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
if err != nil {
return nil, 0, err
}
}
//fmt.Printf("\nRAW: %#v\n", finalSQL)
var root []byte
err = tx.QueryRow(c, finalSQL).Scan(&root)
if err != nil {
return nil, 0, err
}
if err := tx.Commit(c); err != nil {
return nil, 0, err
}
if conf.EnableTracing && len(qc.Selects) != 0 {
c.addTrace(
qc.Selects,
qc.Selects[0].ID,
st)
}
if conf.UseAllowList == false {
_allowList.add(&c.req)
}
return root, skipped, nil
} }
func (c *coreContext) render(w io.Writer, data []byte) error { func (c *coreContext) render(w io.Writer, data []byte) error {

144
serv/core_build.go Normal file
View File

@ -0,0 +1,144 @@
package serv
import (
"bytes"
"encoding/json"
"errors"
"io"
"github.com/dosco/super-graph/psql"
"github.com/dosco/super-graph/qcode"
)
type stmt struct {
role *configRole
qc *qcode.QCode
skipped uint32
sql string
}
func (c *coreContext) buildStmt() ([]stmt, error) {
var vars map[string]json.RawMessage
if len(c.req.Vars) != 0 {
if err := json.Unmarshal(c.req.Vars, &vars); err != nil {
return nil, err
}
}
gql := []byte(c.req.Query)
if len(conf.Roles) == 0 {
return nil, errors.New(`no roles found ('user' and 'anon' required)`)
}
qc, err := qcompile.Compile(gql, conf.Roles[0].Name)
if err != nil {
return nil, err
}
stmts := make([]stmt, 0, len(conf.Roles))
mutation := (qc.Type != qcode.QTQuery)
w := &bytes.Buffer{}
for i := range conf.Roles {
role := &conf.Roles[i]
if mutation && len(c.req.role) != 0 && role.Name != c.req.role {
continue
}
if i > 0 {
qc, err = qcompile.Compile(gql, role.Name)
if err != nil {
return nil, err
}
}
stmts = append(stmts, stmt{role: role, qc: qc})
if mutation {
skipped, err := pcompile.Compile(qc, w, psql.Variables(vars))
if err != nil {
return nil, err
}
s := &stmts[len(stmts)-1]
s.skipped = skipped
s.sql = w.String()
w.Reset()
}
}
if mutation {
return stmts, nil
}
io.WriteString(w, `SELECT "_sg_auth_info"."role", (CASE "_sg_auth_info"."role" `)
for _, s := range stmts {
io.WriteString(w, `WHEN '`)
io.WriteString(w, s.role.Name)
io.WriteString(w, `' THEN (`)
s.skipped, err = pcompile.Compile(s.qc, w, psql.Variables(vars))
if err != nil {
return nil, err
}
io.WriteString(w, `) `)
}
io.WriteString(w, `END) FROM (`)
if len(conf.RolesQuery) == 0 {
v := c.Value(userRoleKey)
io.WriteString(w, `VALUES ("`)
if v != nil {
io.WriteString(w, v.(string))
} else {
io.WriteString(w, c.req.role)
}
io.WriteString(w, `")) AS "_sg_auth_info"(role) LIMIT 1;`)
} else {
io.WriteString(w, `SELECT (CASE WHEN EXISTS (`)
io.WriteString(w, conf.RolesQuery)
io.WriteString(w, `) THEN `)
io.WriteString(w, `(SELECT (CASE`)
for _, s := range stmts {
if len(s.role.Match) == 0 {
continue
}
io.WriteString(w, ` WHEN `)
io.WriteString(w, s.role.Match)
io.WriteString(w, ` THEN '`)
io.WriteString(w, s.role.Name)
io.WriteString(w, `'`)
}
if len(c.req.role) == 0 {
io.WriteString(w, ` ELSE 'anon' END) FROM (`)
} else {
io.WriteString(w, ` ELSE '`)
io.WriteString(w, c.req.role)
io.WriteString(w, `' END) FROM (`)
}
io.WriteString(w, conf.RolesQuery)
io.WriteString(w, `) AS "_sg_auth_roles_query" LIMIT 1) ELSE '`)
if len(c.req.role) == 0 {
io.WriteString(w, `anon`)
} else {
io.WriteString(w, c.req.role)
}
io.WriteString(w, `' END) FROM (VALUES (1)) AS "_sg_auth_filler") AS "_sg_auth_info"(role) LIMIT 1; `)
}
stmts[0].sql = w.String()
stmts[0].role = nil
return stmts, nil
}

View File

@ -30,14 +30,15 @@ type gqlReq struct {
Query string `json:"query"` Query string `json:"query"`
Vars json.RawMessage `json:"variables"` Vars json.RawMessage `json:"variables"`
ref string ref string
role string
hdr http.Header hdr http.Header
} }
type variables map[string]json.RawMessage type variables map[string]json.RawMessage
type gqlResp struct { type gqlResp struct {
Error string `json:"error,omitempty"` Error string `json:"message,omitempty"`
Data json.RawMessage `json:"data"` Data json.RawMessage `json:"data,omitempty"`
Extensions *extensions `json:"extensions,omitempty"` Extensions *extensions `json:"extensions,omitempty"`
} }
@ -94,55 +95,20 @@ func apiv1Http(w http.ResponseWriter, r *http.Request) {
} }
if strings.EqualFold(ctx.req.OpName, introspectionQuery) { if strings.EqualFold(ctx.req.OpName, introspectionQuery) {
// dat, err := ioutil.ReadFile("test.schema") introspect(w)
// if err != nil {
// http.Error(w, err.Error(), http.StatusInternalServerError)
// return
// }
//w.Write(dat)
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{
"data": {
"__schema": {
"queryType": {
"name": "Query"
},
"mutationType": null,
"subscriptionType": null
}
},
"extensions":{
"tracing":{
"version":1,
"startTime":"2019-06-04T19:53:31.093Z",
"endTime":"2019-06-04T19:53:31.108Z",
"duration":15219720,
"execution": {
"resolvers": [{
"path": ["__schema"],
"parentType": "Query",
"fieldName": "__schema",
"returnType": "__Schema!",
"startOffset": 50950,
"duration": 17187
}]
}
}
}
}`))
return return
} }
err = ctx.handleReq(w, r) err = ctx.handleReq(w, r)
if err == errUnauthorized { if err == errUnauthorized {
err := "Not authorized" w.WriteHeader(http.StatusUnauthorized)
logger.Debug().Msg(err) json.NewEncoder(w).Encode(gqlResp{Error: err.Error()})
http.Error(w, err, 401) return
} }
if err != nil { if err != nil {
logger.Err(err).Msg("Failed to handle request") logger.Err(err).Msg("failed to handle request")
errorResp(w, err) errorResp(w, err)
} }
} }

36
serv/introsp.go Normal file
View File

@ -0,0 +1,36 @@
package serv
import "net/http"
func introspect(w http.ResponseWriter) {
w.Header().Set("Content-Type", "application/json")
w.Write([]byte(`{
"data": {
"__schema": {
"queryType": {
"name": "Query"
},
"mutationType": null,
"subscriptionType": null
}
},
"extensions":{
"tracing":{
"version":1,
"startTime":"2019-06-04T19:53:31.093Z",
"endTime":"2019-06-04T19:53:31.108Z",
"duration":15219720,
"execution": {
"resolvers": [{
"path": ["__schema"],
"parentType": "Query",
"fieldName": "__schema",
"returnType": "__Schema!",
"startOffset": 50950,
"duration": 17187
}]
}
}
}
}`))
}

View File

@ -7,7 +7,6 @@ import (
"fmt" "fmt"
"io" "io"
"github.com/dosco/super-graph/psql"
"github.com/dosco/super-graph/qcode" "github.com/dosco/super-graph/qcode"
"github.com/jackc/pgconn" "github.com/jackc/pgconn"
"github.com/valyala/fasttemplate" "github.com/valyala/fasttemplate"
@ -27,54 +26,43 @@ var (
func initPreparedList() { func initPreparedList() {
_preparedList = make(map[string]*preparedItem) _preparedList = make(map[string]*preparedItem)
for k, v := range _allowList.list { if err := prepareRoleStmt(); err != nil {
err := prepareStmt(k, v.gql, v.vars) logger.Fatal().Err(err).Msg("failed to prepare get role statement")
}
for _, v := range _allowList.list {
err := prepareStmt(v.gql, v.vars)
if err != nil { if err != nil {
logger.Warn().Err(err).Send() logger.Warn().Str("gql", v.gql).Err(err).Send()
} }
} }
} }
func prepareStmt(key, gql string, varBytes json.RawMessage) error { func prepareStmt(gql string, varBytes json.RawMessage) error {
if len(gql) == 0 || len(key) == 0 { if len(gql) == 0 {
return nil return nil
} }
qc, err := qcompile.Compile([]byte(gql)) c := &coreContext{Context: context.Background()}
c.req.Query = gql
c.req.Vars = varBytes
stmts, err := c.buildStmt()
if err != nil { if err != nil {
return err return err
} }
var vars map[string]json.RawMessage if len(stmts) != 0 && stmts[0].qc.Type == qcode.QTQuery {
c.req.Vars = nil
if len(varBytes) != 0 {
vars = make(map[string]json.RawMessage)
if err := json.Unmarshal(varBytes, &vars); err != nil {
return err
}
} }
buf := &bytes.Buffer{} for _, s := range stmts {
if len(s.sql) == 0 {
skipped, err := pcompile.Compile(qc, buf, psql.Variables(vars)) continue
if err != nil {
return err
} }
t := fasttemplate.New(buf.String(), `{{`, `}}`) finalSQL, am := processTemplate(s.sql)
am := make([][]byte, 0, 5)
i := 0
finalSQL := t.ExecuteFuncString(func(w io.Writer, tag string) (int, error) {
am = append(am, []byte(tag))
i++
return w.Write([]byte(fmt.Sprintf("$%d", i)))
})
if err != nil {
return err
}
ctx := context.Background() ctx := context.Background()
@ -89,16 +77,84 @@ func prepareStmt(key, gql string, varBytes json.RawMessage) error {
return err return err
} }
var key string
if s.role == nil {
key = gqlHash(gql, c.req.Vars, "")
} else {
key = gqlHash(gql, c.req.Vars, s.role.Name)
}
_preparedList[key] = &preparedItem{ _preparedList[key] = &preparedItem{
stmt: pstmt, stmt: pstmt,
args: am, args: am,
skipped: skipped, skipped: s.skipped,
qc: qc, qc: s.qc,
} }
if err := tx.Commit(ctx); err != nil { if err := tx.Commit(ctx); err != nil {
return err return err
} }
}
return nil return nil
} }
func prepareRoleStmt() error {
if len(conf.RolesQuery) == 0 {
return nil
}
w := &bytes.Buffer{}
io.WriteString(w, `SELECT (CASE`)
for _, role := range conf.Roles {
if len(role.Match) == 0 {
continue
}
io.WriteString(w, ` WHEN `)
io.WriteString(w, role.Match)
io.WriteString(w, ` THEN '`)
io.WriteString(w, role.Name)
io.WriteString(w, `'`)
}
io.WriteString(w, ` ELSE {{role}} END) FROM (`)
io.WriteString(w, conf.RolesQuery)
io.WriteString(w, `) AS "_sg_auth_roles_query"`)
roleSQL, _ := processTemplate(w.String())
ctx := context.Background()
tx, err := db.Begin(ctx)
if err != nil {
return err
}
defer tx.Rollback(ctx)
_, err = tx.Prepare(ctx, "_sg_get_role", roleSQL)
if err != nil {
return err
}
return nil
}
func processTemplate(tmpl string) (string, [][]byte) {
t := fasttemplate.New(tmpl, `{{`, `}}`)
am := make([][]byte, 0, 5)
i := 0
vmap := make(map[string]int)
return t.ExecuteFuncString(func(w io.Writer, tag string) (int, error) {
if n, ok := vmap[tag]; ok {
return w.Write([]byte(fmt.Sprintf("$%d", n)))
}
am = append(am, []byte(tag))
i++
vmap[tag] = i
return w.Write([]byte(fmt.Sprintf("$%d", i)))
}), am
}

View File

@ -12,7 +12,6 @@ import (
rice "github.com/GeertJohan/go.rice" rice "github.com/GeertJohan/go.rice"
"github.com/dosco/super-graph/psql" "github.com/dosco/super-graph/psql"
"github.com/dosco/super-graph/qcode" "github.com/dosco/super-graph/qcode"
"github.com/gobuffalo/flect"
) )
func initCompilers(c *config) (*qcode.Compiler, *psql.Compiler, error) { func initCompilers(c *config) (*qcode.Compiler, *psql.Compiler, error) {
@ -22,52 +21,53 @@ func initCompilers(c *config) (*qcode.Compiler, *psql.Compiler, error) {
} }
conf := qcode.Config{ conf := qcode.Config{
DefaultFilter: c.DB.Defaults.Filter,
FilterMap: qcode.Filters{
All: make(map[string][]string, len(c.Tables)),
Query: make(map[string][]string, len(c.Tables)),
Insert: make(map[string][]string, len(c.Tables)),
Update: make(map[string][]string, len(c.Tables)),
Delete: make(map[string][]string, len(c.Tables)),
},
Blocklist: c.DB.Defaults.Blocklist, Blocklist: c.DB.Defaults.Blocklist,
KeepArgs: false, KeepArgs: false,
} }
for i := range c.Tables {
t := c.Tables[i]
singular := flect.Singularize(t.Name)
plural := flect.Pluralize(t.Name)
setFilter := func(fm map[string][]string, fil []string) {
switch {
case len(fil) == 0:
return
case fil[0] == "none" || len(fil[0]) == 0:
fm[singular] = []string{}
fm[plural] = []string{}
default:
fm[singular] = t.Filter
fm[plural] = t.Filter
}
}
setFilter(conf.FilterMap.All, t.Filter)
setFilter(conf.FilterMap.Query, t.FilterQuery)
setFilter(conf.FilterMap.Insert, t.FilterInsert)
setFilter(conf.FilterMap.Update, t.FilterUpdate)
setFilter(conf.FilterMap.Delete, t.FilterDelete)
}
qc, err := qcode.NewCompiler(conf) qc, err := qcode.NewCompiler(conf)
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
for _, r := range c.Roles {
for _, t := range r.Tables {
query := qcode.QueryConfig{
Limit: t.Query.Limit,
Filters: t.Query.Filters,
Columns: t.Query.Columns,
DisableFunctions: t.Query.DisableAggregation,
}
insert := qcode.InsertConfig{
Filters: t.Insert.Filters,
Columns: t.Insert.Columns,
Set: t.Insert.Set,
}
update := qcode.UpdateConfig{
Filters: t.Insert.Filters,
Columns: t.Insert.Columns,
Set: t.Insert.Set,
}
delete := qcode.DeleteConfig{
Filters: t.Insert.Filters,
Columns: t.Insert.Columns,
}
qc.AddRole(r.Name, t.Name, qcode.TRConfig{
Query: query,
Insert: insert,
Update: update,
Delete: delete,
})
}
}
pc := psql.NewCompiler(psql.Config{ pc := psql.NewCompiler(psql.Config{
Schema: schema, Schema: schema,
Vars: c.getVariables(), Vars: c.DB.Vars,
}) })
return qc, pc, nil return qc, pc, nil

45
serv/sqllog.go Normal file
View File

@ -0,0 +1,45 @@
package serv
import (
"context"
"github.com/jackc/pgx/v4"
"github.com/rs/zerolog"
)
type Logger struct {
logger zerolog.Logger
}
// NewLogger accepts a zerolog.Logger as input and returns a new custom pgx
// logging fascade as output.
func NewSQLLogger(logger zerolog.Logger) *Logger {
return &Logger{
logger: logger.With().Logger(),
}
}
func (pl *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) {
var zlevel zerolog.Level
switch level {
case pgx.LogLevelNone:
zlevel = zerolog.NoLevel
case pgx.LogLevelError:
zlevel = zerolog.ErrorLevel
case pgx.LogLevelWarn:
zlevel = zerolog.WarnLevel
case pgx.LogLevelInfo:
zlevel = zerolog.InfoLevel
case pgx.LogLevelDebug:
zlevel = zerolog.DebugLevel
default:
zlevel = zerolog.DebugLevel
}
if sql, ok := data["sql"]; ok {
delete(data, "sql")
pl.logger.WithLevel(zlevel).Fields(data).Msg(sql.(string))
} else {
pl.logger.WithLevel(zlevel).Fields(data).Msg(msg)
}
}

View File

@ -21,16 +21,29 @@ func mkkey(h *xxhash.Digest, k1 string, k2 string) uint64 {
return v return v
} }
func gqlHash(b string, vars []byte) string { func gqlHash(b string, vars []byte, role string) string {
b = strings.TrimSpace(b) b = strings.TrimSpace(b)
h := sha1.New() h := sha1.New()
query := "query"
s, e := 0, 0 s, e := 0, 0
space := []byte{' '} space := []byte{' '}
starting := true
var b0, b1 byte var b0, b1 byte
for { for {
if starting && b[e] == 'q' {
n := 0
se := e
for e < len(b) && n < len(query) && b[e] == query[n] {
n++
e++
}
if n != len(query) {
io.WriteString(h, strings.ToLower(b[se:e]))
}
}
if ws(b[e]) { if ws(b[e]) {
for e < len(b) && ws(b[e]) { for e < len(b) && ws(b[e]) {
e++ e++
@ -42,6 +55,7 @@ func gqlHash(b string, vars []byte) string {
h.Write(space) h.Write(space)
} }
} else { } else {
starting = false
s = e s = e
for e < len(b) && ws(b[e]) == false { for e < len(b) && ws(b[e]) == false {
e++ e++
@ -56,6 +70,10 @@ func gqlHash(b string, vars []byte) string {
} }
} }
if len(role) != 0 {
io.WriteString(h, role)
}
if vars == nil || len(vars) == 0 { if vars == nil || len(vars) == 0 {
return hex.EncodeToString(h.Sum(nil)) return hex.EncodeToString(h.Sum(nil))
} }
@ -80,3 +98,26 @@ func ws(b byte) bool {
func al(b byte) bool { func al(b byte) bool {
return (b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z') || (b >= '0' && b <= '9') return (b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z') || (b >= '0' && b <= '9')
} }
func isMutation(sql string) bool {
for i := range sql {
b := sql[i]
if b == '{' {
return false
}
if al(b) {
return (b == 'm' || b == 'M')
}
}
return false
}
func findStmt(role string, stmts []stmt) *stmt {
for i := range stmts {
if stmts[i].role.Name != role {
continue
}
return &stmts[i]
}
return nil
}

View File

@ -5,7 +5,7 @@ import (
"testing" "testing"
) )
func TestRelaxHash1(t *testing.T) { func TestGQLHash1(t *testing.T) {
var v1 = ` var v1 = `
products( products(
limit: 30, limit: 30,
@ -24,15 +24,15 @@ func TestRelaxHash1(t *testing.T) {
price price
} ` } `
h1 := gqlHash(v1, nil) h1 := gqlHash(v1, nil, "")
h2 := gqlHash(v2, nil) h2 := gqlHash(v2, nil, "")
if strings.Compare(h1, h2) != 0 { if strings.Compare(h1, h2) != 0 {
t.Fatal("Hashes don't match they should") t.Fatal("Hashes don't match they should")
} }
} }
func TestRelaxHash2(t *testing.T) { func TestGQLHash2(t *testing.T) {
var v1 = ` var v1 = `
{ {
products( products(
@ -53,15 +53,15 @@ func TestRelaxHash2(t *testing.T) {
var v2 = ` { products( limit: 30, order_by: { price: desc }, distinct: [ price ] where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) { id name price user { id email } } } ` var v2 = ` { products( limit: 30, order_by: { price: desc }, distinct: [ price ] where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) { id name price user { id email } } } `
h1 := gqlHash(v1, nil) h1 := gqlHash(v1, nil, "")
h2 := gqlHash(v2, nil) h2 := gqlHash(v2, nil, "")
if strings.Compare(h1, h2) != 0 { if strings.Compare(h1, h2) != 0 {
t.Fatal("Hashes don't match they should") t.Fatal("Hashes don't match they should")
} }
} }
func TestRelaxHash3(t *testing.T) { func TestGQLHash3(t *testing.T) {
var v1 = `users { var v1 = `users {
id id
email email
@ -86,15 +86,44 @@ func TestRelaxHash3(t *testing.T) {
} }
` `
h1 := gqlHash(v1, nil) h1 := gqlHash(v1, nil, "")
h2 := gqlHash(v2, nil) h2 := gqlHash(v2, nil, "")
if strings.Compare(h1, h2) != 0 { if strings.Compare(h1, h2) != 0 {
t.Fatal("Hashes don't match they should") t.Fatal("Hashes don't match they should")
} }
} }
func TestRelaxHashWithVars1(t *testing.T) { func TestGQLHash4(t *testing.T) {
var v1 = `
query {
products(
limit: 30
order_by: { price: desc }
distinct: [price]
where: { id: { and: { greater_or_equals: 20, lt: 28 } } }
) {
id
name
price
user {
id
email
}
}
}`
var v2 = ` { products( limit: 30, order_by: { price: desc }, distinct: [ price ] where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) { id name price user { id email } } } `
h1 := gqlHash(v1, nil, "")
h2 := gqlHash(v2, nil, "")
if strings.Compare(h1, h2) != 0 {
t.Fatal("Hashes don't match they should")
}
}
func TestGQLHashWithVars1(t *testing.T) {
var q1 = ` var q1 = `
products( products(
limit: 30, limit: 30,
@ -136,15 +165,15 @@ func TestRelaxHashWithVars1(t *testing.T) {
"user": 123 "user": 123
}` }`
h1 := gqlHash(q1, []byte(v1)) h1 := gqlHash(q1, []byte(v1), "user")
h2 := gqlHash(q2, []byte(v2)) h2 := gqlHash(q2, []byte(v2), "user")
if strings.Compare(h1, h2) != 0 { if strings.Compare(h1, h2) != 0 {
t.Fatal("Hashes don't match they should") t.Fatal("Hashes don't match they should")
} }
} }
func TestRelaxHashWithVars2(t *testing.T) { func TestGQLHashWithVars2(t *testing.T) {
var q1 = ` var q1 = `
products( products(
limit: 30, limit: 30,
@ -193,8 +222,8 @@ func TestRelaxHashWithVars2(t *testing.T) {
"user": 123 "user": 123
}` }`
h1 := gqlHash(q1, []byte(v1)) h1 := gqlHash(q1, []byte(v1), "user")
h2 := gqlHash(q2, []byte(v2)) h2 := gqlHash(q2, []byte(v2), "user")
if strings.Compare(h1, h2) != 0 { if strings.Compare(h1, h2) != 0 {
t.Fatal("Hashes don't match they should") t.Fatal("Hashes don't match they should")

View File

@ -11,17 +11,27 @@ import (
func varMap(ctx *coreContext) func(w io.Writer, tag string) (int, error) { func varMap(ctx *coreContext) func(w io.Writer, tag string) (int, error) {
return func(w io.Writer, tag string) (int, error) { return func(w io.Writer, tag string) (int, error) {
switch tag { switch tag {
case "user_id":
if v := ctx.Value(userIDKey); v != nil {
return stringVar(w, v.(string))
}
return 0, errNoUserID
case "user_id_provider": case "user_id_provider":
if v := ctx.Value(userIDProviderKey); v != nil { if v := ctx.Value(userIDProviderKey); v != nil {
return stringVar(w, v.(string)) return stringVar(w, v.(string))
} }
return 0, errNoUserID io.WriteString(w, "null")
return 0, nil
case "user_id":
if v := ctx.Value(userIDKey); v != nil {
return stringVar(w, v.(string))
}
io.WriteString(w, "null")
return 0, nil
case "user_role":
if v := ctx.Value(userRoleKey); v != nil {
return stringVar(w, v.(string))
}
io.WriteString(w, "null")
return 0, nil
} }
fields := jsn.Get(ctx.req.Vars, [][]byte{[]byte(tag)}) fields := jsn.Get(ctx.req.Vars, [][]byte{[]byte(tag)})

View File

@ -80,7 +80,7 @@ SQL Output
account_id: "select account_id from users where id = $user_id" account_id: "select account_id from users where id = $user_id"
defaults: defaults:
filter: ["{ user_id: { eq: $user_id } }"] Filters: ["{ user_id: { eq: $user_id } }"]
blacklist: blacklist:
- password - password
@ -88,14 +88,14 @@ SQL Output
fields: fields:
- name: users - name: users
filter: ["{ id: { eq: $user_id } }"] Filters: ["{ id: { eq: $user_id } }"]
- name: products - name: products
filter: [ Filters: [
"{ price: { gt: 0 } }", "{ price: { gt: 0 } }",
"{ price: { lt: 8 } }" "{ price: { lt: 8 } }"
] ]
- name: me - name: me
table: users table: users
filter: ["{ id: { eq: $user_id } }"] Filters: ["{ id: { eq: $user_id } }"]

View File

@ -1,4 +1,4 @@
app_name: "{% app_name %} Development" app_name: "Super Graph Development"
host_port: 0.0.0.0:8080 host_port: 0.0.0.0:8080
web_ui: true web_ui: true
@ -53,7 +53,7 @@ auth:
# Comment this out if you want to disable setting # Comment this out if you want to disable setting
# the user_id via a header. Good for testing # the user_id via a header. Good for testing
header: X-User-ID creds_in_header: true
rails: rails:
# Rails version this is used for reading the # Rails version this is used for reading the
@ -84,7 +84,7 @@ database:
type: postgres type: postgres
host: db host: db
port: 5432 port: 5432
dbname: {% app_name_slug %}_development dbname: app_development
user: postgres user: postgres
password: '' password: ''
@ -100,7 +100,7 @@ database:
# Define defaults to for the field key and values below # Define defaults to for the field key and values below
defaults: defaults:
# filter: ["{ user_id: { eq: $user_id } }"] # filters: ["{ user_id: { eq: $user_id } }"]
# Field and table names that you wish to block # Field and table names that you wish to block
blocklist: blocklist:
@ -112,25 +112,7 @@ database:
- token - token
tables: tables:
- name: users
# This filter will overwrite defaults.filter
# filter: ["{ id: { eq: $user_id } }"]
# filter_query: ["{ id: { eq: $user_id } }"]
filter_update: ["{ id: { eq: $user_id } }"]
filter_delete: ["{ id: { eq: $user_id } }"]
# - name: products
# # Multiple filters are AND'd together
# filter: [
# "{ price: { gt: 0 } }",
# "{ price: { lt: 8 } }"
# ]
- name: customers - name: customers
# No filter is used for this field not
# even defaults.filter
filter: none
remotes: remotes:
- name: payments - name: payments
id: stripe_id id: stripe_id
@ -149,7 +131,61 @@ database:
# real db table backing them # real db table backing them
name: me name: me
table: users table: users
filter: ["{ id: { eq: $user_id } }"]
# - name: posts roles_query: "SELECT * FROM users as usr WHERE id = $user_id"
# filter: ["{ account_id: { _eq: $account_id } }"]
roles:
- name: anon
tables:
- name: products
limit: 10
query:
columns: ["id", "name", "description" ]
aggregation: false
insert:
allow: false
update:
allow: false
delete:
allow: false
- name: user
tables:
- name: users
query:
filters: ["{ id: { _eq: $user_id } }"]
- name: products
query:
limit: 50
filters: ["{ user_id: { eq: $user_id } }"]
columns: ["id", "name", "description" ]
disable_aggregation: false
insert:
filters: ["{ user_id: { eq: $user_id } }"]
columns: ["id", "name", "description" ]
set:
- created_at: "now"
update:
filters: ["{ user_id: { eq: $user_id } }"]
columns:
- id
- name
set:
- updated_at: "now"
delete:
deny: true
- name: admin
match: id = 1
tables:
- name: users
# query:
# filters: ["{ account_id: { _eq: $account_id } }"]

View File

@ -1,4 +1,4 @@
app_name: "{% app_name %} Production" app_name: "Super Graph Production"
host_port: 0.0.0.0:8080 host_port: 0.0.0.0:8080
web_ui: false web_ui: false
@ -47,10 +47,6 @@ auth:
type: rails type: rails
cookie: _app_session cookie: _app_session
# Comment this out if you want to disable setting
# the user_id via a header. Good for testing
header: X-User-ID
rails: rails:
# Rails version this is used for reading the # Rails version this is used for reading the
# various cookies formats. # various cookies formats.
@ -80,7 +76,7 @@ database:
type: postgres type: postgres
host: db host: db
port: 5432 port: 5432
dbname: {% app_name_slug %}_production dbname: {{app_name_slug}}_development
user: postgres user: postgres
password: '' password: ''
#pool_size: 10 #pool_size: 10
@ -94,7 +90,7 @@ database:
# Define defaults to for the field key and values below # Define defaults to for the field key and values below
defaults: defaults:
filter: ["{ user_id: { eq: $user_id } }"] filters: ["{ user_id: { eq: $user_id } }"]
# Field and table names that you wish to block # Field and table names that you wish to block
blocklist: blocklist:
@ -106,25 +102,7 @@ database:
- token - token
tables: tables:
- name: users
# This filter will overwrite defaults.filter
# filter: ["{ id: { eq: $user_id } }"]
# filter_query: ["{ id: { eq: $user_id } }"]
filter_update: ["{ id: { eq: $user_id } }"]
filter_delete: ["{ id: { eq: $user_id } }"]
- name: products
# Multiple filters are AND'd together
filter: [
"{ price: { gt: 0 } }",
"{ price: { lt: 8 } }"
]
- name: customers - name: customers
# No filter is used for this field not
# even defaults.filter
filter: none
# remotes: # remotes:
# - name: payments # - name: payments
# id: stripe_id # id: stripe_id
@ -141,7 +119,61 @@ database:
# real db table backing them # real db table backing them
name: me name: me
table: users table: users
filter: ["{ id: { eq: $user_id } }"]
# - name: posts roles_query: "SELECT * FROM users as usr WHERE id = $user_id"
# filter: ["{ account_id: { _eq: $account_id } }"]
roles:
- name: anon
tables:
- name: products
limit: 10
query:
columns: ["id", "name", "description" ]
aggregation: false
insert:
allow: false
update:
allow: false
delete:
allow: false
- name: user
tables:
- name: users
query:
filters: ["{ id: { _eq: $user_id } }"]
- name: products
query:
limit: 50
filters: ["{ user_id: { eq: $user_id } }"]
columns: ["id", "name", "description" ]
disable_aggregation: false
insert:
filters: ["{ user_id: { eq: $user_id } }"]
columns: ["id", "name", "description" ]
set:
- created_at: "now"
update:
filters: ["{ user_id: { eq: $user_id } }"]
columns:
- id
- name
set:
- updated_at: "now"
delete:
deny: true
- name: admin
match: id = 1
tables:
- name: users
# query:
# filters: ["{ account_id: { _eq: $account_id } }"]