Optimize lexer and fix bugs
This commit is contained in:
parent
9af320f396
commit
340dea242d
1
go.mod
1
go.mod
|
@ -27,6 +27,7 @@ require (
|
|||
github.com/spf13/viper v1.3.1
|
||||
github.com/valyala/fasttemplate v1.0.1
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2
|
||||
golang.org/x/net v0.0.0-20190311183353-d8887717615a
|
||||
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223
|
||||
mellium.im/sasl v0.2.1 // indirect
|
||||
)
|
||||
|
|
1
go.sum
1
go.sum
|
@ -111,6 +111,7 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90Pveol
|
|||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd h1:nTDtHvHSdCn1m6ITfMRqtOd/9+7a3s8RBNOZ3eYZzJA=
|
||||
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||
golang.org/x/net v0.0.0-20190311183353-d8887717615a h1:oWX7TPOiFAMXLq8o0ikBYfCJVlRHBcsciT5bXOrH628=
|
||||
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f h1:wMNYb4v58l5UBM7MYRLPG6ZhfOqbKu7X5eyFl8ZhKvA=
|
||||
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=
|
||||
|
|
|
@ -126,7 +126,7 @@ func TestMain(m *testing.M) {
|
|||
}
|
||||
|
||||
func compileGQLToPSQL(gql string) ([]byte, error) {
|
||||
qc, err := qcompile.CompileQuery(gql)
|
||||
qc, err := qcompile.Compile([]byte(gql))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -141,7 +141,7 @@ func compileGQLToPSQL(gql string) ([]byte, error) {
|
|||
|
||||
func withComplexArgs(t *testing.T) {
|
||||
gql := `query {
|
||||
products(
|
||||
proDUcts(
|
||||
# returns only 30 items
|
||||
limit: 30,
|
||||
|
||||
|
@ -157,7 +157,7 @@ func withComplexArgs(t *testing.T) {
|
|||
# only items with an id >= 30 and < 30 are returned
|
||||
where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) {
|
||||
id
|
||||
name
|
||||
NAME
|
||||
price
|
||||
}
|
||||
}`
|
||||
|
@ -482,8 +482,8 @@ func TestCompileGQL(t *testing.T) {
|
|||
t.Run("syntheticTables", syntheticTables)
|
||||
}
|
||||
|
||||
var benchGQL = `query {
|
||||
products(
|
||||
var benchGQL = []byte(`query {
|
||||
proDUcts(
|
||||
# returns only 30 items
|
||||
limit: 30,
|
||||
|
||||
|
@ -496,14 +496,14 @@ var benchGQL = `query {
|
|||
# only items with an id >= 30 and < 30 are returned
|
||||
where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) {
|
||||
id
|
||||
name
|
||||
NAME
|
||||
price
|
||||
user {
|
||||
full_name
|
||||
picture : avatar
|
||||
}
|
||||
}
|
||||
}`
|
||||
}`)
|
||||
|
||||
func BenchmarkCompile(b *testing.B) {
|
||||
w := &bytes.Buffer{}
|
||||
|
@ -514,7 +514,7 @@ func BenchmarkCompile(b *testing.B) {
|
|||
for n := 0; n < b.N; n++ {
|
||||
w.Reset()
|
||||
|
||||
qc, err := qcompile.CompileQuery(benchGQL)
|
||||
qc, err := qcompile.Compile(benchGQL)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
@ -535,7 +535,7 @@ func BenchmarkCompileParallel(b *testing.B) {
|
|||
for pb.Next() {
|
||||
w.Reset()
|
||||
|
||||
qc, err := qcompile.CompileQuery(benchGQL)
|
||||
qc, err := qcompile.Compile(benchGQL)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,7 @@
|
|||
goos: darwin
|
||||
goarch: amd64
|
||||
pkg: github.com/dosco/super-graph/qcode
|
||||
BenchmarkQCompile-8 200000 10029 ns/op 2291 B/op 38 allocs/op
|
||||
BenchmarkQCompileP-8 500000 2925 ns/op 2298 B/op 38 allocs/op
|
||||
PASS
|
||||
ok github.com/dosco/super-graph/qcode 3.616s
|
|
@ -2,10 +2,10 @@ package qcode
|
|||
|
||||
// FuzzerEntrypoint for Fuzzbuzz
|
||||
func FuzzerEntrypoint(data []byte) int {
|
||||
testData := string(data)
|
||||
//testData := string(data)
|
||||
|
||||
qcompile, _ := NewCompiler(Config{})
|
||||
_, err := qcompile.CompileQuery(testData)
|
||||
_, err := qcompile.Compile(data)
|
||||
if err != nil {
|
||||
return -1
|
||||
}
|
||||
|
|
223
qcode/lex.go
223
qcode/lex.go
|
@ -1,59 +1,37 @@
|
|||
package qcode
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"unicode"
|
||||
"unicode/utf8"
|
||||
)
|
||||
|
||||
var (
|
||||
queryToken = []byte("query")
|
||||
mutationToken = []byte("mutation")
|
||||
subscriptionToken = []byte("subscription")
|
||||
trueToken = []byte("true")
|
||||
falseToken = []byte("true")
|
||||
quotesToken = []byte(`'"`)
|
||||
signsToken = []byte(`+-`)
|
||||
punctuatorToken = []byte(`!():=[]{|}`)
|
||||
spreadToken = []byte(`...`)
|
||||
digitToken = []byte(`0123456789`)
|
||||
dotToken = []byte(`.`)
|
||||
)
|
||||
|
||||
// Pos represents a byte position in the original input text from which
|
||||
// this template was parsed.
|
||||
type Pos int
|
||||
|
||||
func (p Pos) Position() Pos {
|
||||
return p
|
||||
}
|
||||
|
||||
// item represents a token or text string returned from the scanner.
|
||||
type item struct {
|
||||
typ itemType // The type of this item.
|
||||
pos Pos // The starting position, in bytes, of this item in the input string.
|
||||
val string // The value of this item.
|
||||
line int // The line number at the start of this item.
|
||||
}
|
||||
|
||||
func (i *item) String() string {
|
||||
var v string
|
||||
|
||||
switch i.typ {
|
||||
case itemEOF:
|
||||
v = "EOF"
|
||||
case itemError:
|
||||
v = "error"
|
||||
case itemName:
|
||||
v = "name"
|
||||
case itemQuery:
|
||||
v = "query"
|
||||
case itemMutation:
|
||||
v = "mutation"
|
||||
case itemSub:
|
||||
v = "subscription"
|
||||
case itemPunctuator:
|
||||
v = "punctuator"
|
||||
case itemDirective:
|
||||
v = "directive"
|
||||
case itemVariable:
|
||||
v = "variable"
|
||||
case itemIntVal:
|
||||
v = "int"
|
||||
case itemFloatVal:
|
||||
v = "float"
|
||||
case itemStringVal:
|
||||
v = "string"
|
||||
}
|
||||
return fmt.Sprintf("%s %q", v, i.val)
|
||||
end Pos // The ending position, in bytes, of this item in the input string.
|
||||
line uint16 // The line number at the start of this item.
|
||||
}
|
||||
|
||||
// itemType identifies the type of lex items.
|
||||
|
@ -103,14 +81,14 @@ type stateFn func(*lexer) stateFn
|
|||
|
||||
// lexer holds the state of the scanner.
|
||||
type lexer struct {
|
||||
name string // the name of the input; used only for error reports
|
||||
input string // the string being scanned
|
||||
input []byte // the string being scanned
|
||||
pos Pos // current position in the input
|
||||
start Pos // start position of this item
|
||||
width Pos // width of last rune read from input
|
||||
items []item // array of scanned items
|
||||
itemsA [100]item
|
||||
line int // 1+number of newlines seen
|
||||
itemsA [50]item
|
||||
line uint16 // 1+number of newlines seen
|
||||
err error
|
||||
}
|
||||
|
||||
var zeroLex = lexer{}
|
||||
|
@ -119,13 +97,13 @@ func (l *lexer) Reset() {
|
|||
*l = zeroLex
|
||||
}
|
||||
|
||||
// next returns the next rune in the input.
|
||||
// next returns the next byte in the input.
|
||||
func (l *lexer) next() rune {
|
||||
if int(l.pos) >= len(l.input) {
|
||||
l.width = 0
|
||||
return eof
|
||||
}
|
||||
r, w := utf8.DecodeRuneInString(l.input[l.pos:])
|
||||
r, w := utf8.DecodeRune(l.input[l.pos:])
|
||||
l.width = Pos(w)
|
||||
l.pos += l.width
|
||||
if r == '\n' {
|
||||
|
@ -150,30 +128,38 @@ func (l *lexer) backup() {
|
|||
}
|
||||
}
|
||||
|
||||
func (l *lexer) current() string {
|
||||
return l.input[l.start:l.pos]
|
||||
func (l *lexer) current() (Pos, Pos) {
|
||||
return l.start, l.pos
|
||||
}
|
||||
|
||||
// emit passes an item back to the client.
|
||||
func (l *lexer) emit(t itemType) {
|
||||
l.items = append(l.items, item{t, l.start, l.input[l.start:l.pos], l.line})
|
||||
l.items = append(l.items, item{t, l.start, l.pos, l.line})
|
||||
// Some items contain text internally. If so, count their newlines.
|
||||
switch t {
|
||||
case itemName:
|
||||
l.line += strings.Count(l.input[l.start:l.pos], "\n")
|
||||
for i := l.start; i < l.pos; i++ {
|
||||
if l.input[i] == '\n' {
|
||||
l.line++
|
||||
}
|
||||
}
|
||||
}
|
||||
l.start = l.pos
|
||||
}
|
||||
|
||||
// ignore skips over the pending input before this point.
|
||||
func (l *lexer) ignore() {
|
||||
l.line += strings.Count(l.input[l.start:l.pos], "\n")
|
||||
for i := l.start; i < l.pos; i++ {
|
||||
if l.input[i] == '\n' {
|
||||
l.line++
|
||||
}
|
||||
}
|
||||
l.start = l.pos
|
||||
}
|
||||
|
||||
// accept consumes the next rune if it's from the valid set.
|
||||
func (l *lexer) accept(valid string) bool {
|
||||
if strings.ContainsRune(valid, l.next()) {
|
||||
func (l *lexer) accept(valid []byte) bool {
|
||||
if bytes.ContainsRune(valid, l.next()) {
|
||||
return true
|
||||
}
|
||||
l.backup()
|
||||
|
@ -199,8 +185,8 @@ func (l *lexer) acceptComment() {
|
|||
}
|
||||
|
||||
// acceptRun consumes a run of runes from the valid set.
|
||||
func (l *lexer) acceptRun(valid string) {
|
||||
for strings.ContainsRune(valid, l.next()) {
|
||||
func (l *lexer) acceptRun(valid []byte) {
|
||||
for bytes.ContainsRune(valid, l.next()) {
|
||||
}
|
||||
l.backup()
|
||||
}
|
||||
|
@ -208,23 +194,25 @@ func (l *lexer) acceptRun(valid string) {
|
|||
// errorf returns an error token and terminates the scan by passing
|
||||
// back a nil pointer that will be the next state, terminating l.nextItem.
|
||||
func (l *lexer) errorf(format string, args ...interface{}) stateFn {
|
||||
l.items = append(l.items, item{itemError, l.start,
|
||||
fmt.Sprintf(format, args...), l.line})
|
||||
l.err = fmt.Errorf(format, args...)
|
||||
l.items = append(l.items, item{itemError, l.start, l.pos, l.line})
|
||||
return nil
|
||||
}
|
||||
|
||||
// lex creates a new scanner for the input string.
|
||||
func lex(l *lexer, input string) error {
|
||||
func lex(l *lexer, input []byte) error {
|
||||
if len(input) == 0 {
|
||||
return errors.New("empty query")
|
||||
}
|
||||
|
||||
l.input = input
|
||||
l.line = 1
|
||||
l.items = l.itemsA[:0]
|
||||
l.line = 1
|
||||
|
||||
l.run()
|
||||
|
||||
if last := l.items[len(l.items)-1]; last.typ == itemError {
|
||||
return fmt.Errorf(last.val)
|
||||
return l.err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
@ -262,7 +250,7 @@ func lexRoot(l *lexer) stateFn {
|
|||
if l.acceptAlphaNum() {
|
||||
l.emit(itemVariable)
|
||||
}
|
||||
case strings.ContainsRune("!():=[]{|}", r):
|
||||
case contains(l.input, l.start, l.pos, punctuatorToken):
|
||||
if item, ok := punctuators[r]; ok {
|
||||
l.emit(item)
|
||||
} else {
|
||||
|
@ -273,7 +261,7 @@ func lexRoot(l *lexer) stateFn {
|
|||
return lexString
|
||||
case r == '.':
|
||||
if len(l.input) >= 3 {
|
||||
if strings.HasSuffix(l.input[:l.pos], "...") {
|
||||
if equals(l.input, 0, 3, spreadToken) {
|
||||
l.emit(itemSpread)
|
||||
return lexRoot
|
||||
}
|
||||
|
@ -295,34 +283,28 @@ func lexRoot(l *lexer) stateFn {
|
|||
func lexName(l *lexer) stateFn {
|
||||
for {
|
||||
r := l.next()
|
||||
|
||||
if r == eof {
|
||||
l.emit(itemEOF)
|
||||
return nil
|
||||
}
|
||||
|
||||
if !isAlphaNumeric(r) {
|
||||
l.backup()
|
||||
v := l.current()
|
||||
s, e := l.current()
|
||||
|
||||
lowercase(l.input, s, e)
|
||||
|
||||
if len(v) == 0 {
|
||||
switch {
|
||||
case strings.EqualFold(v, "query"):
|
||||
l.emit(itemQuery)
|
||||
break
|
||||
case strings.EqualFold(v, "mutation"):
|
||||
l.emit(itemMutation)
|
||||
break
|
||||
case strings.EqualFold(v, "subscription"):
|
||||
l.emit(itemSub)
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
switch {
|
||||
case strings.EqualFold(v, "true"):
|
||||
case equals(l.input, s, e, queryToken):
|
||||
l.emit(itemQuery)
|
||||
case equals(l.input, s, e, mutationToken):
|
||||
l.emit(itemMutation)
|
||||
case equals(l.input, s, e, subscriptionToken):
|
||||
l.emit(itemSub)
|
||||
case equals(l.input, s, e, trueToken):
|
||||
l.emit(itemBoolVal)
|
||||
case strings.EqualFold(v, "false"):
|
||||
case equals(l.input, s, e, falseToken):
|
||||
l.emit(itemBoolVal)
|
||||
default:
|
||||
l.emit(itemName)
|
||||
|
@ -335,7 +317,7 @@ func lexName(l *lexer) stateFn {
|
|||
|
||||
// lexString scans a string.
|
||||
func lexString(l *lexer) stateFn {
|
||||
if l.accept("\"'") {
|
||||
if l.accept([]byte(quotesToken)) {
|
||||
l.ignore()
|
||||
|
||||
for {
|
||||
|
@ -347,7 +329,7 @@ func lexString(l *lexer) stateFn {
|
|||
if r == '\'' || r == '"' {
|
||||
l.backup()
|
||||
l.emit(itemStringVal)
|
||||
if l.accept("\"'") {
|
||||
if l.accept(quotesToken) {
|
||||
l.ignore()
|
||||
}
|
||||
break
|
||||
|
@ -364,20 +346,19 @@ func lexString(l *lexer) stateFn {
|
|||
func lexNumber(l *lexer) stateFn {
|
||||
var it itemType
|
||||
// Optional leading sign.
|
||||
l.accept("+-")
|
||||
l.accept(signsToken)
|
||||
|
||||
// Is it integer
|
||||
digits := "0123456789"
|
||||
if l.accept(digits) {
|
||||
l.acceptRun(digits)
|
||||
if l.accept(digitToken) {
|
||||
l.acceptRun(digitToken)
|
||||
it = itemIntVal
|
||||
}
|
||||
|
||||
// Is it float
|
||||
if l.peek() == '.' {
|
||||
if l.accept(".") {
|
||||
if l.accept(digits) {
|
||||
l.acceptRun(digits)
|
||||
if l.accept(dotToken) {
|
||||
if l.accept(digitToken) {
|
||||
l.acceptRun(digitToken)
|
||||
it = itemFloatVal
|
||||
}
|
||||
} else {
|
||||
|
@ -413,6 +394,74 @@ func isAlphaNumeric(r rune) bool {
|
|||
return r == '_' || unicode.IsLetter(r) || unicode.IsDigit(r)
|
||||
}
|
||||
|
||||
func equals(b []byte, s Pos, e Pos, val []byte) bool {
|
||||
n := 0
|
||||
for i := s; i < e; i++ {
|
||||
if n >= len(val) {
|
||||
return true
|
||||
}
|
||||
switch {
|
||||
case b[i] >= 'A' && b[i] <= 'Z' && ('a'+(b[i]-'A')) != val[n]:
|
||||
return false
|
||||
case b[i] != val[n]:
|
||||
return false
|
||||
}
|
||||
n++
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func contains(b []byte, s Pos, e Pos, val []byte) bool {
|
||||
for i := s; i < e; i++ {
|
||||
for n := 0; n < len(val); n++ {
|
||||
if b[i] == val[n] {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func lowercase(b []byte, s Pos, e Pos) {
|
||||
for i := s; i < e; i++ {
|
||||
if b[i] >= 'A' && b[i] <= 'Z' {
|
||||
b[i] = ('a' + (b[i] - 'A'))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (i *item) String() string {
|
||||
var v string
|
||||
|
||||
switch i.typ {
|
||||
case itemEOF:
|
||||
v = "EOF"
|
||||
case itemError:
|
||||
v = "error"
|
||||
case itemName:
|
||||
v = "name"
|
||||
case itemQuery:
|
||||
v = "query"
|
||||
case itemMutation:
|
||||
v = "mutation"
|
||||
case itemSub:
|
||||
v = "subscription"
|
||||
case itemPunctuator:
|
||||
v = "punctuator"
|
||||
case itemDirective:
|
||||
v = "directive"
|
||||
case itemVariable:
|
||||
v = "variable"
|
||||
case itemIntVal:
|
||||
v = "int"
|
||||
case itemFloatVal:
|
||||
v = "float"
|
||||
case itemStringVal:
|
||||
v = "string"
|
||||
}
|
||||
return fmt.Sprintf("%s", v)
|
||||
}
|
||||
|
||||
/*
|
||||
|
||||
Copyright (c) 2009 The Go Authors. All rights reserved.
|
||||
|
|
137
qcode/parse.go
137
qcode/parse.go
|
@ -4,6 +4,7 @@ import (
|
|||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"unsafe"
|
||||
|
||||
"github.com/dosco/super-graph/util"
|
||||
)
|
||||
|
@ -53,9 +54,9 @@ type Field struct {
|
|||
Name string
|
||||
Alias string
|
||||
Args []Arg
|
||||
argsA [10]Arg
|
||||
argsA [5]Arg
|
||||
Children []int32
|
||||
childrenA [10]int32
|
||||
childrenA [5]int32
|
||||
}
|
||||
|
||||
type Arg struct {
|
||||
|
@ -78,6 +79,7 @@ func (n *Node) Reset() {
|
|||
}
|
||||
|
||||
type Parser struct {
|
||||
input []byte // the string being scanned
|
||||
pos int
|
||||
items []item
|
||||
depth int
|
||||
|
@ -96,38 +98,32 @@ var lexPool = sync.Pool{
|
|||
New: func() interface{} { return new(lexer) },
|
||||
}
|
||||
|
||||
func Parse(gql string) (*Operation, error) {
|
||||
if len(gql) == 0 {
|
||||
return nil, errors.New("blank query")
|
||||
}
|
||||
l := lexPool.Get().(*lexer)
|
||||
l.Reset()
|
||||
|
||||
if err := lex(l, gql); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p := &Parser{
|
||||
pos: -1,
|
||||
items: l.items,
|
||||
}
|
||||
op, err := p.parseOp()
|
||||
lexPool.Put(l)
|
||||
|
||||
return op, err
|
||||
func Parse(gql []byte) (*Operation, error) {
|
||||
return parseSelectionSet(nil, gql)
|
||||
}
|
||||
|
||||
func ParseQuery(gql string) (*Operation, error) {
|
||||
return parseByType(gql, opQuery)
|
||||
func ParseQuery(gql []byte) (*Operation, error) {
|
||||
op := opPool.Get().(*Operation)
|
||||
op.Reset()
|
||||
|
||||
op.Type = opQuery
|
||||
op.Name = ""
|
||||
op.Fields = op.fieldsA[:0]
|
||||
op.Args = op.argsA[:0]
|
||||
|
||||
return parseSelectionSet(op, gql)
|
||||
}
|
||||
|
||||
func ParseArgValue(argVal string) (*Node, error) {
|
||||
l := lexPool.Get().(*lexer)
|
||||
l.Reset()
|
||||
|
||||
if err := lex(l, argVal); err != nil {
|
||||
if err := lex(l, []byte(argVal)); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p := &Parser{
|
||||
input: l.input,
|
||||
pos: -1,
|
||||
items: l.items,
|
||||
}
|
||||
|
@ -137,20 +133,42 @@ func ParseArgValue(argVal string) (*Node, error) {
|
|||
return op, err
|
||||
}
|
||||
|
||||
func parseByType(gql string, ty parserType) (*Operation, error) {
|
||||
func parseSelectionSet(op *Operation, gql []byte) (*Operation, error) {
|
||||
var err error
|
||||
|
||||
if len(gql) == 0 {
|
||||
return nil, errors.New("blank query")
|
||||
}
|
||||
|
||||
l := lexPool.Get().(*lexer)
|
||||
l.Reset()
|
||||
|
||||
if err := lex(l, gql); err != nil {
|
||||
if err = lex(l, gql); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
p := &Parser{
|
||||
input: l.input,
|
||||
pos: -1,
|
||||
items: l.items,
|
||||
}
|
||||
op, err := p.parseOpByType(ty)
|
||||
|
||||
if op == nil {
|
||||
op, err = p.parseOp()
|
||||
} else {
|
||||
if p.peek(itemObjOpen) {
|
||||
p.ignore()
|
||||
}
|
||||
|
||||
op.Fields, err = p.parseFields(op.Fields)
|
||||
}
|
||||
|
||||
lexPool.Put(l)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return op, err
|
||||
}
|
||||
|
||||
|
@ -198,18 +216,34 @@ func (p *Parser) peek(types ...itemType) bool {
|
|||
return false
|
||||
}
|
||||
|
||||
func (p *Parser) parseOpByType(ty parserType) (*Operation, error) {
|
||||
func (p *Parser) parseOp() (*Operation, error) {
|
||||
if p.peek(itemQuery, itemMutation, itemSub) == false {
|
||||
err := fmt.Errorf(
|
||||
"expecting a query, mutation or subscription (not '%s')",
|
||||
p.val(p.next()))
|
||||
return nil, err
|
||||
}
|
||||
item := p.next()
|
||||
|
||||
op := opPool.Get().(*Operation)
|
||||
op.Reset()
|
||||
|
||||
op.Type = ty
|
||||
switch item.typ {
|
||||
case itemQuery:
|
||||
op.Type = opQuery
|
||||
case itemMutation:
|
||||
op.Type = opMutate
|
||||
case itemSub:
|
||||
op.Type = opSub
|
||||
}
|
||||
|
||||
op.Fields = op.fieldsA[:0]
|
||||
op.Args = op.argsA[:0]
|
||||
|
||||
var err error
|
||||
|
||||
if p.peek(itemName) {
|
||||
op.Name = p.next().val
|
||||
op.Name = p.val(p.next())
|
||||
}
|
||||
|
||||
if p.peek(itemArgsOpen) {
|
||||
|
@ -228,33 +262,9 @@ func (p *Parser) parseOpByType(ty parserType) (*Operation, error) {
|
|||
}
|
||||
}
|
||||
|
||||
if p.peek(itemObjClose) {
|
||||
p.ignore()
|
||||
}
|
||||
|
||||
return op, nil
|
||||
}
|
||||
|
||||
func (p *Parser) parseOp() (*Operation, error) {
|
||||
if p.peek(itemQuery, itemMutation, itemSub) == false {
|
||||
err := fmt.Errorf("expecting a query, mutation or subscription (not '%s')", p.next().val)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
item := p.next()
|
||||
|
||||
switch item.typ {
|
||||
case itemQuery:
|
||||
return p.parseOpByType(opQuery)
|
||||
case itemMutation:
|
||||
return p.parseOpByType(opMutate)
|
||||
case itemSub:
|
||||
return p.parseOpByType(opSub)
|
||||
}
|
||||
|
||||
return nil, errors.New("unknown operation type")
|
||||
}
|
||||
|
||||
func (p *Parser) parseFields(fields []Field) ([]Field, error) {
|
||||
st := util.NewStack()
|
||||
|
||||
|
@ -278,6 +288,7 @@ func (p *Parser) parseFields(fields []Field) ([]Field, error) {
|
|||
}
|
||||
|
||||
fields = append(fields, Field{ID: int32(len(fields))})
|
||||
|
||||
f := &fields[(len(fields) - 1)]
|
||||
f.Args = f.argsA[:0]
|
||||
f.Children = f.childrenA[:0]
|
||||
|
@ -309,14 +320,14 @@ func (p *Parser) parseFields(fields []Field) ([]Field, error) {
|
|||
|
||||
func (p *Parser) parseField(f *Field) error {
|
||||
var err error
|
||||
f.Name = p.next().val
|
||||
f.Name = p.val(p.next())
|
||||
|
||||
if p.peek(itemColon) {
|
||||
p.ignore()
|
||||
|
||||
if p.peek(itemName) {
|
||||
f.Alias = f.Name
|
||||
f.Name = p.next().val
|
||||
f.Name = p.val(p.next())
|
||||
} else {
|
||||
return errors.New("expecting an aliased field name")
|
||||
}
|
||||
|
@ -347,7 +358,7 @@ func (p *Parser) parseArgs(args []Arg) ([]Arg, error) {
|
|||
if p.peek(itemName) == false {
|
||||
return nil, errors.New("expecting an argument name")
|
||||
}
|
||||
args = append(args, Arg{Name: p.next().val})
|
||||
args = append(args, Arg{Name: p.val(p.next())})
|
||||
arg := &args[(len(args) - 1)]
|
||||
|
||||
if p.peek(itemColon) == false {
|
||||
|
@ -414,7 +425,7 @@ func (p *Parser) parseObj() (*Node, error) {
|
|||
if p.peek(itemName) == false {
|
||||
return nil, errors.New("expecting an argument name")
|
||||
}
|
||||
nodeName := p.next().val
|
||||
nodeName := p.val(p.next())
|
||||
|
||||
if p.peek(itemColon) == false {
|
||||
return nil, errors.New("missing ':' after Field argument name")
|
||||
|
@ -465,13 +476,21 @@ func (p *Parser) parseValue() (*Node, error) {
|
|||
case itemVariable:
|
||||
node.Type = nodeVar
|
||||
default:
|
||||
return nil, fmt.Errorf("expecting a number, string, object, list or variable as an argument value (not %s)", p.next().val)
|
||||
return nil, fmt.Errorf("expecting a number, string, object, list or variable as an argument value (not %s)", p.val(p.next()))
|
||||
}
|
||||
node.Val = item.val
|
||||
node.Val = p.val(item)
|
||||
|
||||
return node, nil
|
||||
}
|
||||
|
||||
func (p *Parser) val(v item) string {
|
||||
return b2s(p.input[v.pos:v.end])
|
||||
}
|
||||
|
||||
func b2s(b []byte) string {
|
||||
return *(*string)(unsafe.Pointer(&b))
|
||||
}
|
||||
|
||||
func (t parserType) String() string {
|
||||
var v string
|
||||
|
||||
|
|
|
@ -47,12 +47,13 @@ func compareOp(op1, op2 Operation) error {
|
|||
|
||||
func TestCompile(t *testing.T) {
|
||||
qcompile, _ := NewCompiler(Config{})
|
||||
_, err := qcompile.CompileQuery(`query {
|
||||
product(id: 15) {
|
||||
|
||||
_, err := qcompile.CompileQuery([]byte(`
|
||||
product(id: 15) {
|
||||
id
|
||||
name
|
||||
}
|
||||
}`)
|
||||
}`))
|
||||
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
@ -60,7 +61,8 @@ func TestCompile(t *testing.T) {
|
|||
|
||||
func TestInvalidCompile(t *testing.T) {
|
||||
qcompile, _ := NewCompiler(Config{})
|
||||
_, err := qcompile.CompileQuery(`#`)
|
||||
_, err := qcompile.CompileQuery([]byte(`#`))
|
||||
|
||||
if err == nil {
|
||||
t.Fatal(errors.New("expecting an error"))
|
||||
}
|
||||
|
@ -68,13 +70,14 @@ func TestInvalidCompile(t *testing.T) {
|
|||
|
||||
func TestEmptyCompile(t *testing.T) {
|
||||
qcompile, _ := NewCompiler(Config{})
|
||||
_, err := qcompile.CompileQuery(``)
|
||||
_, err := qcompile.CompileQuery([]byte(``))
|
||||
|
||||
if err == nil {
|
||||
t.Fatal(errors.New("expecting an error"))
|
||||
}
|
||||
}
|
||||
|
||||
var gql = `query {
|
||||
var gql = []byte(`
|
||||
products(
|
||||
# returns only 30 items
|
||||
limit: 30,
|
||||
|
@ -93,8 +96,7 @@ var gql = `query {
|
|||
id
|
||||
name
|
||||
price
|
||||
}
|
||||
}`
|
||||
}`)
|
||||
|
||||
func BenchmarkQCompile(b *testing.B) {
|
||||
qcompile, _ := NewCompiler(Config{})
|
||||
|
|
|
@ -50,6 +50,7 @@ type Exp struct {
|
|||
ListType ValType
|
||||
ListVal []string
|
||||
Children []*Exp
|
||||
childrenA [5]*Exp
|
||||
}
|
||||
|
||||
type OrderBy struct {
|
||||
|
@ -144,7 +145,7 @@ func NewCompiler(c Config) (*Compiler, error) {
|
|||
bl := make(map[string]struct{}, len(c.Blacklist))
|
||||
|
||||
for i := range c.Blacklist {
|
||||
bl[strings.ToLower(c.Blacklist[i])] = struct{}{}
|
||||
bl[c.Blacklist[i]] = struct{}{}
|
||||
}
|
||||
|
||||
fl, err := compileFilter(c.DefaultFilter)
|
||||
|
@ -159,9 +160,8 @@ func NewCompiler(c Config) (*Compiler, error) {
|
|||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
k1 := strings.ToLower(k)
|
||||
singular := flect.Singularize(k1)
|
||||
plural := flect.Pluralize(k1)
|
||||
singular := flect.Singularize(k)
|
||||
plural := flect.Pluralize(k)
|
||||
|
||||
fm[singular] = fil
|
||||
fm[plural] = fil
|
||||
|
@ -170,11 +170,11 @@ func NewCompiler(c Config) (*Compiler, error) {
|
|||
return &Compiler{fl, fm, bl, c.KeepArgs}, nil
|
||||
}
|
||||
|
||||
func (com *Compiler) CompileQuery(query string) (*QCode, error) {
|
||||
func (com *Compiler) Compile(query []byte) (*QCode, error) {
|
||||
var qc QCode
|
||||
var err error
|
||||
|
||||
op, err := ParseQuery(query)
|
||||
op, err := Parse(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -197,6 +197,25 @@ func (com *Compiler) CompileQuery(query string) (*QCode, error) {
|
|||
return &qc, nil
|
||||
}
|
||||
|
||||
func (com *Compiler) CompileQuery(query []byte) (*QCode, error) {
|
||||
var err error
|
||||
|
||||
op, err := ParseQuery(query)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
qc := &QCode{}
|
||||
qc.Query, err = com.compileQuery(op)
|
||||
opPool.Put(op)
|
||||
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return qc, nil
|
||||
}
|
||||
|
||||
func (com *Compiler) compileQuery(op *Operation) (*Query, error) {
|
||||
id := int32(0)
|
||||
parentID := int32(0)
|
||||
|
@ -226,15 +245,14 @@ func (com *Compiler) compileQuery(op *Operation) (*Query, error) {
|
|||
}
|
||||
field := &op.Fields[fid]
|
||||
|
||||
tn := strings.ToLower(field.Name)
|
||||
if _, ok := com.bl[tn]; ok {
|
||||
if _, ok := com.bl[field.Name]; ok {
|
||||
continue
|
||||
}
|
||||
|
||||
selects = append(selects, Select{
|
||||
ID: id,
|
||||
ParentID: parentID,
|
||||
Table: tn,
|
||||
Table: field.Name,
|
||||
Children: make([]int32, 0, 5),
|
||||
})
|
||||
s := &selects[(len(selects) - 1)]
|
||||
|
@ -259,9 +277,8 @@ func (com *Compiler) compileQuery(op *Operation) (*Query, error) {
|
|||
|
||||
for _, cid := range field.Children {
|
||||
f := op.Fields[cid]
|
||||
fn := strings.ToLower(f.Name)
|
||||
|
||||
if _, ok := com.bl[fn]; ok {
|
||||
if _, ok := com.bl[f.Name]; ok {
|
||||
continue
|
||||
}
|
||||
|
||||
|
@ -271,7 +288,7 @@ func (com *Compiler) compileQuery(op *Operation) (*Query, error) {
|
|||
continue
|
||||
}
|
||||
|
||||
col := Column{Name: fn}
|
||||
col := Column{Name: f.Name}
|
||||
|
||||
if len(f.Alias) != 0 {
|
||||
col.FieldName = f.Alias
|
||||
|
@ -298,8 +315,11 @@ func (com *Compiler) compileQuery(op *Operation) (*Query, error) {
|
|||
if fil != nil && fil.Op != OpNop {
|
||||
|
||||
if root.Where != nil {
|
||||
ex := &Exp{Op: OpAnd, Children: []*Exp{fil, root.Where}}
|
||||
root.Where = ex
|
||||
ow := root.Where
|
||||
root.Where = &Exp{Op: OpAnd}
|
||||
root.Where.Children = root.Where.childrenA[:2]
|
||||
root.Where.Children[0] = fil
|
||||
root.Where.Children[1] = ow
|
||||
} else {
|
||||
root.Where = fil
|
||||
}
|
||||
|
@ -322,9 +342,7 @@ func (com *Compiler) compileArgs(sel *Select, args []Arg) error {
|
|||
for i := range args {
|
||||
arg := &args[i]
|
||||
|
||||
an := strings.ToLower(arg.Name)
|
||||
|
||||
switch an {
|
||||
switch arg.Name {
|
||||
case "id":
|
||||
if sel.ID == 0 {
|
||||
err = com.compileArgID(sel, arg)
|
||||
|
@ -348,7 +366,7 @@ func (com *Compiler) compileArgs(sel *Select, args []Arg) error {
|
|||
}
|
||||
|
||||
if sel.Args != nil {
|
||||
sel.Args[an] = arg.Val
|
||||
sel.Args[arg.Name] = arg.Val
|
||||
} else {
|
||||
nodePool.Put(arg.Val)
|
||||
}
|
||||
|
@ -392,7 +410,7 @@ func (com *Compiler) compileArgNode(node *Node) (*Exp, error) {
|
|||
}
|
||||
|
||||
if len(eT.node.Name) != 0 {
|
||||
if _, ok := com.bl[strings.ToLower(eT.node.Name)]; ok {
|
||||
if _, ok := com.bl[eT.node.Name]; ok {
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
@ -468,7 +486,11 @@ func (com *Compiler) compileArgSearch(sel *Select, arg *Arg) error {
|
|||
}
|
||||
|
||||
if sel.Where != nil {
|
||||
sel.Where = &Exp{Op: OpAnd, Children: []*Exp{ex, sel.Where}}
|
||||
ow := sel.Where
|
||||
sel.Where = &Exp{Op: OpAnd}
|
||||
sel.Where.Children = sel.Where.childrenA[:2]
|
||||
sel.Where.Children[0] = ex
|
||||
sel.Where.Children[1] = ow
|
||||
} else {
|
||||
sel.Where = ex
|
||||
}
|
||||
|
@ -484,7 +506,11 @@ func (com *Compiler) compileArgWhere(sel *Select, arg *Arg) error {
|
|||
}
|
||||
|
||||
if sel.Where != nil {
|
||||
sel.Where = &Exp{Op: OpAnd, Children: []*Exp{ex, sel.Where}}
|
||||
ow := sel.Where
|
||||
sel.Where = &Exp{Op: OpAnd}
|
||||
sel.Where.Children = sel.Where.childrenA[:2]
|
||||
sel.Where.Children[0] = ex
|
||||
sel.Where.Children[1] = ow
|
||||
} else {
|
||||
sel.Where = ex
|
||||
}
|
||||
|
@ -515,7 +541,7 @@ func (com *Compiler) compileArgOrderBy(sel *Select, arg *Arg) error {
|
|||
return fmt.Errorf("17: unexpected value %v (%t)", intf, intf)
|
||||
}
|
||||
|
||||
if _, ok := com.bl[strings.ToLower(node.Name)]; ok {
|
||||
if _, ok := com.bl[node.Name]; ok {
|
||||
if !com.ka {
|
||||
nodePool.Put(node)
|
||||
}
|
||||
|
@ -534,8 +560,7 @@ func (com *Compiler) compileArgOrderBy(sel *Select, arg *Arg) error {
|
|||
|
||||
ob := &OrderBy{}
|
||||
|
||||
val := strings.ToLower(node.Val)
|
||||
switch val {
|
||||
switch node.Val {
|
||||
case "asc":
|
||||
ob.Order = OrderAsc
|
||||
case "desc":
|
||||
|
@ -565,7 +590,7 @@ func (com *Compiler) compileArgOrderBy(sel *Select, arg *Arg) error {
|
|||
func (com *Compiler) compileArgDistinctOn(sel *Select, arg *Arg) error {
|
||||
node := arg.Val
|
||||
|
||||
if _, ok := com.bl[strings.ToLower(node.Name)]; ok {
|
||||
if _, ok := com.bl[node.Name]; ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -619,7 +644,6 @@ func compileSub() (*Query, error) {
|
|||
}
|
||||
|
||||
func newExp(st *util.Stack, eT *expT) (*Exp, error) {
|
||||
ex := &Exp{}
|
||||
node := eT.node
|
||||
|
||||
if len(node.Name) == 0 {
|
||||
|
@ -627,11 +651,13 @@ func newExp(st *util.Stack, eT *expT) (*Exp, error) {
|
|||
return nil, nil
|
||||
}
|
||||
|
||||
name := strings.ToLower(node.Name)
|
||||
name := node.Name
|
||||
if name[0] == '_' {
|
||||
name = name[1:]
|
||||
}
|
||||
|
||||
ex := &Exp{}
|
||||
|
||||
switch name {
|
||||
case "and":
|
||||
ex.Op = OpAnd
|
||||
|
@ -756,7 +782,7 @@ func setWhereColName(ex *Exp, node *Node) {
|
|||
continue
|
||||
}
|
||||
if len(n.Name) != 0 {
|
||||
k := strings.ToLower(n.Name)
|
||||
k := n.Name
|
||||
if k == "and" || k == "or" || k == "not" ||
|
||||
k == "_and" || k == "_or" || k == "_not" {
|
||||
continue
|
||||
|
@ -778,8 +804,7 @@ func setOrderByColName(ob *OrderBy, node *Node) {
|
|||
|
||||
for n := node; n != nil; n = n.Parent {
|
||||
if len(n.Name) != 0 {
|
||||
k := strings.ToLower(n.Name)
|
||||
list = append([]string{k}, list...)
|
||||
list = append([]string{n.Name}, list...)
|
||||
}
|
||||
}
|
||||
if len(list) != 0 {
|
||||
|
|
|
@ -36,7 +36,7 @@ type coreContext struct {
|
|||
func (c *coreContext) handleReq(w io.Writer, req *http.Request) error {
|
||||
var err error
|
||||
|
||||
qc, err := qcompile.CompileQuery(c.req.Query)
|
||||
qc, err := qcompile.CompileQuery([]byte(c.req.Query))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue