diff --git a/internal/serv/internal/auth/jwt.go b/internal/serv/internal/auth/jwt.go index 97fb114..3656c56 100644 --- a/internal/serv/internal/auth/jwt.go +++ b/internal/serv/internal/auth/jwt.go @@ -3,6 +3,8 @@ package auth import ( "context" "encoding/json" + "errors" + "fmt" "io/ioutil" "net/http" "strconv" @@ -28,6 +30,30 @@ type firebasePKCache struct { var firebasePublicKeys firebasePKCache +type standardClaims struct { + jwt.StandardClaims + Audience []string `json:"aud,omitempty"` +} + +func (c *standardClaims) MatchAudience(audience string) bool { + matchLegacy := c.StandardClaims.Audience == audience + if matchLegacy { + return true + } + + if c.Audience == nil { + return false + } + + for _, tokenAudience := range c.Audience { + if audience == tokenAudience { + return true + } + } + + return false +} + func JwtHandler(ac *Auth, next http.Handler) (http.HandlerFunc, error) { var key interface{} var jwtProvider int @@ -99,17 +125,16 @@ func JwtHandler(ac *Auth, next http.Handler) (http.HandlerFunc, error) { } } - token, err := jwt.ParseWithClaims(tok, &jwt.StandardClaims{}, keyFunc) - + token, err := jwt.ParseWithClaims(tok, &standardClaims{}, keyFunc) if err != nil { next.ServeHTTP(w, r) return } - if claims, ok := token.Claims.(*jwt.StandardClaims); ok { + if claims, ok := token.Claims.(*standardClaims); ok { ctx := r.Context() - if ac.JWT.Audience != "" && claims.Audience != ac.JWT.Audience { + if ac.JWT.Audience != "" && !claims.MatchAudience(ac.JWT.Audience) { next.ServeHTTP(w, r) return }