diff --git a/pkg/proxy/host_filter.go b/pkg/proxy/host_filter.go new file mode 100644 index 0000000..713968a --- /dev/null +++ b/pkg/proxy/host_filter.go @@ -0,0 +1,29 @@ +package proxy + +import ( + "net/http" + + "forge.cadoles.com/arcad/edge/pkg/proxy/wildcard" +) + +func FilterHosts(allowedHostPatterns ...string) Middleware { + return func(h http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + if matches := wildcard.MatchAny(r.Host, allowedHostPatterns...); !matches { + http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden) + + return + } + + h.ServeHTTP(w, r) + } + + return http.HandlerFunc(fn) + } +} + +func WithAllowedHosts(allowedHostPatterns ...string) OptionFunc { + return func(o *Options) { + o.Middlewares = append(o.Middlewares, FilterHosts(allowedHostPatterns...)) + } +} diff --git a/pkg/proxy/host_rewrite.go b/pkg/proxy/host_rewrite.go new file mode 100644 index 0000000..2a55bd8 --- /dev/null +++ b/pkg/proxy/host_rewrite.go @@ -0,0 +1,65 @@ +package proxy + +import ( + "net/http" + "net/url" + "sort" + + "forge.cadoles.com/arcad/edge/pkg/proxy/wildcard" + "gitlab.com/wpetit/goweb/logger" +) + +func RewriteHosts(mappings map[string]*url.URL) Middleware { + patterns := make([]string, len(mappings)) + + for p := range mappings { + patterns = append(patterns, p) + } + + sort.Strings(patterns) + + return func(h http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + var match *url.URL + + for _, p := range patterns { + logger.Debug(ctx, "matching host to pattern", logger.F("host", r.Host), logger.F("pattern", p)) + + if matches := wildcard.Match(r.Host, p); !matches { + continue + } + + match = mappings[p] + break + } + + if match == nil { + h.ServeHTTP(w, r) + + return + } + + ctx = logger.With(ctx, logger.F("originalHost", r.Host)) + r = r.WithContext(ctx) + + originalURL := r.URL.String() + + r.URL.Host = match.Host + r.URL.Scheme = match.Scheme + + logger.Debug(ctx, "rewriting url", logger.F("from", originalURL), logger.F("to", r.URL.String())) + + h.ServeHTTP(w, r) + } + + return http.HandlerFunc(fn) + } +} + +func WithRewriteHosts(mappings map[string]*url.URL) OptionFunc { + return func(o *Options) { + o.Middlewares = append(o.Middlewares, RewriteHosts(mappings)) + } +} diff --git a/pkg/proxy/middleware.go b/pkg/proxy/middleware.go new file mode 100644 index 0000000..b1efb0f --- /dev/null +++ b/pkg/proxy/middleware.go @@ -0,0 +1,33 @@ +package proxy + +import "net/http" + +type Middleware func(h http.Handler) http.Handler + +type ProxyResponseTransformer interface { + TransformResponse(*http.Response) error +} + +type defaultProxyResponseTransformer struct{} + +// TransformResponse implements ProxyResponseTransformer +func (*defaultProxyResponseTransformer) TransformResponse(*http.Response) error { + return nil +} + +var _ ProxyResponseTransformer = &defaultProxyResponseTransformer{} + +type ProxyResponseMiddleware func(ProxyResponseTransformer) ProxyResponseTransformer + +type ProxyRequestTransformer interface { + TransformRequest(*http.Request) +} + +type ProxyRequestMiddleware func(ProxyRequestTransformer) ProxyRequestTransformer + +type defaultProxyRequestTransformer struct{} + +// TransformRequest implements ProxyRequestTransformer +func (*defaultProxyRequestTransformer) TransformRequest(*http.Request) {} + +var _ ProxyRequestTransformer = &defaultProxyRequestTransformer{} diff --git a/pkg/proxy/options.go b/pkg/proxy/options.go new file mode 100644 index 0000000..22a6945 --- /dev/null +++ b/pkg/proxy/options.go @@ -0,0 +1,29 @@ +package proxy + +type Options struct { + Middlewares []Middleware + ProxyRequestMiddlewares []ProxyRequestMiddleware + ProxyResponseMiddlewares []ProxyResponseMiddleware +} + +func defaultOptions() *Options { + return &Options{ + Middlewares: make([]Middleware, 0), + ProxyRequestMiddlewares: make([]ProxyRequestMiddleware, 0), + ProxyResponseMiddlewares: make([]ProxyResponseMiddleware, 0), + } +} + +type OptionFunc func(*Options) + +func WithProxyRequestMiddlewares(middlewares ...ProxyRequestMiddleware) OptionFunc { + return func(o *Options) { + o.ProxyRequestMiddlewares = middlewares + } +} + +func WithproxyResponseMiddlewares(middlewares ...ProxyResponseMiddleware) OptionFunc { + return func(o *Options) { + o.ProxyResponseMiddlewares = middlewares + } +} diff --git a/pkg/proxy/proxy.go b/pkg/proxy/proxy.go new file mode 100644 index 0000000..e932199 --- /dev/null +++ b/pkg/proxy/proxy.go @@ -0,0 +1,131 @@ +package proxy + +import ( + "fmt" + "net/http" + "net/http/httputil" + "net/url" + "sync" + + "github.com/pkg/errors" + "gitlab.com/wpetit/goweb/logger" +) + +type Proxy struct { + reversers sync.Map + handler http.Handler + proxyResponseTransformer ProxyResponseTransformer + proxyRequestTransformer ProxyRequestTransformer +} + +// ServeHTTP implements http.Handler +func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { + p.handler.ServeHTTP(w, r) +} + +func (p *Proxy) proxyRequest(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + var reverser *httputil.ReverseProxy + + key := fmt.Sprintf("%s://%s", r.URL.Scheme, r.URL.Host) + + createAndStore := func() { + target := &url.URL{ + Scheme: r.URL.Scheme, + Host: r.URL.Host, + } + + reverser = httputil.NewSingleHostReverseProxy(target) + + originalDirector := reverser.Director + + if p.proxyRequestTransformer != nil { + reverser.Director = func(r *http.Request) { + originalURL := r.URL.String() + originalDirector(r) + p.proxyRequestTransformer.TransformRequest(r) + logger.Debug(ctx, "proxying request", logger.F("targetURL", r.URL.String()), logger.F("originalURL", originalURL)) + } + } + + if p.proxyResponseTransformer != nil { + reverser.ModifyResponse = func(r *http.Response) error { + if err := p.proxyResponseTransformer.TransformResponse(r); err != nil { + return errors.WithStack(err) + } + + return nil + } + } + + p.reversers.Store(key, reverser) + } + + raw, exists := p.reversers.Load(key) + if !exists { + createAndStore() + } + + reverser, ok := raw.(*httputil.ReverseProxy) + if !ok { + createAndStore() + } + + reverser.ServeHTTP(w, r) +} + +func New(funcs ...OptionFunc) *Proxy { + opts := defaultOptions() + for _, fn := range funcs { + fn(opts) + } + + proxy := &Proxy{} + + handler := http.HandlerFunc(proxy.proxyRequest) + proxy.handler = createMiddlewareChain(handler, opts.Middlewares) + + proxy.proxyRequestTransformer = createProxyRequestChain(&defaultProxyRequestTransformer{}, opts.ProxyRequestMiddlewares) + proxy.proxyResponseTransformer = createProxyResponseChain(&defaultProxyResponseTransformer{}, opts.ProxyResponseMiddlewares) + + return proxy +} + +var _ http.Handler = &Proxy{} + +func createMiddlewareChain(handler http.Handler, middlewares []Middleware) http.Handler { + reverse(middlewares) + + for _, m := range middlewares { + handler = m(handler) + } + + return handler +} + +func createProxyResponseChain(transformer ProxyResponseTransformer, middlewares []ProxyResponseMiddleware) ProxyResponseTransformer { + reverse(middlewares) + + for _, m := range middlewares { + transformer = m(transformer) + } + + return transformer +} + +func createProxyRequestChain(transformer ProxyRequestTransformer, middlewares []ProxyRequestMiddleware) ProxyRequestTransformer { + reverse(middlewares) + + for _, m := range middlewares { + transformer = m(transformer) + } + + return transformer +} + +func reverse[S ~[]E, E any](s S) { + for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 { + s[i], s[j] = s[j], s[i] + } +} diff --git a/pkg/proxy/wildcard/match.go b/pkg/proxy/wildcard/match.go new file mode 100644 index 0000000..3c4ce51 --- /dev/null +++ b/pkg/proxy/wildcard/match.go @@ -0,0 +1,44 @@ +package wildcard + +const wildcard = '*' + +func Match(str, pattern string) bool { + if pattern == "" { + return str == pattern + } + + if pattern == string(wildcard) { + return true + } + + return deepMatchRune([]rune(str), []rune(pattern)) +} + +func MatchAny(str string, patterns ...string) bool { + for _, p := range patterns { + if matches := Match(str, p); matches { + return matches + } + } + + return false +} + +func deepMatchRune(str, pattern []rune) bool { + for len(pattern) > 0 { + switch pattern[0] { + default: + if len(str) == 0 || str[0] != pattern[0] { + return false + } + case wildcard: + return deepMatchRune(str, pattern[1:]) || + (len(str) > 0 && deepMatchRune(str[1:], pattern)) + } + + str = str[1:] + pattern = pattern[1:] + } + + return len(str) == 0 && len(pattern) == 0 +}