feat: simplify middleware stack
Cadoles/go-proxy/pipeline/head This commit looks good Details

This commit is contained in:
wpetit 2023-04-24 20:14:31 +02:00
parent 56b7434498
commit e2dc3e1a03
4 changed files with 87 additions and 72 deletions

View File

@ -1,33 +1,11 @@
package proxy
import "net/http"
import (
"net/http"
)
type Middleware func(h http.Handler) http.Handler
type ProxyResponseTransformer interface {
TransformResponse(*http.Response) error
}
type defaultProxyResponseTransformer struct{}
// TransformResponse implements ProxyResponseTransformer
func (*defaultProxyResponseTransformer) TransformResponse(*http.Response) error {
return nil
}
var _ ProxyResponseTransformer = &defaultProxyResponseTransformer{}
type ProxyResponseMiddleware func(ProxyResponseTransformer) ProxyResponseTransformer
type ProxyRequestTransformer interface {
TransformRequest(*http.Request)
}
type ProxyRequestMiddleware func(ProxyRequestTransformer) ProxyRequestTransformer
type defaultProxyRequestTransformer struct{}
// TransformRequest implements ProxyRequestTransformer
func (*defaultProxyRequestTransformer) TransformRequest(*http.Request) {}
var _ ProxyRequestTransformer = &defaultProxyRequestTransformer{}
type (
Middleware func(h http.Handler) http.Handler
RequestTransformer func(r *http.Request)
ResponseTransformer func(r *http.Response) error
)

View File

@ -1,29 +1,35 @@
package proxy
type Options struct {
Middlewares []Middleware
ProxyRequestMiddlewares []ProxyRequestMiddleware
ProxyResponseMiddlewares []ProxyResponseMiddleware
Middlewares []Middleware
RequestTransformers []RequestTransformer
ResponseTransformers []ResponseTransformer
}
func defaultOptions() *Options {
return &Options{
Middlewares: make([]Middleware, 0),
ProxyRequestMiddlewares: make([]ProxyRequestMiddleware, 0),
ProxyResponseMiddlewares: make([]ProxyResponseMiddleware, 0),
Middlewares: make([]Middleware, 0),
RequestTransformers: make([]RequestTransformer, 0),
ResponseTransformers: make([]ResponseTransformer, 0),
}
}
type OptionFunc func(*Options)
func WithProxyRequestMiddlewares(middlewares ...ProxyRequestMiddleware) OptionFunc {
func WithMiddlewares(middlewares ...Middleware) OptionFunc {
return func(o *Options) {
o.ProxyRequestMiddlewares = middlewares
o.Middlewares = middlewares
}
}
func WithproxyResponseMiddlewares(middlewares ...ProxyResponseMiddleware) OptionFunc {
func WithRequestTransformers(transformers ...RequestTransformer) OptionFunc {
return func(o *Options) {
o.ProxyResponseMiddlewares = middlewares
o.RequestTransformers = transformers
}
}
func WithResponseTransformers(transformers ...ResponseTransformer) OptionFunc {
return func(o *Options) {
o.ResponseTransformers = transformers
}
}

View File

@ -12,10 +12,10 @@ import (
)
type Proxy struct {
reversers sync.Map
handler http.Handler
proxyResponseTransformer ProxyResponseTransformer
proxyRequestTransformer ProxyRequestTransformer
reversers sync.Map
handler http.Handler
responseTransformers []ResponseTransformer
requestTransformers []RequestTransformer
}
// ServeHTTP implements http.Handler
@ -34,21 +34,29 @@ func (p *Proxy) proxyRequest(w http.ResponseWriter, r *http.Request) {
Host: r.URL.Host,
}
if target.Host == "" || target.Scheme == "" {
return
}
reverser = httputil.NewSingleHostReverseProxy(target)
originalDirector := reverser.Director
if p.proxyRequestTransformer != nil {
if len(p.requestTransformers) > 0 {
reverser.Director = func(r *http.Request) {
originalDirector(r)
p.proxyRequestTransformer.TransformRequest(r)
for _, t := range p.requestTransformers {
t(r)
}
}
}
if p.proxyResponseTransformer != nil {
if len(p.responseTransformers) > 0 {
reverser.ModifyResponse = func(r *http.Response) error {
if err := p.proxyResponseTransformer.TransformResponse(r); err != nil {
return errors.WithStack(err)
for _, t := range p.responseTransformers {
if err := t(r); err != nil {
return errors.WithStack(err)
}
}
return nil
@ -68,6 +76,12 @@ func (p *Proxy) proxyRequest(w http.ResponseWriter, r *http.Request) {
createAndStore()
}
if reverser == nil {
http.Error(w, http.StatusText(http.StatusBadGateway), http.StatusBadGateway)
return
}
reverser.ServeHTTP(w, r)
}
@ -81,9 +95,8 @@ func New(funcs ...OptionFunc) *Proxy {
handler := http.HandlerFunc(proxy.proxyRequest)
proxy.handler = createMiddlewareChain(handler, opts.Middlewares)
proxy.proxyRequestTransformer = createProxyRequestChain(&defaultProxyRequestTransformer{}, opts.ProxyRequestMiddlewares)
proxy.proxyResponseTransformer = createProxyResponseChain(&defaultProxyResponseTransformer{}, opts.ProxyResponseMiddlewares)
proxy.requestTransformers = opts.RequestTransformers
proxy.responseTransformers = opts.ResponseTransformers
return proxy
}
@ -99,23 +112,3 @@ func createMiddlewareChain(handler http.Handler, middlewares []Middleware) http.
return handler
}
func createProxyResponseChain(transformer ProxyResponseTransformer, middlewares []ProxyResponseMiddleware) ProxyResponseTransformer {
util.Reverse(middlewares)
for _, m := range middlewares {
transformer = m(transformer)
}
return transformer
}
func createProxyRequestChain(transformer ProxyRequestTransformer, middlewares []ProxyRequestMiddleware) ProxyRequestTransformer {
util.Reverse(middlewares)
for _, m := range middlewares {
transformer = m(transformer)
}
return transformer
}

38
proxy_test.go Normal file
View File

@ -0,0 +1,38 @@
package proxy
import (
"io/ioutil"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/pkg/errors"
)
func TestProxy(t *testing.T) {
res := httptest.NewRecorder()
req, err := http.NewRequest("GET", "http://example.com", nil)
if err != nil {
t.Fatalf("%+v", errors.WithStack(err))
}
proxy := New()
proxy.ServeHTTP(res, req)
if e, g := http.StatusOK, res.Code; e != g {
t.Errorf("res.Code: expected '%v', got '%v'", e, g)
}
body, err := ioutil.ReadAll(res.Body)
if err != nil {
t.Errorf("%+v", errors.WithStack(err))
}
marker := "<title>Example Domain</title>"
if !strings.Contains(string(body), marker) {
t.Errorf("could not find expected marker in response body")
}
}