Remove other allocations in psql

This commit is contained in:
Vikram Rangnekar 2019-06-01 19:48:42 -04:00
parent 77e56643c5
commit f9fc5dd7de
5 changed files with 152 additions and 48 deletions

View File

@ -44,20 +44,24 @@ func (c *Compiler) IDColumn(table string) string {
return t.PrimaryCol return t.PrimaryCol
} }
func (c *Compiler) Compile(qc *qcode.QCode) (uint32, []byte, error) { func (c *Compiler) CompileEx(qc *qcode.QCode) (uint32, []byte, error) {
w := &bytes.Buffer{}
skipped, err := c.Compile(qc, w)
return skipped, w.Bytes(), err
}
func (c *Compiler) Compile(qc *qcode.QCode, w *bytes.Buffer) (uint32, error) {
if len(qc.Query.Selects) == 0 { if len(qc.Query.Selects) == 0 {
return 0, nil, errors.New("empty query") return 0, errors.New("empty query")
} }
root := &qc.Query.Selects[0] root := &qc.Query.Selects[0]
st := util.NewStack() st := util.NewStack()
ti, err := c.getTable(root) ti, err := c.getTable(root)
if err != nil { if err != nil {
return 0, nil, err return 0, err
} }
w := &bytes.Buffer{}
st.Push(&selectBlockClose{nil, root}) st.Push(&selectBlockClose{nil, root})
st.Push(&selectBlock{nil, root, qc, ti, c}) st.Push(&selectBlock{nil, root, qc, ti, c})
@ -82,7 +86,7 @@ func (c *Compiler) Compile(qc *qcode.QCode) (uint32, []byte, error) {
case *selectBlock: case *selectBlock:
skipped, err := v.render(w) skipped, err := v.render(w)
if err != nil { if err != nil {
return 0, nil, err return 0, err
} }
ignored |= skipped ignored |= skipped
@ -94,7 +98,7 @@ func (c *Compiler) Compile(qc *qcode.QCode) (uint32, []byte, error) {
ti, err := c.getTable(child) ti, err := c.getTable(child)
if err != nil { if err != nil {
return 0, nil, err return 0, err
} }
st.Push(&joinClose{child}) st.Push(&joinClose{child})
@ -113,7 +117,7 @@ func (c *Compiler) Compile(qc *qcode.QCode) (uint32, []byte, error) {
} }
if err != nil { if err != nil {
return 0, nil, err return 0, err
} }
} }
@ -121,7 +125,7 @@ func (c *Compiler) Compile(qc *qcode.QCode) (uint32, []byte, error) {
alias(w, `done_1337`) alias(w, `done_1337`)
w.WriteString(`;`) w.WriteString(`;`)
return ignored, w.Bytes(), nil return ignored, nil
} }
func (c *Compiler) getTable(sel *qcode.Select) (*DBTableInfo, error) { func (c *Compiler) getTable(sel *qcode.Select) (*DBTableInfo, error) {

View File

@ -1,6 +1,7 @@
package psql package psql
import ( import (
"bytes"
"log" "log"
"os" "os"
"testing" "testing"
@ -128,7 +129,7 @@ func compileGQLToPSQL(gql string) ([]byte, error) {
return nil, err return nil, err
} }
_, sqlStmt, err := pcompile.Compile(qc) _, sqlStmt, err := pcompile.CompileEx(qc)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@ -503,13 +504,21 @@ func BenchmarkCompileGQLToSQL(b *testing.B) {
} }
}` }`
w := &bytes.Buffer()
b.ResetTimer() b.ResetTimer()
b.ReportAllocs() b.ReportAllocs()
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
_, err := compileGQLToPSQL(gql) qc, err := qcompile.CompileQuery(gql)
if err != nil { if err != nil {
b.Fatal(err) b.Fatal(err)
} }
_, sqlStmt, err := pcompile.Compile(qc, w)
if err != nil {
b.Fatal(err)
}
w.Reset()
} }
} }

View File

@ -4,7 +4,6 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"log"
"net/http" "net/http"
"net/url" "net/url"
@ -20,7 +19,7 @@ func railsRedisHandler(next http.HandlerFunc) http.HandlerFunc {
} }
if len(conf.Auth.Rails.URL) == 0 { if len(conf.Auth.Rails.URL) == 0 {
log.Fatal(errors.New("no auth.rails.url defined")) logger.Fatal(errors.New("no auth.rails.url defined"))
} }
rp := &redis.Pool{ rp := &redis.Pool{
@ -79,12 +78,12 @@ func railsMemcacheHandler(next http.HandlerFunc) http.HandlerFunc {
} }
if len(conf.Auth.Rails.URL) == 0 { if len(conf.Auth.Rails.URL) == 0 {
log.Fatal(errors.New("no auth.rails.url defined")) logger.Fatal(errors.New("no auth.rails.url defined"))
} }
rURL, err := url.Parse(conf.Auth.Rails.URL) rURL, err := url.Parse(conf.Auth.Rails.URL)
if err != nil { if err != nil {
log.Fatal(err) logger.Fatal(err)
} }
mc := memcache.New(rURL.Host) mc := memcache.New(rURL.Host)
@ -127,7 +126,7 @@ func railsCookieHandler(next http.HandlerFunc) http.HandlerFunc {
ra, err := railsAuth(conf) ra, err := railsAuth(conf)
if err != nil { if err != nil {
log.Fatal(err) logger.Fatal(err)
} }
return func(w http.ResponseWriter, r *http.Request) { return func(w http.ResponseWriter, r *http.Request) {

View File

@ -4,10 +4,10 @@ import (
"bytes" "bytes"
"context" "context"
"encoding/json" "encoding/json"
"fmt" "errors"
"io" "io"
"net/http" "net/http"
"strings" "os"
"time" "time"
"github.com/cespare/xxhash/v2" "github.com/cespare/xxhash/v2"
@ -62,17 +62,117 @@ func (c *coreContext) handleReq(w io.Writer, req *http.Request) error {
// these values contain the id to be used with fetching remote data // these values contain the id to be used with fetching remote data
from := jsn.Get(data, fids) from := jsn.Get(data, fids)
var to []jsn.Field
switch {
case len(from) == 1:
to, err = c.resolveRemote(req, h, from[0], sel, sfmap)
case len(from) > 1:
to, err = c.resolveRemotes(req, h, from, sel, sfmap)
default:
return errors.New("something wrong no remote ids found in db response")
}
if err != nil {
return err
}
var ob bytes.Buffer
err = jsn.Replace(&ob, data, from, to)
if err != nil {
return err
}
return c.render(w, ob.Bytes())
}
func (c *coreContext) resolveRemote(
req *http.Request,
h *xxhash.Digest,
field jsn.Field,
sel []qcode.Select,
sfmap map[uint64]*qcode.Select) ([]jsn.Field, error) {
// replacement data for the marked insertion points
// key and value will be replaced by whats below
toA := [1]jsn.Field{}
to := toA[:0]
// use the json key to find the related Select object
k1 := xxhash.Sum64(field.Key)
s, ok := sfmap[k1]
if !ok {
return nil, nil
}
p := sel[s.ParentID]
// then use the Table nme in the Select and it's parent
// to find the resolver to use for this relationship
k2 := mkkey(h, s.Table, p.Table)
r, ok := rmap[k2]
if !ok {
return nil, nil
}
id := jsn.Value(field.Value)
if len(id) == 0 {
return nil, nil
}
st := time.Now()
b, err := r.Fn(req, id)
if err != nil {
return nil, err
}
if conf.EnableTracing {
c.addTrace(s, st)
}
if len(r.Path) != 0 {
b = jsn.Strip(b, r.Path)
}
var ob bytes.Buffer
if len(s.Cols) != 0 {
err = jsn.Filter(&ob, b, colsToList(s.Cols))
if err != nil {
return nil, err
}
} else {
ob.WriteString("null")
}
to[0] = jsn.Field{[]byte(s.FieldName), ob.Bytes()}
return to, nil
}
func (c *coreContext) resolveRemotes(
req *http.Request,
h *xxhash.Digest,
from []jsn.Field,
sel []qcode.Select,
sfmap map[uint64]*qcode.Select) ([]jsn.Field, error) {
// replacement data for the marked insertion points // replacement data for the marked insertion points
// key and value will be replaced by whats below // key and value will be replaced by whats below
to := make([]jsn.Field, 0, len(from)) to := make([]jsn.Field, 0, len(from))
for _, id := range from { for _, id := range from {
// use the json key to find the related Select object // use the json key to find the related Select object
k1 := xxhash.Sum64(id.Key) k1 := xxhash.Sum64(id.Key)
s, ok := sfmap[k1] s, ok := sfmap[k1]
if !ok { if !ok {
continue return nil, nil
} }
p := sel[s.ParentID] p := sel[s.ParentID]
@ -82,19 +182,19 @@ func (c *coreContext) handleReq(w io.Writer, req *http.Request) error {
r, ok := rmap[k2] r, ok := rmap[k2]
if !ok { if !ok {
continue return nil, nil
} }
id := jsn.Value(id.Value) id := jsn.Value(id.Value)
if len(id) == 0 { if len(id) == 0 {
continue return nil, nil
} }
st := time.Now() st := time.Now()
b, err := r.Fn(req, id) b, err := r.Fn(req, id)
if err != nil { if err != nil {
return err return nil, err
} }
if conf.EnableTracing { if conf.EnableTracing {
@ -110,29 +210,19 @@ func (c *coreContext) handleReq(w io.Writer, req *http.Request) error {
if len(s.Cols) != 0 { if len(s.Cols) != 0 {
err = jsn.Filter(&ob, b, colsToList(s.Cols)) err = jsn.Filter(&ob, b, colsToList(s.Cols))
if err != nil { if err != nil {
return err return nil, err
} }
} else { } else {
ob.WriteString("null") ob.WriteString("null")
} }
f := jsn.Field{[]byte(s.FieldName), ob.Bytes()} to = append(to, jsn.Field{[]byte(s.FieldName), ob.Bytes()})
to = append(to, f)
} }
return to, nil
var ob bytes.Buffer
err = jsn.Replace(&ob, data, from, to)
if err != nil {
return err
}
return c.render(w, ob.Bytes())
} }
func (c *coreContext) resolveSQL(qc *qcode.QCode, vars variables) ( func (c *coreContext) resolveSQL(qc *qcode.QCode, vars variables) ([]byte, uint32, error) {
[]byte, uint32, error) {
// var entry []byte // var entry []byte
// var key string // var key string
@ -153,15 +243,17 @@ func (c *coreContext) resolveSQL(qc *qcode.QCode, vars variables) (
// } // }
// } // }
skipped, stmts, err := pcompile.Compile(qc) stmt := &bytes.Buffer{}
skipped, err := pcompile.Compile(qc, stmt)
if err != nil { if err != nil {
return nil, 0, err return nil, 0, err
} }
t := fasttemplate.New(stmts[0], openVar, closeVar) t := fasttemplate.New(stmt.String(), openVar, closeVar)
var sqlStmt strings.Builder stmt.Reset()
_, err = t.Execute(&sqlStmt, vars) _, err = t.Execute(stmt, vars)
if err == errNoUserID && if err == errNoUserID &&
authFailBlock == authFailBlockPerQuery && authFailBlock == authFailBlockPerQuery &&
@ -173,10 +265,10 @@ func (c *coreContext) resolveSQL(qc *qcode.QCode, vars variables) (
return nil, 0, err return nil, 0, err
} }
finalSQL := sqlStmt.Bytes() finalSQL := stmt.String()
if conf.DebugLevel > 0 { if conf.DebugLevel > 0 {
fmt.Println(finalSQL) os.Stdout.WriteString(finalSQL)
} }
// if cacheEnabled { // if cacheEnabled {

View File

@ -5,7 +5,6 @@ import (
"errors" "errors"
"flag" "flag"
"fmt" "fmt"
"log"
"net/http" "net/http"
"os" "os"
"os/signal" "os/signal"
@ -156,6 +155,7 @@ func initCompilers(c *config) (*qcode.Compiler, *psql.Compiler, error) {
DefaultFilter: c.DB.Defaults.Filter, DefaultFilter: c.DB.Defaults.Filter,
FilterMap: c.getFilterMap(), FilterMap: c.getFilterMap(),
Blacklist: c.DB.Defaults.Blacklist, Blacklist: c.DB.Defaults.Blacklist,
KeepArgs: false,
}) })
if err != nil { if err != nil {
@ -178,17 +178,17 @@ func Init() {
conf, err = initConf() conf, err = initConf()
if err != nil { if err != nil {
log.Fatal(err) logger.Fatal(err)
} }
db, err = initDB(conf) db, err = initDB(conf)
if err != nil { if err != nil {
log.Fatal(err) logger.Fatal(err)
} }
qcompile, pcompile, err = initCompilers(conf) qcompile, pcompile, err = initCompilers(conf)
if err != nil { if err != nil {
log.Fatal(err) logger.Fatal(err)
} }
initResolvers() initResolvers()
@ -212,14 +212,14 @@ func startHTTP() {
<-sigint <-sigint
if err := srv.Shutdown(context.Background()); err != nil { if err := srv.Shutdown(context.Background()); err != nil {
log.Printf("http: %v", err) logger.Printf("http: %v", err)
} }
close(idleConnsClosed) close(idleConnsClosed)
}() }()
srv.RegisterOnShutdown(func() { srv.RegisterOnShutdown(func() {
if err := db.Close(); err != nil { if err := db.Close(); err != nil {
log.Println(err) logger.Println(err)
} }
}) })