feat: reusable rule engine to prevent memory reallocation
All checks were successful
Cadoles/bouncer/pipeline/pr-develop This commit looks good

This commit is contained in:
2024-09-24 15:46:42 +02:00
parent f37425018b
commit fea0610346
23 changed files with 885 additions and 198 deletions

View File

@ -1,109 +1,155 @@
package http
import (
"context"
"fmt"
"net/http"
"net/url"
"strconv"
"strings"
"time"
"forge.cadoles.com/Cadoles/go-proxy/wildcard"
"forge.cadoles.com/cadoles/bouncer/internal/rule"
"github.com/expr-lang/expr"
"github.com/pkg/errors"
)
func setRequestHostFunc(r *http.Request) expr.Option {
func setRequestHostFunc() expr.Option {
return expr.Function(
"set_host",
func(params ...any) (any, error) {
host := params[0].(string)
ctx, err := rule.Assert[context.Context](params[0])
if err != nil {
return nil, errors.WithStack(err)
}
host, err := rule.Assert[string](params[1])
if err != nil {
return nil, errors.WithStack(err)
}
r, ok := ctxRequest(ctx)
if !ok {
return nil, errors.New("could not find http request in context")
}
r.Host = host
return true, nil
},
new(func(string) bool),
new(func(context.Context, string) bool),
)
}
func setRequestURL(r *http.Request) expr.Option {
func setRequestURLFunc() expr.Option {
return expr.Function(
"set_url",
func(params ...any) (any, error) {
rawURL := params[0].(string)
ctx, err := rule.Assert[context.Context](params[0])
if err != nil {
return nil, errors.WithStack(err)
}
rawURL, err := rule.Assert[string](params[1])
if err != nil {
return nil, errors.WithStack(err)
}
url, err := url.Parse(rawURL)
if err != nil {
return false, errors.WithStack(err)
}
r, ok := ctxRequest(ctx)
if !ok {
return nil, errors.New("could not find http request in context")
}
r.URL = url
return true, nil
},
new(func(string) bool),
new(func(context.Context, string) bool),
)
}
func addRequestHeaderFunc(r *http.Request) expr.Option {
func addRequestHeaderFunc() expr.Option {
return expr.Function(
"add_header",
func(params ...any) (any, error) {
name := params[0].(string)
rawValue := params[1]
ctx, err := rule.Assert[context.Context](params[0])
if err != nil {
return nil, errors.WithStack(err)
}
var value string
switch v := rawValue.(type) {
case []string:
value = strings.Join(v, ",")
case time.Time:
value = strconv.FormatInt(v.UTC().Unix(), 10)
case time.Duration:
value = strconv.FormatInt(int64(v.Seconds()), 10)
default:
value = fmt.Sprintf("%v", rawValue)
name, err := rule.Assert[string](params[1])
if err != nil {
return nil, errors.WithStack(err)
}
value := formatValue(params[2])
r, ok := ctxRequest(ctx)
if !ok {
return nil, errors.New("could not find http request in context")
}
r.Header.Add(name, value)
return true, nil
},
new(func(string, string) bool),
new(func(context.Context, string, string) bool),
)
}
func setRequestHeaderFunc(r *http.Request) expr.Option {
func setRequestHeaderFunc() expr.Option {
return expr.Function(
"set_header",
func(params ...any) (any, error) {
name := params[0].(string)
rawValue := params[1]
ctx, err := rule.Assert[context.Context](params[0])
if err != nil {
return nil, errors.WithStack(err)
}
var value string
switch v := rawValue.(type) {
case []string:
value = strings.Join(v, ",")
case time.Time:
value = strconv.FormatInt(v.UTC().Unix(), 10)
case time.Duration:
value = strconv.FormatInt(int64(v.Seconds()), 10)
default:
value = fmt.Sprintf("%v", rawValue)
name, err := rule.Assert[string](params[1])
if err != nil {
return nil, errors.WithStack(err)
}
value := formatValue(params[2])
r, ok := ctxRequest(ctx)
if !ok {
return nil, errors.New("could not find http request in context")
}
r.Header.Set(name, value)
return true, nil
},
new(func(string, string) bool),
new(func(context.Context, string, string) bool),
)
}
func delRequestHeadersFunc(r *http.Request) expr.Option {
func delRequestHeadersFunc() expr.Option {
return expr.Function(
"del_headers",
func(params ...any) (any, error) {
pattern := params[0].(string)
ctx, err := rule.Assert[context.Context](params[0])
if err != nil {
return nil, errors.WithStack(err)
}
pattern, err := rule.Assert[string](params[1])
if err != nil {
return nil, errors.WithStack(err)
}
r, ok := ctxRequest(ctx)
if !ok {
return nil, errors.New("could not find http request in context")
}
deleted := false
for key := range r.Header {
@ -117,6 +163,21 @@ func delRequestHeadersFunc(r *http.Request) expr.Option {
return deleted, nil
},
new(func(string) bool),
new(func(context.Context, string) bool),
)
}
func formatValue(v any) string {
var value string
switch v := v.(type) {
case []string:
value = strings.Join(v, ",")
case time.Time:
value = strconv.FormatInt(v.UTC().Unix(), 10)
case time.Duration:
value = strconv.FormatInt(int64(v.Seconds()), 10)
default:
value = fmt.Sprintf("%v", v)
}
return value
}