super-graph/serv/allow.go

277 lines
4.6 KiB
Go
Raw Normal View History

2019-07-29 07:13:33 +02:00
package serv
import (
2019-09-26 06:35:31 +02:00
"bytes"
2019-09-05 06:09:56 +02:00
"encoding/json"
2019-07-29 07:13:33 +02:00
"fmt"
"io/ioutil"
"log"
"os"
"path"
2019-07-29 07:13:33 +02:00
"sort"
"strings"
)
2019-09-05 06:09:56 +02:00
const (
AL_QUERY int = iota + 1
AL_VARS
)
2019-07-29 07:13:33 +02:00
type allowItem struct {
2019-09-05 06:09:56 +02:00
uri string
gql string
vars json.RawMessage
2019-07-29 07:13:33 +02:00
}
var _allowList allowList
type allowList struct {
2019-10-25 07:39:59 +02:00
list []*allowItem
index map[string]int
2019-08-02 16:07:50 +02:00
filepath string
2019-07-29 07:13:33 +02:00
saveChan chan *allowItem
2019-09-20 06:19:11 +02:00
active bool
2019-07-29 07:13:33 +02:00
}
func initAllowList(cpath string) {
2019-07-29 07:13:33 +02:00
_allowList = allowList{
2019-10-25 07:39:59 +02:00
index: make(map[string]int),
2019-07-29 07:13:33 +02:00
saveChan: make(chan *allowItem),
2019-09-20 06:19:11 +02:00
active: true,
2019-07-29 07:13:33 +02:00
}
2019-08-02 16:07:50 +02:00
if len(cpath) != 0 {
fp := path.Join(cpath, "allow.list")
2019-08-02 16:07:50 +02:00
if _, err := os.Stat(fp); err == nil {
_allowList.filepath = fp
} else if !os.IsNotExist(err) {
logger.Fatal().Err(err).Send()
2019-08-02 16:07:50 +02:00
}
}
if len(_allowList.filepath) == 0 {
fp := "./allow.list"
if _, err := os.Stat(fp); err == nil {
_allowList.filepath = fp
} else if !os.IsNotExist(err) {
logger.Fatal().Err(err).Send()
2019-08-02 16:07:50 +02:00
}
}
if len(_allowList.filepath) == 0 {
fp := "./config/allow.list"
if _, err := os.Stat(fp); err == nil {
_allowList.filepath = fp
} else if !os.IsNotExist(err) {
logger.Fatal().Err(err).Send()
2019-08-02 16:07:50 +02:00
}
}
if len(_allowList.filepath) == 0 {
if conf.UseAllowList {
logger.Fatal().Msg("allow.list not found")
}
2019-08-02 16:07:50 +02:00
if len(cpath) == 0 {
_allowList.filepath = "./config/allow.list"
} else {
_allowList.filepath = path.Join(cpath, "allow.list")
}
logger.Warn().Msg("allow.list not found")
} else {
_allowList.load()
}
2019-07-29 07:13:33 +02:00
go func() {
for v := range _allowList.saveChan {
_allowList.save(v)
}
}()
}
func (al *allowList) add(req *gqlReq) {
2019-09-20 06:19:11 +02:00
if al.active == false || len(req.ref) == 0 || len(req.Query) == 0 {
2019-07-29 07:13:33 +02:00
return
}
2019-10-03 09:08:01 +02:00
var query string
for i := 0; i < len(req.Query); i++ {
c := req.Query[i]
if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' {
query = req.Query
break
} else if c == '{' {
query = "query " + req.Query
break
}
}
2019-07-29 07:13:33 +02:00
al.saveChan <- &allowItem{
2019-09-05 06:09:56 +02:00
uri: req.ref,
2019-10-03 09:08:01 +02:00
gql: query,
2019-09-05 06:09:56 +02:00
vars: req.Vars,
2019-07-29 07:13:33 +02:00
}
}
func (al *allowList) load() {
2019-09-20 06:19:11 +02:00
if al.active == false {
return
}
2019-08-02 16:07:50 +02:00
b, err := ioutil.ReadFile(al.filepath)
2019-07-29 07:13:33 +02:00
if err != nil {
log.Fatal(err)
}
if len(b) == 0 {
return
}
var uri string
2019-09-05 06:09:56 +02:00
var varBytes []byte
2019-07-29 07:13:33 +02:00
s, e, c := 0, 0, 0
2019-09-05 06:09:56 +02:00
ty := 0
2019-07-29 07:13:33 +02:00
for {
if c == 0 && b[e] == '#' {
s = e
2019-09-05 06:09:56 +02:00
for e < len(b) && b[e] != '\n' {
2019-07-29 07:49:48 +02:00
e++
}
if (e - s) > 2 {
uri = strings.TrimSpace(string(b[(s + 1):e]))
2019-07-29 07:13:33 +02:00
}
}
2019-09-05 06:09:56 +02:00
if e >= len(b) {
break
}
if matchPrefix(b, e, "query") || matchPrefix(b, e, "mutation") {
2019-07-29 07:13:33 +02:00
if c == 0 {
s = e
}
2019-09-05 06:09:56 +02:00
ty = AL_QUERY
} else if matchPrefix(b, e, "variables") {
if c == 0 {
s = e + len("variables") + 1
}
ty = AL_VARS
} else if b[e] == '{' {
2019-07-29 07:13:33 +02:00
c++
2019-09-05 06:09:56 +02:00
2019-07-29 07:13:33 +02:00
} else if b[e] == '}' {
c--
2019-09-05 06:09:56 +02:00
2019-07-29 07:13:33 +02:00
if c == 0 {
2019-09-05 06:09:56 +02:00
if ty == AL_QUERY {
q := string(b[s:(e + 1)])
2019-10-25 07:39:59 +02:00
key := gqlHash(q, varBytes, "")
if idx, ok := al.index[key]; !ok {
al.list = append(al.list, &allowItem{
uri: uri,
gql: q,
vars: varBytes,
})
al.index[key] = len(al.list) - 1
} else {
item := al.list[idx]
item.gql = q
2019-09-05 06:09:56 +02:00
item.vars = varBytes
}
varBytes = nil
} else if ty == AL_VARS {
varBytes = b[s:(e + 1)]
2019-07-29 07:13:33 +02:00
}
2019-09-05 06:09:56 +02:00
ty = 0
2019-07-29 07:13:33 +02:00
}
}
e++
if e >= len(b) {
break
}
}
}
func (al *allowList) save(item *allowItem) {
2019-09-20 06:19:11 +02:00
if al.active == false {
return
}
2019-10-25 07:39:59 +02:00
key := gqlHash(item.gql, item.vars, "")
if _, ok := al.index[key]; ok {
return
}
2019-07-29 07:13:33 +02:00
al.list = append(al.list, item)
al.index[key] = len(al.list) - 1
2019-08-02 16:07:50 +02:00
f, err := os.Create(al.filepath)
2019-07-29 07:13:33 +02:00
if err != nil {
logger.Warn().Err(err).Msgf("Failed to write allow list: %s", al.filepath)
2019-08-03 17:08:16 +02:00
return
2019-07-29 07:13:33 +02:00
}
defer f.Close()
keys := []string{}
2019-09-05 06:09:56 +02:00
urlMap := make(map[string][]*allowItem)
2019-07-29 07:13:33 +02:00
for _, v := range al.list {
2019-09-05 06:09:56 +02:00
urlMap[v.uri] = append(urlMap[v.uri], v)
2019-07-29 07:13:33 +02:00
}
for k := range urlMap {
keys = append(keys, k)
}
sort.Strings(keys)
for i := range keys {
k := keys[i]
v := urlMap[k]
f.WriteString(fmt.Sprintf("# %s\n\n", k))
for i := range v {
2019-09-26 06:35:31 +02:00
if len(v[i].vars) != 0 && bytes.Equal(v[i].vars, []byte("{}")) == false {
2019-09-05 06:09:56 +02:00
vj, err := json.MarshalIndent(v[i].vars, "", "\t")
if err != nil {
logger.Warn().Err(err).Msg("Failed to write allow list 'vars' to file")
continue
}
f.WriteString(fmt.Sprintf("variables %s\n\n", vj))
}
2019-09-26 06:35:31 +02:00
if v[i].gql[0] == '{' {
f.WriteString(fmt.Sprintf("query %s\n\n", v[i].gql))
} else {
f.WriteString(fmt.Sprintf("%s\n\n", v[i].gql))
}
2019-09-05 06:09:56 +02:00
}
}
}
func matchPrefix(b []byte, i int, s string) bool {
if (len(b) - i) < len(s) {
return false
}
for n := 0; n < len(s); n++ {
if b[(i+n)] != s[n] {
return false
2019-07-29 07:13:33 +02:00
}
}
2019-09-05 06:09:56 +02:00
return true
2019-07-29 07:13:33 +02:00
}