diff --git a/doc/fr/references/layers/rewriter.md b/doc/fr/references/layers/rewriter.md index 049f412..32a5669 100644 --- a/doc/fr/references/layers/rewriter.md +++ b/doc/fr/references/layers/rewriter.md @@ -38,6 +38,33 @@ Supprimer un ou plusieurs entêtes HTTP dont le nom correspond au patron `patter Le patron est défini par une chaîne comprenant un ou plusieurs caractères `*`, signifiant un ou plusieurs caractères arbitraires. +##### `get_cookie(ctx, name string) Cookie` + +Récupère un cookie depuis la requête/réponse (en fonction du contexte d'utilisation). +Retourne `nil` si le cookie n'existe pas. + +**Cookie** + +```js +// Plus d'informations sur https://pkg.go.dev/net/http#Cookie +{ + name: "string", // Nom du cookie + value: "string", // Valeur associée au cookie + path: "string", // Chemin associé au cookie (présent uniquement dans un contexte de réponse) + domain: "string", // Domaine associé au cookie (présent uniquement dans un contexte de réponse) + expires: "string", // Date d'expiration du cookie (présent uniquement dans un contexte de réponse) + max_age: "string", // Age maximum du cookie (présent uniquement dans un contexte de réponse) + secure: "boolean", // Le cookie doit-il être présent uniquement en HTTPS ? (présent uniquement dans un contexte de réponse) + http_only: "boolean", // Le cookie est il accessible en Javascript ? (présent uniquement dans un contexte de réponse) + same_site: "int" // Voir https://pkg.go.dev/net/http#SameSite (présent uniquement dans un contexte de réponse) +} +``` + +##### `set_cookie(ctx, cookie Cookie)` + +Définit un cookie sur la requête/réponse (en fonction du contexte d'utilisation). +Voir la méthode `get_cookie()` pour voir les attributs potentiels. + #### Requête ##### `set_host(ctx, host string)` @@ -48,6 +75,12 @@ Modifier la valeur de l'entête `Host` de la requête. Modifier l'URL du serveur cible. +##### `redirect(ctx, statusCode int, url string)` + +Interrompt la requête et retourne une redirection HTTP au client. + +Le code HTTP utilisé doit être supérieur ou égale à `300` et inférieur à `400` (non inclus). + #### Réponse _Pas de fonctions spécifiques._ @@ -72,7 +105,7 @@ L'URL originale, avant réécriture du `Host` par Bouncer. }, host: "string", // Nom d'hôte (:) de l'URL path: "string", // Chemin de l'URL (format assaini) - rawPath: "string", // Chemin de l'URL (format brut) + raw_path: "string", // Chemin de l'URL (format brut) raw_query: "string", // Variables d'URL (format brut) fragment : "string", // Fragment d'URL (format assaini) raw_fragment : "string" // Fragment d'URL (format brut) @@ -96,7 +129,7 @@ La requête en cours de traitement. }, host: "string", // Nom d'hôte (:) de l'URL path: "string", // Chemin de l'URL (format assaini) - rawPath: "string", // Chemin de l'URL (format brut) + raw_path: "string", // Chemin de l'URL (format brut) raw_query: "string", // Variables d'URL (format brut) fragment : "string", // Fragment d'URL (format assaini) raw_fragment : "string" // Fragment d'URL (format brut) diff --git a/go.mod b/go.mod index b7af9e8..08dff22 100644 --- a/go.mod +++ b/go.mod @@ -1,8 +1,8 @@ module forge.cadoles.com/cadoles/bouncer -go 1.22 +go 1.23 -toolchain go1.22.0 +toolchain go1.23.0 require ( forge.cadoles.com/Cadoles/go-proxy v0.0.0-20240626132607-e1db6466a926 diff --git a/internal/proxy/director/layer/rewriter/api.go b/internal/proxy/director/layer/rewriter/api.go new file mode 100644 index 0000000..ba8e313 --- /dev/null +++ b/internal/proxy/director/layer/rewriter/api.go @@ -0,0 +1,79 @@ +package rewriter + +import ( + "context" + "fmt" + + "forge.cadoles.com/cadoles/bouncer/internal/rule" + "github.com/expr-lang/expr" + "github.com/pkg/errors" +) + +type errRedirect struct { + statusCode int + url string +} + +func (e *errRedirect) StatusCode() int { + return e.statusCode +} + +func (e *errRedirect) URL() string { + return e.url +} + +func (e *errRedirect) Error() string { + return fmt.Sprintf("redirect %d %s", e.statusCode, e.url) +} + +func newErrRedirect(statusCode int, url string) *errRedirect { + return &errRedirect{ + url: url, + statusCode: statusCode, + } +} + +var _ error = &errRedirect{} + +func redirectFunc() expr.Option { + return expr.Function( + "redirect", + func(params ...any) (any, error) { + _, err := rule.Assert[context.Context](params[0]) + if err != nil { + return nil, errors.WithStack(err) + } + + statusCode, err := rule.Assert[int](params[1]) + if err != nil { + return nil, errors.WithStack(err) + } + + if statusCode < 300 || statusCode >= 400 { + return nil, errors.Errorf("unexpected redirect status code '%d'", statusCode) + } + + url, err := rule.Assert[string](params[2]) + if err != nil { + return nil, errors.WithStack(err) + } + + return nil, newErrRedirect(statusCode, url) + }, + new(func(context.Context, int, string) bool), + ) +} + +func WithRewriterFuncs() rule.OptionFunc { + return func(opts *rule.Options) { + funcs := []expr.Option{ + redirectFunc(), + } + + if len(opts.Expr) == 0 { + opts.Expr = make([]expr.Option, 0) + } + + opts.Expr = append(opts.Expr, funcs...) + } +} diff --git a/internal/proxy/director/layer/rewriter/layer.go b/internal/proxy/director/layer/rewriter/layer.go index 8982339..7f44b8a 100644 --- a/internal/proxy/director/layer/rewriter/layer.go +++ b/internal/proxy/director/layer/rewriter/layer.go @@ -46,6 +46,12 @@ func (l *Layer) Middleware(layer *store.Layer) proxy.Middleware { } if err := l.applyRequestRules(ctx, r, layer.Revision, options); err != nil { + var redirect *errRedirect + if errors.As(err, &redirect) { + http.Redirect(w, r, redirect.URL(), redirect.StatusCode()) + return + } + logger.Error(ctx, "could not apply request rules", logger.E(errors.WithStack(err))) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) @@ -88,6 +94,7 @@ func New(funcs ...OptionFunc) *Layer { engine, err := rule.NewEngine[*RequestVars]( rule.WithRules(options.Rules.Request...), ruleHTTP.WithRequestFuncs(), + WithRewriterFuncs(), ) if err != nil { return nil, errors.WithStack(err) diff --git a/internal/proxy/director/layer/rewriter/rules.go b/internal/proxy/director/layer/rewriter/rules.go index e622c48..e4d6148 100644 --- a/internal/proxy/director/layer/rewriter/rules.go +++ b/internal/proxy/director/layer/rewriter/rules.go @@ -3,6 +3,7 @@ package rewriter import ( "context" "net/http" + "net/url" "forge.cadoles.com/cadoles/bouncer/internal/proxy/director" "forge.cadoles.com/cadoles/bouncer/internal/rule" @@ -27,6 +28,26 @@ type URLVar struct { RawFragment string `expr:"raw_fragment"` } +func fromURL(url *url.URL) URLVar { + return URLVar{ + Scheme: url.Scheme, + Opaque: url.Opaque, + User: UserVar{ + Username: url.User.Username(), + Password: func() string { + passwd, _ := url.User.Password() + return passwd + }(), + }, + Host: url.Host, + Path: url.Path, + RawPath: url.RawPath, + RawQuery: url.RawQuery, + Fragment: url.Fragment, + RawFragment: url.RawFragment, + } +} + type UserVar struct { Username string `expr:"username"` Password string `expr:"password"` @@ -48,6 +69,24 @@ type RequestVar struct { RequestURI string `expr:"request_uri"` } +func fromRequest(r *http.Request) RequestVar { + return RequestVar{ + Method: r.Method, + URL: fromURL(r.URL), + RawURL: r.URL.String(), + Proto: r.Proto, + ProtoMajor: r.ProtoMajor, + ProtoMinor: r.ProtoMinor, + Header: r.Header, + ContentLength: r.ContentLength, + TransferEncoding: r.TransferEncoding, + Host: r.Host, + Trailer: r.Trailer, + RemoteAddr: r.RemoteAddr, + RequestURI: r.RequestURI, + } +} + func (l *Layer) applyRequestRules(ctx context.Context, r *http.Request, layerRevision int, options *LayerOptions) error { rules := options.Rules.Request if len(rules) == 0 { @@ -65,54 +104,8 @@ func (l *Layer) applyRequestRules(ctx context.Context, r *http.Request, layerRev } vars := &RequestVars{ - OriginalURL: URLVar{ - Scheme: originalURL.Scheme, - Opaque: originalURL.Opaque, - User: UserVar{ - Username: originalURL.User.Username(), - Password: func() string { - passwd, _ := originalURL.User.Password() - return passwd - }(), - }, - Host: originalURL.Host, - Path: originalURL.Path, - RawPath: originalURL.RawPath, - RawQuery: originalURL.RawQuery, - Fragment: originalURL.Fragment, - RawFragment: originalURL.RawFragment, - }, - Request: RequestVar{ - Method: r.Method, - URL: URLVar{ - Scheme: r.URL.Scheme, - Opaque: r.URL.Opaque, - User: UserVar{ - Username: r.URL.User.Username(), - Password: func() string { - passwd, _ := r.URL.User.Password() - return passwd - }(), - }, - Host: r.URL.Host, - Path: r.URL.Path, - RawPath: r.URL.RawPath, - RawQuery: r.URL.RawQuery, - Fragment: r.URL.Fragment, - RawFragment: r.URL.RawFragment, - }, - RawURL: r.URL.String(), - Proto: r.Proto, - ProtoMajor: r.ProtoMajor, - ProtoMinor: r.ProtoMinor, - Header: r.Header, - ContentLength: r.ContentLength, - TransferEncoding: r.TransferEncoding, - Host: r.Host, - Trailer: r.Trailer, - RemoteAddr: r.RemoteAddr, - RequestURI: r.RequestURI, - }, + OriginalURL: fromURL(originalURL), + Request: fromRequest(r), } ctx = ruleHTTP.WithRequest(ctx, r) @@ -169,54 +162,8 @@ func (l *Layer) applyResponseRules(ctx context.Context, r *http.Response, layerR } vars := &ResponseVars{ - OriginalURL: URLVar{ - Scheme: originalURL.Scheme, - Opaque: originalURL.Opaque, - User: UserVar{ - Username: originalURL.User.Username(), - Password: func() string { - passwd, _ := originalURL.User.Password() - return passwd - }(), - }, - Host: originalURL.Host, - Path: originalURL.Path, - RawPath: originalURL.RawPath, - RawQuery: originalURL.RawQuery, - Fragment: originalURL.Fragment, - RawFragment: originalURL.RawFragment, - }, - Request: RequestVar{ - Method: r.Request.Method, - URL: URLVar{ - Scheme: r.Request.URL.Scheme, - Opaque: r.Request.URL.Opaque, - User: UserVar{ - Username: r.Request.URL.User.Username(), - Password: func() string { - passwd, _ := r.Request.URL.User.Password() - return passwd - }(), - }, - Host: r.Request.URL.Host, - Path: r.Request.URL.Path, - RawPath: r.Request.URL.RawPath, - RawQuery: r.Request.URL.RawQuery, - Fragment: r.Request.URL.Fragment, - RawFragment: r.Request.URL.RawFragment, - }, - RawURL: r.Request.URL.String(), - Proto: r.Request.Proto, - ProtoMajor: r.Request.ProtoMajor, - ProtoMinor: r.Request.ProtoMinor, - Header: r.Request.Header, - ContentLength: r.Request.ContentLength, - TransferEncoding: r.Request.TransferEncoding, - Host: r.Request.Host, - Trailer: r.Request.Trailer, - RemoteAddr: r.Request.RemoteAddr, - RequestURI: r.Request.RequestURI, - }, + OriginalURL: fromURL(originalURL), + Request: fromRequest(r.Request), Response: ResponseVar{ Proto: r.Proto, ProtoMajor: r.ProtoMajor, @@ -231,6 +178,7 @@ func (l *Layer) applyResponseRules(ctx context.Context, r *http.Response, layerR } ctx = ruleHTTP.WithResponse(ctx, r) + ctx = ruleHTTP.WithRequest(ctx, r.Request) if _, err := engine.Apply(ctx, vars); err != nil { return errors.WithStack(err) diff --git a/internal/rule/http/context.go b/internal/rule/http/context.go index 4eda60b..9de9991 100644 --- a/internal/rule/http/context.go +++ b/internal/rule/http/context.go @@ -22,10 +22,10 @@ func WithResponse(ctx context.Context, r *http.Response) context.Context { return context.WithValue(ctx, contextKeyResponse, r) } -func ctxRequest(ctx context.Context) (*http.Request, bool) { +func CtxRequest(ctx context.Context) (*http.Request, bool) { return rule.Context[*http.Request](ctx, contextKeyRequest) } -func ctxResponse(ctx context.Context) (*http.Response, bool) { +func CtxResponse(ctx context.Context) (*http.Response, bool) { return rule.Context[*http.Response](ctx, contextKeyResponse) } diff --git a/internal/rule/http/option.go b/internal/rule/http/option.go index 01de1a9..15e813e 100644 --- a/internal/rule/http/option.go +++ b/internal/rule/http/option.go @@ -13,6 +13,8 @@ func WithRequestFuncs() rule.OptionFunc { addRequestHeaderFunc(), delRequestHeadersFunc(), setRequestHostFunc(), + getRequestCookieFunc(), + addRequestCookieFunc(), } if len(opts.Expr) == 0 { @@ -29,6 +31,8 @@ func WithResponseFuncs() rule.OptionFunc { setResponseHeaderFunc(), addResponseHeaderFunc(), delResponseHeadersFunc(), + addResponseCookieFunc(), + getResponseCookieFunc(), } if len(opts.Expr) == 0 { diff --git a/internal/rule/http/request.go b/internal/rule/http/request.go index 81bcb90..80c59ee 100644 --- a/internal/rule/http/request.go +++ b/internal/rule/http/request.go @@ -3,6 +3,7 @@ package http import ( "context" "fmt" + "net/http" "net/url" "strconv" "strings" @@ -28,7 +29,7 @@ func setRequestHostFunc() expr.Option { return nil, errors.WithStack(err) } - r, ok := ctxRequest(ctx) + r, ok := CtxRequest(ctx) if !ok { return nil, errors.New("could not find http request in context") } @@ -60,7 +61,7 @@ func setRequestURLFunc() expr.Option { return false, errors.WithStack(err) } - r, ok := ctxRequest(ctx) + r, ok := CtxRequest(ctx) if !ok { return nil, errors.New("could not find http request in context") } @@ -89,7 +90,7 @@ func addRequestHeaderFunc() expr.Option { value := formatValue(params[2]) - r, ok := ctxRequest(ctx) + r, ok := CtxRequest(ctx) if !ok { return nil, errors.New("could not find http request in context") } @@ -118,7 +119,7 @@ func setRequestHeaderFunc() expr.Option { value := formatValue(params[2]) - r, ok := ctxRequest(ctx) + r, ok := CtxRequest(ctx) if !ok { return nil, errors.New("could not find http request in context") } @@ -145,7 +146,7 @@ func delRequestHeadersFunc() expr.Option { return nil, errors.WithStack(err) } - r, ok := ctxRequest(ctx) + r, ok := CtxRequest(ctx) if !ok { return nil, errors.New("could not find http request in context") } @@ -167,6 +168,145 @@ func delRequestHeadersFunc() expr.Option { ) } +type CookieVar struct { + Name string `expr:"name"` + Value string `expr:"value"` + Path string `expr:"path"` + Domain string `expr:"domain"` + Expires time.Time `expr:"expires"` + MaxAge int `expr:"max_age"` + Secure bool `expr:"secure"` + HttpOnly bool `expr:"http_only"` + SameSite http.SameSite `expr:"same_site"` +} + +func getRequestCookieFunc() expr.Option { + return expr.Function( + "get_cookie", + func(params ...any) (any, error) { + ctx, err := rule.Assert[context.Context](params[0]) + if err != nil { + return nil, errors.WithStack(err) + } + + name, err := rule.Assert[string](params[1]) + if err != nil { + return nil, errors.WithStack(err) + } + + r, ok := CtxRequest(ctx) + if !ok { + return nil, errors.New("could not find http request in context") + } + + cookie, err := r.Cookie(name) + if err != nil && !errors.Is(err, http.ErrNoCookie) { + return nil, errors.WithStack(err) + } + + if cookie == nil { + return nil, nil + } + + return CookieVar{ + Name: cookie.Name, + Value: cookie.Value, + Path: cookie.Path, + Domain: cookie.Domain, + Expires: cookie.Expires, + MaxAge: cookie.MaxAge, + Secure: cookie.Secure, + HttpOnly: cookie.HttpOnly, + SameSite: cookie.SameSite, + }, nil + }, + new(func(context.Context, string) CookieVar), + ) +} + +func addRequestCookieFunc() expr.Option { + return expr.Function( + "add_cookie", + func(params ...any) (any, error) { + ctx, err := rule.Assert[context.Context](params[0]) + if err != nil { + return nil, errors.WithStack(err) + } + + values, err := rule.Assert[map[string]any](params[1]) + if err != nil { + return nil, errors.WithStack(err) + } + + cookie, err := cookieFrom(values) + if err != nil { + return nil, errors.WithStack(err) + } + + r, ok := CtxRequest(ctx) + if !ok { + return nil, errors.New("could not find http request in context") + } + + r.AddCookie(cookie) + + return true, nil + }, + new(func(context.Context, map[string]any) bool), + ) +} + +func cookieFrom(values map[string]any) (*http.Cookie, error) { + cookie := &http.Cookie{} + + if name, ok := values["name"].(string); ok { + cookie.Name = name + } + + if value, ok := values["value"].(string); ok { + cookie.Value = value + } + + if domain, ok := values["domain"].(string); ok { + cookie.Domain = domain + } + + if path, ok := values["path"].(string); ok { + cookie.Path = path + } + + if httpOnly, ok := values["http_only"].(bool); ok { + cookie.HttpOnly = httpOnly + } + + if maxAge, ok := values["max_age"].(int); ok { + cookie.MaxAge = maxAge + } + + if secure, ok := values["secure"].(bool); ok { + cookie.Secure = secure + } + + if sameSite, ok := values["same_site"].(http.SameSite); ok { + cookie.SameSite = sameSite + } else if sameSite, ok := values["same_site"].(int); ok { + cookie.SameSite = http.SameSite(sameSite) + } + + if expires, ok := values["expires"].(time.Time); ok { + cookie.Expires = expires + } else if rawExpires, ok := values["expires"].(string); ok { + expires, err := time.Parse(http.TimeFormat, rawExpires) + if err != nil { + return nil, errors.WithStack(err) + } + + cookie.Expires = expires + } + + return cookie, nil +} + func formatValue(v any) string { var value string switch v := v.(type) { diff --git a/internal/rule/http/request_test.go b/internal/rule/http/request_test.go index 1022f27..a4848bf 100644 --- a/internal/rule/http/request_test.go +++ b/internal/rule/http/request_test.go @@ -2,6 +2,7 @@ package http import ( "context" + "fmt" "net/http" "testing" @@ -185,6 +186,134 @@ func TestDelRequestHeaders(t *testing.T) { } } +func TestAddRequestCookie(t *testing.T) { + type TestCase struct { + Cookie map[string]any + Check func(t *testing.T, tc TestCase, req *http.Request) + ShouldFail bool + } + + testCases := []TestCase{ + { + Cookie: map[string]any{ + "name": "test", + }, + Check: func(t *testing.T, tc TestCase, req *http.Request) { + cookie, err := req.Cookie(tc.Cookie["name"].(string)) + if err != nil { + t.Errorf("%+v", errors.WithStack(err)) + return + } + + if e, g := tc.Cookie["name"], cookie.Name; e != g { + t.Errorf("cookie.Name: expected '%v', got '%v'", e, g) + } + }, + }, + { + Cookie: map[string]any{ + "name": "foo", + "value": "test", + }, + Check: func(t *testing.T, tc TestCase, req *http.Request) { + cookie, err := req.Cookie(tc.Cookie["name"].(string)) + if err != nil { + t.Errorf("%+v", errors.WithStack(err)) + return + } + + if e, g := tc.Cookie["name"], cookie.Name; e != g { + t.Errorf("cookie.Name: expected '%v', got '%v'", e, g) + } + + if e, g := tc.Cookie["value"], cookie.Value; e != g { + t.Errorf("cookie.Value: expected '%v', got '%v'", e, g) + } + }, + }, + } + + for idx, tc := range testCases { + t.Run(fmt.Sprintf("Case_%d", idx), func(t *testing.T) { + type Vars struct { + NewCookie map[string]any `expr:"new_cookie"` + } + + engine := createRuleEngine[Vars](t, + rule.WithExpr(addRequestCookieFunc()), + rule.WithRules( + `add_cookie(ctx, vars.new_cookie)`, + ), + ) + + req, err := http.NewRequest("GET", "http://example.net", nil) + if err != nil { + t.Fatalf("%+v", errors.WithStack(err)) + } + + vars := Vars{ + NewCookie: tc.Cookie, + } + + ctx := context.Background() + ctx = WithRequest(ctx, req) + + if _, err := engine.Apply(ctx, vars); err != nil { + t.Fatalf("%+v", errors.WithStack(err)) + } + + if tc.ShouldFail { + t.Error("engine.Apply() should have failed") + } + + if tc.Check != nil { + tc.Check(t, tc, req) + } + }) + } +} + +func TestGetRequestCookie(t *testing.T) { + type Vars struct { + CookieName string `expr:"cookieName"` + } + + engine := createRuleEngine[Vars](t, + rule.WithExpr(getRequestCookieFunc()), + rule.WithRules( + "let cookie = get_cookie(ctx, vars.cookieName); cookie.value", + ), + ) + + req, err := http.NewRequest("GET", "http://example.net", nil) + if err != nil { + t.Fatalf("%+v", errors.WithStack(err)) + } + + vars := Vars{ + CookieName: "foo", + } + + cookie := &http.Cookie{ + Name: vars.CookieName, + Value: "bar", + } + + req.AddCookie(cookie) + + ctx := context.Background() + ctx = WithRequest(ctx, req) + + results, err := engine.Apply(ctx, vars) + if err != nil { + t.Fatalf("%+v", errors.WithStack(err)) + } + + if e, g := cookie.Value, results[0]; e != g { + t.Errorf("result[0]: expected '%v', got '%v'", e, g) + } +} + func createRuleEngine[V any](t *testing.T, funcs ...rule.OptionFunc) *rule.Engine[V] { engine, err := rule.NewEngine[V](funcs...) if err != nil { diff --git a/internal/rule/http/response.go b/internal/rule/http/response.go index 36f60d2..c065399 100644 --- a/internal/rule/http/response.go +++ b/internal/rule/http/response.go @@ -3,6 +3,7 @@ package http import ( "context" "fmt" + "net/http" "strconv" "strings" "time" @@ -41,7 +42,7 @@ func addResponseHeaderFunc() expr.Option { value = fmt.Sprintf("%v", rawValue) } - r, ok := ctxResponse(ctx) + r, ok := CtxResponse(ctx) if !ok { return nil, errors.New("could not find http response in context") } @@ -82,7 +83,7 @@ func setResponseHeaderFunc() expr.Option { value = fmt.Sprintf("%v", rawValue) } - r, ok := ctxResponse(ctx) + r, ok := CtxResponse(ctx) if !ok { return nil, errors.New("could not find http response in context") } @@ -109,7 +110,7 @@ func delResponseHeadersFunc() expr.Option { return nil, errors.WithStack(err) } - r, ok := ctxResponse(ctx) + r, ok := CtxResponse(ctx) if !ok { return nil, errors.New("could not find http response in context") } @@ -130,3 +131,84 @@ func delResponseHeadersFunc() expr.Option { new(func(context.Context, string) bool), ) } + +func addResponseCookieFunc() expr.Option { + return expr.Function( + "add_cookie", + func(params ...any) (any, error) { + ctx, err := rule.Assert[context.Context](params[0]) + if err != nil { + return nil, errors.WithStack(err) + } + + values, err := rule.Assert[map[string]any](params[1]) + if err != nil { + return nil, errors.WithStack(err) + } + + cookie, err := cookieFrom(values) + if err != nil { + return nil, errors.WithStack(err) + } + + r, ok := CtxResponse(ctx) + if !ok { + return nil, errors.New("could not find http request in context") + } + + r.Header.Add("Set-Cookie", cookie.String()) + + return true, nil + }, + new(func(context.Context, map[string]any) bool), + ) +} + +func getResponseCookieFunc() expr.Option { + return expr.Function( + "get_cookie", + func(params ...any) (any, error) { + ctx, err := rule.Assert[context.Context](params[0]) + if err != nil { + return nil, errors.WithStack(err) + } + + name, err := rule.Assert[string](params[1]) + if err != nil { + return nil, errors.WithStack(err) + } + + res, ok := CtxResponse(ctx) + if !ok { + return nil, errors.New("could not find http response in context") + } + + var cookie *http.Cookie + for _, c := range res.Cookies() { + if c.Name != name { + continue + } + + cookie = c + break + } + + if cookie == nil { + return nil, nil + } + + return CookieVar{ + Name: cookie.Name, + Value: cookie.Value, + Path: cookie.Path, + Domain: cookie.Domain, + Expires: cookie.Expires, + MaxAge: cookie.MaxAge, + Secure: cookie.Secure, + HttpOnly: cookie.HttpOnly, + SameSite: cookie.SameSite, + }, nil + }, + new(func(context.Context, string) CookieVar), + ) +} diff --git a/internal/rule/http/response_test.go b/internal/rule/http/response_test.go index 9528d95..f600ae6 100644 --- a/internal/rule/http/response_test.go +++ b/internal/rule/http/response_test.go @@ -2,9 +2,11 @@ package http import ( "context" + "fmt" "io" "net/http" "testing" + "time" "forge.cadoles.com/cadoles/bouncer/internal/rule" "github.com/pkg/errors" @@ -124,6 +126,182 @@ func TestResponseDelHeaders(t *testing.T) { } } +func TestAddResponseCookie(t *testing.T) { + type TestCase struct { + Cookie map[string]any + Check func(t *testing.T, tc TestCase, res *http.Response) + ShouldFail bool + } + + testCases := []TestCase{ + { + Cookie: map[string]any{ + "name": "foo", + "value": "test", + "domain": "example.net", + "path": "/custom", + "same_site": http.SameSiteStrictMode, + "http_only": true, + "secure": false, + "expires": time.Now().UTC().Truncate(time.Second), + }, + Check: func(t *testing.T, tc TestCase, res *http.Response) { + var cookie *http.Cookie + for _, c := range res.Cookies() { + if c.Name == tc.Cookie["name"] { + cookie = c + break + } + } + if cookie == nil { + t.Errorf("could not find cookie '%s'", tc.Cookie["name"]) + return + } + + if e, g := tc.Cookie["name"], cookie.Name; e != g { + t.Errorf("cookie.Name: expected '%v', got '%v'", e, g) + } + + if e, g := tc.Cookie["value"], cookie.Value; e != g { + t.Errorf("cookie.Value: expected '%v', got '%v'", e, g) + } + + if e, g := tc.Cookie["domain"], cookie.Domain; e != g { + t.Errorf("cookie.Domain: expected '%v', got '%v'", e, g) + } + + if e, g := tc.Cookie["path"], cookie.Path; e != g { + t.Errorf("cookie.Path: expected '%v', got '%v'", e, g) + } + + if e, g := tc.Cookie["secure"], cookie.Secure; e != g { + t.Errorf("cookie.Secure: expected '%v', got '%v'", e, g) + } + + if e, g := tc.Cookie["http_only"], cookie.HttpOnly; e != g { + t.Errorf("cookie.HttpOnly: expected '%v', got '%v'", e, g) + } + + if e, g := tc.Cookie["same_site"], cookie.SameSite; e != g { + t.Errorf("cookie.SameSite: expected '%v', got '%v'", e, g) + } + + if e, g := tc.Cookie["expires"], cookie.Expires; e != g { + t.Errorf("cookie.Expires: expected '%v', got '%v'", e, g) + } + }, + }, + { + Cookie: map[string]any{ + "name": "foo", + "expires": time.Now().UTC().Format(http.TimeFormat), + }, + Check: func(t *testing.T, tc TestCase, res *http.Response) { + var cookie *http.Cookie + for _, c := range res.Cookies() { + if c.Name == tc.Cookie["name"] { + cookie = c + break + } + } + if cookie == nil { + t.Errorf("could not find cookie '%s'", tc.Cookie["name"]) + return + } + + if e, g := tc.Cookie["expires"], cookie.Expires.Format(http.TimeFormat); e != g { + t.Errorf("cookie.Expires: expected '%v', got '%v'", e, g) + } + }, + }, + } + + for idx, tc := range testCases { + t.Run(fmt.Sprintf("Case_%d", idx), func(t *testing.T) { + type Vars struct { + NewCookie map[string]any `expr:"new_cookie"` + } + + engine := createRuleEngine[Vars](t, + rule.WithExpr(addResponseCookieFunc()), + rule.WithRules( + `add_cookie(ctx, vars.new_cookie)`, + ), + ) + + req, err := http.NewRequest("GET", "http://example.net", nil) + if err != nil { + t.Fatalf("%+v", errors.WithStack(err)) + } + + resp := createResponse(req, http.StatusOK, nil) + + vars := Vars{ + NewCookie: tc.Cookie, + } + + ctx := context.Background() + ctx = WithRequest(ctx, req) + ctx = WithResponse(ctx, resp) + + if _, err := engine.Apply(ctx, vars); err != nil { + t.Fatalf("%+v", errors.WithStack(err)) + } + + if tc.ShouldFail { + t.Error("engine.Apply() should have failed") + } + + if tc.Check != nil { + tc.Check(t, tc, resp) + } + }) + } +} + +func TestGetResponseCookie(t *testing.T) { + type Vars struct { + CookieName string `expr:"cookieName"` + } + + engine := createRuleEngine[Vars](t, + rule.WithExpr(getResponseCookieFunc()), + rule.WithRules( + "let cookie = get_cookie(ctx, vars.cookieName); cookie.value", + ), + ) + + req, err := http.NewRequest("GET", "http://example.net", nil) + if err != nil { + t.Fatalf("%+v", errors.WithStack(err)) + } + + resp := createResponse(req, http.StatusOK, nil) + + vars := Vars{ + CookieName: "foo", + } + + cookie := &http.Cookie{ + Name: vars.CookieName, + Value: "bar", + } + + resp.Header.Add("Set-Cookie", cookie.String()) + + ctx := context.Background() + ctx = WithResponse(ctx, resp) + + results, err := engine.Apply(ctx, vars) + if err != nil { + t.Fatalf("%+v", errors.WithStack(err)) + } + + if e, g := cookie.Value, results[0]; e != g { + t.Errorf("result[0]: expected '%v', got '%v'", e, g) + } +} + func createResponse(req *http.Request, statusCode int, body io.Reader) *http.Response { return &http.Response{ Status: http.StatusText(statusCode),