super-graph/internal/serv/internal/auth/jwt.go

226 lines
4.7 KiB
Go
Raw Normal View History

package auth
2019-03-24 14:57:29 +01:00
import (
"context"
"encoding/json"
2019-03-24 14:57:29 +01:00
"io/ioutil"
"net/http"
"strconv"
2019-03-29 03:34:42 +01:00
"strings"
"time"
2019-03-24 14:57:29 +01:00
jwt "github.com/dgrijalva/jwt-go"
"github.com/dosco/super-graph/core"
2019-03-24 14:57:29 +01:00
)
2019-03-29 03:34:42 +01:00
const (
authHeader = "Authorization"
jwtAuth0 int = iota + 1
jwtFirebase int = iota + 2
firebasePKEndpoint = "https://www.googleapis.com/robot/v1/metadata/x509/securetoken@system.gserviceaccount.com"
firebaseIssuerPrefix = "https://securetoken.google.com/"
2019-03-29 03:34:42 +01:00
)
type firebasePKCache struct {
PublicKeys map[string]string
Expiration time.Time
}
var firebasePublicKeys firebasePKCache
2020-04-11 08:45:06 +02:00
func JwtHandler(ac *Auth, next http.Handler) (http.HandlerFunc, error) {
2019-03-24 14:57:29 +01:00
var key interface{}
2019-03-29 03:34:42 +01:00
var jwtProvider int
2019-03-24 14:57:29 +01:00
cookie := ac.Cookie
2019-03-24 14:57:29 +01:00
if ac.JWT.Provider == "auth0" {
2019-03-29 03:34:42 +01:00
jwtProvider = jwtAuth0
} else if ac.JWT.Provider == "firebase" {
jwtProvider = jwtFirebase
2019-03-29 03:34:42 +01:00
}
secret := ac.JWT.Secret
publicKeyFile := ac.JWT.PubKeyFile
2019-03-24 14:57:29 +01:00
switch {
case secret != "":
2019-03-24 14:57:29 +01:00
key = []byte(secret)
case publicKeyFile != "":
2019-03-24 14:57:29 +01:00
kd, err := ioutil.ReadFile(publicKeyFile)
if err != nil {
return nil, err
2019-03-24 14:57:29 +01:00
}
switch ac.JWT.PubKeyType {
2019-03-24 14:57:29 +01:00
case "ecdsa":
key, err = jwt.ParseECPublicKeyFromPEM(kd)
case "rsa":
key, err = jwt.ParseRSAPublicKeyFromPEM(kd)
default:
key, err = jwt.ParseECPublicKeyFromPEM(kd)
}
if err != nil {
return nil, err
2019-03-24 14:57:29 +01:00
}
}
return func(w http.ResponseWriter, r *http.Request) {
2019-03-24 14:57:29 +01:00
var tok string
if cookie != "" {
2019-03-24 14:57:29 +01:00
ck, err := r.Cookie(cookie)
if err != nil {
next.ServeHTTP(w, r)
return
}
tok = ck.Value
} else {
ah := r.Header.Get(authHeader)
if len(ah) < 10 {
next.ServeHTTP(w, r)
return
}
tok = ah[7:]
}
var keyFunc jwt.Keyfunc
if jwtProvider == jwtFirebase {
keyFunc = firebaseKeyFunction
} else {
keyFunc = func(token *jwt.Token) (interface{}, error) {
return key, nil
}
}
token, err := jwt.ParseWithClaims(tok, &jwt.StandardClaims{}, keyFunc)
2019-03-24 14:57:29 +01:00
if err != nil {
next.ServeHTTP(w, r)
return
}
if claims, ok := token.Claims.(*jwt.StandardClaims); ok {
2019-03-29 03:34:42 +01:00
ctx := r.Context()
if ac.JWT.Audience != "" && claims.Audience != ac.JWT.Audience {
next.ServeHTTP(w, r)
return
}
2019-03-29 03:34:42 +01:00
if jwtProvider == jwtAuth0 {
sub := strings.Split(claims.Subject, "|")
if len(sub) != 2 {
ctx = context.WithValue(ctx, core.UserIDProviderKey, sub[0])
ctx = context.WithValue(ctx, core.UserIDKey, sub[1])
2019-03-29 03:34:42 +01:00
}
} else if jwtProvider == jwtFirebase &&
claims.Issuer == firebaseIssuerPrefix+ac.JWT.Audience {
ctx = context.WithValue(ctx, core.UserIDKey, claims.Subject)
2019-03-29 03:34:42 +01:00
} else {
ctx = context.WithValue(ctx, core.UserIDKey, claims.Subject)
2019-03-29 03:34:42 +01:00
}
2019-11-15 07:35:19 +01:00
2019-03-24 14:57:29 +01:00
next.ServeHTTP(w, r.WithContext(ctx))
2019-11-15 07:35:19 +01:00
return
2019-03-24 14:57:29 +01:00
}
2019-11-15 07:35:19 +01:00
2019-03-24 14:57:29 +01:00
next.ServeHTTP(w, r)
}, nil
2019-03-24 14:57:29 +01:00
}
type firebaseKeyError struct {
Err error
Message string
}
func (e *firebaseKeyError) Error() string {
return e.Message + " " + e.Err.Error()
}
func firebaseKeyFunction(token *jwt.Token) (interface{}, error) {
kid, ok := token.Header["kid"]
if !ok {
return nil, &firebaseKeyError{
Message: "Error 'kid' header not found in token",
}
}
if firebasePublicKeys.Expiration.Before(time.Now()) {
resp, err := http.Get(firebasePKEndpoint)
if err != nil {
return nil, &firebaseKeyError{
Message: "Error connecting to firebase certificate server",
Err: err,
}
}
defer resp.Body.Close()
data, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, &firebaseKeyError{
Message: "Error reading firebase certificate server response",
Err: err,
}
}
cachePolicy := resp.Header.Get("cache-control")
ageIndex := strings.Index(cachePolicy, "max-age=")
if ageIndex < 0 {
return nil, &firebaseKeyError{
Message: "Error parsing cache-control header: 'max-age=' not found",
}
}
ageToEnd := cachePolicy[ageIndex+8:]
endIndex := strings.Index(ageToEnd, ",")
if endIndex < 0 {
endIndex = len(ageToEnd) - 1
}
ageString := ageToEnd[:endIndex]
age, err := strconv.ParseInt(ageString, 10, 64)
if err != nil {
return nil, &firebaseKeyError{
Message: "Error parsing max-age cache policy",
Err: err,
}
}
expiration := time.Now().Add(time.Duration(time.Duration(age) * time.Second))
err = json.Unmarshal(data, &firebasePublicKeys.PublicKeys)
if err != nil {
firebasePublicKeys = firebasePKCache{}
return nil, &firebaseKeyError{
Message: "Error unmarshalling firebase public key json",
Err: err,
}
}
firebasePublicKeys.Expiration = expiration
}
if key, found := firebasePublicKeys.PublicKeys[kid.(string)]; found {
k, err := jwt.ParseRSAPublicKeyFromPEM([]byte(key))
return k, err
}
return nil, &firebaseKeyError{
Message: "Error no matching public key for kid supplied in jwt",
}
}