273 lines
6.0 KiB
Go
273 lines
6.0 KiB
Go
package director
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"sort"
|
|
|
|
"forge.cadoles.com/Cadoles/go-proxy"
|
|
"forge.cadoles.com/Cadoles/go-proxy/wildcard"
|
|
"forge.cadoles.com/cadoles/bouncer/internal/cache"
|
|
"forge.cadoles.com/cadoles/bouncer/internal/store"
|
|
"github.com/pkg/errors"
|
|
"github.com/prometheus/client_golang/prometheus"
|
|
"gitlab.com/wpetit/goweb/logger"
|
|
)
|
|
|
|
type Director struct {
|
|
proxyRepository store.ProxyRepository
|
|
layerRepository store.LayerRepository
|
|
layerRegistry *LayerRegistry
|
|
|
|
proxyCache cache.Cache[string, []*store.Proxy]
|
|
layerCache cache.Cache[string, []*store.Layer]
|
|
|
|
handleError HandleErrorFunc
|
|
}
|
|
|
|
func (d *Director) rewriteRequest(r *http.Request) (*http.Request, error) {
|
|
ctx := r.Context()
|
|
|
|
proxies, err := d.getProxies(ctx)
|
|
if err != nil {
|
|
return r, errors.WithStack(err)
|
|
}
|
|
|
|
url := getRequestURL(r)
|
|
ctx = withOriginalURL(ctx, url)
|
|
ctx = logger.With(ctx, logger.F("url", url.String()))
|
|
|
|
layers := make([]*store.Layer, 0)
|
|
|
|
for _, p := range proxies {
|
|
for _, from := range p.From {
|
|
if matches := wildcard.Match(url.String(), from); !matches {
|
|
continue
|
|
}
|
|
|
|
proxyCtx := logger.With(ctx,
|
|
logger.F("proxy", p.Name),
|
|
logger.F("host", r.Host),
|
|
logger.F("remoteAddr", r.RemoteAddr),
|
|
)
|
|
|
|
metricProxyRequestsTotal.With(prometheus.Labels{metricLabelProxy: string(p.Name)}).Add(1)
|
|
|
|
proxyLayers, err := d.getLayers(proxyCtx, p.Name)
|
|
if err != nil {
|
|
return r, errors.WithStack(err)
|
|
}
|
|
|
|
layers = append(layers, proxyLayers...)
|
|
|
|
if p.To == "" {
|
|
continue
|
|
}
|
|
|
|
toURL, err := url.Parse(p.To)
|
|
if err != nil {
|
|
return r, errors.WithStack(err)
|
|
}
|
|
|
|
r.URL.Host = toURL.Host
|
|
r.URL.Scheme = toURL.Scheme
|
|
r.URL.Path = toURL.JoinPath(r.URL.Path).Path
|
|
|
|
proxyCtx = withLayers(proxyCtx, layers)
|
|
r = r.WithContext(proxyCtx)
|
|
|
|
return r, nil
|
|
}
|
|
}
|
|
|
|
ctx = withLayers(ctx, layers)
|
|
r = r.WithContext(ctx)
|
|
|
|
return r, nil
|
|
}
|
|
|
|
const proxiesCacheKey = "proxies"
|
|
|
|
func (d *Director) getProxies(ctx context.Context) ([]*store.Proxy, error) {
|
|
proxies, exists := d.proxyCache.Get(proxiesCacheKey)
|
|
if exists {
|
|
return proxies, nil
|
|
}
|
|
|
|
headers, err := d.proxyRepository.QueryProxy(ctx, store.WithProxyQueryEnabled(true))
|
|
if err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
|
|
sort.Sort(store.ByProxyWeight(headers))
|
|
|
|
proxies = make([]*store.Proxy, 0, len(headers))
|
|
|
|
for _, h := range headers {
|
|
if !h.Enabled {
|
|
continue
|
|
}
|
|
|
|
proxy, err := d.proxyRepository.GetProxy(ctx, h.Name)
|
|
if err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
|
|
proxies = append(proxies, proxy)
|
|
}
|
|
|
|
d.proxyCache.Set(proxiesCacheKey, proxies)
|
|
|
|
return proxies, nil
|
|
}
|
|
|
|
func (d *Director) getLayers(ctx context.Context, proxyName store.ProxyName) ([]*store.Layer, error) {
|
|
cacheKey := "layers-" + string(proxyName)
|
|
|
|
layers, exists := d.layerCache.Get(cacheKey)
|
|
if exists {
|
|
return layers, nil
|
|
}
|
|
|
|
headers, err := d.layerRepository.QueryLayers(ctx, proxyName, store.WithLayerQueryEnabled(true))
|
|
if err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
|
|
sort.Sort(store.ByLayerWeight(headers))
|
|
|
|
layers = make([]*store.Layer, 0, len(headers))
|
|
|
|
for _, h := range headers {
|
|
if !h.Enabled {
|
|
continue
|
|
}
|
|
|
|
layer, err := d.layerRepository.GetLayer(ctx, proxyName, h.Name)
|
|
if err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
|
|
layers = append(layers, layer)
|
|
}
|
|
|
|
d.layerCache.Set(cacheKey, layers)
|
|
|
|
return layers, nil
|
|
}
|
|
|
|
func (d *Director) RequestTransformer() proxy.RequestTransformer {
|
|
return func(r *http.Request) {
|
|
ctx := r.Context()
|
|
|
|
layers, err := ctxLayers(ctx)
|
|
if err != nil {
|
|
if errors.Is(err, errContextKeyNotFound) {
|
|
return
|
|
}
|
|
|
|
logger.Error(ctx, "could not retrieve layers from context", logger.CapturedE(errors.WithStack(err)))
|
|
|
|
return
|
|
}
|
|
|
|
for _, layer := range layers {
|
|
transformerLayer, ok := d.layerRegistry.GetRequestTransformer(layer.Type)
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
transformer := transformerLayer.RequestTransformer(layer)
|
|
|
|
transformer(r)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (d *Director) ResponseTransformer() proxy.ResponseTransformer {
|
|
return func(r *http.Response) error {
|
|
ctx := r.Request.Context()
|
|
layers, err := ctxLayers(ctx)
|
|
if err != nil {
|
|
if errors.Is(err, errContextKeyNotFound) {
|
|
return nil
|
|
}
|
|
|
|
return errors.WithStack(err)
|
|
}
|
|
|
|
for _, layer := range layers {
|
|
transformerLayer, ok := d.layerRegistry.GetResponseTransformer(layer.Type)
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
transformer := transformerLayer.ResponseTransformer(layer)
|
|
|
|
if err := transformer(r); err != nil {
|
|
return errors.WithStack(err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
}
|
|
|
|
func (d *Director) Middleware() proxy.Middleware {
|
|
return func(next http.Handler) http.Handler {
|
|
fn := func(w http.ResponseWriter, r *http.Request) {
|
|
ctx := withHandleError(r.Context(), d.handleError)
|
|
r = r.WithContext(ctx)
|
|
|
|
r, err := d.rewriteRequest(r)
|
|
if err != nil {
|
|
HandleError(ctx, w, r, http.StatusInternalServerError, errors.Wrap(err, "could not rewrite request"))
|
|
return
|
|
}
|
|
|
|
ctx = r.Context()
|
|
|
|
layers, err := ctxLayers(ctx)
|
|
if err != nil {
|
|
if errors.Is(err, errContextKeyNotFound) {
|
|
return
|
|
}
|
|
|
|
HandleError(ctx, w, r, http.StatusInternalServerError, errors.Wrap(err, "could not retrieve proxy and layers from context"))
|
|
return
|
|
}
|
|
|
|
httpMiddlewares := make([]proxy.Middleware, 0)
|
|
for _, layer := range layers {
|
|
middleware, ok := d.layerRegistry.GetMiddleware(layer.Type)
|
|
if !ok {
|
|
continue
|
|
}
|
|
|
|
httpMiddlewares = append(httpMiddlewares, middleware.Middleware(layer))
|
|
}
|
|
|
|
handler := createMiddlewareChain(next, httpMiddlewares)
|
|
|
|
handler.ServeHTTP(w, r)
|
|
}
|
|
|
|
return http.HandlerFunc(fn)
|
|
}
|
|
}
|
|
|
|
func New(proxyRepository store.ProxyRepository, layerRepository store.LayerRepository, funcs ...OptionFunc) *Director {
|
|
opts := NewOptions(funcs...)
|
|
|
|
registry := NewLayerRegistry(opts.Layers...)
|
|
|
|
return &Director{
|
|
proxyRepository: proxyRepository,
|
|
layerRepository: layerRepository,
|
|
layerRegistry: registry,
|
|
proxyCache: opts.ProxyCache,
|
|
layerCache: opts.LayerCache,
|
|
handleError: opts.HandleError,
|
|
}
|
|
}
|