From e60c441d438ab23bbfc7fa9a0d2d1b22806a35d4 Mon Sep 17 00:00:00 2001 From: William Petit Date: Tue, 28 Mar 2023 11:02:53 +0200 Subject: [PATCH] feat: move proxy package from arcad/edge --- internal/agent/controller/proxy/controller.go | 22 +-- .../agent/controller/proxy/reverse_proxy.go | 3 +- internal/proxy/host_filter.go | 29 ++++ internal/proxy/host_rewrite.go | 65 +++++++++ internal/proxy/middleware.go | 33 +++++ internal/proxy/options.go | 29 ++++ internal/proxy/proxy.go | 131 ++++++++++++++++++ internal/proxy/wildcard/match.go | 44 ++++++ 8 files changed, 343 insertions(+), 13 deletions(-) create mode 100644 internal/proxy/host_filter.go create mode 100644 internal/proxy/host_rewrite.go create mode 100644 internal/proxy/middleware.go create mode 100644 internal/proxy/options.go create mode 100644 internal/proxy/proxy.go create mode 100644 internal/proxy/wildcard/match.go diff --git a/internal/agent/controller/proxy/controller.go b/internal/agent/controller/proxy/controller.go index 046b0dc..eb5cc95 100644 --- a/internal/agent/controller/proxy/controller.go +++ b/internal/agent/controller/proxy/controller.go @@ -5,8 +5,8 @@ import ( "net/url" "forge.cadoles.com/Cadoles/emissary/internal/agent" - "forge.cadoles.com/Cadoles/emissary/internal/spec/proxy" - edgeProxy "forge.cadoles.com/arcad/edge/pkg/proxy" + "forge.cadoles.com/Cadoles/emissary/internal/proxy" + spec "forge.cadoles.com/Cadoles/emissary/internal/spec/proxy" "github.com/mitchellh/hashstructure/v2" "github.com/pkg/errors" "gitlab.com/wpetit/goweb/logger" @@ -18,7 +18,7 @@ type proxyEntry struct { } type Controller struct { - proxies map[proxy.ID]*proxyEntry + proxies map[spec.ID]*proxyEntry } // Name implements node.Controller. @@ -28,9 +28,9 @@ func (c *Controller) Name() string { // Reconcile implements node.Controller. func (c *Controller) Reconcile(ctx context.Context, state *agent.State) error { - proxySpec := proxy.NewSpec() + proxySpec := spec.NewSpec() - if err := state.GetSpec(proxy.NameProxy, proxySpec); err != nil { + if err := state.GetSpec(spec.NameProxy, proxySpec); err != nil { if errors.Is(err, agent.ErrSpecNotFound) { logger.Info(ctx, "could not find proxy spec") @@ -69,7 +69,7 @@ func (c *Controller) stopAllProxies(ctx context.Context) { } } -func (c *Controller) updateProxies(ctx context.Context, spec *proxy.Spec) { +func (c *Controller) updateProxies(ctx context.Context, spec *spec.Spec) { // Stop and remove obsolete proxys for proxyID, entry := range c.proxies { if _, exists := spec.Proxies[proxyID]; exists { @@ -100,7 +100,7 @@ func (c *Controller) updateProxies(ctx context.Context, spec *proxy.Spec) { } } -func (c *Controller) updateProxy(ctx context.Context, proxyID proxy.ID, proxySpec proxy.ProxyEntry) (err error) { +func (c *Controller) updateProxy(ctx context.Context, proxyID spec.ID, proxySpec spec.ProxyEntry) (err error) { newProxySpecHash, err := hashstructure.Hash(proxySpec, hashstructure.FormatV2, nil) if err != nil { return errors.WithStack(err) @@ -140,7 +140,7 @@ func (c *Controller) updateProxy(ctx context.Context, proxyID proxy.ID, proxySpe ) } - options := make([]edgeProxy.OptionFunc, 0) + options := make([]proxy.OptionFunc, 0) allowedHosts := make([]string, len(proxySpec.Mappings)) mappings := make(map[string]*url.URL, len(proxySpec.Mappings)) @@ -156,8 +156,8 @@ func (c *Controller) updateProxy(ctx context.Context, proxyID proxy.ID, proxySpe options = append( options, - edgeProxy.WithAllowedHosts(allowedHosts...), - edgeProxy.WithRewriteHosts(mappings), + proxy.WithAllowedHosts(allowedHosts...), + proxy.WithRewriteHosts(mappings), ) if err := entry.Proxy.Start(ctx, proxySpec.Address, options...); err != nil { @@ -173,7 +173,7 @@ func (c *Controller) updateProxy(ctx context.Context, proxyID proxy.ID, proxySpe func NewController() *Controller { return &Controller{ - proxies: make(map[proxy.ID]*proxyEntry), + proxies: make(map[spec.ID]*proxyEntry), } } diff --git a/internal/agent/controller/proxy/reverse_proxy.go b/internal/agent/controller/proxy/reverse_proxy.go index ee79b9c..1aedaa6 100644 --- a/internal/agent/controller/proxy/reverse_proxy.go +++ b/internal/agent/controller/proxy/reverse_proxy.go @@ -5,10 +5,9 @@ import ( "net/http" "sync" + "forge.cadoles.com/Cadoles/emissary/internal/proxy" "github.com/pkg/errors" "gitlab.com/wpetit/goweb/logger" - - "forge.cadoles.com/arcad/edge/pkg/proxy" ) type ReverseProxy struct { diff --git a/internal/proxy/host_filter.go b/internal/proxy/host_filter.go new file mode 100644 index 0000000..713968a --- /dev/null +++ b/internal/proxy/host_filter.go @@ -0,0 +1,29 @@ +package proxy + +import ( + "net/http" + + "forge.cadoles.com/arcad/edge/pkg/proxy/wildcard" +) + +func FilterHosts(allowedHostPatterns ...string) Middleware { + return func(h http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + if matches := wildcard.MatchAny(r.Host, allowedHostPatterns...); !matches { + http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden) + + return + } + + h.ServeHTTP(w, r) + } + + return http.HandlerFunc(fn) + } +} + +func WithAllowedHosts(allowedHostPatterns ...string) OptionFunc { + return func(o *Options) { + o.Middlewares = append(o.Middlewares, FilterHosts(allowedHostPatterns...)) + } +} diff --git a/internal/proxy/host_rewrite.go b/internal/proxy/host_rewrite.go new file mode 100644 index 0000000..2a55bd8 --- /dev/null +++ b/internal/proxy/host_rewrite.go @@ -0,0 +1,65 @@ +package proxy + +import ( + "net/http" + "net/url" + "sort" + + "forge.cadoles.com/arcad/edge/pkg/proxy/wildcard" + "gitlab.com/wpetit/goweb/logger" +) + +func RewriteHosts(mappings map[string]*url.URL) Middleware { + patterns := make([]string, len(mappings)) + + for p := range mappings { + patterns = append(patterns, p) + } + + sort.Strings(patterns) + + return func(h http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + var match *url.URL + + for _, p := range patterns { + logger.Debug(ctx, "matching host to pattern", logger.F("host", r.Host), logger.F("pattern", p)) + + if matches := wildcard.Match(r.Host, p); !matches { + continue + } + + match = mappings[p] + break + } + + if match == nil { + h.ServeHTTP(w, r) + + return + } + + ctx = logger.With(ctx, logger.F("originalHost", r.Host)) + r = r.WithContext(ctx) + + originalURL := r.URL.String() + + r.URL.Host = match.Host + r.URL.Scheme = match.Scheme + + logger.Debug(ctx, "rewriting url", logger.F("from", originalURL), logger.F("to", r.URL.String())) + + h.ServeHTTP(w, r) + } + + return http.HandlerFunc(fn) + } +} + +func WithRewriteHosts(mappings map[string]*url.URL) OptionFunc { + return func(o *Options) { + o.Middlewares = append(o.Middlewares, RewriteHosts(mappings)) + } +} diff --git a/internal/proxy/middleware.go b/internal/proxy/middleware.go new file mode 100644 index 0000000..b1efb0f --- /dev/null +++ b/internal/proxy/middleware.go @@ -0,0 +1,33 @@ +package proxy + +import "net/http" + +type Middleware func(h http.Handler) http.Handler + +type ProxyResponseTransformer interface { + TransformResponse(*http.Response) error +} + +type defaultProxyResponseTransformer struct{} + +// TransformResponse implements ProxyResponseTransformer +func (*defaultProxyResponseTransformer) TransformResponse(*http.Response) error { + return nil +} + +var _ ProxyResponseTransformer = &defaultProxyResponseTransformer{} + +type ProxyResponseMiddleware func(ProxyResponseTransformer) ProxyResponseTransformer + +type ProxyRequestTransformer interface { + TransformRequest(*http.Request) +} + +type ProxyRequestMiddleware func(ProxyRequestTransformer) ProxyRequestTransformer + +type defaultProxyRequestTransformer struct{} + +// TransformRequest implements ProxyRequestTransformer +func (*defaultProxyRequestTransformer) TransformRequest(*http.Request) {} + +var _ ProxyRequestTransformer = &defaultProxyRequestTransformer{} diff --git a/internal/proxy/options.go b/internal/proxy/options.go new file mode 100644 index 0000000..22a6945 --- /dev/null +++ b/internal/proxy/options.go @@ -0,0 +1,29 @@ +package proxy + +type Options struct { + Middlewares []Middleware + ProxyRequestMiddlewares []ProxyRequestMiddleware + ProxyResponseMiddlewares []ProxyResponseMiddleware +} + +func defaultOptions() *Options { + return &Options{ + Middlewares: make([]Middleware, 0), + ProxyRequestMiddlewares: make([]ProxyRequestMiddleware, 0), + ProxyResponseMiddlewares: make([]ProxyResponseMiddleware, 0), + } +} + +type OptionFunc func(*Options) + +func WithProxyRequestMiddlewares(middlewares ...ProxyRequestMiddleware) OptionFunc { + return func(o *Options) { + o.ProxyRequestMiddlewares = middlewares + } +} + +func WithproxyResponseMiddlewares(middlewares ...ProxyResponseMiddleware) OptionFunc { + return func(o *Options) { + o.ProxyResponseMiddlewares = middlewares + } +} diff --git a/internal/proxy/proxy.go b/internal/proxy/proxy.go new file mode 100644 index 0000000..e932199 --- /dev/null +++ b/internal/proxy/proxy.go @@ -0,0 +1,131 @@ +package proxy + +import ( + "fmt" + "net/http" + "net/http/httputil" + "net/url" + "sync" + + "github.com/pkg/errors" + "gitlab.com/wpetit/goweb/logger" +) + +type Proxy struct { + reversers sync.Map + handler http.Handler + proxyResponseTransformer ProxyResponseTransformer + proxyRequestTransformer ProxyRequestTransformer +} + +// ServeHTTP implements http.Handler +func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { + p.handler.ServeHTTP(w, r) +} + +func (p *Proxy) proxyRequest(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + var reverser *httputil.ReverseProxy + + key := fmt.Sprintf("%s://%s", r.URL.Scheme, r.URL.Host) + + createAndStore := func() { + target := &url.URL{ + Scheme: r.URL.Scheme, + Host: r.URL.Host, + } + + reverser = httputil.NewSingleHostReverseProxy(target) + + originalDirector := reverser.Director + + if p.proxyRequestTransformer != nil { + reverser.Director = func(r *http.Request) { + originalURL := r.URL.String() + originalDirector(r) + p.proxyRequestTransformer.TransformRequest(r) + logger.Debug(ctx, "proxying request", logger.F("targetURL", r.URL.String()), logger.F("originalURL", originalURL)) + } + } + + if p.proxyResponseTransformer != nil { + reverser.ModifyResponse = func(r *http.Response) error { + if err := p.proxyResponseTransformer.TransformResponse(r); err != nil { + return errors.WithStack(err) + } + + return nil + } + } + + p.reversers.Store(key, reverser) + } + + raw, exists := p.reversers.Load(key) + if !exists { + createAndStore() + } + + reverser, ok := raw.(*httputil.ReverseProxy) + if !ok { + createAndStore() + } + + reverser.ServeHTTP(w, r) +} + +func New(funcs ...OptionFunc) *Proxy { + opts := defaultOptions() + for _, fn := range funcs { + fn(opts) + } + + proxy := &Proxy{} + + handler := http.HandlerFunc(proxy.proxyRequest) + proxy.handler = createMiddlewareChain(handler, opts.Middlewares) + + proxy.proxyRequestTransformer = createProxyRequestChain(&defaultProxyRequestTransformer{}, opts.ProxyRequestMiddlewares) + proxy.proxyResponseTransformer = createProxyResponseChain(&defaultProxyResponseTransformer{}, opts.ProxyResponseMiddlewares) + + return proxy +} + +var _ http.Handler = &Proxy{} + +func createMiddlewareChain(handler http.Handler, middlewares []Middleware) http.Handler { + reverse(middlewares) + + for _, m := range middlewares { + handler = m(handler) + } + + return handler +} + +func createProxyResponseChain(transformer ProxyResponseTransformer, middlewares []ProxyResponseMiddleware) ProxyResponseTransformer { + reverse(middlewares) + + for _, m := range middlewares { + transformer = m(transformer) + } + + return transformer +} + +func createProxyRequestChain(transformer ProxyRequestTransformer, middlewares []ProxyRequestMiddleware) ProxyRequestTransformer { + reverse(middlewares) + + for _, m := range middlewares { + transformer = m(transformer) + } + + return transformer +} + +func reverse[S ~[]E, E any](s S) { + for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 { + s[i], s[j] = s[j], s[i] + } +} diff --git a/internal/proxy/wildcard/match.go b/internal/proxy/wildcard/match.go new file mode 100644 index 0000000..3c4ce51 --- /dev/null +++ b/internal/proxy/wildcard/match.go @@ -0,0 +1,44 @@ +package wildcard + +const wildcard = '*' + +func Match(str, pattern string) bool { + if pattern == "" { + return str == pattern + } + + if pattern == string(wildcard) { + return true + } + + return deepMatchRune([]rune(str), []rune(pattern)) +} + +func MatchAny(str string, patterns ...string) bool { + for _, p := range patterns { + if matches := Match(str, p); matches { + return matches + } + } + + return false +} + +func deepMatchRune(str, pattern []rune) bool { + for len(pattern) > 0 { + switch pattern[0] { + default: + if len(str) == 0 || str[0] != pattern[0] { + return false + } + case wildcard: + return deepMatchRune(str, pattern[1:]) || + (len(str) > 0 && deepMatchRune(str[1:], pattern)) + } + + str = str[1:] + pattern = pattern[1:] + } + + return len(str) == 0 && len(pattern) == 0 +}