Cleanup and redesign config files

This commit is contained in:
Vikram Rangnekar
2019-04-08 02:47:59 -04:00
parent 8acc3ed08d
commit e3660473cc
11 changed files with 447 additions and 327 deletions

View File

@ -4,7 +4,6 @@ import (
"context"
"errors"
"net/http"
"strings"
)
const (
@ -21,9 +20,9 @@ var (
)
func headerHandler(next http.HandlerFunc) http.HandlerFunc {
fn := conf.GetString("auth.field_name")
fn := conf.Auth.Header
if len(fn) == 0 {
panic(errors.New("no auth.field_name defined"))
panic(errors.New("no auth.header defined"))
}
return func(w http.ResponseWriter, r *http.Request) {
@ -39,33 +38,26 @@ func headerHandler(next http.HandlerFunc) http.HandlerFunc {
}
func withAuth(next http.HandlerFunc) http.HandlerFunc {
atype := strings.ToLower(conf.GetString("auth.type"))
if len(atype) == 0 {
return next
}
store := strings.ToLower(conf.GetString("auth.store"))
at := conf.Auth.Type
switch atype {
switch at {
case "header":
return headerHandler(next)
case "rails":
switch store {
case "memcache":
return railsMemcacheHandler(next)
case "rails_cookie":
return railsCookieHandler(next)
case "redis":
return railsRedisHandler(next)
case "rails_memcache":
return railsMemcacheHandler(next)
default:
return railsCookieHandler(next)
}
case "rails_redis":
return railsRedisHandler(next)
case "jwt":
return jwtHandler(next)
default:
panic(errors.New("unknown auth.type"))
return next
}
return next

View File

@ -18,18 +18,14 @@ func jwtHandler(next http.HandlerFunc) http.HandlerFunc {
var key interface{}
var jwtProvider int
cookie := conf.GetString("auth.cookie")
cookie := conf.Auth.Cookie
provider := conf.GetString("auth.provider")
if provider == "auth0" {
if conf.Auth.JWT.Provider == "auth0" {
jwtProvider = jwtAuth0
}
conf.BindEnv("auth.secret", "SG_AUTH_SECRET")
secret := conf.GetString("auth.secret")
conf.BindEnv("auth.public_key_file", "SG_AUTH_PUBLIC_KEY_FILE")
publicKeyFile := conf.GetString("auth.public_key_file")
secret := conf.Auth.JWT.Secret
publicKeyFile := conf.Auth.JWT.PubKeyFile
switch {
case len(secret) != 0:
@ -41,7 +37,7 @@ func jwtHandler(next http.HandlerFunc) http.HandlerFunc {
panic(err)
}
switch conf.GetString("auth.public_key_type") {
switch conf.Auth.JWT.PubKeyType {
case "ecdsa":
key, err = jwt.ParseECPublicKeyFromPEM(kd)

View File

@ -14,31 +14,26 @@ import (
)
func railsRedisHandler(next http.HandlerFunc) http.HandlerFunc {
cookie := conf.GetString("auth.cookie")
cookie := conf.Auth.Cookie
if len(cookie) == 0 {
panic(errors.New("no auth.cookie defined"))
}
conf.BindEnv("auth.url", "SG_AUTH_URL")
authURL := conf.GetString("auth.url")
authURL := conf.Auth.RailsRedis.URL
if len(authURL) == 0 {
panic(errors.New("no auth.url defined"))
panic(errors.New("no auth.rails_redis.url defined"))
}
conf.SetDefault("auth.max_idle", 80)
conf.SetDefault("auth.max_active", 12000)
rp := &redis.Pool{
MaxIdle: conf.GetInt("auth.max_idle"),
MaxActive: conf.GetInt("auth.max_active"),
MaxIdle: conf.Auth.RailsRedis.MaxIdle,
MaxActive: conf.Auth.RailsRedis.MaxActive,
Dial: func() (redis.Conn, error) {
c, err := redis.DialURL(authURL)
if err != nil {
panic(err)
}
conf.BindEnv("auth.password", "SG_AUTH_PASSWORD")
pwd := conf.GetString("auth.password")
pwd := conf.Auth.RailsRedis.Password
if len(pwd) != 0 {
if _, err := c.Do("AUTH", pwd); err != nil {
panic(err)
@ -74,14 +69,14 @@ func railsRedisHandler(next http.HandlerFunc) http.HandlerFunc {
}
func railsMemcacheHandler(next http.HandlerFunc) http.HandlerFunc {
cookie := conf.GetString("auth.cookie")
cookie := conf.Auth.Cookie
if len(cookie) == 0 {
panic(errors.New("no auth.cookie defined"))
}
host := conf.GetString("auth.host")
host := conf.Auth.RailsMemcache.Host
if len(host) == 0 {
panic(errors.New("no auth.host defined"))
panic(errors.New("no auth.rails_memcache.host defined"))
}
mc := memcache.New(host)
@ -112,15 +107,14 @@ func railsMemcacheHandler(next http.HandlerFunc) http.HandlerFunc {
}
func railsCookieHandler(next http.HandlerFunc) http.HandlerFunc {
cookie := conf.GetString("auth.cookie")
cookie := conf.Auth.Cookie
if len(cookie) == 0 {
panic(errors.New("no auth.cookie defined"))
}
conf.BindEnv("auth.secret_key_base", "SG_AUTH_SECRET_KEY_BASE")
secret := conf.GetString("auth.secret_key_base")
secret := conf.Auth.RailsCookie.SecretKeyBase
if len(secret) == 0 {
panic(errors.New("no auth.secret_key_base defined"))
panic(errors.New("no auth.rails_cookie.secret_key_base defined"))
}
return func(w http.ResponseWriter, r *http.Request) {

View File

@ -123,8 +123,9 @@ func apiv1Http(w http.ResponseWriter, r *http.Request) {
errorResp(w, err)
return
}
finalSQL := sqlStmt.String()
if debug > 0 {
if conf.DebugLevel > 0 {
fmt.Println(finalSQL)
}
st := time.Now()
@ -140,7 +141,7 @@ func apiv1Http(w http.ResponseWriter, r *http.Request) {
et := time.Now()
resp := gqlResp{}
if tracing {
if conf.EnableTracing {
resp.Extensions = &extensions{newTrace(st, et, qc)}
}

View File

@ -4,6 +4,7 @@ import (
"errors"
"flag"
"fmt"
"log"
"net/http"
"os"
"strings"
@ -26,142 +27,253 @@ const (
var (
logger *logrus.Logger
debug int
conf *viper.Viper
conf *config
db *pg.DB
pcompile *psql.Compiler
qcompile *qcode.Compiler
authFailBlock int
tracing bool
)
func initLog() {
logger = logrus.New()
logger.Formatter = new(logrus.TextFormatter)
logger.Formatter.(*logrus.TextFormatter).DisableColors = false
logger.Formatter.(*logrus.TextFormatter).DisableTimestamp = true
logger.Level = logrus.TraceLevel
logger.Out = os.Stdout
type config struct {
Env string
HostPort string `mapstructure:"host_port"`
WebUI bool `mapstructure:"web_ui"`
DebugLevel int `mapstructure:"debug_level"`
EnableTracing bool `mapstructure:"enable_tracing"`
AuthFailBlock string `mapstructure:"auth_fail_block"`
Inflections map[string]string
Auth struct {
Type string
Cookie string
Header string
RailsCookie struct {
SecretKeyBase string `mapstructure:"secret_key_base"`
}
RailsMemcache struct {
Host string
}
RailsRedis struct {
URL string
Password string
MaxIdle int `mapstructure:"max_idle"`
MaxActive int `mapstructure:"max_active"`
}
JWT struct {
Provider string
Secret string
PubKeyFile string `mapstructure:"public_key_file"`
PubKeyType string `mapstructure:"public_key_type"`
}
}
DB struct {
Type string
Host string
Port string
DBName string
User string
Password string
PoolSize int `mapstructure:"pool_size"`
MaxRetries int `mapstructure:"max_retries"`
LogLevel string `mapstructure:"log_level"`
Variables map[string]string
Defaults struct {
Filter []string
Blacklist []string
}
Fields []struct {
Name string
Filter []string
Table string
Blacklist []string
}
} `mapstructure:"database"`
}
func initConf() {
conf = viper.New()
func initLog() *logrus.Logger {
log := logrus.New()
log.Formatter = new(logrus.TextFormatter)
log.Formatter.(*logrus.TextFormatter).DisableColors = false
log.Formatter.(*logrus.TextFormatter).DisableTimestamp = true
log.Level = logrus.TraceLevel
log.Out = os.Stdout
cPath := flag.String("path", ".", "Path to folder that contains config files")
return log
}
func initConf() (*config, error) {
vi := viper.New()
path := flag.String("path", "./", "Path to config files")
flag.Parse()
conf.AddConfigPath(*cPath)
vi.SetEnvPrefix("SG")
vi.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
vi.AutomaticEnv()
switch os.Getenv("GO_ENV") {
case "production", "prod":
conf.SetConfigName("prod")
case "staging", "stage":
conf.SetConfigName("stage")
default:
conf.SetConfigName("dev")
vi.AddConfigPath(*path)
vi.AddConfigPath("./conf")
vi.SetConfigName(getConfigName())
vi.SetDefault("host_port", "0.0.0.0:8080")
vi.SetDefault("web_ui", false)
vi.SetDefault("debug_level", 0)
vi.SetDefault("enable_tracing", false)
vi.SetDefault("database.type", "postgres")
vi.SetDefault("database.host", "localhost")
vi.SetDefault("database.port", 5432)
vi.SetDefault("database.user", "postgres")
vi.SetDefault("database.password", "")
vi.SetDefault("env", "development")
vi.BindEnv("env", "GO_ENV")
vi.SetDefault("auth.rails_redis.max_idle", 80)
vi.SetDefault("auth.rails_redis.max_active", 12000)
if err := vi.ReadInConfig(); err != nil {
return nil, err
}
err := conf.ReadInConfig()
if err != nil {
logger.Fatal(err)
c := &config{}
if err := vi.Unmarshal(c); err != nil {
return nil, fmt.Errorf("unable to decode config, %v", err)
}
debug = conf.GetInt("debug_level")
for k, v := range conf.GetStringMapString("inflections") {
for k, v := range c.Inflections {
flect.AddPlural(k, v)
}
conf.SetDefault("host_port", "0.0.0.0:8080")
conf.SetDefault("web_ui", false)
conf.SetDefault("debug_level", 0)
conf.SetDefault("enable_tracing", false)
authFailBlock = getAuthFailBlock(c)
conf.SetDefault("database.type", "postgres")
conf.SetDefault("database.host", "localhost")
conf.SetDefault("database.port", 5432)
conf.SetDefault("database.user", "postgres")
conf.SetDefault("database.password", "")
//fmt.Printf("%#v", c)
conf.SetDefault("env", "development")
conf.BindEnv("env", "GO_ENV")
tracing = conf.GetBool("enable_tracing")
switch conf.GetString("auth_fail_block") {
case "always":
authFailBlock = authFailBlockAlways
case "per_query", "perquery", "query":
authFailBlock = authFailBlockPerQuery
case "never", "false":
authFailBlock = authFailBlockNever
default:
authFailBlock = authFailBlockAlways
}
return c, nil
}
func initDB() {
conf.BindEnv("database.host", "SG_DATABASE_HOST")
conf.BindEnv("database.port", "SG_DATABASE_PORT")
conf.BindEnv("database.user", "SG_DATABASE_USER")
conf.BindEnv("database.password", "SG_DATABASE_PASSWORD")
hostport := strings.Join([]string{
conf.GetString("database.host"), conf.GetString("database.port")}, ":")
func initDB(c *config) (*pg.DB, error) {
opt := &pg.Options{
Addr: hostport,
User: conf.GetString("database.user"),
Password: conf.GetString("database.password"),
Database: conf.GetString("database.dbname"),
Addr: strings.Join([]string{c.DB.Host, c.DB.Port}, ":"),
User: c.DB.User,
Password: c.DB.Password,
Database: c.DB.DBName,
}
if conf.IsSet("database.pool_size") {
opt.PoolSize = conf.GetInt("database.pool_size")
if c.DB.PoolSize != 0 {
opt.PoolSize = conf.DB.PoolSize
}
if conf.IsSet("database.max_retries") {
opt.MaxRetries = conf.GetInt("database.max_retries")
if c.DB.MaxRetries != 0 {
opt.MaxRetries = c.DB.MaxRetries
}
if db = pg.Connect(opt); db == nil {
logger.Fatal(errors.New("failed to connect to postgres db"))
db := pg.Connect(opt)
if db == nil {
return nil, errors.New("failed to connect to postgres db")
}
return db, nil
}
func initCompilers() {
filters := conf.GetStringMapString("database.filters")
blacklist := conf.GetStringSlice("database.blacklist")
func initCompilers(c *config) (*qcode.Compiler, *psql.Compiler, error) {
cdb := c.DB
fm := qcode.NewFilterMap(filters)
bl := qcode.NewBlacklist(blacklist)
qcompile = qcode.NewCompiler(fm, bl)
fm := make(map[string][]string, len(cdb.Fields))
for i := range cdb.Fields {
f := cdb.Fields[i]
fm[strings.ToLower(f.Name)] = f.Filter
}
qc, err := qcode.NewCompiler(qcode.Config{
Filter: cdb.Defaults.Filter,
FilterMap: fm,
Blacklist: cdb.Defaults.Blacklist,
})
if err != nil {
return nil, nil, err
}
schema, err := psql.NewDBSchema(db)
if err != nil {
logger.Fatal(err)
return nil, nil, err
}
varlist := conf.GetStringMapString("database.variables")
vars := psql.NewVariables(varlist)
pc := psql.NewCompiler(psql.Config{
Schema: schema,
Vars: cdb.Variables,
})
pcompile = psql.NewCompiler(schema, vars)
return qc, pc, nil
}
func InitAndListen() {
initLog()
initConf()
initDB()
initCompilers()
var err error
logger = initLog()
conf, err = initConf()
if err != nil {
log.Fatal(err)
}
db, err = initDB(conf)
if err != nil {
log.Fatal(err)
}
qcompile, pcompile, err = initCompilers(conf)
if err != nil {
log.Fatal(err)
}
http.HandleFunc("/api/v1/graphql", withAuth(apiv1Http))
if conf.GetBool("web_ui") {
if conf.WebUI {
http.Handle("/", http.FileServer(_escFS(false)))
}
hp := conf.GetString("host_port")
fmt.Printf("Super-Graph listening on %s (%s)\n", hp, conf.GetString("env"))
fmt.Printf("Super-Graph listening on %s (%s)\n",
conf.HostPort, conf.Env)
logger.Fatal(http.ListenAndServe(hp, nil))
logger.Fatal(http.ListenAndServe(conf.HostPort, nil))
}
func getConfigName() string {
ge := strings.ToLower(os.Getenv("GO_ENV"))
switch {
case strings.HasPrefix(ge, "pro"):
return "prod"
case strings.HasPrefix(ge, "sta"):
return "stage"
case strings.HasPrefix(ge, "tes"):
return "test"
}
return "dev"
}
func getAuthFailBlock(c *config) int {
switch c.AuthFailBlock {
case "always":
return authFailBlockAlways
case "per_query", "perquery", "query":
return authFailBlockPerQuery
case "never", "false":
return authFailBlockNever
}
return authFailBlockAlways
}