Add support for websearch_to_tsquery in PG 11

This commit is contained in:
Vikram Rangnekar
2019-12-02 10:52:22 -05:00
parent 9140e597e1
commit 6029c5e05c
15 changed files with 249 additions and 196 deletions

View File

@ -178,9 +178,10 @@ func TestMain(m *testing.M) {
}
schema := &DBSchema{
t: make(map[string]*DBTableInfo),
rm: make(map[string]map[string]*DBRel),
al: make(map[string]struct{}),
ver: 110000,
t: make(map[string]*DBTableInfo),
rm: make(map[string]map[string]*DBRel),
al: make(map[string]struct{}),
}
aliases := map[string][]string{

View File

@ -189,6 +189,11 @@ func (co *Compiler) compileQuery(qc *qcode.QCode, w io.Writer) (uint32, error) {
}
}
if len(sel.Args) != 0 {
for _, v := range sel.Args {
qcode.FreeNode(v)
}
}
}
}
@ -515,36 +520,54 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
if isSearch {
switch {
case cn == "search_rank":
if len(sel.Allowed) != 0 {
if _, ok := sel.Allowed[cn]; !ok {
continue
}
}
cn = ti.TSVCol
arg := sel.Args["search"]
if i != 0 {
io.WriteString(c.w, `, `)
}
//fmt.Fprintf(w, `ts_rank("%s"."%s", to_tsquery('%s')) AS %s`,
//fmt.Fprintf(w, `ts_rank("%s"."%s", websearch_to_tsquery('%s')) AS %s`,
//c.sel.Name, cn, arg.Val, col.Name)
io.WriteString(c.w, `ts_rank(`)
colWithTable(c.w, ti.Name, cn)
io.WriteString(c.w, `, to_tsquery('`)
if c.schema.ver >= 110000 {
io.WriteString(c.w, `, websearch_to_tsquery('`)
} else {
io.WriteString(c.w, `, to_tsquery('`)
}
io.WriteString(c.w, arg.Val)
io.WriteString(c.w, `')`)
io.WriteString(c.w, `'))`)
alias(c.w, col.Name)
i++
case strings.HasPrefix(cn, "search_headline_"):
cn = cn[16:]
cn1 := cn[16:]
if len(sel.Allowed) != 0 {
if _, ok := sel.Allowed[cn1]; !ok {
continue
}
}
arg := sel.Args["search"]
if i != 0 {
io.WriteString(c.w, `, `)
}
//fmt.Fprintf(w, `ts_headline("%s"."%s", to_tsquery('%s')) AS %s`,
//fmt.Fprintf(w, `ts_headline("%s"."%s", websearch_to_tsquery('%s')) AS %s`,
//c.sel.Name, cn, arg.Val, col.Name)
io.WriteString(c.w, `ts_headlinek(`)
colWithTable(c.w, ti.Name, cn)
io.WriteString(c.w, `, to_tsquery('`)
io.WriteString(c.w, `ts_headline(`)
colWithTable(c.w, ti.Name, cn1)
if c.schema.ver >= 110000 {
io.WriteString(c.w, `, websearch_to_tsquery('`)
} else {
io.WriteString(c.w, `, to_tsquery('`)
}
io.WriteString(c.w, arg.Val)
io.WriteString(c.w, `')`)
io.WriteString(c.w, `'))`)
alias(c.w, col.Name)
i++
@ -693,6 +716,7 @@ func (c *compilerContext) renderBaseSelect(sel *qcode.Select, ti *DBTableInfo,
//fmt.Fprintf(w, `) AS "%s_%d"`, c.sel.Name, c.sel.ID)
io.WriteString(c.w, `)`)
aliasWithID(c.w, ti.Name, sel.ID)
return nil
}
@ -939,6 +963,7 @@ func (c *compilerContext) renderOp(ex *qcode.Exp, sel *qcode.Select, ti *DBTable
io.WriteString(c.w, `IS NOT NULL)`)
}
return nil
case qcode.OpEqID:
if len(ti.PrimaryCol) == 0 {
return fmt.Errorf("no primary key column defined for %s", ti.Name)
@ -951,6 +976,7 @@ func (c *compilerContext) renderOp(ex *qcode.Exp, sel *qcode.Select, ti *DBTable
colWithTable(c.w, ti.Name, ti.PrimaryCol)
//io.WriteString(c.w, ti.PrimaryCol)
io.WriteString(c.w, `) =`)
case qcode.OpTsQuery:
if len(ti.TSVCol) == 0 {
return fmt.Errorf("no tsv column defined for %s", ti.Name)
@ -958,10 +984,14 @@ func (c *compilerContext) renderOp(ex *qcode.Exp, sel *qcode.Select, ti *DBTable
if _, ok = ti.Columns[ti.TSVCol]; !ok {
return fmt.Errorf("no tsv column '%s' found ", ti.TSVCol)
}
//fmt.Fprintf(w, `(("%s") @@ to_tsquery('%s'))`, c.ti.TSVCol, val.Val)
io.WriteString(c.w, `(("`)
io.WriteString(c.w, ti.TSVCol)
io.WriteString(c.w, `") @@ to_tsquery('`)
//fmt.Fprintf(w, `(("%s") @@ websearch_to_tsquery('%s'))`, c.ti.TSVCol, val.Val)
io.WriteString(c.w, `((`)
colWithTable(c.w, ti.Name, ti.TSVCol)
if c.schema.ver >= 110000 {
io.WriteString(c.w, `) @@ websearch_to_tsquery('`)
} else {
io.WriteString(c.w, `) @@ to_tsquery('`)
}
io.WriteString(c.w, ex.Val)
io.WriteString(c.w, `'))`)
return nil

View File

@ -142,15 +142,17 @@ func fetchByID(t *testing.T) {
func searchQuery(t *testing.T) {
gql := `query {
products(search: "Imperial") {
products(search: "ale") {
id
name
search_rank
search_headline_description
}
}`
sql := `SELECT json_object_agg('products', json_0) FROM (SELECT coalesce(json_agg("json_0"), '[]') AS "json_0" FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name" FROM "products" WHERE ((("products"."price") > 0) AND (("products"."price") < 8) AND (("tsv") @@ to_tsquery('Imperial'))) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "json_agg_0") AS "sel_0"`
sql := `SELECT json_object_agg('products', json_0) FROM (SELECT coalesce(json_agg("json_0"), '[]') AS "json_0" FROM (SELECT row_to_json((SELECT "json_row_0" FROM (SELECT "products_0"."id" AS "id", "products_0"."name" AS "name", "products_0"."search_rank" AS "search_rank", "products_0"."search_headline_description" AS "search_headline_description") AS "json_row_0")) AS "json_0" FROM (SELECT "products"."id", "products"."name", ts_rank("products"."tsv", websearch_to_tsquery('ale')) AS "search_rank", ts_headline("products"."description", websearch_to_tsquery('ale')) AS "search_headline_description" FROM "products" WHERE ((("products"."tsv") @@ websearch_to_tsquery('ale'))) LIMIT ('20') :: integer) AS "products_0" LIMIT ('20') :: integer) AS "json_agg_0") AS "sel_0"`
resSQL, err := compileGQLToPSQL(gql, nil, "user")
resSQL, err := compileGQLToPSQL(gql, nil, "admin")
if err != nil {
t.Fatal(err)
}

View File

@ -3,6 +3,7 @@ package psql
import (
"context"
"fmt"
"strconv"
"strings"
"github.com/gobuffalo/flect"
@ -144,9 +145,10 @@ ORDER BY id;`
}
type DBSchema struct {
t map[string]*DBTableInfo
rm map[string]map[string]*DBRel
al map[string]struct{}
ver int
t map[string]*DBTableInfo
rm map[string]map[string]*DBRel
al map[string]struct{}
}
type DBTableInfo struct {
@ -184,10 +186,22 @@ func NewDBSchema(db *pgxpool.Pool, aliases map[string][]string) (*DBSchema, erro
dbc, err := db.Acquire(context.Background())
if err != nil {
return nil, fmt.Errorf("error acquiring connection from pool")
return nil, fmt.Errorf("error acquiring connection from pool: %w", err)
}
defer dbc.Release()
var version string
err = dbc.QueryRow(context.Background(), `SHOW server_version_num`).Scan(&version)
if err != nil {
return nil, fmt.Errorf("error fetching version: %w", err)
}
schema.ver, err = strconv.Atoi(version)
if err != nil {
return nil, err
}
tables, err := GetTables(dbc)
if err != nil {
return nil, err