diff --git a/Makefile b/Makefile index 9e29d03..659c846 100644 --- a/Makefile +++ b/Makefile @@ -132,7 +132,7 @@ tools/grafterm/bin/grafterm: GOBIN=$(PWD)/tools/grafterm/bin go install github.com/slok/grafterm/cmd/grafterm@v0.2.0 bench: - go test -bench=. -run '^$$' ./internal/bench + go test -bench=. -run '^$$' -benchtime=10s ./internal/bench tools/benchstat/bin/benchstat: mkdir -p tools/benchstat/bin diff --git a/doc/fr/references/layers/authn/README.md b/doc/fr/references/layers/authn/README.md index e63e835..55087cd 100644 --- a/doc/fr/references/layers/authn/README.md +++ b/doc/fr/references/layers/authn/README.md @@ -27,8 +27,8 @@ Bouncer utilise le projet [`expr`](https://expr-lang.org/) comme DSL. En plus de Le comportement des règles par défaut est le suivant: 1. L'ensemble des entêtes HTTP correspondant au patron `Remote-*` sont supprimés ; -2. L'identifiant de l'utilisateur identifié (`user.subject`) est exporté sous la forme de l'entête HTTP `Remote-User` ; -3. L'ensemble des attributs de l'utilisateur identifié (`user.attrs`) sont exportés sous la forme `Remote-User-Attr-` où `` est le nom de l'attribut en minuscule, avec les `_` transformés en `-`. +2. L'identifiant de l'utilisateur identifié (`vars.user.subject`) est exporté sous la forme de l'entête HTTP `Remote-User` ; +3. L'ensemble des attributs de l'utilisateur identifié (`vars.user.attrs`) sont exportés sous la forme `Remote-User-Attr-` où `` est le nom de l'attribut en minuscule, avec les `_` transformés en `-`. ### Fonctions @@ -36,25 +36,25 @@ Le comportement des règles par défaut est le suivant: Interdire l'accès à l'utilisateur. -##### `add_header(name string, value string)` +##### `add_header(ctx, name string, value string)` Ajouter une valeur à un entête HTTP via son nom `name` et sa valeur `value`. -##### `set_header(name string, value string)` +##### `set_header(ctx, name string, value string)` Définir la valeur d'un entête HTTP via son nom `name` et sa valeur `value`. La valeur précédente est écrasée. -##### `del_headers(pattern string)` +##### `del_headers(ctx, pattern string)` Supprimer un ou plusieurs entêtes HTTP dont le nom correspond au patron `pattern`. Le patron est défini par une chaîne comprenant un ou plusieurs caractères `*`, signifiant un ou plusieurs caractères arbitraires. -##### `set_host(host string)` +##### `set_host(ctx, host string)` Modifier la valeur de l'entête `Host` de la requête. -##### `set_url(url string)` +##### `set_url(ctx, url string)` Modifier l'URL du serveur cible. @@ -62,7 +62,7 @@ Modifier l'URL du serveur cible. Les règles ont accès aux variables suivantes pendant leur exécution. -#### `user` +#### `vars.user` L'utilisateur identifié par le layer. diff --git a/doc/fr/references/layers/authn/basic.md b/doc/fr/references/layers/authn/basic.md index 89a55f4..e4e2875 100644 --- a/doc/fr/references/layers/authn/basic.md +++ b/doc/fr/references/layers/authn/basic.md @@ -14,12 +14,12 @@ Les options disponibles pour le layer sont décrites via un [schéma JSON](https En plus de ces options spécifiques le layer peut également être configuré via [les options communes aux layers `authn-*`](../../../../../internal/proxy/director/layer/authn/layer-options.json). -## Objet `user` et attributs +## Objet `vars.user` et attributs L'objet `user` exposé au moteur de règles sera construit de la manière suivante: -- `user.subject` sera initialisé avec le nom d'utilisateur identifié ; -- `user.attrs` sera composé des attributs associés à l'utilisation (voir les options). +- `vars.user.subject` sera initialisé avec le nom d'utilisateur identifié ; +- `vars.user.attrs` sera composé des attributs associés à l'utilisation (voir les options). ## Métriques diff --git a/doc/fr/references/layers/authn/network.md b/doc/fr/references/layers/authn/network.md index 7db27b9..434169c 100644 --- a/doc/fr/references/layers/authn/network.md +++ b/doc/fr/references/layers/authn/network.md @@ -14,12 +14,12 @@ Les options disponibles pour le layer sont décrites via un [schéma JSON](https En plus de ces options spécifiques le layer peut également être configuré via [les options communes aux layers `authn-*`](../../../../../internal/proxy/director/layer/authn/layer-options.json). -## Objet `user` et attributs +## Objet `vars.user` et attributs -L'objet `user` exposé au moteur de règles sera construit de la manière suivante: +L'objet `vars.user` exposé au moteur de règles sera construit de la manière suivante: -- `user.subject` sera initialisé avec le couple `:` ; -- `user.attrs` sera vide. +- `vars.user.subject` sera initialisé avec le couple `:` ; +- `vars.user.attrs` sera vide. ## Métriques diff --git a/doc/fr/references/layers/authn/oidc.md b/doc/fr/references/layers/authn/oidc.md index 498e8f6..d7c7808 100644 --- a/doc/fr/references/layers/authn/oidc.md +++ b/doc/fr/references/layers/authn/oidc.md @@ -16,18 +16,18 @@ Les options disponibles pour le layer sont décrites via un [schéma JSON](https En plus de ces options spécifiques le layer peut également être configuré via [les options communes aux layers `authn-*`](../../../../../internal/proxy/director/layer/authn/layer-options.json). -## Objet `user` et attributs +## Objet `vars.user` et attributs -L'objet `user` exposé au moteur de règles sera construit de la manière suivante: +L'objet `vars.user` exposé au moteur de règles sera construit de la manière suivante: -- `user.subject` sera initialisé avec la valeur du [claim](https://openid.net/specs/openid-connect-core-1_0.html#Claims) `sub` extrait de l'`idToken` récupéré lors de l'authentification ; -- `user.attrs` comportera les propriétés suivantes: +- `vars.user.subject` sera initialisé avec la valeur du [claim](https://openid.net/specs/openid-connect-core-1_0.html#Claims) `sub` extrait de l'`idToken` récupéré lors de l'authentification ; +- `vars.user.attrs` comportera les propriétés suivantes: - - L'ensemble des `claims` provenant de l'`idToken` seront transposés en `claim_` (ex: `idToken.iss` sera transposé en `user.attrs.claim_iss`) ; - - `user.attrs.access_token`: le jeton d'accès associé à l'authentification ; - - `user.attrs.refresh_token`: le jeton de rafraîchissement associé à l'authentification (si disponible, en fonction des `scopes` demandés par le client) ; - - `user.attrs.token_expiry`: Horodatage Unix (en secondes) associé à la date d'expiration du jeton d'accès ; - - `user.attrs.logout_url`: URL de déconnexion pour la suppression de la session Bouncer. + - L'ensemble des `claims` provenant de l'`idToken` seront transposés en `claim_` (ex: `idToken.iss` sera transposé en `vars.user.attrs.claim_iss`) ; + - `vars.user.attrs.access_token`: le jeton d'accès associé à l'authentification ; + - `vars.user.attrs.refresh_token`: le jeton de rafraîchissement associé à l'authentification (si disponible, en fonction des `scopes` demandés par le client) ; + - `vars.user.attrs.token_expiry`: Horodatage Unix (en secondes) associé à la date d'expiration du jeton d'accès ; + - `vars.user.attrs.logout_url`: URL de déconnexion pour la suppression de la session Bouncer. **Attention** Cette URL ne permet dans la plupart des cas que de supprimer la session côté Bouncer. La suppression de la session côté fournisseur d'identité est conditionné à la présence ou non de l'attribut [`end_session_endpoint`](https://openid.net/specs/openid-connect-session-1_0-17.html#OPMetadata) dans les données du point d'entrée de découverte de service (`.wellknown/openid-configuration`). diff --git a/doc/fr/references/layers/rewriter.md b/doc/fr/references/layers/rewriter.md index 9f4afdc..049f412 100644 --- a/doc/fr/references/layers/rewriter.md +++ b/doc/fr/references/layers/rewriter.md @@ -24,15 +24,15 @@ Bouncer utilise le projet [`expr`](https://expr-lang.org/) comme DSL. En plus de #### Communes -##### `add_header(name string, value string)` +##### `add_header(ctx, name string, value string)` Ajouter une valeur à un entête HTTP via son nom `name` et sa valeur `value`. -##### `set_header(name string, value string)` +##### `set_header(ctx, name string, value string)` Définir la valeur d'un entête HTTP via son nom `name` et sa valeur `value`. La valeur précédente est écrasée. -##### `del_headers(pattern string)` +##### `del_headers(ctx, pattern string)` Supprimer un ou plusieurs entêtes HTTP dont le nom correspond au patron `pattern`. @@ -40,11 +40,11 @@ Le patron est défini par une chaîne comprenant un ou plusieurs caractères `*` #### Requête -##### `set_host(host string)` +##### `set_host(ctx, host string)` Modifier la valeur de l'entête `Host` de la requête. -##### `set_url(url string)` +##### `set_url(ctx, url string)` Modifier l'URL du serveur cible. @@ -58,7 +58,28 @@ Les règles ont accès aux variables suivantes pendant leur exécution. **Ces do #### Requête -##### `request` +##### `vars.original_url` + +L'URL originale, avant réécriture du `Host` par Bouncer. + +```js +{ + scheme: "string", // Schéma HTTP de l'URL + opaque: "string", // Données opaque de l'URL + user: { // Identifiants d'URL (Basic Auth) + username: "", + password: "" + }, + host: "string", // Nom d'hôte (:) de l'URL + path: "string", // Chemin de l'URL (format assaini) + rawPath: "string", // Chemin de l'URL (format brut) + raw_query: "string", // Variables d'URL (format brut) + fragment : "string", // Fragment d'URL (format assaini) + raw_fragment : "string" // Fragment d'URL (format brut) +} +``` + +##### `vars.request` La requête en cours de traitement. @@ -67,61 +88,65 @@ La requête en cours de traitement. method: "string", // Méthode HTTP host: "string", // Nom d'hôte (`Host`) associé à la requête url: { // URL associée à la requête sous sa forme structurée - "scheme": "string", // Schéma HTTP de l'URL - "opaque": "string", // Données opaque de l'URL - "user": { // Identifiants d'URL (Basic Auth) - "username": "", - "password": "" + scheme: "string", // Schéma HTTP de l'URL + opaque: "string", // Données opaque de l'URL + user: { // Identifiants d'URL (Basic Auth) + username: "", + password: "" }, - "host": "string", // Nom d'hôte (:) de l'URL - "path": "string", // Chemin de l'URL (format assaini) - "rawPath": "string", // Chemin de l'URL (format brut) - "rawQuery": "string", // Variables d'URL (format brut) - "fragment" : "string", // Fragment d'URL (format assaini) - "rawFragment" : "string" // Fragment d'URL (format brut) + host: "string", // Nom d'hôte (:) de l'URL + path: "string", // Chemin de l'URL (format assaini) + rawPath: "string", // Chemin de l'URL (format brut) + raw_query: "string", // Variables d'URL (format brut) + fragment : "string", // Fragment d'URL (format assaini) + raw_fragment : "string" // Fragment d'URL (format brut) }, - rawUrl: "string", // URL associée à la requête (format assaini) + raw_url: "string", // URL associée à la requête (format assaini) proto: "string", // Numéro de version du protocole utilisé - protoMajor: "int", // Numéro de version majeure du protocole utilisé - protoMinor: "int", // Numéro de version mineur du protocole utilisé + proto_major: "int", // Numéro de version majeure du protocole utilisé + proto_minor: "int", // Numéro de version mineur du protocole utilisé header: { // Table associative des entêtes HTTP associés à la requête "string": ["string"] }, - contentLength: "int", // Taille du corps de la requête - transferEncoding: ["string"], // MIME-Type(s) d'encodage du corps de la requête + content_length: "int", // Taille du corps de la requête + transfer_encoding: ["string"], // MIME-Type(s) d'encodage du corps de la requête trailer: { // Table associative des entêtes HTTP associés à la requête, transmises après le corps de la requête "string": ["string"] }, - remoteAddr: "string", // Adresse du client HTTP à l'origine de la requête - requestUri: "string" // URL "brute" associée à la requêtes (avant opérations d'assainissement, utiliser "url" plutôt) + remote_addr: "string", // Adresse du client HTTP à l'origine de la requête + request_uri: "string" // URL "brute" associée à la requêtes (avant opérations d'assainissement, utiliser "url" plutôt) } ``` #### Réponse -##### `response` +##### `vars.response` La réponse en cours de traitement. ```js { - statusCode: "int", // Code de statut de la réponse + status_code: "int", // Code de statut de la réponse status: "string", // Message associé au code de statut proto: "string", // Numéro de version du protocole utilisé - protoMajor: "int", // Numéro de version majeure du protocole utilisé - protoMinor: "int", // Numéro de version mineur du protocole utilisé + proto_major: "int", // Numéro de version majeure du protocole utilisé + proto_minor: "int", // Numéro de version mineur du protocole utilisé header: { // Table associative des entêtes HTTP associés à la requête "string": ["string"] }, - contentLength: "int", // Taille du corps de la réponse - transferEncoding: ["string"], // MIME-Type(s) d'encodage du corps de la requête + content_length: "int", // Taille du corps de la réponse + transfer_encoding: ["string"], // MIME-Type(s) d'encodage du corps de la requête trailer: { // Table associative des entêtes HTTP associés à la requête, transmises après le corps de la requête "string": ["string"] }, } ``` -##### `request` +##### `vars.request` + +_Voir section précédente._ + +##### `vars.original_url` _Voir section précédente._ diff --git a/internal/bench/proxy_test.go b/internal/bench/proxy_test.go index 54d7b23..1f20b4f 100644 --- a/internal/bench/proxy_test.go +++ b/internal/bench/proxy_test.go @@ -3,7 +3,6 @@ package proxy_test import ( "context" "io" - "log" "net/http" "net/http/httptest" "net/http/httputil" @@ -24,6 +23,7 @@ import ( redisStore "forge.cadoles.com/cadoles/bouncer/internal/store/redis" "github.com/pkg/errors" "github.com/redis/go-redis/v9" + "gitlab.com/wpetit/goweb/logger" "gopkg.in/yaml.v3" "forge.cadoles.com/cadoles/bouncer/internal/setup" @@ -39,6 +39,19 @@ func BenchmarkProxies(b *testing.B) { name := strings.TrimSuffix(filepath.Base(f), filepath.Ext(f)) b.Run(name, func(b *testing.B) { + heap, err := os.Create(filepath.Join("testdata", "proxies", name+"_heap.prof")) + if err != nil { + b.Fatalf("%+v", errors.Wrapf(err, "could not create heap profile")) + } + + defer func() { + defer heap.Close() + + if err := pprof.WriteHeapProfile(heap); err != nil { + b.Fatalf("%+v", errors.WithStack(err)) + } + }() + conf, err := loadProxyBenchConfig(f) if err != nil { b.Fatalf("%+v", errors.Wrapf(err, "could notre load bench config")) @@ -78,7 +91,7 @@ func BenchmarkProxies(b *testing.B) { b.Logf("fetching url '%s'", rawProxyURL) - profile, err := os.Create(filepath.Join("testdata", "proxies", name+".prof")) + profile, err := os.Create(filepath.Join("testdata", "proxies", name+"_cpu.prof")) if err != nil { b.Fatalf("%+v", errors.Wrapf(err, "could not create cpu profile")) } @@ -86,7 +99,7 @@ func BenchmarkProxies(b *testing.B) { defer profile.Close() if err := pprof.StartCPUProfile(profile); err != nil { - log.Fatal(err) + b.Fatalf("%+v", errors.WithStack(err)) } defer pprof.StopCPUProfile() @@ -227,7 +240,12 @@ func createProxy(name string, conf *proxyBenchConfig, logf func(format string, a } - layers, err := setup.GetLayers(context.Background(), config.NewDefault()) + appConf := config.NewDefault() + appConf.Logger.Level = config.InterpolatedInt(logger.LevelError) + appConf.Layers.Authn.TemplateDir = "../../layers/authn/templates" + appConf.Layers.Queue.TemplateDir = "../../layers/queue/templates" + + layers, err := setup.GetLayers(context.Background(), appConf) if err != nil { return nil, nil, errors.WithStack(err) } diff --git a/internal/bench/testdata/proxies/basic-auth.yml b/internal/bench/testdata/proxies/basic-auth.yml index dad03c6..5f51ee4 100644 --- a/internal/bench/testdata/proxies/basic-auth.yml +++ b/internal/bench/testdata/proxies/basic-auth.yml @@ -12,7 +12,7 @@ proxy: attributes: email: foo@bar.com rules: - - set_header("Remote-User-Attr-Email", user.attrs.email) + - set_header(ctx, "Remote-User-Attr-Email", vars.user.attrs.email) fetch: url: user: diff --git a/internal/bench/testdata/proxies/queue.yml b/internal/bench/testdata/proxies/queue.yml new file mode 100644 index 0000000..0772395 --- /dev/null +++ b/internal/bench/testdata/proxies/queue.yml @@ -0,0 +1,10 @@ +proxy: + from: ["*"] + to: "" + layers: + queue: + type: queue + enabled: true + options: + capacity: 100 + keepAlive: 10s diff --git a/internal/bench/testdata/proxies/rewriter.yml b/internal/bench/testdata/proxies/rewriter.yml index bbaa8b2..fe84696 100644 --- a/internal/bench/testdata/proxies/rewriter.yml +++ b/internal/bench/testdata/proxies/rewriter.yml @@ -8,5 +8,5 @@ proxy: options: rules: request: - - set_host(request.url.host) - - set_header("X-Proxied-With", "bouncer") + - set_host(ctx, vars.request.url.host) + - set_header(ctx, "X-Proxied-With", "bouncer") diff --git a/internal/command/server/dummy/index.gohtml b/internal/command/server/dummy/index.gohtml index 53a2b5e..24aeb63 100644 --- a/internal/command/server/dummy/index.gohtml +++ b/internal/command/server/dummy/index.gohtml @@ -4,14 +4,14 @@

Incoming headers

- + {{ range $key, $val := .Request.Header }} - + @@ -27,7 +27,7 @@

Incoming cookies

Key Value
{{ $key }}
- + @@ -41,7 +41,7 @@ {{ range $cookie := .Request.Cookies }} - + diff --git a/internal/proxy/director/layer/authn/layer.go b/internal/proxy/director/layer/authn/layer.go index dd15a32..0532ef1 100644 --- a/internal/proxy/director/layer/authn/layer.go +++ b/internal/proxy/director/layer/authn/layer.go @@ -74,7 +74,7 @@ func (l *Layer) Middleware(layer *store.Layer) proxy.Middleware { return } - if err := l.applyRules(r, options, user); err != nil { + if err := l.applyRules(ctx, r, options, user); err != nil { if errors.Is(err, ErrForbidden) { l.renderForbiddenPage(w, r, layer, options, user) return diff --git a/internal/proxy/director/layer/authn/rules.go b/internal/proxy/director/layer/authn/rules.go index 80c4837..78e318c 100644 --- a/internal/proxy/director/layer/authn/rules.go +++ b/internal/proxy/director/layer/authn/rules.go @@ -1,6 +1,7 @@ package authn import ( + "context" "net/http" "forge.cadoles.com/cadoles/bouncer/internal/rule" @@ -9,30 +10,32 @@ import ( "github.com/pkg/errors" ) -type Env struct { +type Vars struct { User *User `expr:"user"` } -func (l *Layer) applyRules(r *http.Request, options *LayerOptions, user *User) error { +func (l *Layer) applyRules(ctx context.Context, r *http.Request, options *LayerOptions, user *User) error { rules := options.Rules if len(rules) == 0 { return nil } - engine, err := rule.NewEngine[*Env]( + engine, err := rule.NewEngine[*Vars]( rule.WithRules(options.Rules...), rule.WithExpr(getAuthnAPI()...), - ruleHTTP.WithRequestFuncs(r), + ruleHTTP.WithRequestFuncs(), ) if err != nil { return errors.WithStack(err) } - env := &Env{ + vars := &Vars{ User: user, } - if _, err := engine.Apply(env); err != nil { + ctx = ruleHTTP.WithRequest(ctx, r) + + if _, err := engine.Apply(ctx, vars); err != nil { return errors.WithStack(err) } diff --git a/internal/proxy/director/layer/rewriter/layer.go b/internal/proxy/director/layer/rewriter/layer.go index fa74431..8982339 100644 --- a/internal/proxy/director/layer/rewriter/layer.go +++ b/internal/proxy/director/layer/rewriter/layer.go @@ -6,6 +6,9 @@ import ( proxy "forge.cadoles.com/Cadoles/go-proxy" "forge.cadoles.com/Cadoles/go-proxy/wildcard" "forge.cadoles.com/cadoles/bouncer/internal/proxy/director" + "forge.cadoles.com/cadoles/bouncer/internal/proxy/director/layer/util" + "forge.cadoles.com/cadoles/bouncer/internal/rule" + ruleHTTP "forge.cadoles.com/cadoles/bouncer/internal/rule/http" "forge.cadoles.com/cadoles/bouncer/internal/store" "github.com/pkg/errors" "gitlab.com/wpetit/goweb/logger" @@ -13,7 +16,10 @@ import ( const LayerType store.LayerType = "rewriter" -type Layer struct{} +type Layer struct { + requestRuleEngine *util.RevisionedRuleEngine[*RequestVars, *LayerOptions] + responseRuleEngine *util.RevisionedRuleEngine[*ResponseVars, *LayerOptions] +} func (l *Layer) LayerType() store.LayerType { return LayerType @@ -39,7 +45,7 @@ func (l *Layer) Middleware(layer *store.Layer) proxy.Middleware { return } - if err := l.applyRequestRules(r, options); err != nil { + if err := l.applyRequestRules(ctx, r, layer.Revision, options); err != nil { logger.Error(ctx, "could not apply request rules", logger.E(errors.WithStack(err))) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) @@ -66,7 +72,9 @@ func (l *Layer) ResponseTransformer(layer *store.Layer) proxy.ResponseTransforme return nil } - if err := l.applyResponseRules(r, options); err != nil { + ctx := r.Request.Context() + + if err := l.applyResponseRules(ctx, r, layer.Revision, options); err != nil { return errors.WithStack(err) } @@ -75,7 +83,30 @@ func (l *Layer) ResponseTransformer(layer *store.Layer) proxy.ResponseTransforme } func New(funcs ...OptionFunc) *Layer { - return &Layer{} + return &Layer{ + requestRuleEngine: util.NewRevisionedRuleEngine(func(options *LayerOptions) (*rule.Engine[*RequestVars], error) { + engine, err := rule.NewEngine[*RequestVars]( + rule.WithRules(options.Rules.Request...), + ruleHTTP.WithRequestFuncs(), + ) + if err != nil { + return nil, errors.WithStack(err) + } + + return engine, nil + }), + responseRuleEngine: util.NewRevisionedRuleEngine(func(options *LayerOptions) (*rule.Engine[*ResponseVars], error) { + engine, err := rule.NewEngine[*ResponseVars]( + rule.WithRules(options.Rules.Response...), + ruleHTTP.WithResponseFuncs(), + ) + if err != nil { + return nil, errors.WithStack(err) + } + + return engine, nil + }), + } } var ( diff --git a/internal/proxy/director/layer/rewriter/rules.go b/internal/proxy/director/layer/rewriter/rules.go index 466a04a..e622c48 100644 --- a/internal/proxy/director/layer/rewriter/rules.go +++ b/internal/proxy/director/layer/rewriter/rules.go @@ -1,68 +1,93 @@ package rewriter import ( + "context" "net/http" + "forge.cadoles.com/cadoles/bouncer/internal/proxy/director" "forge.cadoles.com/cadoles/bouncer/internal/rule" ruleHTTP "forge.cadoles.com/cadoles/bouncer/internal/rule/http" "github.com/pkg/errors" ) -type RequestEnv struct { - Request RequestInfo `expr:"request"` +type RequestVars struct { + Request RequestVar `expr:"request"` + OriginalURL URLVar `expr:"original_url"` } -type URLEnv struct { - Scheme string `expr:"scheme"` - Opaque string `expr:"opaque"` - User UserInfoEnv `expr:"user"` - Host string `expr:"host"` - Path string `expr:"path"` - RawPath string `expr:"rawPath"` - RawQuery string `expr:"rawQuery"` - Fragment string `expr:"fragment"` - RawFragment string `expr:"rawFragment"` +type URLVar struct { + Scheme string `expr:"scheme"` + Opaque string `expr:"opaque"` + User UserVar `expr:"user"` + Host string `expr:"host"` + Path string `expr:"path"` + RawPath string `expr:"raw_path"` + RawQuery string `expr:"raw_query"` + Fragment string `expr:"fragment"` + RawFragment string `expr:"raw_fragment"` } -type UserInfoEnv struct { +type UserVar struct { Username string `expr:"username"` Password string `expr:"password"` } -type RequestInfo struct { +type RequestVar struct { Method string `expr:"method"` - URL URLEnv `expr:"url"` - RawURL string `expr:"rawUrl"` + URL URLVar `expr:"url"` + RawURL string `expr:"raw_url"` Proto string `expr:"proto"` - ProtoMajor int `expr:"protoMajor"` - ProtoMinor int `expr:"protoMinor"` + ProtoMajor int `expr:"proto_major"` + ProtoMinor int `expr:"proto_minor"` Header map[string][]string `expr:"header"` - ContentLength int64 `expr:"contentLength"` - TransferEncoding []string `expr:"transferEncoding"` + ContentLength int64 `expr:"content_length"` + TransferEncoding []string `expr:"transfer_encoding"` Host string `expr:"host"` Trailer map[string][]string `expr:"trailer"` - RemoteAddr string `expr:"remoteAddr"` - RequestURI string `expr:"requestUri"` + RemoteAddr string `expr:"remote_addr"` + RequestURI string `expr:"request_uri"` } -func (l *Layer) applyRequestRules(r *http.Request, options *LayerOptions) error { +func (l *Layer) applyRequestRules(ctx context.Context, r *http.Request, layerRevision int, options *LayerOptions) error { rules := options.Rules.Request if len(rules) == 0 { return nil } - engine, err := l.getRequestRuleEngine(r, options) + engine, err := l.getRequestRuleEngine(ctx, layerRevision, options) if err != nil { return errors.WithStack(err) } - env := &RequestEnv{ - Request: RequestInfo{ + originalURL, err := director.OriginalURL(ctx) + if err != nil { + return errors.WithStack(err) + } + + vars := &RequestVars{ + OriginalURL: URLVar{ + Scheme: originalURL.Scheme, + Opaque: originalURL.Opaque, + User: UserVar{ + Username: originalURL.User.Username(), + Password: func() string { + passwd, _ := originalURL.User.Password() + return passwd + }(), + }, + Host: originalURL.Host, + Path: originalURL.Path, + RawPath: originalURL.RawPath, + RawQuery: originalURL.RawQuery, + Fragment: originalURL.Fragment, + RawFragment: originalURL.RawFragment, + }, + Request: RequestVar{ Method: r.Method, - URL: URLEnv{ + URL: URLVar{ Scheme: r.URL.Scheme, Opaque: r.URL.Opaque, - User: UserInfoEnv{ + User: UserVar{ Username: r.URL.User.Username(), Password: func() string { passwd, _ := r.URL.User.Password() @@ -90,18 +115,17 @@ func (l *Layer) applyRequestRules(r *http.Request, options *LayerOptions) error }, } - if _, err := engine.Apply(env); err != nil { + ctx = ruleHTTP.WithRequest(ctx, r) + + if _, err := engine.Apply(ctx, vars); err != nil { return errors.WithStack(err) } return nil } -func (l *Layer) getRequestRuleEngine(r *http.Request, options *LayerOptions) (*rule.Engine[*RequestEnv], error) { - engine, err := rule.NewEngine[*RequestEnv]( - rule.WithRules(options.Rules.Request...), - ruleHTTP.WithRequestFuncs(r), - ) +func (l *Layer) getRequestRuleEngine(ctx context.Context, layerRevision int, options *LayerOptions) (*rule.Engine[*RequestVars], error) { + engine, err := l.requestRuleEngine.Get(ctx, layerRevision, options) if err != nil { return nil, errors.WithStack(err) } @@ -109,42 +133,65 @@ func (l *Layer) getRequestRuleEngine(r *http.Request, options *LayerOptions) (*r return engine, nil } -type ResponseEnv struct { - Request RequestInfo `expr:"request"` - Response ResponseInfo `expr:"response"` +type ResponseVars struct { + OriginalURL URLVar `expr:"original_url"` + Request RequestVar `expr:"request"` + Response ResponseVar `expr:"response"` } -type ResponseInfo struct { +type ResponseVar struct { Status string `expr:"status"` - StatusCode int `expr:"statusCode"` + StatusCode int `expr:"status_code"` Proto string `expr:"proto"` - ProtoMajor int `expr:"protoMajor"` - ProtoMinor int `expr:"protoMinor"` + ProtoMajor int `expr:"proto_major"` + ProtoMinor int `expr:"proto_minor"` Header map[string][]string `expr:"header"` - ContentLength int64 `expr:"contentLength"` - TransferEncoding []string `expr:"transferEncoding"` + ContentLength int64 `expr:"content_length"` + TransferEncoding []string `expr:"transfer_encoding"` Uncompressed bool `expr:"uncompressed"` Trailer map[string][]string `expr:"trailer"` } -func (l *Layer) applyResponseRules(r *http.Response, options *LayerOptions) error { +func (l *Layer) applyResponseRules(ctx context.Context, r *http.Response, layerRevision int, options *LayerOptions) error { rules := options.Rules.Response if len(rules) == 0 { return nil } - engine, err := l.getResponseRuleEngine(r, options) + engine, err := l.getResponseRuleEngine(ctx, layerRevision, options) if err != nil { return errors.WithStack(err) } - env := &ResponseEnv{ - Request: RequestInfo{ + originalURL, err := director.OriginalURL(ctx) + if err != nil { + return errors.WithStack(err) + } + + vars := &ResponseVars{ + OriginalURL: URLVar{ + Scheme: originalURL.Scheme, + Opaque: originalURL.Opaque, + User: UserVar{ + Username: originalURL.User.Username(), + Password: func() string { + passwd, _ := originalURL.User.Password() + return passwd + }(), + }, + Host: originalURL.Host, + Path: originalURL.Path, + RawPath: originalURL.RawPath, + RawQuery: originalURL.RawQuery, + Fragment: originalURL.Fragment, + RawFragment: originalURL.RawFragment, + }, + Request: RequestVar{ Method: r.Request.Method, - URL: URLEnv{ + URL: URLVar{ Scheme: r.Request.URL.Scheme, Opaque: r.Request.URL.Opaque, - User: UserInfoEnv{ + User: UserVar{ Username: r.Request.URL.User.Username(), Password: func() string { passwd, _ := r.Request.URL.User.Password() @@ -170,7 +217,7 @@ func (l *Layer) applyResponseRules(r *http.Response, options *LayerOptions) erro RemoteAddr: r.Request.RemoteAddr, RequestURI: r.Request.RequestURI, }, - Response: ResponseInfo{ + Response: ResponseVar{ Proto: r.Proto, ProtoMajor: r.ProtoMajor, ProtoMinor: r.ProtoMinor, @@ -183,18 +230,17 @@ func (l *Layer) applyResponseRules(r *http.Response, options *LayerOptions) erro }, } - if _, err := engine.Apply(env); err != nil { + ctx = ruleHTTP.WithResponse(ctx, r) + + if _, err := engine.Apply(ctx, vars); err != nil { return errors.WithStack(err) } return nil } -func (l *Layer) getResponseRuleEngine(r *http.Response, options *LayerOptions) (*rule.Engine[*ResponseEnv], error) { - engine, err := rule.NewEngine[*ResponseEnv]( - rule.WithRules(options.Rules.Response...), - ruleHTTP.WithResponseFuncs(r), - ) +func (l *Layer) getResponseRuleEngine(ctx context.Context, layerRevision int, options *LayerOptions) (*rule.Engine[*ResponseVars], error) { + engine, err := l.responseRuleEngine.Get(ctx, layerRevision, options) if err != nil { return nil, errors.WithStack(err) } diff --git a/internal/proxy/director/layer/util/revisioned_rule_engine.go b/internal/proxy/director/layer/util/revisioned_rule_engine.go new file mode 100644 index 0000000..0205573 --- /dev/null +++ b/internal/proxy/director/layer/util/revisioned_rule_engine.go @@ -0,0 +1,51 @@ +package util + +import ( + "context" + "sync" + + "forge.cadoles.com/cadoles/bouncer/internal/rule" + "github.com/pkg/errors" + "gitlab.com/wpetit/goweb/logger" +) + +type RuleEngineFactoryFunc[V any, O any] func(ops O) (*rule.Engine[V], error) + +type RevisionedRuleEngine[V any, O any] struct { + mutex sync.RWMutex + revision int + engine *rule.Engine[V] + factory RuleEngineFactoryFunc[V, O] +} + +func (e *RevisionedRuleEngine[V, O]) Get(ctx context.Context, revision int, opts O) (*rule.Engine[V], error) { + e.mutex.RLock() + if revision == e.revision { + logger.Debug(ctx, "using cached rule engine", logger.F("layerRevision", revision)) + + defer e.mutex.RUnlock() + return e.engine, nil + } + e.mutex.RUnlock() + + e.mutex.Lock() + defer e.mutex.Unlock() + + logger.Debug(ctx, "creating rule engine", logger.F("layerRevision", revision)) + + engine, err := e.factory(opts) + if err != nil { + return nil, errors.WithStack(err) + } + + e.engine = engine + e.revision = revision + + return engine, nil +} + +func NewRevisionedRuleEngine[V any, O any](factory RuleEngineFactoryFunc[V, O]) *RevisionedRuleEngine[V, O] { + return &RevisionedRuleEngine[V, O]{ + factory: factory, + } +} diff --git a/internal/rule/engine.go b/internal/rule/engine.go index d9c3e54..446a063 100644 --- a/internal/rule/engine.go +++ b/internal/rule/engine.go @@ -1,16 +1,28 @@ package rule import ( + "context" + "github.com/expr-lang/expr" "github.com/expr-lang/expr/vm" "github.com/pkg/errors" ) -type Engine[E any] struct { +type Engine[V any] struct { rules []*vm.Program } -func (e *Engine[E]) Apply(env E) ([]any, error) { +func (e *Engine[V]) Apply(ctx context.Context, vars V) ([]any, error) { + type Env[V any] struct { + Context context.Context `expr:"ctx"` + Vars V `expr:"vars"` + } + + env := Env[V]{ + Context: ctx, + Vars: vars, + } + results := make([]any, 0, len(e.rules)) for i, r := range e.rules { result, err := expr.Run(r, env) @@ -42,3 +54,26 @@ func NewEngine[E any](funcs ...OptionFunc) (*Engine[E], error) { return engine, nil } + +func Context[T any](ctx context.Context, key any) (T, bool) { + raw := ctx.Value(key) + if raw == nil { + return *new(T), false + } + + value, err := Assert[T](raw) + if err != nil { + return *new(T), false + } + + return value, true +} + +func Assert[T any](raw any) (T, error) { + value, ok := raw.(T) + if !ok { + return *new(T), errors.Errorf("unexpected value '%T'", value) + } + + return value, nil +} diff --git a/internal/rule/http/context.go b/internal/rule/http/context.go new file mode 100644 index 0000000..4eda60b --- /dev/null +++ b/internal/rule/http/context.go @@ -0,0 +1,31 @@ +package http + +import ( + "context" + "net/http" + + "forge.cadoles.com/cadoles/bouncer/internal/rule" +) + +type contextKey string + +const ( + contextKeyRequest contextKey = "request" + contextKeyResponse contextKey = "response" +) + +func WithRequest(ctx context.Context, r *http.Request) context.Context { + return context.WithValue(ctx, contextKeyRequest, r) +} + +func WithResponse(ctx context.Context, r *http.Response) context.Context { + return context.WithValue(ctx, contextKeyResponse, r) +} + +func ctxRequest(ctx context.Context) (*http.Request, bool) { + return rule.Context[*http.Request](ctx, contextKeyRequest) +} + +func ctxResponse(ctx context.Context) (*http.Response, bool) { + return rule.Context[*http.Response](ctx, contextKeyResponse) +} diff --git a/internal/rule/http/option.go b/internal/rule/http/option.go index d7a9c26..01de1a9 100644 --- a/internal/rule/http/option.go +++ b/internal/rule/http/option.go @@ -1,20 +1,18 @@ package http import ( - "net/http" - "forge.cadoles.com/cadoles/bouncer/internal/rule" "github.com/expr-lang/expr" ) -func WithRequestFuncs(r *http.Request) rule.OptionFunc { +func WithRequestFuncs() rule.OptionFunc { return func(opts *rule.Options) { funcs := []expr.Option{ - setRequestURL(r), - setRequestHeaderFunc(r), - addRequestHeaderFunc(r), - delRequestHeadersFunc(r), - setRequestHostFunc(r), + setRequestURLFunc(), + setRequestHeaderFunc(), + addRequestHeaderFunc(), + delRequestHeadersFunc(), + setRequestHostFunc(), } if len(opts.Expr) == 0 { @@ -25,12 +23,12 @@ func WithRequestFuncs(r *http.Request) rule.OptionFunc { } } -func WithResponseFuncs(r *http.Response) rule.OptionFunc { +func WithResponseFuncs() rule.OptionFunc { return func(opts *rule.Options) { funcs := []expr.Option{ - setResponseHeaderFunc(r), - addResponseHeaderFunc(r), - delResponseHeadersFunc(r), + setResponseHeaderFunc(), + addResponseHeaderFunc(), + delResponseHeadersFunc(), } if len(opts.Expr) == 0 { diff --git a/internal/rule/http/request.go b/internal/rule/http/request.go index 7452b82..81bcb90 100644 --- a/internal/rule/http/request.go +++ b/internal/rule/http/request.go @@ -1,109 +1,155 @@ package http import ( + "context" "fmt" - "net/http" "net/url" "strconv" "strings" "time" "forge.cadoles.com/Cadoles/go-proxy/wildcard" + "forge.cadoles.com/cadoles/bouncer/internal/rule" "github.com/expr-lang/expr" "github.com/pkg/errors" ) -func setRequestHostFunc(r *http.Request) expr.Option { +func setRequestHostFunc() expr.Option { return expr.Function( "set_host", func(params ...any) (any, error) { - host := params[0].(string) + ctx, err := rule.Assert[context.Context](params[0]) + if err != nil { + return nil, errors.WithStack(err) + } + + host, err := rule.Assert[string](params[1]) + if err != nil { + return nil, errors.WithStack(err) + } + + r, ok := ctxRequest(ctx) + if !ok { + return nil, errors.New("could not find http request in context") + } + r.Host = host return true, nil }, - new(func(string) bool), + new(func(context.Context, string) bool), ) } -func setRequestURL(r *http.Request) expr.Option { +func setRequestURLFunc() expr.Option { return expr.Function( "set_url", func(params ...any) (any, error) { - rawURL := params[0].(string) + ctx, err := rule.Assert[context.Context](params[0]) + if err != nil { + return nil, errors.WithStack(err) + } + + rawURL, err := rule.Assert[string](params[1]) + if err != nil { + return nil, errors.WithStack(err) + } url, err := url.Parse(rawURL) if err != nil { return false, errors.WithStack(err) } + r, ok := ctxRequest(ctx) + if !ok { + return nil, errors.New("could not find http request in context") + } + r.URL = url return true, nil }, - new(func(string) bool), + new(func(context.Context, string) bool), ) } -func addRequestHeaderFunc(r *http.Request) expr.Option { +func addRequestHeaderFunc() expr.Option { return expr.Function( "add_header", func(params ...any) (any, error) { - name := params[0].(string) - rawValue := params[1] + ctx, err := rule.Assert[context.Context](params[0]) + if err != nil { + return nil, errors.WithStack(err) + } - var value string - switch v := rawValue.(type) { - case []string: - value = strings.Join(v, ",") - case time.Time: - value = strconv.FormatInt(v.UTC().Unix(), 10) - case time.Duration: - value = strconv.FormatInt(int64(v.Seconds()), 10) - default: - value = fmt.Sprintf("%v", rawValue) + name, err := rule.Assert[string](params[1]) + if err != nil { + return nil, errors.WithStack(err) + } + + value := formatValue(params[2]) + + r, ok := ctxRequest(ctx) + if !ok { + return nil, errors.New("could not find http request in context") } r.Header.Add(name, value) return true, nil }, - new(func(string, string) bool), + new(func(context.Context, string, string) bool), ) } -func setRequestHeaderFunc(r *http.Request) expr.Option { +func setRequestHeaderFunc() expr.Option { return expr.Function( "set_header", func(params ...any) (any, error) { - name := params[0].(string) - rawValue := params[1] + ctx, err := rule.Assert[context.Context](params[0]) + if err != nil { + return nil, errors.WithStack(err) + } - var value string - switch v := rawValue.(type) { - case []string: - value = strings.Join(v, ",") - case time.Time: - value = strconv.FormatInt(v.UTC().Unix(), 10) - case time.Duration: - value = strconv.FormatInt(int64(v.Seconds()), 10) - default: - value = fmt.Sprintf("%v", rawValue) + name, err := rule.Assert[string](params[1]) + if err != nil { + return nil, errors.WithStack(err) + } + + value := formatValue(params[2]) + + r, ok := ctxRequest(ctx) + if !ok { + return nil, errors.New("could not find http request in context") } r.Header.Set(name, value) return true, nil }, - new(func(string, string) bool), + new(func(context.Context, string, string) bool), ) } -func delRequestHeadersFunc(r *http.Request) expr.Option { +func delRequestHeadersFunc() expr.Option { return expr.Function( "del_headers", func(params ...any) (any, error) { - pattern := params[0].(string) + ctx, err := rule.Assert[context.Context](params[0]) + if err != nil { + return nil, errors.WithStack(err) + } + + pattern, err := rule.Assert[string](params[1]) + if err != nil { + return nil, errors.WithStack(err) + } + + r, ok := ctxRequest(ctx) + if !ok { + return nil, errors.New("could not find http request in context") + } + deleted := false for key := range r.Header { @@ -117,6 +163,21 @@ func delRequestHeadersFunc(r *http.Request) expr.Option { return deleted, nil }, - new(func(string) bool), + new(func(context.Context, string) bool), ) } + +func formatValue(v any) string { + var value string + switch v := v.(type) { + case []string: + value = strings.Join(v, ",") + case time.Time: + value = strconv.FormatInt(v.UTC().Unix(), 10) + case time.Duration: + value = strconv.FormatInt(int64(v.Seconds()), 10) + default: + value = fmt.Sprintf("%v", v) + } + return value +} diff --git a/internal/rule/http/request_test.go b/internal/rule/http/request_test.go new file mode 100644 index 0000000..1022f27 --- /dev/null +++ b/internal/rule/http/request_test.go @@ -0,0 +1,195 @@ +package http + +import ( + "context" + "net/http" + "testing" + + "forge.cadoles.com/cadoles/bouncer/internal/rule" + "github.com/pkg/errors" +) + +func TestSetRequestHost(t *testing.T) { + type Vars struct { + NewHost string `expr:"newHost"` + } + + engine := createRuleEngine[Vars](t, + rule.WithExpr(setRequestHostFunc()), + rule.WithRules( + "set_host(ctx, vars.newHost)", + ), + ) + + req, err := http.NewRequest("GET", "http://example.net", nil) + if err != nil { + t.Fatalf("%+v", errors.WithStack(err)) + } + + ctx := context.Background() + + ctx = WithRequest(ctx, req) + + vars := Vars{ + NewHost: "foobar", + } + + if _, err := engine.Apply(ctx, vars); err != nil { + t.Fatalf("%+v", errors.WithStack(err)) + } + + if e, g := vars.NewHost, req.Host; e != g { + t.Errorf("req.Host: expected '%v', got '%v'", e, g) + } +} + +func TestSetRequestURL(t *testing.T) { + type Vars struct { + NewURL string `expr:"newURL"` + } + + engine := createRuleEngine[Vars](t, + rule.WithExpr(setRequestURLFunc()), + rule.WithRules( + "set_url(ctx, vars.newURL)", + ), + ) + + req, err := http.NewRequest("GET", "http://example.net", nil) + if err != nil { + t.Fatalf("%+v", errors.WithStack(err)) + } + + ctx := context.Background() + + ctx = WithRequest(ctx, req) + + vars := Vars{ + NewURL: "http://localhost", + } + + if _, err := engine.Apply(ctx, vars); err != nil { + t.Fatalf("%+v", errors.WithStack(err)) + } + + if e, g := vars.NewURL, req.URL.String(); e != g { + t.Errorf("req.URL.String(): expected '%v', got '%v'", e, g) + } +} + +func TestAddRequestHeader(t *testing.T) { + type Vars struct { + NewHeaderKey string `expr:"newHeaderKey"` + NewHeaderValue string `expr:"newHeaderValue"` + } + + engine := createRuleEngine[Vars](t, + rule.WithExpr(addRequestHeaderFunc()), + rule.WithRules( + "add_header(ctx, vars.newHeaderKey, vars.newHeaderValue)", + ), + ) + + req, err := http.NewRequest("GET", "http://example.net", nil) + if err != nil { + t.Fatalf("%+v", errors.WithStack(err)) + } + + ctx := context.Background() + + ctx = WithRequest(ctx, req) + + vars := Vars{ + NewHeaderKey: "X-My-Header", + NewHeaderValue: "foobar", + } + + if _, err := engine.Apply(ctx, vars); err != nil { + t.Fatalf("%+v", errors.WithStack(err)) + } + + if e, g := vars.NewHeaderValue, req.Header.Get(vars.NewHeaderKey); e != g { + t.Errorf("req.Header.Get(vars.NewHeaderKey): expected '%v', got '%v'", e, g) + } +} + +func TestSetRequestHeader(t *testing.T) { + type Vars struct { + HeaderKey string `expr:"headerKey"` + HeaderValue string `expr:"headerValue"` + } + + engine := createRuleEngine[Vars](t, + rule.WithExpr(setRequestHeaderFunc()), + rule.WithRules( + "set_header(ctx, vars.headerKey, vars.headerValue)", + ), + ) + + req, err := http.NewRequest("GET", "http://example.net", nil) + if err != nil { + t.Fatalf("%+v", errors.WithStack(err)) + } + + vars := Vars{ + HeaderKey: "X-My-Header", + HeaderValue: "foobar", + } + + req.Header.Set(vars.HeaderKey, "test") + + ctx := context.Background() + ctx = WithRequest(ctx, req) + + if _, err := engine.Apply(ctx, vars); err != nil { + t.Fatalf("%+v", errors.WithStack(err)) + } + + if e, g := vars.HeaderValue, req.Header.Get(vars.HeaderKey); e != g { + t.Errorf("req.Header.Get(vars.HeaderKey): expected '%v', got '%v'", e, g) + } +} + +func TestDelRequestHeaders(t *testing.T) { + type Vars struct { + HeaderPattern string `expr:"headerPattern"` + } + + engine := createRuleEngine[Vars](t, + rule.WithExpr(delRequestHeadersFunc()), + rule.WithRules( + "del_headers(ctx, vars.headerPattern)", + ), + ) + + req, err := http.NewRequest("GET", "http://example.net", nil) + if err != nil { + t.Fatalf("%+v", errors.WithStack(err)) + } + + vars := Vars{ + HeaderPattern: "X-My-*", + } + + req.Header.Set("X-My-Header", "test") + + ctx := context.Background() + ctx = WithRequest(ctx, req) + + if _, err := engine.Apply(ctx, vars); err != nil { + t.Fatalf("%+v", errors.WithStack(err)) + } + + if val := req.Header.Get("X-My-Header"); val != "" { + t.Errorf("req.Header.Get(\"X-My-Header\") should be empty, got '%v'", val) + } +} + +func createRuleEngine[V any](t *testing.T, funcs ...rule.OptionFunc) *rule.Engine[V] { + engine, err := rule.NewEngine[V](funcs...) + if err != nil { + t.Fatalf("%+v", errors.WithStack(err)) + } + + return engine +} diff --git a/internal/rule/http/response.go b/internal/rule/http/response.go index be7c9ee..36f60d2 100644 --- a/internal/rule/http/response.go +++ b/internal/rule/http/response.go @@ -1,22 +1,33 @@ package http import ( + "context" "fmt" - "net/http" "strconv" "strings" "time" "forge.cadoles.com/Cadoles/go-proxy/wildcard" + "forge.cadoles.com/cadoles/bouncer/internal/rule" "github.com/expr-lang/expr" + "github.com/pkg/errors" ) -func addResponseHeaderFunc(r *http.Response) expr.Option { +func addResponseHeaderFunc() expr.Option { return expr.Function( "add_header", func(params ...any) (any, error) { - name := params[0].(string) - rawValue := params[1] + ctx, err := rule.Assert[context.Context](params[0]) + if err != nil { + return nil, errors.WithStack(err) + } + + name, err := rule.Assert[string](params[1]) + if err != nil { + return nil, errors.WithStack(err) + } + + rawValue := params[2] var value string switch v := rawValue.(type) { @@ -30,20 +41,34 @@ func addResponseHeaderFunc(r *http.Response) expr.Option { value = fmt.Sprintf("%v", rawValue) } + r, ok := ctxResponse(ctx) + if !ok { + return nil, errors.New("could not find http response in context") + } + r.Header.Add(name, value) return true, nil }, - new(func(string, string) bool), + new(func(context.Context, string, string) bool), ) } -func setResponseHeaderFunc(r *http.Response) expr.Option { +func setResponseHeaderFunc() expr.Option { return expr.Function( "set_header", func(params ...any) (any, error) { - name := params[0].(string) - rawValue := params[1] + ctx, err := rule.Assert[context.Context](params[0]) + if err != nil { + return nil, errors.WithStack(err) + } + + name, err := rule.Assert[string](params[1]) + if err != nil { + return nil, errors.WithStack(err) + } + + rawValue := params[2] var value string switch v := rawValue.(type) { @@ -57,19 +82,38 @@ func setResponseHeaderFunc(r *http.Response) expr.Option { value = fmt.Sprintf("%v", rawValue) } + r, ok := ctxResponse(ctx) + if !ok { + return nil, errors.New("could not find http response in context") + } + r.Header.Set(name, value) return true, nil }, - new(func(string, string) bool), + new(func(context.Context, string, string) bool), ) } -func delResponseHeadersFunc(r *http.Response) expr.Option { +func delResponseHeadersFunc() expr.Option { return expr.Function( "del_headers", func(params ...any) (any, error) { - pattern := params[0].(string) + ctx, err := rule.Assert[context.Context](params[0]) + if err != nil { + return nil, errors.WithStack(err) + } + + pattern, err := rule.Assert[string](params[1]) + if err != nil { + return nil, errors.WithStack(err) + } + + r, ok := ctxResponse(ctx) + if !ok { + return nil, errors.New("could not find http response in context") + } + deleted := false for key := range r.Header { @@ -83,6 +127,6 @@ func delResponseHeadersFunc(r *http.Response) expr.Option { return deleted, nil }, - new(func(string) bool), + new(func(context.Context, string) bool), ) } diff --git a/internal/rule/http/response_test.go b/internal/rule/http/response_test.go new file mode 100644 index 0000000..9528d95 --- /dev/null +++ b/internal/rule/http/response_test.go @@ -0,0 +1,139 @@ +package http + +import ( + "context" + "io" + "net/http" + "testing" + + "forge.cadoles.com/cadoles/bouncer/internal/rule" + "github.com/pkg/errors" +) + +func TestAddResponseHeader(t *testing.T) { + type Vars struct { + NewHeaderKey string `expr:"newHeaderKey"` + NewHeaderValue string `expr:"newHeaderValue"` + } + + engine := createRuleEngine[Vars](t, + rule.WithExpr(addResponseHeaderFunc()), + rule.WithRules( + "add_header(ctx, vars.newHeaderKey, vars.newHeaderValue)", + ), + ) + + req, err := http.NewRequest("GET", "http://example.net", nil) + if err != nil { + t.Fatalf("%+v", errors.WithStack(err)) + } + + resp := createResponse(req, http.StatusOK, nil) + + ctx := context.Background() + + ctx = WithResponse(ctx, resp) + + vars := Vars{ + NewHeaderKey: "X-My-Header", + NewHeaderValue: "foobar", + } + + if _, err := engine.Apply(ctx, vars); err != nil { + t.Fatalf("%+v", errors.WithStack(err)) + } + + if e, g := vars.NewHeaderValue, resp.Header.Get(vars.NewHeaderKey); e != g { + t.Errorf("resp.Header.Get(vars.NewHeaderKey): expected '%v', got '%v'", e, g) + } +} + +func TestResponseSetHeader(t *testing.T) { + type Vars struct { + HeaderKey string `expr:"headerKey"` + HeaderValue string `expr:"headerValue"` + } + + engine := createRuleEngine[Vars](t, + rule.WithExpr(setResponseHeaderFunc()), + rule.WithRules( + "set_header(ctx, vars.headerKey, vars.headerValue)", + ), + ) + + req, err := http.NewRequest("GET", "http://example.net", nil) + if err != nil { + t.Fatalf("%+v", errors.WithStack(err)) + } + + resp := createResponse(req, http.StatusOK, nil) + + vars := Vars{ + HeaderKey: "X-My-Header", + HeaderValue: "foobar", + } + + resp.Header.Set(vars.HeaderKey, "test") + + ctx := context.Background() + ctx = WithResponse(ctx, resp) + + if _, err := engine.Apply(ctx, vars); err != nil { + t.Fatalf("%+v", errors.WithStack(err)) + } + + if e, g := vars.HeaderValue, resp.Header.Get(vars.HeaderKey); e != g { + t.Errorf("resp.Header.Get(vars.HeaderKey): expected '%v', got '%v'", e, g) + } +} + +func TestResponseDelHeaders(t *testing.T) { + type Vars struct { + HeaderPattern string `expr:"headerPattern"` + } + + engine := createRuleEngine[Vars](t, + rule.WithExpr(delResponseHeadersFunc()), + rule.WithRules( + "del_headers(ctx, vars.headerPattern)", + ), + ) + + req, err := http.NewRequest("GET", "http://example.net", nil) + if err != nil { + t.Fatalf("%+v", errors.WithStack(err)) + } + + resp := createResponse(req, http.StatusOK, nil) + + vars := Vars{ + HeaderPattern: "X-My-*", + } + + resp.Header.Set("X-My-Header", "test") + + ctx := context.Background() + ctx = WithResponse(ctx, resp) + + if _, err := engine.Apply(ctx, vars); err != nil { + t.Fatalf("%+v", errors.WithStack(err)) + } + + if val := resp.Header.Get("X-My-Header"); val != "" { + t.Errorf("resp.Header.Get(\"X-My-Header\") should be empty, got '%v'", val) + } +} + +func createResponse(req *http.Request, statusCode int, body io.Reader) *http.Response { + return &http.Response{ + Status: http.StatusText(statusCode), + StatusCode: statusCode, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Body: io.NopCloser(body), + ContentLength: -1, + Request: req, + Header: make(http.Header, 0), + } +}
Name Domain Path
{{ $cookie.Name }}