Added support for query names to the allow.list

This commit is contained in:
Vikram Rangnekar 2019-11-26 01:36:19 -05:00
parent f518d5fc69
commit aff2a13ba4
8 changed files with 173 additions and 34 deletions

View File

@ -152,7 +152,7 @@ mutation {
}
}
query {
query getProducts {
products {
id
name

View File

@ -18,6 +18,8 @@ const (
)
type allowItem struct {
name string
hash string
uri string
gql string
vars json.RawMessage
@ -94,7 +96,7 @@ func initAllowList(cpath string) {
}
func (al *allowList) add(req *gqlReq) {
if al.active == false || len(req.ref) == 0 || len(req.Query) == 0 {
if len(req.ref) == 0 || len(req.Query) == 0 {
return
}
@ -119,11 +121,39 @@ func (al *allowList) add(req *gqlReq) {
}
}
func (al *allowList) load() {
if al.active == false {
return
func (al *allowList) upsert(query, vars []byte, uri string) {
q := string(query)
hash := gqlHash(q, vars, "")
name := gqlName(q)
var key string
if len(name) == 0 {
key = hash
} else {
key = name
}
if i, ok := al.index[key]; !ok {
al.list = append(al.list, &allowItem{
name: name,
hash: hash,
uri: uri,
gql: q,
vars: vars,
})
al.index[key] = len(al.list) - 1
} else {
item := al.list[i]
item.name = name
item.hash = hash
item.gql = q
item.vars = vars
}
}
func (al *allowList) load() {
b, err := ioutil.ReadFile(al.filepath)
if err != nil {
log.Fatal(err)
@ -172,21 +202,7 @@ func (al *allowList) load() {
if c == 0 {
if ty == AL_QUERY {
q := string(b[s:(e + 1)])
key := gqlHash(q, varBytes, "")
if idx, ok := al.index[key]; !ok {
al.list = append(al.list, &allowItem{
uri: uri,
gql: q,
vars: varBytes,
})
al.index[key] = len(al.list) - 1
} else {
item := al.list[idx]
item.gql = q
item.vars = varBytes
}
al.upsert(b[s:(e+1)], varBytes, uri)
varBytes = nil
} else if ty == AL_VARS {
@ -204,19 +220,33 @@ func (al *allowList) load() {
}
func (al *allowList) save(item *allowItem) {
if al.active == false {
return
item.hash = gqlHash(item.gql, item.vars, "")
item.name = gqlName(item.gql)
if len(item.name) == 0 {
key := item.hash
if _, ok := al.index[key]; ok {
return
}
al.list = append(al.list, item)
al.index[key] = len(al.list) - 1
} else {
key := item.name
if i, ok := al.index[key]; ok {
if al.list[i].hash == item.hash {
return
}
al.list[i] = item
} else {
al.list = append(al.list, item)
al.index[key] = len(al.list) - 1
}
}
key := gqlHash(item.gql, item.vars, "")
if _, ok := al.index[key]; ok {
return
}
al.list = append(al.list, item)
al.index[key] = len(al.list) - 1
f, err := os.Create(al.filepath)
if err != nil {
logger.Warn().Err(err).Msgf("Failed to write allow list: %s", al.filepath)

View File

@ -246,9 +246,9 @@ func (c *coreContext) resolveSQL() ([]byte, *stmt, error) {
}
}
// if conf.Production == false {
// _allowList.add(&c.req)
// }
if conf.Production == false {
_allowList.add(&c.req)
}
if len(stmts) > 1 {
if st = findStmt(role, stmts); st == nil {

View File

@ -4,6 +4,7 @@ package serv
func Fuzz(data []byte) int {
gql := string(data)
gqlName(gql)
gqlHash(gql, nil, "")
return 1

View File

@ -10,6 +10,7 @@ func TestFuzzCrashers(t *testing.T) {
}
for _, f := range crashers {
gqlName(f)
gqlHash(f, nil, "")
}
}

View File

@ -112,6 +112,12 @@ func prepareStmt(c context.Context, gql string, vars []byte) error {
}
}
if len(vars) == 0 {
logger.Debug().Msgf("Building prepared statement for:\n %s", gql)
} else {
logger.Debug().Msgf("Building prepared statement:\n %s\n%s", vars, gql)
}
if err := tx.Commit(c); err != nil {
return err
}

View File

@ -106,6 +106,30 @@ func al(b byte) bool {
return (b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z') || (b >= '0' && b <= '9')
}
func gqlName(b string) string {
state, s := 0, 0
for i := 0; i < len(b); i++ {
switch {
case state == 2 && b[i] == '{':
return b[s:i]
case state == 2 && b[i] == ' ':
return b[s:i]
case state == 1 && b[i] == '{':
return ""
case state == 1 && b[i] != ' ':
s = i
state = 2
case state == 1 && b[i] == ' ':
continue
case i != 0 && b[i] == ' ' && (b[i-1] == 'n' || b[i-1] == 'y'):
state = 1
}
}
return ""
}
func findStmt(role string, stmts []stmt) *stmt {
for i := range stmts {
if stmts[i].role.Name != role {

View File

@ -229,3 +229,80 @@ func TestGQLHashWithVars2(t *testing.T) {
t.Fatal("Hashes don't match they should")
}
}
func TestGQLName1(t *testing.T) {
var q = `
query {
products(
distinct: [price]
where: { id: { and: { greater_or_equals: 20, lt: 28 } } }
) { id name } }`
name := gqlName(q)
if len(name) != 0 {
t.Fatal("Name should be empty, not ", name)
}
}
func TestGQLName2(t *testing.T) {
var q = `
query hakuna_matata {
products(
distinct: [price]
where: { id: { and: { greater_or_equals: 20, lt: 28 } } }
) {
id
name
}
}`
name := gqlName(q)
if name != "hakuna_matata" {
t.Fatal("Name should be 'hakuna_matata', not ", name)
}
}
func TestGQLName3(t *testing.T) {
var q = `
mutation means{ users { id } }`
// var v2 = ` { products( limit: 30, order_by: { price: desc }, distinct: [ price ] where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) { id name price user { id email } } } `
name := gqlName(q)
if name != "means" {
t.Fatal("Name should be 'means', not ", name)
}
}
func TestGQLName4(t *testing.T) {
var q = `
query no_worries
users {
id
}
}`
name := gqlName(q)
if name != "no_worries" {
t.Fatal("Name should be 'no_worries', not ", name)
}
}
func TestGQLName5(t *testing.T) {
var q = `
{
users {
id
}
}`
name := gqlName(q)
if len(name) != 0 {
t.Fatal("Name should be empty, not ", name)
}
}