package director import ( "context" "net/http" "sort" "forge.cadoles.com/Cadoles/go-proxy" "forge.cadoles.com/Cadoles/go-proxy/wildcard" "forge.cadoles.com/cadoles/bouncer/internal/store" "forge.cadoles.com/cadoles/bouncer/internal/syncx" "github.com/getsentry/sentry-go" "github.com/pkg/errors" "github.com/prometheus/client_golang/prometheus" "gitlab.com/wpetit/goweb/logger" ) type Director struct { proxyRepository store.ProxyRepository layerRepository store.LayerRepository layerRegistry *LayerRegistry cachedProxies *syncx.CachedResource[string, []*store.Proxy] cachedLayers *syncx.CachedResource[string, []*store.Layer] handleError HandleErrorFunc } const proxiesCacheKey = "proxies" func (d *Director) rewriteRequest(r *http.Request) (*http.Request, error) { ctx := r.Context() proxies, _, err := d.cachedProxies.Get(ctx, proxiesCacheKey) if err != nil { return r, errors.WithStack(err) } url := getRequestURL(r) ctx = withOriginalURL(ctx, url) ctx = logger.With(ctx, logger.F("url", url.String())) layers := make([]*store.Layer, 0) for _, p := range proxies { for _, from := range p.From { if matches := wildcard.Match(url.String(), from); !matches { continue } proxyCtx := logger.With(ctx, logger.F("proxy", p.Name), logger.F("host", r.Host), logger.F("remoteAddr", r.RemoteAddr), ) metricProxyRequestsTotal.With(prometheus.Labels{metricLabelProxy: string(p.Name)}).Add(1) proxyLayers, _, err := d.cachedLayers.Get(proxyCtx, string(p.Name)) if err != nil { return r, errors.WithStack(err) } layers = append(layers, proxyLayers...) if p.To == "" { continue } toURL, err := url.Parse(p.To) if err != nil { return r, errors.WithStack(err) } r.URL.Host = toURL.Host r.URL.Scheme = toURL.Scheme r.URL.Path = toURL.JoinPath(r.URL.Path).Path proxyCtx = withLayers(proxyCtx, layers) r = r.WithContext(proxyCtx) if sentryScope, _ := SentryScope(ctx); sentryScope != nil { sentryScope.SetTags(map[string]string{ "bouncer.proxy.name": string(p.Name), "bouncer.proxy.target.url": r.URL.String(), "bouncer.proxy.target.host": r.URL.Host, }) } return r, nil } } ctx = withLayers(ctx, layers) r = r.WithContext(ctx) return r, nil } func (d *Director) getProxies(ctx context.Context, key string) ([]*store.Proxy, error) { logger.Debug(ctx, "querying fresh proxies") headers, err := d.proxyRepository.QueryProxy(ctx, store.WithProxyQueryEnabled(true)) if err != nil { return nil, errors.WithStack(err) } sort.Sort(store.ByProxyWeight(headers)) proxies := make([]*store.Proxy, 0, len(headers)) for _, h := range headers { if !h.Enabled { continue } proxy, err := d.proxyRepository.GetProxy(ctx, h.Name) if err != nil { return nil, errors.WithStack(err) } proxies = append(proxies, proxy) } return proxies, nil } func (d *Director) getLayers(ctx context.Context, rawProxyName string) ([]*store.Layer, error) { proxyName := store.ProxyName(rawProxyName) logger.Debug(ctx, "querying fresh layers") headers, err := d.layerRepository.QueryLayers(ctx, proxyName, store.WithLayerQueryEnabled(true)) if err != nil { return nil, errors.WithStack(err) } sort.Sort(store.ByLayerWeight(headers)) layers := make([]*store.Layer, 0, len(headers)) for _, h := range headers { if !h.Enabled { continue } layer, err := d.layerRepository.GetLayer(ctx, proxyName, h.Name) if err != nil { return nil, errors.WithStack(err) } layers = append(layers, layer) } return layers, nil } func (d *Director) RequestTransformer() proxy.RequestTransformer { return func(r *http.Request) { ctx := r.Context() layers, err := ctxLayers(ctx) if err != nil { if errors.Is(err, errContextKeyNotFound) { return } logger.Error(ctx, "could not retrieve layers from context", logger.CapturedE(errors.WithStack(err))) return } for _, layer := range layers { transformerLayer, ok := d.layerRegistry.GetRequestTransformer(layer.Type) if !ok { continue } transformer := transformerLayer.RequestTransformer(layer) transformer(r) } } } func (d *Director) ResponseTransformer() proxy.ResponseTransformer { return func(r *http.Response) error { ctx := r.Request.Context() layers, err := ctxLayers(ctx) if err != nil { if errors.Is(err, errContextKeyNotFound) { return nil } return errors.WithStack(err) } for _, layer := range layers { transformerLayer, ok := d.layerRegistry.GetResponseTransformer(layer.Type) if !ok { continue } transformer := transformerLayer.ResponseTransformer(layer) if err := transformer(r); err != nil { return errors.WithStack(err) } } return nil } } func (d *Director) Middleware() proxy.Middleware { return func(next http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { sentry.ConfigureScope(func(scope *sentry.Scope) { ctx := withHandleError(r.Context(), d.handleError) ctx = withSentryScope(ctx, scope) r = r.WithContext(ctx) r, err := d.rewriteRequest(r) if err != nil { HandleError(ctx, w, r, http.StatusInternalServerError, errors.Wrap(err, "could not rewrite request")) return } ctx = r.Context() layers, err := ctxLayers(ctx) if err != nil { if errors.Is(err, errContextKeyNotFound) { return } HandleError(ctx, w, r, http.StatusInternalServerError, errors.Wrap(err, "could not retrieve proxy and layers from context")) return } httpMiddlewares := make([]proxy.Middleware, 0) for _, layer := range layers { middleware, ok := d.layerRegistry.GetMiddleware(layer.Type) if !ok { continue } httpMiddlewares = append(httpMiddlewares, middleware.Middleware(layer)) } handler := createMiddlewareChain(next, httpMiddlewares) handler.ServeHTTP(w, r) }) } return http.HandlerFunc(fn) } } func New(proxyRepository store.ProxyRepository, layerRepository store.LayerRepository, funcs ...OptionFunc) *Director { opts := NewOptions(funcs...) registry := NewLayerRegistry(opts.Layers...) director := &Director{ proxyRepository: proxyRepository, layerRepository: layerRepository, layerRegistry: registry, handleError: opts.HandleError, } director.cachedProxies = syncx.NewCachedResource(opts.ProxyCache, director.getProxies) director.cachedLayers = syncx.NewCachedResource(opts.LayerCache, director.getLayers) return director }