package proxy import ( "fmt" "net/http" "net/http/httputil" "net/url" "sync" "forge.cadoles.com/Cadoles/go-proxy/util" "github.com/pkg/errors" ) type Proxy struct { reversers sync.Map handler http.Handler responseTransformers []ResponseTransformer requestTransformers []RequestTransformer reverseProxyFactory ReverseProxyFactory defaultHandler http.Handler } // ServeHTTP implements http.Handler func (p *Proxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { p.handler.ServeHTTP(w, r) } func (p *Proxy) proxyRequest(w http.ResponseWriter, r *http.Request) { var reverser *httputil.ReverseProxy key := fmt.Sprintf("%s://%s", r.URL.Scheme, r.URL.Host) ctx := r.Context() createAndStore := func() { target := &url.URL{ Scheme: r.URL.Scheme, Host: r.URL.Host, } if target.Host == "" || target.Scheme == "" { return } reverser = p.reverseProxyFactory(ctx, target) if reverser == nil { return } originalDirector := reverser.Director if len(p.requestTransformers) > 0 { reverser.Director = func(r *http.Request) { originalDirector(r) for _, t := range p.requestTransformers { t(r) } } } if len(p.responseTransformers) > 0 { reverser.ModifyResponse = func(r *http.Response) error { for _, t := range p.responseTransformers { if err := t(r); err != nil { return errors.WithStack(err) } } return nil } } p.reversers.Store(key, reverser) } raw, exists := p.reversers.Load(key) if !exists { createAndStore() } reverser, ok := raw.(*httputil.ReverseProxy) if !ok { createAndStore() } if reverser == nil { p.defaultHandler.ServeHTTP(w, r) return } reverser.ServeHTTP(w, r) } func New(funcs ...OptionFunc) *Proxy { opts := defaultOptions() for _, fn := range funcs { fn(opts) } proxy := &Proxy{} handler := http.HandlerFunc(proxy.proxyRequest) proxy.handler = createMiddlewareChain(handler, opts.Middlewares) proxy.requestTransformers = opts.RequestTransformers proxy.responseTransformers = opts.ResponseTransformers proxy.reverseProxyFactory = opts.ReverseProxyFactory proxy.defaultHandler = opts.DefaultHandler return proxy } var _ http.Handler = &Proxy{} func createMiddlewareChain(handler http.Handler, middlewares []Middleware) http.Handler { util.Reverse(middlewares) for _, m := range middlewares { handler = m(handler) } return handler }