Fix to ensure only named queries are saved to the allow list

This commit is contained in:
Vikram Rangnekar
2020-02-01 10:54:19 -05:00
parent 3bd9b199dd
commit 3a4d885987
15 changed files with 723 additions and 636 deletions

View File

@ -1,320 +0,0 @@
package serv
import (
"bytes"
"encoding/json"
"fmt"
"io/ioutil"
"log"
"os"
"path"
"sort"
"strings"
)
const (
AL_QUERY int = iota + 1
AL_VARS
)
type allowItem struct {
name string
hash string
uri string
gql string
vars json.RawMessage
}
var _allowList allowList
type allowList struct {
list []*allowItem
index map[string]int
filepath string
saveChan chan *allowItem
active bool
}
func initAllowList(cpath string) {
_allowList = allowList{
index: make(map[string]int),
saveChan: make(chan *allowItem),
active: true,
}
if len(cpath) != 0 {
fp := path.Join(cpath, "allow.list")
if _, err := os.Stat(fp); err == nil {
_allowList.filepath = fp
} else if !os.IsNotExist(err) {
errlog.Fatal().Err(err).Send()
}
}
if len(_allowList.filepath) == 0 {
fp := "./allow.list"
if _, err := os.Stat(fp); err == nil {
_allowList.filepath = fp
} else if !os.IsNotExist(err) {
errlog.Fatal().Err(err).Send()
}
}
if len(_allowList.filepath) == 0 {
fp := "./config/allow.list"
if _, err := os.Stat(fp); err == nil {
_allowList.filepath = fp
} else if !os.IsNotExist(err) {
errlog.Fatal().Err(err).Send()
}
}
if len(_allowList.filepath) == 0 {
if conf.Production {
errlog.Fatal().Msg("allow.list not found")
}
if len(cpath) == 0 {
_allowList.filepath = "./config/allow.list"
} else {
_allowList.filepath = path.Join(cpath, "allow.list")
}
logger.Warn().Msg("allow.list not found")
} else {
_allowList.load()
}
go func() {
for v := range _allowList.saveChan {
_allowList.save(v)
}
}()
}
func (al *allowList) add(req *gqlReq) {
if al.saveChan == nil || len(req.ref) == 0 || len(req.Query) == 0 {
return
}
var query string
for i := 0; i < len(req.Query); i++ {
c := req.Query[i]
if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' {
query = req.Query
break
} else if c == '{' {
query = "query " + req.Query
break
}
}
al.saveChan <- &allowItem{
uri: req.ref,
gql: query,
vars: req.Vars,
}
}
func (al *allowList) upsert(query, vars []byte, uri string) {
q := string(query)
hash := gqlHash(q, vars, "")
name := gqlName(q)
var key string
if len(name) != 0 {
key = name
} else {
key = hash
}
if i, ok := al.index[key]; !ok {
al.list = append(al.list, &allowItem{
name: name,
hash: hash,
uri: uri,
gql: q,
vars: vars,
})
al.index[key] = len(al.list) - 1
} else {
item := al.list[i]
item.name = name
item.hash = hash
item.gql = q
item.vars = vars
}
}
func (al *allowList) load() {
b, err := ioutil.ReadFile(al.filepath)
if err != nil {
log.Fatal(err)
}
if len(b) == 0 {
return
}
var uri string
var varBytes []byte
s, e, c := 0, 0, 0
ty := 0
for {
if c == 0 && b[e] == '#' {
s = e
for e < len(b) && b[e] != '\n' {
e++
}
if (e - s) > 2 {
uri = strings.TrimSpace(string(b[(s + 1):e]))
}
}
if e >= len(b) {
break
}
if matchPrefix(b, e, "query") || matchPrefix(b, e, "mutation") {
if c == 0 {
s = e
}
ty = AL_QUERY
} else if matchPrefix(b, e, "variables") {
if c == 0 {
s = e + len("variables") + 1
}
ty = AL_VARS
} else if b[e] == '{' {
c++
} else if b[e] == '}' {
c--
if c == 0 {
if ty == AL_QUERY {
al.upsert(b[s:(e+1)], varBytes, uri)
varBytes = nil
} else if ty == AL_VARS {
varBytes = b[s:(e + 1)]
}
ty = 0
}
}
e++
if e >= len(b) {
break
}
}
}
func (al *allowList) save(item *allowItem) {
var err error
item.hash = gqlHash(item.gql, item.vars, "")
item.name = gqlName(item.gql)
if len(item.name) == 0 {
key := item.hash
if _, ok := al.index[key]; ok {
return
}
al.list = append(al.list, item)
al.index[key] = len(al.list) - 1
} else {
key := item.name
if i, ok := al.index[key]; ok {
if al.list[i].hash == item.hash {
return
}
al.list[i] = item
} else {
al.list = append(al.list, item)
al.index[key] = len(al.list) - 1
}
}
f, err := os.Create(al.filepath)
if err != nil {
logger.Warn().Err(err).Msgf("Failed to write allow list: %s", al.filepath)
return
}
defer f.Close()
keys := []string{}
urlMap := make(map[string][]*allowItem)
for _, v := range al.list {
urlMap[v.uri] = append(urlMap[v.uri], v)
}
for k := range urlMap {
keys = append(keys, k)
}
sort.Strings(keys)
for i := range keys {
k := keys[i]
v := urlMap[k]
if _, err := f.WriteString(fmt.Sprintf("# %s\n\n", k)); err != nil {
logger.Error().Err(err).Send()
return
}
for i := range v {
if len(v[i].vars) != 0 && !bytes.Equal(v[i].vars, []byte("{}")) {
vj, err := json.MarshalIndent(v[i].vars, "", " ")
if err != nil {
logger.Warn().Err(err).Msg("Failed to write allow list 'vars' to file")
continue
}
_, err = f.WriteString(fmt.Sprintf("variables %s\n\n", vj))
if err != nil {
logger.Error().Err(err).Send()
return
}
}
if v[i].gql[0] == '{' {
_, err = f.WriteString(fmt.Sprintf("query %s\n\n", v[i].gql))
} else {
_, err = f.WriteString(fmt.Sprintf("%s\n\n", v[i].gql))
}
if err != nil {
logger.Error().Err(err).Send()
return
}
}
}
}
func matchPrefix(b []byte, i int, s string) bool {
if (len(b) - i) < len(s) {
return false
}
for n := 0; n < len(s); n++ {
if b[(i+n)] != s[n] {
return false
}
}
return true
}

View File

@ -1,15 +1,13 @@
package serv
import (
"context"
"fmt"
"os"
"runtime"
"strings"
"github.com/dosco/super-graph/allow"
"github.com/dosco/super-graph/psql"
"github.com/dosco/super-graph/qcode"
"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/pgxpool"
"github.com/rs/zerolog"
"github.com/spf13/cobra"
@ -31,17 +29,18 @@ var (
)
var (
logger zerolog.Logger // logger for everything but errors
errlog zerolog.Logger // logger for errors includes line numbers
conf *config // parsed config
confPath string // path to the config file
db *pgxpool.Pool // database connection pool
schema *psql.DBSchema // database tables, columns and relationships
qcompile *qcode.Compiler // qcode compiler
pcompile *psql.Compiler // postgres sql compiler
logger zerolog.Logger // logger for everything but errors
errlog zerolog.Logger // logger for errors includes line numbers
conf *config // parsed config
confPath string // path to the config file
db *pgxpool.Pool // database connection pool
schema *psql.DBSchema // database tables, columns and relationships
allowList *allow.List // allow.list is contains queries allowed in production
qcompile *qcode.Compiler // qcode compiler
pcompile *psql.Compiler // postgres sql compiler
)
func Init() {
func Cmd() {
initLog()
rootCmd := &cobra.Command{
@ -156,159 +155,6 @@ e.g. db:migrate -+1
}
}
func initLog() {
out := zerolog.ConsoleWriter{Out: os.Stderr}
logger = zerolog.New(out).With().Timestamp().Logger()
errlog = logger.With().Caller().Logger()
}
func initConf() (*config, error) {
vi := newConfig(getConfigName())
if err := vi.ReadInConfig(); err != nil {
return nil, err
}
inherits := vi.GetString("inherits")
if len(inherits) != 0 {
vi = newConfig(inherits)
if err := vi.ReadInConfig(); err != nil {
return nil, err
}
if vi.IsSet("inherits") {
errlog.Fatal().Msgf("inherited config (%s) cannot itself inherit (%s)",
inherits,
vi.GetString("inherits"))
}
vi.SetConfigName(getConfigName())
if err := vi.MergeInConfig(); err != nil {
return nil, err
}
}
c := &config{}
if err := c.Init(vi); err != nil {
return nil, fmt.Errorf("unable to decode config, %v", err)
}
logLevel, err := zerolog.ParseLevel(c.LogLevel)
if err != nil {
errlog.Error().Err(err).Msg("error setting log_level")
}
zerolog.SetGlobalLevel(logLevel)
return c, nil
}
func initDB(c *config, useDB bool) (*pgx.Conn, error) {
config, _ := pgx.ParseConfig("")
config.Host = c.DB.Host
config.Port = c.DB.Port
config.User = c.DB.User
config.Password = c.DB.Password
config.RuntimeParams = map[string]string{
"application_name": c.AppName,
"search_path": c.DB.Schema,
}
if useDB {
config.Database = c.DB.DBName
}
switch c.LogLevel {
case "debug":
config.LogLevel = pgx.LogLevelDebug
case "info":
config.LogLevel = pgx.LogLevelInfo
case "warn":
config.LogLevel = pgx.LogLevelWarn
case "error":
config.LogLevel = pgx.LogLevelError
default:
config.LogLevel = pgx.LogLevelNone
}
config.Logger = NewSQLLogger(logger)
db, err := pgx.ConnectConfig(context.Background(), config)
if err != nil {
return nil, err
}
return db, nil
}
func initDBPool(c *config) (*pgxpool.Pool, error) {
config, _ := pgxpool.ParseConfig("")
config.ConnConfig.Host = c.DB.Host
config.ConnConfig.Port = c.DB.Port
config.ConnConfig.Database = c.DB.DBName
config.ConnConfig.User = c.DB.User
config.ConnConfig.Password = c.DB.Password
config.ConnConfig.RuntimeParams = map[string]string{
"application_name": c.AppName,
"search_path": c.DB.Schema,
}
switch c.LogLevel {
case "debug":
config.ConnConfig.LogLevel = pgx.LogLevelDebug
case "info":
config.ConnConfig.LogLevel = pgx.LogLevelInfo
case "warn":
config.ConnConfig.LogLevel = pgx.LogLevelWarn
case "error":
config.ConnConfig.LogLevel = pgx.LogLevelError
default:
config.ConnConfig.LogLevel = pgx.LogLevelNone
}
config.ConnConfig.Logger = NewSQLLogger(logger)
// if c.DB.MaxRetries != 0 {
// opt.MaxRetries = c.DB.MaxRetries
// }
if c.DB.PoolSize != 0 {
config.MaxConns = conf.DB.PoolSize
}
db, err := pgxpool.ConnectConfig(context.Background(), config)
if err != nil {
return nil, err
}
return db, nil
}
func initCompiler() {
var err error
qcompile, pcompile, err = initCompilers(conf)
if err != nil {
errlog.Fatal().Err(err).Msg("failed to initialize compilers")
}
if err := initResolvers(); err != nil {
errlog.Fatal().Err(err).Msg("failed to initialized resolvers")
}
}
func initConfOnce() {
var err error
if conf == nil {
if conf, err = initConf(); err != nil {
errlog.Fatal().Err(err).Msg("failed to read config")
}
}
}
func cmdVersion(cmd *cobra.Command, args []string) {
fmt.Printf("%s\n", BuildDetails())
}

View File

@ -17,7 +17,7 @@ func cmdServ(cmd *cobra.Command, args []string) {
if err == nil {
initCompiler()
initAllowList(confPath)
initPreparedList()
initPreparedList(confPath)
} else {
fatalInProd(err, "failed to connect to database")
}

View File

@ -11,6 +11,7 @@ import (
"time"
"github.com/cespare/xxhash/v2"
"github.com/dosco/super-graph/allow"
"github.com/dosco/super-graph/qcode"
"github.com/jackc/pgx/v4"
"github.com/valyala/fasttemplate"
@ -107,7 +108,7 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *stmt, error) {
}
ps, ok := _preparedList[gqlHash(c.req.Query, c.req.Vars, role)]
ps, ok := _preparedList[stmtHash(allow.QueryName(c.req.Query), role)]
if !ok {
return nil, nil, errUnauthorized
}
@ -240,8 +241,10 @@ func (c *coreContext) resolveSQL() ([]byte, *stmt, error) {
}
}
if !conf.Production {
_allowList.add(&c.req)
if allowList.IsPersist() {
if err := allowList.Add(c.req.Vars, c.req.Query, c.req.ref); err != nil {
return nil, nil, err
}
}
if len(stmts) > 1 {

View File

@ -4,7 +4,7 @@ package serv
func Fuzz(data []byte) int {
gql := string(data)
gqlName(gql)
QueryName(gql)
gqlHash(gql, nil, "")
return 1

View File

@ -10,7 +10,6 @@ func TestFuzzCrashers(t *testing.T) {
}
for _, f := range crashers {
_ = gqlName(f)
gqlHash(f, nil, "")
}
}

179
serv/init.go Normal file
View File

@ -0,0 +1,179 @@
package serv
import (
"context"
"fmt"
"os"
"github.com/dosco/super-graph/allow"
"github.com/jackc/pgx/v4"
"github.com/jackc/pgx/v4/pgxpool"
"github.com/rs/zerolog"
)
func initLog() {
out := zerolog.ConsoleWriter{Out: os.Stderr}
logger = zerolog.New(out).With().Timestamp().Logger()
errlog = logger.With().Caller().Logger()
}
func initConf() (*config, error) {
vi := newConfig(getConfigName())
if err := vi.ReadInConfig(); err != nil {
return nil, err
}
inherits := vi.GetString("inherits")
if len(inherits) != 0 {
vi = newConfig(inherits)
if err := vi.ReadInConfig(); err != nil {
return nil, err
}
if vi.IsSet("inherits") {
errlog.Fatal().Msgf("inherited config (%s) cannot itself inherit (%s)",
inherits,
vi.GetString("inherits"))
}
vi.SetConfigName(getConfigName())
if err := vi.MergeInConfig(); err != nil {
return nil, err
}
}
c := &config{}
if err := c.Init(vi); err != nil {
return nil, fmt.Errorf("unable to decode config, %v", err)
}
logLevel, err := zerolog.ParseLevel(c.LogLevel)
if err != nil {
errlog.Error().Err(err).Msg("error setting log_level")
}
zerolog.SetGlobalLevel(logLevel)
return c, nil
}
func initDB(c *config, useDB bool) (*pgx.Conn, error) {
config, _ := pgx.ParseConfig("")
config.Host = c.DB.Host
config.Port = c.DB.Port
config.User = c.DB.User
config.Password = c.DB.Password
config.RuntimeParams = map[string]string{
"application_name": c.AppName,
"search_path": c.DB.Schema,
}
if useDB {
config.Database = c.DB.DBName
}
switch c.LogLevel {
case "debug":
config.LogLevel = pgx.LogLevelDebug
case "info":
config.LogLevel = pgx.LogLevelInfo
case "warn":
config.LogLevel = pgx.LogLevelWarn
case "error":
config.LogLevel = pgx.LogLevelError
default:
config.LogLevel = pgx.LogLevelNone
}
config.Logger = NewSQLLogger(logger)
db, err := pgx.ConnectConfig(context.Background(), config)
if err != nil {
return nil, err
}
return db, nil
}
func initDBPool(c *config) (*pgxpool.Pool, error) {
config, _ := pgxpool.ParseConfig("")
config.ConnConfig.Host = c.DB.Host
config.ConnConfig.Port = c.DB.Port
config.ConnConfig.Database = c.DB.DBName
config.ConnConfig.User = c.DB.User
config.ConnConfig.Password = c.DB.Password
config.ConnConfig.RuntimeParams = map[string]string{
"application_name": c.AppName,
"search_path": c.DB.Schema,
}
switch c.LogLevel {
case "debug":
config.ConnConfig.LogLevel = pgx.LogLevelDebug
case "info":
config.ConnConfig.LogLevel = pgx.LogLevelInfo
case "warn":
config.ConnConfig.LogLevel = pgx.LogLevelWarn
case "error":
config.ConnConfig.LogLevel = pgx.LogLevelError
default:
config.ConnConfig.LogLevel = pgx.LogLevelNone
}
config.ConnConfig.Logger = NewSQLLogger(logger)
// if c.DB.MaxRetries != 0 {
// opt.MaxRetries = c.DB.MaxRetries
// }
if c.DB.PoolSize != 0 {
config.MaxConns = conf.DB.PoolSize
}
db, err := pgxpool.ConnectConfig(context.Background(), config)
if err != nil {
return nil, err
}
return db, nil
}
func initCompiler() {
var err error
qcompile, pcompile, err = initCompilers(conf)
if err != nil {
errlog.Fatal().Err(err).Msg("failed to initialize compilers")
}
if err := initResolvers(); err != nil {
errlog.Fatal().Err(err).Msg("failed to initialized resolvers")
}
}
func initConfOnce() {
var err error
if conf == nil {
if conf, err = initConf(); err != nil {
errlog.Fatal().Err(err).Msg("failed to read config")
}
}
}
func initAllowList(cpath string) {
var ac allow.Config
var err error
if !conf.Production {
ac = allow.Config{CreateIfNotExists: true, Persist: true}
}
allowList, err = allow.New(cpath, ac)
if err != nil {
errlog.Fatal().Err(err).Msg("failed to initialize allow list")
}
}

View File

@ -6,6 +6,7 @@ import (
"fmt"
"io"
"github.com/dosco/super-graph/allow"
"github.com/dosco/super-graph/qcode"
"github.com/jackc/pgconn"
"github.com/jackc/pgx/v4"
@ -23,7 +24,10 @@ var (
_preparedList map[string]*preparedItem
)
func initPreparedList() {
func initPreparedList(cpath string) {
if allowList.IsPersist() {
return
}
_preparedList = make(map[string]*preparedItem)
tx, err := db.Begin(context.Background())
@ -43,30 +47,38 @@ func initPreparedList() {
success := 0
for _, v := range _allowList.list {
if len(v.gql) == 0 {
list, err := allowList.Load()
if err != nil {
errlog.Fatal().Err(err).Send()
}
for _, v := range list {
if len(v.Query) == 0 {
continue
}
err := prepareStmt(v.gql, v.vars)
err := prepareStmt(v)
if err == nil {
success++
continue
}
if len(v.vars) == 0 {
logger.Warn().Err(err).Msg(v.gql)
if len(v.Vars) == 0 {
logger.Warn().Err(err).Msg(v.Query)
} else {
logger.Warn().Err(err).Msgf("%s %s", v.vars, v.gql)
logger.Warn().Err(err).Msgf("%s %s", v.Vars, v.Query)
}
}
logger.Info().
Msgf("Registered %d of %d queries from allow.list as prepared statements",
success, len(_allowList.list))
success, len(list))
}
func prepareStmt(gql string, vars []byte) error {
func prepareStmt(item allow.Item) error {
gql := item.Query
vars := item.Vars
qt := qcode.GetQType(gql)
q := []byte(gql)
@ -99,7 +111,7 @@ func prepareStmt(gql string, vars []byte) error {
logger.Debug().Msg("Prepared statement role: user")
err = prepare(tx, stmts1, gqlHash(gql, vars, "user"))
err = prepare(tx, stmts1, stmtHash(item.Name, "user"))
if err != nil {
return err
}
@ -112,7 +124,7 @@ func prepareStmt(gql string, vars []byte) error {
return err
}
err = prepare(tx, stmts2, gqlHash(gql, vars, "anon"))
err = prepare(tx, stmts2, stmtHash(item.Name, "anon"))
if err != nil {
return err
}
@ -127,7 +139,7 @@ func prepareStmt(gql string, vars []byte) error {
return err
}
err = prepare(tx, stmts, gqlHash(gql, vars, role.Name))
err = prepare(tx, stmts, stmtHash(item.Name, role.Name))
if err != nil {
return err
}

View File

@ -22,6 +22,14 @@ func mkkey(h *xxhash.Digest, k1 string, k2 string) uint64 {
return v
}
// nolint: errcheck
func stmtHash(name string, role string) string {
h := sha1.New()
io.WriteString(h, strings.ToLower(name))
io.WriteString(h, role)
return hex.EncodeToString(h.Sum(nil))
}
// nolint: errcheck
func gqlHash(b string, vars []byte, role string) string {
b = strings.TrimSpace(b)
@ -108,30 +116,6 @@ func al(b byte) bool {
return (b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z') || (b >= '0' && b <= '9')
}
func gqlName(b string) string {
state, s := 0, 0
for i := 0; i < len(b); i++ {
switch {
case state == 2 && b[i] == '{':
return b[s:i]
case state == 2 && b[i] == ' ':
return b[s:i]
case state == 1 && b[i] == '{':
return ""
case state == 1 && b[i] != ' ':
s = i
state = 2
case state == 1 && b[i] == ' ':
continue
case i != 0 && b[i] == ' ' && (b[i-1] == 'n' || b[i-1] == 'y'):
state = 1
}
}
return ""
}
func findStmt(role string, stmts []stmt) *stmt {
for i := range stmts {
if stmts[i].role.Name != role {

View File

@ -229,80 +229,3 @@ func TestGQLHashWithVars2(t *testing.T) {
t.Fatal("Hashes don't match they should")
}
}
func TestGQLName1(t *testing.T) {
var q = `
query {
products(
distinct: [price]
where: { id: { and: { greater_or_equals: 20, lt: 28 } } }
) { id name } }`
name := gqlName(q)
if len(name) != 0 {
t.Fatal("Name should be empty, not ", name)
}
}
func TestGQLName2(t *testing.T) {
var q = `
query hakuna_matata {
products(
distinct: [price]
where: { id: { and: { greater_or_equals: 20, lt: 28 } } }
) {
id
name
}
}`
name := gqlName(q)
if name != "hakuna_matata" {
t.Fatal("Name should be 'hakuna_matata', not ", name)
}
}
func TestGQLName3(t *testing.T) {
var q = `
mutation means{ users { id } }`
// var v2 = ` { products( limit: 30, order_by: { price: desc }, distinct: [ price ] where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) { id name price user { id email } } } `
name := gqlName(q)
if name != "means" {
t.Fatal("Name should be 'means', not ", name)
}
}
func TestGQLName4(t *testing.T) {
var q = `
query no_worries
users {
id
}
}`
name := gqlName(q)
if name != "no_worries" {
t.Fatal("Name should be 'no_worries', not ", name)
}
}
func TestGQLName5(t *testing.T) {
var q = `
{
users {
id
}
}`
name := gqlName(q)
if len(name) != 0 {
t.Fatal("Name should be empty, not ", name)
}
}