package auth import ( "context" "net/http" "strings" "github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jws" "github.com/lestrrat-go/jwx/v2/jwt" "github.com/pkg/errors" ) const ( CookieName string = "edge-auth" ) type GetKeySetFunc func() (jwk.Set, error) func WithJWT(getKeySet GetKeySetFunc) OptionFunc { return func(o *Option) { o.GetClaim = func(ctx context.Context, r *http.Request, claimName string) (string, error) { claim, err := getClaim[string](r, claimName, getKeySet) if err != nil { return "", errors.WithStack(err) } return claim, nil } } } func FindToken(r *http.Request, getKeySet GetKeySetFunc) (jwt.Token, error) { authorization := r.Header.Get("Authorization") // Retrieve token from Authorization header rawToken := strings.TrimPrefix(authorization, "Bearer ") // Retrieve token from ?edge-auth= if rawToken == "" { rawToken = r.URL.Query().Get(CookieName) } if rawToken == "" { cookie, err := r.Cookie(CookieName) if err != nil && !errors.Is(err, http.ErrNoCookie) { return nil, errors.WithStack(err) } if cookie != nil { rawToken = cookie.Value } } if rawToken == "" { return nil, errors.WithStack(ErrUnauthenticated) } keySet, err := getKeySet() if err != nil { return nil, errors.WithStack(err) } token, err := jwt.Parse([]byte(rawToken), jwt.WithKeySet(keySet, jws.WithRequireKid(false)), jwt.WithValidate(true), ) if err != nil { return nil, errors.WithStack(err) } return token, nil } func getClaim[T any](r *http.Request, claimAttr string, getKeySet GetKeySetFunc) (T, error) { token, err := FindToken(r, getKeySet) if err != nil { return *new(T), errors.WithStack(err) } ctx := r.Context() mapClaims, err := token.AsMap(ctx) if err != nil { return *new(T), errors.WithStack(err) } rawClaim, exists := mapClaims[claimAttr] if !exists { return *new(T), errors.WithStack(ErrClaimNotFound) } claim, ok := rawClaim.(T) if !ok { return *new(T), errors.Errorf("unexpected claim '%s' to be of type '%T', got '%T'", claimAttr, new(T), rawClaim) } return claim, nil }