package auth import ( "context" "net/http" "strings" "github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jws" "github.com/lestrrat-go/jwx/v2/jwt" "github.com/pkg/errors" ) const ( CookieName string = "edge-auth" ) type GetKeySetFunc func() (jwk.Set, error) func WithJWT(getKeySet GetKeySetFunc) OptionFunc { return func(o *Option) { o.GetClaims = func(ctx context.Context, r *http.Request, names ...string) ([]string, error) { claim, err := getClaims[string](r, getKeySet, names...) if err != nil { return nil, errors.WithStack(err) } return claim, nil } } } func FindRawToken(r *http.Request) (string, error) { authorization := r.Header.Get("Authorization") // Retrieve token from Authorization header rawToken := strings.TrimPrefix(authorization, "Bearer ") // Retrieve token from ?edge-auth= if rawToken == "" { rawToken = r.URL.Query().Get(CookieName) } if rawToken == "" { cookie, err := r.Cookie(CookieName) if err != nil && !errors.Is(err, http.ErrNoCookie) { return "", errors.WithStack(err) } if cookie != nil { rawToken = cookie.Value } } if rawToken == "" { return "", errors.WithStack(ErrUnauthenticated) } return rawToken, nil } func FindToken(r *http.Request, getKeySet GetKeySetFunc) (jwt.Token, error) { rawToken, err := FindRawToken(r) if err != nil { return nil, errors.WithStack(err) } keySet, err := getKeySet() if err != nil { return nil, errors.WithStack(err) } if keySet == nil { return nil, errors.New("no keyset") } token, err := jwt.Parse([]byte(rawToken), jwt.WithKeySet(keySet, jws.WithRequireKid(false)), jwt.WithValidate(true), ) if err != nil { return nil, errors.WithStack(err) } return token, nil } func getClaims[T any](r *http.Request, getKeySet GetKeySetFunc, names ...string) ([]T, error) { token, err := FindToken(r, getKeySet) if err != nil { return nil, errors.WithStack(err) } ctx := r.Context() mapClaims, err := token.AsMap(ctx) if err != nil { return nil, errors.WithStack(err) } claims := make([]T, len(names)) for idx, n := range names { rawClaim, exists := mapClaims[n] if !exists { continue } claim, ok := rawClaim.(T) if !ok { return nil, errors.Errorf("unexpected claim '%s' to be of type '%T', got '%T'", n, new(T), rawClaim) } claims[idx] = claim } return claims, nil }