Optimize db queries limit use of transactions

This commit is contained in:
Vikram Rangnekar
2019-11-21 02:14:12 -05:00
parent 176514b5f1
commit a4c09dedd5
18 changed files with 298 additions and 196 deletions

View File

@ -13,25 +13,21 @@ func argMap(ctx *coreContext) func(w io.Writer, tag string) (int, error) {
switch tag {
case "user_id_provider":
if v := ctx.Value(userIDProviderKey); v != nil {
return stringArg(w, v.(string))
return io.WriteString(w, v.(string))
}
io.WriteString(w, "null")
return 0, nil
return io.WriteString(w, "null")
case "user_id":
if v := ctx.Value(userIDKey); v != nil {
return stringArg(w, v.(string))
return io.WriteString(w, v.(string))
}
io.WriteString(w, "null")
return 0, nil
return io.WriteString(w, "null")
case "user_role":
if v := ctx.Value(userRoleKey); v != nil {
return stringArg(w, v.(string))
return io.WriteString(w, v.(string))
}
io.WriteString(w, "null")
return 0, nil
return io.WriteString(w, "null")
}
fields := jsn.Get(ctx.req.Vars, [][]byte{[]byte(tag)})
@ -39,22 +35,7 @@ func argMap(ctx *coreContext) func(w io.Writer, tag string) (int, error) {
return 0, fmt.Errorf("variable '%s' not found", tag)
}
is := false
for i := range fields[0].Value {
c := fields[0].Value[i]
if c != ' ' {
is = (c == '"') || (c == '{') || (c == '[')
break
}
}
if is {
return stringArgB(w, fields[0].Value)
}
w.Write(fields[0].Value)
return 0, nil
return w.Write(fields[0].Value)
}
}
@ -100,23 +81,3 @@ func argList(ctx *coreContext, args [][]byte) []interface{} {
return vars
}
func stringArg(w io.Writer, v string) (int, error) {
if n, err := w.Write([]byte(`'`)); err != nil {
return n, err
}
if n, err := w.Write([]byte(v)); err != nil {
return n, err
}
return w.Write([]byte(`'`))
}
func stringArgB(w io.Writer, v []byte) (int, error) {
if n, err := w.Write([]byte(`'`)); err != nil {
return n, err
}
if n, err := w.Write(v); err != nil {
return n, err
}
return w.Write([]byte(`'`))
}

View File

@ -1,6 +1,7 @@
package serv
import (
"bytes"
"context"
"encoding/json"
"fmt"
@ -12,6 +13,7 @@ import (
"github.com/brianvoe/gofakeit"
"github.com/dop251/goja"
"github.com/spf13/cobra"
"github.com/valyala/fasttemplate"
)
func cmdDBSeed(cmd *cobra.Command, args []string) {
@ -57,22 +59,75 @@ func cmdDBSeed(cmd *cobra.Command, args []string) {
}
//func runFunc(call goja.FunctionCall) {
func graphQLFunc(query string, data interface{}) map[string]interface{} {
func graphQLFunc(query string, data interface{}, opt map[string]string) map[string]interface{} {
b, err := json.Marshal(data)
if err != nil {
logger.Fatal().Err(err).Msg("failed to json serialize")
}
c := &coreContext{Context: context.Background()}
ctx := context.Background()
if v, ok := opt["user_id"]; ok && len(v) != 0 {
ctx = context.WithValue(ctx, userIDKey, v)
}
var role string
if v, ok := opt["role"]; ok && len(v) != 0 {
role = v
} else {
role = "user"
}
c := &coreContext{Context: ctx}
c.req.Query = query
c.req.Vars = b
c.req.role = "user"
res, err := c.execQuery()
st, err := c.buildStmtByRole(role)
if err != nil {
logger.Fatal().Err(err).Msg("graphql query failed")
}
buf := &bytes.Buffer{}
t := fasttemplate.New(st.sql, openVar, closeVar)
_, err = t.ExecuteFunc(buf, argMap(c))
if err == errNoUserID {
logger.Fatal().Msg("query requires a user_id")
}
if err != nil {
logger.Fatal().Err(err).Send()
}
finalSQL := buf.String()
tx, err := db.Begin(c)
if err != nil {
logger.Fatal().Err(err).Send()
}
defer tx.Rollback(c)
// if err := c.setLocalUserID(tx); err != nil {
// return nil, 0, err
// }
var root []byte
if err = tx.QueryRow(c, finalSQL).Scan(&root); err != nil {
logger.Fatal().Err(err).Msg("sql query failed")
}
if err := tx.Commit(c); err != nil {
logger.Fatal().Err(err).Send()
}
res, err := c.execRemoteJoin(st.qc, st.skipped, root)
if err != nil {
logger.Fatal().Err(err).Msg("remote join failed")
}
val := make(map[string]interface{})
err = json.Unmarshal(res, &val)
@ -156,10 +211,9 @@ func setFakeFuncs(f *goja.Object) {
f.Set("transmission_gear_type", gofakeit.TransmissionGearType)
// Text
f.Set("word", gofakeit.Word)
f.Set("sentence", gofakeit.Sentence)
f.Set("paragrph", gofakeit.Paragraph)
f.Set("paragraph", gofakeit.Paragraph)
f.Set("question", gofakeit.Question)
f.Set("quote", gofakeit.Quote)

View File

@ -71,6 +71,12 @@ func (c *coreContext) execQuery() ([]byte, error) {
}
}
return c.execRemoteJoin(qc, skipped, data)
}
func (c *coreContext) execRemoteJoin(qc *qcode.QCode, skipped uint32, data []byte) ([]byte, error) {
var err error
if len(data) == 0 || skipped == 0 {
return data, nil
}
@ -114,11 +120,19 @@ func (c *coreContext) execQuery() ([]byte, error) {
}
func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) {
tx, err := db.Begin(c)
if err != nil {
return nil, nil, err
var tx pgx.Tx
var err error
mutation := isMutation(c.req.Query)
useRoleQuery := len(conf.RolesQuery) != 0 && mutation
useTx := useRoleQuery || conf.DB.SetUserID
if useTx {
if tx, err = db.Begin(c); err != nil {
return nil, nil, err
}
defer tx.Rollback(c)
}
defer tx.Rollback(c)
if conf.DB.SetUserID {
if err := c.setLocalUserID(tx); err != nil {
@ -127,8 +141,6 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) {
}
var role string
mutation := isMutation(c.req.Query)
useRoleQuery := len(conf.RolesQuery) != 0 && mutation
if useRoleQuery {
if role, err = c.executeRoleQuery(tx); err != nil {
@ -149,12 +161,19 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) {
}
var root []byte
var row pgx.Row
vars := argList(c, ps.args)
if mutation {
err = tx.QueryRow(c, ps.stmt.SQL, vars...).Scan(&root)
if useTx {
row = tx.QueryRow(c, ps.stmt.SQL, vars...)
} else {
err = tx.QueryRow(c, ps.stmt.SQL, vars...).Scan(&role, &root)
row = db.QueryRow(c, ps.stmt.SQL, vars...)
}
if mutation {
err = row.Scan(&root)
} else {
err = row.Scan(&role, &root)
}
logger.Debug().Str("default_role", c.req.role).Str("role", role).Msg(c.req.Query)
@ -165,22 +184,35 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) {
c.req.role = role
if err := tx.Commit(c); err != nil {
return nil, nil, err
if useTx {
if err := tx.Commit(c); err != nil {
return nil, nil, err
}
}
return root, ps, nil
}
func (c *coreContext) resolveSQL() ([]byte, uint32, error) {
tx, err := db.Begin(c)
if err != nil {
return nil, 0, err
}
defer tx.Rollback(c)
var tx pgx.Tx
var err error
mutation := isMutation(c.req.Query)
useRoleQuery := len(conf.RolesQuery) != 0 && mutation
useTx := useRoleQuery || conf.DB.SetUserID
if useTx {
if tx, err = db.Begin(c); err != nil {
return nil, 0, err
}
defer tx.Rollback(c)
}
if conf.DB.SetUserID {
if err := c.setLocalUserID(tx); err != nil {
return nil, 0, err
}
}
if useRoleQuery {
if c.req.role, err = c.executeRoleQuery(tx); err != nil {
@ -225,42 +257,36 @@ func (c *coreContext) resolveSQL() ([]byte, uint32, error) {
stime = time.Now()
}
if conf.DB.SetUserID {
if err := c.setLocalUserID(tx); err != nil {
return nil, 0, err
}
var root []byte
var role, defaultRole string
var row pgx.Row
if useTx {
row = tx.QueryRow(c, finalSQL)
} else {
row = db.QueryRow(c, finalSQL)
}
var root []byte
var role string
log := logger.Debug()
if mutation {
err = tx.QueryRow(c, finalSQL).Scan(&root)
log = log.Str("role", role)
err = row.Scan(&root)
} else {
err = tx.QueryRow(c, finalSQL).Scan(&role, &root)
log = log.Str("default_role", c.req.role).Str("role", role)
err = row.Scan(&role, &root)
defaultRole = c.req.role
c.req.role = role
}
log.Msg(c.req.Query)
logger.Debug().Str("default_role", defaultRole).Str("role", role).Msg(c.req.Query)
if err != nil {
return nil, 0, err
}
if err := tx.Commit(c); err != nil {
return nil, 0, err
}
if mutation {
st = findStmt(c.req.role, stmts)
} else {
st = &stmts[0]
if useTx {
if err := tx.Commit(c); err != nil {
return nil, 0, err
}
}
if conf.EnableTracing && len(st.qc.Selects) != 0 {

View File

@ -150,3 +150,39 @@ func (c *coreContext) buildStmt() ([]stmt, error) {
return stmts, nil
}
func (c *coreContext) buildStmtByRole(role string) (stmt, error) {
var st stmt
var err error
if len(role) == 0 {
return st, errors.New(`no role defined`)
}
var vars map[string]json.RawMessage
if len(c.req.Vars) != 0 {
if err := json.Unmarshal(c.req.Vars, &vars); err != nil {
return st, err
}
}
gql := []byte(c.req.Query)
st.qc, err = qcompile.Compile(gql, role)
if err != nil {
return st, err
}
w := &bytes.Buffer{}
st.skipped, err = pcompile.Compile(st.qc, w, psql.Variables(vars))
if err != nil {
return st, err
}
st.sql = w.String()
return st, nil
}

View File

@ -9,6 +9,7 @@ import (
"github.com/dosco/super-graph/qcode"
"github.com/jackc/pgconn"
"github.com/jackc/pgx/v4"
"github.com/valyala/fasttemplate"
)
@ -24,22 +25,35 @@ var (
)
func initPreparedList() {
ctx := context.Background()
tx, err := db.Begin(ctx)
if err != nil {
logger.Fatal().Err(err).Send()
}
defer tx.Rollback(ctx)
_preparedList = make(map[string]*preparedItem)
if err := prepareRoleStmt(); err != nil {
if err := prepareRoleStmt(ctx, tx); err != nil {
logger.Fatal().Err(err).Msg("failed to prepare get role statement")
}
for _, v := range _allowList.list {
err := prepareStmt(v.gql, v.vars)
err := prepareStmt(ctx, tx, v.gql, v.vars)
if err != nil {
logger.Warn().Str("gql", v.gql).Err(err).Send()
}
}
if err := tx.Commit(ctx); err != nil {
logger.Fatal().Err(err).Send()
}
logger.Info().Msgf("Registered %d queries from allow.list as prepared statements", len(_allowList.list))
}
func prepareStmt(gql string, varBytes json.RawMessage) error {
func prepareStmt(ctx context.Context, tx pgx.Tx, gql string, varBytes json.RawMessage) error {
if len(gql) == 0 {
return nil
}
@ -64,15 +78,7 @@ func prepareStmt(gql string, varBytes json.RawMessage) error {
finalSQL, am := processTemplate(s.sql)
ctx := context.Background()
tx, err := db.Begin(ctx)
if err != nil {
return err
}
defer tx.Rollback(ctx)
pstmt, err := tx.Prepare(ctx, "", finalSQL)
pstmt, err := tx.Prepare(c.Context, "", finalSQL)
if err != nil {
return err
}
@ -92,15 +98,12 @@ func prepareStmt(gql string, varBytes json.RawMessage) error {
qc: s.qc,
}
if err := tx.Commit(ctx); err != nil {
return err
}
}
return nil
}
func prepareRoleStmt() error {
func prepareRoleStmt(ctx context.Context, tx pgx.Tx) error {
if len(conf.RolesQuery) == 0 {
return nil
}
@ -125,15 +128,7 @@ func prepareRoleStmt() error {
roleSQL, _ := processTemplate(w.String())
ctx := context.Background()
tx, err := db.Begin(ctx)
if err != nil {
return err
}
defer tx.Rollback(ctx)
_, err = tx.Prepare(ctx, "_sg_get_role", roleSQL)
_, err := tx.Prepare(ctx, "_sg_get_role", roleSQL)
if err != nil {
return err
}
@ -142,19 +137,31 @@ func prepareRoleStmt() error {
}
func processTemplate(tmpl string) (string, [][]byte) {
t := fasttemplate.New(tmpl, `{{`, `}}`)
am := make([][]byte, 0, 5)
i := 0
st := struct {
vmap map[string]int
am [][]byte
i int
}{
vmap: make(map[string]int),
am: make([][]byte, 0, 5),
i: 0,
}
vmap := make(map[string]int)
return t.ExecuteFuncString(func(w io.Writer, tag string) (int, error) {
if n, ok := vmap[tag]; ok {
execFunc := func(w io.Writer, tag string) (int, error) {
if n, ok := st.vmap[tag]; ok {
return w.Write([]byte(fmt.Sprintf("$%d", n)))
}
am = append(am, []byte(tag))
i++
vmap[tag] = i
return w.Write([]byte(fmt.Sprintf("$%d", i)))
}), am
st.am = append(st.am, []byte(tag))
st.i++
st.vmap[tag] = st.i
return w.Write([]byte(fmt.Sprintf("$%d", st.i)))
}
t1 := fasttemplate.New(tmpl, `'{{`, `}}'`)
ts1 := t1.ExecuteFuncString(execFunc)
t2 := fasttemplate.New(ts1, `{{`, `}}`)
ts2 := t2.ExecuteFuncString(execFunc)
return ts2, st.am
}