feat: simplify middleware stack
All checks were successful
Cadoles/go-proxy/pipeline/head This commit looks good
All checks were successful
Cadoles/go-proxy/pipeline/head This commit looks good
This commit is contained in:
parent
56b7434498
commit
e2dc3e1a03
@ -1,33 +1,11 @@
|
||||
package proxy
|
||||
|
||||
import "net/http"
|
||||
import (
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type Middleware func(h http.Handler) http.Handler
|
||||
|
||||
type ProxyResponseTransformer interface {
|
||||
TransformResponse(*http.Response) error
|
||||
}
|
||||
|
||||
type defaultProxyResponseTransformer struct{}
|
||||
|
||||
// TransformResponse implements ProxyResponseTransformer
|
||||
func (*defaultProxyResponseTransformer) TransformResponse(*http.Response) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ ProxyResponseTransformer = &defaultProxyResponseTransformer{}
|
||||
|
||||
type ProxyResponseMiddleware func(ProxyResponseTransformer) ProxyResponseTransformer
|
||||
|
||||
type ProxyRequestTransformer interface {
|
||||
TransformRequest(*http.Request)
|
||||
}
|
||||
|
||||
type ProxyRequestMiddleware func(ProxyRequestTransformer) ProxyRequestTransformer
|
||||
|
||||
type defaultProxyRequestTransformer struct{}
|
||||
|
||||
// TransformRequest implements ProxyRequestTransformer
|
||||
func (*defaultProxyRequestTransformer) TransformRequest(*http.Request) {}
|
||||
|
||||
var _ ProxyRequestTransformer = &defaultProxyRequestTransformer{}
|
||||
type (
|
||||
Middleware func(h http.Handler) http.Handler
|
||||
RequestTransformer func(r *http.Request)
|
||||
ResponseTransformer func(r *http.Response) error
|
||||
)
|
||||
|
26
options.go
26
options.go
@ -1,29 +1,35 @@
|
||||
package proxy
|
||||
|
||||
type Options struct {
|
||||
Middlewares []Middleware
|
||||
ProxyRequestMiddlewares []ProxyRequestMiddleware
|
||||
ProxyResponseMiddlewares []ProxyResponseMiddleware
|
||||
Middlewares []Middleware
|
||||
RequestTransformers []RequestTransformer
|
||||
ResponseTransformers []ResponseTransformer
|
||||
}
|
||||
|
||||
func defaultOptions() *Options {
|
||||
return &Options{
|
||||
Middlewares: make([]Middleware, 0),
|
||||
ProxyRequestMiddlewares: make([]ProxyRequestMiddleware, 0),
|
||||
ProxyResponseMiddlewares: make([]ProxyResponseMiddleware, 0),
|
||||
Middlewares: make([]Middleware, 0),
|
||||
RequestTransformers: make([]RequestTransformer, 0),
|
||||
ResponseTransformers: make([]ResponseTransformer, 0),
|
||||
}
|
||||
}
|
||||
|
||||
type OptionFunc func(*Options)
|
||||
|
||||
func WithProxyRequestMiddlewares(middlewares ...ProxyRequestMiddleware) OptionFunc {
|
||||
func WithMiddlewares(middlewares ...Middleware) OptionFunc {
|
||||
return func(o *Options) {
|
||||
o.ProxyRequestMiddlewares = middlewares
|
||||
o.Middlewares = middlewares
|
||||
}
|
||||
}
|
||||
|
||||
func WithproxyResponseMiddlewares(middlewares ...ProxyResponseMiddleware) OptionFunc {
|
||||
func WithRequestTransformers(transformers ...RequestTransformer) OptionFunc {
|
||||
return func(o *Options) {
|
||||
o.ProxyResponseMiddlewares = middlewares
|
||||
o.RequestTransformers = transformers
|
||||
}
|
||||
}
|
||||
|
||||
func WithResponseTransformers(transformers ...ResponseTransformer) OptionFunc {
|
||||
return func(o *Options) {
|
||||
o.ResponseTransformers = transformers
|
||||
}
|
||||
}
|
||||
|
57
proxy.go
57
proxy.go
@ -12,10 +12,10 @@ import (
|
||||
)
|
||||
|
||||
type Proxy struct {
|
||||
reversers sync.Map
|
||||
handler http.Handler
|
||||
proxyResponseTransformer ProxyResponseTransformer
|
||||
proxyRequestTransformer ProxyRequestTransformer
|
||||
reversers sync.Map
|
||||
handler http.Handler
|
||||
responseTransformers []ResponseTransformer
|
||||
requestTransformers []RequestTransformer
|
||||
}
|
||||
|
||||
// ServeHTTP implements http.Handler
|
||||
@ -34,21 +34,29 @@ func (p *Proxy) proxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||
Host: r.URL.Host,
|
||||
}
|
||||
|
||||
if target.Host == "" || target.Scheme == "" {
|
||||
return
|
||||
}
|
||||
|
||||
reverser = httputil.NewSingleHostReverseProxy(target)
|
||||
|
||||
originalDirector := reverser.Director
|
||||
|
||||
if p.proxyRequestTransformer != nil {
|
||||
if len(p.requestTransformers) > 0 {
|
||||
reverser.Director = func(r *http.Request) {
|
||||
originalDirector(r)
|
||||
p.proxyRequestTransformer.TransformRequest(r)
|
||||
for _, t := range p.requestTransformers {
|
||||
t(r)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if p.proxyResponseTransformer != nil {
|
||||
if len(p.responseTransformers) > 0 {
|
||||
reverser.ModifyResponse = func(r *http.Response) error {
|
||||
if err := p.proxyResponseTransformer.TransformResponse(r); err != nil {
|
||||
return errors.WithStack(err)
|
||||
for _, t := range p.responseTransformers {
|
||||
if err := t(r); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
@ -68,6 +76,12 @@ func (p *Proxy) proxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||
createAndStore()
|
||||
}
|
||||
|
||||
if reverser == nil {
|
||||
http.Error(w, http.StatusText(http.StatusBadGateway), http.StatusBadGateway)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
reverser.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
@ -81,9 +95,8 @@ func New(funcs ...OptionFunc) *Proxy {
|
||||
|
||||
handler := http.HandlerFunc(proxy.proxyRequest)
|
||||
proxy.handler = createMiddlewareChain(handler, opts.Middlewares)
|
||||
|
||||
proxy.proxyRequestTransformer = createProxyRequestChain(&defaultProxyRequestTransformer{}, opts.ProxyRequestMiddlewares)
|
||||
proxy.proxyResponseTransformer = createProxyResponseChain(&defaultProxyResponseTransformer{}, opts.ProxyResponseMiddlewares)
|
||||
proxy.requestTransformers = opts.RequestTransformers
|
||||
proxy.responseTransformers = opts.ResponseTransformers
|
||||
|
||||
return proxy
|
||||
}
|
||||
@ -99,23 +112,3 @@ func createMiddlewareChain(handler http.Handler, middlewares []Middleware) http.
|
||||
|
||||
return handler
|
||||
}
|
||||
|
||||
func createProxyResponseChain(transformer ProxyResponseTransformer, middlewares []ProxyResponseMiddleware) ProxyResponseTransformer {
|
||||
util.Reverse(middlewares)
|
||||
|
||||
for _, m := range middlewares {
|
||||
transformer = m(transformer)
|
||||
}
|
||||
|
||||
return transformer
|
||||
}
|
||||
|
||||
func createProxyRequestChain(transformer ProxyRequestTransformer, middlewares []ProxyRequestMiddleware) ProxyRequestTransformer {
|
||||
util.Reverse(middlewares)
|
||||
|
||||
for _, m := range middlewares {
|
||||
transformer = m(transformer)
|
||||
}
|
||||
|
||||
return transformer
|
||||
}
|
||||
|
38
proxy_test.go
Normal file
38
proxy_test.go
Normal file
@ -0,0 +1,38 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func TestProxy(t *testing.T) {
|
||||
res := httptest.NewRecorder()
|
||||
req, err := http.NewRequest("GET", "http://example.com", nil)
|
||||
if err != nil {
|
||||
t.Fatalf("%+v", errors.WithStack(err))
|
||||
}
|
||||
|
||||
proxy := New()
|
||||
|
||||
proxy.ServeHTTP(res, req)
|
||||
|
||||
if e, g := http.StatusOK, res.Code; e != g {
|
||||
t.Errorf("res.Code: expected '%v', got '%v'", e, g)
|
||||
}
|
||||
|
||||
body, err := ioutil.ReadAll(res.Body)
|
||||
if err != nil {
|
||||
t.Errorf("%+v", errors.WithStack(err))
|
||||
}
|
||||
|
||||
marker := "<title>Example Domain</title>"
|
||||
|
||||
if !strings.Contains(string(body), marker) {
|
||||
t.Errorf("could not find expected marker in response body")
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user