diff --git a/go.mod b/go.mod index 542de0a..a6f882c 100644 --- a/go.mod +++ b/go.mod @@ -33,6 +33,7 @@ require ( require ( github.com/Masterminds/goutils v1.1.1 // indirect github.com/Masterminds/semver/v3 v3.2.0 // indirect + github.com/antonmedv/expr v1.12.7 // indirect github.com/barnybug/go-cast v0.0.0-20201201064555-a87ccbc26692 // indirect github.com/dop251/goja_nodejs v0.0.0-20230320130059-dcf93ba651dd // indirect github.com/gabriel-vasile/mimetype v1.4.1 // indirect diff --git a/go.sum b/go.sum index fc5586b..8fd6a3b 100644 --- a/go.sum +++ b/go.sum @@ -149,6 +149,8 @@ github.com/alexflint/go-filemutex v1.1.0/go.mod h1:7P4iRhttt/nUvUOrYIhcpMzv2G6CY github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883 h1:bvNMNQO63//z+xNgfBlViaCIJKLlCJ6/fmUseuG0wVQ= github.com/andreyvit/diff v0.0.0-20170406064948-c7f18ee00883/go.mod h1:rCTlJbsFo29Kk6CurOXKm700vrz8f0KW0JNfpkRJY/8= github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kdvxnR2qWY= +github.com/antonmedv/expr v1.12.7 h1:jfV/l/+dHWAadLwAtESXNxXdfbK9bE4+FNMHYCMntwk= +github.com/antonmedv/expr v1.12.7/go.mod h1:FPC8iWArxls7axbVLsW+kpg1mz29A1b2M6jt+hZfDkU= github.com/apache/arrow/go/arrow v0.0.0-20210818145353-234c94e4ce64/go.mod h1:2qMFB56yOP3KzkB3PbYZ4AlUFg3a88F67TIx5lB/WwY= github.com/apache/arrow/go/arrow v0.0.0-20211013220434-5962184e7a30/go.mod h1:Q7yQnSMnLvcXlZ8RV+jwz/6y1rQTqbX6C82SndT52Zs= github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o= diff --git a/internal/auth/agent/authenticator.go b/internal/auth/agent/authenticator.go index 3892d1a..0cb6c70 100644 --- a/internal/auth/agent/authenticator.go +++ b/internal/auth/agent/authenticator.go @@ -11,7 +11,6 @@ import ( "github.com/lestrrat-go/jwx/v2/jws" "github.com/lestrrat-go/jwx/v2/jwt" "github.com/pkg/errors" - "gitlab.com/wpetit/goweb/logger" ) const DefaultAcceptableSkew = 5 * time.Minute @@ -23,8 +22,6 @@ type Authenticator struct { // Authenticate implements auth.Authenticator. func (a *Authenticator) Authenticate(ctx context.Context, r *http.Request) (auth.User, error) { - ctx = logger.With(r.Context(), logger.F("remoteAddr", r.RemoteAddr)) - authorization := r.Header.Get("Authorization") if authorization == "" { return nil, errors.WithStack(auth.ErrUnauthenticated) diff --git a/internal/auth/thirdparty/authenticator.go b/internal/auth/thirdparty/authenticator.go index 4dae9e1..95af1e5 100644 --- a/internal/auth/thirdparty/authenticator.go +++ b/internal/auth/thirdparty/authenticator.go @@ -8,22 +8,25 @@ import ( "forge.cadoles.com/Cadoles/emissary/internal/auth" "forge.cadoles.com/Cadoles/emissary/internal/jwk" + "github.com/lestrrat-go/jwx/v2/jwt" "github.com/pkg/errors" - "gitlab.com/wpetit/goweb/logger" ) const DefaultAcceptableSkew = 5 * time.Minute +type ( + GetKeySet func(context.Context) (jwk.Set, error) + GetTokenRole func(context.Context, jwt.Token) (string, error) +) + type Authenticator struct { - keys jwk.Set - issuer string + getKeySet GetKeySet + getTokenRole GetTokenRole acceptableSkew time.Duration } // Authenticate implements auth.Authenticator. func (a *Authenticator) Authenticate(ctx context.Context, r *http.Request) (auth.User, error) { - ctx = logger.With(r.Context(), logger.F("remoteAddr", r.RemoteAddr)) - authorization := r.Header.Get("Authorization") if authorization == "" { return nil, errors.WithStack(auth.ErrUnauthenticated) @@ -34,37 +37,37 @@ func (a *Authenticator) Authenticate(ctx context.Context, r *http.Request) (auth return nil, errors.WithStack(auth.ErrUnauthenticated) } - token, err := parseToken(ctx, a.keys, a.issuer, rawToken, a.acceptableSkew) + keys, err := a.getKeySet(ctx) if err != nil { return nil, errors.WithStack(err) } - rawRole, exists := token.Get(keyRole) - if !exists { - return nil, errors.New("could not find 'thumbprint' claim") + token, err := parseToken(ctx, keys, rawToken, a.acceptableSkew) + if err != nil { + return nil, errors.WithStack(err) } - role, ok := rawRole.(string) - if !ok { - return nil, errors.Errorf("unexpected '%s' claim value: '%v'", keyRole, rawRole) + rawRole, err := a.getTokenRole(ctx, token) + if err != nil { + return nil, errors.WithStack(err) } - if !isValidRole(role) { - return nil, errors.Errorf("invalid role '%s'", role) + if !isValidRole(rawRole) { + return nil, errors.Errorf("invalid role '%s'", rawRole) } user := &User{ subject: token.Subject(), - role: Role(role), + role: Role(rawRole), } return user, nil } -func NewAuthenticator(keys jwk.Set, issuer string, acceptableSkew time.Duration) *Authenticator { +func NewAuthenticator(getKeySet GetKeySet, getTokenRole GetTokenRole, acceptableSkew time.Duration) *Authenticator { return &Authenticator{ - keys: keys, - issuer: issuer, + getTokenRole: getTokenRole, + getKeySet: getKeySet, acceptableSkew: acceptableSkew, } } diff --git a/internal/auth/thirdparty/jwt.go b/internal/auth/thirdparty/jwt.go index df7e445..1e76465 100644 --- a/internal/auth/thirdparty/jwt.go +++ b/internal/auth/thirdparty/jwt.go @@ -11,15 +11,13 @@ import ( "github.com/pkg/errors" ) -const keyRole = "role" - -func parseToken(ctx context.Context, keys jwk.Set, issuer string, rawToken string, acceptableSkew time.Duration) (jwt.Token, error) { +func parseToken(ctx context.Context, keys jwk.Set, rawToken string, acceptableSkew time.Duration) (jwt.Token, error) { token, err := jwt.Parse( []byte(rawToken), jwt.WithKeySet(keys, jws.WithRequireKid(false)), - jwt.WithIssuer(issuer), jwt.WithValidate(true), jwt.WithAcceptableSkew(acceptableSkew), + jwt.WithContext(ctx), ) if err != nil { return nil, errors.WithStack(err) @@ -28,18 +26,16 @@ func parseToken(ctx context.Context, keys jwk.Set, issuer string, rawToken strin return token, nil } -func GenerateToken(ctx context.Context, key jwk.Key, issuer, subject string, role Role) (string, error) { +const DefaultRoleKey string = "role" + +func GenerateToken(ctx context.Context, key jwk.Key, subject string, role Role) (string, error) { token := jwt.New() if err := token.Set(jwt.SubjectKey, subject); err != nil { return "", errors.WithStack(err) } - if err := token.Set(jwt.IssuerKey, issuer); err != nil { - return "", errors.WithStack(err) - } - - if err := token.Set(keyRole, role); err != nil { + if err := token.Set(DefaultRoleKey, role); err != nil { return "", errors.WithStack(err) } diff --git a/internal/command/server/auth/create_token.go b/internal/command/server/auth/create_token.go index a9a26d6..fdf2a8d 100644 --- a/internal/command/server/auth/create_token.go +++ b/internal/command/server/auth/create_token.go @@ -36,12 +36,17 @@ func CreateTokenCommand() *cli.Command { subject := ctx.String("subject") role := ctx.String("role") - key, err := jwk.LoadOrGenerate(string(conf.Server.PrivateKeyPath), jwk.DefaultKeySize) + localAuth := conf.Server.Auth.Local + if localAuth == nil { + return errors.New("local auth is disabled") + } + + key, err := jwk.LoadOrGenerate(string(localAuth.PrivateKeyPath), jwk.DefaultKeySize) if err != nil { return errors.WithStack(err) } - token, err := thirdparty.GenerateToken(ctx.Context, key, string(conf.Server.Issuer), subject, thirdparty.Role(role)) + token, err := thirdparty.GenerateToken(ctx.Context, key, subject, thirdparty.Role(role)) if err != nil { return errors.WithStack(err) } diff --git a/internal/config/environment.go b/internal/config/environment.go index 3b1b3d2..82470d8 100644 --- a/internal/config/environment.go +++ b/internal/config/environment.go @@ -4,6 +4,7 @@ import ( "os" "regexp" "strconv" + "time" "github.com/pkg/errors" "gopkg.in/yaml.v3" @@ -123,3 +124,37 @@ func (iss *InterpolatedStringSlice) UnmarshalYAML(value *yaml.Node) error { return nil } + +type InterpolatedDuration time.Duration + +func (id *InterpolatedDuration) UnmarshalYAML(value *yaml.Node) error { + var str string + + if err := value.Decode(&str); err != nil { + return errors.Wrapf(err, "could not decode value '%v' (line '%d') into string", value.Value, value.Line) + } + + if match := reVar.FindStringSubmatch(str); len(match) > 0 { + str = os.Getenv(match[1]) + } + + duration, err := time.ParseDuration(str) + if err != nil { + return errors.Wrapf(err, "could not parse duration '%v', line '%d'", str, value.Line) + } + + *id = InterpolatedDuration(duration) + + return nil +} + +func (id *InterpolatedDuration) MarshalYAML() (interface{}, error) { + duration := time.Duration(*id) + + return duration.String(), nil +} + +func NewInterpolatedDuration(d time.Duration) *InterpolatedDuration { + id := InterpolatedDuration(d) + return &id +} diff --git a/internal/config/server.go b/internal/config/server.go index b196d46..8b192fb 100644 --- a/internal/config/server.go +++ b/internal/config/server.go @@ -1,19 +1,50 @@ package config +import ( + "fmt" + + "forge.cadoles.com/Cadoles/emissary/internal/auth/thirdparty" +) + type ServerConfig struct { - PrivateKeyPath InterpolatedString `yaml:"privateKeyPath"` - Issuer InterpolatedString `yaml:"issuer"` - HTTP HTTPConfig `yaml:"http"` - Database DatabaseConfig `yaml:"database"` - CORS CORSConfig `yaml:"cors"` + HTTP HTTPConfig `yaml:"http"` + Database DatabaseConfig `yaml:"database"` + CORS CORSConfig `yaml:"cors"` + Auth AuthConfig `yaml:"auth"` } func NewDefaultServerConfig() ServerConfig { return ServerConfig{ - PrivateKeyPath: "server-key.json", - Issuer: "http://127.0.0.1:3000", - HTTP: NewDefaultHTTPConfig(), - Database: NewDefaultDatabaseConfig(), - CORS: NewDefaultCORSConfig(), + HTTP: NewDefaultHTTPConfig(), + Database: NewDefaultDatabaseConfig(), + CORS: NewDefaultCORSConfig(), + Auth: NewDefaultAuthConfig(), } } + +type AuthConfig struct { + Local *LocalAuthConfig `yaml:"local"` + Remote *RemoteAuthConfig `yaml:"remote"` + RoleExtractionRules []string `yaml:"roleExtractionRules"` +} + +func NewDefaultAuthConfig() AuthConfig { + return AuthConfig{ + Local: &LocalAuthConfig{ + PrivateKeyPath: "server-key.json", + }, + Remote: nil, + RoleExtractionRules: []string{ + fmt.Sprintf("jwt.%s != nil ? str(jwt.%s) : ''", thirdparty.DefaultRoleKey, thirdparty.DefaultRoleKey), + }, + } +} + +type LocalAuthConfig struct { + PrivateKeyPath InterpolatedString `yaml:"privateKeyPath"` +} + +type RemoteAuthConfig struct { + JsonWebKeySetURL InterpolatedString `yaml:"jwksUrl"` + RefreshInterval *InterpolatedDuration `yaml:"refreshInterval"` +} diff --git a/internal/jwk/jwk.go b/internal/jwk/jwk.go index afa0bdf..67745a1 100644 --- a/internal/jwk/jwk.go +++ b/internal/jwk/jwk.go @@ -1,11 +1,13 @@ package jwk import ( + "context" "crypto/rand" "crypto/rsa" "encoding/json" "io/ioutil" "os" + "time" "github.com/btcsuite/btcd/btcutil/base58" "github.com/lestrrat-go/jwx/v2/jwa" @@ -34,7 +36,7 @@ func Parse(src []byte, options ...jwk.ParseOption) (Set, error) { return jwk.Parse(src, options...) } -func PublicKeySet(keys ...jwk.Key) (jwk.Set, error) { +func RS256PublicKeySet(keys ...jwk.Key) (jwk.Set, error) { set := jwk.NewSet() for _, k := range keys { @@ -85,6 +87,27 @@ func LoadOrGenerate(path string, size int) (jwk.Key, error) { return key, nil } +func CreateCachedRemoteKeySet(ctx context.Context, url string, refreshInterval time.Duration) (func(context.Context) (jwk.Set, error), error) { + cache := jwk.NewCache(ctx) + + if err := cache.Register(url, jwk.WithMinRefreshInterval(refreshInterval)); err != nil { + return nil, errors.WithStack(err) + } + + if _, err := cache.Refresh(ctx, url); err != nil { + return nil, errors.WithStack(err) + } + + return func(ctx context.Context) (jwk.Set, error) { + keySet, err := cache.Get(ctx, url) + if err != nil { + return nil, errors.WithStack(err) + } + + return keySet, nil + }, nil +} + func Generate(size int) (jwk.Key, error) { privKey, err := rsa.GenerateKey(rand.Reader, size) if err != nil { diff --git a/internal/jwk/jwk_test.go b/internal/jwk/jwk_test.go index 6a748e8..0dd0c14 100644 --- a/internal/jwk/jwk_test.go +++ b/internal/jwk/jwk_test.go @@ -12,7 +12,7 @@ func TestJWK(t *testing.T) { t.Fatalf("%+v", errors.WithStack(err)) } - keySet, err := PublicKeySet(privateKey) + keySet, err := RS256PublicKeySet(privateKey) if err != nil { t.Fatalf("%+v", errors.WithStack(err)) } diff --git a/internal/server/server.go b/internal/server/server.go index 7d3721d..9cc9cee 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -6,6 +6,9 @@ import ( "log" "net" "net/http" + "reflect" + "strings" + "time" "forge.cadoles.com/Cadoles/emissary/internal/auth" "forge.cadoles.com/Cadoles/emissary/internal/auth/agent" @@ -13,9 +16,13 @@ import ( "forge.cadoles.com/Cadoles/emissary/internal/config" "forge.cadoles.com/Cadoles/emissary/internal/datastore" "forge.cadoles.com/Cadoles/emissary/internal/jwk" + "github.com/antonmedv/expr" + "github.com/antonmedv/expr/vm" "github.com/go-chi/chi" "github.com/go-chi/chi/middleware" "github.com/go-chi/cors" + "github.com/lestrrat-go/jwx/v2/jwa" + "github.com/lestrrat-go/jwx/v2/jwt" "github.com/pkg/errors" "gitlab.com/wpetit/goweb/logger" ) @@ -72,20 +79,6 @@ func (s *Server) run(parentCtx context.Context, addrs chan net.Addr, errs chan e } }() - key, err := jwk.LoadOrGenerate(string(s.conf.PrivateKeyPath), jwk.DefaultKeySize) - if err != nil { - errs <- errors.WithStack(err) - - return - } - - keys, err := jwk.PublicKeySet(key) - if err != nil { - errs <- errors.WithStack(err) - - return - } - router := chi.NewRouter() router.Use(middleware.Logger) @@ -100,12 +93,19 @@ func (s *Server) run(parentCtx context.Context, addrs chan net.Addr, errs chan e router.Use(corsMiddleware.Handler) + thirdPartyAuth, err := s.getThirdPartyAuthenticator() + if err != nil { + errs <- errors.WithStack(err) + + return + } + router.Route("/api/v1", func(r chi.Router) { r.Post("/register", s.registerAgent) r.Group(func(r chi.Router) { r.Use(auth.Middleware( - thirdparty.NewAuthenticator(keys, string(s.conf.Issuer), thirdparty.DefaultAcceptableSkew), + thirdPartyAuth, agent.NewAuthenticator(s.agentRepo, agent.DefaultAcceptableSkew), )) @@ -131,6 +131,151 @@ func (s *Server) run(parentCtx context.Context, addrs chan net.Addr, errs chan e logger.Info(ctx, "http server exiting") } +func (s *Server) getThirdPartyAuthenticator() (*thirdparty.Authenticator, error) { + var localPublicKey jwk.Key + + localAuth := s.conf.Auth.Local + if localAuth != nil { + key, err := jwk.LoadOrGenerate(string(localAuth.PrivateKeyPath), jwk.DefaultKeySize) + if err != nil { + return nil, errors.WithStack(err) + } + + publicKey, err := key.PublicKey() + if err != nil { + return nil, errors.WithStack(err) + } + + if err := publicKey.Set(jwk.AlgorithmKey, jwa.RS256); err != nil { + return nil, errors.WithStack(err) + } + + localPublicKey = publicKey + } + + var getRemoteKeySet thirdparty.GetKeySet + + remoteAuth := s.conf.Auth.Remote + if remoteAuth != nil { + refreshInterval := time.Minute * 15 + if remoteAuth.RefreshInterval != nil { + refreshInterval = time.Duration(*remoteAuth.RefreshInterval) + } + + fn, err := jwk.CreateCachedRemoteKeySet(context.Background(), string(remoteAuth.JsonWebKeySetURL), refreshInterval) + if err != nil { + return nil, errors.WithStack(err) + } + + getRemoteKeySet = fn + } + + getKeySet := func(ctx context.Context) (jwk.Set, error) { + keySet := jwk.NewSet() + + if localPublicKey != nil { + if err := keySet.AddKey(localPublicKey); err != nil { + return nil, errors.WithStack(err) + } + } + + if getRemoteKeySet != nil { + remoteKeySet, err := getRemoteKeySet(ctx) + if err != nil { + return nil, errors.WithStack(err) + } + + for idx := 0; idx < remoteKeySet.Len(); idx++ { + key, ok := remoteKeySet.Key(idx) + if !ok { + break + } + + if err := keySet.AddKey(key); err != nil { + return nil, errors.WithStack(err) + } + } + } + + return keySet, nil + } + + getTokenRole, err := s.createGetTokenRoleFunc() + if err != nil { + return nil, errors.WithStack(err) + } + + return thirdparty.NewAuthenticator(getKeySet, getTokenRole, thirdparty.DefaultAcceptableSkew), nil +} + +func (s *Server) createGetTokenRoleFunc() (func(ctx context.Context, token jwt.Token) (string, error), error) { + rawRules := s.conf.Auth.RoleExtractionRules + rules := make([]*vm.Program, 0, len(rawRules)) + + type Env struct { + JWT map[string]any `expr:"jwt"` + } + + strFunc := expr.Function( + "str", + func(params ...any) (any, error) { + var builder strings.Builder + + for _, p := range params { + if _, err := builder.WriteString(fmt.Sprintf("%v", p)); err != nil { + return nil, errors.WithStack(err) + } + } + + return builder.String(), nil + }, + new(func(any) string), + ) + + for _, rr := range rawRules { + r, err := expr.Compile(rr, + expr.Env(Env{}), + expr.AsKind(reflect.String), + strFunc, + ) + if err != nil { + return nil, errors.Wrapf(err, "could not compile role extraction rule '%s'", rr) + } + + rules = append(rules, r) + } + + return func(ctx context.Context, token jwt.Token) (string, error) { + jwt, err := token.AsMap(ctx) + if err != nil { + return "", errors.WithStack(err) + } + + vm := vm.VM{} + + for _, r := range rules { + result, err := vm.Run(r, Env{ + JWT: jwt, + }) + if err != nil { + return "", errors.WithStack(err) + } + + role, ok := result.(string) + if !ok { + logger.Debug(ctx, "ignoring unexpected role extraction result", logger.F("result", result)) + continue + } + + if role != "" { + return role, nil + } + } + + return "", errors.New("could not extract role from token") + }, nil +} + func New(funcs ...OptionFunc) *Server { opt := defaultOption() for _, fn := range funcs { diff --git a/misc/packaging/common/config-server.yml b/misc/packaging/common/config-server.yml index ee6a542..bf28e04 100644 --- a/misc/packaging/common/config-server.yml +++ b/misc/packaging/common/config-server.yml @@ -2,8 +2,6 @@ logger: level: 1 format: human server: - privateKeyPath: /var/lib/emissary/server-key.json - issuer: http://127.0.0.1:3000 http: host: 0.0.0.0 port: 3000 @@ -25,3 +23,11 @@ server: - Authorization - Sentry-Trace debug: false + auth: + local: + privateKeyPath: /var/lib/emissary/server-key.json + roleExtractionRules: + - "jwt.role != nil ? str(jwt.role) : ''" + remote: ~ + # jwksUrl: https://my-server/.well-known/jwks.json + diff --git a/pkg/client/register_agent.go b/pkg/client/register_agent.go index 1a41922..b35dc8a 100644 --- a/pkg/client/register_agent.go +++ b/pkg/client/register_agent.go @@ -10,7 +10,7 @@ import ( ) func (c *Client) RegisterAgent(ctx context.Context, key Key, thumbprint string, meta []MetadataTuple, funcs ...OptionFunc) (*Agent, error) { - keySet, err := jwk.PublicKeySet(key) + keySet, err := jwk.RS256PublicKeySet(key) if err != nil { return nil, errors.WithStack(err) }