package http import ( "context" "fmt" "io" "net/http" "testing" "time" "forge.cadoles.com/cadoles/bouncer/internal/rule" "github.com/pkg/errors" ) func TestAddResponseHeader(t *testing.T) { type Vars struct { NewHeaderKey string `expr:"newHeaderKey"` NewHeaderValue string `expr:"newHeaderValue"` } engine := createRuleEngine[Vars](t, rule.WithExpr(addResponseHeaderFunc()), rule.WithRules( "add_header(ctx, vars.newHeaderKey, vars.newHeaderValue)", ), ) req, err := http.NewRequest("GET", "http://example.net", nil) if err != nil { t.Fatalf("%+v", errors.WithStack(err)) } resp := createResponse(req, http.StatusOK, nil) ctx := context.Background() ctx = WithResponse(ctx, resp) vars := Vars{ NewHeaderKey: "X-My-Header", NewHeaderValue: "foobar", } if _, err := engine.Apply(ctx, vars); err != nil { t.Fatalf("%+v", errors.WithStack(err)) } if e, g := vars.NewHeaderValue, resp.Header.Get(vars.NewHeaderKey); e != g { t.Errorf("resp.Header.Get(vars.NewHeaderKey): expected '%v', got '%v'", e, g) } } func TestResponseSetHeader(t *testing.T) { type Vars struct { HeaderKey string `expr:"headerKey"` HeaderValue string `expr:"headerValue"` } engine := createRuleEngine[Vars](t, rule.WithExpr(setResponseHeaderFunc()), rule.WithRules( "set_header(ctx, vars.headerKey, vars.headerValue)", ), ) req, err := http.NewRequest("GET", "http://example.net", nil) if err != nil { t.Fatalf("%+v", errors.WithStack(err)) } resp := createResponse(req, http.StatusOK, nil) vars := Vars{ HeaderKey: "X-My-Header", HeaderValue: "foobar", } resp.Header.Set(vars.HeaderKey, "test") ctx := context.Background() ctx = WithResponse(ctx, resp) if _, err := engine.Apply(ctx, vars); err != nil { t.Fatalf("%+v", errors.WithStack(err)) } if e, g := vars.HeaderValue, resp.Header.Get(vars.HeaderKey); e != g { t.Errorf("resp.Header.Get(vars.HeaderKey): expected '%v', got '%v'", e, g) } } func TestResponseDelHeaders(t *testing.T) { type Vars struct { HeaderPattern string `expr:"headerPattern"` } engine := createRuleEngine[Vars](t, rule.WithExpr(delResponseHeadersFunc()), rule.WithRules( "del_headers(ctx, vars.headerPattern)", ), ) req, err := http.NewRequest("GET", "http://example.net", nil) if err != nil { t.Fatalf("%+v", errors.WithStack(err)) } resp := createResponse(req, http.StatusOK, nil) vars := Vars{ HeaderPattern: "X-My-*", } resp.Header.Set("X-My-Header", "test") ctx := context.Background() ctx = WithResponse(ctx, resp) if _, err := engine.Apply(ctx, vars); err != nil { t.Fatalf("%+v", errors.WithStack(err)) } if val := resp.Header.Get("X-My-Header"); val != "" { t.Errorf("resp.Header.Get(\"X-My-Header\") should be empty, got '%v'", val) } } func TestAddResponseCookie(t *testing.T) { type TestCase struct { Cookie map[string]any Check func(t *testing.T, tc TestCase, res *http.Response) ShouldFail bool } testCases := []TestCase{ { Cookie: map[string]any{ "name": "foo", "value": "test", "domain": "example.net", "path": "/custom", "same_site": http.SameSiteStrictMode, "http_only": true, "secure": false, "expires": time.Now().UTC().Truncate(time.Second), }, Check: func(t *testing.T, tc TestCase, res *http.Response) { var cookie *http.Cookie for _, c := range res.Cookies() { if c.Name == tc.Cookie["name"] { cookie = c break } } if cookie == nil { t.Errorf("could not find cookie '%s'", tc.Cookie["name"]) return } if e, g := tc.Cookie["name"], cookie.Name; e != g { t.Errorf("cookie.Name: expected '%v', got '%v'", e, g) } if e, g := tc.Cookie["value"], cookie.Value; e != g { t.Errorf("cookie.Value: expected '%v', got '%v'", e, g) } if e, g := tc.Cookie["domain"], cookie.Domain; e != g { t.Errorf("cookie.Domain: expected '%v', got '%v'", e, g) } if e, g := tc.Cookie["path"], cookie.Path; e != g { t.Errorf("cookie.Path: expected '%v', got '%v'", e, g) } if e, g := tc.Cookie["secure"], cookie.Secure; e != g { t.Errorf("cookie.Secure: expected '%v', got '%v'", e, g) } if e, g := tc.Cookie["http_only"], cookie.HttpOnly; e != g { t.Errorf("cookie.HttpOnly: expected '%v', got '%v'", e, g) } if e, g := tc.Cookie["same_site"], cookie.SameSite; e != g { t.Errorf("cookie.SameSite: expected '%v', got '%v'", e, g) } if e, g := tc.Cookie["expires"], cookie.Expires; e != g { t.Errorf("cookie.Expires: expected '%v', got '%v'", e, g) } }, }, { Cookie: map[string]any{ "name": "foo", "expires": time.Now().UTC().Format(http.TimeFormat), }, Check: func(t *testing.T, tc TestCase, res *http.Response) { var cookie *http.Cookie for _, c := range res.Cookies() { if c.Name == tc.Cookie["name"] { cookie = c break } } if cookie == nil { t.Errorf("could not find cookie '%s'", tc.Cookie["name"]) return } if e, g := tc.Cookie["expires"], cookie.Expires.Format(http.TimeFormat); e != g { t.Errorf("cookie.Expires: expected '%v', got '%v'", e, g) } }, }, } for idx, tc := range testCases { t.Run(fmt.Sprintf("Case_%d", idx), func(t *testing.T) { type Vars struct { NewCookie map[string]any `expr:"new_cookie"` } engine := createRuleEngine[Vars](t, rule.WithExpr(addResponseCookieFunc()), rule.WithRules( `add_cookie(ctx, vars.new_cookie)`, ), ) req, err := http.NewRequest("GET", "http://example.net", nil) if err != nil { t.Fatalf("%+v", errors.WithStack(err)) } resp := createResponse(req, http.StatusOK, nil) vars := Vars{ NewCookie: tc.Cookie, } ctx := context.Background() ctx = WithRequest(ctx, req) ctx = WithResponse(ctx, resp) if _, err := engine.Apply(ctx, vars); err != nil { t.Fatalf("%+v", errors.WithStack(err)) } if tc.ShouldFail { t.Error("engine.Apply() should have failed") } if tc.Check != nil { tc.Check(t, tc, resp) } }) } } func TestGetResponseCookie(t *testing.T) { type Vars struct { CookieName string `expr:"cookieName"` } engine := createRuleEngine[Vars](t, rule.WithExpr(getResponseCookieFunc()), rule.WithRules( "let cookie = get_cookie(ctx, vars.cookieName); cookie.value", ), ) req, err := http.NewRequest("GET", "http://example.net", nil) if err != nil { t.Fatalf("%+v", errors.WithStack(err)) } resp := createResponse(req, http.StatusOK, nil) vars := Vars{ CookieName: "foo", } cookie := &http.Cookie{ Name: vars.CookieName, Value: "bar", } resp.Header.Add("Set-Cookie", cookie.String()) ctx := context.Background() ctx = WithResponse(ctx, resp) results, err := engine.Apply(ctx, vars) if err != nil { t.Fatalf("%+v", errors.WithStack(err)) } if e, g := cookie.Value, results[0]; e != g { t.Errorf("result[0]: expected '%v', got '%v'", e, g) } } func createResponse(req *http.Request, statusCode int, body io.Reader) *http.Response { return &http.Response{ Status: http.StatusText(statusCode), StatusCode: statusCode, Proto: "HTTP/1.1", ProtoMajor: 1, ProtoMinor: 1, Body: io.NopCloser(body), ContentLength: -1, Request: req, Header: make(http.Header, 0), } }