package auth import ( "context" "net/http" "strings" "github.com/golang-jwt/jwt" "github.com/pkg/errors" ) func WithJWT(keyFunc jwt.Keyfunc) OptionFunc { return func(o *Option) { o.GetClaim = func(ctx context.Context, r *http.Request, claimName string) (string, error) { claim, err := getClaim[string](r, claimName, keyFunc) if err != nil { return "", errors.WithStack(err) } return claim, nil } } } func getClaim[T any](r *http.Request, claimAttr string, keyFunc jwt.Keyfunc) (T, error) { rawToken := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") if rawToken == "" { rawToken = r.URL.Query().Get("token") } if rawToken == "" { return *new(T), errors.WithStack(ErrUnauthenticated) } token, err := jwt.Parse(rawToken, keyFunc) if err != nil { return *new(T), errors.WithStack(err) } if !token.Valid { return *new(T), errors.Errorf("invalid jwt token: '%v'", token.Raw) } mapClaims, ok := token.Claims.(jwt.MapClaims) if !ok { return *new(T), errors.Errorf("unexpected claims type '%T'", token.Claims) } rawClaim, exists := mapClaims[claimAttr] if !exists { return *new(T), errors.WithStack(ErrClaimNotFound) } claim, ok := rawClaim.(T) if !ok { return *new(T), errors.Errorf("unexpected claim '%s' to be of type '%T', got '%T'", claimAttr, new(T), rawClaim) } return claim, nil }