Optimize prepared statement flow for RBAC
This commit is contained in:
@ -182,6 +182,8 @@ func (al *allowList) load() {
|
||||
item.vars = varBytes
|
||||
}
|
||||
|
||||
//fmt.Println("%%", item.gql, string(item.vars))
|
||||
|
||||
al.list[gqlHash(q, varBytes, "")] = item
|
||||
varBytes = nil
|
||||
|
||||
|
@ -7,9 +7,9 @@ import (
|
||||
)
|
||||
|
||||
var (
|
||||
userIDProviderKey = struct{}{}
|
||||
userIDKey = struct{}{}
|
||||
userRoleKey = struct{}{}
|
||||
userIDProviderKey = "user_id_provider"
|
||||
userIDKey = "user_id"
|
||||
userRoleKey = "user_role"
|
||||
)
|
||||
|
||||
func headerAuth(next http.HandlerFunc) http.HandlerFunc {
|
||||
|
@ -122,14 +122,14 @@ func railsCookieHandler(next http.HandlerFunc) http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ck, err := r.Cookie(cookie)
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Send()
|
||||
logger.Warn().Err(err).Msg("rails cookie missing")
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
userID, err := ra.ParseCookie(ck.Value)
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Send()
|
||||
logger.Warn().Err(err).Msg("failed to parse rails cookie")
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
@ -210,6 +210,8 @@ func initConf() (*config, error) {
|
||||
c.Roles = append(c.Roles, configRole{Name: "anon"})
|
||||
}
|
||||
|
||||
c.Validate()
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
|
@ -168,6 +168,32 @@ func newConfig() *viper.Viper {
|
||||
return vi
|
||||
}
|
||||
|
||||
func (c *config) Validate() {
|
||||
rm := make(map[string]struct{})
|
||||
|
||||
for i := range c.Roles {
|
||||
name := strings.ToLower(c.Roles[i].Name)
|
||||
if _, ok := rm[name]; ok {
|
||||
logger.Fatal().Msgf("duplicate config for role '%s'", c.Roles[i].Name)
|
||||
}
|
||||
rm[name] = struct{}{}
|
||||
}
|
||||
|
||||
tm := make(map[string]struct{})
|
||||
|
||||
for i := range c.Tables {
|
||||
name := strings.ToLower(c.Tables[i].Name)
|
||||
if _, ok := tm[name]; ok {
|
||||
logger.Fatal().Msgf("duplicate config for table '%s'", c.Tables[i].Name)
|
||||
}
|
||||
tm[name] = struct{}{}
|
||||
}
|
||||
|
||||
if len(c.RolesQuery) == 0 {
|
||||
logger.Warn().Msgf("no 'roles_query' defined.")
|
||||
}
|
||||
}
|
||||
|
||||
func (c *config) getAliasMap() map[string][]string {
|
||||
m := make(map[string][]string, len(c.Tables))
|
||||
|
||||
|
14
serv/core.go
14
serv/core.go
@ -131,16 +131,20 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) {
|
||||
}
|
||||
|
||||
var role string
|
||||
useRoleQuery := len(conf.RolesQuery) != 0 && isMutation(c.req.Query)
|
||||
mutation := isMutation(c.req.Query)
|
||||
useRoleQuery := len(conf.RolesQuery) != 0 && mutation
|
||||
|
||||
if useRoleQuery {
|
||||
if role, err = c.executeRoleQuery(tx); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
} else if v := c.Value(userRoleKey); v != nil {
|
||||
role = v.(string)
|
||||
} else {
|
||||
|
||||
} else if mutation {
|
||||
role = c.req.role
|
||||
|
||||
}
|
||||
|
||||
ps, ok := _preparedList[gqlHash(c.req.Query, c.req.Vars, role)]
|
||||
@ -151,7 +155,11 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *preparedItem, error) {
|
||||
var root []byte
|
||||
vars := varList(c, ps.args)
|
||||
|
||||
err = tx.QueryRow(c, ps.stmt.SQL, vars...).Scan(&root)
|
||||
if mutation {
|
||||
err = tx.QueryRow(c, ps.stmt.SQL, vars...).Scan(&root)
|
||||
} else {
|
||||
err = tx.QueryRow(c, ps.stmt.SQL, vars...).Scan(&c.req.role, &root)
|
||||
}
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
@ -37,8 +37,8 @@ type gqlReq struct {
|
||||
type variables map[string]json.RawMessage
|
||||
|
||||
type gqlResp struct {
|
||||
Error string `json:"error,omitempty"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
Error string `json:"message,omitempty"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
Extensions *extensions `json:"extensions,omitempty"`
|
||||
}
|
||||
|
||||
@ -102,7 +102,9 @@ func apiv1Http(w http.ResponseWriter, r *http.Request) {
|
||||
err = ctx.handleReq(w, r)
|
||||
|
||||
if err == errUnauthorized {
|
||||
http.Error(w, "Not authorized", 401)
|
||||
w.WriteHeader(http.StatusUnauthorized)
|
||||
json.NewEncoder(w).Encode(gqlResp{Error: err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
|
@ -31,6 +31,7 @@ func initPreparedList() {
|
||||
}
|
||||
|
||||
for _, v := range _allowList.list {
|
||||
|
||||
err := prepareStmt(v.gql, v.vars)
|
||||
if err != nil {
|
||||
logger.Warn().Str("gql", v.gql).Err(err).Send()
|
||||
@ -52,6 +53,10 @@ func prepareStmt(gql string, varBytes json.RawMessage) error {
|
||||
return err
|
||||
}
|
||||
|
||||
if len(stmts) != 0 && stmts[0].qc.Type == qcode.QTQuery {
|
||||
c.req.Vars = nil
|
||||
}
|
||||
|
||||
for _, s := range stmts {
|
||||
if len(s.sql) == 0 {
|
||||
continue
|
||||
@ -75,9 +80,9 @@ func prepareStmt(gql string, varBytes json.RawMessage) error {
|
||||
var key string
|
||||
|
||||
if s.role == nil {
|
||||
key = gqlHash(gql, varBytes, "")
|
||||
key = gqlHash(gql, c.req.Vars, "")
|
||||
} else {
|
||||
key = gqlHash(gql, varBytes, s.role.Name)
|
||||
key = gqlHash(gql, c.req.Vars, s.role.Name)
|
||||
}
|
||||
|
||||
_preparedList[key] = &preparedItem{
|
||||
|
@ -24,13 +24,26 @@ func mkkey(h *xxhash.Digest, k1 string, k2 string) uint64 {
|
||||
func gqlHash(b string, vars []byte, role string) string {
|
||||
b = strings.TrimSpace(b)
|
||||
h := sha1.New()
|
||||
query := "query"
|
||||
|
||||
s, e := 0, 0
|
||||
space := []byte{' '}
|
||||
starting := true
|
||||
|
||||
var b0, b1 byte
|
||||
|
||||
for {
|
||||
if starting && b[e] == 'q' {
|
||||
n := 0
|
||||
se := e
|
||||
for e < len(b) && n < len(query) && b[e] == query[n] {
|
||||
n++
|
||||
e++
|
||||
}
|
||||
if n != len(query) {
|
||||
io.WriteString(h, strings.ToLower(b[se:e]))
|
||||
}
|
||||
}
|
||||
if ws(b[e]) {
|
||||
for e < len(b) && ws(b[e]) {
|
||||
e++
|
||||
@ -42,6 +55,7 @@ func gqlHash(b string, vars []byte, role string) string {
|
||||
h.Write(space)
|
||||
}
|
||||
} else {
|
||||
starting = false
|
||||
s = e
|
||||
for e < len(b) && ws(b[e]) == false {
|
||||
e++
|
||||
|
@ -5,7 +5,7 @@ import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRelaxHash1(t *testing.T) {
|
||||
func TestGQLHash1(t *testing.T) {
|
||||
var v1 = `
|
||||
products(
|
||||
limit: 30,
|
||||
@ -24,15 +24,15 @@ func TestRelaxHash1(t *testing.T) {
|
||||
price
|
||||
} `
|
||||
|
||||
h1 := gqlHash(v1, nil)
|
||||
h2 := gqlHash(v2, nil)
|
||||
h1 := gqlHash(v1, nil, "")
|
||||
h2 := gqlHash(v2, nil, "")
|
||||
|
||||
if strings.Compare(h1, h2) != 0 {
|
||||
t.Fatal("Hashes don't match they should")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRelaxHash2(t *testing.T) {
|
||||
func TestGQLHash2(t *testing.T) {
|
||||
var v1 = `
|
||||
{
|
||||
products(
|
||||
@ -53,15 +53,15 @@ func TestRelaxHash2(t *testing.T) {
|
||||
|
||||
var v2 = ` { products( limit: 30, order_by: { price: desc }, distinct: [ price ] where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) { id name price user { id email } } } `
|
||||
|
||||
h1 := gqlHash(v1, nil)
|
||||
h2 := gqlHash(v2, nil)
|
||||
h1 := gqlHash(v1, nil, "")
|
||||
h2 := gqlHash(v2, nil, "")
|
||||
|
||||
if strings.Compare(h1, h2) != 0 {
|
||||
t.Fatal("Hashes don't match they should")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRelaxHash3(t *testing.T) {
|
||||
func TestGQLHash3(t *testing.T) {
|
||||
var v1 = `users {
|
||||
id
|
||||
email
|
||||
@ -86,15 +86,44 @@ func TestRelaxHash3(t *testing.T) {
|
||||
}
|
||||
`
|
||||
|
||||
h1 := gqlHash(v1, nil)
|
||||
h2 := gqlHash(v2, nil)
|
||||
h1 := gqlHash(v1, nil, "")
|
||||
h2 := gqlHash(v2, nil, "")
|
||||
|
||||
if strings.Compare(h1, h2) != 0 {
|
||||
t.Fatal("Hashes don't match they should")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRelaxHashWithVars1(t *testing.T) {
|
||||
func TestGQLHash4(t *testing.T) {
|
||||
var v1 = `
|
||||
query {
|
||||
products(
|
||||
limit: 30
|
||||
order_by: { price: desc }
|
||||
distinct: [price]
|
||||
where: { id: { and: { greater_or_equals: 20, lt: 28 } } }
|
||||
) {
|
||||
id
|
||||
name
|
||||
price
|
||||
user {
|
||||
id
|
||||
email
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
var v2 = ` { products( limit: 30, order_by: { price: desc }, distinct: [ price ] where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) { id name price user { id email } } } `
|
||||
|
||||
h1 := gqlHash(v1, nil, "")
|
||||
h2 := gqlHash(v2, nil, "")
|
||||
|
||||
if strings.Compare(h1, h2) != 0 {
|
||||
t.Fatal("Hashes don't match they should")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGQLHashWithVars1(t *testing.T) {
|
||||
var q1 = `
|
||||
products(
|
||||
limit: 30,
|
||||
@ -136,15 +165,15 @@ func TestRelaxHashWithVars1(t *testing.T) {
|
||||
"user": 123
|
||||
}`
|
||||
|
||||
h1 := gqlHash(q1, []byte(v1))
|
||||
h2 := gqlHash(q2, []byte(v2))
|
||||
h1 := gqlHash(q1, []byte(v1), "user")
|
||||
h2 := gqlHash(q2, []byte(v2), "user")
|
||||
|
||||
if strings.Compare(h1, h2) != 0 {
|
||||
t.Fatal("Hashes don't match they should")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRelaxHashWithVars2(t *testing.T) {
|
||||
func TestGQLHashWithVars2(t *testing.T) {
|
||||
var q1 = `
|
||||
products(
|
||||
limit: 30,
|
||||
@ -193,8 +222,8 @@ func TestRelaxHashWithVars2(t *testing.T) {
|
||||
"user": 123
|
||||
}`
|
||||
|
||||
h1 := gqlHash(q1, []byte(v1))
|
||||
h2 := gqlHash(q2, []byte(v2))
|
||||
h1 := gqlHash(q1, []byte(v1), "user")
|
||||
h2 := gqlHash(q2, []byte(v2), "user")
|
||||
|
||||
if strings.Compare(h1, h2) != 0 {
|
||||
t.Fatal("Hashes don't match they should")
|
||||
|
Reference in New Issue
Block a user