diff --git a/allow/allow.go b/allow/allow.go index ef66a99..5d84af6 100644 --- a/allow/allow.go +++ b/allow/allow.go @@ -18,11 +18,11 @@ const ( ) type Item struct { - Name string - key string - URI string - Query string - Vars json.RawMessage + Name string + key string + Query string + Vars json.RawMessage + Comment string } type List struct { @@ -105,7 +105,7 @@ func (al *List) IsPersist() bool { return al.saveChan != nil } -func (al *List) Add(vars []byte, query, uri string) error { +func (al *List) Set(vars []byte, query, comment string) error { if al.saveChan == nil { return errors.New("allow.list is read-only") } @@ -129,9 +129,9 @@ func (al *List) Add(vars []byte, query, uri string) error { } al.saveChan <- Item{ - URI: uri, - Query: q, - Vars: vars, + Comment: comment, + Query: q, + Vars: vars, } return nil @@ -149,7 +149,7 @@ func (al *List) Load() ([]Item, error) { return list, nil } - var uri string + var comment bytes.Buffer var varBytes []byte itemMap := make(map[string]struct{}) @@ -166,7 +166,7 @@ func (al *List) Load() ([]Item, error) { e++ } if (e - s) > 2 { - uri = strings.TrimSpace(string(b[(s + 1):e])) + comment.Write(b[(s + 1):(e + 1)]) } } @@ -207,13 +207,14 @@ func (al *List) Load() ([]Item, error) { if _, ok := itemMap[key]; !ok { v := Item{ - Name: name, - key: key, - URI: uri, - Query: query, - Vars: varBytes, + Name: name, + key: key, + Query: query, + Vars: varBytes, + Comment: comment.String(), } list = append(list, v) + comment.Reset() } varBytes = nil @@ -252,6 +253,9 @@ func (al *List) save(item Item) error { } if index != -1 { + if len(list[index].Comment) != 0 { + item.Comment = list[index].Comment + } list[index] = item } else { list = append(list, item) @@ -269,9 +273,29 @@ func (al *List) save(item Item) error { }) for _, v := range list { - _, err := f.WriteString(fmt.Sprintf("# %s\n\n", v.URI)) - if err != nil { - return err + cmtLines := strings.Split(v.Comment, "\n") + + i := 0 + for _, c := range cmtLines { + if c = strings.TrimSpace(c); len(c) == 0 { + continue + } + + _, err := f.WriteString(fmt.Sprintf("# %s\n", c)) + if err != nil { + return err + } + i++ + } + + if i != 0 { + if _, err := f.WriteString("\n"); err != nil { + return err + } + } else { + if _, err := f.WriteString(fmt.Sprintf("# Query named %s\n\n", v.Name)); err != nil { + return err + } } if len(v.Vars) != 0 && !bytes.Equal(v.Vars, []byte("{}")) { @@ -317,17 +341,13 @@ func QueryName(b string) string { for i := 0; i < len(b); i++ { switch { - case state == 2 && b[i] == '{': - return b[s:i] - case state == 2 && b[i] == ' ': + case state == 2 && !isValidNameChar(b[i]): return b[s:i] case state == 1 && b[i] == '{': return "" - case state == 1 && b[i] != ' ': + case state == 1 && isValidNameChar(b[i]): s = i state = 2 - case state == 1 && b[i] == ' ': - continue case i != 0 && b[i] == ' ' && (b[i-1] == 'n' || b[i-1] == 'y'): state = 1 } @@ -335,3 +355,7 @@ func QueryName(b string) string { return "" } + +func isValidNameChar(c byte) bool { + return (c >= 'a' && c <= 'z') || (c >= 'A' && c <= 'Z') || (c >= '0' && c <= '9') || c == '_' +} diff --git a/allow/allow_test.go b/allow/allow_test.go index f92dc76..07f81b6 100644 --- a/allow/allow_test.go +++ b/allow/allow_test.go @@ -21,7 +21,9 @@ func TestGQLName1(t *testing.T) { func TestGQLName2(t *testing.T) { var q = ` - query hakuna_matata { + query hakuna_matata + + { products( distinct: [price] where: { id: { and: { greater_or_equals: 20, lt: 28 } } } diff --git a/serv/core.go b/serv/core.go index 5302a91..2fdfe4b 100644 --- a/serv/core.go +++ b/serv/core.go @@ -242,7 +242,7 @@ func (c *coreContext) resolveSQL() ([]byte, *stmt, error) { } if allowList.IsPersist() { - if err := allowList.Add(c.req.Vars, c.req.Query, c.req.ref); err != nil { + if err := allowList.Set(c.req.Vars, c.req.Query, c.req.ref); err != nil { return nil, nil, err } }