diff --git a/internal/serv/internal/auth/auth.go b/internal/serv/internal/auth/auth.go index 3dcbc26..a9c6198 100644 --- a/internal/serv/internal/auth/auth.go +++ b/internal/serv/internal/auth/auth.go @@ -32,6 +32,7 @@ type Auth struct { Secret string PubKeyFile string `mapstructure:"public_key_file"` PubKeyType string `mapstructure:"public_key_type"` + Audience string `mapstructure:"audience"` } Header struct { diff --git a/internal/serv/internal/auth/jwt.go b/internal/serv/internal/auth/jwt.go index b9df700..2c1361e 100644 --- a/internal/serv/internal/auth/jwt.go +++ b/internal/serv/internal/auth/jwt.go @@ -2,19 +2,32 @@ package auth import ( "context" + "encoding/json" "io/ioutil" "net/http" + "strconv" "strings" + "time" jwt "github.com/dgrijalva/jwt-go" "github.com/dosco/super-graph/core" ) const ( - authHeader = "Authorization" - jwtAuth0 int = iota + 1 + authHeader = "Authorization" + jwtAuth0 int = iota + 1 + jwtFirebase int = iota + 2 + firebasePKEndpoint = "https://www.googleapis.com/robot/v1/metadata/x509/securetoken@system.gserviceaccount.com" + firebaseIssuerPrefix = "https://securetoken.google.com/" ) +type firebasePKCache struct { + PublicKeys map[string]string + Expiration time.Time +} + +var firebasePublicKeys firebasePKCache + func JwtHandler(ac *Auth, next http.Handler) (http.HandlerFunc, error) { var key interface{} var jwtProvider int @@ -23,6 +36,8 @@ func JwtHandler(ac *Auth, next http.Handler) (http.HandlerFunc, error) { if ac.JWT.Provider == "auth0" { jwtProvider = jwtAuth0 + } else if ac.JWT.Provider == "firebase" { + jwtProvider = jwtFirebase } secret := ac.JWT.Secret @@ -56,6 +71,7 @@ func JwtHandler(ac *Auth, next http.Handler) (http.HandlerFunc, error) { } return func(w http.ResponseWriter, r *http.Request) { + var tok string if len(cookie) != 0 { @@ -74,9 +90,16 @@ func JwtHandler(ac *Auth, next http.Handler) (http.HandlerFunc, error) { tok = ah[7:] } - token, err := jwt.ParseWithClaims(tok, &jwt.StandardClaims{}, func(token *jwt.Token) (interface{}, error) { - return key, nil - }) + var keyFunc jwt.Keyfunc + if jwtProvider == jwtFirebase { + keyFunc = firebaseKeyFunction + } else { + keyFunc = func(token *jwt.Token) (interface{}, error) { + return key, nil + } + } + + token, err := jwt.ParseWithClaims(tok, &jwt.StandardClaims{}, keyFunc) if err != nil { next.ServeHTTP(w, r) @@ -86,12 +109,20 @@ func JwtHandler(ac *Auth, next http.Handler) (http.HandlerFunc, error) { if claims, ok := token.Claims.(*jwt.StandardClaims); ok { ctx := r.Context() + if ac.JWT.Audience != "" && claims.Audience != ac.JWT.Audience { + next.ServeHTTP(w, r) + return + } + if jwtProvider == jwtAuth0 { sub := strings.Split(claims.Subject, "|") if len(sub) != 2 { ctx = context.WithValue(ctx, core.UserIDProviderKey, sub[0]) ctx = context.WithValue(ctx, core.UserIDKey, sub[1]) } + } else if jwtProvider == jwtFirebase && + claims.Issuer == firebaseIssuerPrefix+ac.JWT.Audience { + ctx = context.WithValue(ctx, core.UserIDKey, claims.Subject) } else { ctx = context.WithValue(ctx, core.UserIDKey, claims.Subject) } @@ -103,3 +134,92 @@ func JwtHandler(ac *Auth, next http.Handler) (http.HandlerFunc, error) { next.ServeHTTP(w, r) }, nil } + +type firebaseKeyError struct { + Err error + Message string +} + +func (e *firebaseKeyError) Error() string { + return e.Message + " " + e.Err.Error() +} + +func firebaseKeyFunction(token *jwt.Token) (interface{}, error) { + kid, ok := token.Header["kid"] + + if !ok { + return nil, &firebaseKeyError{ + Message: "Error 'kid' header not found in token", + } + } + + if firebasePublicKeys.Expiration.Before(time.Now()) { + resp, err := http.Get(firebasePKEndpoint) + + if err != nil { + return nil, &firebaseKeyError{ + Message: "Error connecting to firebase certificate server", + Err: err, + } + } + + defer resp.Body.Close() + + data, err := ioutil.ReadAll(resp.Body) + + if err != nil { + return nil, &firebaseKeyError{ + Message: "Error reading firebase certificate server response", + Err: err, + } + } + + cachePolicy := resp.Header.Get("cache-control") + ageIndex := strings.Index(cachePolicy, "max-age=") + + if ageIndex < 0 { + return nil, &firebaseKeyError{ + Message: "Error parsing cache-control header: 'max-age=' not found", + } + } + + ageToEnd := cachePolicy[ageIndex+8:] + endIndex := strings.Index(ageToEnd, ",") + if endIndex < 0 { + endIndex = len(ageToEnd) - 1 + } + ageString := ageToEnd[:endIndex] + + age, err := strconv.ParseInt(ageString, 10, 64) + + if err != nil { + return nil, &firebaseKeyError{ + Message: "Error parsing max-age cache policy", + Err: err, + } + } + + expiration := time.Now().Add(time.Duration(time.Duration(age) * time.Second)) + + err = json.Unmarshal(data, &firebasePublicKeys.PublicKeys) + + if err != nil { + firebasePublicKeys = firebasePKCache{} + return nil, &firebaseKeyError{ + Message: "Error unmarshalling firebase public key json", + Err: err, + } + } + + firebasePublicKeys.Expiration = expiration + } + + if key, found := firebasePublicKeys.PublicKeys[kid.(string)]; found { + k, err := jwt.ParseRSAPublicKeyFromPEM([]byte(key)) + return k, err + } + + return nil, &firebaseKeyError{ + Message: "Error no matching public key for kid supplied in jwt", + } +}