super-graph/serv/auth_jwt.go

106 lines
1.9 KiB
Go
Raw Permalink Normal View History

2019-03-24 14:57:29 +01:00
package serv
import (
"context"
"io/ioutil"
"net/http"
2019-03-29 03:34:42 +01:00
"strings"
2019-03-24 14:57:29 +01:00
jwt "github.com/dgrijalva/jwt-go"
)
2019-03-29 03:34:42 +01:00
const (
authHeader = "Authorization"
jwtBase int = iota
2019-03-29 03:34:42 +01:00
jwtAuth0
)
2019-03-24 14:57:29 +01:00
func jwtHandler(next http.HandlerFunc) http.HandlerFunc {
var key interface{}
2019-03-29 03:34:42 +01:00
var jwtProvider int
2019-03-24 14:57:29 +01:00
2019-04-08 08:47:59 +02:00
cookie := conf.Auth.Cookie
2019-03-24 14:57:29 +01:00
2019-04-08 08:47:59 +02:00
if conf.Auth.JWT.Provider == "auth0" {
2019-03-29 03:34:42 +01:00
jwtProvider = jwtAuth0
}
2019-04-08 08:47:59 +02:00
secret := conf.Auth.JWT.Secret
publicKeyFile := conf.Auth.JWT.PubKeyFile
2019-03-24 14:57:29 +01:00
switch {
case len(secret) != 0:
key = []byte(secret)
case len(publicKeyFile) != 0:
kd, err := ioutil.ReadFile(publicKeyFile)
if err != nil {
logger.Fatal().Err(err).Send()
2019-03-24 14:57:29 +01:00
}
2019-04-08 08:47:59 +02:00
switch conf.Auth.JWT.PubKeyType {
2019-03-24 14:57:29 +01:00
case "ecdsa":
key, err = jwt.ParseECPublicKeyFromPEM(kd)
case "rsa":
key, err = jwt.ParseRSAPublicKeyFromPEM(kd)
default:
key, err = jwt.ParseECPublicKeyFromPEM(kd)
}
if err != nil {
logger.Fatal().Err(err).Send()
2019-03-24 14:57:29 +01:00
}
}
return func(w http.ResponseWriter, r *http.Request) {
var tok string
if len(cookie) != 0 {
ck, err := r.Cookie(cookie)
if err != nil {
next.ServeHTTP(w, r)
return
}
tok = ck.Value
} else {
ah := r.Header.Get(authHeader)
if len(ah) < 10 {
next.ServeHTTP(w, r)
return
}
tok = ah[7:]
}
token, err := jwt.ParseWithClaims(tok, &jwt.StandardClaims{}, func(token *jwt.Token) (interface{}, error) {
return key, nil
})
if err != nil {
next.ServeHTTP(w, r)
return
}
if claims, ok := token.Claims.(*jwt.StandardClaims); ok {
2019-03-29 03:34:42 +01:00
ctx := r.Context()
if jwtProvider == jwtAuth0 {
sub := strings.Split(claims.Subject, "|")
if len(sub) != 2 {
ctx = context.WithValue(ctx, userIDProviderKey, sub[0])
ctx = context.WithValue(ctx, userIDKey, sub[1])
}
} else {
ctx = context.WithValue(ctx, userIDKey, claims.Subject)
}
2019-11-15 07:35:19 +01:00
2019-03-24 14:57:29 +01:00
next.ServeHTTP(w, r.WithContext(ctx))
2019-11-15 07:35:19 +01:00
return
2019-03-24 14:57:29 +01:00
}
2019-11-15 07:35:19 +01:00
2019-03-24 14:57:29 +01:00
next.ServeHTTP(w, r)
}
}