feat: simplify middleware stack
Cadoles/go-proxy/pipeline/head This commit looks good Details

This commit is contained in:
wpetit 2023-04-24 20:14:31 +02:00
parent 56b7434498
commit e2dc3e1a03
4 changed files with 87 additions and 72 deletions

View File

@ -1,33 +1,11 @@
package proxy package proxy
import "net/http" import (
"net/http"
)
type Middleware func(h http.Handler) http.Handler type (
Middleware func(h http.Handler) http.Handler
type ProxyResponseTransformer interface { RequestTransformer func(r *http.Request)
TransformResponse(*http.Response) error ResponseTransformer func(r *http.Response) error
} )
type defaultProxyResponseTransformer struct{}
// TransformResponse implements ProxyResponseTransformer
func (*defaultProxyResponseTransformer) TransformResponse(*http.Response) error {
return nil
}
var _ ProxyResponseTransformer = &defaultProxyResponseTransformer{}
type ProxyResponseMiddleware func(ProxyResponseTransformer) ProxyResponseTransformer
type ProxyRequestTransformer interface {
TransformRequest(*http.Request)
}
type ProxyRequestMiddleware func(ProxyRequestTransformer) ProxyRequestTransformer
type defaultProxyRequestTransformer struct{}
// TransformRequest implements ProxyRequestTransformer
func (*defaultProxyRequestTransformer) TransformRequest(*http.Request) {}
var _ ProxyRequestTransformer = &defaultProxyRequestTransformer{}

View File

@ -1,29 +1,35 @@
package proxy package proxy
type Options struct { type Options struct {
Middlewares []Middleware Middlewares []Middleware
ProxyRequestMiddlewares []ProxyRequestMiddleware RequestTransformers []RequestTransformer
ProxyResponseMiddlewares []ProxyResponseMiddleware ResponseTransformers []ResponseTransformer
} }
func defaultOptions() *Options { func defaultOptions() *Options {
return &Options{ return &Options{
Middlewares: make([]Middleware, 0), Middlewares: make([]Middleware, 0),
ProxyRequestMiddlewares: make([]ProxyRequestMiddleware, 0), RequestTransformers: make([]RequestTransformer, 0),
ProxyResponseMiddlewares: make([]ProxyResponseMiddleware, 0), ResponseTransformers: make([]ResponseTransformer, 0),
} }
} }
type OptionFunc func(*Options) type OptionFunc func(*Options)
func WithProxyRequestMiddlewares(middlewares ...ProxyRequestMiddleware) OptionFunc { func WithMiddlewares(middlewares ...Middleware) OptionFunc {
return func(o *Options) { return func(o *Options) {
o.ProxyRequestMiddlewares = middlewares o.Middlewares = middlewares
} }
} }
func WithproxyResponseMiddlewares(middlewares ...ProxyResponseMiddleware) OptionFunc { func WithRequestTransformers(transformers ...RequestTransformer) OptionFunc {
return func(o *Options) { return func(o *Options) {
o.ProxyResponseMiddlewares = middlewares o.RequestTransformers = transformers
}
}
func WithResponseTransformers(transformers ...ResponseTransformer) OptionFunc {
return func(o *Options) {
o.ResponseTransformers = transformers
} }
} }

View File

@ -12,10 +12,10 @@ import (
) )
type Proxy struct { type Proxy struct {
reversers sync.Map reversers sync.Map
handler http.Handler handler http.Handler
proxyResponseTransformer ProxyResponseTransformer responseTransformers []ResponseTransformer
proxyRequestTransformer ProxyRequestTransformer requestTransformers []RequestTransformer
} }
// ServeHTTP implements http.Handler // ServeHTTP implements http.Handler
@ -34,21 +34,29 @@ func (p *Proxy) proxyRequest(w http.ResponseWriter, r *http.Request) {
Host: r.URL.Host, Host: r.URL.Host,
} }
if target.Host == "" || target.Scheme == "" {
return
}
reverser = httputil.NewSingleHostReverseProxy(target) reverser = httputil.NewSingleHostReverseProxy(target)
originalDirector := reverser.Director originalDirector := reverser.Director
if p.proxyRequestTransformer != nil { if len(p.requestTransformers) > 0 {
reverser.Director = func(r *http.Request) { reverser.Director = func(r *http.Request) {
originalDirector(r) originalDirector(r)
p.proxyRequestTransformer.TransformRequest(r) for _, t := range p.requestTransformers {
t(r)
}
} }
} }
if p.proxyResponseTransformer != nil { if len(p.responseTransformers) > 0 {
reverser.ModifyResponse = func(r *http.Response) error { reverser.ModifyResponse = func(r *http.Response) error {
if err := p.proxyResponseTransformer.TransformResponse(r); err != nil { for _, t := range p.responseTransformers {
return errors.WithStack(err) if err := t(r); err != nil {
return errors.WithStack(err)
}
} }
return nil return nil
@ -68,6 +76,12 @@ func (p *Proxy) proxyRequest(w http.ResponseWriter, r *http.Request) {
createAndStore() createAndStore()
} }
if reverser == nil {
http.Error(w, http.StatusText(http.StatusBadGateway), http.StatusBadGateway)
return
}
reverser.ServeHTTP(w, r) reverser.ServeHTTP(w, r)
} }
@ -81,9 +95,8 @@ func New(funcs ...OptionFunc) *Proxy {
handler := http.HandlerFunc(proxy.proxyRequest) handler := http.HandlerFunc(proxy.proxyRequest)
proxy.handler = createMiddlewareChain(handler, opts.Middlewares) proxy.handler = createMiddlewareChain(handler, opts.Middlewares)
proxy.requestTransformers = opts.RequestTransformers
proxy.proxyRequestTransformer = createProxyRequestChain(&defaultProxyRequestTransformer{}, opts.ProxyRequestMiddlewares) proxy.responseTransformers = opts.ResponseTransformers
proxy.proxyResponseTransformer = createProxyResponseChain(&defaultProxyResponseTransformer{}, opts.ProxyResponseMiddlewares)
return proxy return proxy
} }
@ -99,23 +112,3 @@ func createMiddlewareChain(handler http.Handler, middlewares []Middleware) http.
return handler return handler
} }
func createProxyResponseChain(transformer ProxyResponseTransformer, middlewares []ProxyResponseMiddleware) ProxyResponseTransformer {
util.Reverse(middlewares)
for _, m := range middlewares {
transformer = m(transformer)
}
return transformer
}
func createProxyRequestChain(transformer ProxyRequestTransformer, middlewares []ProxyRequestMiddleware) ProxyRequestTransformer {
util.Reverse(middlewares)
for _, m := range middlewares {
transformer = m(transformer)
}
return transformer
}

38
proxy_test.go Normal file
View File

@ -0,0 +1,38 @@
package proxy
import (
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/pkg/errors"
)
func TestProxy(t *testing.T) {
res := httptest.NewRecorder()
req, err := http.NewRequest("GET", "http://example.com", nil)
if err != nil {
t.Fatalf("%+v", errors.WithStack(err))
}
proxy := New()
proxy.ServeHTTP(res, req)
if e, g := http.StatusOK, res.Code; e != g {
t.Errorf("res.Code: expected '%v', got '%v'", e, g)
}
body, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Errorf("%+v", errors.WithStack(err))
}
marker := "<title>Example Domain</title>"
if !strings.Contains(string(body), marker) {
t.Errorf("could not find expected marker in response body")
}
}