bouncer/internal/admin/server.go

235 lines
6.0 KiB
Go

package admin
import (
"context"
"fmt"
"log"
"net"
"net/http"
"net/http/pprof"
"forge.cadoles.com/cadoles/bouncer/internal/auth"
"forge.cadoles.com/cadoles/bouncer/internal/auth/jwt"
bouncerChi "forge.cadoles.com/cadoles/bouncer/internal/chi"
"forge.cadoles.com/cadoles/bouncer/internal/config"
"forge.cadoles.com/cadoles/bouncer/internal/integration"
"forge.cadoles.com/cadoles/bouncer/internal/jwk"
"forge.cadoles.com/cadoles/bouncer/internal/store"
sentryhttp "github.com/getsentry/sentry-go/http"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/go-chi/cors"
"github.com/pkg/errors"
"github.com/prometheus/client_golang/prometheus/promhttp"
"github.com/redis/go-redis/v9"
"gitlab.com/wpetit/goweb/logger"
)
type Server struct {
serverConfig config.AdminServerConfig
redisConfig config.RedisConfig
redisClient redis.UniversalClient
integrations []integration.Integration
bootstrapConfig config.BootstrapConfig
proxyRepository store.ProxyRepository
layerRepository store.LayerRepository
privateKey jwk.Key
publicKeys jwk.Set
}
func (s *Server) Start(ctx context.Context) (<-chan net.Addr, <-chan error) {
errs := make(chan error)
addrs := make(chan net.Addr)
go s.run(ctx, addrs, errs)
return addrs, errs
}
func (s *Server) run(parentCtx context.Context, addrs chan net.Addr, errs chan error) {
defer func() {
close(errs)
close(addrs)
}()
ctx, cancel := context.WithCancel(parentCtx)
defer cancel()
if err := s.initRepositories(ctx); err != nil {
errs <- errors.WithStack(err)
return
}
if err := s.bootstrapProxies(ctx); err != nil {
errs <- errors.WithStack(err)
return
}
if err := s.initPrivateKey(ctx); err != nil {
errs <- errors.WithStack(err)
return
}
ctx = integration.WithPrivateKey(ctx, s.privateKey)
ctx = integration.WithPublicKeySet(ctx, s.publicKeys)
if err := integration.RunOnStartup(ctx, s.integrations); err != nil {
errs <- errors.WithStack(err)
return
}
listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", s.serverConfig.HTTP.Host, s.serverConfig.HTTP.Port))
if err != nil {
errs <- errors.WithStack(err)
return
}
addrs <- listener.Addr()
defer func() {
if err := listener.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
errs <- errors.WithStack(err)
}
}()
go func() {
<-ctx.Done()
if err := listener.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
log.Printf("%+v", errors.WithStack(err))
}
}()
router := chi.NewRouter()
if s.serverConfig.HTTP.UseRealIP {
router.Use(middleware.RealIP)
}
router.Use(middleware.RequestLogger(bouncerChi.NewLogFormatter()))
if s.serverConfig.Sentry.DSN != "" {
logger.Info(ctx, "enabling sentry http middleware")
sentryMiddleware := sentryhttp.New(sentryhttp.Options{
Repanic: true,
})
router.Use(sentryMiddleware.Handle)
}
corsMiddleware := cors.New(cors.Options{
AllowedOrigins: s.serverConfig.CORS.AllowedOrigins,
AllowedMethods: s.serverConfig.CORS.AllowedMethods,
AllowCredentials: bool(s.serverConfig.CORS.AllowCredentials),
AllowedHeaders: s.serverConfig.CORS.AllowedHeaders,
Debug: bool(s.serverConfig.CORS.Debug),
})
router.Use(corsMiddleware.Handler)
if s.serverConfig.Metrics.Enabled {
metrics := s.serverConfig.Metrics
logger.Info(ctx, "enabling metrics", logger.F("endpoint", metrics.Endpoint))
router.Group(func(r chi.Router) {
if metrics.BasicAuth != nil {
logger.Info(ctx, "enabling authentication on metrics endpoint")
r.Use(middleware.BasicAuth(
"metrics",
metrics.BasicAuth.CredentialsMap(),
))
}
r.Handle(string(metrics.Endpoint), promhttp.Handler())
})
}
if s.serverConfig.Profiling.Enabled {
profiling := s.serverConfig.Profiling
logger.Info(ctx, "enabling profiling", logger.F("endpoint", profiling.Endpoint))
router.Group(func(r chi.Router) {
if profiling.BasicAuth != nil {
logger.Info(ctx, "enabling authentication on metrics endpoint")
r.Use(middleware.BasicAuth(
"profiling",
profiling.BasicAuth.CredentialsMap(),
))
}
r.Route(string(profiling.Endpoint), func(r chi.Router) {
r.HandleFunc("/", pprof.Index)
r.HandleFunc("/cmdline", pprof.Cmdline)
r.HandleFunc("/profile", pprof.Profile)
r.HandleFunc("/symbol", pprof.Symbol)
r.HandleFunc("/trace", pprof.Trace)
r.HandleFunc("/{name}", func(w http.ResponseWriter, r *http.Request) {
name := chi.URLParam(r, "name")
pprof.Handler(name).ServeHTTP(w, r)
})
})
})
}
router.Route("/api/v1", func(r chi.Router) {
r.Group(func(r chi.Router) {
r.Use(auth.Middleware(
jwt.NewAuthenticator(s.publicKeys, string(s.serverConfig.Auth.Issuer), jwt.DefaultAcceptableSkew),
))
r.Route("/definitions", func(r chi.Router) {
r.With(assertReadAccess).Get("/layers", s.queryLayerDefinition)
})
r.Route("/proxies", func(r chi.Router) {
r.With(assertReadAccess).Get("/", s.queryProxy)
r.With(assertWriteAccess).Post("/", s.createProxy)
r.With(assertReadAccess).Get("/{proxyName}", s.getProxy)
r.With(assertWriteAccess).Put("/{proxyName}", s.updateProxy)
r.With(assertWriteAccess).Delete("/{proxyName}", s.deleteProxy)
r.With(assertReadAccess).Get("/{proxyName}/layers", s.queryLayer)
r.With(assertWriteAccess).Post("/{proxyName}/layers", s.createLayer)
r.With(assertReadAccess).Get("/{proxyName}/layers/{layerName}", s.getLayer)
r.With(assertWriteAccess).Put("/{proxyName}/layers/{layerName}", s.updateLayer)
r.With(assertWriteAccess).Delete("/{proxyName}/layers/{layerName}", s.deleteLayer)
})
})
})
logger.Info(ctx, "http server listening")
if err := http.Serve(listener, router); err != nil && !errors.Is(err, net.ErrClosed) {
errs <- errors.WithStack(err)
}
logger.Info(ctx, "http server exiting")
}
func NewServer(funcs ...OptionFunc) *Server {
opt := defaultOption()
for _, fn := range funcs {
fn(opt)
}
return &Server{
serverConfig: opt.ServerConfig,
redisConfig: opt.RedisConfig,
bootstrapConfig: opt.BootstrapConfig,
integrations: opt.Integrations,
}
}