super-graph/serv/core_build.go

183 lines
3.7 KiB
Go
Raw Normal View History

package serv
import (
"bytes"
"encoding/json"
"errors"
2019-11-25 08:22:33 +01:00
"fmt"
"io"
"github.com/dosco/super-graph/psql"
"github.com/dosco/super-graph/qcode"
)
type stmt struct {
role *configRole
qc *qcode.QCode
skipped uint32
sql string
}
2019-11-25 08:22:33 +01:00
func buildStmt(qt qcode.QType, gql, vars []byte, role string) ([]stmt, error) {
switch qt {
case qcode.QTMutation:
return buildRoleStmt(gql, vars, role)
2019-11-25 08:22:33 +01:00
case qcode.QTQuery:
if role == "anon" {
return buildRoleStmt(gql, vars, "anon")
}
2019-11-25 08:22:33 +01:00
2019-12-10 06:03:44 +01:00
if conf.isABACEnabled() {
2019-11-25 08:22:33 +01:00
return buildMultiStmt(gql, vars)
}
return buildRoleStmt(gql, vars, "user")
2019-11-25 08:22:33 +01:00
default:
return nil, fmt.Errorf("unknown query type '%d'", qt)
}
}
func buildRoleStmt(gql, vars []byte, role string) ([]stmt, error) {
ro, ok := conf.roles[role]
if !ok {
return nil, fmt.Errorf(`roles '%s' not defined in config`, role)
}
var vm map[string]json.RawMessage
var err error
if len(vars) != 0 {
if err := json.Unmarshal(vars, &vm); err != nil {
return nil, err
}
}
2019-11-25 08:22:33 +01:00
qc, err := qcompile.Compile(gql, ro.Name)
if err != nil {
return nil, err
}
2019-11-25 08:22:33 +01:00
stmts := []stmt{stmt{role: ro, qc: qc}}
w := &bytes.Buffer{}
skipped, err := pcompile.Compile(qc, w, psql.Variables(vm))
if err != nil {
return nil, err
}
2019-11-25 08:22:33 +01:00
stmts[0].skipped = skipped
stmts[0].sql = w.String()
return stmts, nil
}
func buildMultiStmt(gql, vars []byte) ([]stmt, error) {
var vm map[string]json.RawMessage
var err error
if len(vars) != 0 {
if err := json.Unmarshal(vars, &vm); err != nil {
return nil, err
}
}
if len(conf.RolesQuery) == 0 {
return nil, errors.New("roles_query not defined")
2019-11-25 08:22:33 +01:00
}
stmts := make([]stmt, 0, len(conf.Roles))
w := &bytes.Buffer{}
2019-11-25 08:22:33 +01:00
for i := 0; i < len(conf.Roles); i++ {
role := &conf.Roles[i]
// skip anon as it's not included in the combined multi-statement
if role.Name == "anon" {
continue
}
2019-11-25 08:22:33 +01:00
qc, err := qcompile.Compile(gql, role.Name)
2019-11-07 08:37:24 +01:00
if err != nil {
return nil, err
}
stmts = append(stmts, stmt{role: role, qc: qc})
2019-11-25 08:22:33 +01:00
skipped, err := pcompile.Compile(qc, w, psql.Variables(vm))
if err != nil {
return nil, err
}
2019-11-25 08:22:33 +01:00
s := &stmts[len(stmts)-1]
s.skipped = skipped
s.sql = w.String()
w.Reset()
}
2019-11-25 08:22:33 +01:00
sql, err := renderUserQuery(stmts, vm)
if err != nil {
return nil, err
}
2019-11-25 08:22:33 +01:00
stmts[0].sql = sql
return stmts, nil
}
//nolint: errcheck
2019-11-25 08:22:33 +01:00
func renderUserQuery(
stmts []stmt, vars map[string]json.RawMessage) (string, error) {
w := &bytes.Buffer{}
io.WriteString(w, `SELECT "_sg_auth_info"."role", (CASE "_sg_auth_info"."role" `)
for _, s := range stmts {
2019-11-25 08:22:33 +01:00
if len(s.role.Match) == 0 &&
s.role.Name != "user" && s.role.Name != "anon" {
continue
}
io.WriteString(w, `WHEN '`)
io.WriteString(w, s.role.Name)
io.WriteString(w, `' THEN (`)
io.WriteString(w, s.sql)
io.WriteString(w, `) `)
}
2019-11-25 08:22:33 +01:00
io.WriteString(w, `END) FROM (SELECT (CASE WHEN EXISTS (`)
io.WriteString(w, conf.RolesQuery)
io.WriteString(w, `) THEN `)
2019-11-25 08:22:33 +01:00
io.WriteString(w, `(SELECT (CASE`)
for _, s := range stmts {
if len(s.role.Match) == 0 {
continue
}
2019-11-25 08:22:33 +01:00
io.WriteString(w, ` WHEN `)
io.WriteString(w, s.role.Match)
io.WriteString(w, ` THEN '`)
io.WriteString(w, s.role.Name)
io.WriteString(w, `'`)
}
2019-11-25 08:22:33 +01:00
io.WriteString(w, ` ELSE 'user' END) FROM (`)
io.WriteString(w, conf.RolesQuery)
io.WriteString(w, `) AS "_sg_auth_roles_query" LIMIT 1) `)
io.WriteString(w, `ELSE 'anon' END) FROM (VALUES (1)) AS "_sg_auth_filler") AS "_sg_auth_info"(role) LIMIT 1; `)
2019-11-25 08:22:33 +01:00
return w.String(), nil
}
2019-11-25 08:22:33 +01:00
func hasTablesWithConfig(qc *qcode.QCode, role *configRole) bool {
for _, id := range qc.Roots {
t, err := schema.GetTable(qc.Selects[id].Name)
2019-11-25 08:22:33 +01:00
if err != nil {
return false
}
if _, ok := role.tablesMap[t.Name]; !ok {
return false
}
}
2019-11-25 08:22:33 +01:00
return true
}