2019-10-24 08:07:42 +02:00
|
|
|
package serv
|
|
|
|
|
|
|
|
import (
|
|
|
|
"bytes"
|
|
|
|
"encoding/json"
|
|
|
|
"errors"
|
2019-11-25 08:22:33 +01:00
|
|
|
"fmt"
|
2019-10-24 08:07:42 +02:00
|
|
|
"io"
|
|
|
|
|
|
|
|
"github.com/dosco/super-graph/psql"
|
|
|
|
"github.com/dosco/super-graph/qcode"
|
|
|
|
)
|
|
|
|
|
|
|
|
type stmt struct {
|
|
|
|
role *configRole
|
|
|
|
qc *qcode.QCode
|
|
|
|
skipped uint32
|
|
|
|
sql string
|
|
|
|
}
|
|
|
|
|
2019-11-25 08:22:33 +01:00
|
|
|
func buildStmt(qt qcode.QType, gql, vars []byte, role string) ([]stmt, error) {
|
|
|
|
switch qt {
|
|
|
|
case qcode.QTMutation:
|
|
|
|
return buildRoleStmt(gql, vars, role)
|
2019-10-24 08:07:42 +02:00
|
|
|
|
2019-11-25 08:22:33 +01:00
|
|
|
case qcode.QTQuery:
|
2019-12-09 07:48:18 +01:00
|
|
|
if role == "anon" {
|
|
|
|
return buildRoleStmt(gql, vars, "anon")
|
|
|
|
}
|
2019-11-25 08:22:33 +01:00
|
|
|
|
2019-12-10 06:03:44 +01:00
|
|
|
if conf.isABACEnabled() {
|
2019-11-25 08:22:33 +01:00
|
|
|
return buildMultiStmt(gql, vars)
|
|
|
|
}
|
|
|
|
|
2019-12-09 07:48:18 +01:00
|
|
|
return buildRoleStmt(gql, vars, "user")
|
|
|
|
|
2019-11-25 08:22:33 +01:00
|
|
|
default:
|
|
|
|
return nil, fmt.Errorf("unknown query type '%d'", qt)
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
func buildRoleStmt(gql, vars []byte, role string) ([]stmt, error) {
|
|
|
|
ro, ok := conf.roles[role]
|
|
|
|
if !ok {
|
|
|
|
return nil, fmt.Errorf(`roles '%s' not defined in config`, role)
|
|
|
|
}
|
|
|
|
|
|
|
|
var vm map[string]json.RawMessage
|
|
|
|
var err error
|
|
|
|
|
|
|
|
if len(vars) != 0 {
|
|
|
|
if err := json.Unmarshal(vars, &vm); err != nil {
|
2019-10-24 08:07:42 +02:00
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2019-11-25 08:22:33 +01:00
|
|
|
qc, err := qcompile.Compile(gql, ro.Name)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
2019-10-24 08:07:42 +02:00
|
|
|
|
2019-11-25 08:22:33 +01:00
|
|
|
// For the 'anon' role in production only compile
|
|
|
|
// queries for tables defined in the config file.
|
2019-11-28 07:25:46 +01:00
|
|
|
if conf.Production && ro.Name == "anon" && !hasTablesWithConfig(qc, ro) {
|
2019-11-25 08:22:33 +01:00
|
|
|
return nil, errors.New("query contains tables with no 'anon' role config")
|
2019-10-24 08:07:42 +02:00
|
|
|
}
|
|
|
|
|
2019-11-25 08:22:33 +01:00
|
|
|
stmts := []stmt{stmt{role: ro, qc: qc}}
|
|
|
|
w := &bytes.Buffer{}
|
|
|
|
|
|
|
|
skipped, err := pcompile.Compile(qc, w, psql.Variables(vm))
|
2019-10-24 08:07:42 +02:00
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
2019-11-25 08:22:33 +01:00
|
|
|
stmts[0].skipped = skipped
|
|
|
|
stmts[0].sql = w.String()
|
|
|
|
|
|
|
|
return stmts, nil
|
|
|
|
}
|
|
|
|
|
|
|
|
func buildMultiStmt(gql, vars []byte) ([]stmt, error) {
|
|
|
|
var vm map[string]json.RawMessage
|
|
|
|
var err error
|
|
|
|
|
|
|
|
if len(vars) != 0 {
|
|
|
|
if err := json.Unmarshal(vars, &vm); err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
|
|
|
if len(conf.RolesQuery) == 0 {
|
2020-03-06 11:08:54 +01:00
|
|
|
return nil, errors.New("roles_query not defined")
|
2019-11-25 08:22:33 +01:00
|
|
|
}
|
|
|
|
|
2019-10-24 08:07:42 +02:00
|
|
|
stmts := make([]stmt, 0, len(conf.Roles))
|
|
|
|
w := &bytes.Buffer{}
|
|
|
|
|
2019-11-25 08:22:33 +01:00
|
|
|
for i := 0; i < len(conf.Roles); i++ {
|
2019-10-24 08:07:42 +02:00
|
|
|
role := &conf.Roles[i]
|
|
|
|
|
2020-03-06 11:08:54 +01:00
|
|
|
// skip anon as it's not included in the combined multi-statement
|
2019-12-02 16:52:22 +01:00
|
|
|
if role.Name == "anon" {
|
|
|
|
continue
|
|
|
|
}
|
|
|
|
|
2019-11-25 08:22:33 +01:00
|
|
|
qc, err := qcompile.Compile(gql, role.Name)
|
2019-11-07 08:37:24 +01:00
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
|
2019-10-24 08:07:42 +02:00
|
|
|
stmts = append(stmts, stmt{role: role, qc: qc})
|
|
|
|
|
2019-11-25 08:22:33 +01:00
|
|
|
skipped, err := pcompile.Compile(qc, w, psql.Variables(vm))
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
2019-10-24 08:07:42 +02:00
|
|
|
}
|
2019-11-25 08:22:33 +01:00
|
|
|
|
|
|
|
s := &stmts[len(stmts)-1]
|
|
|
|
s.skipped = skipped
|
|
|
|
s.sql = w.String()
|
|
|
|
w.Reset()
|
2019-10-24 08:07:42 +02:00
|
|
|
}
|
|
|
|
|
2019-11-25 08:22:33 +01:00
|
|
|
sql, err := renderUserQuery(stmts, vm)
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
2019-10-24 08:07:42 +02:00
|
|
|
}
|
|
|
|
|
2019-11-25 08:22:33 +01:00
|
|
|
stmts[0].sql = sql
|
|
|
|
return stmts, nil
|
|
|
|
}
|
|
|
|
|
2019-11-28 07:25:46 +01:00
|
|
|
//nolint: errcheck
|
2019-11-25 08:22:33 +01:00
|
|
|
func renderUserQuery(
|
|
|
|
stmts []stmt, vars map[string]json.RawMessage) (string, error) {
|
|
|
|
w := &bytes.Buffer{}
|
|
|
|
|
2019-10-24 08:07:42 +02:00
|
|
|
io.WriteString(w, `SELECT "_sg_auth_info"."role", (CASE "_sg_auth_info"."role" `)
|
|
|
|
|
|
|
|
for _, s := range stmts {
|
2019-11-25 08:22:33 +01:00
|
|
|
if len(s.role.Match) == 0 &&
|
|
|
|
s.role.Name != "user" && s.role.Name != "anon" {
|
|
|
|
continue
|
|
|
|
}
|
2019-10-24 08:07:42 +02:00
|
|
|
io.WriteString(w, `WHEN '`)
|
|
|
|
io.WriteString(w, s.role.Name)
|
|
|
|
io.WriteString(w, `' THEN (`)
|
2019-12-02 16:52:22 +01:00
|
|
|
io.WriteString(w, s.sql)
|
2019-10-24 08:07:42 +02:00
|
|
|
io.WriteString(w, `) `)
|
|
|
|
}
|
|
|
|
|
2019-11-25 08:22:33 +01:00
|
|
|
io.WriteString(w, `END) FROM (SELECT (CASE WHEN EXISTS (`)
|
|
|
|
io.WriteString(w, conf.RolesQuery)
|
|
|
|
io.WriteString(w, `) THEN `)
|
2019-10-24 08:07:42 +02:00
|
|
|
|
2019-11-25 08:22:33 +01:00
|
|
|
io.WriteString(w, `(SELECT (CASE`)
|
|
|
|
for _, s := range stmts {
|
|
|
|
if len(s.role.Match) == 0 {
|
|
|
|
continue
|
2019-10-24 08:07:42 +02:00
|
|
|
}
|
2019-11-25 08:22:33 +01:00
|
|
|
io.WriteString(w, ` WHEN `)
|
|
|
|
io.WriteString(w, s.role.Match)
|
|
|
|
io.WriteString(w, ` THEN '`)
|
|
|
|
io.WriteString(w, s.role.Name)
|
|
|
|
io.WriteString(w, `'`)
|
2019-10-24 08:07:42 +02:00
|
|
|
}
|
|
|
|
|
2019-11-25 08:22:33 +01:00
|
|
|
io.WriteString(w, ` ELSE 'user' END) FROM (`)
|
|
|
|
io.WriteString(w, conf.RolesQuery)
|
|
|
|
io.WriteString(w, `) AS "_sg_auth_roles_query" LIMIT 1) `)
|
|
|
|
io.WriteString(w, `ELSE 'anon' END) FROM (VALUES (1)) AS "_sg_auth_filler") AS "_sg_auth_info"(role) LIMIT 1; `)
|
2019-10-24 08:07:42 +02:00
|
|
|
|
2019-11-25 08:22:33 +01:00
|
|
|
return w.String(), nil
|
2019-10-24 08:07:42 +02:00
|
|
|
}
|
2019-11-21 08:14:12 +01:00
|
|
|
|
2019-11-25 08:22:33 +01:00
|
|
|
func hasTablesWithConfig(qc *qcode.QCode, role *configRole) bool {
|
|
|
|
for _, id := range qc.Roots {
|
2019-11-29 07:38:23 +01:00
|
|
|
t, err := schema.GetTable(qc.Selects[id].Name)
|
2019-11-25 08:22:33 +01:00
|
|
|
if err != nil {
|
|
|
|
return false
|
|
|
|
}
|
|
|
|
if _, ok := role.tablesMap[t.Name]; !ok {
|
|
|
|
return false
|
2019-11-21 08:14:12 +01:00
|
|
|
}
|
|
|
|
}
|
2019-11-25 08:22:33 +01:00
|
|
|
return true
|
2019-11-21 08:14:12 +01:00
|
|
|
}
|