Add insert mutation with bulk insert

This commit is contained in:
Vikram Rangnekar 2019-09-05 00:09:56 -04:00
parent 5b9105ff0c
commit c0a21e448f
30 changed files with 1080 additions and 265 deletions

View File

@ -47,4 +47,79 @@ query {
email
}
}
}
variables {
"insert": {
"name": "Hello",
"description": "World",
"created_at": "now",
"updated_at": "now"
},
"user": 123
}
mutation {
products(insert: $insert) {
id
name
description
}
}
variables {
"insert": {
"name": "Hello",
"description": "World",
"created_at": "now",
"updated_at": "now"
},
"user": 123
}
mutation {
products(insert: $insert) {
id
}
}
variables {
"insert": {
"description": "World3",
"name": "Hello3",
"created_at": "now",
"updated_at": "now"
},
"user": 123
}
{
customers {
id
email
payments {
customer_id
amount
billing_details
}
}
}
variables {
"insert": {
"description": "World3",
"name": "Hello3",
"created_at": "now",
"updated_at": "now"
},
"user": 123
}
{
me {
id
full_name
}
}

View File

@ -84,8 +84,9 @@ database:
#log_level: "debug"
# Define variables here that you want to use in filters
# sub-queries must be wrapped in ()
variables:
account_id: "select account_id from users where id = $user_id"
account_id: "(select account_id from users where id = $user_id)"
# Define defaults to for the field key and values below
defaults:
@ -105,12 +106,12 @@ database:
# This filter will overwrite defaults.filter
# filter: ["{ id: { eq: $user_id } }"]
- name: products
# Multiple filters are AND'd together
filter: [
"{ price: { gt: 0 } }",
"{ price: { lt: 8 } }"
]
# - name: products
# # Multiple filters are AND'd together
# filter: [
# "{ price: { gt: 0 } }",
# "{ price: { lt: 8 } }"
# ]
- name: customers
# No filter is used for this field not

View File

@ -82,8 +82,9 @@ database:
#log_level: "debug"
# Define variables here that you want to use in filters
# sub-queries must be wrapped in ()
variables:
account_id: "select account_id from users where id = $user_id"
account_id: "(select account_id from users where id = $user_id)"
# Define defaults to for the field key and values below
defaults:

View File

@ -2,6 +2,8 @@ version: '3.4'
services:
db:
image: postgres
ports:
- "5432:5432"
# redis:
# image: redis:alpine

39
jsn/README.md Normal file
View File

@ -0,0 +1,39 @@
# JSN - Fast low allocation JSON library
## Design
This libary is designed as a set of seperate functions to extract data and mutate
JSON. All functions are focused on keeping allocations to a minimum and be as fast
as possible. The functions don't validate the JSON a seperate `Validate` function
does that.
The JSON parsing algo processes each object `{}` or array `[]` in a bottom up way
where once the end of the array or object is found only then the keys within it are
parsed from the top down.
```
{"id":1,"posts": [{"title":"PT1-1","description":"PD1-1"}], "full_name":"FN1","email":"E1" }
id: 1
posts: [{"title":"PT1-1","description":"PD1-1"}]
[{"title":"PT1-1","description":"PD1-1"}]
{"title":"PT1-1","description":"PD1-1"}
title: "PT1-1"
description: "PD1-1
full_name: "FN1"
email: "E1"
```
## Functions
- Strip: Strip a path from the root to a child node and return the rest
- Replace: Replace values by key
- Get: Get all keys
- Filter: Extract specific keys from an object
- Tree: Fetch unique keys from an array or object

View File

@ -43,7 +43,7 @@ func Get(b []byte, keys [][]byte) []Field {
kmap[xxhash.Sum64(keys[i])] = struct{}{}
}
res := make([]Field, 20)
res := make([]Field, 0, 20)
s, e, d := 0, 0, 0
@ -127,7 +127,7 @@ func Get(b []byte, keys [][]byte) []Field {
_, ok := kmap[xxhash.Sum64(k)]
if ok {
res[n] = Field{k, b[s:(e + 1)]}
res = append(res, Field{k, b[s:(e + 1)]})
n++
}

View File

@ -21,6 +21,10 @@ var (
"full_name": "Caroll Orn Sr.",
"email": "joannarau@hegmann.io",
"__twitter_id": "ABC123"
"more": [{
"__twitter_id": "more123",
"hello: "world
}]
}
},
{
@ -163,6 +167,7 @@ func TestGet(t *testing.T) {
{[]byte("__twitter_id"), []byte(`"ABCD"`)},
{[]byte("__twitter_id"), []byte(`"2048666903444506956"`)},
{[]byte("__twitter_id"), []byte(`"ABC123"`)},
{[]byte("__twitter_id"), []byte(`"more123"`)},
{[]byte("__twitter_id"),
[]byte(`[{ "name": "hello" }, { "name": "world"}]`)},
{[]byte("__twitter_id"),
@ -340,6 +345,80 @@ func TestReplaceEmpty(t *testing.T) {
}
}
func TestKeys1(t *testing.T) {
json := `[{"id":1,"posts": [{"title":"PT1-1","description":"PD1-1"}, {"title":"PT1-2","description":"PD1-2"}], "full_name":"FN1","email":"E1","books": [{"name":"BN1-1","description":"BD1-1"},{"name":"BN1-2","description":"BD1-2"},{"name":"BN1-2","description":"BD1-2"}]},{"id":1,"posts": [{"title":"PT1-1","description":"PD1-1"}, {"title":"PT1-2","description":"PD1-2"}], "full_name":"FN1","email":"E1","books": [{"name":"BN1-1","description":"BD1-1"},{"name":"BN1-2","description":"BD1-2"},{"name":"BN1-2","description":"BD1-2"}]},{"id":1,"posts": [{"title":"PT1-1","description":"PD1-1"}, {"title":"PT1-2","description":"PD1-2"}], "full_name":"FN1","email":"E1","books": [{"name":"BN1-1","description":"BD1-1"},{"name":"BN1-2","description":"BD1-2"},{"name":"BN1-2","description":"BD1-2"}]}]`
fields := Keys([]byte(json))
exp := []string{
"id", "posts", "title", "description", "full_name", "email", "books", "name", "description",
}
if len(exp) != len(fields) {
t.Errorf("Expected %d fields %d", len(exp), len(fields))
}
for i := range exp {
if string(fields[i]) != exp[i] {
t.Errorf("Expected field '%s' got '%s'", string(exp[i]), fields[i])
}
}
}
func TestKeys2(t *testing.T) {
json := `{"id":1,"posts": [{"title":"PT1-1","description":"PD1-1"}, {"title":"PT1-2","description":"PD1-2"}], "full_name":"FN1","email":"E1","books": [{"name":"BN1-1","description":"BD1-1"},{"name":"BN1-2","description":"BD1-2"},{"name":"BN1-2","description":"BD1-2"}]}`
fields := Keys([]byte(json))
exp := []string{
"id", "posts", "title", "description", "full_name", "email", "books", "name", "description",
}
// for i := range fields {
// fmt.Println("-->", string(fields[i]))
// }
if len(exp) != len(fields) {
t.Errorf("Expected %d fields %d", len(exp), len(fields))
}
for i := range exp {
if string(fields[i]) != exp[i] {
t.Errorf("Expected field '%s' got '%s'", string(exp[i]), fields[i])
}
}
}
func TestKeys3(t *testing.T) {
json := `{
"insert": {
"created_at": "now",
"test": { "type1": "a", "type2": "b" },
"name": "Hello",
"updated_at": "now",
"description": "World"
},
"user": 123
}`
fields := Keys([]byte(json))
exp := []string{
"insert", "created_at", "test", "type1", "type2", "name", "updated_at", "description",
"user",
}
if len(exp) != len(fields) {
t.Errorf("Expected %d fields %d", len(exp), len(fields))
}
for i := range exp {
if string(fields[i]) != exp[i] {
t.Errorf("Expected field '%s' got '%s'", string(exp[i]), fields[i])
}
}
}
func BenchmarkGet(b *testing.B) {
b.ReportAllocs()

122
jsn/keys.go Normal file
View File

@ -0,0 +1,122 @@
package jsn
func Keys(b []byte) [][]byte {
res := make([][]byte, 0, 20)
s, e, d := 0, 0, 0
var k []byte
state := expectValue
st := NewStack()
ae := 0
for i := 0; i < len(b); i++ {
if state == expectObjClose || state == expectListClose {
switch b[i] {
case '{', '[':
d++
case '}', ']':
d--
}
}
si := st.Peek()
switch {
case state == expectKey && si != nil && i >= si.ss:
i = si.se + 1
st.Pop()
case state == expectKey && b[i] == '{':
state = expectObjClose
s = i
d++
case state == expectObjClose && d == 0 && b[i] == '}':
state = expectKey
if ae != 0 {
st.Push(skipInfo{i, ae})
ae = 0
}
e = i
i = s
case state == expectKey && b[i] == '"':
state = expectKeyClose
s = i
case state == expectKeyClose && b[i] == '"':
state = expectColon
k = b[(s + 1):i]
case state == expectColon && b[i] == ':':
state = expectValue
case state == expectValue && b[i] == '"':
state = expectString
s = i
case state == expectString && b[i] == '"':
e = i
case state == expectValue && b[i] == '{':
state = expectObjClose
s = i
d++
case state == expectObjClose && d == 0 && b[i] == '}':
state = expectKey
e = i
i = s
case state == expectValue && b[i] == '[':
state = expectListClose
s = i
d++
case state == expectListClose && d == 0 && b[i] == ']':
state = expectKey
ae = i
e = i
i = s
case state == expectValue && (b[i] >= '0' && b[i] <= '9'):
state = expectNumClose
s = i
case state == expectNumClose &&
((b[i] < '0' || b[i] > '9') &&
(b[i] != '.' && b[i] != 'e' && b[i] != 'E' && b[i] != '+' && b[i] != '-')):
i--
e = i
case state == expectValue &&
(b[i] == 'f' || b[i] == 'F' || b[i] == 't' || b[i] == 'T'):
state = expectBoolClose
s = i
case state == expectBoolClose && (b[i] == 'e' || b[i] == 'E'):
e = i
case state == expectValue && b[i] == 'n':
state = expectNull
case state == expectNull && b[i] == 'l':
e = i
}
if e != 0 {
if k != nil {
res = append(res, k)
}
state = expectKey
k = nil
e = 0
}
}
return res
}

51
jsn/stack.go Normal file
View File

@ -0,0 +1,51 @@
package jsn
type skipInfo struct {
ss, se int
}
type Stack struct {
stA [20]skipInfo
st []skipInfo
top int
}
// Create a new Stack
func NewStack() *Stack {
s := &Stack{top: -1}
s.st = s.stA[:0]
return s
}
// Return the number of items in the Stack
func (s *Stack) Len() int {
return (s.top + 1)
}
// View the top item on the Stack
func (s *Stack) Peek() *skipInfo {
if s.top == -1 {
return nil
}
return &s.st[s.top]
}
// Pop the top item of the Stack and return it
func (s *Stack) Pop() *skipInfo {
if s.top == -1 {
return nil
}
s.top--
return &s.st[(s.top + 1)]
}
// Push a value onto the top of the Stack
func (s *Stack) Push(value skipInfo) {
s.top++
if len(s.st) <= s.top {
s.st = append(s.st, value)
} else {
s.st[s.top] = value
}
}

37
jsn/tree.go Normal file
View File

@ -0,0 +1,37 @@
package jsn
import (
"bytes"
"encoding/json"
)
func Tree(v []byte) (map[string]interface{}, bool, error) {
dec := json.NewDecoder(bytes.NewReader(v))
array := false
// read open bracket
for i := range v {
if v[i] != ' ' {
array = (v[i] == '[')
break
}
}
if array {
if _, err := dec.Token(); err != nil {
return nil, false, err
}
}
// while the array contains values
var m map[string]interface{}
// decode an array value (Message)
err := dec.Decode(&m)
if err != nil {
return nil, false, err
}
return m, array, nil
}

View File

@ -1,7 +0,0 @@
goos: darwin
goarch: amd64
pkg: github.com/dosco/super-graph/psql
BenchmarkCompile-8 100000 16476 ns/op 3282 B/op 66 allocs/op
BenchmarkCompileParallel-8 300000 4639 ns/op 3324 B/op 66 allocs/op
PASS
ok github.com/dosco/super-graph/psql 3.274s

View File

@ -1,7 +0,0 @@
goos: darwin
goarch: amd64
pkg: github.com/dosco/super-graph/psql
BenchmarkCompile-8 100000 15728 ns/op 3000 B/op 60 allocs/op
BenchmarkCompileParallel-8 300000 5077 ns/op 3023 B/op 60 allocs/op
PASS
ok github.com/dosco/super-graph/psql 3.318s

View File

@ -1,7 +0,0 @@
goos: darwin
goarch: amd64
pkg: github.com/dosco/super-graph/psql
BenchmarkCompile-8 1000000 15997 ns/op 3048 B/op 58 allocs/op
BenchmarkCompileParallel-8 3000000 4722 ns/op 3073 B/op 58 allocs/op
PASS
ok github.com/dosco/super-graph/psql 35.024s

View File

@ -1,7 +0,0 @@
goos: darwin
goarch: amd64
pkg: github.com/dosco/super-graph/psql
BenchmarkCompile-8 100000 16829 ns/op 2887 B/op 57 allocs/op
BenchmarkCompileParallel-8 300000 5450 ns/op 2911 B/op 57 allocs/op
PASS
ok github.com/dosco/super-graph/psql 3.561s

90
psql/insert.go Normal file
View File

@ -0,0 +1,90 @@
package psql
import (
"bytes"
"errors"
"io"
"github.com/dosco/super-graph/jsn"
"github.com/dosco/super-graph/qcode"
)
func (co *Compiler) compileMutation(qc *qcode.QCode, w *bytes.Buffer, vars Variables) (uint32, error) {
if len(qc.Selects) == 0 {
return 0, errors.New("empty query")
}
c := &compilerContext{w, qc.Selects, co}
root := &qc.Selects[0]
c.w.WriteString(`WITH `)
c.w.WriteString(root.Table)
c.w.WriteString(` AS (`)
if _, err := c.renderInsert(qc, w, vars); err != nil {
return 0, err
}
c.w.WriteString(`) `)
return c.compileQuery(qc, w)
}
func (c *compilerContext) renderInsert(qc *qcode.QCode, w *bytes.Buffer, vars Variables) (uint32, error) {
root := &qc.Selects[0]
insert, ok := vars["insert"]
if !ok {
return 0, errors.New("Variable 'insert' not defined")
}
jt, array, err := jsn.Tree(insert)
if err != nil {
return 0, err
}
c.w.WriteString(`WITH input AS (SELECT {{insert}}::json AS j) INSERT INTO `)
c.w.WriteString(root.Table)
io.WriteString(c.w, " (")
c.renderInsertColumns(qc, w, jt)
io.WriteString(c.w, ")")
c.w.WriteString(` SELECT `)
c.renderInsertColumns(qc, w, jt)
c.w.WriteString(` FROM input i, `)
if array {
c.w.WriteString(`json_populate_recordset`)
} else {
c.w.WriteString(`json_populate_record`)
}
c.w.WriteString(`(NULL::`)
c.w.WriteString(root.Table)
c.w.WriteString(`, i.j) t RETURNING * `)
return 0, nil
}
func (c *compilerContext) renderInsertColumns(qc *qcode.QCode, w *bytes.Buffer,
jt map[string]interface{}) (uint32, error) {
ti, err := c.schema.GetTable(qc.Selects[0].Table)
if err != nil {
return 0, err
}
i := 0
for _, cn := range ti.ColumnNames {
if _, ok := jt[cn]; !ok {
continue
}
if i != 0 {
io.WriteString(c.w, ", ")
}
c.w.WriteString(cn)
i++
}
return 0, nil
}

65
psql/insert_test.go Normal file
View File

@ -0,0 +1,65 @@
package psql
import (
"encoding/json"
"fmt"
"testing"
)
func singleInsert(t *testing.T) {
gql := `mutation {
product(id: 15, insert: $insert) {
id
name
}
}`
sql := `test`
vars := map[string]json.RawMessage{
"insert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`),
}
resSQL, err := compileGQLToPSQL(gql, vars)
if err != nil {
t.Fatal(err)
}
fmt.Println(">", string(resSQL))
if string(resSQL) != sql {
t.Fatal(errNotExpected)
}
}
func bulkInsert(t *testing.T) {
gql := `mutation {
product(id: 15, insert: $insert) {
id
name
}
}`
sql := `test`
vars := map[string]json.RawMessage{
"insert": json.RawMessage(` [{ "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }]`),
}
resSQL, err := compileGQLToPSQL(gql, vars)
if err != nil {
t.Fatal(err)
}
fmt.Println(">", string(resSQL))
if string(resSQL) != sql {
t.Fatal(errNotExpected)
}
}
func TestCompileInsert(t *testing.T) {
t.Run("singleInsert", singleInsert)
t.Run("bulkInsert", bulkInsert)
}

View File

@ -2,6 +2,7 @@ package psql
import (
"bytes"
"encoding/json"
"errors"
"fmt"
"io"
@ -18,6 +19,8 @@ const (
closeBlock = 500
)
type Variables map[string]json.RawMessage
type Config struct {
Schema *DBSchema
Vars map[string]string
@ -51,19 +54,30 @@ type compilerContext struct {
*Compiler
}
func (co *Compiler) CompileEx(qc *qcode.QCode) (uint32, []byte, error) {
func (co *Compiler) CompileEx(qc *qcode.QCode, vars Variables) (uint32, []byte, error) {
w := &bytes.Buffer{}
skipped, err := co.Compile(qc, w)
skipped, err := co.Compile(qc, w, vars)
return skipped, w.Bytes(), err
}
func (co *Compiler) Compile(qc *qcode.QCode, w *bytes.Buffer) (uint32, error) {
if len(qc.Query.Selects) == 0 {
func (co *Compiler) Compile(qc *qcode.QCode, w *bytes.Buffer, vars Variables) (uint32, error) {
switch qc.Type {
case qcode.QTQuery:
return co.compileQuery(qc, w)
case qcode.QTMutation:
return co.compileMutation(qc, w, vars)
}
return 0, errors.New("unknown operation")
}
func (co *Compiler) compileQuery(qc *qcode.QCode, w *bytes.Buffer) (uint32, error) {
if len(qc.Selects) == 0 {
return 0, errors.New("empty query")
}
c := &compilerContext{w, qc.Query.Selects, co}
root := &qc.Query.Selects[0]
c := &compilerContext{w, qc.Selects, co}
root := &qc.Selects[0]
st := NewStack()
st.Push(root.ID + closeBlock)
@ -844,7 +858,7 @@ func (c *compilerContext) renderList(ex *qcode.Exp) {
func (c *compilerContext) renderVal(ex *qcode.Exp,
vars map[string]string) {
io.WriteString(c.w, ` (`)
//io.WriteString(c.w, ` (`)
switch ex.Type {
case qcode.ValBool, qcode.ValInt, qcode.ValFloat:
if len(ex.Val) != 0 {
@ -852,21 +866,23 @@ func (c *compilerContext) renderVal(ex *qcode.Exp,
} else {
c.w.WriteString(`''`)
}
case qcode.ValStr:
c.w.WriteString(`'`)
c.w.WriteString(ex.Val)
c.w.WriteString(`'`)
case qcode.ValVar:
if val, ok := vars[ex.Val]; ok {
c.w.WriteString(val)
} else {
//fmt.Fprintf(w, `'{{%s}}'`, ex.Val)
c.w.WriteString(`'{{`)
c.w.WriteString(`{{`)
c.w.WriteString(ex.Val)
c.w.WriteString(`}}'`)
c.w.WriteString(`}}`)
}
}
c.w.WriteString(`)`)
//c.w.WriteString(`)`)
}
func funcPrefixLen(fn string) int {

View File

@ -125,13 +125,13 @@ func TestMain(m *testing.M) {
os.Exit(m.Run())
}
func compileGQLToPSQL(gql string) ([]byte, error) {
func compileGQLToPSQL(gql string, vars Variables) ([]byte, error) {
qc, err := qcompile.Compile([]byte(gql))
if err != nil {
return nil, err
}
_, sqlStmt, err := pcompile.CompileEx(qc)
_, sqlStmt, err := pcompile.CompileEx(qc, vars)
if err != nil {
return nil, err
}
@ -164,7 +164,7 @@ func withComplexArgs(t *testing.T) {
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products" ORDER BY "products_0_price_ob" DESC), '[]') AS "products" FROM (SELECT DISTINCT ON ("products_0_price_ob") row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products", "products_0"."price" AS "products_0_price_ob" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("products"."id") < (28)) AND (("products"."id") >= (20))) LIMIT ('30') :: integer) AS "products_0" ORDER BY "products_0_price_ob" DESC LIMIT ('30') :: integer) AS "products_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql)
resSQL, err := compileGQLToPSQL(gql, nil)
if err != nil {
t.Fatal(err)
}
@ -192,7 +192,7 @@ func withWhereMultiOr(t *testing.T) {
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("products"."price") < (20)) OR (("products"."price") > (10)) OR NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql)
resSQL, err := compileGQLToPSQL(gql, nil)
if err != nil {
t.Fatal(err)
}
@ -218,7 +218,7 @@ func withWhereIsNull(t *testing.T) {
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("products"."price") > (10)) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql)
resSQL, err := compileGQLToPSQL(gql, nil)
if err != nil {
t.Fatal(err)
}
@ -244,7 +244,7 @@ func withWhereAndList(t *testing.T) {
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("products"."price") > (10)) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql)
resSQL, err := compileGQLToPSQL(gql, nil)
if err != nil {
t.Fatal(err)
}
@ -264,7 +264,7 @@ func fetchByID(t *testing.T) {
sql := `SELECT json_object_agg('product', product) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "product_0"."id" AS "id", "product_0"."name" AS "name") AS "sel_0")) AS "product" FROM (SELECT "product"."id", "product"."name" FROM "products" AS "product" WHERE ((("product"."price") > (0)) AND (("product"."price") < (8)) AND (("id") = (15))) LIMIT ('1') :: integer) AS "product_0" LIMIT ('1') :: integer) AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql)
resSQL, err := compileGQLToPSQL(gql, nil)
if err != nil {
t.Fatal(err)
}
@ -284,7 +284,7 @@ func searchQuery(t *testing.T) {
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("tsv") @@ to_tsquery('Imperial'))) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql)
resSQL, err := compileGQLToPSQL(gql, nil)
if err != nil {
t.Fatal(err)
}
@ -307,7 +307,7 @@ func oneToMany(t *testing.T) {
sql := `SELECT json_object_agg('users', users) FROM (SELECT coalesce(json_agg("users"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."email" AS "email", "products_1_join"."products" AS "products") AS "sel_0")) AS "users" FROM (SELECT "users"."email", "users"."id" FROM "users" WHERE ((("users"."id") = ('{{user_id}}'))) LIMIT ('20') :: integer) AS "users_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "products_1"."name" AS "name", "products_1"."price" AS "price") AS "sel_1")) AS "products" FROM (SELECT "products"."name", "products"."price" FROM "products" WHERE ((("products"."user_id") = ("users_0"."id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "products_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "users_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql)
resSQL, err := compileGQLToPSQL(gql, nil)
if err != nil {
t.Fatal(err)
}
@ -330,7 +330,7 @@ func belongsTo(t *testing.T) {
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."price" AS "price", "users_1_join"."users" AS "users") AS "sel_0")) AS "products" FROM (SELECT "products"."name", "products"."price", "products"."user_id" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8))) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("users"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "users_1"."email" AS "email") AS "sel_1")) AS "users" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = ("products_0"."user_id"))) LIMIT ('20') :: integer) AS "users_1" LIMIT ('20') :: integer) AS "users_1") AS "users_1_join" ON ('true') LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql)
resSQL, err := compileGQLToPSQL(gql, nil)
if err != nil {
t.Fatal(err)
}
@ -353,7 +353,7 @@ func manyToMany(t *testing.T) {
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "customers_1_join"."customers" AS "customers") AS "sel_0")) AS "products" FROM (SELECT "products"."name", "products"."id" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8))) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("customers"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "customers_1"."email" AS "email", "customers_1"."full_name" AS "full_name") AS "sel_1")) AS "customers" FROM (SELECT "customers"."email", "customers"."full_name" FROM "customers" LEFT OUTER JOIN "purchases" ON (("purchases"."product_id") = ("products_0"."id")) WHERE ((("customers"."id") = ("purchases"."customer_id"))) LIMIT ('20') :: integer) AS "customers_1" LIMIT ('20') :: integer) AS "customers_1") AS "customers_1_join" ON ('true') LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql)
resSQL, err := compileGQLToPSQL(gql, nil)
if err != nil {
t.Fatal(err)
}
@ -376,7 +376,7 @@ func manyToManyReverse(t *testing.T) {
sql := `SELECT json_object_agg('customers', customers) FROM (SELECT coalesce(json_agg("customers"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "customers_0"."email" AS "email", "customers_0"."full_name" AS "full_name", "products_1_join"."products" AS "products") AS "sel_0")) AS "customers" FROM (SELECT "customers"."email", "customers"."full_name", "customers"."id" FROM "customers" LIMIT ('20') :: integer) AS "customers_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "products_1"."name" AS "name") AS "sel_1")) AS "products" FROM (SELECT "products"."name" FROM "products" LEFT OUTER JOIN "purchases" ON (("purchases"."customer_id") = ("customers_0"."id")) WHERE ((("products"."id") = ("purchases"."product_id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "products_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "customers_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql)
resSQL, err := compileGQLToPSQL(gql, nil)
if err != nil {
t.Fatal(err)
}
@ -396,7 +396,7 @@ func aggFunction(t *testing.T) {
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."count_price" AS "count_price") AS "sel_0")) AS "products" FROM (SELECT "products"."name", count("products"."price") AS "count_price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8))) GROUP BY "products"."name" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql)
resSQL, err := compileGQLToPSQL(gql, nil)
if err != nil {
t.Fatal(err)
}
@ -416,7 +416,7 @@ func aggFunctionWithFilter(t *testing.T) {
sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."max_price" AS "max_price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", max("products"."price") AS "max_price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("products"."id") > (10))) GROUP BY "products"."id" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql)
resSQL, err := compileGQLToPSQL(gql, nil)
if err != nil {
t.Fatal(err)
}
@ -436,7 +436,7 @@ func queryWithVariables(t *testing.T) {
sql := `SELECT json_object_agg('product', product) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "product_0"."id" AS "id", "product_0"."name" AS "name") AS "sel_0")) AS "product" FROM (SELECT "product"."id", "product"."name" FROM "products" AS "product" WHERE ((("product"."price") > (0)) AND (("product"."price") < (8)) AND (("product"."price") = ('{{product_price}}')) AND (("id") = ('{{product_id}}'))) LIMIT ('1') :: integer) AS "product_0" LIMIT ('1') :: integer) AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql)
resSQL, err := compileGQLToPSQL(gql, nil)
if err != nil {
t.Fatal(err)
}
@ -455,7 +455,7 @@ func syntheticTables(t *testing.T) {
sql := `SELECT json_object_agg('me', me) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "me_0"."email" AS "email") AS "sel_0")) AS "me" FROM (SELECT "me"."email" FROM "users" AS "me" WHERE ((("me"."id") = ('{{user_id}}'))) LIMIT ('1') :: integer) AS "me_0" LIMIT ('1') :: integer) AS "done_1337";`
resSQL, err := compileGQLToPSQL(gql)
resSQL, err := compileGQLToPSQL(gql, nil)
if err != nil {
t.Fatal(err)
}
@ -465,7 +465,7 @@ func syntheticTables(t *testing.T) {
}
}
func TestCompileGQL(t *testing.T) {
func TestCompileSelect(t *testing.T) {
t.Run("withComplexArgs", withComplexArgs)
t.Run("withWhereAndList", withWhereAndList)
t.Run("withWhereIsNull", withWhereIsNull)
@ -519,7 +519,7 @@ func BenchmarkCompile(b *testing.B) {
b.Fatal(err)
}
_, err = pcompile.Compile(qc, w)
_, err = pcompile.Compile(qc, w, nil)
if err != nil {
b.Fatal(err)
}
@ -540,7 +540,7 @@ func BenchmarkCompileParallel(b *testing.B) {
b.Fatal(err)
}
_, err = pcompile.Compile(qc, w)
_, err = pcompile.Compile(qc, w, nil)
if err != nil {
b.Fatal(err)
}

View File

@ -106,11 +106,12 @@ type DBSchema struct {
}
type DBTableInfo struct {
Name string
Singular bool
PrimaryCol string
TSVCol string
Columns map[string]*DBColumn
Name string
Singular bool
PrimaryCol string
TSVCol string
Columns map[string]*DBColumn
ColumnNames []string
}
type RelType int
@ -162,25 +163,30 @@ func (s *DBSchema) updateSchema(
// Foreign key columns in current table
colByID := make(map[int]*DBColumn)
columns := make(map[string]*DBColumn, len(cols))
colNames := make([]string, len(cols))
for i := range cols {
c := cols[i]
columns[strings.ToLower(c.Name)] = cols[i]
name := strings.ToLower(c.Name)
columns[name] = cols[i]
colNames = append(colNames, name)
colByID[c.ID] = cols[i]
}
singular := strings.ToLower(flect.Singularize(t.Name))
s.t[singular] = &DBTableInfo{
Name: t.Name,
Singular: true,
Columns: columns,
Name: t.Name,
Singular: true,
Columns: columns,
ColumnNames: colNames,
}
plural := strings.ToLower(flect.Pluralize(t.Name))
s.t[plural] = &DBTableInfo{
Name: t.Name,
Singular: false,
Columns: columns,
Name: t.Name,
Singular: false,
Columns: columns,
ColumnNames: colNames,
}
ct := strings.ToLower(t.Name)

View File

@ -100,19 +100,7 @@ var lexPool = sync.Pool{
}
func Parse(gql []byte) (*Operation, error) {
return parseSelectionSet(nil, gql)
}
func ParseQuery(gql []byte) (*Operation, error) {
op := opPool.Get().(*Operation)
op.Reset()
op.Type = opQuery
op.Name = ""
op.Fields = op.fieldsA[:0]
op.Args = op.argsA[:0]
return parseSelectionSet(op, gql)
return parseSelectionSet(gql)
}
func ParseArgValue(argVal string) (*Node, error) {
@ -134,7 +122,7 @@ func ParseArgValue(argVal string) (*Node, error) {
return op, err
}
func parseSelectionSet(op *Operation, gql []byte) (*Operation, error) {
func parseSelectionSet(gql []byte) (*Operation, error) {
var err error
if len(gql) == 0 {
@ -154,14 +142,28 @@ func parseSelectionSet(op *Operation, gql []byte) (*Operation, error) {
items: l.items,
}
if op == nil {
op, err = p.parseOp()
} else {
if p.peek(itemObjOpen) {
p.ignore()
}
var op *Operation
if p.peek(itemObjOpen) {
p.ignore()
}
if p.peek(itemName) {
op = opPool.Get().(*Operation)
op.Reset()
op.Type = opQuery
op.Name = ""
op.Fields = op.fieldsA[:0]
op.Args = op.argsA[:0]
op.Fields, err = p.parseFields(op.Fields)
} else {
op, err = p.parseOp()
if err != nil {
return nil, err
}
}
lexPool.Put(l)

View File

@ -45,10 +45,10 @@ func compareOp(op1, op2 Operation) error {
}
*/
func TestCompile(t *testing.T) {
func TestCompile1(t *testing.T) {
qcompile, _ := NewCompiler(Config{})
_, err := qcompile.CompileQuery([]byte(`
_, err := qcompile.Compile([]byte(`
product(id: 15) {
id
name
@ -59,9 +59,39 @@ func TestCompile(t *testing.T) {
}
}
func TestCompile2(t *testing.T) {
qcompile, _ := NewCompiler(Config{})
_, err := qcompile.Compile([]byte(`
query { product(id: 15) {
id
name
} }`))
if err != nil {
t.Fatal(err)
}
}
func TestCompile3(t *testing.T) {
qcompile, _ := NewCompiler(Config{})
_, err := qcompile.Compile([]byte(`
mutation {
product(id: 15, name: "Test") {
id
name
}
}`))
if err != nil {
t.Fatal(err)
}
}
func TestInvalidCompile1(t *testing.T) {
qcompile, _ := NewCompiler(Config{})
_, err := qcompile.CompileQuery([]byte(`#`))
_, err := qcompile.Compile([]byte(`#`))
if err == nil {
t.Fatal(errors.New("expecting an error"))
@ -70,7 +100,7 @@ func TestInvalidCompile1(t *testing.T) {
func TestInvalidCompile2(t *testing.T) {
qcompile, _ := NewCompiler(Config{})
_, err := qcompile.CompileQuery([]byte(`{u(where:{not:0})}`))
_, err := qcompile.Compile([]byte(`{u(where:{not:0})}`))
if err == nil {
t.Fatal(errors.New("expecting an error"))
@ -79,7 +109,7 @@ func TestInvalidCompile2(t *testing.T) {
func TestEmptyCompile(t *testing.T) {
qcompile, _ := NewCompiler(Config{})
_, err := qcompile.CompileQuery([]byte(``))
_, err := qcompile.Compile([]byte(``))
if err == nil {
t.Fatal(errors.New("expecting an error"))
@ -114,7 +144,7 @@ func BenchmarkQCompile(b *testing.B) {
b.ReportAllocs()
for n := 0; n < b.N; n++ {
_, err := qcompile.CompileQuery(gql)
_, err := qcompile.Compile(gql)
if err != nil {
b.Fatal(err)
@ -130,7 +160,7 @@ func BenchmarkQCompileP(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_, err := qcompile.CompileQuery(gql)
_, err := qcompile.Compile(gql)
if err != nil {
b.Fatal(err)

View File

@ -10,15 +10,17 @@ import (
"github.com/gobuffalo/flect"
)
type QType int
const (
maxSelectors = 30
QTQuery QType = iota + 1
QTMutation
)
type QCode struct {
Query *Query
}
type Query struct {
Type QType
Selects []Select
}
@ -149,6 +151,11 @@ type Compiler struct {
ka bool
}
var opMap = map[parserType]QType{
opQuery: QTQuery,
opMutate: QTMutation,
}
var expPool = sync.Pool{
New: func() interface{} { return new(Exp) },
}
@ -196,44 +203,23 @@ func (com *Compiler) Compile(query []byte) (*QCode, error) {
return nil, err
}
switch op.Type {
case opQuery:
qc.Query, err = com.compileQuery(op)
case opMutate:
case opSub:
default:
err = fmt.Errorf("Unknown operation type %d", op.Type)
}
qc.Selects, err = com.compileQuery(op)
if err != nil {
return nil, err
}
if t, ok := opMap[op.Type]; ok {
qc.Type = t
} else {
return nil, fmt.Errorf("Unknown operation type %d", op.Type)
}
opPool.Put(op)
return &qc, nil
}
func (com *Compiler) CompileQuery(query []byte) (*QCode, error) {
var err error
op, err := ParseQuery(query)
if err != nil {
return nil, err
}
qc := &QCode{}
qc.Query, err = com.compileQuery(op)
opPool.Put(op)
if err != nil {
return nil, err
}
return qc, nil
}
func (com *Compiler) compileQuery(op *Operation) (*Query, error) {
func (com *Compiler) compileQuery(op *Operation) ([]Select, error) {
id := int32(0)
parentID := int32(0)
@ -344,7 +330,7 @@ func (com *Compiler) compileQuery(op *Operation) (*Query, error) {
return nil, errors.New("invalid query")
}
return &Query{selects[:id]}, nil
return selects[:id], nil
}
func (com *Compiler) compileArgs(sel *Select, args []Arg) error {
@ -661,14 +647,6 @@ func (com *Compiler) compileArgOffset(sel *Select, arg *Arg) error {
return nil
}
func compileMutate() (*Query, error) {
return nil, nil
}
func compileSub() (*Query, error) {
return nil, nil
}
func newExp(st *util.Stack, node *Node, usePool bool) (*Exp, error) {
name := node.Name
if name[0] == '_' {

View File

@ -1,6 +1,7 @@
package serv
import (
"encoding/json"
"fmt"
"io/ioutil"
"log"
@ -9,9 +10,15 @@ import (
"strings"
)
const (
AL_QUERY int = iota + 1
AL_VARS
)
type allowItem struct {
uri string
gql string
uri string
gql string
vars json.RawMessage
}
var _allowList allowList
@ -77,8 +84,9 @@ func (al *allowList) add(req *gqlReq) {
}
al.saveChan <- &allowItem{
uri: req.ref,
gql: req.Query,
uri: req.ref,
gql: req.Query,
vars: req.Vars,
}
}
@ -93,32 +101,62 @@ func (al *allowList) load() {
}
var uri string
var varBytes []byte
s, e, c := 0, 0, 0
ty := 0
for {
if c == 0 && b[e] == '#' {
s = e
for b[e] != '\n' && e < len(b) {
for e < len(b) && b[e] != '\n' {
e++
}
if (e - s) > 2 {
uri = strings.TrimSpace(string(b[(s + 1):e]))
}
}
if b[e] == '{' {
if e >= len(b) {
break
}
if matchPrefix(b, e, "query") || matchPrefix(b, e, "mutation") {
if c == 0 {
s = e
}
ty = AL_QUERY
} else if matchPrefix(b, e, "variables") {
if c == 0 {
s = e + len("variables") + 1
}
ty = AL_VARS
} else if b[e] == '{' {
c++
} else if b[e] == '}' {
c--
if c == 0 {
q := b[s:(e + 1)]
al.list[gqlHash(q)] = &allowItem{
uri: uri,
gql: string(q),
if ty == AL_QUERY {
q := string(b[s:(e + 1)])
item := &allowItem{
uri: uri,
gql: q,
}
if len(varBytes) != 0 {
item.vars = varBytes
}
al.list[gqlHash(q, varBytes)] = item
varBytes = nil
} else if ty == AL_VARS {
varBytes = b[s:(e + 1)]
}
ty = 0
}
}
@ -130,7 +168,7 @@ func (al *allowList) load() {
}
func (al *allowList) save(item *allowItem) {
al.list[gqlHash([]byte(item.gql))] = item
al.list[gqlHash(item.gql, item.vars)] = item
f, err := os.Create(al.filepath)
if err != nil {
@ -141,10 +179,10 @@ func (al *allowList) save(item *allowItem) {
defer f.Close()
keys := []string{}
urlMap := make(map[string][]string)
urlMap := make(map[string][]*allowItem)
for _, v := range al.list {
urlMap[v.uri] = append(urlMap[v.uri], v.gql)
urlMap[v.uri] = append(urlMap[v.uri], v)
}
for k := range urlMap {
@ -159,7 +197,28 @@ func (al *allowList) save(item *allowItem) {
f.WriteString(fmt.Sprintf("# %s\n\n", k))
for i := range v {
f.WriteString(fmt.Sprintf("query %s\n\n", v[i]))
if len(v[i].vars) != 0 {
vj, err := json.MarshalIndent(v[i].vars, "", "\t")
if err != nil {
logger.Warn().Err(err).Msg("Failed to write allow list 'vars' to file")
continue
}
f.WriteString(fmt.Sprintf("variables %s\n\n", vj))
}
f.WriteString(fmt.Sprintf("%s\n\n", v[i].gql))
}
}
}
func matchPrefix(b []byte, i int, s string) bool {
if (len(b) - i) < len(s) {
return false
}
for n := 0; n < len(s); n++ {
if b[(i+n)] != s[n] {
return false
}
}
return true
}

View File

@ -14,6 +14,7 @@ import (
"github.com/cespare/xxhash/v2"
"github.com/dosco/super-graph/jsn"
"github.com/dosco/super-graph/psql"
"github.com/dosco/super-graph/qcode"
"github.com/go-pg/pg"
"github.com/valyala/fasttemplate"
@ -42,7 +43,7 @@ func (c *coreContext) handleReq(w io.Writer, req *http.Request) error {
if conf.UseAllowList {
var ps *preparedItem
data, ps, err = c.resolvePreparedSQL([]byte(c.req.Query))
data, ps, err = c.resolvePreparedSQL(c.req.Query)
if err != nil {
return err
}
@ -52,7 +53,7 @@ func (c *coreContext) handleReq(w io.Writer, req *http.Request) error {
} else {
qc, err = qcompile.CompileQuery([]byte(c.req.Query))
qc, err = qcompile.Compile([]byte(c.req.Query))
if err != nil {
return err
}
@ -67,7 +68,7 @@ func (c *coreContext) handleReq(w io.Writer, req *http.Request) error {
return c.render(w, data)
}
sel := qc.Query.Selects
sel := qc.Selects
h := xxhash.New()
// fetch the field name used within the db response json
@ -252,8 +253,8 @@ func (c *coreContext) resolveRemotes(
return to, cerr
}
func (c *coreContext) resolvePreparedSQL(gql []byte) ([]byte, *preparedItem, error) {
ps, ok := _preparedList[gqlHash(gql)]
func (c *coreContext) resolvePreparedSQL(gql string) ([]byte, *preparedItem, error) {
ps, ok := _preparedList[gqlHash(gql, c.req.Vars)]
if !ok {
return nil, nil, errUnauthorized
}
@ -266,17 +267,22 @@ func (c *coreContext) resolvePreparedSQL(gql []byte) ([]byte, *preparedItem, err
return nil, nil, err
}
fmt.Printf("PRE: %#v %#v\n", ps.stmt, vars)
fmt.Printf("PRE: %v\n", ps.stmt)
return []byte(root), ps, nil
}
func (c *coreContext) resolveSQL(qc *qcode.QCode) (
[]byte, uint32, error) {
stmt := &bytes.Buffer{}
skipped, err := pcompile.Compile(qc, stmt)
vars := make(map[string]json.RawMessage)
if err := json.Unmarshal(c.req.Vars, &vars); err != nil {
return nil, 0, err
}
skipped, err := pcompile.Compile(qc, stmt, psql.Variables(vars))
if err != nil {
return nil, 0, err
}
@ -284,7 +290,7 @@ func (c *coreContext) resolveSQL(qc *qcode.QCode) (
t := fasttemplate.New(stmt.String(), openVar, closeVar)
stmt.Reset()
_, err = t.Execute(stmt, varMap(c))
_, err = t.ExecuteFunc(stmt, varMap(c))
if err == errNoUserID &&
authFailBlock == authFailBlockPerQuery &&
@ -317,10 +323,10 @@ func (c *coreContext) resolveSQL(qc *qcode.QCode) (
return nil, 0, err
}
if conf.EnableTracing && len(qc.Query.Selects) != 0 {
if conf.EnableTracing && len(qc.Selects) != 0 {
c.addTrace(
qc.Query.Selects,
qc.Query.Selects[0].ID,
qc.Selects,
qc.Selects[0].ID,
st)
}

35
serv/core_test.go Normal file
View File

@ -0,0 +1,35 @@
package serv
/*
func simpleMutation(t *testing.T) {
gql := `mutation {
product(id: 15, insert: { name: "Test", price: 20.5 }) {
id
name
}
}`
sql := `test`
backgroundCtx := context.Background()
ctx := &coreContext{Context: backgroundCtx}
resSQL, err := compileGQLToPSQL(gql)
if err != nil {
t.Fatal(err)
}
fmt.Println(">", string(resSQL))
if string(resSQL) != sql {
t.Fatal(errNotExpected)
}
}
func TestCompileGQL(t *testing.T) {
t.Run("withComplexArgs", withComplexArgs)
t.Run("simpleMutation", simpleMutation)
}
*/

View File

@ -26,13 +26,13 @@ var (
)
type gqlReq struct {
OpName string `json:"operationName"`
Query string `json:"query"`
Vars variables `json:"variables"`
OpName string `json:"operationName"`
Query string `json:"query"`
Vars json.RawMessage `json:"variables"`
ref string
}
type variables map[string]interface{}
type variables map[string]json.RawMessage
type gqlResp struct {
Error string `json:"error,omitempty"`

View File

@ -2,9 +2,11 @@ package serv
import (
"bytes"
"encoding/json"
"fmt"
"io"
"github.com/dosco/super-graph/psql"
"github.com/dosco/super-graph/qcode"
"github.com/go-pg/pg"
"github.com/valyala/fasttemplate"
@ -12,7 +14,7 @@ import (
type preparedItem struct {
stmt *pg.Stmt
args []string
args [][]byte
skipped uint32
qc *qcode.QCode
}
@ -25,36 +27,46 @@ func initPreparedList() {
_preparedList = make(map[string]*preparedItem)
for k, v := range _allowList.list {
err := prepareStmt(k, v.gql)
err := prepareStmt(k, v.gql, v.vars)
if err != nil {
panic(err)
}
}
}
func prepareStmt(key, gql string) error {
func prepareStmt(key, gql string, varBytes json.RawMessage) error {
if len(gql) == 0 || len(key) == 0 {
return nil
}
qc, err := qcompile.CompileQuery([]byte(gql))
qc, err := qcompile.Compile([]byte(gql))
if err != nil {
return err
}
var vars map[string]json.RawMessage
if len(varBytes) != 0 {
vars = make(map[string]json.RawMessage)
if err := json.Unmarshal(varBytes, &vars); err != nil {
return err
}
}
buf := &bytes.Buffer{}
skipped, err := pcompile.Compile(qc, buf)
skipped, err := pcompile.Compile(qc, buf, psql.Variables(vars))
if err != nil {
return err
}
t := fasttemplate.New(buf.String(), `('{{`, `}}')`)
am := make([]string, 0, 5)
t := fasttemplate.New(buf.String(), `{{`, `}}`)
am := make([][]byte, 0, 5)
i := 0
finalSQL := t.ExecuteFuncString(func(w io.Writer, tag string) (int, error) {
am = append(am, tag)
am = append(am, []byte(tag))
i++
return w.Write([]byte(fmt.Sprintf("$%d", i)))
})

View File

@ -4,8 +4,12 @@ import (
"bytes"
"crypto/sha1"
"encoding/hex"
"io"
"sort"
"strings"
"github.com/cespare/xxhash/v2"
"github.com/dosco/super-graph/jsn"
)
func mkkey(h *xxhash.Digest, k1 string, k2 string) uint64 {
@ -17,8 +21,8 @@ func mkkey(h *xxhash.Digest, k1 string, k2 string) uint64 {
return v
}
func gqlHash(b []byte) string {
b = bytes.TrimSpace(b)
func gqlHash(b string, vars []byte) string {
b = strings.TrimSpace(b)
h := sha1.New()
s, e := 0, 0
@ -45,13 +49,27 @@ func gqlHash(b []byte) string {
if e != 0 {
b0 = b[(e - 1)]
}
h.Write(bytes.ToLower(b[s:e]))
io.WriteString(h, strings.ToLower(b[s:e]))
}
if e >= len(b) {
break
}
}
if vars == nil {
return hex.EncodeToString(h.Sum(nil))
}
fields := jsn.Keys([]byte(vars))
sort.Slice(fields, func(i, j int) bool {
return bytes.Compare(fields[i], fields[j]) == -1
})
for i := range fields {
h.Write(fields[i])
}
return hex.EncodeToString(h.Sum(nil))
}

View File

@ -6,7 +6,7 @@ import (
)
func TestRelaxHash1(t *testing.T) {
var v1 = []byte(`
var v1 = `
products(
limit: 30,
@ -14,18 +14,18 @@ func TestRelaxHash1(t *testing.T) {
id
name
price
}`)
}`
var v2 = []byte(`
var v2 = `
products
(limit: 30, where: { id: { AND: { greater_or_equals: 20, lt: 28 } } }) {
id
name
price
} `)
} `
h1 := gqlHash(v1)
h2 := gqlHash(v2)
h1 := gqlHash(v1, nil)
h2 := gqlHash(v2, nil)
if strings.Compare(h1, h2) != 0 {
t.Fatal("Hashes don't match they should")
@ -33,7 +33,7 @@ func TestRelaxHash1(t *testing.T) {
}
func TestRelaxHash2(t *testing.T) {
var v1 = []byte(`
var v1 = `
{
products(
limit: 30
@ -49,12 +49,119 @@ func TestRelaxHash2(t *testing.T) {
email
}
}
}`)
}`
var v2 = []byte(` { products( limit: 30, order_by: { price: desc }, distinct: [ price ] where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) { id name price user { id email } } } `)
var v2 = ` { products( limit: 30, order_by: { price: desc }, distinct: [ price ] where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) { id name price user { id email } } } `
h1 := gqlHash(v1)
h2 := gqlHash(v2)
h1 := gqlHash(v1, nil)
h2 := gqlHash(v2, nil)
if strings.Compare(h1, h2) != 0 {
t.Fatal("Hashes don't match they should")
}
}
func TestRelaxHashWithVars1(t *testing.T) {
var q1 = `
products(
limit: 30,
where: { id: { AND: { greater_or_equals: 20, lt: 28 } } }) {
id
name
price
}`
var v1 = `
{
"insert": {
"name": "Hello",
"description": "World",
"created_at": "now",
"updated_at": "now",
"test": { "type2": "b", "type1": "a" }
},
"user": 123
}`
var q2 = `
products
(limit: 30, where: { id: { AND: { greater_or_equals: 20, lt: 28 } } }) {
id
name
price
} `
var v2 = `{
"insert": {
"created_at": "now",
"test": { "type1": "a", "type2": "b" },
"name": "Hello",
"updated_at": "now",
"description": "World"
},
"user": 123
}`
h1 := gqlHash(q1, []byte(va1))
h2 := gqlHash(q2, []byte(va2))
if strings.Compare(h1, h2) != 0 {
t.Fatal("Hashes don't match they should")
}
}
func TestRelaxHashWithVars2(t *testing.T) {
var q1 = `
products(
limit: 30,
where: { id: { AND: { greater_or_equals: 20, lt: 28 } } }) {
id
name
price
}`
var v1 = `
{
"insert": [{
"name": "Hello",
"description": "World",
"created_at": "now",
"updated_at": "now",
"test": { "type2": "b", "type1": "a" }
},
{
"name": "Hello",
"description": "World",
"created_at": "now",
"updated_at": "now",
"test": { "type2": "b", "type1": "a" }
}],
"user": 123
}`
var q2 = `
products
(limit: 30, where: { id: { AND: { greater_or_equals: 20, lt: 28 } } }) {
id
name
price
} `
var v2 = `{
"insert": {
"created_at": "now",
"test": { "type1": "a", "type2": "b" },
"name": "Hello",
"updated_at": "now",
"description": "World"
},
"user": 123
}`
h1 := gqlHash(q1, []byte(va1))
h2 := gqlHash(q2, []byte(va2))
if strings.Compare(h1, h2) != 0 {
t.Fatal("Hashes don't match they should")

View File

@ -1,95 +1,107 @@
package serv
import (
"bytes"
"fmt"
"io"
"strconv"
"strings"
"github.com/valyala/fasttemplate"
"github.com/dosco/super-graph/jsn"
)
func varMap(ctx *coreContext) variables {
userIDFn := func(w io.Writer, _ string) (int, error) {
if v := ctx.Value(userIDKey); v != nil {
return w.Write([]byte(v.(string)))
}
return 0, errNoUserID
}
func varMap(ctx *coreContext) func(w io.Writer, tag string) (int, error) {
return func(w io.Writer, tag string) (int, error) {
switch tag {
case "user_id":
if v := ctx.Value(userIDKey); v != nil {
return stringVar(w, v.(string))
}
return 0, errNoUserID
userIDProviderFn := func(w io.Writer, _ string) (int, error) {
if v := ctx.Value(userIDProviderKey); v != nil {
return w.Write([]byte(v.(string)))
}
return 0, errNoUserID
}
userIDTag := fasttemplate.TagFunc(userIDFn)
userIDProviderTag := fasttemplate.TagFunc(userIDProviderFn)
vm := variables{
"user_id": userIDTag,
"user_id_provider": userIDProviderTag,
"USER_ID": userIDTag,
"USER_ID_PROVIDER": userIDProviderTag,
}
for k, v := range ctx.req.Vars {
var buf []byte
k = strings.ToLower(k)
if _, ok := vm[k]; ok {
continue
case "user_id_provider":
if v := ctx.Value(userIDProviderKey); v != nil {
return stringVar(w, v.(string))
}
return 0, errNoUserID
}
switch val := v.(type) {
case string:
vm[k] = val
case int:
vm[k] = strconv.AppendInt(buf, int64(val), 10)
case int64:
vm[k] = strconv.AppendInt(buf, val, 10)
case float64:
vm[k] = strconv.AppendFloat(buf, val, 'f', -1, 64)
fields := jsn.Get(ctx.req.Vars, [][]byte{[]byte(tag)})
if len(fields) == 0 {
return 0, fmt.Errorf("variable '%s' not found", tag)
}
is := false
for i := range fields[0].Value {
c := fields[0].Value[i]
if c != ' ' {
is = (c == '"') || (c == '{') || (c == '[')
break
}
}
if is {
return stringVarB(w, fields[0].Value)
}
w.Write(fields[0].Value)
return 0, nil
}
return vm
}
func varList(ctx *coreContext, args []string) []interface{} {
vars := make([]interface{}, 0, len(args))
func varList(ctx *coreContext, args [][]byte) []interface{} {
vars := make([]interface{}, len(args))
for k, v := range ctx.req.Vars {
ctx.req.Vars[strings.ToLower(k)] = v
var fields map[string]interface{}
var err error
if len(ctx.req.Vars) != 0 {
fields, _, err = jsn.Tree(ctx.req.Vars)
if err != nil {
logger.Warn().Err(err).Msg("Failed to parse variables")
}
}
for i := range args {
arg := strings.ToLower(args[i])
av := args[i]
if arg == "user_id" {
switch {
case bytes.Equal(av, []byte("user_id")):
if v := ctx.Value(userIDKey); v != nil {
vars = append(vars, v.(string))
vars[i] = v.(string)
}
}
if arg == "user_id_provider" {
case bytes.Equal(av, []byte("user_id_provider")):
if v := ctx.Value(userIDProviderKey); v != nil {
vars = append(vars, v.(string))
vars[i] = v.(string)
}
}
if v, ok := ctx.req.Vars[arg]; ok {
switch val := v.(type) {
case string:
vars = append(vars, val)
case int:
vars = append(vars, strconv.FormatInt(int64(val), 10))
case int64:
vars = append(vars, strconv.FormatInt(int64(val), 10))
case float64:
vars = append(vars, strconv.FormatFloat(val, 'f', -1, 64))
default:
if v, ok := fields[string(av)]; ok {
vars[i] = v
}
}
}
return vars
}
func stringVar(w io.Writer, v string) (int, error) {
if n, err := w.Write([]byte(`'`)); err != nil {
return n, err
}
if n, err := w.Write([]byte(v)); err != nil {
return n, err
}
return w.Write([]byte(`'`))
}
func stringVarB(w io.Writer, v []byte) (int, error) {
if n, err := w.Write([]byte(`'`)); err != nil {
return n, err
}
if n, err := w.Write(v); err != nil {
return n, err
}
return w.Write([]byte(`'`))
}