181 lines
4.0 KiB
Go
181 lines
4.0 KiB
Go
|
package serv
|
||
|
|
||
|
import (
|
||
|
"errors"
|
||
|
"flag"
|
||
|
"fmt"
|
||
|
"net/http"
|
||
|
"os"
|
||
|
"regexp"
|
||
|
"strings"
|
||
|
|
||
|
"github.com/dosco/super-graph/psql"
|
||
|
"github.com/dosco/super-graph/qcode"
|
||
|
"github.com/go-pg/pg"
|
||
|
"github.com/jinzhu/inflection"
|
||
|
"github.com/sirupsen/logrus"
|
||
|
"github.com/spf13/viper"
|
||
|
)
|
||
|
|
||
|
const (
|
||
|
authFailBlockAlways = iota + 1
|
||
|
authFailBlockPerQuery
|
||
|
authFailBlockNever
|
||
|
)
|
||
|
|
||
|
var (
|
||
|
logger *logrus.Logger
|
||
|
debug int
|
||
|
conf *viper.Viper
|
||
|
db *pg.DB
|
||
|
pcompile *psql.Compiler
|
||
|
qcompile *qcode.Compiler
|
||
|
authFailBlock int
|
||
|
)
|
||
|
|
||
|
func initLog() {
|
||
|
logger = logrus.New()
|
||
|
logger.Formatter = new(logrus.TextFormatter)
|
||
|
logger.Formatter.(*logrus.TextFormatter).DisableColors = false
|
||
|
logger.Formatter.(*logrus.TextFormatter).DisableTimestamp = true
|
||
|
logger.Level = logrus.TraceLevel
|
||
|
logger.Out = os.Stdout
|
||
|
}
|
||
|
|
||
|
func initConf() {
|
||
|
conf = viper.New()
|
||
|
|
||
|
cPath := flag.String("path", ".", "Path to folder that contains config files")
|
||
|
flag.Parse()
|
||
|
|
||
|
conf.AddConfigPath(*cPath)
|
||
|
|
||
|
switch os.Getenv("GO_ENV") {
|
||
|
case "production", "prod":
|
||
|
conf.SetConfigName("prod")
|
||
|
case "staging", "stage":
|
||
|
conf.SetConfigName("stage")
|
||
|
default:
|
||
|
conf.SetConfigName("dev")
|
||
|
}
|
||
|
|
||
|
err := conf.ReadInConfig()
|
||
|
if err != nil {
|
||
|
logger.Fatal(err)
|
||
|
}
|
||
|
|
||
|
debug = conf.GetInt("debug_level")
|
||
|
|
||
|
for k, v := range conf.GetStringMapString("inflections") {
|
||
|
inflection.AddIrregular(k, v)
|
||
|
}
|
||
|
|
||
|
conf.SetDefault("host_port", "0.0.0.0:8080")
|
||
|
conf.SetDefault("web_ui", false)
|
||
|
conf.SetDefault("debug_level", 0)
|
||
|
|
||
|
conf.SetDefault("database.type", "postgres")
|
||
|
conf.SetDefault("database.host", "localhost")
|
||
|
conf.SetDefault("database.port", 5432)
|
||
|
conf.SetDefault("database.user", "postgres")
|
||
|
conf.SetDefault("database.password", "")
|
||
|
|
||
|
conf.SetDefault("env", "development")
|
||
|
conf.BindEnv("env", "GO_ENV")
|
||
|
|
||
|
switch conf.GetString("auth_fail_block") {
|
||
|
case "always":
|
||
|
authFailBlock = authFailBlockAlways
|
||
|
case "per_query", "perquery", "query":
|
||
|
authFailBlock = authFailBlockPerQuery
|
||
|
case "never", "false":
|
||
|
authFailBlock = authFailBlockNever
|
||
|
default:
|
||
|
authFailBlock = authFailBlockAlways
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func initDB() {
|
||
|
conf.BindEnv("database.host", "SG_DATABASE_HOST")
|
||
|
conf.BindEnv("database.port", "SG_DATABASE_PORT")
|
||
|
conf.BindEnv("database.user", "SG_DATABASE_USER")
|
||
|
conf.BindEnv("database.password", "SG_DATABASE_PASSWORD")
|
||
|
|
||
|
hostport := strings.Join([]string{
|
||
|
conf.GetString("database.host"), conf.GetString("database.port")}, ":")
|
||
|
|
||
|
opt := &pg.Options{
|
||
|
Addr: hostport,
|
||
|
User: conf.GetString("database.user"),
|
||
|
Password: conf.GetString("database.password"),
|
||
|
Database: conf.GetString("database.dbname"),
|
||
|
}
|
||
|
|
||
|
if conf.IsSet("database.pool_size") {
|
||
|
opt.PoolSize = conf.GetInt("database.pool_size")
|
||
|
}
|
||
|
|
||
|
if conf.IsSet("database.max_retries") {
|
||
|
opt.MaxRetries = conf.GetInt("database.max_retries")
|
||
|
}
|
||
|
|
||
|
if db = pg.Connect(opt); db == nil {
|
||
|
logger.Fatal(errors.New("failed to connect to postgres db"))
|
||
|
}
|
||
|
}
|
||
|
|
||
|
func initCompilers() {
|
||
|
fv := conf.GetStringMapString("database.filters")
|
||
|
fm := make(qcode.FilterMap)
|
||
|
for k, v := range fv {
|
||
|
fil, err := qcode.CompileFilter(v)
|
||
|
if err != nil {
|
||
|
panic(err)
|
||
|
}
|
||
|
key := strings.ToLower(k)
|
||
|
fm[key] = fil
|
||
|
}
|
||
|
|
||
|
bv := conf.GetStringSlice("database.blacklist")
|
||
|
var bl *regexp.Regexp
|
||
|
if len(bv) != 0 {
|
||
|
re := fmt.Sprintf("(?i)%s", strings.Join(bv, "|"))
|
||
|
bl = regexp.MustCompile(re)
|
||
|
}
|
||
|
qcompile = qcode.NewCompiler(fm, bl)
|
||
|
|
||
|
schema, err := psql.NewDBSchema(db)
|
||
|
if err != nil {
|
||
|
logger.Fatal(err)
|
||
|
}
|
||
|
|
||
|
re := regexp.MustCompile(`(?mi)\$([a-zA-Z0-9_.]+)`)
|
||
|
vl := conf.GetStringMapString("database.variables")
|
||
|
vars := make(map[string]string)
|
||
|
|
||
|
for k, v := range vl {
|
||
|
vars[k] = re.ReplaceAllString(v, `{{$1}}`)
|
||
|
}
|
||
|
|
||
|
pcompile = psql.NewCompiler(schema, vars)
|
||
|
}
|
||
|
|
||
|
func InitAndListen() {
|
||
|
initLog()
|
||
|
initConf()
|
||
|
initDB()
|
||
|
initCompilers()
|
||
|
|
||
|
http.HandleFunc("/api/v1/graphql", withAuth(apiv1Http))
|
||
|
|
||
|
if conf.GetBool("web_ui") {
|
||
|
fs := http.FileServer(http.Dir("web/build"))
|
||
|
http.Handle("/", fs)
|
||
|
}
|
||
|
|
||
|
hp := conf.GetString("host_port")
|
||
|
fmt.Printf("Super-Graph listening on %s (%s)\n", hp, conf.GetString("env"))
|
||
|
|
||
|
logger.Fatal(http.ListenAndServe(hp, nil))
|
||
|
}
|