First commit
This commit is contained in:
71
serv/auth.go
Normal file
71
serv/auth.go
Normal file
@ -0,0 +1,71 @@
|
||||
package serv
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
const (
|
||||
salt = "encrypted cookie"
|
||||
signSalt = "signed encrypted cookie"
|
||||
emptySecret = ""
|
||||
authHeader = "Authorization"
|
||||
)
|
||||
|
||||
var (
|
||||
userIDKey = struct{}{}
|
||||
errSessionData = errors.New("error decoding session data")
|
||||
)
|
||||
|
||||
func headerHandler(next http.HandlerFunc) http.HandlerFunc {
|
||||
fn := conf.GetString("auth.field_name")
|
||||
if len(fn) == 0 {
|
||||
panic(errors.New("no auth.field_name defined"))
|
||||
}
|
||||
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
userID := r.Header.Get(fn)
|
||||
if len(userID) == 0 {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), userIDKey, userID)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
func withAuth(next http.HandlerFunc) http.HandlerFunc {
|
||||
atype := strings.ToLower(conf.GetString("auth.type"))
|
||||
if len(atype) == 0 {
|
||||
return next
|
||||
}
|
||||
store := strings.ToLower(conf.GetString("auth.store"))
|
||||
|
||||
switch atype {
|
||||
case "header":
|
||||
return headerHandler(next)
|
||||
|
||||
case "rails":
|
||||
switch store {
|
||||
case "memcache":
|
||||
return railsMemcacheHandler(next)
|
||||
|
||||
case "redis":
|
||||
return railsRedisHandler(next)
|
||||
|
||||
default:
|
||||
return railsCookieHandler(next)
|
||||
}
|
||||
|
||||
case "jwt":
|
||||
return jwtHandler(next)
|
||||
|
||||
default:
|
||||
panic(errors.New("unknown auth.type"))
|
||||
}
|
||||
|
||||
return next
|
||||
}
|
84
serv/auth_jwt.go
Normal file
84
serv/auth_jwt.go
Normal file
@ -0,0 +1,84 @@
|
||||
package serv
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
|
||||
jwt "github.com/dgrijalva/jwt-go"
|
||||
)
|
||||
|
||||
func jwtHandler(next http.HandlerFunc) http.HandlerFunc {
|
||||
var key interface{}
|
||||
|
||||
cookie := conf.GetString("auth.cookie")
|
||||
|
||||
conf.BindEnv("auth.secret", "SG_AUTH_SECRET")
|
||||
secret := conf.GetString("auth.secret")
|
||||
|
||||
conf.BindEnv("auth.public_key_file", "SG_AUTH_PUBLIC_KEY_FILE")
|
||||
publicKeyFile := conf.GetString("auth.public_key_file")
|
||||
|
||||
switch {
|
||||
case len(secret) != 0:
|
||||
key = []byte(secret)
|
||||
|
||||
case len(publicKeyFile) != 0:
|
||||
kd, err := ioutil.ReadFile(publicKeyFile)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
switch conf.GetString("auth.public_key_type") {
|
||||
case "ecdsa":
|
||||
key, err = jwt.ParseECPublicKeyFromPEM(kd)
|
||||
|
||||
case "rsa":
|
||||
key, err = jwt.ParseRSAPublicKeyFromPEM(kd)
|
||||
|
||||
default:
|
||||
key, err = jwt.ParseECPublicKeyFromPEM(kd)
|
||||
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
var tok string
|
||||
|
||||
if len(cookie) != 0 {
|
||||
ck, err := r.Cookie(cookie)
|
||||
if err != nil {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
tok = ck.Value
|
||||
} else {
|
||||
ah := r.Header.Get(authHeader)
|
||||
if len(ah) < 10 {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
tok = ah[7:]
|
||||
}
|
||||
|
||||
token, err := jwt.ParseWithClaims(tok, &jwt.StandardClaims{}, func(token *jwt.Token) (interface{}, error) {
|
||||
return key, nil
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
if claims, ok := token.Claims.(*jwt.StandardClaims); ok {
|
||||
ctx := context.WithValue(r.Context(), userIDKey, claims.Id)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
}
|
||||
}
|
240
serv/auth_rails.go
Normal file
240
serv/auth_rails.go
Normal file
@ -0,0 +1,240 @@
|
||||
package serv
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"github.com/adjust/gorails/marshal"
|
||||
"github.com/adjust/gorails/session"
|
||||
"github.com/bradfitz/gomemcache/memcache"
|
||||
"github.com/garyburd/redigo/redis"
|
||||
)
|
||||
|
||||
func railsRedisHandler(next http.HandlerFunc) http.HandlerFunc {
|
||||
cookie := conf.GetString("auth.cookie")
|
||||
if len(cookie) == 0 {
|
||||
panic(errors.New("no auth.cookie defined"))
|
||||
}
|
||||
|
||||
conf.BindEnv("auth.url", "SG_AUTH_URL")
|
||||
authURL := conf.GetString("auth.url")
|
||||
if len(authURL) == 0 {
|
||||
panic(errors.New("no auth.url defined"))
|
||||
}
|
||||
|
||||
conf.SetDefault("auth.max_idle", 80)
|
||||
conf.SetDefault("auth.max_active", 12000)
|
||||
|
||||
rp := &redis.Pool{
|
||||
MaxIdle: conf.GetInt("auth.max_idle"),
|
||||
MaxActive: conf.GetInt("auth.max_active"),
|
||||
Dial: func() (redis.Conn, error) {
|
||||
c, err := redis.DialURL(authURL)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
conf.BindEnv("auth.password", "SG_AUTH_PASSWORD")
|
||||
pwd := conf.GetString("auth.password")
|
||||
if len(pwd) != 0 {
|
||||
if _, err := c.Do("AUTH", pwd); err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
return c, err
|
||||
},
|
||||
}
|
||||
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ck, err := r.Cookie(cookie)
|
||||
if err != nil {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("session:%s", ck.Value)
|
||||
sessionData, err := redis.Bytes(rp.Get().Do("GET", key))
|
||||
if err != nil {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
userID, err := railsAuth(string(sessionData), emptySecret)
|
||||
if err != nil {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), userIDKey, userID)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
func railsMemcacheHandler(next http.HandlerFunc) http.HandlerFunc {
|
||||
cookie := conf.GetString("auth.cookie")
|
||||
if len(cookie) == 0 {
|
||||
panic(errors.New("no auth.cookie defined"))
|
||||
}
|
||||
|
||||
host := conf.GetString("auth.host")
|
||||
if len(host) == 0 {
|
||||
panic(errors.New("no auth.host defined"))
|
||||
}
|
||||
|
||||
mc := memcache.New(host)
|
||||
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ck, err := r.Cookie(cookie)
|
||||
if err != nil {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
key := fmt.Sprintf("session:%s", ck.Value)
|
||||
item, err := mc.Get(key)
|
||||
if err != nil {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
userID, err := railsAuth(string(item.Value), emptySecret)
|
||||
if err != nil {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), userIDKey, userID)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
func railsCookieHandler(next http.HandlerFunc) http.HandlerFunc {
|
||||
cookie := conf.GetString("auth.cookie")
|
||||
if len(cookie) == 0 {
|
||||
panic(errors.New("no auth.cookie defined"))
|
||||
}
|
||||
|
||||
conf.BindEnv("auth.secret_key_base", "SG_AUTH_SECRET_KEY_BASE")
|
||||
secret := conf.GetString("auth.secret_key_base")
|
||||
if len(secret) == 0 {
|
||||
panic(errors.New("no auth.secret_key_base defined"))
|
||||
}
|
||||
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
ck, err := r.Cookie(cookie)
|
||||
if err != nil {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
userID, err := railsAuth(ck.Value, secret)
|
||||
if err != nil {
|
||||
next.ServeHTTP(w, r)
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), userIDKey, userID)
|
||||
next.ServeHTTP(w, r.WithContext(ctx))
|
||||
}
|
||||
}
|
||||
|
||||
func railsAuth(cookie, secret string) (userID string, err error) {
|
||||
var dcookie []byte
|
||||
|
||||
if len(secret) != 0 {
|
||||
dcookie, err = session.DecryptSignedCookie(cookie, secret, salt, signSalt)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
if dcookie[0] != '{' {
|
||||
userID, err = getUserId4(dcookie)
|
||||
} else {
|
||||
userID, err = getUserId(dcookie)
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func getUserId(data []byte) (userID string, err error) {
|
||||
var sessionData map[string]interface{}
|
||||
|
||||
err = json.Unmarshal(data, &sessionData)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
userKey, ok := sessionData["warden.user.user.key"]
|
||||
if !ok {
|
||||
err = errors.New("key 'warden.user.user.key' not found in session data")
|
||||
}
|
||||
|
||||
items, ok := userKey.([]interface{})
|
||||
if !ok {
|
||||
err = errSessionData
|
||||
return
|
||||
}
|
||||
|
||||
if len(items) != 2 {
|
||||
err = errSessionData
|
||||
return
|
||||
}
|
||||
|
||||
uids, ok := items[0].([]interface{})
|
||||
if !ok {
|
||||
err = errSessionData
|
||||
return
|
||||
}
|
||||
|
||||
uid, ok := uids[0].(float64)
|
||||
if !ok {
|
||||
err = errSessionData
|
||||
return
|
||||
}
|
||||
userID = fmt.Sprintf("%d", int64(uid))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
func getUserId4(data []byte) (userID string, err error) {
|
||||
sessionData, err := marshal.CreateMarshalledObject(data).GetAsMap()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
wardenData, ok := sessionData["warden.user.user.key"]
|
||||
if !ok {
|
||||
err = errSessionData
|
||||
return
|
||||
}
|
||||
|
||||
wardenUserKey, err := wardenData.GetAsArray()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if len(wardenUserKey) < 1 {
|
||||
err = errSessionData
|
||||
return
|
||||
}
|
||||
|
||||
userData, err := wardenUserKey[0].GetAsArray()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if len(userData) < 1 {
|
||||
err = errSessionData
|
||||
return
|
||||
}
|
||||
|
||||
uid, err := userData[0].GetAsInteger()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
userID = fmt.Sprintf("%d", uid)
|
||||
|
||||
return
|
||||
}
|
49
serv/auth_test.go
Normal file
49
serv/auth_test.go
Normal file
@ -0,0 +1,49 @@
|
||||
package serv
|
||||
|
||||
import (
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRailsEncryptedSession(t *testing.T) {
|
||||
cookie := "dDdjMW5jYUNYaFpBT1BSdFgwQkk4ZWNlT214L1FnM0pyZzZ1d21nSnVTTm9zS0ljN000S1JmT3cxcTNtRld2Ny0tQUFBQUFBQUFBQUFBQUFBQUFBQUFBQT09--75d8323b0f0e41cf4d5aabee1b229b1be76b83b6"
|
||||
|
||||
secret := "development_secret"
|
||||
|
||||
userID, err := railsAuth(cookie, secret)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
if userID != "1" {
|
||||
t.Errorf("Expecting userID 1 got %s", userID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRailsJsonSession(t *testing.T) {
|
||||
sessionData := `{"warden.user.user.key":[[1],"secret"]}`
|
||||
|
||||
userID, err := getUserId([]byte(sessionData))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
if userID != "1" {
|
||||
t.Errorf("Expecting userID 1 got %s", userID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestRailsMarshaledSession(t *testing.T) {
|
||||
sessionData := "\x04\b{\bI\"\x15member_return_to\x06:\x06ETI\"\x06/\x06;\x00TI\"\x19warden.user.user.key\x06;\x00T[\a[\x06i\aI\"\"$2a$11$6SgXdvO9hld82kQAvpEY3e\x06;\x00TI\"\x10_csrf_token\x06;\x00FI\"17lqwj1UsTTgbXBQKH4ipCNW32uLusvfSPds1txppMec=\x06;\x00F"
|
||||
|
||||
userID, err := getUserId4([]byte(sessionData))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
if userID != "2" {
|
||||
t.Errorf("Expecting userID 2 got %s", userID)
|
||||
}
|
||||
}
|
167
serv/http.go
Normal file
167
serv/http.go
Normal file
@ -0,0 +1,167 @@
|
||||
package serv
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/go-pg/pg"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/valyala/fasttemplate"
|
||||
)
|
||||
|
||||
const (
|
||||
introspectionQuery = "IntrospectionQuery"
|
||||
openVar = "{{"
|
||||
closeVar = "}}"
|
||||
)
|
||||
|
||||
var (
|
||||
upgrader = websocket.Upgrader{}
|
||||
errNoUserID = errors.New("no user_id available")
|
||||
)
|
||||
|
||||
type gqlReq struct {
|
||||
OpName string `json:"operationName"`
|
||||
Query string `json:"query"`
|
||||
Variables map[string]string `json:"variables"`
|
||||
}
|
||||
|
||||
type gqlResp struct {
|
||||
Error string `json:"error,omitempty"`
|
||||
Data json.RawMessage `json:"data,omitempty"`
|
||||
}
|
||||
|
||||
func apiv1Http(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
if authFailBlock == authFailBlockAlways && authCheck(ctx) == false {
|
||||
http.Error(w, "Not authorized", 401)
|
||||
return
|
||||
}
|
||||
|
||||
b, err := ioutil.ReadAll(r.Body)
|
||||
defer r.Body.Close()
|
||||
if err != nil {
|
||||
errorResp(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
req := &gqlReq{}
|
||||
if err := json.Unmarshal(b, req); err != nil {
|
||||
errorResp(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
if strings.EqualFold(req.OpName, introspectionQuery) {
|
||||
dat, err := ioutil.ReadFile("test.schema")
|
||||
if err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
return
|
||||
}
|
||||
w.Write(dat)
|
||||
return
|
||||
}
|
||||
|
||||
qc, err := qcompile.CompileQuery(req.Query)
|
||||
if err != nil {
|
||||
errorResp(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
var sqlStmt strings.Builder
|
||||
|
||||
if err := pcompile.Compile(&sqlStmt, qc); err != nil {
|
||||
errorResp(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
t := fasttemplate.New(sqlStmt.String(), openVar, closeVar)
|
||||
sqlStmt.Reset()
|
||||
|
||||
_, err = t.Execute(&sqlStmt, varValues(ctx))
|
||||
if err == errNoUserID &&
|
||||
authFailBlock == authFailBlockPerQuery &&
|
||||
authCheck(ctx) == false {
|
||||
http.Error(w, "Not authorized", 401)
|
||||
return
|
||||
}
|
||||
if err != nil {
|
||||
errorResp(w, err)
|
||||
return
|
||||
}
|
||||
finalSQL := sqlStmt.String()
|
||||
if debug > 0 {
|
||||
fmt.Println(finalSQL)
|
||||
}
|
||||
|
||||
var root json.RawMessage
|
||||
_, err = db.Query(pg.Scan(&root), finalSQL)
|
||||
|
||||
if err != nil {
|
||||
errorResp(w, err)
|
||||
return
|
||||
}
|
||||
|
||||
json.NewEncoder(w).Encode(gqlResp{Data: json.RawMessage(root)})
|
||||
}
|
||||
|
||||
/*
|
||||
func apiv1Ws(w http.ResponseWriter, r *http.Request) {
|
||||
c, err := upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
return
|
||||
}
|
||||
defer c.Close()
|
||||
for {
|
||||
mt, message, err := c.ReadMessage()
|
||||
if err != nil {
|
||||
fmt.Println("read:", err)
|
||||
break
|
||||
}
|
||||
fmt.Printf("recv: %s", message)
|
||||
err = c.WriteMessage(mt, message)
|
||||
if err != nil {
|
||||
fmt.Println("write:", err)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func serve(w http.ResponseWriter, r *http.Request) {
|
||||
// if websocket.IsWebSocketUpgrade(r) {
|
||||
// apiv1Ws(w, r)
|
||||
// return
|
||||
// }
|
||||
apiv1Http(w, r)
|
||||
}
|
||||
*/
|
||||
|
||||
func errorResp(w http.ResponseWriter, err error) {
|
||||
b, _ := json.Marshal(gqlResp{Error: err.Error()})
|
||||
http.Error(w, string(b), http.StatusBadRequest)
|
||||
}
|
||||
|
||||
func authCheck(ctx context.Context) bool {
|
||||
return (ctx.Value(userIDKey) != nil)
|
||||
}
|
||||
|
||||
func varValues(ctx context.Context) map[string]interface{} {
|
||||
userIDFn := fasttemplate.TagFunc(func(w io.Writer, _ string) (int, error) {
|
||||
if v := ctx.Value(userIDKey); v != nil {
|
||||
return w.Write([]byte(v.(string)))
|
||||
}
|
||||
return 0, errNoUserID
|
||||
})
|
||||
|
||||
return map[string]interface{}{
|
||||
"USER_ID": userIDFn,
|
||||
"user_id": userIDFn,
|
||||
}
|
||||
}
|
180
serv/serv.go
Normal file
180
serv/serv.go
Normal file
@ -0,0 +1,180 @@
|
||||
package serv
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"os"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/dosco/super-graph/psql"
|
||||
"github.com/dosco/super-graph/qcode"
|
||||
"github.com/go-pg/pg"
|
||||
"github.com/jinzhu/inflection"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/spf13/viper"
|
||||
)
|
||||
|
||||
const (
|
||||
authFailBlockAlways = iota + 1
|
||||
authFailBlockPerQuery
|
||||
authFailBlockNever
|
||||
)
|
||||
|
||||
var (
|
||||
logger *logrus.Logger
|
||||
debug int
|
||||
conf *viper.Viper
|
||||
db *pg.DB
|
||||
pcompile *psql.Compiler
|
||||
qcompile *qcode.Compiler
|
||||
authFailBlock int
|
||||
)
|
||||
|
||||
func initLog() {
|
||||
logger = logrus.New()
|
||||
logger.Formatter = new(logrus.TextFormatter)
|
||||
logger.Formatter.(*logrus.TextFormatter).DisableColors = false
|
||||
logger.Formatter.(*logrus.TextFormatter).DisableTimestamp = true
|
||||
logger.Level = logrus.TraceLevel
|
||||
logger.Out = os.Stdout
|
||||
}
|
||||
|
||||
func initConf() {
|
||||
conf = viper.New()
|
||||
|
||||
cPath := flag.String("path", ".", "Path to folder that contains config files")
|
||||
flag.Parse()
|
||||
|
||||
conf.AddConfigPath(*cPath)
|
||||
|
||||
switch os.Getenv("GO_ENV") {
|
||||
case "production", "prod":
|
||||
conf.SetConfigName("prod")
|
||||
case "staging", "stage":
|
||||
conf.SetConfigName("stage")
|
||||
default:
|
||||
conf.SetConfigName("dev")
|
||||
}
|
||||
|
||||
err := conf.ReadInConfig()
|
||||
if err != nil {
|
||||
logger.Fatal(err)
|
||||
}
|
||||
|
||||
debug = conf.GetInt("debug_level")
|
||||
|
||||
for k, v := range conf.GetStringMapString("inflections") {
|
||||
inflection.AddIrregular(k, v)
|
||||
}
|
||||
|
||||
conf.SetDefault("host_port", "0.0.0.0:8080")
|
||||
conf.SetDefault("web_ui", false)
|
||||
conf.SetDefault("debug_level", 0)
|
||||
|
||||
conf.SetDefault("database.type", "postgres")
|
||||
conf.SetDefault("database.host", "localhost")
|
||||
conf.SetDefault("database.port", 5432)
|
||||
conf.SetDefault("database.user", "postgres")
|
||||
conf.SetDefault("database.password", "")
|
||||
|
||||
conf.SetDefault("env", "development")
|
||||
conf.BindEnv("env", "GO_ENV")
|
||||
|
||||
switch conf.GetString("auth_fail_block") {
|
||||
case "always":
|
||||
authFailBlock = authFailBlockAlways
|
||||
case "per_query", "perquery", "query":
|
||||
authFailBlock = authFailBlockPerQuery
|
||||
case "never", "false":
|
||||
authFailBlock = authFailBlockNever
|
||||
default:
|
||||
authFailBlock = authFailBlockAlways
|
||||
}
|
||||
}
|
||||
|
||||
func initDB() {
|
||||
conf.BindEnv("database.host", "SG_DATABASE_HOST")
|
||||
conf.BindEnv("database.port", "SG_DATABASE_PORT")
|
||||
conf.BindEnv("database.user", "SG_DATABASE_USER")
|
||||
conf.BindEnv("database.password", "SG_DATABASE_PASSWORD")
|
||||
|
||||
hostport := strings.Join([]string{
|
||||
conf.GetString("database.host"), conf.GetString("database.port")}, ":")
|
||||
|
||||
opt := &pg.Options{
|
||||
Addr: hostport,
|
||||
User: conf.GetString("database.user"),
|
||||
Password: conf.GetString("database.password"),
|
||||
Database: conf.GetString("database.dbname"),
|
||||
}
|
||||
|
||||
if conf.IsSet("database.pool_size") {
|
||||
opt.PoolSize = conf.GetInt("database.pool_size")
|
||||
}
|
||||
|
||||
if conf.IsSet("database.max_retries") {
|
||||
opt.MaxRetries = conf.GetInt("database.max_retries")
|
||||
}
|
||||
|
||||
if db = pg.Connect(opt); db == nil {
|
||||
logger.Fatal(errors.New("failed to connect to postgres db"))
|
||||
}
|
||||
}
|
||||
|
||||
func initCompilers() {
|
||||
fv := conf.GetStringMapString("database.filters")
|
||||
fm := make(qcode.FilterMap)
|
||||
for k, v := range fv {
|
||||
fil, err := qcode.CompileFilter(v)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
key := strings.ToLower(k)
|
||||
fm[key] = fil
|
||||
}
|
||||
|
||||
bv := conf.GetStringSlice("database.blacklist")
|
||||
var bl *regexp.Regexp
|
||||
if len(bv) != 0 {
|
||||
re := fmt.Sprintf("(?i)%s", strings.Join(bv, "|"))
|
||||
bl = regexp.MustCompile(re)
|
||||
}
|
||||
qcompile = qcode.NewCompiler(fm, bl)
|
||||
|
||||
schema, err := psql.NewDBSchema(db)
|
||||
if err != nil {
|
||||
logger.Fatal(err)
|
||||
}
|
||||
|
||||
re := regexp.MustCompile(`(?mi)\$([a-zA-Z0-9_.]+)`)
|
||||
vl := conf.GetStringMapString("database.variables")
|
||||
vars := make(map[string]string)
|
||||
|
||||
for k, v := range vl {
|
||||
vars[k] = re.ReplaceAllString(v, `{{$1}}`)
|
||||
}
|
||||
|
||||
pcompile = psql.NewCompiler(schema, vars)
|
||||
}
|
||||
|
||||
func InitAndListen() {
|
||||
initLog()
|
||||
initConf()
|
||||
initDB()
|
||||
initCompilers()
|
||||
|
||||
http.HandleFunc("/api/v1/graphql", withAuth(apiv1Http))
|
||||
|
||||
if conf.GetBool("web_ui") {
|
||||
fs := http.FileServer(http.Dir("web/build"))
|
||||
http.Handle("/", fs)
|
||||
}
|
||||
|
||||
hp := conf.GetString("host_port")
|
||||
fmt.Printf("Super-Graph listening on %s (%s)\n", hp, conf.GetString("env"))
|
||||
|
||||
logger.Fatal(http.ListenAndServe(hp, nil))
|
||||
}
|
Reference in New Issue
Block a user