From 340dea242d0ab1cc79879fa29536c0da022862c9 Mon Sep 17 00:00:00 2001 From: Vikram Rangnekar Date: Fri, 14 Jun 2019 22:17:21 -0400 Subject: [PATCH] Optimize lexer and fix bugs --- go.mod | 1 + go.sum | 1 + psql/psql_test.go | 18 ++-- qcode/bench.3 | 7 ++ qcode/fuzz.go | 4 +- qcode/lex.go | 223 +++++++++++++++++++++++++++----------------- qcode/parse.go | 137 +++++++++++++++------------ qcode/parse_test.go | 20 ++-- qcode/qcode.go | 85 +++++++++++------ serv/core.go | 2 +- 10 files changed, 301 insertions(+), 197 deletions(-) create mode 100644 qcode/bench.3 diff --git a/go.mod b/go.mod index 96cadbf..4209118 100644 --- a/go.mod +++ b/go.mod @@ -27,6 +27,7 @@ require ( github.com/spf13/viper v1.3.1 github.com/valyala/fasttemplate v1.0.1 golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 + golang.org/x/net v0.0.0-20190311183353-d8887717615a golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223 mellium.im/sasl v0.2.1 // indirect ) diff --git a/go.sum b/go.sum index 23aa80e..19ef946 100644 --- a/go.sum +++ b/go.sum @@ -111,6 +111,7 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90Pveol golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd h1:nTDtHvHSdCn1m6ITfMRqtOd/9+7a3s8RBNOZ3eYZzJA= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= +golang.org/x/net v0.0.0-20190311183353-d8887717615a h1:oWX7TPOiFAMXLq8o0ikBYfCJVlRHBcsciT5bXOrH628= golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f h1:wMNYb4v58l5UBM7MYRLPG6ZhfOqbKu7X5eyFl8ZhKvA= golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= diff --git a/psql/psql_test.go b/psql/psql_test.go index 6d2bef4..158c1a8 100644 --- a/psql/psql_test.go +++ b/psql/psql_test.go @@ -126,7 +126,7 @@ func TestMain(m *testing.M) { } func compileGQLToPSQL(gql string) ([]byte, error) { - qc, err := qcompile.CompileQuery(gql) + qc, err := qcompile.Compile([]byte(gql)) if err != nil { return nil, err } @@ -141,7 +141,7 @@ func compileGQLToPSQL(gql string) ([]byte, error) { func withComplexArgs(t *testing.T) { gql := `query { - products( + proDUcts( # returns only 30 items limit: 30, @@ -157,7 +157,7 @@ func withComplexArgs(t *testing.T) { # only items with an id >= 30 and < 30 are returned where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) { id - name + NAME price } }` @@ -482,8 +482,8 @@ func TestCompileGQL(t *testing.T) { t.Run("syntheticTables", syntheticTables) } -var benchGQL = `query { - products( +var benchGQL = []byte(`query { + proDUcts( # returns only 30 items limit: 30, @@ -496,14 +496,14 @@ var benchGQL = `query { # only items with an id >= 30 and < 30 are returned where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) { id - name + NAME price user { full_name picture : avatar } } -}` +}`) func BenchmarkCompile(b *testing.B) { w := &bytes.Buffer{} @@ -514,7 +514,7 @@ func BenchmarkCompile(b *testing.B) { for n := 0; n < b.N; n++ { w.Reset() - qc, err := qcompile.CompileQuery(benchGQL) + qc, err := qcompile.Compile(benchGQL) if err != nil { b.Fatal(err) } @@ -535,7 +535,7 @@ func BenchmarkCompileParallel(b *testing.B) { for pb.Next() { w.Reset() - qc, err := qcompile.CompileQuery(benchGQL) + qc, err := qcompile.Compile(benchGQL) if err != nil { b.Fatal(err) } diff --git a/qcode/bench.3 b/qcode/bench.3 new file mode 100644 index 0000000..39ab100 --- /dev/null +++ b/qcode/bench.3 @@ -0,0 +1,7 @@ +goos: darwin +goarch: amd64 +pkg: github.com/dosco/super-graph/qcode +BenchmarkQCompile-8 200000 10029 ns/op 2291 B/op 38 allocs/op +BenchmarkQCompileP-8 500000 2925 ns/op 2298 B/op 38 allocs/op +PASS +ok github.com/dosco/super-graph/qcode 3.616s diff --git a/qcode/fuzz.go b/qcode/fuzz.go index 1201eb7..db8f3c8 100644 --- a/qcode/fuzz.go +++ b/qcode/fuzz.go @@ -2,10 +2,10 @@ package qcode // FuzzerEntrypoint for Fuzzbuzz func FuzzerEntrypoint(data []byte) int { - testData := string(data) + //testData := string(data) qcompile, _ := NewCompiler(Config{}) - _, err := qcompile.CompileQuery(testData) + _, err := qcompile.Compile(data) if err != nil { return -1 } diff --git a/qcode/lex.go b/qcode/lex.go index 25ed9e8..14ece16 100644 --- a/qcode/lex.go +++ b/qcode/lex.go @@ -1,59 +1,37 @@ package qcode import ( + "bytes" "errors" "fmt" - "strings" "unicode" "unicode/utf8" ) +var ( + queryToken = []byte("query") + mutationToken = []byte("mutation") + subscriptionToken = []byte("subscription") + trueToken = []byte("true") + falseToken = []byte("true") + quotesToken = []byte(`'"`) + signsToken = []byte(`+-`) + punctuatorToken = []byte(`!():=[]{|}`) + spreadToken = []byte(`...`) + digitToken = []byte(`0123456789`) + dotToken = []byte(`.`) +) + // Pos represents a byte position in the original input text from which // this template was parsed. type Pos int -func (p Pos) Position() Pos { - return p -} - // item represents a token or text string returned from the scanner. type item struct { typ itemType // The type of this item. pos Pos // The starting position, in bytes, of this item in the input string. - val string // The value of this item. - line int // The line number at the start of this item. -} - -func (i *item) String() string { - var v string - - switch i.typ { - case itemEOF: - v = "EOF" - case itemError: - v = "error" - case itemName: - v = "name" - case itemQuery: - v = "query" - case itemMutation: - v = "mutation" - case itemSub: - v = "subscription" - case itemPunctuator: - v = "punctuator" - case itemDirective: - v = "directive" - case itemVariable: - v = "variable" - case itemIntVal: - v = "int" - case itemFloatVal: - v = "float" - case itemStringVal: - v = "string" - } - return fmt.Sprintf("%s %q", v, i.val) + end Pos // The ending position, in bytes, of this item in the input string. + line uint16 // The line number at the start of this item. } // itemType identifies the type of lex items. @@ -103,14 +81,14 @@ type stateFn func(*lexer) stateFn // lexer holds the state of the scanner. type lexer struct { - name string // the name of the input; used only for error reports - input string // the string being scanned + input []byte // the string being scanned pos Pos // current position in the input start Pos // start position of this item width Pos // width of last rune read from input items []item // array of scanned items - itemsA [100]item - line int // 1+number of newlines seen + itemsA [50]item + line uint16 // 1+number of newlines seen + err error } var zeroLex = lexer{} @@ -119,13 +97,13 @@ func (l *lexer) Reset() { *l = zeroLex } -// next returns the next rune in the input. +// next returns the next byte in the input. func (l *lexer) next() rune { if int(l.pos) >= len(l.input) { l.width = 0 return eof } - r, w := utf8.DecodeRuneInString(l.input[l.pos:]) + r, w := utf8.DecodeRune(l.input[l.pos:]) l.width = Pos(w) l.pos += l.width if r == '\n' { @@ -150,30 +128,38 @@ func (l *lexer) backup() { } } -func (l *lexer) current() string { - return l.input[l.start:l.pos] +func (l *lexer) current() (Pos, Pos) { + return l.start, l.pos } // emit passes an item back to the client. func (l *lexer) emit(t itemType) { - l.items = append(l.items, item{t, l.start, l.input[l.start:l.pos], l.line}) + l.items = append(l.items, item{t, l.start, l.pos, l.line}) // Some items contain text internally. If so, count their newlines. switch t { case itemName: - l.line += strings.Count(l.input[l.start:l.pos], "\n") + for i := l.start; i < l.pos; i++ { + if l.input[i] == '\n' { + l.line++ + } + } } l.start = l.pos } // ignore skips over the pending input before this point. func (l *lexer) ignore() { - l.line += strings.Count(l.input[l.start:l.pos], "\n") + for i := l.start; i < l.pos; i++ { + if l.input[i] == '\n' { + l.line++ + } + } l.start = l.pos } // accept consumes the next rune if it's from the valid set. -func (l *lexer) accept(valid string) bool { - if strings.ContainsRune(valid, l.next()) { +func (l *lexer) accept(valid []byte) bool { + if bytes.ContainsRune(valid, l.next()) { return true } l.backup() @@ -199,8 +185,8 @@ func (l *lexer) acceptComment() { } // acceptRun consumes a run of runes from the valid set. -func (l *lexer) acceptRun(valid string) { - for strings.ContainsRune(valid, l.next()) { +func (l *lexer) acceptRun(valid []byte) { + for bytes.ContainsRune(valid, l.next()) { } l.backup() } @@ -208,23 +194,25 @@ func (l *lexer) acceptRun(valid string) { // errorf returns an error token and terminates the scan by passing // back a nil pointer that will be the next state, terminating l.nextItem. func (l *lexer) errorf(format string, args ...interface{}) stateFn { - l.items = append(l.items, item{itemError, l.start, - fmt.Sprintf(format, args...), l.line}) + l.err = fmt.Errorf(format, args...) + l.items = append(l.items, item{itemError, l.start, l.pos, l.line}) return nil } // lex creates a new scanner for the input string. -func lex(l *lexer, input string) error { +func lex(l *lexer, input []byte) error { if len(input) == 0 { return errors.New("empty query") } + l.input = input - l.line = 1 l.items = l.itemsA[:0] + l.line = 1 + l.run() if last := l.items[len(l.items)-1]; last.typ == itemError { - return fmt.Errorf(last.val) + return l.err } return nil } @@ -262,7 +250,7 @@ func lexRoot(l *lexer) stateFn { if l.acceptAlphaNum() { l.emit(itemVariable) } - case strings.ContainsRune("!():=[]{|}", r): + case contains(l.input, l.start, l.pos, punctuatorToken): if item, ok := punctuators[r]; ok { l.emit(item) } else { @@ -273,7 +261,7 @@ func lexRoot(l *lexer) stateFn { return lexString case r == '.': if len(l.input) >= 3 { - if strings.HasSuffix(l.input[:l.pos], "...") { + if equals(l.input, 0, 3, spreadToken) { l.emit(itemSpread) return lexRoot } @@ -295,34 +283,28 @@ func lexRoot(l *lexer) stateFn { func lexName(l *lexer) stateFn { for { r := l.next() + if r == eof { l.emit(itemEOF) return nil } + if !isAlphaNumeric(r) { l.backup() - v := l.current() + s, e := l.current() lowercase(l.input, s, e) - if len(v) == 0 { - switch { - case strings.EqualFold(v, "query"): - l.emit(itemQuery) - break - case strings.EqualFold(v, "mutation"): - l.emit(itemMutation) - break - case strings.EqualFold(v, "subscription"): - l.emit(itemSub) - break - } - } - switch { - case strings.EqualFold(v, "true"): + case equals(l.input, s, e, queryToken): + l.emit(itemQuery) + case equals(l.input, s, e, mutationToken): + l.emit(itemMutation) + case equals(l.input, s, e, subscriptionToken): + l.emit(itemSub) + case equals(l.input, s, e, trueToken): l.emit(itemBoolVal) - case strings.EqualFold(v, "false"): + case equals(l.input, s, e, falseToken): l.emit(itemBoolVal) default: l.emit(itemName) @@ -335,7 +317,7 @@ func lexName(l *lexer) stateFn { // lexString scans a string. func lexString(l *lexer) stateFn { - if l.accept("\"'") { + if l.accept([]byte(quotesToken)) { l.ignore() for { @@ -347,7 +329,7 @@ func lexString(l *lexer) stateFn { if r == '\'' || r == '"' { l.backup() l.emit(itemStringVal) - if l.accept("\"'") { + if l.accept(quotesToken) { l.ignore() } break @@ -364,20 +346,19 @@ func lexString(l *lexer) stateFn { func lexNumber(l *lexer) stateFn { var it itemType // Optional leading sign. - l.accept("+-") + l.accept(signsToken) // Is it integer - digits := "0123456789" - if l.accept(digits) { - l.acceptRun(digits) + if l.accept(digitToken) { + l.acceptRun(digitToken) it = itemIntVal } // Is it float if l.peek() == '.' { - if l.accept(".") { - if l.accept(digits) { - l.acceptRun(digits) + if l.accept(dotToken) { + if l.accept(digitToken) { + l.acceptRun(digitToken) it = itemFloatVal } } else { @@ -413,6 +394,74 @@ func isAlphaNumeric(r rune) bool { return r == '_' || unicode.IsLetter(r) || unicode.IsDigit(r) } +func equals(b []byte, s Pos, e Pos, val []byte) bool { + n := 0 + for i := s; i < e; i++ { + if n >= len(val) { + return true + } + switch { + case b[i] >= 'A' && b[i] <= 'Z' && ('a'+(b[i]-'A')) != val[n]: + return false + case b[i] != val[n]: + return false + } + n++ + } + return true +} + +func contains(b []byte, s Pos, e Pos, val []byte) bool { + for i := s; i < e; i++ { + for n := 0; n < len(val); n++ { + if b[i] == val[n] { + return true + } + } + } + return false +} + +func lowercase(b []byte, s Pos, e Pos) { + for i := s; i < e; i++ { + if b[i] >= 'A' && b[i] <= 'Z' { + b[i] = ('a' + (b[i] - 'A')) + } + } +} + +func (i *item) String() string { + var v string + + switch i.typ { + case itemEOF: + v = "EOF" + case itemError: + v = "error" + case itemName: + v = "name" + case itemQuery: + v = "query" + case itemMutation: + v = "mutation" + case itemSub: + v = "subscription" + case itemPunctuator: + v = "punctuator" + case itemDirective: + v = "directive" + case itemVariable: + v = "variable" + case itemIntVal: + v = "int" + case itemFloatVal: + v = "float" + case itemStringVal: + v = "string" + } + return fmt.Sprintf("%s", v) +} + /* Copyright (c) 2009 The Go Authors. All rights reserved. diff --git a/qcode/parse.go b/qcode/parse.go index c4dc8c0..578cb7b 100644 --- a/qcode/parse.go +++ b/qcode/parse.go @@ -4,6 +4,7 @@ import ( "errors" "fmt" "sync" + "unsafe" "github.com/dosco/super-graph/util" ) @@ -53,9 +54,9 @@ type Field struct { Name string Alias string Args []Arg - argsA [10]Arg + argsA [5]Arg Children []int32 - childrenA [10]int32 + childrenA [5]int32 } type Arg struct { @@ -78,6 +79,7 @@ func (n *Node) Reset() { } type Parser struct { + input []byte // the string being scanned pos int items []item depth int @@ -96,38 +98,32 @@ var lexPool = sync.Pool{ New: func() interface{} { return new(lexer) }, } -func Parse(gql string) (*Operation, error) { - if len(gql) == 0 { - return nil, errors.New("blank query") - } - l := lexPool.Get().(*lexer) - l.Reset() - - if err := lex(l, gql); err != nil { - return nil, err - } - p := &Parser{ - pos: -1, - items: l.items, - } - op, err := p.parseOp() - lexPool.Put(l) - - return op, err +func Parse(gql []byte) (*Operation, error) { + return parseSelectionSet(nil, gql) } -func ParseQuery(gql string) (*Operation, error) { - return parseByType(gql, opQuery) +func ParseQuery(gql []byte) (*Operation, error) { + op := opPool.Get().(*Operation) + op.Reset() + + op.Type = opQuery + op.Name = "" + op.Fields = op.fieldsA[:0] + op.Args = op.argsA[:0] + + return parseSelectionSet(op, gql) } func ParseArgValue(argVal string) (*Node, error) { l := lexPool.Get().(*lexer) l.Reset() - if err := lex(l, argVal); err != nil { + if err := lex(l, []byte(argVal)); err != nil { return nil, err } + p := &Parser{ + input: l.input, pos: -1, items: l.items, } @@ -137,20 +133,42 @@ func ParseArgValue(argVal string) (*Node, error) { return op, err } -func parseByType(gql string, ty parserType) (*Operation, error) { +func parseSelectionSet(op *Operation, gql []byte) (*Operation, error) { + var err error + + if len(gql) == 0 { + return nil, errors.New("blank query") + } + l := lexPool.Get().(*lexer) l.Reset() - if err := lex(l, gql); err != nil { + if err = lex(l, gql); err != nil { return nil, err } + p := &Parser{ + input: l.input, pos: -1, items: l.items, } - op, err := p.parseOpByType(ty) + + if op == nil { + op, err = p.parseOp() + } else { + if p.peek(itemObjOpen) { + p.ignore() + } + + op.Fields, err = p.parseFields(op.Fields) + } + lexPool.Put(l) + if err != nil { + return nil, err + } + return op, err } @@ -198,18 +216,34 @@ func (p *Parser) peek(types ...itemType) bool { return false } -func (p *Parser) parseOpByType(ty parserType) (*Operation, error) { +func (p *Parser) parseOp() (*Operation, error) { + if p.peek(itemQuery, itemMutation, itemSub) == false { + err := fmt.Errorf( + "expecting a query, mutation or subscription (not '%s')", + p.val(p.next())) + return nil, err + } + item := p.next() + op := opPool.Get().(*Operation) op.Reset() - op.Type = ty + switch item.typ { + case itemQuery: + op.Type = opQuery + case itemMutation: + op.Type = opMutate + case itemSub: + op.Type = opSub + } + op.Fields = op.fieldsA[:0] op.Args = op.argsA[:0] var err error if p.peek(itemName) { - op.Name = p.next().val + op.Name = p.val(p.next()) } if p.peek(itemArgsOpen) { @@ -228,33 +262,9 @@ func (p *Parser) parseOpByType(ty parserType) (*Operation, error) { } } - if p.peek(itemObjClose) { - p.ignore() - } - return op, nil } -func (p *Parser) parseOp() (*Operation, error) { - if p.peek(itemQuery, itemMutation, itemSub) == false { - err := fmt.Errorf("expecting a query, mutation or subscription (not '%s')", p.next().val) - return nil, err - } - - item := p.next() - - switch item.typ { - case itemQuery: - return p.parseOpByType(opQuery) - case itemMutation: - return p.parseOpByType(opMutate) - case itemSub: - return p.parseOpByType(opSub) - } - - return nil, errors.New("unknown operation type") -} - func (p *Parser) parseFields(fields []Field) ([]Field, error) { st := util.NewStack() @@ -278,6 +288,7 @@ func (p *Parser) parseFields(fields []Field) ([]Field, error) { } fields = append(fields, Field{ID: int32(len(fields))}) + f := &fields[(len(fields) - 1)] f.Args = f.argsA[:0] f.Children = f.childrenA[:0] @@ -309,14 +320,14 @@ func (p *Parser) parseFields(fields []Field) ([]Field, error) { func (p *Parser) parseField(f *Field) error { var err error - f.Name = p.next().val + f.Name = p.val(p.next()) if p.peek(itemColon) { p.ignore() if p.peek(itemName) { f.Alias = f.Name - f.Name = p.next().val + f.Name = p.val(p.next()) } else { return errors.New("expecting an aliased field name") } @@ -347,7 +358,7 @@ func (p *Parser) parseArgs(args []Arg) ([]Arg, error) { if p.peek(itemName) == false { return nil, errors.New("expecting an argument name") } - args = append(args, Arg{Name: p.next().val}) + args = append(args, Arg{Name: p.val(p.next())}) arg := &args[(len(args) - 1)] if p.peek(itemColon) == false { @@ -414,7 +425,7 @@ func (p *Parser) parseObj() (*Node, error) { if p.peek(itemName) == false { return nil, errors.New("expecting an argument name") } - nodeName := p.next().val + nodeName := p.val(p.next()) if p.peek(itemColon) == false { return nil, errors.New("missing ':' after Field argument name") @@ -465,13 +476,21 @@ func (p *Parser) parseValue() (*Node, error) { case itemVariable: node.Type = nodeVar default: - return nil, fmt.Errorf("expecting a number, string, object, list or variable as an argument value (not %s)", p.next().val) + return nil, fmt.Errorf("expecting a number, string, object, list or variable as an argument value (not %s)", p.val(p.next())) } - node.Val = item.val + node.Val = p.val(item) return node, nil } +func (p *Parser) val(v item) string { + return b2s(p.input[v.pos:v.end]) +} + +func b2s(b []byte) string { + return *(*string)(unsafe.Pointer(&b)) +} + func (t parserType) String() string { var v string diff --git a/qcode/parse_test.go b/qcode/parse_test.go index c2e9834..3fef81c 100644 --- a/qcode/parse_test.go +++ b/qcode/parse_test.go @@ -47,12 +47,13 @@ func compareOp(op1, op2 Operation) error { func TestCompile(t *testing.T) { qcompile, _ := NewCompiler(Config{}) - _, err := qcompile.CompileQuery(`query { - product(id: 15) { + + _, err := qcompile.CompileQuery([]byte(` + product(id: 15) { id name - } - }`) + }`)) + if err != nil { t.Fatal(err) } @@ -60,7 +61,8 @@ func TestCompile(t *testing.T) { func TestInvalidCompile(t *testing.T) { qcompile, _ := NewCompiler(Config{}) - _, err := qcompile.CompileQuery(`#`) + _, err := qcompile.CompileQuery([]byte(`#`)) + if err == nil { t.Fatal(errors.New("expecting an error")) } @@ -68,13 +70,14 @@ func TestInvalidCompile(t *testing.T) { func TestEmptyCompile(t *testing.T) { qcompile, _ := NewCompiler(Config{}) - _, err := qcompile.CompileQuery(``) + _, err := qcompile.CompileQuery([]byte(``)) + if err == nil { t.Fatal(errors.New("expecting an error")) } } -var gql = `query { +var gql = []byte(` products( # returns only 30 items limit: 30, @@ -93,8 +96,7 @@ var gql = `query { id name price - } -}` + }`) func BenchmarkQCompile(b *testing.B) { qcompile, _ := NewCompiler(Config{}) diff --git a/qcode/qcode.go b/qcode/qcode.go index 01eb6ab..3c098db 100644 --- a/qcode/qcode.go +++ b/qcode/qcode.go @@ -50,6 +50,7 @@ type Exp struct { ListType ValType ListVal []string Children []*Exp + childrenA [5]*Exp } type OrderBy struct { @@ -144,7 +145,7 @@ func NewCompiler(c Config) (*Compiler, error) { bl := make(map[string]struct{}, len(c.Blacklist)) for i := range c.Blacklist { - bl[strings.ToLower(c.Blacklist[i])] = struct{}{} + bl[c.Blacklist[i]] = struct{}{} } fl, err := compileFilter(c.DefaultFilter) @@ -159,9 +160,8 @@ func NewCompiler(c Config) (*Compiler, error) { if err != nil { return nil, err } - k1 := strings.ToLower(k) - singular := flect.Singularize(k1) - plural := flect.Pluralize(k1) + singular := flect.Singularize(k) + plural := flect.Pluralize(k) fm[singular] = fil fm[plural] = fil @@ -170,11 +170,11 @@ func NewCompiler(c Config) (*Compiler, error) { return &Compiler{fl, fm, bl, c.KeepArgs}, nil } -func (com *Compiler) CompileQuery(query string) (*QCode, error) { +func (com *Compiler) Compile(query []byte) (*QCode, error) { var qc QCode var err error - op, err := ParseQuery(query) + op, err := Parse(query) if err != nil { return nil, err } @@ -197,6 +197,25 @@ func (com *Compiler) CompileQuery(query string) (*QCode, error) { return &qc, nil } +func (com *Compiler) CompileQuery(query []byte) (*QCode, error) { + var err error + + op, err := ParseQuery(query) + if err != nil { + return nil, err + } + + qc := &QCode{} + qc.Query, err = com.compileQuery(op) + opPool.Put(op) + + if err != nil { + return nil, err + } + + return qc, nil +} + func (com *Compiler) compileQuery(op *Operation) (*Query, error) { id := int32(0) parentID := int32(0) @@ -226,15 +245,14 @@ func (com *Compiler) compileQuery(op *Operation) (*Query, error) { } field := &op.Fields[fid] - tn := strings.ToLower(field.Name) - if _, ok := com.bl[tn]; ok { + if _, ok := com.bl[field.Name]; ok { continue } selects = append(selects, Select{ ID: id, ParentID: parentID, - Table: tn, + Table: field.Name, Children: make([]int32, 0, 5), }) s := &selects[(len(selects) - 1)] @@ -259,9 +277,8 @@ func (com *Compiler) compileQuery(op *Operation) (*Query, error) { for _, cid := range field.Children { f := op.Fields[cid] - fn := strings.ToLower(f.Name) - if _, ok := com.bl[fn]; ok { + if _, ok := com.bl[f.Name]; ok { continue } @@ -271,7 +288,7 @@ func (com *Compiler) compileQuery(op *Operation) (*Query, error) { continue } - col := Column{Name: fn} + col := Column{Name: f.Name} if len(f.Alias) != 0 { col.FieldName = f.Alias @@ -298,8 +315,11 @@ func (com *Compiler) compileQuery(op *Operation) (*Query, error) { if fil != nil && fil.Op != OpNop { if root.Where != nil { - ex := &Exp{Op: OpAnd, Children: []*Exp{fil, root.Where}} - root.Where = ex + ow := root.Where + root.Where = &Exp{Op: OpAnd} + root.Where.Children = root.Where.childrenA[:2] + root.Where.Children[0] = fil + root.Where.Children[1] = ow } else { root.Where = fil } @@ -322,9 +342,7 @@ func (com *Compiler) compileArgs(sel *Select, args []Arg) error { for i := range args { arg := &args[i] - an := strings.ToLower(arg.Name) - - switch an { + switch arg.Name { case "id": if sel.ID == 0 { err = com.compileArgID(sel, arg) @@ -348,7 +366,7 @@ func (com *Compiler) compileArgs(sel *Select, args []Arg) error { } if sel.Args != nil { - sel.Args[an] = arg.Val + sel.Args[arg.Name] = arg.Val } else { nodePool.Put(arg.Val) } @@ -392,7 +410,7 @@ func (com *Compiler) compileArgNode(node *Node) (*Exp, error) { } if len(eT.node.Name) != 0 { - if _, ok := com.bl[strings.ToLower(eT.node.Name)]; ok { + if _, ok := com.bl[eT.node.Name]; ok { continue } } @@ -468,7 +486,11 @@ func (com *Compiler) compileArgSearch(sel *Select, arg *Arg) error { } if sel.Where != nil { - sel.Where = &Exp{Op: OpAnd, Children: []*Exp{ex, sel.Where}} + ow := sel.Where + sel.Where = &Exp{Op: OpAnd} + sel.Where.Children = sel.Where.childrenA[:2] + sel.Where.Children[0] = ex + sel.Where.Children[1] = ow } else { sel.Where = ex } @@ -484,7 +506,11 @@ func (com *Compiler) compileArgWhere(sel *Select, arg *Arg) error { } if sel.Where != nil { - sel.Where = &Exp{Op: OpAnd, Children: []*Exp{ex, sel.Where}} + ow := sel.Where + sel.Where = &Exp{Op: OpAnd} + sel.Where.Children = sel.Where.childrenA[:2] + sel.Where.Children[0] = ex + sel.Where.Children[1] = ow } else { sel.Where = ex } @@ -515,7 +541,7 @@ func (com *Compiler) compileArgOrderBy(sel *Select, arg *Arg) error { return fmt.Errorf("17: unexpected value %v (%t)", intf, intf) } - if _, ok := com.bl[strings.ToLower(node.Name)]; ok { + if _, ok := com.bl[node.Name]; ok { if !com.ka { nodePool.Put(node) } @@ -534,8 +560,7 @@ func (com *Compiler) compileArgOrderBy(sel *Select, arg *Arg) error { ob := &OrderBy{} - val := strings.ToLower(node.Val) - switch val { + switch node.Val { case "asc": ob.Order = OrderAsc case "desc": @@ -565,7 +590,7 @@ func (com *Compiler) compileArgOrderBy(sel *Select, arg *Arg) error { func (com *Compiler) compileArgDistinctOn(sel *Select, arg *Arg) error { node := arg.Val - if _, ok := com.bl[strings.ToLower(node.Name)]; ok { + if _, ok := com.bl[node.Name]; ok { return nil } @@ -619,7 +644,6 @@ func compileSub() (*Query, error) { } func newExp(st *util.Stack, eT *expT) (*Exp, error) { - ex := &Exp{} node := eT.node if len(node.Name) == 0 { @@ -627,11 +651,13 @@ func newExp(st *util.Stack, eT *expT) (*Exp, error) { return nil, nil } - name := strings.ToLower(node.Name) + name := node.Name if name[0] == '_' { name = name[1:] } + ex := &Exp{} + switch name { case "and": ex.Op = OpAnd @@ -756,7 +782,7 @@ func setWhereColName(ex *Exp, node *Node) { continue } if len(n.Name) != 0 { - k := strings.ToLower(n.Name) + k := n.Name if k == "and" || k == "or" || k == "not" || k == "_and" || k == "_or" || k == "_not" { continue @@ -778,8 +804,7 @@ func setOrderByColName(ob *OrderBy, node *Node) { for n := node; n != nil; n = n.Parent { if len(n.Name) != 0 { - k := strings.ToLower(n.Name) - list = append([]string{k}, list...) + list = append([]string{n.Name}, list...) } } if len(list) != 0 { diff --git a/serv/core.go b/serv/core.go index c069902..458097a 100644 --- a/serv/core.go +++ b/serv/core.go @@ -36,7 +36,7 @@ type coreContext struct { func (c *coreContext) handleReq(w io.Writer, req *http.Request) error { var err error - qc, err := qcompile.CompileQuery(c.req.Query) + qc, err := qcompile.CompileQuery([]byte(c.req.Query)) if err != nil { return err }