Optimize lexer and fix bugs

This commit is contained in:
Vikram Rangnekar 2019-06-14 22:17:21 -04:00
parent 9af320f396
commit 340dea242d
10 changed files with 301 additions and 197 deletions

1
go.mod
View File

@ -27,6 +27,7 @@ require (
github.com/spf13/viper v1.3.1 github.com/spf13/viper v1.3.1
github.com/valyala/fasttemplate v1.0.1 github.com/valyala/fasttemplate v1.0.1
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2
golang.org/x/net v0.0.0-20190311183353-d8887717615a
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223 golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223
mellium.im/sasl v0.2.1 // indirect mellium.im/sasl v0.2.1 // indirect
) )

1
go.sum
View File

@ -111,6 +111,7 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90Pveol
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd h1:nTDtHvHSdCn1m6ITfMRqtOd/9+7a3s8RBNOZ3eYZzJA= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd h1:nTDtHvHSdCn1m6ITfMRqtOd/9+7a3s8RBNOZ3eYZzJA=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190311183353-d8887717615a h1:oWX7TPOiFAMXLq8o0ikBYfCJVlRHBcsciT5bXOrH628=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f h1:wMNYb4v58l5UBM7MYRLPG6ZhfOqbKu7X5eyFl8ZhKvA= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f h1:wMNYb4v58l5UBM7MYRLPG6ZhfOqbKu7X5eyFl8ZhKvA=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=

View File

@ -126,7 +126,7 @@ func TestMain(m *testing.M) {
} }
func compileGQLToPSQL(gql string) ([]byte, error) { func compileGQLToPSQL(gql string) ([]byte, error) {
qc, err := qcompile.CompileQuery(gql) qc, err := qcompile.Compile([]byte(gql))
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -141,7 +141,7 @@ func compileGQLToPSQL(gql string) ([]byte, error) {
func withComplexArgs(t *testing.T) { func withComplexArgs(t *testing.T) {
gql := `query { gql := `query {
products( proDUcts(
# returns only 30 items # returns only 30 items
limit: 30, limit: 30,
@ -157,7 +157,7 @@ func withComplexArgs(t *testing.T) {
# only items with an id >= 30 and < 30 are returned # only items with an id >= 30 and < 30 are returned
where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) { where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) {
id id
name NAME
price price
} }
}` }`
@ -482,8 +482,8 @@ func TestCompileGQL(t *testing.T) {
t.Run("syntheticTables", syntheticTables) t.Run("syntheticTables", syntheticTables)
} }
var benchGQL = `query { var benchGQL = []byte(`query {
products( proDUcts(
# returns only 30 items # returns only 30 items
limit: 30, limit: 30,
@ -496,14 +496,14 @@ var benchGQL = `query {
# only items with an id >= 30 and < 30 are returned # only items with an id >= 30 and < 30 are returned
where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) { where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) {
id id
name NAME
price price
user { user {
full_name full_name
picture : avatar picture : avatar
} }
} }
}` }`)
func BenchmarkCompile(b *testing.B) { func BenchmarkCompile(b *testing.B) {
w := &bytes.Buffer{} w := &bytes.Buffer{}
@ -514,7 +514,7 @@ func BenchmarkCompile(b *testing.B) {
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
w.Reset() w.Reset()
qc, err := qcompile.CompileQuery(benchGQL) qc, err := qcompile.Compile(benchGQL)
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
@ -535,7 +535,7 @@ func BenchmarkCompileParallel(b *testing.B) {
for pb.Next() { for pb.Next() {
w.Reset() w.Reset()
qc, err := qcompile.CompileQuery(benchGQL) qc, err := qcompile.Compile(benchGQL)
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }

7
qcode/bench.3 Normal file
View File

@ -0,0 +1,7 @@
goos: darwin
goarch: amd64
pkg: github.com/dosco/super-graph/qcode
BenchmarkQCompile-8 200000 10029 ns/op 2291 B/op 38 allocs/op
BenchmarkQCompileP-8 500000 2925 ns/op 2298 B/op 38 allocs/op
PASS
ok github.com/dosco/super-graph/qcode 3.616s

View File

@ -2,10 +2,10 @@ package qcode
// FuzzerEntrypoint for Fuzzbuzz // FuzzerEntrypoint for Fuzzbuzz
func FuzzerEntrypoint(data []byte) int { func FuzzerEntrypoint(data []byte) int {
testData := string(data) //testData := string(data)
qcompile, _ := NewCompiler(Config{}) qcompile, _ := NewCompiler(Config{})
_, err := qcompile.CompileQuery(testData) _, err := qcompile.Compile(data)
if err != nil { if err != nil {
return -1 return -1
} }

View File

@ -1,59 +1,37 @@
package qcode package qcode
import ( import (
"bytes"
"errors" "errors"
"fmt" "fmt"
"strings"
"unicode" "unicode"
"unicode/utf8" "unicode/utf8"
) )
var (
queryToken = []byte("query")
mutationToken = []byte("mutation")
subscriptionToken = []byte("subscription")
trueToken = []byte("true")
falseToken = []byte("true")
quotesToken = []byte(`'"`)
signsToken = []byte(`+-`)
punctuatorToken = []byte(`!():=[]{|}`)
spreadToken = []byte(`...`)
digitToken = []byte(`0123456789`)
dotToken = []byte(`.`)
)
// Pos represents a byte position in the original input text from which // Pos represents a byte position in the original input text from which
// this template was parsed. // this template was parsed.
type Pos int type Pos int
func (p Pos) Position() Pos {
return p
}
// item represents a token or text string returned from the scanner. // item represents a token or text string returned from the scanner.
type item struct { type item struct {
typ itemType // The type of this item. typ itemType // The type of this item.
pos Pos // The starting position, in bytes, of this item in the input string. pos Pos // The starting position, in bytes, of this item in the input string.
val string // The value of this item. end Pos // The ending position, in bytes, of this item in the input string.
line int // The line number at the start of this item. line uint16 // The line number at the start of this item.
}
func (i *item) String() string {
var v string
switch i.typ {
case itemEOF:
v = "EOF"
case itemError:
v = "error"
case itemName:
v = "name"
case itemQuery:
v = "query"
case itemMutation:
v = "mutation"
case itemSub:
v = "subscription"
case itemPunctuator:
v = "punctuator"
case itemDirective:
v = "directive"
case itemVariable:
v = "variable"
case itemIntVal:
v = "int"
case itemFloatVal:
v = "float"
case itemStringVal:
v = "string"
}
return fmt.Sprintf("%s %q", v, i.val)
} }
// itemType identifies the type of lex items. // itemType identifies the type of lex items.
@ -103,14 +81,14 @@ type stateFn func(*lexer) stateFn
// lexer holds the state of the scanner. // lexer holds the state of the scanner.
type lexer struct { type lexer struct {
name string // the name of the input; used only for error reports input []byte // the string being scanned
input string // the string being scanned
pos Pos // current position in the input pos Pos // current position in the input
start Pos // start position of this item start Pos // start position of this item
width Pos // width of last rune read from input width Pos // width of last rune read from input
items []item // array of scanned items items []item // array of scanned items
itemsA [100]item itemsA [50]item
line int // 1+number of newlines seen line uint16 // 1+number of newlines seen
err error
} }
var zeroLex = lexer{} var zeroLex = lexer{}
@ -119,13 +97,13 @@ func (l *lexer) Reset() {
*l = zeroLex *l = zeroLex
} }
// next returns the next rune in the input. // next returns the next byte in the input.
func (l *lexer) next() rune { func (l *lexer) next() rune {
if int(l.pos) >= len(l.input) { if int(l.pos) >= len(l.input) {
l.width = 0 l.width = 0
return eof return eof
} }
r, w := utf8.DecodeRuneInString(l.input[l.pos:]) r, w := utf8.DecodeRune(l.input[l.pos:])
l.width = Pos(w) l.width = Pos(w)
l.pos += l.width l.pos += l.width
if r == '\n' { if r == '\n' {
@ -150,30 +128,38 @@ func (l *lexer) backup() {
} }
} }
func (l *lexer) current() string { func (l *lexer) current() (Pos, Pos) {
return l.input[l.start:l.pos] return l.start, l.pos
} }
// emit passes an item back to the client. // emit passes an item back to the client.
func (l *lexer) emit(t itemType) { func (l *lexer) emit(t itemType) {
l.items = append(l.items, item{t, l.start, l.input[l.start:l.pos], l.line}) l.items = append(l.items, item{t, l.start, l.pos, l.line})
// Some items contain text internally. If so, count their newlines. // Some items contain text internally. If so, count their newlines.
switch t { switch t {
case itemName: case itemName:
l.line += strings.Count(l.input[l.start:l.pos], "\n") for i := l.start; i < l.pos; i++ {
if l.input[i] == '\n' {
l.line++
}
}
} }
l.start = l.pos l.start = l.pos
} }
// ignore skips over the pending input before this point. // ignore skips over the pending input before this point.
func (l *lexer) ignore() { func (l *lexer) ignore() {
l.line += strings.Count(l.input[l.start:l.pos], "\n") for i := l.start; i < l.pos; i++ {
if l.input[i] == '\n' {
l.line++
}
}
l.start = l.pos l.start = l.pos
} }
// accept consumes the next rune if it's from the valid set. // accept consumes the next rune if it's from the valid set.
func (l *lexer) accept(valid string) bool { func (l *lexer) accept(valid []byte) bool {
if strings.ContainsRune(valid, l.next()) { if bytes.ContainsRune(valid, l.next()) {
return true return true
} }
l.backup() l.backup()
@ -199,8 +185,8 @@ func (l *lexer) acceptComment() {
} }
// acceptRun consumes a run of runes from the valid set. // acceptRun consumes a run of runes from the valid set.
func (l *lexer) acceptRun(valid string) { func (l *lexer) acceptRun(valid []byte) {
for strings.ContainsRune(valid, l.next()) { for bytes.ContainsRune(valid, l.next()) {
} }
l.backup() l.backup()
} }
@ -208,23 +194,25 @@ func (l *lexer) acceptRun(valid string) {
// errorf returns an error token and terminates the scan by passing // errorf returns an error token and terminates the scan by passing
// back a nil pointer that will be the next state, terminating l.nextItem. // back a nil pointer that will be the next state, terminating l.nextItem.
func (l *lexer) errorf(format string, args ...interface{}) stateFn { func (l *lexer) errorf(format string, args ...interface{}) stateFn {
l.items = append(l.items, item{itemError, l.start, l.err = fmt.Errorf(format, args...)
fmt.Sprintf(format, args...), l.line}) l.items = append(l.items, item{itemError, l.start, l.pos, l.line})
return nil return nil
} }
// lex creates a new scanner for the input string. // lex creates a new scanner for the input string.
func lex(l *lexer, input string) error { func lex(l *lexer, input []byte) error {
if len(input) == 0 { if len(input) == 0 {
return errors.New("empty query") return errors.New("empty query")
} }
l.input = input l.input = input
l.line = 1
l.items = l.itemsA[:0] l.items = l.itemsA[:0]
l.line = 1
l.run() l.run()
if last := l.items[len(l.items)-1]; last.typ == itemError { if last := l.items[len(l.items)-1]; last.typ == itemError {
return fmt.Errorf(last.val) return l.err
} }
return nil return nil
} }
@ -262,7 +250,7 @@ func lexRoot(l *lexer) stateFn {
if l.acceptAlphaNum() { if l.acceptAlphaNum() {
l.emit(itemVariable) l.emit(itemVariable)
} }
case strings.ContainsRune("!():=[]{|}", r): case contains(l.input, l.start, l.pos, punctuatorToken):
if item, ok := punctuators[r]; ok { if item, ok := punctuators[r]; ok {
l.emit(item) l.emit(item)
} else { } else {
@ -273,7 +261,7 @@ func lexRoot(l *lexer) stateFn {
return lexString return lexString
case r == '.': case r == '.':
if len(l.input) >= 3 { if len(l.input) >= 3 {
if strings.HasSuffix(l.input[:l.pos], "...") { if equals(l.input, 0, 3, spreadToken) {
l.emit(itemSpread) l.emit(itemSpread)
return lexRoot return lexRoot
} }
@ -295,34 +283,28 @@ func lexRoot(l *lexer) stateFn {
func lexName(l *lexer) stateFn { func lexName(l *lexer) stateFn {
for { for {
r := l.next() r := l.next()
if r == eof { if r == eof {
l.emit(itemEOF) l.emit(itemEOF)
return nil return nil
} }
if !isAlphaNumeric(r) { if !isAlphaNumeric(r) {
l.backup() l.backup()
v := l.current() s, e := l.current()
lowercase(l.input, s, e) lowercase(l.input, s, e)
if len(v) == 0 {
switch { switch {
case strings.EqualFold(v, "query"): case equals(l.input, s, e, queryToken):
l.emit(itemQuery) l.emit(itemQuery)
break case equals(l.input, s, e, mutationToken):
case strings.EqualFold(v, "mutation"):
l.emit(itemMutation) l.emit(itemMutation)
break case equals(l.input, s, e, subscriptionToken):
case strings.EqualFold(v, "subscription"):
l.emit(itemSub) l.emit(itemSub)
break case equals(l.input, s, e, trueToken):
}
}
switch {
case strings.EqualFold(v, "true"):
l.emit(itemBoolVal) l.emit(itemBoolVal)
case strings.EqualFold(v, "false"): case equals(l.input, s, e, falseToken):
l.emit(itemBoolVal) l.emit(itemBoolVal)
default: default:
l.emit(itemName) l.emit(itemName)
@ -335,7 +317,7 @@ func lexName(l *lexer) stateFn {
// lexString scans a string. // lexString scans a string.
func lexString(l *lexer) stateFn { func lexString(l *lexer) stateFn {
if l.accept("\"'") { if l.accept([]byte(quotesToken)) {
l.ignore() l.ignore()
for { for {
@ -347,7 +329,7 @@ func lexString(l *lexer) stateFn {
if r == '\'' || r == '"' { if r == '\'' || r == '"' {
l.backup() l.backup()
l.emit(itemStringVal) l.emit(itemStringVal)
if l.accept("\"'") { if l.accept(quotesToken) {
l.ignore() l.ignore()
} }
break break
@ -364,20 +346,19 @@ func lexString(l *lexer) stateFn {
func lexNumber(l *lexer) stateFn { func lexNumber(l *lexer) stateFn {
var it itemType var it itemType
// Optional leading sign. // Optional leading sign.
l.accept("+-") l.accept(signsToken)
// Is it integer // Is it integer
digits := "0123456789" if l.accept(digitToken) {
if l.accept(digits) { l.acceptRun(digitToken)
l.acceptRun(digits)
it = itemIntVal it = itemIntVal
} }
// Is it float // Is it float
if l.peek() == '.' { if l.peek() == '.' {
if l.accept(".") { if l.accept(dotToken) {
if l.accept(digits) { if l.accept(digitToken) {
l.acceptRun(digits) l.acceptRun(digitToken)
it = itemFloatVal it = itemFloatVal
} }
} else { } else {
@ -413,6 +394,74 @@ func isAlphaNumeric(r rune) bool {
return r == '_' || unicode.IsLetter(r) || unicode.IsDigit(r) return r == '_' || unicode.IsLetter(r) || unicode.IsDigit(r)
} }
func equals(b []byte, s Pos, e Pos, val []byte) bool {
n := 0
for i := s; i < e; i++ {
if n >= len(val) {
return true
}
switch {
case b[i] >= 'A' && b[i] <= 'Z' && ('a'+(b[i]-'A')) != val[n]:
return false
case b[i] != val[n]:
return false
}
n++
}
return true
}
func contains(b []byte, s Pos, e Pos, val []byte) bool {
for i := s; i < e; i++ {
for n := 0; n < len(val); n++ {
if b[i] == val[n] {
return true
}
}
}
return false
}
func lowercase(b []byte, s Pos, e Pos) {
for i := s; i < e; i++ {
if b[i] >= 'A' && b[i] <= 'Z' {
b[i] = ('a' + (b[i] - 'A'))
}
}
}
func (i *item) String() string {
var v string
switch i.typ {
case itemEOF:
v = "EOF"
case itemError:
v = "error"
case itemName:
v = "name"
case itemQuery:
v = "query"
case itemMutation:
v = "mutation"
case itemSub:
v = "subscription"
case itemPunctuator:
v = "punctuator"
case itemDirective:
v = "directive"
case itemVariable:
v = "variable"
case itemIntVal:
v = "int"
case itemFloatVal:
v = "float"
case itemStringVal:
v = "string"
}
return fmt.Sprintf("%s", v)
}
/* /*
Copyright (c) 2009 The Go Authors. All rights reserved. Copyright (c) 2009 The Go Authors. All rights reserved.

View File

@ -4,6 +4,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"sync" "sync"
"unsafe"
"github.com/dosco/super-graph/util" "github.com/dosco/super-graph/util"
) )
@ -53,9 +54,9 @@ type Field struct {
Name string Name string
Alias string Alias string
Args []Arg Args []Arg
argsA [10]Arg argsA [5]Arg
Children []int32 Children []int32
childrenA [10]int32 childrenA [5]int32
} }
type Arg struct { type Arg struct {
@ -78,6 +79,7 @@ func (n *Node) Reset() {
} }
type Parser struct { type Parser struct {
input []byte // the string being scanned
pos int pos int
items []item items []item
depth int depth int
@ -96,38 +98,32 @@ var lexPool = sync.Pool{
New: func() interface{} { return new(lexer) }, New: func() interface{} { return new(lexer) },
} }
func Parse(gql string) (*Operation, error) { func Parse(gql []byte) (*Operation, error) {
if len(gql) == 0 { return parseSelectionSet(nil, gql)
return nil, errors.New("blank query")
}
l := lexPool.Get().(*lexer)
l.Reset()
if err := lex(l, gql); err != nil {
return nil, err
}
p := &Parser{
pos: -1,
items: l.items,
}
op, err := p.parseOp()
lexPool.Put(l)
return op, err
} }
func ParseQuery(gql string) (*Operation, error) { func ParseQuery(gql []byte) (*Operation, error) {
return parseByType(gql, opQuery) op := opPool.Get().(*Operation)
op.Reset()
op.Type = opQuery
op.Name = ""
op.Fields = op.fieldsA[:0]
op.Args = op.argsA[:0]
return parseSelectionSet(op, gql)
} }
func ParseArgValue(argVal string) (*Node, error) { func ParseArgValue(argVal string) (*Node, error) {
l := lexPool.Get().(*lexer) l := lexPool.Get().(*lexer)
l.Reset() l.Reset()
if err := lex(l, argVal); err != nil { if err := lex(l, []byte(argVal)); err != nil {
return nil, err return nil, err
} }
p := &Parser{ p := &Parser{
input: l.input,
pos: -1, pos: -1,
items: l.items, items: l.items,
} }
@ -137,20 +133,42 @@ func ParseArgValue(argVal string) (*Node, error) {
return op, err return op, err
} }
func parseByType(gql string, ty parserType) (*Operation, error) { func parseSelectionSet(op *Operation, gql []byte) (*Operation, error) {
var err error
if len(gql) == 0 {
return nil, errors.New("blank query")
}
l := lexPool.Get().(*lexer) l := lexPool.Get().(*lexer)
l.Reset() l.Reset()
if err := lex(l, gql); err != nil { if err = lex(l, gql); err != nil {
return nil, err return nil, err
} }
p := &Parser{ p := &Parser{
input: l.input,
pos: -1, pos: -1,
items: l.items, items: l.items,
} }
op, err := p.parseOpByType(ty)
if op == nil {
op, err = p.parseOp()
} else {
if p.peek(itemObjOpen) {
p.ignore()
}
op.Fields, err = p.parseFields(op.Fields)
}
lexPool.Put(l) lexPool.Put(l)
if err != nil {
return nil, err
}
return op, err return op, err
} }
@ -198,18 +216,34 @@ func (p *Parser) peek(types ...itemType) bool {
return false return false
} }
func (p *Parser) parseOpByType(ty parserType) (*Operation, error) { func (p *Parser) parseOp() (*Operation, error) {
if p.peek(itemQuery, itemMutation, itemSub) == false {
err := fmt.Errorf(
"expecting a query, mutation or subscription (not '%s')",
p.val(p.next()))
return nil, err
}
item := p.next()
op := opPool.Get().(*Operation) op := opPool.Get().(*Operation)
op.Reset() op.Reset()
op.Type = ty switch item.typ {
case itemQuery:
op.Type = opQuery
case itemMutation:
op.Type = opMutate
case itemSub:
op.Type = opSub
}
op.Fields = op.fieldsA[:0] op.Fields = op.fieldsA[:0]
op.Args = op.argsA[:0] op.Args = op.argsA[:0]
var err error var err error
if p.peek(itemName) { if p.peek(itemName) {
op.Name = p.next().val op.Name = p.val(p.next())
} }
if p.peek(itemArgsOpen) { if p.peek(itemArgsOpen) {
@ -228,33 +262,9 @@ func (p *Parser) parseOpByType(ty parserType) (*Operation, error) {
} }
} }
if p.peek(itemObjClose) {
p.ignore()
}
return op, nil return op, nil
} }
func (p *Parser) parseOp() (*Operation, error) {
if p.peek(itemQuery, itemMutation, itemSub) == false {
err := fmt.Errorf("expecting a query, mutation or subscription (not '%s')", p.next().val)
return nil, err
}
item := p.next()
switch item.typ {
case itemQuery:
return p.parseOpByType(opQuery)
case itemMutation:
return p.parseOpByType(opMutate)
case itemSub:
return p.parseOpByType(opSub)
}
return nil, errors.New("unknown operation type")
}
func (p *Parser) parseFields(fields []Field) ([]Field, error) { func (p *Parser) parseFields(fields []Field) ([]Field, error) {
st := util.NewStack() st := util.NewStack()
@ -278,6 +288,7 @@ func (p *Parser) parseFields(fields []Field) ([]Field, error) {
} }
fields = append(fields, Field{ID: int32(len(fields))}) fields = append(fields, Field{ID: int32(len(fields))})
f := &fields[(len(fields) - 1)] f := &fields[(len(fields) - 1)]
f.Args = f.argsA[:0] f.Args = f.argsA[:0]
f.Children = f.childrenA[:0] f.Children = f.childrenA[:0]
@ -309,14 +320,14 @@ func (p *Parser) parseFields(fields []Field) ([]Field, error) {
func (p *Parser) parseField(f *Field) error { func (p *Parser) parseField(f *Field) error {
var err error var err error
f.Name = p.next().val f.Name = p.val(p.next())
if p.peek(itemColon) { if p.peek(itemColon) {
p.ignore() p.ignore()
if p.peek(itemName) { if p.peek(itemName) {
f.Alias = f.Name f.Alias = f.Name
f.Name = p.next().val f.Name = p.val(p.next())
} else { } else {
return errors.New("expecting an aliased field name") return errors.New("expecting an aliased field name")
} }
@ -347,7 +358,7 @@ func (p *Parser) parseArgs(args []Arg) ([]Arg, error) {
if p.peek(itemName) == false { if p.peek(itemName) == false {
return nil, errors.New("expecting an argument name") return nil, errors.New("expecting an argument name")
} }
args = append(args, Arg{Name: p.next().val}) args = append(args, Arg{Name: p.val(p.next())})
arg := &args[(len(args) - 1)] arg := &args[(len(args) - 1)]
if p.peek(itemColon) == false { if p.peek(itemColon) == false {
@ -414,7 +425,7 @@ func (p *Parser) parseObj() (*Node, error) {
if p.peek(itemName) == false { if p.peek(itemName) == false {
return nil, errors.New("expecting an argument name") return nil, errors.New("expecting an argument name")
} }
nodeName := p.next().val nodeName := p.val(p.next())
if p.peek(itemColon) == false { if p.peek(itemColon) == false {
return nil, errors.New("missing ':' after Field argument name") return nil, errors.New("missing ':' after Field argument name")
@ -465,13 +476,21 @@ func (p *Parser) parseValue() (*Node, error) {
case itemVariable: case itemVariable:
node.Type = nodeVar node.Type = nodeVar
default: default:
return nil, fmt.Errorf("expecting a number, string, object, list or variable as an argument value (not %s)", p.next().val) return nil, fmt.Errorf("expecting a number, string, object, list or variable as an argument value (not %s)", p.val(p.next()))
} }
node.Val = item.val node.Val = p.val(item)
return node, nil return node, nil
} }
func (p *Parser) val(v item) string {
return b2s(p.input[v.pos:v.end])
}
func b2s(b []byte) string {
return *(*string)(unsafe.Pointer(&b))
}
func (t parserType) String() string { func (t parserType) String() string {
var v string var v string

View File

@ -47,12 +47,13 @@ func compareOp(op1, op2 Operation) error {
func TestCompile(t *testing.T) { func TestCompile(t *testing.T) {
qcompile, _ := NewCompiler(Config{}) qcompile, _ := NewCompiler(Config{})
_, err := qcompile.CompileQuery(`query {
_, err := qcompile.CompileQuery([]byte(`
product(id: 15) { product(id: 15) {
id id
name name
} }`))
}`)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@ -60,7 +61,8 @@ func TestCompile(t *testing.T) {
func TestInvalidCompile(t *testing.T) { func TestInvalidCompile(t *testing.T) {
qcompile, _ := NewCompiler(Config{}) qcompile, _ := NewCompiler(Config{})
_, err := qcompile.CompileQuery(`#`) _, err := qcompile.CompileQuery([]byte(`#`))
if err == nil { if err == nil {
t.Fatal(errors.New("expecting an error")) t.Fatal(errors.New("expecting an error"))
} }
@ -68,13 +70,14 @@ func TestInvalidCompile(t *testing.T) {
func TestEmptyCompile(t *testing.T) { func TestEmptyCompile(t *testing.T) {
qcompile, _ := NewCompiler(Config{}) qcompile, _ := NewCompiler(Config{})
_, err := qcompile.CompileQuery(``) _, err := qcompile.CompileQuery([]byte(``))
if err == nil { if err == nil {
t.Fatal(errors.New("expecting an error")) t.Fatal(errors.New("expecting an error"))
} }
} }
var gql = `query { var gql = []byte(`
products( products(
# returns only 30 items # returns only 30 items
limit: 30, limit: 30,
@ -93,8 +96,7 @@ var gql = `query {
id id
name name
price price
} }`)
}`
func BenchmarkQCompile(b *testing.B) { func BenchmarkQCompile(b *testing.B) {
qcompile, _ := NewCompiler(Config{}) qcompile, _ := NewCompiler(Config{})

View File

@ -50,6 +50,7 @@ type Exp struct {
ListType ValType ListType ValType
ListVal []string ListVal []string
Children []*Exp Children []*Exp
childrenA [5]*Exp
} }
type OrderBy struct { type OrderBy struct {
@ -144,7 +145,7 @@ func NewCompiler(c Config) (*Compiler, error) {
bl := make(map[string]struct{}, len(c.Blacklist)) bl := make(map[string]struct{}, len(c.Blacklist))
for i := range c.Blacklist { for i := range c.Blacklist {
bl[strings.ToLower(c.Blacklist[i])] = struct{}{} bl[c.Blacklist[i]] = struct{}{}
} }
fl, err := compileFilter(c.DefaultFilter) fl, err := compileFilter(c.DefaultFilter)
@ -159,9 +160,8 @@ func NewCompiler(c Config) (*Compiler, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
k1 := strings.ToLower(k) singular := flect.Singularize(k)
singular := flect.Singularize(k1) plural := flect.Pluralize(k)
plural := flect.Pluralize(k1)
fm[singular] = fil fm[singular] = fil
fm[plural] = fil fm[plural] = fil
@ -170,11 +170,11 @@ func NewCompiler(c Config) (*Compiler, error) {
return &Compiler{fl, fm, bl, c.KeepArgs}, nil return &Compiler{fl, fm, bl, c.KeepArgs}, nil
} }
func (com *Compiler) CompileQuery(query string) (*QCode, error) { func (com *Compiler) Compile(query []byte) (*QCode, error) {
var qc QCode var qc QCode
var err error var err error
op, err := ParseQuery(query) op, err := Parse(query)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -197,6 +197,25 @@ func (com *Compiler) CompileQuery(query string) (*QCode, error) {
return &qc, nil return &qc, nil
} }
func (com *Compiler) CompileQuery(query []byte) (*QCode, error) {
var err error
op, err := ParseQuery(query)
if err != nil {
return nil, err
}
qc := &QCode{}
qc.Query, err = com.compileQuery(op)
opPool.Put(op)
if err != nil {
return nil, err
}
return qc, nil
}
func (com *Compiler) compileQuery(op *Operation) (*Query, error) { func (com *Compiler) compileQuery(op *Operation) (*Query, error) {
id := int32(0) id := int32(0)
parentID := int32(0) parentID := int32(0)
@ -226,15 +245,14 @@ func (com *Compiler) compileQuery(op *Operation) (*Query, error) {
} }
field := &op.Fields[fid] field := &op.Fields[fid]
tn := strings.ToLower(field.Name) if _, ok := com.bl[field.Name]; ok {
if _, ok := com.bl[tn]; ok {
continue continue
} }
selects = append(selects, Select{ selects = append(selects, Select{
ID: id, ID: id,
ParentID: parentID, ParentID: parentID,
Table: tn, Table: field.Name,
Children: make([]int32, 0, 5), Children: make([]int32, 0, 5),
}) })
s := &selects[(len(selects) - 1)] s := &selects[(len(selects) - 1)]
@ -259,9 +277,8 @@ func (com *Compiler) compileQuery(op *Operation) (*Query, error) {
for _, cid := range field.Children { for _, cid := range field.Children {
f := op.Fields[cid] f := op.Fields[cid]
fn := strings.ToLower(f.Name)
if _, ok := com.bl[fn]; ok { if _, ok := com.bl[f.Name]; ok {
continue continue
} }
@ -271,7 +288,7 @@ func (com *Compiler) compileQuery(op *Operation) (*Query, error) {
continue continue
} }
col := Column{Name: fn} col := Column{Name: f.Name}
if len(f.Alias) != 0 { if len(f.Alias) != 0 {
col.FieldName = f.Alias col.FieldName = f.Alias
@ -298,8 +315,11 @@ func (com *Compiler) compileQuery(op *Operation) (*Query, error) {
if fil != nil && fil.Op != OpNop { if fil != nil && fil.Op != OpNop {
if root.Where != nil { if root.Where != nil {
ex := &Exp{Op: OpAnd, Children: []*Exp{fil, root.Where}} ow := root.Where
root.Where = ex root.Where = &Exp{Op: OpAnd}
root.Where.Children = root.Where.childrenA[:2]
root.Where.Children[0] = fil
root.Where.Children[1] = ow
} else { } else {
root.Where = fil root.Where = fil
} }
@ -322,9 +342,7 @@ func (com *Compiler) compileArgs(sel *Select, args []Arg) error {
for i := range args { for i := range args {
arg := &args[i] arg := &args[i]
an := strings.ToLower(arg.Name) switch arg.Name {
switch an {
case "id": case "id":
if sel.ID == 0 { if sel.ID == 0 {
err = com.compileArgID(sel, arg) err = com.compileArgID(sel, arg)
@ -348,7 +366,7 @@ func (com *Compiler) compileArgs(sel *Select, args []Arg) error {
} }
if sel.Args != nil { if sel.Args != nil {
sel.Args[an] = arg.Val sel.Args[arg.Name] = arg.Val
} else { } else {
nodePool.Put(arg.Val) nodePool.Put(arg.Val)
} }
@ -392,7 +410,7 @@ func (com *Compiler) compileArgNode(node *Node) (*Exp, error) {
} }
if len(eT.node.Name) != 0 { if len(eT.node.Name) != 0 {
if _, ok := com.bl[strings.ToLower(eT.node.Name)]; ok { if _, ok := com.bl[eT.node.Name]; ok {
continue continue
} }
} }
@ -468,7 +486,11 @@ func (com *Compiler) compileArgSearch(sel *Select, arg *Arg) error {
} }
if sel.Where != nil { if sel.Where != nil {
sel.Where = &Exp{Op: OpAnd, Children: []*Exp{ex, sel.Where}} ow := sel.Where
sel.Where = &Exp{Op: OpAnd}
sel.Where.Children = sel.Where.childrenA[:2]
sel.Where.Children[0] = ex
sel.Where.Children[1] = ow
} else { } else {
sel.Where = ex sel.Where = ex
} }
@ -484,7 +506,11 @@ func (com *Compiler) compileArgWhere(sel *Select, arg *Arg) error {
} }
if sel.Where != nil { if sel.Where != nil {
sel.Where = &Exp{Op: OpAnd, Children: []*Exp{ex, sel.Where}} ow := sel.Where
sel.Where = &Exp{Op: OpAnd}
sel.Where.Children = sel.Where.childrenA[:2]
sel.Where.Children[0] = ex
sel.Where.Children[1] = ow
} else { } else {
sel.Where = ex sel.Where = ex
} }
@ -515,7 +541,7 @@ func (com *Compiler) compileArgOrderBy(sel *Select, arg *Arg) error {
return fmt.Errorf("17: unexpected value %v (%t)", intf, intf) return fmt.Errorf("17: unexpected value %v (%t)", intf, intf)
} }
if _, ok := com.bl[strings.ToLower(node.Name)]; ok { if _, ok := com.bl[node.Name]; ok {
if !com.ka { if !com.ka {
nodePool.Put(node) nodePool.Put(node)
} }
@ -534,8 +560,7 @@ func (com *Compiler) compileArgOrderBy(sel *Select, arg *Arg) error {
ob := &OrderBy{} ob := &OrderBy{}
val := strings.ToLower(node.Val) switch node.Val {
switch val {
case "asc": case "asc":
ob.Order = OrderAsc ob.Order = OrderAsc
case "desc": case "desc":
@ -565,7 +590,7 @@ func (com *Compiler) compileArgOrderBy(sel *Select, arg *Arg) error {
func (com *Compiler) compileArgDistinctOn(sel *Select, arg *Arg) error { func (com *Compiler) compileArgDistinctOn(sel *Select, arg *Arg) error {
node := arg.Val node := arg.Val
if _, ok := com.bl[strings.ToLower(node.Name)]; ok { if _, ok := com.bl[node.Name]; ok {
return nil return nil
} }
@ -619,7 +644,6 @@ func compileSub() (*Query, error) {
} }
func newExp(st *util.Stack, eT *expT) (*Exp, error) { func newExp(st *util.Stack, eT *expT) (*Exp, error) {
ex := &Exp{}
node := eT.node node := eT.node
if len(node.Name) == 0 { if len(node.Name) == 0 {
@ -627,11 +651,13 @@ func newExp(st *util.Stack, eT *expT) (*Exp, error) {
return nil, nil return nil, nil
} }
name := strings.ToLower(node.Name) name := node.Name
if name[0] == '_' { if name[0] == '_' {
name = name[1:] name = name[1:]
} }
ex := &Exp{}
switch name { switch name {
case "and": case "and":
ex.Op = OpAnd ex.Op = OpAnd
@ -756,7 +782,7 @@ func setWhereColName(ex *Exp, node *Node) {
continue continue
} }
if len(n.Name) != 0 { if len(n.Name) != 0 {
k := strings.ToLower(n.Name) k := n.Name
if k == "and" || k == "or" || k == "not" || if k == "and" || k == "or" || k == "not" ||
k == "_and" || k == "_or" || k == "_not" { k == "_and" || k == "_or" || k == "_not" {
continue continue
@ -778,8 +804,7 @@ func setOrderByColName(ob *OrderBy, node *Node) {
for n := node; n != nil; n = n.Parent { for n := node; n != nil; n = n.Parent {
if len(n.Name) != 0 { if len(n.Name) != 0 {
k := strings.ToLower(n.Name) list = append([]string{n.Name}, list...)
list = append([]string{k}, list...)
} }
} }
if len(list) != 0 { if len(list) != 0 {

View File

@ -36,7 +36,7 @@ type coreContext struct {
func (c *coreContext) handleReq(w io.Writer, req *http.Request) error { func (c *coreContext) handleReq(w io.Writer, req *http.Request) error {
var err error var err error
qc, err := qcompile.CompileQuery(c.req.Query) qc, err := qcompile.CompileQuery([]byte(c.req.Query))
if err != nil { if err != nil {
return err return err
} }