feat(auth): remote and local third-party authentication
Some checks reported warnings
arcad/emissary/pipeline/head This commit is unstable
Some checks reported warnings
arcad/emissary/pipeline/head This commit is unstable
This commit is contained in:
@ -6,6 +6,9 @@ import (
|
||||
"log"
|
||||
"net"
|
||||
"net/http"
|
||||
"reflect"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"forge.cadoles.com/Cadoles/emissary/internal/auth"
|
||||
"forge.cadoles.com/Cadoles/emissary/internal/auth/agent"
|
||||
@ -13,9 +16,13 @@ import (
|
||||
"forge.cadoles.com/Cadoles/emissary/internal/config"
|
||||
"forge.cadoles.com/Cadoles/emissary/internal/datastore"
|
||||
"forge.cadoles.com/Cadoles/emissary/internal/jwk"
|
||||
"github.com/antonmedv/expr"
|
||||
"github.com/antonmedv/expr/vm"
|
||||
"github.com/go-chi/chi"
|
||||
"github.com/go-chi/chi/middleware"
|
||||
"github.com/go-chi/cors"
|
||||
"github.com/lestrrat-go/jwx/v2/jwa"
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
"github.com/pkg/errors"
|
||||
"gitlab.com/wpetit/goweb/logger"
|
||||
)
|
||||
@ -72,20 +79,6 @@ func (s *Server) run(parentCtx context.Context, addrs chan net.Addr, errs chan e
|
||||
}
|
||||
}()
|
||||
|
||||
key, err := jwk.LoadOrGenerate(string(s.conf.PrivateKeyPath), jwk.DefaultKeySize)
|
||||
if err != nil {
|
||||
errs <- errors.WithStack(err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
keys, err := jwk.PublicKeySet(key)
|
||||
if err != nil {
|
||||
errs <- errors.WithStack(err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
router := chi.NewRouter()
|
||||
|
||||
router.Use(middleware.Logger)
|
||||
@ -100,12 +93,19 @@ func (s *Server) run(parentCtx context.Context, addrs chan net.Addr, errs chan e
|
||||
|
||||
router.Use(corsMiddleware.Handler)
|
||||
|
||||
thirdPartyAuth, err := s.getThirdPartyAuthenticator()
|
||||
if err != nil {
|
||||
errs <- errors.WithStack(err)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
router.Route("/api/v1", func(r chi.Router) {
|
||||
r.Post("/register", s.registerAgent)
|
||||
|
||||
r.Group(func(r chi.Router) {
|
||||
r.Use(auth.Middleware(
|
||||
thirdparty.NewAuthenticator(keys, string(s.conf.Issuer), thirdparty.DefaultAcceptableSkew),
|
||||
thirdPartyAuth,
|
||||
agent.NewAuthenticator(s.agentRepo, agent.DefaultAcceptableSkew),
|
||||
))
|
||||
|
||||
@ -131,6 +131,151 @@ func (s *Server) run(parentCtx context.Context, addrs chan net.Addr, errs chan e
|
||||
logger.Info(ctx, "http server exiting")
|
||||
}
|
||||
|
||||
func (s *Server) getThirdPartyAuthenticator() (*thirdparty.Authenticator, error) {
|
||||
var localPublicKey jwk.Key
|
||||
|
||||
localAuth := s.conf.Auth.Local
|
||||
if localAuth != nil {
|
||||
key, err := jwk.LoadOrGenerate(string(localAuth.PrivateKeyPath), jwk.DefaultKeySize)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
publicKey, err := key.PublicKey()
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
if err := publicKey.Set(jwk.AlgorithmKey, jwa.RS256); err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
localPublicKey = publicKey
|
||||
}
|
||||
|
||||
var getRemoteKeySet thirdparty.GetKeySet
|
||||
|
||||
remoteAuth := s.conf.Auth.Remote
|
||||
if remoteAuth != nil {
|
||||
refreshInterval := time.Minute * 15
|
||||
if remoteAuth.RefreshInterval != nil {
|
||||
refreshInterval = time.Duration(*remoteAuth.RefreshInterval)
|
||||
}
|
||||
|
||||
fn, err := jwk.CreateCachedRemoteKeySet(context.Background(), string(remoteAuth.JsonWebKeySetURL), refreshInterval)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
getRemoteKeySet = fn
|
||||
}
|
||||
|
||||
getKeySet := func(ctx context.Context) (jwk.Set, error) {
|
||||
keySet := jwk.NewSet()
|
||||
|
||||
if localPublicKey != nil {
|
||||
if err := keySet.AddKey(localPublicKey); err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
}
|
||||
|
||||
if getRemoteKeySet != nil {
|
||||
remoteKeySet, err := getRemoteKeySet(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
for idx := 0; idx < remoteKeySet.Len(); idx++ {
|
||||
key, ok := remoteKeySet.Key(idx)
|
||||
if !ok {
|
||||
break
|
||||
}
|
||||
|
||||
if err := keySet.AddKey(key); err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return keySet, nil
|
||||
}
|
||||
|
||||
getTokenRole, err := s.createGetTokenRoleFunc()
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return thirdparty.NewAuthenticator(getKeySet, getTokenRole, thirdparty.DefaultAcceptableSkew), nil
|
||||
}
|
||||
|
||||
func (s *Server) createGetTokenRoleFunc() (func(ctx context.Context, token jwt.Token) (string, error), error) {
|
||||
rawRules := s.conf.Auth.RoleExtractionRules
|
||||
rules := make([]*vm.Program, 0, len(rawRules))
|
||||
|
||||
type Env struct {
|
||||
JWT map[string]any `expr:"jwt"`
|
||||
}
|
||||
|
||||
strFunc := expr.Function(
|
||||
"str",
|
||||
func(params ...any) (any, error) {
|
||||
var builder strings.Builder
|
||||
|
||||
for _, p := range params {
|
||||
if _, err := builder.WriteString(fmt.Sprintf("%v", p)); err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
}
|
||||
|
||||
return builder.String(), nil
|
||||
},
|
||||
new(func(any) string),
|
||||
)
|
||||
|
||||
for _, rr := range rawRules {
|
||||
r, err := expr.Compile(rr,
|
||||
expr.Env(Env{}),
|
||||
expr.AsKind(reflect.String),
|
||||
strFunc,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "could not compile role extraction rule '%s'", rr)
|
||||
}
|
||||
|
||||
rules = append(rules, r)
|
||||
}
|
||||
|
||||
return func(ctx context.Context, token jwt.Token) (string, error) {
|
||||
jwt, err := token.AsMap(ctx)
|
||||
if err != nil {
|
||||
return "", errors.WithStack(err)
|
||||
}
|
||||
|
||||
vm := vm.VM{}
|
||||
|
||||
for _, r := range rules {
|
||||
result, err := vm.Run(r, Env{
|
||||
JWT: jwt,
|
||||
})
|
||||
if err != nil {
|
||||
return "", errors.WithStack(err)
|
||||
}
|
||||
|
||||
role, ok := result.(string)
|
||||
if !ok {
|
||||
logger.Debug(ctx, "ignoring unexpected role extraction result", logger.F("result", result))
|
||||
continue
|
||||
}
|
||||
|
||||
if role != "" {
|
||||
return role, nil
|
||||
}
|
||||
}
|
||||
|
||||
return "", errors.New("could not extract role from token")
|
||||
}, nil
|
||||
}
|
||||
|
||||
func New(funcs ...OptionFunc) *Server {
|
||||
opt := defaultOption()
|
||||
for _, fn := range funcs {
|
||||
|
Reference in New Issue
Block a user