Optimize lexer and fix bugs

This commit is contained in:
Vikram Rangnekar 2019-06-14 22:17:21 -04:00
parent 9af320f396
commit 340dea242d
10 changed files with 301 additions and 197 deletions

1
go.mod
View File

@ -27,6 +27,7 @@ require (
github.com/spf13/viper v1.3.1
github.com/valyala/fasttemplate v1.0.1
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2
golang.org/x/net v0.0.0-20190311183353-d8887717615a
golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223
mellium.im/sasl v0.2.1 // indirect
)

1
go.sum
View File

@ -111,6 +111,7 @@ golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2 h1:VklqNMn3ovrHsnt90Pveol
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd h1:nTDtHvHSdCn1m6ITfMRqtOd/9+7a3s8RBNOZ3eYZzJA=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20190311183353-d8887717615a h1:oWX7TPOiFAMXLq8o0ikBYfCJVlRHBcsciT5bXOrH628=
golang.org/x/net v0.0.0-20190311183353-d8887717615a/go.mod h1:t9HGtf8HONx5eT2rtn7q6eTqICYqUVnKs3thJo3Qplg=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f h1:wMNYb4v58l5UBM7MYRLPG6ZhfOqbKu7X5eyFl8ZhKvA=
golang.org/x/sync v0.0.0-20180314180146-1d60e4601c6f/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM=

View File

@ -126,7 +126,7 @@ func TestMain(m *testing.M) {
}
func compileGQLToPSQL(gql string) ([]byte, error) {
qc, err := qcompile.CompileQuery(gql)
qc, err := qcompile.Compile([]byte(gql))
if err != nil {
return nil, err
}
@ -141,7 +141,7 @@ func compileGQLToPSQL(gql string) ([]byte, error) {
func withComplexArgs(t *testing.T) {
gql := `query {
products(
proDUcts(
# returns only 30 items
limit: 30,
@ -157,7 +157,7 @@ func withComplexArgs(t *testing.T) {
# only items with an id >= 30 and < 30 are returned
where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) {
id
name
NAME
price
}
}`
@ -482,8 +482,8 @@ func TestCompileGQL(t *testing.T) {
t.Run("syntheticTables", syntheticTables)
}
var benchGQL = `query {
products(
var benchGQL = []byte(`query {
proDUcts(
# returns only 30 items
limit: 30,
@ -496,14 +496,14 @@ var benchGQL = `query {
# only items with an id >= 30 and < 30 are returned
where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) {
id
name
NAME
price
user {
full_name
picture : avatar
}
}
}`
}`)
func BenchmarkCompile(b *testing.B) {
w := &bytes.Buffer{}
@ -514,7 +514,7 @@ func BenchmarkCompile(b *testing.B) {
for n := 0; n < b.N; n++ {
w.Reset()
qc, err := qcompile.CompileQuery(benchGQL)
qc, err := qcompile.Compile(benchGQL)
if err != nil {
b.Fatal(err)
}
@ -535,7 +535,7 @@ func BenchmarkCompileParallel(b *testing.B) {
for pb.Next() {
w.Reset()
qc, err := qcompile.CompileQuery(benchGQL)
qc, err := qcompile.Compile(benchGQL)
if err != nil {
b.Fatal(err)
}

7
qcode/bench.3 Normal file
View File

@ -0,0 +1,7 @@
goos: darwin
goarch: amd64
pkg: github.com/dosco/super-graph/qcode
BenchmarkQCompile-8 200000 10029 ns/op 2291 B/op 38 allocs/op
BenchmarkQCompileP-8 500000 2925 ns/op 2298 B/op 38 allocs/op
PASS
ok github.com/dosco/super-graph/qcode 3.616s

View File

@ -2,10 +2,10 @@ package qcode
// FuzzerEntrypoint for Fuzzbuzz
func FuzzerEntrypoint(data []byte) int {
testData := string(data)
//testData := string(data)
qcompile, _ := NewCompiler(Config{})
_, err := qcompile.CompileQuery(testData)
_, err := qcompile.Compile(data)
if err != nil {
return -1
}

View File

@ -1,59 +1,37 @@
package qcode
import (
"bytes"
"errors"
"fmt"
"strings"
"unicode"
"unicode/utf8"
)
var (
queryToken = []byte("query")
mutationToken = []byte("mutation")
subscriptionToken = []byte("subscription")
trueToken = []byte("true")
falseToken = []byte("true")
quotesToken = []byte(`'"`)
signsToken = []byte(`+-`)
punctuatorToken = []byte(`!():=[]{|}`)
spreadToken = []byte(`...`)
digitToken = []byte(`0123456789`)
dotToken = []byte(`.`)
)
// Pos represents a byte position in the original input text from which
// this template was parsed.
type Pos int
func (p Pos) Position() Pos {
return p
}
// item represents a token or text string returned from the scanner.
type item struct {
typ itemType // The type of this item.
pos Pos // The starting position, in bytes, of this item in the input string.
val string // The value of this item.
line int // The line number at the start of this item.
}
func (i *item) String() string {
var v string
switch i.typ {
case itemEOF:
v = "EOF"
case itemError:
v = "error"
case itemName:
v = "name"
case itemQuery:
v = "query"
case itemMutation:
v = "mutation"
case itemSub:
v = "subscription"
case itemPunctuator:
v = "punctuator"
case itemDirective:
v = "directive"
case itemVariable:
v = "variable"
case itemIntVal:
v = "int"
case itemFloatVal:
v = "float"
case itemStringVal:
v = "string"
}
return fmt.Sprintf("%s %q", v, i.val)
end Pos // The ending position, in bytes, of this item in the input string.
line uint16 // The line number at the start of this item.
}
// itemType identifies the type of lex items.
@ -103,14 +81,14 @@ type stateFn func(*lexer) stateFn
// lexer holds the state of the scanner.
type lexer struct {
name string // the name of the input; used only for error reports
input string // the string being scanned
input []byte // the string being scanned
pos Pos // current position in the input
start Pos // start position of this item
width Pos // width of last rune read from input
items []item // array of scanned items
itemsA [100]item
line int // 1+number of newlines seen
itemsA [50]item
line uint16 // 1+number of newlines seen
err error
}
var zeroLex = lexer{}
@ -119,13 +97,13 @@ func (l *lexer) Reset() {
*l = zeroLex
}
// next returns the next rune in the input.
// next returns the next byte in the input.
func (l *lexer) next() rune {
if int(l.pos) >= len(l.input) {
l.width = 0
return eof
}
r, w := utf8.DecodeRuneInString(l.input[l.pos:])
r, w := utf8.DecodeRune(l.input[l.pos:])
l.width = Pos(w)
l.pos += l.width
if r == '\n' {
@ -150,30 +128,38 @@ func (l *lexer) backup() {
}
}
func (l *lexer) current() string {
return l.input[l.start:l.pos]
func (l *lexer) current() (Pos, Pos) {
return l.start, l.pos
}
// emit passes an item back to the client.
func (l *lexer) emit(t itemType) {
l.items = append(l.items, item{t, l.start, l.input[l.start:l.pos], l.line})
l.items = append(l.items, item{t, l.start, l.pos, l.line})
// Some items contain text internally. If so, count their newlines.
switch t {
case itemName:
l.line += strings.Count(l.input[l.start:l.pos], "\n")
for i := l.start; i < l.pos; i++ {
if l.input[i] == '\n' {
l.line++
}
}
}
l.start = l.pos
}
// ignore skips over the pending input before this point.
func (l *lexer) ignore() {
l.line += strings.Count(l.input[l.start:l.pos], "\n")
for i := l.start; i < l.pos; i++ {
if l.input[i] == '\n' {
l.line++
}
}
l.start = l.pos
}
// accept consumes the next rune if it's from the valid set.
func (l *lexer) accept(valid string) bool {
if strings.ContainsRune(valid, l.next()) {
func (l *lexer) accept(valid []byte) bool {
if bytes.ContainsRune(valid, l.next()) {
return true
}
l.backup()
@ -199,8 +185,8 @@ func (l *lexer) acceptComment() {
}
// acceptRun consumes a run of runes from the valid set.
func (l *lexer) acceptRun(valid string) {
for strings.ContainsRune(valid, l.next()) {
func (l *lexer) acceptRun(valid []byte) {
for bytes.ContainsRune(valid, l.next()) {
}
l.backup()
}
@ -208,23 +194,25 @@ func (l *lexer) acceptRun(valid string) {
// errorf returns an error token and terminates the scan by passing
// back a nil pointer that will be the next state, terminating l.nextItem.
func (l *lexer) errorf(format string, args ...interface{}) stateFn {
l.items = append(l.items, item{itemError, l.start,
fmt.Sprintf(format, args...), l.line})
l.err = fmt.Errorf(format, args...)
l.items = append(l.items, item{itemError, l.start, l.pos, l.line})
return nil
}
// lex creates a new scanner for the input string.
func lex(l *lexer, input string) error {
func lex(l *lexer, input []byte) error {
if len(input) == 0 {
return errors.New("empty query")
}
l.input = input
l.line = 1
l.items = l.itemsA[:0]
l.line = 1
l.run()
if last := l.items[len(l.items)-1]; last.typ == itemError {
return fmt.Errorf(last.val)
return l.err
}
return nil
}
@ -262,7 +250,7 @@ func lexRoot(l *lexer) stateFn {
if l.acceptAlphaNum() {
l.emit(itemVariable)
}
case strings.ContainsRune("!():=[]{|}", r):
case contains(l.input, l.start, l.pos, punctuatorToken):
if item, ok := punctuators[r]; ok {
l.emit(item)
} else {
@ -273,7 +261,7 @@ func lexRoot(l *lexer) stateFn {
return lexString
case r == '.':
if len(l.input) >= 3 {
if strings.HasSuffix(l.input[:l.pos], "...") {
if equals(l.input, 0, 3, spreadToken) {
l.emit(itemSpread)
return lexRoot
}
@ -295,34 +283,28 @@ func lexRoot(l *lexer) stateFn {
func lexName(l *lexer) stateFn {
for {
r := l.next()
if r == eof {
l.emit(itemEOF)
return nil
}
if !isAlphaNumeric(r) {
l.backup()
v := l.current()
s, e := l.current()
lowercase(l.input, s, e)
if len(v) == 0 {
switch {
case strings.EqualFold(v, "query"):
case equals(l.input, s, e, queryToken):
l.emit(itemQuery)
break
case strings.EqualFold(v, "mutation"):
case equals(l.input, s, e, mutationToken):
l.emit(itemMutation)
break
case strings.EqualFold(v, "subscription"):
case equals(l.input, s, e, subscriptionToken):
l.emit(itemSub)
break
}
}
switch {
case strings.EqualFold(v, "true"):
case equals(l.input, s, e, trueToken):
l.emit(itemBoolVal)
case strings.EqualFold(v, "false"):
case equals(l.input, s, e, falseToken):
l.emit(itemBoolVal)
default:
l.emit(itemName)
@ -335,7 +317,7 @@ func lexName(l *lexer) stateFn {
// lexString scans a string.
func lexString(l *lexer) stateFn {
if l.accept("\"'") {
if l.accept([]byte(quotesToken)) {
l.ignore()
for {
@ -347,7 +329,7 @@ func lexString(l *lexer) stateFn {
if r == '\'' || r == '"' {
l.backup()
l.emit(itemStringVal)
if l.accept("\"'") {
if l.accept(quotesToken) {
l.ignore()
}
break
@ -364,20 +346,19 @@ func lexString(l *lexer) stateFn {
func lexNumber(l *lexer) stateFn {
var it itemType
// Optional leading sign.
l.accept("+-")
l.accept(signsToken)
// Is it integer
digits := "0123456789"
if l.accept(digits) {
l.acceptRun(digits)
if l.accept(digitToken) {
l.acceptRun(digitToken)
it = itemIntVal
}
// Is it float
if l.peek() == '.' {
if l.accept(".") {
if l.accept(digits) {
l.acceptRun(digits)
if l.accept(dotToken) {
if l.accept(digitToken) {
l.acceptRun(digitToken)
it = itemFloatVal
}
} else {
@ -413,6 +394,74 @@ func isAlphaNumeric(r rune) bool {
return r == '_' || unicode.IsLetter(r) || unicode.IsDigit(r)
}
func equals(b []byte, s Pos, e Pos, val []byte) bool {
n := 0
for i := s; i < e; i++ {
if n >= len(val) {
return true
}
switch {
case b[i] >= 'A' && b[i] <= 'Z' && ('a'+(b[i]-'A')) != val[n]:
return false
case b[i] != val[n]:
return false
}
n++
}
return true
}
func contains(b []byte, s Pos, e Pos, val []byte) bool {
for i := s; i < e; i++ {
for n := 0; n < len(val); n++ {
if b[i] == val[n] {
return true
}
}
}
return false
}
func lowercase(b []byte, s Pos, e Pos) {
for i := s; i < e; i++ {
if b[i] >= 'A' && b[i] <= 'Z' {
b[i] = ('a' + (b[i] - 'A'))
}
}
}
func (i *item) String() string {
var v string
switch i.typ {
case itemEOF:
v = "EOF"
case itemError:
v = "error"
case itemName:
v = "name"
case itemQuery:
v = "query"
case itemMutation:
v = "mutation"
case itemSub:
v = "subscription"
case itemPunctuator:
v = "punctuator"
case itemDirective:
v = "directive"
case itemVariable:
v = "variable"
case itemIntVal:
v = "int"
case itemFloatVal:
v = "float"
case itemStringVal:
v = "string"
}
return fmt.Sprintf("%s", v)
}
/*
Copyright (c) 2009 The Go Authors. All rights reserved.

View File

@ -4,6 +4,7 @@ import (
"errors"
"fmt"
"sync"
"unsafe"
"github.com/dosco/super-graph/util"
)
@ -53,9 +54,9 @@ type Field struct {
Name string
Alias string
Args []Arg
argsA [10]Arg
argsA [5]Arg
Children []int32
childrenA [10]int32
childrenA [5]int32
}
type Arg struct {
@ -78,6 +79,7 @@ func (n *Node) Reset() {
}
type Parser struct {
input []byte // the string being scanned
pos int
items []item
depth int
@ -96,38 +98,32 @@ var lexPool = sync.Pool{
New: func() interface{} { return new(lexer) },
}
func Parse(gql string) (*Operation, error) {
if len(gql) == 0 {
return nil, errors.New("blank query")
}
l := lexPool.Get().(*lexer)
l.Reset()
if err := lex(l, gql); err != nil {
return nil, err
}
p := &Parser{
pos: -1,
items: l.items,
}
op, err := p.parseOp()
lexPool.Put(l)
return op, err
func Parse(gql []byte) (*Operation, error) {
return parseSelectionSet(nil, gql)
}
func ParseQuery(gql string) (*Operation, error) {
return parseByType(gql, opQuery)
func ParseQuery(gql []byte) (*Operation, error) {
op := opPool.Get().(*Operation)
op.Reset()
op.Type = opQuery
op.Name = ""
op.Fields = op.fieldsA[:0]
op.Args = op.argsA[:0]
return parseSelectionSet(op, gql)
}
func ParseArgValue(argVal string) (*Node, error) {
l := lexPool.Get().(*lexer)
l.Reset()
if err := lex(l, argVal); err != nil {
if err := lex(l, []byte(argVal)); err != nil {
return nil, err
}
p := &Parser{
input: l.input,
pos: -1,
items: l.items,
}
@ -137,20 +133,42 @@ func ParseArgValue(argVal string) (*Node, error) {
return op, err
}
func parseByType(gql string, ty parserType) (*Operation, error) {
func parseSelectionSet(op *Operation, gql []byte) (*Operation, error) {
var err error
if len(gql) == 0 {
return nil, errors.New("blank query")
}
l := lexPool.Get().(*lexer)
l.Reset()
if err := lex(l, gql); err != nil {
if err = lex(l, gql); err != nil {
return nil, err
}
p := &Parser{
input: l.input,
pos: -1,
items: l.items,
}
op, err := p.parseOpByType(ty)
if op == nil {
op, err = p.parseOp()
} else {
if p.peek(itemObjOpen) {
p.ignore()
}
op.Fields, err = p.parseFields(op.Fields)
}
lexPool.Put(l)
if err != nil {
return nil, err
}
return op, err
}
@ -198,18 +216,34 @@ func (p *Parser) peek(types ...itemType) bool {
return false
}
func (p *Parser) parseOpByType(ty parserType) (*Operation, error) {
func (p *Parser) parseOp() (*Operation, error) {
if p.peek(itemQuery, itemMutation, itemSub) == false {
err := fmt.Errorf(
"expecting a query, mutation or subscription (not '%s')",
p.val(p.next()))
return nil, err
}
item := p.next()
op := opPool.Get().(*Operation)
op.Reset()
op.Type = ty
switch item.typ {
case itemQuery:
op.Type = opQuery
case itemMutation:
op.Type = opMutate
case itemSub:
op.Type = opSub
}
op.Fields = op.fieldsA[:0]
op.Args = op.argsA[:0]
var err error
if p.peek(itemName) {
op.Name = p.next().val
op.Name = p.val(p.next())
}
if p.peek(itemArgsOpen) {
@ -228,33 +262,9 @@ func (p *Parser) parseOpByType(ty parserType) (*Operation, error) {
}
}
if p.peek(itemObjClose) {
p.ignore()
}
return op, nil
}
func (p *Parser) parseOp() (*Operation, error) {
if p.peek(itemQuery, itemMutation, itemSub) == false {
err := fmt.Errorf("expecting a query, mutation or subscription (not '%s')", p.next().val)
return nil, err
}
item := p.next()
switch item.typ {
case itemQuery:
return p.parseOpByType(opQuery)
case itemMutation:
return p.parseOpByType(opMutate)
case itemSub:
return p.parseOpByType(opSub)
}
return nil, errors.New("unknown operation type")
}
func (p *Parser) parseFields(fields []Field) ([]Field, error) {
st := util.NewStack()
@ -278,6 +288,7 @@ func (p *Parser) parseFields(fields []Field) ([]Field, error) {
}
fields = append(fields, Field{ID: int32(len(fields))})
f := &fields[(len(fields) - 1)]
f.Args = f.argsA[:0]
f.Children = f.childrenA[:0]
@ -309,14 +320,14 @@ func (p *Parser) parseFields(fields []Field) ([]Field, error) {
func (p *Parser) parseField(f *Field) error {
var err error
f.Name = p.next().val
f.Name = p.val(p.next())
if p.peek(itemColon) {
p.ignore()
if p.peek(itemName) {
f.Alias = f.Name
f.Name = p.next().val
f.Name = p.val(p.next())
} else {
return errors.New("expecting an aliased field name")
}
@ -347,7 +358,7 @@ func (p *Parser) parseArgs(args []Arg) ([]Arg, error) {
if p.peek(itemName) == false {
return nil, errors.New("expecting an argument name")
}
args = append(args, Arg{Name: p.next().val})
args = append(args, Arg{Name: p.val(p.next())})
arg := &args[(len(args) - 1)]
if p.peek(itemColon) == false {
@ -414,7 +425,7 @@ func (p *Parser) parseObj() (*Node, error) {
if p.peek(itemName) == false {
return nil, errors.New("expecting an argument name")
}
nodeName := p.next().val
nodeName := p.val(p.next())
if p.peek(itemColon) == false {
return nil, errors.New("missing ':' after Field argument name")
@ -465,13 +476,21 @@ func (p *Parser) parseValue() (*Node, error) {
case itemVariable:
node.Type = nodeVar
default:
return nil, fmt.Errorf("expecting a number, string, object, list or variable as an argument value (not %s)", p.next().val)
return nil, fmt.Errorf("expecting a number, string, object, list or variable as an argument value (not %s)", p.val(p.next()))
}
node.Val = item.val
node.Val = p.val(item)
return node, nil
}
func (p *Parser) val(v item) string {
return b2s(p.input[v.pos:v.end])
}
func b2s(b []byte) string {
return *(*string)(unsafe.Pointer(&b))
}
func (t parserType) String() string {
var v string

View File

@ -47,12 +47,13 @@ func compareOp(op1, op2 Operation) error {
func TestCompile(t *testing.T) {
qcompile, _ := NewCompiler(Config{})
_, err := qcompile.CompileQuery(`query {
_, err := qcompile.CompileQuery([]byte(`
product(id: 15) {
id
name
}
}`)
}`))
if err != nil {
t.Fatal(err)
}
@ -60,7 +61,8 @@ func TestCompile(t *testing.T) {
func TestInvalidCompile(t *testing.T) {
qcompile, _ := NewCompiler(Config{})
_, err := qcompile.CompileQuery(`#`)
_, err := qcompile.CompileQuery([]byte(`#`))
if err == nil {
t.Fatal(errors.New("expecting an error"))
}
@ -68,13 +70,14 @@ func TestInvalidCompile(t *testing.T) {
func TestEmptyCompile(t *testing.T) {
qcompile, _ := NewCompiler(Config{})
_, err := qcompile.CompileQuery(``)
_, err := qcompile.CompileQuery([]byte(``))
if err == nil {
t.Fatal(errors.New("expecting an error"))
}
}
var gql = `query {
var gql = []byte(`
products(
# returns only 30 items
limit: 30,
@ -93,8 +96,7 @@ var gql = `query {
id
name
price
}
}`
}`)
func BenchmarkQCompile(b *testing.B) {
qcompile, _ := NewCompiler(Config{})

View File

@ -50,6 +50,7 @@ type Exp struct {
ListType ValType
ListVal []string
Children []*Exp
childrenA [5]*Exp
}
type OrderBy struct {
@ -144,7 +145,7 @@ func NewCompiler(c Config) (*Compiler, error) {
bl := make(map[string]struct{}, len(c.Blacklist))
for i := range c.Blacklist {
bl[strings.ToLower(c.Blacklist[i])] = struct{}{}
bl[c.Blacklist[i]] = struct{}{}
}
fl, err := compileFilter(c.DefaultFilter)
@ -159,9 +160,8 @@ func NewCompiler(c Config) (*Compiler, error) {
if err != nil {
return nil, err
}
k1 := strings.ToLower(k)
singular := flect.Singularize(k1)
plural := flect.Pluralize(k1)
singular := flect.Singularize(k)
plural := flect.Pluralize(k)
fm[singular] = fil
fm[plural] = fil
@ -170,11 +170,11 @@ func NewCompiler(c Config) (*Compiler, error) {
return &Compiler{fl, fm, bl, c.KeepArgs}, nil
}
func (com *Compiler) CompileQuery(query string) (*QCode, error) {
func (com *Compiler) Compile(query []byte) (*QCode, error) {
var qc QCode
var err error
op, err := ParseQuery(query)
op, err := Parse(query)
if err != nil {
return nil, err
}
@ -197,6 +197,25 @@ func (com *Compiler) CompileQuery(query string) (*QCode, error) {
return &qc, nil
}
func (com *Compiler) CompileQuery(query []byte) (*QCode, error) {
var err error
op, err := ParseQuery(query)
if err != nil {
return nil, err
}
qc := &QCode{}
qc.Query, err = com.compileQuery(op)
opPool.Put(op)
if err != nil {
return nil, err
}
return qc, nil
}
func (com *Compiler) compileQuery(op *Operation) (*Query, error) {
id := int32(0)
parentID := int32(0)
@ -226,15 +245,14 @@ func (com *Compiler) compileQuery(op *Operation) (*Query, error) {
}
field := &op.Fields[fid]
tn := strings.ToLower(field.Name)
if _, ok := com.bl[tn]; ok {
if _, ok := com.bl[field.Name]; ok {
continue
}
selects = append(selects, Select{
ID: id,
ParentID: parentID,
Table: tn,
Table: field.Name,
Children: make([]int32, 0, 5),
})
s := &selects[(len(selects) - 1)]
@ -259,9 +277,8 @@ func (com *Compiler) compileQuery(op *Operation) (*Query, error) {
for _, cid := range field.Children {
f := op.Fields[cid]
fn := strings.ToLower(f.Name)
if _, ok := com.bl[fn]; ok {
if _, ok := com.bl[f.Name]; ok {
continue
}
@ -271,7 +288,7 @@ func (com *Compiler) compileQuery(op *Operation) (*Query, error) {
continue
}
col := Column{Name: fn}
col := Column{Name: f.Name}
if len(f.Alias) != 0 {
col.FieldName = f.Alias
@ -298,8 +315,11 @@ func (com *Compiler) compileQuery(op *Operation) (*Query, error) {
if fil != nil && fil.Op != OpNop {
if root.Where != nil {
ex := &Exp{Op: OpAnd, Children: []*Exp{fil, root.Where}}
root.Where = ex
ow := root.Where
root.Where = &Exp{Op: OpAnd}
root.Where.Children = root.Where.childrenA[:2]
root.Where.Children[0] = fil
root.Where.Children[1] = ow
} else {
root.Where = fil
}
@ -322,9 +342,7 @@ func (com *Compiler) compileArgs(sel *Select, args []Arg) error {
for i := range args {
arg := &args[i]
an := strings.ToLower(arg.Name)
switch an {
switch arg.Name {
case "id":
if sel.ID == 0 {
err = com.compileArgID(sel, arg)
@ -348,7 +366,7 @@ func (com *Compiler) compileArgs(sel *Select, args []Arg) error {
}
if sel.Args != nil {
sel.Args[an] = arg.Val
sel.Args[arg.Name] = arg.Val
} else {
nodePool.Put(arg.Val)
}
@ -392,7 +410,7 @@ func (com *Compiler) compileArgNode(node *Node) (*Exp, error) {
}
if len(eT.node.Name) != 0 {
if _, ok := com.bl[strings.ToLower(eT.node.Name)]; ok {
if _, ok := com.bl[eT.node.Name]; ok {
continue
}
}
@ -468,7 +486,11 @@ func (com *Compiler) compileArgSearch(sel *Select, arg *Arg) error {
}
if sel.Where != nil {
sel.Where = &Exp{Op: OpAnd, Children: []*Exp{ex, sel.Where}}
ow := sel.Where
sel.Where = &Exp{Op: OpAnd}
sel.Where.Children = sel.Where.childrenA[:2]
sel.Where.Children[0] = ex
sel.Where.Children[1] = ow
} else {
sel.Where = ex
}
@ -484,7 +506,11 @@ func (com *Compiler) compileArgWhere(sel *Select, arg *Arg) error {
}
if sel.Where != nil {
sel.Where = &Exp{Op: OpAnd, Children: []*Exp{ex, sel.Where}}
ow := sel.Where
sel.Where = &Exp{Op: OpAnd}
sel.Where.Children = sel.Where.childrenA[:2]
sel.Where.Children[0] = ex
sel.Where.Children[1] = ow
} else {
sel.Where = ex
}
@ -515,7 +541,7 @@ func (com *Compiler) compileArgOrderBy(sel *Select, arg *Arg) error {
return fmt.Errorf("17: unexpected value %v (%t)", intf, intf)
}
if _, ok := com.bl[strings.ToLower(node.Name)]; ok {
if _, ok := com.bl[node.Name]; ok {
if !com.ka {
nodePool.Put(node)
}
@ -534,8 +560,7 @@ func (com *Compiler) compileArgOrderBy(sel *Select, arg *Arg) error {
ob := &OrderBy{}
val := strings.ToLower(node.Val)
switch val {
switch node.Val {
case "asc":
ob.Order = OrderAsc
case "desc":
@ -565,7 +590,7 @@ func (com *Compiler) compileArgOrderBy(sel *Select, arg *Arg) error {
func (com *Compiler) compileArgDistinctOn(sel *Select, arg *Arg) error {
node := arg.Val
if _, ok := com.bl[strings.ToLower(node.Name)]; ok {
if _, ok := com.bl[node.Name]; ok {
return nil
}
@ -619,7 +644,6 @@ func compileSub() (*Query, error) {
}
func newExp(st *util.Stack, eT *expT) (*Exp, error) {
ex := &Exp{}
node := eT.node
if len(node.Name) == 0 {
@ -627,11 +651,13 @@ func newExp(st *util.Stack, eT *expT) (*Exp, error) {
return nil, nil
}
name := strings.ToLower(node.Name)
name := node.Name
if name[0] == '_' {
name = name[1:]
}
ex := &Exp{}
switch name {
case "and":
ex.Op = OpAnd
@ -756,7 +782,7 @@ func setWhereColName(ex *Exp, node *Node) {
continue
}
if len(n.Name) != 0 {
k := strings.ToLower(n.Name)
k := n.Name
if k == "and" || k == "or" || k == "not" ||
k == "_and" || k == "_or" || k == "_not" {
continue
@ -778,8 +804,7 @@ func setOrderByColName(ob *OrderBy, node *Node) {
for n := node; n != nil; n = n.Parent {
if len(n.Name) != 0 {
k := strings.ToLower(n.Name)
list = append([]string{k}, list...)
list = append([]string{n.Name}, list...)
}
}
if len(list) != 0 {

View File

@ -36,7 +36,7 @@ type coreContext struct {
func (c *coreContext) handleReq(w io.Writer, req *http.Request) error {
var err error
qc, err := qcompile.CompileQuery(c.req.Query)
qc, err := qcompile.CompileQuery([]byte(c.req.Query))
if err != nil {
return err
}