super-graph/serv/prepare.go

156 lines
2.7 KiB
Go
Raw Normal View History

2019-07-29 07:13:33 +02:00
package serv
import (
"bytes"
2019-09-26 06:35:31 +02:00
"context"
2019-09-05 06:09:56 +02:00
"encoding/json"
2019-07-29 07:13:33 +02:00
"fmt"
"io"
"github.com/dosco/super-graph/qcode"
2019-09-26 06:35:31 +02:00
"github.com/jackc/pgconn"
2019-07-29 07:13:33 +02:00
"github.com/valyala/fasttemplate"
)
type preparedItem struct {
2019-09-26 06:35:31 +02:00
stmt *pgconn.StatementDescription
2019-09-05 06:09:56 +02:00
args [][]byte
2019-07-29 07:13:33 +02:00
skipped uint32
qc *qcode.QCode
}
var (
_preparedList map[string]*preparedItem
)
func initPreparedList() {
_preparedList = make(map[string]*preparedItem)
if err := prepareRoleStmt(); err != nil {
logger.Fatal().Err(err).Msg("failed to prepare get role statement")
}
for _, v := range _allowList.list {
err := prepareStmt(v.gql, v.vars)
2019-07-29 07:13:33 +02:00
if err != nil {
2019-10-15 08:30:19 +02:00
logger.Warn().Str("gql", v.gql).Err(err).Send()
2019-07-29 07:13:33 +02:00
}
}
}
func prepareStmt(gql string, varBytes json.RawMessage) error {
if len(gql) == 0 {
2019-07-29 07:13:33 +02:00
return nil
}
c := &coreContext{Context: context.Background()}
c.req.Query = gql
c.req.Vars = varBytes
stmts, err := c.buildStmt()
2019-07-29 07:13:33 +02:00
if err != nil {
return err
}
for _, s := range stmts {
if len(s.sql) == 0 {
continue
}
finalSQL, am := processTemplate(s.sql)
2019-09-05 06:09:56 +02:00
ctx := context.Background()
2019-09-05 06:09:56 +02:00
tx, err := db.Begin(ctx)
if err != nil {
2019-09-05 06:09:56 +02:00
return err
}
defer tx.Rollback(ctx)
2019-09-05 06:09:56 +02:00
pstmt, err := tx.Prepare(ctx, "", finalSQL)
if err != nil {
return err
}
2019-07-29 07:13:33 +02:00
var key string
if s.role == nil {
key = gqlHash(gql, varBytes, "")
} else {
key = gqlHash(gql, varBytes, s.role.Name)
}
_preparedList[key] = &preparedItem{
stmt: pstmt,
args: am,
skipped: s.skipped,
qc: s.qc,
}
if err := tx.Commit(ctx); err != nil {
return err
}
2019-07-29 07:13:33 +02:00
}
return nil
}
2019-07-29 07:13:33 +02:00
func prepareRoleStmt() error {
if len(conf.RolesQuery) == 0 {
return nil
}
2019-07-29 07:13:33 +02:00
w := &bytes.Buffer{}
io.WriteString(w, `SELECT (CASE`)
for _, role := range conf.Roles {
if len(role.Match) == 0 {
continue
}
io.WriteString(w, ` WHEN `)
io.WriteString(w, role.Match)
io.WriteString(w, ` THEN '`)
io.WriteString(w, role.Name)
io.WriteString(w, `'`)
2019-07-29 07:13:33 +02:00
}
io.WriteString(w, ` ELSE {{role}} END) FROM (`)
io.WriteString(w, conf.RolesQuery)
io.WriteString(w, `) AS "_sg_auth_roles_query"`)
roleSQL, _ := processTemplate(w.String())
2019-09-26 06:35:31 +02:00
ctx := context.Background()
tx, err := db.Begin(ctx)
if err != nil {
return err
}
defer tx.Rollback(ctx)
_, err = tx.Prepare(ctx, "_sg_get_role", roleSQL)
2019-07-29 07:13:33 +02:00
if err != nil {
return err
}
return nil
}
2019-07-29 07:13:33 +02:00
func processTemplate(tmpl string) (string, [][]byte) {
t := fasttemplate.New(tmpl, `{{`, `}}`)
am := make([][]byte, 0, 5)
i := 0
2019-09-26 06:35:31 +02:00
vmap := make(map[string]int)
return t.ExecuteFuncString(func(w io.Writer, tag string) (int, error) {
if n, ok := vmap[tag]; ok {
return w.Write([]byte(fmt.Sprintf("$%d", n)))
}
am = append(am, []byte(tag))
i++
vmap[tag] = i
return w.Write([]byte(fmt.Sprintf("$%d", i)))
}), am
2019-07-29 07:13:33 +02:00
}