Fix to ensure only named queries are saved to the allow list
This commit is contained in:
320
serv/allow.go
320
serv/allow.go
@ -1,320 +0,0 @@
|
||||
package serv
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/ioutil"
|
||||
"log"
|
||||
"os"
|
||||
"path"
|
||||
"sort"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
AL_QUERY int = iota + 1
|
||||
AL_VARS
|
||||
)
|
||||
|
||||
type allowItem struct {
|
||||
name string
|
||||
hash string
|
||||
uri string
|
||||
gql string
|
||||
vars json.RawMessage
|
||||
}
|
||||
|
||||
var _allowList allowList
|
||||
|
||||
type allowList struct {
|
||||
list []*allowItem
|
||||
index map[string]int
|
||||
filepath string
|
||||
saveChan chan *allowItem
|
||||
active bool
|
||||
}
|
||||
|
||||
func initAllowList(cpath string) {
|
||||
_allowList = allowList{
|
||||
index: make(map[string]int),
|
||||
saveChan: make(chan *allowItem),
|
||||
active: true,
|
||||
}
|
||||
|
||||
if len(cpath) != 0 {
|
||||
fp := path.Join(cpath, "allow.list")
|
||||
|
||||
if _, err := os.Stat(fp); err == nil {
|
||||
_allowList.filepath = fp
|
||||
} else if !os.IsNotExist(err) {
|
||||
errlog.Fatal().Err(err).Send()
|
||||
}
|
||||
}
|
||||
|
||||
if len(_allowList.filepath) == 0 {
|
||||
fp := "./allow.list"
|
||||
|
||||
if _, err := os.Stat(fp); err == nil {
|
||||
_allowList.filepath = fp
|
||||
} else if !os.IsNotExist(err) {
|
||||
errlog.Fatal().Err(err).Send()
|
||||
}
|
||||
}
|
||||
|
||||
if len(_allowList.filepath) == 0 {
|
||||
fp := "./config/allow.list"
|
||||
|
||||
if _, err := os.Stat(fp); err == nil {
|
||||
_allowList.filepath = fp
|
||||
} else if !os.IsNotExist(err) {
|
||||
errlog.Fatal().Err(err).Send()
|
||||
}
|
||||
}
|
||||
|
||||
if len(_allowList.filepath) == 0 {
|
||||
if conf.Production {
|
||||
errlog.Fatal().Msg("allow.list not found")
|
||||
}
|
||||
|
||||
if len(cpath) == 0 {
|
||||
_allowList.filepath = "./config/allow.list"
|
||||
} else {
|
||||
_allowList.filepath = path.Join(cpath, "allow.list")
|
||||
}
|
||||
|
||||
logger.Warn().Msg("allow.list not found")
|
||||
} else {
|
||||
_allowList.load()
|
||||
}
|
||||
|
||||
go func() {
|
||||
for v := range _allowList.saveChan {
|
||||
_allowList.save(v)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (al *allowList) add(req *gqlReq) {
|
||||
if al.saveChan == nil || len(req.ref) == 0 || len(req.Query) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
var query string
|
||||
|
||||
for i := 0; i < len(req.Query); i++ {
|
||||
c := req.Query[i]
|
||||
if c >= 'a' && c <= 'z' || c >= 'A' && c <= 'Z' {
|
||||
query = req.Query
|
||||
break
|
||||
|
||||
} else if c == '{' {
|
||||
query = "query " + req.Query
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
al.saveChan <- &allowItem{
|
||||
uri: req.ref,
|
||||
gql: query,
|
||||
vars: req.Vars,
|
||||
}
|
||||
}
|
||||
|
||||
func (al *allowList) upsert(query, vars []byte, uri string) {
|
||||
q := string(query)
|
||||
hash := gqlHash(q, vars, "")
|
||||
name := gqlName(q)
|
||||
|
||||
var key string
|
||||
|
||||
if len(name) != 0 {
|
||||
key = name
|
||||
} else {
|
||||
key = hash
|
||||
}
|
||||
|
||||
if i, ok := al.index[key]; !ok {
|
||||
al.list = append(al.list, &allowItem{
|
||||
name: name,
|
||||
hash: hash,
|
||||
uri: uri,
|
||||
gql: q,
|
||||
vars: vars,
|
||||
})
|
||||
al.index[key] = len(al.list) - 1
|
||||
} else {
|
||||
item := al.list[i]
|
||||
item.name = name
|
||||
item.hash = hash
|
||||
item.gql = q
|
||||
item.vars = vars
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
func (al *allowList) load() {
|
||||
b, err := ioutil.ReadFile(al.filepath)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
if len(b) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
var uri string
|
||||
var varBytes []byte
|
||||
|
||||
s, e, c := 0, 0, 0
|
||||
ty := 0
|
||||
|
||||
for {
|
||||
if c == 0 && b[e] == '#' {
|
||||
s = e
|
||||
for e < len(b) && b[e] != '\n' {
|
||||
e++
|
||||
}
|
||||
if (e - s) > 2 {
|
||||
uri = strings.TrimSpace(string(b[(s + 1):e]))
|
||||
}
|
||||
}
|
||||
|
||||
if e >= len(b) {
|
||||
break
|
||||
}
|
||||
|
||||
if matchPrefix(b, e, "query") || matchPrefix(b, e, "mutation") {
|
||||
if c == 0 {
|
||||
s = e
|
||||
}
|
||||
ty = AL_QUERY
|
||||
} else if matchPrefix(b, e, "variables") {
|
||||
if c == 0 {
|
||||
s = e + len("variables") + 1
|
||||
}
|
||||
ty = AL_VARS
|
||||
} else if b[e] == '{' {
|
||||
c++
|
||||
|
||||
} else if b[e] == '}' {
|
||||
c--
|
||||
|
||||
if c == 0 {
|
||||
if ty == AL_QUERY {
|
||||
al.upsert(b[s:(e+1)], varBytes, uri)
|
||||
varBytes = nil
|
||||
|
||||
} else if ty == AL_VARS {
|
||||
varBytes = b[s:(e + 1)]
|
||||
}
|
||||
ty = 0
|
||||
}
|
||||
}
|
||||
|
||||
e++
|
||||
if e >= len(b) {
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (al *allowList) save(item *allowItem) {
|
||||
var err error
|
||||
|
||||
item.hash = gqlHash(item.gql, item.vars, "")
|
||||
item.name = gqlName(item.gql)
|
||||
|
||||
if len(item.name) == 0 {
|
||||
key := item.hash
|
||||
|
||||
if _, ok := al.index[key]; ok {
|
||||
return
|
||||
}
|
||||
|
||||
al.list = append(al.list, item)
|
||||
al.index[key] = len(al.list) - 1
|
||||
|
||||
} else {
|
||||
key := item.name
|
||||
|
||||
if i, ok := al.index[key]; ok {
|
||||
if al.list[i].hash == item.hash {
|
||||
return
|
||||
}
|
||||
al.list[i] = item
|
||||
} else {
|
||||
al.list = append(al.list, item)
|
||||
al.index[key] = len(al.list) - 1
|
||||
}
|
||||
}
|
||||
|
||||
f, err := os.Create(al.filepath)
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msgf("Failed to write allow list: %s", al.filepath)
|
||||
return
|
||||
}
|
||||
|
||||
defer f.Close()
|
||||
|
||||
keys := []string{}
|
||||
urlMap := make(map[string][]*allowItem)
|
||||
|
||||
for _, v := range al.list {
|
||||
urlMap[v.uri] = append(urlMap[v.uri], v)
|
||||
}
|
||||
|
||||
for k := range urlMap {
|
||||
keys = append(keys, k)
|
||||
}
|
||||
sort.Strings(keys)
|
||||
|
||||
for i := range keys {
|
||||
k := keys[i]
|
||||
v := urlMap[k]
|
||||
|
||||
if _, err := f.WriteString(fmt.Sprintf("# %s\n\n", k)); err != nil {
|
||||
logger.Error().Err(err).Send()
|
||||
return
|
||||
}
|
||||
|
||||
for i := range v {
|
||||
if len(v[i].vars) != 0 && !bytes.Equal(v[i].vars, []byte("{}")) {
|
||||
vj, err := json.MarshalIndent(v[i].vars, "", " ")
|
||||
if err != nil {
|
||||
logger.Warn().Err(err).Msg("Failed to write allow list 'vars' to file")
|
||||
continue
|
||||
}
|
||||
|
||||
_, err = f.WriteString(fmt.Sprintf("variables %s\n\n", vj))
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Send()
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if v[i].gql[0] == '{' {
|
||||
_, err = f.WriteString(fmt.Sprintf("query %s\n\n", v[i].gql))
|
||||
} else {
|
||||
_, err = f.WriteString(fmt.Sprintf("%s\n\n", v[i].gql))
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logger.Error().Err(err).Send()
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func matchPrefix(b []byte, i int, s string) bool {
|
||||
if (len(b) - i) < len(s) {
|
||||
return false
|
||||
}
|
||||
for n := 0; n < len(s); n++ {
|
||||
if b[(i+n)] != s[n] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
176
serv/cmd.go
176
serv/cmd.go
@ -1,15 +1,13 @@
|
||||
package serv
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
"github.com/dosco/super-graph/allow"
|
||||
"github.com/dosco/super-graph/psql"
|
||||
"github.com/dosco/super-graph/qcode"
|
||||
"github.com/jackc/pgx/v4"
|
||||
"github.com/jackc/pgx/v4/pgxpool"
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/spf13/cobra"
|
||||
@ -31,17 +29,18 @@ var (
|
||||
)
|
||||
|
||||
var (
|
||||
logger zerolog.Logger // logger for everything but errors
|
||||
errlog zerolog.Logger // logger for errors includes line numbers
|
||||
conf *config // parsed config
|
||||
confPath string // path to the config file
|
||||
db *pgxpool.Pool // database connection pool
|
||||
schema *psql.DBSchema // database tables, columns and relationships
|
||||
qcompile *qcode.Compiler // qcode compiler
|
||||
pcompile *psql.Compiler // postgres sql compiler
|
||||
logger zerolog.Logger // logger for everything but errors
|
||||
errlog zerolog.Logger // logger for errors includes line numbers
|
||||
conf *config // parsed config
|
||||
confPath string // path to the config file
|
||||
db *pgxpool.Pool // database connection pool
|
||||
schema *psql.DBSchema // database tables, columns and relationships
|
||||
allowList *allow.List // allow.list is contains queries allowed in production
|
||||
qcompile *qcode.Compiler // qcode compiler
|
||||
pcompile *psql.Compiler // postgres sql compiler
|
||||
)
|
||||
|
||||
func Init() {
|
||||
func Cmd() {
|
||||
initLog()
|
||||
|
||||
rootCmd := &cobra.Command{
|
||||
@ -156,159 +155,6 @@ e.g. db:migrate -+1
|
||||
}
|
||||
}
|
||||
|
||||
func initLog() {
|
||||
out := zerolog.ConsoleWriter{Out: os.Stderr}
|
||||
logger = zerolog.New(out).With().Timestamp().Logger()
|
||||
errlog = logger.With().Caller().Logger()
|
||||
}
|
||||
|
||||
func initConf() (*config, error) {
|
||||
vi := newConfig(getConfigName())
|
||||
|
||||
if err := vi.ReadInConfig(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
inherits := vi.GetString("inherits")
|
||||
if len(inherits) != 0 {
|
||||
vi = newConfig(inherits)
|
||||
|
||||
if err := vi.ReadInConfig(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if vi.IsSet("inherits") {
|
||||
errlog.Fatal().Msgf("inherited config (%s) cannot itself inherit (%s)",
|
||||
inherits,
|
||||
vi.GetString("inherits"))
|
||||
}
|
||||
|
||||
vi.SetConfigName(getConfigName())
|
||||
|
||||
if err := vi.MergeInConfig(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
c := &config{}
|
||||
|
||||
if err := c.Init(vi); err != nil {
|
||||
return nil, fmt.Errorf("unable to decode config, %v", err)
|
||||
}
|
||||
|
||||
logLevel, err := zerolog.ParseLevel(c.LogLevel)
|
||||
if err != nil {
|
||||
errlog.Error().Err(err).Msg("error setting log_level")
|
||||
}
|
||||
zerolog.SetGlobalLevel(logLevel)
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func initDB(c *config, useDB bool) (*pgx.Conn, error) {
|
||||
config, _ := pgx.ParseConfig("")
|
||||
config.Host = c.DB.Host
|
||||
config.Port = c.DB.Port
|
||||
config.User = c.DB.User
|
||||
config.Password = c.DB.Password
|
||||
config.RuntimeParams = map[string]string{
|
||||
"application_name": c.AppName,
|
||||
"search_path": c.DB.Schema,
|
||||
}
|
||||
|
||||
if useDB {
|
||||
config.Database = c.DB.DBName
|
||||
}
|
||||
|
||||
switch c.LogLevel {
|
||||
case "debug":
|
||||
config.LogLevel = pgx.LogLevelDebug
|
||||
case "info":
|
||||
config.LogLevel = pgx.LogLevelInfo
|
||||
case "warn":
|
||||
config.LogLevel = pgx.LogLevelWarn
|
||||
case "error":
|
||||
config.LogLevel = pgx.LogLevelError
|
||||
default:
|
||||
config.LogLevel = pgx.LogLevelNone
|
||||
}
|
||||
|
||||
config.Logger = NewSQLLogger(logger)
|
||||
|
||||
db, err := pgx.ConnectConfig(context.Background(), config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func initDBPool(c *config) (*pgxpool.Pool, error) {
|
||||
config, _ := pgxpool.ParseConfig("")
|
||||
config.ConnConfig.Host = c.DB.Host
|
||||
config.ConnConfig.Port = c.DB.Port
|
||||
config.ConnConfig.Database = c.DB.DBName
|
||||
config.ConnConfig.User = c.DB.User
|
||||
config.ConnConfig.Password = c.DB.Password
|
||||
config.ConnConfig.RuntimeParams = map[string]string{
|
||||
"application_name": c.AppName,
|
||||
"search_path": c.DB.Schema,
|
||||
}
|
||||
|
||||
switch c.LogLevel {
|
||||
case "debug":
|
||||
config.ConnConfig.LogLevel = pgx.LogLevelDebug
|
||||
case "info":
|
||||
config.ConnConfig.LogLevel = pgx.LogLevelInfo
|
||||
case "warn":
|
||||
config.ConnConfig.LogLevel = pgx.LogLevelWarn
|
||||
case "error":
|
||||
config.ConnConfig.LogLevel = pgx.LogLevelError
|
||||
default:
|
||||
config.ConnConfig.LogLevel = pgx.LogLevelNone
|
||||
}
|
||||
|
||||
config.ConnConfig.Logger = NewSQLLogger(logger)
|
||||
|
||||
// if c.DB.MaxRetries != 0 {
|
||||
// opt.MaxRetries = c.DB.MaxRetries
|
||||
// }
|
||||
|
||||
if c.DB.PoolSize != 0 {
|
||||
config.MaxConns = conf.DB.PoolSize
|
||||
}
|
||||
|
||||
db, err := pgxpool.ConnectConfig(context.Background(), config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func initCompiler() {
|
||||
var err error
|
||||
|
||||
qcompile, pcompile, err = initCompilers(conf)
|
||||
if err != nil {
|
||||
errlog.Fatal().Err(err).Msg("failed to initialize compilers")
|
||||
}
|
||||
|
||||
if err := initResolvers(); err != nil {
|
||||
errlog.Fatal().Err(err).Msg("failed to initialized resolvers")
|
||||
}
|
||||
}
|
||||
|
||||
func initConfOnce() {
|
||||
var err error
|
||||
|
||||
if conf == nil {
|
||||
if conf, err = initConf(); err != nil {
|
||||
errlog.Fatal().Err(err).Msg("failed to read config")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func cmdVersion(cmd *cobra.Command, args []string) {
|
||||
fmt.Printf("%s\n", BuildDetails())
|
||||
}
|
||||
|
@ -17,7 +17,7 @@ func cmdServ(cmd *cobra.Command, args []string) {
|
||||
if err == nil {
|
||||
initCompiler()
|
||||
initAllowList(confPath)
|
||||
initPreparedList()
|
||||
initPreparedList(confPath)
|
||||
} else {
|
||||
fatalInProd(err, "failed to connect to database")
|
||||
}
|
||||
|
@ -11,6 +11,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/cespare/xxhash/v2"
|
||||
"github.com/dosco/super-graph/allow"
|
||||
"github.com/dosco/super-graph/qcode"
|
||||
"github.com/jackc/pgx/v4"
|
||||
"github.com/valyala/fasttemplate"
|
||||
@ -107,7 +108,7 @@ func (c *coreContext) resolvePreparedSQL() ([]byte, *stmt, error) {
|
||||
|
||||
}
|
||||
|
||||
ps, ok := _preparedList[gqlHash(c.req.Query, c.req.Vars, role)]
|
||||
ps, ok := _preparedList[stmtHash(allow.QueryName(c.req.Query), role)]
|
||||
if !ok {
|
||||
return nil, nil, errUnauthorized
|
||||
}
|
||||
@ -240,8 +241,10 @@ func (c *coreContext) resolveSQL() ([]byte, *stmt, error) {
|
||||
}
|
||||
}
|
||||
|
||||
if !conf.Production {
|
||||
_allowList.add(&c.req)
|
||||
if allowList.IsPersist() {
|
||||
if err := allowList.Add(c.req.Vars, c.req.Query, c.req.ref); err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if len(stmts) > 1 {
|
||||
|
@ -4,7 +4,7 @@ package serv
|
||||
|
||||
func Fuzz(data []byte) int {
|
||||
gql := string(data)
|
||||
gqlName(gql)
|
||||
QueryName(gql)
|
||||
gqlHash(gql, nil, "")
|
||||
|
||||
return 1
|
||||
|
@ -10,7 +10,6 @@ func TestFuzzCrashers(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, f := range crashers {
|
||||
_ = gqlName(f)
|
||||
gqlHash(f, nil, "")
|
||||
}
|
||||
}
|
||||
|
179
serv/init.go
Normal file
179
serv/init.go
Normal file
@ -0,0 +1,179 @@
|
||||
package serv
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/dosco/super-graph/allow"
|
||||
"github.com/jackc/pgx/v4"
|
||||
"github.com/jackc/pgx/v4/pgxpool"
|
||||
"github.com/rs/zerolog"
|
||||
)
|
||||
|
||||
func initLog() {
|
||||
out := zerolog.ConsoleWriter{Out: os.Stderr}
|
||||
logger = zerolog.New(out).With().Timestamp().Logger()
|
||||
errlog = logger.With().Caller().Logger()
|
||||
}
|
||||
|
||||
func initConf() (*config, error) {
|
||||
vi := newConfig(getConfigName())
|
||||
|
||||
if err := vi.ReadInConfig(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
inherits := vi.GetString("inherits")
|
||||
if len(inherits) != 0 {
|
||||
vi = newConfig(inherits)
|
||||
|
||||
if err := vi.ReadInConfig(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if vi.IsSet("inherits") {
|
||||
errlog.Fatal().Msgf("inherited config (%s) cannot itself inherit (%s)",
|
||||
inherits,
|
||||
vi.GetString("inherits"))
|
||||
}
|
||||
|
||||
vi.SetConfigName(getConfigName())
|
||||
|
||||
if err := vi.MergeInConfig(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
c := &config{}
|
||||
|
||||
if err := c.Init(vi); err != nil {
|
||||
return nil, fmt.Errorf("unable to decode config, %v", err)
|
||||
}
|
||||
|
||||
logLevel, err := zerolog.ParseLevel(c.LogLevel)
|
||||
if err != nil {
|
||||
errlog.Error().Err(err).Msg("error setting log_level")
|
||||
}
|
||||
zerolog.SetGlobalLevel(logLevel)
|
||||
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func initDB(c *config, useDB bool) (*pgx.Conn, error) {
|
||||
config, _ := pgx.ParseConfig("")
|
||||
config.Host = c.DB.Host
|
||||
config.Port = c.DB.Port
|
||||
config.User = c.DB.User
|
||||
config.Password = c.DB.Password
|
||||
config.RuntimeParams = map[string]string{
|
||||
"application_name": c.AppName,
|
||||
"search_path": c.DB.Schema,
|
||||
}
|
||||
|
||||
if useDB {
|
||||
config.Database = c.DB.DBName
|
||||
}
|
||||
|
||||
switch c.LogLevel {
|
||||
case "debug":
|
||||
config.LogLevel = pgx.LogLevelDebug
|
||||
case "info":
|
||||
config.LogLevel = pgx.LogLevelInfo
|
||||
case "warn":
|
||||
config.LogLevel = pgx.LogLevelWarn
|
||||
case "error":
|
||||
config.LogLevel = pgx.LogLevelError
|
||||
default:
|
||||
config.LogLevel = pgx.LogLevelNone
|
||||
}
|
||||
|
||||
config.Logger = NewSQLLogger(logger)
|
||||
|
||||
db, err := pgx.ConnectConfig(context.Background(), config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func initDBPool(c *config) (*pgxpool.Pool, error) {
|
||||
config, _ := pgxpool.ParseConfig("")
|
||||
config.ConnConfig.Host = c.DB.Host
|
||||
config.ConnConfig.Port = c.DB.Port
|
||||
config.ConnConfig.Database = c.DB.DBName
|
||||
config.ConnConfig.User = c.DB.User
|
||||
config.ConnConfig.Password = c.DB.Password
|
||||
config.ConnConfig.RuntimeParams = map[string]string{
|
||||
"application_name": c.AppName,
|
||||
"search_path": c.DB.Schema,
|
||||
}
|
||||
|
||||
switch c.LogLevel {
|
||||
case "debug":
|
||||
config.ConnConfig.LogLevel = pgx.LogLevelDebug
|
||||
case "info":
|
||||
config.ConnConfig.LogLevel = pgx.LogLevelInfo
|
||||
case "warn":
|
||||
config.ConnConfig.LogLevel = pgx.LogLevelWarn
|
||||
case "error":
|
||||
config.ConnConfig.LogLevel = pgx.LogLevelError
|
||||
default:
|
||||
config.ConnConfig.LogLevel = pgx.LogLevelNone
|
||||
}
|
||||
|
||||
config.ConnConfig.Logger = NewSQLLogger(logger)
|
||||
|
||||
// if c.DB.MaxRetries != 0 {
|
||||
// opt.MaxRetries = c.DB.MaxRetries
|
||||
// }
|
||||
|
||||
if c.DB.PoolSize != 0 {
|
||||
config.MaxConns = conf.DB.PoolSize
|
||||
}
|
||||
|
||||
db, err := pgxpool.ConnectConfig(context.Background(), config)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func initCompiler() {
|
||||
var err error
|
||||
|
||||
qcompile, pcompile, err = initCompilers(conf)
|
||||
if err != nil {
|
||||
errlog.Fatal().Err(err).Msg("failed to initialize compilers")
|
||||
}
|
||||
|
||||
if err := initResolvers(); err != nil {
|
||||
errlog.Fatal().Err(err).Msg("failed to initialized resolvers")
|
||||
}
|
||||
}
|
||||
|
||||
func initConfOnce() {
|
||||
var err error
|
||||
|
||||
if conf == nil {
|
||||
if conf, err = initConf(); err != nil {
|
||||
errlog.Fatal().Err(err).Msg("failed to read config")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func initAllowList(cpath string) {
|
||||
var ac allow.Config
|
||||
var err error
|
||||
|
||||
if !conf.Production {
|
||||
ac = allow.Config{CreateIfNotExists: true, Persist: true}
|
||||
}
|
||||
|
||||
allowList, err = allow.New(cpath, ac)
|
||||
if err != nil {
|
||||
errlog.Fatal().Err(err).Msg("failed to initialize allow list")
|
||||
}
|
||||
}
|
@ -6,6 +6,7 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/dosco/super-graph/allow"
|
||||
"github.com/dosco/super-graph/qcode"
|
||||
"github.com/jackc/pgconn"
|
||||
"github.com/jackc/pgx/v4"
|
||||
@ -23,7 +24,10 @@ var (
|
||||
_preparedList map[string]*preparedItem
|
||||
)
|
||||
|
||||
func initPreparedList() {
|
||||
func initPreparedList(cpath string) {
|
||||
if allowList.IsPersist() {
|
||||
return
|
||||
}
|
||||
_preparedList = make(map[string]*preparedItem)
|
||||
|
||||
tx, err := db.Begin(context.Background())
|
||||
@ -43,30 +47,38 @@ func initPreparedList() {
|
||||
|
||||
success := 0
|
||||
|
||||
for _, v := range _allowList.list {
|
||||
if len(v.gql) == 0 {
|
||||
list, err := allowList.Load()
|
||||
if err != nil {
|
||||
errlog.Fatal().Err(err).Send()
|
||||
}
|
||||
|
||||
for _, v := range list {
|
||||
if len(v.Query) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
err := prepareStmt(v.gql, v.vars)
|
||||
err := prepareStmt(v)
|
||||
if err == nil {
|
||||
success++
|
||||
continue
|
||||
}
|
||||
|
||||
if len(v.vars) == 0 {
|
||||
logger.Warn().Err(err).Msg(v.gql)
|
||||
if len(v.Vars) == 0 {
|
||||
logger.Warn().Err(err).Msg(v.Query)
|
||||
} else {
|
||||
logger.Warn().Err(err).Msgf("%s %s", v.vars, v.gql)
|
||||
logger.Warn().Err(err).Msgf("%s %s", v.Vars, v.Query)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Info().
|
||||
Msgf("Registered %d of %d queries from allow.list as prepared statements",
|
||||
success, len(_allowList.list))
|
||||
success, len(list))
|
||||
}
|
||||
|
||||
func prepareStmt(gql string, vars []byte) error {
|
||||
func prepareStmt(item allow.Item) error {
|
||||
gql := item.Query
|
||||
vars := item.Vars
|
||||
|
||||
qt := qcode.GetQType(gql)
|
||||
q := []byte(gql)
|
||||
|
||||
@ -99,7 +111,7 @@ func prepareStmt(gql string, vars []byte) error {
|
||||
|
||||
logger.Debug().Msg("Prepared statement role: user")
|
||||
|
||||
err = prepare(tx, stmts1, gqlHash(gql, vars, "user"))
|
||||
err = prepare(tx, stmts1, stmtHash(item.Name, "user"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -112,7 +124,7 @@ func prepareStmt(gql string, vars []byte) error {
|
||||
return err
|
||||
}
|
||||
|
||||
err = prepare(tx, stmts2, gqlHash(gql, vars, "anon"))
|
||||
err = prepare(tx, stmts2, stmtHash(item.Name, "anon"))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -127,7 +139,7 @@ func prepareStmt(gql string, vars []byte) error {
|
||||
return err
|
||||
}
|
||||
|
||||
err = prepare(tx, stmts, gqlHash(gql, vars, role.Name))
|
||||
err = prepare(tx, stmts, stmtHash(item.Name, role.Name))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
@ -22,6 +22,14 @@ func mkkey(h *xxhash.Digest, k1 string, k2 string) uint64 {
|
||||
return v
|
||||
}
|
||||
|
||||
// nolint: errcheck
|
||||
func stmtHash(name string, role string) string {
|
||||
h := sha1.New()
|
||||
io.WriteString(h, strings.ToLower(name))
|
||||
io.WriteString(h, role)
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
// nolint: errcheck
|
||||
func gqlHash(b string, vars []byte, role string) string {
|
||||
b = strings.TrimSpace(b)
|
||||
@ -108,30 +116,6 @@ func al(b byte) bool {
|
||||
return (b >= 'a' && b <= 'z') || (b >= 'A' && b <= 'Z') || (b >= '0' && b <= '9')
|
||||
}
|
||||
|
||||
func gqlName(b string) string {
|
||||
state, s := 0, 0
|
||||
|
||||
for i := 0; i < len(b); i++ {
|
||||
switch {
|
||||
case state == 2 && b[i] == '{':
|
||||
return b[s:i]
|
||||
case state == 2 && b[i] == ' ':
|
||||
return b[s:i]
|
||||
case state == 1 && b[i] == '{':
|
||||
return ""
|
||||
case state == 1 && b[i] != ' ':
|
||||
s = i
|
||||
state = 2
|
||||
case state == 1 && b[i] == ' ':
|
||||
continue
|
||||
case i != 0 && b[i] == ' ' && (b[i-1] == 'n' || b[i-1] == 'y'):
|
||||
state = 1
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
func findStmt(role string, stmts []stmt) *stmt {
|
||||
for i := range stmts {
|
||||
if stmts[i].role.Name != role {
|
||||
|
@ -229,80 +229,3 @@ func TestGQLHashWithVars2(t *testing.T) {
|
||||
t.Fatal("Hashes don't match they should")
|
||||
}
|
||||
}
|
||||
|
||||
func TestGQLName1(t *testing.T) {
|
||||
var q = `
|
||||
query {
|
||||
products(
|
||||
distinct: [price]
|
||||
where: { id: { and: { greater_or_equals: 20, lt: 28 } } }
|
||||
) { id name } }`
|
||||
|
||||
name := gqlName(q)
|
||||
|
||||
if len(name) != 0 {
|
||||
t.Fatal("Name should be empty, not ", name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGQLName2(t *testing.T) {
|
||||
var q = `
|
||||
query hakuna_matata {
|
||||
products(
|
||||
distinct: [price]
|
||||
where: { id: { and: { greater_or_equals: 20, lt: 28 } } }
|
||||
) {
|
||||
id
|
||||
name
|
||||
}
|
||||
}`
|
||||
|
||||
name := gqlName(q)
|
||||
|
||||
if name != "hakuna_matata" {
|
||||
t.Fatal("Name should be 'hakuna_matata', not ", name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGQLName3(t *testing.T) {
|
||||
var q = `
|
||||
mutation means{ users { id } }`
|
||||
|
||||
// var v2 = ` { products( limit: 30, order_by: { price: desc }, distinct: [ price ] where: { id: { and: { greater_or_equals: 20, lt: 28 } } }) { id name price user { id email } } } `
|
||||
|
||||
name := gqlName(q)
|
||||
|
||||
if name != "means" {
|
||||
t.Fatal("Name should be 'means', not ", name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGQLName4(t *testing.T) {
|
||||
var q = `
|
||||
query no_worries
|
||||
users {
|
||||
id
|
||||
}
|
||||
}`
|
||||
|
||||
name := gqlName(q)
|
||||
|
||||
if name != "no_worries" {
|
||||
t.Fatal("Name should be 'no_worries', not ", name)
|
||||
}
|
||||
}
|
||||
|
||||
func TestGQLName5(t *testing.T) {
|
||||
var q = `
|
||||
{
|
||||
users {
|
||||
id
|
||||
}
|
||||
}`
|
||||
|
||||
name := gqlName(q)
|
||||
|
||||
if len(name) != 0 {
|
||||
t.Fatal("Name should be empty, not ", name)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user