2019-07-29 07:13:33 +02:00
|
|
|
package serv
|
|
|
|
|
|
|
|
import (
|
|
|
|
"bytes"
|
2019-09-26 06:35:31 +02:00
|
|
|
"context"
|
2019-09-05 06:09:56 +02:00
|
|
|
"encoding/json"
|
2019-07-29 07:13:33 +02:00
|
|
|
"fmt"
|
|
|
|
"io"
|
|
|
|
|
|
|
|
"github.com/dosco/super-graph/qcode"
|
2019-09-26 06:35:31 +02:00
|
|
|
"github.com/jackc/pgconn"
|
2019-07-29 07:13:33 +02:00
|
|
|
"github.com/valyala/fasttemplate"
|
|
|
|
)
|
|
|
|
|
|
|
|
type preparedItem struct {
|
2019-09-26 06:35:31 +02:00
|
|
|
stmt *pgconn.StatementDescription
|
2019-09-05 06:09:56 +02:00
|
|
|
args [][]byte
|
2019-07-29 07:13:33 +02:00
|
|
|
skipped uint32
|
|
|
|
qc *qcode.QCode
|
|
|
|
}
|
|
|
|
|
|
|
|
var (
|
|
|
|
_preparedList map[string]*preparedItem
|
|
|
|
)
|
|
|
|
|
|
|
|
func initPreparedList() {
|
|
|
|
_preparedList = make(map[string]*preparedItem)
|
|
|
|
|
2019-10-24 08:07:42 +02:00
|
|
|
if err := prepareRoleStmt(); err != nil {
|
|
|
|
logger.Fatal().Err(err).Msg("failed to prepare get role statement")
|
|
|
|
}
|
|
|
|
|
|
|
|
for _, v := range _allowList.list {
|
2019-10-25 06:01:22 +02:00
|
|
|
|
2019-10-24 08:07:42 +02:00
|
|
|
err := prepareStmt(v.gql, v.vars)
|
2019-07-29 07:13:33 +02:00
|
|
|
if err != nil {
|
2019-10-15 08:30:19 +02:00
|
|
|
logger.Warn().Str("gql", v.gql).Err(err).Send()
|
2019-07-29 07:13:33 +02:00
|
|
|
}
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2019-10-24 08:07:42 +02:00
|
|
|
func prepareStmt(gql string, varBytes json.RawMessage) error {
|
|
|
|
if len(gql) == 0 {
|
2019-07-29 07:13:33 +02:00
|
|
|
return nil
|
|
|
|
}
|
|
|
|
|
2019-10-24 08:07:42 +02:00
|
|
|
c := &coreContext{Context: context.Background()}
|
|
|
|
c.req.Query = gql
|
|
|
|
c.req.Vars = varBytes
|
|
|
|
|
|
|
|
stmts, err := c.buildStmt()
|
2019-07-29 07:13:33 +02:00
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2019-10-25 06:01:22 +02:00
|
|
|
if len(stmts) != 0 && stmts[0].qc.Type == qcode.QTQuery {
|
|
|
|
c.req.Vars = nil
|
|
|
|
}
|
|
|
|
|
2019-10-24 08:07:42 +02:00
|
|
|
for _, s := range stmts {
|
|
|
|
if len(s.sql) == 0 {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
|
|
|
finalSQL, am := processTemplate(s.sql)
|
2019-09-05 06:09:56 +02:00
|
|
|
|
2019-10-24 08:07:42 +02:00
|
|
|
ctx := context.Background()
|
2019-09-05 06:09:56 +02:00
|
|
|
|
2019-10-24 08:07:42 +02:00
|
|
|
tx, err := db.Begin(ctx)
|
|
|
|
if err != nil {
|
2019-09-05 06:09:56 +02:00
|
|
|
return err
|
|
|
|
}
|
2019-10-24 08:07:42 +02:00
|
|
|
defer tx.Rollback(ctx)
|
2019-09-05 06:09:56 +02:00
|
|
|
|
2019-10-24 08:07:42 +02:00
|
|
|
pstmt, err := tx.Prepare(ctx, "", finalSQL)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
2019-07-29 07:13:33 +02:00
|
|
|
|
2019-10-24 08:07:42 +02:00
|
|
|
var key string
|
|
|
|
|
|
|
|
if s.role == nil {
|
2019-10-25 06:01:22 +02:00
|
|
|
key = gqlHash(gql, c.req.Vars, "")
|
2019-10-24 08:07:42 +02:00
|
|
|
} else {
|
2019-10-25 06:01:22 +02:00
|
|
|
key = gqlHash(gql, c.req.Vars, s.role.Name)
|
2019-10-24 08:07:42 +02:00
|
|
|
}
|
|
|
|
|
|
|
|
_preparedList[key] = &preparedItem{
|
|
|
|
stmt: pstmt,
|
|
|
|
args: am,
|
|
|
|
skipped: s.skipped,
|
|
|
|
qc: s.qc,
|
|
|
|
}
|
|
|
|
|
|
|
|
if err := tx.Commit(ctx); err != nil {
|
|
|
|
return err
|
|
|
|
}
|
2019-07-29 07:13:33 +02:00
|
|
|
}
|
|
|
|
|
2019-10-24 08:07:42 +02:00
|
|
|
return nil
|
|
|
|
}
|
2019-07-29 07:13:33 +02:00
|
|
|
|
2019-10-24 08:07:42 +02:00
|
|
|
func prepareRoleStmt() error {
|
|
|
|
if len(conf.RolesQuery) == 0 {
|
|
|
|
return nil
|
|
|
|
}
|
2019-07-29 07:13:33 +02:00
|
|
|
|
2019-10-24 08:07:42 +02:00
|
|
|
w := &bytes.Buffer{}
|
|
|
|
|
|
|
|
io.WriteString(w, `SELECT (CASE`)
|
|
|
|
for _, role := range conf.Roles {
|
|
|
|
if len(role.Match) == 0 {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
io.WriteString(w, ` WHEN `)
|
|
|
|
io.WriteString(w, role.Match)
|
|
|
|
io.WriteString(w, ` THEN '`)
|
|
|
|
io.WriteString(w, role.Name)
|
|
|
|
io.WriteString(w, `'`)
|
2019-07-29 07:13:33 +02:00
|
|
|
}
|
|
|
|
|
2019-10-24 08:07:42 +02:00
|
|
|
io.WriteString(w, ` ELSE {{role}} END) FROM (`)
|
|
|
|
io.WriteString(w, conf.RolesQuery)
|
|
|
|
io.WriteString(w, `) AS "_sg_auth_roles_query"`)
|
|
|
|
|
|
|
|
roleSQL, _ := processTemplate(w.String())
|
|
|
|
|
2019-09-26 06:35:31 +02:00
|
|
|
ctx := context.Background()
|
|
|
|
|
|
|
|
tx, err := db.Begin(ctx)
|
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
defer tx.Rollback(ctx)
|
|
|
|
|
2019-10-24 08:07:42 +02:00
|
|
|
_, err = tx.Prepare(ctx, "_sg_get_role", roleSQL)
|
2019-07-29 07:13:33 +02:00
|
|
|
if err != nil {
|
|
|
|
return err
|
|
|
|
}
|
|
|
|
|
2019-10-24 08:07:42 +02:00
|
|
|
return nil
|
|
|
|
}
|
2019-07-29 07:13:33 +02:00
|
|
|
|
2019-10-24 08:07:42 +02:00
|
|
|
func processTemplate(tmpl string) (string, [][]byte) {
|
|
|
|
t := fasttemplate.New(tmpl, `{{`, `}}`)
|
|
|
|
am := make([][]byte, 0, 5)
|
|
|
|
i := 0
|
2019-09-26 06:35:31 +02:00
|
|
|
|
2019-10-24 08:07:42 +02:00
|
|
|
vmap := make(map[string]int)
|
|
|
|
|
|
|
|
return t.ExecuteFuncString(func(w io.Writer, tag string) (int, error) {
|
|
|
|
if n, ok := vmap[tag]; ok {
|
|
|
|
return w.Write([]byte(fmt.Sprintf("$%d", n)))
|
|
|
|
}
|
|
|
|
am = append(am, []byte(tag))
|
|
|
|
i++
|
|
|
|
vmap[tag] = i
|
|
|
|
return w.Write([]byte(fmt.Sprintf("$%d", i)))
|
|
|
|
}), am
|
2019-07-29 07:13:33 +02:00
|
|
|
}
|