From f9fc5dd7de90ea4c65259a6fc42d3be596fccefd Mon Sep 17 00:00:00 2001 From: Vikram Rangnekar Date: Sat, 1 Jun 2019 19:48:42 -0400 Subject: [PATCH] Remove other allocations in psql --- psql/psql.go | 22 ++++--- psql/psql_test.go | 13 +++- serv/auth_rails.go | 9 ++- serv/core.go | 144 +++++++++++++++++++++++++++++++++++++-------- serv/serv.go | 12 ++-- 5 files changed, 152 insertions(+), 48 deletions(-) diff --git a/psql/psql.go b/psql/psql.go index 581705e..eb74e8d 100644 --- a/psql/psql.go +++ b/psql/psql.go @@ -44,20 +44,24 @@ func (c *Compiler) IDColumn(table string) string { return t.PrimaryCol } -func (c *Compiler) Compile(qc *qcode.QCode) (uint32, []byte, error) { +func (c *Compiler) CompileEx(qc *qcode.QCode) (uint32, []byte, error) { + w := &bytes.Buffer{} + skipped, err := c.Compile(qc, w) + return skipped, w.Bytes(), err +} + +func (c *Compiler) Compile(qc *qcode.QCode, w *bytes.Buffer) (uint32, error) { if len(qc.Query.Selects) == 0 { - return 0, nil, errors.New("empty query") + return 0, errors.New("empty query") } root := &qc.Query.Selects[0] st := util.NewStack() ti, err := c.getTable(root) if err != nil { - return 0, nil, err + return 0, err } - w := &bytes.Buffer{} - st.Push(&selectBlockClose{nil, root}) st.Push(&selectBlock{nil, root, qc, ti, c}) @@ -82,7 +86,7 @@ func (c *Compiler) Compile(qc *qcode.QCode) (uint32, []byte, error) { case *selectBlock: skipped, err := v.render(w) if err != nil { - return 0, nil, err + return 0, err } ignored |= skipped @@ -94,7 +98,7 @@ func (c *Compiler) Compile(qc *qcode.QCode) (uint32, []byte, error) { ti, err := c.getTable(child) if err != nil { - return 0, nil, err + return 0, err } st.Push(&joinClose{child}) @@ -113,7 +117,7 @@ func (c *Compiler) Compile(qc *qcode.QCode) (uint32, []byte, error) { } if err != nil { - return 0, nil, err + return 0, err } } @@ -121,7 +125,7 @@ func (c *Compiler) Compile(qc *qcode.QCode) (uint32, []byte, error) { alias(w, `done_1337`) w.WriteString(`;`) - return ignored, w.Bytes(), nil + return ignored, nil } func (c *Compiler) getTable(sel *qcode.Select) (*DBTableInfo, error) { diff --git a/psql/psql_test.go b/psql/psql_test.go index 052581e..50d3a32 100644 --- a/psql/psql_test.go +++ b/psql/psql_test.go @@ -1,6 +1,7 @@ package psql import ( + "bytes" "log" "os" "testing" @@ -128,7 +129,7 @@ func compileGQLToPSQL(gql string) ([]byte, error) { return nil, err } - _, sqlStmt, err := pcompile.Compile(qc) + _, sqlStmt, err := pcompile.CompileEx(qc) if err != nil { return nil, err } @@ -503,13 +504,21 @@ func BenchmarkCompileGQLToSQL(b *testing.B) { } }` + w := &bytes.Buffer() + b.ResetTimer() b.ReportAllocs() for n := 0; n < b.N; n++ { - _, err := compileGQLToPSQL(gql) + qc, err := qcompile.CompileQuery(gql) if err != nil { b.Fatal(err) } + + _, sqlStmt, err := pcompile.Compile(qc, w) + if err != nil { + b.Fatal(err) + } + w.Reset() } } diff --git a/serv/auth_rails.go b/serv/auth_rails.go index 99b5b19..8baf77f 100644 --- a/serv/auth_rails.go +++ b/serv/auth_rails.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "log" "net/http" "net/url" @@ -20,7 +19,7 @@ func railsRedisHandler(next http.HandlerFunc) http.HandlerFunc { } if len(conf.Auth.Rails.URL) == 0 { - log.Fatal(errors.New("no auth.rails.url defined")) + logger.Fatal(errors.New("no auth.rails.url defined")) } rp := &redis.Pool{ @@ -79,12 +78,12 @@ func railsMemcacheHandler(next http.HandlerFunc) http.HandlerFunc { } if len(conf.Auth.Rails.URL) == 0 { - log.Fatal(errors.New("no auth.rails.url defined")) + logger.Fatal(errors.New("no auth.rails.url defined")) } rURL, err := url.Parse(conf.Auth.Rails.URL) if err != nil { - log.Fatal(err) + logger.Fatal(err) } mc := memcache.New(rURL.Host) @@ -127,7 +126,7 @@ func railsCookieHandler(next http.HandlerFunc) http.HandlerFunc { ra, err := railsAuth(conf) if err != nil { - log.Fatal(err) + logger.Fatal(err) } return func(w http.ResponseWriter, r *http.Request) { diff --git a/serv/core.go b/serv/core.go index 51eaa90..802b5b3 100644 --- a/serv/core.go +++ b/serv/core.go @@ -4,10 +4,10 @@ import ( "bytes" "context" "encoding/json" - "fmt" + "errors" "io" "net/http" - "strings" + "os" "time" "github.com/cespare/xxhash/v2" @@ -62,17 +62,117 @@ func (c *coreContext) handleReq(w io.Writer, req *http.Request) error { // these values contain the id to be used with fetching remote data from := jsn.Get(data, fids) + var to []jsn.Field + switch { + case len(from) == 1: + to, err = c.resolveRemote(req, h, from[0], sel, sfmap) + + case len(from) > 1: + to, err = c.resolveRemotes(req, h, from, sel, sfmap) + + default: + return errors.New("something wrong no remote ids found in db response") + } + + if err != nil { + return err + } + + var ob bytes.Buffer + + err = jsn.Replace(&ob, data, from, to) + if err != nil { + return err + } + + return c.render(w, ob.Bytes()) +} + +func (c *coreContext) resolveRemote( + req *http.Request, + h *xxhash.Digest, + field jsn.Field, + sel []qcode.Select, + sfmap map[uint64]*qcode.Select) ([]jsn.Field, error) { + + // replacement data for the marked insertion points + // key and value will be replaced by whats below + toA := [1]jsn.Field{} + to := toA[:0] + + // use the json key to find the related Select object + k1 := xxhash.Sum64(field.Key) + + s, ok := sfmap[k1] + if !ok { + return nil, nil + } + p := sel[s.ParentID] + + // then use the Table nme in the Select and it's parent + // to find the resolver to use for this relationship + k2 := mkkey(h, s.Table, p.Table) + + r, ok := rmap[k2] + if !ok { + return nil, nil + } + + id := jsn.Value(field.Value) + if len(id) == 0 { + return nil, nil + } + + st := time.Now() + + b, err := r.Fn(req, id) + if err != nil { + return nil, err + } + + if conf.EnableTracing { + c.addTrace(s, st) + } + + if len(r.Path) != 0 { + b = jsn.Strip(b, r.Path) + } + + var ob bytes.Buffer + + if len(s.Cols) != 0 { + err = jsn.Filter(&ob, b, colsToList(s.Cols)) + if err != nil { + return nil, err + } + + } else { + ob.WriteString("null") + } + + to[0] = jsn.Field{[]byte(s.FieldName), ob.Bytes()} + return to, nil +} + +func (c *coreContext) resolveRemotes( + req *http.Request, + h *xxhash.Digest, + from []jsn.Field, + sel []qcode.Select, + sfmap map[uint64]*qcode.Select) ([]jsn.Field, error) { + // replacement data for the marked insertion points // key and value will be replaced by whats below to := make([]jsn.Field, 0, len(from)) for _, id := range from { + // use the json key to find the related Select object k1 := xxhash.Sum64(id.Key) s, ok := sfmap[k1] if !ok { - continue + return nil, nil } p := sel[s.ParentID] @@ -82,19 +182,19 @@ func (c *coreContext) handleReq(w io.Writer, req *http.Request) error { r, ok := rmap[k2] if !ok { - continue + return nil, nil } id := jsn.Value(id.Value) if len(id) == 0 { - continue + return nil, nil } st := time.Now() b, err := r.Fn(req, id) if err != nil { - return err + return nil, err } if conf.EnableTracing { @@ -110,29 +210,19 @@ func (c *coreContext) handleReq(w io.Writer, req *http.Request) error { if len(s.Cols) != 0 { err = jsn.Filter(&ob, b, colsToList(s.Cols)) if err != nil { - return err + return nil, err } } else { ob.WriteString("null") } - f := jsn.Field{[]byte(s.FieldName), ob.Bytes()} - to = append(to, f) + to = append(to, jsn.Field{[]byte(s.FieldName), ob.Bytes()}) } - - var ob bytes.Buffer - - err = jsn.Replace(&ob, data, from, to) - if err != nil { - return err - } - - return c.render(w, ob.Bytes()) + return to, nil } -func (c *coreContext) resolveSQL(qc *qcode.QCode, vars variables) ( - []byte, uint32, error) { +func (c *coreContext) resolveSQL(qc *qcode.QCode, vars variables) ([]byte, uint32, error) { // var entry []byte // var key string @@ -153,15 +243,17 @@ func (c *coreContext) resolveSQL(qc *qcode.QCode, vars variables) ( // } // } - skipped, stmts, err := pcompile.Compile(qc) + stmt := &bytes.Buffer{} + + skipped, err := pcompile.Compile(qc, stmt) if err != nil { return nil, 0, err } - t := fasttemplate.New(stmts[0], openVar, closeVar) + t := fasttemplate.New(stmt.String(), openVar, closeVar) - var sqlStmt strings.Builder - _, err = t.Execute(&sqlStmt, vars) + stmt.Reset() + _, err = t.Execute(stmt, vars) if err == errNoUserID && authFailBlock == authFailBlockPerQuery && @@ -173,10 +265,10 @@ func (c *coreContext) resolveSQL(qc *qcode.QCode, vars variables) ( return nil, 0, err } - finalSQL := sqlStmt.Bytes() + finalSQL := stmt.String() if conf.DebugLevel > 0 { - fmt.Println(finalSQL) + os.Stdout.WriteString(finalSQL) } // if cacheEnabled { diff --git a/serv/serv.go b/serv/serv.go index b8adb30..0c76ced 100644 --- a/serv/serv.go +++ b/serv/serv.go @@ -5,7 +5,6 @@ import ( "errors" "flag" "fmt" - "log" "net/http" "os" "os/signal" @@ -156,6 +155,7 @@ func initCompilers(c *config) (*qcode.Compiler, *psql.Compiler, error) { DefaultFilter: c.DB.Defaults.Filter, FilterMap: c.getFilterMap(), Blacklist: c.DB.Defaults.Blacklist, + KeepArgs: false, }) if err != nil { @@ -178,17 +178,17 @@ func Init() { conf, err = initConf() if err != nil { - log.Fatal(err) + logger.Fatal(err) } db, err = initDB(conf) if err != nil { - log.Fatal(err) + logger.Fatal(err) } qcompile, pcompile, err = initCompilers(conf) if err != nil { - log.Fatal(err) + logger.Fatal(err) } initResolvers() @@ -212,14 +212,14 @@ func startHTTP() { <-sigint if err := srv.Shutdown(context.Background()); err != nil { - log.Printf("http: %v", err) + logger.Printf("http: %v", err) } close(idleConnsClosed) }() srv.RegisterOnShutdown(func() { if err := db.Close(); err != nil { - log.Println(err) + logger.Println(err) } })