bouncer/internal/proxy/server.go
William Petit 82228fd115
All checks were successful
Cadoles/bouncer/pipeline/head This commit looks good
feat: allow customization of proxy transport configuration
2023-07-01 13:43:18 -06:00

189 lines
4.9 KiB
Go

package proxy
import (
"context"
"fmt"
"log"
"net"
"net/http"
"net/http/httputil"
"net/url"
"time"
"forge.cadoles.com/Cadoles/go-proxy"
bouncerChi "forge.cadoles.com/cadoles/bouncer/internal/chi"
"forge.cadoles.com/cadoles/bouncer/internal/config"
"forge.cadoles.com/cadoles/bouncer/internal/proxy/director"
"forge.cadoles.com/cadoles/bouncer/internal/store"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/pkg/errors"
"github.com/prometheus/client_golang/prometheus/promhttp"
"gitlab.com/wpetit/goweb/logger"
)
type Server struct {
serverConfig config.ProxyServerConfig
redisConfig config.RedisConfig
directorLayers []director.Layer
proxyRepository store.ProxyRepository
layerRepository store.LayerRepository
}
func (s *Server) Start(ctx context.Context) (<-chan net.Addr, <-chan error) {
errs := make(chan error)
addrs := make(chan net.Addr)
go s.run(ctx, addrs, errs)
return addrs, errs
}
func (s *Server) run(parentCtx context.Context, addrs chan net.Addr, errs chan error) {
defer func() {
close(errs)
close(addrs)
}()
ctx, cancel := context.WithCancel(parentCtx)
defer cancel()
if err := s.initRepositories(ctx); err != nil {
errs <- errors.WithStack(err)
return
}
listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", s.serverConfig.HTTP.Host, s.serverConfig.HTTP.Port))
if err != nil {
errs <- errors.WithStack(err)
return
}
addrs <- listener.Addr()
defer func() {
if err := listener.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
errs <- errors.WithStack(err)
}
}()
go func() {
<-ctx.Done()
if err := listener.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
log.Printf("%+v", errors.WithStack(err))
}
}()
router := chi.NewRouter()
logger.Info(ctx, "http server listening")
director := director.New(
s.proxyRepository,
s.layerRepository,
s.directorLayers...,
)
router.Use(middleware.RequestLogger(bouncerChi.NewLogFormatter()))
if s.serverConfig.Metrics.Enabled {
metrics := s.serverConfig.Metrics
logger.Info(ctx, "enabling metrics", logger.F("endpoint", metrics.Endpoint))
router.Group(func(r chi.Router) {
if metrics.BasicAuth != nil {
logger.Info(ctx, "enabling authentication on metrics endpoint")
r.Use(middleware.BasicAuth(
"metrics",
metrics.BasicAuth.CredentialsMap(),
))
}
r.Handle(string(metrics.Endpoint), promhttp.Handler())
})
}
router.Group(func(r chi.Router) {
r.Use(director.Middleware())
handler := proxy.New(
proxy.WithRequestTransformers(
director.RequestTransformer(),
),
proxy.WithResponseTransformers(
director.ResponseTransformer(),
),
proxy.WithReverseProxyFactory(s.createReverseProxy),
)
r.Handle("/*", handler)
})
if err := http.Serve(listener, router); err != nil && !errors.Is(err, net.ErrClosed) {
errs <- errors.WithStack(err)
}
logger.Info(ctx, "http server exiting")
}
func (s *Server) createReverseProxy(ctx context.Context, target *url.URL) *httputil.ReverseProxy {
reverseProxy := httputil.NewSingleHostReverseProxy(target)
dialConfig := s.serverConfig.Dial
dialer := &net.Dialer{
Timeout: time.Duration(*dialConfig.Timeout),
KeepAlive: time.Duration(*dialConfig.KeepAlive),
FallbackDelay: time.Duration(*dialConfig.FallbackDelay),
DualStack: bool(dialConfig.DualStack),
}
transportConfig := s.serverConfig.Transport
reverseProxy.Transport = &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: dialer.DialContext,
ForceAttemptHTTP2: bool(transportConfig.ForceAttemptHTTP2),
MaxIdleConns: int(transportConfig.MaxIdleConns),
MaxIdleConnsPerHost: int(transportConfig.MaxIdleConnsPerHost),
MaxConnsPerHost: int(transportConfig.MaxConnsPerHost),
IdleConnTimeout: time.Duration(*transportConfig.IdleConnTimeout),
TLSHandshakeTimeout: time.Duration(*transportConfig.TLSHandshakeTimeout),
ExpectContinueTimeout: time.Duration(*transportConfig.ExpectContinueTimeout),
DisableKeepAlives: bool(transportConfig.DisableKeepAlives),
DisableCompression: bool(transportConfig.DisableCompression),
ResponseHeaderTimeout: time.Duration(*transportConfig.ResponseHeaderTimeout),
WriteBufferSize: int(transportConfig.WriteBufferSize),
ReadBufferSize: int(transportConfig.ReadBufferSize),
MaxResponseHeaderBytes: int64(transportConfig.MaxResponseHeaderBytes),
}
reverseProxy.ErrorHandler = s.errorHandler
return reverseProxy
}
func (s *Server) errorHandler(w http.ResponseWriter, r *http.Request, err error) {
logger.Error(r.Context(), "proxy error", logger.E(errors.WithStack(err)))
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}
func NewServer(funcs ...OptionFunc) *Server {
opt := defaultOption()
for _, fn := range funcs {
fn(opt)
}
return &Server{
serverConfig: opt.ServerConfig,
redisConfig: opt.RedisConfig,
directorLayers: opt.DirectorLayers,
}
}