First commit

This commit is contained in:
Vikram Rangnekar
2019-03-24 09:57:29 -04:00
commit b9d38a5e9d
153 changed files with 18120 additions and 0 deletions

71
serv/auth.go Normal file
View File

@ -0,0 +1,71 @@
package serv
import (
"context"
"errors"
"net/http"
"strings"
)
const (
salt = "encrypted cookie"
signSalt = "signed encrypted cookie"
emptySecret = ""
authHeader = "Authorization"
)
var (
userIDKey = struct{}{}
errSessionData = errors.New("error decoding session data")
)
func headerHandler(next http.HandlerFunc) http.HandlerFunc {
fn := conf.GetString("auth.field_name")
if len(fn) == 0 {
panic(errors.New("no auth.field_name defined"))
}
return func(w http.ResponseWriter, r *http.Request) {
userID := r.Header.Get(fn)
if len(userID) == 0 {
next.ServeHTTP(w, r)
return
}
ctx := context.WithValue(r.Context(), userIDKey, userID)
next.ServeHTTP(w, r.WithContext(ctx))
}
}
func withAuth(next http.HandlerFunc) http.HandlerFunc {
atype := strings.ToLower(conf.GetString("auth.type"))
if len(atype) == 0 {
return next
}
store := strings.ToLower(conf.GetString("auth.store"))
switch atype {
case "header":
return headerHandler(next)
case "rails":
switch store {
case "memcache":
return railsMemcacheHandler(next)
case "redis":
return railsRedisHandler(next)
default:
return railsCookieHandler(next)
}
case "jwt":
return jwtHandler(next)
default:
panic(errors.New("unknown auth.type"))
}
return next
}

84
serv/auth_jwt.go Normal file
View File

@ -0,0 +1,84 @@
package serv
import (
"context"
"io/ioutil"
"net/http"
jwt "github.com/dgrijalva/jwt-go"
)
func jwtHandler(next http.HandlerFunc) http.HandlerFunc {
var key interface{}
cookie := conf.GetString("auth.cookie")
conf.BindEnv("auth.secret", "SG_AUTH_SECRET")
secret := conf.GetString("auth.secret")
conf.BindEnv("auth.public_key_file", "SG_AUTH_PUBLIC_KEY_FILE")
publicKeyFile := conf.GetString("auth.public_key_file")
switch {
case len(secret) != 0:
key = []byte(secret)
case len(publicKeyFile) != 0:
kd, err := ioutil.ReadFile(publicKeyFile)
if err != nil {
panic(err)
}
switch conf.GetString("auth.public_key_type") {
case "ecdsa":
key, err = jwt.ParseECPublicKeyFromPEM(kd)
case "rsa":
key, err = jwt.ParseRSAPublicKeyFromPEM(kd)
default:
key, err = jwt.ParseECPublicKeyFromPEM(kd)
}
if err != nil {
panic(err)
}
}
return func(w http.ResponseWriter, r *http.Request) {
var tok string
if len(cookie) != 0 {
ck, err := r.Cookie(cookie)
if err != nil {
next.ServeHTTP(w, r)
return
}
tok = ck.Value
} else {
ah := r.Header.Get(authHeader)
if len(ah) < 10 {
next.ServeHTTP(w, r)
return
}
tok = ah[7:]
}
token, err := jwt.ParseWithClaims(tok, &jwt.StandardClaims{}, func(token *jwt.Token) (interface{}, error) {
return key, nil
})
if err != nil {
next.ServeHTTP(w, r)
return
}
if claims, ok := token.Claims.(*jwt.StandardClaims); ok {
ctx := context.WithValue(r.Context(), userIDKey, claims.Id)
next.ServeHTTP(w, r.WithContext(ctx))
}
next.ServeHTTP(w, r)
}
}

240
serv/auth_rails.go Normal file
View File

@ -0,0 +1,240 @@
package serv
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"github.com/adjust/gorails/marshal"
"github.com/adjust/gorails/session"
"github.com/bradfitz/gomemcache/memcache"
"github.com/garyburd/redigo/redis"
)
func railsRedisHandler(next http.HandlerFunc) http.HandlerFunc {
cookie := conf.GetString("auth.cookie")
if len(cookie) == 0 {
panic(errors.New("no auth.cookie defined"))
}
conf.BindEnv("auth.url", "SG_AUTH_URL")
authURL := conf.GetString("auth.url")
if len(authURL) == 0 {
panic(errors.New("no auth.url defined"))
}
conf.SetDefault("auth.max_idle", 80)
conf.SetDefault("auth.max_active", 12000)
rp := &redis.Pool{
MaxIdle: conf.GetInt("auth.max_idle"),
MaxActive: conf.GetInt("auth.max_active"),
Dial: func() (redis.Conn, error) {
c, err := redis.DialURL(authURL)
if err != nil {
panic(err)
}
conf.BindEnv("auth.password", "SG_AUTH_PASSWORD")
pwd := conf.GetString("auth.password")
if len(pwd) != 0 {
if _, err := c.Do("AUTH", pwd); err != nil {
panic(err)
}
}
return c, err
},
}
return func(w http.ResponseWriter, r *http.Request) {
ck, err := r.Cookie(cookie)
if err != nil {
next.ServeHTTP(w, r)
return
}
key := fmt.Sprintf("session:%s", ck.Value)
sessionData, err := redis.Bytes(rp.Get().Do("GET", key))
if err != nil {
next.ServeHTTP(w, r)
return
}
userID, err := railsAuth(string(sessionData), emptySecret)
if err != nil {
next.ServeHTTP(w, r)
return
}
ctx := context.WithValue(r.Context(), userIDKey, userID)
next.ServeHTTP(w, r.WithContext(ctx))
}
}
func railsMemcacheHandler(next http.HandlerFunc) http.HandlerFunc {
cookie := conf.GetString("auth.cookie")
if len(cookie) == 0 {
panic(errors.New("no auth.cookie defined"))
}
host := conf.GetString("auth.host")
if len(host) == 0 {
panic(errors.New("no auth.host defined"))
}
mc := memcache.New(host)
return func(w http.ResponseWriter, r *http.Request) {
ck, err := r.Cookie(cookie)
if err != nil {
next.ServeHTTP(w, r)
return
}
key := fmt.Sprintf("session:%s", ck.Value)
item, err := mc.Get(key)
if err != nil {
next.ServeHTTP(w, r)
return
}
userID, err := railsAuth(string(item.Value), emptySecret)
if err != nil {
next.ServeHTTP(w, r)
return
}
ctx := context.WithValue(r.Context(), userIDKey, userID)
next.ServeHTTP(w, r.WithContext(ctx))
}
}
func railsCookieHandler(next http.HandlerFunc) http.HandlerFunc {
cookie := conf.GetString("auth.cookie")
if len(cookie) == 0 {
panic(errors.New("no auth.cookie defined"))
}
conf.BindEnv("auth.secret_key_base", "SG_AUTH_SECRET_KEY_BASE")
secret := conf.GetString("auth.secret_key_base")
if len(secret) == 0 {
panic(errors.New("no auth.secret_key_base defined"))
}
return func(w http.ResponseWriter, r *http.Request) {
ck, err := r.Cookie(cookie)
if err != nil {
next.ServeHTTP(w, r)
return
}
userID, err := railsAuth(ck.Value, secret)
if err != nil {
next.ServeHTTP(w, r)
return
}
ctx := context.WithValue(r.Context(), userIDKey, userID)
next.ServeHTTP(w, r.WithContext(ctx))
}
}
func railsAuth(cookie, secret string) (userID string, err error) {
var dcookie []byte
if len(secret) != 0 {
dcookie, err = session.DecryptSignedCookie(cookie, secret, salt, signSalt)
if err != nil {
return
}
}
if dcookie[0] != '{' {
userID, err = getUserId4(dcookie)
} else {
userID, err = getUserId(dcookie)
}
return
}
func getUserId(data []byte) (userID string, err error) {
var sessionData map[string]interface{}
err = json.Unmarshal(data, &sessionData)
if err != nil {
return
}
userKey, ok := sessionData["warden.user.user.key"]
if !ok {
err = errors.New("key 'warden.user.user.key' not found in session data")
}
items, ok := userKey.([]interface{})
if !ok {
err = errSessionData
return
}
if len(items) != 2 {
err = errSessionData
return
}
uids, ok := items[0].([]interface{})
if !ok {
err = errSessionData
return
}
uid, ok := uids[0].(float64)
if !ok {
err = errSessionData
return
}
userID = fmt.Sprintf("%d", int64(uid))
return
}
func getUserId4(data []byte) (userID string, err error) {
sessionData, err := marshal.CreateMarshalledObject(data).GetAsMap()
if err != nil {
return
}
wardenData, ok := sessionData["warden.user.user.key"]
if !ok {
err = errSessionData
return
}
wardenUserKey, err := wardenData.GetAsArray()
if err != nil {
return
}
if len(wardenUserKey) < 1 {
err = errSessionData
return
}
userData, err := wardenUserKey[0].GetAsArray()
if err != nil {
return
}
if len(userData) < 1 {
err = errSessionData
return
}
uid, err := userData[0].GetAsInteger()
if err != nil {
return
}
userID = fmt.Sprintf("%d", uid)
return
}

49
serv/auth_test.go Normal file
View File

@ -0,0 +1,49 @@
package serv
import (
"testing"
)
func TestRailsEncryptedSession(t *testing.T) {
cookie := "dDdjMW5jYUNYaFpBT1BSdFgwQkk4ZWNlT214L1FnM0pyZzZ1d21nSnVTTm9zS0ljN000S1JmT3cxcTNtRld2Ny0tQUFBQUFBQUFBQUFBQUFBQUFBQUFBQT09--75d8323b0f0e41cf4d5aabee1b229b1be76b83b6"
secret := "development_secret"
userID, err := railsAuth(cookie, secret)
if err != nil {
t.Error(err)
return
}
if userID != "1" {
t.Errorf("Expecting userID 1 got %s", userID)
}
}
func TestRailsJsonSession(t *testing.T) {
sessionData := `{"warden.user.user.key":[[1],"secret"]}`
userID, err := getUserId([]byte(sessionData))
if err != nil {
t.Error(err)
return
}
if userID != "1" {
t.Errorf("Expecting userID 1 got %s", userID)
}
}
func TestRailsMarshaledSession(t *testing.T) {
sessionData := "\x04\b{\bI\"\x15member_return_to\x06:\x06ETI\"\x06/\x06;\x00TI\"\x19warden.user.user.key\x06;\x00T[\a[\x06i\aI\"\"$2a$11$6SgXdvO9hld82kQAvpEY3e\x06;\x00TI\"\x10_csrf_token\x06;\x00FI\"17lqwj1UsTTgbXBQKH4ipCNW32uLusvfSPds1txppMec=\x06;\x00F"
userID, err := getUserId4([]byte(sessionData))
if err != nil {
t.Error(err)
return
}
if userID != "2" {
t.Errorf("Expecting userID 2 got %s", userID)
}
}

167
serv/http.go Normal file
View File

@ -0,0 +1,167 @@
package serv
import (
"context"
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
"strings"
"github.com/go-pg/pg"
"github.com/gorilla/websocket"
"github.com/valyala/fasttemplate"
)
const (
introspectionQuery = "IntrospectionQuery"
openVar = "{{"
closeVar = "}}"
)
var (
upgrader = websocket.Upgrader{}
errNoUserID = errors.New("no user_id available")
)
type gqlReq struct {
OpName string `json:"operationName"`
Query string `json:"query"`
Variables map[string]string `json:"variables"`
}
type gqlResp struct {
Error string `json:"error,omitempty"`
Data json.RawMessage `json:"data,omitempty"`
}
func apiv1Http(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
if authFailBlock == authFailBlockAlways && authCheck(ctx) == false {
http.Error(w, "Not authorized", 401)
return
}
b, err := ioutil.ReadAll(r.Body)
defer r.Body.Close()
if err != nil {
errorResp(w, err)
return
}
req := &gqlReq{}
if err := json.Unmarshal(b, req); err != nil {
errorResp(w, err)
return
}
if strings.EqualFold(req.OpName, introspectionQuery) {
dat, err := ioutil.ReadFile("test.schema")
if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
return
}
w.Write(dat)
return
}
qc, err := qcompile.CompileQuery(req.Query)
if err != nil {
errorResp(w, err)
return
}
var sqlStmt strings.Builder
if err := pcompile.Compile(&sqlStmt, qc); err != nil {
errorResp(w, err)
return
}
t := fasttemplate.New(sqlStmt.String(), openVar, closeVar)
sqlStmt.Reset()
_, err = t.Execute(&sqlStmt, varValues(ctx))
if err == errNoUserID &&
authFailBlock == authFailBlockPerQuery &&
authCheck(ctx) == false {
http.Error(w, "Not authorized", 401)
return
}
if err != nil {
errorResp(w, err)
return
}
finalSQL := sqlStmt.String()
if debug > 0 {
fmt.Println(finalSQL)
}
var root json.RawMessage
_, err = db.Query(pg.Scan(&root), finalSQL)
if err != nil {
errorResp(w, err)
return
}
json.NewEncoder(w).Encode(gqlResp{Data: json.RawMessage(root)})
}
/*
func apiv1Ws(w http.ResponseWriter, r *http.Request) {
c, err := upgrader.Upgrade(w, r, nil)
if err != nil {
fmt.Println(err)
return
}
defer c.Close()
for {
mt, message, err := c.ReadMessage()
if err != nil {
fmt.Println("read:", err)
break
}
fmt.Printf("recv: %s", message)
err = c.WriteMessage(mt, message)
if err != nil {
fmt.Println("write:", err)
break
}
}
}
func serve(w http.ResponseWriter, r *http.Request) {
// if websocket.IsWebSocketUpgrade(r) {
// apiv1Ws(w, r)
// return
// }
apiv1Http(w, r)
}
*/
func errorResp(w http.ResponseWriter, err error) {
b, _ := json.Marshal(gqlResp{Error: err.Error()})
http.Error(w, string(b), http.StatusBadRequest)
}
func authCheck(ctx context.Context) bool {
return (ctx.Value(userIDKey) != nil)
}
func varValues(ctx context.Context) map[string]interface{} {
userIDFn := fasttemplate.TagFunc(func(w io.Writer, _ string) (int, error) {
if v := ctx.Value(userIDKey); v != nil {
return w.Write([]byte(v.(string)))
}
return 0, errNoUserID
})
return map[string]interface{}{
"USER_ID": userIDFn,
"user_id": userIDFn,
}
}

180
serv/serv.go Normal file
View File

@ -0,0 +1,180 @@
package serv
import (
"errors"
"flag"
"fmt"
"net/http"
"os"
"regexp"
"strings"
"github.com/dosco/super-graph/psql"
"github.com/dosco/super-graph/qcode"
"github.com/go-pg/pg"
"github.com/jinzhu/inflection"
"github.com/sirupsen/logrus"
"github.com/spf13/viper"
)
const (
authFailBlockAlways = iota + 1
authFailBlockPerQuery
authFailBlockNever
)
var (
logger *logrus.Logger
debug int
conf *viper.Viper
db *pg.DB
pcompile *psql.Compiler
qcompile *qcode.Compiler
authFailBlock int
)
func initLog() {
logger = logrus.New()
logger.Formatter = new(logrus.TextFormatter)
logger.Formatter.(*logrus.TextFormatter).DisableColors = false
logger.Formatter.(*logrus.TextFormatter).DisableTimestamp = true
logger.Level = logrus.TraceLevel
logger.Out = os.Stdout
}
func initConf() {
conf = viper.New()
cPath := flag.String("path", ".", "Path to folder that contains config files")
flag.Parse()
conf.AddConfigPath(*cPath)
switch os.Getenv("GO_ENV") {
case "production", "prod":
conf.SetConfigName("prod")
case "staging", "stage":
conf.SetConfigName("stage")
default:
conf.SetConfigName("dev")
}
err := conf.ReadInConfig()
if err != nil {
logger.Fatal(err)
}
debug = conf.GetInt("debug_level")
for k, v := range conf.GetStringMapString("inflections") {
inflection.AddIrregular(k, v)
}
conf.SetDefault("host_port", "0.0.0.0:8080")
conf.SetDefault("web_ui", false)
conf.SetDefault("debug_level", 0)
conf.SetDefault("database.type", "postgres")
conf.SetDefault("database.host", "localhost")
conf.SetDefault("database.port", 5432)
conf.SetDefault("database.user", "postgres")
conf.SetDefault("database.password", "")
conf.SetDefault("env", "development")
conf.BindEnv("env", "GO_ENV")
switch conf.GetString("auth_fail_block") {
case "always":
authFailBlock = authFailBlockAlways
case "per_query", "perquery", "query":
authFailBlock = authFailBlockPerQuery
case "never", "false":
authFailBlock = authFailBlockNever
default:
authFailBlock = authFailBlockAlways
}
}
func initDB() {
conf.BindEnv("database.host", "SG_DATABASE_HOST")
conf.BindEnv("database.port", "SG_DATABASE_PORT")
conf.BindEnv("database.user", "SG_DATABASE_USER")
conf.BindEnv("database.password", "SG_DATABASE_PASSWORD")
hostport := strings.Join([]string{
conf.GetString("database.host"), conf.GetString("database.port")}, ":")
opt := &pg.Options{
Addr: hostport,
User: conf.GetString("database.user"),
Password: conf.GetString("database.password"),
Database: conf.GetString("database.dbname"),
}
if conf.IsSet("database.pool_size") {
opt.PoolSize = conf.GetInt("database.pool_size")
}
if conf.IsSet("database.max_retries") {
opt.MaxRetries = conf.GetInt("database.max_retries")
}
if db = pg.Connect(opt); db == nil {
logger.Fatal(errors.New("failed to connect to postgres db"))
}
}
func initCompilers() {
fv := conf.GetStringMapString("database.filters")
fm := make(qcode.FilterMap)
for k, v := range fv {
fil, err := qcode.CompileFilter(v)
if err != nil {
panic(err)
}
key := strings.ToLower(k)
fm[key] = fil
}
bv := conf.GetStringSlice("database.blacklist")
var bl *regexp.Regexp
if len(bv) != 0 {
re := fmt.Sprintf("(?i)%s", strings.Join(bv, "|"))
bl = regexp.MustCompile(re)
}
qcompile = qcode.NewCompiler(fm, bl)
schema, err := psql.NewDBSchema(db)
if err != nil {
logger.Fatal(err)
}
re := regexp.MustCompile(`(?mi)\$([a-zA-Z0-9_.]+)`)
vl := conf.GetStringMapString("database.variables")
vars := make(map[string]string)
for k, v := range vl {
vars[k] = re.ReplaceAllString(v, `{{$1}}`)
}
pcompile = psql.NewCompiler(schema, vars)
}
func InitAndListen() {
initLog()
initConf()
initDB()
initCompilers()
http.HandleFunc("/api/v1/graphql", withAuth(apiv1Http))
if conf.GetBool("web_ui") {
fs := http.FileServer(http.Dir("web/build"))
http.Handle("/", fs)
}
hp := conf.GetString("host_port")
fmt.Printf("Super-Graph listening on %s (%s)\n", hp, conf.GetString("env"))
logger.Fatal(http.ListenAndServe(hp, nil))
}