package proxy import ( "fmt" "net/http" "net/http/httputil" "net/url" "sync" "github.com/pkg/errors" "gitlab.com/wpetit/goweb/logger" ) type Proxy struct { reversers sync.Map handler http.Handler proxyResponseTransformer ProxyResponseTransformer proxyRequestTransformer ProxyRequestTransformer } // ServeHTTP implements http.Handler func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { p.handler.ServeHTTP(w, r) } func (p *Proxy) proxyRequest(w http.ResponseWriter, r *http.Request) { ctx := r.Context() var reverser *httputil.ReverseProxy key := fmt.Sprintf("%s://%s", r.URL.Scheme, r.URL.Host) createAndStore := func() { target := &url.URL{ Scheme: r.URL.Scheme, Host: r.URL.Host, } reverser = httputil.NewSingleHostReverseProxy(target) originalDirector := reverser.Director if p.proxyRequestTransformer != nil { reverser.Director = func(r *http.Request) { originalURL := r.URL.String() originalDirector(r) p.proxyRequestTransformer.TransformRequest(r) logger.Debug(ctx, "proxying request", logger.F("targetURL", r.URL.String()), logger.F("originalURL", originalURL)) } } if p.proxyResponseTransformer != nil { reverser.ModifyResponse = func(r *http.Response) error { if err := p.proxyResponseTransformer.TransformResponse(r); err != nil { return errors.WithStack(err) } return nil } } p.reversers.Store(key, reverser) } raw, exists := p.reversers.Load(key) if !exists { createAndStore() } reverser, ok := raw.(*httputil.ReverseProxy) if !ok { createAndStore() } reverser.ServeHTTP(w, r) } func New(funcs ...OptionFunc) *Proxy { opts := defaultOptions() for _, fn := range funcs { fn(opts) } proxy := &Proxy{} handler := http.HandlerFunc(proxy.proxyRequest) proxy.handler = createMiddlewareChain(handler, opts.Middlewares) proxy.proxyRequestTransformer = createProxyRequestChain(&defaultProxyRequestTransformer{}, opts.ProxyRequestMiddlewares) proxy.proxyResponseTransformer = createProxyResponseChain(&defaultProxyResponseTransformer{}, opts.ProxyResponseMiddlewares) return proxy } var _ http.Handler = &Proxy{} func createMiddlewareChain(handler http.Handler, middlewares []Middleware) http.Handler { reverse(middlewares) for _, m := range middlewares { handler = m(handler) } return handler } func createProxyResponseChain(transformer ProxyResponseTransformer, middlewares []ProxyResponseMiddleware) ProxyResponseTransformer { reverse(middlewares) for _, m := range middlewares { transformer = m(transformer) } return transformer } func createProxyRequestChain(transformer ProxyRequestTransformer, middlewares []ProxyRequestMiddleware) ProxyRequestTransformer { reverse(middlewares) for _, m := range middlewares { transformer = m(transformer) } return transformer } func reverse[S ~[]E, E any](s S) { for i, j := 0, len(s)-1; i < j; i, j = i+1, j-1 { s[i], s[j] = s[j], s[i] } }