super-graph/serv/serv.go

168 lines
3.8 KiB
Go

package serv
import (
"errors"
"flag"
"fmt"
"net/http"
"os"
"strings"
"github.com/dosco/super-graph/psql"
"github.com/dosco/super-graph/qcode"
"github.com/go-pg/pg"
"github.com/gobuffalo/flect"
"github.com/sirupsen/logrus"
"github.com/spf13/viper"
)
//go:generate esc -o static.go -prefix ../web/build -private -pkg serv ../web/build
const (
authFailBlockAlways = iota + 1
authFailBlockPerQuery
authFailBlockNever
)
var (
logger *logrus.Logger
debug int
conf *viper.Viper
db *pg.DB
pcompile *psql.Compiler
qcompile *qcode.Compiler
authFailBlock int
tracing bool
)
func initLog() {
logger = logrus.New()
logger.Formatter = new(logrus.TextFormatter)
logger.Formatter.(*logrus.TextFormatter).DisableColors = false
logger.Formatter.(*logrus.TextFormatter).DisableTimestamp = true
logger.Level = logrus.TraceLevel
logger.Out = os.Stdout
}
func initConf() {
conf = viper.New()
cPath := flag.String("path", ".", "Path to folder that contains config files")
flag.Parse()
conf.AddConfigPath(*cPath)
switch os.Getenv("GO_ENV") {
case "production", "prod":
conf.SetConfigName("prod")
case "staging", "stage":
conf.SetConfigName("stage")
default:
conf.SetConfigName("dev")
}
err := conf.ReadInConfig()
if err != nil {
logger.Fatal(err)
}
debug = conf.GetInt("debug_level")
for k, v := range conf.GetStringMapString("inflections") {
flect.AddPlural(k, v)
}
conf.SetDefault("host_port", "0.0.0.0:8080")
conf.SetDefault("web_ui", false)
conf.SetDefault("debug_level", 0)
conf.SetDefault("enable_tracing", false)
conf.SetDefault("database.type", "postgres")
conf.SetDefault("database.host", "localhost")
conf.SetDefault("database.port", 5432)
conf.SetDefault("database.user", "postgres")
conf.SetDefault("database.password", "")
conf.SetDefault("env", "development")
conf.BindEnv("env", "GO_ENV")
tracing = conf.GetBool("enable_tracing")
switch conf.GetString("auth_fail_block") {
case "always":
authFailBlock = authFailBlockAlways
case "per_query", "perquery", "query":
authFailBlock = authFailBlockPerQuery
case "never", "false":
authFailBlock = authFailBlockNever
default:
authFailBlock = authFailBlockAlways
}
}
func initDB() {
conf.BindEnv("database.host", "SG_DATABASE_HOST")
conf.BindEnv("database.port", "SG_DATABASE_PORT")
conf.BindEnv("database.user", "SG_DATABASE_USER")
conf.BindEnv("database.password", "SG_DATABASE_PASSWORD")
hostport := strings.Join([]string{
conf.GetString("database.host"), conf.GetString("database.port")}, ":")
opt := &pg.Options{
Addr: hostport,
User: conf.GetString("database.user"),
Password: conf.GetString("database.password"),
Database: conf.GetString("database.dbname"),
}
if conf.IsSet("database.pool_size") {
opt.PoolSize = conf.GetInt("database.pool_size")
}
if conf.IsSet("database.max_retries") {
opt.MaxRetries = conf.GetInt("database.max_retries")
}
if db = pg.Connect(opt); db == nil {
logger.Fatal(errors.New("failed to connect to postgres db"))
}
}
func initCompilers() {
filters := conf.GetStringMapString("database.filters")
blacklist := conf.GetStringSlice("database.blacklist")
fm := qcode.NewFilterMap(filters)
bl := qcode.NewBlacklist(blacklist)
qcompile = qcode.NewCompiler(fm, bl)
schema, err := psql.NewDBSchema(db)
if err != nil {
logger.Fatal(err)
}
varlist := conf.GetStringMapString("database.variables")
vars := psql.NewVariables(varlist)
pcompile = psql.NewCompiler(schema, vars)
}
func InitAndListen() {
initLog()
initConf()
initDB()
initCompilers()
http.HandleFunc("/api/v1/graphql", withAuth(apiv1Http))
if conf.GetBool("web_ui") {
http.Handle("/", http.FileServer(_escFS(false)))
}
hp := conf.GetString("host_port")
fmt.Printf("Super-Graph listening on %s (%s)\n", hp, conf.GetString("env"))
logger.Fatal(http.ListenAndServe(hp, nil))
}