Add insert mutation with bulk insert
This commit is contained in:
parent
5b9105ff0c
commit
c0a21e448f
|
@ -47,4 +47,79 @@ query {
|
||||||
email
|
email
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
variables {
|
||||||
|
"insert": {
|
||||||
|
"name": "Hello",
|
||||||
|
"description": "World",
|
||||||
|
"created_at": "now",
|
||||||
|
"updated_at": "now"
|
||||||
|
},
|
||||||
|
"user": 123
|
||||||
|
}
|
||||||
|
|
||||||
|
mutation {
|
||||||
|
products(insert: $insert) {
|
||||||
|
id
|
||||||
|
name
|
||||||
|
description
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
variables {
|
||||||
|
"insert": {
|
||||||
|
"name": "Hello",
|
||||||
|
"description": "World",
|
||||||
|
"created_at": "now",
|
||||||
|
"updated_at": "now"
|
||||||
|
},
|
||||||
|
"user": 123
|
||||||
|
}
|
||||||
|
|
||||||
|
mutation {
|
||||||
|
products(insert: $insert) {
|
||||||
|
id
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
variables {
|
||||||
|
"insert": {
|
||||||
|
"description": "World3",
|
||||||
|
"name": "Hello3",
|
||||||
|
"created_at": "now",
|
||||||
|
"updated_at": "now"
|
||||||
|
},
|
||||||
|
"user": 123
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
customers {
|
||||||
|
id
|
||||||
|
email
|
||||||
|
payments {
|
||||||
|
customer_id
|
||||||
|
amount
|
||||||
|
billing_details
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
variables {
|
||||||
|
"insert": {
|
||||||
|
"description": "World3",
|
||||||
|
"name": "Hello3",
|
||||||
|
"created_at": "now",
|
||||||
|
"updated_at": "now"
|
||||||
|
},
|
||||||
|
"user": 123
|
||||||
|
}
|
||||||
|
|
||||||
|
{
|
||||||
|
me {
|
||||||
|
id
|
||||||
|
full_name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -84,8 +84,9 @@ database:
|
||||||
#log_level: "debug"
|
#log_level: "debug"
|
||||||
|
|
||||||
# Define variables here that you want to use in filters
|
# Define variables here that you want to use in filters
|
||||||
|
# sub-queries must be wrapped in ()
|
||||||
variables:
|
variables:
|
||||||
account_id: "select account_id from users where id = $user_id"
|
account_id: "(select account_id from users where id = $user_id)"
|
||||||
|
|
||||||
# Define defaults to for the field key and values below
|
# Define defaults to for the field key and values below
|
||||||
defaults:
|
defaults:
|
||||||
|
@ -105,12 +106,12 @@ database:
|
||||||
# This filter will overwrite defaults.filter
|
# This filter will overwrite defaults.filter
|
||||||
# filter: ["{ id: { eq: $user_id } }"]
|
# filter: ["{ id: { eq: $user_id } }"]
|
||||||
|
|
||||||
- name: products
|
# - name: products
|
||||||
# Multiple filters are AND'd together
|
# # Multiple filters are AND'd together
|
||||||
filter: [
|
# filter: [
|
||||||
"{ price: { gt: 0 } }",
|
# "{ price: { gt: 0 } }",
|
||||||
"{ price: { lt: 8 } }"
|
# "{ price: { lt: 8 } }"
|
||||||
]
|
# ]
|
||||||
|
|
||||||
- name: customers
|
- name: customers
|
||||||
# No filter is used for this field not
|
# No filter is used for this field not
|
||||||
|
|
|
@ -82,8 +82,9 @@ database:
|
||||||
#log_level: "debug"
|
#log_level: "debug"
|
||||||
|
|
||||||
# Define variables here that you want to use in filters
|
# Define variables here that you want to use in filters
|
||||||
|
# sub-queries must be wrapped in ()
|
||||||
variables:
|
variables:
|
||||||
account_id: "select account_id from users where id = $user_id"
|
account_id: "(select account_id from users where id = $user_id)"
|
||||||
|
|
||||||
# Define defaults to for the field key and values below
|
# Define defaults to for the field key and values below
|
||||||
defaults:
|
defaults:
|
||||||
|
|
|
@ -2,6 +2,8 @@ version: '3.4'
|
||||||
services:
|
services:
|
||||||
db:
|
db:
|
||||||
image: postgres
|
image: postgres
|
||||||
|
ports:
|
||||||
|
- "5432:5432"
|
||||||
|
|
||||||
# redis:
|
# redis:
|
||||||
# image: redis:alpine
|
# image: redis:alpine
|
||||||
|
|
|
@ -0,0 +1,39 @@
|
||||||
|
# JSN - Fast low allocation JSON library
|
||||||
|
## Design
|
||||||
|
|
||||||
|
This libary is designed as a set of seperate functions to extract data and mutate
|
||||||
|
JSON. All functions are focused on keeping allocations to a minimum and be as fast
|
||||||
|
as possible. The functions don't validate the JSON a seperate `Validate` function
|
||||||
|
does that.
|
||||||
|
|
||||||
|
The JSON parsing algo processes each object `{}` or array `[]` in a bottom up way
|
||||||
|
where once the end of the array or object is found only then the keys within it are
|
||||||
|
parsed from the top down.
|
||||||
|
|
||||||
|
```
|
||||||
|
{"id":1,"posts": [{"title":"PT1-1","description":"PD1-1"}], "full_name":"FN1","email":"E1" }
|
||||||
|
|
||||||
|
id: 1
|
||||||
|
|
||||||
|
posts: [{"title":"PT1-1","description":"PD1-1"}]
|
||||||
|
|
||||||
|
[{"title":"PT1-1","description":"PD1-1"}]
|
||||||
|
|
||||||
|
{"title":"PT1-1","description":"PD1-1"}
|
||||||
|
|
||||||
|
title: "PT1-1"
|
||||||
|
|
||||||
|
description: "PD1-1
|
||||||
|
|
||||||
|
full_name: "FN1"
|
||||||
|
|
||||||
|
email: "E1"
|
||||||
|
```
|
||||||
|
|
||||||
|
## Functions
|
||||||
|
|
||||||
|
- Strip: Strip a path from the root to a child node and return the rest
|
||||||
|
- Replace: Replace values by key
|
||||||
|
- Get: Get all keys
|
||||||
|
- Filter: Extract specific keys from an object
|
||||||
|
- Tree: Fetch unique keys from an array or object
|
|
@ -43,7 +43,7 @@ func Get(b []byte, keys [][]byte) []Field {
|
||||||
kmap[xxhash.Sum64(keys[i])] = struct{}{}
|
kmap[xxhash.Sum64(keys[i])] = struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
res := make([]Field, 20)
|
res := make([]Field, 0, 20)
|
||||||
|
|
||||||
s, e, d := 0, 0, 0
|
s, e, d := 0, 0, 0
|
||||||
|
|
||||||
|
@ -127,7 +127,7 @@ func Get(b []byte, keys [][]byte) []Field {
|
||||||
_, ok := kmap[xxhash.Sum64(k)]
|
_, ok := kmap[xxhash.Sum64(k)]
|
||||||
|
|
||||||
if ok {
|
if ok {
|
||||||
res[n] = Field{k, b[s:(e + 1)]}
|
res = append(res, Field{k, b[s:(e + 1)]})
|
||||||
n++
|
n++
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,10 @@ var (
|
||||||
"full_name": "Caroll Orn Sr.",
|
"full_name": "Caroll Orn Sr.",
|
||||||
"email": "joannarau@hegmann.io",
|
"email": "joannarau@hegmann.io",
|
||||||
"__twitter_id": "ABC123"
|
"__twitter_id": "ABC123"
|
||||||
|
"more": [{
|
||||||
|
"__twitter_id": "more123",
|
||||||
|
"hello: "world
|
||||||
|
}]
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
|
@ -163,6 +167,7 @@ func TestGet(t *testing.T) {
|
||||||
{[]byte("__twitter_id"), []byte(`"ABCD"`)},
|
{[]byte("__twitter_id"), []byte(`"ABCD"`)},
|
||||||
{[]byte("__twitter_id"), []byte(`"2048666903444506956"`)},
|
{[]byte("__twitter_id"), []byte(`"2048666903444506956"`)},
|
||||||
{[]byte("__twitter_id"), []byte(`"ABC123"`)},
|
{[]byte("__twitter_id"), []byte(`"ABC123"`)},
|
||||||
|
{[]byte("__twitter_id"), []byte(`"more123"`)},
|
||||||
{[]byte("__twitter_id"),
|
{[]byte("__twitter_id"),
|
||||||
[]byte(`[{ "name": "hello" }, { "name": "world"}]`)},
|
[]byte(`[{ "name": "hello" }, { "name": "world"}]`)},
|
||||||
{[]byte("__twitter_id"),
|
{[]byte("__twitter_id"),
|
||||||
|
@ -340,6 +345,80 @@ func TestReplaceEmpty(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestKeys1(t *testing.T) {
|
||||||
|
json := `[{"id":1,"posts": [{"title":"PT1-1","description":"PD1-1"}, {"title":"PT1-2","description":"PD1-2"}], "full_name":"FN1","email":"E1","books": [{"name":"BN1-1","description":"BD1-1"},{"name":"BN1-2","description":"BD1-2"},{"name":"BN1-2","description":"BD1-2"}]},{"id":1,"posts": [{"title":"PT1-1","description":"PD1-1"}, {"title":"PT1-2","description":"PD1-2"}], "full_name":"FN1","email":"E1","books": [{"name":"BN1-1","description":"BD1-1"},{"name":"BN1-2","description":"BD1-2"},{"name":"BN1-2","description":"BD1-2"}]},{"id":1,"posts": [{"title":"PT1-1","description":"PD1-1"}, {"title":"PT1-2","description":"PD1-2"}], "full_name":"FN1","email":"E1","books": [{"name":"BN1-1","description":"BD1-1"},{"name":"BN1-2","description":"BD1-2"},{"name":"BN1-2","description":"BD1-2"}]}]`
|
||||||
|
|
||||||
|
fields := Keys([]byte(json))
|
||||||
|
|
||||||
|
exp := []string{
|
||||||
|
"id", "posts", "title", "description", "full_name", "email", "books", "name", "description",
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(exp) != len(fields) {
|
||||||
|
t.Errorf("Expected %d fields %d", len(exp), len(fields))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range exp {
|
||||||
|
if string(fields[i]) != exp[i] {
|
||||||
|
t.Errorf("Expected field '%s' got '%s'", string(exp[i]), fields[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKeys2(t *testing.T) {
|
||||||
|
json := `{"id":1,"posts": [{"title":"PT1-1","description":"PD1-1"}, {"title":"PT1-2","description":"PD1-2"}], "full_name":"FN1","email":"E1","books": [{"name":"BN1-1","description":"BD1-1"},{"name":"BN1-2","description":"BD1-2"},{"name":"BN1-2","description":"BD1-2"}]}`
|
||||||
|
|
||||||
|
fields := Keys([]byte(json))
|
||||||
|
|
||||||
|
exp := []string{
|
||||||
|
"id", "posts", "title", "description", "full_name", "email", "books", "name", "description",
|
||||||
|
}
|
||||||
|
|
||||||
|
// for i := range fields {
|
||||||
|
// fmt.Println("-->", string(fields[i]))
|
||||||
|
// }
|
||||||
|
|
||||||
|
if len(exp) != len(fields) {
|
||||||
|
t.Errorf("Expected %d fields %d", len(exp), len(fields))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range exp {
|
||||||
|
if string(fields[i]) != exp[i] {
|
||||||
|
t.Errorf("Expected field '%s' got '%s'", string(exp[i]), fields[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKeys3(t *testing.T) {
|
||||||
|
json := `{
|
||||||
|
"insert": {
|
||||||
|
"created_at": "now",
|
||||||
|
"test": { "type1": "a", "type2": "b" },
|
||||||
|
"name": "Hello",
|
||||||
|
"updated_at": "now",
|
||||||
|
"description": "World"
|
||||||
|
},
|
||||||
|
"user": 123
|
||||||
|
}`
|
||||||
|
|
||||||
|
fields := Keys([]byte(json))
|
||||||
|
|
||||||
|
exp := []string{
|
||||||
|
"insert", "created_at", "test", "type1", "type2", "name", "updated_at", "description",
|
||||||
|
"user",
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(exp) != len(fields) {
|
||||||
|
t.Errorf("Expected %d fields %d", len(exp), len(fields))
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range exp {
|
||||||
|
if string(fields[i]) != exp[i] {
|
||||||
|
t.Errorf("Expected field '%s' got '%s'", string(exp[i]), fields[i])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func BenchmarkGet(b *testing.B) {
|
func BenchmarkGet(b *testing.B) {
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,122 @@
|
||||||
|
package jsn
|
||||||
|
|
||||||
|
func Keys(b []byte) [][]byte {
|
||||||
|
res := make([][]byte, 0, 20)
|
||||||
|
|
||||||
|
s, e, d := 0, 0, 0
|
||||||
|
|
||||||
|
var k []byte
|
||||||
|
state := expectValue
|
||||||
|
|
||||||
|
st := NewStack()
|
||||||
|
ae := 0
|
||||||
|
|
||||||
|
for i := 0; i < len(b); i++ {
|
||||||
|
if state == expectObjClose || state == expectListClose {
|
||||||
|
switch b[i] {
|
||||||
|
case '{', '[':
|
||||||
|
d++
|
||||||
|
case '}', ']':
|
||||||
|
d--
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
si := st.Peek()
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case state == expectKey && si != nil && i >= si.ss:
|
||||||
|
i = si.se + 1
|
||||||
|
st.Pop()
|
||||||
|
|
||||||
|
case state == expectKey && b[i] == '{':
|
||||||
|
state = expectObjClose
|
||||||
|
s = i
|
||||||
|
d++
|
||||||
|
|
||||||
|
case state == expectObjClose && d == 0 && b[i] == '}':
|
||||||
|
state = expectKey
|
||||||
|
if ae != 0 {
|
||||||
|
st.Push(skipInfo{i, ae})
|
||||||
|
ae = 0
|
||||||
|
}
|
||||||
|
e = i
|
||||||
|
i = s
|
||||||
|
|
||||||
|
case state == expectKey && b[i] == '"':
|
||||||
|
state = expectKeyClose
|
||||||
|
s = i
|
||||||
|
|
||||||
|
case state == expectKeyClose && b[i] == '"':
|
||||||
|
state = expectColon
|
||||||
|
k = b[(s + 1):i]
|
||||||
|
|
||||||
|
case state == expectColon && b[i] == ':':
|
||||||
|
state = expectValue
|
||||||
|
|
||||||
|
case state == expectValue && b[i] == '"':
|
||||||
|
state = expectString
|
||||||
|
s = i
|
||||||
|
|
||||||
|
case state == expectString && b[i] == '"':
|
||||||
|
e = i
|
||||||
|
|
||||||
|
case state == expectValue && b[i] == '{':
|
||||||
|
state = expectObjClose
|
||||||
|
s = i
|
||||||
|
d++
|
||||||
|
|
||||||
|
case state == expectObjClose && d == 0 && b[i] == '}':
|
||||||
|
state = expectKey
|
||||||
|
e = i
|
||||||
|
i = s
|
||||||
|
|
||||||
|
case state == expectValue && b[i] == '[':
|
||||||
|
state = expectListClose
|
||||||
|
s = i
|
||||||
|
d++
|
||||||
|
|
||||||
|
case state == expectListClose && d == 0 && b[i] == ']':
|
||||||
|
state = expectKey
|
||||||
|
ae = i
|
||||||
|
e = i
|
||||||
|
i = s
|
||||||
|
|
||||||
|
case state == expectValue && (b[i] >= '0' && b[i] <= '9'):
|
||||||
|
state = expectNumClose
|
||||||
|
s = i
|
||||||
|
|
||||||
|
case state == expectNumClose &&
|
||||||
|
((b[i] < '0' || b[i] > '9') &&
|
||||||
|
(b[i] != '.' && b[i] != 'e' && b[i] != 'E' && b[i] != '+' && b[i] != '-')):
|
||||||
|
i--
|
||||||
|
e = i
|
||||||
|
|
||||||
|
case state == expectValue &&
|
||||||
|
(b[i] == 'f' || b[i] == 'F' || b[i] == 't' || b[i] == 'T'):
|
||||||
|
state = expectBoolClose
|
||||||
|
s = i
|
||||||
|
|
||||||
|
case state == expectBoolClose && (b[i] == 'e' || b[i] == 'E'):
|
||||||
|
e = i
|
||||||
|
|
||||||
|
case state == expectValue && b[i] == 'n':
|
||||||
|
state = expectNull
|
||||||
|
|
||||||
|
case state == expectNull && b[i] == 'l':
|
||||||
|
e = i
|
||||||
|
}
|
||||||
|
|
||||||
|
if e != 0 {
|
||||||
|
if k != nil {
|
||||||
|
res = append(res, k)
|
||||||
|
}
|
||||||
|
|
||||||
|
state = expectKey
|
||||||
|
k = nil
|
||||||
|
e = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return res
|
||||||
|
}
|
|
@ -0,0 +1,51 @@
|
||||||
|
package jsn
|
||||||
|
|
||||||
|
type skipInfo struct {
|
||||||
|
ss, se int
|
||||||
|
}
|
||||||
|
|
||||||
|
type Stack struct {
|
||||||
|
stA [20]skipInfo
|
||||||
|
st []skipInfo
|
||||||
|
top int
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new Stack
|
||||||
|
func NewStack() *Stack {
|
||||||
|
s := &Stack{top: -1}
|
||||||
|
s.st = s.stA[:0]
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
// Return the number of items in the Stack
|
||||||
|
func (s *Stack) Len() int {
|
||||||
|
return (s.top + 1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// View the top item on the Stack
|
||||||
|
func (s *Stack) Peek() *skipInfo {
|
||||||
|
if s.top == -1 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return &s.st[s.top]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pop the top item of the Stack and return it
|
||||||
|
func (s *Stack) Pop() *skipInfo {
|
||||||
|
if s.top == -1 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
s.top--
|
||||||
|
return &s.st[(s.top + 1)]
|
||||||
|
}
|
||||||
|
|
||||||
|
// Push a value onto the top of the Stack
|
||||||
|
func (s *Stack) Push(value skipInfo) {
|
||||||
|
s.top++
|
||||||
|
if len(s.st) <= s.top {
|
||||||
|
s.st = append(s.st, value)
|
||||||
|
} else {
|
||||||
|
s.st[s.top] = value
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,37 @@
|
||||||
|
package jsn
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
|
)
|
||||||
|
|
||||||
|
func Tree(v []byte) (map[string]interface{}, bool, error) {
|
||||||
|
dec := json.NewDecoder(bytes.NewReader(v))
|
||||||
|
array := false
|
||||||
|
|
||||||
|
// read open bracket
|
||||||
|
|
||||||
|
for i := range v {
|
||||||
|
if v[i] != ' ' {
|
||||||
|
array = (v[i] == '[')
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if array {
|
||||||
|
if _, err := dec.Token(); err != nil {
|
||||||
|
return nil, false, err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// while the array contains values
|
||||||
|
var m map[string]interface{}
|
||||||
|
|
||||||
|
// decode an array value (Message)
|
||||||
|
err := dec.Decode(&m)
|
||||||
|
if err != nil {
|
||||||
|
return nil, false, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return m, array, nil
|
||||||
|
}
|
|
@ -1,7 +0,0 @@
|
||||||
goos: darwin
|
|
||||||
goarch: amd64
|
|
||||||
pkg: github.com/dosco/super-graph/psql
|
|
||||||
BenchmarkCompile-8 100000 16476 ns/op 3282 B/op 66 allocs/op
|
|
||||||
BenchmarkCompileParallel-8 300000 4639 ns/op 3324 B/op 66 allocs/op
|
|
||||||
PASS
|
|
||||||
ok github.com/dosco/super-graph/psql 3.274s
|
|
|
@ -1,7 +0,0 @@
|
||||||
goos: darwin
|
|
||||||
goarch: amd64
|
|
||||||
pkg: github.com/dosco/super-graph/psql
|
|
||||||
BenchmarkCompile-8 100000 15728 ns/op 3000 B/op 60 allocs/op
|
|
||||||
BenchmarkCompileParallel-8 300000 5077 ns/op 3023 B/op 60 allocs/op
|
|
||||||
PASS
|
|
||||||
ok github.com/dosco/super-graph/psql 3.318s
|
|
|
@ -1,7 +0,0 @@
|
||||||
goos: darwin
|
|
||||||
goarch: amd64
|
|
||||||
pkg: github.com/dosco/super-graph/psql
|
|
||||||
BenchmarkCompile-8 1000000 15997 ns/op 3048 B/op 58 allocs/op
|
|
||||||
BenchmarkCompileParallel-8 3000000 4722 ns/op 3073 B/op 58 allocs/op
|
|
||||||
PASS
|
|
||||||
ok github.com/dosco/super-graph/psql 35.024s
|
|
|
@ -1,7 +0,0 @@
|
||||||
goos: darwin
|
|
||||||
goarch: amd64
|
|
||||||
pkg: github.com/dosco/super-graph/psql
|
|
||||||
BenchmarkCompile-8 100000 16829 ns/op 2887 B/op 57 allocs/op
|
|
||||||
BenchmarkCompileParallel-8 300000 5450 ns/op 2911 B/op 57 allocs/op
|
|
||||||
PASS
|
|
||||||
ok github.com/dosco/super-graph/psql 3.561s
|
|
|
@ -0,0 +1,90 @@
|
||||||
|
package psql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"errors"
|
||||||
|
"io"
|
||||||
|
|
||||||
|
"github.com/dosco/super-graph/jsn"
|
||||||
|
"github.com/dosco/super-graph/qcode"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (co *Compiler) compileMutation(qc *qcode.QCode, w *bytes.Buffer, vars Variables) (uint32, error) {
|
||||||
|
if len(qc.Selects) == 0 {
|
||||||
|
return 0, errors.New("empty query")
|
||||||
|
}
|
||||||
|
|
||||||
|
c := &compilerContext{w, qc.Selects, co}
|
||||||
|
root := &qc.Selects[0]
|
||||||
|
|
||||||
|
c.w.WriteString(`WITH `)
|
||||||
|
c.w.WriteString(root.Table)
|
||||||
|
c.w.WriteString(` AS (`)
|
||||||
|
|
||||||
|
if _, err := c.renderInsert(qc, w, vars); err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.w.WriteString(`) `)
|
||||||
|
|
||||||
|
return c.compileQuery(qc, w)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *compilerContext) renderInsert(qc *qcode.QCode, w *bytes.Buffer, vars Variables) (uint32, error) {
|
||||||
|
root := &qc.Selects[0]
|
||||||
|
|
||||||
|
insert, ok := vars["insert"]
|
||||||
|
if !ok {
|
||||||
|
return 0, errors.New("Variable 'insert' not defined")
|
||||||
|
}
|
||||||
|
|
||||||
|
jt, array, err := jsn.Tree(insert)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
c.w.WriteString(`WITH input AS (SELECT {{insert}}::json AS j) INSERT INTO `)
|
||||||
|
c.w.WriteString(root.Table)
|
||||||
|
io.WriteString(c.w, " (")
|
||||||
|
c.renderInsertColumns(qc, w, jt)
|
||||||
|
io.WriteString(c.w, ")")
|
||||||
|
|
||||||
|
c.w.WriteString(` SELECT `)
|
||||||
|
c.renderInsertColumns(qc, w, jt)
|
||||||
|
c.w.WriteString(` FROM input i, `)
|
||||||
|
|
||||||
|
if array {
|
||||||
|
c.w.WriteString(`json_populate_recordset`)
|
||||||
|
} else {
|
||||||
|
c.w.WriteString(`json_populate_record`)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.w.WriteString(`(NULL::`)
|
||||||
|
c.w.WriteString(root.Table)
|
||||||
|
c.w.WriteString(`, i.j) t RETURNING * `)
|
||||||
|
|
||||||
|
return 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *compilerContext) renderInsertColumns(qc *qcode.QCode, w *bytes.Buffer,
|
||||||
|
jt map[string]interface{}) (uint32, error) {
|
||||||
|
|
||||||
|
ti, err := c.schema.GetTable(qc.Selects[0].Table)
|
||||||
|
if err != nil {
|
||||||
|
return 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
i := 0
|
||||||
|
for _, cn := range ti.ColumnNames {
|
||||||
|
if _, ok := jt[cn]; !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if i != 0 {
|
||||||
|
io.WriteString(c.w, ", ")
|
||||||
|
}
|
||||||
|
c.w.WriteString(cn)
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0, nil
|
||||||
|
}
|
|
@ -0,0 +1,65 @@
|
||||||
|
package psql
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func singleInsert(t *testing.T) {
|
||||||
|
gql := `mutation {
|
||||||
|
product(id: 15, insert: $insert) {
|
||||||
|
id
|
||||||
|
name
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
sql := `test`
|
||||||
|
|
||||||
|
vars := map[string]json.RawMessage{
|
||||||
|
"insert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`),
|
||||||
|
}
|
||||||
|
|
||||||
|
resSQL, err := compileGQLToPSQL(gql, vars)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println(">", string(resSQL))
|
||||||
|
|
||||||
|
if string(resSQL) != sql {
|
||||||
|
t.Fatal(errNotExpected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func bulkInsert(t *testing.T) {
|
||||||
|
gql := `mutation {
|
||||||
|
product(id: 15, insert: $insert) {
|
||||||
|
id
|
||||||
|
name
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
sql := `test`
|
||||||
|
|
||||||
|
vars := map[string]json.RawMessage{
|
||||||
|
"insert": json.RawMessage(` [{ "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }]`),
|
||||||
|
}
|
||||||
|
|
||||||
|
resSQL, err := compileGQLToPSQL(gql, vars)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println(">", string(resSQL))
|
||||||
|
|
||||||
|
if string(resSQL) != sql {
|
||||||
|
t.Fatal(errNotExpected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompileInsert(t *testing.T) {
|
||||||
|
t.Run("singleInsert", singleInsert)
|
||||||
|
t.Run("bulkInsert", bulkInsert)
|
||||||
|
|
||||||
|
}
|
|
@ -2,6 +2,7 @@ package psql
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
@ -18,6 +19,8 @@ const (
|
||||||
closeBlock = 500
|
closeBlock = 500
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type Variables map[string]json.RawMessage
|
||||||
|
|
||||||
type Config struct {
|
type Config struct {
|
||||||
Schema *DBSchema
|
Schema *DBSchema
|
||||||
Vars map[string]string
|
Vars map[string]string
|
||||||
|
@ -51,19 +54,30 @@ type compilerContext struct {
|
||||||
*Compiler
|
*Compiler
|
||||||
}
|
}
|
||||||
|
|
||||||
func (co *Compiler) CompileEx(qc *qcode.QCode) (uint32, []byte, error) {
|
func (co *Compiler) CompileEx(qc *qcode.QCode, vars Variables) (uint32, []byte, error) {
|
||||||
w := &bytes.Buffer{}
|
w := &bytes.Buffer{}
|
||||||
skipped, err := co.Compile(qc, w)
|
skipped, err := co.Compile(qc, w, vars)
|
||||||
return skipped, w.Bytes(), err
|
return skipped, w.Bytes(), err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (co *Compiler) Compile(qc *qcode.QCode, w *bytes.Buffer) (uint32, error) {
|
func (co *Compiler) Compile(qc *qcode.QCode, w *bytes.Buffer, vars Variables) (uint32, error) {
|
||||||
if len(qc.Query.Selects) == 0 {
|
switch qc.Type {
|
||||||
|
case qcode.QTQuery:
|
||||||
|
return co.compileQuery(qc, w)
|
||||||
|
case qcode.QTMutation:
|
||||||
|
return co.compileMutation(qc, w, vars)
|
||||||
|
}
|
||||||
|
|
||||||
|
return 0, errors.New("unknown operation")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (co *Compiler) compileQuery(qc *qcode.QCode, w *bytes.Buffer) (uint32, error) {
|
||||||
|
if len(qc.Selects) == 0 {
|
||||||
return 0, errors.New("empty query")
|
return 0, errors.New("empty query")
|
||||||
}
|
}
|
||||||
|
|
||||||
c := &compilerContext{w, qc.Query.Selects, co}
|
c := &compilerContext{w, qc.Selects, co}
|
||||||
root := &qc.Query.Selects[0]
|
root := &qc.Selects[0]
|
||||||
|
|
||||||
st := NewStack()
|
st := NewStack()
|
||||||
st.Push(root.ID + closeBlock)
|
st.Push(root.ID + closeBlock)
|
||||||
|
@ -844,7 +858,7 @@ func (c *compilerContext) renderList(ex *qcode.Exp) {
|
||||||
func (c *compilerContext) renderVal(ex *qcode.Exp,
|
func (c *compilerContext) renderVal(ex *qcode.Exp,
|
||||||
vars map[string]string) {
|
vars map[string]string) {
|
||||||
|
|
||||||
io.WriteString(c.w, ` (`)
|
//io.WriteString(c.w, ` (`)
|
||||||
switch ex.Type {
|
switch ex.Type {
|
||||||
case qcode.ValBool, qcode.ValInt, qcode.ValFloat:
|
case qcode.ValBool, qcode.ValInt, qcode.ValFloat:
|
||||||
if len(ex.Val) != 0 {
|
if len(ex.Val) != 0 {
|
||||||
|
@ -852,21 +866,23 @@ func (c *compilerContext) renderVal(ex *qcode.Exp,
|
||||||
} else {
|
} else {
|
||||||
c.w.WriteString(`''`)
|
c.w.WriteString(`''`)
|
||||||
}
|
}
|
||||||
|
|
||||||
case qcode.ValStr:
|
case qcode.ValStr:
|
||||||
c.w.WriteString(`'`)
|
c.w.WriteString(`'`)
|
||||||
c.w.WriteString(ex.Val)
|
c.w.WriteString(ex.Val)
|
||||||
c.w.WriteString(`'`)
|
c.w.WriteString(`'`)
|
||||||
|
|
||||||
case qcode.ValVar:
|
case qcode.ValVar:
|
||||||
if val, ok := vars[ex.Val]; ok {
|
if val, ok := vars[ex.Val]; ok {
|
||||||
c.w.WriteString(val)
|
c.w.WriteString(val)
|
||||||
} else {
|
} else {
|
||||||
//fmt.Fprintf(w, `'{{%s}}'`, ex.Val)
|
//fmt.Fprintf(w, `'{{%s}}'`, ex.Val)
|
||||||
c.w.WriteString(`'{{`)
|
c.w.WriteString(`{{`)
|
||||||
c.w.WriteString(ex.Val)
|
c.w.WriteString(ex.Val)
|
||||||
c.w.WriteString(`}}'`)
|
c.w.WriteString(`}}`)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
c.w.WriteString(`)`)
|
//c.w.WriteString(`)`)
|
||||||
}
|
}
|
||||||
|
|
||||||
func funcPrefixLen(fn string) int {
|
func funcPrefixLen(fn string) int {
|
|
@ -125,13 +125,13 @@ func TestMain(m *testing.M) {
|
||||||
os.Exit(m.Run())
|
os.Exit(m.Run())
|
||||||
}
|
}
|
||||||
|
|
||||||
func compileGQLToPSQL(gql string) ([]byte, error) {
|
func compileGQLToPSQL(gql string, vars Variables) ([]byte, error) {
|
||||||
qc, err := qcompile.Compile([]byte(gql))
|
qc, err := qcompile.Compile([]byte(gql))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
_, sqlStmt, err := pcompile.CompileEx(qc)
|
_, sqlStmt, err := pcompile.CompileEx(qc, vars)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -164,7 +164,7 @@ func withComplexArgs(t *testing.T) {
|
||||||
|
|
||||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products" ORDER BY "products_0_price_ob" DESC), '[]') AS "products" FROM (SELECT DISTINCT ON ("products_0_price_ob") row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products", "products_0"."price" AS "products_0_price_ob" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("products"."id") < (28)) AND (("products"."id") >= (20))) LIMIT ('30') :: integer) AS "products_0" ORDER BY "products_0_price_ob" DESC LIMIT ('30') :: integer) AS "products_0") AS "done_1337";`
|
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products" ORDER BY "products_0_price_ob" DESC), '[]') AS "products" FROM (SELECT DISTINCT ON ("products_0_price_ob") row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products", "products_0"."price" AS "products_0_price_ob" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("products"."id") < (28)) AND (("products"."id") >= (20))) LIMIT ('30') :: integer) AS "products_0" ORDER BY "products_0_price_ob" DESC LIMIT ('30') :: integer) AS "products_0") AS "done_1337";`
|
||||||
|
|
||||||
resSQL, err := compileGQLToPSQL(gql)
|
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -192,7 +192,7 @@ func withWhereMultiOr(t *testing.T) {
|
||||||
|
|
||||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("products"."price") < (20)) OR (("products"."price") > (10)) OR NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
|
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("products"."price") < (20)) OR (("products"."price") > (10)) OR NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
|
||||||
|
|
||||||
resSQL, err := compileGQLToPSQL(gql)
|
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -218,7 +218,7 @@ func withWhereIsNull(t *testing.T) {
|
||||||
|
|
||||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("products"."price") > (10)) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
|
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("products"."price") > (10)) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
|
||||||
|
|
||||||
resSQL, err := compileGQLToPSQL(gql)
|
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -244,7 +244,7 @@ func withWhereAndList(t *testing.T) {
|
||||||
|
|
||||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("products"."price") > (10)) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
|
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("products"."price") > (10)) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
|
||||||
|
|
||||||
resSQL, err := compileGQLToPSQL(gql)
|
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -264,7 +264,7 @@ func fetchByID(t *testing.T) {
|
||||||
|
|
||||||
sql := `SELECT json_object_agg('product', product) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "product_0"."id" AS "id", "product_0"."name" AS "name") AS "sel_0")) AS "product" FROM (SELECT "product"."id", "product"."name" FROM "products" AS "product" WHERE ((("product"."price") > (0)) AND (("product"."price") < (8)) AND (("id") = (15))) LIMIT ('1') :: integer) AS "product_0" LIMIT ('1') :: integer) AS "done_1337";`
|
sql := `SELECT json_object_agg('product', product) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "product_0"."id" AS "id", "product_0"."name" AS "name") AS "sel_0")) AS "product" FROM (SELECT "product"."id", "product"."name" FROM "products" AS "product" WHERE ((("product"."price") > (0)) AND (("product"."price") < (8)) AND (("id") = (15))) LIMIT ('1') :: integer) AS "product_0" LIMIT ('1') :: integer) AS "done_1337";`
|
||||||
|
|
||||||
resSQL, err := compileGQLToPSQL(gql)
|
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -284,7 +284,7 @@ func searchQuery(t *testing.T) {
|
||||||
|
|
||||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("tsv") @@ to_tsquery('Imperial'))) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
|
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("tsv") @@ to_tsquery('Imperial'))) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
|
||||||
|
|
||||||
resSQL, err := compileGQLToPSQL(gql)
|
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -307,7 +307,7 @@ func oneToMany(t *testing.T) {
|
||||||
|
|
||||||
sql := `SELECT json_object_agg('users', users) FROM (SELECT coalesce(json_agg("users"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."email" AS "email", "products_1_join"."products" AS "products") AS "sel_0")) AS "users" FROM (SELECT "users"."email", "users"."id" FROM "users" WHERE ((("users"."id") = ('{{user_id}}'))) LIMIT ('20') :: integer) AS "users_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "products_1"."name" AS "name", "products_1"."price" AS "price") AS "sel_1")) AS "products" FROM (SELECT "products"."name", "products"."price" FROM "products" WHERE ((("products"."user_id") = ("users_0"."id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "products_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "users_0") AS "done_1337";`
|
sql := `SELECT json_object_agg('users', users) FROM (SELECT coalesce(json_agg("users"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."email" AS "email", "products_1_join"."products" AS "products") AS "sel_0")) AS "users" FROM (SELECT "users"."email", "users"."id" FROM "users" WHERE ((("users"."id") = ('{{user_id}}'))) LIMIT ('20') :: integer) AS "users_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "products_1"."name" AS "name", "products_1"."price" AS "price") AS "sel_1")) AS "products" FROM (SELECT "products"."name", "products"."price" FROM "products" WHERE ((("products"."user_id") = ("users_0"."id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "products_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "users_0") AS "done_1337";`
|
||||||
|
|
||||||
resSQL, err := compileGQLToPSQL(gql)
|
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -330,7 +330,7 @@ func belongsTo(t *testing.T) {
|
||||||
|
|
||||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."price" AS "price", "users_1_join"."users" AS "users") AS "sel_0")) AS "products" FROM (SELECT "products"."name", "products"."price", "products"."user_id" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8))) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("users"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "users_1"."email" AS "email") AS "sel_1")) AS "users" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = ("products_0"."user_id"))) LIMIT ('20') :: integer) AS "users_1" LIMIT ('20') :: integer) AS "users_1") AS "users_1_join" ON ('true') LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
|
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."price" AS "price", "users_1_join"."users" AS "users") AS "sel_0")) AS "products" FROM (SELECT "products"."name", "products"."price", "products"."user_id" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8))) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("users"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "users_1"."email" AS "email") AS "sel_1")) AS "users" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = ("products_0"."user_id"))) LIMIT ('20') :: integer) AS "users_1" LIMIT ('20') :: integer) AS "users_1") AS "users_1_join" ON ('true') LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
|
||||||
|
|
||||||
resSQL, err := compileGQLToPSQL(gql)
|
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -353,7 +353,7 @@ func manyToMany(t *testing.T) {
|
||||||
|
|
||||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "customers_1_join"."customers" AS "customers") AS "sel_0")) AS "products" FROM (SELECT "products"."name", "products"."id" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8))) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("customers"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "customers_1"."email" AS "email", "customers_1"."full_name" AS "full_name") AS "sel_1")) AS "customers" FROM (SELECT "customers"."email", "customers"."full_name" FROM "customers" LEFT OUTER JOIN "purchases" ON (("purchases"."product_id") = ("products_0"."id")) WHERE ((("customers"."id") = ("purchases"."customer_id"))) LIMIT ('20') :: integer) AS "customers_1" LIMIT ('20') :: integer) AS "customers_1") AS "customers_1_join" ON ('true') LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
|
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "customers_1_join"."customers" AS "customers") AS "sel_0")) AS "products" FROM (SELECT "products"."name", "products"."id" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8))) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("customers"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "customers_1"."email" AS "email", "customers_1"."full_name" AS "full_name") AS "sel_1")) AS "customers" FROM (SELECT "customers"."email", "customers"."full_name" FROM "customers" LEFT OUTER JOIN "purchases" ON (("purchases"."product_id") = ("products_0"."id")) WHERE ((("customers"."id") = ("purchases"."customer_id"))) LIMIT ('20') :: integer) AS "customers_1" LIMIT ('20') :: integer) AS "customers_1") AS "customers_1_join" ON ('true') LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
|
||||||
|
|
||||||
resSQL, err := compileGQLToPSQL(gql)
|
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -376,7 +376,7 @@ func manyToManyReverse(t *testing.T) {
|
||||||
|
|
||||||
sql := `SELECT json_object_agg('customers', customers) FROM (SELECT coalesce(json_agg("customers"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "customers_0"."email" AS "email", "customers_0"."full_name" AS "full_name", "products_1_join"."products" AS "products") AS "sel_0")) AS "customers" FROM (SELECT "customers"."email", "customers"."full_name", "customers"."id" FROM "customers" LIMIT ('20') :: integer) AS "customers_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "products_1"."name" AS "name") AS "sel_1")) AS "products" FROM (SELECT "products"."name" FROM "products" LEFT OUTER JOIN "purchases" ON (("purchases"."customer_id") = ("customers_0"."id")) WHERE ((("products"."id") = ("purchases"."product_id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "products_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "customers_0") AS "done_1337";`
|
sql := `SELECT json_object_agg('customers', customers) FROM (SELECT coalesce(json_agg("customers"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "customers_0"."email" AS "email", "customers_0"."full_name" AS "full_name", "products_1_join"."products" AS "products") AS "sel_0")) AS "customers" FROM (SELECT "customers"."email", "customers"."full_name", "customers"."id" FROM "customers" LIMIT ('20') :: integer) AS "customers_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "products_1"."name" AS "name") AS "sel_1")) AS "products" FROM (SELECT "products"."name" FROM "products" LEFT OUTER JOIN "purchases" ON (("purchases"."customer_id") = ("customers_0"."id")) WHERE ((("products"."id") = ("purchases"."product_id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "products_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "customers_0") AS "done_1337";`
|
||||||
|
|
||||||
resSQL, err := compileGQLToPSQL(gql)
|
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -396,7 +396,7 @@ func aggFunction(t *testing.T) {
|
||||||
|
|
||||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."count_price" AS "count_price") AS "sel_0")) AS "products" FROM (SELECT "products"."name", count("products"."price") AS "count_price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8))) GROUP BY "products"."name" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
|
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."count_price" AS "count_price") AS "sel_0")) AS "products" FROM (SELECT "products"."name", count("products"."price") AS "count_price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8))) GROUP BY "products"."name" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
|
||||||
|
|
||||||
resSQL, err := compileGQLToPSQL(gql)
|
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -416,7 +416,7 @@ func aggFunctionWithFilter(t *testing.T) {
|
||||||
|
|
||||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."max_price" AS "max_price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", max("products"."price") AS "max_price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("products"."id") > (10))) GROUP BY "products"."id" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
|
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."max_price" AS "max_price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", max("products"."price") AS "max_price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("products"."id") > (10))) GROUP BY "products"."id" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
|
||||||
|
|
||||||
resSQL, err := compileGQLToPSQL(gql)
|
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -436,7 +436,7 @@ func queryWithVariables(t *testing.T) {
|
||||||
|
|
||||||
sql := `SELECT json_object_agg('product', product) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "product_0"."id" AS "id", "product_0"."name" AS "name") AS "sel_0")) AS "product" FROM (SELECT "product"."id", "product"."name" FROM "products" AS "product" WHERE ((("product"."price") > (0)) AND (("product"."price") < (8)) AND (("product"."price") = ('{{product_price}}')) AND (("id") = ('{{product_id}}'))) LIMIT ('1') :: integer) AS "product_0" LIMIT ('1') :: integer) AS "done_1337";`
|
sql := `SELECT json_object_agg('product', product) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "product_0"."id" AS "id", "product_0"."name" AS "name") AS "sel_0")) AS "product" FROM (SELECT "product"."id", "product"."name" FROM "products" AS "product" WHERE ((("product"."price") > (0)) AND (("product"."price") < (8)) AND (("product"."price") = ('{{product_price}}')) AND (("id") = ('{{product_id}}'))) LIMIT ('1') :: integer) AS "product_0" LIMIT ('1') :: integer) AS "done_1337";`
|
||||||
|
|
||||||
resSQL, err := compileGQLToPSQL(gql)
|
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -455,7 +455,7 @@ func syntheticTables(t *testing.T) {
|
||||||
|
|
||||||
sql := `SELECT json_object_agg('me', me) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "me_0"."email" AS "email") AS "sel_0")) AS "me" FROM (SELECT "me"."email" FROM "users" AS "me" WHERE ((("me"."id") = ('{{user_id}}'))) LIMIT ('1') :: integer) AS "me_0" LIMIT ('1') :: integer) AS "done_1337";`
|
sql := `SELECT json_object_agg('me', me) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "me_0"."email" AS "email") AS "sel_0")) AS "me" FROM (SELECT "me"."email" FROM "users" AS "me" WHERE ((("me"."id") = ('{{user_id}}'))) LIMIT ('1') :: integer) AS "me_0" LIMIT ('1') :: integer) AS "done_1337";`
|
||||||
|
|
||||||
resSQL, err := compileGQLToPSQL(gql)
|
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -465,7 +465,7 @@ func syntheticTables(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestCompileGQL(t *testing.T) {
|
func TestCompileSelect(t *testing.T) {
|
||||||
t.Run("withComplexArgs", withComplexArgs)
|
t.Run("withComplexArgs", withComplexArgs)
|
||||||
t.Run("withWhereAndList", withWhereAndList)
|
t.Run("withWhereAndList", withWhereAndList)
|
||||||
t.Run("withWhereIsNull", withWhereIsNull)
|
t.Run("withWhereIsNull", withWhereIsNull)
|
||||||
|
@ -519,7 +519,7 @@ func BenchmarkCompile(b *testing.B) {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = pcompile.Compile(qc, w)
|
_, err = pcompile.Compile(qc, w, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@ -540,7 +540,7 @@ func BenchmarkCompileParallel(b *testing.B) {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = pcompile.Compile(qc, w)
|
_, err = pcompile.Compile(qc, w, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
}
|
}
|
|
@ -106,11 +106,12 @@ type DBSchema struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
type DBTableInfo struct {
|
type DBTableInfo struct {
|
||||||
Name string
|
Name string
|
||||||
Singular bool
|
Singular bool
|
||||||
PrimaryCol string
|
PrimaryCol string
|
||||||
TSVCol string
|
TSVCol string
|
||||||
Columns map[string]*DBColumn
|
Columns map[string]*DBColumn
|
||||||
|
ColumnNames []string
|
||||||
}
|
}
|
||||||
|
|
||||||
type RelType int
|
type RelType int
|
||||||
|
@ -162,25 +163,30 @@ func (s *DBSchema) updateSchema(
|
||||||
// Foreign key columns in current table
|
// Foreign key columns in current table
|
||||||
colByID := make(map[int]*DBColumn)
|
colByID := make(map[int]*DBColumn)
|
||||||
columns := make(map[string]*DBColumn, len(cols))
|
columns := make(map[string]*DBColumn, len(cols))
|
||||||
|
colNames := make([]string, len(cols))
|
||||||
|
|
||||||
for i := range cols {
|
for i := range cols {
|
||||||
c := cols[i]
|
c := cols[i]
|
||||||
columns[strings.ToLower(c.Name)] = cols[i]
|
name := strings.ToLower(c.Name)
|
||||||
|
columns[name] = cols[i]
|
||||||
|
colNames = append(colNames, name)
|
||||||
colByID[c.ID] = cols[i]
|
colByID[c.ID] = cols[i]
|
||||||
}
|
}
|
||||||
|
|
||||||
singular := strings.ToLower(flect.Singularize(t.Name))
|
singular := strings.ToLower(flect.Singularize(t.Name))
|
||||||
s.t[singular] = &DBTableInfo{
|
s.t[singular] = &DBTableInfo{
|
||||||
Name: t.Name,
|
Name: t.Name,
|
||||||
Singular: true,
|
Singular: true,
|
||||||
Columns: columns,
|
Columns: columns,
|
||||||
|
ColumnNames: colNames,
|
||||||
}
|
}
|
||||||
|
|
||||||
plural := strings.ToLower(flect.Pluralize(t.Name))
|
plural := strings.ToLower(flect.Pluralize(t.Name))
|
||||||
s.t[plural] = &DBTableInfo{
|
s.t[plural] = &DBTableInfo{
|
||||||
Name: t.Name,
|
Name: t.Name,
|
||||||
Singular: false,
|
Singular: false,
|
||||||
Columns: columns,
|
Columns: columns,
|
||||||
|
ColumnNames: colNames,
|
||||||
}
|
}
|
||||||
|
|
||||||
ct := strings.ToLower(t.Name)
|
ct := strings.ToLower(t.Name)
|
||||||
|
|
|
@ -100,19 +100,7 @@ var lexPool = sync.Pool{
|
||||||
}
|
}
|
||||||
|
|
||||||
func Parse(gql []byte) (*Operation, error) {
|
func Parse(gql []byte) (*Operation, error) {
|
||||||
return parseSelectionSet(nil, gql)
|
return parseSelectionSet(gql)
|
||||||
}
|
|
||||||
|
|
||||||
func ParseQuery(gql []byte) (*Operation, error) {
|
|
||||||
op := opPool.Get().(*Operation)
|
|
||||||
op.Reset()
|
|
||||||
|
|
||||||
op.Type = opQuery
|
|
||||||
op.Name = ""
|
|
||||||
op.Fields = op.fieldsA[:0]
|
|
||||||
op.Args = op.argsA[:0]
|
|
||||||
|
|
||||||
return parseSelectionSet(op, gql)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func ParseArgValue(argVal string) (*Node, error) {
|
func ParseArgValue(argVal string) (*Node, error) {
|
||||||
|
@ -134,7 +122,7 @@ func ParseArgValue(argVal string) (*Node, error) {
|
||||||
return op, err
|
return op, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseSelectionSet(op *Operation, gql []byte) (*Operation, error) {
|
func parseSelectionSet(gql []byte) (*Operation, error) {
|
||||||
var err error
|
var err error
|
||||||
|
|
||||||
if len(gql) == 0 {
|
if len(gql) == 0 {
|
||||||
|
@ -154,14 +142,28 @@ func parseSelectionSet(op *Operation, gql []byte) (*Operation, error) {
|
||||||
items: l.items,
|
items: l.items,
|
||||||
}
|
}
|
||||||
|
|
||||||
if op == nil {
|
var op *Operation
|
||||||
op, err = p.parseOp()
|
|
||||||
} else {
|
|
||||||
if p.peek(itemObjOpen) {
|
|
||||||
p.ignore()
|
|
||||||
}
|
|
||||||
|
|
||||||
|
if p.peek(itemObjOpen) {
|
||||||
|
p.ignore()
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.peek(itemName) {
|
||||||
|
op = opPool.Get().(*Operation)
|
||||||
|
op.Reset()
|
||||||
|
|
||||||
|
op.Type = opQuery
|
||||||
|
op.Name = ""
|
||||||
|
op.Fields = op.fieldsA[:0]
|
||||||
|
op.Args = op.argsA[:0]
|
||||||
op.Fields, err = p.parseFields(op.Fields)
|
op.Fields, err = p.parseFields(op.Fields)
|
||||||
|
|
||||||
|
} else {
|
||||||
|
op, err = p.parseOp()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
lexPool.Put(l)
|
lexPool.Put(l)
|
||||||
|
|
|
@ -45,10 +45,10 @@ func compareOp(op1, op2 Operation) error {
|
||||||
}
|
}
|
||||||
*/
|
*/
|
||||||
|
|
||||||
func TestCompile(t *testing.T) {
|
func TestCompile1(t *testing.T) {
|
||||||
qcompile, _ := NewCompiler(Config{})
|
qcompile, _ := NewCompiler(Config{})
|
||||||
|
|
||||||
_, err := qcompile.CompileQuery([]byte(`
|
_, err := qcompile.Compile([]byte(`
|
||||||
product(id: 15) {
|
product(id: 15) {
|
||||||
id
|
id
|
||||||
name
|
name
|
||||||
|
@ -59,9 +59,39 @@ func TestCompile(t *testing.T) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestCompile2(t *testing.T) {
|
||||||
|
qcompile, _ := NewCompiler(Config{})
|
||||||
|
|
||||||
|
_, err := qcompile.Compile([]byte(`
|
||||||
|
query { product(id: 15) {
|
||||||
|
id
|
||||||
|
name
|
||||||
|
} }`))
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompile3(t *testing.T) {
|
||||||
|
qcompile, _ := NewCompiler(Config{})
|
||||||
|
|
||||||
|
_, err := qcompile.Compile([]byte(`
|
||||||
|
mutation {
|
||||||
|
product(id: 15, name: "Test") {
|
||||||
|
id
|
||||||
|
name
|
||||||
|
}
|
||||||
|
}`))
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestInvalidCompile1(t *testing.T) {
|
func TestInvalidCompile1(t *testing.T) {
|
||||||
qcompile, _ := NewCompiler(Config{})
|
qcompile, _ := NewCompiler(Config{})
|
||||||
_, err := qcompile.CompileQuery([]byte(`#`))
|
_, err := qcompile.Compile([]byte(`#`))
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal(errors.New("expecting an error"))
|
t.Fatal(errors.New("expecting an error"))
|
||||||
|
@ -70,7 +100,7 @@ func TestInvalidCompile1(t *testing.T) {
|
||||||
|
|
||||||
func TestInvalidCompile2(t *testing.T) {
|
func TestInvalidCompile2(t *testing.T) {
|
||||||
qcompile, _ := NewCompiler(Config{})
|
qcompile, _ := NewCompiler(Config{})
|
||||||
_, err := qcompile.CompileQuery([]byte(`{u(where:{not:0})}`))
|
_, err := qcompile.Compile([]byte(`{u(where:{not:0})}`))
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal(errors.New("expecting an error"))
|
t.Fatal(errors.New("expecting an error"))
|
||||||
|
@ -79,7 +109,7 @@ func TestInvalidCompile2(t *testing.T) {
|
||||||
|
|
||||||
func TestEmptyCompile(t *testing.T) {
|
func TestEmptyCompile(t *testing.T) {
|
||||||
qcompile, _ := NewCompiler(Config{})
|
qcompile, _ := NewCompiler(Config{})
|
||||||
_, err := qcompile.CompileQuery([]byte(``))
|
_, err := qcompile.Compile([]byte(``))
|
||||||
|
|
||||||
if err == nil {
|
if err == nil {
|
||||||
t.Fatal(errors.New("expecting an error"))
|
t.Fatal(errors.New("expecting an error"))
|
||||||
|
@ -114,7 +144,7 @@ func BenchmarkQCompile(b *testing.B) {
|
||||||
b.ReportAllocs()
|
b.ReportAllocs()
|
||||||
|
|
||||||
for n := 0; n < b.N; n++ {
|
for n := 0; n < b.N; n++ {
|
||||||
_, err := qcompile.CompileQuery(gql)
|
_, err := qcompile.Compile(gql)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
|
@ -130,7 +160,7 @@ func BenchmarkQCompileP(b *testing.B) {
|
||||||
|
|
||||||
b.RunParallel(func(pb *testing.PB) {
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
for pb.Next() {
|
for pb.Next() {
|
||||||
_, err := qcompile.CompileQuery(gql)
|
_, err := qcompile.Compile(gql)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
|
|
|
@ -10,15 +10,17 @@ import (
|
||||||
"github.com/gobuffalo/flect"
|
"github.com/gobuffalo/flect"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type QType int
|
||||||
|
|
||||||
const (
|
const (
|
||||||
maxSelectors = 30
|
maxSelectors = 30
|
||||||
|
|
||||||
|
QTQuery QType = iota + 1
|
||||||
|
QTMutation
|
||||||
)
|
)
|
||||||
|
|
||||||
type QCode struct {
|
type QCode struct {
|
||||||
Query *Query
|
Type QType
|
||||||
}
|
|
||||||
|
|
||||||
type Query struct {
|
|
||||||
Selects []Select
|
Selects []Select
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -149,6 +151,11 @@ type Compiler struct {
|
||||||
ka bool
|
ka bool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var opMap = map[parserType]QType{
|
||||||
|
opQuery: QTQuery,
|
||||||
|
opMutate: QTMutation,
|
||||||
|
}
|
||||||
|
|
||||||
var expPool = sync.Pool{
|
var expPool = sync.Pool{
|
||||||
New: func() interface{} { return new(Exp) },
|
New: func() interface{} { return new(Exp) },
|
||||||
}
|
}
|
||||||
|
@ -196,44 +203,23 @@ func (com *Compiler) Compile(query []byte) (*QCode, error) {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
switch op.Type {
|
qc.Selects, err = com.compileQuery(op)
|
||||||
case opQuery:
|
|
||||||
qc.Query, err = com.compileQuery(op)
|
|
||||||
case opMutate:
|
|
||||||
case opSub:
|
|
||||||
default:
|
|
||||||
err = fmt.Errorf("Unknown operation type %d", op.Type)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if t, ok := opMap[op.Type]; ok {
|
||||||
|
qc.Type = t
|
||||||
|
} else {
|
||||||
|
return nil, fmt.Errorf("Unknown operation type %d", op.Type)
|
||||||
|
}
|
||||||
|
|
||||||
opPool.Put(op)
|
opPool.Put(op)
|
||||||
|
|
||||||
return &qc, nil
|
return &qc, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (com *Compiler) CompileQuery(query []byte) (*QCode, error) {
|
func (com *Compiler) compileQuery(op *Operation) ([]Select, error) {
|
||||||
var err error
|
|
||||||
|
|
||||||
op, err := ParseQuery(query)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
qc := &QCode{}
|
|
||||||
qc.Query, err = com.compileQuery(op)
|
|
||||||
opPool.Put(op)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
return qc, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (com *Compiler) compileQuery(op *Operation) (*Query, error) {
|
|
||||||
id := int32(0)
|
id := int32(0)
|
||||||
parentID := int32(0)
|
parentID := int32(0)
|
||||||
|
|
||||||
|
@ -344,7 +330,7 @@ func (com *Compiler) compileQuery(op *Operation) (*Query, error) {
|
||||||
return nil, errors.New("invalid query")
|
return nil, errors.New("invalid query")
|
||||||
}
|
}
|
||||||
|
|
||||||
return &Query{selects[:id]}, nil
|
return selects[:id], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (com *Compiler) compileArgs(sel *Select, args []Arg) error {
|
func (com *Compiler) compileArgs(sel *Select, args []Arg) error {
|
||||||
|
@ -661,14 +647,6 @@ func (com *Compiler) compileArgOffset(sel *Select, arg *Arg) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func compileMutate() (*Query, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func compileSub() (*Query, error) {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func newExp(st *util.Stack, node *Node, usePool bool) (*Exp, error) {
|
func newExp(st *util.Stack, node *Node, usePool bool) (*Exp, error) {
|
||||||
name := node.Name
|
name := node.Name
|
||||||
if name[0] == '_' {
|
if name[0] == '_' {
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package serv
|
package serv
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"log"
|
"log"
|
||||||
|
@ -9,9 +10,15 @@ import (
|
||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
AL_QUERY int = iota + 1
|
||||||
|
AL_VARS
|
||||||
|
)
|
||||||
|
|
||||||
type allowItem struct {
|
type allowItem struct {
|
||||||
uri string
|
uri string
|
||||||
gql string
|
gql string
|
||||||
|
vars json.RawMessage
|
||||||
}
|
}
|
||||||
|
|
||||||
var _allowList allowList
|
var _allowList allowList
|
||||||
|
@ -77,8 +84,9 @@ func (al *allowList) add(req *gqlReq) {
|
||||||
}
|
}
|
||||||
|
|
||||||
al.saveChan <- &allowItem{
|
al.saveChan <- &allowItem{
|
||||||
uri: req.ref,
|
uri: req.ref,
|
||||||
gql: req.Query,
|
gql: req.Query,
|
||||||
|
vars: req.Vars,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -93,32 +101,62 @@ func (al *allowList) load() {
|
||||||
}
|
}
|
||||||
|
|
||||||
var uri string
|
var uri string
|
||||||
|
var varBytes []byte
|
||||||
|
|
||||||
s, e, c := 0, 0, 0
|
s, e, c := 0, 0, 0
|
||||||
|
ty := 0
|
||||||
|
|
||||||
for {
|
for {
|
||||||
if c == 0 && b[e] == '#' {
|
if c == 0 && b[e] == '#' {
|
||||||
s = e
|
s = e
|
||||||
for b[e] != '\n' && e < len(b) {
|
for e < len(b) && b[e] != '\n' {
|
||||||
e++
|
e++
|
||||||
}
|
}
|
||||||
if (e - s) > 2 {
|
if (e - s) > 2 {
|
||||||
uri = strings.TrimSpace(string(b[(s + 1):e]))
|
uri = strings.TrimSpace(string(b[(s + 1):e]))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if b[e] == '{' {
|
|
||||||
|
if e >= len(b) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if matchPrefix(b, e, "query") || matchPrefix(b, e, "mutation") {
|
||||||
if c == 0 {
|
if c == 0 {
|
||||||
s = e
|
s = e
|
||||||
}
|
}
|
||||||
|
ty = AL_QUERY
|
||||||
|
} else if matchPrefix(b, e, "variables") {
|
||||||
|
if c == 0 {
|
||||||
|
s = e + len("variables") + 1
|
||||||
|
}
|
||||||
|
ty = AL_VARS
|
||||||
|
} else if b[e] == '{' {
|
||||||
c++
|
c++
|
||||||
|
|
||||||
} else if b[e] == '}' {
|
} else if b[e] == '}' {
|
||||||
c--
|
c--
|
||||||
|
|
||||||
if c == 0 {
|
if c == 0 {
|
||||||
q := b[s:(e + 1)]
|
if ty == AL_QUERY {
|
||||||
al.list[gqlHash(q)] = &allowItem{
|
q := string(b[s:(e + 1)])
|
||||||
uri: uri,
|
|
||||||
gql: string(q),
|
item := &allowItem{
|
||||||
|
uri: uri,
|
||||||
|
gql: q,
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(varBytes) != 0 {
|
||||||
|
item.vars = varBytes
|
||||||
|
}
|
||||||
|
|
||||||
|
al.list[gqlHash(q, varBytes)] = item
|
||||||
|
varBytes = nil
|
||||||
|
|
||||||
|
} else if ty == AL_VARS {
|
||||||
|
varBytes = b[s:(e + 1)]
|
||||||
}
|
}
|
||||||
|
ty = 0
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -130,7 +168,7 @@ func (al *allowList) load() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (al *allowList) save(item *allowItem) {
|
func (al *allowList) save(item *allowItem) {
|
||||||
al.list[gqlHash([]byte(item.gql))] = item
|
al.list[gqlHash(item.gql, item.vars)] = item
|
||||||
|
|
||||||
f, err := os.Create(al.filepath)
|
f, err := os.Create(al.filepath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -141,10 +179,10 @@ func (al *allowList) save(item *allowItem) {
|
||||||
defer f.Close()
|
defer f.Close()
|
||||||
|
|
||||||
keys := []string{}
|
keys := []string{}
|
||||||
urlMap := make(map[string][]string)
|
urlMap := make(map[string][]*allowItem)
|
||||||
|
|
||||||
for _, v := range al.list {
|
for _, v := range al.list {
|
||||||
urlMap[v.uri] = append(urlMap[v.uri], v.gql)
|
urlMap[v.uri] = append(urlMap[v.uri], v)
|
||||||
}
|
}
|
||||||
|
|
||||||
for k := range urlMap {
|
for k := range urlMap {
|
||||||
|
@ -159,7 +197,28 @@ func (al *allowList) save(item *allowItem) {
|
||||||
f.WriteString(fmt.Sprintf("# %s\n\n", k))
|
f.WriteString(fmt.Sprintf("# %s\n\n", k))
|
||||||
|
|
||||||
for i := range v {
|
for i := range v {
|
||||||
f.WriteString(fmt.Sprintf("query %s\n\n", v[i]))
|
if len(v[i].vars) != 0 {
|
||||||
|
vj, err := json.MarshalIndent(v[i].vars, "", "\t")
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn().Err(err).Msg("Failed to write allow list 'vars' to file")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
f.WriteString(fmt.Sprintf("variables %s\n\n", vj))
|
||||||
|
}
|
||||||
|
|
||||||
|
f.WriteString(fmt.Sprintf("%s\n\n", v[i].gql))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func matchPrefix(b []byte, i int, s string) bool {
|
||||||
|
if (len(b) - i) < len(s) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for n := 0; n < len(s); n++ {
|
||||||
|
if b[(i+n)] != s[n] {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
30
serv/core.go
30
serv/core.go
|
@ -14,6 +14,7 @@ import (
|
||||||
|
|
||||||
"github.com/cespare/xxhash/v2"
|
"github.com/cespare/xxhash/v2"
|
||||||
"github.com/dosco/super-graph/jsn"
|
"github.com/dosco/super-graph/jsn"
|
||||||
|
"github.com/dosco/super-graph/psql"
|
||||||
"github.com/dosco/super-graph/qcode"
|
"github.com/dosco/super-graph/qcode"
|
||||||
"github.com/go-pg/pg"
|
"github.com/go-pg/pg"
|
||||||
"github.com/valyala/fasttemplate"
|
"github.com/valyala/fasttemplate"
|
||||||
|
@ -42,7 +43,7 @@ func (c *coreContext) handleReq(w io.Writer, req *http.Request) error {
|
||||||
if conf.UseAllowList {
|
if conf.UseAllowList {
|
||||||
var ps *preparedItem
|
var ps *preparedItem
|
||||||
|
|
||||||
data, ps, err = c.resolvePreparedSQL([]byte(c.req.Query))
|
data, ps, err = c.resolvePreparedSQL(c.req.Query)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -52,7 +53,7 @@ func (c *coreContext) handleReq(w io.Writer, req *http.Request) error {
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
|
|
||||||
qc, err = qcompile.CompileQuery([]byte(c.req.Query))
|
qc, err = qcompile.Compile([]byte(c.req.Query))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -67,7 +68,7 @@ func (c *coreContext) handleReq(w io.Writer, req *http.Request) error {
|
||||||
return c.render(w, data)
|
return c.render(w, data)
|
||||||
}
|
}
|
||||||
|
|
||||||
sel := qc.Query.Selects
|
sel := qc.Selects
|
||||||
h := xxhash.New()
|
h := xxhash.New()
|
||||||
|
|
||||||
// fetch the field name used within the db response json
|
// fetch the field name used within the db response json
|
||||||
|
@ -252,8 +253,8 @@ func (c *coreContext) resolveRemotes(
|
||||||
return to, cerr
|
return to, cerr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *coreContext) resolvePreparedSQL(gql []byte) ([]byte, *preparedItem, error) {
|
func (c *coreContext) resolvePreparedSQL(gql string) ([]byte, *preparedItem, error) {
|
||||||
ps, ok := _preparedList[gqlHash(gql)]
|
ps, ok := _preparedList[gqlHash(gql, c.req.Vars)]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, nil, errUnauthorized
|
return nil, nil, errUnauthorized
|
||||||
}
|
}
|
||||||
|
@ -266,17 +267,22 @@ func (c *coreContext) resolvePreparedSQL(gql []byte) ([]byte, *preparedItem, err
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
fmt.Printf("PRE: %#v %#v\n", ps.stmt, vars)
|
fmt.Printf("PRE: %v\n", ps.stmt)
|
||||||
|
|
||||||
return []byte(root), ps, nil
|
return []byte(root), ps, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *coreContext) resolveSQL(qc *qcode.QCode) (
|
func (c *coreContext) resolveSQL(qc *qcode.QCode) (
|
||||||
[]byte, uint32, error) {
|
[]byte, uint32, error) {
|
||||||
|
|
||||||
stmt := &bytes.Buffer{}
|
stmt := &bytes.Buffer{}
|
||||||
|
|
||||||
skipped, err := pcompile.Compile(qc, stmt)
|
vars := make(map[string]json.RawMessage)
|
||||||
|
|
||||||
|
if err := json.Unmarshal(c.req.Vars, &vars); err != nil {
|
||||||
|
return nil, 0, err
|
||||||
|
}
|
||||||
|
|
||||||
|
skipped, err := pcompile.Compile(qc, stmt, psql.Variables(vars))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
|
@ -284,7 +290,7 @@ func (c *coreContext) resolveSQL(qc *qcode.QCode) (
|
||||||
t := fasttemplate.New(stmt.String(), openVar, closeVar)
|
t := fasttemplate.New(stmt.String(), openVar, closeVar)
|
||||||
|
|
||||||
stmt.Reset()
|
stmt.Reset()
|
||||||
_, err = t.Execute(stmt, varMap(c))
|
_, err = t.ExecuteFunc(stmt, varMap(c))
|
||||||
|
|
||||||
if err == errNoUserID &&
|
if err == errNoUserID &&
|
||||||
authFailBlock == authFailBlockPerQuery &&
|
authFailBlock == authFailBlockPerQuery &&
|
||||||
|
@ -317,10 +323,10 @@ func (c *coreContext) resolveSQL(qc *qcode.QCode) (
|
||||||
return nil, 0, err
|
return nil, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if conf.EnableTracing && len(qc.Query.Selects) != 0 {
|
if conf.EnableTracing && len(qc.Selects) != 0 {
|
||||||
c.addTrace(
|
c.addTrace(
|
||||||
qc.Query.Selects,
|
qc.Selects,
|
||||||
qc.Query.Selects[0].ID,
|
qc.Selects[0].ID,
|
||||||
st)
|
st)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,35 @@
|
||||||
|
package serv
|
||||||
|
|
||||||
|
/*
|
||||||
|
|
||||||
|
func simpleMutation(t *testing.T) {
|
||||||
|
gql := `mutation {
|
||||||
|
product(id: 15, insert: { name: "Test", price: 20.5 }) {
|
||||||
|
id
|
||||||
|
name
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
sql := `test`
|
||||||
|
|
||||||
|
backgroundCtx := context.Background()
|
||||||
|
ctx := &coreContext{Context: backgroundCtx}
|
||||||
|
|
||||||
|
resSQL, err := compileGQLToPSQL(gql)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
fmt.Println(">", string(resSQL))
|
||||||
|
|
||||||
|
if string(resSQL) != sql {
|
||||||
|
t.Fatal(errNotExpected)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCompileGQL(t *testing.T) {
|
||||||
|
t.Run("withComplexArgs", withComplexArgs)
|
||||||
|
t.Run("simpleMutation", simpleMutation)
|
||||||
|
}
|
||||||
|
|
||||||
|
*/
|
|
@ -26,13 +26,13 @@ var (
|
||||||
)
|
)
|
||||||
|
|
||||||
type gqlReq struct {
|
type gqlReq struct {
|
||||||
OpName string `json:"operationName"`
|
OpName string `json:"operationName"`
|
||||||
Query string `json:"query"`
|
Query string `json:"query"`
|
||||||
Vars variables `json:"variables"`
|
Vars json.RawMessage `json:"variables"`
|
||||||
ref string
|
ref string
|
||||||
}
|
}
|
||||||
|
|
||||||
type variables map[string]interface{}
|
type variables map[string]json.RawMessage
|
||||||
|
|
||||||
type gqlResp struct {
|
type gqlResp struct {
|
||||||
Error string `json:"error,omitempty"`
|
Error string `json:"error,omitempty"`
|
||||||
|
|
|
@ -2,9 +2,11 @@ package serv
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
|
||||||
|
"github.com/dosco/super-graph/psql"
|
||||||
"github.com/dosco/super-graph/qcode"
|
"github.com/dosco/super-graph/qcode"
|
||||||
"github.com/go-pg/pg"
|
"github.com/go-pg/pg"
|
||||||
"github.com/valyala/fasttemplate"
|
"github.com/valyala/fasttemplate"
|
||||||
|
@ -12,7 +14,7 @@ import (
|
||||||
|
|
||||||
type preparedItem struct {
|
type preparedItem struct {
|
||||||
stmt *pg.Stmt
|
stmt *pg.Stmt
|
||||||
args []string
|
args [][]byte
|
||||||
skipped uint32
|
skipped uint32
|
||||||
qc *qcode.QCode
|
qc *qcode.QCode
|
||||||
}
|
}
|
||||||
|
@ -25,36 +27,46 @@ func initPreparedList() {
|
||||||
_preparedList = make(map[string]*preparedItem)
|
_preparedList = make(map[string]*preparedItem)
|
||||||
|
|
||||||
for k, v := range _allowList.list {
|
for k, v := range _allowList.list {
|
||||||
err := prepareStmt(k, v.gql)
|
err := prepareStmt(k, v.gql, v.vars)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
panic(err)
|
panic(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func prepareStmt(key, gql string) error {
|
func prepareStmt(key, gql string, varBytes json.RawMessage) error {
|
||||||
if len(gql) == 0 || len(key) == 0 {
|
if len(gql) == 0 || len(key) == 0 {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
qc, err := qcompile.CompileQuery([]byte(gql))
|
qc, err := qcompile.Compile([]byte(gql))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var vars map[string]json.RawMessage
|
||||||
|
|
||||||
|
if len(varBytes) != 0 {
|
||||||
|
vars = make(map[string]json.RawMessage)
|
||||||
|
|
||||||
|
if err := json.Unmarshal(varBytes, &vars); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
buf := &bytes.Buffer{}
|
buf := &bytes.Buffer{}
|
||||||
|
|
||||||
skipped, err := pcompile.Compile(qc, buf)
|
skipped, err := pcompile.Compile(qc, buf, psql.Variables(vars))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
t := fasttemplate.New(buf.String(), `('{{`, `}}')`)
|
t := fasttemplate.New(buf.String(), `{{`, `}}`)
|
||||||
am := make([]string, 0, 5)
|
am := make([][]byte, 0, 5)
|
||||||
i := 0
|
i := 0
|
||||||
|
|
||||||
finalSQL := t.ExecuteFuncString(func(w io.Writer, tag string) (int, error) {
|
finalSQL := t.ExecuteFuncString(func(w io.Writer, tag string) (int, error) {
|
||||||
am = append(am, tag)
|
am = append(am, []byte(tag))
|
||||||
i++
|
i++
|
||||||
return w.Write([]byte(fmt.Sprintf("$%d", i)))
|
return w.Write([]byte(fmt.Sprintf("$%d", i)))
|
||||||
})
|
})
|
||||||
|
|
|
@ -4,8 +4,12 @@ import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"crypto/sha1"
|
"crypto/sha1"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
|
"io"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
|
||||||
"github.com/cespare/xxhash/v2"
|
"github.com/cespare/xxhash/v2"
|
||||||
|
"github.com/dosco/super-graph/jsn"
|
||||||
)
|
)
|
||||||
|
|
||||||
func mkkey(h *xxhash.Digest, k1 string, k2 string) uint64 {
|
func mkkey(h *xxhash.Digest, k1 string, k2 string) uint64 {
|
||||||
|
@ -17,8 +21,8 @@ func mkkey(h *xxhash.Digest, k1 string, k2 string) uint64 {
|
||||||
return v
|
return v
|
||||||
}
|
}
|
||||||
|
|
||||||
func gqlHash(b []byte) string {
|
func gqlHash(b string, vars []byte) string {
|
||||||
b = bytes.TrimSpace(b)
|
b = strings.TrimSpace(b)
|
||||||
h := sha1.New()
|
h := sha1.New()
|
||||||
|
|
||||||
s, e := 0, 0
|
s, e := 0, 0
|
||||||
|
@ -45,13 +49,27 @@ func gqlHash(b []byte) string {
|
||||||
if e != 0 {
|
if e != 0 {
|
||||||
b0 = b[(e - 1)]
|
b0 = b[(e - 1)]
|
||||||
}
|
}
|
||||||
h.Write(bytes.ToLower(b[s:e]))
|
io.WriteString(h, strings.ToLower(b[s:e]))
|
||||||
}
|
}
|
||||||
if e >= len(b) {
|
if e >= len(b) {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if vars == nil {
|
||||||
|
return hex.EncodeToString(h.Sum(nil))
|
||||||
|
}
|
||||||
|
|
||||||
|
fields := jsn.Keys([]byte(vars))
|
||||||
|
|
||||||
|
sort.Slice(fields, func(i, j int) bool {
|
||||||
|
return bytes.Compare(fields[i], fields[j]) == -1
|
||||||
|
})
|
||||||
|
|
||||||
|
for i := range fields {
|
||||||
|
h.Write(fields[i])
|
||||||
|
}
|
||||||
|
|
||||||
return hex.EncodeToString(h.Sum(nil))
|
return hex.EncodeToString(h.Sum(nil))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -6,7 +6,7 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestRelaxHash1(t *testing.T) {
|
func TestRelaxHash1(t *testing.T) {
|
||||||
var v1 = []byte(`
|
var v1 = `
|
||||||
products(
|
products(
|
||||||
limit: 30,
|
limit: 30,
|
||||||
|
|
||||||
|
@ -14,18 +14,18 @@ func TestRelaxHash1(t *testing.T) {
|
||||||
id
|
id
|
||||||
name
|
name
|
||||||
price
|
price
|
||||||
}`)
|
}`
|
||||||
|
|
||||||
var v2 = []byte(`
|
var v2 = `
|
||||||
products
|
products
|
||||||
(limit: 30, where: { id: { AND: { greater_or_equals: 20, lt: 28 } } }) {
|
(limit: 30, where: { id: { AND: { greater_or_equals: 20, lt: 28 } } }) {
|
||||||
id
|
id
|
||||||
name
|
name
|
||||||
price
|
price
|
||||||
} `)
|
} `
|
||||||
|
|
||||||
h1 := gqlHash(v1)
|
h1 := gqlHash(v1, nil)
|
||||||
h2 := gqlHash(v2)
|
h2 := gqlHash(v2, nil)
|
||||||
|
|
||||||
if strings.Compare(h1, h2) != 0 {
|
if strings.Compare(h1, h2) != 0 {
|
||||||
t.Fatal("Hashes don't match they should")
|
t.Fatal("Hashes don't match they should")
|
||||||
|
@ -33,7 +33,7 @@ func TestRelaxHash1(t *testing.T) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestRelaxHash2(t *testing.T) {
|
func TestRelaxHash2(t *testing.T) {
|
||||||
var v1 = []byte(`
|
var v1 = `
|
||||||
{
|
{
|
||||||
products(
|
products(
|
||||||
limit: 30
|
limit: 30
|
||||||
|
@ -49,12 +49,119 @@ func TestRelaxHash2(t *testing.T) {
|
||||||
email
|
email
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}`)
|
}`
|
||||||
|
|
||||||
var v2 = []byte(` { products( limit: 30, order_by: { price: desc }, distinct: [ price ] where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) { id name price user { id email } } } `)
|
var v2 = ` { products( limit: 30, order_by: { price: desc }, distinct: [ price ] where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) { id name price user { id email } } } `
|
||||||
|
|
||||||
h1 := gqlHash(v1)
|
h1 := gqlHash(v1, nil)
|
||||||
h2 := gqlHash(v2)
|
h2 := gqlHash(v2, nil)
|
||||||
|
|
||||||
|
if strings.Compare(h1, h2) != 0 {
|
||||||
|
t.Fatal("Hashes don't match they should")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRelaxHashWithVars1(t *testing.T) {
|
||||||
|
var q1 = `
|
||||||
|
products(
|
||||||
|
limit: 30,
|
||||||
|
|
||||||
|
where: { id: { AND: { greater_or_equals: 20, lt: 28 } } }) {
|
||||||
|
id
|
||||||
|
name
|
||||||
|
price
|
||||||
|
}`
|
||||||
|
|
||||||
|
var v1 = `
|
||||||
|
{
|
||||||
|
"insert": {
|
||||||
|
"name": "Hello",
|
||||||
|
"description": "World",
|
||||||
|
"created_at": "now",
|
||||||
|
"updated_at": "now",
|
||||||
|
"test": { "type2": "b", "type1": "a" }
|
||||||
|
},
|
||||||
|
"user": 123
|
||||||
|
}`
|
||||||
|
|
||||||
|
var q2 = `
|
||||||
|
products
|
||||||
|
(limit: 30, where: { id: { AND: { greater_or_equals: 20, lt: 28 } } }) {
|
||||||
|
id
|
||||||
|
name
|
||||||
|
price
|
||||||
|
} `
|
||||||
|
|
||||||
|
var v2 = `{
|
||||||
|
"insert": {
|
||||||
|
"created_at": "now",
|
||||||
|
"test": { "type1": "a", "type2": "b" },
|
||||||
|
"name": "Hello",
|
||||||
|
"updated_at": "now",
|
||||||
|
"description": "World"
|
||||||
|
},
|
||||||
|
"user": 123
|
||||||
|
}`
|
||||||
|
|
||||||
|
h1 := gqlHash(q1, []byte(va1))
|
||||||
|
h2 := gqlHash(q2, []byte(va2))
|
||||||
|
|
||||||
|
if strings.Compare(h1, h2) != 0 {
|
||||||
|
t.Fatal("Hashes don't match they should")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRelaxHashWithVars2(t *testing.T) {
|
||||||
|
var q1 = `
|
||||||
|
products(
|
||||||
|
limit: 30,
|
||||||
|
|
||||||
|
where: { id: { AND: { greater_or_equals: 20, lt: 28 } } }) {
|
||||||
|
id
|
||||||
|
name
|
||||||
|
price
|
||||||
|
}`
|
||||||
|
|
||||||
|
var v1 = `
|
||||||
|
{
|
||||||
|
"insert": [{
|
||||||
|
"name": "Hello",
|
||||||
|
"description": "World",
|
||||||
|
"created_at": "now",
|
||||||
|
"updated_at": "now",
|
||||||
|
"test": { "type2": "b", "type1": "a" }
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"name": "Hello",
|
||||||
|
"description": "World",
|
||||||
|
"created_at": "now",
|
||||||
|
"updated_at": "now",
|
||||||
|
"test": { "type2": "b", "type1": "a" }
|
||||||
|
}],
|
||||||
|
"user": 123
|
||||||
|
}`
|
||||||
|
|
||||||
|
var q2 = `
|
||||||
|
products
|
||||||
|
(limit: 30, where: { id: { AND: { greater_or_equals: 20, lt: 28 } } }) {
|
||||||
|
id
|
||||||
|
name
|
||||||
|
price
|
||||||
|
} `
|
||||||
|
|
||||||
|
var v2 = `{
|
||||||
|
"insert": {
|
||||||
|
"created_at": "now",
|
||||||
|
"test": { "type1": "a", "type2": "b" },
|
||||||
|
"name": "Hello",
|
||||||
|
"updated_at": "now",
|
||||||
|
"description": "World"
|
||||||
|
},
|
||||||
|
"user": 123
|
||||||
|
}`
|
||||||
|
|
||||||
|
h1 := gqlHash(q1, []byte(va1))
|
||||||
|
h2 := gqlHash(q2, []byte(va2))
|
||||||
|
|
||||||
if strings.Compare(h1, h2) != 0 {
|
if strings.Compare(h1, h2) != 0 {
|
||||||
t.Fatal("Hashes don't match they should")
|
t.Fatal("Hashes don't match they should")
|
||||||
|
|
140
serv/vars.go
140
serv/vars.go
|
@ -1,95 +1,107 @@
|
||||||
package serv
|
package serv
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"strconv"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/valyala/fasttemplate"
|
"github.com/dosco/super-graph/jsn"
|
||||||
)
|
)
|
||||||
|
|
||||||
func varMap(ctx *coreContext) variables {
|
func varMap(ctx *coreContext) func(w io.Writer, tag string) (int, error) {
|
||||||
userIDFn := func(w io.Writer, _ string) (int, error) {
|
return func(w io.Writer, tag string) (int, error) {
|
||||||
if v := ctx.Value(userIDKey); v != nil {
|
switch tag {
|
||||||
return w.Write([]byte(v.(string)))
|
case "user_id":
|
||||||
}
|
if v := ctx.Value(userIDKey); v != nil {
|
||||||
return 0, errNoUserID
|
return stringVar(w, v.(string))
|
||||||
}
|
}
|
||||||
|
return 0, errNoUserID
|
||||||
|
|
||||||
userIDProviderFn := func(w io.Writer, _ string) (int, error) {
|
case "user_id_provider":
|
||||||
if v := ctx.Value(userIDProviderKey); v != nil {
|
if v := ctx.Value(userIDProviderKey); v != nil {
|
||||||
return w.Write([]byte(v.(string)))
|
return stringVar(w, v.(string))
|
||||||
}
|
}
|
||||||
return 0, errNoUserID
|
return 0, errNoUserID
|
||||||
}
|
|
||||||
|
|
||||||
userIDTag := fasttemplate.TagFunc(userIDFn)
|
|
||||||
userIDProviderTag := fasttemplate.TagFunc(userIDProviderFn)
|
|
||||||
|
|
||||||
vm := variables{
|
|
||||||
"user_id": userIDTag,
|
|
||||||
"user_id_provider": userIDProviderTag,
|
|
||||||
"USER_ID": userIDTag,
|
|
||||||
"USER_ID_PROVIDER": userIDProviderTag,
|
|
||||||
}
|
|
||||||
|
|
||||||
for k, v := range ctx.req.Vars {
|
|
||||||
var buf []byte
|
|
||||||
k = strings.ToLower(k)
|
|
||||||
|
|
||||||
if _, ok := vm[k]; ok {
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
switch val := v.(type) {
|
fields := jsn.Get(ctx.req.Vars, [][]byte{[]byte(tag)})
|
||||||
case string:
|
if len(fields) == 0 {
|
||||||
vm[k] = val
|
return 0, fmt.Errorf("variable '%s' not found", tag)
|
||||||
case int:
|
|
||||||
vm[k] = strconv.AppendInt(buf, int64(val), 10)
|
|
||||||
case int64:
|
|
||||||
vm[k] = strconv.AppendInt(buf, val, 10)
|
|
||||||
case float64:
|
|
||||||
vm[k] = strconv.AppendFloat(buf, val, 'f', -1, 64)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
is := false
|
||||||
|
|
||||||
|
for i := range fields[0].Value {
|
||||||
|
c := fields[0].Value[i]
|
||||||
|
if c != ' ' {
|
||||||
|
is = (c == '"') || (c == '{') || (c == '[')
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if is {
|
||||||
|
return stringVarB(w, fields[0].Value)
|
||||||
|
}
|
||||||
|
|
||||||
|
w.Write(fields[0].Value)
|
||||||
|
return 0, nil
|
||||||
}
|
}
|
||||||
return vm
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func varList(ctx *coreContext, args []string) []interface{} {
|
func varList(ctx *coreContext, args [][]byte) []interface{} {
|
||||||
vars := make([]interface{}, 0, len(args))
|
vars := make([]interface{}, len(args))
|
||||||
|
|
||||||
for k, v := range ctx.req.Vars {
|
var fields map[string]interface{}
|
||||||
ctx.req.Vars[strings.ToLower(k)] = v
|
var err error
|
||||||
|
|
||||||
|
if len(ctx.req.Vars) != 0 {
|
||||||
|
fields, _, err = jsn.Tree(ctx.req.Vars)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
logger.Warn().Err(err).Msg("Failed to parse variables")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for i := range args {
|
for i := range args {
|
||||||
arg := strings.ToLower(args[i])
|
av := args[i]
|
||||||
|
|
||||||
if arg == "user_id" {
|
switch {
|
||||||
|
case bytes.Equal(av, []byte("user_id")):
|
||||||
if v := ctx.Value(userIDKey); v != nil {
|
if v := ctx.Value(userIDKey); v != nil {
|
||||||
vars = append(vars, v.(string))
|
vars[i] = v.(string)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if arg == "user_id_provider" {
|
case bytes.Equal(av, []byte("user_id_provider")):
|
||||||
if v := ctx.Value(userIDProviderKey); v != nil {
|
if v := ctx.Value(userIDProviderKey); v != nil {
|
||||||
vars = append(vars, v.(string))
|
vars[i] = v.(string)
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if v, ok := ctx.req.Vars[arg]; ok {
|
default:
|
||||||
switch val := v.(type) {
|
if v, ok := fields[string(av)]; ok {
|
||||||
case string:
|
vars[i] = v
|
||||||
vars = append(vars, val)
|
|
||||||
case int:
|
|
||||||
vars = append(vars, strconv.FormatInt(int64(val), 10))
|
|
||||||
case int64:
|
|
||||||
vars = append(vars, strconv.FormatInt(int64(val), 10))
|
|
||||||
case float64:
|
|
||||||
vars = append(vars, strconv.FormatFloat(val, 'f', -1, 64))
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return vars
|
return vars
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func stringVar(w io.Writer, v string) (int, error) {
|
||||||
|
if n, err := w.Write([]byte(`'`)); err != nil {
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
if n, err := w.Write([]byte(v)); err != nil {
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
return w.Write([]byte(`'`))
|
||||||
|
}
|
||||||
|
|
||||||
|
func stringVarB(w io.Writer, v []byte) (int, error) {
|
||||||
|
if n, err := w.Write([]byte(`'`)); err != nil {
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
if n, err := w.Write(v); err != nil {
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
return w.Write([]byte(`'`))
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue