Add insert mutation with bulk insert
This commit is contained in:
parent
5b9105ff0c
commit
c0a21e448f
|
@ -47,4 +47,79 @@ query {
|
|||
email
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
variables {
|
||||
"insert": {
|
||||
"name": "Hello",
|
||||
"description": "World",
|
||||
"created_at": "now",
|
||||
"updated_at": "now"
|
||||
},
|
||||
"user": 123
|
||||
}
|
||||
|
||||
mutation {
|
||||
products(insert: $insert) {
|
||||
id
|
||||
name
|
||||
description
|
||||
}
|
||||
}
|
||||
|
||||
variables {
|
||||
"insert": {
|
||||
"name": "Hello",
|
||||
"description": "World",
|
||||
"created_at": "now",
|
||||
"updated_at": "now"
|
||||
},
|
||||
"user": 123
|
||||
}
|
||||
|
||||
mutation {
|
||||
products(insert: $insert) {
|
||||
id
|
||||
}
|
||||
}
|
||||
|
||||
variables {
|
||||
"insert": {
|
||||
"description": "World3",
|
||||
"name": "Hello3",
|
||||
"created_at": "now",
|
||||
"updated_at": "now"
|
||||
},
|
||||
"user": 123
|
||||
}
|
||||
|
||||
{
|
||||
customers {
|
||||
id
|
||||
email
|
||||
payments {
|
||||
customer_id
|
||||
amount
|
||||
billing_details
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
variables {
|
||||
"insert": {
|
||||
"description": "World3",
|
||||
"name": "Hello3",
|
||||
"created_at": "now",
|
||||
"updated_at": "now"
|
||||
},
|
||||
"user": 123
|
||||
}
|
||||
|
||||
{
|
||||
me {
|
||||
id
|
||||
full_name
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -84,8 +84,9 @@ database:
|
|||
#log_level: "debug"
|
||||
|
||||
# Define variables here that you want to use in filters
|
||||
# sub-queries must be wrapped in ()
|
||||
variables:
|
||||
account_id: "select account_id from users where id = $user_id"
|
||||
account_id: "(select account_id from users where id = $user_id)"
|
||||
|
||||
# Define defaults to for the field key and values below
|
||||
defaults:
|
||||
|
@ -105,12 +106,12 @@ database:
|
|||
# This filter will overwrite defaults.filter
|
||||
# filter: ["{ id: { eq: $user_id } }"]
|
||||
|
||||
- name: products
|
||||
# Multiple filters are AND'd together
|
||||
filter: [
|
||||
"{ price: { gt: 0 } }",
|
||||
"{ price: { lt: 8 } }"
|
||||
]
|
||||
# - name: products
|
||||
# # Multiple filters are AND'd together
|
||||
# filter: [
|
||||
# "{ price: { gt: 0 } }",
|
||||
# "{ price: { lt: 8 } }"
|
||||
# ]
|
||||
|
||||
- name: customers
|
||||
# No filter is used for this field not
|
||||
|
|
|
@ -82,8 +82,9 @@ database:
|
|||
#log_level: "debug"
|
||||
|
||||
# Define variables here that you want to use in filters
|
||||
# sub-queries must be wrapped in ()
|
||||
variables:
|
||||
account_id: "select account_id from users where id = $user_id"
|
||||
account_id: "(select account_id from users where id = $user_id)"
|
||||
|
||||
# Define defaults to for the field key and values below
|
||||
defaults:
|
||||
|
|
|
@ -2,6 +2,8 @@ version: '3.4'
|
|||
services:
|
||||
db:
|
||||
image: postgres
|
||||
ports:
|
||||
- "5432:5432"
|
||||
|
||||
# redis:
|
||||
# image: redis:alpine
|
||||
|
|
|
@ -0,0 +1,39 @@
|
|||
# JSN - Fast low allocation JSON library
|
||||
## Design
|
||||
|
||||
This libary is designed as a set of seperate functions to extract data and mutate
|
||||
JSON. All functions are focused on keeping allocations to a minimum and be as fast
|
||||
as possible. The functions don't validate the JSON a seperate `Validate` function
|
||||
does that.
|
||||
|
||||
The JSON parsing algo processes each object `{}` or array `[]` in a bottom up way
|
||||
where once the end of the array or object is found only then the keys within it are
|
||||
parsed from the top down.
|
||||
|
||||
```
|
||||
{"id":1,"posts": [{"title":"PT1-1","description":"PD1-1"}], "full_name":"FN1","email":"E1" }
|
||||
|
||||
id: 1
|
||||
|
||||
posts: [{"title":"PT1-1","description":"PD1-1"}]
|
||||
|
||||
[{"title":"PT1-1","description":"PD1-1"}]
|
||||
|
||||
{"title":"PT1-1","description":"PD1-1"}
|
||||
|
||||
title: "PT1-1"
|
||||
|
||||
description: "PD1-1
|
||||
|
||||
full_name: "FN1"
|
||||
|
||||
email: "E1"
|
||||
```
|
||||
|
||||
## Functions
|
||||
|
||||
- Strip: Strip a path from the root to a child node and return the rest
|
||||
- Replace: Replace values by key
|
||||
- Get: Get all keys
|
||||
- Filter: Extract specific keys from an object
|
||||
- Tree: Fetch unique keys from an array or object
|
|
@ -43,7 +43,7 @@ func Get(b []byte, keys [][]byte) []Field {
|
|||
kmap[xxhash.Sum64(keys[i])] = struct{}{}
|
||||
}
|
||||
|
||||
res := make([]Field, 20)
|
||||
res := make([]Field, 0, 20)
|
||||
|
||||
s, e, d := 0, 0, 0
|
||||
|
||||
|
@ -127,7 +127,7 @@ func Get(b []byte, keys [][]byte) []Field {
|
|||
_, ok := kmap[xxhash.Sum64(k)]
|
||||
|
||||
if ok {
|
||||
res[n] = Field{k, b[s:(e + 1)]}
|
||||
res = append(res, Field{k, b[s:(e + 1)]})
|
||||
n++
|
||||
}
|
||||
|
||||
|
|
|
@ -21,6 +21,10 @@ var (
|
|||
"full_name": "Caroll Orn Sr.",
|
||||
"email": "joannarau@hegmann.io",
|
||||
"__twitter_id": "ABC123"
|
||||
"more": [{
|
||||
"__twitter_id": "more123",
|
||||
"hello: "world
|
||||
}]
|
||||
}
|
||||
},
|
||||
{
|
||||
|
@ -163,6 +167,7 @@ func TestGet(t *testing.T) {
|
|||
{[]byte("__twitter_id"), []byte(`"ABCD"`)},
|
||||
{[]byte("__twitter_id"), []byte(`"2048666903444506956"`)},
|
||||
{[]byte("__twitter_id"), []byte(`"ABC123"`)},
|
||||
{[]byte("__twitter_id"), []byte(`"more123"`)},
|
||||
{[]byte("__twitter_id"),
|
||||
[]byte(`[{ "name": "hello" }, { "name": "world"}]`)},
|
||||
{[]byte("__twitter_id"),
|
||||
|
@ -340,6 +345,80 @@ func TestReplaceEmpty(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestKeys1(t *testing.T) {
|
||||
json := `[{"id":1,"posts": [{"title":"PT1-1","description":"PD1-1"}, {"title":"PT1-2","description":"PD1-2"}], "full_name":"FN1","email":"E1","books": [{"name":"BN1-1","description":"BD1-1"},{"name":"BN1-2","description":"BD1-2"},{"name":"BN1-2","description":"BD1-2"}]},{"id":1,"posts": [{"title":"PT1-1","description":"PD1-1"}, {"title":"PT1-2","description":"PD1-2"}], "full_name":"FN1","email":"E1","books": [{"name":"BN1-1","description":"BD1-1"},{"name":"BN1-2","description":"BD1-2"},{"name":"BN1-2","description":"BD1-2"}]},{"id":1,"posts": [{"title":"PT1-1","description":"PD1-1"}, {"title":"PT1-2","description":"PD1-2"}], "full_name":"FN1","email":"E1","books": [{"name":"BN1-1","description":"BD1-1"},{"name":"BN1-2","description":"BD1-2"},{"name":"BN1-2","description":"BD1-2"}]}]`
|
||||
|
||||
fields := Keys([]byte(json))
|
||||
|
||||
exp := []string{
|
||||
"id", "posts", "title", "description", "full_name", "email", "books", "name", "description",
|
||||
}
|
||||
|
||||
if len(exp) != len(fields) {
|
||||
t.Errorf("Expected %d fields %d", len(exp), len(fields))
|
||||
}
|
||||
|
||||
for i := range exp {
|
||||
if string(fields[i]) != exp[i] {
|
||||
t.Errorf("Expected field '%s' got '%s'", string(exp[i]), fields[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeys2(t *testing.T) {
|
||||
json := `{"id":1,"posts": [{"title":"PT1-1","description":"PD1-1"}, {"title":"PT1-2","description":"PD1-2"}], "full_name":"FN1","email":"E1","books": [{"name":"BN1-1","description":"BD1-1"},{"name":"BN1-2","description":"BD1-2"},{"name":"BN1-2","description":"BD1-2"}]}`
|
||||
|
||||
fields := Keys([]byte(json))
|
||||
|
||||
exp := []string{
|
||||
"id", "posts", "title", "description", "full_name", "email", "books", "name", "description",
|
||||
}
|
||||
|
||||
// for i := range fields {
|
||||
// fmt.Println("-->", string(fields[i]))
|
||||
// }
|
||||
|
||||
if len(exp) != len(fields) {
|
||||
t.Errorf("Expected %d fields %d", len(exp), len(fields))
|
||||
}
|
||||
|
||||
for i := range exp {
|
||||
if string(fields[i]) != exp[i] {
|
||||
t.Errorf("Expected field '%s' got '%s'", string(exp[i]), fields[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestKeys3(t *testing.T) {
|
||||
json := `{
|
||||
"insert": {
|
||||
"created_at": "now",
|
||||
"test": { "type1": "a", "type2": "b" },
|
||||
"name": "Hello",
|
||||
"updated_at": "now",
|
||||
"description": "World"
|
||||
},
|
||||
"user": 123
|
||||
}`
|
||||
|
||||
fields := Keys([]byte(json))
|
||||
|
||||
exp := []string{
|
||||
"insert", "created_at", "test", "type1", "type2", "name", "updated_at", "description",
|
||||
"user",
|
||||
}
|
||||
|
||||
if len(exp) != len(fields) {
|
||||
t.Errorf("Expected %d fields %d", len(exp), len(fields))
|
||||
}
|
||||
|
||||
for i := range exp {
|
||||
if string(fields[i]) != exp[i] {
|
||||
t.Errorf("Expected field '%s' got '%s'", string(exp[i]), fields[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkGet(b *testing.B) {
|
||||
b.ReportAllocs()
|
||||
|
||||
|
|
|
@ -0,0 +1,122 @@
|
|||
package jsn
|
||||
|
||||
func Keys(b []byte) [][]byte {
|
||||
res := make([][]byte, 0, 20)
|
||||
|
||||
s, e, d := 0, 0, 0
|
||||
|
||||
var k []byte
|
||||
state := expectValue
|
||||
|
||||
st := NewStack()
|
||||
ae := 0
|
||||
|
||||
for i := 0; i < len(b); i++ {
|
||||
if state == expectObjClose || state == expectListClose {
|
||||
switch b[i] {
|
||||
case '{', '[':
|
||||
d++
|
||||
case '}', ']':
|
||||
d--
|
||||
}
|
||||
}
|
||||
|
||||
si := st.Peek()
|
||||
|
||||
switch {
|
||||
case state == expectKey && si != nil && i >= si.ss:
|
||||
i = si.se + 1
|
||||
st.Pop()
|
||||
|
||||
case state == expectKey && b[i] == '{':
|
||||
state = expectObjClose
|
||||
s = i
|
||||
d++
|
||||
|
||||
case state == expectObjClose && d == 0 && b[i] == '}':
|
||||
state = expectKey
|
||||
if ae != 0 {
|
||||
st.Push(skipInfo{i, ae})
|
||||
ae = 0
|
||||
}
|
||||
e = i
|
||||
i = s
|
||||
|
||||
case state == expectKey && b[i] == '"':
|
||||
state = expectKeyClose
|
||||
s = i
|
||||
|
||||
case state == expectKeyClose && b[i] == '"':
|
||||
state = expectColon
|
||||
k = b[(s + 1):i]
|
||||
|
||||
case state == expectColon && b[i] == ':':
|
||||
state = expectValue
|
||||
|
||||
case state == expectValue && b[i] == '"':
|
||||
state = expectString
|
||||
s = i
|
||||
|
||||
case state == expectString && b[i] == '"':
|
||||
e = i
|
||||
|
||||
case state == expectValue && b[i] == '{':
|
||||
state = expectObjClose
|
||||
s = i
|
||||
d++
|
||||
|
||||
case state == expectObjClose && d == 0 && b[i] == '}':
|
||||
state = expectKey
|
||||
e = i
|
||||
i = s
|
||||
|
||||
case state == expectValue && b[i] == '[':
|
||||
state = expectListClose
|
||||
s = i
|
||||
d++
|
||||
|
||||
case state == expectListClose && d == 0 && b[i] == ']':
|
||||
state = expectKey
|
||||
ae = i
|
||||
e = i
|
||||
i = s
|
||||
|
||||
case state == expectValue && (b[i] >= '0' && b[i] <= '9'):
|
||||
state = expectNumClose
|
||||
s = i
|
||||
|
||||
case state == expectNumClose &&
|
||||
((b[i] < '0' || b[i] > '9') &&
|
||||
(b[i] != '.' && b[i] != 'e' && b[i] != 'E' && b[i] != '+' && b[i] != '-')):
|
||||
i--
|
||||
e = i
|
||||
|
||||
case state == expectValue &&
|
||||
(b[i] == 'f' || b[i] == 'F' || b[i] == 't' || b[i] == 'T'):
|
||||
state = expectBoolClose
|
||||
s = i
|
||||
|
||||
case state == expectBoolClose && (b[i] == 'e' || b[i] == 'E'):
|
||||
e = i
|
||||
|
||||
case state == expectValue && b[i] == 'n':
|
||||
state = expectNull
|
||||
|
||||
case state == expectNull && b[i] == 'l':
|
||||
e = i
|
||||
}
|
||||
|
||||
if e != 0 {
|
||||
if k != nil {
|
||||
res = append(res, k)
|
||||
}
|
||||
|
||||
state = expectKey
|
||||
k = nil
|
||||
e = 0
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
return res
|
||||
}
|
|
@ -0,0 +1,51 @@
|
|||
package jsn
|
||||
|
||||
type skipInfo struct {
|
||||
ss, se int
|
||||
}
|
||||
|
||||
type Stack struct {
|
||||
stA [20]skipInfo
|
||||
st []skipInfo
|
||||
top int
|
||||
}
|
||||
|
||||
// Create a new Stack
|
||||
func NewStack() *Stack {
|
||||
s := &Stack{top: -1}
|
||||
s.st = s.stA[:0]
|
||||
return s
|
||||
}
|
||||
|
||||
// Return the number of items in the Stack
|
||||
func (s *Stack) Len() int {
|
||||
return (s.top + 1)
|
||||
}
|
||||
|
||||
// View the top item on the Stack
|
||||
func (s *Stack) Peek() *skipInfo {
|
||||
if s.top == -1 {
|
||||
return nil
|
||||
}
|
||||
return &s.st[s.top]
|
||||
}
|
||||
|
||||
// Pop the top item of the Stack and return it
|
||||
func (s *Stack) Pop() *skipInfo {
|
||||
if s.top == -1 {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.top--
|
||||
return &s.st[(s.top + 1)]
|
||||
}
|
||||
|
||||
// Push a value onto the top of the Stack
|
||||
func (s *Stack) Push(value skipInfo) {
|
||||
s.top++
|
||||
if len(s.st) <= s.top {
|
||||
s.st = append(s.st, value)
|
||||
} else {
|
||||
s.st[s.top] = value
|
||||
}
|
||||
}
|
|
@ -0,0 +1,37 @@
|
|||
package jsn
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
)
|
||||
|
||||
func Tree(v []byte) (map[string]interface{}, bool, error) {
|
||||
dec := json.NewDecoder(bytes.NewReader(v))
|
||||
array := false
|
||||
|
||||
// read open bracket
|
||||
|
||||
for i := range v {
|
||||
if v[i] != ' ' {
|
||||
array = (v[i] == '[')
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if array {
|
||||
if _, err := dec.Token(); err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
}
|
||||
|
||||
// while the array contains values
|
||||
var m map[string]interface{}
|
||||
|
||||
// decode an array value (Message)
|
||||
err := dec.Decode(&m)
|
||||
if err != nil {
|
||||
return nil, false, err
|
||||
}
|
||||
|
||||
return m, array, nil
|
||||
}
|
|
@ -1,7 +0,0 @@
|
|||
goos: darwin
|
||||
goarch: amd64
|
||||
pkg: github.com/dosco/super-graph/psql
|
||||
BenchmarkCompile-8 100000 16476 ns/op 3282 B/op 66 allocs/op
|
||||
BenchmarkCompileParallel-8 300000 4639 ns/op 3324 B/op 66 allocs/op
|
||||
PASS
|
||||
ok github.com/dosco/super-graph/psql 3.274s
|
|
@ -1,7 +0,0 @@
|
|||
goos: darwin
|
||||
goarch: amd64
|
||||
pkg: github.com/dosco/super-graph/psql
|
||||
BenchmarkCompile-8 100000 15728 ns/op 3000 B/op 60 allocs/op
|
||||
BenchmarkCompileParallel-8 300000 5077 ns/op 3023 B/op 60 allocs/op
|
||||
PASS
|
||||
ok github.com/dosco/super-graph/psql 3.318s
|
|
@ -1,7 +0,0 @@
|
|||
goos: darwin
|
||||
goarch: amd64
|
||||
pkg: github.com/dosco/super-graph/psql
|
||||
BenchmarkCompile-8 1000000 15997 ns/op 3048 B/op 58 allocs/op
|
||||
BenchmarkCompileParallel-8 3000000 4722 ns/op 3073 B/op 58 allocs/op
|
||||
PASS
|
||||
ok github.com/dosco/super-graph/psql 35.024s
|
|
@ -1,7 +0,0 @@
|
|||
goos: darwin
|
||||
goarch: amd64
|
||||
pkg: github.com/dosco/super-graph/psql
|
||||
BenchmarkCompile-8 100000 16829 ns/op 2887 B/op 57 allocs/op
|
||||
BenchmarkCompileParallel-8 300000 5450 ns/op 2911 B/op 57 allocs/op
|
||||
PASS
|
||||
ok github.com/dosco/super-graph/psql 3.561s
|
|
@ -0,0 +1,90 @@
|
|||
package psql
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
|
||||
"github.com/dosco/super-graph/jsn"
|
||||
"github.com/dosco/super-graph/qcode"
|
||||
)
|
||||
|
||||
func (co *Compiler) compileMutation(qc *qcode.QCode, w *bytes.Buffer, vars Variables) (uint32, error) {
|
||||
if len(qc.Selects) == 0 {
|
||||
return 0, errors.New("empty query")
|
||||
}
|
||||
|
||||
c := &compilerContext{w, qc.Selects, co}
|
||||
root := &qc.Selects[0]
|
||||
|
||||
c.w.WriteString(`WITH `)
|
||||
c.w.WriteString(root.Table)
|
||||
c.w.WriteString(` AS (`)
|
||||
|
||||
if _, err := c.renderInsert(qc, w, vars); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
c.w.WriteString(`) `)
|
||||
|
||||
return c.compileQuery(qc, w)
|
||||
}
|
||||
|
||||
func (c *compilerContext) renderInsert(qc *qcode.QCode, w *bytes.Buffer, vars Variables) (uint32, error) {
|
||||
root := &qc.Selects[0]
|
||||
|
||||
insert, ok := vars["insert"]
|
||||
if !ok {
|
||||
return 0, errors.New("Variable 'insert' not defined")
|
||||
}
|
||||
|
||||
jt, array, err := jsn.Tree(insert)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
c.w.WriteString(`WITH input AS (SELECT {{insert}}::json AS j) INSERT INTO `)
|
||||
c.w.WriteString(root.Table)
|
||||
io.WriteString(c.w, " (")
|
||||
c.renderInsertColumns(qc, w, jt)
|
||||
io.WriteString(c.w, ")")
|
||||
|
||||
c.w.WriteString(` SELECT `)
|
||||
c.renderInsertColumns(qc, w, jt)
|
||||
c.w.WriteString(` FROM input i, `)
|
||||
|
||||
if array {
|
||||
c.w.WriteString(`json_populate_recordset`)
|
||||
} else {
|
||||
c.w.WriteString(`json_populate_record`)
|
||||
}
|
||||
|
||||
c.w.WriteString(`(NULL::`)
|
||||
c.w.WriteString(root.Table)
|
||||
c.w.WriteString(`, i.j) t RETURNING * `)
|
||||
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (c *compilerContext) renderInsertColumns(qc *qcode.QCode, w *bytes.Buffer,
|
||||
jt map[string]interface{}) (uint32, error) {
|
||||
|
||||
ti, err := c.schema.GetTable(qc.Selects[0].Table)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
i := 0
|
||||
for _, cn := range ti.ColumnNames {
|
||||
if _, ok := jt[cn]; !ok {
|
||||
continue
|
||||
}
|
||||
if i != 0 {
|
||||
io.WriteString(c.w, ", ")
|
||||
}
|
||||
c.w.WriteString(cn)
|
||||
i++
|
||||
}
|
||||
|
||||
return 0, nil
|
||||
}
|
|
@ -0,0 +1,65 @@
|
|||
package psql
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func singleInsert(t *testing.T) {
|
||||
gql := `mutation {
|
||||
product(id: 15, insert: $insert) {
|
||||
id
|
||||
name
|
||||
}
|
||||
}`
|
||||
|
||||
sql := `test`
|
||||
|
||||
vars := map[string]json.RawMessage{
|
||||
"insert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`),
|
||||
}
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, vars)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
fmt.Println(">", string(resSQL))
|
||||
|
||||
if string(resSQL) != sql {
|
||||
t.Fatal(errNotExpected)
|
||||
}
|
||||
}
|
||||
|
||||
func bulkInsert(t *testing.T) {
|
||||
gql := `mutation {
|
||||
product(id: 15, insert: $insert) {
|
||||
id
|
||||
name
|
||||
}
|
||||
}`
|
||||
|
||||
sql := `test`
|
||||
|
||||
vars := map[string]json.RawMessage{
|
||||
"insert": json.RawMessage(` [{ "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }]`),
|
||||
}
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql, vars)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
fmt.Println(">", string(resSQL))
|
||||
|
||||
if string(resSQL) != sql {
|
||||
t.Fatal(errNotExpected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileInsert(t *testing.T) {
|
||||
t.Run("singleInsert", singleInsert)
|
||||
t.Run("bulkInsert", bulkInsert)
|
||||
|
||||
}
|
|
@ -2,6 +2,7 @@ package psql
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
@ -18,6 +19,8 @@ const (
|
|||
closeBlock = 500
|
||||
)
|
||||
|
||||
type Variables map[string]json.RawMessage
|
||||
|
||||
type Config struct {
|
||||
Schema *DBSchema
|
||||
Vars map[string]string
|
||||
|
@ -51,19 +54,30 @@ type compilerContext struct {
|
|||
*Compiler
|
||||
}
|
||||
|
||||
func (co *Compiler) CompileEx(qc *qcode.QCode) (uint32, []byte, error) {
|
||||
func (co *Compiler) CompileEx(qc *qcode.QCode, vars Variables) (uint32, []byte, error) {
|
||||
w := &bytes.Buffer{}
|
||||
skipped, err := co.Compile(qc, w)
|
||||
skipped, err := co.Compile(qc, w, vars)
|
||||
return skipped, w.Bytes(), err
|
||||
}
|
||||
|
||||
func (co *Compiler) Compile(qc *qcode.QCode, w *bytes.Buffer) (uint32, error) {
|
||||
if len(qc.Query.Selects) == 0 {
|
||||
func (co *Compiler) Compile(qc *qcode.QCode, w *bytes.Buffer, vars Variables) (uint32, error) {
|
||||
switch qc.Type {
|
||||
case qcode.QTQuery:
|
||||
return co.compileQuery(qc, w)
|
||||
case qcode.QTMutation:
|
||||
return co.compileMutation(qc, w, vars)
|
||||
}
|
||||
|
||||
return 0, errors.New("unknown operation")
|
||||
}
|
||||
|
||||
func (co *Compiler) compileQuery(qc *qcode.QCode, w *bytes.Buffer) (uint32, error) {
|
||||
if len(qc.Selects) == 0 {
|
||||
return 0, errors.New("empty query")
|
||||
}
|
||||
|
||||
c := &compilerContext{w, qc.Query.Selects, co}
|
||||
root := &qc.Query.Selects[0]
|
||||
c := &compilerContext{w, qc.Selects, co}
|
||||
root := &qc.Selects[0]
|
||||
|
||||
st := NewStack()
|
||||
st.Push(root.ID + closeBlock)
|
||||
|
@ -844,7 +858,7 @@ func (c *compilerContext) renderList(ex *qcode.Exp) {
|
|||
func (c *compilerContext) renderVal(ex *qcode.Exp,
|
||||
vars map[string]string) {
|
||||
|
||||
io.WriteString(c.w, ` (`)
|
||||
//io.WriteString(c.w, ` (`)
|
||||
switch ex.Type {
|
||||
case qcode.ValBool, qcode.ValInt, qcode.ValFloat:
|
||||
if len(ex.Val) != 0 {
|
||||
|
@ -852,21 +866,23 @@ func (c *compilerContext) renderVal(ex *qcode.Exp,
|
|||
} else {
|
||||
c.w.WriteString(`''`)
|
||||
}
|
||||
|
||||
case qcode.ValStr:
|
||||
c.w.WriteString(`'`)
|
||||
c.w.WriteString(ex.Val)
|
||||
c.w.WriteString(`'`)
|
||||
|
||||
case qcode.ValVar:
|
||||
if val, ok := vars[ex.Val]; ok {
|
||||
c.w.WriteString(val)
|
||||
} else {
|
||||
//fmt.Fprintf(w, `'{{%s}}'`, ex.Val)
|
||||
c.w.WriteString(`'{{`)
|
||||
c.w.WriteString(`{{`)
|
||||
c.w.WriteString(ex.Val)
|
||||
c.w.WriteString(`}}'`)
|
||||
c.w.WriteString(`}}`)
|
||||
}
|
||||
}
|
||||
c.w.WriteString(`)`)
|
||||
//c.w.WriteString(`)`)
|
||||
}
|
||||
|
||||
func funcPrefixLen(fn string) int {
|
|
@ -125,13 +125,13 @@ func TestMain(m *testing.M) {
|
|||
os.Exit(m.Run())
|
||||
}
|
||||
|
||||
func compileGQLToPSQL(gql string) ([]byte, error) {
|
||||
func compileGQLToPSQL(gql string, vars Variables) ([]byte, error) {
|
||||
qc, err := qcompile.Compile([]byte(gql))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
_, sqlStmt, err := pcompile.CompileEx(qc)
|
||||
_, sqlStmt, err := pcompile.CompileEx(qc, vars)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -164,7 +164,7 @@ func withComplexArgs(t *testing.T) {
|
|||
|
||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products" ORDER BY "products_0_price_ob" DESC), '[]') AS "products" FROM (SELECT DISTINCT ON ("products_0_price_ob") row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products", "products_0"."price" AS "products_0_price_ob" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("products"."id") < (28)) AND (("products"."id") >= (20))) LIMIT ('30') :: integer) AS "products_0" ORDER BY "products_0_price_ob" DESC LIMIT ('30') :: integer) AS "products_0") AS "done_1337";`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql)
|
||||
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -192,7 +192,7 @@ func withWhereMultiOr(t *testing.T) {
|
|||
|
||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("products"."price") < (20)) OR (("products"."price") > (10)) OR NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql)
|
||||
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -218,7 +218,7 @@ func withWhereIsNull(t *testing.T) {
|
|||
|
||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("products"."price") > (10)) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql)
|
||||
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -244,7 +244,7 @@ func withWhereAndList(t *testing.T) {
|
|||
|
||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("products"."price") > (10)) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql)
|
||||
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -264,7 +264,7 @@ func fetchByID(t *testing.T) {
|
|||
|
||||
sql := `SELECT json_object_agg('product', product) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "product_0"."id" AS "id", "product_0"."name" AS "name") AS "sel_0")) AS "product" FROM (SELECT "product"."id", "product"."name" FROM "products" AS "product" WHERE ((("product"."price") > (0)) AND (("product"."price") < (8)) AND (("id") = (15))) LIMIT ('1') :: integer) AS "product_0" LIMIT ('1') :: integer) AS "done_1337";`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql)
|
||||
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -284,7 +284,7 @@ func searchQuery(t *testing.T) {
|
|||
|
||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("tsv") @@ to_tsquery('Imperial'))) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql)
|
||||
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -307,7 +307,7 @@ func oneToMany(t *testing.T) {
|
|||
|
||||
sql := `SELECT json_object_agg('users', users) FROM (SELECT coalesce(json_agg("users"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."email" AS "email", "products_1_join"."products" AS "products") AS "sel_0")) AS "users" FROM (SELECT "users"."email", "users"."id" FROM "users" WHERE ((("users"."id") = ('{{user_id}}'))) LIMIT ('20') :: integer) AS "users_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "products_1"."name" AS "name", "products_1"."price" AS "price") AS "sel_1")) AS "products" FROM (SELECT "products"."name", "products"."price" FROM "products" WHERE ((("products"."user_id") = ("users_0"."id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "products_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "users_0") AS "done_1337";`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql)
|
||||
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -330,7 +330,7 @@ func belongsTo(t *testing.T) {
|
|||
|
||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."price" AS "price", "users_1_join"."users" AS "users") AS "sel_0")) AS "products" FROM (SELECT "products"."name", "products"."price", "products"."user_id" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8))) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("users"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "users_1"."email" AS "email") AS "sel_1")) AS "users" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = ("products_0"."user_id"))) LIMIT ('20') :: integer) AS "users_1" LIMIT ('20') :: integer) AS "users_1") AS "users_1_join" ON ('true') LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql)
|
||||
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -353,7 +353,7 @@ func manyToMany(t *testing.T) {
|
|||
|
||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "customers_1_join"."customers" AS "customers") AS "sel_0")) AS "products" FROM (SELECT "products"."name", "products"."id" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8))) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("customers"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "customers_1"."email" AS "email", "customers_1"."full_name" AS "full_name") AS "sel_1")) AS "customers" FROM (SELECT "customers"."email", "customers"."full_name" FROM "customers" LEFT OUTER JOIN "purchases" ON (("purchases"."product_id") = ("products_0"."id")) WHERE ((("customers"."id") = ("purchases"."customer_id"))) LIMIT ('20') :: integer) AS "customers_1" LIMIT ('20') :: integer) AS "customers_1") AS "customers_1_join" ON ('true') LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql)
|
||||
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -376,7 +376,7 @@ func manyToManyReverse(t *testing.T) {
|
|||
|
||||
sql := `SELECT json_object_agg('customers', customers) FROM (SELECT coalesce(json_agg("customers"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "customers_0"."email" AS "email", "customers_0"."full_name" AS "full_name", "products_1_join"."products" AS "products") AS "sel_0")) AS "customers" FROM (SELECT "customers"."email", "customers"."full_name", "customers"."id" FROM "customers" LIMIT ('20') :: integer) AS "customers_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "products_1"."name" AS "name") AS "sel_1")) AS "products" FROM (SELECT "products"."name" FROM "products" LEFT OUTER JOIN "purchases" ON (("purchases"."customer_id") = ("customers_0"."id")) WHERE ((("products"."id") = ("purchases"."product_id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "products_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "customers_0") AS "done_1337";`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql)
|
||||
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -396,7 +396,7 @@ func aggFunction(t *testing.T) {
|
|||
|
||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."count_price" AS "count_price") AS "sel_0")) AS "products" FROM (SELECT "products"."name", count("products"."price") AS "count_price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8))) GROUP BY "products"."name" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql)
|
||||
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -416,7 +416,7 @@ func aggFunctionWithFilter(t *testing.T) {
|
|||
|
||||
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."max_price" AS "max_price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", max("products"."price") AS "max_price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("products"."id") > (10))) GROUP BY "products"."id" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql)
|
||||
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -436,7 +436,7 @@ func queryWithVariables(t *testing.T) {
|
|||
|
||||
sql := `SELECT json_object_agg('product', product) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "product_0"."id" AS "id", "product_0"."name" AS "name") AS "sel_0")) AS "product" FROM (SELECT "product"."id", "product"."name" FROM "products" AS "product" WHERE ((("product"."price") > (0)) AND (("product"."price") < (8)) AND (("product"."price") = ('{{product_price}}')) AND (("id") = ('{{product_id}}'))) LIMIT ('1') :: integer) AS "product_0" LIMIT ('1') :: integer) AS "done_1337";`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql)
|
||||
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -455,7 +455,7 @@ func syntheticTables(t *testing.T) {
|
|||
|
||||
sql := `SELECT json_object_agg('me', me) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "me_0"."email" AS "email") AS "sel_0")) AS "me" FROM (SELECT "me"."email" FROM "users" AS "me" WHERE ((("me"."id") = ('{{user_id}}'))) LIMIT ('1') :: integer) AS "me_0" LIMIT ('1') :: integer) AS "done_1337";`
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql)
|
||||
resSQL, err := compileGQLToPSQL(gql, nil)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -465,7 +465,7 @@ func syntheticTables(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestCompileGQL(t *testing.T) {
|
||||
func TestCompileSelect(t *testing.T) {
|
||||
t.Run("withComplexArgs", withComplexArgs)
|
||||
t.Run("withWhereAndList", withWhereAndList)
|
||||
t.Run("withWhereIsNull", withWhereIsNull)
|
||||
|
@ -519,7 +519,7 @@ func BenchmarkCompile(b *testing.B) {
|
|||
b.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = pcompile.Compile(qc, w)
|
||||
_, err = pcompile.Compile(qc, w, nil)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
@ -540,7 +540,7 @@ func BenchmarkCompileParallel(b *testing.B) {
|
|||
b.Fatal(err)
|
||||
}
|
||||
|
||||
_, err = pcompile.Compile(qc, w)
|
||||
_, err = pcompile.Compile(qc, w, nil)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
|
@ -106,11 +106,12 @@ type DBSchema struct {
|
|||
}
|
||||
|
||||
type DBTableInfo struct {
|
||||
Name string
|
||||
Singular bool
|
||||
PrimaryCol string
|
||||
TSVCol string
|
||||
Columns map[string]*DBColumn
|
||||
Name string
|
||||
Singular bool
|
||||
PrimaryCol string
|
||||
TSVCol string
|
||||
Columns map[string]*DBColumn
|
||||
ColumnNames []string
|
||||
}
|
||||
|
||||
type RelType int
|
||||
|
@ -162,25 +163,30 @@ func (s *DBSchema) updateSchema(
|
|||
// Foreign key columns in current table
|
||||
colByID := make(map[int]*DBColumn)
|
||||
columns := make(map[string]*DBColumn, len(cols))
|
||||
colNames := make([]string, len(cols))
|
||||
|
||||
for i := range cols {
|
||||
c := cols[i]
|
||||
columns[strings.ToLower(c.Name)] = cols[i]
|
||||
name := strings.ToLower(c.Name)
|
||||
columns[name] = cols[i]
|
||||
colNames = append(colNames, name)
|
||||
colByID[c.ID] = cols[i]
|
||||
}
|
||||
|
||||
singular := strings.ToLower(flect.Singularize(t.Name))
|
||||
s.t[singular] = &DBTableInfo{
|
||||
Name: t.Name,
|
||||
Singular: true,
|
||||
Columns: columns,
|
||||
Name: t.Name,
|
||||
Singular: true,
|
||||
Columns: columns,
|
||||
ColumnNames: colNames,
|
||||
}
|
||||
|
||||
plural := strings.ToLower(flect.Pluralize(t.Name))
|
||||
s.t[plural] = &DBTableInfo{
|
||||
Name: t.Name,
|
||||
Singular: false,
|
||||
Columns: columns,
|
||||
Name: t.Name,
|
||||
Singular: false,
|
||||
Columns: columns,
|
||||
ColumnNames: colNames,
|
||||
}
|
||||
|
||||
ct := strings.ToLower(t.Name)
|
||||
|
|
|
@ -100,19 +100,7 @@ var lexPool = sync.Pool{
|
|||
}
|
||||
|
||||
func Parse(gql []byte) (*Operation, error) {
|
||||
return parseSelectionSet(nil, gql)
|
||||
}
|
||||
|
||||
func ParseQuery(gql []byte) (*Operation, error) {
|
||||
op := opPool.Get().(*Operation)
|
||||
op.Reset()
|
||||
|
||||
op.Type = opQuery
|
||||
op.Name = ""
|
||||
op.Fields = op.fieldsA[:0]
|
||||
op.Args = op.argsA[:0]
|
||||
|
||||
return parseSelectionSet(op, gql)
|
||||
return parseSelectionSet(gql)
|
||||
}
|
||||
|
||||
func ParseArgValue(argVal string) (*Node, error) {
|
||||
|
@ -134,7 +122,7 @@ func ParseArgValue(argVal string) (*Node, error) {
|
|||
return op, err
|
||||
}
|
||||
|
||||
func parseSelectionSet(op *Operation, gql []byte) (*Operation, error) {
|
||||
func parseSelectionSet(gql []byte) (*Operation, error) {
|
||||
var err error
|
||||
|
||||
if len(gql) == 0 {
|
||||
|
@ -154,14 +142,28 @@ func parseSelectionSet(op *Operation, gql []byte) (*Operation, error) {
|
|||
items: l.items,
|
||||
}
|
||||
|
||||
if op == nil {
|
||||
op, err = p.parseOp()
|
||||
} else {
|
||||
if p.peek(itemObjOpen) {
|
||||
p.ignore()
|
||||
}
|
||||
var op *Operation
|
||||
|
||||
if p.peek(itemObjOpen) {
|
||||
p.ignore()
|
||||
}
|
||||
|
||||
if p.peek(itemName) {
|
||||
op = opPool.Get().(*Operation)
|
||||
op.Reset()
|
||||
|
||||
op.Type = opQuery
|
||||
op.Name = ""
|
||||
op.Fields = op.fieldsA[:0]
|
||||
op.Args = op.argsA[:0]
|
||||
op.Fields, err = p.parseFields(op.Fields)
|
||||
|
||||
} else {
|
||||
op, err = p.parseOp()
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
lexPool.Put(l)
|
||||
|
|
|
@ -45,10 +45,10 @@ func compareOp(op1, op2 Operation) error {
|
|||
}
|
||||
*/
|
||||
|
||||
func TestCompile(t *testing.T) {
|
||||
func TestCompile1(t *testing.T) {
|
||||
qcompile, _ := NewCompiler(Config{})
|
||||
|
||||
_, err := qcompile.CompileQuery([]byte(`
|
||||
_, err := qcompile.Compile([]byte(`
|
||||
product(id: 15) {
|
||||
id
|
||||
name
|
||||
|
@ -59,9 +59,39 @@ func TestCompile(t *testing.T) {
|
|||
}
|
||||
}
|
||||
|
||||
func TestCompile2(t *testing.T) {
|
||||
qcompile, _ := NewCompiler(Config{})
|
||||
|
||||
_, err := qcompile.Compile([]byte(`
|
||||
query { product(id: 15) {
|
||||
id
|
||||
name
|
||||
} }`))
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompile3(t *testing.T) {
|
||||
qcompile, _ := NewCompiler(Config{})
|
||||
|
||||
_, err := qcompile.Compile([]byte(`
|
||||
mutation {
|
||||
product(id: 15, name: "Test") {
|
||||
id
|
||||
name
|
||||
}
|
||||
}`))
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestInvalidCompile1(t *testing.T) {
|
||||
qcompile, _ := NewCompiler(Config{})
|
||||
_, err := qcompile.CompileQuery([]byte(`#`))
|
||||
_, err := qcompile.Compile([]byte(`#`))
|
||||
|
||||
if err == nil {
|
||||
t.Fatal(errors.New("expecting an error"))
|
||||
|
@ -70,7 +100,7 @@ func TestInvalidCompile1(t *testing.T) {
|
|||
|
||||
func TestInvalidCompile2(t *testing.T) {
|
||||
qcompile, _ := NewCompiler(Config{})
|
||||
_, err := qcompile.CompileQuery([]byte(`{u(where:{not:0})}`))
|
||||
_, err := qcompile.Compile([]byte(`{u(where:{not:0})}`))
|
||||
|
||||
if err == nil {
|
||||
t.Fatal(errors.New("expecting an error"))
|
||||
|
@ -79,7 +109,7 @@ func TestInvalidCompile2(t *testing.T) {
|
|||
|
||||
func TestEmptyCompile(t *testing.T) {
|
||||
qcompile, _ := NewCompiler(Config{})
|
||||
_, err := qcompile.CompileQuery([]byte(``))
|
||||
_, err := qcompile.Compile([]byte(``))
|
||||
|
||||
if err == nil {
|
||||
t.Fatal(errors.New("expecting an error"))
|
||||
|
@ -114,7 +144,7 @@ func BenchmarkQCompile(b *testing.B) {
|
|||
b.ReportAllocs()
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
_, err := qcompile.CompileQuery(gql)
|
||||
_, err := qcompile.Compile(gql)
|
||||
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
|
@ -130,7 +160,7 @@ func BenchmarkQCompileP(b *testing.B) {
|
|||
|
||||
b.RunParallel(func(pb *testing.PB) {
|
||||
for pb.Next() {
|
||||
_, err := qcompile.CompileQuery(gql)
|
||||
_, err := qcompile.Compile(gql)
|
||||
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
|
|
|
@ -10,15 +10,17 @@ import (
|
|||
"github.com/gobuffalo/flect"
|
||||
)
|
||||
|
||||
type QType int
|
||||
|
||||
const (
|
||||
maxSelectors = 30
|
||||
|
||||
QTQuery QType = iota + 1
|
||||
QTMutation
|
||||
)
|
||||
|
||||
type QCode struct {
|
||||
Query *Query
|
||||
}
|
||||
|
||||
type Query struct {
|
||||
Type QType
|
||||
Selects []Select
|
||||
}
|
||||
|
||||
|
@ -149,6 +151,11 @@ type Compiler struct {
|
|||
ka bool
|
||||
}
|
||||
|
||||
var opMap = map[parserType]QType{
|
||||
opQuery: QTQuery,
|
||||
opMutate: QTMutation,
|
||||
}
|
||||
|
||||
var expPool = sync.Pool{
|
||||
New: func() interface{} { return new(Exp) },
|
||||
}
|
||||
|
@ -196,44 +203,23 @@ func (com *Compiler) Compile(query []byte) (*QCode, error) {
|
|||
return nil, err
|
||||
}
|
||||
|
||||
switch op.Type {
|
||||
case opQuery:
|
||||
qc.Query, err = com.compileQuery(op)
|
||||
case opMutate:
|
||||
case opSub:
|
||||
default:
|
||||
err = fmt.Errorf("Unknown operation type %d", op.Type)
|
||||
}
|
||||
|
||||
qc.Selects, err = com.compileQuery(op)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if t, ok := opMap[op.Type]; ok {
|
||||
qc.Type = t
|
||||
} else {
|
||||
return nil, fmt.Errorf("Unknown operation type %d", op.Type)
|
||||
}
|
||||
|
||||
opPool.Put(op)
|
||||
|
||||
return &qc, nil
|
||||
}
|
||||
|
||||
func (com *Compiler) CompileQuery(query []byte) (*QCode, error) {
|
||||
var err error
|
||||
|
||||
op, err := ParseQuery(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
qc := &QCode{}
|
||||
qc.Query, err = com.compileQuery(op)
|
||||
opPool.Put(op)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return qc, nil
|
||||
}
|
||||
|
||||
func (com *Compiler) compileQuery(op *Operation) (*Query, error) {
|
||||
func (com *Compiler) compileQuery(op *Operation) ([]Select, error) {
|
||||
id := int32(0)
|
||||
parentID := int32(0)
|
||||
|
||||
|
@ -344,7 +330,7 @@ func (com *Compiler) compileQuery(op *Operation) (*Query, error) {
|
|||
return nil, errors.New("invalid query")
|
||||
}
|
||||
|
||||
return &Query{selects[:id]}, nil
|
||||
return selects[:id], nil
|
||||
}
|
||||
|
||||
func (com *Compiler) compileArgs(sel *Select, args []Arg) error {
|
||||
|
@ -661,14 +647,6 @@ func (com *Compiler) compileArgOffset(sel *Select, arg *Arg) error {
|
|||
return nil
|
||||
}
|
||||
|
||||
func compileMutate() (*Query, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func compileSub() (*Query, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func newExp(st *util.Stack, node *Node, usePool bool) (*Exp, error) {
|
||||
name := node.Name
|
||||
if name[0] == '_' {
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
package serv
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
|
@ -9,9 +10,15 @@ import (
|
|||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
AL_QUERY int = iota + 1
|
||||
AL_VARS
|
||||
)
|
||||
|
||||
type allowItem struct {
|
||||
uri string
|
||||
gql string
|
||||
uri string
|
||||
gql string
|
||||
vars json.RawMessage
|
||||
}
|
||||
|
||||
var _allowList allowList
|
||||
|
@ -77,8 +84,9 @@ func (al *allowList) add(req *gqlReq) {
|
|||
}
|
||||
|
||||
al.saveChan <- &allowItem{
|
||||
uri: req.ref,
|
||||
gql: req.Query,
|
||||
uri: req.ref,
|
||||
gql: req.Query,
|
||||
vars: req.Vars,
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -93,32 +101,62 @@ func (al *allowList) load() {
|
|||
}
|
||||
|
||||
var uri string
|
||||
var varBytes []byte
|
||||
|
||||
s, e, c := 0, 0, 0
|
||||
ty := 0
|
||||
|
||||
for {
|
||||
if c == 0 && b[e] == '#' {
|
||||
s = e
|
||||
for b[e] != '\n' && e < len(b) {
|
||||
for e < len(b) && b[e] != '\n' {
|
||||
e++
|
||||
}
|
||||
if (e - s) > 2 {
|
||||
uri = strings.TrimSpace(string(b[(s + 1):e]))
|
||||
}
|
||||
}
|
||||
if b[e] == '{' {
|
||||
|
||||
if e >= len(b) {
|
||||
break
|
||||
}
|
||||
|
||||
if matchPrefix(b, e, "query") || matchPrefix(b, e, "mutation") {
|
||||
if c == 0 {
|
||||
s = e
|
||||
}
|
||||
ty = AL_QUERY
|
||||
} else if matchPrefix(b, e, "variables") {
|
||||
if c == 0 {
|
||||
s = e + len("variables") + 1
|
||||
}
|
||||
ty = AL_VARS
|
||||
} else if b[e] == '{' {
|
||||
c++
|
||||
|
||||
} else if b[e] == '}' {
|
||||
c--
|
||||
|
||||
if c == 0 {
|
||||
q := b[s:(e + 1)]
|
||||
al.list[gqlHash(q)] = &allowItem{
|
||||
uri: uri,
|
||||
gql: string(q),
|
||||
if ty == AL_QUERY {
|
||||
q := string(b[s:(e + 1)])
|
||||
|
||||
item := &allowItem{
|
||||
uri: uri,
|
||||
gql: q,
|
||||
}
|
||||
|
||||
if len(varBytes) != 0 {
|
||||
item.vars = varBytes
|
||||
}
|
||||
|
||||
al.list[gqlHash(q, varBytes)] = item
|
||||
varBytes = nil
|
||||
|
||||
} else if ty == AL_VARS {
|
||||
varBytes = b[s:(e + 1)]
|
||||
}
|
||||
ty = 0
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -130,7 +168,7 @@ func (al *allowList) load() {
|
|||
}
|
||||
|
||||
func (al *allowList) save(item *allowItem) {
|
||||
al.list[gqlHash([]byte(item.gql))] = item
|
||||
al.list[gqlHash(item.gql, item.vars)] = item
|
||||
|
||||
f, err := os.Create(al.filepath)
|
||||
if err != nil {
|
||||
|
@ -141,10 +179,10 @@ func (al *allowList) save(item *allowItem) {
|
|||
defer f.Close()
|
||||
|
||||
keys := []string{}
|
||||
urlMap := make(map[string][]string)
|
||||
urlMap := make(map[string][]*allowItem)
|
||||
|
||||
for _, v := range al.list {
|
||||
urlMap[v.uri] = append(urlMap[v.uri], v.gql)
|
||||
urlMap[v.uri] = append(urlMap[v.uri], v)
|
||||
}
|
||||
|
||||
for k := range urlMap {
|
||||
|
@ -159,7 +197,28 @@ func (al *allowList) save(item *allowItem) {
|
|||
f.WriteString(fmt.Sprintf("# %s\n\n", k))
|
||||
|
||||
for i := range v {
|
||||
f.WriteString(fmt.Sprintf("query %s\n\n", v[i]))
|
||||
if len(v[i].vars) != 0 {
|
||||
vj, err := json.MarshalIndent(v[i].vars, "", "\t")
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msg("Failed to write allow list 'vars' to file")
|
||||
continue
|
||||
}
|
||||
f.WriteString(fmt.Sprintf("variables %s\n\n", vj))
|
||||
}
|
||||
|
||||
f.WriteString(fmt.Sprintf("%s\n\n", v[i].gql))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func matchPrefix(b []byte, i int, s string) bool {
|
||||
if (len(b) - i) < len(s) {
|
||||
return false
|
||||
}
|
||||
for n := 0; n < len(s); n++ {
|
||||
if b[(i+n)] != s[n] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
|
30
serv/core.go
30
serv/core.go
|
@ -14,6 +14,7 @@ import (
|
|||
|
||||
"github.com/cespare/xxhash/v2"
|
||||
"github.com/dosco/super-graph/jsn"
|
||||
"github.com/dosco/super-graph/psql"
|
||||
"github.com/dosco/super-graph/qcode"
|
||||
"github.com/go-pg/pg"
|
||||
"github.com/valyala/fasttemplate"
|
||||
|
@ -42,7 +43,7 @@ func (c *coreContext) handleReq(w io.Writer, req *http.Request) error {
|
|||
if conf.UseAllowList {
|
||||
var ps *preparedItem
|
||||
|
||||
data, ps, err = c.resolvePreparedSQL([]byte(c.req.Query))
|
||||
data, ps, err = c.resolvePreparedSQL(c.req.Query)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -52,7 +53,7 @@ func (c *coreContext) handleReq(w io.Writer, req *http.Request) error {
|
|||
|
||||
} else {
|
||||
|
||||
qc, err = qcompile.CompileQuery([]byte(c.req.Query))
|
||||
qc, err = qcompile.Compile([]byte(c.req.Query))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -67,7 +68,7 @@ func (c *coreContext) handleReq(w io.Writer, req *http.Request) error {
|
|||
return c.render(w, data)
|
||||
}
|
||||
|
||||
sel := qc.Query.Selects
|
||||
sel := qc.Selects
|
||||
h := xxhash.New()
|
||||
|
||||
// fetch the field name used within the db response json
|
||||
|
@ -252,8 +253,8 @@ func (c *coreContext) resolveRemotes(
|
|||
return to, cerr
|
||||
}
|
||||
|
||||
func (c *coreContext) resolvePreparedSQL(gql []byte) ([]byte, *preparedItem, error) {
|
||||
ps, ok := _preparedList[gqlHash(gql)]
|
||||
func (c *coreContext) resolvePreparedSQL(gql string) ([]byte, *preparedItem, error) {
|
||||
ps, ok := _preparedList[gqlHash(gql, c.req.Vars)]
|
||||
if !ok {
|
||||
return nil, nil, errUnauthorized
|
||||
}
|
||||
|
@ -266,17 +267,22 @@ func (c *coreContext) resolvePreparedSQL(gql []byte) ([]byte, *preparedItem, err
|
|||
return nil, nil, err
|
||||
}
|
||||
|
||||
fmt.Printf("PRE: %#v %#v\n", ps.stmt, vars)
|
||||
fmt.Printf("PRE: %v\n", ps.stmt)
|
||||
|
||||
return []byte(root), ps, nil
|
||||
}
|
||||
|
||||
func (c *coreContext) resolveSQL(qc *qcode.QCode) (
|
||||
[]byte, uint32, error) {
|
||||
|
||||
stmt := &bytes.Buffer{}
|
||||
|
||||
skipped, err := pcompile.Compile(qc, stmt)
|
||||
vars := make(map[string]json.RawMessage)
|
||||
|
||||
if err := json.Unmarshal(c.req.Vars, &vars); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
skipped, err := pcompile.Compile(qc, stmt, psql.Variables(vars))
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
@ -284,7 +290,7 @@ func (c *coreContext) resolveSQL(qc *qcode.QCode) (
|
|||
t := fasttemplate.New(stmt.String(), openVar, closeVar)
|
||||
|
||||
stmt.Reset()
|
||||
_, err = t.Execute(stmt, varMap(c))
|
||||
_, err = t.ExecuteFunc(stmt, varMap(c))
|
||||
|
||||
if err == errNoUserID &&
|
||||
authFailBlock == authFailBlockPerQuery &&
|
||||
|
@ -317,10 +323,10 @@ func (c *coreContext) resolveSQL(qc *qcode.QCode) (
|
|||
return nil, 0, err
|
||||
}
|
||||
|
||||
if conf.EnableTracing && len(qc.Query.Selects) != 0 {
|
||||
if conf.EnableTracing && len(qc.Selects) != 0 {
|
||||
c.addTrace(
|
||||
qc.Query.Selects,
|
||||
qc.Query.Selects[0].ID,
|
||||
qc.Selects,
|
||||
qc.Selects[0].ID,
|
||||
st)
|
||||
}
|
||||
|
||||
|
|
|
@ -0,0 +1,35 @@
|
|||
package serv
|
||||
|
||||
/*
|
||||
|
||||
func simpleMutation(t *testing.T) {
|
||||
gql := `mutation {
|
||||
product(id: 15, insert: { name: "Test", price: 20.5 }) {
|
||||
id
|
||||
name
|
||||
}
|
||||
}`
|
||||
|
||||
sql := `test`
|
||||
|
||||
backgroundCtx := context.Background()
|
||||
ctx := &coreContext{Context: backgroundCtx}
|
||||
|
||||
resSQL, err := compileGQLToPSQL(gql)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
fmt.Println(">", string(resSQL))
|
||||
|
||||
if string(resSQL) != sql {
|
||||
t.Fatal(errNotExpected)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCompileGQL(t *testing.T) {
|
||||
t.Run("withComplexArgs", withComplexArgs)
|
||||
t.Run("simpleMutation", simpleMutation)
|
||||
}
|
||||
|
||||
*/
|
|
@ -26,13 +26,13 @@ var (
|
|||
)
|
||||
|
||||
type gqlReq struct {
|
||||
OpName string `json:"operationName"`
|
||||
Query string `json:"query"`
|
||||
Vars variables `json:"variables"`
|
||||
OpName string `json:"operationName"`
|
||||
Query string `json:"query"`
|
||||
Vars json.RawMessage `json:"variables"`
|
||||
ref string
|
||||
}
|
||||
|
||||
type variables map[string]interface{}
|
||||
type variables map[string]json.RawMessage
|
||||
|
||||
type gqlResp struct {
|
||||
Error string `json:"error,omitempty"`
|
||||
|
|
|
@ -2,9 +2,11 @@ package serv
|
|||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/dosco/super-graph/psql"
|
||||
"github.com/dosco/super-graph/qcode"
|
||||
"github.com/go-pg/pg"
|
||||
"github.com/valyala/fasttemplate"
|
||||
|
@ -12,7 +14,7 @@ import (
|
|||
|
||||
type preparedItem struct {
|
||||
stmt *pg.Stmt
|
||||
args []string
|
||||
args [][]byte
|
||||
skipped uint32
|
||||
qc *qcode.QCode
|
||||
}
|
||||
|
@ -25,36 +27,46 @@ func initPreparedList() {
|
|||
_preparedList = make(map[string]*preparedItem)
|
||||
|
||||
for k, v := range _allowList.list {
|
||||
err := prepareStmt(k, v.gql)
|
||||
err := prepareStmt(k, v.gql, v.vars)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func prepareStmt(key, gql string) error {
|
||||
func prepareStmt(key, gql string, varBytes json.RawMessage) error {
|
||||
if len(gql) == 0 || len(key) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
qc, err := qcompile.CompileQuery([]byte(gql))
|
||||
qc, err := qcompile.Compile([]byte(gql))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var vars map[string]json.RawMessage
|
||||
|
||||
if len(varBytes) != 0 {
|
||||
vars = make(map[string]json.RawMessage)
|
||||
|
||||
if err := json.Unmarshal(varBytes, &vars); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
|
||||
skipped, err := pcompile.Compile(qc, buf)
|
||||
skipped, err := pcompile.Compile(qc, buf, psql.Variables(vars))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
t := fasttemplate.New(buf.String(), `('{{`, `}}')`)
|
||||
am := make([]string, 0, 5)
|
||||
t := fasttemplate.New(buf.String(), `{{`, `}}`)
|
||||
am := make([][]byte, 0, 5)
|
||||
i := 0
|
||||
|
||||
finalSQL := t.ExecuteFuncString(func(w io.Writer, tag string) (int, error) {
|
||||
am = append(am, tag)
|
||||
am = append(am, []byte(tag))
|
||||
i++
|
||||
return w.Write([]byte(fmt.Sprintf("$%d", i)))
|
||||
})
|
||||
|
|
|
@ -4,8 +4,12 @@ import (
|
|||
"bytes"
|
||||
"crypto/sha1"
|
||||
"encoding/hex"
|
||||
"io"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/cespare/xxhash/v2"
|
||||
"github.com/dosco/super-graph/jsn"
|
||||
)
|
||||
|
||||
func mkkey(h *xxhash.Digest, k1 string, k2 string) uint64 {
|
||||
|
@ -17,8 +21,8 @@ func mkkey(h *xxhash.Digest, k1 string, k2 string) uint64 {
|
|||
return v
|
||||
}
|
||||
|
||||
func gqlHash(b []byte) string {
|
||||
b = bytes.TrimSpace(b)
|
||||
func gqlHash(b string, vars []byte) string {
|
||||
b = strings.TrimSpace(b)
|
||||
h := sha1.New()
|
||||
|
||||
s, e := 0, 0
|
||||
|
@ -45,13 +49,27 @@ func gqlHash(b []byte) string {
|
|||
if e != 0 {
|
||||
b0 = b[(e - 1)]
|
||||
}
|
||||
h.Write(bytes.ToLower(b[s:e]))
|
||||
io.WriteString(h, strings.ToLower(b[s:e]))
|
||||
}
|
||||
if e >= len(b) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if vars == nil {
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
fields := jsn.Keys([]byte(vars))
|
||||
|
||||
sort.Slice(fields, func(i, j int) bool {
|
||||
return bytes.Compare(fields[i], fields[j]) == -1
|
||||
})
|
||||
|
||||
for i := range fields {
|
||||
h.Write(fields[i])
|
||||
}
|
||||
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
|
|
|
@ -6,7 +6,7 @@ import (
|
|||
)
|
||||
|
||||
func TestRelaxHash1(t *testing.T) {
|
||||
var v1 = []byte(`
|
||||
var v1 = `
|
||||
products(
|
||||
limit: 30,
|
||||
|
||||
|
@ -14,18 +14,18 @@ func TestRelaxHash1(t *testing.T) {
|
|||
id
|
||||
name
|
||||
price
|
||||
}`)
|
||||
}`
|
||||
|
||||
var v2 = []byte(`
|
||||
var v2 = `
|
||||
products
|
||||
(limit: 30, where: { id: { AND: { greater_or_equals: 20, lt: 28 } } }) {
|
||||
id
|
||||
name
|
||||
price
|
||||
} `)
|
||||
} `
|
||||
|
||||
h1 := gqlHash(v1)
|
||||
h2 := gqlHash(v2)
|
||||
h1 := gqlHash(v1, nil)
|
||||
h2 := gqlHash(v2, nil)
|
||||
|
||||
if strings.Compare(h1, h2) != 0 {
|
||||
t.Fatal("Hashes don't match they should")
|
||||
|
@ -33,7 +33,7 @@ func TestRelaxHash1(t *testing.T) {
|
|||
}
|
||||
|
||||
func TestRelaxHash2(t *testing.T) {
|
||||
var v1 = []byte(`
|
||||
var v1 = `
|
||||
{
|
||||
products(
|
||||
limit: 30
|
||||
|
@ -49,12 +49,119 @@ func TestRelaxHash2(t *testing.T) {
|
|||
email
|
||||
}
|
||||
}
|
||||
}`)
|
||||
}`
|
||||
|
||||
var v2 = []byte(` { products( limit: 30, order_by: { price: desc }, distinct: [ price ] where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) { id name price user { id email } } } `)
|
||||
var v2 = ` { products( limit: 30, order_by: { price: desc }, distinct: [ price ] where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) { id name price user { id email } } } `
|
||||
|
||||
h1 := gqlHash(v1)
|
||||
h2 := gqlHash(v2)
|
||||
h1 := gqlHash(v1, nil)
|
||||
h2 := gqlHash(v2, nil)
|
||||
|
||||
if strings.Compare(h1, h2) != 0 {
|
||||
t.Fatal("Hashes don't match they should")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRelaxHashWithVars1(t *testing.T) {
|
||||
var q1 = `
|
||||
products(
|
||||
limit: 30,
|
||||
|
||||
where: { id: { AND: { greater_or_equals: 20, lt: 28 } } }) {
|
||||
id
|
||||
name
|
||||
price
|
||||
}`
|
||||
|
||||
var v1 = `
|
||||
{
|
||||
"insert": {
|
||||
"name": "Hello",
|
||||
"description": "World",
|
||||
"created_at": "now",
|
||||
"updated_at": "now",
|
||||
"test": { "type2": "b", "type1": "a" }
|
||||
},
|
||||
"user": 123
|
||||
}`
|
||||
|
||||
var q2 = `
|
||||
products
|
||||
(limit: 30, where: { id: { AND: { greater_or_equals: 20, lt: 28 } } }) {
|
||||
id
|
||||
name
|
||||
price
|
||||
} `
|
||||
|
||||
var v2 = `{
|
||||
"insert": {
|
||||
"created_at": "now",
|
||||
"test": { "type1": "a", "type2": "b" },
|
||||
"name": "Hello",
|
||||
"updated_at": "now",
|
||||
"description": "World"
|
||||
},
|
||||
"user": 123
|
||||
}`
|
||||
|
||||
h1 := gqlHash(q1, []byte(va1))
|
||||
h2 := gqlHash(q2, []byte(va2))
|
||||
|
||||
if strings.Compare(h1, h2) != 0 {
|
||||
t.Fatal("Hashes don't match they should")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRelaxHashWithVars2(t *testing.T) {
|
||||
var q1 = `
|
||||
products(
|
||||
limit: 30,
|
||||
|
||||
where: { id: { AND: { greater_or_equals: 20, lt: 28 } } }) {
|
||||
id
|
||||
name
|
||||
price
|
||||
}`
|
||||
|
||||
var v1 = `
|
||||
{
|
||||
"insert": [{
|
||||
"name": "Hello",
|
||||
"description": "World",
|
||||
"created_at": "now",
|
||||
"updated_at": "now",
|
||||
"test": { "type2": "b", "type1": "a" }
|
||||
},
|
||||
{
|
||||
"name": "Hello",
|
||||
"description": "World",
|
||||
"created_at": "now",
|
||||
"updated_at": "now",
|
||||
"test": { "type2": "b", "type1": "a" }
|
||||
}],
|
||||
"user": 123
|
||||
}`
|
||||
|
||||
var q2 = `
|
||||
products
|
||||
(limit: 30, where: { id: { AND: { greater_or_equals: 20, lt: 28 } } }) {
|
||||
id
|
||||
name
|
||||
price
|
||||
} `
|
||||
|
||||
var v2 = `{
|
||||
"insert": {
|
||||
"created_at": "now",
|
||||
"test": { "type1": "a", "type2": "b" },
|
||||
"name": "Hello",
|
||||
"updated_at": "now",
|
||||
"description": "World"
|
||||
},
|
||||
"user": 123
|
||||
}`
|
||||
|
||||
h1 := gqlHash(q1, []byte(va1))
|
||||
h2 := gqlHash(q2, []byte(va2))
|
||||
|
||||
if strings.Compare(h1, h2) != 0 {
|
||||
t.Fatal("Hashes don't match they should")
|
||||
|
|
140
serv/vars.go
140
serv/vars.go
|
@ -1,95 +1,107 @@
|
|||
package serv
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/valyala/fasttemplate"
|
||||
"github.com/dosco/super-graph/jsn"
|
||||
)
|
||||
|
||||
func varMap(ctx *coreContext) variables {
|
||||
userIDFn := func(w io.Writer, _ string) (int, error) {
|
||||
if v := ctx.Value(userIDKey); v != nil {
|
||||
return w.Write([]byte(v.(string)))
|
||||
}
|
||||
return 0, errNoUserID
|
||||
}
|
||||
func varMap(ctx *coreContext) func(w io.Writer, tag string) (int, error) {
|
||||
return func(w io.Writer, tag string) (int, error) {
|
||||
switch tag {
|
||||
case "user_id":
|
||||
if v := ctx.Value(userIDKey); v != nil {
|
||||
return stringVar(w, v.(string))
|
||||
}
|
||||
return 0, errNoUserID
|
||||
|
||||
userIDProviderFn := func(w io.Writer, _ string) (int, error) {
|
||||
if v := ctx.Value(userIDProviderKey); v != nil {
|
||||
return w.Write([]byte(v.(string)))
|
||||
}
|
||||
return 0, errNoUserID
|
||||
}
|
||||
|
||||
userIDTag := fasttemplate.TagFunc(userIDFn)
|
||||
userIDProviderTag := fasttemplate.TagFunc(userIDProviderFn)
|
||||
|
||||
vm := variables{
|
||||
"user_id": userIDTag,
|
||||
"user_id_provider": userIDProviderTag,
|
||||
"USER_ID": userIDTag,
|
||||
"USER_ID_PROVIDER": userIDProviderTag,
|
||||
}
|
||||
|
||||
for k, v := range ctx.req.Vars {
|
||||
var buf []byte
|
||||
k = strings.ToLower(k)
|
||||
|
||||
if _, ok := vm[k]; ok {
|
||||
continue
|
||||
case "user_id_provider":
|
||||
if v := ctx.Value(userIDProviderKey); v != nil {
|
||||
return stringVar(w, v.(string))
|
||||
}
|
||||
return 0, errNoUserID
|
||||
}
|
||||
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
vm[k] = val
|
||||
case int:
|
||||
vm[k] = strconv.AppendInt(buf, int64(val), 10)
|
||||
case int64:
|
||||
vm[k] = strconv.AppendInt(buf, val, 10)
|
||||
case float64:
|
||||
vm[k] = strconv.AppendFloat(buf, val, 'f', -1, 64)
|
||||
fields := jsn.Get(ctx.req.Vars, [][]byte{[]byte(tag)})
|
||||
if len(fields) == 0 {
|
||||
return 0, fmt.Errorf("variable '%s' not found", tag)
|
||||
}
|
||||
|
||||
is := false
|
||||
|
||||
for i := range fields[0].Value {
|
||||
c := fields[0].Value[i]
|
||||
if c != ' ' {
|
||||
is = (c == '"') || (c == '{') || (c == '[')
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if is {
|
||||
return stringVarB(w, fields[0].Value)
|
||||
}
|
||||
|
||||
w.Write(fields[0].Value)
|
||||
return 0, nil
|
||||
}
|
||||
return vm
|
||||
}
|
||||
|
||||
func varList(ctx *coreContext, args []string) []interface{} {
|
||||
vars := make([]interface{}, 0, len(args))
|
||||
func varList(ctx *coreContext, args [][]byte) []interface{} {
|
||||
vars := make([]interface{}, len(args))
|
||||
|
||||
for k, v := range ctx.req.Vars {
|
||||
ctx.req.Vars[strings.ToLower(k)] = v
|
||||
var fields map[string]interface{}
|
||||
var err error
|
||||
|
||||
if len(ctx.req.Vars) != 0 {
|
||||
fields, _, err = jsn.Tree(ctx.req.Vars)
|
||||
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msg("Failed to parse variables")
|
||||
}
|
||||
}
|
||||
|
||||
for i := range args {
|
||||
arg := strings.ToLower(args[i])
|
||||
av := args[i]
|
||||
|
||||
if arg == "user_id" {
|
||||
switch {
|
||||
case bytes.Equal(av, []byte("user_id")):
|
||||
if v := ctx.Value(userIDKey); v != nil {
|
||||
vars = append(vars, v.(string))
|
||||
vars[i] = v.(string)
|
||||
}
|
||||
}
|
||||
|
||||
if arg == "user_id_provider" {
|
||||
case bytes.Equal(av, []byte("user_id_provider")):
|
||||
if v := ctx.Value(userIDProviderKey); v != nil {
|
||||
vars = append(vars, v.(string))
|
||||
vars[i] = v.(string)
|
||||
}
|
||||
}
|
||||
|
||||
if v, ok := ctx.req.Vars[arg]; ok {
|
||||
switch val := v.(type) {
|
||||
case string:
|
||||
vars = append(vars, val)
|
||||
case int:
|
||||
vars = append(vars, strconv.FormatInt(int64(val), 10))
|
||||
case int64:
|
||||
vars = append(vars, strconv.FormatInt(int64(val), 10))
|
||||
case float64:
|
||||
vars = append(vars, strconv.FormatFloat(val, 'f', -1, 64))
|
||||
default:
|
||||
if v, ok := fields[string(av)]; ok {
|
||||
vars[i] = v
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return vars
|
||||
}
|
||||
|
||||
func stringVar(w io.Writer, v string) (int, error) {
|
||||
if n, err := w.Write([]byte(`'`)); err != nil {
|
||||
return n, err
|
||||
}
|
||||
if n, err := w.Write([]byte(v)); err != nil {
|
||||
return n, err
|
||||
}
|
||||
return w.Write([]byte(`'`))
|
||||
}
|
||||
|
||||
func stringVarB(w io.Writer, v []byte) (int, error) {
|
||||
if n, err := w.Write([]byte(`'`)); err != nil {
|
||||
return n, err
|
||||
}
|
||||
if n, err := w.Write(v); err != nil {
|
||||
return n, err
|
||||
}
|
||||
return w.Write([]byte(`'`))
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue