From c0a21e448f6a281cd297a9aa6b0b663f263eaa6d Mon Sep 17 00:00:00 2001 From: Vikram Rangnekar Date: Thu, 5 Sep 2019 00:09:56 -0400 Subject: [PATCH] Add insert mutation with bulk insert --- config/allow.list | 77 +++++++++++++- config/dev.yml | 15 +-- config/prod.yml | 3 +- docker-compose.yml | 2 + jsn/README.md | 39 +++++++ jsn/get.go | 4 +- jsn/json_test.go | 79 +++++++++++++++ jsn/keys.go | 122 ++++++++++++++++++++++ jsn/stack.go | 51 ++++++++++ jsn/tree.go | 37 +++++++ psql/bench.4 | 7 -- psql/bench.5 | 7 -- psql/bench.6 | 7 -- psql/bench.7 | 7 -- psql/insert.go | 90 +++++++++++++++++ psql/insert_test.go | 65 ++++++++++++ psql/{psql.go => select.go} | 36 +++++-- psql/{psql_test.go => select_test.go} | 38 +++---- psql/tables.go | 30 +++--- qcode/parse.go | 42 ++++---- qcode/parse_test.go | 44 ++++++-- qcode/qcode.go | 62 ++++-------- serv/allow.go | 87 +++++++++++++--- serv/core.go | 30 +++--- serv/core_test.go | 35 +++++++ serv/http.go | 8 +- serv/prepare.go | 28 ++++-- serv/utils.go | 24 ++++- serv/utils_test.go | 129 ++++++++++++++++++++++-- serv/vars.go | 140 ++++++++++++++------------ 30 files changed, 1080 insertions(+), 265 deletions(-) create mode 100644 jsn/README.md create mode 100644 jsn/keys.go create mode 100644 jsn/stack.go create mode 100644 jsn/tree.go delete mode 100644 psql/bench.4 delete mode 100644 psql/bench.5 delete mode 100644 psql/bench.6 delete mode 100644 psql/bench.7 create mode 100644 psql/insert.go create mode 100644 psql/insert_test.go rename psql/{psql.go => select.go} (96%) rename psql/{psql_test.go => select_test.go} (96%) create mode 100644 serv/core_test.go diff --git a/config/allow.list b/config/allow.list index ed61b4f..a0befc5 100644 --- a/config/allow.list +++ b/config/allow.list @@ -47,4 +47,79 @@ query { email } } -} \ No newline at end of file + +variables { + "insert": { + "name": "Hello", + "description": "World", + "created_at": "now", + "updated_at": "now" + }, + "user": 123 +} + +mutation { + products(insert: $insert) { + id + name + description + } +} + +variables { + "insert": { + "name": "Hello", + "description": "World", + "created_at": "now", + "updated_at": "now" + }, + "user": 123 +} + +mutation { + products(insert: $insert) { + id + } +} + +variables { + "insert": { + "description": "World3", + "name": "Hello3", + "created_at": "now", + "updated_at": "now" + }, + "user": 123 +} + +{ + customers { + id + email + payments { + customer_id + amount + billing_details + } + } +} + + +variables { + "insert": { + "description": "World3", + "name": "Hello3", + "created_at": "now", + "updated_at": "now" + }, + "user": 123 +} + +{ + me { + id + full_name + } +} + + diff --git a/config/dev.yml b/config/dev.yml index aae2586..f8296ca 100644 --- a/config/dev.yml +++ b/config/dev.yml @@ -84,8 +84,9 @@ database: #log_level: "debug" # Define variables here that you want to use in filters + # sub-queries must be wrapped in () variables: - account_id: "select account_id from users where id = $user_id" + account_id: "(select account_id from users where id = $user_id)" # Define defaults to for the field key and values below defaults: @@ -105,12 +106,12 @@ database: # This filter will overwrite defaults.filter # filter: ["{ id: { eq: $user_id } }"] - - name: products - # Multiple filters are AND'd together - filter: [ - "{ price: { gt: 0 } }", - "{ price: { lt: 8 } }" - ] + # - name: products + # # Multiple filters are AND'd together + # filter: [ + # "{ price: { gt: 0 } }", + # "{ price: { lt: 8 } }" + # ] - name: customers # No filter is used for this field not diff --git a/config/prod.yml b/config/prod.yml index a6cd51e..2336bc2 100644 --- a/config/prod.yml +++ b/config/prod.yml @@ -82,8 +82,9 @@ database: #log_level: "debug" # Define variables here that you want to use in filters + # sub-queries must be wrapped in () variables: - account_id: "select account_id from users where id = $user_id" + account_id: "(select account_id from users where id = $user_id)" # Define defaults to for the field key and values below defaults: diff --git a/docker-compose.yml b/docker-compose.yml index fdf3b4e..a81ff1f 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -2,6 +2,8 @@ version: '3.4' services: db: image: postgres + ports: + - "5432:5432" # redis: # image: redis:alpine diff --git a/jsn/README.md b/jsn/README.md new file mode 100644 index 0000000..96a7d36 --- /dev/null +++ b/jsn/README.md @@ -0,0 +1,39 @@ +# JSN - Fast low allocation JSON library +## Design + +This libary is designed as a set of seperate functions to extract data and mutate +JSON. All functions are focused on keeping allocations to a minimum and be as fast +as possible. The functions don't validate the JSON a seperate `Validate` function +does that. + +The JSON parsing algo processes each object `{}` or array `[]` in a bottom up way +where once the end of the array or object is found only then the keys within it are +parsed from the top down. + +``` +{"id":1,"posts": [{"title":"PT1-1","description":"PD1-1"}], "full_name":"FN1","email":"E1" } + +id: 1 + +posts: [{"title":"PT1-1","description":"PD1-1"}] + +[{"title":"PT1-1","description":"PD1-1"}] + +{"title":"PT1-1","description":"PD1-1"} + +title: "PT1-1" + +description: "PD1-1 + +full_name: "FN1" + +email: "E1" +``` + +## Functions + +- Strip: Strip a path from the root to a child node and return the rest +- Replace: Replace values by key +- Get: Get all keys +- Filter: Extract specific keys from an object +- Tree: Fetch unique keys from an array or object diff --git a/jsn/get.go b/jsn/get.go index 584ce84..29c37ac 100644 --- a/jsn/get.go +++ b/jsn/get.go @@ -43,7 +43,7 @@ func Get(b []byte, keys [][]byte) []Field { kmap[xxhash.Sum64(keys[i])] = struct{}{} } - res := make([]Field, 20) + res := make([]Field, 0, 20) s, e, d := 0, 0, 0 @@ -127,7 +127,7 @@ func Get(b []byte, keys [][]byte) []Field { _, ok := kmap[xxhash.Sum64(k)] if ok { - res[n] = Field{k, b[s:(e + 1)]} + res = append(res, Field{k, b[s:(e + 1)]}) n++ } diff --git a/jsn/json_test.go b/jsn/json_test.go index 5a1fbec..7b08a37 100644 --- a/jsn/json_test.go +++ b/jsn/json_test.go @@ -21,6 +21,10 @@ var ( "full_name": "Caroll Orn Sr.", "email": "joannarau@hegmann.io", "__twitter_id": "ABC123" + "more": [{ + "__twitter_id": "more123", + "hello: "world + }] } }, { @@ -163,6 +167,7 @@ func TestGet(t *testing.T) { {[]byte("__twitter_id"), []byte(`"ABCD"`)}, {[]byte("__twitter_id"), []byte(`"2048666903444506956"`)}, {[]byte("__twitter_id"), []byte(`"ABC123"`)}, + {[]byte("__twitter_id"), []byte(`"more123"`)}, {[]byte("__twitter_id"), []byte(`[{ "name": "hello" }, { "name": "world"}]`)}, {[]byte("__twitter_id"), @@ -340,6 +345,80 @@ func TestReplaceEmpty(t *testing.T) { } } +func TestKeys1(t *testing.T) { + json := `[{"id":1,"posts": [{"title":"PT1-1","description":"PD1-1"}, {"title":"PT1-2","description":"PD1-2"}], "full_name":"FN1","email":"E1","books": [{"name":"BN1-1","description":"BD1-1"},{"name":"BN1-2","description":"BD1-2"},{"name":"BN1-2","description":"BD1-2"}]},{"id":1,"posts": [{"title":"PT1-1","description":"PD1-1"}, {"title":"PT1-2","description":"PD1-2"}], "full_name":"FN1","email":"E1","books": [{"name":"BN1-1","description":"BD1-1"},{"name":"BN1-2","description":"BD1-2"},{"name":"BN1-2","description":"BD1-2"}]},{"id":1,"posts": [{"title":"PT1-1","description":"PD1-1"}, {"title":"PT1-2","description":"PD1-2"}], "full_name":"FN1","email":"E1","books": [{"name":"BN1-1","description":"BD1-1"},{"name":"BN1-2","description":"BD1-2"},{"name":"BN1-2","description":"BD1-2"}]}]` + + fields := Keys([]byte(json)) + + exp := []string{ + "id", "posts", "title", "description", "full_name", "email", "books", "name", "description", + } + + if len(exp) != len(fields) { + t.Errorf("Expected %d fields %d", len(exp), len(fields)) + } + + for i := range exp { + if string(fields[i]) != exp[i] { + t.Errorf("Expected field '%s' got '%s'", string(exp[i]), fields[i]) + } + } +} + +func TestKeys2(t *testing.T) { + json := `{"id":1,"posts": [{"title":"PT1-1","description":"PD1-1"}, {"title":"PT1-2","description":"PD1-2"}], "full_name":"FN1","email":"E1","books": [{"name":"BN1-1","description":"BD1-1"},{"name":"BN1-2","description":"BD1-2"},{"name":"BN1-2","description":"BD1-2"}]}` + + fields := Keys([]byte(json)) + + exp := []string{ + "id", "posts", "title", "description", "full_name", "email", "books", "name", "description", + } + + // for i := range fields { + // fmt.Println("-->", string(fields[i])) + // } + + if len(exp) != len(fields) { + t.Errorf("Expected %d fields %d", len(exp), len(fields)) + } + + for i := range exp { + if string(fields[i]) != exp[i] { + t.Errorf("Expected field '%s' got '%s'", string(exp[i]), fields[i]) + } + } +} + +func TestKeys3(t *testing.T) { + json := `{ + "insert": { + "created_at": "now", + "test": { "type1": "a", "type2": "b" }, + "name": "Hello", + "updated_at": "now", + "description": "World" + }, + "user": 123 + }` + + fields := Keys([]byte(json)) + + exp := []string{ + "insert", "created_at", "test", "type1", "type2", "name", "updated_at", "description", + "user", + } + + if len(exp) != len(fields) { + t.Errorf("Expected %d fields %d", len(exp), len(fields)) + } + + for i := range exp { + if string(fields[i]) != exp[i] { + t.Errorf("Expected field '%s' got '%s'", string(exp[i]), fields[i]) + } + } +} + func BenchmarkGet(b *testing.B) { b.ReportAllocs() diff --git a/jsn/keys.go b/jsn/keys.go new file mode 100644 index 0000000..e1c0e1b --- /dev/null +++ b/jsn/keys.go @@ -0,0 +1,122 @@ +package jsn + +func Keys(b []byte) [][]byte { + res := make([][]byte, 0, 20) + + s, e, d := 0, 0, 0 + + var k []byte + state := expectValue + + st := NewStack() + ae := 0 + + for i := 0; i < len(b); i++ { + if state == expectObjClose || state == expectListClose { + switch b[i] { + case '{', '[': + d++ + case '}', ']': + d-- + } + } + + si := st.Peek() + + switch { + case state == expectKey && si != nil && i >= si.ss: + i = si.se + 1 + st.Pop() + + case state == expectKey && b[i] == '{': + state = expectObjClose + s = i + d++ + + case state == expectObjClose && d == 0 && b[i] == '}': + state = expectKey + if ae != 0 { + st.Push(skipInfo{i, ae}) + ae = 0 + } + e = i + i = s + + case state == expectKey && b[i] == '"': + state = expectKeyClose + s = i + + case state == expectKeyClose && b[i] == '"': + state = expectColon + k = b[(s + 1):i] + + case state == expectColon && b[i] == ':': + state = expectValue + + case state == expectValue && b[i] == '"': + state = expectString + s = i + + case state == expectString && b[i] == '"': + e = i + + case state == expectValue && b[i] == '{': + state = expectObjClose + s = i + d++ + + case state == expectObjClose && d == 0 && b[i] == '}': + state = expectKey + e = i + i = s + + case state == expectValue && b[i] == '[': + state = expectListClose + s = i + d++ + + case state == expectListClose && d == 0 && b[i] == ']': + state = expectKey + ae = i + e = i + i = s + + case state == expectValue && (b[i] >= '0' && b[i] <= '9'): + state = expectNumClose + s = i + + case state == expectNumClose && + ((b[i] < '0' || b[i] > '9') && + (b[i] != '.' && b[i] != 'e' && b[i] != 'E' && b[i] != '+' && b[i] != '-')): + i-- + e = i + + case state == expectValue && + (b[i] == 'f' || b[i] == 'F' || b[i] == 't' || b[i] == 'T'): + state = expectBoolClose + s = i + + case state == expectBoolClose && (b[i] == 'e' || b[i] == 'E'): + e = i + + case state == expectValue && b[i] == 'n': + state = expectNull + + case state == expectNull && b[i] == 'l': + e = i + } + + if e != 0 { + if k != nil { + res = append(res, k) + } + + state = expectKey + k = nil + e = 0 + } + + } + + return res +} diff --git a/jsn/stack.go b/jsn/stack.go new file mode 100644 index 0000000..f49484e --- /dev/null +++ b/jsn/stack.go @@ -0,0 +1,51 @@ +package jsn + +type skipInfo struct { + ss, se int +} + +type Stack struct { + stA [20]skipInfo + st []skipInfo + top int +} + +// Create a new Stack +func NewStack() *Stack { + s := &Stack{top: -1} + s.st = s.stA[:0] + return s +} + +// Return the number of items in the Stack +func (s *Stack) Len() int { + return (s.top + 1) +} + +// View the top item on the Stack +func (s *Stack) Peek() *skipInfo { + if s.top == -1 { + return nil + } + return &s.st[s.top] +} + +// Pop the top item of the Stack and return it +func (s *Stack) Pop() *skipInfo { + if s.top == -1 { + return nil + } + + s.top-- + return &s.st[(s.top + 1)] +} + +// Push a value onto the top of the Stack +func (s *Stack) Push(value skipInfo) { + s.top++ + if len(s.st) <= s.top { + s.st = append(s.st, value) + } else { + s.st[s.top] = value + } +} diff --git a/jsn/tree.go b/jsn/tree.go new file mode 100644 index 0000000..121500e --- /dev/null +++ b/jsn/tree.go @@ -0,0 +1,37 @@ +package jsn + +import ( + "bytes" + "encoding/json" +) + +func Tree(v []byte) (map[string]interface{}, bool, error) { + dec := json.NewDecoder(bytes.NewReader(v)) + array := false + + // read open bracket + + for i := range v { + if v[i] != ' ' { + array = (v[i] == '[') + break + } + } + + if array { + if _, err := dec.Token(); err != nil { + return nil, false, err + } + } + + // while the array contains values + var m map[string]interface{} + + // decode an array value (Message) + err := dec.Decode(&m) + if err != nil { + return nil, false, err + } + + return m, array, nil +} diff --git a/psql/bench.4 b/psql/bench.4 deleted file mode 100644 index a3e431f..0000000 --- a/psql/bench.4 +++ /dev/null @@ -1,7 +0,0 @@ -goos: darwin -goarch: amd64 -pkg: github.com/dosco/super-graph/psql -BenchmarkCompile-8 100000 16476 ns/op 3282 B/op 66 allocs/op -BenchmarkCompileParallel-8 300000 4639 ns/op 3324 B/op 66 allocs/op -PASS -ok github.com/dosco/super-graph/psql 3.274s diff --git a/psql/bench.5 b/psql/bench.5 deleted file mode 100644 index 9e9eaf1..0000000 --- a/psql/bench.5 +++ /dev/null @@ -1,7 +0,0 @@ -goos: darwin -goarch: amd64 -pkg: github.com/dosco/super-graph/psql -BenchmarkCompile-8 100000 15728 ns/op 3000 B/op 60 allocs/op -BenchmarkCompileParallel-8 300000 5077 ns/op 3023 B/op 60 allocs/op -PASS -ok github.com/dosco/super-graph/psql 3.318s diff --git a/psql/bench.6 b/psql/bench.6 deleted file mode 100644 index 9724f6f..0000000 --- a/psql/bench.6 +++ /dev/null @@ -1,7 +0,0 @@ -goos: darwin -goarch: amd64 -pkg: github.com/dosco/super-graph/psql -BenchmarkCompile-8 1000000 15997 ns/op 3048 B/op 58 allocs/op -BenchmarkCompileParallel-8 3000000 4722 ns/op 3073 B/op 58 allocs/op -PASS -ok github.com/dosco/super-graph/psql 35.024s diff --git a/psql/bench.7 b/psql/bench.7 deleted file mode 100644 index 1deee4b..0000000 --- a/psql/bench.7 +++ /dev/null @@ -1,7 +0,0 @@ -goos: darwin -goarch: amd64 -pkg: github.com/dosco/super-graph/psql -BenchmarkCompile-8 100000 16829 ns/op 2887 B/op 57 allocs/op -BenchmarkCompileParallel-8 300000 5450 ns/op 2911 B/op 57 allocs/op -PASS -ok github.com/dosco/super-graph/psql 3.561s diff --git a/psql/insert.go b/psql/insert.go new file mode 100644 index 0000000..9bc9d0c --- /dev/null +++ b/psql/insert.go @@ -0,0 +1,90 @@ +package psql + +import ( + "bytes" + "errors" + "io" + + "github.com/dosco/super-graph/jsn" + "github.com/dosco/super-graph/qcode" +) + +func (co *Compiler) compileMutation(qc *qcode.QCode, w *bytes.Buffer, vars Variables) (uint32, error) { + if len(qc.Selects) == 0 { + return 0, errors.New("empty query") + } + + c := &compilerContext{w, qc.Selects, co} + root := &qc.Selects[0] + + c.w.WriteString(`WITH `) + c.w.WriteString(root.Table) + c.w.WriteString(` AS (`) + + if _, err := c.renderInsert(qc, w, vars); err != nil { + return 0, err + } + + c.w.WriteString(`) `) + + return c.compileQuery(qc, w) +} + +func (c *compilerContext) renderInsert(qc *qcode.QCode, w *bytes.Buffer, vars Variables) (uint32, error) { + root := &qc.Selects[0] + + insert, ok := vars["insert"] + if !ok { + return 0, errors.New("Variable 'insert' not defined") + } + + jt, array, err := jsn.Tree(insert) + if err != nil { + return 0, err + } + + c.w.WriteString(`WITH input AS (SELECT {{insert}}::json AS j) INSERT INTO `) + c.w.WriteString(root.Table) + io.WriteString(c.w, " (") + c.renderInsertColumns(qc, w, jt) + io.WriteString(c.w, ")") + + c.w.WriteString(` SELECT `) + c.renderInsertColumns(qc, w, jt) + c.w.WriteString(` FROM input i, `) + + if array { + c.w.WriteString(`json_populate_recordset`) + } else { + c.w.WriteString(`json_populate_record`) + } + + c.w.WriteString(`(NULL::`) + c.w.WriteString(root.Table) + c.w.WriteString(`, i.j) t RETURNING * `) + + return 0, nil +} + +func (c *compilerContext) renderInsertColumns(qc *qcode.QCode, w *bytes.Buffer, + jt map[string]interface{}) (uint32, error) { + + ti, err := c.schema.GetTable(qc.Selects[0].Table) + if err != nil { + return 0, err + } + + i := 0 + for _, cn := range ti.ColumnNames { + if _, ok := jt[cn]; !ok { + continue + } + if i != 0 { + io.WriteString(c.w, ", ") + } + c.w.WriteString(cn) + i++ + } + + return 0, nil +} diff --git a/psql/insert_test.go b/psql/insert_test.go new file mode 100644 index 0000000..a5d1a25 --- /dev/null +++ b/psql/insert_test.go @@ -0,0 +1,65 @@ +package psql + +import ( + "encoding/json" + "fmt" + "testing" +) + +func singleInsert(t *testing.T) { + gql := `mutation { + product(id: 15, insert: $insert) { + id + name + } + }` + + sql := `test` + + vars := map[string]json.RawMessage{ + "insert": json.RawMessage(` { "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }`), + } + + resSQL, err := compileGQLToPSQL(gql, vars) + if err != nil { + t.Fatal(err) + } + + fmt.Println(">", string(resSQL)) + + if string(resSQL) != sql { + t.Fatal(errNotExpected) + } +} + +func bulkInsert(t *testing.T) { + gql := `mutation { + product(id: 15, insert: $insert) { + id + name + } + }` + + sql := `test` + + vars := map[string]json.RawMessage{ + "insert": json.RawMessage(` [{ "name": "my_name", "woo": { "hoo": "goo" }, "description": "my_desc" }]`), + } + + resSQL, err := compileGQLToPSQL(gql, vars) + if err != nil { + t.Fatal(err) + } + + fmt.Println(">", string(resSQL)) + + if string(resSQL) != sql { + t.Fatal(errNotExpected) + } +} + +func TestCompileInsert(t *testing.T) { + t.Run("singleInsert", singleInsert) + t.Run("bulkInsert", bulkInsert) + +} diff --git a/psql/psql.go b/psql/select.go similarity index 96% rename from psql/psql.go rename to psql/select.go index 00213c0..1442fc8 100644 --- a/psql/psql.go +++ b/psql/select.go @@ -2,6 +2,7 @@ package psql import ( "bytes" + "encoding/json" "errors" "fmt" "io" @@ -18,6 +19,8 @@ const ( closeBlock = 500 ) +type Variables map[string]json.RawMessage + type Config struct { Schema *DBSchema Vars map[string]string @@ -51,19 +54,30 @@ type compilerContext struct { *Compiler } -func (co *Compiler) CompileEx(qc *qcode.QCode) (uint32, []byte, error) { +func (co *Compiler) CompileEx(qc *qcode.QCode, vars Variables) (uint32, []byte, error) { w := &bytes.Buffer{} - skipped, err := co.Compile(qc, w) + skipped, err := co.Compile(qc, w, vars) return skipped, w.Bytes(), err } -func (co *Compiler) Compile(qc *qcode.QCode, w *bytes.Buffer) (uint32, error) { - if len(qc.Query.Selects) == 0 { +func (co *Compiler) Compile(qc *qcode.QCode, w *bytes.Buffer, vars Variables) (uint32, error) { + switch qc.Type { + case qcode.QTQuery: + return co.compileQuery(qc, w) + case qcode.QTMutation: + return co.compileMutation(qc, w, vars) + } + + return 0, errors.New("unknown operation") +} + +func (co *Compiler) compileQuery(qc *qcode.QCode, w *bytes.Buffer) (uint32, error) { + if len(qc.Selects) == 0 { return 0, errors.New("empty query") } - c := &compilerContext{w, qc.Query.Selects, co} - root := &qc.Query.Selects[0] + c := &compilerContext{w, qc.Selects, co} + root := &qc.Selects[0] st := NewStack() st.Push(root.ID + closeBlock) @@ -844,7 +858,7 @@ func (c *compilerContext) renderList(ex *qcode.Exp) { func (c *compilerContext) renderVal(ex *qcode.Exp, vars map[string]string) { - io.WriteString(c.w, ` (`) + //io.WriteString(c.w, ` (`) switch ex.Type { case qcode.ValBool, qcode.ValInt, qcode.ValFloat: if len(ex.Val) != 0 { @@ -852,21 +866,23 @@ func (c *compilerContext) renderVal(ex *qcode.Exp, } else { c.w.WriteString(`''`) } + case qcode.ValStr: c.w.WriteString(`'`) c.w.WriteString(ex.Val) c.w.WriteString(`'`) + case qcode.ValVar: if val, ok := vars[ex.Val]; ok { c.w.WriteString(val) } else { //fmt.Fprintf(w, `'{{%s}}'`, ex.Val) - c.w.WriteString(`'{{`) + c.w.WriteString(`{{`) c.w.WriteString(ex.Val) - c.w.WriteString(`}}'`) + c.w.WriteString(`}}`) } } - c.w.WriteString(`)`) + //c.w.WriteString(`)`) } func funcPrefixLen(fn string) int { diff --git a/psql/psql_test.go b/psql/select_test.go similarity index 96% rename from psql/psql_test.go rename to psql/select_test.go index 591c581..2e9e571 100644 --- a/psql/psql_test.go +++ b/psql/select_test.go @@ -125,13 +125,13 @@ func TestMain(m *testing.M) { os.Exit(m.Run()) } -func compileGQLToPSQL(gql string) ([]byte, error) { +func compileGQLToPSQL(gql string, vars Variables) ([]byte, error) { qc, err := qcompile.Compile([]byte(gql)) if err != nil { return nil, err } - _, sqlStmt, err := pcompile.CompileEx(qc) + _, sqlStmt, err := pcompile.CompileEx(qc, vars) if err != nil { return nil, err } @@ -164,7 +164,7 @@ func withComplexArgs(t *testing.T) { sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products" ORDER BY "products_0_price_ob" DESC), '[]') AS "products" FROM (SELECT DISTINCT ON ("products_0_price_ob") row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products", "products_0"."price" AS "products_0_price_ob" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("products"."id") < (28)) AND (("products"."id") >= (20))) LIMIT ('30') :: integer) AS "products_0" ORDER BY "products_0_price_ob" DESC LIMIT ('30') :: integer) AS "products_0") AS "done_1337";` - resSQL, err := compileGQLToPSQL(gql) + resSQL, err := compileGQLToPSQL(gql, nil) if err != nil { t.Fatal(err) } @@ -192,7 +192,7 @@ func withWhereMultiOr(t *testing.T) { sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("products"."price") < (20)) OR (("products"."price") > (10)) OR NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";` - resSQL, err := compileGQLToPSQL(gql) + resSQL, err := compileGQLToPSQL(gql, nil) if err != nil { t.Fatal(err) } @@ -218,7 +218,7 @@ func withWhereIsNull(t *testing.T) { sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("products"."price") > (10)) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";` - resSQL, err := compileGQLToPSQL(gql) + resSQL, err := compileGQLToPSQL(gql, nil) if err != nil { t.Fatal(err) } @@ -244,7 +244,7 @@ func withWhereAndList(t *testing.T) { sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."price" AS "price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name", "products"."price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("products"."price") > (10)) AND NOT (("products"."id") IS NULL)) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";` - resSQL, err := compileGQLToPSQL(gql) + resSQL, err := compileGQLToPSQL(gql, nil) if err != nil { t.Fatal(err) } @@ -264,7 +264,7 @@ func fetchByID(t *testing.T) { sql := `SELECT json_object_agg('product', product) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "product_0"."id" AS "id", "product_0"."name" AS "name") AS "sel_0")) AS "product" FROM (SELECT "product"."id", "product"."name" FROM "products" AS "product" WHERE ((("product"."price") > (0)) AND (("product"."price") < (8)) AND (("id") = (15))) LIMIT ('1') :: integer) AS "product_0" LIMIT ('1') :: integer) AS "done_1337";` - resSQL, err := compileGQLToPSQL(gql) + resSQL, err := compileGQLToPSQL(gql, nil) if err != nil { t.Fatal(err) } @@ -284,7 +284,7 @@ func searchQuery(t *testing.T) { sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "sel_0")) AS "products" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("tsv") @@ to_tsquery('Imperial'))) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";` - resSQL, err := compileGQLToPSQL(gql) + resSQL, err := compileGQLToPSQL(gql, nil) if err != nil { t.Fatal(err) } @@ -307,7 +307,7 @@ func oneToMany(t *testing.T) { sql := `SELECT json_object_agg('users', users) FROM (SELECT coalesce(json_agg("users"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "users_0"."email" AS "email", "products_1_join"."products" AS "products") AS "sel_0")) AS "users" FROM (SELECT "users"."email", "users"."id" FROM "users" WHERE ((("users"."id") = ('{{user_id}}'))) LIMIT ('20') :: integer) AS "users_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "products_1"."name" AS "name", "products_1"."price" AS "price") AS "sel_1")) AS "products" FROM (SELECT "products"."name", "products"."price" FROM "products" WHERE ((("products"."user_id") = ("users_0"."id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "products_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "users_0") AS "done_1337";` - resSQL, err := compileGQLToPSQL(gql) + resSQL, err := compileGQLToPSQL(gql, nil) if err != nil { t.Fatal(err) } @@ -330,7 +330,7 @@ func belongsTo(t *testing.T) { sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."price" AS "price", "users_1_join"."users" AS "users") AS "sel_0")) AS "products" FROM (SELECT "products"."name", "products"."price", "products"."user_id" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8))) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("users"), '[]') AS "users" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "users_1"."email" AS "email") AS "sel_1")) AS "users" FROM (SELECT "users"."email" FROM "users" WHERE ((("users"."id") = ("products_0"."user_id"))) LIMIT ('20') :: integer) AS "users_1" LIMIT ('20') :: integer) AS "users_1") AS "users_1_join" ON ('true') LIMIT ('20') :: integer) AS "products_0") AS "done_1337";` - resSQL, err := compileGQLToPSQL(gql) + resSQL, err := compileGQLToPSQL(gql, nil) if err != nil { t.Fatal(err) } @@ -353,7 +353,7 @@ func manyToMany(t *testing.T) { sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "customers_1_join"."customers" AS "customers") AS "sel_0")) AS "products" FROM (SELECT "products"."name", "products"."id" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8))) LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("customers"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "customers_1"."email" AS "email", "customers_1"."full_name" AS "full_name") AS "sel_1")) AS "customers" FROM (SELECT "customers"."email", "customers"."full_name" FROM "customers" LEFT OUTER JOIN "purchases" ON (("purchases"."product_id") = ("products_0"."id")) WHERE ((("customers"."id") = ("purchases"."customer_id"))) LIMIT ('20') :: integer) AS "customers_1" LIMIT ('20') :: integer) AS "customers_1") AS "customers_1_join" ON ('true') LIMIT ('20') :: integer) AS "products_0") AS "done_1337";` - resSQL, err := compileGQLToPSQL(gql) + resSQL, err := compileGQLToPSQL(gql, nil) if err != nil { t.Fatal(err) } @@ -376,7 +376,7 @@ func manyToManyReverse(t *testing.T) { sql := `SELECT json_object_agg('customers', customers) FROM (SELECT coalesce(json_agg("customers"), '[]') AS "customers" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "customers_0"."email" AS "email", "customers_0"."full_name" AS "full_name", "products_1_join"."products" AS "products") AS "sel_0")) AS "customers" FROM (SELECT "customers"."email", "customers"."full_name", "customers"."id" FROM "customers" LIMIT ('20') :: integer) AS "customers_0" LEFT OUTER JOIN LATERAL (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_1" FROM (SELECT "products_1"."name" AS "name") AS "sel_1")) AS "products" FROM (SELECT "products"."name" FROM "products" LEFT OUTER JOIN "purchases" ON (("purchases"."customer_id") = ("customers_0"."id")) WHERE ((("products"."id") = ("purchases"."product_id"))) LIMIT ('20') :: integer) AS "products_1" LIMIT ('20') :: integer) AS "products_1") AS "products_1_join" ON ('true') LIMIT ('20') :: integer) AS "customers_0") AS "done_1337";` - resSQL, err := compileGQLToPSQL(gql) + resSQL, err := compileGQLToPSQL(gql, nil) if err != nil { t.Fatal(err) } @@ -396,7 +396,7 @@ func aggFunction(t *testing.T) { sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."name" AS "name", "products_0"."count_price" AS "count_price") AS "sel_0")) AS "products" FROM (SELECT "products"."name", count("products"."price") AS "count_price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8))) GROUP BY "products"."name" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";` - resSQL, err := compileGQLToPSQL(gql) + resSQL, err := compileGQLToPSQL(gql, nil) if err != nil { t.Fatal(err) } @@ -416,7 +416,7 @@ func aggFunctionWithFilter(t *testing.T) { sql := `SELECT json_object_agg('products', products) FROM (SELECT coalesce(json_agg("products"), '[]') AS "products" FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."max_price" AS "max_price") AS "sel_0")) AS "products" FROM (SELECT "products"."id", max("products"."price") AS "max_price" FROM "products" WHERE ((("products"."price") > (0)) AND (("products"."price") < (8)) AND (("products"."id") > (10))) GROUP BY "products"."id" LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "products_0") AS "done_1337";` - resSQL, err := compileGQLToPSQL(gql) + resSQL, err := compileGQLToPSQL(gql, nil) if err != nil { t.Fatal(err) } @@ -436,7 +436,7 @@ func queryWithVariables(t *testing.T) { sql := `SELECT json_object_agg('product', product) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "product_0"."id" AS "id", "product_0"."name" AS "name") AS "sel_0")) AS "product" FROM (SELECT "product"."id", "product"."name" FROM "products" AS "product" WHERE ((("product"."price") > (0)) AND (("product"."price") < (8)) AND (("product"."price") = ('{{product_price}}')) AND (("id") = ('{{product_id}}'))) LIMIT ('1') :: integer) AS "product_0" LIMIT ('1') :: integer) AS "done_1337";` - resSQL, err := compileGQLToPSQL(gql) + resSQL, err := compileGQLToPSQL(gql, nil) if err != nil { t.Fatal(err) } @@ -455,7 +455,7 @@ func syntheticTables(t *testing.T) { sql := `SELECT json_object_agg('me', me) FROM (SELECT row_to_json((SELECT "sel_0" FROM (SELECT "me_0"."email" AS "email") AS "sel_0")) AS "me" FROM (SELECT "me"."email" FROM "users" AS "me" WHERE ((("me"."id") = ('{{user_id}}'))) LIMIT ('1') :: integer) AS "me_0" LIMIT ('1') :: integer) AS "done_1337";` - resSQL, err := compileGQLToPSQL(gql) + resSQL, err := compileGQLToPSQL(gql, nil) if err != nil { t.Fatal(err) } @@ -465,7 +465,7 @@ func syntheticTables(t *testing.T) { } } -func TestCompileGQL(t *testing.T) { +func TestCompileSelect(t *testing.T) { t.Run("withComplexArgs", withComplexArgs) t.Run("withWhereAndList", withWhereAndList) t.Run("withWhereIsNull", withWhereIsNull) @@ -519,7 +519,7 @@ func BenchmarkCompile(b *testing.B) { b.Fatal(err) } - _, err = pcompile.Compile(qc, w) + _, err = pcompile.Compile(qc, w, nil) if err != nil { b.Fatal(err) } @@ -540,7 +540,7 @@ func BenchmarkCompileParallel(b *testing.B) { b.Fatal(err) } - _, err = pcompile.Compile(qc, w) + _, err = pcompile.Compile(qc, w, nil) if err != nil { b.Fatal(err) } diff --git a/psql/tables.go b/psql/tables.go index 84e11d5..2c94230 100644 --- a/psql/tables.go +++ b/psql/tables.go @@ -106,11 +106,12 @@ type DBSchema struct { } type DBTableInfo struct { - Name string - Singular bool - PrimaryCol string - TSVCol string - Columns map[string]*DBColumn + Name string + Singular bool + PrimaryCol string + TSVCol string + Columns map[string]*DBColumn + ColumnNames []string } type RelType int @@ -162,25 +163,30 @@ func (s *DBSchema) updateSchema( // Foreign key columns in current table colByID := make(map[int]*DBColumn) columns := make(map[string]*DBColumn, len(cols)) + colNames := make([]string, len(cols)) for i := range cols { c := cols[i] - columns[strings.ToLower(c.Name)] = cols[i] + name := strings.ToLower(c.Name) + columns[name] = cols[i] + colNames = append(colNames, name) colByID[c.ID] = cols[i] } singular := strings.ToLower(flect.Singularize(t.Name)) s.t[singular] = &DBTableInfo{ - Name: t.Name, - Singular: true, - Columns: columns, + Name: t.Name, + Singular: true, + Columns: columns, + ColumnNames: colNames, } plural := strings.ToLower(flect.Pluralize(t.Name)) s.t[plural] = &DBTableInfo{ - Name: t.Name, - Singular: false, - Columns: columns, + Name: t.Name, + Singular: false, + Columns: columns, + ColumnNames: colNames, } ct := strings.ToLower(t.Name) diff --git a/qcode/parse.go b/qcode/parse.go index 7bd10fa..c07ab45 100644 --- a/qcode/parse.go +++ b/qcode/parse.go @@ -100,19 +100,7 @@ var lexPool = sync.Pool{ } func Parse(gql []byte) (*Operation, error) { - return parseSelectionSet(nil, gql) -} - -func ParseQuery(gql []byte) (*Operation, error) { - op := opPool.Get().(*Operation) - op.Reset() - - op.Type = opQuery - op.Name = "" - op.Fields = op.fieldsA[:0] - op.Args = op.argsA[:0] - - return parseSelectionSet(op, gql) + return parseSelectionSet(gql) } func ParseArgValue(argVal string) (*Node, error) { @@ -134,7 +122,7 @@ func ParseArgValue(argVal string) (*Node, error) { return op, err } -func parseSelectionSet(op *Operation, gql []byte) (*Operation, error) { +func parseSelectionSet(gql []byte) (*Operation, error) { var err error if len(gql) == 0 { @@ -154,14 +142,28 @@ func parseSelectionSet(op *Operation, gql []byte) (*Operation, error) { items: l.items, } - if op == nil { - op, err = p.parseOp() - } else { - if p.peek(itemObjOpen) { - p.ignore() - } + var op *Operation + if p.peek(itemObjOpen) { + p.ignore() + } + + if p.peek(itemName) { + op = opPool.Get().(*Operation) + op.Reset() + + op.Type = opQuery + op.Name = "" + op.Fields = op.fieldsA[:0] + op.Args = op.argsA[:0] op.Fields, err = p.parseFields(op.Fields) + + } else { + op, err = p.parseOp() + + if err != nil { + return nil, err + } } lexPool.Put(l) diff --git a/qcode/parse_test.go b/qcode/parse_test.go index 6edd9ca..dba397f 100644 --- a/qcode/parse_test.go +++ b/qcode/parse_test.go @@ -45,10 +45,10 @@ func compareOp(op1, op2 Operation) error { } */ -func TestCompile(t *testing.T) { +func TestCompile1(t *testing.T) { qcompile, _ := NewCompiler(Config{}) - _, err := qcompile.CompileQuery([]byte(` + _, err := qcompile.Compile([]byte(` product(id: 15) { id name @@ -59,9 +59,39 @@ func TestCompile(t *testing.T) { } } +func TestCompile2(t *testing.T) { + qcompile, _ := NewCompiler(Config{}) + + _, err := qcompile.Compile([]byte(` + query { product(id: 15) { + id + name + } }`)) + + if err != nil { + t.Fatal(err) + } +} + +func TestCompile3(t *testing.T) { + qcompile, _ := NewCompiler(Config{}) + + _, err := qcompile.Compile([]byte(` + mutation { + product(id: 15, name: "Test") { + id + name + } + }`)) + + if err != nil { + t.Fatal(err) + } +} + func TestInvalidCompile1(t *testing.T) { qcompile, _ := NewCompiler(Config{}) - _, err := qcompile.CompileQuery([]byte(`#`)) + _, err := qcompile.Compile([]byte(`#`)) if err == nil { t.Fatal(errors.New("expecting an error")) @@ -70,7 +100,7 @@ func TestInvalidCompile1(t *testing.T) { func TestInvalidCompile2(t *testing.T) { qcompile, _ := NewCompiler(Config{}) - _, err := qcompile.CompileQuery([]byte(`{u(where:{not:0})}`)) + _, err := qcompile.Compile([]byte(`{u(where:{not:0})}`)) if err == nil { t.Fatal(errors.New("expecting an error")) @@ -79,7 +109,7 @@ func TestInvalidCompile2(t *testing.T) { func TestEmptyCompile(t *testing.T) { qcompile, _ := NewCompiler(Config{}) - _, err := qcompile.CompileQuery([]byte(``)) + _, err := qcompile.Compile([]byte(``)) if err == nil { t.Fatal(errors.New("expecting an error")) @@ -114,7 +144,7 @@ func BenchmarkQCompile(b *testing.B) { b.ReportAllocs() for n := 0; n < b.N; n++ { - _, err := qcompile.CompileQuery(gql) + _, err := qcompile.Compile(gql) if err != nil { b.Fatal(err) @@ -130,7 +160,7 @@ func BenchmarkQCompileP(b *testing.B) { b.RunParallel(func(pb *testing.PB) { for pb.Next() { - _, err := qcompile.CompileQuery(gql) + _, err := qcompile.Compile(gql) if err != nil { b.Fatal(err) diff --git a/qcode/qcode.go b/qcode/qcode.go index 0335482..306f745 100644 --- a/qcode/qcode.go +++ b/qcode/qcode.go @@ -10,15 +10,17 @@ import ( "github.com/gobuffalo/flect" ) +type QType int + const ( maxSelectors = 30 + + QTQuery QType = iota + 1 + QTMutation ) type QCode struct { - Query *Query -} - -type Query struct { + Type QType Selects []Select } @@ -149,6 +151,11 @@ type Compiler struct { ka bool } +var opMap = map[parserType]QType{ + opQuery: QTQuery, + opMutate: QTMutation, +} + var expPool = sync.Pool{ New: func() interface{} { return new(Exp) }, } @@ -196,44 +203,23 @@ func (com *Compiler) Compile(query []byte) (*QCode, error) { return nil, err } - switch op.Type { - case opQuery: - qc.Query, err = com.compileQuery(op) - case opMutate: - case opSub: - default: - err = fmt.Errorf("Unknown operation type %d", op.Type) - } - + qc.Selects, err = com.compileQuery(op) if err != nil { return nil, err } + if t, ok := opMap[op.Type]; ok { + qc.Type = t + } else { + return nil, fmt.Errorf("Unknown operation type %d", op.Type) + } + opPool.Put(op) return &qc, nil } -func (com *Compiler) CompileQuery(query []byte) (*QCode, error) { - var err error - - op, err := ParseQuery(query) - if err != nil { - return nil, err - } - - qc := &QCode{} - qc.Query, err = com.compileQuery(op) - opPool.Put(op) - - if err != nil { - return nil, err - } - - return qc, nil -} - -func (com *Compiler) compileQuery(op *Operation) (*Query, error) { +func (com *Compiler) compileQuery(op *Operation) ([]Select, error) { id := int32(0) parentID := int32(0) @@ -344,7 +330,7 @@ func (com *Compiler) compileQuery(op *Operation) (*Query, error) { return nil, errors.New("invalid query") } - return &Query{selects[:id]}, nil + return selects[:id], nil } func (com *Compiler) compileArgs(sel *Select, args []Arg) error { @@ -661,14 +647,6 @@ func (com *Compiler) compileArgOffset(sel *Select, arg *Arg) error { return nil } -func compileMutate() (*Query, error) { - return nil, nil -} - -func compileSub() (*Query, error) { - return nil, nil -} - func newExp(st *util.Stack, node *Node, usePool bool) (*Exp, error) { name := node.Name if name[0] == '_' { diff --git a/serv/allow.go b/serv/allow.go index 0816d96..0fec2ca 100644 --- a/serv/allow.go +++ b/serv/allow.go @@ -1,6 +1,7 @@ package serv import ( + "encoding/json" "fmt" "io/ioutil" "log" @@ -9,9 +10,15 @@ import ( "strings" ) +const ( + AL_QUERY int = iota + 1 + AL_VARS +) + type allowItem struct { - uri string - gql string + uri string + gql string + vars json.RawMessage } var _allowList allowList @@ -77,8 +84,9 @@ func (al *allowList) add(req *gqlReq) { } al.saveChan <- &allowItem{ - uri: req.ref, - gql: req.Query, + uri: req.ref, + gql: req.Query, + vars: req.Vars, } } @@ -93,32 +101,62 @@ func (al *allowList) load() { } var uri string + var varBytes []byte s, e, c := 0, 0, 0 + ty := 0 for { if c == 0 && b[e] == '#' { s = e - for b[e] != '\n' && e < len(b) { + for e < len(b) && b[e] != '\n' { e++ } if (e - s) > 2 { uri = strings.TrimSpace(string(b[(s + 1):e])) } } - if b[e] == '{' { + + if e >= len(b) { + break + } + + if matchPrefix(b, e, "query") || matchPrefix(b, e, "mutation") { if c == 0 { s = e } + ty = AL_QUERY + } else if matchPrefix(b, e, "variables") { + if c == 0 { + s = e + len("variables") + 1 + } + ty = AL_VARS + } else if b[e] == '{' { c++ + } else if b[e] == '}' { c-- + if c == 0 { - q := b[s:(e + 1)] - al.list[gqlHash(q)] = &allowItem{ - uri: uri, - gql: string(q), + if ty == AL_QUERY { + q := string(b[s:(e + 1)]) + + item := &allowItem{ + uri: uri, + gql: q, + } + + if len(varBytes) != 0 { + item.vars = varBytes + } + + al.list[gqlHash(q, varBytes)] = item + varBytes = nil + + } else if ty == AL_VARS { + varBytes = b[s:(e + 1)] } + ty = 0 } } @@ -130,7 +168,7 @@ func (al *allowList) load() { } func (al *allowList) save(item *allowItem) { - al.list[gqlHash([]byte(item.gql))] = item + al.list[gqlHash(item.gql, item.vars)] = item f, err := os.Create(al.filepath) if err != nil { @@ -141,10 +179,10 @@ func (al *allowList) save(item *allowItem) { defer f.Close() keys := []string{} - urlMap := make(map[string][]string) + urlMap := make(map[string][]*allowItem) for _, v := range al.list { - urlMap[v.uri] = append(urlMap[v.uri], v.gql) + urlMap[v.uri] = append(urlMap[v.uri], v) } for k := range urlMap { @@ -159,7 +197,28 @@ func (al *allowList) save(item *allowItem) { f.WriteString(fmt.Sprintf("# %s\n\n", k)) for i := range v { - f.WriteString(fmt.Sprintf("query %s\n\n", v[i])) + if len(v[i].vars) != 0 { + vj, err := json.MarshalIndent(v[i].vars, "", "\t") + if err != nil { + logger.Warn().Err(err).Msg("Failed to write allow list 'vars' to file") + continue + } + f.WriteString(fmt.Sprintf("variables %s\n\n", vj)) + } + + f.WriteString(fmt.Sprintf("%s\n\n", v[i].gql)) } } } + +func matchPrefix(b []byte, i int, s string) bool { + if (len(b) - i) < len(s) { + return false + } + for n := 0; n < len(s); n++ { + if b[(i+n)] != s[n] { + return false + } + } + return true +} diff --git a/serv/core.go b/serv/core.go index c5bc3ab..a8e114a 100644 --- a/serv/core.go +++ b/serv/core.go @@ -14,6 +14,7 @@ import ( "github.com/cespare/xxhash/v2" "github.com/dosco/super-graph/jsn" + "github.com/dosco/super-graph/psql" "github.com/dosco/super-graph/qcode" "github.com/go-pg/pg" "github.com/valyala/fasttemplate" @@ -42,7 +43,7 @@ func (c *coreContext) handleReq(w io.Writer, req *http.Request) error { if conf.UseAllowList { var ps *preparedItem - data, ps, err = c.resolvePreparedSQL([]byte(c.req.Query)) + data, ps, err = c.resolvePreparedSQL(c.req.Query) if err != nil { return err } @@ -52,7 +53,7 @@ func (c *coreContext) handleReq(w io.Writer, req *http.Request) error { } else { - qc, err = qcompile.CompileQuery([]byte(c.req.Query)) + qc, err = qcompile.Compile([]byte(c.req.Query)) if err != nil { return err } @@ -67,7 +68,7 @@ func (c *coreContext) handleReq(w io.Writer, req *http.Request) error { return c.render(w, data) } - sel := qc.Query.Selects + sel := qc.Selects h := xxhash.New() // fetch the field name used within the db response json @@ -252,8 +253,8 @@ func (c *coreContext) resolveRemotes( return to, cerr } -func (c *coreContext) resolvePreparedSQL(gql []byte) ([]byte, *preparedItem, error) { - ps, ok := _preparedList[gqlHash(gql)] +func (c *coreContext) resolvePreparedSQL(gql string) ([]byte, *preparedItem, error) { + ps, ok := _preparedList[gqlHash(gql, c.req.Vars)] if !ok { return nil, nil, errUnauthorized } @@ -266,17 +267,22 @@ func (c *coreContext) resolvePreparedSQL(gql []byte) ([]byte, *preparedItem, err return nil, nil, err } - fmt.Printf("PRE: %#v %#v\n", ps.stmt, vars) + fmt.Printf("PRE: %v\n", ps.stmt) return []byte(root), ps, nil } func (c *coreContext) resolveSQL(qc *qcode.QCode) ( []byte, uint32, error) { - stmt := &bytes.Buffer{} - skipped, err := pcompile.Compile(qc, stmt) + vars := make(map[string]json.RawMessage) + + if err := json.Unmarshal(c.req.Vars, &vars); err != nil { + return nil, 0, err + } + + skipped, err := pcompile.Compile(qc, stmt, psql.Variables(vars)) if err != nil { return nil, 0, err } @@ -284,7 +290,7 @@ func (c *coreContext) resolveSQL(qc *qcode.QCode) ( t := fasttemplate.New(stmt.String(), openVar, closeVar) stmt.Reset() - _, err = t.Execute(stmt, varMap(c)) + _, err = t.ExecuteFunc(stmt, varMap(c)) if err == errNoUserID && authFailBlock == authFailBlockPerQuery && @@ -317,10 +323,10 @@ func (c *coreContext) resolveSQL(qc *qcode.QCode) ( return nil, 0, err } - if conf.EnableTracing && len(qc.Query.Selects) != 0 { + if conf.EnableTracing && len(qc.Selects) != 0 { c.addTrace( - qc.Query.Selects, - qc.Query.Selects[0].ID, + qc.Selects, + qc.Selects[0].ID, st) } diff --git a/serv/core_test.go b/serv/core_test.go new file mode 100644 index 0000000..922b39f --- /dev/null +++ b/serv/core_test.go @@ -0,0 +1,35 @@ +package serv + +/* + +func simpleMutation(t *testing.T) { + gql := `mutation { + product(id: 15, insert: { name: "Test", price: 20.5 }) { + id + name + } + }` + + sql := `test` + + backgroundCtx := context.Background() + ctx := &coreContext{Context: backgroundCtx} + + resSQL, err := compileGQLToPSQL(gql) + if err != nil { + t.Fatal(err) + } + + fmt.Println(">", string(resSQL)) + + if string(resSQL) != sql { + t.Fatal(errNotExpected) + } +} + +func TestCompileGQL(t *testing.T) { + t.Run("withComplexArgs", withComplexArgs) + t.Run("simpleMutation", simpleMutation) +} + +*/ diff --git a/serv/http.go b/serv/http.go index ca0c061..270e563 100644 --- a/serv/http.go +++ b/serv/http.go @@ -26,13 +26,13 @@ var ( ) type gqlReq struct { - OpName string `json:"operationName"` - Query string `json:"query"` - Vars variables `json:"variables"` + OpName string `json:"operationName"` + Query string `json:"query"` + Vars json.RawMessage `json:"variables"` ref string } -type variables map[string]interface{} +type variables map[string]json.RawMessage type gqlResp struct { Error string `json:"error,omitempty"` diff --git a/serv/prepare.go b/serv/prepare.go index 6fa23bb..906b55d 100644 --- a/serv/prepare.go +++ b/serv/prepare.go @@ -2,9 +2,11 @@ package serv import ( "bytes" + "encoding/json" "fmt" "io" + "github.com/dosco/super-graph/psql" "github.com/dosco/super-graph/qcode" "github.com/go-pg/pg" "github.com/valyala/fasttemplate" @@ -12,7 +14,7 @@ import ( type preparedItem struct { stmt *pg.Stmt - args []string + args [][]byte skipped uint32 qc *qcode.QCode } @@ -25,36 +27,46 @@ func initPreparedList() { _preparedList = make(map[string]*preparedItem) for k, v := range _allowList.list { - err := prepareStmt(k, v.gql) + err := prepareStmt(k, v.gql, v.vars) if err != nil { panic(err) } } } -func prepareStmt(key, gql string) error { +func prepareStmt(key, gql string, varBytes json.RawMessage) error { if len(gql) == 0 || len(key) == 0 { return nil } - qc, err := qcompile.CompileQuery([]byte(gql)) + qc, err := qcompile.Compile([]byte(gql)) if err != nil { return err } + var vars map[string]json.RawMessage + + if len(varBytes) != 0 { + vars = make(map[string]json.RawMessage) + + if err := json.Unmarshal(varBytes, &vars); err != nil { + return err + } + } + buf := &bytes.Buffer{} - skipped, err := pcompile.Compile(qc, buf) + skipped, err := pcompile.Compile(qc, buf, psql.Variables(vars)) if err != nil { return err } - t := fasttemplate.New(buf.String(), `('{{`, `}}')`) - am := make([]string, 0, 5) + t := fasttemplate.New(buf.String(), `{{`, `}}`) + am := make([][]byte, 0, 5) i := 0 finalSQL := t.ExecuteFuncString(func(w io.Writer, tag string) (int, error) { - am = append(am, tag) + am = append(am, []byte(tag)) i++ return w.Write([]byte(fmt.Sprintf("$%d", i))) }) diff --git a/serv/utils.go b/serv/utils.go index dd8e9d0..1bbcc16 100644 --- a/serv/utils.go +++ b/serv/utils.go @@ -4,8 +4,12 @@ import ( "bytes" "crypto/sha1" "encoding/hex" + "io" + "sort" + "strings" "github.com/cespare/xxhash/v2" + "github.com/dosco/super-graph/jsn" ) func mkkey(h *xxhash.Digest, k1 string, k2 string) uint64 { @@ -17,8 +21,8 @@ func mkkey(h *xxhash.Digest, k1 string, k2 string) uint64 { return v } -func gqlHash(b []byte) string { - b = bytes.TrimSpace(b) +func gqlHash(b string, vars []byte) string { + b = strings.TrimSpace(b) h := sha1.New() s, e := 0, 0 @@ -45,13 +49,27 @@ func gqlHash(b []byte) string { if e != 0 { b0 = b[(e - 1)] } - h.Write(bytes.ToLower(b[s:e])) + io.WriteString(h, strings.ToLower(b[s:e])) } if e >= len(b) { break } } + if vars == nil { + return hex.EncodeToString(h.Sum(nil)) + } + + fields := jsn.Keys([]byte(vars)) + + sort.Slice(fields, func(i, j int) bool { + return bytes.Compare(fields[i], fields[j]) == -1 + }) + + for i := range fields { + h.Write(fields[i]) + } + return hex.EncodeToString(h.Sum(nil)) } diff --git a/serv/utils_test.go b/serv/utils_test.go index d382005..bb628af 100644 --- a/serv/utils_test.go +++ b/serv/utils_test.go @@ -6,7 +6,7 @@ import ( ) func TestRelaxHash1(t *testing.T) { - var v1 = []byte(` + var v1 = ` products( limit: 30, @@ -14,18 +14,18 @@ func TestRelaxHash1(t *testing.T) { id name price - }`) + }` - var v2 = []byte(` + var v2 = ` products (limit: 30, where: { id: { AND: { greater_or_equals: 20, lt: 28 } } }) { id name price - } `) + } ` - h1 := gqlHash(v1) - h2 := gqlHash(v2) + h1 := gqlHash(v1, nil) + h2 := gqlHash(v2, nil) if strings.Compare(h1, h2) != 0 { t.Fatal("Hashes don't match they should") @@ -33,7 +33,7 @@ func TestRelaxHash1(t *testing.T) { } func TestRelaxHash2(t *testing.T) { - var v1 = []byte(` + var v1 = ` { products( limit: 30 @@ -49,12 +49,119 @@ func TestRelaxHash2(t *testing.T) { email } } - }`) + }` - var v2 = []byte(` { products( limit: 30, order_by: { price: desc }, distinct: [ price ] where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) { id name price user { id email } } } `) + var v2 = ` { products( limit: 30, order_by: { price: desc }, distinct: [ price ] where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) { id name price user { id email } } } ` - h1 := gqlHash(v1) - h2 := gqlHash(v2) + h1 := gqlHash(v1, nil) + h2 := gqlHash(v2, nil) + + if strings.Compare(h1, h2) != 0 { + t.Fatal("Hashes don't match they should") + } +} + +func TestRelaxHashWithVars1(t *testing.T) { + var q1 = ` + products( + limit: 30, + + where: { id: { AND: { greater_or_equals: 20, lt: 28 } } }) { + id + name + price + }` + + var v1 = ` + { + "insert": { + "name": "Hello", + "description": "World", + "created_at": "now", + "updated_at": "now", + "test": { "type2": "b", "type1": "a" } + }, + "user": 123 + }` + + var q2 = ` + products + (limit: 30, where: { id: { AND: { greater_or_equals: 20, lt: 28 } } }) { + id + name + price + } ` + + var v2 = `{ + "insert": { + "created_at": "now", + "test": { "type1": "a", "type2": "b" }, + "name": "Hello", + "updated_at": "now", + "description": "World" + }, + "user": 123 + }` + + h1 := gqlHash(q1, []byte(va1)) + h2 := gqlHash(q2, []byte(va2)) + + if strings.Compare(h1, h2) != 0 { + t.Fatal("Hashes don't match they should") + } +} + +func TestRelaxHashWithVars2(t *testing.T) { + var q1 = ` + products( + limit: 30, + + where: { id: { AND: { greater_or_equals: 20, lt: 28 } } }) { + id + name + price + }` + + var v1 = ` + { + "insert": [{ + "name": "Hello", + "description": "World", + "created_at": "now", + "updated_at": "now", + "test": { "type2": "b", "type1": "a" } + }, + { + "name": "Hello", + "description": "World", + "created_at": "now", + "updated_at": "now", + "test": { "type2": "b", "type1": "a" } + }], + "user": 123 + }` + + var q2 = ` + products + (limit: 30, where: { id: { AND: { greater_or_equals: 20, lt: 28 } } }) { + id + name + price + } ` + + var v2 = `{ + "insert": { + "created_at": "now", + "test": { "type1": "a", "type2": "b" }, + "name": "Hello", + "updated_at": "now", + "description": "World" + }, + "user": 123 + }` + + h1 := gqlHash(q1, []byte(va1)) + h2 := gqlHash(q2, []byte(va2)) if strings.Compare(h1, h2) != 0 { t.Fatal("Hashes don't match they should") diff --git a/serv/vars.go b/serv/vars.go index 01be528..6ad9da6 100644 --- a/serv/vars.go +++ b/serv/vars.go @@ -1,95 +1,107 @@ package serv import ( + "bytes" + "fmt" "io" - "strconv" - "strings" - "github.com/valyala/fasttemplate" + "github.com/dosco/super-graph/jsn" ) -func varMap(ctx *coreContext) variables { - userIDFn := func(w io.Writer, _ string) (int, error) { - if v := ctx.Value(userIDKey); v != nil { - return w.Write([]byte(v.(string))) - } - return 0, errNoUserID - } +func varMap(ctx *coreContext) func(w io.Writer, tag string) (int, error) { + return func(w io.Writer, tag string) (int, error) { + switch tag { + case "user_id": + if v := ctx.Value(userIDKey); v != nil { + return stringVar(w, v.(string)) + } + return 0, errNoUserID - userIDProviderFn := func(w io.Writer, _ string) (int, error) { - if v := ctx.Value(userIDProviderKey); v != nil { - return w.Write([]byte(v.(string))) - } - return 0, errNoUserID - } - - userIDTag := fasttemplate.TagFunc(userIDFn) - userIDProviderTag := fasttemplate.TagFunc(userIDProviderFn) - - vm := variables{ - "user_id": userIDTag, - "user_id_provider": userIDProviderTag, - "USER_ID": userIDTag, - "USER_ID_PROVIDER": userIDProviderTag, - } - - for k, v := range ctx.req.Vars { - var buf []byte - k = strings.ToLower(k) - - if _, ok := vm[k]; ok { - continue + case "user_id_provider": + if v := ctx.Value(userIDProviderKey); v != nil { + return stringVar(w, v.(string)) + } + return 0, errNoUserID } - switch val := v.(type) { - case string: - vm[k] = val - case int: - vm[k] = strconv.AppendInt(buf, int64(val), 10) - case int64: - vm[k] = strconv.AppendInt(buf, val, 10) - case float64: - vm[k] = strconv.AppendFloat(buf, val, 'f', -1, 64) + fields := jsn.Get(ctx.req.Vars, [][]byte{[]byte(tag)}) + if len(fields) == 0 { + return 0, fmt.Errorf("variable '%s' not found", tag) } + + is := false + + for i := range fields[0].Value { + c := fields[0].Value[i] + if c != ' ' { + is = (c == '"') || (c == '{') || (c == '[') + break + } + } + + if is { + return stringVarB(w, fields[0].Value) + } + + w.Write(fields[0].Value) + return 0, nil } - return vm } -func varList(ctx *coreContext, args []string) []interface{} { - vars := make([]interface{}, 0, len(args)) +func varList(ctx *coreContext, args [][]byte) []interface{} { + vars := make([]interface{}, len(args)) - for k, v := range ctx.req.Vars { - ctx.req.Vars[strings.ToLower(k)] = v + var fields map[string]interface{} + var err error + + if len(ctx.req.Vars) != 0 { + fields, _, err = jsn.Tree(ctx.req.Vars) + + if err != nil { + logger.Warn().Err(err).Msg("Failed to parse variables") + } } for i := range args { - arg := strings.ToLower(args[i]) + av := args[i] - if arg == "user_id" { + switch { + case bytes.Equal(av, []byte("user_id")): if v := ctx.Value(userIDKey); v != nil { - vars = append(vars, v.(string)) + vars[i] = v.(string) } - } - if arg == "user_id_provider" { + case bytes.Equal(av, []byte("user_id_provider")): if v := ctx.Value(userIDProviderKey); v != nil { - vars = append(vars, v.(string)) + vars[i] = v.(string) } - } - if v, ok := ctx.req.Vars[arg]; ok { - switch val := v.(type) { - case string: - vars = append(vars, val) - case int: - vars = append(vars, strconv.FormatInt(int64(val), 10)) - case int64: - vars = append(vars, strconv.FormatInt(int64(val), 10)) - case float64: - vars = append(vars, strconv.FormatFloat(val, 'f', -1, 64)) + default: + if v, ok := fields[string(av)]; ok { + vars[i] = v } } } return vars } + +func stringVar(w io.Writer, v string) (int, error) { + if n, err := w.Write([]byte(`'`)); err != nil { + return n, err + } + if n, err := w.Write([]byte(v)); err != nil { + return n, err + } + return w.Write([]byte(`'`)) +} + +func stringVarB(w io.Writer, v []byte) (int, error) { + if n, err := w.Write([]byte(`'`)); err != nil { + return n, err + } + if n, err := w.Write(v); err != nil { + return n, err + } + return w.Write([]byte(`'`)) +}