feat: Add firebase auth support and JWT audience check (#71)
This commit is contained in:
parent
bd157290f6
commit
a775f9475b
|
@ -32,6 +32,7 @@ type Auth struct {
|
||||||
Secret string
|
Secret string
|
||||||
PubKeyFile string `mapstructure:"public_key_file"`
|
PubKeyFile string `mapstructure:"public_key_file"`
|
||||||
PubKeyType string `mapstructure:"public_key_type"`
|
PubKeyType string `mapstructure:"public_key_type"`
|
||||||
|
Audience string `mapstructure:"audience"`
|
||||||
}
|
}
|
||||||
|
|
||||||
Header struct {
|
Header struct {
|
||||||
|
|
|
@ -2,9 +2,12 @@ package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
jwt "github.com/dgrijalva/jwt-go"
|
jwt "github.com/dgrijalva/jwt-go"
|
||||||
"github.com/dosco/super-graph/core"
|
"github.com/dosco/super-graph/core"
|
||||||
|
@ -13,8 +16,18 @@ import (
|
||||||
const (
|
const (
|
||||||
authHeader = "Authorization"
|
authHeader = "Authorization"
|
||||||
jwtAuth0 int = iota + 1
|
jwtAuth0 int = iota + 1
|
||||||
|
jwtFirebase int = iota + 2
|
||||||
|
firebasePKEndpoint = "https://www.googleapis.com/robot/v1/metadata/x509/securetoken@system.gserviceaccount.com"
|
||||||
|
firebaseIssuerPrefix = "https://securetoken.google.com/"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type firebasePKCache struct {
|
||||||
|
PublicKeys map[string]string
|
||||||
|
Expiration time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
var firebasePublicKeys firebasePKCache
|
||||||
|
|
||||||
func JwtHandler(ac *Auth, next http.Handler) (http.HandlerFunc, error) {
|
func JwtHandler(ac *Auth, next http.Handler) (http.HandlerFunc, error) {
|
||||||
var key interface{}
|
var key interface{}
|
||||||
var jwtProvider int
|
var jwtProvider int
|
||||||
|
@ -23,6 +36,8 @@ func JwtHandler(ac *Auth, next http.Handler) (http.HandlerFunc, error) {
|
||||||
|
|
||||||
if ac.JWT.Provider == "auth0" {
|
if ac.JWT.Provider == "auth0" {
|
||||||
jwtProvider = jwtAuth0
|
jwtProvider = jwtAuth0
|
||||||
|
} else if ac.JWT.Provider == "firebase" {
|
||||||
|
jwtProvider = jwtFirebase
|
||||||
}
|
}
|
||||||
|
|
||||||
secret := ac.JWT.Secret
|
secret := ac.JWT.Secret
|
||||||
|
@ -56,6 +71,7 @@ func JwtHandler(ac *Auth, next http.Handler) (http.HandlerFunc, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
return func(w http.ResponseWriter, r *http.Request) {
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
var tok string
|
var tok string
|
||||||
|
|
||||||
if len(cookie) != 0 {
|
if len(cookie) != 0 {
|
||||||
|
@ -74,9 +90,16 @@ func JwtHandler(ac *Auth, next http.Handler) (http.HandlerFunc, error) {
|
||||||
tok = ah[7:]
|
tok = ah[7:]
|
||||||
}
|
}
|
||||||
|
|
||||||
token, err := jwt.ParseWithClaims(tok, &jwt.StandardClaims{}, func(token *jwt.Token) (interface{}, error) {
|
var keyFunc jwt.Keyfunc
|
||||||
|
if jwtProvider == jwtFirebase {
|
||||||
|
keyFunc = firebaseKeyFunction
|
||||||
|
} else {
|
||||||
|
keyFunc = func(token *jwt.Token) (interface{}, error) {
|
||||||
return key, nil
|
return key, nil
|
||||||
})
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := jwt.ParseWithClaims(tok, &jwt.StandardClaims{}, keyFunc)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
|
@ -86,12 +109,20 @@ func JwtHandler(ac *Auth, next http.Handler) (http.HandlerFunc, error) {
|
||||||
if claims, ok := token.Claims.(*jwt.StandardClaims); ok {
|
if claims, ok := token.Claims.(*jwt.StandardClaims); ok {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
|
|
||||||
|
if ac.JWT.Audience != "" && claims.Audience != ac.JWT.Audience {
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if jwtProvider == jwtAuth0 {
|
if jwtProvider == jwtAuth0 {
|
||||||
sub := strings.Split(claims.Subject, "|")
|
sub := strings.Split(claims.Subject, "|")
|
||||||
if len(sub) != 2 {
|
if len(sub) != 2 {
|
||||||
ctx = context.WithValue(ctx, core.UserIDProviderKey, sub[0])
|
ctx = context.WithValue(ctx, core.UserIDProviderKey, sub[0])
|
||||||
ctx = context.WithValue(ctx, core.UserIDKey, sub[1])
|
ctx = context.WithValue(ctx, core.UserIDKey, sub[1])
|
||||||
}
|
}
|
||||||
|
} else if jwtProvider == jwtFirebase &&
|
||||||
|
claims.Issuer == firebaseIssuerPrefix+ac.JWT.Audience {
|
||||||
|
ctx = context.WithValue(ctx, core.UserIDKey, claims.Subject)
|
||||||
} else {
|
} else {
|
||||||
ctx = context.WithValue(ctx, core.UserIDKey, claims.Subject)
|
ctx = context.WithValue(ctx, core.UserIDKey, claims.Subject)
|
||||||
}
|
}
|
||||||
|
@ -103,3 +134,92 @@ func JwtHandler(ac *Auth, next http.Handler) (http.HandlerFunc, error) {
|
||||||
next.ServeHTTP(w, r)
|
next.ServeHTTP(w, r)
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type firebaseKeyError struct {
|
||||||
|
Err error
|
||||||
|
Message string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *firebaseKeyError) Error() string {
|
||||||
|
return e.Message + " " + e.Err.Error()
|
||||||
|
}
|
||||||
|
|
||||||
|
func firebaseKeyFunction(token *jwt.Token) (interface{}, error) {
|
||||||
|
kid, ok := token.Header["kid"]
|
||||||
|
|
||||||
|
if !ok {
|
||||||
|
return nil, &firebaseKeyError{
|
||||||
|
Message: "Error 'kid' header not found in token",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if firebasePublicKeys.Expiration.Before(time.Now()) {
|
||||||
|
resp, err := http.Get(firebasePKEndpoint)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, &firebaseKeyError{
|
||||||
|
Message: "Error connecting to firebase certificate server",
|
||||||
|
Err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
defer resp.Body.Close()
|
||||||
|
|
||||||
|
data, err := ioutil.ReadAll(resp.Body)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, &firebaseKeyError{
|
||||||
|
Message: "Error reading firebase certificate server response",
|
||||||
|
Err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
cachePolicy := resp.Header.Get("cache-control")
|
||||||
|
ageIndex := strings.Index(cachePolicy, "max-age=")
|
||||||
|
|
||||||
|
if ageIndex < 0 {
|
||||||
|
return nil, &firebaseKeyError{
|
||||||
|
Message: "Error parsing cache-control header: 'max-age=' not found",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ageToEnd := cachePolicy[ageIndex+8:]
|
||||||
|
endIndex := strings.Index(ageToEnd, ",")
|
||||||
|
if endIndex < 0 {
|
||||||
|
endIndex = len(ageToEnd) - 1
|
||||||
|
}
|
||||||
|
ageString := ageToEnd[:endIndex]
|
||||||
|
|
||||||
|
age, err := strconv.ParseInt(ageString, 10, 64)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, &firebaseKeyError{
|
||||||
|
Message: "Error parsing max-age cache policy",
|
||||||
|
Err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
expiration := time.Now().Add(time.Duration(time.Duration(age) * time.Second))
|
||||||
|
|
||||||
|
err = json.Unmarshal(data, &firebasePublicKeys.PublicKeys)
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
firebasePublicKeys = firebasePKCache{}
|
||||||
|
return nil, &firebaseKeyError{
|
||||||
|
Message: "Error unmarshalling firebase public key json",
|
||||||
|
Err: err,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
firebasePublicKeys.Expiration = expiration
|
||||||
|
}
|
||||||
|
|
||||||
|
if key, found := firebasePublicKeys.PublicKeys[kid.(string)]; found {
|
||||||
|
k, err := jwt.ParseRSAPublicKeyFromPEM([]byte(key))
|
||||||
|
return k, err
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, &firebaseKeyError{
|
||||||
|
Message: "Error no matching public key for kid supplied in jwt",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue