From aff2a13ba402630ffd6e8198ae367778f3b2789b Mon Sep 17 00:00:00 2001 From: Vikram Rangnekar Date: Tue, 26 Nov 2019 01:36:19 -0500 Subject: [PATCH] Added support for query names to the allow.list --- config/allow.list | 2 +- serv/allow.go | 90 ++++++++++++++++++++++++++++++---------------- serv/core.go | 6 ++-- serv/fuzz.go | 1 + serv/fuzz_test.go | 1 + serv/prepare.go | 6 ++++ serv/utils.go | 24 +++++++++++++ serv/utils_test.go | 77 +++++++++++++++++++++++++++++++++++++++ 8 files changed, 173 insertions(+), 34 deletions(-) diff --git a/config/allow.list b/config/allow.list index 5eec9c2..92413d8 100644 --- a/config/allow.list +++ b/config/allow.list @@ -152,7 +152,7 @@ mutation { } } -query { +query getProducts { products { id name diff --git a/serv/allow.go b/serv/allow.go index 3ce2b83..0232b5b 100644 --- a/serv/allow.go +++ b/serv/allow.go @@ -18,6 +18,8 @@ const ( ) type allowItem struct { + name string + hash string uri string gql string vars json.RawMessage @@ -94,7 +96,7 @@ func initAllowList(cpath string) { } func (al *allowList) add(req *gqlReq) { - if al.active == false || len(req.ref) == 0 || len(req.Query) == 0 { + if len(req.ref) == 0 || len(req.Query) == 0 { return } @@ -119,11 +121,39 @@ func (al *allowList) add(req *gqlReq) { } } -func (al *allowList) load() { - if al.active == false { - return +func (al *allowList) upsert(query, vars []byte, uri string) { + q := string(query) + hash := gqlHash(q, vars, "") + name := gqlName(q) + + var key string + + if len(name) == 0 { + key = hash + } else { + key = name } + if i, ok := al.index[key]; !ok { + al.list = append(al.list, &allowItem{ + name: name, + hash: hash, + uri: uri, + gql: q, + vars: vars, + }) + al.index[key] = len(al.list) - 1 + } else { + item := al.list[i] + item.name = name + item.hash = hash + item.gql = q + item.vars = vars + + } +} + +func (al *allowList) load() { b, err := ioutil.ReadFile(al.filepath) if err != nil { log.Fatal(err) @@ -172,21 +202,7 @@ func (al *allowList) load() { if c == 0 { if ty == AL_QUERY { - q := string(b[s:(e + 1)]) - key := gqlHash(q, varBytes, "") - - if idx, ok := al.index[key]; !ok { - al.list = append(al.list, &allowItem{ - uri: uri, - gql: q, - vars: varBytes, - }) - al.index[key] = len(al.list) - 1 - } else { - item := al.list[idx] - item.gql = q - item.vars = varBytes - } + al.upsert(b[s:(e+1)], varBytes, uri) varBytes = nil } else if ty == AL_VARS { @@ -204,19 +220,33 @@ func (al *allowList) load() { } func (al *allowList) save(item *allowItem) { - if al.active == false { - return + item.hash = gqlHash(item.gql, item.vars, "") + item.name = gqlName(item.gql) + + if len(item.name) == 0 { + key := item.hash + + if _, ok := al.index[key]; ok { + return + } + + al.list = append(al.list, item) + al.index[key] = len(al.list) - 1 + + } else { + key := item.name + + if i, ok := al.index[key]; ok { + if al.list[i].hash == item.hash { + return + } + al.list[i] = item + } else { + al.list = append(al.list, item) + al.index[key] = len(al.list) - 1 + } } - key := gqlHash(item.gql, item.vars, "") - - if _, ok := al.index[key]; ok { - return - } - - al.list = append(al.list, item) - al.index[key] = len(al.list) - 1 - f, err := os.Create(al.filepath) if err != nil { logger.Warn().Err(err).Msgf("Failed to write allow list: %s", al.filepath) diff --git a/serv/core.go b/serv/core.go index 88a4206..d83da60 100644 --- a/serv/core.go +++ b/serv/core.go @@ -246,9 +246,9 @@ func (c *coreContext) resolveSQL() ([]byte, *stmt, error) { } } - // if conf.Production == false { - // _allowList.add(&c.req) - // } + if conf.Production == false { + _allowList.add(&c.req) + } if len(stmts) > 1 { if st = findStmt(role, stmts); st == nil { diff --git a/serv/fuzz.go b/serv/fuzz.go index f00eed5..34ef656 100644 --- a/serv/fuzz.go +++ b/serv/fuzz.go @@ -4,6 +4,7 @@ package serv func Fuzz(data []byte) int { gql := string(data) + gqlName(gql) gqlHash(gql, nil, "") return 1 diff --git a/serv/fuzz_test.go b/serv/fuzz_test.go index 68fe2c6..2c543bb 100644 --- a/serv/fuzz_test.go +++ b/serv/fuzz_test.go @@ -10,6 +10,7 @@ func TestFuzzCrashers(t *testing.T) { } for _, f := range crashers { + gqlName(f) gqlHash(f, nil, "") } } diff --git a/serv/prepare.go b/serv/prepare.go index 75b3c55..00e48de 100644 --- a/serv/prepare.go +++ b/serv/prepare.go @@ -112,6 +112,12 @@ func prepareStmt(c context.Context, gql string, vars []byte) error { } } + if len(vars) == 0 { + logger.Debug().Msgf("Building prepared statement for:\n %s", gql) + } else { + logger.Debug().Msgf("Building prepared statement:\n %s\n%s", vars, gql) + } + if err := tx.Commit(c); err != nil { return err } diff --git a/serv/utils.go b/serv/utils.go index 2452ed0..bd82acf 100644 --- a/serv/utils.go +++ b/serv/utils.go @@ -106,6 +106,30 @@ func al(b byte) bool { return (b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z') || (b >= '0' && b <= '9') } +func gqlName(b string) string { + state, s := 0, 0 + + for i := 0; i < len(b); i++ { + switch { + case state == 2 && b[i] == '{': + return b[s:i] + case state == 2 && b[i] == ' ': + return b[s:i] + case state == 1 && b[i] == '{': + return "" + case state == 1 && b[i] != ' ': + s = i + state = 2 + case state == 1 && b[i] == ' ': + continue + case i != 0 && b[i] == ' ' && (b[i-1] == 'n' || b[i-1] == 'y'): + state = 1 + } + } + + return "" +} + func findStmt(role string, stmts []stmt) *stmt { for i := range stmts { if stmts[i].role.Name != role { diff --git a/serv/utils_test.go b/serv/utils_test.go index b8babeb..952eb61 100644 --- a/serv/utils_test.go +++ b/serv/utils_test.go @@ -229,3 +229,80 @@ func TestGQLHashWithVars2(t *testing.T) { t.Fatal("Hashes don't match they should") } } + +func TestGQLName1(t *testing.T) { + var q = ` + query { + products( + distinct: [price] + where: { id: { and: { greater_or_equals: 20, lt: 28 } } } + ) { id name } }` + + name := gqlName(q) + + if len(name) != 0 { + t.Fatal("Name should be empty, not ", name) + } +} + +func TestGQLName2(t *testing.T) { + var q = ` + query hakuna_matata { + products( + distinct: [price] + where: { id: { and: { greater_or_equals: 20, lt: 28 } } } + ) { + id + name + } + }` + + name := gqlName(q) + + if name != "hakuna_matata" { + t.Fatal("Name should be 'hakuna_matata', not ", name) + } +} + +func TestGQLName3(t *testing.T) { + var q = ` + mutation means{ users { id } }` + + // var v2 = ` { products( limit: 30, order_by: { price: desc }, distinct: [ price ] where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) { id name price user { id email } } } ` + + name := gqlName(q) + + if name != "means" { + t.Fatal("Name should be 'means', not ", name) + } +} + +func TestGQLName4(t *testing.T) { + var q = ` + query no_worries + users { + id + } + }` + + name := gqlName(q) + + if name != "no_worries" { + t.Fatal("Name should be 'no_worries', not ", name) + } +} + +func TestGQLName5(t *testing.T) { + var q = ` + { + users { + id + } + }` + + name := gqlName(q) + + if len(name) != 0 { + t.Fatal("Name should be empty, not ", name) + } +}