Optimize prepared statement flow for RBAC

This commit is contained in:
Vikram Rangnekar
2019-10-25 00:01:22 -04:00
parent 6bc66d28bc
commit 4edc15eb98
14 changed files with 136 additions and 46 deletions

View File

@ -182,6 +182,8 @@ func (al *allowList) load() {
item.vars = varBytes
}
//fmt.Println("%%", item.gql, string(item.vars))
al.list[gqlHash(q, varBytes, "")] = item
varBytes = nil

View File

@ -7,9 +7,9 @@ import (
)
var (
userIDProviderKey = struct{}{}
userIDKey = struct{}{}
userRoleKey = struct{}{}
userIDProviderKey = "user_id_provider"
userIDKey = "user_id"
userRoleKey = "user_role"
)
func headerAuth(next http.HandlerFunc) http.HandlerFunc {

View File

@ -122,14 +122,14 @@ func railsCookieHandler(next http.HandlerFunc) http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
ck, err := r.Cookie(cookie)
if err != nil {
logger.Warn().Err(err).Send()
logger.Warn().Err(err).Msg("rails cookie missing")
next.ServeHTTP(w, r)
return
}
userID, err := ra.ParseCookie(ck.Value)
if err != nil {
logger.Warn().Err(err).Send()
logger.Warn().Err(err).Msg("failed to parse rails cookie")
next.ServeHTTP(w, r)
return
}

View File

@ -210,6 +210,8 @@ func initConf() (*config, error) {
c.Roles = append(c.Roles, configRole{Name: "anon"})
}
c.Validate()
return c, nil
}

View File

@ -168,6 +168,32 @@ func newConfig() *viper.Viper {
return vi
}
func (c *config) Validate() {
rm := make(map[string]struct{})
for i := range c.Roles {
name := strings.ToLower(c.Roles[i].Name)
if _, ok := rm[name]; ok {
logger.Fatal().Msgf("duplicate config for role '%s'", c.Roles[i].Name)
}
rm[name] = struct{}{}
}
tm := make(map[string]struct{})
for i := range c.Tables {
name := strings.ToLower(c.Tables[i].Name)
if _, ok := tm[name]; ok {
logger.Fatal().Msgf("duplicate config for table '%s'", c.Tables[i].Name)
}
tm[name] = struct{}{}
}
if len(c.RolesQuery) == 0 {
logger.Warn().Msgf("no 'roles_query' defined.")
}
}
func (c *config) getAliasMap() map[string][]string {
m := make(map[string][]string, len(c.Tables))

View File

@ -131,16 +131,20 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) {
}
var role string
useRoleQuery := len(conf.RolesQuery) != 0 && isMutation(c.req.Query)
mutation := isMutation(c.req.Query)
useRoleQuery := len(conf.RolesQuery) != 0 && mutation
if useRoleQuery {
if role, err = c.executeRoleQuery(tx); err != nil {
return nil, nil, err
}
} else if v := c.Value(userRoleKey); v != nil {
role = v.(string)
} else {
} else if mutation {
role = c.req.role
}
ps, ok := _preparedList[gqlHash(c.req.Query, c.req.Vars, role)]
@ -151,7 +155,11 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) {
var root []byte
vars := varList(c, ps.args)
err = tx.QueryRow(c, ps.stmt.SQL, vars...).Scan(&root)
if mutation {
err = tx.QueryRow(c, ps.stmt.SQL, vars...).Scan(&root)
} else {
err = tx.QueryRow(c, ps.stmt.SQL, vars...).Scan(&c.req.role, &root)
}
if err != nil {
return nil, nil, err
}

View File

@ -37,8 +37,8 @@ type gqlReq struct {
type variables map[string]json.RawMessage
type gqlResp struct {
Error string `json:"error,omitempty"`
Data json.RawMessage `json:"data"`
Error string `json:"message,omitempty"`
Data json.RawMessage `json:"data,omitempty"`
Extensions *extensions `json:"extensions,omitempty"`
}
@ -102,7 +102,9 @@ func apiv1Http(w http.ResponseWriter, r *http.Request) {
err = ctx.handleReq(w, r)
if err == errUnauthorized {
http.Error(w, "Not authorized", 401)
w.WriteHeader(http.StatusUnauthorized)
json.NewEncoder(w).Encode(gqlResp{Error: err.Error()})
return
}
if err != nil {

View File

@ -31,6 +31,7 @@ func initPreparedList() {
}
for _, v := range _allowList.list {
err := prepareStmt(v.gql, v.vars)
if err != nil {
logger.Warn().Str("gql", v.gql).Err(err).Send()
@ -52,6 +53,10 @@ func prepareStmt(gql string, varBytes json.RawMessage) error {
return err
}
if len(stmts) != 0 && stmts[0].qc.Type == qcode.QTQuery {
c.req.Vars = nil
}
for _, s := range stmts {
if len(s.sql) == 0 {
continue
@ -75,9 +80,9 @@ func prepareStmt(gql string, varBytes json.RawMessage) error {
var key string
if s.role == nil {
key = gqlHash(gql, varBytes, "")
key = gqlHash(gql, c.req.Vars, "")
} else {
key = gqlHash(gql, varBytes, s.role.Name)
key = gqlHash(gql, c.req.Vars, s.role.Name)
}
_preparedList[key] = &preparedItem{

View File

@ -24,13 +24,26 @@ func mkkey(h *xxhash.Digest, k1 string, k2 string) uint64 {
func gqlHash(b string, vars []byte, role string) string {
b = strings.TrimSpace(b)
h := sha1.New()
query := "query"
s, e := 0, 0
space := []byte{' '}
starting := true
var b0, b1 byte
for {
if starting && b[e] == 'q' {
n := 0
se := e
for e < len(b) && n < len(query) && b[e] == query[n] {
n++
e++
}
if n != len(query) {
io.WriteString(h, strings.ToLower(b[se:e]))
}
}
if ws(b[e]) {
for e < len(b) && ws(b[e]) {
e++
@ -42,6 +55,7 @@ func gqlHash(b string, vars []byte, role string) string {
h.Write(space)
}
} else {
starting = false
s = e
for e < len(b) && ws(b[e]) == false {
e++

View File

@ -5,7 +5,7 @@ import (
"testing"
)
func TestRelaxHash1(t *testing.T) {
func TestGQLHash1(t *testing.T) {
var v1 = `
products(
limit: 30,
@ -24,15 +24,15 @@ func TestRelaxHash1(t *testing.T) {
price
} `
h1 := gqlHash(v1, nil)
h2 := gqlHash(v2, nil)
h1 := gqlHash(v1, nil, "")
h2 := gqlHash(v2, nil, "")
if strings.Compare(h1, h2) != 0 {
t.Fatal("Hashes don't match they should")
}
}
func TestRelaxHash2(t *testing.T) {
func TestGQLHash2(t *testing.T) {
var v1 = `
{
products(
@ -53,15 +53,15 @@ func TestRelaxHash2(t *testing.T) {
var v2 = ` { products( limit: 30, order_by: { price: desc }, distinct: [ price ] where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) { id name price user { id email } } } `
h1 := gqlHash(v1, nil)
h2 := gqlHash(v2, nil)
h1 := gqlHash(v1, nil, "")
h2 := gqlHash(v2, nil, "")
if strings.Compare(h1, h2) != 0 {
t.Fatal("Hashes don't match they should")
}
}
func TestRelaxHash3(t *testing.T) {
func TestGQLHash3(t *testing.T) {
var v1 = `users {
id
email
@ -86,15 +86,44 @@ func TestRelaxHash3(t *testing.T) {
}
`
h1 := gqlHash(v1, nil)
h2 := gqlHash(v2, nil)
h1 := gqlHash(v1, nil, "")
h2 := gqlHash(v2, nil, "")
if strings.Compare(h1, h2) != 0 {
t.Fatal("Hashes don't match they should")
}
}
func TestRelaxHashWithVars1(t *testing.T) {
func TestGQLHash4(t *testing.T) {
var v1 = `
query {
products(
limit: 30
order_by: { price: desc }
distinct: [price]
where: { id: { and: { greater_or_equals: 20, lt: 28 } } }
) {
id
name
price
user {
id
email
}
}
}`
var v2 = ` { products( limit: 30, order_by: { price: desc }, distinct: [ price ] where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) { id name price user { id email } } } `
h1 := gqlHash(v1, nil, "")
h2 := gqlHash(v2, nil, "")
if strings.Compare(h1, h2) != 0 {
t.Fatal("Hashes don't match they should")
}
}
func TestGQLHashWithVars1(t *testing.T) {
var q1 = `
products(
limit: 30,
@ -136,15 +165,15 @@ func TestRelaxHashWithVars1(t *testing.T) {
"user": 123
}`
h1 := gqlHash(q1, []byte(v1))
h2 := gqlHash(q2, []byte(v2))
h1 := gqlHash(q1, []byte(v1), "user")
h2 := gqlHash(q2, []byte(v2), "user")
if strings.Compare(h1, h2) != 0 {
t.Fatal("Hashes don't match they should")
}
}
func TestRelaxHashWithVars2(t *testing.T) {
func TestGQLHashWithVars2(t *testing.T) {
var q1 = `
products(
limit: 30,
@ -193,8 +222,8 @@ func TestRelaxHashWithVars2(t *testing.T) {
"user": 123
}`
h1 := gqlHash(q1, []byte(v1))
h2 := gqlHash(q2, []byte(v2))
h1 := gqlHash(q1, []byte(v1), "user")
h2 := gqlHash(q2, []byte(v2), "user")
if strings.Compare(h1, h2) != 0 {
t.Fatal("Hashes don't match they should")