package circuitbreaker import ( "context" "html/template" "net" "net/http" "path/filepath" "sync" "forge.cadoles.com/Cadoles/go-proxy" "forge.cadoles.com/Cadoles/go-proxy/wildcard" "forge.cadoles.com/cadoles/bouncer/internal/proxy/director" "forge.cadoles.com/cadoles/bouncer/internal/store" "github.com/Masterminds/sprig/v3" "github.com/pkg/errors" "gitlab.com/wpetit/goweb/logger" ) const LayerType store.LayerType = "circuitbreaker" type Layer struct { templateDir string loadOnce sync.Once tmpl *template.Template } // LayerType implements director.MiddlewareLayer func (l *Layer) LayerType() store.LayerType { return LayerType } // Middleware implements director.MiddlewareLayer func (l *Layer) Middleware(layer *store.Layer) proxy.Middleware { return func(h http.Handler) http.Handler { fn := func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() options, err := fromStoreOptions(layer.Options) if err != nil { logger.Error(ctx, "could not parse layer options", logger.E(errors.WithStack(err))) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } matches, err := l.matchAnyAuthorizedCIDRs(ctx, r.RemoteAddr, options.AuthorizedCIDRs) if err != nil { logger.Error(ctx, "could not match authorized cidrs", logger.E(errors.WithStack(err))) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } if matches { h.ServeHTTP(w, r) return } matches = wildcard.MatchAny(r.URL.String(), options.MatchURLs...) if !matches { h.ServeHTTP(w, r) return } l.renderCircuitBreakerPage(w, r, layer, options) } return http.HandlerFunc(fn) } } func (l *Layer) matchAnyAuthorizedCIDRs(ctx context.Context, remoteHostPort string, CIDRs []string) (bool, error) { remoteHost, _, err := net.SplitHostPort(remoteHostPort) if err != nil { return false, errors.WithStack(err) } remoteAddr := net.ParseIP(remoteHost) if remoteAddr == nil { return false, errors.Errorf("remote host '%s' is not a valid ip address", remoteHost) } for _, rawCIDR := range CIDRs { _, net, err := net.ParseCIDR(rawCIDR) if err != nil { return false, errors.WithStack(err) } match := net.Contains(remoteAddr) if !match { continue } return true, nil } logger.Debug(ctx, "comparing remote host with authorized cidrs", logger.F("remoteAddr", remoteAddr)) return false, nil } func (l *Layer) renderCircuitBreakerPage(w http.ResponseWriter, r *http.Request, layer *store.Layer, options *LayerOptions) { ctx := r.Context() pattern := filepath.Join(l.templateDir, "*.gohtml") logger.Info(ctx, "loading circuit breaker page templates", logger.F("pattern", pattern)) tmpl, err := template.New("").Funcs(sprig.FuncMap()).ParseGlob(pattern) if err != nil { logger.Error(ctx, "could not load circuit breaker templates", logger.E(errors.WithStack(err))) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } templateData := struct { Layer *store.Layer LayerOptions *LayerOptions }{ Layer: layer, LayerOptions: options, } w.Header().Add("Cache-Control", "no-cache") w.WriteHeader(http.StatusOK) if err := tmpl.ExecuteTemplate(w, options.TemplateBlock, templateData); err != nil { logger.Error(ctx, "could not render circuit breaker page", logger.E(errors.WithStack(err))) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } } func New(funcs ...OptionFunc) *Layer { opts := defaultOptions() for _, fn := range funcs { fn(opts) } return &Layer{ templateDir: opts.TemplateDir, } } var _ director.MiddlewareLayer = &Layer{}