Add support for websearch_to_tsquery
in PG 11
This commit is contained in:
@ -110,7 +110,7 @@ func graphQLFunc(query string, data interface{}, opt map[string]string) map[stri
|
||||
|
||||
var root []byte
|
||||
|
||||
if err = tx.QueryRow(c, finalSQL).Scan(&root); err != nil {
|
||||
if err = tx.QueryRow(context.Background(), finalSQL).Scan(&root); err != nil {
|
||||
errlog.Fatal().Err(err).Msg("sql query failed")
|
||||
}
|
||||
|
||||
|
38
serv/core.go
38
serv/core.go
@ -81,7 +81,7 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *stmt, error) {
|
||||
useTx := useRoleQuery || conf.DB.SetUserID
|
||||
|
||||
if useTx {
|
||||
if tx, err = db.Begin(c); err != nil {
|
||||
if tx, err = db.Begin(context.Background()); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer tx.Rollback(c) //nolint: errcheck
|
||||
@ -122,9 +122,9 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *stmt, error) {
|
||||
}
|
||||
|
||||
if useTx {
|
||||
row = tx.QueryRow(c, ps.sd.SQL, vars...)
|
||||
row = tx.QueryRow(context.Background(), ps.sd.SQL, vars...)
|
||||
} else {
|
||||
row = db.QueryRow(c, ps.sd.SQL, vars...)
|
||||
row = db.QueryRow(context.Background(), ps.sd.SQL, vars...)
|
||||
}
|
||||
|
||||
if mutation || anonQuery {
|
||||
@ -146,7 +146,7 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *stmt, error) {
|
||||
c.req.role = role
|
||||
|
||||
if useTx {
|
||||
if err := tx.Commit(c); err != nil {
|
||||
if err := tx.Commit(context.Background()); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
@ -166,10 +166,10 @@ func (c *coreContext) resolveSQL() ([]byte, *stmt, error) {
|
||||
useTx := useRoleQuery || conf.DB.SetUserID
|
||||
|
||||
if useTx {
|
||||
if tx, err = db.Begin(c); err != nil {
|
||||
if tx, err = db.Begin(context.Background()); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer tx.Rollback(c) //nolint: errcheck
|
||||
defer tx.Rollback(context.Background()) //nolint: errcheck
|
||||
}
|
||||
|
||||
if conf.DB.SetUserID {
|
||||
@ -215,9 +215,9 @@ func (c *coreContext) resolveSQL() ([]byte, *stmt, error) {
|
||||
defaultRole := c.req.role
|
||||
|
||||
if useTx {
|
||||
row = tx.QueryRow(c, finalSQL)
|
||||
row = tx.QueryRow(context.Background(), finalSQL)
|
||||
} else {
|
||||
row = db.QueryRow(c, finalSQL)
|
||||
row = db.QueryRow(context.Background(), finalSQL)
|
||||
}
|
||||
|
||||
if len(stmts) == 1 {
|
||||
@ -237,7 +237,7 @@ func (c *coreContext) resolveSQL() ([]byte, *stmt, error) {
|
||||
}
|
||||
|
||||
if useTx {
|
||||
if err := tx.Commit(c); err != nil {
|
||||
if err := tx.Commit(context.Background()); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
@ -263,7 +263,7 @@ func (c *coreContext) resolveSQL() ([]byte, *stmt, error) {
|
||||
|
||||
func (c *coreContext) executeRoleQuery(tx pgx.Tx) (string, error) {
|
||||
var role string
|
||||
row := tx.QueryRow(c, "_sg_get_role", c.req.role, 1)
|
||||
row := tx.QueryRow(context.Background(), "_sg_get_role", c.req.role, 1)
|
||||
|
||||
if err := row.Scan(&role); err != nil {
|
||||
return "", err
|
||||
@ -320,6 +320,15 @@ func (c *coreContext) addTrace(sel []qcode.Select, id int32, st time.Time) {
|
||||
append(c.res.Extensions.Tracing.Execution.Resolvers, tr)
|
||||
}
|
||||
|
||||
func setLocalUserID(c context.Context, tx pgx.Tx) error {
|
||||
var err error
|
||||
if v := c.Value(userIDKey); v != nil {
|
||||
_, err = tx.Exec(context.Background(), fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func parentFieldIds(h *xxhash.Digest, sel []qcode.Select, skipped uint32) (
|
||||
[][]byte,
|
||||
map[uint64]*qcode.Select) {
|
||||
@ -363,15 +372,6 @@ func parentFieldIds(h *xxhash.Digest, sel []qcode.Select, skipped uint32) (
|
||||
return fm, sm
|
||||
}
|
||||
|
||||
func setLocalUserID(c context.Context, tx pgx.Tx) error {
|
||||
var err error
|
||||
if v := c.Value(userIDKey); v != nil {
|
||||
_, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func isSkipped(n uint32, pos uint32) bool {
|
||||
return ((n & (1 << pos)) != 0)
|
||||
}
|
||||
|
@ -97,6 +97,10 @@ func buildMultiStmt(gql, vars []byte) ([]stmt, error) {
|
||||
for i := 0; i < len(conf.Roles); i++ {
|
||||
role := &conf.Roles[i]
|
||||
|
||||
if role.Name == "anon" {
|
||||
continue
|
||||
}
|
||||
|
||||
qc, err := qcompile.Compile(gql, role.Name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@ -127,8 +131,6 @@ func buildMultiStmt(gql, vars []byte) ([]stmt, error) {
|
||||
//nolint: errcheck
|
||||
func renderUserQuery(
|
||||
stmts []stmt, vars map[string]json.RawMessage) (string, error) {
|
||||
|
||||
var err error
|
||||
w := &bytes.Buffer{}
|
||||
|
||||
io.WriteString(w, `SELECT "_sg_auth_info"."role", (CASE "_sg_auth_info"."role" `)
|
||||
@ -141,11 +143,7 @@ func renderUserQuery(
|
||||
io.WriteString(w, `WHEN '`)
|
||||
io.WriteString(w, s.role.Name)
|
||||
io.WriteString(w, `' THEN (`)
|
||||
|
||||
s.skipped, err = pcompile.Compile(s.qc, w, psql.Variables(vars))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
io.WriteString(w, s.sql)
|
||||
io.WriteString(w, `) `)
|
||||
}
|
||||
|
||||
|
@ -23,21 +23,20 @@ var (
|
||||
)
|
||||
|
||||
func initPreparedList() {
|
||||
c := context.Background()
|
||||
_preparedList = make(map[string]*preparedItem)
|
||||
|
||||
tx, err := db.Begin(c)
|
||||
tx, err := db.Begin(context.Background())
|
||||
if err != nil {
|
||||
errlog.Fatal().Err(err).Send()
|
||||
}
|
||||
defer tx.Rollback(c) //nolint: errcheck
|
||||
defer tx.Rollback(context.Background()) //nolint: errcheck
|
||||
|
||||
err = prepareRoleStmt(c, tx)
|
||||
err = prepareRoleStmt(tx)
|
||||
if err != nil {
|
||||
errlog.Fatal().Err(err).Msg("failed to prepare get role statement")
|
||||
}
|
||||
|
||||
if err := tx.Commit(c); err != nil {
|
||||
if err := tx.Commit(context.Background()); err != nil {
|
||||
errlog.Fatal().Err(err).Send()
|
||||
}
|
||||
|
||||
@ -48,7 +47,7 @@ func initPreparedList() {
|
||||
continue
|
||||
}
|
||||
|
||||
err := prepareStmt(c, v.gql, v.vars)
|
||||
err := prepareStmt(v.gql, v.vars)
|
||||
if err == nil {
|
||||
success++
|
||||
continue
|
||||
@ -66,15 +65,15 @@ func initPreparedList() {
|
||||
success, len(_allowList.list))
|
||||
}
|
||||
|
||||
func prepareStmt(c context.Context, gql string, vars []byte) error {
|
||||
func prepareStmt(gql string, vars []byte) error {
|
||||
qt := qcode.GetQType(gql)
|
||||
q := []byte(gql)
|
||||
|
||||
tx, err := db.Begin(c)
|
||||
tx, err := db.Begin(context.Background())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback(c) //nolint: errcheck
|
||||
defer tx.Rollback(context.Background()) //nolint: errcheck
|
||||
|
||||
switch qt {
|
||||
case qcode.QTQuery:
|
||||
@ -83,7 +82,7 @@ func prepareStmt(c context.Context, gql string, vars []byte) error {
|
||||
return err
|
||||
}
|
||||
|
||||
err = prepare(c, tx, &stmts1[0], gqlHash(gql, vars, "user"))
|
||||
err = prepare(tx, &stmts1[0], gqlHash(gql, vars, "user"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -93,7 +92,7 @@ func prepareStmt(c context.Context, gql string, vars []byte) error {
|
||||
return err
|
||||
}
|
||||
|
||||
err = prepare(c, tx, &stmts2[0], gqlHash(gql, vars, "anon"))
|
||||
err = prepare(tx, &stmts2[0], gqlHash(gql, vars, "anon"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -105,7 +104,7 @@ func prepareStmt(c context.Context, gql string, vars []byte) error {
|
||||
return err
|
||||
}
|
||||
|
||||
err = prepare(c, tx, &stmts[0], gqlHash(gql, vars, role.Name))
|
||||
err = prepare(tx, &stmts[0], gqlHash(gql, vars, role.Name))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -118,17 +117,17 @@ func prepareStmt(c context.Context, gql string, vars []byte) error {
|
||||
logger.Debug().Msgf("Building prepared statement:\n %s\n%s", vars, gql)
|
||||
}
|
||||
|
||||
if err := tx.Commit(c); err != nil {
|
||||
if err := tx.Commit(context.Background()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func prepare(c context.Context, tx pgx.Tx, st *stmt, key string) error {
|
||||
func prepare(tx pgx.Tx, st *stmt, key string) error {
|
||||
finalSQL, am := processTemplate(st.sql)
|
||||
|
||||
sd, err := tx.Prepare(c, "", finalSQL)
|
||||
sd, err := tx.Prepare(context.Background(), "", finalSQL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -142,7 +141,7 @@ func prepare(c context.Context, tx pgx.Tx, st *stmt, key string) error {
|
||||
}
|
||||
|
||||
// nolint: errcheck
|
||||
func prepareRoleStmt(c context.Context, tx pgx.Tx) error {
|
||||
func prepareRoleStmt(tx pgx.Tx) error {
|
||||
if len(conf.RolesQuery) == 0 {
|
||||
return nil
|
||||
}
|
||||
@ -167,7 +166,7 @@ func prepareRoleStmt(c context.Context, tx pgx.Tx) error {
|
||||
|
||||
roleSQL, _ := processTemplate(w.String())
|
||||
|
||||
_, err := tx.Prepare(c, "_sg_get_role", roleSQL)
|
||||
_, err := tx.Prepare(context.Background(), "_sg_get_role", roleSQL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -24,7 +24,6 @@ func initCompilers(c *config) (*qcode.Compiler, *psql.Compiler, error) {
|
||||
|
||||
conf := qcode.Config{
|
||||
Blocklist: c.DB.Blocklist,
|
||||
KeepArgs: false,
|
||||
}
|
||||
|
||||
qc, err := qcode.NewCompiler(conf)
|
||||
|
Reference in New Issue
Block a user