diff --git a/middleware.go b/middleware.go index b1efb0f..d61723e 100644 --- a/middleware.go +++ b/middleware.go @@ -1,33 +1,11 @@ package proxy -import "net/http" +import ( + "net/http" +) -type Middleware func(h http.Handler) http.Handler - -type ProxyResponseTransformer interface { - TransformResponse(*http.Response) error -} - -type defaultProxyResponseTransformer struct{} - -// TransformResponse implements ProxyResponseTransformer -func (*defaultProxyResponseTransformer) TransformResponse(*http.Response) error { - return nil -} - -var _ ProxyResponseTransformer = &defaultProxyResponseTransformer{} - -type ProxyResponseMiddleware func(ProxyResponseTransformer) ProxyResponseTransformer - -type ProxyRequestTransformer interface { - TransformRequest(*http.Request) -} - -type ProxyRequestMiddleware func(ProxyRequestTransformer) ProxyRequestTransformer - -type defaultProxyRequestTransformer struct{} - -// TransformRequest implements ProxyRequestTransformer -func (*defaultProxyRequestTransformer) TransformRequest(*http.Request) {} - -var _ ProxyRequestTransformer = &defaultProxyRequestTransformer{} +type ( + Middleware func(h http.Handler) http.Handler + RequestTransformer func(r *http.Request) + ResponseTransformer func(r *http.Response) error +) diff --git a/options.go b/options.go index 22a6945..3dd168f 100644 --- a/options.go +++ b/options.go @@ -1,29 +1,35 @@ package proxy type Options struct { - Middlewares []Middleware - ProxyRequestMiddlewares []ProxyRequestMiddleware - ProxyResponseMiddlewares []ProxyResponseMiddleware + Middlewares []Middleware + RequestTransformers []RequestTransformer + ResponseTransformers []ResponseTransformer } func defaultOptions() *Options { return &Options{ - Middlewares: make([]Middleware, 0), - ProxyRequestMiddlewares: make([]ProxyRequestMiddleware, 0), - ProxyResponseMiddlewares: make([]ProxyResponseMiddleware, 0), + Middlewares: make([]Middleware, 0), + RequestTransformers: make([]RequestTransformer, 0), + ResponseTransformers: make([]ResponseTransformer, 0), } } type OptionFunc func(*Options) -func WithProxyRequestMiddlewares(middlewares ...ProxyRequestMiddleware) OptionFunc { +func WithMiddlewares(middlewares ...Middleware) OptionFunc { return func(o *Options) { - o.ProxyRequestMiddlewares = middlewares + o.Middlewares = middlewares } } -func WithproxyResponseMiddlewares(middlewares ...ProxyResponseMiddleware) OptionFunc { +func WithRequestTransformers(transformers ...RequestTransformer) OptionFunc { return func(o *Options) { - o.ProxyResponseMiddlewares = middlewares + o.RequestTransformers = transformers + } +} + +func WithResponseTransformers(transformers ...ResponseTransformer) OptionFunc { + return func(o *Options) { + o.ResponseTransformers = transformers } } diff --git a/proxy.go b/proxy.go index c4ea2b7..04b518a 100644 --- a/proxy.go +++ b/proxy.go @@ -12,10 +12,10 @@ import ( ) type Proxy struct { - reversers sync.Map - handler http.Handler - proxyResponseTransformer ProxyResponseTransformer - proxyRequestTransformer ProxyRequestTransformer + reversers sync.Map + handler http.Handler + responseTransformers []ResponseTransformer + requestTransformers []RequestTransformer } // ServeHTTP implements http.Handler @@ -34,21 +34,29 @@ func (p *Proxy) proxyRequest(w http.ResponseWriter, r *http.Request) { Host: r.URL.Host, } + if target.Host == "" || target.Scheme == "" { + return + } + reverser = httputil.NewSingleHostReverseProxy(target) originalDirector := reverser.Director - if p.proxyRequestTransformer != nil { + if len(p.requestTransformers) > 0 { reverser.Director = func(r *http.Request) { originalDirector(r) - p.proxyRequestTransformer.TransformRequest(r) + for _, t := range p.requestTransformers { + t(r) + } } } - if p.proxyResponseTransformer != nil { + if len(p.responseTransformers) > 0 { reverser.ModifyResponse = func(r *http.Response) error { - if err := p.proxyResponseTransformer.TransformResponse(r); err != nil { - return errors.WithStack(err) + for _, t := range p.responseTransformers { + if err := t(r); err != nil { + return errors.WithStack(err) + } } return nil @@ -68,6 +76,12 @@ func (p *Proxy) proxyRequest(w http.ResponseWriter, r *http.Request) { createAndStore() } + if reverser == nil { + http.Error(w, http.StatusText(http.StatusBadGateway), http.StatusBadGateway) + + return + } + reverser.ServeHTTP(w, r) } @@ -81,9 +95,8 @@ func New(funcs ...OptionFunc) *Proxy { handler := http.HandlerFunc(proxy.proxyRequest) proxy.handler = createMiddlewareChain(handler, opts.Middlewares) - - proxy.proxyRequestTransformer = createProxyRequestChain(&defaultProxyRequestTransformer{}, opts.ProxyRequestMiddlewares) - proxy.proxyResponseTransformer = createProxyResponseChain(&defaultProxyResponseTransformer{}, opts.ProxyResponseMiddlewares) + proxy.requestTransformers = opts.RequestTransformers + proxy.responseTransformers = opts.ResponseTransformers return proxy } @@ -99,23 +112,3 @@ func createMiddlewareChain(handler http.Handler, middlewares []Middleware) http. return handler } - -func createProxyResponseChain(transformer ProxyResponseTransformer, middlewares []ProxyResponseMiddleware) ProxyResponseTransformer { - util.Reverse(middlewares) - - for _, m := range middlewares { - transformer = m(transformer) - } - - return transformer -} - -func createProxyRequestChain(transformer ProxyRequestTransformer, middlewares []ProxyRequestMiddleware) ProxyRequestTransformer { - util.Reverse(middlewares) - - for _, m := range middlewares { - transformer = m(transformer) - } - - return transformer -} diff --git a/proxy_test.go b/proxy_test.go new file mode 100644 index 0000000..3f4d4b5 --- /dev/null +++ b/proxy_test.go @@ -0,0 +1,38 @@ +package proxy + +import ( + "io/ioutil" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/pkg/errors" +) + +func TestProxy(t *testing.T) { + res := httptest.NewRecorder() + req, err := http.NewRequest("GET", "http://example.com", nil) + if err != nil { + t.Fatalf("%+v", errors.WithStack(err)) + } + + proxy := New() + + proxy.ServeHTTP(res, req) + + if e, g := http.StatusOK, res.Code; e != g { + t.Errorf("res.Code: expected '%v', got '%v'", e, g) + } + + body, err := ioutil.ReadAll(res.Body) + if err != nil { + t.Errorf("%+v", errors.WithStack(err)) + } + + marker := "Example Domain" + + if !strings.Contains(string(body), marker) { + t.Errorf("could not find expected marker in response body") + } +}