bouncer/internal/rule/http/request.go

184 lines
3.6 KiB
Go

package http
import (
"context"
"fmt"
"net/url"
"strconv"
"strings"
"time"
"forge.cadoles.com/Cadoles/go-proxy/wildcard"
"forge.cadoles.com/cadoles/bouncer/internal/rule"
"github.com/expr-lang/expr"
"github.com/pkg/errors"
)
func setRequestHostFunc() expr.Option {
return expr.Function(
"set_host",
func(params ...any) (any, error) {
ctx, err := rule.Assert[context.Context](params[0])
if err != nil {
return nil, errors.WithStack(err)
}
host, err := rule.Assert[string](params[1])
if err != nil {
return nil, errors.WithStack(err)
}
r, ok := ctxRequest(ctx)
if !ok {
return nil, errors.New("could not find http request in context")
}
r.Host = host
return true, nil
},
new(func(context.Context, string) bool),
)
}
func setRequestURLFunc() expr.Option {
return expr.Function(
"set_url",
func(params ...any) (any, error) {
ctx, err := rule.Assert[context.Context](params[0])
if err != nil {
return nil, errors.WithStack(err)
}
rawURL, err := rule.Assert[string](params[1])
if err != nil {
return nil, errors.WithStack(err)
}
url, err := url.Parse(rawURL)
if err != nil {
return false, errors.WithStack(err)
}
r, ok := ctxRequest(ctx)
if !ok {
return nil, errors.New("could not find http request in context")
}
r.URL = url
return true, nil
},
new(func(context.Context, string) bool),
)
}
func addRequestHeaderFunc() expr.Option {
return expr.Function(
"add_header",
func(params ...any) (any, error) {
ctx, err := rule.Assert[context.Context](params[0])
if err != nil {
return nil, errors.WithStack(err)
}
name, err := rule.Assert[string](params[1])
if err != nil {
return nil, errors.WithStack(err)
}
value := formatValue(params[2])
r, ok := ctxRequest(ctx)
if !ok {
return nil, errors.New("could not find http request in context")
}
r.Header.Add(name, value)
return true, nil
},
new(func(context.Context, string, string) bool),
)
}
func setRequestHeaderFunc() expr.Option {
return expr.Function(
"set_header",
func(params ...any) (any, error) {
ctx, err := rule.Assert[context.Context](params[0])
if err != nil {
return nil, errors.WithStack(err)
}
name, err := rule.Assert[string](params[1])
if err != nil {
return nil, errors.WithStack(err)
}
value := formatValue(params[2])
r, ok := ctxRequest(ctx)
if !ok {
return nil, errors.New("could not find http request in context")
}
r.Header.Set(name, value)
return true, nil
},
new(func(context.Context, string, string) bool),
)
}
func delRequestHeadersFunc() expr.Option {
return expr.Function(
"del_headers",
func(params ...any) (any, error) {
ctx, err := rule.Assert[context.Context](params[0])
if err != nil {
return nil, errors.WithStack(err)
}
pattern, err := rule.Assert[string](params[1])
if err != nil {
return nil, errors.WithStack(err)
}
r, ok := ctxRequest(ctx)
if !ok {
return nil, errors.New("could not find http request in context")
}
deleted := false
for key := range r.Header {
if !wildcard.Match(key, pattern) {
continue
}
r.Header.Del(key)
deleted = true
}
return deleted, nil
},
new(func(context.Context, string) bool),
)
}
func formatValue(v any) string {
var value string
switch v := v.(type) {
case []string:
value = strings.Join(v, ",")
case time.Time:
value = strconv.FormatInt(v.UTC().Unix(), 10)
case time.Duration:
value = strconv.FormatInt(int64(v.Seconds()), 10)
default:
value = fmt.Sprintf("%v", v)
}
return value
}