diff --git a/.env.dist b/.env.dist index 6f09b89..4a9a61a 100644 --- a/.env.dist +++ b/.env.dist @@ -1,3 +1,4 @@ RUN_APP_ARGS="" #EDGE_DOCUMENTSTORE_DSN="rpc://localhost:3001/documentstore?tenant=local&appId=%APPID%" -#EDGE_BLOBSTORE_DSN="rpc://localhost:3001/blobstore?tenant=local&appId=%APPID%" \ No newline at end of file +#EDGE_BLOBSTORE_DSN="rpc://localhost:3001/blobstore?tenant=local&appId=%APPID%" +#EDGE_SHARESTORE_DSN="rpc://localhost:3001/sharestore?tenant=local" \ No newline at end of file diff --git a/.gitignore b/.gitignore index 942f5b3..c86f74d 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,5 @@ .mktools/ /dist /.chglog -/CHANGELOG.md \ No newline at end of file +/CHANGELOG.md +/storage-server.key \ No newline at end of file diff --git a/cmd/cli/command/app/run.go b/cmd/cli/command/app/run.go index 05003e5..e2ad78c 100644 --- a/cmd/cli/command/app/run.go +++ b/cmd/cli/command/app/run.go @@ -16,6 +16,7 @@ import ( "forge.cadoles.com/arcad/edge/pkg/bus" "forge.cadoles.com/arcad/edge/pkg/bus/memory" appHTTP "forge.cadoles.com/arcad/edge/pkg/http" + "forge.cadoles.com/arcad/edge/pkg/jwtutil" "forge.cadoles.com/arcad/edge/pkg/module" appModule "forge.cadoles.com/arcad/edge/pkg/module/app" appModuleMemory "forge.cadoles.com/arcad/edge/pkg/module/app/memory" @@ -33,7 +34,6 @@ import ( "forge.cadoles.com/arcad/edge/pkg/bundle" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" - "github.com/lestrrat-go/jwx/v2/jwa" "github.com/lestrrat-go/jwx/v2/jwk" "github.com/pkg/errors" "github.com/urfave/cli/v2" @@ -50,6 +50,8 @@ import ( "forge.cadoles.com/arcad/edge/pkg/storage/share" ) +var dummySecret = []byte("not_so_secret") + func RunCommand() *cli.Command { return &cli.Command{ Name: "run", @@ -194,7 +196,7 @@ func runApp(ctx context.Context, path, address, documentStoreDSN, blobStoreDSN, ctx = logger.With(ctx, logger.F("appID", manifest.ID)) // Add auth handler - key, err := dummyKey() + key, err := jwtutil.NewSymmetricKey(dummySecret) if err != nil { return errors.WithStack(err) } @@ -220,17 +222,17 @@ func runApp(ctx context.Context, path, address, documentStoreDSN, blobStoreDSN, appModule.Mount(appRepository), authModule.Mount( authHTTP.NewLocalHandler( - jwa.HS256, key, + key, authHTTP.WithRoutePrefix("/auth"), authHTTP.WithAccounts(deps.Accounts...), ), - authModule.WithJWT(dummyKeySet), + authModule.WithJWT(func() (jwk.Set, error) { + return jwtutil.NewSymmetricKeySet(dummySecret) + }), ), ), appHTTP.WithHTTPMiddlewares( - authModuleMiddleware.AnonymousUser( - jwa.HS256, key, - ), + authModuleMiddleware.AnonymousUser(key), ), ) if err := handler.Load(bundle); err != nil { @@ -276,7 +278,9 @@ func getServerModules(deps *moduleDeps) []app.ServerModuleFactory { module.StoreModuleFactory(deps.DocumentStore), blob.ModuleFactory(deps.Bus, deps.BlobStore), authModule.ModuleFactory( - authModule.WithJWT(dummyKeySet), + authModule.WithJWT(func() (jwk.Set, error) { + return jwtutil.NewSymmetricKeySet(dummySecret) + }), ), appModule.ModuleFactory(deps.AppRepository), fetch.ModuleFactory(deps.Bus), @@ -284,36 +288,6 @@ func getServerModules(deps *moduleDeps) []app.ServerModuleFactory { } } -var dummySecret = []byte("not_so_secret") - -func dummyKey() (jwk.Key, error) { - key, err := jwk.FromRaw(dummySecret) - if err != nil { - return nil, errors.WithStack(err) - } - - return key, nil -} - -func dummyKeySet() (jwk.Set, error) { - key, err := dummyKey() - if err != nil { - return nil, errors.WithStack(err) - } - - if err := key.Set(jwk.AlgorithmKey, jwa.HS256); err != nil { - return nil, errors.WithStack(err) - } - - set := jwk.NewSet() - - if err := set.AddKey(key); err != nil { - return nil, errors.WithStack(err) - } - - return set, nil -} - func ensureDir(path string) error { if err := os.MkdirAll(filepath.Dir(path), os.ModePerm); err != nil { return errors.WithStack(err) diff --git a/cmd/storage-server/command/run.go b/cmd/storage-server/command/run.go index 829745e..aa0cad2 100644 --- a/cmd/storage-server/command/run.go +++ b/cmd/storage-server/command/run.go @@ -1,6 +1,7 @@ package command import ( + "context" "fmt" "net/http" "strings" @@ -9,6 +10,7 @@ import ( "github.com/hashicorp/golang-lru/v2/expirable" "github.com/keegancsmith/rpc" + "github.com/lestrrat-go/jwx/v2/jwk" "gitlab.com/wpetit/goweb/logger" "github.com/go-chi/chi/v5" @@ -17,6 +19,7 @@ import ( "github.com/urfave/cli/v2" // Register storage drivers + "forge.cadoles.com/arcad/edge/pkg/jwtutil" "forge.cadoles.com/arcad/edge/pkg/storage" "forge.cadoles.com/arcad/edge/pkg/storage/driver" _ "forge.cadoles.com/arcad/edge/pkg/storage/driver/rpc" @@ -50,6 +53,12 @@ func Run() *cli.Command { EnvVars: []string{"STORAGE_SERVER_SHARESTORE_DSN_PATTERN"}, Value: fmt.Sprintf("sqlite://data/%%TENANT%%/sharestore.sqlite?_pragma=foreign_keys(1)&_pragma=busy_timeout=%d", (60 * time.Second).Milliseconds()), }, + &cli.StringFlag{ + Name: "private-key", + EnvVars: []string{"STORAGE_SERVER_PRIVATE_KEY"}, + Value: "storage-server.key", + TakesFile: true, + }, &cli.DurationFlag{ Name: "cache-ttl", EnvVars: []string{"STORAGE_SERVER_CACHE_TTL"}, @@ -68,9 +77,20 @@ func Run() *cli.Command { shareStoreDSNPattern := ctx.String("sharestore-dsn-pattern") cacheSize := ctx.Int("cache-size") cacheTTL := ctx.Duration("cache-ttl") + privateKeyFile := ctx.String("private-key") router := chi.NewRouter() + rsaKey, err := jwtutil.LoadOrGenerateRSAKey(privateKeyFile, 2048) + if err != nil { + return errors.WithStack(err) + } + + privateKey, err := jwtutil.FromRSA(rsaKey) + if err != nil { + return errors.WithStack(err) + } + getBlobStoreServer := createGetCachedStoreServer( func(dsn string) (storage.BlobStore, error) { return driver.NewBlobStore(dsn) @@ -100,6 +120,7 @@ func Run() *cli.Command { router.Use(middleware.RealIP) router.Use(middleware.Logger) + router.Use(authenticate(privateKey)) router.Handle("/blobstore", createStoreHandler(getBlobStoreServer, blobStoreDSNPattern, cacheSize, cacheTTL)) router.Handle("/documentstore", createStoreHandler(getDocumentStoreServer, documentStoreDSNPattern, cacheSize, cacheTTL)) @@ -177,3 +198,76 @@ func createStoreHandler(getStoreServer getRPCServerFunc, dsnPattern string, cach server.ServeHTTP(w, r) }) } + +func authenticate(privateKey jwk.Key) func(http.Handler) http.Handler { + keySet, err := jwtutil.NewKeySet(privateKey) + getKeySet := func() (jwk.Set, error) { + if err != nil { + return nil, errors.WithStack(err) + } + + return keySet, nil + } + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + token, err := jwtutil.FindToken(r, getKeySet, jwtutil.WithFinders( + jwtutil.FindTokenFromQueryString("token"), + )) + if err != nil { + logger.Error(ctx, "could not find jwt token", logger.E(errors.WithStack(err))) + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + + return + } + + tokenMap, err := token.AsMap(ctx) + if err != nil { + logger.Error(ctx, "could not transform token to map", logger.E(errors.WithStack(err))) + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + + return + } + + rawTenant, exists := tokenMap["tenant"] + if !exists { + logger.Warn(ctx, "could not find tenant claim", logger.F("token", token)) + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + + return + } + + tenant, ok := rawTenant.(string) + if !ok { + logger.Warn(ctx, "unexpected tenant claim value", logger.F("token", token)) + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + + return + } + + r = r.WithContext(context.WithValue(ctx, "tenant", tenant)) + + rawAppId, exists := tokenMap["appId"] + if !exists { + logger.Warn(ctx, "could not find appId claim", logger.F("token", token)) + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + + return + } + + appId, ok := rawAppId.(string) + if !ok { + logger.Warn(ctx, "unexpected appId claim value", logger.F("token", token)) + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + + return + } + + r = r.WithContext(context.WithValue(ctx, "appId", appId)) + + next.ServeHTTP(w, r) + }) + } +} diff --git a/pkg/module/auth/error.go b/pkg/jwtutil/error.go similarity index 55% rename from pkg/module/auth/error.go rename to pkg/jwtutil/error.go index 96618f9..9a789f8 100644 --- a/pkg/module/auth/error.go +++ b/pkg/jwtutil/error.go @@ -1,5 +1,6 @@ -package auth +package jwtutil import "errors" var ErrUnauthenticated = errors.New("unauthenticated") +var ErrNoKeySet = errors.New("no keyset") diff --git a/pkg/jwtutil/jwt.go b/pkg/jwtutil/jwt.go new file mode 100644 index 0000000..0d8602f --- /dev/null +++ b/pkg/jwtutil/jwt.go @@ -0,0 +1,123 @@ +package jwtutil + +import ( + "net/http" + "strings" + + "github.com/lestrrat-go/jwx/v2/jwk" + "github.com/lestrrat-go/jwx/v2/jws" + "github.com/lestrrat-go/jwx/v2/jwt" + "github.com/pkg/errors" +) + +type TokenFinderFunc func(r *http.Request) (string, error) + +type FindTokenOptions struct { + Finders []TokenFinderFunc +} + +type FindTokenOptionFunc func(*FindTokenOptions) + +type GetKeySetFunc func() (jwk.Set, error) + +func WithFinders(finders ...TokenFinderFunc) FindTokenOptionFunc { + return func(opts *FindTokenOptions) { + opts.Finders = finders + } +} + +func NewFindTokenOptions(funcs ...FindTokenOptionFunc) *FindTokenOptions { + opts := &FindTokenOptions{ + Finders: []TokenFinderFunc{ + FindTokenFromAuthorizationHeader, + }, + } + + for _, fn := range funcs { + fn(opts) + } + + return opts +} + +func FindTokenFromAuthorizationHeader(r *http.Request) (string, error) { + authorization := r.Header.Get("Authorization") + + // Retrieve token from Authorization header + rawToken := strings.TrimPrefix(authorization, "Bearer ") + + return rawToken, nil +} + +func FindTokenFromQueryString(name string) TokenFinderFunc { + return func(r *http.Request) (string, error) { + return r.URL.Query().Get(name), nil + } +} + +func FindTokenFromCookie(cookieName string) TokenFinderFunc { + return func(r *http.Request) (string, error) { + cookie, err := r.Cookie(cookieName) + if err != nil && !errors.Is(err, http.ErrNoCookie) { + return "", errors.WithStack(err) + } + + if cookie == nil { + return "", nil + } + + return cookie.Value, nil + } +} + +func FindRawToken(r *http.Request, funcs ...FindTokenOptionFunc) (string, error) { + opts := NewFindTokenOptions(funcs...) + + var rawToken string + var err error + + for _, find := range opts.Finders { + rawToken, err = find(r) + if err != nil { + return "", errors.WithStack(err) + } + + if rawToken == "" { + continue + } + + break + } + + if rawToken == "" { + return "", errors.WithStack(ErrUnauthenticated) + } + + return rawToken, nil +} + +func FindToken(r *http.Request, getKeySet GetKeySetFunc, funcs ...FindTokenOptionFunc) (jwt.Token, error) { + rawToken, err := FindRawToken(r, funcs...) + if err != nil { + return nil, errors.WithStack(err) + } + + keySet, err := getKeySet() + if err != nil { + return nil, errors.WithStack(err) + } + + if keySet == nil { + return nil, errors.WithStack(ErrNoKeySet) + } + + token, err := jwt.Parse([]byte(rawToken), + jwt.WithKeySet(keySet, jws.WithRequireKid(false)), + jwt.WithValidate(true), + ) + if err != nil { + return nil, errors.WithStack(err) + } + + return token, nil +} diff --git a/pkg/jwtutil/key.go b/pkg/jwtutil/key.go new file mode 100644 index 0000000..9a23627 --- /dev/null +++ b/pkg/jwtutil/key.go @@ -0,0 +1,52 @@ +package jwtutil + +import ( + "github.com/lestrrat-go/jwx/v2/jwa" + "github.com/lestrrat-go/jwx/v2/jwk" + "github.com/pkg/errors" +) + +func NewKeySet(keys ...jwk.Key) (jwk.Set, error) { + set := jwk.NewSet() + + for _, k := range keys { + if err := set.AddKey(k); err != nil { + return nil, errors.WithStack(err) + } + } + + return set, nil +} + +func NewSymmetricKey(secret []byte) (jwk.Key, error) { + key, err := jwk.FromRaw(secret) + if err != nil { + return nil, errors.WithStack(err) + } + + if err := key.Set(jwk.AlgorithmKey, jwa.HS256); err != nil { + return nil, errors.WithStack(err) + } + + return key, nil +} + +func NewSymmetricKeySet(secrets ...[]byte) (jwk.Set, error) { + keys := make([]jwk.Key, len(secrets)) + + for idx, sec := range secrets { + key, err := NewSymmetricKey(sec) + if err != nil { + return nil, errors.WithStack(err) + } + + keys[idx] = key + } + + keySet, err := NewKeySet(keys...) + if err != nil { + return nil, errors.WithStack(err) + } + + return keySet, nil +} diff --git a/pkg/jwtutil/rsa.go b/pkg/jwtutil/rsa.go new file mode 100644 index 0000000..25c951e --- /dev/null +++ b/pkg/jwtutil/rsa.go @@ -0,0 +1,90 @@ +package jwtutil + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "os" + + "github.com/lestrrat-go/jwx/v2/jwa" + "github.com/lestrrat-go/jwx/v2/jwk" + "github.com/pkg/errors" +) + +func FromRSA(privateKey *rsa.PrivateKey) (jwk.Key, error) { + key, err := jwk.FromRaw(privateKey) + if err != nil { + return nil, errors.WithStack(err) + } + + if err := key.Set(jwk.AlgorithmKey, jwa.RS256); err != nil { + return nil, errors.WithStack(err) + } + + return key, nil +} + +func LoadOrGenerateRSAKey(path string, size int) (*rsa.PrivateKey, error) { + privateKey, err := LoadRSAKey(path) + if err != nil { + if !errors.Is(err, os.ErrNotExist) { + return nil, errors.WithStack(err) + } + + privateKey, err = GenerateRSAKey(size) + if err != nil { + return nil, errors.WithStack(err) + } + + if err := SaveRSAKey(path, privateKey); err != nil { + return nil, errors.WithStack(err) + } + } + + return privateKey, nil +} + +func LoadRSAKey(path string) (*rsa.PrivateKey, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, errors.WithStack(err) + } + + block, _ := pem.Decode(data) + if block == nil { + return nil, errors.New("failed to parse pem block") + } + + privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes) + if err != nil { + return nil, errors.WithStack(err) + } + + return privateKey, nil +} + +func SaveRSAKey(path string, privateKey *rsa.PrivateKey) error { + data := x509.MarshalPKCS1PrivateKey(privateKey) + pem := pem.EncodeToMemory( + &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: data, + }, + ) + + if err := os.WriteFile(path, pem, os.FileMode(0600)); err != nil { + return errors.WithStack(err) + } + + return nil +} + +func GenerateRSAKey(size int) (*rsa.PrivateKey, error) { + privateKey, err := rsa.GenerateKey(rand.Reader, size) + if err != nil { + return nil, errors.WithStack(err) + } + + return privateKey, nil +} diff --git a/pkg/module/auth/jwt/jwt.go b/pkg/jwtutil/token.go similarity index 66% rename from pkg/module/auth/jwt/jwt.go rename to pkg/jwtutil/token.go index 0d128b9..748c186 100644 --- a/pkg/module/auth/jwt/jwt.go +++ b/pkg/jwtutil/token.go @@ -1,15 +1,14 @@ -package jwt +package jwtutil import ( "time" - "github.com/lestrrat-go/jwx/v2/jwa" "github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jwt" "github.com/pkg/errors" ) -func GenerateSignedToken(algo jwa.KeyAlgorithm, key jwk.Key, claims map[string]any) ([]byte, error) { +func GenerateSignedToken(key jwk.Key, claims map[string]any) ([]byte, error) { token := jwt.New() if err := token.Set(jwt.NotBeforeKey, time.Now()); err != nil { @@ -22,11 +21,11 @@ func GenerateSignedToken(algo jwa.KeyAlgorithm, key jwk.Key, claims map[string]a } } - if err := token.Set(jwk.AlgorithmKey, jwa.HS256); err != nil { + if err := token.Set(jwk.AlgorithmKey, key.Algorithm()); err != nil { return nil, errors.WithStack(err) } - rawToken, err := jwt.Sign(token, jwt.WithKey(algo, key)) + rawToken, err := jwt.Sign(token, jwt.WithKey(key.Algorithm(), key)) if err != nil { return nil, errors.WithStack(err) } diff --git a/pkg/module/auth/http/local_handler.go b/pkg/module/auth/http/local_handler.go index 302da25..d8d494f 100644 --- a/pkg/module/auth/http/local_handler.go +++ b/pkg/module/auth/http/local_handler.go @@ -7,11 +7,10 @@ import ( _ "embed" + "forge.cadoles.com/arcad/edge/pkg/jwtutil" "forge.cadoles.com/arcad/edge/pkg/module/auth" "forge.cadoles.com/arcad/edge/pkg/module/auth/http/passwd" - "forge.cadoles.com/arcad/edge/pkg/module/auth/jwt" "github.com/go-chi/chi/v5" - "github.com/lestrrat-go/jwx/v2/jwa" "github.com/lestrrat-go/jwx/v2/jwk" "github.com/pkg/errors" "gitlab.com/wpetit/goweb/logger" @@ -32,7 +31,6 @@ func init() { type LocalHandler struct { router chi.Router - algo jwa.KeyAlgorithm key jwk.Key getCookieDomain GetCookieDomainFunc cookieDuration time.Duration @@ -113,7 +111,7 @@ func (h *LocalHandler) handleForm(w http.ResponseWriter, r *http.Request) { account.Claims[auth.ClaimIssuer] = "local" - token, err := jwt.GenerateSignedToken(h.algo, h.key, account.Claims) + token, err := jwtutil.GenerateSignedToken(h.key, account.Claims) if err != nil { logger.Error(ctx, "could not generate signed token", logger.E(errors.WithStack(err))) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) @@ -182,14 +180,13 @@ func (h *LocalHandler) authenticate(username, password string) (*LocalAccount, e return &account, nil } -func NewLocalHandler(algo jwa.KeyAlgorithm, key jwk.Key, funcs ...LocalHandlerOptionFunc) *LocalHandler { +func NewLocalHandler(key jwk.Key, funcs ...LocalHandlerOptionFunc) *LocalHandler { opts := defaultLocalHandlerOptions() for _, fn := range funcs { fn(opts) } handler := &LocalHandler{ - algo: algo, key: key, accounts: toAccountsMap(opts.Accounts), getCookieDomain: opts.GetCookieDomain, diff --git a/pkg/module/auth/jwt.go b/pkg/module/auth/jwt.go deleted file mode 100644 index b96075c..0000000 --- a/pkg/module/auth/jwt.go +++ /dev/null @@ -1,118 +0,0 @@ -package auth - -import ( - "context" - "net/http" - "strings" - - "github.com/lestrrat-go/jwx/v2/jwk" - "github.com/lestrrat-go/jwx/v2/jws" - "github.com/lestrrat-go/jwx/v2/jwt" - "github.com/pkg/errors" -) - -const ( - CookieName string = "edge-auth" -) - -type GetKeySetFunc func() (jwk.Set, error) - -func WithJWT(getKeySet GetKeySetFunc) OptionFunc { - return func(o *Option) { - o.GetClaims = func(ctx context.Context, r *http.Request, names ...string) ([]string, error) { - claim, err := getClaims[string](r, getKeySet, names...) - if err != nil { - return nil, errors.WithStack(err) - } - - return claim, nil - } - } -} - -func FindRawToken(r *http.Request) (string, error) { - authorization := r.Header.Get("Authorization") - - // Retrieve token from Authorization header - rawToken := strings.TrimPrefix(authorization, "Bearer ") - - // Retrieve token from ?edge-auth= - if rawToken == "" { - rawToken = r.URL.Query().Get(CookieName) - } - - if rawToken == "" { - cookie, err := r.Cookie(CookieName) - if err != nil && !errors.Is(err, http.ErrNoCookie) { - return "", errors.WithStack(err) - } - - if cookie != nil { - rawToken = cookie.Value - } - } - - if rawToken == "" { - return "", errors.WithStack(ErrUnauthenticated) - } - - return rawToken, nil -} - -func FindToken(r *http.Request, getKeySet GetKeySetFunc) (jwt.Token, error) { - rawToken, err := FindRawToken(r) - if err != nil { - return nil, errors.WithStack(err) - } - - keySet, err := getKeySet() - if err != nil { - return nil, errors.WithStack(err) - } - - if keySet == nil { - return nil, errors.New("no keyset") - } - - token, err := jwt.Parse([]byte(rawToken), - jwt.WithKeySet(keySet, jws.WithRequireKid(false)), - jwt.WithValidate(true), - ) - if err != nil { - return nil, errors.WithStack(err) - } - - return token, nil -} - -func getClaims[T any](r *http.Request, getKeySet GetKeySetFunc, names ...string) ([]T, error) { - token, err := FindToken(r, getKeySet) - if err != nil { - return nil, errors.WithStack(err) - } - - ctx := r.Context() - - mapClaims, err := token.AsMap(ctx) - if err != nil { - return nil, errors.WithStack(err) - } - - claims := make([]T, len(names)) - - for idx, n := range names { - rawClaim, exists := mapClaims[n] - if !exists { - continue - } - - claim, ok := rawClaim.(T) - if !ok { - return nil, errors.Errorf("unexpected claim '%s' to be of type '%T', got '%T'", n, new(T), rawClaim) - } - - claims[idx] = claim - } - - return claims, nil -} diff --git a/pkg/module/auth/middleware/anonymous_user.go b/pkg/module/auth/middleware/anonymous_user.go index c594b9a..65bc383 100644 --- a/pkg/module/auth/middleware/anonymous_user.go +++ b/pkg/module/auth/middleware/anonymous_user.go @@ -7,10 +7,9 @@ import ( "net/http" "time" + "forge.cadoles.com/arcad/edge/pkg/jwtutil" "forge.cadoles.com/arcad/edge/pkg/module/auth" - "forge.cadoles.com/arcad/edge/pkg/module/auth/jwt" "github.com/google/uuid" - "github.com/lestrrat-go/jwx/v2/jwa" "github.com/lestrrat-go/jwx/v2/jwk" "github.com/pkg/errors" "gitlab.com/wpetit/goweb/logger" @@ -18,7 +17,7 @@ import ( const AnonIssuer = "anon" -func AnonymousUser(algo jwa.KeyAlgorithm, key jwk.Key, funcs ...AnonymousUserOptionFunc) func(next http.Handler) http.Handler { +func AnonymousUser(key jwk.Key, funcs ...AnonymousUserOptionFunc) func(next http.Handler) http.Handler { opts := defaultAnonymousUserOptions() for _, fn := range funcs { fn(opts) @@ -26,7 +25,11 @@ func AnonymousUser(algo jwa.KeyAlgorithm, key jwk.Key, funcs ...AnonymousUserOpt return func(next http.Handler) http.Handler { handler := func(w http.ResponseWriter, r *http.Request) { - rawToken, err := auth.FindRawToken(r) + rawToken, err := jwtutil.FindRawToken(r, jwtutil.WithFinders( + jwtutil.FindTokenFromAuthorizationHeader, + jwtutil.FindTokenFromQueryString(auth.CookieName), + jwtutil.FindTokenFromCookie(auth.CookieName), + )) // If request already has a raw token, we do nothing if rawToken != "" && err == nil { @@ -62,7 +65,7 @@ func AnonymousUser(algo jwa.KeyAlgorithm, key jwk.Key, funcs ...AnonymousUserOpt auth.ClaimEdgeTenant: opts.Tenant, } - token, err := jwt.GenerateSignedToken(algo, key, claims) + token, err := jwtutil.GenerateSignedToken(key, claims) if err != nil { logger.Error(ctx, "could not generate signed token", logger.E(errors.WithStack(err))) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) diff --git a/pkg/module/auth/module.go b/pkg/module/auth/module.go index 495a729..90e2cb8 100644 --- a/pkg/module/auth/module.go +++ b/pkg/module/auth/module.go @@ -5,12 +5,17 @@ import ( "forge.cadoles.com/arcad/edge/pkg/app" edgeHTTP "forge.cadoles.com/arcad/edge/pkg/http" + "forge.cadoles.com/arcad/edge/pkg/jwtutil" "forge.cadoles.com/arcad/edge/pkg/module/util" "github.com/dop251/goja" "github.com/pkg/errors" "gitlab.com/wpetit/goweb/logger" ) +const ( + CookieName string = "edge-auth" +) + const ( ClaimSubject = "sub" ClaimIssuer = "iss" @@ -21,8 +26,8 @@ const ( ) type Module struct { - server *app.Server - getClaims GetClaimsFunc + server *app.Server + getClaimFn GetClaimFunc } func (m *Module) Name() string { @@ -68,9 +73,9 @@ func (m *Module) getClaim(call goja.FunctionCall, rt *goja.Runtime) goja.Value { panic(rt.ToValue(errors.New("could not find http request in context"))) } - claim, err := m.getClaims(ctx, req, claimName) + claim, err := m.getClaimFn(ctx, req, claimName) if err != nil { - if errors.Is(err, ErrUnauthenticated) { + if errors.Is(err, jwtutil.ErrUnauthenticated) { return nil } @@ -78,11 +83,7 @@ func (m *Module) getClaim(call goja.FunctionCall, rt *goja.Runtime) goja.Value { return nil } - if len(claim) == 0 || claim[0] == "" { - return nil - } - - return rt.ToValue(claim[0]) + return rt.ToValue(claim) } func ModuleFactory(funcs ...OptionFunc) app.ServerModuleFactory { @@ -93,8 +94,8 @@ func ModuleFactory(funcs ...OptionFunc) app.ServerModuleFactory { return func(server *app.Server) app.ServerModule { return &Module{ - server: server, - getClaims: opt.GetClaims, + server: server, + getClaimFn: opt.GetClaim, } } } diff --git a/pkg/module/auth/module_test.go b/pkg/module/auth/module_test.go index 17cf933..a7fb233 100644 --- a/pkg/module/auth/module_test.go +++ b/pkg/module/auth/module_test.go @@ -10,6 +10,7 @@ import ( "cdr.dev/slog" "forge.cadoles.com/arcad/edge/pkg/app" edgeHTTP "forge.cadoles.com/arcad/edge/pkg/http" + "forge.cadoles.com/arcad/edge/pkg/jwtutil" "forge.cadoles.com/arcad/edge/pkg/module" "github.com/lestrrat-go/jwx/v2/jwa" "github.com/lestrrat-go/jwx/v2/jwk" @@ -130,7 +131,7 @@ func getDummyKey() jwk.Key { return key } -func getDummyKeySet(key jwk.Key) GetKeySetFunc { +func getDummyKeySet(key jwk.Key) jwtutil.GetKeySetFunc { return func() (jwk.Set, error) { set := jwk.NewSet() diff --git a/pkg/module/auth/mount.go b/pkg/module/auth/mount.go index f193150..ea3cd3a 100644 --- a/pkg/module/auth/mount.go +++ b/pkg/module/auth/mount.go @@ -3,6 +3,7 @@ package auth import ( "net/http" + "forge.cadoles.com/arcad/edge/pkg/jwtutil" "github.com/go-chi/chi/v5" "github.com/pkg/errors" "gitlab.com/wpetit/goweb/api" @@ -12,39 +13,39 @@ import ( type MountFunc func(r chi.Router) type Handler struct { - getClaims GetClaimsFunc + getClaim GetClaimFunc profileClaims []string } func (h *Handler) serveProfile(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - claims, err := h.getClaims(ctx, r, h.profileClaims...) - if err != nil { - if errors.Is(err, ErrUnauthenticated) { + profile := make(map[string]any) + + for _, name := range h.profileClaims { + value, err := h.getClaim(ctx, r, name) + if err != nil { + if errors.Is(err, jwtutil.ErrUnauthenticated) { + api.ErrorResponse( + w, http.StatusUnauthorized, + api.ErrCodeUnauthorized, + nil, + ) + + return + } + + logger.Error(ctx, "could not retrieve claims", logger.E(errors.WithStack(err))) api.ErrorResponse( - w, http.StatusUnauthorized, - api.ErrCodeUnauthorized, + w, http.StatusInternalServerError, + api.ErrCodeUnknownError, nil, ) return } - logger.Error(ctx, "could not retrieve claims", logger.E(errors.WithStack(err))) - api.ErrorResponse( - w, http.StatusInternalServerError, - api.ErrCodeUnknownError, - nil, - ) - - return - } - - profile := make(map[string]any) - - for idx, cl := range h.profileClaims { - profile[cl] = claims[idx] + profile[name] = value } api.DataResponse(w, http.StatusOK, struct { @@ -62,7 +63,7 @@ func Mount(authHandler http.Handler, funcs ...OptionFunc) MountFunc { handler := &Handler{ profileClaims: opt.ProfileClaims, - getClaims: opt.GetClaims, + getClaim: opt.GetClaim, } return func(r chi.Router) { diff --git a/pkg/module/auth/option.go b/pkg/module/auth/option.go index a10bbdb..d527290 100644 --- a/pkg/module/auth/option.go +++ b/pkg/module/auth/option.go @@ -2,15 +2,17 @@ package auth import ( "context" + "fmt" "net/http" + "forge.cadoles.com/arcad/edge/pkg/jwtutil" "github.com/pkg/errors" ) -type GetClaimsFunc func(ctx context.Context, r *http.Request, claims ...string) ([]string, error) +type GetClaimFunc func(ctx context.Context, r *http.Request, name string) (string, error) type Option struct { - GetClaims GetClaimsFunc + GetClaim GetClaimFunc ProfileClaims []string } @@ -18,7 +20,7 @@ type OptionFunc func(*Option) func defaultOptions() *Option { return &Option{ - GetClaims: dummyGetClaims, + GetClaim: dummyGetClaim, ProfileClaims: []string{ ClaimSubject, ClaimIssuer, @@ -30,13 +32,13 @@ func defaultOptions() *Option { } } -func dummyGetClaims(ctx context.Context, r *http.Request, claims ...string) ([]string, error) { - return nil, errors.Errorf("dummy getclaim func cannot retrieve claims '%s'", claims) +func dummyGetClaim(ctx context.Context, r *http.Request, name string) (string, error) { + return "", errors.Errorf("dummy getclaim func cannot retrieve claim '%s'", name) } -func WithGetClaims(fn GetClaimsFunc) OptionFunc { +func WithGetClaims(fn GetClaimFunc) OptionFunc { return func(o *Option) { - o.GetClaims = fn + o.GetClaim = fn } } @@ -45,3 +47,34 @@ func WithProfileClaims(claims ...string) OptionFunc { o.ProfileClaims = claims } } + +func WithJWT(getKeySet jwtutil.GetKeySetFunc) OptionFunc { + funcs := []jwtutil.FindTokenOptionFunc{ + jwtutil.WithFinders( + jwtutil.FindTokenFromAuthorizationHeader, + jwtutil.FindTokenFromQueryString(CookieName), + jwtutil.FindTokenFromCookie(CookieName), + ), + } + + return func(o *Option) { + o.GetClaim = func(ctx context.Context, r *http.Request, name string) (string, error) { + token, err := jwtutil.FindToken(r, getKeySet, funcs...) + if err != nil { + return "", errors.WithStack(err) + } + + tokenMap, err := token.AsMap(ctx) + if err != nil { + return "", errors.WithStack(err) + } + + value, exists := tokenMap[name] + if !exists { + return "", nil + } + + return fmt.Sprintf("%v", value), nil + } + } +} diff --git a/pkg/storage/driver/rpc/driver.go b/pkg/storage/driver/rpc/driver.go index fc8d961..b814925 100644 --- a/pkg/storage/driver/rpc/driver.go +++ b/pkg/storage/driver/rpc/driver.go @@ -6,11 +6,13 @@ import ( "forge.cadoles.com/arcad/edge/pkg/storage" "forge.cadoles.com/arcad/edge/pkg/storage/driver" "forge.cadoles.com/arcad/edge/pkg/storage/driver/rpc/client" + "forge.cadoles.com/arcad/edge/pkg/storage/share" ) func init() { driver.RegisterDocumentStoreFactory("rpc", documentStoreFactory) driver.RegisterBlobStoreFactory("rpc", blobStoreFactory) + driver.RegisterShareStoreFactory("rpc", shareStoreFactory) } func documentStoreFactory(url *url.URL) (storage.DocumentStore, error) { @@ -20,3 +22,7 @@ func documentStoreFactory(url *url.URL) (storage.DocumentStore, error) { func blobStoreFactory(url *url.URL) (storage.BlobStore, error) { return client.NewBlobStore(url), nil } + +func shareStoreFactory(url *url.URL) (share.Store, error) { + return client.NewShareStore(url), nil +}