Add support for JWKS-based token validation
This commit is contained in:
@ -33,6 +33,7 @@ type Auth struct {
|
||||
PubKeyFile string `mapstructure:"public_key_file"`
|
||||
PubKeyType string `mapstructure:"public_key_type"`
|
||||
Audience string `mapstructure:"audience"`
|
||||
JWKSURL string `mapstructure:"jwks_url"`
|
||||
}
|
||||
|
||||
Header struct {
|
||||
|
@ -11,6 +11,8 @@ import (
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/lestrrat-go/jwx/jwk"
|
||||
|
||||
jwt "github.com/dgrijalva/jwt-go"
|
||||
"github.com/dosco/super-graph/core"
|
||||
)
|
||||
@ -68,6 +70,7 @@ func JwtHandler(ac *Auth, next http.Handler) (http.HandlerFunc, error) {
|
||||
|
||||
secret := ac.JWT.Secret
|
||||
publicKeyFile := ac.JWT.PubKeyFile
|
||||
jwksURL := ac.JWT.JWKSURL
|
||||
|
||||
switch {
|
||||
case secret != "":
|
||||
@ -120,8 +123,12 @@ func JwtHandler(ac *Auth, next http.Handler) (http.HandlerFunc, error) {
|
||||
if jwtProvider == jwtFirebase {
|
||||
keyFunc = firebaseKeyFunction
|
||||
} else {
|
||||
keyFunc = func(token *jwt.Token) (interface{}, error) {
|
||||
return key, nil
|
||||
if jwksURL != "" {
|
||||
keyFunc = createJWKSKeyFetchFunc(jwksURL)
|
||||
} else {
|
||||
keyFunc = func(token *jwt.Token) (interface{}, error) {
|
||||
return key, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -248,3 +255,28 @@ func firebaseKeyFunction(token *jwt.Token) (interface{}, error) {
|
||||
Message: "Error no matching public key for kid supplied in jwt",
|
||||
}
|
||||
}
|
||||
|
||||
func createJWKSKeyFetchFunc(jwksURL string) func(token *jwt.Token) (interface{}, error) {
|
||||
return func(token *jwt.Token) (interface{}, error) {
|
||||
set, err := jwk.FetchHTTP(jwksURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
keyID, ok := token.Header["kid"].(string)
|
||||
if !ok {
|
||||
return nil, errors.New("expecting JWT header to have string kid")
|
||||
}
|
||||
|
||||
if key := set.LookupKeyID(keyID); len(key) == 1 {
|
||||
var rawKey interface{}
|
||||
if err := key[0].Raw(&rawKey); err != nil {
|
||||
return nil, fmt.Errorf("unable to retrieve raw key %q", keyID)
|
||||
}
|
||||
|
||||
return rawKey, nil
|
||||
}
|
||||
|
||||
return nil, fmt.Errorf("unable to find key %q", keyID)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user