package auth import ( "context" "encoding/json" "errors" "fmt" "io/ioutil" "log" "net/http" "strconv" "strings" "time" "github.com/lestrrat-go/jwx/jwk" jwt "github.com/dgrijalva/jwt-go" "github.com/dosco/super-graph/core" ) const ( authHeader = "Authorization" jwtAuth0 int = iota + 1 jwtFirebase int = iota + 2 firebasePKEndpoint = "https://www.googleapis.com/robot/v1/metadata/x509/securetoken@system.gserviceaccount.com" firebaseIssuerPrefix = "https://securetoken.google.com/" ) type firebasePKCache struct { PublicKeys map[string]string Expiration time.Time } var firebasePublicKeys firebasePKCache type standardClaims struct { jwt.StandardClaims Audience []string `json:"aud,omitempty"` } func (c *standardClaims) MatchAudience(audience string) bool { matchLegacy := c.StandardClaims.Audience == audience if matchLegacy { return true } if c.Audience == nil { return false } for _, tokenAudience := range c.Audience { if audience == tokenAudience { return true } } return false } func JwtHandler(ac *Auth, next http.Handler) (http.HandlerFunc, error) { var key interface{} var jwtProvider int cookie := ac.Cookie if ac.JWT.Provider == "auth0" { jwtProvider = jwtAuth0 } else if ac.JWT.Provider == "firebase" { jwtProvider = jwtFirebase } secret := ac.JWT.Secret publicKeyFile := ac.JWT.PubKeyFile jwksURL := ac.JWT.JWKSURL switch { case secret != "": key = []byte(secret) case publicKeyFile != "": kd, err := ioutil.ReadFile(publicKeyFile) if err != nil { return nil, err } switch ac.JWT.PubKeyType { case "ecdsa": key, err = jwt.ParseECPublicKeyFromPEM(kd) case "rsa": key, err = jwt.ParseRSAPublicKeyFromPEM(kd) default: key, err = jwt.ParseECPublicKeyFromPEM(kd) } if err != nil { return nil, err } } return func(w http.ResponseWriter, r *http.Request) { var tok string if cookie != "" { ck, err := r.Cookie(cookie) if err != nil { next.ServeHTTP(w, r) return } tok = ck.Value } else { ah := r.Header.Get(authHeader) if len(ah) < 10 { next.ServeHTTP(w, r) return } tok = ah[7:] } var keyFunc jwt.Keyfunc if jwtProvider == jwtFirebase { keyFunc = firebaseKeyFunction } else { if jwksURL != "" { keyFunc = createJWKSKeyFetchFunc(jwksURL) } else { keyFunc = func(token *jwt.Token) (interface{}, error) { return key, nil } } } token, err := jwt.ParseWithClaims(tok, &standardClaims{}, keyFunc) if err != nil { log.Printf("[ERR] %s", err) next.ServeHTTP(w, r) return } if claims, ok := token.Claims.(*standardClaims); ok { ctx := r.Context() if ac.JWT.Audience != "" && !claims.MatchAudience(ac.JWT.Audience) { next.ServeHTTP(w, r) return } if jwtProvider == jwtAuth0 { sub := strings.Split(claims.Subject, "|") if len(sub) != 2 { ctx = context.WithValue(ctx, core.UserIDProviderKey, sub[0]) ctx = context.WithValue(ctx, core.UserIDKey, sub[1]) } } else if jwtProvider == jwtFirebase && claims.Issuer == firebaseIssuerPrefix+ac.JWT.Audience { ctx = context.WithValue(ctx, core.UserIDKey, claims.Subject) } else { ctx = context.WithValue(ctx, core.UserIDKey, claims.Subject) } next.ServeHTTP(w, r.WithContext(ctx)) return } next.ServeHTTP(w, r) }, nil } type firebaseKeyError struct { Err error Message string } func (e *firebaseKeyError) Error() string { return e.Message + " " + e.Err.Error() } func firebaseKeyFunction(token *jwt.Token) (interface{}, error) { kid, ok := token.Header["kid"] if !ok { return nil, &firebaseKeyError{ Message: "Error 'kid' header not found in token", } } if firebasePublicKeys.Expiration.Before(time.Now()) { resp, err := http.Get(firebasePKEndpoint) if err != nil { return nil, &firebaseKeyError{ Message: "Error connecting to firebase certificate server", Err: err, } } defer resp.Body.Close() data, err := ioutil.ReadAll(resp.Body) if err != nil { return nil, &firebaseKeyError{ Message: "Error reading firebase certificate server response", Err: err, } } cachePolicy := resp.Header.Get("cache-control") ageIndex := strings.Index(cachePolicy, "max-age=") if ageIndex < 0 { return nil, &firebaseKeyError{ Message: "Error parsing cache-control header: 'max-age=' not found", } } ageToEnd := cachePolicy[ageIndex+8:] endIndex := strings.Index(ageToEnd, ",") if endIndex < 0 { endIndex = len(ageToEnd) - 1 } ageString := ageToEnd[:endIndex] age, err := strconv.ParseInt(ageString, 10, 64) if err != nil { return nil, &firebaseKeyError{ Message: "Error parsing max-age cache policy", Err: err, } } expiration := time.Now().Add(time.Duration(time.Duration(age) * time.Second)) err = json.Unmarshal(data, &firebasePublicKeys.PublicKeys) if err != nil { firebasePublicKeys = firebasePKCache{} return nil, &firebaseKeyError{ Message: "Error unmarshalling firebase public key json", Err: err, } } firebasePublicKeys.Expiration = expiration } if key, found := firebasePublicKeys.PublicKeys[kid.(string)]; found { k, err := jwt.ParseRSAPublicKeyFromPEM([]byte(key)) return k, err } return nil, &firebaseKeyError{ Message: "Error no matching public key for kid supplied in jwt", } } func createJWKSKeyFetchFunc(jwksURL string) func(token *jwt.Token) (interface{}, error) { return func(token *jwt.Token) (interface{}, error) { set, err := jwk.FetchHTTP(jwksURL) if err != nil { return nil, err } keyID, ok := token.Header["kid"].(string) if !ok { return nil, errors.New("expecting JWT header to have string kid") } if key := set.LookupKeyID(keyID); len(key) == 1 { var rawKey interface{} if err := key[0].Raw(&rawKey); err != nil { return nil, fmt.Errorf("unable to retrieve raw key %q", keyID) } return rawKey, nil } return nil, fmt.Errorf("unable to find key %q", keyID) } }