Add role based access control

This commit is contained in:
Vikram Rangnekar
2019-10-14 02:51:36 -04:00
parent 85a74ed30c
commit deb5b93c81
13 changed files with 645 additions and 350 deletions

99
qcode/config.go Normal file
View File

@ -0,0 +1,99 @@
package qcode
type Config struct {
Blocklist []string
KeepArgs bool
}
type QueryConfig struct {
Limit int
Filter []string
Columns []string
DisableFunctions bool
}
type InsertConfig struct {
Filter []string
Columns []string
Set map[string]string
}
type UpdateConfig struct {
Filter []string
Columns []string
Set map[string]string
}
type DeleteConfig struct {
Filter []string
Columns []string
}
type TRConfig struct {
Query QueryConfig
Insert InsertConfig
Update UpdateConfig
Delete DeleteConfig
}
type trval struct {
query struct {
limit string
fil *Exp
cols map[string]struct{}
disable struct {
funcs bool
}
}
insert struct {
fil *Exp
cols map[string]struct{}
set map[string]string
}
update struct {
fil *Exp
cols map[string]struct{}
set map[string]string
}
delete struct {
fil *Exp
cols map[string]struct{}
}
}
func (trv *trval) allowedColumns(qt QType) map[string]struct{} {
switch qt {
case QTQuery:
return trv.query.cols
case QTInsert:
return trv.insert.cols
case QTUpdate:
return trv.update.cols
case QTDelete:
return trv.insert.cols
case QTUpsert:
return trv.insert.cols
}
return nil
}
func (trv *trval) filter(qt QType) *Exp {
switch qt {
case QTQuery:
return trv.query.fil
case QTInsert:
return trv.insert.fil
case QTUpdate:
return trv.update.fil
case QTDelete:
return trv.delete.fil
case QTUpsert:
return trv.insert.fil
}
return nil
}

View File

@ -5,7 +5,7 @@ func FuzzerEntrypoint(data []byte) int {
//testData := string(data)
qcompile, _ := NewCompiler(Config{})
_, err := qcompile.Compile(data)
_, err := qcompile.Compile(data, "user")
if err != nil {
return -1
}

View File

@ -46,13 +46,18 @@ func compareOp(op1, op2 Operation) error {
*/
func TestCompile1(t *testing.T) {
qcompile, _ := NewCompiler(Config{})
qc, _ := NewCompiler(Config{})
qc.AddRole("user", "product", TRConfig{
Query: QueryConfig{
Columns: []string{"id", "Name"},
},
})
_, err := qcompile.Compile([]byte(`
_, err := qc.Compile([]byte(`
product(id: 15) {
id
name
}`))
}`), "user")
if err != nil {
t.Fatal(err)
@ -60,13 +65,18 @@ func TestCompile1(t *testing.T) {
}
func TestCompile2(t *testing.T) {
qcompile, _ := NewCompiler(Config{})
qc, _ := NewCompiler(Config{})
qc.AddRole("user", "product", TRConfig{
Query: QueryConfig{
Columns: []string{"ID"},
},
})
_, err := qcompile.Compile([]byte(`
_, err := qc.Compile([]byte(`
query { product(id: 15) {
id
name
} }`))
} }`), "user")
if err != nil {
t.Fatal(err)
@ -74,15 +84,20 @@ func TestCompile2(t *testing.T) {
}
func TestCompile3(t *testing.T) {
qcompile, _ := NewCompiler(Config{})
qc, _ := NewCompiler(Config{})
qc.AddRole("user", "product", TRConfig{
Query: QueryConfig{
Columns: []string{"ID"},
},
})
_, err := qcompile.Compile([]byte(`
_, err := qc.Compile([]byte(`
mutation {
product(id: 15, name: "Test") {
id
name
}
}`))
}`), "user")
if err != nil {
t.Fatal(err)
@ -91,7 +106,7 @@ func TestCompile3(t *testing.T) {
func TestInvalidCompile1(t *testing.T) {
qcompile, _ := NewCompiler(Config{})
_, err := qcompile.Compile([]byte(`#`))
_, err := qcompile.Compile([]byte(`#`), "user")
if err == nil {
t.Fatal(errors.New("expecting an error"))
@ -100,7 +115,7 @@ func TestInvalidCompile1(t *testing.T) {
func TestInvalidCompile2(t *testing.T) {
qcompile, _ := NewCompiler(Config{})
_, err := qcompile.Compile([]byte(`{u(where:{not:0})}`))
_, err := qcompile.Compile([]byte(`{u(where:{not:0})}`), "user")
if err == nil {
t.Fatal(errors.New("expecting an error"))
@ -109,7 +124,7 @@ func TestInvalidCompile2(t *testing.T) {
func TestEmptyCompile(t *testing.T) {
qcompile, _ := NewCompiler(Config{})
_, err := qcompile.Compile([]byte(``))
_, err := qcompile.Compile([]byte(``), "user")
if err == nil {
t.Fatal(errors.New("expecting an error"))
@ -144,7 +159,7 @@ func BenchmarkQCompile(b *testing.B) {
b.ReportAllocs()
for n := 0; n < b.N; n++ {
_, err := qcompile.Compile(gql)
_, err := qcompile.Compile(gql, "user")
if err != nil {
b.Fatal(err)
@ -160,7 +175,7 @@ func BenchmarkQCompileP(b *testing.B) {
b.RunParallel(func(pb *testing.PB) {
for pb.Next() {
_, err := qcompile.Compile(gql)
_, err := qcompile.Compile(gql, "user")
if err != nil {
b.Fatal(err)

View File

@ -3,6 +3,7 @@ package qcode
import (
"errors"
"fmt"
"strconv"
"strings"
"sync"
@ -17,23 +18,16 @@ const (
maxSelectors = 30
QTQuery QType = iota + 1
QTMutation
ActionInsert Action = iota + 1
ActionUpdate
ActionDelete
ActionUpsert
QTInsert
QTUpdate
QTDelete
QTUpsert
)
type QCode struct {
Type QType
Selects []Select
}
type Column struct {
Table string
Name string
FieldName string
Type QType
ActionVar string
Selects []Select
}
type Select struct {
@ -47,9 +41,15 @@ type Select struct {
OrderBy []*OrderBy
DistinctOn []string
Paging Paging
Action Action
ActionVar string
Children []int32
Functions bool
Allowed map[string]struct{}
}
type Column struct {
Table string
Name string
FieldName string
}
type Exp struct {
@ -77,8 +77,9 @@ type OrderBy struct {
}
type Paging struct {
Limit string
Offset string
Limit string
Offset string
NoLimit bool
}
type ExpOp int
@ -145,81 +146,23 @@ const (
OrderDescNullsLast
)
type Filters struct {
All map[string][]string
Query map[string][]string
Insert map[string][]string
Update map[string][]string
Delete map[string][]string
}
type Config struct {
DefaultFilter []string
FilterMap Filters
Blocklist []string
KeepArgs bool
}
type Compiler struct {
df *Exp
fm struct {
all map[string]*Exp
query map[string]*Exp
insert map[string]*Exp
update map[string]*Exp
delete map[string]*Exp
}
tr map[string]map[string]*trval
bl map[string]struct{}
ka bool
}
var opMap = map[parserType]QType{
opQuery: QTQuery,
opMutate: QTMutation,
}
var expPool = sync.Pool{
New: func() interface{} { return &Exp{doFree: true} },
}
func NewCompiler(c Config) (*Compiler, error) {
var err error
co := &Compiler{ka: c.KeepArgs}
co.tr = make(map[string]map[string]*trval)
co.bl = make(map[string]struct{}, len(c.Blocklist))
for i := range c.Blocklist {
co.bl[c.Blocklist[i]] = struct{}{}
}
co.df, err = compileFilter(c.DefaultFilter)
if err != nil {
return nil, err
}
co.fm.all, err = buildFilters(c.FilterMap.All)
if err != nil {
return nil, err
}
co.fm.query, err = buildFilters(c.FilterMap.Query)
if err != nil {
return nil, err
}
co.fm.insert, err = buildFilters(c.FilterMap.Insert)
if err != nil {
return nil, err
}
co.fm.update, err = buildFilters(c.FilterMap.Update)
if err != nil {
return nil, err
}
co.fm.delete, err = buildFilters(c.FilterMap.Delete)
if err != nil {
return nil, err
co.bl[strings.ToLower(c.Blocklist[i])] = struct{}{}
}
seedExp := [100]Exp{}
@ -232,58 +175,99 @@ func NewCompiler(c Config) (*Compiler, error) {
return co, nil
}
func buildFilters(filMap map[string][]string) (map[string]*Exp, error) {
fm := make(map[string]*Exp, len(filMap))
func (com *Compiler) AddRole(role, table string, trc TRConfig) error {
var err error
trv := &trval{}
for k, v := range filMap {
fil, err := compileFilter(v)
if err != nil {
return nil, err
toMap := func(cols []string) map[string]struct{} {
m := make(map[string]struct{}, len(cols))
for i := range cols {
m[strings.ToLower(cols[i])] = struct{}{}
}
singular := flect.Singularize(k)
plural := flect.Pluralize(k)
fm[singular] = fil
fm[plural] = fil
return m
}
return fm, nil
// query config
trv.query.fil, err = compileFilter(trc.Query.Filter)
if err != nil {
return err
}
if trc.Query.Limit > 0 {
trv.query.limit = strconv.Itoa(trc.Query.Limit)
}
trv.query.cols = toMap(trc.Query.Columns)
trv.query.disable.funcs = trc.Query.DisableFunctions
// insert config
if trv.insert.fil, err = compileFilter(trc.Insert.Filter); err != nil {
return err
}
trv.insert.cols = toMap(trc.Insert.Columns)
// update config
if trv.update.fil, err = compileFilter(trc.Update.Filter); err != nil {
return err
}
trv.insert.cols = toMap(trc.Insert.Columns)
trv.insert.set = trc.Insert.Set
// delete config
if trv.delete.fil, err = compileFilter(trc.Delete.Filter); err != nil {
return err
}
trv.delete.cols = toMap(trc.Delete.Columns)
singular := flect.Singularize(table)
plural := flect.Pluralize(table)
if _, ok := com.tr[role]; !ok {
com.tr[role] = make(map[string]*trval)
}
com.tr[role][singular] = trv
com.tr[role][plural] = trv
return nil
}
func (com *Compiler) Compile(query []byte) (*QCode, error) {
var qc QCode
func (com *Compiler) Compile(query []byte, role string) (*QCode, error) {
var err error
qc := QCode{Type: QTQuery}
op, err := Parse(query)
if err != nil {
return nil, err
}
qc.Selects, err = com.compileQuery(op)
if err != nil {
if err = com.compileQuery(&qc, op, role); err != nil {
return nil, err
}
if t, ok := opMap[op.Type]; ok {
qc.Type = t
} else {
return nil, fmt.Errorf("Unknown operation type %d", op.Type)
}
opPool.Put(op)
return &qc, nil
}
func (com *Compiler) compileQuery(op *Operation) ([]Select, error) {
func (com *Compiler) compileQuery(qc *QCode, op *Operation, role string) error {
id := int32(0)
parentID := int32(0)
if len(op.Fields) == 0 {
return errors.New("invalid graphql no query found")
}
if op.Type == opMutate {
if err := com.setMutationType(qc, op.Fields[0].Args); err != nil {
return err
}
}
selects := make([]Select, 0, 5)
st := NewStack()
action := qc.Type
if len(op.Fields) == 0 {
return nil, errors.New("empty query")
return errors.New("empty query")
}
st.Push(op.Fields[0].ID)
@ -293,7 +277,7 @@ func (com *Compiler) compileQuery(op *Operation) ([]Select, error) {
}
if id >= maxSelectors {
return nil, fmt.Errorf("selector limit reached (%d)", maxSelectors)
return fmt.Errorf("selector limit reached (%d)", maxSelectors)
}
fid := st.Pop()
@ -303,14 +287,28 @@ func (com *Compiler) compileQuery(op *Operation) ([]Select, error) {
continue
}
trv, ok := com.tr[role][field.Name]
if !ok {
continue
}
selects = append(selects, Select{
ID: id,
ParentID: parentID,
Table: field.Name,
Children: make([]int32, 0, 5),
Allowed: trv.allowedColumns(action),
})
s := &selects[(len(selects) - 1)]
if action == QTQuery {
s.Functions = !trv.query.disable.funcs
if len(trv.query.limit) != 0 {
s.Paging.Limit = trv.query.limit
}
}
if s.ID != 0 {
p := &selects[s.ParentID]
p.Children = append(p.Children, s.ID)
@ -322,12 +320,13 @@ func (com *Compiler) compileQuery(op *Operation) ([]Select, error) {
s.FieldName = s.Table
}
err := com.compileArgs(s, field.Args)
err := com.compileArgs(qc, s, field.Args)
if err != nil {
return nil, err
return err
}
s.Cols = make([]Column, 0, len(field.Children))
action = QTQuery
for _, cid := range field.Children {
f := op.Fields[cid]
@ -356,36 +355,14 @@ func (com *Compiler) compileQuery(op *Operation) ([]Select, error) {
}
if id == 0 {
return nil, errors.New("invalid query")
return errors.New("invalid query")
}
var fil *Exp
root := &selects[0]
switch op.Type {
case opQuery:
fil, _ = com.fm.query[root.Table]
case opMutate:
switch root.Action {
case ActionInsert:
fil, _ = com.fm.insert[root.Table]
case ActionUpdate:
fil, _ = com.fm.update[root.Table]
case ActionDelete:
fil, _ = com.fm.delete[root.Table]
case ActionUpsert:
fil, _ = com.fm.insert[root.Table]
}
}
if fil == nil {
fil, _ = com.fm.all[root.Table]
}
if fil == nil {
fil = com.df
if trv, ok := com.tr[role][op.Fields[0].Name]; ok {
fil = trv.filter(qc.Type)
}
if fil != nil && fil.Op != OpNop {
@ -403,10 +380,11 @@ func (com *Compiler) compileQuery(op *Operation) ([]Select, error) {
}
}
return selects[:id], nil
qc.Selects = selects[:id]
return nil
}
func (com *Compiler) compileArgs(sel *Select, args []Arg) error {
func (com *Compiler) compileArgs(qc *QCode, sel *Select, args []Arg) error {
var err error
if com.ka {
@ -418,9 +396,7 @@ func (com *Compiler) compileArgs(sel *Select, args []Arg) error {
switch arg.Name {
case "id":
if sel.ID == 0 {
err = com.compileArgID(sel, arg)
}
err = com.compileArgID(sel, arg)
case "search":
err = com.compileArgSearch(sel, arg)
case "where":
@ -433,18 +409,6 @@ func (com *Compiler) compileArgs(sel *Select, args []Arg) error {
err = com.compileArgLimit(sel, arg)
case "offset":
err = com.compileArgOffset(sel, arg)
case "insert":
sel.Action = ActionInsert
err = com.compileArgAction(sel, arg)
case "update":
sel.Action = ActionUpdate
err = com.compileArgAction(sel, arg)
case "upsert":
sel.Action = ActionUpsert
err = com.compileArgAction(sel, arg)
case "delete":
sel.Action = ActionDelete
err = com.compileArgAction(sel, arg)
}
if err != nil {
@ -461,6 +425,45 @@ func (com *Compiler) compileArgs(sel *Select, args []Arg) error {
return nil
}
func (com *Compiler) setMutationType(qc *QCode, args []Arg) error {
setActionVar := func(arg *Arg) error {
if arg.Val.Type != nodeVar {
return fmt.Errorf("value for argument '%s' must be a variable", arg.Name)
}
qc.ActionVar = arg.Val.Val
return nil
}
for i := range args {
arg := &args[i]
switch arg.Name {
case "insert":
qc.Type = QTInsert
return setActionVar(arg)
case "update":
qc.Type = QTUpdate
return setActionVar(arg)
case "upsert":
qc.Type = QTUpsert
return setActionVar(arg)
case "delete":
qc.Type = QTDelete
if arg.Val.Type != nodeBool {
return fmt.Errorf("value for argument '%s' must be a boolean", arg.Name)
}
if arg.Val.Val == "false" {
qc.Type = QTQuery
}
return nil
}
}
return nil
}
func (com *Compiler) compileArgObj(st *util.Stack, arg *Arg) (*Exp, error) {
if arg.Val.Type != nodeObj {
return nil, fmt.Errorf("expecting an object")
@ -540,6 +543,10 @@ func (com *Compiler) compileArgNode(st *util.Stack, node *Node, usePool bool) (*
}
func (com *Compiler) compileArgID(sel *Select, arg *Arg) error {
if sel.ID != 0 {
return nil
}
if sel.Where != nil && sel.Where.Op == OpEqID {
return nil
}
@ -732,26 +739,6 @@ func (com *Compiler) compileArgOffset(sel *Select, arg *Arg) error {
return nil
}
func (com *Compiler) compileArgAction(sel *Select, arg *Arg) error {
switch sel.Action {
case ActionDelete:
if arg.Val.Type != nodeBool {
return fmt.Errorf("value for argument '%s' must be a boolean", arg.Name)
}
if arg.Val.Val == "false" {
sel.Action = 0
}
default:
if arg.Val.Type != nodeVar {
return fmt.Errorf("value for argument '%s' must be a variable", arg.Name)
}
sel.ActionVar = arg.Val.Val
}
return nil
}
func newExp(st *util.Stack, node *Node, usePool bool) (*Exp, error) {
name := node.Name
if name[0] == '_' {