338 lines
7.3 KiB
Go
338 lines
7.3 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"log"
|
|
"net"
|
|
"net/http"
|
|
"reflect"
|
|
"strings"
|
|
"time"
|
|
|
|
"forge.cadoles.com/Cadoles/emissary/internal/auth/agent"
|
|
"forge.cadoles.com/Cadoles/emissary/internal/auth/thirdparty"
|
|
"forge.cadoles.com/Cadoles/emissary/internal/config"
|
|
"forge.cadoles.com/Cadoles/emissary/internal/datastore"
|
|
"forge.cadoles.com/Cadoles/emissary/internal/jwk"
|
|
"forge.cadoles.com/Cadoles/emissary/internal/server/api"
|
|
"github.com/antonmedv/expr"
|
|
"github.com/antonmedv/expr/vm"
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/go-chi/chi/v5/middleware"
|
|
"github.com/go-chi/cors"
|
|
"github.com/lestrrat-go/jwx/v2/jwa"
|
|
"github.com/lestrrat-go/jwx/v2/jwt"
|
|
"github.com/pkg/errors"
|
|
"gitlab.com/wpetit/goweb/logger"
|
|
)
|
|
|
|
type Server struct {
|
|
conf config.ServerConfig
|
|
agentRepo datastore.AgentRepository
|
|
}
|
|
|
|
func (s *Server) Start(ctx context.Context) (<-chan net.Addr, <-chan error) {
|
|
errs := make(chan error)
|
|
addrs := make(chan net.Addr)
|
|
|
|
go s.run(ctx, addrs, errs)
|
|
|
|
return addrs, errs
|
|
}
|
|
|
|
func (s *Server) run(parentCtx context.Context, addrs chan net.Addr, errs chan error) {
|
|
defer func() {
|
|
close(errs)
|
|
close(addrs)
|
|
}()
|
|
|
|
ctx, cancel := context.WithCancel(parentCtx)
|
|
defer cancel()
|
|
|
|
if err := s.initRepositories(ctx); err != nil {
|
|
errs <- errors.WithStack(err)
|
|
|
|
return
|
|
}
|
|
|
|
listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", s.conf.HTTP.Host, s.conf.HTTP.Port))
|
|
if err != nil {
|
|
errs <- errors.WithStack(err)
|
|
|
|
return
|
|
}
|
|
|
|
addrs <- listener.Addr()
|
|
|
|
defer func() {
|
|
if err := listener.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
|
errs <- errors.WithStack(err)
|
|
}
|
|
}()
|
|
|
|
go func() {
|
|
<-ctx.Done()
|
|
|
|
if err := listener.Close(); err != nil && !errors.Is(err, net.ErrClosed) {
|
|
log.Printf("%+v", errors.WithStack(err))
|
|
}
|
|
}()
|
|
|
|
router := chi.NewRouter()
|
|
|
|
router.Use(middleware.Logger)
|
|
|
|
corsMiddleware := cors.New(cors.Options{
|
|
AllowedOrigins: s.conf.CORS.AllowedOrigins,
|
|
AllowedMethods: s.conf.CORS.AllowedMethods,
|
|
AllowCredentials: bool(s.conf.CORS.AllowCredentials),
|
|
AllowedHeaders: s.conf.CORS.AllowedHeaders,
|
|
Debug: bool(s.conf.CORS.Debug),
|
|
})
|
|
|
|
router.Use(corsMiddleware.Handler)
|
|
|
|
thirdPartyAuth, err := s.getThirdPartyAuthenticator()
|
|
if err != nil {
|
|
errs <- errors.WithStack(err)
|
|
|
|
return
|
|
}
|
|
|
|
router.Route("/api/v1", func(r chi.Router) {
|
|
apiMount := api.NewMount(
|
|
s.agentRepo,
|
|
thirdPartyAuth,
|
|
agent.NewAuthenticator(s.agentRepo, agent.DefaultAcceptableSkew),
|
|
)
|
|
|
|
apiMount.Mount(r)
|
|
})
|
|
|
|
logger.Info(ctx, "http server listening")
|
|
|
|
if err := http.Serve(listener, router); err != nil && !errors.Is(err, net.ErrClosed) {
|
|
errs <- errors.WithStack(err)
|
|
}
|
|
|
|
logger.Info(ctx, "http server exiting")
|
|
}
|
|
|
|
func (s *Server) getThirdPartyAuthenticator() (*thirdparty.Authenticator, error) {
|
|
var localPublicKey jwk.Key
|
|
|
|
localAuth := s.conf.Auth.Local
|
|
if localAuth != nil {
|
|
key, err := jwk.LoadOrGenerate(string(localAuth.PrivateKeyPath), jwk.DefaultKeySize)
|
|
if err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
|
|
publicKey, err := key.PublicKey()
|
|
if err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
|
|
if err := publicKey.Set(jwk.AlgorithmKey, jwa.RS256); err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
|
|
localPublicKey = publicKey
|
|
}
|
|
|
|
var getRemoteKeySet thirdparty.GetKeySet
|
|
|
|
remoteAuth := s.conf.Auth.Remote
|
|
if remoteAuth != nil {
|
|
refreshInterval := time.Minute * 15
|
|
if remoteAuth.RefreshInterval != nil {
|
|
refreshInterval = time.Duration(*remoteAuth.RefreshInterval)
|
|
}
|
|
|
|
fn, err := jwk.CreateCachedRemoteKeySet(context.Background(), string(remoteAuth.JsonWebKeySetURL), refreshInterval)
|
|
if err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
|
|
getRemoteKeySet = fn
|
|
}
|
|
|
|
getKeySet := func(ctx context.Context) (jwk.Set, error) {
|
|
keySet := jwk.NewSet()
|
|
|
|
if localPublicKey != nil {
|
|
if err := keySet.AddKey(localPublicKey); err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
}
|
|
|
|
if getRemoteKeySet != nil {
|
|
remoteKeySet, err := getRemoteKeySet(ctx)
|
|
if err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
|
|
for idx := 0; idx < remoteKeySet.Len(); idx++ {
|
|
key, ok := remoteKeySet.Key(idx)
|
|
if !ok {
|
|
break
|
|
}
|
|
|
|
if err := keySet.AddKey(key); err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
}
|
|
}
|
|
|
|
return keySet, nil
|
|
}
|
|
|
|
getTokenRole, err := s.createGetTokenRoleFunc()
|
|
if err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
|
|
getTenantRole, err := s.createGetTokenTenantFunc()
|
|
if err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
|
|
return thirdparty.NewAuthenticator(getKeySet, getTokenRole, getTenantRole, thirdparty.DefaultAcceptableSkew), nil
|
|
}
|
|
|
|
var ruleFuncs = []expr.Option{
|
|
expr.Function(
|
|
"str",
|
|
func(params ...any) (any, error) {
|
|
var builder strings.Builder
|
|
|
|
for _, p := range params {
|
|
if _, err := builder.WriteString(fmt.Sprintf("%v", p)); err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
}
|
|
|
|
return builder.String(), nil
|
|
},
|
|
new(func(any) string),
|
|
),
|
|
}
|
|
|
|
func (s *Server) createGetTokenRoleFunc() (func(ctx context.Context, token jwt.Token) (string, error), error) {
|
|
rawRules := s.conf.Auth.RoleExtractionRules
|
|
rules := make([]*vm.Program, 0, len(rawRules))
|
|
|
|
type Env struct {
|
|
JWT map[string]any `expr:"jwt"`
|
|
}
|
|
|
|
opts := append([]expr.Option{
|
|
expr.Env(Env{}),
|
|
expr.AsKind(reflect.String),
|
|
}, ruleFuncs...)
|
|
|
|
for _, rr := range rawRules {
|
|
r, err := expr.Compile(rr, opts...)
|
|
if err != nil {
|
|
return nil, errors.Wrapf(err, "could not compile role extraction rule '%s'", rr)
|
|
}
|
|
|
|
rules = append(rules, r)
|
|
}
|
|
|
|
return func(ctx context.Context, token jwt.Token) (string, error) {
|
|
jwt, err := token.AsMap(ctx)
|
|
if err != nil {
|
|
return "", errors.WithStack(err)
|
|
}
|
|
|
|
vm := vm.VM{}
|
|
|
|
for _, r := range rules {
|
|
result, err := vm.Run(r, Env{
|
|
JWT: jwt,
|
|
})
|
|
if err != nil {
|
|
return "", errors.WithStack(err)
|
|
}
|
|
|
|
role, ok := result.(string)
|
|
if !ok {
|
|
logger.Debug(ctx, "ignoring unexpected role extraction result", logger.F("result", result))
|
|
continue
|
|
}
|
|
|
|
if role != "" {
|
|
return role, nil
|
|
}
|
|
}
|
|
|
|
return "", errors.New("could not extract role from token")
|
|
}, nil
|
|
}
|
|
|
|
func (s *Server) createGetTokenTenantFunc() (func(ctx context.Context, token jwt.Token) (string, error), error) {
|
|
rawRules := s.conf.Auth.TenantExtractionRules
|
|
rules := make([]*vm.Program, 0, len(rawRules))
|
|
|
|
type Env struct {
|
|
JWT map[string]any `expr:"jwt"`
|
|
}
|
|
|
|
opts := append([]expr.Option{
|
|
expr.Env(Env{}),
|
|
expr.AsKind(reflect.String),
|
|
}, ruleFuncs...)
|
|
|
|
for _, rr := range rawRules {
|
|
r, err := expr.Compile(rr, opts...)
|
|
if err != nil {
|
|
return nil, errors.Wrapf(err, "could not compile role extraction rule '%s'", rr)
|
|
}
|
|
|
|
rules = append(rules, r)
|
|
}
|
|
|
|
return func(ctx context.Context, token jwt.Token) (string, error) {
|
|
jwt, err := token.AsMap(ctx)
|
|
if err != nil {
|
|
return "", errors.WithStack(err)
|
|
}
|
|
|
|
vm := vm.VM{}
|
|
|
|
for _, r := range rules {
|
|
result, err := vm.Run(r, Env{
|
|
JWT: jwt,
|
|
})
|
|
if err != nil {
|
|
return "", errors.WithStack(err)
|
|
}
|
|
|
|
tenant, ok := result.(string)
|
|
if !ok {
|
|
logger.Debug(ctx, "ignoring unexpected tenant extraction result", logger.F("result", result))
|
|
continue
|
|
}
|
|
|
|
if tenant != "" {
|
|
return tenant, nil
|
|
}
|
|
}
|
|
|
|
return "", errors.New("could not extract tenant from token")
|
|
}, nil
|
|
}
|
|
|
|
func New(funcs ...OptionFunc) *Server {
|
|
opt := defaultOption()
|
|
for _, fn := range funcs {
|
|
fn(opt)
|
|
}
|
|
|
|
return &Server{
|
|
conf: opt.Config,
|
|
}
|
|
}
|