feat: reusable rule engine to prevent memory reallocation
All checks were successful
Cadoles/bouncer/pipeline/pr-develop This commit looks good
All checks were successful
Cadoles/bouncer/pipeline/pr-develop This commit looks good
This commit is contained in:
@ -1,16 +1,28 @@
|
||||
package rule
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/expr-lang/expr"
|
||||
"github.com/expr-lang/expr/vm"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type Engine[E any] struct {
|
||||
type Engine[V any] struct {
|
||||
rules []*vm.Program
|
||||
}
|
||||
|
||||
func (e *Engine[E]) Apply(env E) ([]any, error) {
|
||||
func (e *Engine[V]) Apply(ctx context.Context, vars V) ([]any, error) {
|
||||
type Env[V any] struct {
|
||||
Context context.Context `expr:"ctx"`
|
||||
Vars V `expr:"vars"`
|
||||
}
|
||||
|
||||
env := Env[V]{
|
||||
Context: ctx,
|
||||
Vars: vars,
|
||||
}
|
||||
|
||||
results := make([]any, 0, len(e.rules))
|
||||
for i, r := range e.rules {
|
||||
result, err := expr.Run(r, env)
|
||||
@ -42,3 +54,26 @@ func NewEngine[E any](funcs ...OptionFunc) (*Engine[E], error) {
|
||||
|
||||
return engine, nil
|
||||
}
|
||||
|
||||
func Context[T any](ctx context.Context, key any) (T, bool) {
|
||||
raw := ctx.Value(key)
|
||||
if raw == nil {
|
||||
return *new(T), false
|
||||
}
|
||||
|
||||
value, err := Assert[T](raw)
|
||||
if err != nil {
|
||||
return *new(T), false
|
||||
}
|
||||
|
||||
return value, true
|
||||
}
|
||||
|
||||
func Assert[T any](raw any) (T, error) {
|
||||
value, ok := raw.(T)
|
||||
if !ok {
|
||||
return *new(T), errors.Errorf("unexpected value '%T'", value)
|
||||
}
|
||||
|
||||
return value, nil
|
||||
}
|
||||
|
31
internal/rule/http/context.go
Normal file
31
internal/rule/http/context.go
Normal file
@ -0,0 +1,31 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
|
||||
"forge.cadoles.com/cadoles/bouncer/internal/rule"
|
||||
)
|
||||
|
||||
type contextKey string
|
||||
|
||||
const (
|
||||
contextKeyRequest contextKey = "request"
|
||||
contextKeyResponse contextKey = "response"
|
||||
)
|
||||
|
||||
func WithRequest(ctx context.Context, r *http.Request) context.Context {
|
||||
return context.WithValue(ctx, contextKeyRequest, r)
|
||||
}
|
||||
|
||||
func WithResponse(ctx context.Context, r *http.Response) context.Context {
|
||||
return context.WithValue(ctx, contextKeyResponse, r)
|
||||
}
|
||||
|
||||
func ctxRequest(ctx context.Context) (*http.Request, bool) {
|
||||
return rule.Context[*http.Request](ctx, contextKeyRequest)
|
||||
}
|
||||
|
||||
func ctxResponse(ctx context.Context) (*http.Response, bool) {
|
||||
return rule.Context[*http.Response](ctx, contextKeyResponse)
|
||||
}
|
@ -1,20 +1,18 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"forge.cadoles.com/cadoles/bouncer/internal/rule"
|
||||
"github.com/expr-lang/expr"
|
||||
)
|
||||
|
||||
func WithRequestFuncs(r *http.Request) rule.OptionFunc {
|
||||
func WithRequestFuncs() rule.OptionFunc {
|
||||
return func(opts *rule.Options) {
|
||||
funcs := []expr.Option{
|
||||
setRequestURL(r),
|
||||
setRequestHeaderFunc(r),
|
||||
addRequestHeaderFunc(r),
|
||||
delRequestHeadersFunc(r),
|
||||
setRequestHostFunc(r),
|
||||
setRequestURLFunc(),
|
||||
setRequestHeaderFunc(),
|
||||
addRequestHeaderFunc(),
|
||||
delRequestHeadersFunc(),
|
||||
setRequestHostFunc(),
|
||||
}
|
||||
|
||||
if len(opts.Expr) == 0 {
|
||||
@ -25,12 +23,12 @@ func WithRequestFuncs(r *http.Request) rule.OptionFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func WithResponseFuncs(r *http.Response) rule.OptionFunc {
|
||||
func WithResponseFuncs() rule.OptionFunc {
|
||||
return func(opts *rule.Options) {
|
||||
funcs := []expr.Option{
|
||||
setResponseHeaderFunc(r),
|
||||
addResponseHeaderFunc(r),
|
||||
delResponseHeadersFunc(r),
|
||||
setResponseHeaderFunc(),
|
||||
addResponseHeaderFunc(),
|
||||
delResponseHeadersFunc(),
|
||||
}
|
||||
|
||||
if len(opts.Expr) == 0 {
|
||||
|
@ -1,109 +1,155 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"forge.cadoles.com/Cadoles/go-proxy/wildcard"
|
||||
"forge.cadoles.com/cadoles/bouncer/internal/rule"
|
||||
"github.com/expr-lang/expr"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func setRequestHostFunc(r *http.Request) expr.Option {
|
||||
func setRequestHostFunc() expr.Option {
|
||||
return expr.Function(
|
||||
"set_host",
|
||||
func(params ...any) (any, error) {
|
||||
host := params[0].(string)
|
||||
ctx, err := rule.Assert[context.Context](params[0])
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
host, err := rule.Assert[string](params[1])
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
r, ok := ctxRequest(ctx)
|
||||
if !ok {
|
||||
return nil, errors.New("could not find http request in context")
|
||||
}
|
||||
|
||||
r.Host = host
|
||||
|
||||
return true, nil
|
||||
},
|
||||
new(func(string) bool),
|
||||
new(func(context.Context, string) bool),
|
||||
)
|
||||
}
|
||||
|
||||
func setRequestURL(r *http.Request) expr.Option {
|
||||
func setRequestURLFunc() expr.Option {
|
||||
return expr.Function(
|
||||
"set_url",
|
||||
func(params ...any) (any, error) {
|
||||
rawURL := params[0].(string)
|
||||
ctx, err := rule.Assert[context.Context](params[0])
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
rawURL, err := rule.Assert[string](params[1])
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
url, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
return false, errors.WithStack(err)
|
||||
}
|
||||
|
||||
r, ok := ctxRequest(ctx)
|
||||
if !ok {
|
||||
return nil, errors.New("could not find http request in context")
|
||||
}
|
||||
|
||||
r.URL = url
|
||||
|
||||
return true, nil
|
||||
},
|
||||
new(func(string) bool),
|
||||
new(func(context.Context, string) bool),
|
||||
)
|
||||
}
|
||||
|
||||
func addRequestHeaderFunc(r *http.Request) expr.Option {
|
||||
func addRequestHeaderFunc() expr.Option {
|
||||
return expr.Function(
|
||||
"add_header",
|
||||
func(params ...any) (any, error) {
|
||||
name := params[0].(string)
|
||||
rawValue := params[1]
|
||||
ctx, err := rule.Assert[context.Context](params[0])
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
var value string
|
||||
switch v := rawValue.(type) {
|
||||
case []string:
|
||||
value = strings.Join(v, ",")
|
||||
case time.Time:
|
||||
value = strconv.FormatInt(v.UTC().Unix(), 10)
|
||||
case time.Duration:
|
||||
value = strconv.FormatInt(int64(v.Seconds()), 10)
|
||||
default:
|
||||
value = fmt.Sprintf("%v", rawValue)
|
||||
name, err := rule.Assert[string](params[1])
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
value := formatValue(params[2])
|
||||
|
||||
r, ok := ctxRequest(ctx)
|
||||
if !ok {
|
||||
return nil, errors.New("could not find http request in context")
|
||||
}
|
||||
|
||||
r.Header.Add(name, value)
|
||||
|
||||
return true, nil
|
||||
},
|
||||
new(func(string, string) bool),
|
||||
new(func(context.Context, string, string) bool),
|
||||
)
|
||||
}
|
||||
|
||||
func setRequestHeaderFunc(r *http.Request) expr.Option {
|
||||
func setRequestHeaderFunc() expr.Option {
|
||||
return expr.Function(
|
||||
"set_header",
|
||||
func(params ...any) (any, error) {
|
||||
name := params[0].(string)
|
||||
rawValue := params[1]
|
||||
ctx, err := rule.Assert[context.Context](params[0])
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
var value string
|
||||
switch v := rawValue.(type) {
|
||||
case []string:
|
||||
value = strings.Join(v, ",")
|
||||
case time.Time:
|
||||
value = strconv.FormatInt(v.UTC().Unix(), 10)
|
||||
case time.Duration:
|
||||
value = strconv.FormatInt(int64(v.Seconds()), 10)
|
||||
default:
|
||||
value = fmt.Sprintf("%v", rawValue)
|
||||
name, err := rule.Assert[string](params[1])
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
value := formatValue(params[2])
|
||||
|
||||
r, ok := ctxRequest(ctx)
|
||||
if !ok {
|
||||
return nil, errors.New("could not find http request in context")
|
||||
}
|
||||
|
||||
r.Header.Set(name, value)
|
||||
|
||||
return true, nil
|
||||
},
|
||||
new(func(string, string) bool),
|
||||
new(func(context.Context, string, string) bool),
|
||||
)
|
||||
}
|
||||
|
||||
func delRequestHeadersFunc(r *http.Request) expr.Option {
|
||||
func delRequestHeadersFunc() expr.Option {
|
||||
return expr.Function(
|
||||
"del_headers",
|
||||
func(params ...any) (any, error) {
|
||||
pattern := params[0].(string)
|
||||
ctx, err := rule.Assert[context.Context](params[0])
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
pattern, err := rule.Assert[string](params[1])
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
r, ok := ctxRequest(ctx)
|
||||
if !ok {
|
||||
return nil, errors.New("could not find http request in context")
|
||||
}
|
||||
|
||||
deleted := false
|
||||
|
||||
for key := range r.Header {
|
||||
@ -117,6 +163,21 @@ func delRequestHeadersFunc(r *http.Request) expr.Option {
|
||||
|
||||
return deleted, nil
|
||||
},
|
||||
new(func(string) bool),
|
||||
new(func(context.Context, string) bool),
|
||||
)
|
||||
}
|
||||
|
||||
func formatValue(v any) string {
|
||||
var value string
|
||||
switch v := v.(type) {
|
||||
case []string:
|
||||
value = strings.Join(v, ",")
|
||||
case time.Time:
|
||||
value = strconv.FormatInt(v.UTC().Unix(), 10)
|
||||
case time.Duration:
|
||||
value = strconv.FormatInt(int64(v.Seconds()), 10)
|
||||
default:
|
||||
value = fmt.Sprintf("%v", v)
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
195
internal/rule/http/request_test.go
Normal file
195
internal/rule/http/request_test.go
Normal file
@ -0,0 +1,195 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"forge.cadoles.com/cadoles/bouncer/internal/rule"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func TestSetRequestHost(t *testing.T) {
|
||||
type Vars struct {
|
||||
NewHost string `expr:"newHost"`
|
||||
}
|
||||
|
||||
engine := createRuleEngine[Vars](t,
|
||||
rule.WithExpr(setRequestHostFunc()),
|
||||
rule.WithRules(
|
||||
"set_host(ctx, vars.newHost)",
|
||||
),
|
||||
)
|
||||
|
||||
req, err := http.NewRequest("GET", "http://example.net", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("%+v", errors.WithStack(err))
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
ctx = WithRequest(ctx, req)
|
||||
|
||||
vars := Vars{
|
||||
NewHost: "foobar",
|
||||
}
|
||||
|
||||
if _, err := engine.Apply(ctx, vars); err != nil {
|
||||
t.Fatalf("%+v", errors.WithStack(err))
|
||||
}
|
||||
|
||||
if e, g := vars.NewHost, req.Host; e != g {
|
||||
t.Errorf("req.Host: expected '%v', got '%v'", e, g)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetRequestURL(t *testing.T) {
|
||||
type Vars struct {
|
||||
NewURL string `expr:"newURL"`
|
||||
}
|
||||
|
||||
engine := createRuleEngine[Vars](t,
|
||||
rule.WithExpr(setRequestURLFunc()),
|
||||
rule.WithRules(
|
||||
"set_url(ctx, vars.newURL)",
|
||||
),
|
||||
)
|
||||
|
||||
req, err := http.NewRequest("GET", "http://example.net", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("%+v", errors.WithStack(err))
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
ctx = WithRequest(ctx, req)
|
||||
|
||||
vars := Vars{
|
||||
NewURL: "http://localhost",
|
||||
}
|
||||
|
||||
if _, err := engine.Apply(ctx, vars); err != nil {
|
||||
t.Fatalf("%+v", errors.WithStack(err))
|
||||
}
|
||||
|
||||
if e, g := vars.NewURL, req.URL.String(); e != g {
|
||||
t.Errorf("req.URL.String(): expected '%v', got '%v'", e, g)
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddRequestHeader(t *testing.T) {
|
||||
type Vars struct {
|
||||
NewHeaderKey string `expr:"newHeaderKey"`
|
||||
NewHeaderValue string `expr:"newHeaderValue"`
|
||||
}
|
||||
|
||||
engine := createRuleEngine[Vars](t,
|
||||
rule.WithExpr(addRequestHeaderFunc()),
|
||||
rule.WithRules(
|
||||
"add_header(ctx, vars.newHeaderKey, vars.newHeaderValue)",
|
||||
),
|
||||
)
|
||||
|
||||
req, err := http.NewRequest("GET", "http://example.net", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("%+v", errors.WithStack(err))
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
ctx = WithRequest(ctx, req)
|
||||
|
||||
vars := Vars{
|
||||
NewHeaderKey: "X-My-Header",
|
||||
NewHeaderValue: "foobar",
|
||||
}
|
||||
|
||||
if _, err := engine.Apply(ctx, vars); err != nil {
|
||||
t.Fatalf("%+v", errors.WithStack(err))
|
||||
}
|
||||
|
||||
if e, g := vars.NewHeaderValue, req.Header.Get(vars.NewHeaderKey); e != g {
|
||||
t.Errorf("req.Header.Get(vars.NewHeaderKey): expected '%v', got '%v'", e, g)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetRequestHeader(t *testing.T) {
|
||||
type Vars struct {
|
||||
HeaderKey string `expr:"headerKey"`
|
||||
HeaderValue string `expr:"headerValue"`
|
||||
}
|
||||
|
||||
engine := createRuleEngine[Vars](t,
|
||||
rule.WithExpr(setRequestHeaderFunc()),
|
||||
rule.WithRules(
|
||||
"set_header(ctx, vars.headerKey, vars.headerValue)",
|
||||
),
|
||||
)
|
||||
|
||||
req, err := http.NewRequest("GET", "http://example.net", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("%+v", errors.WithStack(err))
|
||||
}
|
||||
|
||||
vars := Vars{
|
||||
HeaderKey: "X-My-Header",
|
||||
HeaderValue: "foobar",
|
||||
}
|
||||
|
||||
req.Header.Set(vars.HeaderKey, "test")
|
||||
|
||||
ctx := context.Background()
|
||||
ctx = WithRequest(ctx, req)
|
||||
|
||||
if _, err := engine.Apply(ctx, vars); err != nil {
|
||||
t.Fatalf("%+v", errors.WithStack(err))
|
||||
}
|
||||
|
||||
if e, g := vars.HeaderValue, req.Header.Get(vars.HeaderKey); e != g {
|
||||
t.Errorf("req.Header.Get(vars.HeaderKey): expected '%v', got '%v'", e, g)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDelRequestHeaders(t *testing.T) {
|
||||
type Vars struct {
|
||||
HeaderPattern string `expr:"headerPattern"`
|
||||
}
|
||||
|
||||
engine := createRuleEngine[Vars](t,
|
||||
rule.WithExpr(delRequestHeadersFunc()),
|
||||
rule.WithRules(
|
||||
"del_headers(ctx, vars.headerPattern)",
|
||||
),
|
||||
)
|
||||
|
||||
req, err := http.NewRequest("GET", "http://example.net", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("%+v", errors.WithStack(err))
|
||||
}
|
||||
|
||||
vars := Vars{
|
||||
HeaderPattern: "X-My-*",
|
||||
}
|
||||
|
||||
req.Header.Set("X-My-Header", "test")
|
||||
|
||||
ctx := context.Background()
|
||||
ctx = WithRequest(ctx, req)
|
||||
|
||||
if _, err := engine.Apply(ctx, vars); err != nil {
|
||||
t.Fatalf("%+v", errors.WithStack(err))
|
||||
}
|
||||
|
||||
if val := req.Header.Get("X-My-Header"); val != "" {
|
||||
t.Errorf("req.Header.Get(\"X-My-Header\") should be empty, got '%v'", val)
|
||||
}
|
||||
}
|
||||
|
||||
func createRuleEngine[V any](t *testing.T, funcs ...rule.OptionFunc) *rule.Engine[V] {
|
||||
engine, err := rule.NewEngine[V](funcs...)
|
||||
if err != nil {
|
||||
t.Fatalf("%+v", errors.WithStack(err))
|
||||
}
|
||||
|
||||
return engine
|
||||
}
|
@ -1,22 +1,33 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"forge.cadoles.com/Cadoles/go-proxy/wildcard"
|
||||
"forge.cadoles.com/cadoles/bouncer/internal/rule"
|
||||
"github.com/expr-lang/expr"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func addResponseHeaderFunc(r *http.Response) expr.Option {
|
||||
func addResponseHeaderFunc() expr.Option {
|
||||
return expr.Function(
|
||||
"add_header",
|
||||
func(params ...any) (any, error) {
|
||||
name := params[0].(string)
|
||||
rawValue := params[1]
|
||||
ctx, err := rule.Assert[context.Context](params[0])
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
name, err := rule.Assert[string](params[1])
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
rawValue := params[2]
|
||||
|
||||
var value string
|
||||
switch v := rawValue.(type) {
|
||||
@ -30,20 +41,34 @@ func addResponseHeaderFunc(r *http.Response) expr.Option {
|
||||
value = fmt.Sprintf("%v", rawValue)
|
||||
}
|
||||
|
||||
r, ok := ctxResponse(ctx)
|
||||
if !ok {
|
||||
return nil, errors.New("could not find http response in context")
|
||||
}
|
||||
|
||||
r.Header.Add(name, value)
|
||||
|
||||
return true, nil
|
||||
},
|
||||
new(func(string, string) bool),
|
||||
new(func(context.Context, string, string) bool),
|
||||
)
|
||||
}
|
||||
|
||||
func setResponseHeaderFunc(r *http.Response) expr.Option {
|
||||
func setResponseHeaderFunc() expr.Option {
|
||||
return expr.Function(
|
||||
"set_header",
|
||||
func(params ...any) (any, error) {
|
||||
name := params[0].(string)
|
||||
rawValue := params[1]
|
||||
ctx, err := rule.Assert[context.Context](params[0])
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
name, err := rule.Assert[string](params[1])
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
rawValue := params[2]
|
||||
|
||||
var value string
|
||||
switch v := rawValue.(type) {
|
||||
@ -57,19 +82,38 @@ func setResponseHeaderFunc(r *http.Response) expr.Option {
|
||||
value = fmt.Sprintf("%v", rawValue)
|
||||
}
|
||||
|
||||
r, ok := ctxResponse(ctx)
|
||||
if !ok {
|
||||
return nil, errors.New("could not find http response in context")
|
||||
}
|
||||
|
||||
r.Header.Set(name, value)
|
||||
|
||||
return true, nil
|
||||
},
|
||||
new(func(string, string) bool),
|
||||
new(func(context.Context, string, string) bool),
|
||||
)
|
||||
}
|
||||
|
||||
func delResponseHeadersFunc(r *http.Response) expr.Option {
|
||||
func delResponseHeadersFunc() expr.Option {
|
||||
return expr.Function(
|
||||
"del_headers",
|
||||
func(params ...any) (any, error) {
|
||||
pattern := params[0].(string)
|
||||
ctx, err := rule.Assert[context.Context](params[0])
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
pattern, err := rule.Assert[string](params[1])
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
r, ok := ctxResponse(ctx)
|
||||
if !ok {
|
||||
return nil, errors.New("could not find http response in context")
|
||||
}
|
||||
|
||||
deleted := false
|
||||
|
||||
for key := range r.Header {
|
||||
@ -83,6 +127,6 @@ func delResponseHeadersFunc(r *http.Response) expr.Option {
|
||||
|
||||
return deleted, nil
|
||||
},
|
||||
new(func(string) bool),
|
||||
new(func(context.Context, string) bool),
|
||||
)
|
||||
}
|
||||
|
139
internal/rule/http/response_test.go
Normal file
139
internal/rule/http/response_test.go
Normal file
@ -0,0 +1,139 @@
|
||||
package http
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"forge.cadoles.com/cadoles/bouncer/internal/rule"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func TestAddResponseHeader(t *testing.T) {
|
||||
type Vars struct {
|
||||
NewHeaderKey string `expr:"newHeaderKey"`
|
||||
NewHeaderValue string `expr:"newHeaderValue"`
|
||||
}
|
||||
|
||||
engine := createRuleEngine[Vars](t,
|
||||
rule.WithExpr(addResponseHeaderFunc()),
|
||||
rule.WithRules(
|
||||
"add_header(ctx, vars.newHeaderKey, vars.newHeaderValue)",
|
||||
),
|
||||
)
|
||||
|
||||
req, err := http.NewRequest("GET", "http://example.net", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("%+v", errors.WithStack(err))
|
||||
}
|
||||
|
||||
resp := createResponse(req, http.StatusOK, nil)
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
ctx = WithResponse(ctx, resp)
|
||||
|
||||
vars := Vars{
|
||||
NewHeaderKey: "X-My-Header",
|
||||
NewHeaderValue: "foobar",
|
||||
}
|
||||
|
||||
if _, err := engine.Apply(ctx, vars); err != nil {
|
||||
t.Fatalf("%+v", errors.WithStack(err))
|
||||
}
|
||||
|
||||
if e, g := vars.NewHeaderValue, resp.Header.Get(vars.NewHeaderKey); e != g {
|
||||
t.Errorf("resp.Header.Get(vars.NewHeaderKey): expected '%v', got '%v'", e, g)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponseSetHeader(t *testing.T) {
|
||||
type Vars struct {
|
||||
HeaderKey string `expr:"headerKey"`
|
||||
HeaderValue string `expr:"headerValue"`
|
||||
}
|
||||
|
||||
engine := createRuleEngine[Vars](t,
|
||||
rule.WithExpr(setResponseHeaderFunc()),
|
||||
rule.WithRules(
|
||||
"set_header(ctx, vars.headerKey, vars.headerValue)",
|
||||
),
|
||||
)
|
||||
|
||||
req, err := http.NewRequest("GET", "http://example.net", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("%+v", errors.WithStack(err))
|
||||
}
|
||||
|
||||
resp := createResponse(req, http.StatusOK, nil)
|
||||
|
||||
vars := Vars{
|
||||
HeaderKey: "X-My-Header",
|
||||
HeaderValue: "foobar",
|
||||
}
|
||||
|
||||
resp.Header.Set(vars.HeaderKey, "test")
|
||||
|
||||
ctx := context.Background()
|
||||
ctx = WithResponse(ctx, resp)
|
||||
|
||||
if _, err := engine.Apply(ctx, vars); err != nil {
|
||||
t.Fatalf("%+v", errors.WithStack(err))
|
||||
}
|
||||
|
||||
if e, g := vars.HeaderValue, resp.Header.Get(vars.HeaderKey); e != g {
|
||||
t.Errorf("resp.Header.Get(vars.HeaderKey): expected '%v', got '%v'", e, g)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponseDelHeaders(t *testing.T) {
|
||||
type Vars struct {
|
||||
HeaderPattern string `expr:"headerPattern"`
|
||||
}
|
||||
|
||||
engine := createRuleEngine[Vars](t,
|
||||
rule.WithExpr(delResponseHeadersFunc()),
|
||||
rule.WithRules(
|
||||
"del_headers(ctx, vars.headerPattern)",
|
||||
),
|
||||
)
|
||||
|
||||
req, err := http.NewRequest("GET", "http://example.net", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("%+v", errors.WithStack(err))
|
||||
}
|
||||
|
||||
resp := createResponse(req, http.StatusOK, nil)
|
||||
|
||||
vars := Vars{
|
||||
HeaderPattern: "X-My-*",
|
||||
}
|
||||
|
||||
resp.Header.Set("X-My-Header", "test")
|
||||
|
||||
ctx := context.Background()
|
||||
ctx = WithResponse(ctx, resp)
|
||||
|
||||
if _, err := engine.Apply(ctx, vars); err != nil {
|
||||
t.Fatalf("%+v", errors.WithStack(err))
|
||||
}
|
||||
|
||||
if val := resp.Header.Get("X-My-Header"); val != "" {
|
||||
t.Errorf("resp.Header.Get(\"X-My-Header\") should be empty, got '%v'", val)
|
||||
}
|
||||
}
|
||||
|
||||
func createResponse(req *http.Request, statusCode int, body io.Reader) *http.Response {
|
||||
return &http.Response{
|
||||
Status: http.StatusText(statusCode),
|
||||
StatusCode: statusCode,
|
||||
Proto: "HTTP/1.1",
|
||||
ProtoMajor: 1,
|
||||
ProtoMinor: 1,
|
||||
Body: io.NopCloser(body),
|
||||
ContentLength: -1,
|
||||
Request: req,
|
||||
Header: make(http.Header, 0),
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user