Add support for websearch_to_tsquery in PG 11

This commit is contained in:
Vikram Rangnekar
2019-12-02 10:52:22 -05:00
parent 5593c66996
commit 5da79d91bf
15 changed files with 249 additions and 196 deletions

View File

@ -110,7 +110,7 @@ func graphQLFunc(query string, data interface{}, opt map[string]string) map[stri
var root []byte
if err = tx.QueryRow(c, finalSQL).Scan(&root); err != nil {
if err = tx.QueryRow(context.Background(), finalSQL).Scan(&root); err != nil {
errlog.Fatal().Err(err).Msg("sql query failed")
}

View File

@ -81,7 +81,7 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *stmt, error) {
useTx := useRoleQuery || conf.DB.SetUserID
if useTx {
if tx, err = db.Begin(c); err != nil {
if tx, err = db.Begin(context.Background()); err != nil {
return nil, nil, err
}
defer tx.Rollback(c) //nolint: errcheck
@ -122,9 +122,9 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *stmt, error) {
}
if useTx {
row = tx.QueryRow(c, ps.sd.SQL, vars...)
row = tx.QueryRow(context.Background(), ps.sd.SQL, vars...)
} else {
row = db.QueryRow(c, ps.sd.SQL, vars...)
row = db.QueryRow(context.Background(), ps.sd.SQL, vars...)
}
if mutation || anonQuery {
@ -146,7 +146,7 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *stmt, error) {
c.req.role = role
if useTx {
if err := tx.Commit(c); err != nil {
if err := tx.Commit(context.Background()); err != nil {
return nil, nil, err
}
}
@ -166,10 +166,10 @@ func (c *coreContext) resolveSQL() ([]byte, *stmt, error) {
useTx := useRoleQuery || conf.DB.SetUserID
if useTx {
if tx, err = db.Begin(c); err != nil {
if tx, err = db.Begin(context.Background()); err != nil {
return nil, nil, err
}
defer tx.Rollback(c) //nolint: errcheck
defer tx.Rollback(context.Background()) //nolint: errcheck
}
if conf.DB.SetUserID {
@ -215,9 +215,9 @@ func (c *coreContext) resolveSQL() ([]byte, *stmt, error) {
defaultRole := c.req.role
if useTx {
row = tx.QueryRow(c, finalSQL)
row = tx.QueryRow(context.Background(), finalSQL)
} else {
row = db.QueryRow(c, finalSQL)
row = db.QueryRow(context.Background(), finalSQL)
}
if len(stmts) == 1 {
@ -237,7 +237,7 @@ func (c *coreContext) resolveSQL() ([]byte, *stmt, error) {
}
if useTx {
if err := tx.Commit(c); err != nil {
if err := tx.Commit(context.Background()); err != nil {
return nil, nil, err
}
}
@ -263,7 +263,7 @@ func (c *coreContext) resolveSQL() ([]byte, *stmt, error) {
func (c *coreContext) executeRoleQuery(tx pgx.Tx) (string, error) {
var role string
row := tx.QueryRow(c, "_sg_get_role", c.req.role, 1)
row := tx.QueryRow(context.Background(), "_sg_get_role", c.req.role, 1)
if err := row.Scan(&role); err != nil {
return "", err
@ -320,6 +320,15 @@ func (c *coreContext) addTrace(sel []qcode.Select, id int32, st time.Time) {
append(c.res.Extensions.Tracing.Execution.Resolvers, tr)
}
func setLocalUserID(c context.Context, tx pgx.Tx) error {
var err error
if v := c.Value(userIDKey); v != nil {
_, err = tx.Exec(context.Background(), fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
}
return err
}
func parentFieldIds(h *xxhash.Digest, sel []qcode.Select, skipped uint32) (
[][]byte,
map[uint64]*qcode.Select) {
@ -363,15 +372,6 @@ func parentFieldIds(h *xxhash.Digest, sel []qcode.Select, skipped uint32) (
return fm, sm
}
func setLocalUserID(c context.Context, tx pgx.Tx) error {
var err error
if v := c.Value(userIDKey); v != nil {
_, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
}
return err
}
func isSkipped(n uint32, pos uint32) bool {
return ((n & (1 << pos)) != 0)
}

View File

@ -97,6 +97,10 @@ func buildMultiStmt(gql, vars []byte) ([]stmt, error) {
for i := 0; i < len(conf.Roles); i++ {
role := &conf.Roles[i]
if role.Name == "anon" {
continue
}
qc, err := qcompile.Compile(gql, role.Name)
if err != nil {
return nil, err
@ -127,8 +131,6 @@ func buildMultiStmt(gql, vars []byte) ([]stmt, error) {
//nolint: errcheck
func renderUserQuery(
stmts []stmt, vars map[string]json.RawMessage) (string, error) {
var err error
w := &bytes.Buffer{}
io.WriteString(w, `SELECT "_sg_auth_info"."role", (CASE "_sg_auth_info"."role" `)
@ -141,11 +143,7 @@ func renderUserQuery(
io.WriteString(w, `WHEN '`)
io.WriteString(w, s.role.Name)
io.WriteString(w, `' THEN (`)
s.skipped, err = pcompile.Compile(s.qc, w, psql.Variables(vars))
if err != nil {
return "", err
}
io.WriteString(w, s.sql)
io.WriteString(w, `) `)
}

View File

@ -23,21 +23,20 @@ var (
)
func initPreparedList() {
c := context.Background()
_preparedList = make(map[string]*preparedItem)
tx, err := db.Begin(c)
tx, err := db.Begin(context.Background())
if err != nil {
errlog.Fatal().Err(err).Send()
}
defer tx.Rollback(c) //nolint: errcheck
defer tx.Rollback(context.Background()) //nolint: errcheck
err = prepareRoleStmt(c, tx)
err = prepareRoleStmt(tx)
if err != nil {
errlog.Fatal().Err(err).Msg("failed to prepare get role statement")
}
if err := tx.Commit(c); err != nil {
if err := tx.Commit(context.Background()); err != nil {
errlog.Fatal().Err(err).Send()
}
@ -48,7 +47,7 @@ func initPreparedList() {
continue
}
err := prepareStmt(c, v.gql, v.vars)
err := prepareStmt(v.gql, v.vars)
if err == nil {
success++
continue
@ -66,15 +65,15 @@ func initPreparedList() {
success, len(_allowList.list))
}
func prepareStmt(c context.Context, gql string, vars []byte) error {
func prepareStmt(gql string, vars []byte) error {
qt := qcode.GetQType(gql)
q := []byte(gql)
tx, err := db.Begin(c)
tx, err := db.Begin(context.Background())
if err != nil {
return err
}
defer tx.Rollback(c) //nolint: errcheck
defer tx.Rollback(context.Background()) //nolint: errcheck
switch qt {
case qcode.QTQuery:
@ -83,7 +82,7 @@ func prepareStmt(c context.Context, gql string, vars []byte) error {
return err
}
err = prepare(c, tx, &stmts1[0], gqlHash(gql, vars, "user"))
err = prepare(tx, &stmts1[0], gqlHash(gql, vars, "user"))
if err != nil {
return err
}
@ -93,7 +92,7 @@ func prepareStmt(c context.Context, gql string, vars []byte) error {
return err
}
err = prepare(c, tx, &stmts2[0], gqlHash(gql, vars, "anon"))
err = prepare(tx, &stmts2[0], gqlHash(gql, vars, "anon"))
if err != nil {
return err
}
@ -105,7 +104,7 @@ func prepareStmt(c context.Context, gql string, vars []byte) error {
return err
}
err = prepare(c, tx, &stmts[0], gqlHash(gql, vars, role.Name))
err = prepare(tx, &stmts[0], gqlHash(gql, vars, role.Name))
if err != nil {
return err
}
@ -118,17 +117,17 @@ func prepareStmt(c context.Context, gql string, vars []byte) error {
logger.Debug().Msgf("Building prepared statement:\n %s\n%s", vars, gql)
}
if err := tx.Commit(c); err != nil {
if err := tx.Commit(context.Background()); err != nil {
return err
}
return nil
}
func prepare(c context.Context, tx pgx.Tx, st *stmt, key string) error {
func prepare(tx pgx.Tx, st *stmt, key string) error {
finalSQL, am := processTemplate(st.sql)
sd, err := tx.Prepare(c, "", finalSQL)
sd, err := tx.Prepare(context.Background(), "", finalSQL)
if err != nil {
return err
}
@ -142,7 +141,7 @@ func prepare(c context.Context, tx pgx.Tx, st *stmt, key string) error {
}
// nolint: errcheck
func prepareRoleStmt(c context.Context, tx pgx.Tx) error {
func prepareRoleStmt(tx pgx.Tx) error {
if len(conf.RolesQuery) == 0 {
return nil
}
@ -167,7 +166,7 @@ func prepareRoleStmt(c context.Context, tx pgx.Tx) error {
roleSQL, _ := processTemplate(w.String())
_, err := tx.Prepare(c, "_sg_get_role", roleSQL)
_, err := tx.Prepare(context.Background(), "_sg_get_role", roleSQL)
if err != nil {
return err
}

View File

@ -24,7 +24,6 @@ func initCompilers(c *config) (*qcode.Compiler, *psql.Compiler, error) {
conf := qcode.Config{
Blocklist: c.DB.Blocklist,
KeepArgs: false,
}
qc, err := qcode.NewCompiler(conf)