super-graph/serv/prepare.go

229 lines
4.5 KiB
Go
Raw Permalink Normal View History

2019-07-29 07:13:33 +02:00
package serv
import (
"bytes"
2019-09-26 06:35:31 +02:00
"context"
2019-07-29 07:13:33 +02:00
"fmt"
"io"
"github.com/dosco/super-graph/qcode"
2019-09-26 06:35:31 +02:00
"github.com/jackc/pgconn"
"github.com/jackc/pgx/v4"
2019-07-29 07:13:33 +02:00
"github.com/valyala/fasttemplate"
)
type preparedItem struct {
2019-12-10 06:03:44 +01:00
sd *pgconn.StatementDescription
args [][]byte
st stmt
roleArg bool
2019-07-29 07:13:33 +02:00
}
var (
_preparedList map[string]*preparedItem
)
func initPreparedList() {
2019-11-25 08:22:33 +01:00
_preparedList = make(map[string]*preparedItem)
tx, err := db.Begin(context.Background())
if err != nil {
2019-11-25 08:22:33 +01:00
errlog.Fatal().Err(err).Send()
}
defer tx.Rollback(context.Background()) //nolint: errcheck
err = prepareRoleStmt(tx)
2019-11-25 08:22:33 +01:00
if err != nil {
errlog.Fatal().Err(err).Msg("failed to prepare get role statement")
}
2019-07-29 07:13:33 +02:00
if err := tx.Commit(context.Background()); err != nil {
2019-11-25 08:22:33 +01:00
errlog.Fatal().Err(err).Send()
}
2019-11-25 08:22:33 +01:00
success := 0
for _, v := range _allowList.list {
2019-11-25 08:22:33 +01:00
if len(v.gql) == 0 {
continue
}
err := prepareStmt(v.gql, v.vars)
2019-11-25 08:22:33 +01:00
if err == nil {
success++
continue
2019-07-29 07:13:33 +02:00
}
2019-11-25 08:22:33 +01:00
if len(v.vars) == 0 {
logger.Warn().Err(err).Msg(v.gql)
} else {
logger.Warn().Err(err).Msgf("%s %s", v.vars, v.gql)
}
}
2019-11-25 08:22:33 +01:00
logger.Info().
Msgf("Registered %d of %d queries from allow.list as prepared statements",
success, len(_allowList.list))
2019-07-29 07:13:33 +02:00
}
func prepareStmt(gql string, vars []byte) error {
2019-11-25 08:22:33 +01:00
qt := qcode.GetQType(gql)
q := []byte(gql)
2019-12-25 07:24:30 +01:00
if len(vars) == 0 {
logger.Debug().Msgf("Prepared statement:\n%s\n", gql)
} else {
logger.Debug().Msgf("Prepared statement:\n%s\n%s\n", vars, gql)
}
tx, err := db.Begin(context.Background())
2019-07-29 07:13:33 +02:00
if err != nil {
return err
}
defer tx.Rollback(context.Background()) //nolint: errcheck
2019-07-29 07:13:33 +02:00
2019-11-25 08:22:33 +01:00
switch qt {
case qcode.QTQuery:
var stmts1 []stmt
var err error
2019-12-10 06:03:44 +01:00
if conf.isABACEnabled() {
stmts1, err = buildMultiStmt(q, vars)
} else {
stmts1, err = buildRoleStmt(q, vars, "user")
}
if err != nil {
return err
}
2019-07-29 07:13:33 +02:00
2019-12-25 07:24:30 +01:00
logger.Debug().Msg("Prepared statement role: user")
2019-12-10 06:03:44 +01:00
err = prepare(tx, stmts1, gqlHash(gql, vars, "user"))
2019-11-25 08:22:33 +01:00
if err != nil {
return err
}
if conf.isAnonRoleDefined() {
2019-12-25 07:24:30 +01:00
logger.Debug().Msg("Prepared statement for role: anon")
stmts2, err := buildRoleStmt(q, vars, "anon")
if err != nil {
return err
}
2019-12-10 06:03:44 +01:00
err = prepare(tx, stmts2, gqlHash(gql, vars, "anon"))
if err != nil {
return err
}
}
2019-11-25 08:22:33 +01:00
case qcode.QTMutation:
for _, role := range conf.Roles {
2019-12-25 07:24:30 +01:00
logger.Debug().Msgf("Prepared statement for role: %s", role.Name)
2019-11-25 08:22:33 +01:00
stmts, err := buildRoleStmt(q, vars, role.Name)
if err != nil {
return err
}
2019-12-10 06:03:44 +01:00
err = prepare(tx, stmts, gqlHash(gql, vars, role.Name))
2019-11-25 08:22:33 +01:00
if err != nil {
return err
}
}
2019-11-25 08:22:33 +01:00
}
if err := tx.Commit(context.Background()); err != nil {
2019-11-25 08:22:33 +01:00
return err
2019-07-29 07:13:33 +02:00
}
return nil
}
2019-07-29 07:13:33 +02:00
2019-12-10 06:03:44 +01:00
func prepare(tx pgx.Tx, st []stmt, key string) error {
finalSQL, am := processTemplate(st[0].sql)
2019-11-25 08:22:33 +01:00
sd, err := tx.Prepare(context.Background(), "", finalSQL)
2019-11-25 08:22:33 +01:00
if err != nil {
return err
}
_preparedList[key] = &preparedItem{
2019-12-10 06:03:44 +01:00
sd: sd,
args: am,
st: st[0],
roleArg: len(st) > 1,
2019-11-25 08:22:33 +01:00
}
return nil
}
// nolint: errcheck
func prepareRoleStmt(tx pgx.Tx) error {
2019-12-10 06:03:44 +01:00
if !conf.isABACEnabled() {
return nil
}
2019-07-29 07:13:33 +02:00
w := &bytes.Buffer{}
2019-12-25 07:24:30 +01:00
io.WriteString(w, `SELECT (CASE WHEN EXISTS (`)
io.WriteString(w, conf.RolesQuery)
io.WriteString(w, `) THEN `)
io.WriteString(w, `(SELECT (CASE`)
for _, role := range conf.Roles {
if len(role.Match) == 0 {
continue
}
io.WriteString(w, ` WHEN `)
io.WriteString(w, role.Match)
io.WriteString(w, ` THEN '`)
io.WriteString(w, role.Name)
io.WriteString(w, `'`)
2019-07-29 07:13:33 +02:00
}
io.WriteString(w, ` ELSE {{role}} END) FROM (`)
io.WriteString(w, conf.RolesQuery)
2019-12-25 07:24:30 +01:00
io.WriteString(w, `) AS "_sg_auth_roles_query" LIMIT 1) `)
io.WriteString(w, `ELSE 'anon' END) FROM (VALUES (1)) AS "_sg_auth_filler" LIMIT 1; `)
roleSQL, _ := processTemplate(w.String())
_, err := tx.Prepare(context.Background(), "_sg_get_role", roleSQL)
2019-07-29 07:13:33 +02:00
if err != nil {
return err
}
return nil
}
2019-07-29 07:13:33 +02:00
func processTemplate(tmpl string) (string, [][]byte) {
st := struct {
vmap map[string]int
am [][]byte
i int
}{
vmap: make(map[string]int),
am: make([][]byte, 0, 5),
i: 0,
}
execFunc := func(w io.Writer, tag string) (int, error) {
if n, ok := st.vmap[tag]; ok {
return w.Write([]byte(fmt.Sprintf("$%d", n)))
}
st.am = append(st.am, []byte(tag))
st.i++
st.vmap[tag] = st.i
return w.Write([]byte(fmt.Sprintf("$%d", st.i)))
}
t1 := fasttemplate.New(tmpl, `'{{`, `}}'`)
ts1 := t1.ExecuteFuncString(execFunc)
t2 := fasttemplate.New(ts1, `{{`, `}}`)
ts2 := t2.ExecuteFuncString(execFunc)
return ts2, st.am
2019-07-29 07:13:33 +02:00
}