emissary/internal/server/server.go

144 lines
3.4 KiB
Go

package server
import (
"context"
"fmt"
"log"
"net"
"net/http"
"forge.cadoles.com/Cadoles/emissary/internal/auth"
"forge.cadoles.com/Cadoles/emissary/internal/auth/agent"
"forge.cadoles.com/Cadoles/emissary/internal/auth/thirdparty"
"forge.cadoles.com/Cadoles/emissary/internal/config"
"forge.cadoles.com/Cadoles/emissary/internal/datastore"
"forge.cadoles.com/Cadoles/emissary/internal/jwk"
"github.com/go-chi/chi"
"github.com/go-chi/chi/middleware"
"github.com/go-chi/cors"
"github.com/pkg/errors"
"gitlab.com/wpetit/goweb/logger"
)
type Server struct {
conf config.ServerConfig
agentRepo datastore.AgentRepository
}
func (s *Server) Start(ctx context.Context) (<-chan net.Addr, <-chan error) {
errs := make(chan error)
addrs := make(chan net.Addr)
go s.run(ctx, addrs, errs)
return addrs, errs
}
func (s *Server) run(parentCtx context.Context, addrs chan net.Addr, errs chan error) {
defer func() {
close(errs)
close(addrs)
}()
ctx, cancel := context.WithCancel(parentCtx)
defer cancel()
if err := s.initRepositories(ctx); err != nil {
errs <- errors.WithStack(err)
return
}
listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", s.conf.HTTP.Host, s.conf.HTTP.Port))
if err != nil {
errs <- errors.WithStack(err)
return
}
addrs <- listener.Addr()
defer func() {
if err := listener.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
errs <- errors.WithStack(err)
}
}()
go func() {
<-ctx.Done()
if err := listener.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
log.Printf("%+v", errors.WithStack(err))
}
}()
key, err := jwk.LoadOrGenerate(string(s.conf.PrivateKeyPath), jwk.DefaultKeySize)
if err != nil {
errs <- errors.WithStack(err)
return
}
keys, err := jwk.PublicKeySet(key)
if err != nil {
errs <- errors.WithStack(err)
return
}
router := chi.NewRouter()
router.Use(middleware.Logger)
corsMiddleware := cors.New(cors.Options{
AllowedOrigins: s.conf.CORS.AllowedOrigins,
AllowedMethods: s.conf.CORS.AllowedMethods,
AllowCredentials: bool(s.conf.CORS.AllowCredentials),
AllowedHeaders: s.conf.CORS.AllowedHeaders,
Debug: bool(s.conf.CORS.Debug),
})
router.Use(corsMiddleware.Handler)
router.Route("/api/v1", func(r chi.Router) {
r.Post("/register", s.registerAgent)
r.Group(func(r chi.Router) {
r.Use(auth.Middleware(
thirdparty.NewAuthenticator(keys, string(s.conf.Issuer), thirdparty.DefaultAcceptableSkew),
agent.NewAuthenticator(s.agentRepo, agent.DefaultAcceptableSkew),
))
r.Route("/agents", func(r chi.Router) {
r.With(assertGlobalReadAccess).Get("/", s.queryAgents)
r.With(assertAgentReadAccess).Get("/{agentID}", s.getAgent)
r.With(assertAgentWriteAccess).Put("/{agentID}", s.updateAgent)
r.With(assertAgentWriteAccess).Delete("/{agentID}", s.deleteAgent)
r.With(assertAgentReadAccess).Get("/{agentID}/specs", s.getAgentSpecs)
r.With(assertAgentWriteAccess).Post("/{agentID}/specs", s.updateSpec)
r.With(assertAgentWriteAccess).Delete("/{agentID}/specs", s.deleteSpec)
})
})
})
logger.Info(ctx, "http server listening")
if err := http.Serve(listener, router); err != nil && !errors.Is(err, net.ErrClosed) {
errs <- errors.WithStack(err)
}
logger.Info(ctx, "http server exiting")
}
func New(funcs ...OptionFunc) *Server {
opt := defaultOption()
for _, fn := range funcs {
fn(opt)
}
return &Server{
conf: opt.Config,
}
}