From 33f3fefbf3439b1f2a3064a3808faa2c788674e5 Mon Sep 17 00:00:00 2001 From: Vikram Rangnekar Date: Sat, 6 Jun 2020 17:52:21 -0400 Subject: [PATCH] feat: add support for graphql fragments --- core/core.go | 2 - core/internal/psql/query_test.go | 77 +++++ core/internal/psql/tests.sql | 11 +- core/internal/qcode/lex.go | 47 +-- core/internal/qcode/parse.go | 551 +++++++++++++++++++----------- core/internal/qcode/parse_test.go | 96 +++++- 6 files changed, 533 insertions(+), 251 deletions(-) diff --git a/core/core.go b/core/core.go index 31df8d2..fffb1a2 100644 --- a/core/core.go +++ b/core/core.go @@ -196,8 +196,6 @@ func (c *scontext) resolvePreparedSQL() ([]byte, *stmt, error) { return nil, nil, err } - fmt.Println(">>", varsList) - if useTx { row = tx.Stmt(q.sd).QueryRow(varsList...) } else { diff --git a/core/internal/psql/query_test.go b/core/internal/psql/query_test.go index 6960584..113f84f 100644 --- a/core/internal/psql/query_test.go +++ b/core/internal/psql/query_test.go @@ -307,6 +307,80 @@ func multiRoot(t *testing.T) { compileGQLToPSQL(t, gql, nil, "user") } +func withFragment1(t *testing.T) { + gql := ` + fragment userFields1 on user { + id + email + } + + query { + users { + ...userFields2 + + created_at + ...userFields1 + } + } + + fragment userFields2 on user { + first_name + last_name + }` + + compileGQLToPSQL(t, gql, nil, "anon") +} + +func withFragment2(t *testing.T) { + gql := ` + query { + users { + ...userFields2 + + created_at + ...userFields1 + } + } + + fragment userFields1 on user { + id + email + } + + fragment userFields2 on user { + first_name + last_name + }` + + compileGQLToPSQL(t, gql, nil, "anon") +} + +func withFragment3(t *testing.T) { + gql := ` + + fragment userFields1 on user { + id + email + } + + fragment userFields2 on user { + first_name + last_name + } + + query { + users { + ...userFields2 + + created_at + ...userFields1 + } + } +` + + compileGQLToPSQL(t, gql, nil, "anon") +} + func withCursor(t *testing.T) { gql := `query { Products( @@ -400,6 +474,9 @@ func TestCompileQuery(t *testing.T) { t.Run("queryWithVariables", queryWithVariables) t.Run("withWhereOnRelations", withWhereOnRelations) t.Run("multiRoot", multiRoot) + t.Run("withFragment1", withFragment1) + t.Run("withFragment2", withFragment2) + t.Run("withFragment3", withFragment3) t.Run("jsonColumnAsTable", jsonColumnAsTable) t.Run("withCursor", withCursor) t.Run("nullForAuthRequiredInAnon", nullForAuthRequiredInAnon) diff --git a/core/internal/psql/tests.sql b/core/internal/psql/tests.sql index c5feb4c..53dfbd9 100644 --- a/core/internal/psql/tests.sql +++ b/core/internal/psql/tests.sql @@ -86,6 +86,12 @@ SELECT jsonb_build_object('product', "__sj_0"."json") as "__root" FROM (SELECT t SELECT jsonb_build_object('users', "__sj_0"."json") as "__root" FROM (SELECT coalesce(jsonb_agg("__sj_0"."json"), '[]') as "json" FROM (SELECT to_jsonb("__sr_0".*) AS "json"FROM (SELECT "users_0"."id" AS "id", "users_0"."email" AS "email" FROM (SELECT "users"."id", "users"."email" FROM "users" WHERE (NOT EXISTS (SELECT 1 FROM products WHERE (("products"."user_id") = ("users"."id")) AND ((("products"."price") > '3' :: numeric(7,2))))) LIMIT ('20') :: integer) AS "users_0") AS "__sr_0") AS "__sj_0") AS "__sj_0" === RUN TestCompileQuery/multiRoot SELECT jsonb_build_object('customer', "__sj_0"."json", 'user', "__sj_1"."json", 'product', "__sj_2"."json") as "__root" FROM (SELECT to_jsonb("__sr_2".*) AS "json"FROM (SELECT "products_2"."id" AS "id", "products_2"."name" AS "name", "__sj_3"."json" AS "customers", "__sj_4"."json" AS "customer" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE (((("products"."price") > '0' :: numeric(7,2)) AND (("products"."price") < '8' :: numeric(7,2)))) LIMIT ('1') :: integer) AS "products_2" LEFT OUTER JOIN LATERAL (SELECT to_jsonb("__sr_4".*) AS "json"FROM (SELECT "customers_4"."email" AS "email" FROM (SELECT "customers"."email" FROM "customers" LEFT OUTER JOIN "purchases" ON (("purchases"."product_id") = ("products_2"."id")) WHERE ((("customers"."id") = ("purchases"."customer_id"))) LIMIT ('1') :: integer) AS "customers_4") AS "__sr_4") AS "__sj_4" ON ('true') LEFT OUTER JOIN LATERAL (SELECT coalesce(jsonb_agg("__sj_3"."json"), '[]') as "json" FROM (SELECT to_jsonb("__sr_3".*) AS "json"FROM (SELECT "customers_3"."email" AS "email" FROM (SELECT "customers"."email" FROM "customers" LEFT OUTER JOIN "purchases" ON (("purchases"."product_id") = ("products_2"."id")) WHERE ((("customers"."id") = ("purchases"."customer_id"))) LIMIT ('20') :: integer) AS "customers_3") AS "__sr_3") AS "__sj_3") AS "__sj_3" ON ('true')) AS "__sr_2") AS "__sj_2", (SELECT to_jsonb("__sr_1".*) AS "json"FROM (SELECT "users_1"."id" AS "id", "users_1"."email" AS "email" FROM (SELECT "users"."id", "users"."email" FROM "users" LIMIT ('1') :: integer) AS "users_1") AS "__sr_1") AS "__sj_1", (SELECT to_jsonb("__sr_0".*) AS "json"FROM (SELECT "customers_0"."id" AS "id" FROM (SELECT "customers"."id" FROM "customers" LIMIT ('1') :: integer) AS "customers_0") AS "__sr_0") AS "__sj_0" +=== RUN TestCompileQuery/withFragment1 +SELECT jsonb_build_object('users', "__sj_0"."json") as "__root" FROM (SELECT coalesce(jsonb_agg("__sj_0"."json"), '[]') as "json" FROM (SELECT to_jsonb("__sr_0".*) AS "json"FROM (SELECT "users_0"."first_name" AS "first_name", "users_0"."last_name" AS "last_name", "users_0"."created_at" AS "created_at", "users_0"."id" AS "id", "users_0"."email" AS "email" FROM (SELECT , "users"."created_at", "users"."id", "users"."email" FROM "users" GROUP BY "users"."created_at", "users"."id", "users"."email" LIMIT ('20') :: integer) AS "users_0") AS "__sr_0") AS "__sj_0") AS "__sj_0" +=== RUN TestCompileQuery/withFragment2 +SELECT jsonb_build_object('users', "__sj_0"."json") as "__root" FROM (SELECT coalesce(jsonb_agg("__sj_0"."json"), '[]') as "json" FROM (SELECT to_jsonb("__sr_0".*) AS "json"FROM (SELECT "users_0"."first_name" AS "first_name", "users_0"."last_name" AS "last_name", "users_0"."created_at" AS "created_at", "users_0"."id" AS "id", "users_0"."email" AS "email" FROM (SELECT , "users"."created_at", "users"."id", "users"."email" FROM "users" GROUP BY "users"."created_at", "users"."id", "users"."email" LIMIT ('20') :: integer) AS "users_0") AS "__sr_0") AS "__sj_0") AS "__sj_0" +=== RUN TestCompileQuery/withFragment3 +SELECT jsonb_build_object('users', "__sj_0"."json") as "__root" FROM (SELECT coalesce(jsonb_agg("__sj_0"."json"), '[]') as "json" FROM (SELECT to_jsonb("__sr_0".*) AS "json"FROM (SELECT "users_0"."first_name" AS "first_name", "users_0"."last_name" AS "last_name", "users_0"."created_at" AS "created_at", "users_0"."id" AS "id", "users_0"."email" AS "email" FROM (SELECT , "users"."created_at", "users"."id", "users"."email" FROM "users" GROUP BY "users"."created_at", "users"."id", "users"."email" LIMIT ('20') :: integer) AS "users_0") AS "__sr_0") AS "__sj_0") AS "__sj_0" === RUN TestCompileQuery/jsonColumnAsTable SELECT jsonb_build_object('products', "__sj_0"."json") as "__root" FROM (SELECT coalesce(jsonb_agg("__sj_0"."json"), '[]') as "json" FROM (SELECT to_jsonb("__sr_0".*) AS "json"FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "__sj_1"."json" AS "tag_count" FROM (SELECT "products"."id", "products"."name" FROM "products" LIMIT ('20') :: integer) AS "products_0" LEFT OUTER JOIN LATERAL (SELECT to_jsonb("__sr_1".*) AS "json"FROM (SELECT "tag_count_1"."count" AS "count", "__sj_2"."json" AS "tags" FROM (SELECT "tag_count"."count", "tag_count"."tag_id" FROM "products", json_to_recordset("products"."tag_count") AS "tag_count"(tag_id bigint, count int) WHERE ((("products"."id") = ("products_0"."id"))) LIMIT ('1') :: integer) AS "tag_count_1" LEFT OUTER JOIN LATERAL (SELECT coalesce(jsonb_agg("__sj_2"."json"), '[]') as "json" FROM (SELECT to_jsonb("__sr_2".*) AS "json"FROM (SELECT "tags_2"."name" AS "name" FROM (SELECT "tags"."name" FROM "tags" WHERE ((("tags"."id") = ("tag_count_1"."tag_id"))) LIMIT ('20') :: integer) AS "tags_2") AS "__sr_2") AS "__sj_2") AS "__sj_2" ON ('true')) AS "__sr_1") AS "__sj_1" ON ('true')) AS "__sr_0") AS "__sj_0") AS "__sj_0" === RUN TestCompileQuery/withCursor @@ -117,6 +123,9 @@ SELECT jsonb_build_object('users', "__sj_0"."json") as "__root" FROM (SELECT coa --- PASS: TestCompileQuery/queryWithVariables (0.00s) --- PASS: TestCompileQuery/withWhereOnRelations (0.00s) --- PASS: TestCompileQuery/multiRoot (0.00s) + --- PASS: TestCompileQuery/withFragment1 (0.00s) + --- PASS: TestCompileQuery/withFragment2 (0.00s) + --- PASS: TestCompileQuery/withFragment3 (0.00s) --- PASS: TestCompileQuery/jsonColumnAsTable (0.00s) --- PASS: TestCompileQuery/withCursor (0.00s) --- PASS: TestCompileQuery/nullForAuthRequiredInAnon (0.00s) @@ -151,4 +160,4 @@ WITH "_sg_input" AS (SELECT $1 :: json AS j), "_x_users" AS (SELECT * FROM (VALU --- PASS: TestCompileUpdate/nestedUpdateOneToOneWithConnect (0.00s) --- PASS: TestCompileUpdate/nestedUpdateOneToOneWithDisconnect (0.00s) PASS -ok github.com/dosco/super-graph/core/internal/psql (cached) +ok github.com/dosco/super-graph/core/internal/psql 0.374s diff --git a/core/internal/qcode/lex.go b/core/internal/qcode/lex.go index a4d8bf1..259ce2b 100644 --- a/core/internal/qcode/lex.go +++ b/core/internal/qcode/lex.go @@ -11,15 +11,18 @@ import ( var ( queryToken = []byte("query") mutationToken = []byte("mutation") + fragmentToken = []byte("fragment") subscriptionToken = []byte("subscription") + onToken = []byte("on") trueToken = []byte("true") falseToken = []byte("false") quotesToken = []byte(`'"`) signsToken = []byte(`+-`) - punctuatorToken = []byte(`!():=[]{|}`) spreadToken = []byte(`...`) digitToken = []byte(`0123456789`) dotToken = []byte(`.`) + + punctuatorToken = `!():=[]{|}` ) // Pos represents a byte position in the original input text from which @@ -43,6 +46,8 @@ const ( itemName itemQuery itemMutation + itemFragment + itemOn itemSub itemPunctuator itemArgsOpen @@ -263,11 +268,11 @@ func lexRoot(l *lexer) stateFn { l.backup() return lexString case r == '.': - if len(l.input) >= 3 { - if equals(l.input, 0, 3, spreadToken) { - l.emit(itemSpread) - return lexRoot - } + l.acceptRun(dotToken) + s, e := l.current() + if equals(l.input, s, e, spreadToken) { + l.emit(itemSpread) + return lexRoot } fallthrough // '.' can start a number. case r == '+' || r == '-' || ('0' <= r && r <= '9'): @@ -299,10 +304,14 @@ func lexName(l *lexer) stateFn { switch { case equals(l.input, s, e, queryToken): l.emitL(itemQuery) + case equals(l.input, s, e, fragmentToken): + l.emitL(itemFragment) case equals(l.input, s, e, mutationToken): l.emitL(itemMutation) case equals(l.input, s, e, subscriptionToken): l.emitL(itemSub) + case equals(l.input, s, e, onToken): + l.emitL(itemOn) case equals(l.input, s, e, trueToken): l.emitL(itemBoolVal) case equals(l.input, s, e, falseToken): @@ -396,31 +405,11 @@ func isAlphaNumeric(r rune) bool { } func equals(b []byte, s Pos, e Pos, val []byte) bool { - n := 0 - for i := s; i < e; i++ { - if n >= len(val) { - return true - } - switch { - case b[i] >= 'A' && b[i] <= 'Z' && ('a'+(b[i]-'A')) != val[n]: - return false - case b[i] != val[n]: - return false - } - n++ - } - return true + return bytes.EqualFold(b[s:e], val) } -func contains(b []byte, s Pos, e Pos, val []byte) bool { - for i := s; i < e; i++ { - for n := 0; n < len(val); n++ { - if b[i] == val[n] { - return true - } - } - } - return false +func contains(b []byte, s Pos, e Pos, chars string) bool { + return bytes.ContainsAny(b[s:e], chars) } func lowercase(b []byte, s Pos, e Pos) { diff --git a/core/internal/qcode/parse.go b/core/internal/qcode/parse.go index 6d88ac5..2dd1cd2 100644 --- a/core/internal/qcode/parse.go +++ b/core/internal/qcode/parse.go @@ -3,10 +3,9 @@ package qcode import ( "errors" "fmt" + "hash/maphash" "sync" "unsafe" - - "github.com/dosco/super-graph/core/internal/util" ) var ( @@ -35,8 +34,7 @@ const ( NodeVar ) -type Operation struct { - Type parserType +type SelectionSet struct { Name string Args []Arg argsA [10]Arg @@ -44,12 +42,29 @@ type Operation struct { fieldsA [10]Field } +type Operation struct { + Type parserType + SelectionSet +} + var zeroOperation = Operation{} func (o *Operation) Reset() { *o = zeroOperation } +type Fragment struct { + Name string + On string + SelectionSet +} + +var zeroFragment = Fragment{} + +func (f *Fragment) Reset() { + *f = zeroFragment +} + type Field struct { ID int32 ParentID int32 @@ -82,6 +97,8 @@ func (n *Node) Reset() { } type Parser struct { + frags map[uint64]*Fragment + h maphash.Hash input []byte // the string being scanned pos int items []item @@ -96,12 +113,194 @@ var opPool = sync.Pool{ New: func() interface{} { return new(Operation) }, } +var fragPool = sync.Pool{ + New: func() interface{} { return new(Fragment) }, +} + var lexPool = sync.Pool{ New: func() interface{} { return new(lexer) }, } func Parse(gql []byte) (*Operation, error) { - return parseSelectionSet(gql) + var err error + + if len(gql) == 0 { + return nil, errors.New("blank query") + } + + l := lexPool.Get().(*lexer) + l.Reset() + defer lexPool.Put(l) + + if err = lex(l, gql); err != nil { + return nil, err + } + + p := &Parser{ + input: l.input, + pos: -1, + items: l.items, + } + + op := opPool.Get().(*Operation) + op.Reset() + op.Fields = op.fieldsA[:0] + + s := -1 + qf := false + + for { + if p.peek(itemEOF) { + p.ignore() + break + } + + if p.peek(itemFragment) { + p.ignore() + if err = p.parseFragment(op); err != nil { + return nil, err + } + } else { + if !qf && p.peek(itemQuery, itemMutation, itemSub, itemObjOpen) { + s = p.pos + qf = true + } + p.ignore() + } + } + + p.reset(s) + if err := p.parseOp(op); err != nil { + return nil, err + } + + return op, nil +} + +func (p *Parser) parseFragment(op *Operation) error { + frag := fragPool.Get().(*Fragment) + frag.Reset() + + frag.Fields = frag.fieldsA[:0] + frag.Args = frag.argsA[:0] + + if p.peek(itemName) { + frag.Name = p.val(p.next()) + } + + if p.peek(itemOn) { + p.ignore() + } else { + return errors.New("fragment: missing 'on' keyword") + } + + if p.peek(itemName) { + frag.On = p.vall(p.next()) + } else { + return errors.New("fragment: missing table name after 'on' keyword") + } + + if p.peek(itemObjOpen) { + p.ignore() + } else { + return fmt.Errorf("fragment: expecting a '{', got: %s", p.next()) + } + + if err := p.parseSelectionSet(&frag.SelectionSet); err != nil { + return fmt.Errorf("fragment: %v", err) + } + + if p.frags == nil { + p.frags = make(map[uint64]*Fragment) + } + + _, _ = p.h.WriteString(frag.Name) + k := p.h.Sum64() + p.h.Reset() + + p.frags[k] = frag + + return nil +} + +func (p *Parser) parseOp(op *Operation) error { + var err error + var typeSet bool + + if p.peek(itemQuery, itemMutation, itemSub) { + err = p.parseOpTypeAndArgs(op) + + if err != nil { + return fmt.Errorf("%s: %v", op.Type, err) + } + typeSet = true + } + + if p.peek(itemObjOpen) { + p.ignore() + if !typeSet { + op.Type = opQuery + } + + for { + if p.peek(itemEOF, itemFragment) { + p.ignore() + break + } + + err = p.parseSelectionSet(&op.SelectionSet) + if err != nil { + return fmt.Errorf("%s: %v", op.Type, err) + } + } + } else { + return fmt.Errorf("expecting a query, mutation or subscription, got: %s", p.next()) + } + + return nil +} + +func (p *Parser) parseOpTypeAndArgs(op *Operation) error { + item := p.next() + + switch item._type { + case itemQuery: + op.Type = opQuery + case itemMutation: + op.Type = opMutate + case itemSub: + op.Type = opSub + } + + op.Args = op.argsA[:0] + + var err error + + if p.peek(itemName) { + op.Name = p.val(p.next()) + } + + if p.peek(itemArgsOpen) { + p.ignore() + + op.Args, err = p.parseOpParams(op.Args) + if err != nil { + return err + } + } + + return nil +} + +func (p *Parser) parseSelectionSet(selset *SelectionSet) error { + var err error + + selset.Fields, err = p.parseFields(selset.Fields) + if err != nil { + return err + } + + return nil } func ParseArgValue(argVal string) (*Node, error) { @@ -123,216 +322,107 @@ func ParseArgValue(argVal string) (*Node, error) { return op, err } -func parseSelectionSet(gql []byte) (*Operation, error) { - var err error - - if len(gql) == 0 { - return nil, errors.New("blank query") - } - - l := lexPool.Get().(*lexer) - l.Reset() - - if err = lex(l, gql); err != nil { - return nil, err - } - - p := &Parser{ - input: l.input, - pos: -1, - items: l.items, - } - - var op *Operation - - if p.peek(itemObjOpen) { - p.ignore() - op, err = p.parseQueryOp() - } else { - op, err = p.parseOp() - } - - if err != nil { - return nil, err - } - - if p.peek(itemObjClose) { - p.ignore() - } else { - return nil, fmt.Errorf("operation missing closing '}'") - } - - if !p.peek(itemEOF) { - p.ignore() - return nil, fmt.Errorf("invalid '%s' found after closing '}'", p.current()) - } - - lexPool.Put(l) - - return op, err -} - -func (p *Parser) next() item { - n := p.pos + 1 - if n >= len(p.items) { - p.err = errEOT - return item{_type: itemEOF} - } - p.pos = n - return p.items[p.pos] -} - -func (p *Parser) ignore() { - n := p.pos + 1 - if n >= len(p.items) { - p.err = errEOT - return - } - p.pos = n -} - -func (p *Parser) current() string { - item := p.items[p.pos] - return b2s(p.input[item.pos:item.end]) -} - -func (p *Parser) peek(types ...itemType) bool { - n := p.pos + 1 - // if p.items[n]._type == itemEOF { - // return false - // } - - if n >= len(p.items) { - return false - } - for i := 0; i < len(types); i++ { - if p.items[n]._type == types[i] { - return true - } - } - return false -} - -func (p *Parser) parseOp() (*Operation, error) { - if !p.peek(itemQuery, itemMutation, itemSub) { - err := errors.New("expecting a query, mutation or subscription") - return nil, err - } - item := p.next() - - op := opPool.Get().(*Operation) - op.Reset() - - switch item._type { - case itemQuery: - op.Type = opQuery - case itemMutation: - op.Type = opMutate - case itemSub: - op.Type = opSub - } - - op.Fields = op.fieldsA[:0] - op.Args = op.argsA[:0] - - var err error - - if p.peek(itemName) { - op.Name = p.val(p.next()) - } - - if p.peek(itemArgsOpen) { - p.ignore() - - op.Args, err = p.parseOpParams(op.Args) - if err != nil { - return nil, err - } - } - - if p.peek(itemObjOpen) { - p.ignore() - - for n := 0; n < 10; n++ { - if !p.peek(itemName) { - break - } - - op.Fields, err = p.parseFields(op.Fields) - if err != nil { - return nil, err - } - } - } - - return op, nil -} - -func (p *Parser) parseQueryOp() (*Operation, error) { - op := opPool.Get().(*Operation) - op.Reset() - - op.Type = opQuery - op.Fields = op.fieldsA[:0] - op.Args = op.argsA[:0] - - var err error - - for n := 0; n < 10; n++ { - if !p.peek(itemName) { - break - } - - op.Fields, err = p.parseFields(op.Fields) - if err != nil { - return nil, err - } - } - - return op, nil -} - func (p *Parser) parseFields(fields []Field) ([]Field, error) { - st := util.NewStack() + st := NewStack() + + if !p.peek(itemName, itemSpread) { + return nil, fmt.Errorf("unexpected token: %s", p.peekNext()) + } for { + if p.peek(itemEOF) { + p.ignore() + return nil, errors.New("invalid query") + } + + if p.peek(itemObjClose) { + p.ignore() + + if st.Len() != 0 { + st.Pop() + continue + } else { + break + } + } + if len(fields) >= maxFields { return nil, fmt.Errorf("too many fields (max %d)", maxFields) } - if p.peek(itemEOF, itemObjClose) { - p.ignore() - st.Pop() + isFrag := false - if st.Len() == 0 { - break - } else { - continue - } + if p.peek(itemSpread) { + p.ignore() + isFrag = true } if !p.peek(itemName) { - return nil, errors.New("expecting an alias or field name") + if isFrag { + return nil, fmt.Errorf("expecting a fragment name, got: %s", p.next()) + } else { + return nil, fmt.Errorf("expecting an alias or field name, got: %s", p.next()) + } } - fields = append(fields, Field{ID: int32(len(fields))}) + var f *Field - f := &fields[(len(fields) - 1)] - f.Args = f.argsA[:0] - f.Children = f.childrenA[:0] + if isFrag { + name := p.val(p.next()) + p.h.WriteString(name) + k := p.h.Sum64() + p.h.Reset() - // Parse the inside of the the fields () parentheses - // in short parse the args like id, where, etc - if err := p.parseField(f); err != nil { - return nil, err - } + fr, ok := p.frags[k] + if !ok { + return nil, fmt.Errorf("no fragment named '%s' defined", name) + } + + n := int32(len(fields)) + fields = append(fields, fr.Fields...) + + for i := int(n); i < len(fields); i++ { + f := &fields[i] + f.ID = int32(i) + + // If this is the top-level point the parent to the parent of the + // previous field. + if f.ParentID == -1 { + pid := st.Peek() + f.ParentID = pid + if f.ParentID != -1 { + fields[pid].Children = append(fields[f.ParentID].Children, f.ID) + } + // Update all the other parents id's by our new place in this new array + } else { + f.ParentID += n + } + + // Update all the children which is needed. + for j := range f.Children { + f.Children[j] += n + } + } - intf := st.Peek() - if pid, ok := intf.(int32); ok { - f.ParentID = pid - fields[pid].Children = append(fields[pid].Children, f.ID) } else { - f.ParentID = -1 + fields = append(fields, Field{ID: int32(len(fields))}) + + f = &fields[(len(fields) - 1)] + f.Args = f.argsA[:0] + f.Children = f.childrenA[:0] + + // Parse the field + if err := p.parseField(f); err != nil { + return nil, err + } + + if st.Len() == 0 { + f.ParentID = -1 + } else { + pid := st.Peek() + f.ParentID = pid + fields[pid].Children = append(fields[pid].Children, f.ID) + } } // The first opening curley brackets after this @@ -340,13 +430,6 @@ func (p *Parser) parseFields(fields []Field) ([]Field, error) { if p.peek(itemObjOpen) { p.ignore() st.Push(f.ID) - - } else if p.peek(itemObjClose) { - if st.Len() == 0 { - break - } else { - continue - } } } @@ -546,6 +629,62 @@ func (p *Parser) vall(v item) string { return b2s(p.input[v.pos:v.end]) } +func (p *Parser) peek(types ...itemType) bool { + n := p.pos + 1 + l := len(types) + // if p.items[n]._type == itemEOF { + // return false + // } + + if n >= len(p.items) { + return types[0] == itemEOF + } + + if l == 1 { + return p.items[n]._type == types[0] + } + + for i := 0; i < l; i++ { + if p.items[n]._type == types[i] { + return true + } + } + return false +} + +func (p *Parser) next() item { + n := p.pos + 1 + if n >= len(p.items) { + p.err = errEOT + return item{_type: itemEOF} + } + p.pos = n + return p.items[p.pos] +} + +func (p *Parser) ignore() { + n := p.pos + 1 + if n >= len(p.items) { + p.err = errEOT + return + } + p.pos = n +} + +func (p *Parser) peekCurrent() string { + item := p.items[p.pos] + return b2s(p.input[item.pos:item.end]) +} + +func (p *Parser) peekNext() string { + item := p.items[p.pos+1] + return b2s(p.input[item.pos:item.end]) +} + +func (p *Parser) reset(to int) { + p.pos = to +} + func b2s(b []byte) string { return *(*string)(unsafe.Pointer(&b)) } @@ -579,7 +718,7 @@ func (t parserType) String() string { case NodeList: v = "node-list" } - return fmt.Sprintf("<%s>", v) + return v } // type Frees struct { diff --git a/core/internal/qcode/parse_test.go b/core/internal/qcode/parse_test.go index c465e95..29d70e4 100644 --- a/core/internal/qcode/parse_test.go +++ b/core/internal/qcode/parse_test.go @@ -121,7 +121,7 @@ updateThread { } } } -}` +}}` qcompile, _ := NewCompiler(Config{}) _, err := qcompile.Compile([]byte(gql), "anon") @@ -131,19 +131,90 @@ updateThread { } -func TestFragmentsCompile(t *testing.T) { +func TestFragmentsCompile1(t *testing.T) { gql := ` -fragment userFields on user { - name - email -} - -query { users { ...userFields } }` - qcompile, _ := NewCompiler(Config{}) - _, err := qcompile.Compile([]byte(gql), "anon") + fragment userFields1 on user { + id + email + } - if err == nil { - t.Fatal(errors.New("expecting an error")) + query { + users { + ...userFields2 + + created_at + ...userFields1 + } + } + + fragment userFields2 on user { + first_name + last_name + } + ` + qcompile, _ := NewCompiler(Config{}) + _, err := qcompile.Compile([]byte(gql), "user") + + if err != nil { + t.Fatal(err) + } +} + +func TestFragmentsCompile2(t *testing.T) { + gql := ` + query { + users { + ...userFields2 + + created_at + ...userFields1 + } + } + + fragment userFields1 on user { + id + email + } + + fragment userFields2 on user { + first_name + last_name + }` + qcompile, _ := NewCompiler(Config{}) + _, err := qcompile.Compile([]byte(gql), "user") + + if err != nil { + t.Fatal(err) + } +} + +func TestFragmentsCompile3(t *testing.T) { + gql := ` + fragment userFields1 on user { + id + email + } + + fragment userFields2 on user { + first_name + last_name + } + + query { + users { + ...userFields2 + + created_at + ...userFields1 + } + } + + ` + qcompile, _ := NewCompiler(Config{}) + _, err := qcompile.Compile([]byte(gql), "user") + + if err != nil { + t.Fatal(err) } } @@ -201,7 +272,6 @@ func BenchmarkQCompileP(b *testing.B) { } func BenchmarkParse(b *testing.B) { - b.ResetTimer() b.ReportAllocs() for n := 0; n < b.N; n++ {