package server import ( "context" "fmt" "log" "net" "net/http" "reflect" "strings" "time" "forge.cadoles.com/Cadoles/emissary/internal/auth/agent" "forge.cadoles.com/Cadoles/emissary/internal/auth/thirdparty" "forge.cadoles.com/Cadoles/emissary/internal/config" "forge.cadoles.com/Cadoles/emissary/internal/datastore" "forge.cadoles.com/Cadoles/emissary/internal/jwk" "forge.cadoles.com/Cadoles/emissary/internal/server/api" "github.com/antonmedv/expr" "github.com/antonmedv/expr/vm" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/go-chi/cors" "github.com/lestrrat-go/jwx/v2/jwa" "github.com/lestrrat-go/jwx/v2/jwt" "github.com/pkg/errors" "gitlab.com/wpetit/goweb/logger" ) type Server struct { conf config.ServerConfig agentRepo datastore.AgentRepository } func (s *Server) Start(ctx context.Context) (<-chan net.Addr, <-chan error) { errs := make(chan error) addrs := make(chan net.Addr) go s.run(ctx, addrs, errs) return addrs, errs } func (s *Server) run(parentCtx context.Context, addrs chan net.Addr, errs chan error) { defer func() { close(errs) close(addrs) }() ctx, cancel := context.WithCancel(parentCtx) defer cancel() if err := s.initRepositories(ctx); err != nil { errs <- errors.WithStack(err) return } listener, err := net.Listen("tcp", fmt.Sprintf("%s:%d", s.conf.HTTP.Host, s.conf.HTTP.Port)) if err != nil { errs <- errors.WithStack(err) return } addrs <- listener.Addr() defer func() { if err := listener.Close(); err != nil && !errors.Is(err, net.ErrClosed) { errs <- errors.WithStack(err) } }() go func() { <-ctx.Done() if err := listener.Close(); err != nil && !errors.Is(err, net.ErrClosed) { log.Printf("%+v", errors.WithStack(err)) } }() router := chi.NewRouter() router.Use(middleware.Logger) corsMiddleware := cors.New(cors.Options{ AllowedOrigins: s.conf.CORS.AllowedOrigins, AllowedMethods: s.conf.CORS.AllowedMethods, AllowCredentials: bool(s.conf.CORS.AllowCredentials), AllowedHeaders: s.conf.CORS.AllowedHeaders, Debug: bool(s.conf.CORS.Debug), }) router.Use(corsMiddleware.Handler) thirdPartyAuth, err := s.getThirdPartyAuthenticator() if err != nil { errs <- errors.WithStack(err) return } router.Route("/api/v1", func(r chi.Router) { apiMount := api.NewMount( s.agentRepo, thirdPartyAuth, agent.NewAuthenticator(s.agentRepo, agent.DefaultAcceptableSkew), ) apiMount.Mount(r) }) logger.Info(ctx, "http server listening") if err := http.Serve(listener, router); err != nil && !errors.Is(err, net.ErrClosed) { errs <- errors.WithStack(err) } logger.Info(ctx, "http server exiting") } func (s *Server) getThirdPartyAuthenticator() (*thirdparty.Authenticator, error) { var localPublicKey jwk.Key localAuth := s.conf.Auth.Local if localAuth != nil { key, err := jwk.LoadOrGenerate(string(localAuth.PrivateKeyPath), jwk.DefaultKeySize) if err != nil { return nil, errors.WithStack(err) } publicKey, err := key.PublicKey() if err != nil { return nil, errors.WithStack(err) } if err := publicKey.Set(jwk.AlgorithmKey, jwa.RS256); err != nil { return nil, errors.WithStack(err) } localPublicKey = publicKey } var getRemoteKeySet thirdparty.GetKeySet remoteAuth := s.conf.Auth.Remote if remoteAuth != nil { refreshInterval := time.Minute * 15 if remoteAuth.RefreshInterval != nil { refreshInterval = time.Duration(*remoteAuth.RefreshInterval) } fn, err := jwk.CreateCachedRemoteKeySet(context.Background(), string(remoteAuth.JsonWebKeySetURL), refreshInterval) if err != nil { return nil, errors.WithStack(err) } getRemoteKeySet = fn } getKeySet := func(ctx context.Context) (jwk.Set, error) { keySet := jwk.NewSet() if localPublicKey != nil { if err := keySet.AddKey(localPublicKey); err != nil { return nil, errors.WithStack(err) } } if getRemoteKeySet != nil { remoteKeySet, err := getRemoteKeySet(ctx) if err != nil { return nil, errors.WithStack(err) } for idx := 0; idx < remoteKeySet.Len(); idx++ { key, ok := remoteKeySet.Key(idx) if !ok { break } if err := keySet.AddKey(key); err != nil { return nil, errors.WithStack(err) } } } return keySet, nil } getTokenRole, err := s.createGetTokenRoleFunc() if err != nil { return nil, errors.WithStack(err) } getTenantRole, err := s.createGetTokenTenantFunc() if err != nil { return nil, errors.WithStack(err) } return thirdparty.NewAuthenticator(getKeySet, getTokenRole, getTenantRole, thirdparty.DefaultAcceptableSkew), nil } var ruleFuncs = []expr.Option{ expr.Function( "str", func(params ...any) (any, error) { var builder strings.Builder for _, p := range params { if _, err := builder.WriteString(fmt.Sprintf("%v", p)); err != nil { return nil, errors.WithStack(err) } } return builder.String(), nil }, new(func(any) string), ), } func (s *Server) createGetTokenRoleFunc() (func(ctx context.Context, token jwt.Token) (string, error), error) { rawRules := s.conf.Auth.RoleExtractionRules rules := make([]*vm.Program, 0, len(rawRules)) type Env struct { JWT map[string]any `expr:"jwt"` } opts := append([]expr.Option{ expr.Env(Env{}), expr.AsKind(reflect.String), }, ruleFuncs...) for _, rr := range rawRules { r, err := expr.Compile(rr, opts...) if err != nil { return nil, errors.Wrapf(err, "could not compile role extraction rule '%s'", rr) } rules = append(rules, r) } return func(ctx context.Context, token jwt.Token) (string, error) { jwt, err := token.AsMap(ctx) if err != nil { return "", errors.WithStack(err) } vm := vm.VM{} for _, r := range rules { result, err := vm.Run(r, Env{ JWT: jwt, }) if err != nil { return "", errors.WithStack(err) } role, ok := result.(string) if !ok { logger.Debug(ctx, "ignoring unexpected role extraction result", logger.F("result", result)) continue } if role != "" { return role, nil } } return "", errors.New("could not extract role from token") }, nil } func (s *Server) createGetTokenTenantFunc() (func(ctx context.Context, token jwt.Token) (string, error), error) { rawRules := s.conf.Auth.TenantExtractionRules rules := make([]*vm.Program, 0, len(rawRules)) type Env struct { JWT map[string]any `expr:"jwt"` } opts := append([]expr.Option{ expr.Env(Env{}), expr.AsKind(reflect.String), }, ruleFuncs...) for _, rr := range rawRules { r, err := expr.Compile(rr, opts...) if err != nil { return nil, errors.Wrapf(err, "could not compile role extraction rule '%s'", rr) } rules = append(rules, r) } return func(ctx context.Context, token jwt.Token) (string, error) { jwt, err := token.AsMap(ctx) if err != nil { return "", errors.WithStack(err) } vm := vm.VM{} for _, r := range rules { result, err := vm.Run(r, Env{ JWT: jwt, }) if err != nil { return "", errors.WithStack(err) } tenant, ok := result.(string) if !ok { logger.Debug(ctx, "ignoring unexpected tenant extraction result", logger.F("result", result)) continue } if tenant != "" { return tenant, nil } } return "", errors.New("could not extract tenant from token") }, nil } func New(funcs ...OptionFunc) *Server { opt := defaultOption() for _, fn := range funcs { fn(opt) } return &Server{ conf: opt.Config, } }