package serv import ( "context" "io/ioutil" "net/http" "strings" jwt "github.com/dgrijalva/jwt-go" ) const ( jwtBase int = iota jwtAuth0 ) func jwtHandler(next http.HandlerFunc) http.HandlerFunc { var key interface{} var jwtProvider int cookie := conf.GetString("auth.cookie") provider := conf.GetString("auth.provider") if provider == "auth0" { jwtProvider = jwtAuth0 } conf.BindEnv("auth.secret", "SG_AUTH_SECRET") secret := conf.GetString("auth.secret") conf.BindEnv("auth.public_key_file", "SG_AUTH_PUBLIC_KEY_FILE") publicKeyFile := conf.GetString("auth.public_key_file") switch { case len(secret) != 0: key = []byte(secret) case len(publicKeyFile) != 0: kd, err := ioutil.ReadFile(publicKeyFile) if err != nil { panic(err) } switch conf.GetString("auth.public_key_type") { case "ecdsa": key, err = jwt.ParseECPublicKeyFromPEM(kd) case "rsa": key, err = jwt.ParseRSAPublicKeyFromPEM(kd) default: key, err = jwt.ParseECPublicKeyFromPEM(kd) } if err != nil { panic(err) } } return func(w http.ResponseWriter, r *http.Request) { var tok string if len(cookie) != 0 { ck, err := r.Cookie(cookie) if err != nil { next.ServeHTTP(w, r) return } tok = ck.Value } else { ah := r.Header.Get(authHeader) if len(ah) < 10 { next.ServeHTTP(w, r) return } tok = ah[7:] } token, err := jwt.ParseWithClaims(tok, &jwt.StandardClaims{}, func(token *jwt.Token) (interface{}, error) { return key, nil }) if err != nil { next.ServeHTTP(w, r) return } if claims, ok := token.Claims.(*jwt.StandardClaims); ok { ctx := r.Context() if jwtProvider == jwtAuth0 { sub := strings.Split(claims.Subject, "|") if len(sub) != 2 { ctx = context.WithValue(ctx, userIDProviderKey, sub[0]) ctx = context.WithValue(ctx, userIDKey, sub[1]) } } else { ctx = context.WithValue(ctx, userIDKey, claims.Subject) } next.ServeHTTP(w, r.WithContext(ctx)) } next.ServeHTTP(w, r) } }