edge/pkg/module/auth/jwt.go

104 lines
2.1 KiB
Go

package auth
import (
"context"
"net/http"
"strings"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jws"
"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/pkg/errors"
)
const (
CookieName string = "edge-auth"
)
type GetKeySetFunc func() (jwk.Set, error)
func WithJWT(getKeySet GetKeySetFunc) OptionFunc {
return func(o *Option) {
o.GetClaim = func(ctx context.Context, r *http.Request, claimName string) (string, error) {
claim, err := getClaim[string](r, claimName, getKeySet)
if err != nil {
return "", errors.WithStack(err)
}
return claim, nil
}
}
}
func FindToken(r *http.Request, getKeySet GetKeySetFunc) (jwt.Token, error) {
authorization := r.Header.Get("Authorization")
// Retrieve token from Authorization header
rawToken := strings.TrimPrefix(authorization, "Bearer ")
// Retrieve token from ?edge-auth=<value>
if rawToken == "" {
rawToken = r.URL.Query().Get(CookieName)
}
if rawToken == "" {
cookie, err := r.Cookie(CookieName)
if err != nil && !errors.Is(err, http.ErrNoCookie) {
return nil, errors.WithStack(err)
}
if cookie != nil {
rawToken = cookie.Value
}
}
if rawToken == "" {
return nil, errors.WithStack(ErrUnauthenticated)
}
keySet, err := getKeySet()
if err != nil {
return nil, errors.WithStack(err)
}
if keySet == nil {
return nil, errors.New("no keyset")
}
token, err := jwt.Parse([]byte(rawToken),
jwt.WithKeySet(keySet, jws.WithRequireKid(false)),
jwt.WithValidate(true),
)
if err != nil {
return nil, errors.WithStack(err)
}
return token, nil
}
func getClaim[T any](r *http.Request, claimAttr string, getKeySet GetKeySetFunc) (T, error) {
token, err := FindToken(r, getKeySet)
if err != nil {
return *new(T), errors.WithStack(err)
}
ctx := r.Context()
mapClaims, err := token.AsMap(ctx)
if err != nil {
return *new(T), errors.WithStack(err)
}
rawClaim, exists := mapClaims[claimAttr]
if !exists {
return *new(T), errors.WithStack(ErrClaimNotFound)
}
claim, ok := rawClaim.(T)
if !ok {
return *new(T), errors.Errorf("unexpected claim '%s' to be of type '%T', got '%T'", claimAttr, new(T), rawClaim)
}
return claim, nil
}