From 6536b12858bb63f6002ed787ba4281b42f0ca5d5 Mon Sep 17 00:00:00 2001 From: Vikram Rangnekar Date: Sat, 6 Apr 2019 02:35:08 -0400 Subject: [PATCH] Add query support for ts_rank and ts_headline --- psql/psql.go | 106 ++++++++++++++++++++++++++-------------------- psql/psql_test.go | 3 +- qcode/qcode.go | 17 +++++--- serv/serv.go | 2 +- 4 files changed, 74 insertions(+), 54 deletions(-) diff --git a/psql/psql.go b/psql/psql.go index 6095c04..891199b 100644 --- a/psql/psql.go +++ b/psql/psql.go @@ -23,9 +23,10 @@ func NewCompiler(schema *DBSchema, vars Variables) *Compiler { func (c *Compiler) Compile(w io.Writer, qc *qcode.QCode) error { st := util.NewStack() + ti, _ := c.schema.GetTable(qc.Query.Select.Table) st.Push(&selectBlockClose{nil, qc.Query.Select}) - st.Push(&selectBlock{nil, qc.Query.Select, c}) + st.Push(&selectBlock{nil, qc.Query.Select, ti, c}) fmt.Fprintf(w, `SELECT json_object_agg('%s', %s) FROM (`, qc.Query.Select.FieldName, qc.Query.Select.Table) @@ -43,10 +44,15 @@ func (c *Compiler) Compile(w io.Writer, qc *qcode.QCode) error { v.render(w, c.schema, childCols, childIDs) for i := range childIDs { + ti, err := c.schema.GetTable(v.sel.Table) + if err != nil { + continue + } sub := v.sel.Joins[childIDs[i]] + st.Push(&joinClose{sub}) st.Push(&selectBlockClose{v.sel, sub}) - st.Push(&selectBlock{v.sel, sub, c}) + st.Push(&selectBlock{v.sel, sub, ti, c}) st.Push(&joinOpen{sub}) } case *selectBlockClose: @@ -103,6 +109,7 @@ func (c *Compiler) relationshipColumns(parent *qcode.Select) ( type selectBlock struct { parent *qcode.Select sel *qcode.Select + ti *DBTableInfo *Compiler } @@ -267,24 +274,37 @@ func (v *selectBlock) renderBaseSelect(w io.Writer, schema *DBSchema, childCols isFil := v.sel.Where != nil isAgg := false + searchVal := findArgVal(v.sel, "search") + io.WriteString(w, " FROM (SELECT ") for i, col := range v.sel.Cols { cn := col.Name - fn := "" + _, isRealCol := v.schema.ColMap[TCKey{v.sel.Table, cn}] - if _, ok := v.schema.ColMap[TCKey{v.sel.Table, cn}]; !ok { - pl := funcPrefixLen(cn) - if pl == 0 { - continue + if !isRealCol { + switch { + case searchVal != nil && cn == "search_rank": + cn = v.ti.TSVCol + fmt.Fprintf(w, `ts_rank("%s"."%s", to_tsquery('%s')) AS %s`, + v.sel.Table, cn, searchVal.Val, col.Name) + + case searchVal != nil && strings.HasPrefix(cn, "search_headline_"): + cn = cn[16:] + fmt.Fprintf(w, `ts_headline("%s"."%s", to_tsquery('%s')) AS %s`, + v.sel.Table, cn, searchVal.Val, col.Name) + + default: + pl := funcPrefixLen(cn) + if pl == 0 { + fmt.Fprintf(w, `'%s not defined' AS %s`, cn, col.Name) + } else { + isAgg = true + fn := cn[0 : pl-1] + cn := cn[pl:] + fmt.Fprintf(w, `%s("%s"."%s") AS %s`, fn, v.sel.Table, cn, col.Name) + } } - isAgg = true - fn = cn[0 : pl-1] - cn = cn[pl:] - } - - if len(fn) != 0 { - fmt.Fprintf(w, `%s("%s"."%s") AS %s`, fn, v.sel.Table, cn, col.Name) } else { groupBy = append(groupBy, i) fmt.Fprintf(w, `"%s"."%s"`, v.sel.Table, cn) @@ -396,32 +416,6 @@ func (v *selectBlock) renderRelationship(w io.Writer, schema *DBSchema) { } func (v *selectBlock) renderWhere(w io.Writer) error { - if v.sel.Where.Op == qcode.OpEqID { - t, err := v.schema.GetTable(v.sel.Table) - if err != nil { - return err - } - if len(t.PrimaryCol) == 0 { - return fmt.Errorf("no primary key column defined for %s", v.sel.Table) - } - - fmt.Fprintf(w, `(("%s") = ('%s'))`, t.PrimaryCol, v.sel.Where.Val) - return nil - } - - if v.sel.Where.Op == qcode.OpTsQuery { - t, err := v.schema.GetTable(v.sel.Table) - if err != nil { - return err - } - if len(t.TSVCol) == 0 { - return fmt.Errorf("no tsv column defined for %s", v.sel.Table) - } - - fmt.Fprintf(w, `(("%s") @@ to_tsquery('%s'))`, t.TSVCol, v.sel.Where.Val) - return nil - } - st := util.NewStack() if v.sel.Where != nil { @@ -465,7 +459,7 @@ func (v *selectBlock) renderWhere(w io.Writer) error { if val.NestedCol { fmt.Fprintf(w, `(("%s") `, val.Col) - } else { + } else if len(val.Col) != 0 { fmt.Fprintf(w, `(("%s"."%s") `, v.sel.Table, val.Col) } valExists := true @@ -511,11 +505,25 @@ func (v *selectBlock) renderWhere(w io.Writer) error { io.WriteString(w, `?&`) case qcode.OpIsNull: if strings.EqualFold(val.Val, "true") { - io.WriteString(w, `IS NULL`) + io.WriteString(w, `IS NULL)`) } else { - io.WriteString(w, `IS NOT NULL`) + io.WriteString(w, `IS NOT NULL)`) } valExists = false + case qcode.OpEqID: + if len(v.ti.PrimaryCol) == 0 { + return fmt.Errorf("no primary key column defined for %s", v.sel.Table) + } + fmt.Fprintf(w, `(("%s") = ('%s'))`, v.ti.PrimaryCol, val.Val) + valExists = false + case qcode.OpTsQuery: + if len(v.ti.TSVCol) == 0 { + return fmt.Errorf("no tsv column defined for %s", v.sel.Table) + } + + fmt.Fprintf(w, `(("%s") @@ to_tsquery('%s'))`, v.ti.TSVCol, val.Val) + valExists = false + default: return fmt.Errorf("[Where] unexpected op code %d", val.Op) } @@ -526,10 +534,9 @@ func (v *selectBlock) renderWhere(w io.Writer) error { } else { renderVal(w, val, v.vars) } + io.WriteString(w, `)`) } - io.WriteString(w, `)`) - default: return fmt.Errorf("[Where] unexpected value encountered %v", intf) } @@ -640,3 +647,12 @@ func funcPrefixLen(fn string) int { } return 0 } + +func findArgVal(sel *qcode.Select, name string) *qcode.Node { + for i := range sel.Args { + if sel.Args[i].Name == name { + return sel.Args[i].Val + } + } + return nil +} diff --git a/psql/psql_test.go b/psql/psql_test.go index 054243d..945f205 100644 --- a/psql/psql_test.go +++ b/psql/psql_test.go @@ -69,7 +69,8 @@ func TestMain(m *testing.M) { &DBColumn{ID: 4, Name: "price", Type: "numeric(7,2)", NotNull: false, PrimaryKey: false, Uniquekey: false, FKeyTable: "", FKeyColID: []int(nil)}, &DBColumn{ID: 5, Name: "user_id", Type: "bigint", NotNull: false, PrimaryKey: false, Uniquekey: false, FKeyTable: "users", FKeyColID: []int{1}}, &DBColumn{ID: 6, Name: "created_at", Type: "timestamp without time zone", NotNull: true, PrimaryKey: false, Uniquekey: false, FKeyTable: "", FKeyColID: []int(nil)}, - &DBColumn{ID: 7, Name: "updated_at", Type: "timestamp without time zone", NotNull: true, PrimaryKey: false, Uniquekey: false, FKeyTable: "", FKeyColID: []int(nil)}}, + &DBColumn{ID: 7, Name: "updated_at", Type: "timestamp without time zone", NotNull: true, PrimaryKey: false, Uniquekey: false, FKeyTable: "", FKeyColID: []int(nil)}, + &DBColumn{ID: 8, Name: "tsv", Type: "tsvector", NotNull: false, PrimaryKey: false, Uniquekey: false, FKeyTable: "", FKeyColID: []int(nil)}}, []*DBColumn{ &DBColumn{ID: 1, Name: "id", Type: "bigint", NotNull: true, PrimaryKey: true, Uniquekey: false, FKeyTable: "", FKeyColID: []int(nil)}, &DBColumn{ID: 2, Name: "customer_id", Type: "bigint", NotNull: false, PrimaryKey: false, Uniquekey: false, FKeyTable: "customers", FKeyColID: []int{1}}, diff --git a/qcode/qcode.go b/qcode/qcode.go index fefd6ba..84c1e13 100644 --- a/qcode/qcode.go +++ b/qcode/qcode.go @@ -25,6 +25,7 @@ type Column struct { type Select struct { ID int16 + Args []*Arg AsList bool Table string Singular string @@ -332,12 +333,16 @@ func (com *Compiler) compileQuery(op *Operation) (*Query, error) { func (com *Compiler) compileArgs(sel *Select, args []*Arg) error { var err error + ad := make(map[string]struct{}) for i := range args { if args[i] == nil { return fmt.Errorf("[Args] unexpected nil argument found") } an := strings.ToLower(args[i].Name) + if _, ok := ad[an]; ok { + continue + } switch an { case "id": @@ -345,9 +350,7 @@ func (com *Compiler) compileArgs(sel *Select, args []*Arg) error { err = com.compileArgID(sel, args[i]) } case "search": - if sel.ID == int16(0) { - err = com.compileArgSearch(sel, args[i]) - } + err = com.compileArgSearch(sel, args[i]) case "where": err = com.compileArgWhere(sel, args[i]) case "orderby", "order_by", "order": @@ -363,6 +366,8 @@ func (com *Compiler) compileArgs(sel *Select, args []*Arg) error { if err != nil { return err } + + ad[an] = struct{}{} } return nil @@ -446,16 +451,14 @@ func (com *Compiler) compileArgID(sel *Select, arg *Arg) error { } func (com *Compiler) compileArgSearch(sel *Select, arg *Arg) error { - if sel.Where != nil && sel.Where.Op == OpTsQuery { - return nil - } - ex := &Exp{ Op: OpTsQuery, Type: ValStr, Val: arg.Val.Val, } + sel.Args = append(sel.Args, arg) + if sel.Where != nil { sel.Where = &Exp{Op: OpAnd, Children: []*Exp{ex, sel.Where}} } else { diff --git a/serv/serv.go b/serv/serv.go index fda1f97..38bef61 100644 --- a/serv/serv.go +++ b/serv/serv.go @@ -16,7 +16,7 @@ import ( "github.com/spf13/viper" ) -//go:generate esc -o static.go -prefix ../web/build -private -pkg serv ../web/build +//go:generate esc -o static.go -ignore \\.DS_Store -prefix ../web/build -private -pkg serv ../web/build const ( authFailBlockAlways = iota + 1