edge/pkg/module/auth/jwt.go

110 lines
2.2 KiB
Go

package auth
import (
"context"
"net/http"
"strings"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jws"
"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/pkg/errors"
)
const (
CookieName string = "edge-auth"
)
type GetKeySetFunc func() (jwk.Set, error)
func WithJWT(getKeySet GetKeySetFunc) OptionFunc {
return func(o *Option) {
o.GetClaims = func(ctx context.Context, r *http.Request, names ...string) ([]string, error) {
claim, err := getClaims[string](r, getKeySet, names...)
if err != nil {
return nil, errors.WithStack(err)
}
return claim, nil
}
}
}
func FindToken(r *http.Request, getKeySet GetKeySetFunc) (jwt.Token, error) {
authorization := r.Header.Get("Authorization")
// Retrieve token from Authorization header
rawToken := strings.TrimPrefix(authorization, "Bearer ")
// Retrieve token from ?edge-auth=<value>
if rawToken == "" {
rawToken = r.URL.Query().Get(CookieName)
}
if rawToken == "" {
cookie, err := r.Cookie(CookieName)
if err != nil && !errors.Is(err, http.ErrNoCookie) {
return nil, errors.WithStack(err)
}
if cookie != nil {
rawToken = cookie.Value
}
}
if rawToken == "" {
return nil, errors.WithStack(ErrUnauthenticated)
}
keySet, err := getKeySet()
if err != nil {
return nil, errors.WithStack(err)
}
if keySet == nil {
return nil, errors.New("no keyset")
}
token, err := jwt.Parse([]byte(rawToken),
jwt.WithKeySet(keySet, jws.WithRequireKid(false)),
jwt.WithValidate(true),
)
if err != nil {
return nil, errors.WithStack(err)
}
return token, nil
}
func getClaims[T any](r *http.Request, getKeySet GetKeySetFunc, names ...string) ([]T, error) {
token, err := FindToken(r, getKeySet)
if err != nil {
return nil, errors.WithStack(err)
}
ctx := r.Context()
mapClaims, err := token.AsMap(ctx)
if err != nil {
return nil, errors.WithStack(err)
}
claims := make([]T, len(names))
for idx, n := range names {
rawClaim, exists := mapClaims[n]
if !exists {
continue
}
claim, ok := rawClaim.(T)
if !ok {
return nil, errors.Errorf("unexpected claim '%s' to be of type '%T', got '%T'", n, new(T), rawClaim)
}
claims[idx] = claim
}
return claims, nil
}