209 lines
4.1 KiB
Go
209 lines
4.1 KiB
Go
package core
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"crypto/sha256"
|
|
"database/sql"
|
|
"encoding/hex"
|
|
"fmt"
|
|
"io"
|
|
"strings"
|
|
|
|
"github.com/dosco/super-graph/core/internal/allow"
|
|
"github.com/dosco/super-graph/core/internal/qcode"
|
|
)
|
|
|
|
type preparedItem struct {
|
|
sd *sql.Stmt
|
|
st stmt
|
|
roleArg bool
|
|
}
|
|
|
|
func (sg *SuperGraph) initPrepared() error {
|
|
ct := context.Background()
|
|
|
|
if sg.allowList.IsPersist() {
|
|
return nil
|
|
}
|
|
sg.prepared = make(map[string]*preparedItem)
|
|
|
|
tx, err := sg.db.BeginTx(ct, nil)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
defer tx.Rollback() //nolint: errcheck
|
|
|
|
if err = sg.prepareRoleStmt(tx); err != nil {
|
|
return fmt.Errorf("prepareRoleStmt: %w", err)
|
|
}
|
|
|
|
if err := tx.Commit(); err != nil {
|
|
return err
|
|
}
|
|
|
|
success := 0
|
|
|
|
list, err := sg.allowList.Load()
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
for _, v := range list {
|
|
if len(v.Query) == 0 {
|
|
continue
|
|
}
|
|
|
|
err := sg.prepareStmt(v)
|
|
if err != nil {
|
|
return err
|
|
} else {
|
|
success++
|
|
}
|
|
}
|
|
|
|
sg.log.Printf("INF allow list: prepared %d / %d queries", success, len(list))
|
|
|
|
return nil
|
|
}
|
|
|
|
func (sg *SuperGraph) prepareStmt(item allow.Item) error {
|
|
query := item.Query
|
|
qb := []byte(query)
|
|
vars := item.Vars
|
|
|
|
qt := qcode.GetQType(query)
|
|
ct := context.Background()
|
|
|
|
switch qt {
|
|
case qcode.QTQuery:
|
|
var stmts1 []stmt
|
|
var err error
|
|
|
|
if sg.abacEnabled {
|
|
stmts1, err = sg.buildMultiStmt(qb, vars)
|
|
} else {
|
|
stmts1, err = sg.buildRoleStmt(qb, vars, "user")
|
|
}
|
|
|
|
if err == nil {
|
|
if err = sg.prepare(ct, stmts1, stmtHash(item.Name, "user")); err != nil {
|
|
return err
|
|
}
|
|
} else {
|
|
sg.log.Printf("WRN query %s: %v", item.Name, err)
|
|
}
|
|
|
|
if sg.anonExists {
|
|
stmts2, err := sg.buildRoleStmt(qb, vars, "anon")
|
|
|
|
if err == nil {
|
|
if err = sg.prepare(ct, stmts2, stmtHash(item.Name, "anon")); err != nil {
|
|
return err
|
|
}
|
|
} else {
|
|
sg.log.Printf("WRN query %s: %v", item.Name, err)
|
|
}
|
|
}
|
|
|
|
case qcode.QTMutation:
|
|
for _, role := range sg.conf.Roles {
|
|
stmts, err := sg.buildRoleStmt(qb, vars, role.Name)
|
|
|
|
if err == nil {
|
|
if err = sg.prepare(ct, stmts, stmtHash(item.Name, role.Name)); err != nil {
|
|
return err
|
|
}
|
|
} else {
|
|
sg.log.Printf("WRN mutation %s: %v", item.Name, err)
|
|
}
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (sg *SuperGraph) prepare(ct context.Context, st []stmt, key string) error {
|
|
sd, err := sg.db.PrepareContext(ct, st[0].sql)
|
|
if err != nil {
|
|
return fmt.Errorf("prepare failed: %v: %s", err, st[0].sql)
|
|
}
|
|
|
|
sg.prepared[key] = &preparedItem{
|
|
sd: sd,
|
|
st: st[0],
|
|
roleArg: len(st) > 1,
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// nolint: errcheck
|
|
func (sg *SuperGraph) prepareRoleStmt(tx *sql.Tx) error {
|
|
var err error
|
|
|
|
if !sg.abacEnabled {
|
|
return nil
|
|
}
|
|
|
|
rq := strings.ReplaceAll(sg.conf.RolesQuery, "$user_id", "$1")
|
|
w := &bytes.Buffer{}
|
|
|
|
io.WriteString(w, `SELECT (CASE WHEN EXISTS (`)
|
|
io.WriteString(w, rq)
|
|
io.WriteString(w, `) THEN `)
|
|
|
|
io.WriteString(w, `(SELECT (CASE`)
|
|
for _, role := range sg.conf.Roles {
|
|
if len(role.Match) == 0 {
|
|
continue
|
|
}
|
|
io.WriteString(w, ` WHEN `)
|
|
io.WriteString(w, role.Match)
|
|
io.WriteString(w, ` THEN '`)
|
|
io.WriteString(w, role.Name)
|
|
io.WriteString(w, `'`)
|
|
}
|
|
|
|
io.WriteString(w, ` ELSE $2 END) FROM (`)
|
|
io.WriteString(w, sg.conf.RolesQuery)
|
|
io.WriteString(w, `) AS "_sg_auth_roles_query" LIMIT 1) `)
|
|
io.WriteString(w, `ELSE 'anon' END) FROM (VALUES (1)) AS "_sg_auth_filler" LIMIT 1; `)
|
|
|
|
sg.getRole, err = tx.Prepare(w.String())
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (sg *SuperGraph) initAllowList() error {
|
|
var ac allow.Config
|
|
var err error
|
|
|
|
if sg.conf.AllowListFile == "" {
|
|
sg.conf.AllowListFile = "allow.list"
|
|
}
|
|
|
|
// When list is not eabled it is still created and
|
|
// and new queries are saved to it.
|
|
if !sg.conf.UseAllowList {
|
|
ac = allow.Config{CreateIfNotExists: true, Persist: true, Log: sg.log}
|
|
}
|
|
|
|
sg.allowList, err = allow.New(sg.conf.AllowListFile, ac)
|
|
if err != nil {
|
|
return fmt.Errorf("failed to initialize allow list: %w", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
// nolint: errcheck
|
|
func stmtHash(name string, role string) string {
|
|
h := sha256.New()
|
|
io.WriteString(h, strings.ToLower(name))
|
|
io.WriteString(h, role)
|
|
return hex.EncodeToString(h.Sum(nil))
|
|
}
|