diff --git a/.gitignore b/.gitignore index 5c88189..fdcd5db 100644 --- a/.gitignore +++ b/.gitignore @@ -28,3 +28,4 @@ main .DS_Store .swp main +super-graph diff --git a/.wtc.yaml b/.wtc.yaml new file mode 100644 index 0000000..215573f --- /dev/null +++ b/.wtc.yaml @@ -0,0 +1,13 @@ +no_trace: false +debounce: 300 # if rule has no debounce, this will be used instead +ignore: \.git/ +trig: [start, run] # will run on start +rules: + - name: start + - name: run + match: \.go$ + ignore: web|examples|docs|_test\.go$ + command: go run main.go serv + - name: test + match: _test\.go$ + command: go test -cover {PKG} \ No newline at end of file diff --git a/Dockerfile b/Dockerfile index c2487cc..805cc0d 100644 --- a/Dockerfile +++ b/Dockerfile @@ -11,9 +11,8 @@ RUN apk update && \ apk add --no-cache git && \ apk add --no-cache upx=3.95-r2 -RUN go get -u github.com/shanzi/wu && \ - go install github.com/shanzi/wu && \ - go get github.com/GeertJohan/go.rice/rice +RUN go get -u github.com/rafaelsq/wtc && \ + go get -u github.com/GeertJohan/go.rice/rice WORKDIR /app COPY . /app diff --git a/README.md b/README.md index ec614f4..e6fbee3 100644 --- a/README.md +++ b/README.md @@ -46,6 +46,9 @@ This compiler is what sits at the heart of Super Graph with layers of useful fun ## Contact me +I'm happy to help you deploy Super Graph so feel free to reach out over +Twitter or Discord. + [twitter/dosco](https://twitter.com/dosco) [chat/super-graph](https://discord.gg/6pSWCTZ) diff --git a/config/allow.list b/config/allow.list index a17a562..196e46b 100644 --- a/config/allow.list +++ b/config/allow.list @@ -1,5 +1,27 @@ # http://localhost:8080/ +variables { + "data": [ + { + "name": "Protect Ya Neck", + "created_at": "now", + "updated_at": "now" + }, + { + "name": "Enter the Wu-Tang", + "created_at": "now", + "updated_at": "now" + } + ] +} + +mutation { + products(insert: $data) { + id + name + } +} + variables { "update": { "name": "Wu-Tang", @@ -16,16 +38,16 @@ mutation { } } -variables { - "data": { - "product_id": 5 - } -} - -mutation { - products(id: $product_id, delete: true) { +query { + users { id - name + email + picture: avatar + products(limit: 2, where: {price: {gt: 10}}) { + id + name + description + } } } @@ -73,6 +95,118 @@ query { } } +variables { + "data": [ + { + "name": "Gumbo1", + "created_at": "now", + "updated_at": "now" + }, + { + "name": "Gumbo2", + "created_at": "now", + "updated_at": "now" + } + ] +} + +query { + products { + id + name + user { + email + } + } +} + +variables { + "data": { + "product_id": 5 + } +} + +mutation { + products(id: $product_id, delete: true) { + id + name + } +} + +variables { + "data": [ + { + "name": "Gumbo1", + "created_at": "now", + "updated_at": "now" + }, + { + "name": "Gumbo2", + "created_at": "now", + "updated_at": "now" + } + ] +} + +query { + products { + id + name + price + users { + email + } + } +} + + +variables { + "data": { + "email": "gfk@myspace.com", + "full_name": "Ghostface Killah", + "created_at": "now", + "updated_at": "now" + } +} + +mutation { + user(insert: $data) { + id + } +} + +variables { + "data": [ + { + "name": "Gumbo1", + "created_at": "now", + "updated_at": "now" + }, + { + "name": "Gumbo2", + "created_at": "now", + "updated_at": "now" + } + ] +} + +query { + products { + id + name + users { + email + } + } +} + +query { + me { + id + email + full_name + } +} variables { "update": { @@ -112,62 +246,30 @@ query { } } - -query { - me { - id - email - full_name - } -} - -variables { - "data": { - "email": "gfk@myspace.com", - "full_name": "Ghostface Killah", - "created_at": "now", - "updated_at": "now" - } -} - -mutation { - user(insert: $data) { - id - } -} - -query { - users { - id - email - picture: avatar - products(limit: 2, where: {price: {gt: 10}}) { - id - name - description - } - } -} - variables { "data": [ { - "name": "Protect Ya Neck", + "name": "Gumbo1", "created_at": "now", "updated_at": "now" }, { - "name": "Enter the Wu-Tang", + "name": "Gumbo2", "created_at": "now", "updated_at": "now" } ] } -mutation { - products(insert: $data) { +query { + products { id name + description + users { + email + } } } + diff --git a/config/dev.yml b/config/dev.yml index 7c3ba46..ffe9f64 100644 --- a/config/dev.yml +++ b/config/dev.yml @@ -22,7 +22,7 @@ enable_tracing: true # Watch the config folder and reload Super Graph # with the new configs when a change is detected -reload_on_config_change: false +reload_on_config_change: true # File that points to the database seeding script # seed_file: seed.js @@ -53,7 +53,7 @@ auth: # Comment this out if you want to disable setting # the user_id via a header. Good for testing - header: X-User-ID + creds_in_header: true rails: # Rails version this is used for reading the @@ -100,7 +100,7 @@ database: # Define defaults to for the field key and values below defaults: - # filter: ["{ user_id: { eq: $user_id } }"] + # filters: ["{ user_id: { eq: $user_id } }"] # Field and table names that you wish to block blocklist: @@ -112,25 +112,7 @@ database: - token tables: - - name: users - # This filter will overwrite defaults.filter - # filter: ["{ id: { eq: $user_id } }"] - # filter_query: ["{ id: { eq: $user_id } }"] - filter_update: ["{ id: { eq: $user_id } }"] - filter_delete: ["{ id: { eq: $user_id } }"] - - # - name: products - # # Multiple filters are AND'd together - # filter: [ - # "{ price: { gt: 0 } }", - # "{ price: { lt: 8 } }" - # ] - - name: customers - # No filter is used for this field not - # even defaults.filter - filter: none - remotes: - name: payments id: stripe_id @@ -149,7 +131,61 @@ tables: # real db table backing them name: me table: users - filter: ["{ id: { eq: $user_id } }"] - # - name: posts - # filter: ["{ account_id: { _eq: $account_id } }"] \ No newline at end of file +roles_query: "SELECT * FROM users as usr WHERE id = $user_id" + +roles: + - name: anon + tables: + - name: products + limit: 10 + + query: + columns: ["id", "name", "description" ] + aggregation: false + + insert: + allow: false + + update: + allow: false + + delete: + allow: false + + - name: user + tables: + - name: users + query: + filters: ["{ id: { _eq: $user_id } }"] + + - name: products + query: + limit: 50 + filters: ["{ user_id: { eq: $user_id } }"] + columns: ["id", "name", "description" ] + disable_aggregation: false + + insert: + filters: ["{ user_id: { eq: $user_id } }"] + columns: ["id", "name", "description" ] + set: + - created_at: "now" + + update: + filters: ["{ user_id: { eq: $user_id } }"] + columns: + - id + - name + set: + - updated_at: "now" + + delete: + deny: true + + - name: admin + match: id = 1 + tables: + - name: users + # query: + # filters: ["{ account_id: { _eq: $account_id } }"] diff --git a/config/prod.yml b/config/prod.yml index fa5f932..95abfb7 100644 --- a/config/prod.yml +++ b/config/prod.yml @@ -47,10 +47,6 @@ auth: type: rails cookie: _app_session - # Comment this out if you want to disable setting - # the user_id via a header. Good for testing - header: X-User-ID - rails: # Rails version this is used for reading the # various cookies formats. @@ -94,7 +90,7 @@ database: # Define defaults to for the field key and values below defaults: - filter: ["{ user_id: { eq: $user_id } }"] + filters: ["{ user_id: { eq: $user_id } }"] # Field and table names that you wish to block blocklist: @@ -106,25 +102,7 @@ database: - token tables: - - name: users - # This filter will overwrite defaults.filter - # filter: ["{ id: { eq: $user_id } }"] - # filter_query: ["{ id: { eq: $user_id } }"] - filter_update: ["{ id: { eq: $user_id } }"] - filter_delete: ["{ id: { eq: $user_id } }"] - - - name: products - # Multiple filters are AND'd together - filter: [ - "{ price: { gt: 0 } }", - "{ price: { lt: 8 } }" - ] - - name: customers - # No filter is used for this field not - # even defaults.filter - filter: none - # remotes: # - name: payments # id: stripe_id @@ -141,7 +119,61 @@ tables: # real db table backing them name: me table: users - filter: ["{ id: { eq: $user_id } }"] - # - name: posts - # filter: ["{ account_id: { _eq: $account_id } }"] \ No newline at end of file +roles_query: "SELECT * FROM users as usr WHERE id = $user_id" + +roles: + - name: anon + tables: + - name: products + limit: 10 + + query: + columns: ["id", "name", "description" ] + aggregation: false + + insert: + allow: false + + update: + allow: false + + delete: + allow: false + + - name: user + tables: + - name: users + query: + filters: ["{ id: { _eq: $user_id } }"] + + - name: products + query: + limit: 50 + filters: ["{ user_id: { eq: $user_id } }"] + columns: ["id", "name", "description" ] + disable_aggregation: false + + insert: + filters: ["{ user_id: { eq: $user_id } }"] + columns: ["id", "name", "description" ] + set: + - created_at: "now" + + update: + filters: ["{ user_id: { eq: $user_id } }"] + columns: + - id + - name + set: + - updated_at: "now" + + delete: + deny: true + + - name: admin + match: id = 1 + tables: + - name: users + # query: + # filters: ["{ account_id: { _eq: $account_id } }"] diff --git a/docker-compose.yml b/docker-compose.yml index d41e9f5..b3beb6e 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -34,7 +34,7 @@ services: volumes: - .:/app working_dir: /app - command: wu -pattern="*.go" go run main.go serv + command: wtc depends_on: - db - rails_app diff --git a/docs/guide.md b/docs/guide.md index e0aacef..55f0b43 100644 --- a/docs/guide.md +++ b/docs/guide.md @@ -1043,26 +1043,35 @@ We're tried to ensure that the config file is self documenting and easy to work app_name: "Super Graph Development" host_port: 0.0.0.0:8080 web_ui: true -debug_level: 1 -# debug, info, warn, error, fatal, panic, disable -log_level: "info" +# debug, info, warn, error, fatal, panic +log_level: "debug" # Disable this in development to get a list of # queries used. When enabled super graph # will only allow queries from this list # List saved to ./config/allow.list -use_allow_list: true +use_allow_list: false # Throw a 401 on auth failure for queries that need auth # valid values: always, per_query, never -auth_fail_block: always +auth_fail_block: never # Latency tracing for database queries and remote joins # the resulting latency information is returned with the # response enable_tracing: true +# Watch the config folder and reload Super Graph +# with the new configs when a change is detected +reload_on_config_change: true + +# File that points to the database seeding script +# seed_file: seed.js + +# Path pointing to where the migrations can be found +migrations_path: ./config/migrations + # Postgres related environment Variables # SG_DATABASE_HOST # SG_DATABASE_PORT @@ -1086,7 +1095,7 @@ auth: # Comment this out if you want to disable setting # the user_id via a header. Good for testing - header: X-User-ID + creds_in_header: true rails: # Rails version this is used for reading the @@ -1097,10 +1106,10 @@ auth: secret_key_base: 0a248500a64c01184edb4d7ad3a805488f8097ac761b76aaa6c17c01dcb7af03a2f18ba61b2868134b9c7b79a122bc0dadff4367414a2d173297bfea92be5566 # Remote cookie store. (memcache or redis) - # url: redis://127.0.0.1:6379 - # password: test - # max_idle: 80, - # max_active: 12000, + # url: redis://redis:6379 + # password: "" + # max_idle: 80 + # max_active: 12000 # In most cases you don't need these # salt: "encrypted cookie" @@ -1120,20 +1129,23 @@ database: dbname: app_development user: postgres password: '' - # pool_size: 10 - # max_retries: 0 - # log_level: "debug" + + #schema: "public" + #pool_size: 10 + #max_retries: 0 + #log_level: "debug" # Define variables here that you want to use in filters + # sub-queries must be wrapped in () variables: - account_id: "select account_id from users where id = $user_id" + account_id: "(select account_id from users where id = $user_id)" # Define defaults to for the field key and values below defaults: - filter: ["{ user_id: { eq: $user_id } }"] + # filters: ["{ user_id: { eq: $user_id } }"] # Field and table names that you wish to block - blacklist: + blocklist: - ar_internal_metadata - schema_migrations - secret @@ -1141,43 +1153,85 @@ database: - encrypted - token - tables: - - name: users - # This filter will overwrite defaults.filter - filter: ["{ id: { eq: $user_id } }"] +tables: + - name: customers + remotes: + - name: payments + id: stripe_id + url: http://rails_app:3000/stripe/$id + path: data + # debug: true + pass_headers: + - cookie + set_headers: + - name: Host + value: 0.0.0.0 + # - name: Authorization + # value: Bearer - - name: products - # Multiple filters are AND'd together - filter: [ - "{ price: { gt: 0 } }", - "{ price: { lt: 8 } }" - ] + - # You can create new fields that have a + # real db table backing them + name: me + table: users - - name: customers - # No filter is used for this field not - # even defaults.filter - filter: none +roles_query: "SELECT * FROM users as usr WHERE id = $user_id" - remotes: - - name: payments - id: stripe_id - url: http://rails_app:3000/stripe/$id - path: data - # pass_headers: - # - cookie - # - host - set_headers: - - name: Authorization - value: Bearer +roles: + - name: anon + tables: + - name: products + limit: 10 - - # You can create new fields that have a - # real db table backing them - name: me - table: users - filter: ["{ id: { eq: $user_id } }"] + query: + columns: ["id", "name", "description" ] + aggregation: false + + insert: + allow: false + + update: + allow: false + + delete: + allow: false + + - name: user + tables: + - name: users + query: + filters: ["{ id: { _eq: $user_id } }"] + + - name: products + query: + limit: 50 + filters: ["{ user_id: { eq: $user_id } }"] + columns: ["id", "name", "description" ] + disable_aggregation: false + + insert: + filters: ["{ user_id: { eq: $user_id } }"] + columns: ["id", "name", "description" ] + set: + - created_at: "now" + + update: + filters: ["{ user_id: { eq: $user_id } }"] + columns: + - id + - name + set: + - updated_at: "now" + + delete: + deny: true + + - name: admin + match: id = 1 + tables: + - name: users + # query: + # filters: ["{ account_id: { _eq: $account_id } }"] - # - name: posts - # filter: ["{ account_id: { _eq: $account_id } }"] ``` If deploying into environments like Kubernetes it's useful to be able to configure things like secrets and hosts though environment variables therfore we expose the below environment variables. This is escpecially useful for secrets since they are usually injected in via a secrets management framework ie. Kubernetes Secrets diff --git a/examples/rails-app/Gemfile.lock b/examples/rails-app/Gemfile.lock index e69de29..7a6e370 100644 --- a/examples/rails-app/Gemfile.lock +++ b/examples/rails-app/Gemfile.lock @@ -0,0 +1,273 @@ +GIT + remote: https://github.com/stympy/faker.git + revision: 4e9144825fcc9ba5c83cc0fd037779ab82f3120b + branch: master + specs: + faker (2.6.0) + i18n (>= 1.6, < 1.8) + +GEM + remote: https://rubygems.org/ + specs: + actioncable (6.0.0) + actionpack (= 6.0.0) + nio4r (~> 2.0) + websocket-driver (>= 0.6.1) + actionmailbox (6.0.0) + actionpack (= 6.0.0) + activejob (= 6.0.0) + activerecord (= 6.0.0) + activestorage (= 6.0.0) + activesupport (= 6.0.0) + mail (>= 2.7.1) + actionmailer (6.0.0) + actionpack (= 6.0.0) + actionview (= 6.0.0) + activejob (= 6.0.0) + mail (~> 2.5, >= 2.5.4) + rails-dom-testing (~> 2.0) + actionpack (6.0.0) + actionview (= 6.0.0) + activesupport (= 6.0.0) + rack (~> 2.0) + rack-test (>= 0.6.3) + rails-dom-testing (~> 2.0) + rails-html-sanitizer (~> 1.0, >= 1.2.0) + actiontext (6.0.0) + actionpack (= 6.0.0) + activerecord (= 6.0.0) + activestorage (= 6.0.0) + activesupport (= 6.0.0) + nokogiri (>= 1.8.5) + actionview (6.0.0) + activesupport (= 6.0.0) + builder (~> 3.1) + erubi (~> 1.4) + rails-dom-testing (~> 2.0) + rails-html-sanitizer (~> 1.1, >= 1.2.0) + activejob (6.0.0) + activesupport (= 6.0.0) + globalid (>= 0.3.6) + activemodel (6.0.0) + activesupport (= 6.0.0) + activerecord (6.0.0) + activemodel (= 6.0.0) + activesupport (= 6.0.0) + activestorage (6.0.0) + actionpack (= 6.0.0) + activejob (= 6.0.0) + activerecord (= 6.0.0) + marcel (~> 0.3.1) + activesupport (6.0.0) + concurrent-ruby (~> 1.0, >= 1.0.2) + i18n (>= 0.7, < 2) + minitest (~> 5.1) + tzinfo (~> 1.1) + zeitwerk (~> 2.1, >= 2.1.8) + addressable (2.7.0) + public_suffix (>= 2.0.2, < 5.0) + archive-zip (0.12.0) + io-like (~> 0.3.0) + bcrypt (3.1.13) + bindex (0.8.1) + bootsnap (1.4.5) + msgpack (~> 1.0) + builder (3.2.3) + byebug (11.0.1) + capybara (3.29.0) + addressable + mini_mime (>= 0.1.3) + nokogiri (~> 1.8) + rack (>= 1.6.0) + rack-test (>= 0.6.3) + regexp_parser (~> 1.5) + xpath (~> 3.2) + childprocess (3.0.0) + chromedriver-helper (2.1.1) + archive-zip (~> 0.10) + nokogiri (~> 1.8) + coffee-rails (4.2.2) + coffee-script (>= 2.2.0) + railties (>= 4.0.0) + coffee-script (2.4.1) + coffee-script-source + execjs + coffee-script-source (1.12.2) + concurrent-ruby (1.1.5) + crass (1.0.4) + devise (4.7.1) + bcrypt (~> 3.0) + orm_adapter (~> 0.1) + railties (>= 4.1.0) + responders + warden (~> 1.2.3) + erubi (1.9.0) + execjs (2.7.0) + ffi (1.11.1) + globalid (0.4.2) + activesupport (>= 4.2.0) + i18n (1.7.0) + concurrent-ruby (~> 1.0) + io-like (0.3.0) + jbuilder (2.9.1) + activesupport (>= 4.2.0) + listen (3.1.5) + rb-fsevent (~> 0.9, >= 0.9.4) + rb-inotify (~> 0.9, >= 0.9.7) + ruby_dep (~> 1.2) + loofah (2.3.0) + crass (~> 1.0.2) + nokogiri (>= 1.5.9) + mail (2.7.1) + mini_mime (>= 0.1.1) + marcel (0.3.3) + mimemagic (~> 0.3.2) + method_source (0.9.2) + mimemagic (0.3.3) + mini_mime (1.0.2) + mini_portile2 (2.4.0) + minitest (5.12.2) + msgpack (1.3.1) + nio4r (2.5.2) + nokogiri (1.10.4) + mini_portile2 (~> 2.4.0) + orm_adapter (0.5.0) + pg (1.1.4) + public_suffix (4.0.1) + puma (3.12.1) + rack (2.0.7) + rack-test (1.1.0) + rack (>= 1.0, < 3) + rails (6.0.0) + actioncable (= 6.0.0) + actionmailbox (= 6.0.0) + actionmailer (= 6.0.0) + actionpack (= 6.0.0) + actiontext (= 6.0.0) + actionview (= 6.0.0) + activejob (= 6.0.0) + activemodel (= 6.0.0) + activerecord (= 6.0.0) + activestorage (= 6.0.0) + activesupport (= 6.0.0) + bundler (>= 1.3.0) + railties (= 6.0.0) + sprockets-rails (>= 2.0.0) + rails-dom-testing (2.0.3) + activesupport (>= 4.2.0) + nokogiri (>= 1.6) + rails-html-sanitizer (1.3.0) + loofah (~> 2.3) + railties (6.0.0) + actionpack (= 6.0.0) + activesupport (= 6.0.0) + method_source + rake (>= 0.8.7) + thor (>= 0.20.3, < 2.0) + rake (13.0.0) + rb-fsevent (0.10.3) + rb-inotify (0.10.0) + ffi (~> 1.0) + redis (4.1.3) + redis-actionpack (5.1.0) + actionpack (>= 4.0, < 7) + redis-rack (>= 1, < 3) + redis-store (>= 1.1.0, < 2) + redis-activesupport (5.2.0) + activesupport (>= 3, < 7) + redis-store (>= 1.3, < 2) + redis-rack (2.0.6) + rack (>= 1.5, < 3) + redis-store (>= 1.2, < 2) + redis-rails (5.0.2) + redis-actionpack (>= 5.0, < 6) + redis-activesupport (>= 5.0, < 6) + redis-store (>= 1.2, < 2) + redis-store (1.8.0) + redis (>= 4, < 5) + regexp_parser (1.6.0) + responders (3.0.0) + actionpack (>= 5.0) + railties (>= 5.0) + ruby_dep (1.5.0) + rubyzip (2.0.0) + sass (3.7.4) + sass-listen (~> 4.0.0) + sass-listen (4.0.0) + rb-fsevent (~> 0.9, >= 0.9.4) + rb-inotify (~> 0.9, >= 0.9.7) + sass-rails (5.1.0) + railties (>= 5.2.0) + sass (~> 3.1) + sprockets (>= 2.8, < 4.0) + sprockets-rails (>= 2.0, < 4.0) + tilt (>= 1.1, < 3) + selenium-webdriver (3.142.6) + childprocess (>= 0.5, < 4.0) + rubyzip (>= 1.2.2) + spring (2.1.0) + spring-watcher-listen (2.0.1) + listen (>= 2.7, < 4.0) + spring (>= 1.2, < 3.0) + sprockets (3.7.2) + concurrent-ruby (~> 1.0) + rack (> 1, < 3) + sprockets-rails (3.2.1) + actionpack (>= 4.0) + activesupport (>= 4.0) + sprockets (>= 3.0.0) + thor (0.20.3) + thread_safe (0.3.6) + tilt (2.0.10) + turbolinks (5.2.1) + turbolinks-source (~> 5.2) + turbolinks-source (5.2.0) + tzinfo (1.2.5) + thread_safe (~> 0.1) + uglifier (4.2.0) + execjs (>= 0.3.0, < 3) + warden (1.2.8) + rack (>= 2.0.6) + web-console (4.0.1) + actionview (>= 6.0.0) + activemodel (>= 6.0.0) + bindex (>= 0.4.0) + railties (>= 6.0.0) + websocket-driver (0.7.1) + websocket-extensions (>= 0.1.0) + websocket-extensions (0.1.4) + xpath (3.2.0) + nokogiri (~> 1.8) + zeitwerk (2.2.0) + +PLATFORMS + ruby + +DEPENDENCIES + bootsnap (>= 1.1.0) + byebug + capybara (>= 2.15) + chromedriver-helper + coffee-rails (~> 4.2) + devise + faker! + jbuilder (~> 2.5) + listen (>= 3.0.5, < 3.2) + pg (>= 0.18, < 2.0) + puma (~> 3.11) + rails (~> 6.0.0.rc1) + redis-rails + sass-rails (~> 5.0) + selenium-webdriver + spring + spring-watcher-listen (~> 2.0.0) + turbolinks (~> 5) + tzinfo-data + uglifier (>= 1.3.0) + web-console (>= 3.3.0) + +RUBY VERSION + ruby 2.5.7p206 + +BUNDLED WITH + 1.17.3 diff --git a/psql/insert.go b/psql/mutate.go similarity index 57% rename from psql/insert.go rename to psql/mutate.go index 027588e..84bb122 100644 --- a/psql/insert.go +++ b/psql/mutate.go @@ -1,7 +1,6 @@ package psql import ( - "bytes" "errors" "fmt" "io" @@ -10,9 +9,9 @@ import ( "github.com/dosco/super-graph/qcode" ) -var zeroPaging = qcode.Paging{} +var noLimit = qcode.Paging{NoLimit: true} -func (co *Compiler) compileMutation(qc *qcode.QCode, w *bytes.Buffer, vars Variables) (uint32, error) { +func (co *Compiler) compileMutation(qc *qcode.QCode, w io.Writer, vars Variables) (uint32, error) { if len(qc.Selects) == 0 { return 0, errors.New("empty query") } @@ -25,27 +24,27 @@ func (co *Compiler) compileMutation(qc *qcode.QCode, w *bytes.Buffer, vars Varia return 0, err } - c.w.WriteString(`WITH `) + io.WriteString(c.w, `WITH `) quoted(c.w, ti.Name) - c.w.WriteString(` AS `) + io.WriteString(c.w, ` AS `) - switch root.Action { - case qcode.ActionInsert: + switch qc.Type { + case qcode.QTInsert: if _, err := c.renderInsert(qc, w, vars, ti); err != nil { return 0, err } - case qcode.ActionUpdate: + case qcode.QTUpdate: if _, err := c.renderUpdate(qc, w, vars, ti); err != nil { return 0, err } - case qcode.ActionUpsert: + case qcode.QTUpsert: if _, err := c.renderUpsert(qc, w, vars, ti); err != nil { return 0, err } - case qcode.ActionDelete: + case qcode.QTDelete: if _, err := c.renderDelete(qc, w, vars, ti); err != nil { return 0, err } @@ -56,7 +55,7 @@ func (co *Compiler) compileMutation(qc *qcode.QCode, w *bytes.Buffer, vars Varia io.WriteString(c.w, ` RETURNING *) `) - root.Paging = zeroPaging + root.Paging = noLimit root.DistinctOn = root.DistinctOn[:] root.OrderBy = root.OrderBy[:] root.Where = nil @@ -65,13 +64,12 @@ func (co *Compiler) compileMutation(qc *qcode.QCode, w *bytes.Buffer, vars Varia return c.compileQuery(qc, w) } -func (c *compilerContext) renderInsert(qc *qcode.QCode, w *bytes.Buffer, +func (c *compilerContext) renderInsert(qc *qcode.QCode, w io.Writer, vars Variables, ti *DBTableInfo) (uint32, error) { - root := &qc.Selects[0] - insert, ok := vars[root.ActionVar] + insert, ok := vars[qc.ActionVar] if !ok { - return 0, fmt.Errorf("Variable '%s' not defined", root.ActionVar) + return 0, fmt.Errorf("Variable '%s' not defined", qc.ActionVar) } jt, array, err := jsn.Tree(insert) @@ -79,56 +77,62 @@ func (c *compilerContext) renderInsert(qc *qcode.QCode, w *bytes.Buffer, return 0, err } - c.w.WriteString(`(WITH "input" AS (SELECT {{`) - c.w.WriteString(root.ActionVar) - c.w.WriteString(`}}::json AS j) INSERT INTO `) + io.WriteString(c.w, `(WITH "input" AS (SELECT {{`) + io.WriteString(c.w, qc.ActionVar) + io.WriteString(c.w, `}}::json AS j) INSERT INTO `) quoted(c.w, ti.Name) io.WriteString(c.w, ` (`) c.renderInsertUpdateColumns(qc, w, jt, ti) io.WriteString(c.w, `)`) - c.w.WriteString(` SELECT `) + io.WriteString(c.w, ` SELECT `) c.renderInsertUpdateColumns(qc, w, jt, ti) - c.w.WriteString(` FROM input i, `) + io.WriteString(c.w, ` FROM input i, `) if array { - c.w.WriteString(`json_populate_recordset`) + io.WriteString(c.w, `json_populate_recordset`) } else { - c.w.WriteString(`json_populate_record`) + io.WriteString(c.w, `json_populate_record`) } - c.w.WriteString(`(NULL::`) - c.w.WriteString(ti.Name) - c.w.WriteString(`, i.j) t`) + io.WriteString(c.w, `(NULL::`) + io.WriteString(c.w, ti.Name) + io.WriteString(c.w, `, i.j) t`) return 0, nil } -func (c *compilerContext) renderInsertUpdateColumns(qc *qcode.QCode, w *bytes.Buffer, +func (c *compilerContext) renderInsertUpdateColumns(qc *qcode.QCode, w io.Writer, jt map[string]interface{}, ti *DBTableInfo) (uint32, error) { + root := &qc.Selects[0] i := 0 for _, cn := range ti.ColumnNames { if _, ok := jt[cn]; !ok { continue } + if len(root.Allowed) != 0 { + if _, ok := root.Allowed[cn]; !ok { + continue + } + } if i != 0 { io.WriteString(c.w, `, `) } - c.w.WriteString(cn) + io.WriteString(c.w, cn) i++ } return 0, nil } -func (c *compilerContext) renderUpdate(qc *qcode.QCode, w *bytes.Buffer, +func (c *compilerContext) renderUpdate(qc *qcode.QCode, w io.Writer, vars Variables, ti *DBTableInfo) (uint32, error) { root := &qc.Selects[0] - update, ok := vars[root.ActionVar] + update, ok := vars[qc.ActionVar] if !ok { - return 0, fmt.Errorf("Variable '%s' not defined", root.ActionVar) + return 0, fmt.Errorf("Variable '%s' not defined", qc.ActionVar) } jt, array, err := jsn.Tree(update) @@ -136,26 +140,26 @@ func (c *compilerContext) renderUpdate(qc *qcode.QCode, w *bytes.Buffer, return 0, err } - c.w.WriteString(`(WITH "input" AS (SELECT {{`) - c.w.WriteString(root.ActionVar) - c.w.WriteString(`}}::json AS j) UPDATE `) + io.WriteString(c.w, `(WITH "input" AS (SELECT {{`) + io.WriteString(c.w, qc.ActionVar) + io.WriteString(c.w, `}}::json AS j) UPDATE `) quoted(c.w, ti.Name) io.WriteString(c.w, ` SET (`) c.renderInsertUpdateColumns(qc, w, jt, ti) - c.w.WriteString(`) = (SELECT `) + io.WriteString(c.w, `) = (SELECT `) c.renderInsertUpdateColumns(qc, w, jt, ti) - c.w.WriteString(` FROM input i, `) + io.WriteString(c.w, ` FROM input i, `) if array { - c.w.WriteString(`json_populate_recordset`) + io.WriteString(c.w, `json_populate_recordset`) } else { - c.w.WriteString(`json_populate_record`) + io.WriteString(c.w, `json_populate_record`) } - c.w.WriteString(`(NULL::`) - c.w.WriteString(ti.Name) - c.w.WriteString(`, i.j) t)`) + io.WriteString(c.w, `(NULL::`) + io.WriteString(c.w, ti.Name) + io.WriteString(c.w, `, i.j) t)`) io.WriteString(c.w, ` WHERE `) @@ -166,11 +170,11 @@ func (c *compilerContext) renderUpdate(qc *qcode.QCode, w *bytes.Buffer, return 0, nil } -func (c *compilerContext) renderDelete(qc *qcode.QCode, w *bytes.Buffer, +func (c *compilerContext) renderDelete(qc *qcode.QCode, w io.Writer, vars Variables, ti *DBTableInfo) (uint32, error) { root := &qc.Selects[0] - c.w.WriteString(`(DELETE FROM `) + io.WriteString(c.w, `(DELETE FROM `) quoted(c.w, ti.Name) io.WriteString(c.w, ` WHERE `) @@ -181,13 +185,12 @@ func (c *compilerContext) renderDelete(qc *qcode.QCode, w *bytes.Buffer, return 0, nil } -func (c *compilerContext) renderUpsert(qc *qcode.QCode, w *bytes.Buffer, +func (c *compilerContext) renderUpsert(qc *qcode.QCode, w io.Writer, vars Variables, ti *DBTableInfo) (uint32, error) { - root := &qc.Selects[0] - upsert, ok := vars[root.ActionVar] + upsert, ok := vars[qc.ActionVar] if !ok { - return 0, fmt.Errorf("Variable '%s' not defined", root.ActionVar) + return 0, fmt.Errorf("Variable '%s' not defined", qc.ActionVar) } jt, _, err := jsn.Tree(upsert) @@ -199,7 +202,7 @@ func (c *compilerContext) renderUpsert(qc *qcode.QCode, w *bytes.Buffer, return 0, err } - c.w.WriteString(` ON CONFLICT DO (`) + io.WriteString(c.w, ` ON CONFLICT DO (`) i := 0 for _, cn := range ti.ColumnNames { @@ -214,15 +217,15 @@ func (c *compilerContext) renderUpsert(qc *qcode.QCode, w *bytes.Buffer, if i != 0 { io.WriteString(c.w, `, `) } - c.w.WriteString(cn) + io.WriteString(c.w, cn) i++ } if i == 0 { - c.w.WriteString(ti.PrimaryCol) + io.WriteString(c.w, ti.PrimaryCol) } - c.w.WriteString(`) DO `) + io.WriteString(c.w, `) DO `) - c.w.WriteString(`UPDATE `) + io.WriteString(c.w, `UPDATE `) io.WriteString(c.w, ` SET `) i = 0 @@ -233,17 +236,17 @@ func (c *compilerContext) renderUpsert(qc *qcode.QCode, w *bytes.Buffer, if i != 0 { io.WriteString(c.w, `, `) } - c.w.WriteString(cn) + io.WriteString(c.w, cn) io.WriteString(c.w, ` = EXCLUDED.`) - c.w.WriteString(cn) + io.WriteString(c.w, cn) i++ } return 0, nil } -func quoted(w *bytes.Buffer, identifier string) { - w.WriteString(`"`) - w.WriteString(identifier) - w.WriteString(`"`) +func quoted(w io.Writer, identifier string) { + io.WriteString(w, `"`) + io.WriteString(w, identifier) + io.WriteString(w, `"`) } diff --git a/psql/insert_test.go b/psql/mutate_test.go similarity index 90% rename from psql/insert_test.go rename to psql/mutate_test.go index bd03d7f..390c301 100644 --- a/psql/insert_test.go +++ b/psql/mutate_test.go @@ -12,13 +12,13 @@ func simpleInsert(t *testing.T) { } }` - sql := `WITH "users" AS (WITH "input" AS (SELECT {{data}}::json AS j) INSERT INTO "users" (full_name, email) SELECT full_name, email FROM input i, json_populate_record(NULL::users, i.j) t RETURNING *) SELECT json_object_agg('user', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."id" AS "id") AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."id" FROM "users") AS "users_0") AS "done_1337";` + sql := `WITH "users" AS (WITH "input" AS (SELECT {{data}}::json AS j) INSERT INTO "users" (full_name, email) SELECT full_name, email FROM input i, json_populate_record(NULL::users, i.j) t RETURNING *) SELECT json_object_agg('user', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."id" AS "id") AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."id" FROM "users") AS "users_0") AS "done_1337"` vars := map[string]json.RawMessage{ "data": json.RawMessage(`{"email": "reannagreenholt@orn.com", "full_name": "Flo Barton"}`), } - resSQL, err := compileGQLToPSQL(gql, vars) + resSQL, err := compileGQLToPSQL(gql, vars, "user") if err != nil { t.Fatal(err) } @@ -36,13 +36,13 @@ func singleInsert(t *testing.T) { } }` - sql := `WITH "products" AS (WITH "input" AS (SELECT {{insert}}::json AS j) INSERT INTO "products" (name, description, user_id) SELECT name, description, user_id FROM input i, json_populate_record(NULL::products, i.j) t RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337";` + sql := `WITH "products" AS (WITH "input" AS (SELECT {{insert}}::json AS j) INSERT INTO "products" (name, description, user_id) SELECT name, description, user_id FROM input i, json_populate_record(NULL::products, i.j) t RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"` vars := map[string]json.RawMessage{ "insert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc", "user_id": 5 }`), } - resSQL, err := compileGQLToPSQL(gql, vars) + resSQL, err := compileGQLToPSQL(gql, vars, "user") if err != nil { t.Fatal(err) } @@ -54,19 +54,19 @@ func singleInsert(t *testing.T) { func bulkInsert(t *testing.T) { gql := `mutation { - product(id: 15, insert: $insert) { + product(name: "test", id: 15, insert: $insert) { id name } }` - sql := `WITH "products" AS (WITH "input" AS (SELECT {{insert}}::json AS j) INSERT INTO "products" (name, description) SELECT name, description FROM input i, json_populate_recordset(NULL::products, i.j) t RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337";` + sql := `WITH "products" AS (WITH "input" AS (SELECT {{insert}}::json AS j) INSERT INTO "products" (name, description) SELECT name, description FROM input i, json_populate_recordset(NULL::products, i.j) t RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"` vars := map[string]json.RawMessage{ "insert": json.RawMessage(` [{ "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }]`), } - resSQL, err := compileGQLToPSQL(gql, vars) + resSQL, err := compileGQLToPSQL(gql, vars, "user") if err != nil { t.Fatal(err) } @@ -84,13 +84,13 @@ func singleUpsert(t *testing.T) { } }` - sql := `WITH "products" AS (WITH "input" AS (SELECT {{upsert}}::json AS j) INSERT INTO "products" (name, description) SELECT name, description FROM input i, json_populate_record(NULL::products, i.j) t ON CONFLICT DO (id) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337";` + sql := `WITH "products" AS (WITH "input" AS (SELECT {{upsert}}::json AS j) INSERT INTO "products" (name, description) SELECT name, description FROM input i, json_populate_record(NULL::products, i.j) t ON CONFLICT DO (id) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"` vars := map[string]json.RawMessage{ "upsert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`), } - resSQL, err := compileGQLToPSQL(gql, vars) + resSQL, err := compileGQLToPSQL(gql, vars, "user") if err != nil { t.Fatal(err) } @@ -108,13 +108,13 @@ func bulkUpsert(t *testing.T) { } }` - sql := `WITH "products" AS (WITH "input" AS (SELECT {{upsert}}::json AS j) INSERT INTO "products" (name, description) SELECT name, description FROM input i, json_populate_recordset(NULL::products, i.j) t ON CONFLICT DO (id) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337";` + sql := `WITH "products" AS (WITH "input" AS (SELECT {{upsert}}::json AS j) INSERT INTO "products" (name, description) SELECT name, description FROM input i, json_populate_recordset(NULL::products, i.j) t ON CONFLICT DO (id) DO UPDATE SET name = EXCLUDED.name, description = EXCLUDED.description RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"` vars := map[string]json.RawMessage{ "upsert": json.RawMessage(` [{ "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }]`), } - resSQL, err := compileGQLToPSQL(gql, vars) + resSQL, err := compileGQLToPSQL(gql, vars, "user") if err != nil { t.Fatal(err) } @@ -132,13 +132,13 @@ func singleUpdate(t *testing.T) { } }` - sql := `WITH "products" AS (WITH "input" AS (SELECT {{update}}::json AS j) UPDATE "products" SET (name, description) = (SELECT name, description FROM input i, json_populate_record(NULL::products, i.j) t) WHERE (("products"."user_id") = {{user_id}}) AND (("products"."id") = 1) AND (("products"."id") = 15) RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337";` + sql := `WITH "products" AS (WITH "input" AS (SELECT {{update}}::json AS j) UPDATE "products" SET (name, description) = (SELECT name, description FROM input i, json_populate_record(NULL::products, i.j) t) WHERE (("products"."user_id") = {{user_id}}) AND (("products"."id") = 1) AND (("products"."id") = 15) RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"` vars := map[string]json.RawMessage{ "update": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`), } - resSQL, err := compileGQLToPSQL(gql, vars) + resSQL, err := compileGQLToPSQL(gql, vars, "user") if err != nil { t.Fatal(err) } @@ -156,13 +156,13 @@ func delete(t *testing.T) { } }` - sql := `WITH "products" AS (DELETE FROM "products" WHERE (("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") = 1) RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337";` + sql := `WITH "products" AS (DELETE FROM "products" WHERE (("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") = 1) RETURNING *) SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products") AS "products_0") AS "done_1337"` vars := map[string]json.RawMessage{ "update": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`), } - resSQL, err := compileGQLToPSQL(gql, vars) + resSQL, err := compileGQLToPSQL(gql, vars, "user") if err != nil { t.Fatal(err) } @@ -172,7 +172,7 @@ func delete(t *testing.T) { } } -func TestCompileInsert(t *testing.T) { +func TestCompileMutate(t *testing.T) { t.Run("simpleInsert", simpleInsert) t.Run("singleInsert", singleInsert) t.Run("bulkInsert", bulkInsert) diff --git a/psql/select.go b/psql/query.go similarity index 65% rename from psql/select.go rename to psql/query.go index 1058a62..c06ab95 100644 --- a/psql/select.go +++ b/psql/query.go @@ -49,7 +49,7 @@ func (c *Compiler) IDColumn(table string) (string, error) { } type compilerContext struct { - w *bytes.Buffer + w io.Writer s []qcode.Select *Compiler } @@ -60,18 +60,18 @@ func (co *Compiler) CompileEx(qc *qcode.QCode, vars Variables) (uint32, []byte, return skipped, w.Bytes(), err } -func (co *Compiler) Compile(qc *qcode.QCode, w *bytes.Buffer, vars Variables) (uint32, error) { +func (co *Compiler) Compile(qc *qcode.QCode, w io.Writer, vars Variables) (uint32, error) { switch qc.Type { case qcode.QTQuery: return co.compileQuery(qc, w) - case qcode.QTMutation: + case qcode.QTInsert, qcode.QTUpdate, qcode.QTDelete, qcode.QTUpsert: return co.compileMutation(qc, w, vars) } - return 0, errors.New("unknown operation") + return 0, fmt.Errorf("Unknown operation type %d", qc.Type) } -func (co *Compiler) compileQuery(qc *qcode.QCode, w *bytes.Buffer) (uint32, error) { +func (co *Compiler) compileQuery(qc *qcode.QCode, w io.Writer) (uint32, error) { if len(qc.Selects) == 0 { return 0, errors.New("empty query") } @@ -90,17 +90,17 @@ func (co *Compiler) compileQuery(qc *qcode.QCode, w *bytes.Buffer) (uint32, erro //fmt.Fprintf(w, `SELECT json_object_agg('%s', %s) FROM (`, //root.FieldName, root.Table) - c.w.WriteString(`SELECT json_object_agg('`) - c.w.WriteString(root.FieldName) - c.w.WriteString(`', `) + io.WriteString(c.w, `SELECT json_object_agg('`) + io.WriteString(c.w, root.FieldName) + io.WriteString(c.w, `', `) if ti.Singular == false { - c.w.WriteString(root.Table) + io.WriteString(c.w, root.Table) } else { - c.w.WriteString("sel_json_") + io.WriteString(c.w, "sel_json_") int2string(c.w, root.ID) } - c.w.WriteString(`) FROM (`) + io.WriteString(c.w, `) FROM (`) var ignored uint32 @@ -161,9 +161,8 @@ func (co *Compiler) compileQuery(qc *qcode.QCode, w *bytes.Buffer) (uint32, erro } } - c.w.WriteString(`)`) + io.WriteString(c.w, `)`) alias(c.w, `done_1337`) - c.w.WriteString(`;`) return ignored, nil } @@ -219,10 +218,10 @@ func (c *compilerContext) renderSelect(sel *qcode.Select, ti *DBTableInfo) (uint // SELECT if ti.Singular == false { //fmt.Fprintf(w, `SELECT coalesce(json_agg("%s"`, c.sel.Table) - c.w.WriteString(`SELECT coalesce(json_agg("`) - c.w.WriteString("sel_json_") + io.WriteString(c.w, `SELECT coalesce(json_agg("`) + io.WriteString(c.w, "sel_json_") int2string(c.w, sel.ID) - c.w.WriteString(`"`) + io.WriteString(c.w, `"`) if hasOrder { err := c.renderOrderBy(sel, ti) @@ -232,24 +231,24 @@ func (c *compilerContext) renderSelect(sel *qcode.Select, ti *DBTableInfo) (uint } //fmt.Fprintf(w, `), '[]') AS "%s" FROM (`, c.sel.Table) - c.w.WriteString(`), '[]')`) + io.WriteString(c.w, `), '[]')`) alias(c.w, sel.Table) - c.w.WriteString(` FROM (`) + io.WriteString(c.w, ` FROM (`) } // ROW-TO-JSON - c.w.WriteString(`SELECT `) + io.WriteString(c.w, `SELECT `) if len(sel.DistinctOn) != 0 { c.renderDistinctOn(sel, ti) } - c.w.WriteString(`row_to_json((`) + io.WriteString(c.w, `row_to_json((`) //fmt.Fprintf(w, `SELECT "sel_%d" FROM (SELECT `, c.sel.ID) - c.w.WriteString(`SELECT "sel_`) + io.WriteString(c.w, `SELECT "sel_`) int2string(c.w, sel.ID) - c.w.WriteString(`" FROM (SELECT `) + io.WriteString(c.w, `" FROM (SELECT `) // Combined column names c.renderColumns(sel, ti) @@ -262,11 +261,11 @@ func (c *compilerContext) renderSelect(sel *qcode.Select, ti *DBTableInfo) (uint } //fmt.Fprintf(w, `) AS "sel_%d"`, c.sel.ID) - c.w.WriteString(`)`) + io.WriteString(c.w, `)`) aliasWithID(c.w, "sel", sel.ID) //fmt.Fprintf(w, `)) AS "%s"`, c.sel.Table) - c.w.WriteString(`))`) + io.WriteString(c.w, `))`) aliasWithID(c.w, "sel_json", sel.ID) // END-ROW-TO-JSON @@ -295,31 +294,33 @@ func (c *compilerContext) renderSelectClose(sel *qcode.Select, ti *DBTableInfo) } } - if sel.Action == 0 { - if len(sel.Paging.Limit) != 0 { - //fmt.Fprintf(w, ` LIMIT ('%s') :: integer`, c.sel.Paging.Limit) - c.w.WriteString(` LIMIT ('`) - c.w.WriteString(sel.Paging.Limit) - c.w.WriteString(`') :: integer`) + switch { + case sel.Paging.NoLimit: + break - } else if ti.Singular { - c.w.WriteString(` LIMIT ('1') :: integer`) + case len(sel.Paging.Limit) != 0: + //fmt.Fprintf(w, ` LIMIT ('%s') :: integer`, c.sel.Paging.Limit) + io.WriteString(c.w, ` LIMIT ('`) + io.WriteString(c.w, sel.Paging.Limit) + io.WriteString(c.w, `') :: integer`) - } else { - c.w.WriteString(` LIMIT ('20') :: integer`) - } + case ti.Singular: + io.WriteString(c.w, ` LIMIT ('1') :: integer`) + + default: + io.WriteString(c.w, ` LIMIT ('20') :: integer`) } if len(sel.Paging.Offset) != 0 { //fmt.Fprintf(w, ` OFFSET ('%s') :: integer`, c.sel.Paging.Offset) - c.w.WriteString(`OFFSET ('`) - c.w.WriteString(sel.Paging.Offset) - c.w.WriteString(`') :: integer`) + io.WriteString(c.w, `OFFSET ('`) + io.WriteString(c.w, sel.Paging.Offset) + io.WriteString(c.w, `') :: integer`) } if ti.Singular == false { //fmt.Fprintf(w, `) AS "sel_json_agg_%d"`, c.sel.ID) - c.w.WriteString(`)`) + io.WriteString(c.w, `)`) aliasWithID(c.w, "sel_json_agg", sel.ID) } @@ -327,15 +328,15 @@ func (c *compilerContext) renderSelectClose(sel *qcode.Select, ti *DBTableInfo) } func (c *compilerContext) renderJoin(sel *qcode.Select) error { - c.w.WriteString(` LEFT OUTER JOIN LATERAL (`) + io.WriteString(c.w, ` LEFT OUTER JOIN LATERAL (`) return nil } func (c *compilerContext) renderJoinClose(sel *qcode.Select) error { //fmt.Fprintf(w, `) AS "%s_%d_join" ON ('true')`, c.sel.Table, c.sel.ID) - c.w.WriteString(`)`) + io.WriteString(c.w, `)`) aliasWithIDSuffix(c.w, sel.Table, sel.ID, "_join") - c.w.WriteString(` ON ('true')`) + io.WriteString(c.w, ` ON ('true')`) return nil } @@ -358,25 +359,43 @@ func (c *compilerContext) renderJoinTable(sel *qcode.Select) error { //fmt.Fprintf(w, ` LEFT OUTER JOIN "%s" ON (("%s"."%s") = ("%s_%d"."%s"))`, //rel.Through, rel.Through, rel.ColT, c.parent.Table, c.parent.ID, rel.Col1) - c.w.WriteString(` LEFT OUTER JOIN "`) - c.w.WriteString(rel.Through) - c.w.WriteString(`" ON ((`) + io.WriteString(c.w, ` LEFT OUTER JOIN "`) + io.WriteString(c.w, rel.Through) + io.WriteString(c.w, `" ON ((`) colWithTable(c.w, rel.Through, rel.ColT) - c.w.WriteString(`) = (`) + io.WriteString(c.w, `) = (`) colWithTableID(c.w, pt.Name, parent.ID, rel.Col1) - c.w.WriteString(`))`) + io.WriteString(c.w, `))`) return nil } func (c *compilerContext) renderColumns(sel *qcode.Select, ti *DBTableInfo) { - for i, col := range sel.Cols { + i := 0 + for _, col := range sel.Cols { + if len(sel.Allowed) != 0 { + n := funcPrefixLen(col.Name) + if n != 0 { + if sel.Functions == false { + continue + } + if _, ok := sel.Allowed[col.Name[n:]]; !ok { + continue + } + } else { + if _, ok := sel.Allowed[col.Name]; !ok { + continue + } + } + } + if i != 0 { io.WriteString(c.w, ", ") } //fmt.Fprintf(w, `"%s_%d"."%s" AS "%s"`, //c.sel.Table, c.sel.ID, col.Name, col.FieldName) colWithTableIDAlias(c.w, ti.Name, sel.ID, col.Name, col.FieldName) + i++ } } @@ -415,10 +434,24 @@ func (c *compilerContext) renderJoinedColumns(sel *qcode.Select, ti *DBTableInfo } childSel := &c.s[id] + cti, err := c.schema.GetTable(childSel.Table) + if err != nil { + continue + } + //fmt.Fprintf(w, `"%s_%d_join"."%s" AS "%s"`, //s.Table, s.ID, s.Table, s.FieldName) - colWithTableIDSuffixAlias(c.w, childSel.Table, childSel.ID, - "_join", childSel.Table, childSel.FieldName) + if cti.Singular { + io.WriteString(c.w, `"sel_json_`) + int2string(c.w, childSel.ID) + io.WriteString(c.w, `" AS "`) + io.WriteString(c.w, childSel.FieldName) + io.WriteString(c.w, `"`) + + } else { + colWithTableIDSuffixAlias(c.w, childSel.Table, childSel.ID, + "_join", childSel.Table, childSel.FieldName) + } } return nil @@ -433,9 +466,10 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo, isSearch := sel.Args["search"] != nil isAgg := false - c.w.WriteString(` FROM (SELECT `) + io.WriteString(c.w, ` FROM (SELECT `) - for i, col := range sel.Cols { + i := 0 + for n, col := range sel.Cols { cn := col.Name _, isRealCol := ti.Columns[cn] @@ -447,93 +481,116 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo, cn = ti.TSVCol arg := sel.Args["search"] + if i != 0 { + io.WriteString(c.w, `, `) + } //fmt.Fprintf(w, `ts_rank("%s"."%s", to_tsquery('%s')) AS %s`, //c.sel.Table, cn, arg.Val, col.Name) - c.w.WriteString(`ts_rank(`) + io.WriteString(c.w, `ts_rank(`) colWithTable(c.w, ti.Name, cn) - c.w.WriteString(`, to_tsquery('`) - c.w.WriteString(arg.Val) - c.w.WriteString(`')`) + io.WriteString(c.w, `, to_tsquery('`) + io.WriteString(c.w, arg.Val) + io.WriteString(c.w, `')`) alias(c.w, col.Name) + i++ case strings.HasPrefix(cn, "search_headline_"): cn = cn[16:] arg := sel.Args["search"] + if i != 0 { + io.WriteString(c.w, `, `) + } //fmt.Fprintf(w, `ts_headline("%s"."%s", to_tsquery('%s')) AS %s`, //c.sel.Table, cn, arg.Val, col.Name) - c.w.WriteString(`ts_headlinek(`) + io.WriteString(c.w, `ts_headlinek(`) colWithTable(c.w, ti.Name, cn) - c.w.WriteString(`, to_tsquery('`) - c.w.WriteString(arg.Val) - c.w.WriteString(`')`) + io.WriteString(c.w, `, to_tsquery('`) + io.WriteString(c.w, arg.Val) + io.WriteString(c.w, `')`) alias(c.w, col.Name) + i++ + } } else { pl := funcPrefixLen(cn) if pl == 0 { + if i != 0 { + io.WriteString(c.w, `, `) + } //fmt.Fprintf(w, `'%s not defined' AS %s`, cn, col.Name) - c.w.WriteString(`'`) - c.w.WriteString(cn) - c.w.WriteString(` not defined'`) + io.WriteString(c.w, `'`) + io.WriteString(c.w, cn) + io.WriteString(c.w, ` not defined'`) alias(c.w, col.Name) - } else { - isAgg = true + i++ + + } else if sel.Functions { + cn1 := cn[pl:] + if _, ok := sel.Allowed[cn1]; !ok { + continue + } + if i != 0 { + io.WriteString(c.w, `, `) + } fn := cn[0 : pl-1] - cn := cn[pl:] + isAgg = true + //fmt.Fprintf(w, `%s("%s"."%s") AS %s`, fn, c.sel.Table, cn, col.Name) - c.w.WriteString(fn) - c.w.WriteString(`(`) - colWithTable(c.w, ti.Name, cn) - c.w.WriteString(`)`) + io.WriteString(c.w, fn) + io.WriteString(c.w, `(`) + colWithTable(c.w, ti.Name, cn1) + io.WriteString(c.w, `)`) alias(c.w, col.Name) + i++ + } } } else { - groupBy = append(groupBy, i) + groupBy = append(groupBy, n) //fmt.Fprintf(w, `"%s"."%s"`, c.sel.Table, cn) + if i != 0 { + io.WriteString(c.w, `, `) + } colWithTable(c.w, ti.Name, cn) - } + i++ - if i < len(sel.Cols)-1 || len(childCols) != 0 { - //io.WriteString(w, ", ") - c.w.WriteString(`, `) } } - for i, col := range childCols { + for _, col := range childCols { if i != 0 { - //io.WriteString(w, ", ") - c.w.WriteString(`, `) + io.WriteString(c.w, `, `) } //fmt.Fprintf(w, `"%s"."%s"`, col.Table, col.Name) colWithTable(c.w, col.Table, col.Name) + i++ } - c.w.WriteString(` FROM `) + io.WriteString(c.w, ` FROM `) //fmt.Fprintf(w, ` FROM "%s"`, c.sel.Table) - c.w.WriteString(`"`) - c.w.WriteString(ti.Name) - c.w.WriteString(`"`) + io.WriteString(c.w, `"`) + io.WriteString(c.w, ti.Name) + io.WriteString(c.w, `"`) // if tn, ok := c.tmap[sel.Table]; ok { // //fmt.Fprintf(w, ` FROM "%s" AS "%s"`, tn, c.sel.Table) // tableWithAlias(c.w, ti.Name, sel.Table) // } else { // //fmt.Fprintf(w, ` FROM "%s"`, c.sel.Table) - // c.w.WriteString(`"`) - // c.w.WriteString(sel.Table) - // c.w.WriteString(`"`) + // io.WriteString(c.w, `"`) + // io.WriteString(c.w, sel.Table) + // io.WriteString(c.w, `"`) // } if isRoot && isFil { - c.w.WriteString(` WHERE (`) + io.WriteString(c.w, ` WHERE (`) if err := c.renderWhere(sel, ti); err != nil { return err } - c.w.WriteString(`)`) + io.WriteString(c.w, `)`) } if !isRoot { @@ -541,28 +598,28 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo, return err } - c.w.WriteString(` WHERE (`) + io.WriteString(c.w, ` WHERE (`) if err := c.renderRelationship(sel, ti); err != nil { return err } if isFil { - c.w.WriteString(` AND `) + io.WriteString(c.w, ` AND `) if err := c.renderWhere(sel, ti); err != nil { return err } } - c.w.WriteString(`)`) + io.WriteString(c.w, `)`) } if isAgg { if len(groupBy) != 0 { - c.w.WriteString(` GROUP BY `) + io.WriteString(c.w, ` GROUP BY `) for i, id := range groupBy { if i != 0 { - c.w.WriteString(`, `) + io.WriteString(c.w, `, `) } //fmt.Fprintf(w, `"%s"."%s"`, c.sel.Table, c.sel.Cols[id].Name) colWithTable(c.w, ti.Name, sel.Cols[id].Name) @@ -570,30 +627,32 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo, } } - if sel.Action == 0 { - if len(sel.Paging.Limit) != 0 { - //fmt.Fprintf(w, ` LIMIT ('%s') :: integer`, c.sel.Paging.Limit) - c.w.WriteString(` LIMIT ('`) - c.w.WriteString(sel.Paging.Limit) - c.w.WriteString(`') :: integer`) + switch { + case sel.Paging.NoLimit: + break - } else if ti.Singular { - c.w.WriteString(` LIMIT ('1') :: integer`) + case len(sel.Paging.Limit) != 0: + //fmt.Fprintf(w, ` LIMIT ('%s') :: integer`, c.sel.Paging.Limit) + io.WriteString(c.w, ` LIMIT ('`) + io.WriteString(c.w, sel.Paging.Limit) + io.WriteString(c.w, `') :: integer`) - } else { - c.w.WriteString(` LIMIT ('20') :: integer`) - } + case ti.Singular: + io.WriteString(c.w, ` LIMIT ('1') :: integer`) + + default: + io.WriteString(c.w, ` LIMIT ('20') :: integer`) } if len(sel.Paging.Offset) != 0 { //fmt.Fprintf(w, ` OFFSET ('%s') :: integer`, c.sel.Paging.Offset) - c.w.WriteString(` OFFSET ('`) - c.w.WriteString(sel.Paging.Offset) - c.w.WriteString(`') :: integer`) + io.WriteString(c.w, ` OFFSET ('`) + io.WriteString(c.w, sel.Paging.Offset) + io.WriteString(c.w, `') :: integer`) } //fmt.Fprintf(w, `) AS "%s_%d"`, c.sel.Table, c.sel.ID) - c.w.WriteString(`)`) + io.WriteString(c.w, `)`) aliasWithID(c.w, ti.Name, sel.ID) return nil } @@ -604,7 +663,7 @@ func (c *compilerContext) renderOrderByColumns(sel *qcode.Select, ti *DBTableInf for i := range sel.OrderBy { if colsRendered { //io.WriteString(w, ", ") - c.w.WriteString(`, `) + io.WriteString(c.w, `, `) } col := sel.OrderBy[i].Col @@ -612,7 +671,7 @@ func (c *compilerContext) renderOrderByColumns(sel *qcode.Select, ti *DBTableInf //c.sel.Table, c.sel.ID, c, //c.sel.Table, c.sel.ID, c) colWithTableID(c.w, ti.Name, sel.ID, col) - c.w.WriteString(` AS `) + io.WriteString(c.w, ` AS `) tableIDColSuffix(c.w, sel.Table, sel.ID, col, "_ob") } } @@ -629,29 +688,29 @@ func (c *compilerContext) renderRelationship(sel *qcode.Select, ti *DBTableInfo) case RelBelongTo: //fmt.Fprintf(w, `(("%s"."%s") = ("%s_%d"."%s"))`, //c.sel.Table, rel.Col1, c.parent.Table, c.parent.ID, rel.Col2) - c.w.WriteString(`((`) + io.WriteString(c.w, `((`) colWithTable(c.w, ti.Name, rel.Col1) - c.w.WriteString(`) = (`) + io.WriteString(c.w, `) = (`) colWithTableID(c.w, parent.Table, parent.ID, rel.Col2) - c.w.WriteString(`))`) + io.WriteString(c.w, `))`) case RelOneToMany: //fmt.Fprintf(w, `(("%s"."%s") = ("%s_%d"."%s"))`, //c.sel.Table, rel.Col1, c.parent.Table, c.parent.ID, rel.Col2) - c.w.WriteString(`((`) + io.WriteString(c.w, `((`) colWithTable(c.w, ti.Name, rel.Col1) - c.w.WriteString(`) = (`) + io.WriteString(c.w, `) = (`) colWithTableID(c.w, parent.Table, parent.ID, rel.Col2) - c.w.WriteString(`))`) + io.WriteString(c.w, `))`) case RelOneToManyThrough: //fmt.Fprintf(w, `(("%s"."%s") = ("%s"."%s"))`, //c.sel.Table, rel.Col1, rel.Through, rel.Col2) - c.w.WriteString(`((`) + io.WriteString(c.w, `((`) colWithTable(c.w, ti.Name, rel.Col1) - c.w.WriteString(`) = (`) + io.WriteString(c.w, `) = (`) colWithTable(c.w, rel.Through, rel.Col2) - c.w.WriteString(`))`) + io.WriteString(c.w, `))`) } return nil @@ -675,11 +734,11 @@ func (c *compilerContext) renderWhere(sel *qcode.Select, ti *DBTableInfo) error case qcode.ExpOp: switch val { case qcode.OpAnd: - c.w.WriteString(` AND `) + io.WriteString(c.w, ` AND `) case qcode.OpOr: - c.w.WriteString(` OR `) + io.WriteString(c.w, ` OR `) case qcode.OpNot: - c.w.WriteString(`NOT `) + io.WriteString(c.w, `NOT `) default: return fmt.Errorf("11: unexpected value %v (%t)", intf, intf) } @@ -703,62 +762,62 @@ func (c *compilerContext) renderWhere(sel *qcode.Select, ti *DBTableInfo) error default: if val.NestedCol { //fmt.Fprintf(w, `(("%s") `, val.Col) - c.w.WriteString(`(("`) - c.w.WriteString(val.Col) - c.w.WriteString(`") `) + io.WriteString(c.w, `(("`) + io.WriteString(c.w, val.Col) + io.WriteString(c.w, `") `) } else if len(val.Col) != 0 { //fmt.Fprintf(w, `(("%s"."%s") `, c.sel.Table, val.Col) - c.w.WriteString(`((`) + io.WriteString(c.w, `((`) colWithTable(c.w, ti.Name, val.Col) - c.w.WriteString(`) `) + io.WriteString(c.w, `) `) } valExists := true switch val.Op { case qcode.OpEquals: - c.w.WriteString(`=`) + io.WriteString(c.w, `=`) case qcode.OpNotEquals: - c.w.WriteString(`!=`) + io.WriteString(c.w, `!=`) case qcode.OpGreaterOrEquals: - c.w.WriteString(`>=`) + io.WriteString(c.w, `>=`) case qcode.OpLesserOrEquals: - c.w.WriteString(`<=`) + io.WriteString(c.w, `<=`) case qcode.OpGreaterThan: - c.w.WriteString(`>`) + io.WriteString(c.w, `>`) case qcode.OpLesserThan: - c.w.WriteString(`<`) + io.WriteString(c.w, `<`) case qcode.OpIn: - c.w.WriteString(`IN`) + io.WriteString(c.w, `IN`) case qcode.OpNotIn: - c.w.WriteString(`NOT IN`) + io.WriteString(c.w, `NOT IN`) case qcode.OpLike: - c.w.WriteString(`LIKE`) + io.WriteString(c.w, `LIKE`) case qcode.OpNotLike: - c.w.WriteString(`NOT LIKE`) + io.WriteString(c.w, `NOT LIKE`) case qcode.OpILike: - c.w.WriteString(`ILIKE`) + io.WriteString(c.w, `ILIKE`) case qcode.OpNotILike: - c.w.WriteString(`NOT ILIKE`) + io.WriteString(c.w, `NOT ILIKE`) case qcode.OpSimilar: - c.w.WriteString(`SIMILAR TO`) + io.WriteString(c.w, `SIMILAR TO`) case qcode.OpNotSimilar: - c.w.WriteString(`NOT SIMILAR TO`) + io.WriteString(c.w, `NOT SIMILAR TO`) case qcode.OpContains: - c.w.WriteString(`@>`) + io.WriteString(c.w, `@>`) case qcode.OpContainedIn: - c.w.WriteString(`<@`) + io.WriteString(c.w, `<@`) case qcode.OpHasKey: - c.w.WriteString(`?`) + io.WriteString(c.w, `?`) case qcode.OpHasKeyAny: - c.w.WriteString(`?|`) + io.WriteString(c.w, `?|`) case qcode.OpHasKeyAll: - c.w.WriteString(`?&`) + io.WriteString(c.w, `?&`) case qcode.OpIsNull: if strings.EqualFold(val.Val, "true") { - c.w.WriteString(`IS NULL)`) + io.WriteString(c.w, `IS NULL)`) } else { - c.w.WriteString(`IS NOT NULL)`) + io.WriteString(c.w, `IS NOT NULL)`) } valExists = false case qcode.OpEqID: @@ -766,20 +825,20 @@ func (c *compilerContext) renderWhere(sel *qcode.Select, ti *DBTableInfo) error return fmt.Errorf("no primary key column defined for %s", ti.Name) } //fmt.Fprintf(w, `(("%s") =`, c.ti.PrimaryCol) - c.w.WriteString(`((`) + io.WriteString(c.w, `((`) colWithTable(c.w, ti.Name, ti.PrimaryCol) - //c.w.WriteString(ti.PrimaryCol) - c.w.WriteString(`) =`) + //io.WriteString(c.w, ti.PrimaryCol) + io.WriteString(c.w, `) =`) case qcode.OpTsQuery: if len(ti.TSVCol) == 0 { return fmt.Errorf("no tsv column defined for %s", ti.Name) } //fmt.Fprintf(w, `(("%s") @@ to_tsquery('%s'))`, c.ti.TSVCol, val.Val) - c.w.WriteString(`(("`) - c.w.WriteString(ti.TSVCol) - c.w.WriteString(`") @@ to_tsquery('`) - c.w.WriteString(val.Val) - c.w.WriteString(`'))`) + io.WriteString(c.w, `(("`) + io.WriteString(c.w, ti.TSVCol) + io.WriteString(c.w, `") @@ to_tsquery('`) + io.WriteString(c.w, val.Val) + io.WriteString(c.w, `'))`) valExists = false default: @@ -792,7 +851,7 @@ func (c *compilerContext) renderWhere(sel *qcode.Select, ti *DBTableInfo) error } else { c.renderVal(val, c.vars) } - c.w.WriteString(`)`) + io.WriteString(c.w, `)`) } qcode.FreeExp(val) @@ -808,10 +867,10 @@ func (c *compilerContext) renderWhere(sel *qcode.Select, ti *DBTableInfo) error } func (c *compilerContext) renderOrderBy(sel *qcode.Select, ti *DBTableInfo) error { - c.w.WriteString(` ORDER BY `) + io.WriteString(c.w, ` ORDER BY `) for i := range sel.OrderBy { if i != 0 { - c.w.WriteString(`, `) + io.WriteString(c.w, `, `) } ob := sel.OrderBy[i] @@ -819,27 +878,27 @@ func (c *compilerContext) renderOrderBy(sel *qcode.Select, ti *DBTableInfo) erro case qcode.OrderAsc: //fmt.Fprintf(w, `"%s_%d.ob.%s" ASC`, sel.Table, sel.ID, ob.Col) tableIDColSuffix(c.w, ti.Name, sel.ID, ob.Col, "_ob") - c.w.WriteString(` ASC`) + io.WriteString(c.w, ` ASC`) case qcode.OrderDesc: //fmt.Fprintf(w, `"%s_%d.ob.%s" DESC`, sel.Table, sel.ID, ob.Col) tableIDColSuffix(c.w, ti.Name, sel.ID, ob.Col, "_ob") - c.w.WriteString(` DESC`) + io.WriteString(c.w, ` DESC`) case qcode.OrderAscNullsFirst: //fmt.Fprintf(w, `"%s_%d.ob.%s" ASC NULLS FIRST`, sel.Table, sel.ID, ob.Col) tableIDColSuffix(c.w, ti.Name, sel.ID, ob.Col, "_ob") - c.w.WriteString(` ASC NULLS FIRST`) + io.WriteString(c.w, ` ASC NULLS FIRST`) case qcode.OrderDescNullsFirst: //fmt.Fprintf(w, `%s_%d.ob.%s DESC NULLS FIRST`, sel.Table, sel.ID, ob.Col) tableIDColSuffix(c.w, ti.Name, sel.ID, ob.Col, "_ob") - c.w.WriteString(` DESC NULLLS FIRST`) + io.WriteString(c.w, ` DESC NULLLS FIRST`) case qcode.OrderAscNullsLast: //fmt.Fprintf(w, `"%s_%d.ob.%s ASC NULLS LAST`, sel.Table, sel.ID, ob.Col) tableIDColSuffix(c.w, ti.Name, sel.ID, ob.Col, "_ob") - c.w.WriteString(` ASC NULLS LAST`) + io.WriteString(c.w, ` ASC NULLS LAST`) case qcode.OrderDescNullsLast: //fmt.Fprintf(w, `%s_%d.ob.%s DESC NULLS LAST`, sel.Table, sel.ID, ob.Col) tableIDColSuffix(c.w, ti.Name, sel.ID, ob.Col, "_ob") - c.w.WriteString(` DESC NULLS LAST`) + io.WriteString(c.w, ` DESC NULLS LAST`) default: return fmt.Errorf("13: unexpected value %v", ob.Order) } @@ -851,30 +910,30 @@ func (c *compilerContext) renderDistinctOn(sel *qcode.Select, ti *DBTableInfo) { io.WriteString(c.w, `DISTINCT ON (`) for i := range sel.DistinctOn { if i != 0 { - c.w.WriteString(`, `) + io.WriteString(c.w, `, `) } //fmt.Fprintf(w, `"%s_%d.ob.%s"`, c.sel.Table, c.sel.ID, c.sel.DistinctOn[i]) tableIDColSuffix(c.w, ti.Name, sel.ID, sel.DistinctOn[i], "_ob") } - c.w.WriteString(`) `) + io.WriteString(c.w, `) `) } func (c *compilerContext) renderList(ex *qcode.Exp) { io.WriteString(c.w, ` (`) for i := range ex.ListVal { if i != 0 { - c.w.WriteString(`, `) + io.WriteString(c.w, `, `) } switch ex.ListType { case qcode.ValBool, qcode.ValInt, qcode.ValFloat: - c.w.WriteString(ex.ListVal[i]) + io.WriteString(c.w, ex.ListVal[i]) case qcode.ValStr: - c.w.WriteString(`'`) - c.w.WriteString(ex.ListVal[i]) - c.w.WriteString(`'`) + io.WriteString(c.w, `'`) + io.WriteString(c.w, ex.ListVal[i]) + io.WriteString(c.w, `'`) } } - c.w.WriteString(`)`) + io.WriteString(c.w, `)`) } func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string) { @@ -883,27 +942,27 @@ func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string) { switch ex.Type { case qcode.ValBool, qcode.ValInt, qcode.ValFloat: if len(ex.Val) != 0 { - c.w.WriteString(ex.Val) + io.WriteString(c.w, ex.Val) } else { - c.w.WriteString(`''`) + io.WriteString(c.w, `''`) } case qcode.ValStr: - c.w.WriteString(`'`) - c.w.WriteString(ex.Val) - c.w.WriteString(`'`) + io.WriteString(c.w, `'`) + io.WriteString(c.w, ex.Val) + io.WriteString(c.w, `'`) case qcode.ValVar: if val, ok := vars[ex.Val]; ok { - c.w.WriteString(val) + io.WriteString(c.w, val) } else { //fmt.Fprintf(w, `'{{%s}}'`, ex.Val) - c.w.WriteString(`{{`) - c.w.WriteString(ex.Val) - c.w.WriteString(`}}`) + io.WriteString(c.w, `{{`) + io.WriteString(c.w, ex.Val) + io.WriteString(c.w, `}}`) } } - //c.w.WriteString(`)`) + //io.WriteString(c.w, `)`) } func funcPrefixLen(fn string) int { @@ -939,105 +998,105 @@ func hasBit(n uint32, pos uint32) bool { return (val > 0) } -func alias(w *bytes.Buffer, alias string) { - w.WriteString(` AS "`) - w.WriteString(alias) - w.WriteString(`"`) +func alias(w io.Writer, alias string) { + io.WriteString(w, ` AS "`) + io.WriteString(w, alias) + io.WriteString(w, `"`) } -func aliasWithID(w *bytes.Buffer, alias string, id int32) { - w.WriteString(` AS "`) - w.WriteString(alias) - w.WriteString(`_`) +func aliasWithID(w io.Writer, alias string, id int32) { + io.WriteString(w, ` AS "`) + io.WriteString(w, alias) + io.WriteString(w, `_`) int2string(w, id) - w.WriteString(`"`) + io.WriteString(w, `"`) } -func aliasWithIDSuffix(w *bytes.Buffer, alias string, id int32, suffix string) { - w.WriteString(` AS "`) - w.WriteString(alias) - w.WriteString(`_`) +func aliasWithIDSuffix(w io.Writer, alias string, id int32, suffix string) { + io.WriteString(w, ` AS "`) + io.WriteString(w, alias) + io.WriteString(w, `_`) int2string(w, id) - w.WriteString(suffix) - w.WriteString(`"`) + io.WriteString(w, suffix) + io.WriteString(w, `"`) } -func colWithAlias(w *bytes.Buffer, col, alias string) { - w.WriteString(`"`) - w.WriteString(col) - w.WriteString(`" AS "`) - w.WriteString(alias) - w.WriteString(`"`) +func colWithAlias(w io.Writer, col, alias string) { + io.WriteString(w, `"`) + io.WriteString(w, col) + io.WriteString(w, `" AS "`) + io.WriteString(w, alias) + io.WriteString(w, `"`) } -func tableWithAlias(w *bytes.Buffer, table, alias string) { - w.WriteString(`"`) - w.WriteString(table) - w.WriteString(`" AS "`) - w.WriteString(alias) - w.WriteString(`"`) +func tableWithAlias(w io.Writer, table, alias string) { + io.WriteString(w, `"`) + io.WriteString(w, table) + io.WriteString(w, `" AS "`) + io.WriteString(w, alias) + io.WriteString(w, `"`) } -func colWithTable(w *bytes.Buffer, table, col string) { - w.WriteString(`"`) - w.WriteString(table) - w.WriteString(`"."`) - w.WriteString(col) - w.WriteString(`"`) +func colWithTable(w io.Writer, table, col string) { + io.WriteString(w, `"`) + io.WriteString(w, table) + io.WriteString(w, `"."`) + io.WriteString(w, col) + io.WriteString(w, `"`) } -func colWithTableID(w *bytes.Buffer, table string, id int32, col string) { - w.WriteString(`"`) - w.WriteString(table) - w.WriteString(`_`) +func colWithTableID(w io.Writer, table string, id int32, col string) { + io.WriteString(w, `"`) + io.WriteString(w, table) + io.WriteString(w, `_`) int2string(w, id) - w.WriteString(`"."`) - w.WriteString(col) - w.WriteString(`"`) + io.WriteString(w, `"."`) + io.WriteString(w, col) + io.WriteString(w, `"`) } -func colWithTableIDAlias(w *bytes.Buffer, table string, id int32, col, alias string) { - w.WriteString(`"`) - w.WriteString(table) - w.WriteString(`_`) +func colWithTableIDAlias(w io.Writer, table string, id int32, col, alias string) { + io.WriteString(w, `"`) + io.WriteString(w, table) + io.WriteString(w, `_`) int2string(w, id) - w.WriteString(`"."`) - w.WriteString(col) - w.WriteString(`" AS "`) - w.WriteString(alias) - w.WriteString(`"`) + io.WriteString(w, `"."`) + io.WriteString(w, col) + io.WriteString(w, `" AS "`) + io.WriteString(w, alias) + io.WriteString(w, `"`) } -func colWithTableIDSuffixAlias(w *bytes.Buffer, table string, id int32, +func colWithTableIDSuffixAlias(w io.Writer, table string, id int32, suffix, col, alias string) { - w.WriteString(`"`) - w.WriteString(table) - w.WriteString(`_`) + io.WriteString(w, `"`) + io.WriteString(w, table) + io.WriteString(w, `_`) int2string(w, id) - w.WriteString(suffix) - w.WriteString(`"."`) - w.WriteString(col) - w.WriteString(`" AS "`) - w.WriteString(alias) - w.WriteString(`"`) + io.WriteString(w, suffix) + io.WriteString(w, `"."`) + io.WriteString(w, col) + io.WriteString(w, `" AS "`) + io.WriteString(w, alias) + io.WriteString(w, `"`) } -func tableIDColSuffix(w *bytes.Buffer, table string, id int32, col, suffix string) { - w.WriteString(`"`) - w.WriteString(table) - w.WriteString(`_`) +func tableIDColSuffix(w io.Writer, table string, id int32, col, suffix string) { + io.WriteString(w, `"`) + io.WriteString(w, table) + io.WriteString(w, `_`) int2string(w, id) - w.WriteString(`_`) - w.WriteString(col) - w.WriteString(suffix) - w.WriteString(`"`) + io.WriteString(w, `_`) + io.WriteString(w, col) + io.WriteString(w, suffix) + io.WriteString(w, `"`) } const charset = "0123456789" -func int2string(w *bytes.Buffer, val int32) { +func int2string(w io.Writer, val int32) { if val < 10 { - w.WriteByte(charset[val]) + w.Write([]byte{charset[val]}) return } @@ -1053,7 +1112,7 @@ func int2string(w *bytes.Buffer, val int32) { for val3 > 0 { d := val3 % 10 val3 /= 10 - w.WriteByte(charset[d]) + w.Write([]byte{charset[d]}) } } diff --git a/psql/select_test.go b/psql/query_test.go similarity index 83% rename from psql/select_test.go rename to psql/query_test.go index 553f576..78330b6 100644 --- a/psql/select_test.go +++ b/psql/query_test.go @@ -22,32 +22,6 @@ func TestMain(m *testing.M) { var err error qcompile, err = qcode.NewCompiler(qcode.Config{ - DefaultFilter: []string{ - `{ user_id: { _eq: $user_id } }`, - }, - FilterMap: qcode.Filters{ - All: map[string][]string{ - "users": []string{ - "{ id: { eq: $user_id } }", - }, - "products": []string{ - "{ price: { gt: 0 } }", - "{ price: { lt: 8 } }", - }, - "customers": []string{}, - "mes": []string{ - "{ id: { eq: $user_id } }", - }, - }, - Query: map[string][]string{ - "users": []string{}, - }, - Update: map[string][]string{ - "products": []string{ - "{ user_id: { eq: $user_id } }", - }, - }, - }, Blocklist: []string{ "secret", "password", @@ -55,6 +29,59 @@ func TestMain(m *testing.M) { }, }) + qcompile.AddRole("user", "product", qcode.TRConfig{ + Query: qcode.QueryConfig{ + Columns: []string{"id", "name", "price", "users", "customers"}, + Filters: []string{ + "{ price: { gt: 0 } }", + "{ price: { lt: 8 } }", + }, + }, + Update: qcode.UpdateConfig{ + Filters: []string{"{ user_id: { eq: $user_id } }"}, + }, + Delete: qcode.DeleteConfig{ + Filters: []string{ + "{ price: { gt: 0 } }", + "{ price: { lt: 8 } }", + }, + }, + }) + + qcompile.AddRole("anon", "product", qcode.TRConfig{ + Query: qcode.QueryConfig{ + Columns: []string{"id", "name"}, + }, + }) + + qcompile.AddRole("anon1", "product", qcode.TRConfig{ + Query: qcode.QueryConfig{ + Columns: []string{"id", "name", "price"}, + DisableFunctions: true, + }, + }) + + qcompile.AddRole("user", "users", qcode.TRConfig{ + Query: qcode.QueryConfig{ + Columns: []string{"id", "full_name", "avatar", "email", "products"}, + }, + }) + + qcompile.AddRole("user", "mes", qcode.TRConfig{ + Query: qcode.QueryConfig{ + Columns: []string{"id", "full_name", "avatar"}, + Filters: []string{ + "{ id: { eq: $user_id } }", + }, + }, + }) + + qcompile.AddRole("user", "customers", qcode.TRConfig{ + Query: qcode.QueryConfig{ + Columns: []string{"id", "email", "full_name", "products"}, + }, + }) + if err != nil { log.Fatal(err) } @@ -135,9 +162,8 @@ func TestMain(m *testing.M) { os.Exit(m.Run()) } -func compileGQLToPSQL(gql string, vars Variables) ([]byte, error) { - - qc, err := qcompile.Compile([]byte(gql)) +func compileGQLToPSQL(gql string, vars Variables, role string) ([]byte, error) { + qc, err := qcompile.Compile([]byte(gql), role) if err != nil { return nil, err } @@ -147,6 +173,8 @@ func compileGQLToPSQL(gql string, vars Variables) ([]byte, error) { return nil, err } + //fmt.Println(string(sqlStmt)) + return sqlStmt, nil } @@ -173,9 +201,9 @@ func withComplexArgs(t *testing.T) { } }` - sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0" ORDER BY "products_0_price_ob" DESC), '[]') AS "products" FROM (SELECT DISTINCT ON ("products_0_price_ob") row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0", "products_0"."price" AS "products_0_price_ob" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") < 28) AND (("products"."id") >= 20)) LIMIT ('30') :: integer) AS "products_0" ORDER BY "products_0_price_ob" DESC LIMIT ('30') :: integer) AS "sel_json_agg_0") AS "done_1337";` + sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0" ORDER BY "products_0_price_ob" DESC), '[]') AS "products" FROM (SELECT DISTINCT ON ("products_0_price_ob") row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0", "products_0"."price" AS "products_0_price_ob" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") < 28) AND (("products"."id") >= 20)) LIMIT ('30') :: integer) AS "products_0" ORDER BY "products_0_price_ob" DESC LIMIT ('30') :: integer) AS "sel_json_agg_0") AS "done_1337"` - resSQL, err := compileGQLToPSQL(gql, nil) + resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { t.Fatal(err) } @@ -201,9 +229,9 @@ func withWhereMultiOr(t *testing.T) { } }` - sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") < 20) OR (("products"."price") > 10) OR NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` + sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") < 20) OR (("products"."price") > 10) OR NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"` - resSQL, err := compileGQLToPSQL(gql, nil) + resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { t.Fatal(err) } @@ -227,9 +255,9 @@ func withWhereIsNull(t *testing.T) { } }` - sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") > 10) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` + sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") > 10) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"` - resSQL, err := compileGQLToPSQL(gql, nil) + resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { t.Fatal(err) } @@ -253,9 +281,9 @@ func withWhereAndList(t *testing.T) { } }` - sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") > 10) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` + sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") > 10) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"` - resSQL, err := compileGQLToPSQL(gql, nil) + resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { t.Fatal(err) } @@ -273,9 +301,9 @@ func fetchByID(t *testing.T) { } }` - sql := `SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") = 15)) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337";` + sql := `SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") = 15)) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337"` - resSQL, err := compileGQLToPSQL(gql, nil) + resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { t.Fatal(err) } @@ -293,9 +321,9 @@ func searchQuery(t *testing.T) { } }` - sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("tsv") @@ to_tsquery('Imperial'))) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` + sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("tsv") @@ to_tsquery('Imperial'))) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"` - resSQL, err := compileGQLToPSQL(gql, nil) + resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { t.Fatal(err) } @@ -316,9 +344,9 @@ func oneToMany(t *testing.T) { } }` - sql := `SELECT json_object_agg('users', users) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."email" AS "email", "products_1_join"."products" AS "products") AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."email", "users"."id" FROM "users" LIMIT ('20') :: integer) AS "users_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "products_1"."name" AS "name", "products_1"."price" AS "price") AS "sel_1")) AS "sel_json_1" FROM (SELECT "products"."name", "products"."price" FROM "products" WHERE ((("products"."user_id") = ("users_0"."id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` + sql := `SELECT json_object_agg('users', users) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."email" AS "email", "products_1_join"."products" AS "products") AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."email", "users"."id" FROM "users" LIMIT ('20') :: integer) AS "users_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "products_1"."name" AS "name", "products_1"."price" AS "price") AS "sel_1")) AS "sel_json_1" FROM (SELECT "products"."name", "products"."price" FROM "products" WHERE ((("products"."user_id") = ("users_0"."id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"` - resSQL, err := compileGQLToPSQL(gql, nil) + resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { t.Fatal(err) } @@ -339,9 +367,9 @@ func belongsTo(t *testing.T) { } }` - sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."price" AS "price", "users_1_join"."users" AS "users") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", "products"."price", "products"."user_id" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "users_1"."email" AS "email") AS "sel_1")) AS "sel_json_1" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = ("products_0"."user_id"))) LIMIT ('20') :: integer) AS "users_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "users_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` + sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."price" AS "price", "users_1_join"."users" AS "users") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", "products"."price", "products"."user_id" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "users_1"."email" AS "email") AS "sel_1")) AS "sel_json_1" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = ("products_0"."user_id"))) LIMIT ('20') :: integer) AS "users_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "users_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"` - resSQL, err := compileGQLToPSQL(gql, nil) + resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { t.Fatal(err) } @@ -362,9 +390,9 @@ func manyToMany(t *testing.T) { } }` - sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "customers_1_join"."customers" AS "customers") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", "products"."id" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "customers_1"."email" AS "email", "customers_1"."full_name" AS "full_name") AS "sel_1")) AS "sel_json_1" FROM (SELECT "customers"."email", "customers"."full_name" FROM "customers" LEFT OUTER JOIN "purchases" ON (("purchases"."product_id") = ("products_0"."id")) WHERE ((("customers"."id") = ("purchases"."customer_id"))) LIMIT ('20') :: integer) AS "customers_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "customers_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` + sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "customers_1_join"."customers" AS "customers") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", "products"."id" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "customers_1"."email" AS "email", "customers_1"."full_name" AS "full_name") AS "sel_1")) AS "sel_json_1" FROM (SELECT "customers"."email", "customers"."full_name" FROM "customers" LEFT OUTER JOIN "purchases" ON (("purchases"."product_id") = ("products_0"."id")) WHERE ((("customers"."id") = ("purchases"."customer_id"))) LIMIT ('20') :: integer) AS "customers_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "customers_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"` - resSQL, err := compileGQLToPSQL(gql, nil) + resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { t.Fatal(err) } @@ -385,9 +413,9 @@ func manyToManyReverse(t *testing.T) { } }` - sql := `SELECT json_object_agg('customers', customers) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "customers_0"."email" AS "email", "customers_0"."full_name" AS "full_name", "products_1_join"."products" AS "products") AS "sel_0")) AS "sel_json_0" FROM (SELECT "customers"."email", "customers"."full_name", "customers"."id" FROM "customers" LIMIT ('20') :: integer) AS "customers_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "products_1"."name" AS "name") AS "sel_1")) AS "sel_json_1" FROM (SELECT "products"."name" FROM "products" LEFT OUTER JOIN "purchases" ON (("purchases"."customer_id") = ("customers_0"."id")) WHERE ((("products"."id") = ("purchases"."product_id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` + sql := `SELECT json_object_agg('customers', customers) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "customers_0"."email" AS "email", "customers_0"."full_name" AS "full_name", "products_1_join"."products" AS "products") AS "sel_0")) AS "sel_json_0" FROM (SELECT "customers"."email", "customers"."full_name", "customers"."id" FROM "customers" LIMIT ('20') :: integer) AS "customers_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("sel_json_1"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "products_1"."name" AS "name") AS "sel_1")) AS "sel_json_1" FROM (SELECT "products"."name" FROM "products" LEFT OUTER JOIN "purchases" ON (("purchases"."customer_id") = ("customers_0"."id")) WHERE ((("products"."id") = ("purchases"."product_id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "sel_json_agg_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"` - resSQL, err := compileGQLToPSQL(gql, nil) + resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { t.Fatal(err) } @@ -405,9 +433,49 @@ func aggFunction(t *testing.T) { } }` - sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."count_price" AS "count_price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", count("products"."price") AS "count_price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) GROUP BY "products"."name" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` + sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."count_price" AS "count_price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name", count("products"."price") AS "count_price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8)) GROUP BY "products"."name" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"` - resSQL, err := compileGQLToPSQL(gql, nil) + resSQL, err := compileGQLToPSQL(gql, nil, "user") + if err != nil { + t.Fatal(err) + } + + if string(resSQL) != sql { + t.Fatal(errNotExpected) + } +} + +func aggFunctionBlockedByCol(t *testing.T) { + gql := `query { + products { + name + count_price + } + }` + + sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name" FROM "products" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"` + + resSQL, err := compileGQLToPSQL(gql, nil, "anon") + if err != nil { + t.Fatal(err) + } + + if string(resSQL) != sql { + t.Fatal(errNotExpected) + } +} + +func aggFunctionDisabled(t *testing.T) { + gql := `query { + products { + name + count_price + } + }` + + sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."name" FROM "products" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"` + + resSQL, err := compileGQLToPSQL(gql, nil, "anon1") if err != nil { t.Fatal(err) } @@ -425,9 +493,9 @@ func aggFunctionWithFilter(t *testing.T) { } }` - sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."max_price" AS "max_price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", max("products"."price") AS "max_price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") > 10)) GROUP BY "products"."id" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337";` + sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("sel_json_0"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."max_price" AS "max_price") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", max("products"."price") AS "max_price" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."id") > 10)) GROUP BY "products"."id" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "sel_json_agg_0") AS "done_1337"` - resSQL, err := compileGQLToPSQL(gql, nil) + resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { t.Fatal(err) } @@ -445,9 +513,9 @@ func queryWithVariables(t *testing.T) { } }` - sql := `SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") = {{product_price}}) AND (("products"."id") = {{product_id}})) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337";` + sql := `SELECT json_object_agg('product', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "sel_json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("products"."price") = {{product_price}}) AND (("products"."id") = {{product_id}})) LIMIT ('1') :: integer) AS "products_0" LIMIT ('1') :: integer) AS "done_1337"` - resSQL, err := compileGQLToPSQL(gql, nil) + resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { t.Fatal(err) } @@ -464,9 +532,9 @@ func syntheticTables(t *testing.T) { } }` - sql := `SELECT json_object_agg('me', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."email" AS "email") AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = {{user_id}})) LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "done_1337";` + sql := `SELECT json_object_agg('me', sel_json_0) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT ) AS "sel_0")) AS "sel_json_0" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = {{user_id}})) LIMIT ('1') :: integer) AS "users_0" LIMIT ('1') :: integer) AS "done_1337"` - resSQL, err := compileGQLToPSQL(gql, nil) + resSQL, err := compileGQLToPSQL(gql, nil, "user") if err != nil { t.Fatal(err) } @@ -476,7 +544,7 @@ func syntheticTables(t *testing.T) { } } -func TestCompileSelect(t *testing.T) { +func TestCompileQuery(t *testing.T) { t.Run("withComplexArgs", withComplexArgs) t.Run("withWhereAndList", withWhereAndList) t.Run("withWhereIsNull", withWhereIsNull) @@ -488,10 +556,11 @@ func TestCompileSelect(t *testing.T) { t.Run("manyToMany", manyToMany) t.Run("manyToManyReverse", manyToManyReverse) t.Run("aggFunction", aggFunction) + t.Run("aggFunctionBlockedByCol", aggFunctionBlockedByCol) + t.Run("aggFunctionDisabled", aggFunctionDisabled) t.Run("aggFunctionWithFilter", aggFunctionWithFilter) t.Run("syntheticTables", syntheticTables) t.Run("queryWithVariables", queryWithVariables) - } var benchGQL = []byte(`query { @@ -526,7 +595,7 @@ func BenchmarkCompile(b *testing.B) { for n := 0; n < b.N; n++ { w.Reset() - qc, err := qcompile.Compile(benchGQL) + qc, err := qcompile.Compile(benchGQL, "user") if err != nil { b.Fatal(err) } @@ -547,7 +616,7 @@ func BenchmarkCompileParallel(b *testing.B) { for pb.Next() { w.Reset() - qc, err := qcompile.Compile(benchGQL) + qc, err := qcompile.Compile(benchGQL, "user") if err != nil { b.Fatal(err) } diff --git a/qcode/config.go b/qcode/config.go new file mode 100644 index 0000000..c68a3d1 --- /dev/null +++ b/qcode/config.go @@ -0,0 +1,99 @@ +package qcode + +type Config struct { + Blocklist []string + KeepArgs bool +} + +type QueryConfig struct { + Limit int + Filters []string + Columns []string + DisableFunctions bool +} + +type InsertConfig struct { + Filters []string + Columns []string + Set map[string]string +} + +type UpdateConfig struct { + Filters []string + Columns []string + Set map[string]string +} + +type DeleteConfig struct { + Filters []string + Columns []string +} + +type TRConfig struct { + Query QueryConfig + Insert InsertConfig + Update UpdateConfig + Delete DeleteConfig +} + +type trval struct { + query struct { + limit string + fil *Exp + cols map[string]struct{} + disable struct { + funcs bool + } + } + + insert struct { + fil *Exp + cols map[string]struct{} + set map[string]string + } + + update struct { + fil *Exp + cols map[string]struct{} + set map[string]string + } + + delete struct { + fil *Exp + cols map[string]struct{} + } +} + +func (trv *trval) allowedColumns(qt QType) map[string]struct{} { + switch qt { + case QTQuery: + return trv.query.cols + case QTInsert: + return trv.insert.cols + case QTUpdate: + return trv.update.cols + case QTDelete: + return trv.insert.cols + case QTUpsert: + return trv.insert.cols + } + + return nil +} + +func (trv *trval) filter(qt QType) *Exp { + switch qt { + case QTQuery: + return trv.query.fil + case QTInsert: + return trv.insert.fil + case QTUpdate: + return trv.update.fil + case QTDelete: + return trv.delete.fil + case QTUpsert: + return trv.insert.fil + } + + return nil +} diff --git a/qcode/fuzz.go b/qcode/fuzz.go index db8f3c8..89a4a3c 100644 --- a/qcode/fuzz.go +++ b/qcode/fuzz.go @@ -5,7 +5,7 @@ func FuzzerEntrypoint(data []byte) int { //testData := string(data) qcompile, _ := NewCompiler(Config{}) - _, err := qcompile.Compile(data) + _, err := qcompile.Compile(data, "user") if err != nil { return -1 } diff --git a/qcode/parse.go b/qcode/parse.go index c07ab45..0fe6c34 100644 --- a/qcode/parse.go +++ b/qcode/parse.go @@ -18,7 +18,9 @@ type parserType int32 const ( maxFields = 100 maxArgs = 10 +) +const ( parserError parserType = iota parserEOF opQuery diff --git a/qcode/parse_test.go b/qcode/parse_test.go index dba397f..0e04ed2 100644 --- a/qcode/parse_test.go +++ b/qcode/parse_test.go @@ -46,13 +46,18 @@ func compareOp(op1, op2 Operation) error { */ func TestCompile1(t *testing.T) { - qcompile, _ := NewCompiler(Config{}) + qc, _ := NewCompiler(Config{}) + qc.AddRole("user", "product", TRConfig{ + Query: QueryConfig{ + Columns: []string{"id", "Name"}, + }, + }) - _, err := qcompile.Compile([]byte(` + _, err := qc.Compile([]byte(` product(id: 15) { id name - }`)) + }`), "user") if err != nil { t.Fatal(err) @@ -60,13 +65,18 @@ func TestCompile1(t *testing.T) { } func TestCompile2(t *testing.T) { - qcompile, _ := NewCompiler(Config{}) + qc, _ := NewCompiler(Config{}) + qc.AddRole("user", "product", TRConfig{ + Query: QueryConfig{ + Columns: []string{"ID"}, + }, + }) - _, err := qcompile.Compile([]byte(` + _, err := qc.Compile([]byte(` query { product(id: 15) { id name - } }`)) + } }`), "user") if err != nil { t.Fatal(err) @@ -74,15 +84,20 @@ func TestCompile2(t *testing.T) { } func TestCompile3(t *testing.T) { - qcompile, _ := NewCompiler(Config{}) + qc, _ := NewCompiler(Config{}) + qc.AddRole("user", "product", TRConfig{ + Query: QueryConfig{ + Columns: []string{"ID"}, + }, + }) - _, err := qcompile.Compile([]byte(` + _, err := qc.Compile([]byte(` mutation { product(id: 15, name: "Test") { id name } - }`)) + }`), "user") if err != nil { t.Fatal(err) @@ -91,7 +106,7 @@ func TestCompile3(t *testing.T) { func TestInvalidCompile1(t *testing.T) { qcompile, _ := NewCompiler(Config{}) - _, err := qcompile.Compile([]byte(`#`)) + _, err := qcompile.Compile([]byte(`#`), "user") if err == nil { t.Fatal(errors.New("expecting an error")) @@ -100,7 +115,7 @@ func TestInvalidCompile1(t *testing.T) { func TestInvalidCompile2(t *testing.T) { qcompile, _ := NewCompiler(Config{}) - _, err := qcompile.Compile([]byte(`{u(where:{not:0})}`)) + _, err := qcompile.Compile([]byte(`{u(where:{not:0})}`), "user") if err == nil { t.Fatal(errors.New("expecting an error")) @@ -109,7 +124,7 @@ func TestInvalidCompile2(t *testing.T) { func TestEmptyCompile(t *testing.T) { qcompile, _ := NewCompiler(Config{}) - _, err := qcompile.Compile([]byte(``)) + _, err := qcompile.Compile([]byte(``), "user") if err == nil { t.Fatal(errors.New("expecting an error")) @@ -144,7 +159,7 @@ func BenchmarkQCompile(b *testing.B) { b.ReportAllocs() for n := 0; n < b.N; n++ { - _, err := qcompile.Compile(gql) + _, err := qcompile.Compile(gql, "user") if err != nil { b.Fatal(err) @@ -160,7 +175,7 @@ func BenchmarkQCompileP(b *testing.B) { b.RunParallel(func(pb *testing.PB) { for pb.Next() { - _, err := qcompile.Compile(gql) + _, err := qcompile.Compile(gql, "user") if err != nil { b.Fatal(err) diff --git a/qcode/qcode.go b/qcode/qcode.go index bafe3e4..8c90d93 100644 --- a/qcode/qcode.go +++ b/qcode/qcode.go @@ -3,6 +3,7 @@ package qcode import ( "errors" "fmt" + "strconv" "strings" "sync" @@ -15,25 +16,20 @@ type Action int const ( maxSelectors = 30 +) +const ( QTQuery QType = iota + 1 - QTMutation - - ActionInsert Action = iota + 1 - ActionUpdate - ActionDelete - ActionUpsert + QTInsert + QTUpdate + QTDelete + QTUpsert ) type QCode struct { - Type QType - Selects []Select -} - -type Column struct { - Table string - Name string - FieldName string + Type QType + ActionVar string + Selects []Select } type Select struct { @@ -47,9 +43,15 @@ type Select struct { OrderBy []*OrderBy DistinctOn []string Paging Paging - Action Action - ActionVar string Children []int32 + Functions bool + Allowed map[string]struct{} +} + +type Column struct { + Table string + Name string + FieldName string } type Exp struct { @@ -77,8 +79,9 @@ type OrderBy struct { } type Paging struct { - Limit string - Offset string + Limit string + Offset string + NoLimit bool } type ExpOp int @@ -145,81 +148,23 @@ const ( OrderDescNullsLast ) -type Filters struct { - All map[string][]string - Query map[string][]string - Insert map[string][]string - Update map[string][]string - Delete map[string][]string -} - -type Config struct { - DefaultFilter []string - FilterMap Filters - Blocklist []string - KeepArgs bool -} - type Compiler struct { - df *Exp - fm struct { - all map[string]*Exp - query map[string]*Exp - insert map[string]*Exp - update map[string]*Exp - delete map[string]*Exp - } + tr map[string]map[string]*trval bl map[string]struct{} ka bool } -var opMap = map[parserType]QType{ - opQuery: QTQuery, - opMutate: QTMutation, -} - var expPool = sync.Pool{ New: func() interface{} { return &Exp{doFree: true} }, } func NewCompiler(c Config) (*Compiler, error) { - var err error co := &Compiler{ka: c.KeepArgs} - + co.tr = make(map[string]map[string]*trval) co.bl = make(map[string]struct{}, len(c.Blocklist)) for i := range c.Blocklist { - co.bl[c.Blocklist[i]] = struct{}{} - } - - co.df, err = compileFilter(c.DefaultFilter) - if err != nil { - return nil, err - } - - co.fm.all, err = buildFilters(c.FilterMap.All) - if err != nil { - return nil, err - } - - co.fm.query, err = buildFilters(c.FilterMap.Query) - if err != nil { - return nil, err - } - - co.fm.insert, err = buildFilters(c.FilterMap.Insert) - if err != nil { - return nil, err - } - - co.fm.update, err = buildFilters(c.FilterMap.Update) - if err != nil { - return nil, err - } - - co.fm.delete, err = buildFilters(c.FilterMap.Delete) - if err != nil { - return nil, err + co.bl[strings.ToLower(c.Blocklist[i])] = struct{}{} } seedExp := [100]Exp{} @@ -232,58 +177,99 @@ func NewCompiler(c Config) (*Compiler, error) { return co, nil } -func buildFilters(filMap map[string][]string) (map[string]*Exp, error) { - fm := make(map[string]*Exp, len(filMap)) +func (com *Compiler) AddRole(role, table string, trc TRConfig) error { + var err error + trv := &trval{} - for k, v := range filMap { - fil, err := compileFilter(v) - if err != nil { - return nil, err + toMap := func(cols []string) map[string]struct{} { + m := make(map[string]struct{}, len(cols)) + for i := range cols { + m[strings.ToLower(cols[i])] = struct{}{} } - singular := flect.Singularize(k) - plural := flect.Pluralize(k) - - fm[singular] = fil - fm[plural] = fil + return m } - return fm, nil + // query config + trv.query.fil, err = compileFilter(trc.Query.Filters) + if err != nil { + return err + } + if trc.Query.Limit > 0 { + trv.query.limit = strconv.Itoa(trc.Query.Limit) + } + trv.query.cols = toMap(trc.Query.Columns) + trv.query.disable.funcs = trc.Query.DisableFunctions + + // insert config + if trv.insert.fil, err = compileFilter(trc.Insert.Filters); err != nil { + return err + } + trv.insert.cols = toMap(trc.Insert.Columns) + + // update config + if trv.update.fil, err = compileFilter(trc.Update.Filters); err != nil { + return err + } + trv.insert.cols = toMap(trc.Insert.Columns) + trv.insert.set = trc.Insert.Set + + // delete config + if trv.delete.fil, err = compileFilter(trc.Delete.Filters); err != nil { + return err + } + trv.delete.cols = toMap(trc.Delete.Columns) + + singular := flect.Singularize(table) + plural := flect.Pluralize(table) + + if _, ok := com.tr[role]; !ok { + com.tr[role] = make(map[string]*trval) + } + + com.tr[role][singular] = trv + com.tr[role][plural] = trv + return nil } -func (com *Compiler) Compile(query []byte) (*QCode, error) { - var qc QCode +func (com *Compiler) Compile(query []byte, role string) (*QCode, error) { var err error + qc := QCode{Type: QTQuery} + op, err := Parse(query) if err != nil { return nil, err } - qc.Selects, err = com.compileQuery(op) - if err != nil { + if err = com.compileQuery(&qc, op, role); err != nil { return nil, err } - if t, ok := opMap[op.Type]; ok { - qc.Type = t - } else { - return nil, fmt.Errorf("Unknown operation type %d", op.Type) - } - opPool.Put(op) return &qc, nil } -func (com *Compiler) compileQuery(op *Operation) ([]Select, error) { +func (com *Compiler) compileQuery(qc *QCode, op *Operation, role string) error { id := int32(0) parentID := int32(0) + if len(op.Fields) == 0 { + return errors.New("invalid graphql no query found") + } + + if op.Type == opMutate { + if err := com.setMutationType(qc, op.Fields[0].Args); err != nil { + return err + } + } + selects := make([]Select, 0, 5) st := NewStack() + action := qc.Type if len(op.Fields) == 0 { - return nil, errors.New("empty query") + return errors.New("empty query") } st.Push(op.Fields[0].ID) @@ -293,7 +279,7 @@ func (com *Compiler) compileQuery(op *Operation) ([]Select, error) { } if id >= maxSelectors { - return nil, fmt.Errorf("selector limit reached (%d)", maxSelectors) + return fmt.Errorf("selector limit reached (%d)", maxSelectors) } fid := st.Pop() @@ -303,14 +289,25 @@ func (com *Compiler) compileQuery(op *Operation) ([]Select, error) { continue } + trv := com.getRole(role, field.Name) + selects = append(selects, Select{ ID: id, ParentID: parentID, Table: field.Name, Children: make([]int32, 0, 5), + Allowed: trv.allowedColumns(action), }) s := &selects[(len(selects) - 1)] + if action == QTQuery { + s.Functions = !trv.query.disable.funcs + + if len(trv.query.limit) != 0 { + s.Paging.Limit = trv.query.limit + } + } + if s.ID != 0 { p := &selects[s.ParentID] p.Children = append(p.Children, s.ID) @@ -322,12 +319,13 @@ func (com *Compiler) compileQuery(op *Operation) ([]Select, error) { s.FieldName = s.Table } - err := com.compileArgs(s, field.Args) + err := com.compileArgs(qc, s, field.Args) if err != nil { - return nil, err + return err } s.Cols = make([]Column, 0, len(field.Children)) + action = QTQuery for _, cid := range field.Children { f := op.Fields[cid] @@ -356,36 +354,14 @@ func (com *Compiler) compileQuery(op *Operation) ([]Select, error) { } if id == 0 { - return nil, errors.New("invalid query") + return errors.New("invalid query") } var fil *Exp - root := &selects[0] - switch op.Type { - case opQuery: - fil, _ = com.fm.query[root.Table] - - case opMutate: - switch root.Action { - case ActionInsert: - fil, _ = com.fm.insert[root.Table] - case ActionUpdate: - fil, _ = com.fm.update[root.Table] - case ActionDelete: - fil, _ = com.fm.delete[root.Table] - case ActionUpsert: - fil, _ = com.fm.insert[root.Table] - } - } - - if fil == nil { - fil, _ = com.fm.all[root.Table] - } - - if fil == nil { - fil = com.df + if trv, ok := com.tr[role][op.Fields[0].Name]; ok { + fil = trv.filter(qc.Type) } if fil != nil && fil.Op != OpNop { @@ -403,10 +379,11 @@ func (com *Compiler) compileQuery(op *Operation) ([]Select, error) { } } - return selects[:id], nil + qc.Selects = selects[:id] + return nil } -func (com *Compiler) compileArgs(sel *Select, args []Arg) error { +func (com *Compiler) compileArgs(qc *QCode, sel *Select, args []Arg) error { var err error if com.ka { @@ -418,9 +395,7 @@ func (com *Compiler) compileArgs(sel *Select, args []Arg) error { switch arg.Name { case "id": - if sel.ID == 0 { - err = com.compileArgID(sel, arg) - } + err = com.compileArgID(sel, arg) case "search": err = com.compileArgSearch(sel, arg) case "where": @@ -433,18 +408,6 @@ func (com *Compiler) compileArgs(sel *Select, args []Arg) error { err = com.compileArgLimit(sel, arg) case "offset": err = com.compileArgOffset(sel, arg) - case "insert": - sel.Action = ActionInsert - err = com.compileArgAction(sel, arg) - case "update": - sel.Action = ActionUpdate - err = com.compileArgAction(sel, arg) - case "upsert": - sel.Action = ActionUpsert - err = com.compileArgAction(sel, arg) - case "delete": - sel.Action = ActionDelete - err = com.compileArgAction(sel, arg) } if err != nil { @@ -461,6 +424,45 @@ func (com *Compiler) compileArgs(sel *Select, args []Arg) error { return nil } +func (com *Compiler) setMutationType(qc *QCode, args []Arg) error { + setActionVar := func(arg *Arg) error { + if arg.Val.Type != nodeVar { + return fmt.Errorf("value for argument '%s' must be a variable", arg.Name) + } + qc.ActionVar = arg.Val.Val + return nil + } + + for i := range args { + arg := &args[i] + + switch arg.Name { + case "insert": + qc.Type = QTInsert + return setActionVar(arg) + case "update": + qc.Type = QTUpdate + return setActionVar(arg) + case "upsert": + qc.Type = QTUpsert + return setActionVar(arg) + case "delete": + qc.Type = QTDelete + + if arg.Val.Type != nodeBool { + return fmt.Errorf("value for argument '%s' must be a boolean", arg.Name) + } + + if arg.Val.Val == "false" { + qc.Type = QTQuery + } + return nil + } + } + + return nil +} + func (com *Compiler) compileArgObj(st *util.Stack, arg *Arg) (*Exp, error) { if arg.Val.Type != nodeObj { return nil, fmt.Errorf("expecting an object") @@ -540,6 +542,10 @@ func (com *Compiler) compileArgNode(st *util.Stack, node *Node, usePool bool) (* } func (com *Compiler) compileArgID(sel *Select, arg *Arg) error { + if sel.ID != 0 { + return nil + } + if sel.Where != nil && sel.Where.Op == OpEqID { return nil } @@ -732,24 +738,14 @@ func (com *Compiler) compileArgOffset(sel *Select, arg *Arg) error { return nil } -func (com *Compiler) compileArgAction(sel *Select, arg *Arg) error { - switch sel.Action { - case ActionDelete: - if arg.Val.Type != nodeBool { - return fmt.Errorf("value for argument '%s' must be a boolean", arg.Name) - } - if arg.Val.Val == "false" { - sel.Action = 0 - } +var zeroTrv = &trval{} - default: - if arg.Val.Type != nodeVar { - return fmt.Errorf("value for argument '%s' must be a variable", arg.Name) - } - sel.ActionVar = arg.Val.Val +func (com *Compiler) getRole(role, field string) *trval { + if trv, ok := com.tr[role][field]; ok { + return trv + } else { + return zeroTrv } - - return nil } func newExp(st *util.Stack, node *Node, usePool bool) (*Exp, error) { diff --git a/serv/allow.go b/serv/allow.go index 729ef61..f8f02ad 100644 --- a/serv/allow.go +++ b/serv/allow.go @@ -26,7 +26,8 @@ type allowItem struct { var _allowList allowList type allowList struct { - list map[string]*allowItem + list []*allowItem + index map[string]int filepath string saveChan chan *allowItem active bool @@ -34,7 +35,7 @@ type allowList struct { func initAllowList(cpath string) { _allowList = allowList{ - list: make(map[string]*allowItem), + index: make(map[string]int), saveChan: make(chan *allowItem), active: true, } @@ -172,17 +173,21 @@ func (al *allowList) load() { if c == 0 { if ty == AL_QUERY { q := string(b[s:(e + 1)]) + key := gqlHash(q, varBytes, "") - item := &allowItem{ - uri: uri, - gql: q, - } - - if len(varBytes) != 0 { + if idx, ok := al.index[key]; !ok { + al.list = append(al.list, &allowItem{ + uri: uri, + gql: q, + vars: varBytes, + }) + al.index[key] = len(al.list) - 1 + } else { + item := al.list[idx] + item.gql = q item.vars = varBytes } - al.list[gqlHash(q, varBytes)] = item varBytes = nil } else if ty == AL_VARS { @@ -203,7 +208,15 @@ func (al *allowList) save(item *allowItem) { if al.active == false { return } - al.list[gqlHash(item.gql, item.vars)] = item + + key := gqlHash(item.gql, item.vars, "") + + if idx, ok := al.index[key]; ok { + al.list[idx] = item + } else { + al.list = append(al.list, item) + al.index[key] = len(al.list) - 1 + } f, err := os.Create(al.filepath) if err != nil { diff --git a/serv/auth.go b/serv/auth.go index e597a9b..22ab698 100644 --- a/serv/auth.go +++ b/serv/auth.go @@ -7,28 +7,42 @@ import ( ) var ( - userIDProviderKey = struct{}{} - userIDKey = struct{}{} + userIDProviderKey = "user_id_provider" + userIDKey = "user_id" + userRoleKey = "user_role" ) -func headerAuth(r *http.Request, c *config) *http.Request { - if len(c.Auth.Header) == 0 { - return nil - } +func headerAuth(next http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() - userID := r.Header.Get(c.Auth.Header) - if len(userID) != 0 { - ctx := context.WithValue(r.Context(), userIDKey, userID) - return r.WithContext(ctx) - } + userIDProvider := r.Header.Get("X-User-ID-Provider") + if len(userIDProvider) != 0 { + ctx = context.WithValue(ctx, userIDProviderKey, userIDProvider) + } - return nil + userID := r.Header.Get("X-User-ID") + if len(userID) != 0 { + ctx = context.WithValue(ctx, userIDKey, userID) + } + + userRole := r.Header.Get("X-User-Role") + if len(userRole) != 0 { + ctx = context.WithValue(ctx, userRoleKey, userRole) + } + + next.ServeHTTP(w, r.WithContext(ctx)) + } } func withAuth(next http.HandlerFunc) http.HandlerFunc { at := conf.Auth.Type ru := conf.Auth.Rails.URL + if conf.Auth.CredsInHeader { + next = headerAuth(next) + } + switch at { case "rails": if strings.HasPrefix(ru, "memcache:") { diff --git a/serv/auth_jwt.go b/serv/auth_jwt.go index 25ed785..ef4f834 100644 --- a/serv/auth_jwt.go +++ b/serv/auth_jwt.go @@ -58,11 +58,6 @@ func jwtHandler(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { var tok string - if rn := headerAuth(r, conf); rn != nil { - next.ServeHTTP(w, rn) - return - } - if len(cookie) != 0 { ck, err := r.Cookie(cookie) if err != nil { @@ -102,7 +97,6 @@ func jwtHandler(next http.HandlerFunc) http.HandlerFunc { } next.ServeHTTP(w, r.WithContext(ctx)) } - next.ServeHTTP(w, r) } } diff --git a/serv/auth_rails.go b/serv/auth_rails.go index c72c0d7..7f78da0 100644 --- a/serv/auth_rails.go +++ b/serv/auth_rails.go @@ -42,11 +42,6 @@ func railsRedisHandler(next http.HandlerFunc) http.HandlerFunc { } return func(w http.ResponseWriter, r *http.Request) { - if rn := headerAuth(r, conf); rn != nil { - next.ServeHTTP(w, rn) - return - } - ck, err := r.Cookie(cookie) if err != nil { next.ServeHTTP(w, r) @@ -83,17 +78,12 @@ func railsMemcacheHandler(next http.HandlerFunc) http.HandlerFunc { rURL, err := url.Parse(conf.Auth.Rails.URL) if err != nil { - logger.Fatal().Err(err) + logger.Fatal().Err(err).Send() } mc := memcache.New(rURL.Host) return func(w http.ResponseWriter, r *http.Request) { - if rn := headerAuth(r, conf); rn != nil { - next.ServeHTTP(w, rn) - return - } - ck, err := r.Cookie(cookie) if err != nil { next.ServeHTTP(w, r) @@ -126,25 +116,20 @@ func railsCookieHandler(next http.HandlerFunc) http.HandlerFunc { ra, err := railsAuth(conf) if err != nil { - logger.Fatal().Err(err) + logger.Fatal().Err(err).Send() } return func(w http.ResponseWriter, r *http.Request) { - if rn := headerAuth(r, conf); rn != nil { - next.ServeHTTP(w, rn) - return - } - ck, err := r.Cookie(cookie) if err != nil { - logger.Error().Err(err) + logger.Warn().Err(err).Msg("rails cookie missing") next.ServeHTTP(w, r) return } userID, err := ra.ParseCookie(ck.Value) if err != nil { - logger.Error().Err(err) + logger.Warn().Err(err).Msg("failed to parse rails cookie") next.ServeHTTP(w, r) return } diff --git a/serv/cmd.go b/serv/cmd.go index 87c1660..6fbc09f 100644 --- a/serv/cmd.go +++ b/serv/cmd.go @@ -10,7 +10,6 @@ import ( "github.com/dosco/super-graph/qcode" "github.com/gobuffalo/flect" "github.com/jackc/pgx/v4" - "github.com/jackc/pgx/v4/log/zerologadapter" "github.com/jackc/pgx/v4/pgxpool" "github.com/rs/zerolog" "github.com/spf13/cobra" @@ -184,7 +183,34 @@ func initConf() (*config, error) { } zerolog.SetGlobalLevel(logLevel) - //fmt.Printf("%#v", c) + for k, v := range c.DB.Vars { + c.DB.Vars[k] = sanitize(v) + } + + c.RolesQuery = sanitize(c.RolesQuery) + + rolesMap := make(map[string]struct{}) + + for i := range c.Roles { + role := &c.Roles[i] + + if _, ok := rolesMap[role.Name]; ok { + logger.Fatal().Msgf("duplicate role '%s' found", role.Name) + } + role.Name = sanitize(role.Name) + role.Match = sanitize(role.Match) + rolesMap[role.Name] = struct{}{} + } + + if _, ok := rolesMap["user"]; !ok { + c.Roles = append(c.Roles, configRole{Name: "user"}) + } + + if _, ok := rolesMap["anon"]; !ok { + c.Roles = append(c.Roles, configRole{Name: "anon"}) + } + + c.Validate() return c, nil } @@ -217,7 +243,7 @@ func initDB(c *config, useDB bool) (*pgx.Conn, error) { config.LogLevel = pgx.LogLevelNone } - config.Logger = zerologadapter.NewLogger(*logger) + config.Logger = NewSQLLogger(*logger) db, err := pgx.ConnectConfig(context.Background(), config) if err != nil { @@ -252,7 +278,7 @@ func initDBPool(c *config) (*pgxpool.Pool, error) { config.ConnConfig.LogLevel = pgx.LogLevelNone } - config.ConnConfig.Logger = zerologadapter.NewLogger(*logger) + config.ConnConfig.Logger = NewSQLLogger(*logger) // if c.DB.MaxRetries != 0 { // opt.MaxRetries = c.DB.MaxRetries diff --git a/serv/cmd_seed.go b/serv/cmd_seed.go index f0cc2d4..514c543 100644 --- a/serv/cmd_seed.go +++ b/serv/cmd_seed.go @@ -66,6 +66,7 @@ func graphQLFunc(query string, data interface{}) map[string]interface{} { c := &coreContext{Context: context.Background()} c.req.Query = query c.req.Vars = b + c.req.role = "user" res, err := c.execQuery() if err != nil { diff --git a/serv/config.go b/serv/config.go index 1fb7f63..160cd4b 100644 --- a/serv/config.go +++ b/serv/config.go @@ -1,7 +1,9 @@ package serv import ( + "regexp" "strings" + "unicode" "github.com/spf13/viper" ) @@ -24,9 +26,9 @@ type config struct { Inflections map[string]string Auth struct { - Type string - Cookie string - Header string + Type string + Cookie string + CredsInHeader bool `mapstructure:"creds_in_header"` Rails struct { Version string @@ -60,10 +62,10 @@ type config struct { MaxRetries int `mapstructure:"max_retries"` LogLevel string `mapstructure:"log_level"` - vars map[string][]byte `mapstructure:"variables"` + Vars map[string]string `mapstructure:"variables"` Defaults struct { - Filter []string + Filters []string Blocklist []string } @@ -71,18 +73,16 @@ type config struct { } `mapstructure:"database"` Tables []configTable + + RolesQuery string `mapstructure:"roles_query"` + Roles []configRole } type configTable struct { - Name string - Filter []string - FilterQuery []string `mapstructure:"filter_query"` - FilterInsert []string `mapstructure:"filter_insert"` - FilterUpdate []string `mapstructure:"filter_update"` - FilterDelete []string `mapstructure:"filter_delete"` - Table string - Blocklist []string - Remotes []configRemote + Name string + Table string + Blocklist []string + Remotes []configRemote } type configRemote struct { @@ -98,6 +98,42 @@ type configRemote struct { } `mapstructure:"set_headers"` } +type configRole struct { + Name string + Match string + Tables []struct { + Name string + + Query struct { + Limit int + Filters []string + Columns []string + DisableAggregation bool `mapstructure:"disable_aggregation"` + Deny bool + } + + Insert struct { + Filters []string + Columns []string + Set map[string]string + Deny bool + } + + Update struct { + Filters []string + Columns []string + Set map[string]string + Deny bool + } + + Delete struct { + Filters []string + Columns []string + Deny bool + } + } +} + func newConfig() *viper.Viper { vi := viper.New() @@ -132,24 +168,30 @@ func newConfig() *viper.Viper { return vi } -func (c *config) getVariables() map[string]string { - vars := make(map[string]string, len(c.DB.vars)) +func (c *config) Validate() { + rm := make(map[string]struct{}) - for k, v := range c.DB.vars { - isVar := false - - for i := range v { - if v[i] == '$' { - isVar = true - } else if v[i] == ' ' { - isVar = false - } else if isVar && v[i] >= 'a' && v[i] <= 'z' { - v[i] = 'A' + (v[i] - 'a') - } + for i := range c.Roles { + name := strings.ToLower(c.Roles[i].Name) + if _, ok := rm[name]; ok { + logger.Fatal().Msgf("duplicate config for role '%s'", c.Roles[i].Name) } - vars[k] = string(v) + rm[name] = struct{}{} + } + + tm := make(map[string]struct{}) + + for i := range c.Tables { + name := strings.ToLower(c.Tables[i].Name) + if _, ok := tm[name]; ok { + logger.Fatal().Msgf("duplicate config for table '%s'", c.Tables[i].Name) + } + tm[name] = struct{}{} + } + + if len(c.RolesQuery) == 0 { + logger.Warn().Msgf("no 'roles_query' defined.") } - return vars } func (c *config) getAliasMap() map[string][]string { @@ -167,3 +209,21 @@ func (c *config) getAliasMap() map[string][]string { } return m } + +var varRe1 = regexp.MustCompile(`(?mi)\$([a-zA-Z0-9_.]+)`) +var varRe2 = regexp.MustCompile(`\{\{([a-zA-Z0-9_.]+)\}\}`) + +func sanitize(s string) string { + s0 := varRe1.ReplaceAllString(s, `{{$1}}`) + + s1 := strings.Map(func(r rune) rune { + if unicode.IsSpace(r) { + return ' ' + } + return r + }, s0) + + return varRe2.ReplaceAllStringFunc(s1, func(m string) string { + return strings.ToLower(m) + }) +} diff --git a/serv/core.go b/serv/core.go index 9c7f3ee..dc5d42b 100644 --- a/serv/core.go +++ b/serv/core.go @@ -13,8 +13,8 @@ import ( "github.com/cespare/xxhash/v2" "github.com/dosco/super-graph/jsn" - "github.com/dosco/super-graph/psql" "github.com/dosco/super-graph/qcode" + "github.com/jackc/pgx/v4" "github.com/valyala/fasttemplate" ) @@ -32,6 +32,12 @@ func (c *coreContext) handleReq(w io.Writer, req *http.Request) error { c.req.ref = req.Referer() c.req.hdr = req.Header + if authCheck(c) { + c.req.role = "user" + } else { + c.req.role = "anon" + } + b, err := c.execQuery() if err != nil { return err @@ -46,10 +52,12 @@ func (c *coreContext) execQuery() ([]byte, error) { var qc *qcode.QCode var data []byte + logger.Debug().Str("role", c.req.role).Msg(c.req.Query) + if conf.UseAllowList { var ps *preparedItem - data, ps, err = c.resolvePreparedSQL(c.req.Query) + data, ps, err = c.resolvePreparedSQL() if err != nil { return nil, err } @@ -59,12 +67,7 @@ func (c *coreContext) execQuery() ([]byte, error) { } else { - qc, err = qcompile.Compile([]byte(c.req.Query)) - if err != nil { - return nil, err - } - - data, skipped, err = c.resolveSQL(qc) + data, skipped, err = c.resolveSQL() if err != nil { return nil, err } @@ -112,6 +115,160 @@ func (c *coreContext) execQuery() ([]byte, error) { return ob.Bytes(), nil } +func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) { + tx, err := db.Begin(c) + if err != nil { + return nil, nil, err + } + defer tx.Rollback(c) + + if v := c.Value(userIDKey); v != nil { + _, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v)) + + if err != nil { + return nil, nil, err + } + } + + var role string + mutation := isMutation(c.req.Query) + useRoleQuery := len(conf.RolesQuery) != 0 && mutation + + if useRoleQuery { + if role, err = c.executeRoleQuery(tx); err != nil { + return nil, nil, err + } + + } else if v := c.Value(userRoleKey); v != nil { + role = v.(string) + + } else if mutation { + role = c.req.role + + } + + ps, ok := _preparedList[gqlHash(c.req.Query, c.req.Vars, role)] + if !ok { + return nil, nil, errUnauthorized + } + + var root []byte + vars := varList(c, ps.args) + + if mutation { + err = tx.QueryRow(c, ps.stmt.SQL, vars...).Scan(&root) + } else { + err = tx.QueryRow(c, ps.stmt.SQL, vars...).Scan(&c.req.role, &root) + } + if err != nil { + return nil, nil, err + } + + if err := tx.Commit(c); err != nil { + return nil, nil, err + } + + return root, ps, nil +} + +func (c *coreContext) resolveSQL() ([]byte, uint32, error) { + tx, err := db.Begin(c) + if err != nil { + return nil, 0, err + } + defer tx.Rollback(c) + + mutation := isMutation(c.req.Query) + useRoleQuery := len(conf.RolesQuery) != 0 && mutation + + if useRoleQuery { + if c.req.role, err = c.executeRoleQuery(tx); err != nil { + return nil, 0, err + } + + } else if v := c.Value(userRoleKey); v != nil { + c.req.role = v.(string) + } + + stmts, err := c.buildStmt() + if err != nil { + return nil, 0, err + } + + var st *stmt + + if mutation { + st = findStmt(c.req.role, stmts) + } else { + st = &stmts[0] + } + + t := fasttemplate.New(st.sql, openVar, closeVar) + + buf := &bytes.Buffer{} + _, err = t.ExecuteFunc(buf, varMap(c)) + + if err == errNoUserID && + authFailBlock == authFailBlockPerQuery && + authCheck(c) == false { + return nil, 0, errUnauthorized + } + + if err != nil { + return nil, 0, err + } + + finalSQL := buf.String() + + var stime time.Time + + if conf.EnableTracing { + stime = time.Now() + } + + if v := c.Value(userIDKey); v != nil { + _, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v)) + + if err != nil { + return nil, 0, err + } + } + + var root []byte + + if mutation { + err = tx.QueryRow(c, finalSQL).Scan(&root) + } else { + err = tx.QueryRow(c, finalSQL).Scan(&c.req.role, &root) + } + if err != nil { + return nil, 0, err + } + + if err := tx.Commit(c); err != nil { + return nil, 0, err + } + + if mutation { + st = findStmt(c.req.role, stmts) + } else { + st = &stmts[0] + } + + if conf.EnableTracing && len(st.qc.Selects) != 0 { + c.addTrace( + st.qc.Selects, + st.qc.Selects[0].ID, + stime) + } + + if conf.UseAllowList == false { + _allowList.add(&c.req) + } + + return root, st.skipped, nil +} + func (c *coreContext) resolveRemote( hdr http.Header, h *xxhash.Digest, @@ -259,125 +416,15 @@ func (c *coreContext) resolveRemotes( return to, cerr } -func (c *coreContext) resolvePreparedSQL(gql string) ([]byte, *preparedItem, error) { - ps, ok := _preparedList[gqlHash(gql, c.req.Vars)] - if !ok { - return nil, nil, errUnauthorized +func (c *coreContext) executeRoleQuery(tx pgx.Tx) (string, error) { + var role string + row := tx.QueryRow(c, "_sg_get_role", c.req.role, 1) + + if err := row.Scan(&role); err != nil { + return "", err } - var root []byte - vars := varList(c, ps.args) - - tx, err := db.Begin(c) - if err != nil { - return nil, nil, err - } - defer tx.Rollback(c) - - if v := c.Value(userIDKey); v != nil { - _, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v)) - - if err != nil { - return nil, nil, err - } - } - - err = tx.QueryRow(c, ps.stmt.SQL, vars...).Scan(&root) - if err != nil { - return nil, nil, err - } - - if err := tx.Commit(c); err != nil { - return nil, nil, err - } - - fmt.Printf("PRE: %v\n", ps.stmt) - - return root, ps, nil -} - -func (c *coreContext) resolveSQL(qc *qcode.QCode) ([]byte, uint32, error) { - var vars map[string]json.RawMessage - stmt := &bytes.Buffer{} - - if len(c.req.Vars) != 0 { - if err := json.Unmarshal(c.req.Vars, &vars); err != nil { - return nil, 0, err - } - } - - skipped, err := pcompile.Compile(qc, stmt, psql.Variables(vars)) - if err != nil { - return nil, 0, err - } - - t := fasttemplate.New(stmt.String(), openVar, closeVar) - - stmt.Reset() - _, err = t.ExecuteFunc(stmt, varMap(c)) - - if err == errNoUserID && - authFailBlock == authFailBlockPerQuery && - authCheck(c) == false { - return nil, 0, errUnauthorized - } - - if err != nil { - return nil, 0, err - } - - finalSQL := stmt.String() - - // if conf.LogLevel == "debug" { - // os.Stdout.WriteString(finalSQL) - // os.Stdout.WriteString("\n\n") - // } - - var st time.Time - - if conf.EnableTracing { - st = time.Now() - } - - tx, err := db.Begin(c) - if err != nil { - return nil, 0, err - } - defer tx.Rollback(c) - - if v := c.Value(userIDKey); v != nil { - _, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v)) - - if err != nil { - return nil, 0, err - } - } - - //fmt.Printf("\nRAW: %#v\n", finalSQL) - - var root []byte - - err = tx.QueryRow(c, finalSQL).Scan(&root) - if err != nil { - return nil, 0, err - } - - if err := tx.Commit(c); err != nil { - return nil, 0, err - } - - if conf.EnableTracing && len(qc.Selects) != 0 { - c.addTrace( - qc.Selects, - qc.Selects[0].ID, - st) - } - - if conf.UseAllowList == false { - _allowList.add(&c.req) - } - - return root, skipped, nil + return role, nil } func (c *coreContext) render(w io.Writer, data []byte) error { diff --git a/serv/core_build.go b/serv/core_build.go new file mode 100644 index 0000000..fd86003 --- /dev/null +++ b/serv/core_build.go @@ -0,0 +1,144 @@ +package serv + +import ( + "bytes" + "encoding/json" + "errors" + "io" + + "github.com/dosco/super-graph/psql" + "github.com/dosco/super-graph/qcode" +) + +type stmt struct { + role *configRole + qc *qcode.QCode + skipped uint32 + sql string +} + +func (c *coreContext) buildStmt() ([]stmt, error) { + var vars map[string]json.RawMessage + + if len(c.req.Vars) != 0 { + if err := json.Unmarshal(c.req.Vars, &vars); err != nil { + return nil, err + } + } + + gql := []byte(c.req.Query) + + if len(conf.Roles) == 0 { + return nil, errors.New(`no roles found ('user' and 'anon' required)`) + } + + qc, err := qcompile.Compile(gql, conf.Roles[0].Name) + if err != nil { + return nil, err + } + + stmts := make([]stmt, 0, len(conf.Roles)) + mutation := (qc.Type != qcode.QTQuery) + w := &bytes.Buffer{} + + for i := range conf.Roles { + role := &conf.Roles[i] + + if mutation && len(c.req.role) != 0 && role.Name != c.req.role { + continue + } + + if i > 0 { + qc, err = qcompile.Compile(gql, role.Name) + if err != nil { + return nil, err + } + } + + stmts = append(stmts, stmt{role: role, qc: qc}) + + if mutation { + skipped, err := pcompile.Compile(qc, w, psql.Variables(vars)) + if err != nil { + return nil, err + } + + s := &stmts[len(stmts)-1] + s.skipped = skipped + s.sql = w.String() + w.Reset() + } + } + + if mutation { + return stmts, nil + } + + io.WriteString(w, `SELECT "_sg_auth_info"."role", (CASE "_sg_auth_info"."role" `) + + for _, s := range stmts { + io.WriteString(w, `WHEN '`) + io.WriteString(w, s.role.Name) + io.WriteString(w, `' THEN (`) + + s.skipped, err = pcompile.Compile(s.qc, w, psql.Variables(vars)) + if err != nil { + return nil, err + } + + io.WriteString(w, `) `) + } + io.WriteString(w, `END) FROM (`) + + if len(conf.RolesQuery) == 0 { + v := c.Value(userRoleKey) + + io.WriteString(w, `VALUES ("`) + if v != nil { + io.WriteString(w, v.(string)) + } else { + io.WriteString(w, c.req.role) + } + io.WriteString(w, `")) AS "_sg_auth_info"(role) LIMIT 1;`) + + } else { + + io.WriteString(w, `SELECT (CASE WHEN EXISTS (`) + io.WriteString(w, conf.RolesQuery) + io.WriteString(w, `) THEN `) + + io.WriteString(w, `(SELECT (CASE`) + for _, s := range stmts { + if len(s.role.Match) == 0 { + continue + } + io.WriteString(w, ` WHEN `) + io.WriteString(w, s.role.Match) + io.WriteString(w, ` THEN '`) + io.WriteString(w, s.role.Name) + io.WriteString(w, `'`) + } + + if len(c.req.role) == 0 { + io.WriteString(w, ` ELSE 'anon' END) FROM (`) + } else { + io.WriteString(w, ` ELSE '`) + io.WriteString(w, c.req.role) + io.WriteString(w, `' END) FROM (`) + } + + io.WriteString(w, conf.RolesQuery) + io.WriteString(w, `) AS "_sg_auth_roles_query" LIMIT 1) ELSE '`) + if len(c.req.role) == 0 { + io.WriteString(w, `anon`) + } else { + io.WriteString(w, c.req.role) + } + io.WriteString(w, `' END) FROM (VALUES (1)) AS "_sg_auth_filler") AS "_sg_auth_info"(role) LIMIT 1; `) + } + + stmts[0].sql = w.String() + stmts[0].role = nil + + return stmts, nil +} diff --git a/serv/http.go b/serv/http.go index c943110..737006d 100644 --- a/serv/http.go +++ b/serv/http.go @@ -30,14 +30,15 @@ type gqlReq struct { Query string `json:"query"` Vars json.RawMessage `json:"variables"` ref string + role string hdr http.Header } type variables map[string]json.RawMessage type gqlResp struct { - Error string `json:"error,omitempty"` - Data json.RawMessage `json:"data"` + Error string `json:"message,omitempty"` + Data json.RawMessage `json:"data,omitempty"` Extensions *extensions `json:"extensions,omitempty"` } @@ -94,55 +95,20 @@ func apiv1Http(w http.ResponseWriter, r *http.Request) { } if strings.EqualFold(ctx.req.OpName, introspectionQuery) { - // dat, err := ioutil.ReadFile("test.schema") - // if err != nil { - // http.Error(w, err.Error(), http.StatusInternalServerError) - // return - // } - //w.Write(dat) - w.Header().Set("Content-Type", "application/json") - w.Write([]byte(`{ - "data": { - "__schema": { - "queryType": { - "name": "Query" - }, - "mutationType": null, - "subscriptionType": null - } - }, - "extensions":{ - "tracing":{ - "version":1, - "startTime":"2019-06-04T19:53:31.093Z", - "endTime":"2019-06-04T19:53:31.108Z", - "duration":15219720, - "execution": { - "resolvers": [{ - "path": ["__schema"], - "parentType": "Query", - "fieldName": "__schema", - "returnType": "__Schema!", - "startOffset": 50950, - "duration": 17187 - }] - } - } - } - }`)) + introspect(w) return } err = ctx.handleReq(w, r) if err == errUnauthorized { - err := "Not authorized" - logger.Debug().Msg(err) - http.Error(w, err, 401) + w.WriteHeader(http.StatusUnauthorized) + json.NewEncoder(w).Encode(gqlResp{Error: err.Error()}) + return } if err != nil { - logger.Err(err).Msg("Failed to handle request") + logger.Err(err).Msg("failed to handle request") errorResp(w, err) } } diff --git a/serv/introsp.go b/serv/introsp.go new file mode 100644 index 0000000..2fbf26f --- /dev/null +++ b/serv/introsp.go @@ -0,0 +1,36 @@ +package serv + +import "net/http" + +func introspect(w http.ResponseWriter) { + w.Header().Set("Content-Type", "application/json") + w.Write([]byte(`{ + "data": { + "__schema": { + "queryType": { + "name": "Query" + }, + "mutationType": null, + "subscriptionType": null + } + }, + "extensions":{ + "tracing":{ + "version":1, + "startTime":"2019-06-04T19:53:31.093Z", + "endTime":"2019-06-04T19:53:31.108Z", + "duration":15219720, + "execution": { + "resolvers": [{ + "path": ["__schema"], + "parentType": "Query", + "fieldName": "__schema", + "returnType": "__Schema!", + "startOffset": 50950, + "duration": 17187 + }] + } + } + } + }`)) +} diff --git a/serv/prepare.go b/serv/prepare.go index bf9a475..2329578 100644 --- a/serv/prepare.go +++ b/serv/prepare.go @@ -7,7 +7,6 @@ import ( "fmt" "io" - "github.com/dosco/super-graph/psql" "github.com/dosco/super-graph/qcode" "github.com/jackc/pgconn" "github.com/valyala/fasttemplate" @@ -27,55 +26,105 @@ var ( func initPreparedList() { _preparedList = make(map[string]*preparedItem) - for k, v := range _allowList.list { - err := prepareStmt(k, v.gql, v.vars) + if err := prepareRoleStmt(); err != nil { + logger.Fatal().Err(err).Msg("failed to prepare get role statement") + } + + for _, v := range _allowList.list { + + err := prepareStmt(v.gql, v.vars) if err != nil { - logger.Warn().Err(err).Send() + logger.Warn().Str("gql", v.gql).Err(err).Send() } } } -func prepareStmt(key, gql string, varBytes json.RawMessage) error { - if len(gql) == 0 || len(key) == 0 { +func prepareStmt(gql string, varBytes json.RawMessage) error { + if len(gql) == 0 { return nil } - qc, err := qcompile.Compile([]byte(gql)) + c := &coreContext{Context: context.Background()} + c.req.Query = gql + c.req.Vars = varBytes + + stmts, err := c.buildStmt() if err != nil { return err } - var vars map[string]json.RawMessage + if len(stmts) != 0 && stmts[0].qc.Type == qcode.QTQuery { + c.req.Vars = nil + } - if len(varBytes) != 0 { - vars = make(map[string]json.RawMessage) + for _, s := range stmts { + if len(s.sql) == 0 { + continue + } - if err := json.Unmarshal(varBytes, &vars); err != nil { + finalSQL, am := processTemplate(s.sql) + + ctx := context.Background() + + tx, err := db.Begin(ctx) + if err != nil { + return err + } + defer tx.Rollback(ctx) + + pstmt, err := tx.Prepare(ctx, "", finalSQL) + if err != nil { + return err + } + + var key string + + if s.role == nil { + key = gqlHash(gql, c.req.Vars, "") + } else { + key = gqlHash(gql, c.req.Vars, s.role.Name) + } + + _preparedList[key] = &preparedItem{ + stmt: pstmt, + args: am, + skipped: s.skipped, + qc: s.qc, + } + + if err := tx.Commit(ctx); err != nil { return err } } - buf := &bytes.Buffer{} + return nil +} - skipped, err := pcompile.Compile(qc, buf, psql.Variables(vars)) - if err != nil { - return err +func prepareRoleStmt() error { + if len(conf.RolesQuery) == 0 { + return nil } - t := fasttemplate.New(buf.String(), `{{`, `}}`) - am := make([][]byte, 0, 5) - i := 0 + w := &bytes.Buffer{} - finalSQL := t.ExecuteFuncString(func(w io.Writer, tag string) (int, error) { - am = append(am, []byte(tag)) - i++ - return w.Write([]byte(fmt.Sprintf("$%d", i))) - }) - - if err != nil { - return err + io.WriteString(w, `SELECT (CASE`) + for _, role := range conf.Roles { + if len(role.Match) == 0 { + continue + } + io.WriteString(w, ` WHEN `) + io.WriteString(w, role.Match) + io.WriteString(w, ` THEN '`) + io.WriteString(w, role.Name) + io.WriteString(w, `'`) } + io.WriteString(w, ` ELSE {{role}} END) FROM (`) + io.WriteString(w, conf.RolesQuery) + io.WriteString(w, `) AS "_sg_auth_roles_query"`) + + roleSQL, _ := processTemplate(w.String()) + ctx := context.Background() tx, err := db.Begin(ctx) @@ -84,21 +133,28 @@ func prepareStmt(key, gql string, varBytes json.RawMessage) error { } defer tx.Rollback(ctx) - pstmt, err := tx.Prepare(ctx, "", finalSQL) + _, err = tx.Prepare(ctx, "_sg_get_role", roleSQL) if err != nil { return err } - _preparedList[key] = &preparedItem{ - stmt: pstmt, - args: am, - skipped: skipped, - qc: qc, - } - - if err := tx.Commit(ctx); err != nil { - return err - } - return nil } + +func processTemplate(tmpl string) (string, [][]byte) { + t := fasttemplate.New(tmpl, `{{`, `}}`) + am := make([][]byte, 0, 5) + i := 0 + + vmap := make(map[string]int) + + return t.ExecuteFuncString(func(w io.Writer, tag string) (int, error) { + if n, ok := vmap[tag]; ok { + return w.Write([]byte(fmt.Sprintf("$%d", n))) + } + am = append(am, []byte(tag)) + i++ + vmap[tag] = i + return w.Write([]byte(fmt.Sprintf("$%d", i))) + }), am +} diff --git a/serv/serv.go b/serv/serv.go index 1006520..a98c16e 100644 --- a/serv/serv.go +++ b/serv/serv.go @@ -12,7 +12,6 @@ import ( rice "github.com/GeertJohan/go.rice" "github.com/dosco/super-graph/psql" "github.com/dosco/super-graph/qcode" - "github.com/gobuffalo/flect" ) func initCompilers(c *config) (*qcode.Compiler, *psql.Compiler, error) { @@ -22,52 +21,53 @@ func initCompilers(c *config) (*qcode.Compiler, *psql.Compiler, error) { } conf := qcode.Config{ - DefaultFilter: c.DB.Defaults.Filter, - FilterMap: qcode.Filters{ - All: make(map[string][]string, len(c.Tables)), - Query: make(map[string][]string, len(c.Tables)), - Insert: make(map[string][]string, len(c.Tables)), - Update: make(map[string][]string, len(c.Tables)), - Delete: make(map[string][]string, len(c.Tables)), - }, Blocklist: c.DB.Defaults.Blocklist, KeepArgs: false, } - for i := range c.Tables { - t := c.Tables[i] - - singular := flect.Singularize(t.Name) - plural := flect.Pluralize(t.Name) - - setFilter := func(fm map[string][]string, fil []string) { - switch { - case len(fil) == 0: - return - case fil[0] == "none" || len(fil[0]) == 0: - fm[singular] = []string{} - fm[plural] = []string{} - default: - fm[singular] = t.Filter - fm[plural] = t.Filter - } - } - - setFilter(conf.FilterMap.All, t.Filter) - setFilter(conf.FilterMap.Query, t.FilterQuery) - setFilter(conf.FilterMap.Insert, t.FilterInsert) - setFilter(conf.FilterMap.Update, t.FilterUpdate) - setFilter(conf.FilterMap.Delete, t.FilterDelete) - } - qc, err := qcode.NewCompiler(conf) if err != nil { return nil, nil, err } + for _, r := range c.Roles { + for _, t := range r.Tables { + query := qcode.QueryConfig{ + Limit: t.Query.Limit, + Filters: t.Query.Filters, + Columns: t.Query.Columns, + DisableFunctions: t.Query.DisableAggregation, + } + + insert := qcode.InsertConfig{ + Filters: t.Insert.Filters, + Columns: t.Insert.Columns, + Set: t.Insert.Set, + } + + update := qcode.UpdateConfig{ + Filters: t.Insert.Filters, + Columns: t.Insert.Columns, + Set: t.Insert.Set, + } + + delete := qcode.DeleteConfig{ + Filters: t.Insert.Filters, + Columns: t.Insert.Columns, + } + + qc.AddRole(r.Name, t.Name, qcode.TRConfig{ + Query: query, + Insert: insert, + Update: update, + Delete: delete, + }) + } + } + pc := psql.NewCompiler(psql.Config{ Schema: schema, - Vars: c.getVariables(), + Vars: c.DB.Vars, }) return qc, pc, nil diff --git a/serv/sqllog.go b/serv/sqllog.go new file mode 100644 index 0000000..3fccbea --- /dev/null +++ b/serv/sqllog.go @@ -0,0 +1,45 @@ +package serv + +import ( + "context" + + "github.com/jackc/pgx/v4" + "github.com/rs/zerolog" +) + +type Logger struct { + logger zerolog.Logger +} + +// NewLogger accepts a zerolog.Logger as input and returns a new custom pgx +// logging fascade as output. +func NewSQLLogger(logger zerolog.Logger) *Logger { + return &Logger{ + logger: logger.With().Logger(), + } +} + +func (pl *Logger) Log(ctx context.Context, level pgx.LogLevel, msg string, data map[string]interface{}) { + var zlevel zerolog.Level + switch level { + case pgx.LogLevelNone: + zlevel = zerolog.NoLevel + case pgx.LogLevelError: + zlevel = zerolog.ErrorLevel + case pgx.LogLevelWarn: + zlevel = zerolog.WarnLevel + case pgx.LogLevelInfo: + zlevel = zerolog.InfoLevel + case pgx.LogLevelDebug: + zlevel = zerolog.DebugLevel + default: + zlevel = zerolog.DebugLevel + } + + if sql, ok := data["sql"]; ok { + delete(data, "sql") + pl.logger.WithLevel(zlevel).Fields(data).Msg(sql.(string)) + } else { + pl.logger.WithLevel(zlevel).Fields(data).Msg(msg) + } +} diff --git a/serv/utils.go b/serv/utils.go index baf49c3..b59dded 100644 --- a/serv/utils.go +++ b/serv/utils.go @@ -21,16 +21,29 @@ func mkkey(h *xxhash.Digest, k1 string, k2 string) uint64 { return v } -func gqlHash(b string, vars []byte) string { +func gqlHash(b string, vars []byte, role string) string { b = strings.TrimSpace(b) h := sha1.New() + query := "query" s, e := 0, 0 space := []byte{' '} + starting := true var b0, b1 byte for { + if starting && b[e] == 'q' { + n := 0 + se := e + for e < len(b) && n < len(query) && b[e] == query[n] { + n++ + e++ + } + if n != len(query) { + io.WriteString(h, strings.ToLower(b[se:e])) + } + } if ws(b[e]) { for e < len(b) && ws(b[e]) { e++ @@ -42,6 +55,7 @@ func gqlHash(b string, vars []byte) string { h.Write(space) } } else { + starting = false s = e for e < len(b) && ws(b[e]) == false { e++ @@ -56,6 +70,10 @@ func gqlHash(b string, vars []byte) string { } } + if len(role) != 0 { + io.WriteString(h, role) + } + if vars == nil || len(vars) == 0 { return hex.EncodeToString(h.Sum(nil)) } @@ -80,3 +98,26 @@ func ws(b byte) bool { func al(b byte) bool { return (b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z') || (b >= '0' && b <= '9') } + +func isMutation(sql string) bool { + for i := range sql { + b := sql[i] + if b == '{' { + return false + } + if al(b) { + return (b == 'm' || b == 'M') + } + } + return false +} + +func findStmt(role string, stmts []stmt) *stmt { + for i := range stmts { + if stmts[i].role.Name != role { + continue + } + return &stmts[i] + } + return nil +} diff --git a/serv/utils_test.go b/serv/utils_test.go index 17d91b7..b8babeb 100644 --- a/serv/utils_test.go +++ b/serv/utils_test.go @@ -5,7 +5,7 @@ import ( "testing" ) -func TestRelaxHash1(t *testing.T) { +func TestGQLHash1(t *testing.T) { var v1 = ` products( limit: 30, @@ -24,15 +24,15 @@ func TestRelaxHash1(t *testing.T) { price } ` - h1 := gqlHash(v1, nil) - h2 := gqlHash(v2, nil) + h1 := gqlHash(v1, nil, "") + h2 := gqlHash(v2, nil, "") if strings.Compare(h1, h2) != 0 { t.Fatal("Hashes don't match they should") } } -func TestRelaxHash2(t *testing.T) { +func TestGQLHash2(t *testing.T) { var v1 = ` { products( @@ -53,15 +53,15 @@ func TestRelaxHash2(t *testing.T) { var v2 = ` { products( limit: 30, order_by: { price: desc }, distinct: [ price ] where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) { id name price user { id email } } } ` - h1 := gqlHash(v1, nil) - h2 := gqlHash(v2, nil) + h1 := gqlHash(v1, nil, "") + h2 := gqlHash(v2, nil, "") if strings.Compare(h1, h2) != 0 { t.Fatal("Hashes don't match they should") } } -func TestRelaxHash3(t *testing.T) { +func TestGQLHash3(t *testing.T) { var v1 = `users { id email @@ -86,15 +86,44 @@ func TestRelaxHash3(t *testing.T) { } ` - h1 := gqlHash(v1, nil) - h2 := gqlHash(v2, nil) + h1 := gqlHash(v1, nil, "") + h2 := gqlHash(v2, nil, "") if strings.Compare(h1, h2) != 0 { t.Fatal("Hashes don't match they should") } } -func TestRelaxHashWithVars1(t *testing.T) { +func TestGQLHash4(t *testing.T) { + var v1 = ` + query { + products( + limit: 30 + order_by: { price: desc } + distinct: [price] + where: { id: { and: { greater_or_equals: 20, lt: 28 } } } + ) { + id + name + price + user { + id + email + } + } + }` + + var v2 = ` { products( limit: 30, order_by: { price: desc }, distinct: [ price ] where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) { id name price user { id email } } } ` + + h1 := gqlHash(v1, nil, "") + h2 := gqlHash(v2, nil, "") + + if strings.Compare(h1, h2) != 0 { + t.Fatal("Hashes don't match they should") + } +} + +func TestGQLHashWithVars1(t *testing.T) { var q1 = ` products( limit: 30, @@ -136,15 +165,15 @@ func TestRelaxHashWithVars1(t *testing.T) { "user": 123 }` - h1 := gqlHash(q1, []byte(v1)) - h2 := gqlHash(q2, []byte(v2)) + h1 := gqlHash(q1, []byte(v1), "user") + h2 := gqlHash(q2, []byte(v2), "user") if strings.Compare(h1, h2) != 0 { t.Fatal("Hashes don't match they should") } } -func TestRelaxHashWithVars2(t *testing.T) { +func TestGQLHashWithVars2(t *testing.T) { var q1 = ` products( limit: 30, @@ -193,8 +222,8 @@ func TestRelaxHashWithVars2(t *testing.T) { "user": 123 }` - h1 := gqlHash(q1, []byte(v1)) - h2 := gqlHash(q2, []byte(v2)) + h1 := gqlHash(q1, []byte(v1), "user") + h2 := gqlHash(q2, []byte(v2), "user") if strings.Compare(h1, h2) != 0 { t.Fatal("Hashes don't match they should") diff --git a/serv/vars.go b/serv/vars.go index 6ad9da6..f20627c 100644 --- a/serv/vars.go +++ b/serv/vars.go @@ -11,17 +11,27 @@ import ( func varMap(ctx *coreContext) func(w io.Writer, tag string) (int, error) { return func(w io.Writer, tag string) (int, error) { switch tag { - case "user_id": - if v := ctx.Value(userIDKey); v != nil { - return stringVar(w, v.(string)) - } - return 0, errNoUserID - case "user_id_provider": if v := ctx.Value(userIDProviderKey); v != nil { return stringVar(w, v.(string)) } - return 0, errNoUserID + io.WriteString(w, "null") + return 0, nil + + case "user_id": + if v := ctx.Value(userIDKey); v != nil { + return stringVar(w, v.(string)) + } + + io.WriteString(w, "null") + return 0, nil + + case "user_role": + if v := ctx.Value(userRoleKey); v != nil { + return stringVar(w, v.(string)) + } + io.WriteString(w, "null") + return 0, nil } fields := jsn.Get(ctx.req.Vars, [][]byte{[]byte(tag)}) diff --git a/slides/overview.slide b/slides/overview.slide index 7888781..e52ff40 100644 --- a/slides/overview.slide +++ b/slides/overview.slide @@ -80,7 +80,7 @@ SQL Output account_id: "select account_id from users where id = $user_id" defaults: - filter: ["{ user_id: { eq: $user_id } }"] + Filters: ["{ user_id: { eq: $user_id } }"] blacklist: - password @@ -88,14 +88,14 @@ SQL Output fields: - name: users - filter: ["{ id: { eq: $user_id } }"] + Filters: ["{ id: { eq: $user_id } }"] - name: products - filter: [ + Filters: [ "{ price: { gt: 0 } }", "{ price: { lt: 8 } }" ] - name: me table: users - filter: ["{ id: { eq: $user_id } }"] + Filters: ["{ id: { eq: $user_id } }"] diff --git a/tmpl/dev.yml b/tmpl/dev.yml index b53a4d5..ffe9f64 100644 --- a/tmpl/dev.yml +++ b/tmpl/dev.yml @@ -1,4 +1,4 @@ -app_name: "{% app_name %} Development" +app_name: "Super Graph Development" host_port: 0.0.0.0:8080 web_ui: true @@ -53,7 +53,7 @@ auth: # Comment this out if you want to disable setting # the user_id via a header. Good for testing - header: X-User-ID + creds_in_header: true rails: # Rails version this is used for reading the @@ -84,7 +84,7 @@ database: type: postgres host: db port: 5432 - dbname: {% app_name_slug %}_development + dbname: app_development user: postgres password: '' @@ -100,7 +100,7 @@ database: # Define defaults to for the field key and values below defaults: - # filter: ["{ user_id: { eq: $user_id } }"] + # filters: ["{ user_id: { eq: $user_id } }"] # Field and table names that you wish to block blocklist: @@ -111,45 +111,81 @@ database: - encrypted - token - tables: - - name: users - # This filter will overwrite defaults.filter - # filter: ["{ id: { eq: $user_id } }"] - # filter_query: ["{ id: { eq: $user_id } }"] - filter_update: ["{ id: { eq: $user_id } }"] - filter_delete: ["{ id: { eq: $user_id } }"] +tables: + - name: customers + remotes: + - name: payments + id: stripe_id + url: http://rails_app:3000/stripe/$id + path: data + # debug: true + pass_headers: + - cookie + set_headers: + - name: Host + value: 0.0.0.0 + # - name: Authorization + # value: Bearer - # - name: products - # # Multiple filters are AND'd together - # filter: [ - # "{ price: { gt: 0 } }", - # "{ price: { lt: 8 } }" - # ] + - # You can create new fields that have a + # real db table backing them + name: me + table: users - - name: customers - # No filter is used for this field not - # even defaults.filter - filter: none +roles_query: "SELECT * FROM users as usr WHERE id = $user_id" - remotes: - - name: payments - id: stripe_id - url: http://rails_app:3000/stripe/$id - path: data - # debug: true - pass_headers: - - cookie - set_headers: - - name: Host - value: 0.0.0.0 - # - name: Authorization - # value: Bearer +roles: + - name: anon + tables: + - name: products + limit: 10 - - # You can create new fields that have a - # real db table backing them - name: me - table: users - filter: ["{ id: { eq: $user_id } }"] + query: + columns: ["id", "name", "description" ] + aggregation: false - # - name: posts - # filter: ["{ account_id: { _eq: $account_id } }"] \ No newline at end of file + insert: + allow: false + + update: + allow: false + + delete: + allow: false + + - name: user + tables: + - name: users + query: + filters: ["{ id: { _eq: $user_id } }"] + + - name: products + query: + limit: 50 + filters: ["{ user_id: { eq: $user_id } }"] + columns: ["id", "name", "description" ] + disable_aggregation: false + + insert: + filters: ["{ user_id: { eq: $user_id } }"] + columns: ["id", "name", "description" ] + set: + - created_at: "now" + + update: + filters: ["{ user_id: { eq: $user_id } }"] + columns: + - id + - name + set: + - updated_at: "now" + + delete: + deny: true + + - name: admin + match: id = 1 + tables: + - name: users + # query: + # filters: ["{ account_id: { _eq: $account_id } }"] diff --git a/tmpl/prod.yml b/tmpl/prod.yml index 9597d7a..95abfb7 100644 --- a/tmpl/prod.yml +++ b/tmpl/prod.yml @@ -1,4 +1,4 @@ -app_name: "{% app_name %} Production" +app_name: "Super Graph Production" host_port: 0.0.0.0:8080 web_ui: false @@ -47,10 +47,6 @@ auth: type: rails cookie: _app_session - # Comment this out if you want to disable setting - # the user_id via a header. Good for testing - header: X-User-ID - rails: # Rails version this is used for reading the # various cookies formats. @@ -80,7 +76,7 @@ database: type: postgres host: db port: 5432 - dbname: {% app_name_slug %}_production + dbname: {{app_name_slug}}_development user: postgres password: '' #pool_size: 10 @@ -94,7 +90,7 @@ database: # Define defaults to for the field key and values below defaults: - filter: ["{ user_id: { eq: $user_id } }"] + filters: ["{ user_id: { eq: $user_id } }"] # Field and table names that you wish to block blocklist: @@ -105,43 +101,79 @@ database: - encrypted - token - tables: - - name: users - # This filter will overwrite defaults.filter - # filter: ["{ id: { eq: $user_id } }"] - # filter_query: ["{ id: { eq: $user_id } }"] - filter_update: ["{ id: { eq: $user_id } }"] - filter_delete: ["{ id: { eq: $user_id } }"] +tables: + - name: customers + # remotes: + # - name: payments + # id: stripe_id + # url: http://rails_app:3000/stripe/$id + # path: data + # # pass_headers: + # # - cookie + # # - host + # set_headers: + # - name: Authorization + # value: Bearer - - name: products - # Multiple filters are AND'd together - filter: [ - "{ price: { gt: 0 } }", - "{ price: { lt: 8 } }" - ] + - # You can create new fields that have a + # real db table backing them + name: me + table: users - - name: customers - # No filter is used for this field not - # even defaults.filter - filter: none +roles_query: "SELECT * FROM users as usr WHERE id = $user_id" - # remotes: - # - name: payments - # id: stripe_id - # url: http://rails_app:3000/stripe/$id - # path: data - # # pass_headers: - # # - cookie - # # - host - # set_headers: - # - name: Authorization - # value: Bearer +roles: + - name: anon + tables: + - name: products + limit: 10 - - # You can create new fields that have a - # real db table backing them - name: me - table: users - filter: ["{ id: { eq: $user_id } }"] + query: + columns: ["id", "name", "description" ] + aggregation: false - # - name: posts - # filter: ["{ account_id: { _eq: $account_id } }"] \ No newline at end of file + insert: + allow: false + + update: + allow: false + + delete: + allow: false + + - name: user + tables: + - name: users + query: + filters: ["{ id: { _eq: $user_id } }"] + + - name: products + query: + limit: 50 + filters: ["{ user_id: { eq: $user_id } }"] + columns: ["id", "name", "description" ] + disable_aggregation: false + + insert: + filters: ["{ user_id: { eq: $user_id } }"] + columns: ["id", "name", "description" ] + set: + - created_at: "now" + + update: + filters: ["{ user_id: { eq: $user_id } }"] + columns: + - id + - name + set: + - updated_at: "now" + + delete: + deny: true + + - name: admin + match: id = 1 + tables: + - name: users + # query: + # filters: ["{ account_id: { _eq: $account_id } }"]