feat: simplify middleware stack
Cadoles/go-proxy/pipeline/head This commit looks good
Details
Cadoles/go-proxy/pipeline/head This commit looks good
Details
This commit is contained in:
parent
56b7434498
commit
e2dc3e1a03
|
@ -1,33 +1,11 @@
|
||||||
package proxy
|
package proxy
|
||||||
|
|
||||||
import "net/http"
|
import (
|
||||||
|
"net/http"
|
||||||
|
)
|
||||||
|
|
||||||
type Middleware func(h http.Handler) http.Handler
|
type (
|
||||||
|
Middleware func(h http.Handler) http.Handler
|
||||||
type ProxyResponseTransformer interface {
|
RequestTransformer func(r *http.Request)
|
||||||
TransformResponse(*http.Response) error
|
ResponseTransformer func(r *http.Response) error
|
||||||
}
|
)
|
||||||
|
|
||||||
type defaultProxyResponseTransformer struct{}
|
|
||||||
|
|
||||||
// TransformResponse implements ProxyResponseTransformer
|
|
||||||
func (*defaultProxyResponseTransformer) TransformResponse(*http.Response) error {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
var _ ProxyResponseTransformer = &defaultProxyResponseTransformer{}
|
|
||||||
|
|
||||||
type ProxyResponseMiddleware func(ProxyResponseTransformer) ProxyResponseTransformer
|
|
||||||
|
|
||||||
type ProxyRequestTransformer interface {
|
|
||||||
TransformRequest(*http.Request)
|
|
||||||
}
|
|
||||||
|
|
||||||
type ProxyRequestMiddleware func(ProxyRequestTransformer) ProxyRequestTransformer
|
|
||||||
|
|
||||||
type defaultProxyRequestTransformer struct{}
|
|
||||||
|
|
||||||
// TransformRequest implements ProxyRequestTransformer
|
|
||||||
func (*defaultProxyRequestTransformer) TransformRequest(*http.Request) {}
|
|
||||||
|
|
||||||
var _ ProxyRequestTransformer = &defaultProxyRequestTransformer{}
|
|
||||||
|
|
22
options.go
22
options.go
|
@ -2,28 +2,34 @@ package proxy
|
||||||
|
|
||||||
type Options struct {
|
type Options struct {
|
||||||
Middlewares []Middleware
|
Middlewares []Middleware
|
||||||
ProxyRequestMiddlewares []ProxyRequestMiddleware
|
RequestTransformers []RequestTransformer
|
||||||
ProxyResponseMiddlewares []ProxyResponseMiddleware
|
ResponseTransformers []ResponseTransformer
|
||||||
}
|
}
|
||||||
|
|
||||||
func defaultOptions() *Options {
|
func defaultOptions() *Options {
|
||||||
return &Options{
|
return &Options{
|
||||||
Middlewares: make([]Middleware, 0),
|
Middlewares: make([]Middleware, 0),
|
||||||
ProxyRequestMiddlewares: make([]ProxyRequestMiddleware, 0),
|
RequestTransformers: make([]RequestTransformer, 0),
|
||||||
ProxyResponseMiddlewares: make([]ProxyResponseMiddleware, 0),
|
ResponseTransformers: make([]ResponseTransformer, 0),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type OptionFunc func(*Options)
|
type OptionFunc func(*Options)
|
||||||
|
|
||||||
func WithProxyRequestMiddlewares(middlewares ...ProxyRequestMiddleware) OptionFunc {
|
func WithMiddlewares(middlewares ...Middleware) OptionFunc {
|
||||||
return func(o *Options) {
|
return func(o *Options) {
|
||||||
o.ProxyRequestMiddlewares = middlewares
|
o.Middlewares = middlewares
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithproxyResponseMiddlewares(middlewares ...ProxyResponseMiddleware) OptionFunc {
|
func WithRequestTransformers(transformers ...RequestTransformer) OptionFunc {
|
||||||
return func(o *Options) {
|
return func(o *Options) {
|
||||||
o.ProxyResponseMiddlewares = middlewares
|
o.RequestTransformers = transformers
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithResponseTransformers(transformers ...ResponseTransformer) OptionFunc {
|
||||||
|
return func(o *Options) {
|
||||||
|
o.ResponseTransformers = transformers
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
51
proxy.go
51
proxy.go
|
@ -14,8 +14,8 @@ import (
|
||||||
type Proxy struct {
|
type Proxy struct {
|
||||||
reversers sync.Map
|
reversers sync.Map
|
||||||
handler http.Handler
|
handler http.Handler
|
||||||
proxyResponseTransformer ProxyResponseTransformer
|
responseTransformers []ResponseTransformer
|
||||||
proxyRequestTransformer ProxyRequestTransformer
|
requestTransformers []RequestTransformer
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServeHTTP implements http.Handler
|
// ServeHTTP implements http.Handler
|
||||||
|
@ -34,22 +34,30 @@ func (p *Proxy) proxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||||
Host: r.URL.Host,
|
Host: r.URL.Host,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if target.Host == "" || target.Scheme == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
reverser = httputil.NewSingleHostReverseProxy(target)
|
reverser = httputil.NewSingleHostReverseProxy(target)
|
||||||
|
|
||||||
originalDirector := reverser.Director
|
originalDirector := reverser.Director
|
||||||
|
|
||||||
if p.proxyRequestTransformer != nil {
|
if len(p.requestTransformers) > 0 {
|
||||||
reverser.Director = func(r *http.Request) {
|
reverser.Director = func(r *http.Request) {
|
||||||
originalDirector(r)
|
originalDirector(r)
|
||||||
p.proxyRequestTransformer.TransformRequest(r)
|
for _, t := range p.requestTransformers {
|
||||||
|
t(r)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if p.proxyResponseTransformer != nil {
|
if len(p.responseTransformers) > 0 {
|
||||||
reverser.ModifyResponse = func(r *http.Response) error {
|
reverser.ModifyResponse = func(r *http.Response) error {
|
||||||
if err := p.proxyResponseTransformer.TransformResponse(r); err != nil {
|
for _, t := range p.responseTransformers {
|
||||||
|
if err := t(r); err != nil {
|
||||||
return errors.WithStack(err)
|
return errors.WithStack(err)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
@ -68,6 +76,12 @@ func (p *Proxy) proxyRequest(w http.ResponseWriter, r *http.Request) {
|
||||||
createAndStore()
|
createAndStore()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if reverser == nil {
|
||||||
|
http.Error(w, http.StatusText(http.StatusBadGateway), http.StatusBadGateway)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
reverser.ServeHTTP(w, r)
|
reverser.ServeHTTP(w, r)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -81,9 +95,8 @@ func New(funcs ...OptionFunc) *Proxy {
|
||||||
|
|
||||||
handler := http.HandlerFunc(proxy.proxyRequest)
|
handler := http.HandlerFunc(proxy.proxyRequest)
|
||||||
proxy.handler = createMiddlewareChain(handler, opts.Middlewares)
|
proxy.handler = createMiddlewareChain(handler, opts.Middlewares)
|
||||||
|
proxy.requestTransformers = opts.RequestTransformers
|
||||||
proxy.proxyRequestTransformer = createProxyRequestChain(&defaultProxyRequestTransformer{}, opts.ProxyRequestMiddlewares)
|
proxy.responseTransformers = opts.ResponseTransformers
|
||||||
proxy.proxyResponseTransformer = createProxyResponseChain(&defaultProxyResponseTransformer{}, opts.ProxyResponseMiddlewares)
|
|
||||||
|
|
||||||
return proxy
|
return proxy
|
||||||
}
|
}
|
||||||
|
@ -99,23 +112,3 @@ func createMiddlewareChain(handler http.Handler, middlewares []Middleware) http.
|
||||||
|
|
||||||
return handler
|
return handler
|
||||||
}
|
}
|
||||||
|
|
||||||
func createProxyResponseChain(transformer ProxyResponseTransformer, middlewares []ProxyResponseMiddleware) ProxyResponseTransformer {
|
|
||||||
util.Reverse(middlewares)
|
|
||||||
|
|
||||||
for _, m := range middlewares {
|
|
||||||
transformer = m(transformer)
|
|
||||||
}
|
|
||||||
|
|
||||||
return transformer
|
|
||||||
}
|
|
||||||
|
|
||||||
func createProxyRequestChain(transformer ProxyRequestTransformer, middlewares []ProxyRequestMiddleware) ProxyRequestTransformer {
|
|
||||||
util.Reverse(middlewares)
|
|
||||||
|
|
||||||
for _, m := range middlewares {
|
|
||||||
transformer = m(transformer)
|
|
||||||
}
|
|
||||||
|
|
||||||
return transformer
|
|
||||||
}
|
|
||||||
|
|
|
@ -0,0 +1,38 @@
|
||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io/ioutil"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestProxy(t *testing.T) {
|
||||||
|
res := httptest.NewRecorder()
|
||||||
|
req, err := http.NewRequest("GET", "http://example.com", nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("%+v", errors.WithStack(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy := New()
|
||||||
|
|
||||||
|
proxy.ServeHTTP(res, req)
|
||||||
|
|
||||||
|
if e, g := http.StatusOK, res.Code; e != g {
|
||||||
|
t.Errorf("res.Code: expected '%v', got '%v'", e, g)
|
||||||
|
}
|
||||||
|
|
||||||
|
body, err := ioutil.ReadAll(res.Body)
|
||||||
|
if err != nil {
|
||||||
|
t.Errorf("%+v", errors.WithStack(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
marker := "<title>Example Domain</title>"
|
||||||
|
|
||||||
|
if !strings.Contains(string(body), marker) {
|
||||||
|
t.Errorf("could not find expected marker in response body")
|
||||||
|
}
|
||||||
|
}
|
Loading…
Reference in New Issue