bouncer/internal/proxy/director/director.go

245 lines
5.3 KiB
Go
Raw Normal View History

2023-04-24 20:52:12 +02:00
package director
import (
"context"
"net/http"
"sort"
"forge.cadoles.com/Cadoles/go-proxy"
"forge.cadoles.com/Cadoles/go-proxy/wildcard"
"forge.cadoles.com/cadoles/bouncer/internal/store"
"github.com/pkg/errors"
"github.com/prometheus/client_golang/prometheus"
2023-04-24 20:52:12 +02:00
"gitlab.com/wpetit/goweb/logger"
)
type Director struct {
proxyRepository store.ProxyRepository
layerRepository store.LayerRepository
layerRegistry *LayerRegistry
}
func (d *Director) rewriteRequest(r *http.Request) (*http.Request, error) {
ctx := r.Context()
proxies, err := d.getProxies(ctx)
if err != nil {
return r, errors.WithStack(err)
}
url := getRequestURL(r)
ctx = logger.With(r.Context(), logger.F("url", url.String()))
2023-04-24 20:52:12 +02:00
var match *store.Proxy
MAIN:
for _, p := range proxies {
for _, from := range p.From {
logger.Debug(
ctx, "matching request with proxy's from",
logger.F("from", from),
)
if matches := wildcard.Match(url.String(), from); !matches {
logger.Debug(
ctx, "proxy's from matched",
logger.F("from", from),
)
2023-04-24 20:52:12 +02:00
continue
}
match = p
break MAIN
}
}
if match == nil {
return r, nil
}
toURL, err := url.Parse(match.To)
if err != nil {
return r, errors.WithStack(err)
}
r.URL.Host = toURL.Host
r.URL.Scheme = toURL.Scheme
ctx = logger.With(ctx,
logger.F("proxy", match.Name),
logger.F("host", r.Host),
logger.F("remoteAddr", r.RemoteAddr),
)
metricProxyRequestsTotal.With(prometheus.Labels{metricLabelProxy: string(match.Name)}).Add(1)
2023-04-24 20:52:12 +02:00
ctx = withProxy(ctx, match)
layers, err := d.getLayers(ctx, match.Name)
if err != nil {
return r, errors.WithStack(err)
}
ctx = withLayers(ctx, layers)
r = r.WithContext(ctx)
return r, nil
}
func (d *Director) getProxies(ctx context.Context) ([]*store.Proxy, error) {
headers, err := d.proxyRepository.QueryProxy(ctx, store.WithProxyQueryEnabled(true))
if err != nil {
return nil, errors.WithStack(err)
}
sort.Sort(store.ByProxyWeight(headers))
proxies := make([]*store.Proxy, 0, len(headers))
for _, h := range headers {
if !h.Enabled {
continue
}
proxy, err := d.proxyRepository.GetProxy(ctx, h.Name)
if err != nil {
return nil, errors.WithStack(err)
}
proxies = append(proxies, proxy)
}
return proxies, nil
}
func (d *Director) getLayers(ctx context.Context, proxyName store.ProxyName) ([]*store.Layer, error) {
headers, err := d.layerRepository.QueryLayers(ctx, proxyName, store.WithLayerQueryEnabled(true))
if err != nil {
return nil, errors.WithStack(err)
}
sort.Sort(store.ByLayerWeight(headers))
layers := make([]*store.Layer, 0, len(headers))
for _, h := range headers {
if !h.Enabled {
continue
}
layer, err := d.layerRepository.GetLayer(ctx, proxyName, h.Name)
if err != nil {
return nil, errors.WithStack(err)
}
layers = append(layers, layer)
}
return layers, nil
}
func (d *Director) RequestTransformer() proxy.RequestTransformer {
return func(r *http.Request) {
ctx := r.Context()
layers, err := ctxLayers(ctx)
if err != nil {
if errors.Is(err, errContextKeyNotFound) {
return
}
logger.Error(ctx, "could not retrieve layers from context", logger.E(errors.WithStack(err)))
return
}
for _, layer := range layers {
transformerLayer, ok := d.layerRegistry.GetRequestTransformer(layer.Type)
if !ok {
continue
}
transformer := transformerLayer.RequestTransformer(layer)
transformer(r)
}
}
}
func (d *Director) ResponseTransformer() proxy.ResponseTransformer {
return func(r *http.Response) error {
ctx := r.Request.Context()
layers, err := ctxLayers(ctx)
if err != nil {
if errors.Is(err, errContextKeyNotFound) {
return nil
}
return errors.WithStack(err)
}
for _, layer := range layers {
transformerLayer, ok := d.layerRegistry.GetResponseTransformer(layer.Type)
if !ok {
continue
}
transformer := transformerLayer.ResponseTransformer(layer)
if err := transformer(r); err != nil {
return errors.WithStack(err)
}
}
return nil
}
}
func (d *Director) Middleware() proxy.Middleware {
return func(next http.Handler) http.Handler {
fn := func(w http.ResponseWriter, r *http.Request) {
r, err := d.rewriteRequest(r)
if err != nil {
logger.Error(r.Context(), "could not rewrite request", logger.E(errors.WithStack(err)))
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
ctx := r.Context()
layers, err := ctxLayers(ctx)
if err != nil {
if errors.Is(err, errContextKeyNotFound) {
return
}
logger.Error(ctx, "could not retrieve proxy and layers from context", logger.E(errors.WithStack(err)))
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
httpMiddlewares := make([]proxy.Middleware, 0)
for _, layer := range layers {
middleware, ok := d.layerRegistry.GetMiddleware(layer.Type)
if !ok {
continue
}
httpMiddlewares = append(httpMiddlewares, middleware.Middleware(layer))
}
handler := createMiddlewareChain(next, httpMiddlewares)
handler.ServeHTTP(w, r)
}
return http.HandlerFunc(fn)
}
}
func New(proxyRepository store.ProxyRepository, layerRepository store.LayerRepository, layers ...Layer) *Director {
registry := NewLayerRegistry(layers...)
return &Director{proxyRepository, layerRepository, registry}
}