fix: prepared statements not working in prod mode
This commit is contained in:
parent
a6691de1b7
commit
c400461835
|
@ -94,16 +94,13 @@ func (c *scontext) execQuery() ([]byte, error) {
|
||||||
|
|
||||||
if c.sg.conf.UseAllowList {
|
if c.sg.conf.UseAllowList {
|
||||||
data, st, err = c.resolvePreparedSQL()
|
data, st, err = c.resolvePreparedSQL()
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
} else {
|
} else {
|
||||||
data, st, err = c.resolveSQL()
|
data, st, err = c.resolveSQL()
|
||||||
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
if len(data) == 0 || st.skipped == 0 {
|
if len(data) == 0 || st.skipped == 0 {
|
||||||
return data, nil
|
return data, nil
|
||||||
|
|
|
@ -58,21 +58,14 @@ func (sg *SuperGraph) initPrepared() error {
|
||||||
}
|
}
|
||||||
|
|
||||||
err := sg.prepareStmt(v)
|
err := sg.prepareStmt(v)
|
||||||
if err == nil {
|
if err != nil {
|
||||||
|
sg.log.Printf("WRN %s: %v", v.Name, err)
|
||||||
|
} else {
|
||||||
success++
|
success++
|
||||||
continue
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// if len(v.Vars) == 0 {
|
sg.log.Printf("INF allow list: prepared %d / %d queries", success, len(list))
|
||||||
// logger.Warn().Err(err).Msg(v.Query)
|
|
||||||
// } else {
|
|
||||||
// logger.Warn().Err(err).Msgf("%s %s", v.Vars, v.Query)
|
|
||||||
// }
|
|
||||||
}
|
|
||||||
|
|
||||||
// logger.Info().
|
|
||||||
// Msgf("Registered %d of %d queries from allow.list as prepared statements",
|
|
||||||
// success, len(list))
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -84,13 +77,6 @@ func (sg *SuperGraph) prepareStmt(item allow.Item) error {
|
||||||
|
|
||||||
qt := qcode.GetQType(query)
|
qt := qcode.GetQType(query)
|
||||||
ct := context.Background()
|
ct := context.Background()
|
||||||
|
|
||||||
tx, err := sg.db.BeginTx(ct, nil)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer tx.Rollback() //nolint: errcheck
|
|
||||||
|
|
||||||
switch qt {
|
switch qt {
|
||||||
case qcode.QTQuery:
|
case qcode.QTQuery:
|
||||||
var stmts1 []stmt
|
var stmts1 []stmt
|
||||||
|
@ -108,7 +94,7 @@ func (sg *SuperGraph) prepareStmt(item allow.Item) error {
|
||||||
|
|
||||||
//logger.Debug().Msgf("Prepared statement 'query %s' (user)", item.Name)
|
//logger.Debug().Msgf("Prepared statement 'query %s' (user)", item.Name)
|
||||||
|
|
||||||
err = sg.prepare(ct, tx, stmts1, stmtHash(item.Name, "user"))
|
err = sg.prepare(ct, stmts1, stmtHash(item.Name, "user"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -124,7 +110,7 @@ func (sg *SuperGraph) prepareStmt(item allow.Item) error {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = sg.prepare(ct, tx, stmts2, stmtHash(item.Name, "anon"))
|
err = sg.prepare(ct, stmts2, stmtHash(item.Name, "anon"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
@ -135,36 +121,26 @@ func (sg *SuperGraph) prepareStmt(item allow.Item) error {
|
||||||
// logger.Debug().Msgf("Prepared statement 'mutation %s' (%s)", item.Name, role.Name)
|
// logger.Debug().Msgf("Prepared statement 'mutation %s' (%s)", item.Name, role.Name)
|
||||||
|
|
||||||
stmts, err := sg.buildRoleStmt(qb, vars, role.Name)
|
stmts, err := sg.buildRoleStmt(qb, vars, role.Name)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// if len(item.Vars) == 0 {
|
return err
|
||||||
// logger.Warn().Err(err).Msg(item.Query)
|
|
||||||
// } else {
|
|
||||||
// logger.Warn().Err(err).Msgf("%s %s", item.Vars, item.Query)
|
|
||||||
// }
|
|
||||||
continue
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = sg.prepare(ct, tx, stmts, stmtHash(item.Name, role.Name))
|
err = sg.prepare(ct, stmts, stmtHash(item.Name, role.Name))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := tx.Commit(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (sg *SuperGraph) prepare(ct context.Context, tx *sql.Tx, st []stmt, key string) error {
|
func (sg *SuperGraph) prepare(ct context.Context, st []stmt, key string) error {
|
||||||
finalSQL, am := processTemplate(st[0].sql)
|
finalSQL, am := processTemplate(st[0].sql)
|
||||||
|
|
||||||
sd, err := tx.Prepare(finalSQL)
|
sd, err := sg.db.Prepare(finalSQL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return fmt.Errorf("prepare failed: %v: %s", err, finalSQL)
|
||||||
}
|
}
|
||||||
|
|
||||||
sg.prepared[key] = &preparedItem{
|
sg.prepared[key] = &preparedItem{
|
||||||
|
@ -256,6 +232,8 @@ func (sg *SuperGraph) initAllowList() error {
|
||||||
sg.log.Printf("WRN allow list disabled no file specified")
|
sg.log.Printf("WRN allow list disabled no file specified")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// When list is not eabled it is still created and
|
||||||
|
// and new queries are saved to it.
|
||||||
if !sg.conf.UseAllowList {
|
if !sg.conf.UseAllowList {
|
||||||
ac = allow.Config{CreateIfNotExists: true, Persist: true}
|
ac = allow.Config{CreateIfNotExists: true, Persist: true}
|
||||||
}
|
}
|
||||||
|
|
|
@ -24,10 +24,6 @@ func cmdServ(cmd *cobra.Command, args []string) {
|
||||||
fatalInProd(err, "failed to connect to database")
|
fatalInProd(err, "failed to connect to database")
|
||||||
}
|
}
|
||||||
|
|
||||||
// if conf != nil && db != nil {
|
|
||||||
// initResolvers()
|
|
||||||
// }
|
|
||||||
|
|
||||||
sg, err = core.NewSuperGraph(&conf.Core, db)
|
sg, err = core.NewSuperGraph(&conf.Core, db)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
fatalInProd(err, "failed to initialize Super Graph")
|
fatalInProd(err, "failed to initialize Super Graph")
|
||||||
|
|
|
@ -49,10 +49,6 @@ func ReadInConfig(configFile string) (*Config, error) {
|
||||||
return nil, fmt.Errorf("failed to decode config, %v", err)
|
return nil, fmt.Errorf("failed to decode config, %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(c.Core.AllowListFile) == 0 {
|
|
||||||
c.Core.AllowListFile = path.Join(cpath, "allow.list")
|
|
||||||
}
|
|
||||||
|
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -8,6 +8,7 @@ import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"path"
|
"path"
|
||||||
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -21,7 +22,12 @@ const (
|
||||||
)
|
)
|
||||||
|
|
||||||
func initConf() (*Config, error) {
|
func initConf() (*Config, error) {
|
||||||
c, err := ReadInConfig(path.Join(confPath, GetConfigName()))
|
cp, err := filepath.Abs(confPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
c, err := ReadInConfig(path.Join(cp, GetConfigName()))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@ -86,6 +92,14 @@ func initConf() (*Config, error) {
|
||||||
c.AuthFailBlock = false
|
c.AuthFailBlock = false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if len(c.AllowListFile) == 0 {
|
||||||
|
c.AllowListFile = c.relPath("./allow.list")
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.Production {
|
||||||
|
c.UseAllowList = true
|
||||||
|
}
|
||||||
|
|
||||||
return c, nil
|
return c, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -190,17 +190,3 @@ func self() (string, error) {
|
||||||
}
|
}
|
||||||
return bin, nil
|
return bin, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get path relative to cwd
|
|
||||||
func relpath(p string) string {
|
|
||||||
cwd, err := os.Getwd()
|
|
||||||
if err != nil {
|
|
||||||
return p
|
|
||||||
}
|
|
||||||
|
|
||||||
if strings.HasPrefix(p, cwd) {
|
|
||||||
return "./" + strings.TrimLeft(p[len(cwd):], "/")
|
|
||||||
}
|
|
||||||
|
|
||||||
return p
|
|
||||||
}
|
|
||||||
|
|
|
@ -119,3 +119,17 @@ func isDev() bool {
|
||||||
func sanitize(value string) string {
|
func sanitize(value string) string {
|
||||||
return strings.ToLower(strings.TrimSpace(value))
|
return strings.ToLower(strings.TrimSpace(value))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Get path relative to cwd
|
||||||
|
func relpath(p string) string {
|
||||||
|
cwd, err := os.Getwd()
|
||||||
|
if err != nil {
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.HasPrefix(p, cwd) {
|
||||||
|
return "./" + strings.TrimLeft(p[len(cwd):], "/")
|
||||||
|
}
|
||||||
|
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue