From c40046183529c6676551f72c5e46f8c641eca92f Mon Sep 17 00:00:00 2001 From: Vikram Rangnekar Date: Sun, 19 Apr 2020 12:54:37 -0400 Subject: [PATCH] fix: prepared statements not working in prod mode --- core/core.go | 11 ++++----- core/prepare.go | 48 +++++++++++---------------------------- internal/serv/cmd_serv.go | 4 ---- internal/serv/config.go | 4 ---- internal/serv/init.go | 16 ++++++++++++- internal/serv/reload.go | 14 ------------ internal/serv/utils.go | 14 ++++++++++++ 7 files changed, 46 insertions(+), 65 deletions(-) diff --git a/core/core.go b/core/core.go index 9ab0e04..9d544b2 100644 --- a/core/core.go +++ b/core/core.go @@ -94,15 +94,12 @@ func (c *scontext) execQuery() ([]byte, error) { if c.sg.conf.UseAllowList { data, st, err = c.resolvePreparedSQL() - if err != nil { - return nil, err - } - } else { data, st, err = c.resolveSQL() - if err != nil { - return nil, err - } + } + + if err != nil { + return nil, err } if len(data) == 0 || st.skipped == 0 { diff --git a/core/prepare.go b/core/prepare.go index c0c503a..f9fc9a0 100644 --- a/core/prepare.go +++ b/core/prepare.go @@ -58,21 +58,14 @@ func (sg *SuperGraph) initPrepared() error { } err := sg.prepareStmt(v) - if err == nil { + if err != nil { + sg.log.Printf("WRN %s: %v", v.Name, err) + } else { success++ - continue } - - // if len(v.Vars) == 0 { - // logger.Warn().Err(err).Msg(v.Query) - // } else { - // logger.Warn().Err(err).Msgf("%s %s", v.Vars, v.Query) - // } } - // logger.Info(). - // Msgf("Registered %d of %d queries from allow.list as prepared statements", - // success, len(list)) + sg.log.Printf("INF allow list: prepared %d / %d queries", success, len(list)) return nil } @@ -84,13 +77,6 @@ func (sg *SuperGraph) prepareStmt(item allow.Item) error { qt := qcode.GetQType(query) ct := context.Background() - - tx, err := sg.db.BeginTx(ct, nil) - if err != nil { - return err - } - defer tx.Rollback() //nolint: errcheck - switch qt { case qcode.QTQuery: var stmts1 []stmt @@ -108,7 +94,7 @@ func (sg *SuperGraph) prepareStmt(item allow.Item) error { //logger.Debug().Msgf("Prepared statement 'query %s' (user)", item.Name) - err = sg.prepare(ct, tx, stmts1, stmtHash(item.Name, "user")) + err = sg.prepare(ct, stmts1, stmtHash(item.Name, "user")) if err != nil { return err } @@ -124,7 +110,7 @@ func (sg *SuperGraph) prepareStmt(item allow.Item) error { return err } - err = sg.prepare(ct, tx, stmts2, stmtHash(item.Name, "anon")) + err = sg.prepare(ct, stmts2, stmtHash(item.Name, "anon")) if err != nil { return err } @@ -135,36 +121,26 @@ func (sg *SuperGraph) prepareStmt(item allow.Item) error { // logger.Debug().Msgf("Prepared statement 'mutation %s' (%s)", item.Name, role.Name) stmts, err := sg.buildRoleStmt(qb, vars, role.Name) - if err != nil { - // if len(item.Vars) == 0 { - // logger.Warn().Err(err).Msg(item.Query) - // } else { - // logger.Warn().Err(err).Msgf("%s %s", item.Vars, item.Query) - // } - continue + return err } - err = sg.prepare(ct, tx, stmts, stmtHash(item.Name, role.Name)) + err = sg.prepare(ct, stmts, stmtHash(item.Name, role.Name)) if err != nil { return err } } } - if err := tx.Commit(); err != nil { - return err - } - return nil } -func (sg *SuperGraph) prepare(ct context.Context, tx *sql.Tx, st []stmt, key string) error { +func (sg *SuperGraph) prepare(ct context.Context, st []stmt, key string) error { finalSQL, am := processTemplate(st[0].sql) - sd, err := tx.Prepare(finalSQL) + sd, err := sg.db.Prepare(finalSQL) if err != nil { - return err + return fmt.Errorf("prepare failed: %v: %s", err, finalSQL) } sg.prepared[key] = &preparedItem{ @@ -256,6 +232,8 @@ func (sg *SuperGraph) initAllowList() error { sg.log.Printf("WRN allow list disabled no file specified") } + // When list is not eabled it is still created and + // and new queries are saved to it. if !sg.conf.UseAllowList { ac = allow.Config{CreateIfNotExists: true, Persist: true} } diff --git a/internal/serv/cmd_serv.go b/internal/serv/cmd_serv.go index 254c416..2ce446e 100644 --- a/internal/serv/cmd_serv.go +++ b/internal/serv/cmd_serv.go @@ -24,10 +24,6 @@ func cmdServ(cmd *cobra.Command, args []string) { fatalInProd(err, "failed to connect to database") } - // if conf != nil && db != nil { - // initResolvers() - // } - sg, err = core.NewSuperGraph(&conf.Core, db) if err != nil { fatalInProd(err, "failed to initialize Super Graph") diff --git a/internal/serv/config.go b/internal/serv/config.go index 7af5cc5..28da0da 100644 --- a/internal/serv/config.go +++ b/internal/serv/config.go @@ -49,10 +49,6 @@ func ReadInConfig(configFile string) (*Config, error) { return nil, fmt.Errorf("failed to decode config, %v", err) } - if len(c.Core.AllowListFile) == 0 { - c.Core.AllowListFile = path.Join(cpath, "allow.list") - } - return c, nil } diff --git a/internal/serv/init.go b/internal/serv/init.go index 6a0cfac..fd4f89f 100644 --- a/internal/serv/init.go +++ b/internal/serv/init.go @@ -8,6 +8,7 @@ import ( "fmt" "io/ioutil" "path" + "path/filepath" "strings" "time" @@ -21,7 +22,12 @@ const ( ) func initConf() (*Config, error) { - c, err := ReadInConfig(path.Join(confPath, GetConfigName())) + cp, err := filepath.Abs(confPath) + if err != nil { + return nil, err + } + + c, err := ReadInConfig(path.Join(cp, GetConfigName())) if err != nil { return nil, err } @@ -86,6 +92,14 @@ func initConf() (*Config, error) { c.AuthFailBlock = false } + if len(c.AllowListFile) == 0 { + c.AllowListFile = c.relPath("./allow.list") + } + + if c.Production { + c.UseAllowList = true + } + return c, nil } diff --git a/internal/serv/reload.go b/internal/serv/reload.go index 850db72..553413b 100644 --- a/internal/serv/reload.go +++ b/internal/serv/reload.go @@ -190,17 +190,3 @@ func self() (string, error) { } return bin, nil } - -// Get path relative to cwd -func relpath(p string) string { - cwd, err := os.Getwd() - if err != nil { - return p - } - - if strings.HasPrefix(p, cwd) { - return "./" + strings.TrimLeft(p[len(cwd):], "/") - } - - return p -} diff --git a/internal/serv/utils.go b/internal/serv/utils.go index 2a88260..6c64649 100644 --- a/internal/serv/utils.go +++ b/internal/serv/utils.go @@ -119,3 +119,17 @@ func isDev() bool { func sanitize(value string) string { return strings.ToLower(strings.TrimSpace(value)) } + +// Get path relative to cwd +func relpath(p string) string { + cwd, err := os.Getwd() + if err != nil { + return p + } + + if strings.HasPrefix(p, cwd) { + return "./" + strings.TrimLeft(p[len(cwd):], "/") + } + + return p +}