super-graph/serv/serv.go

181 lines
4.0 KiB
Go

package serv
import (
"errors"
"flag"
"fmt"
"net/http"
"os"
"regexp"
"strings"
"github.com/dosco/super-graph/psql"
"github.com/dosco/super-graph/qcode"
"github.com/go-pg/pg"
"github.com/jinzhu/inflection"
"github.com/sirupsen/logrus"
"github.com/spf13/viper"
)
const (
authFailBlockAlways = iota + 1
authFailBlockPerQuery
authFailBlockNever
)
var (
logger *logrus.Logger
debug int
conf *viper.Viper
db *pg.DB
pcompile *psql.Compiler
qcompile *qcode.Compiler
authFailBlock int
)
func initLog() {
logger = logrus.New()
logger.Formatter = new(logrus.TextFormatter)
logger.Formatter.(*logrus.TextFormatter).DisableColors = false
logger.Formatter.(*logrus.TextFormatter).DisableTimestamp = true
logger.Level = logrus.TraceLevel
logger.Out = os.Stdout
}
func initConf() {
conf = viper.New()
cPath := flag.String("path", ".", "Path to folder that contains config files")
flag.Parse()
conf.AddConfigPath(*cPath)
switch os.Getenv("GO_ENV") {
case "production", "prod":
conf.SetConfigName("prod")
case "staging", "stage":
conf.SetConfigName("stage")
default:
conf.SetConfigName("dev")
}
err := conf.ReadInConfig()
if err != nil {
logger.Fatal(err)
}
debug = conf.GetInt("debug_level")
for k, v := range conf.GetStringMapString("inflections") {
inflection.AddIrregular(k, v)
}
conf.SetDefault("host_port", "0.0.0.0:8080")
conf.SetDefault("web_ui", false)
conf.SetDefault("debug_level", 0)
conf.SetDefault("database.type", "postgres")
conf.SetDefault("database.host", "localhost")
conf.SetDefault("database.port", 5432)
conf.SetDefault("database.user", "postgres")
conf.SetDefault("database.password", "")
conf.SetDefault("env", "development")
conf.BindEnv("env", "GO_ENV")
switch conf.GetString("auth_fail_block") {
case "always":
authFailBlock = authFailBlockAlways
case "per_query", "perquery", "query":
authFailBlock = authFailBlockPerQuery
case "never", "false":
authFailBlock = authFailBlockNever
default:
authFailBlock = authFailBlockAlways
}
}
func initDB() {
conf.BindEnv("database.host", "SG_DATABASE_HOST")
conf.BindEnv("database.port", "SG_DATABASE_PORT")
conf.BindEnv("database.user", "SG_DATABASE_USER")
conf.BindEnv("database.password", "SG_DATABASE_PASSWORD")
hostport := strings.Join([]string{
conf.GetString("database.host"), conf.GetString("database.port")}, ":")
opt := &pg.Options{
Addr: hostport,
User: conf.GetString("database.user"),
Password: conf.GetString("database.password"),
Database: conf.GetString("database.dbname"),
}
if conf.IsSet("database.pool_size") {
opt.PoolSize = conf.GetInt("database.pool_size")
}
if conf.IsSet("database.max_retries") {
opt.MaxRetries = conf.GetInt("database.max_retries")
}
if db = pg.Connect(opt); db == nil {
logger.Fatal(errors.New("failed to connect to postgres db"))
}
}
func initCompilers() {
fv := conf.GetStringMapString("database.filters")
fm := make(qcode.FilterMap)
for k, v := range fv {
fil, err := qcode.CompileFilter(v)
if err != nil {
panic(err)
}
key := strings.ToLower(k)
fm[key] = fil
}
bv := conf.GetStringSlice("database.blacklist")
var bl *regexp.Regexp
if len(bv) != 0 {
re := fmt.Sprintf("(?i)%s", strings.Join(bv, "|"))
bl = regexp.MustCompile(re)
}
qcompile = qcode.NewCompiler(fm, bl)
schema, err := psql.NewDBSchema(db)
if err != nil {
logger.Fatal(err)
}
re := regexp.MustCompile(`(?mi)\$([a-zA-Z0-9_.]+)`)
vl := conf.GetStringMapString("database.variables")
vars := make(map[string]string)
for k, v := range vl {
vars[k] = re.ReplaceAllString(v, `{{$1}}`)
}
pcompile = psql.NewCompiler(schema, vars)
}
func InitAndListen() {
initLog()
initConf()
initDB()
initCompilers()
http.HandleFunc("/api/v1/graphql", withAuth(apiv1Http))
if conf.GetBool("web_ui") {
fs := http.FileServer(http.Dir("web/build"))
http.Handle("/", fs)
}
hp := conf.GetString("host_port")
fmt.Printf("Super-Graph listening on %s (%s)\n", hp, conf.GetString("env"))
logger.Fatal(http.ListenAndServe(hp, nil))
}