Add insert mutation with bulk insert

This commit is contained in:
Vikram Rangnekar 2019-09-05 00:09:56 -04:00
parent 5b9105ff0c
commit c0a21e448f
30 changed files with 1080 additions and 265 deletions

View File

@ -47,4 +47,79 @@ query {
email email
} }
} }
variables {
"insert": {
"name": "Hello",
"description": "World",
"created_at": "now",
"updated_at": "now"
},
"user": 123
} }
mutation {
products(insert: $insert) {
id
name
description
}
}
variables {
"insert": {
"name": "Hello",
"description": "World",
"created_at": "now",
"updated_at": "now"
},
"user": 123
}
mutation {
products(insert: $insert) {
id
}
}
variables {
"insert": {
"description": "World3",
"name": "Hello3",
"created_at": "now",
"updated_at": "now"
},
"user": 123
}
{
customers {
id
email
payments {
customer_id
amount
billing_details
}
}
}
variables {
"insert": {
"description": "World3",
"name": "Hello3",
"created_at": "now",
"updated_at": "now"
},
"user": 123
}
{
me {
id
full_name
}
}

View File

@ -84,8 +84,9 @@ database:
#log_level: "debug" #log_level: "debug"
# Define variables here that you want to use in filters # Define variables here that you want to use in filters
# sub-queries must be wrapped in ()
variables: variables:
account_id: "select account_id from users where id = $user_id" account_id: "(select account_id from users where id = $user_id)"
# Define defaults to for the field key and values below # Define defaults to for the field key and values below
defaults: defaults:
@ -105,12 +106,12 @@ database:
# This filter will overwrite defaults.filter # This filter will overwrite defaults.filter
# filter: ["{ id: { eq: $user_id } }"] # filter: ["{ id: { eq: $user_id } }"]
- name: products # - name: products
# Multiple filters are AND'd together # # Multiple filters are AND'd together
filter: [ # filter: [
"{ price: { gt: 0 } }", # "{ price: { gt: 0 } }",
"{ price: { lt: 8 } }" # "{ price: { lt: 8 } }"
] # ]
- name: customers - name: customers
# No filter is used for this field not # No filter is used for this field not

View File

@ -82,8 +82,9 @@ database:
#log_level: "debug" #log_level: "debug"
# Define variables here that you want to use in filters # Define variables here that you want to use in filters
# sub-queries must be wrapped in ()
variables: variables:
account_id: "select account_id from users where id = $user_id" account_id: "(select account_id from users where id = $user_id)"
# Define defaults to for the field key and values below # Define defaults to for the field key and values below
defaults: defaults:

View File

@ -2,6 +2,8 @@ version: '3.4'
services: services:
db: db:
image: postgres image: postgres
ports:
- "5432:5432"
# redis: # redis:
# image: redis:alpine # image: redis:alpine

39
jsn/README.md Normal file
View File

@ -0,0 +1,39 @@
# JSN - Fast low allocation JSON library
## Design
This libary is designed as a set of seperate functions to extract data and mutate
JSON. All functions are focused on keeping allocations to a minimum and be as fast
as possible. The functions don't validate the JSON a seperate `Validate` function
does that.
The JSON parsing algo processes each object `{}` or array `[]` in a bottom up way
where once the end of the array or object is found only then the keys within it are
parsed from the top down.
```
{"id":1,"posts": [{"title":"PT1-1","description":"PD1-1"}], "full_name":"FN1","email":"E1" }
id: 1
posts: [{"title":"PT1-1","description":"PD1-1"}]
[{"title":"PT1-1","description":"PD1-1"}]
{"title":"PT1-1","description":"PD1-1"}
title: "PT1-1"
description: "PD1-1
full_name: "FN1"
email: "E1"
```
## Functions
- Strip: Strip a path from the root to a child node and return the rest
- Replace: Replace values by key
- Get: Get all keys
- Filter: Extract specific keys from an object
- Tree: Fetch unique keys from an array or object

View File

@ -43,7 +43,7 @@ func Get(b []byte, keys [][]byte) []Field {
kmap[xxhash.Sum64(keys[i])] = struct{}{} kmap[xxhash.Sum64(keys[i])] = struct{}{}
} }
res := make([]Field, 20) res := make([]Field, 0, 20)
s, e, d := 0, 0, 0 s, e, d := 0, 0, 0
@ -127,7 +127,7 @@ func Get(b []byte, keys [][]byte) []Field {
_, ok := kmap[xxhash.Sum64(k)] _, ok := kmap[xxhash.Sum64(k)]
if ok { if ok {
res[n] = Field{k, b[s:(e + 1)]} res = append(res, Field{k, b[s:(e + 1)]})
n++ n++
} }

View File

@ -21,6 +21,10 @@ var (
"full_name": "Caroll Orn Sr.", "full_name": "Caroll Orn Sr.",
"email": "joannarau@hegmann.io", "email": "joannarau@hegmann.io",
"__twitter_id": "ABC123" "__twitter_id": "ABC123"
"more": [{
"__twitter_id": "more123",
"hello: "world
}]
} }
}, },
{ {
@ -163,6 +167,7 @@ func TestGet(t *testing.T) {
{[]byte("__twitter_id"), []byte(`"ABCD"`)}, {[]byte("__twitter_id"), []byte(`"ABCD"`)},
{[]byte("__twitter_id"), []byte(`"2048666903444506956"`)}, {[]byte("__twitter_id"), []byte(`"2048666903444506956"`)},
{[]byte("__twitter_id"), []byte(`"ABC123"`)}, {[]byte("__twitter_id"), []byte(`"ABC123"`)},
{[]byte("__twitter_id"), []byte(`"more123"`)},
{[]byte("__twitter_id"), {[]byte("__twitter_id"),
[]byte(`[{ "name": "hello" }, { "name": "world"}]`)}, []byte(`[{ "name": "hello" }, { "name": "world"}]`)},
{[]byte("__twitter_id"), {[]byte("__twitter_id"),
@ -340,6 +345,80 @@ func TestReplaceEmpty(t *testing.T) {
} }
} }
func TestKeys1(t *testing.T) {
json := `[{"id":1,"posts": [{"title":"PT1-1","description":"PD1-1"}, {"title":"PT1-2","description":"PD1-2"}], "full_name":"FN1","email":"E1","books": [{"name":"BN1-1","description":"BD1-1"},{"name":"BN1-2","description":"BD1-2"},{"name":"BN1-2","description":"BD1-2"}]},{"id":1,"posts": [{"title":"PT1-1","description":"PD1-1"}, {"title":"PT1-2","description":"PD1-2"}], "full_name":"FN1","email":"E1","books": [{"name":"BN1-1","description":"BD1-1"},{"name":"BN1-2","description":"BD1-2"},{"name":"BN1-2","description":"BD1-2"}]},{"id":1,"posts": [{"title":"PT1-1","description":"PD1-1"}, {"title":"PT1-2","description":"PD1-2"}], "full_name":"FN1","email":"E1","books": [{"name":"BN1-1","description":"BD1-1"},{"name":"BN1-2","description":"BD1-2"},{"name":"BN1-2","description":"BD1-2"}]}]`
fields := Keys([]byte(json))
exp := []string{
"id", "posts", "title", "description", "full_name", "email", "books", "name", "description",
}
if len(exp) != len(fields) {
t.Errorf("Expected %d fields %d", len(exp), len(fields))
}
for i := range exp {
if string(fields[i]) != exp[i] {
t.Errorf("Expected field '%s' got '%s'", string(exp[i]), fields[i])
}
}
}
func TestKeys2(t *testing.T) {
json := `{"id":1,"posts": [{"title":"PT1-1","description":"PD1-1"}, {"title":"PT1-2","description":"PD1-2"}], "full_name":"FN1","email":"E1","books": [{"name":"BN1-1","description":"BD1-1"},{"name":"BN1-2","description":"BD1-2"},{"name":"BN1-2","description":"BD1-2"}]}`
fields := Keys([]byte(json))
exp := []string{
"id", "posts", "title", "description", "full_name", "email", "books", "name", "description",
}
// for i := range fields {
// fmt.Println("-->", string(fields[i]))
// }
if len(exp) != len(fields) {
t.Errorf("Expected %d fields %d", len(exp), len(fields))
}
for i := range exp {
if string(fields[i]) != exp[i] {
t.Errorf("Expected field '%s' got '%s'", string(exp[i]), fields[i])
}
}
}
func TestKeys3(t *testing.T) {
json := `{
"insert": {
"created_at": "now",
"test": { "type1": "a", "type2": "b" },
"name": "Hello",
"updated_at": "now",
"description": "World"
},
"user": 123
}`
fields := Keys([]byte(json))
exp := []string{
"insert", "created_at", "test", "type1", "type2", "name", "updated_at", "description",
"user",
}
if len(exp) != len(fields) {
t.Errorf("Expected %d fields %d", len(exp), len(fields))
}
for i := range exp {
if string(fields[i]) != exp[i] {
t.Errorf("Expected field '%s' got '%s'", string(exp[i]), fields[i])
}
}
}
func BenchmarkGet(b *testing.B) { func BenchmarkGet(b *testing.B) {
b.ReportAllocs() b.ReportAllocs()

122
jsn/keys.go Normal file
View File

@ -0,0 +1,122 @@
package jsn
func Keys(b []byte) [][]byte {
res := make([][]byte, 0, 20)
s, e, d := 0, 0, 0
var k []byte
state := expectValue
st := NewStack()
ae := 0
for i := 0; i < len(b); i++ {
if state == expectObjClose || state == expectListClose {
switch b[i] {
case '{', '[':
d++
case '}', ']':
d--
}
}
si := st.Peek()
switch {
case state == expectKey && si != nil && i >= si.ss:
i = si.se + 1
st.Pop()
case state == expectKey && b[i] == '{':
state = expectObjClose
s = i
d++
case state == expectObjClose && d == 0 && b[i] == '}':
state = expectKey
if ae != 0 {
st.Push(skipInfo{i, ae})
ae = 0
}
e = i
i = s
case state == expectKey && b[i] == '"':
state = expectKeyClose
s = i
case state == expectKeyClose && b[i] == '"':
state = expectColon
k = b[(s + 1):i]
case state == expectColon && b[i] == ':':
state = expectValue
case state == expectValue && b[i] == '"':
state = expectString
s = i
case state == expectString && b[i] == '"':
e = i
case state == expectValue && b[i] == '{':
state = expectObjClose
s = i
d++
case state == expectObjClose && d == 0 && b[i] == '}':
state = expectKey
e = i
i = s
case state == expectValue && b[i] == '[':
state = expectListClose
s = i
d++
case state == expectListClose && d == 0 && b[i] == ']':
state = expectKey
ae = i
e = i
i = s
case state == expectValue && (b[i] >= '0' && b[i] <= '9'):
state = expectNumClose
s = i
case state == expectNumClose &&
((b[i] < '0' || b[i] > '9') &&
(b[i] != '.' && b[i] != 'e' && b[i] != 'E' && b[i] != '+' && b[i] != '-')):
i--
e = i
case state == expectValue &&
(b[i] == 'f' || b[i] == 'F' || b[i] == 't' || b[i] == 'T'):
state = expectBoolClose
s = i
case state == expectBoolClose && (b[i] == 'e' || b[i] == 'E'):
e = i
case state == expectValue && b[i] == 'n':
state = expectNull
case state == expectNull && b[i] == 'l':
e = i
}
if e != 0 {
if k != nil {
res = append(res, k)
}
state = expectKey
k = nil
e = 0
}
}
return res
}

51
jsn/stack.go Normal file
View File

@ -0,0 +1,51 @@
package jsn
type skipInfo struct {
ss, se int
}
type Stack struct {
stA [20]skipInfo
st []skipInfo
top int
}
// Create a new Stack
func NewStack() *Stack {
s := &Stack{top: -1}
s.st = s.stA[:0]
return s
}
// Return the number of items in the Stack
func (s *Stack) Len() int {
return (s.top + 1)
}
// View the top item on the Stack
func (s *Stack) Peek() *skipInfo {
if s.top == -1 {
return nil
}
return &s.st[s.top]
}
// Pop the top item of the Stack and return it
func (s *Stack) Pop() *skipInfo {
if s.top == -1 {
return nil
}
s.top--
return &s.st[(s.top + 1)]
}
// Push a value onto the top of the Stack
func (s *Stack) Push(value skipInfo) {
s.top++
if len(s.st) <= s.top {
s.st = append(s.st, value)
} else {
s.st[s.top] = value
}
}

37
jsn/tree.go Normal file
View File

@ -0,0 +1,37 @@
package jsn
import (
"bytes"
"encoding/json"
)
func Tree(v []byte) (map[string]interface{}, bool, error) {
dec := json.NewDecoder(bytes.NewReader(v))
array := false
// read open bracket
for i := range v {
if v[i] != ' ' {
array = (v[i] == '[')
break
}
}
if array {
if _, err := dec.Token(); err != nil {
return nil, false, err
}
}
// while the array contains values
var m map[string]interface{}
// decode an array value (Message)
err := dec.Decode(&m)
if err != nil {
return nil, false, err
}
return m, array, nil
}

View File

@ -1,7 +0,0 @@
goos: darwin
goarch: amd64
pkg: github.com/dosco/super-graph/psql
BenchmarkCompile-8 100000 16476 ns/op 3282 B/op 66 allocs/op
BenchmarkCompileParallel-8 300000 4639 ns/op 3324 B/op 66 allocs/op
PASS
ok github.com/dosco/super-graph/psql 3.274s

View File

@ -1,7 +0,0 @@
goos: darwin
goarch: amd64
pkg: github.com/dosco/super-graph/psql
BenchmarkCompile-8 100000 15728 ns/op 3000 B/op 60 allocs/op
BenchmarkCompileParallel-8 300000 5077 ns/op 3023 B/op 60 allocs/op
PASS
ok github.com/dosco/super-graph/psql 3.318s

View File

@ -1,7 +0,0 @@
goos: darwin
goarch: amd64
pkg: github.com/dosco/super-graph/psql
BenchmarkCompile-8 1000000 15997 ns/op 3048 B/op 58 allocs/op
BenchmarkCompileParallel-8 3000000 4722 ns/op 3073 B/op 58 allocs/op
PASS
ok github.com/dosco/super-graph/psql 35.024s

View File

@ -1,7 +0,0 @@
goos: darwin
goarch: amd64
pkg: github.com/dosco/super-graph/psql
BenchmarkCompile-8 100000 16829 ns/op 2887 B/op 57 allocs/op
BenchmarkCompileParallel-8 300000 5450 ns/op 2911 B/op 57 allocs/op
PASS
ok github.com/dosco/super-graph/psql 3.561s

90
psql/insert.go Normal file
View File

@ -0,0 +1,90 @@
package psql
import (
"bytes"
"errors"
"io"
"github.com/dosco/super-graph/jsn"
"github.com/dosco/super-graph/qcode"
)
func (co *Compiler) compileMutation(qc *qcode.QCode, w *bytes.Buffer, vars Variables) (uint32, error) {
if len(qc.Selects) == 0 {
return 0, errors.New("empty query")
}
c := &compilerContext{w, qc.Selects, co}
root := &qc.Selects[0]
c.w.WriteString(`WITH `)
c.w.WriteString(root.Table)
c.w.WriteString(` AS (`)
if _, err := c.renderInsert(qc, w, vars); err != nil {
return 0, err
}
c.w.WriteString(`) `)
return c.compileQuery(qc, w)
}
func (c *compilerContext) renderInsert(qc *qcode.QCode, w *bytes.Buffer, vars Variables) (uint32, error) {
root := &qc.Selects[0]
insert, ok := vars["insert"]
if !ok {
return 0, errors.New("Variable 'insert' not defined")
}
jt, array, err := jsn.Tree(insert)
if err != nil {
return 0, err
}
c.w.WriteString(`WITH input AS (SELECT {{insert}}::json AS j) INSERT INTO `)
c.w.WriteString(root.Table)
io.WriteString(c.w, " (")
c.renderInsertColumns(qc, w, jt)
io.WriteString(c.w, ")")
c.w.WriteString(` SELECT `)
c.renderInsertColumns(qc, w, jt)
c.w.WriteString(` FROM input i, `)
if array {
c.w.WriteString(`json_populate_recordset`)
} else {
c.w.WriteString(`json_populate_record`)
}
c.w.WriteString(`(NULL::`)
c.w.WriteString(root.Table)
c.w.WriteString(`, i.j) t RETURNING * `)
return 0, nil
}
func (c *compilerContext) renderInsertColumns(qc *qcode.QCode, w *bytes.Buffer,
jt map[string]interface{}) (uint32, error) {
ti, err := c.schema.GetTable(qc.Selects[0].Table)
if err != nil {
return 0, err
}
i := 0
for _, cn := range ti.ColumnNames {
if _, ok := jt[cn]; !ok {
continue
}
if i != 0 {
io.WriteString(c.w, ", ")
}
c.w.WriteString(cn)
i++
}
return 0, nil
}

65
psql/insert_test.go Normal file
View File

@ -0,0 +1,65 @@
package psql
import (
"encoding/json"
"fmt"
"testing"
)
func singleInsert(t *testing.T) {
gql := `mutation {
product(id: 15, insert: $insert) {
id
name
}
}`
sql := `test`
vars := map[string]json.RawMessage{
"insert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`),
}
resSQL, err := compileGQLToPSQL(gql, vars)
if err != nil {
t.Fatal(err)
}
fmt.Println(">", string(resSQL))
if string(resSQL) != sql {
t.Fatal(errNotExpected)
}
}
func bulkInsert(t *testing.T) {
gql := `mutation {
product(id: 15, insert: $insert) {
id
name
}
}`
sql := `test`
vars := map[string]json.RawMessage{
"insert": json.RawMessage(` [{ "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }]`),
}
resSQL, err := compileGQLToPSQL(gql, vars)
if err != nil {
t.Fatal(err)
}
fmt.Println(">", string(resSQL))
if string(resSQL) != sql {
t.Fatal(errNotExpected)
}
}
func TestCompileInsert(t *testing.T) {
t.Run("singleInsert", singleInsert)
t.Run("bulkInsert", bulkInsert)
}

View File

@ -2,6 +2,7 @@ package psql
import ( import (
"bytes" "bytes"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@ -18,6 +19,8 @@ const (
closeBlock = 500 closeBlock = 500
) )
type Variables map[string]json.RawMessage
type Config struct { type Config struct {
Schema *DBSchema Schema *DBSchema
Vars map[string]string Vars map[string]string
@ -51,19 +54,30 @@ type compilerContext struct {
*Compiler *Compiler
} }
func (co *Compiler) CompileEx(qc *qcode.QCode) (uint32, []byte, error) { func (co *Compiler) CompileEx(qc *qcode.QCode, vars Variables) (uint32, []byte, error) {
w := &bytes.Buffer{} w := &bytes.Buffer{}
skipped, err := co.Compile(qc, w) skipped, err := co.Compile(qc, w, vars)
return skipped, w.Bytes(), err return skipped, w.Bytes(), err
} }
func (co *Compiler) Compile(qc *qcode.QCode, w *bytes.Buffer) (uint32, error) { func (co *Compiler) Compile(qc *qcode.QCode, w *bytes.Buffer, vars Variables) (uint32, error) {
if len(qc.Query.Selects) == 0 { switch qc.Type {
case qcode.QTQuery:
return co.compileQuery(qc, w)
case qcode.QTMutation:
return co.compileMutation(qc, w, vars)
}
return 0, errors.New("unknown operation")
}
func (co *Compiler) compileQuery(qc *qcode.QCode, w *bytes.Buffer) (uint32, error) {
if len(qc.Selects) == 0 {
return 0, errors.New("empty query") return 0, errors.New("empty query")
} }
c := &compilerContext{w, qc.Query.Selects, co} c := &compilerContext{w, qc.Selects, co}
root := &qc.Query.Selects[0] root := &qc.Selects[0]
st := NewStack() st := NewStack()
st.Push(root.ID + closeBlock) st.Push(root.ID + closeBlock)
@ -844,7 +858,7 @@ func (c *compilerContext) renderList(ex *qcode.Exp) {
func (c *compilerContext) renderVal(ex *qcode.Exp, func (c *compilerContext) renderVal(ex *qcode.Exp,
vars map[string]string) { vars map[string]string) {
io.WriteString(c.w, ` (`) //io.WriteString(c.w, ` (`)
switch ex.Type { switch ex.Type {
case qcode.ValBool, qcode.ValInt, qcode.ValFloat: case qcode.ValBool, qcode.ValInt, qcode.ValFloat:
if len(ex.Val) != 0 { if len(ex.Val) != 0 {
@ -852,21 +866,23 @@ func (c *compilerContext) renderVal(ex *qcode.Exp,
} else { } else {
c.w.WriteString(`''`) c.w.WriteString(`''`)
} }
case qcode.ValStr: case qcode.ValStr:
c.w.WriteString(`'`) c.w.WriteString(`'`)
c.w.WriteString(ex.Val) c.w.WriteString(ex.Val)
c.w.WriteString(`'`) c.w.WriteString(`'`)
case qcode.ValVar: case qcode.ValVar:
if val, ok := vars[ex.Val]; ok { if val, ok := vars[ex.Val]; ok {
c.w.WriteString(val) c.w.WriteString(val)
} else { } else {
//fmt.Fprintf(w, `'{{%s}}'`, ex.Val) //fmt.Fprintf(w, `'{{%s}}'`, ex.Val)
c.w.WriteString(`'{{`) c.w.WriteString(`{{`)
c.w.WriteString(ex.Val) c.w.WriteString(ex.Val)
c.w.WriteString(`}}'`) c.w.WriteString(`}}`)
} }
} }
c.w.WriteString(`)`) //c.w.WriteString(`)`)
} }
func funcPrefixLen(fn string) int { func funcPrefixLen(fn string) int {

View File

@ -125,13 +125,13 @@ func TestMain(m *testing.M) {
os.Exit(m.Run()) os.Exit(m.Run())
} }
func compileGQLToPSQL(gql string) ([]byte, error) { func compileGQLToPSQL(gql string, vars Variables) ([]byte, error) {
qc, err := qcompile.Compile([]byte(gql)) qc, err := qcompile.Compile([]byte(gql))
if err != nil { if err != nil {
return nil, err return nil, err
} }
_, sqlStmt, err := pcompile.CompileEx(qc) _, sqlStmt, err := pcompile.CompileEx(qc, vars)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -164,7 +164,7 @@ func withComplexArgs(t *testing.T) {
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products" ORDER BY "products_0_price_ob" DESC), '[]') AS "products" FROM (SELECT DISTINCT ON ("products_0_price_ob") row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products", "products_0"."price" AS "products_0_price_ob" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("products"."id") < (28)) AND (("products"."id") >= (20))) LIMIT ('30') :: integer) AS "products_0" ORDER BY "products_0_price_ob" DESC LIMIT ('30') :: integer) AS "products_0") AS "done_1337";` sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products" ORDER BY "products_0_price_ob" DESC), '[]') AS "products" FROM (SELECT DISTINCT ON ("products_0_price_ob") row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products", "products_0"."price" AS "products_0_price_ob" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("products"."id") < (28)) AND (("products"."id") >= (20))) LIMIT ('30') :: integer) AS "products_0" ORDER BY "products_0_price_ob" DESC LIMIT ('30') :: integer) AS "products_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql) resSQL, err := compileGQLToPSQL(gql, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -192,7 +192,7 @@ func withWhereMultiOr(t *testing.T) {
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("products"."price") < (20)) OR (("products"."price") > (10)) OR NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";` sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("products"."price") < (20)) OR (("products"."price") > (10)) OR NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql) resSQL, err := compileGQLToPSQL(gql, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -218,7 +218,7 @@ func withWhereIsNull(t *testing.T) {
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("products"."price") > (10)) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";` sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("products"."price") > (10)) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql) resSQL, err := compileGQLToPSQL(gql, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -244,7 +244,7 @@ func withWhereAndList(t *testing.T) {
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("products"."price") > (10)) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";` sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("products"."price") > (10)) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql) resSQL, err := compileGQLToPSQL(gql, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -264,7 +264,7 @@ func fetchByID(t *testing.T) {
sql := `SELECT json_object_agg('product', product) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "product_0"."id" AS "id", "product_0"."name" AS "name") AS "sel_0")) AS "product" FROM (SELECT "product"."id", "product"."name" FROM "products" AS "product" WHERE ((("product"."price") > (0)) AND (("product"."price") < (8)) AND (("id") = (15))) LIMIT ('1') :: integer) AS "product_0" LIMIT ('1') :: integer) AS "done_1337";` sql := `SELECT json_object_agg('product', product) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "product_0"."id" AS "id", "product_0"."name" AS "name") AS "sel_0")) AS "product" FROM (SELECT "product"."id", "product"."name" FROM "products" AS "product" WHERE ((("product"."price") > (0)) AND (("product"."price") < (8)) AND (("id") = (15))) LIMIT ('1') :: integer) AS "product_0" LIMIT ('1') :: integer) AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql) resSQL, err := compileGQLToPSQL(gql, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -284,7 +284,7 @@ func searchQuery(t *testing.T) {
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("tsv") @@ to_tsquery('Imperial'))) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";` sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("tsv") @@ to_tsquery('Imperial'))) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql) resSQL, err := compileGQLToPSQL(gql, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -307,7 +307,7 @@ func oneToMany(t *testing.T) {
sql := `SELECT json_object_agg('users', users) FROM (SELECT coalesce(json_agg("users"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."email" AS "email", "products_1_join"."products" AS "products") AS "sel_0")) AS "users" FROM (SELECT "users"."email", "users"."id" FROM "users" WHERE ((("users"."id") = ('{{user_id}}'))) LIMIT ('20') :: integer) AS "users_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "products_1"."name" AS "name", "products_1"."price" AS "price") AS "sel_1")) AS "products" FROM (SELECT "products"."name", "products"."price" FROM "products" WHERE ((("products"."user_id") = ("users_0"."id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "products_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "users_0") AS "done_1337";` sql := `SELECT json_object_agg('users', users) FROM (SELECT coalesce(json_agg("users"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."email" AS "email", "products_1_join"."products" AS "products") AS "sel_0")) AS "users" FROM (SELECT "users"."email", "users"."id" FROM "users" WHERE ((("users"."id") = ('{{user_id}}'))) LIMIT ('20') :: integer) AS "users_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "products_1"."name" AS "name", "products_1"."price" AS "price") AS "sel_1")) AS "products" FROM (SELECT "products"."name", "products"."price" FROM "products" WHERE ((("products"."user_id") = ("users_0"."id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "products_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "users_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql) resSQL, err := compileGQLToPSQL(gql, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -330,7 +330,7 @@ func belongsTo(t *testing.T) {
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."price" AS "price", "users_1_join"."users" AS "users") AS "sel_0")) AS "products" FROM (SELECT "products"."name", "products"."price", "products"."user_id" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8))) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("users"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "users_1"."email" AS "email") AS "sel_1")) AS "users" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = ("products_0"."user_id"))) LIMIT ('20') :: integer) AS "users_1" LIMIT ('20') :: integer) AS "users_1") AS "users_1_join" ON ('true') LIMIT ('20') :: integer) AS "products_0") AS "done_1337";` sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."price" AS "price", "users_1_join"."users" AS "users") AS "sel_0")) AS "products" FROM (SELECT "products"."name", "products"."price", "products"."user_id" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8))) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("users"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "users_1"."email" AS "email") AS "sel_1")) AS "users" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = ("products_0"."user_id"))) LIMIT ('20') :: integer) AS "users_1" LIMIT ('20') :: integer) AS "users_1") AS "users_1_join" ON ('true') LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql) resSQL, err := compileGQLToPSQL(gql, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -353,7 +353,7 @@ func manyToMany(t *testing.T) {
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "customers_1_join"."customers" AS "customers") AS "sel_0")) AS "products" FROM (SELECT "products"."name", "products"."id" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8))) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("customers"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "customers_1"."email" AS "email", "customers_1"."full_name" AS "full_name") AS "sel_1")) AS "customers" FROM (SELECT "customers"."email", "customers"."full_name" FROM "customers" LEFT OUTER JOIN "purchases" ON (("purchases"."product_id") = ("products_0"."id")) WHERE ((("customers"."id") = ("purchases"."customer_id"))) LIMIT ('20') :: integer) AS "customers_1" LIMIT ('20') :: integer) AS "customers_1") AS "customers_1_join" ON ('true') LIMIT ('20') :: integer) AS "products_0") AS "done_1337";` sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "customers_1_join"."customers" AS "customers") AS "sel_0")) AS "products" FROM (SELECT "products"."name", "products"."id" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8))) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("customers"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "customers_1"."email" AS "email", "customers_1"."full_name" AS "full_name") AS "sel_1")) AS "customers" FROM (SELECT "customers"."email", "customers"."full_name" FROM "customers" LEFT OUTER JOIN "purchases" ON (("purchases"."product_id") = ("products_0"."id")) WHERE ((("customers"."id") = ("purchases"."customer_id"))) LIMIT ('20') :: integer) AS "customers_1" LIMIT ('20') :: integer) AS "customers_1") AS "customers_1_join" ON ('true') LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql) resSQL, err := compileGQLToPSQL(gql, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -376,7 +376,7 @@ func manyToManyReverse(t *testing.T) {
sql := `SELECT json_object_agg('customers', customers) FROM (SELECT coalesce(json_agg("customers"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "customers_0"."email" AS "email", "customers_0"."full_name" AS "full_name", "products_1_join"."products" AS "products") AS "sel_0")) AS "customers" FROM (SELECT "customers"."email", "customers"."full_name", "customers"."id" FROM "customers" LIMIT ('20') :: integer) AS "customers_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "products_1"."name" AS "name") AS "sel_1")) AS "products" FROM (SELECT "products"."name" FROM "products" LEFT OUTER JOIN "purchases" ON (("purchases"."customer_id") = ("customers_0"."id")) WHERE ((("products"."id") = ("purchases"."product_id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "products_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "customers_0") AS "done_1337";` sql := `SELECT json_object_agg('customers', customers) FROM (SELECT coalesce(json_agg("customers"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "customers_0"."email" AS "email", "customers_0"."full_name" AS "full_name", "products_1_join"."products" AS "products") AS "sel_0")) AS "customers" FROM (SELECT "customers"."email", "customers"."full_name", "customers"."id" FROM "customers" LIMIT ('20') :: integer) AS "customers_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "products_1"."name" AS "name") AS "sel_1")) AS "products" FROM (SELECT "products"."name" FROM "products" LEFT OUTER JOIN "purchases" ON (("purchases"."customer_id") = ("customers_0"."id")) WHERE ((("products"."id") = ("purchases"."product_id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "products_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "customers_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql) resSQL, err := compileGQLToPSQL(gql, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -396,7 +396,7 @@ func aggFunction(t *testing.T) {
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."count_price" AS "count_price") AS "sel_0")) AS "products" FROM (SELECT "products"."name", count("products"."price") AS "count_price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8))) GROUP BY "products"."name" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";` sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."count_price" AS "count_price") AS "sel_0")) AS "products" FROM (SELECT "products"."name", count("products"."price") AS "count_price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8))) GROUP BY "products"."name" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql) resSQL, err := compileGQLToPSQL(gql, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -416,7 +416,7 @@ func aggFunctionWithFilter(t *testing.T) {
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."max_price" AS "max_price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", max("products"."price") AS "max_price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("products"."id") > (10))) GROUP BY "products"."id" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";` sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."max_price" AS "max_price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", max("products"."price") AS "max_price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("products"."id") > (10))) GROUP BY "products"."id" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql) resSQL, err := compileGQLToPSQL(gql, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -436,7 +436,7 @@ func queryWithVariables(t *testing.T) {
sql := `SELECT json_object_agg('product', product) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "product_0"."id" AS "id", "product_0"."name" AS "name") AS "sel_0")) AS "product" FROM (SELECT "product"."id", "product"."name" FROM "products" AS "product" WHERE ((("product"."price") > (0)) AND (("product"."price") < (8)) AND (("product"."price") = ('{{product_price}}')) AND (("id") = ('{{product_id}}'))) LIMIT ('1') :: integer) AS "product_0" LIMIT ('1') :: integer) AS "done_1337";` sql := `SELECT json_object_agg('product', product) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "product_0"."id" AS "id", "product_0"."name" AS "name") AS "sel_0")) AS "product" FROM (SELECT "product"."id", "product"."name" FROM "products" AS "product" WHERE ((("product"."price") > (0)) AND (("product"."price") < (8)) AND (("product"."price") = ('{{product_price}}')) AND (("id") = ('{{product_id}}'))) LIMIT ('1') :: integer) AS "product_0" LIMIT ('1') :: integer) AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql) resSQL, err := compileGQLToPSQL(gql, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -455,7 +455,7 @@ func syntheticTables(t *testing.T) {
sql := `SELECT json_object_agg('me', me) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "me_0"."email" AS "email") AS "sel_0")) AS "me" FROM (SELECT "me"."email" FROM "users" AS "me" WHERE ((("me"."id") = ('{{user_id}}'))) LIMIT ('1') :: integer) AS "me_0" LIMIT ('1') :: integer) AS "done_1337";` sql := `SELECT json_object_agg('me', me) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "me_0"."email" AS "email") AS "sel_0")) AS "me" FROM (SELECT "me"."email" FROM "users" AS "me" WHERE ((("me"."id") = ('{{user_id}}'))) LIMIT ('1') :: integer) AS "me_0" LIMIT ('1') :: integer) AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql) resSQL, err := compileGQLToPSQL(gql, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -465,7 +465,7 @@ func syntheticTables(t *testing.T) {
} }
} }
func TestCompileGQL(t *testing.T) { func TestCompileSelect(t *testing.T) {
t.Run("withComplexArgs", withComplexArgs) t.Run("withComplexArgs", withComplexArgs)
t.Run("withWhereAndList", withWhereAndList) t.Run("withWhereAndList", withWhereAndList)
t.Run("withWhereIsNull", withWhereIsNull) t.Run("withWhereIsNull", withWhereIsNull)
@ -519,7 +519,7 @@ func BenchmarkCompile(b *testing.B) {
b.Fatal(err) b.Fatal(err)
} }
_, err = pcompile.Compile(qc, w) _, err = pcompile.Compile(qc, w, nil)
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
@ -540,7 +540,7 @@ func BenchmarkCompileParallel(b *testing.B) {
b.Fatal(err) b.Fatal(err)
} }
_, err = pcompile.Compile(qc, w) _, err = pcompile.Compile(qc, w, nil)
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }

View File

@ -111,6 +111,7 @@ type DBTableInfo struct {
PrimaryCol string PrimaryCol string
TSVCol string TSVCol string
Columns map[string]*DBColumn Columns map[string]*DBColumn
ColumnNames []string
} }
type RelType int type RelType int
@ -162,10 +163,13 @@ func (s *DBSchema) updateSchema(
// Foreign key columns in current table // Foreign key columns in current table
colByID := make(map[int]*DBColumn) colByID := make(map[int]*DBColumn)
columns := make(map[string]*DBColumn, len(cols)) columns := make(map[string]*DBColumn, len(cols))
colNames := make([]string, len(cols))
for i := range cols { for i := range cols {
c := cols[i] c := cols[i]
columns[strings.ToLower(c.Name)] = cols[i] name := strings.ToLower(c.Name)
columns[name] = cols[i]
colNames = append(colNames, name)
colByID[c.ID] = cols[i] colByID[c.ID] = cols[i]
} }
@ -174,6 +178,7 @@ func (s *DBSchema) updateSchema(
Name: t.Name, Name: t.Name,
Singular: true, Singular: true,
Columns: columns, Columns: columns,
ColumnNames: colNames,
} }
plural := strings.ToLower(flect.Pluralize(t.Name)) plural := strings.ToLower(flect.Pluralize(t.Name))
@ -181,6 +186,7 @@ func (s *DBSchema) updateSchema(
Name: t.Name, Name: t.Name,
Singular: false, Singular: false,
Columns: columns, Columns: columns,
ColumnNames: colNames,
} }
ct := strings.ToLower(t.Name) ct := strings.ToLower(t.Name)

View File

@ -100,19 +100,7 @@ var lexPool = sync.Pool{
} }
func Parse(gql []byte) (*Operation, error) { func Parse(gql []byte) (*Operation, error) {
return parseSelectionSet(nil, gql) return parseSelectionSet(gql)
}
func ParseQuery(gql []byte) (*Operation, error) {
op := opPool.Get().(*Operation)
op.Reset()
op.Type = opQuery
op.Name = ""
op.Fields = op.fieldsA[:0]
op.Args = op.argsA[:0]
return parseSelectionSet(op, gql)
} }
func ParseArgValue(argVal string) (*Node, error) { func ParseArgValue(argVal string) (*Node, error) {
@ -134,7 +122,7 @@ func ParseArgValue(argVal string) (*Node, error) {
return op, err return op, err
} }
func parseSelectionSet(op *Operation, gql []byte) (*Operation, error) { func parseSelectionSet(gql []byte) (*Operation, error) {
var err error var err error
if len(gql) == 0 { if len(gql) == 0 {
@ -154,14 +142,28 @@ func parseSelectionSet(op *Operation, gql []byte) (*Operation, error) {
items: l.items, items: l.items,
} }
if op == nil { var op *Operation
op, err = p.parseOp()
} else {
if p.peek(itemObjOpen) { if p.peek(itemObjOpen) {
p.ignore() p.ignore()
} }
if p.peek(itemName) {
op = opPool.Get().(*Operation)
op.Reset()
op.Type = opQuery
op.Name = ""
op.Fields = op.fieldsA[:0]
op.Args = op.argsA[:0]
op.Fields, err = p.parseFields(op.Fields) op.Fields, err = p.parseFields(op.Fields)
} else {
op, err = p.parseOp()
if err != nil {
return nil, err
}
} }
lexPool.Put(l) lexPool.Put(l)

View File

@ -45,10 +45,10 @@ func compareOp(op1, op2 Operation) error {
} }
*/ */
func TestCompile(t *testing.T) { func TestCompile1(t *testing.T) {
qcompile, _ := NewCompiler(Config{}) qcompile, _ := NewCompiler(Config{})
_, err := qcompile.CompileQuery([]byte(` _, err := qcompile.Compile([]byte(`
product(id: 15) { product(id: 15) {
id id
name name
@ -59,9 +59,39 @@ func TestCompile(t *testing.T) {
} }
} }
func TestCompile2(t *testing.T) {
qcompile, _ := NewCompiler(Config{})
_, err := qcompile.Compile([]byte(`
query { product(id: 15) {
id
name
} }`))
if err != nil {
t.Fatal(err)
}
}
func TestCompile3(t *testing.T) {
qcompile, _ := NewCompiler(Config{})
_, err := qcompile.Compile([]byte(`
mutation {
product(id: 15, name: "Test") {
id
name
}
}`))
if err != nil {
t.Fatal(err)
}
}
func TestInvalidCompile1(t *testing.T) { func TestInvalidCompile1(t *testing.T) {
qcompile, _ := NewCompiler(Config{}) qcompile, _ := NewCompiler(Config{})
_, err := qcompile.CompileQuery([]byte(`#`)) _, err := qcompile.Compile([]byte(`#`))
if err == nil { if err == nil {
t.Fatal(errors.New("expecting an error")) t.Fatal(errors.New("expecting an error"))
@ -70,7 +100,7 @@ func TestInvalidCompile1(t *testing.T) {
func TestInvalidCompile2(t *testing.T) { func TestInvalidCompile2(t *testing.T) {
qcompile, _ := NewCompiler(Config{}) qcompile, _ := NewCompiler(Config{})
_, err := qcompile.CompileQuery([]byte(`{u(where:{not:0})}`)) _, err := qcompile.Compile([]byte(`{u(where:{not:0})}`))
if err == nil { if err == nil {
t.Fatal(errors.New("expecting an error")) t.Fatal(errors.New("expecting an error"))
@ -79,7 +109,7 @@ func TestInvalidCompile2(t *testing.T) {
func TestEmptyCompile(t *testing.T) { func TestEmptyCompile(t *testing.T) {
qcompile, _ := NewCompiler(Config{}) qcompile, _ := NewCompiler(Config{})
_, err := qcompile.CompileQuery([]byte(``)) _, err := qcompile.Compile([]byte(``))
if err == nil { if err == nil {
t.Fatal(errors.New("expecting an error")) t.Fatal(errors.New("expecting an error"))
@ -114,7 +144,7 @@ func BenchmarkQCompile(b *testing.B) {
b.ReportAllocs() b.ReportAllocs()
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
_, err := qcompile.CompileQuery(gql) _, err := qcompile.Compile(gql)
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
@ -130,7 +160,7 @@ func BenchmarkQCompileP(b *testing.B) {
b.RunParallel(func(pb *testing.PB) { b.RunParallel(func(pb *testing.PB) {
for pb.Next() { for pb.Next() {
_, err := qcompile.CompileQuery(gql) _, err := qcompile.Compile(gql)
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)

View File

@ -10,15 +10,17 @@ import (
"github.com/gobuffalo/flect" "github.com/gobuffalo/flect"
) )
type QType int
const ( const (
maxSelectors = 30 maxSelectors = 30
QTQuery QType = iota + 1
QTMutation
) )
type QCode struct { type QCode struct {
Query *Query Type QType
}
type Query struct {
Selects []Select Selects []Select
} }
@ -149,6 +151,11 @@ type Compiler struct {
ka bool ka bool
} }
var opMap = map[parserType]QType{
opQuery: QTQuery,
opMutate: QTMutation,
}
var expPool = sync.Pool{ var expPool = sync.Pool{
New: func() interface{} { return new(Exp) }, New: func() interface{} { return new(Exp) },
} }
@ -196,44 +203,23 @@ func (com *Compiler) Compile(query []byte) (*QCode, error) {
return nil, err return nil, err
} }
switch op.Type { qc.Selects, err = com.compileQuery(op)
case opQuery:
qc.Query, err = com.compileQuery(op)
case opMutate:
case opSub:
default:
err = fmt.Errorf("Unknown operation type %d", op.Type)
}
if err != nil { if err != nil {
return nil, err return nil, err
} }
if t, ok := opMap[op.Type]; ok {
qc.Type = t
} else {
return nil, fmt.Errorf("Unknown operation type %d", op.Type)
}
opPool.Put(op) opPool.Put(op)
return &qc, nil return &qc, nil
} }
func (com *Compiler) CompileQuery(query []byte) (*QCode, error) { func (com *Compiler) compileQuery(op *Operation) ([]Select, error) {
var err error
op, err := ParseQuery(query)
if err != nil {
return nil, err
}
qc := &QCode{}
qc.Query, err = com.compileQuery(op)
opPool.Put(op)
if err != nil {
return nil, err
}
return qc, nil
}
func (com *Compiler) compileQuery(op *Operation) (*Query, error) {
id := int32(0) id := int32(0)
parentID := int32(0) parentID := int32(0)
@ -344,7 +330,7 @@ func (com *Compiler) compileQuery(op *Operation) (*Query, error) {
return nil, errors.New("invalid query") return nil, errors.New("invalid query")
} }
return &Query{selects[:id]}, nil return selects[:id], nil
} }
func (com *Compiler) compileArgs(sel *Select, args []Arg) error { func (com *Compiler) compileArgs(sel *Select, args []Arg) error {
@ -661,14 +647,6 @@ func (com *Compiler) compileArgOffset(sel *Select, arg *Arg) error {
return nil return nil
} }
func compileMutate() (*Query, error) {
return nil, nil
}
func compileSub() (*Query, error) {
return nil, nil
}
func newExp(st *util.Stack, node *Node, usePool bool) (*Exp, error) { func newExp(st *util.Stack, node *Node, usePool bool) (*Exp, error) {
name := node.Name name := node.Name
if name[0] == '_' { if name[0] == '_' {

View File

@ -1,6 +1,7 @@
package serv package serv
import ( import (
"encoding/json"
"fmt" "fmt"
"io/ioutil" "io/ioutil"
"log" "log"
@ -9,9 +10,15 @@ import (
"strings" "strings"
) )
const (
AL_QUERY int = iota + 1
AL_VARS
)
type allowItem struct { type allowItem struct {
uri string uri string
gql string gql string
vars json.RawMessage
} }
var _allowList allowList var _allowList allowList
@ -79,6 +86,7 @@ func (al *allowList) add(req *gqlReq) {
al.saveChan <- &allowItem{ al.saveChan <- &allowItem{
uri: req.ref, uri: req.ref,
gql: req.Query, gql: req.Query,
vars: req.Vars,
} }
} }
@ -93,32 +101,62 @@ func (al *allowList) load() {
} }
var uri string var uri string
var varBytes []byte
s, e, c := 0, 0, 0 s, e, c := 0, 0, 0
ty := 0
for { for {
if c == 0 && b[e] == '#' { if c == 0 && b[e] == '#' {
s = e s = e
for b[e] != '\n' && e < len(b) { for e < len(b) && b[e] != '\n' {
e++ e++
} }
if (e - s) > 2 { if (e - s) > 2 {
uri = strings.TrimSpace(string(b[(s + 1):e])) uri = strings.TrimSpace(string(b[(s + 1):e]))
} }
} }
if b[e] == '{' {
if e >= len(b) {
break
}
if matchPrefix(b, e, "query") || matchPrefix(b, e, "mutation") {
if c == 0 { if c == 0 {
s = e s = e
} }
ty = AL_QUERY
} else if matchPrefix(b, e, "variables") {
if c == 0 {
s = e + len("variables") + 1
}
ty = AL_VARS
} else if b[e] == '{' {
c++ c++
} else if b[e] == '}' { } else if b[e] == '}' {
c-- c--
if c == 0 { if c == 0 {
q := b[s:(e + 1)] if ty == AL_QUERY {
al.list[gqlHash(q)] = &allowItem{ q := string(b[s:(e + 1)])
item := &allowItem{
uri: uri, uri: uri,
gql: string(q), gql: q,
} }
if len(varBytes) != 0 {
item.vars = varBytes
}
al.list[gqlHash(q, varBytes)] = item
varBytes = nil
} else if ty == AL_VARS {
varBytes = b[s:(e + 1)]
}
ty = 0
} }
} }
@ -130,7 +168,7 @@ func (al *allowList) load() {
} }
func (al *allowList) save(item *allowItem) { func (al *allowList) save(item *allowItem) {
al.list[gqlHash([]byte(item.gql))] = item al.list[gqlHash(item.gql, item.vars)] = item
f, err := os.Create(al.filepath) f, err := os.Create(al.filepath)
if err != nil { if err != nil {
@ -141,10 +179,10 @@ func (al *allowList) save(item *allowItem) {
defer f.Close() defer f.Close()
keys := []string{} keys := []string{}
urlMap := make(map[string][]string) urlMap := make(map[string][]*allowItem)
for _, v := range al.list { for _, v := range al.list {
urlMap[v.uri] = append(urlMap[v.uri], v.gql) urlMap[v.uri] = append(urlMap[v.uri], v)
} }
for k := range urlMap { for k := range urlMap {
@ -159,7 +197,28 @@ func (al *allowList) save(item *allowItem) {
f.WriteString(fmt.Sprintf("# %s\n\n", k)) f.WriteString(fmt.Sprintf("# %s\n\n", k))
for i := range v { for i := range v {
f.WriteString(fmt.Sprintf("query %s\n\n", v[i])) if len(v[i].vars) != 0 {
vj, err := json.MarshalIndent(v[i].vars, "", "\t")
if err != nil {
logger.Warn().Err(err).Msg("Failed to write allow list 'vars' to file")
continue
}
f.WriteString(fmt.Sprintf("variables %s\n\n", vj))
}
f.WriteString(fmt.Sprintf("%s\n\n", v[i].gql))
} }
} }
} }
func matchPrefix(b []byte, i int, s string) bool {
if (len(b) - i) < len(s) {
return false
}
for n := 0; n < len(s); n++ {
if b[(i+n)] != s[n] {
return false
}
}
return true
}

View File

@ -14,6 +14,7 @@ import (
"github.com/cespare/xxhash/v2" "github.com/cespare/xxhash/v2"
"github.com/dosco/super-graph/jsn" "github.com/dosco/super-graph/jsn"
"github.com/dosco/super-graph/psql"
"github.com/dosco/super-graph/qcode" "github.com/dosco/super-graph/qcode"
"github.com/go-pg/pg" "github.com/go-pg/pg"
"github.com/valyala/fasttemplate" "github.com/valyala/fasttemplate"
@ -42,7 +43,7 @@ func (c *coreContext) handleReq(w io.Writer, req *http.Request) error {
if conf.UseAllowList { if conf.UseAllowList {
var ps *preparedItem var ps *preparedItem
data, ps, err = c.resolvePreparedSQL([]byte(c.req.Query)) data, ps, err = c.resolvePreparedSQL(c.req.Query)
if err != nil { if err != nil {
return err return err
} }
@ -52,7 +53,7 @@ func (c *coreContext) handleReq(w io.Writer, req *http.Request) error {
} else { } else {
qc, err = qcompile.CompileQuery([]byte(c.req.Query)) qc, err = qcompile.Compile([]byte(c.req.Query))
if err != nil { if err != nil {
return err return err
} }
@ -67,7 +68,7 @@ func (c *coreContext) handleReq(w io.Writer, req *http.Request) error {
return c.render(w, data) return c.render(w, data)
} }
sel := qc.Query.Selects sel := qc.Selects
h := xxhash.New() h := xxhash.New()
// fetch the field name used within the db response json // fetch the field name used within the db response json
@ -252,8 +253,8 @@ func (c *coreContext) resolveRemotes(
return to, cerr return to, cerr
} }
func (c *coreContext) resolvePreparedSQL(gql []byte) ([]byte, *preparedItem, error) { func (c *coreContext) resolvePreparedSQL(gql string) ([]byte, *preparedItem, error) {
ps, ok := _preparedList[gqlHash(gql)] ps, ok := _preparedList[gqlHash(gql, c.req.Vars)]
if !ok { if !ok {
return nil, nil, errUnauthorized return nil, nil, errUnauthorized
} }
@ -266,17 +267,22 @@ func (c *coreContext) resolvePreparedSQL(gql []byte) ([]byte, *preparedItem, err
return nil, nil, err return nil, nil, err
} }
fmt.Printf("PRE: %#v %#v\n", ps.stmt, vars) fmt.Printf("PRE: %v\n", ps.stmt)
return []byte(root), ps, nil return []byte(root), ps, nil
} }
func (c *coreContext) resolveSQL(qc *qcode.QCode) ( func (c *coreContext) resolveSQL(qc *qcode.QCode) (
[]byte, uint32, error) { []byte, uint32, error) {
stmt := &bytes.Buffer{} stmt := &bytes.Buffer{}
skipped, err := pcompile.Compile(qc, stmt) vars := make(map[string]json.RawMessage)
if err := json.Unmarshal(c.req.Vars, &vars); err != nil {
return nil, 0, err
}
skipped, err := pcompile.Compile(qc, stmt, psql.Variables(vars))
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
@ -284,7 +290,7 @@ func (c *coreContext) resolveSQL(qc *qcode.QCode) (
t := fasttemplate.New(stmt.String(), openVar, closeVar) t := fasttemplate.New(stmt.String(), openVar, closeVar)
stmt.Reset() stmt.Reset()
_, err = t.Execute(stmt, varMap(c)) _, err = t.ExecuteFunc(stmt, varMap(c))
if err == errNoUserID && if err == errNoUserID &&
authFailBlock == authFailBlockPerQuery && authFailBlock == authFailBlockPerQuery &&
@ -317,10 +323,10 @@ func (c *coreContext) resolveSQL(qc *qcode.QCode) (
return nil, 0, err return nil, 0, err
} }
if conf.EnableTracing && len(qc.Query.Selects) != 0 { if conf.EnableTracing && len(qc.Selects) != 0 {
c.addTrace( c.addTrace(
qc.Query.Selects, qc.Selects,
qc.Query.Selects[0].ID, qc.Selects[0].ID,
st) st)
} }

35
serv/core_test.go Normal file
View File

@ -0,0 +1,35 @@
package serv
/*
func simpleMutation(t *testing.T) {
gql := `mutation {
product(id: 15, insert: { name: "Test", price: 20.5 }) {
id
name
}
}`
sql := `test`
backgroundCtx := context.Background()
ctx := &coreContext{Context: backgroundCtx}
resSQL, err := compileGQLToPSQL(gql)
if err != nil {
t.Fatal(err)
}
fmt.Println(">", string(resSQL))
if string(resSQL) != sql {
t.Fatal(errNotExpected)
}
}
func TestCompileGQL(t *testing.T) {
t.Run("withComplexArgs", withComplexArgs)
t.Run("simpleMutation", simpleMutation)
}
*/

View File

@ -28,11 +28,11 @@ var (
type gqlReq struct { type gqlReq struct {
OpName string `json:"operationName"` OpName string `json:"operationName"`
Query string `json:"query"` Query string `json:"query"`
Vars variables `json:"variables"` Vars json.RawMessage `json:"variables"`
ref string ref string
} }
type variables map[string]interface{} type variables map[string]json.RawMessage
type gqlResp struct { type gqlResp struct {
Error string `json:"error,omitempty"` Error string `json:"error,omitempty"`

View File

@ -2,9 +2,11 @@ package serv
import ( import (
"bytes" "bytes"
"encoding/json"
"fmt" "fmt"
"io" "io"
"github.com/dosco/super-graph/psql"
"github.com/dosco/super-graph/qcode" "github.com/dosco/super-graph/qcode"
"github.com/go-pg/pg" "github.com/go-pg/pg"
"github.com/valyala/fasttemplate" "github.com/valyala/fasttemplate"
@ -12,7 +14,7 @@ import (
type preparedItem struct { type preparedItem struct {
stmt *pg.Stmt stmt *pg.Stmt
args []string args [][]byte
skipped uint32 skipped uint32
qc *qcode.QCode qc *qcode.QCode
} }
@ -25,36 +27,46 @@ func initPreparedList() {
_preparedList = make(map[string]*preparedItem) _preparedList = make(map[string]*preparedItem)
for k, v := range _allowList.list { for k, v := range _allowList.list {
err := prepareStmt(k, v.gql) err := prepareStmt(k, v.gql, v.vars)
if err != nil { if err != nil {
panic(err) panic(err)
} }
} }
} }
func prepareStmt(key, gql string) error { func prepareStmt(key, gql string, varBytes json.RawMessage) error {
if len(gql) == 0 || len(key) == 0 { if len(gql) == 0 || len(key) == 0 {
return nil return nil
} }
qc, err := qcompile.CompileQuery([]byte(gql)) qc, err := qcompile.Compile([]byte(gql))
if err != nil { if err != nil {
return err return err
} }
var vars map[string]json.RawMessage
if len(varBytes) != 0 {
vars = make(map[string]json.RawMessage)
if err := json.Unmarshal(varBytes, &vars); err != nil {
return err
}
}
buf := &bytes.Buffer{} buf := &bytes.Buffer{}
skipped, err := pcompile.Compile(qc, buf) skipped, err := pcompile.Compile(qc, buf, psql.Variables(vars))
if err != nil { if err != nil {
return err return err
} }
t := fasttemplate.New(buf.String(), `('{{`, `}}')`) t := fasttemplate.New(buf.String(), `{{`, `}}`)
am := make([]string, 0, 5) am := make([][]byte, 0, 5)
i := 0 i := 0
finalSQL := t.ExecuteFuncString(func(w io.Writer, tag string) (int, error) { finalSQL := t.ExecuteFuncString(func(w io.Writer, tag string) (int, error) {
am = append(am, tag) am = append(am, []byte(tag))
i++ i++
return w.Write([]byte(fmt.Sprintf("$%d", i))) return w.Write([]byte(fmt.Sprintf("$%d", i)))
}) })

View File

@ -4,8 +4,12 @@ import (
"bytes" "bytes"
"crypto/sha1" "crypto/sha1"
"encoding/hex" "encoding/hex"
"io"
"sort"
"strings"
"github.com/cespare/xxhash/v2" "github.com/cespare/xxhash/v2"
"github.com/dosco/super-graph/jsn"
) )
func mkkey(h *xxhash.Digest, k1 string, k2 string) uint64 { func mkkey(h *xxhash.Digest, k1 string, k2 string) uint64 {
@ -17,8 +21,8 @@ func mkkey(h *xxhash.Digest, k1 string, k2 string) uint64 {
return v return v
} }
func gqlHash(b []byte) string { func gqlHash(b string, vars []byte) string {
b = bytes.TrimSpace(b) b = strings.TrimSpace(b)
h := sha1.New() h := sha1.New()
s, e := 0, 0 s, e := 0, 0
@ -45,13 +49,27 @@ func gqlHash(b []byte) string {
if e != 0 { if e != 0 {
b0 = b[(e - 1)] b0 = b[(e - 1)]
} }
h.Write(bytes.ToLower(b[s:e])) io.WriteString(h, strings.ToLower(b[s:e]))
} }
if e >= len(b) { if e >= len(b) {
break break
} }
} }
if vars == nil {
return hex.EncodeToString(h.Sum(nil))
}
fields := jsn.Keys([]byte(vars))
sort.Slice(fields, func(i, j int) bool {
return bytes.Compare(fields[i], fields[j]) == -1
})
for i := range fields {
h.Write(fields[i])
}
return hex.EncodeToString(h.Sum(nil)) return hex.EncodeToString(h.Sum(nil))
} }

View File

@ -6,7 +6,7 @@ import (
) )
func TestRelaxHash1(t *testing.T) { func TestRelaxHash1(t *testing.T) {
var v1 = []byte(` var v1 = `
products( products(
limit: 30, limit: 30,
@ -14,18 +14,18 @@ func TestRelaxHash1(t *testing.T) {
id id
name name
price price
}`) }`
var v2 = []byte(` var v2 = `
products products
(limit: 30, where: { id: { AND: { greater_or_equals: 20, lt: 28 } } }) { (limit: 30, where: { id: { AND: { greater_or_equals: 20, lt: 28 } } }) {
id id
name name
price price
} `) } `
h1 := gqlHash(v1) h1 := gqlHash(v1, nil)
h2 := gqlHash(v2) h2 := gqlHash(v2, nil)
if strings.Compare(h1, h2) != 0 { if strings.Compare(h1, h2) != 0 {
t.Fatal("Hashes don't match they should") t.Fatal("Hashes don't match they should")
@ -33,7 +33,7 @@ func TestRelaxHash1(t *testing.T) {
} }
func TestRelaxHash2(t *testing.T) { func TestRelaxHash2(t *testing.T) {
var v1 = []byte(` var v1 = `
{ {
products( products(
limit: 30 limit: 30
@ -49,12 +49,119 @@ func TestRelaxHash2(t *testing.T) {
email email
} }
} }
}`) }`
var v2 = []byte(` { products( limit: 30, order_by: { price: desc }, distinct: [ price ] where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) { id name price user { id email } } } `) var v2 = ` { products( limit: 30, order_by: { price: desc }, distinct: [ price ] where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) { id name price user { id email } } } `
h1 := gqlHash(v1) h1 := gqlHash(v1, nil)
h2 := gqlHash(v2) h2 := gqlHash(v2, nil)
if strings.Compare(h1, h2) != 0 {
t.Fatal("Hashes don't match they should")
}
}
func TestRelaxHashWithVars1(t *testing.T) {
var q1 = `
products(
limit: 30,
where: { id: { AND: { greater_or_equals: 20, lt: 28 } } }) {
id
name
price
}`
var v1 = `
{
"insert": {
"name": "Hello",
"description": "World",
"created_at": "now",
"updated_at": "now",
"test": { "type2": "b", "type1": "a" }
},
"user": 123
}`
var q2 = `
products
(limit: 30, where: { id: { AND: { greater_or_equals: 20, lt: 28 } } }) {
id
name
price
} `
var v2 = `{
"insert": {
"created_at": "now",
"test": { "type1": "a", "type2": "b" },
"name": "Hello",
"updated_at": "now",
"description": "World"
},
"user": 123
}`
h1 := gqlHash(q1, []byte(va1))
h2 := gqlHash(q2, []byte(va2))
if strings.Compare(h1, h2) != 0 {
t.Fatal("Hashes don't match they should")
}
}
func TestRelaxHashWithVars2(t *testing.T) {
var q1 = `
products(
limit: 30,
where: { id: { AND: { greater_or_equals: 20, lt: 28 } } }) {
id
name
price
}`
var v1 = `
{
"insert": [{
"name": "Hello",
"description": "World",
"created_at": "now",
"updated_at": "now",
"test": { "type2": "b", "type1": "a" }
},
{
"name": "Hello",
"description": "World",
"created_at": "now",
"updated_at": "now",
"test": { "type2": "b", "type1": "a" }
}],
"user": 123
}`
var q2 = `
products
(limit: 30, where: { id: { AND: { greater_or_equals: 20, lt: 28 } } }) {
id
name
price
} `
var v2 = `{
"insert": {
"created_at": "now",
"test": { "type1": "a", "type2": "b" },
"name": "Hello",
"updated_at": "now",
"description": "World"
},
"user": 123
}`
h1 := gqlHash(q1, []byte(va1))
h2 := gqlHash(q2, []byte(va2))
if strings.Compare(h1, h2) != 0 { if strings.Compare(h1, h2) != 0 {
t.Fatal("Hashes don't match they should") t.Fatal("Hashes don't match they should")

View File

@ -1,95 +1,107 @@
package serv package serv
import ( import (
"bytes"
"fmt"
"io" "io"
"strconv"
"strings"
"github.com/valyala/fasttemplate" "github.com/dosco/super-graph/jsn"
) )
func varMap(ctx *coreContext) variables { func varMap(ctx *coreContext) func(w io.Writer, tag string) (int, error) {
userIDFn := func(w io.Writer, _ string) (int, error) { return func(w io.Writer, tag string) (int, error) {
switch tag {
case "user_id":
if v := ctx.Value(userIDKey); v != nil { if v := ctx.Value(userIDKey); v != nil {
return w.Write([]byte(v.(string))) return stringVar(w, v.(string))
} }
return 0, errNoUserID return 0, errNoUserID
}
userIDProviderFn := func(w io.Writer, _ string) (int, error) { case "user_id_provider":
if v := ctx.Value(userIDProviderKey); v != nil { if v := ctx.Value(userIDProviderKey); v != nil {
return w.Write([]byte(v.(string))) return stringVar(w, v.(string))
} }
return 0, errNoUserID return 0, errNoUserID
} }
userIDTag := fasttemplate.TagFunc(userIDFn) fields := jsn.Get(ctx.req.Vars, [][]byte{[]byte(tag)})
userIDProviderTag := fasttemplate.TagFunc(userIDProviderFn) if len(fields) == 0 {
return 0, fmt.Errorf("variable '%s' not found", tag)
vm := variables{
"user_id": userIDTag,
"user_id_provider": userIDProviderTag,
"USER_ID": userIDTag,
"USER_ID_PROVIDER": userIDProviderTag,
} }
for k, v := range ctx.req.Vars { is := false
var buf []byte
k = strings.ToLower(k)
if _, ok := vm[k]; ok { for i := range fields[0].Value {
continue c := fields[0].Value[i]
if c != ' ' {
is = (c == '"') || (c == '{') || (c == '[')
break
}
} }
switch val := v.(type) { if is {
case string: return stringVarB(w, fields[0].Value)
vm[k] = val
case int:
vm[k] = strconv.AppendInt(buf, int64(val), 10)
case int64:
vm[k] = strconv.AppendInt(buf, val, 10)
case float64:
vm[k] = strconv.AppendFloat(buf, val, 'f', -1, 64)
} }
w.Write(fields[0].Value)
return 0, nil
} }
return vm
} }
func varList(ctx *coreContext, args []string) []interface{} { func varList(ctx *coreContext, args [][]byte) []interface{} {
vars := make([]interface{}, 0, len(args)) vars := make([]interface{}, len(args))
for k, v := range ctx.req.Vars { var fields map[string]interface{}
ctx.req.Vars[strings.ToLower(k)] = v var err error
if len(ctx.req.Vars) != 0 {
fields, _, err = jsn.Tree(ctx.req.Vars)
if err != nil {
logger.Warn().Err(err).Msg("Failed to parse variables")
}
} }
for i := range args { for i := range args {
arg := strings.ToLower(args[i]) av := args[i]
if arg == "user_id" { switch {
case bytes.Equal(av, []byte("user_id")):
if v := ctx.Value(userIDKey); v != nil { if v := ctx.Value(userIDKey); v != nil {
vars = append(vars, v.(string)) vars[i] = v.(string)
}
} }
if arg == "user_id_provider" { case bytes.Equal(av, []byte("user_id_provider")):
if v := ctx.Value(userIDProviderKey); v != nil { if v := ctx.Value(userIDProviderKey); v != nil {
vars = append(vars, v.(string)) vars[i] = v.(string)
}
} }
if v, ok := ctx.req.Vars[arg]; ok { default:
switch val := v.(type) { if v, ok := fields[string(av)]; ok {
case string: vars[i] = v
vars = append(vars, val)
case int:
vars = append(vars, strconv.FormatInt(int64(val), 10))
case int64:
vars = append(vars, strconv.FormatInt(int64(val), 10))
case float64:
vars = append(vars, strconv.FormatFloat(val, 'f', -1, 64))
} }
} }
} }
return vars return vars
} }
func stringVar(w io.Writer, v string) (int, error) {
if n, err := w.Write([]byte(`'`)); err != nil {
return n, err
}
if n, err := w.Write([]byte(v)); err != nil {
return n, err
}
return w.Write([]byte(`'`))
}
func stringVarB(w io.Writer, v []byte) (int, error) {
if n, err := w.Write([]byte(`'`)); err != nil {
return n, err
}
if n, err := w.Write(v); err != nil {
return n, err
}
return w.Write([]byte(`'`))
}