Cleanup and redesign config files
This commit is contained in:
30
serv/auth.go
30
serv/auth.go
@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
@ -21,9 +20,9 @@ var (
|
||||
)
|
||||
|
||||
func headerHandler(next http.HandlerFunc) http.HandlerFunc {
|
||||
fn := conf.GetString("auth.field_name")
|
||||
fn := conf.Auth.Header
|
||||
if len(fn) == 0 {
|
||||
panic(errors.New("no auth.field_name defined"))
|
||||
panic(errors.New("no auth.header defined"))
|
||||
}
|
||||
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
@ -39,33 +38,26 @@ func headerHandler(next http.HandlerFunc) http.HandlerFunc {
|
||||
}
|
||||
|
||||
func withAuth(next http.HandlerFunc) http.HandlerFunc {
|
||||
atype := strings.ToLower(conf.GetString("auth.type"))
|
||||
if len(atype) == 0 {
|
||||
return next
|
||||
}
|
||||
store := strings.ToLower(conf.GetString("auth.store"))
|
||||
at := conf.Auth.Type
|
||||
|
||||
switch atype {
|
||||
switch at {
|
||||
case "header":
|
||||
return headerHandler(next)
|
||||
|
||||
case "rails":
|
||||
switch store {
|
||||
case "memcache":
|
||||
return railsMemcacheHandler(next)
|
||||
case "rails_cookie":
|
||||
return railsCookieHandler(next)
|
||||
|
||||
case "redis":
|
||||
return railsRedisHandler(next)
|
||||
case "rails_memcache":
|
||||
return railsMemcacheHandler(next)
|
||||
|
||||
default:
|
||||
return railsCookieHandler(next)
|
||||
}
|
||||
case "rails_redis":
|
||||
return railsRedisHandler(next)
|
||||
|
||||
case "jwt":
|
||||
return jwtHandler(next)
|
||||
|
||||
default:
|
||||
panic(errors.New("unknown auth.type"))
|
||||
return next
|
||||
}
|
||||
|
||||
return next
|
||||
|
@ -18,18 +18,14 @@ func jwtHandler(next http.HandlerFunc) http.HandlerFunc {
|
||||
var key interface{}
|
||||
var jwtProvider int
|
||||
|
||||
cookie := conf.GetString("auth.cookie")
|
||||
cookie := conf.Auth.Cookie
|
||||
|
||||
provider := conf.GetString("auth.provider")
|
||||
if provider == "auth0" {
|
||||
if conf.Auth.JWT.Provider == "auth0" {
|
||||
jwtProvider = jwtAuth0
|
||||
}
|
||||
|
||||
conf.BindEnv("auth.secret", "SG_AUTH_SECRET")
|
||||
secret := conf.GetString("auth.secret")
|
||||
|
||||
conf.BindEnv("auth.public_key_file", "SG_AUTH_PUBLIC_KEY_FILE")
|
||||
publicKeyFile := conf.GetString("auth.public_key_file")
|
||||
secret := conf.Auth.JWT.Secret
|
||||
publicKeyFile := conf.Auth.JWT.PubKeyFile
|
||||
|
||||
switch {
|
||||
case len(secret) != 0:
|
||||
@ -41,7 +37,7 @@ func jwtHandler(next http.HandlerFunc) http.HandlerFunc {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
switch conf.GetString("auth.public_key_type") {
|
||||
switch conf.Auth.JWT.PubKeyType {
|
||||
case "ecdsa":
|
||||
key, err = jwt.ParseECPublicKeyFromPEM(kd)
|
||||
|
||||
|
@ -14,31 +14,26 @@ import (
|
||||
)
|
||||
|
||||
func railsRedisHandler(next http.HandlerFunc) http.HandlerFunc {
|
||||
cookie := conf.GetString("auth.cookie")
|
||||
cookie := conf.Auth.Cookie
|
||||
if len(cookie) == 0 {
|
||||
panic(errors.New("no auth.cookie defined"))
|
||||
}
|
||||
|
||||
conf.BindEnv("auth.url", "SG_AUTH_URL")
|
||||
authURL := conf.GetString("auth.url")
|
||||
authURL := conf.Auth.RailsRedis.URL
|
||||
if len(authURL) == 0 {
|
||||
panic(errors.New("no auth.url defined"))
|
||||
panic(errors.New("no auth.rails_redis.url defined"))
|
||||
}
|
||||
|
||||
conf.SetDefault("auth.max_idle", 80)
|
||||
conf.SetDefault("auth.max_active", 12000)
|
||||
|
||||
rp := &redis.Pool{
|
||||
MaxIdle: conf.GetInt("auth.max_idle"),
|
||||
MaxActive: conf.GetInt("auth.max_active"),
|
||||
MaxIdle: conf.Auth.RailsRedis.MaxIdle,
|
||||
MaxActive: conf.Auth.RailsRedis.MaxActive,
|
||||
Dial: func() (redis.Conn, error) {
|
||||
c, err := redis.DialURL(authURL)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
conf.BindEnv("auth.password", "SG_AUTH_PASSWORD")
|
||||
pwd := conf.GetString("auth.password")
|
||||
pwd := conf.Auth.RailsRedis.Password
|
||||
if len(pwd) != 0 {
|
||||
if _, err := c.Do("AUTH", pwd); err != nil {
|
||||
panic(err)
|
||||
@ -74,14 +69,14 @@ func railsRedisHandler(next http.HandlerFunc) http.HandlerFunc {
|
||||
}
|
||||
|
||||
func railsMemcacheHandler(next http.HandlerFunc) http.HandlerFunc {
|
||||
cookie := conf.GetString("auth.cookie")
|
||||
cookie := conf.Auth.Cookie
|
||||
if len(cookie) == 0 {
|
||||
panic(errors.New("no auth.cookie defined"))
|
||||
}
|
||||
|
||||
host := conf.GetString("auth.host")
|
||||
host := conf.Auth.RailsMemcache.Host
|
||||
if len(host) == 0 {
|
||||
panic(errors.New("no auth.host defined"))
|
||||
panic(errors.New("no auth.rails_memcache.host defined"))
|
||||
}
|
||||
|
||||
mc := memcache.New(host)
|
||||
@ -112,15 +107,14 @@ func railsMemcacheHandler(next http.HandlerFunc) http.HandlerFunc {
|
||||
}
|
||||
|
||||
func railsCookieHandler(next http.HandlerFunc) http.HandlerFunc {
|
||||
cookie := conf.GetString("auth.cookie")
|
||||
cookie := conf.Auth.Cookie
|
||||
if len(cookie) == 0 {
|
||||
panic(errors.New("no auth.cookie defined"))
|
||||
}
|
||||
|
||||
conf.BindEnv("auth.secret_key_base", "SG_AUTH_SECRET_KEY_BASE")
|
||||
secret := conf.GetString("auth.secret_key_base")
|
||||
secret := conf.Auth.RailsCookie.SecretKeyBase
|
||||
if len(secret) == 0 {
|
||||
panic(errors.New("no auth.secret_key_base defined"))
|
||||
panic(errors.New("no auth.rails_cookie.secret_key_base defined"))
|
||||
}
|
||||
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
|
@ -123,8 +123,9 @@ func apiv1Http(w http.ResponseWriter, r *http.Request) {
|
||||
errorResp(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
finalSQL := sqlStmt.String()
|
||||
if debug > 0 {
|
||||
if conf.DebugLevel > 0 {
|
||||
fmt.Println(finalSQL)
|
||||
}
|
||||
st := time.Now()
|
||||
@ -140,7 +141,7 @@ func apiv1Http(w http.ResponseWriter, r *http.Request) {
|
||||
et := time.Now()
|
||||
resp := gqlResp{}
|
||||
|
||||
if tracing {
|
||||
if conf.EnableTracing {
|
||||
resp.Extensions = &extensions{newTrace(st, et, qc)}
|
||||
}
|
||||
|
||||
|
288
serv/serv.go
288
serv/serv.go
@ -4,6 +4,7 @@ import (
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"os"
|
||||
"strings"
|
||||
@ -26,142 +27,253 @@ const (
|
||||
|
||||
var (
|
||||
logger *logrus.Logger
|
||||
debug int
|
||||
conf *viper.Viper
|
||||
conf *config
|
||||
db *pg.DB
|
||||
pcompile *psql.Compiler
|
||||
qcompile *qcode.Compiler
|
||||
authFailBlock int
|
||||
tracing bool
|
||||
)
|
||||
|
||||
func initLog() {
|
||||
logger = logrus.New()
|
||||
logger.Formatter = new(logrus.TextFormatter)
|
||||
logger.Formatter.(*logrus.TextFormatter).DisableColors = false
|
||||
logger.Formatter.(*logrus.TextFormatter).DisableTimestamp = true
|
||||
logger.Level = logrus.TraceLevel
|
||||
logger.Out = os.Stdout
|
||||
type config struct {
|
||||
Env string
|
||||
HostPort string `mapstructure:"host_port"`
|
||||
WebUI bool `mapstructure:"web_ui"`
|
||||
DebugLevel int `mapstructure:"debug_level"`
|
||||
EnableTracing bool `mapstructure:"enable_tracing"`
|
||||
AuthFailBlock string `mapstructure:"auth_fail_block"`
|
||||
Inflections map[string]string
|
||||
|
||||
Auth struct {
|
||||
Type string
|
||||
Cookie string
|
||||
Header string
|
||||
|
||||
RailsCookie struct {
|
||||
SecretKeyBase string `mapstructure:"secret_key_base"`
|
||||
}
|
||||
|
||||
RailsMemcache struct {
|
||||
Host string
|
||||
}
|
||||
|
||||
RailsRedis struct {
|
||||
URL string
|
||||
Password string
|
||||
MaxIdle int `mapstructure:"max_idle"`
|
||||
MaxActive int `mapstructure:"max_active"`
|
||||
}
|
||||
|
||||
JWT struct {
|
||||
Provider string
|
||||
Secret string
|
||||
PubKeyFile string `mapstructure:"public_key_file"`
|
||||
PubKeyType string `mapstructure:"public_key_type"`
|
||||
}
|
||||
}
|
||||
|
||||
DB struct {
|
||||
Type string
|
||||
Host string
|
||||
Port string
|
||||
DBName string
|
||||
User string
|
||||
Password string
|
||||
PoolSize int `mapstructure:"pool_size"`
|
||||
MaxRetries int `mapstructure:"max_retries"`
|
||||
LogLevel string `mapstructure:"log_level"`
|
||||
|
||||
Variables map[string]string
|
||||
|
||||
Defaults struct {
|
||||
Filter []string
|
||||
Blacklist []string
|
||||
}
|
||||
|
||||
Fields []struct {
|
||||
Name string
|
||||
Filter []string
|
||||
Table string
|
||||
Blacklist []string
|
||||
}
|
||||
} `mapstructure:"database"`
|
||||
}
|
||||
|
||||
func initConf() {
|
||||
conf = viper.New()
|
||||
func initLog() *logrus.Logger {
|
||||
log := logrus.New()
|
||||
log.Formatter = new(logrus.TextFormatter)
|
||||
log.Formatter.(*logrus.TextFormatter).DisableColors = false
|
||||
log.Formatter.(*logrus.TextFormatter).DisableTimestamp = true
|
||||
log.Level = logrus.TraceLevel
|
||||
log.Out = os.Stdout
|
||||
|
||||
cPath := flag.String("path", ".", "Path to folder that contains config files")
|
||||
return log
|
||||
}
|
||||
|
||||
func initConf() (*config, error) {
|
||||
vi := viper.New()
|
||||
|
||||
path := flag.String("path", "./", "Path to config files")
|
||||
flag.Parse()
|
||||
|
||||
conf.AddConfigPath(*cPath)
|
||||
vi.SetEnvPrefix("SG")
|
||||
vi.SetEnvKeyReplacer(strings.NewReplacer(".", "_"))
|
||||
vi.AutomaticEnv()
|
||||
|
||||
switch os.Getenv("GO_ENV") {
|
||||
case "production", "prod":
|
||||
conf.SetConfigName("prod")
|
||||
case "staging", "stage":
|
||||
conf.SetConfigName("stage")
|
||||
default:
|
||||
conf.SetConfigName("dev")
|
||||
vi.AddConfigPath(*path)
|
||||
vi.AddConfigPath("./conf")
|
||||
vi.SetConfigName(getConfigName())
|
||||
|
||||
vi.SetDefault("host_port", "0.0.0.0:8080")
|
||||
vi.SetDefault("web_ui", false)
|
||||
vi.SetDefault("debug_level", 0)
|
||||
vi.SetDefault("enable_tracing", false)
|
||||
|
||||
vi.SetDefault("database.type", "postgres")
|
||||
vi.SetDefault("database.host", "localhost")
|
||||
vi.SetDefault("database.port", 5432)
|
||||
vi.SetDefault("database.user", "postgres")
|
||||
vi.SetDefault("database.password", "")
|
||||
|
||||
vi.SetDefault("env", "development")
|
||||
vi.BindEnv("env", "GO_ENV")
|
||||
|
||||
vi.SetDefault("auth.rails_redis.max_idle", 80)
|
||||
vi.SetDefault("auth.rails_redis.max_active", 12000)
|
||||
|
||||
if err := vi.ReadInConfig(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err := conf.ReadInConfig()
|
||||
if err != nil {
|
||||
logger.Fatal(err)
|
||||
c := &config{}
|
||||
|
||||
if err := vi.Unmarshal(c); err != nil {
|
||||
return nil, fmt.Errorf("unable to decode config, %v", err)
|
||||
}
|
||||
|
||||
debug = conf.GetInt("debug_level")
|
||||
|
||||
for k, v := range conf.GetStringMapString("inflections") {
|
||||
for k, v := range c.Inflections {
|
||||
flect.AddPlural(k, v)
|
||||
}
|
||||
|
||||
conf.SetDefault("host_port", "0.0.0.0:8080")
|
||||
conf.SetDefault("web_ui", false)
|
||||
conf.SetDefault("debug_level", 0)
|
||||
conf.SetDefault("enable_tracing", false)
|
||||
authFailBlock = getAuthFailBlock(c)
|
||||
|
||||
conf.SetDefault("database.type", "postgres")
|
||||
conf.SetDefault("database.host", "localhost")
|
||||
conf.SetDefault("database.port", 5432)
|
||||
conf.SetDefault("database.user", "postgres")
|
||||
conf.SetDefault("database.password", "")
|
||||
//fmt.Printf("%#v", c)
|
||||
|
||||
conf.SetDefault("env", "development")
|
||||
conf.BindEnv("env", "GO_ENV")
|
||||
|
||||
tracing = conf.GetBool("enable_tracing")
|
||||
|
||||
switch conf.GetString("auth_fail_block") {
|
||||
case "always":
|
||||
authFailBlock = authFailBlockAlways
|
||||
case "per_query", "perquery", "query":
|
||||
authFailBlock = authFailBlockPerQuery
|
||||
case "never", "false":
|
||||
authFailBlock = authFailBlockNever
|
||||
default:
|
||||
authFailBlock = authFailBlockAlways
|
||||
}
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func initDB() {
|
||||
conf.BindEnv("database.host", "SG_DATABASE_HOST")
|
||||
conf.BindEnv("database.port", "SG_DATABASE_PORT")
|
||||
conf.BindEnv("database.user", "SG_DATABASE_USER")
|
||||
conf.BindEnv("database.password", "SG_DATABASE_PASSWORD")
|
||||
|
||||
hostport := strings.Join([]string{
|
||||
conf.GetString("database.host"), conf.GetString("database.port")}, ":")
|
||||
|
||||
func initDB(c *config) (*pg.DB, error) {
|
||||
opt := &pg.Options{
|
||||
Addr: hostport,
|
||||
User: conf.GetString("database.user"),
|
||||
Password: conf.GetString("database.password"),
|
||||
Database: conf.GetString("database.dbname"),
|
||||
Addr: strings.Join([]string{c.DB.Host, c.DB.Port}, ":"),
|
||||
User: c.DB.User,
|
||||
Password: c.DB.Password,
|
||||
Database: c.DB.DBName,
|
||||
}
|
||||
|
||||
if conf.IsSet("database.pool_size") {
|
||||
opt.PoolSize = conf.GetInt("database.pool_size")
|
||||
if c.DB.PoolSize != 0 {
|
||||
opt.PoolSize = conf.DB.PoolSize
|
||||
}
|
||||
|
||||
if conf.IsSet("database.max_retries") {
|
||||
opt.MaxRetries = conf.GetInt("database.max_retries")
|
||||
if c.DB.MaxRetries != 0 {
|
||||
opt.MaxRetries = c.DB.MaxRetries
|
||||
}
|
||||
|
||||
if db = pg.Connect(opt); db == nil {
|
||||
logger.Fatal(errors.New("failed to connect to postgres db"))
|
||||
db := pg.Connect(opt)
|
||||
if db == nil {
|
||||
return nil, errors.New("failed to connect to postgres db")
|
||||
}
|
||||
|
||||
return db, nil
|
||||
}
|
||||
|
||||
func initCompilers() {
|
||||
filters := conf.GetStringMapString("database.filters")
|
||||
blacklist := conf.GetStringSlice("database.blacklist")
|
||||
func initCompilers(c *config) (*qcode.Compiler, *psql.Compiler, error) {
|
||||
cdb := c.DB
|
||||
|
||||
fm := qcode.NewFilterMap(filters)
|
||||
bl := qcode.NewBlacklist(blacklist)
|
||||
qcompile = qcode.NewCompiler(fm, bl)
|
||||
fm := make(map[string][]string, len(cdb.Fields))
|
||||
for i := range cdb.Fields {
|
||||
f := cdb.Fields[i]
|
||||
fm[strings.ToLower(f.Name)] = f.Filter
|
||||
}
|
||||
|
||||
qc, err := qcode.NewCompiler(qcode.Config{
|
||||
Filter: cdb.Defaults.Filter,
|
||||
FilterMap: fm,
|
||||
Blacklist: cdb.Defaults.Blacklist,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
schema, err := psql.NewDBSchema(db)
|
||||
if err != nil {
|
||||
logger.Fatal(err)
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
varlist := conf.GetStringMapString("database.variables")
|
||||
vars := psql.NewVariables(varlist)
|
||||
pc := psql.NewCompiler(psql.Config{
|
||||
Schema: schema,
|
||||
Vars: cdb.Variables,
|
||||
})
|
||||
|
||||
pcompile = psql.NewCompiler(schema, vars)
|
||||
return qc, pc, nil
|
||||
}
|
||||
|
||||
func InitAndListen() {
|
||||
initLog()
|
||||
initConf()
|
||||
initDB()
|
||||
initCompilers()
|
||||
var err error
|
||||
|
||||
logger = initLog()
|
||||
|
||||
conf, err = initConf()
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
db, err = initDB(conf)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
qcompile, pcompile, err = initCompilers(conf)
|
||||
if err != nil {
|
||||
log.Fatal(err)
|
||||
}
|
||||
|
||||
http.HandleFunc("/api/v1/graphql", withAuth(apiv1Http))
|
||||
|
||||
if conf.GetBool("web_ui") {
|
||||
if conf.WebUI {
|
||||
http.Handle("/", http.FileServer(_escFS(false)))
|
||||
}
|
||||
|
||||
hp := conf.GetString("host_port")
|
||||
fmt.Printf("Super-Graph listening on %s (%s)\n", hp, conf.GetString("env"))
|
||||
fmt.Printf("Super-Graph listening on %s (%s)\n",
|
||||
conf.HostPort, conf.Env)
|
||||
|
||||
logger.Fatal(http.ListenAndServe(hp, nil))
|
||||
logger.Fatal(http.ListenAndServe(conf.HostPort, nil))
|
||||
}
|
||||
|
||||
func getConfigName() string {
|
||||
ge := strings.ToLower(os.Getenv("GO_ENV"))
|
||||
|
||||
switch {
|
||||
case strings.HasPrefix(ge, "pro"):
|
||||
return "prod"
|
||||
|
||||
case strings.HasPrefix(ge, "sta"):
|
||||
return "stage"
|
||||
|
||||
case strings.HasPrefix(ge, "tes"):
|
||||
return "test"
|
||||
}
|
||||
|
||||
return "dev"
|
||||
}
|
||||
|
||||
func getAuthFailBlock(c *config) int {
|
||||
switch c.AuthFailBlock {
|
||||
case "always":
|
||||
return authFailBlockAlways
|
||||
case "per_query", "perquery", "query":
|
||||
return authFailBlockPerQuery
|
||||
case "never", "false":
|
||||
return authFailBlockNever
|
||||
}
|
||||
|
||||
return authFailBlockAlways
|
||||
}
|
||||
|
Reference in New Issue
Block a user