324 lines
6.7 KiB
Go
324 lines
6.7 KiB
Go
package http
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
"strconv"
|
|
"strings"
|
|
"time"
|
|
|
|
"forge.cadoles.com/Cadoles/go-proxy/wildcard"
|
|
"forge.cadoles.com/cadoles/bouncer/internal/rule"
|
|
"github.com/expr-lang/expr"
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
func setRequestHostFunc() expr.Option {
|
|
return expr.Function(
|
|
"set_host",
|
|
func(params ...any) (any, error) {
|
|
ctx, err := rule.Assert[context.Context](params[0])
|
|
if err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
|
|
host, err := rule.Assert[string](params[1])
|
|
if err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
|
|
r, ok := CtxRequest(ctx)
|
|
if !ok {
|
|
return nil, errors.New("could not find http request in context")
|
|
}
|
|
|
|
r.Host = host
|
|
|
|
return true, nil
|
|
},
|
|
new(func(context.Context, string) bool),
|
|
)
|
|
}
|
|
|
|
func setRequestURLFunc() expr.Option {
|
|
return expr.Function(
|
|
"set_url",
|
|
func(params ...any) (any, error) {
|
|
ctx, err := rule.Assert[context.Context](params[0])
|
|
if err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
|
|
rawURL, err := rule.Assert[string](params[1])
|
|
if err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
|
|
url, err := url.Parse(rawURL)
|
|
if err != nil {
|
|
return false, errors.WithStack(err)
|
|
}
|
|
|
|
r, ok := CtxRequest(ctx)
|
|
if !ok {
|
|
return nil, errors.New("could not find http request in context")
|
|
}
|
|
|
|
r.URL = url
|
|
|
|
return true, nil
|
|
},
|
|
new(func(context.Context, string) bool),
|
|
)
|
|
}
|
|
|
|
func addRequestHeaderFunc() expr.Option {
|
|
return expr.Function(
|
|
"add_header",
|
|
func(params ...any) (any, error) {
|
|
ctx, err := rule.Assert[context.Context](params[0])
|
|
if err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
|
|
name, err := rule.Assert[string](params[1])
|
|
if err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
|
|
value := formatValue(params[2])
|
|
|
|
r, ok := CtxRequest(ctx)
|
|
if !ok {
|
|
return nil, errors.New("could not find http request in context")
|
|
}
|
|
|
|
r.Header.Add(name, value)
|
|
|
|
return true, nil
|
|
},
|
|
new(func(context.Context, string, string) bool),
|
|
)
|
|
}
|
|
|
|
func setRequestHeaderFunc() expr.Option {
|
|
return expr.Function(
|
|
"set_header",
|
|
func(params ...any) (any, error) {
|
|
ctx, err := rule.Assert[context.Context](params[0])
|
|
if err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
|
|
name, err := rule.Assert[string](params[1])
|
|
if err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
|
|
value := formatValue(params[2])
|
|
|
|
r, ok := CtxRequest(ctx)
|
|
if !ok {
|
|
return nil, errors.New("could not find http request in context")
|
|
}
|
|
|
|
r.Header.Set(name, value)
|
|
|
|
return true, nil
|
|
},
|
|
new(func(context.Context, string, string) bool),
|
|
)
|
|
}
|
|
|
|
func delRequestHeadersFunc() expr.Option {
|
|
return expr.Function(
|
|
"del_headers",
|
|
func(params ...any) (any, error) {
|
|
ctx, err := rule.Assert[context.Context](params[0])
|
|
if err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
|
|
pattern, err := rule.Assert[string](params[1])
|
|
if err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
|
|
r, ok := CtxRequest(ctx)
|
|
if !ok {
|
|
return nil, errors.New("could not find http request in context")
|
|
}
|
|
|
|
deleted := false
|
|
|
|
for key := range r.Header {
|
|
if !wildcard.Match(key, pattern) {
|
|
continue
|
|
}
|
|
|
|
r.Header.Del(key)
|
|
deleted = true
|
|
}
|
|
|
|
return deleted, nil
|
|
},
|
|
new(func(context.Context, string) bool),
|
|
)
|
|
}
|
|
|
|
type CookieVar struct {
|
|
Name string `expr:"name"`
|
|
Value string `expr:"value"`
|
|
Path string `expr:"path"`
|
|
Domain string `expr:"domain"`
|
|
Expires time.Time `expr:"expires"`
|
|
MaxAge int `expr:"max_age"`
|
|
Secure bool `expr:"secure"`
|
|
HttpOnly bool `expr:"http_only"`
|
|
SameSite http.SameSite `expr:"same_site"`
|
|
}
|
|
|
|
func getRequestCookieFunc() expr.Option {
|
|
return expr.Function(
|
|
"get_cookie",
|
|
func(params ...any) (any, error) {
|
|
ctx, err := rule.Assert[context.Context](params[0])
|
|
if err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
|
|
name, err := rule.Assert[string](params[1])
|
|
if err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
|
|
r, ok := CtxRequest(ctx)
|
|
if !ok {
|
|
return nil, errors.New("could not find http request in context")
|
|
}
|
|
|
|
cookie, err := r.Cookie(name)
|
|
if err != nil && !errors.Is(err, http.ErrNoCookie) {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
|
|
if cookie == nil {
|
|
return nil, nil
|
|
}
|
|
|
|
return CookieVar{
|
|
Name: cookie.Name,
|
|
Value: cookie.Value,
|
|
Path: cookie.Path,
|
|
Domain: cookie.Domain,
|
|
Expires: cookie.Expires,
|
|
MaxAge: cookie.MaxAge,
|
|
Secure: cookie.Secure,
|
|
HttpOnly: cookie.HttpOnly,
|
|
SameSite: cookie.SameSite,
|
|
}, nil
|
|
},
|
|
new(func(context.Context, string) CookieVar),
|
|
)
|
|
}
|
|
|
|
func addRequestCookieFunc() expr.Option {
|
|
return expr.Function(
|
|
"add_cookie",
|
|
func(params ...any) (any, error) {
|
|
ctx, err := rule.Assert[context.Context](params[0])
|
|
if err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
|
|
values, err := rule.Assert[map[string]any](params[1])
|
|
if err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
|
|
cookie, err := cookieFrom(values)
|
|
if err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
|
|
r, ok := CtxRequest(ctx)
|
|
if !ok {
|
|
return nil, errors.New("could not find http request in context")
|
|
}
|
|
|
|
r.AddCookie(cookie)
|
|
|
|
return true, nil
|
|
},
|
|
new(func(context.Context, map[string]any) bool),
|
|
)
|
|
}
|
|
|
|
func cookieFrom(values map[string]any) (*http.Cookie, error) {
|
|
cookie := &http.Cookie{}
|
|
|
|
if name, ok := values["name"].(string); ok {
|
|
cookie.Name = name
|
|
}
|
|
|
|
if value, ok := values["value"].(string); ok {
|
|
cookie.Value = value
|
|
}
|
|
|
|
if domain, ok := values["domain"].(string); ok {
|
|
cookie.Domain = domain
|
|
}
|
|
|
|
if path, ok := values["path"].(string); ok {
|
|
cookie.Path = path
|
|
}
|
|
|
|
if httpOnly, ok := values["http_only"].(bool); ok {
|
|
cookie.HttpOnly = httpOnly
|
|
}
|
|
|
|
if maxAge, ok := values["max_age"].(int); ok {
|
|
cookie.MaxAge = maxAge
|
|
}
|
|
|
|
if secure, ok := values["secure"].(bool); ok {
|
|
cookie.Secure = secure
|
|
}
|
|
|
|
if sameSite, ok := values["same_site"].(http.SameSite); ok {
|
|
cookie.SameSite = sameSite
|
|
} else if sameSite, ok := values["same_site"].(int); ok {
|
|
cookie.SameSite = http.SameSite(sameSite)
|
|
}
|
|
|
|
if expires, ok := values["expires"].(time.Time); ok {
|
|
cookie.Expires = expires
|
|
} else if rawExpires, ok := values["expires"].(string); ok {
|
|
expires, err := time.Parse(http.TimeFormat, rawExpires)
|
|
if err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
|
|
cookie.Expires = expires
|
|
}
|
|
|
|
return cookie, nil
|
|
}
|
|
|
|
func formatValue(v any) string {
|
|
var value string
|
|
switch v := v.(type) {
|
|
case []string:
|
|
value = strings.Join(v, ",")
|
|
case time.Time:
|
|
value = strconv.FormatInt(v.UTC().Unix(), 10)
|
|
case time.Duration:
|
|
value = strconv.FormatInt(int64(v.Seconds()), 10)
|
|
default:
|
|
value = fmt.Sprintf("%v", v)
|
|
}
|
|
return value
|
|
}
|