241 lines
4.7 KiB
Go
241 lines
4.7 KiB
Go
package serv
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
|
|
"github.com/adjust/gorails/marshal"
|
|
"github.com/adjust/gorails/session"
|
|
"github.com/bradfitz/gomemcache/memcache"
|
|
"github.com/garyburd/redigo/redis"
|
|
)
|
|
|
|
func railsRedisHandler(next http.HandlerFunc) http.HandlerFunc {
|
|
cookie := conf.GetString("auth.cookie")
|
|
if len(cookie) == 0 {
|
|
panic(errors.New("no auth.cookie defined"))
|
|
}
|
|
|
|
conf.BindEnv("auth.url", "SG_AUTH_URL")
|
|
authURL := conf.GetString("auth.url")
|
|
if len(authURL) == 0 {
|
|
panic(errors.New("no auth.url defined"))
|
|
}
|
|
|
|
conf.SetDefault("auth.max_idle", 80)
|
|
conf.SetDefault("auth.max_active", 12000)
|
|
|
|
rp := &redis.Pool{
|
|
MaxIdle: conf.GetInt("auth.max_idle"),
|
|
MaxActive: conf.GetInt("auth.max_active"),
|
|
Dial: func() (redis.Conn, error) {
|
|
c, err := redis.DialURL(authURL)
|
|
if err != nil {
|
|
panic(err)
|
|
}
|
|
|
|
conf.BindEnv("auth.password", "SG_AUTH_PASSWORD")
|
|
pwd := conf.GetString("auth.password")
|
|
if len(pwd) != 0 {
|
|
if _, err := c.Do("AUTH", pwd); err != nil {
|
|
panic(err)
|
|
}
|
|
}
|
|
return c, err
|
|
},
|
|
}
|
|
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
ck, err := r.Cookie(cookie)
|
|
if err != nil {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
key := fmt.Sprintf("session:%s", ck.Value)
|
|
sessionData, err := redis.Bytes(rp.Get().Do("GET", key))
|
|
if err != nil {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
userID, err := railsAuth(string(sessionData), emptySecret)
|
|
if err != nil {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
ctx := context.WithValue(r.Context(), userIDKey, userID)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
}
|
|
}
|
|
|
|
func railsMemcacheHandler(next http.HandlerFunc) http.HandlerFunc {
|
|
cookie := conf.GetString("auth.cookie")
|
|
if len(cookie) == 0 {
|
|
panic(errors.New("no auth.cookie defined"))
|
|
}
|
|
|
|
host := conf.GetString("auth.host")
|
|
if len(host) == 0 {
|
|
panic(errors.New("no auth.host defined"))
|
|
}
|
|
|
|
mc := memcache.New(host)
|
|
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
ck, err := r.Cookie(cookie)
|
|
if err != nil {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
key := fmt.Sprintf("session:%s", ck.Value)
|
|
item, err := mc.Get(key)
|
|
if err != nil {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
userID, err := railsAuth(string(item.Value), emptySecret)
|
|
if err != nil {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
ctx := context.WithValue(r.Context(), userIDKey, userID)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
}
|
|
}
|
|
|
|
func railsCookieHandler(next http.HandlerFunc) http.HandlerFunc {
|
|
cookie := conf.GetString("auth.cookie")
|
|
if len(cookie) == 0 {
|
|
panic(errors.New("no auth.cookie defined"))
|
|
}
|
|
|
|
conf.BindEnv("auth.secret_key_base", "SG_AUTH_SECRET_KEY_BASE")
|
|
secret := conf.GetString("auth.secret_key_base")
|
|
if len(secret) == 0 {
|
|
panic(errors.New("no auth.secret_key_base defined"))
|
|
}
|
|
|
|
return func(w http.ResponseWriter, r *http.Request) {
|
|
ck, err := r.Cookie(cookie)
|
|
if err != nil {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
userID, err := railsAuth(ck.Value, secret)
|
|
if err != nil {
|
|
next.ServeHTTP(w, r)
|
|
return
|
|
}
|
|
|
|
ctx := context.WithValue(r.Context(), userIDKey, userID)
|
|
next.ServeHTTP(w, r.WithContext(ctx))
|
|
}
|
|
}
|
|
|
|
func railsAuth(cookie, secret string) (userID string, err error) {
|
|
var dcookie []byte
|
|
|
|
if len(secret) != 0 {
|
|
dcookie, err = session.DecryptSignedCookie(cookie, secret, salt, signSalt)
|
|
if err != nil {
|
|
return
|
|
}
|
|
}
|
|
|
|
if dcookie[0] != '{' {
|
|
userID, err = getUserId4(dcookie)
|
|
} else {
|
|
userID, err = getUserId(dcookie)
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
func getUserId(data []byte) (userID string, err error) {
|
|
var sessionData map[string]interface{}
|
|
|
|
err = json.Unmarshal(data, &sessionData)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
userKey, ok := sessionData["warden.user.user.key"]
|
|
if !ok {
|
|
err = errors.New("key 'warden.user.user.key' not found in session data")
|
|
}
|
|
|
|
items, ok := userKey.([]interface{})
|
|
if !ok {
|
|
err = errSessionData
|
|
return
|
|
}
|
|
|
|
if len(items) != 2 {
|
|
err = errSessionData
|
|
return
|
|
}
|
|
|
|
uids, ok := items[0].([]interface{})
|
|
if !ok {
|
|
err = errSessionData
|
|
return
|
|
}
|
|
|
|
uid, ok := uids[0].(float64)
|
|
if !ok {
|
|
err = errSessionData
|
|
return
|
|
}
|
|
userID = fmt.Sprintf("%d", int64(uid))
|
|
|
|
return
|
|
}
|
|
|
|
func getUserId4(data []byte) (userID string, err error) {
|
|
sessionData, err := marshal.CreateMarshalledObject(data).GetAsMap()
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
wardenData, ok := sessionData["warden.user.user.key"]
|
|
if !ok {
|
|
err = errSessionData
|
|
return
|
|
}
|
|
|
|
wardenUserKey, err := wardenData.GetAsArray()
|
|
if err != nil {
|
|
return
|
|
}
|
|
if len(wardenUserKey) < 1 {
|
|
err = errSessionData
|
|
return
|
|
}
|
|
|
|
userData, err := wardenUserKey[0].GetAsArray()
|
|
if err != nil {
|
|
return
|
|
}
|
|
if len(userData) < 1 {
|
|
err = errSessionData
|
|
return
|
|
}
|
|
|
|
uid, err := userData[0].GetAsInteger()
|
|
if err != nil {
|
|
return
|
|
}
|
|
userID = fmt.Sprintf("%d", uid)
|
|
|
|
return
|
|
}
|