Get RBAC working for queries and mutations
This commit is contained in:
@ -182,7 +182,7 @@ func (al *allowList) load() {
|
||||
item.vars = varBytes
|
||||
}
|
||||
|
||||
al.list[gqlHash(q, varBytes)] = item
|
||||
al.list[gqlHash(q, varBytes, "")] = item
|
||||
varBytes = nil
|
||||
|
||||
} else if ty == AL_VARS {
|
||||
@ -203,7 +203,11 @@ func (al *allowList) save(item *allowItem) {
|
||||
if al.active == false {
|
||||
return
|
||||
}
|
||||
al.list[gqlHash(item.gql, item.vars)] = item
|
||||
h := gqlHash(item.gql, item.vars, "")
|
||||
if _, ok := al.list[h]; ok {
|
||||
return
|
||||
}
|
||||
al.list[gqlHash(item.gql, item.vars, "")] = item
|
||||
|
||||
f, err := os.Create(al.filepath)
|
||||
if err != nil {
|
||||
|
34
serv/auth.go
34
serv/auth.go
@ -9,26 +9,40 @@ import (
|
||||
var (
|
||||
userIDProviderKey = struct{}{}
|
||||
userIDKey = struct{}{}
|
||||
userRoleKey = struct{}{}
|
||||
)
|
||||
|
||||
func headerAuth(r *http.Request, c *config) *http.Request {
|
||||
if len(c.Auth.Header) == 0 {
|
||||
return nil
|
||||
}
|
||||
func headerAuth(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
userID := r.Header.Get(c.Auth.Header)
|
||||
if len(userID) != 0 {
|
||||
ctx := context.WithValue(r.Context(), userIDKey, userID)
|
||||
return r.WithContext(ctx)
|
||||
}
|
||||
userIDProvider := r.Header.Get("X-User-ID-Provider")
|
||||
if len(userIDProvider) != 0 {
|
||||
ctx = context.WithValue(ctx, userIDProviderKey, userIDProvider)
|
||||
}
|
||||
|
||||
return nil
|
||||
userID := r.Header.Get("X-User-ID")
|
||||
if len(userID) != 0 {
|
||||
ctx = context.WithValue(ctx, userIDKey, userID)
|
||||
}
|
||||
|
||||
userRole := r.Header.Get("X-User-Role")
|
||||
if len(userRole) != 0 {
|
||||
ctx = context.WithValue(ctx, userRoleKey, userRole)
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
func withAuth(next http.HandlerFunc) http.HandlerFunc {
|
||||
at := conf.Auth.Type
|
||||
ru := conf.Auth.Rails.URL
|
||||
|
||||
if conf.Auth.CredsInHeader {
|
||||
next = headerAuth(next)
|
||||
}
|
||||
|
||||
switch at {
|
||||
case "rails":
|
||||
if strings.HasPrefix(ru, "memcache:") {
|
||||
|
@ -58,11 +58,6 @@ func jwtHandler(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
var tok string
|
||||
|
||||
if rn := headerAuth(r, conf); rn != nil {
|
||||
next.ServeHTTP(w, rn)
|
||||
return
|
||||
}
|
||||
|
||||
if len(cookie) != 0 {
|
||||
ck, err := r.Cookie(cookie)
|
||||
if err != nil {
|
||||
@ -102,7 +97,6 @@ func jwtHandler(next http.HandlerFunc) http.HandlerFunc {
|
||||
}
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
}
|
||||
|
@ -42,11 +42,6 @@ func railsRedisHandler(next http.HandlerFunc) http.HandlerFunc {
|
||||
}
|
||||
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if rn := headerAuth(r, conf); rn != nil {
|
||||
next.ServeHTTP(w, rn)
|
||||
return
|
||||
}
|
||||
|
||||
ck, err := r.Cookie(cookie)
|
||||
if err != nil {
|
||||
next.ServeHTTP(w, r)
|
||||
@ -83,17 +78,12 @@ func railsMemcacheHandler(next http.HandlerFunc) http.HandlerFunc {
|
||||
|
||||
rURL, err := url.Parse(conf.Auth.Rails.URL)
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err)
|
||||
logger.Fatal().Err(err).Send()
|
||||
}
|
||||
|
||||
mc := memcache.New(rURL.Host)
|
||||
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if rn := headerAuth(r, conf); rn != nil {
|
||||
next.ServeHTTP(w, rn)
|
||||
return
|
||||
}
|
||||
|
||||
ck, err := r.Cookie(cookie)
|
||||
if err != nil {
|
||||
next.ServeHTTP(w, r)
|
||||
@ -126,25 +116,20 @@ func railsCookieHandler(next http.HandlerFunc) http.HandlerFunc {
|
||||
|
||||
ra, err := railsAuth(conf)
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err)
|
||||
logger.Fatal().Err(err).Send()
|
||||
}
|
||||
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
if rn := headerAuth(r, conf); rn != nil {
|
||||
next.ServeHTTP(w, rn)
|
||||
return
|
||||
}
|
||||
|
||||
ck, err := r.Cookie(cookie)
|
||||
if err != nil {
|
||||
logger.Error().Err(err)
|
||||
logger.Warn().Err(err).Send()
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
userID, err := ra.ParseCookie(ck.Value)
|
||||
if err != nil {
|
||||
logger.Error().Err(err)
|
||||
logger.Warn().Err(err).Send()
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
27
serv/cmd.go
27
serv/cmd.go
@ -183,7 +183,32 @@ func initConf() (*config, error) {
|
||||
}
|
||||
zerolog.SetGlobalLevel(logLevel)
|
||||
|
||||
//fmt.Printf("%#v", c)
|
||||
for k, v := range c.DB.Vars {
|
||||
c.DB.Vars[k] = sanitize(v)
|
||||
}
|
||||
|
||||
c.RolesQuery = sanitize(c.RolesQuery)
|
||||
|
||||
rolesMap := make(map[string]struct{})
|
||||
|
||||
for i := range c.Roles {
|
||||
role := &c.Roles[i]
|
||||
|
||||
if _, ok := rolesMap[role.Name]; ok {
|
||||
logger.Fatal().Msgf("duplicate role '%s' found", role.Name)
|
||||
}
|
||||
role.Name = sanitize(role.Name)
|
||||
role.Match = sanitize(role.Match)
|
||||
rolesMap[role.Name] = struct{}{}
|
||||
}
|
||||
|
||||
if _, ok := rolesMap["user"]; !ok {
|
||||
c.Roles = append(c.Roles, configRole{Name: "user"})
|
||||
}
|
||||
|
||||
if _, ok := rolesMap["anon"]; !ok {
|
||||
c.Roles = append(c.Roles, configRole{Name: "anon"})
|
||||
}
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
@ -66,8 +66,9 @@ func graphQLFunc(query string, data interface{}) map[string]interface{} {
|
||||
c := &coreContext{Context: context.Background()}
|
||||
c.req.Query = query
|
||||
c.req.Vars = b
|
||||
c.req.role = "user"
|
||||
|
||||
res, err := c.execQuery("user")
|
||||
res, err := c.execQuery()
|
||||
if err != nil {
|
||||
logger.Fatal().Err(err).Msg("graphql query failed")
|
||||
}
|
||||
|
@ -1,7 +1,9 @@
|
||||
package serv
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
@ -24,9 +26,9 @@ type config struct {
|
||||
Inflections map[string]string
|
||||
|
||||
Auth struct {
|
||||
Type string
|
||||
Cookie string
|
||||
Header string
|
||||
Type string
|
||||
Cookie string
|
||||
CredsInHeader bool `mapstructure:"creds_in_header"`
|
||||
|
||||
Rails struct {
|
||||
Version string
|
||||
@ -60,7 +62,7 @@ type config struct {
|
||||
MaxRetries int `mapstructure:"max_retries"`
|
||||
LogLevel string `mapstructure:"log_level"`
|
||||
|
||||
vars map[string][]byte `mapstructure:"variables"`
|
||||
Vars map[string]string `mapstructure:"variables"`
|
||||
|
||||
Defaults struct {
|
||||
Filter []string
|
||||
@ -71,7 +73,9 @@ type config struct {
|
||||
} `mapstructure:"database"`
|
||||
|
||||
Tables []configTable
|
||||
Roles []configRoles
|
||||
|
||||
RolesQuery string `mapstructure:"roles_query"`
|
||||
Roles []configRole
|
||||
}
|
||||
|
||||
type configTable struct {
|
||||
@ -94,8 +98,9 @@ type configRemote struct {
|
||||
} `mapstructure:"set_headers"`
|
||||
}
|
||||
|
||||
type configRoles struct {
|
||||
type configRole struct {
|
||||
Name string
|
||||
Match string
|
||||
Tables []struct {
|
||||
Name string
|
||||
|
||||
@ -163,26 +168,6 @@ func newConfig() *viper.Viper {
|
||||
return vi
|
||||
}
|
||||
|
||||
func (c *config) getVariables() map[string]string {
|
||||
vars := make(map[string]string, len(c.DB.vars))
|
||||
|
||||
for k, v := range c.DB.vars {
|
||||
isVar := false
|
||||
|
||||
for i := range v {
|
||||
if v[i] == '$' {
|
||||
isVar = true
|
||||
} else if v[i] == ' ' {
|
||||
isVar = false
|
||||
} else if isVar && v[i] >= 'a' && v[i] <= 'z' {
|
||||
v[i] = 'A' + (v[i] - 'a')
|
||||
}
|
||||
}
|
||||
vars[k] = string(v)
|
||||
}
|
||||
return vars
|
||||
}
|
||||
|
||||
func (c *config) getAliasMap() map[string][]string {
|
||||
m := make(map[string][]string, len(c.Tables))
|
||||
|
||||
@ -198,3 +183,21 @@ func (c *config) getAliasMap() map[string][]string {
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
var varRe1 = regexp.MustCompile(`(?mi)\$([a-zA-Z0-9_.]+)`)
|
||||
var varRe2 = regexp.MustCompile(`\{\{([a-zA-Z0-9_.]+)\}\}`)
|
||||
|
||||
func sanitize(s string) string {
|
||||
s0 := varRe1.ReplaceAllString(s, `{{$1}}`)
|
||||
|
||||
s1 := strings.Map(func(r rune) rune {
|
||||
if unicode.IsSpace(r) {
|
||||
return ' '
|
||||
}
|
||||
return r
|
||||
}, s0)
|
||||
|
||||
return varRe2.ReplaceAllStringFunc(s1, func(m string) string {
|
||||
return strings.ToLower(m)
|
||||
})
|
||||
}
|
||||
|
293
serv/core.go
293
serv/core.go
@ -13,8 +13,8 @@ import (
|
||||
|
||||
"github.com/cespare/xxhash/v2"
|
||||
"github.com/dosco/super-graph/jsn"
|
||||
"github.com/dosco/super-graph/psql"
|
||||
"github.com/dosco/super-graph/qcode"
|
||||
"github.com/jackc/pgx/v4"
|
||||
"github.com/valyala/fasttemplate"
|
||||
)
|
||||
|
||||
@ -32,15 +32,13 @@ func (c *coreContext) handleReq(w io.Writer, req *http.Request) error {
|
||||
c.req.ref = req.Referer()
|
||||
c.req.hdr = req.Header
|
||||
|
||||
var role string
|
||||
|
||||
if authCheck(c) {
|
||||
role = "user"
|
||||
c.req.role = "user"
|
||||
} else {
|
||||
role = "anon"
|
||||
c.req.role = "anon"
|
||||
}
|
||||
|
||||
b, err := c.execQuery(role)
|
||||
b, err := c.execQuery()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -48,18 +46,18 @@ func (c *coreContext) handleReq(w io.Writer, req *http.Request) error {
|
||||
return c.render(w, b)
|
||||
}
|
||||
|
||||
func (c *coreContext) execQuery(role string) ([]byte, error) {
|
||||
func (c *coreContext) execQuery() ([]byte, error) {
|
||||
var err error
|
||||
var skipped uint32
|
||||
var qc *qcode.QCode
|
||||
var data []byte
|
||||
|
||||
logger.Debug().Str("role", role).Msg(c.req.Query)
|
||||
logger.Debug().Str("role", c.req.role).Msg(c.req.Query)
|
||||
|
||||
if conf.UseAllowList {
|
||||
var ps *preparedItem
|
||||
|
||||
data, ps, err = c.resolvePreparedSQL(c.req.Query)
|
||||
data, ps, err = c.resolvePreparedSQL()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -69,12 +67,7 @@ func (c *coreContext) execQuery(role string) ([]byte, error) {
|
||||
|
||||
} else {
|
||||
|
||||
qc, err = qcompile.Compile([]byte(c.req.Query), role)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
data, skipped, err = c.resolveSQL(qc)
|
||||
data, skipped, err = c.resolveSQL()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@ -122,6 +115,152 @@ func (c *coreContext) execQuery(role string) ([]byte, error) {
|
||||
return ob.Bytes(), nil
|
||||
}
|
||||
|
||||
func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) {
|
||||
tx, err := db.Begin(c)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer tx.Rollback(c)
|
||||
|
||||
if v := c.Value(userIDKey); v != nil {
|
||||
_, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
|
||||
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
var role string
|
||||
useRoleQuery := len(conf.RolesQuery) != 0 && isMutation(c.req.Query)
|
||||
|
||||
if useRoleQuery {
|
||||
if role, err = c.executeRoleQuery(tx); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
} else if v := c.Value(userRoleKey); v != nil {
|
||||
role = v.(string)
|
||||
} else {
|
||||
role = c.req.role
|
||||
}
|
||||
|
||||
ps, ok := _preparedList[gqlHash(c.req.Query, c.req.Vars, role)]
|
||||
if !ok {
|
||||
return nil, nil, errUnauthorized
|
||||
}
|
||||
|
||||
var root []byte
|
||||
vars := varList(c, ps.args)
|
||||
|
||||
err = tx.QueryRow(c, ps.stmt.SQL, vars...).Scan(&root)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if err := tx.Commit(c); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
return root, ps, nil
|
||||
}
|
||||
|
||||
func (c *coreContext) resolveSQL() ([]byte, uint32, error) {
|
||||
tx, err := db.Begin(c)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer tx.Rollback(c)
|
||||
|
||||
mutation := isMutation(c.req.Query)
|
||||
useRoleQuery := len(conf.RolesQuery) != 0 && mutation
|
||||
|
||||
if useRoleQuery {
|
||||
if c.req.role, err = c.executeRoleQuery(tx); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
} else if v := c.Value(userRoleKey); v != nil {
|
||||
c.req.role = v.(string)
|
||||
}
|
||||
|
||||
stmts, err := c.buildStmt()
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
var st *stmt
|
||||
|
||||
if mutation {
|
||||
st = findStmt(c.req.role, stmts)
|
||||
} else {
|
||||
st = &stmts[0]
|
||||
}
|
||||
|
||||
t := fasttemplate.New(st.sql, openVar, closeVar)
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
_, err = t.ExecuteFunc(buf, varMap(c))
|
||||
|
||||
if err == errNoUserID &&
|
||||
authFailBlock == authFailBlockPerQuery &&
|
||||
authCheck(c) == false {
|
||||
return nil, 0, errUnauthorized
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
finalSQL := buf.String()
|
||||
|
||||
var stime time.Time
|
||||
|
||||
if conf.EnableTracing {
|
||||
stime = time.Now()
|
||||
}
|
||||
|
||||
if v := c.Value(userIDKey); v != nil {
|
||||
_, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
|
||||
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
}
|
||||
|
||||
var root []byte
|
||||
|
||||
if mutation {
|
||||
err = tx.QueryRow(c, finalSQL).Scan(&root)
|
||||
} else {
|
||||
err = tx.QueryRow(c, finalSQL).Scan(&c.req.role, &root)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
if err := tx.Commit(c); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
if mutation {
|
||||
st = findStmt(c.req.role, stmts)
|
||||
} else {
|
||||
st = &stmts[0]
|
||||
}
|
||||
|
||||
if conf.EnableTracing && len(st.qc.Selects) != 0 {
|
||||
c.addTrace(
|
||||
st.qc.Selects,
|
||||
st.qc.Selects[0].ID,
|
||||
stime)
|
||||
}
|
||||
|
||||
if conf.UseAllowList == false {
|
||||
_allowList.add(&c.req)
|
||||
}
|
||||
|
||||
return root, st.skipped, nil
|
||||
}
|
||||
|
||||
func (c *coreContext) resolveRemote(
|
||||
hdr http.Header,
|
||||
h *xxhash.Digest,
|
||||
@ -269,125 +408,15 @@ func (c *coreContext) resolveRemotes(
|
||||
return to, cerr
|
||||
}
|
||||
|
||||
func (c *coreContext) resolvePreparedSQL(gql string) ([]byte, *preparedItem, error) {
|
||||
ps, ok := _preparedList[gqlHash(gql, c.req.Vars)]
|
||||
if !ok {
|
||||
return nil, nil, errUnauthorized
|
||||
func (c *coreContext) executeRoleQuery(tx pgx.Tx) (string, error) {
|
||||
var role string
|
||||
row := tx.QueryRow(c, "_sg_get_role", c.req.role, 1)
|
||||
|
||||
if err := row.Scan(&role); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
var root []byte
|
||||
vars := varList(c, ps.args)
|
||||
|
||||
tx, err := db.Begin(c)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
defer tx.Rollback(c)
|
||||
|
||||
if v := c.Value(userIDKey); v != nil {
|
||||
_, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
|
||||
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
err = tx.QueryRow(c, ps.stmt.SQL, vars...).Scan(&root)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
if err := tx.Commit(c); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
fmt.Printf("PRE: %v\n", ps.stmt)
|
||||
|
||||
return root, ps, nil
|
||||
}
|
||||
|
||||
func (c *coreContext) resolveSQL(qc *qcode.QCode) ([]byte, uint32, error) {
|
||||
var vars map[string]json.RawMessage
|
||||
stmt := &bytes.Buffer{}
|
||||
|
||||
if len(c.req.Vars) != 0 {
|
||||
if err := json.Unmarshal(c.req.Vars, &vars); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
}
|
||||
|
||||
skipped, err := pcompile.Compile(qc, stmt, psql.Variables(vars))
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
t := fasttemplate.New(stmt.String(), openVar, closeVar)
|
||||
|
||||
stmt.Reset()
|
||||
_, err = t.ExecuteFunc(stmt, varMap(c))
|
||||
|
||||
if err == errNoUserID &&
|
||||
authFailBlock == authFailBlockPerQuery &&
|
||||
authCheck(c) == false {
|
||||
return nil, 0, errUnauthorized
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
finalSQL := stmt.String()
|
||||
|
||||
// if conf.LogLevel == "debug" {
|
||||
// os.Stdout.WriteString(finalSQL)
|
||||
// os.Stdout.WriteString("\n\n")
|
||||
// }
|
||||
|
||||
var st time.Time
|
||||
|
||||
if conf.EnableTracing {
|
||||
st = time.Now()
|
||||
}
|
||||
|
||||
tx, err := db.Begin(c)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
defer tx.Rollback(c)
|
||||
|
||||
if v := c.Value(userIDKey); v != nil {
|
||||
_, err = tx.Exec(c, fmt.Sprintf(`SET LOCAL "user.id" = %s;`, v))
|
||||
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
}
|
||||
|
||||
//fmt.Printf("\nRAW: %#v\n", finalSQL)
|
||||
|
||||
var root []byte
|
||||
|
||||
err = tx.QueryRow(c, finalSQL).Scan(&root)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
if err := tx.Commit(c); err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
if conf.EnableTracing && len(qc.Selects) != 0 {
|
||||
c.addTrace(
|
||||
qc.Selects,
|
||||
qc.Selects[0].ID,
|
||||
st)
|
||||
}
|
||||
|
||||
if conf.UseAllowList == false {
|
||||
_allowList.add(&c.req)
|
||||
}
|
||||
|
||||
return root, skipped, nil
|
||||
return role, nil
|
||||
}
|
||||
|
||||
func (c *coreContext) render(w io.Writer, data []byte) error {
|
||||
|
144
serv/core_build.go
Normal file
144
serv/core_build.go
Normal file
@ -0,0 +1,144 @@
|
||||
package serv
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
|
||||
"github.com/dosco/super-graph/psql"
|
||||
"github.com/dosco/super-graph/qcode"
|
||||
)
|
||||
|
||||
type stmt struct {
|
||||
role *configRole
|
||||
qc *qcode.QCode
|
||||
skipped uint32
|
||||
sql string
|
||||
}
|
||||
|
||||
func (c *coreContext) buildStmt() ([]stmt, error) {
|
||||
var vars map[string]json.RawMessage
|
||||
|
||||
if len(c.req.Vars) != 0 {
|
||||
if err := json.Unmarshal(c.req.Vars, &vars); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
gql := []byte(c.req.Query)
|
||||
|
||||
if len(conf.Roles) == 0 {
|
||||
return nil, errors.New(`no roles found ('user' and 'anon' required)`)
|
||||
}
|
||||
|
||||
qc, err := qcompile.Compile(gql, conf.Roles[0].Name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
stmts := make([]stmt, 0, len(conf.Roles))
|
||||
mutation := (qc.Type != qcode.QTQuery)
|
||||
w := &bytes.Buffer{}
|
||||
|
||||
for i := range conf.Roles {
|
||||
role := &conf.Roles[i]
|
||||
|
||||
if mutation && len(c.req.role) != 0 && role.Name != c.req.role {
|
||||
continue
|
||||
}
|
||||
|
||||
if i > 0 {
|
||||
qc, err = qcompile.Compile(gql, role.Name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
stmts = append(stmts, stmt{role: role, qc: qc})
|
||||
|
||||
if mutation {
|
||||
skipped, err := pcompile.Compile(qc, w, psql.Variables(vars))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
s := &stmts[len(stmts)-1]
|
||||
s.skipped = skipped
|
||||
s.sql = w.String()
|
||||
w.Reset()
|
||||
}
|
||||
}
|
||||
|
||||
if mutation {
|
||||
return stmts, nil
|
||||
}
|
||||
|
||||
io.WriteString(w, `SELECT "_sg_auth_info"."role", (CASE "_sg_auth_info"."role" `)
|
||||
|
||||
for _, s := range stmts {
|
||||
io.WriteString(w, `WHEN '`)
|
||||
io.WriteString(w, s.role.Name)
|
||||
io.WriteString(w, `' THEN (`)
|
||||
|
||||
s.skipped, err = pcompile.Compile(s.qc, w, psql.Variables(vars))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
io.WriteString(w, `) `)
|
||||
}
|
||||
io.WriteString(w, `END) FROM (`)
|
||||
|
||||
if len(conf.RolesQuery) == 0 {
|
||||
v := c.Value(userRoleKey)
|
||||
|
||||
io.WriteString(w, `VALUES ("`)
|
||||
if v != nil {
|
||||
io.WriteString(w, v.(string))
|
||||
} else {
|
||||
io.WriteString(w, c.req.role)
|
||||
}
|
||||
io.WriteString(w, `")) AS "_sg_auth_info"(role) LIMIT 1;`)
|
||||
|
||||
} else {
|
||||
|
||||
io.WriteString(w, `SELECT (CASE WHEN EXISTS (`)
|
||||
io.WriteString(w, conf.RolesQuery)
|
||||
io.WriteString(w, `) THEN `)
|
||||
|
||||
io.WriteString(w, `(SELECT (CASE`)
|
||||
for _, s := range stmts {
|
||||
if len(s.role.Match) == 0 {
|
||||
continue
|
||||
}
|
||||
io.WriteString(w, ` WHEN `)
|
||||
io.WriteString(w, s.role.Match)
|
||||
io.WriteString(w, ` THEN '`)
|
||||
io.WriteString(w, s.role.Name)
|
||||
io.WriteString(w, `'`)
|
||||
}
|
||||
|
||||
if len(c.req.role) == 0 {
|
||||
io.WriteString(w, ` ELSE 'anon' END) FROM (`)
|
||||
} else {
|
||||
io.WriteString(w, ` ELSE '`)
|
||||
io.WriteString(w, c.req.role)
|
||||
io.WriteString(w, `' END) FROM (`)
|
||||
}
|
||||
|
||||
io.WriteString(w, conf.RolesQuery)
|
||||
io.WriteString(w, `) AS "_sg_auth_roles_query" LIMIT 1) ELSE '`)
|
||||
if len(c.req.role) == 0 {
|
||||
io.WriteString(w, `anon`)
|
||||
} else {
|
||||
io.WriteString(w, c.req.role)
|
||||
}
|
||||
io.WriteString(w, `' END) FROM (VALUES (1)) AS "_sg_auth_filler") AS "_sg_auth_info"(role) LIMIT 1; `)
|
||||
}
|
||||
|
||||
stmts[0].sql = w.String()
|
||||
stmts[0].role = nil
|
||||
|
||||
return stmts, nil
|
||||
}
|
@ -30,6 +30,7 @@ type gqlReq struct {
|
||||
Query string `json:"query"`
|
||||
Vars json.RawMessage `json:"variables"`
|
||||
ref string
|
||||
role string
|
||||
hdr http.Header
|
||||
}
|
||||
|
||||
@ -101,13 +102,11 @@ func apiv1Http(w http.ResponseWriter, r *http.Request) {
|
||||
err = ctx.handleReq(w, r)
|
||||
|
||||
if err == errUnauthorized {
|
||||
err := "Not authorized"
|
||||
logger.Debug().Msg(err)
|
||||
http.Error(w, err, 401)
|
||||
http.Error(w, "Not authorized", 401)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logger.Err(err).Msg("Failed to handle request")
|
||||
logger.Err(err).Msg("failed to handle request")
|
||||
errorResp(w, err)
|
||||
}
|
||||
}
|
||||
|
125
serv/prepare.go
125
serv/prepare.go
@ -7,7 +7,6 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/dosco/super-graph/psql"
|
||||
"github.com/dosco/super-graph/qcode"
|
||||
"github.com/jackc/pgconn"
|
||||
"github.com/valyala/fasttemplate"
|
||||
@ -27,55 +26,100 @@ var (
|
||||
func initPreparedList() {
|
||||
_preparedList = make(map[string]*preparedItem)
|
||||
|
||||
for k, v := range _allowList.list {
|
||||
err := prepareStmt(k, v.gql, v.vars)
|
||||
if err := prepareRoleStmt(); err != nil {
|
||||
logger.Fatal().Err(err).Msg("failed to prepare get role statement")
|
||||
}
|
||||
|
||||
for _, v := range _allowList.list {
|
||||
err := prepareStmt(v.gql, v.vars)
|
||||
if err != nil {
|
||||
logger.Warn().Str("gql", v.gql).Err(err).Send()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func prepareStmt(key, gql string, varBytes json.RawMessage) error {
|
||||
if len(gql) == 0 || len(key) == 0 {
|
||||
func prepareStmt(gql string, varBytes json.RawMessage) error {
|
||||
if len(gql) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
qc, err := qcompile.Compile([]byte(gql), "user")
|
||||
c := &coreContext{Context: context.Background()}
|
||||
c.req.Query = gql
|
||||
c.req.Vars = varBytes
|
||||
|
||||
stmts, err := c.buildStmt()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var vars map[string]json.RawMessage
|
||||
for _, s := range stmts {
|
||||
if len(s.sql) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if len(varBytes) != 0 {
|
||||
vars = make(map[string]json.RawMessage)
|
||||
finalSQL, am := processTemplate(s.sql)
|
||||
|
||||
if err := json.Unmarshal(varBytes, &vars); err != nil {
|
||||
ctx := context.Background()
|
||||
|
||||
tx, err := db.Begin(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
pstmt, err := tx.Prepare(ctx, "", finalSQL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var key string
|
||||
|
||||
if s.role == nil {
|
||||
key = gqlHash(gql, varBytes, "")
|
||||
} else {
|
||||
key = gqlHash(gql, varBytes, s.role.Name)
|
||||
}
|
||||
|
||||
_preparedList[key] = &preparedItem{
|
||||
stmt: pstmt,
|
||||
args: am,
|
||||
skipped: s.skipped,
|
||||
qc: s.qc,
|
||||
}
|
||||
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
buf := &bytes.Buffer{}
|
||||
return nil
|
||||
}
|
||||
|
||||
skipped, err := pcompile.Compile(qc, buf, psql.Variables(vars))
|
||||
if err != nil {
|
||||
return err
|
||||
func prepareRoleStmt() error {
|
||||
if len(conf.RolesQuery) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
t := fasttemplate.New(buf.String(), `{{`, `}}`)
|
||||
am := make([][]byte, 0, 5)
|
||||
i := 0
|
||||
w := &bytes.Buffer{}
|
||||
|
||||
finalSQL := t.ExecuteFuncString(func(w io.Writer, tag string) (int, error) {
|
||||
am = append(am, []byte(tag))
|
||||
i++
|
||||
return w.Write([]byte(fmt.Sprintf("$%d", i)))
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
io.WriteString(w, `SELECT (CASE`)
|
||||
for _, role := range conf.Roles {
|
||||
if len(role.Match) == 0 {
|
||||
continue
|
||||
}
|
||||
io.WriteString(w, ` WHEN `)
|
||||
io.WriteString(w, role.Match)
|
||||
io.WriteString(w, ` THEN '`)
|
||||
io.WriteString(w, role.Name)
|
||||
io.WriteString(w, `'`)
|
||||
}
|
||||
|
||||
io.WriteString(w, ` ELSE {{role}} END) FROM (`)
|
||||
io.WriteString(w, conf.RolesQuery)
|
||||
io.WriteString(w, `) AS "_sg_auth_roles_query"`)
|
||||
|
||||
roleSQL, _ := processTemplate(w.String())
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
tx, err := db.Begin(ctx)
|
||||
@ -84,21 +128,28 @@ func prepareStmt(key, gql string, varBytes json.RawMessage) error {
|
||||
}
|
||||
defer tx.Rollback(ctx)
|
||||
|
||||
pstmt, err := tx.Prepare(ctx, "", finalSQL)
|
||||
_, err = tx.Prepare(ctx, "_sg_get_role", roleSQL)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_preparedList[key] = &preparedItem{
|
||||
stmt: pstmt,
|
||||
args: am,
|
||||
skipped: skipped,
|
||||
qc: qc,
|
||||
}
|
||||
|
||||
if err := tx.Commit(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func processTemplate(tmpl string) (string, [][]byte) {
|
||||
t := fasttemplate.New(tmpl, `{{`, `}}`)
|
||||
am := make([][]byte, 0, 5)
|
||||
i := 0
|
||||
|
||||
vmap := make(map[string]int)
|
||||
|
||||
return t.ExecuteFuncString(func(w io.Writer, tag string) (int, error) {
|
||||
if n, ok := vmap[tag]; ok {
|
||||
return w.Write([]byte(fmt.Sprintf("$%d", n)))
|
||||
}
|
||||
am = append(am, []byte(tag))
|
||||
i++
|
||||
vmap[tag] = i
|
||||
return w.Write([]byte(fmt.Sprintf("$%d", i)))
|
||||
}), am
|
||||
}
|
||||
|
@ -67,7 +67,7 @@ func initCompilers(c *config) (*qcode.Compiler, *psql.Compiler, error) {
|
||||
|
||||
pc := psql.NewCompiler(psql.Config{
|
||||
Schema: schema,
|
||||
Vars: c.getVariables(),
|
||||
Vars: c.DB.Vars,
|
||||
})
|
||||
|
||||
return qc, pc, nil
|
||||
|
@ -21,7 +21,7 @@ func mkkey(h *xxhash.Digest, k1 string, k2 string) uint64 {
|
||||
return v
|
||||
}
|
||||
|
||||
func gqlHash(b string, vars []byte) string {
|
||||
func gqlHash(b string, vars []byte, role string) string {
|
||||
b = strings.TrimSpace(b)
|
||||
h := sha1.New()
|
||||
|
||||
@ -56,6 +56,10 @@ func gqlHash(b string, vars []byte) string {
|
||||
}
|
||||
}
|
||||
|
||||
if len(role) != 0 {
|
||||
io.WriteString(h, role)
|
||||
}
|
||||
|
||||
if vars == nil || len(vars) == 0 {
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
@ -80,3 +84,26 @@ func ws(b byte) bool {
|
||||
func al(b byte) bool {
|
||||
return (b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z') || (b >= '0' && b <= '9')
|
||||
}
|
||||
|
||||
func isMutation(sql string) bool {
|
||||
for i := range sql {
|
||||
b := sql[i]
|
||||
if b == '{' {
|
||||
return false
|
||||
}
|
||||
if al(b) {
|
||||
return (b == 'm' || b == 'M')
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func findStmt(role string, stmts []stmt) *stmt {
|
||||
for i := range stmts {
|
||||
if stmts[i].role.Name != role {
|
||||
continue
|
||||
}
|
||||
return &stmts[i]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
24
serv/vars.go
24
serv/vars.go
@ -11,17 +11,27 @@ import (
|
||||
func varMap(ctx *coreContext) func(w io.Writer, tag string) (int, error) {
|
||||
return func(w io.Writer, tag string) (int, error) {
|
||||
switch tag {
|
||||
case "user_id":
|
||||
if v := ctx.Value(userIDKey); v != nil {
|
||||
return stringVar(w, v.(string))
|
||||
}
|
||||
return 0, errNoUserID
|
||||
|
||||
case "user_id_provider":
|
||||
if v := ctx.Value(userIDProviderKey); v != nil {
|
||||
return stringVar(w, v.(string))
|
||||
}
|
||||
return 0, errNoUserID
|
||||
io.WriteString(w, "null")
|
||||
return 0, nil
|
||||
|
||||
case "user_id":
|
||||
if v := ctx.Value(userIDKey); v != nil {
|
||||
return stringVar(w, v.(string))
|
||||
}
|
||||
|
||||
io.WriteString(w, "null")
|
||||
return 0, nil
|
||||
|
||||
case "user_role":
|
||||
if v := ctx.Value(userRoleKey); v != nil {
|
||||
return stringVar(w, v.(string))
|
||||
}
|
||||
io.WriteString(w, "null")
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
fields := jsn.Get(ctx.req.Vars, [][]byte{[]byte(tag)})
|
||||
|
Reference in New Issue
Block a user