From d2472623f2d40f09cc5f3b9300ffc0826b4513f2 Mon Sep 17 00:00:00 2001 From: William Petit Date: Thu, 28 Sep 2023 23:41:01 -0600 Subject: [PATCH] feat(storage-server): jwt based authentication --- .env.dist | 3 +- .gitignore | 3 +- cmd/cli/command/app/run.go | 58 +++------ cmd/storage-server/command/auth/new_token.go | 53 ++++++++ cmd/storage-server/command/auth/root.go | 8 +- cmd/storage-server/command/flag/flag.go | 43 ++++++ cmd/storage-server/command/run.go | 119 +++++++++++++++-- .../src/public/index.html | 3 + .../src/public/test/share-module.js | 29 +++++ .../src/public/test/util.js | 7 + misc/client-sdk-testsuite/src/server/main.js | 7 + modd.conf | 11 +- pkg/{module/auth => jwtutil}/error.go | 3 +- pkg/jwtutil/io.go | 71 ++++++++++ pkg/jwtutil/key.go | 52 ++++++++ pkg/jwtutil/request.go | 123 ++++++++++++++++++ .../auth/jwt/jwt.go => jwtutil/token.go} | 8 +- pkg/module/auth/http/local_handler.go | 28 ++-- pkg/module/auth/jwt.go | 118 ----------------- pkg/module/auth/middleware/anonymous_user.go | 12 +- pkg/module/auth/module.go | 23 ++-- pkg/module/auth/module_test.go | 3 +- pkg/module/auth/mount.go | 43 +++--- pkg/module/auth/option.go | 47 ++++++- pkg/storage/driver/rpc/driver.go | 6 + 25 files changed, 646 insertions(+), 235 deletions(-) create mode 100644 cmd/storage-server/command/auth/new_token.go create mode 100644 cmd/storage-server/command/flag/flag.go create mode 100644 misc/client-sdk-testsuite/src/public/test/share-module.js create mode 100644 misc/client-sdk-testsuite/src/public/test/util.js rename pkg/{module/auth => jwtutil}/error.go (55%) create mode 100644 pkg/jwtutil/io.go create mode 100644 pkg/jwtutil/key.go create mode 100644 pkg/jwtutil/request.go rename pkg/{module/auth/jwt/jwt.go => jwtutil/token.go} (68%) delete mode 100644 pkg/module/auth/jwt.go diff --git a/.env.dist b/.env.dist index 6f09b89..4a9a61a 100644 --- a/.env.dist +++ b/.env.dist @@ -1,3 +1,4 @@ RUN_APP_ARGS="" #EDGE_DOCUMENTSTORE_DSN="rpc://localhost:3001/documentstore?tenant=local&appId=%APPID%" -#EDGE_BLOBSTORE_DSN="rpc://localhost:3001/blobstore?tenant=local&appId=%APPID%" \ No newline at end of file +#EDGE_BLOBSTORE_DSN="rpc://localhost:3001/blobstore?tenant=local&appId=%APPID%" +#EDGE_SHARESTORE_DSN="rpc://localhost:3001/sharestore?tenant=local" \ No newline at end of file diff --git a/.gitignore b/.gitignore index 942f5b3..c86f74d 100644 --- a/.gitignore +++ b/.gitignore @@ -9,4 +9,5 @@ .mktools/ /dist /.chglog -/CHANGELOG.md \ No newline at end of file +/CHANGELOG.md +/storage-server.key \ No newline at end of file diff --git a/cmd/cli/command/app/run.go b/cmd/cli/command/app/run.go index 05003e5..8ef5674 100644 --- a/cmd/cli/command/app/run.go +++ b/cmd/cli/command/app/run.go @@ -16,6 +16,7 @@ import ( "forge.cadoles.com/arcad/edge/pkg/bus" "forge.cadoles.com/arcad/edge/pkg/bus/memory" appHTTP "forge.cadoles.com/arcad/edge/pkg/http" + "forge.cadoles.com/arcad/edge/pkg/jwtutil" "forge.cadoles.com/arcad/edge/pkg/module" appModule "forge.cadoles.com/arcad/edge/pkg/module/app" appModuleMemory "forge.cadoles.com/arcad/edge/pkg/module/app/memory" @@ -50,6 +51,8 @@ import ( "forge.cadoles.com/arcad/edge/pkg/storage/share" ) +var dummySecret = []byte("not_so_secret") + func RunCommand() *cli.Command { return &cli.Command{ Name: "run", @@ -194,13 +197,14 @@ func runApp(ctx context.Context, path, address, documentStoreDSN, blobStoreDSN, ctx = logger.With(ctx, logger.F("appID", manifest.ID)) // Add auth handler - key, err := dummyKey() + key, err := jwtutil.NewSymmetricKey(dummySecret) if err != nil { return errors.WithStack(err) } deps := &moduleDeps{} funcs := []ModuleDepFunc{ + initAppID(manifest), initMemoryBus, initDatastores(documentStoreDSN, blobStoreDSN, shareStoreDSN, manifest.ID), initAccounts(accountsFile, manifest.ID), @@ -220,17 +224,18 @@ func runApp(ctx context.Context, path, address, documentStoreDSN, blobStoreDSN, appModule.Mount(appRepository), authModule.Mount( authHTTP.NewLocalHandler( - jwa.HS256, key, + key, + jwa.HS256, authHTTP.WithRoutePrefix("/auth"), authHTTP.WithAccounts(deps.Accounts...), ), - authModule.WithJWT(dummyKeySet), + authModule.WithJWT(func() (jwk.Set, error) { + return jwtutil.NewSymmetricKeySet(dummySecret) + }), ), ), appHTTP.WithHTTPMiddlewares( - authModuleMiddleware.AnonymousUser( - jwa.HS256, key, - ), + authModuleMiddleware.AnonymousUser(key, jwa.HS256), ), ) if err := handler.Load(bundle); err != nil { @@ -276,7 +281,9 @@ func getServerModules(deps *moduleDeps) []app.ServerModuleFactory { module.StoreModuleFactory(deps.DocumentStore), blob.ModuleFactory(deps.Bus, deps.BlobStore), authModule.ModuleFactory( - authModule.WithJWT(dummyKeySet), + authModule.WithJWT(func() (jwk.Set, error) { + return jwtutil.NewSymmetricKeySet(dummySecret) + }), ), appModule.ModuleFactory(deps.AppRepository), fetch.ModuleFactory(deps.Bus), @@ -284,36 +291,6 @@ func getServerModules(deps *moduleDeps) []app.ServerModuleFactory { } } -var dummySecret = []byte("not_so_secret") - -func dummyKey() (jwk.Key, error) { - key, err := jwk.FromRaw(dummySecret) - if err != nil { - return nil, errors.WithStack(err) - } - - return key, nil -} - -func dummyKeySet() (jwk.Set, error) { - key, err := dummyKey() - if err != nil { - return nil, errors.WithStack(err) - } - - if err := key.Set(jwk.AlgorithmKey, jwa.HS256); err != nil { - return nil, errors.WithStack(err) - } - - set := jwk.NewSet() - - if err := set.AddKey(key); err != nil { - return nil, errors.WithStack(err) - } - - return set, nil -} - func ensureDir(path string) error { if err := os.MkdirAll(filepath.Dir(path), os.ModePerm); err != nil { return errors.WithStack(err) @@ -435,6 +412,13 @@ func newAppRepository(host string, basePort uint64, manifests ...*app.Manifest) ) } +func initAppID(manifest *app.Manifest) ModuleDepFunc { + return func(deps *moduleDeps) error { + deps.AppID = manifest.ID + return nil + } +} + func initAppRepository(repo appModule.Repository) ModuleDepFunc { return func(deps *moduleDeps) error { deps.AppRepository = repo diff --git a/cmd/storage-server/command/auth/new_token.go b/cmd/storage-server/command/auth/new_token.go new file mode 100644 index 0000000..ecd3520 --- /dev/null +++ b/cmd/storage-server/command/auth/new_token.go @@ -0,0 +1,53 @@ +package auth + +import ( + "fmt" + + "forge.cadoles.com/arcad/edge/cmd/storage-server/command/flag" + "forge.cadoles.com/arcad/edge/pkg/jwtutil" + "github.com/lestrrat-go/jwx/v2/jwa" + "github.com/pkg/errors" + "github.com/urfave/cli/v2" +) + +func NewToken() *cli.Command { + return &cli.Command{ + Name: "new-token", + Usage: "Generate new authentication token", + Flags: []cli.Flag{ + &cli.StringFlag{ + Name: "tenant", + }, + flag.PrivateKey, + flag.PrivateKeySigningAlgorithm, + flag.PrivateKeyDefaultSize, + }, + Action: func(ctx *cli.Context) error { + privateKeyFile := flag.GetPrivateKey(ctx) + signingAlgorithm := flag.GetSigningAlgorithm(ctx) + privateKeyDefaultSize := flag.GetPrivateKeyDefaultSize(ctx) + tenant := ctx.String("tenant") + + privateKey, err := jwtutil.LoadOrGenerateKey( + privateKeyFile, + privateKeyDefaultSize, + ) + if err != nil { + return errors.WithStack(err) + } + + claims := map[string]any{ + "tenant": tenant, + } + + token, err := jwtutil.SignedToken(privateKey, jwa.SignatureAlgorithm(signingAlgorithm), claims) + if err != nil { + return errors.Wrap(err, "could not generate signed token") + } + + fmt.Println(string(token)) + + return nil + }, + } +} diff --git a/cmd/storage-server/command/auth/root.go b/cmd/storage-server/command/auth/root.go index bdf044a..1dc3ddc 100644 --- a/cmd/storage-server/command/auth/root.go +++ b/cmd/storage-server/command/auth/root.go @@ -6,8 +6,10 @@ import ( func Root() *cli.Command { return &cli.Command{ - Name: "auth", - Usage: "Auth related command", - Subcommands: []*cli.Command{}, + Name: "auth", + Usage: "Auth related command", + Subcommands: []*cli.Command{ + NewToken(), + }, } } diff --git a/cmd/storage-server/command/flag/flag.go b/cmd/storage-server/command/flag/flag.go new file mode 100644 index 0000000..99ede47 --- /dev/null +++ b/cmd/storage-server/command/flag/flag.go @@ -0,0 +1,43 @@ +package flag + +import ( + "github.com/lestrrat-go/jwx/v2/jwa" + "github.com/urfave/cli/v2" +) + +const PrivateKeyFlagName = "private-key" + +var PrivateKey = &cli.StringFlag{ + Name: PrivateKeyFlagName, + EnvVars: []string{"STORAGE_SERVER_PRIVATE_KEY"}, + Value: "storage-server.key", + TakesFile: true, +} + +func GetPrivateKey(ctx *cli.Context) string { + return ctx.String(PrivateKeyFlagName) +} + +const SigningAlgorithmFlagName = "signing-algorithm" + +var PrivateKeySigningAlgorithm = &cli.StringFlag{ + Name: SigningAlgorithmFlagName, + EnvVars: []string{"STORAGE_SERVER_PRIVATE_KEY_SIGNING_ALGORITHM"}, + Value: jwa.RS256.String(), +} + +func GetSigningAlgorithm(ctx *cli.Context) string { + return ctx.String(SigningAlgorithmFlagName) +} + +const PrivateKeyDefaultSizeFlagName = "private-key-default-size" + +var PrivateKeyDefaultSize = &cli.IntFlag{ + Name: PrivateKeyDefaultSizeFlagName, + EnvVars: []string{"STORAGE_SERVER_PRIVATE_KEY_DEFAULT_SIZE"}, + Value: 2048, +} + +func GetPrivateKeyDefaultSize(ctx *cli.Context) int { + return ctx.Int(PrivateKeyDefaultSizeFlagName) +} diff --git a/cmd/storage-server/command/run.go b/cmd/storage-server/command/run.go index 829745e..17681d7 100644 --- a/cmd/storage-server/command/run.go +++ b/cmd/storage-server/command/run.go @@ -1,6 +1,7 @@ package command import ( + "context" "fmt" "net/http" "strings" @@ -9,6 +10,8 @@ import ( "github.com/hashicorp/golang-lru/v2/expirable" "github.com/keegancsmith/rpc" + "github.com/lestrrat-go/jwx/v2/jwa" + "github.com/lestrrat-go/jwx/v2/jwk" "gitlab.com/wpetit/goweb/logger" "github.com/go-chi/chi/v5" @@ -17,6 +20,8 @@ import ( "github.com/urfave/cli/v2" // Register storage drivers + "forge.cadoles.com/arcad/edge/cmd/storage-server/command/flag" + "forge.cadoles.com/arcad/edge/pkg/jwtutil" "forge.cadoles.com/arcad/edge/pkg/storage" "forge.cadoles.com/arcad/edge/pkg/storage/driver" _ "forge.cadoles.com/arcad/edge/pkg/storage/driver/rpc" @@ -50,6 +55,9 @@ func Run() *cli.Command { EnvVars: []string{"STORAGE_SERVER_SHARESTORE_DSN_PATTERN"}, Value: fmt.Sprintf("sqlite://data/%%TENANT%%/sharestore.sqlite?_pragma=foreign_keys(1)&_pragma=busy_timeout=%d", (60 * time.Second).Milliseconds()), }, + flag.PrivateKey, + flag.PrivateKeySigningAlgorithm, + flag.PrivateKeyDefaultSize, &cli.DurationFlag{ Name: "cache-ttl", EnvVars: []string{"STORAGE_SERVER_CACHE_TTL"}, @@ -68,9 +76,25 @@ func Run() *cli.Command { shareStoreDSNPattern := ctx.String("sharestore-dsn-pattern") cacheSize := ctx.Int("cache-size") cacheTTL := ctx.Duration("cache-ttl") + privateKeyFile := flag.GetPrivateKey(ctx) + signingAlgorithm := flag.GetSigningAlgorithm(ctx) + privateKeyDefaultSize := flag.GetPrivateKeyDefaultSize(ctx) router := chi.NewRouter() + privateKey, err := jwtutil.LoadOrGenerateKey( + privateKeyFile, + privateKeyDefaultSize, + ) + if err != nil { + return errors.WithStack(err) + } + + publicKey, err := privateKey.PublicKey() + if err != nil { + return errors.WithStack(err) + } + getBlobStoreServer := createGetCachedStoreServer( func(dsn string) (storage.BlobStore, error) { return driver.NewBlobStore(dsn) @@ -100,10 +124,11 @@ func Run() *cli.Command { router.Use(middleware.RealIP) router.Use(middleware.Logger) + router.Use(authenticate(publicKey, jwa.SignatureAlgorithm(signingAlgorithm))) - router.Handle("/blobstore", createStoreHandler(getBlobStoreServer, blobStoreDSNPattern, cacheSize, cacheTTL)) - router.Handle("/documentstore", createStoreHandler(getDocumentStoreServer, documentStoreDSNPattern, cacheSize, cacheTTL)) - router.Handle("/sharestore", createStoreHandler(getShareStoreServer, shareStoreDSNPattern, cacheSize, cacheTTL)) + router.Handle("/blobstore", createStoreHandler(getBlobStoreServer, blobStoreDSNPattern, true, cacheSize, cacheTTL)) + router.Handle("/documentstore", createStoreHandler(getDocumentStoreServer, documentStoreDSNPattern, true, cacheSize, cacheTTL)) + router.Handle("/sharestore", createStoreHandler(getShareStoreServer, shareStoreDSNPattern, false, cacheSize, cacheTTL)) if err := http.ListenAndServe(addr, router); err != nil { return errors.WithStack(err) @@ -150,17 +175,19 @@ func createGetCachedStoreServer[T any](storeFactory func(dsn string) (T, error), } } -func createStoreHandler(getStoreServer getRPCServerFunc, dsnPattern string, cacheSize int, cacheTTL time.Duration) http.Handler { +func createStoreHandler(getStoreServer getRPCServerFunc, dsnPattern string, appIDRequired bool, cacheSize int, cacheTTL time.Duration) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - tenant := r.URL.Query().Get("tenant") - if tenant == "" { - http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) + ctx := r.Context() + + tenant, ok := ctx.Value("tenant").(string) + if !ok || tenant == "" { + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } appID := r.URL.Query().Get("appId") - if tenant == "" { + if appIDRequired && appID == "" { http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) return @@ -177,3 +204,79 @@ func createStoreHandler(getStoreServer getRPCServerFunc, dsnPattern string, cach server.ServeHTTP(w, r) }) } + +func authenticate(privateKey jwk.Key, signingAlgorithm jwa.SignatureAlgorithm) func(http.Handler) http.Handler { + var ( + createKeySet sync.Once + err error + getKeySet jwtutil.GetKeySetFunc + ) + + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + createKeySet.Do(func() { + err = privateKey.Set(jwk.AlgorithmKey, signingAlgorithm) + if err != nil { + return + } + + var keySet jwk.Set + + keySet, err = jwtutil.NewKeySet(privateKey) + if err != nil { + return + } + + getKeySet = func() (jwk.Set, error) { + return keySet, nil + } + }) + if err != nil { + logger.Error(ctx, "could not create keyset accessor", logger.E(errors.WithStack(err))) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + + return + } + + token, err := jwtutil.FindToken(r, getKeySet, jwtutil.WithFinders( + jwtutil.FindTokenFromQueryString("token"), + )) + if err != nil { + logger.Error(ctx, "could not find jwt token", logger.E(errors.WithStack(err))) + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + + return + } + + tokenMap, err := token.AsMap(ctx) + if err != nil { + logger.Error(ctx, "could not transform token to map", logger.E(errors.WithStack(err))) + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + + return + } + + rawTenant, exists := tokenMap["tenant"] + if !exists { + logger.Warn(ctx, "could not find tenant claim", logger.F("token", token)) + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + + return + } + + tenant, ok := rawTenant.(string) + if !ok { + logger.Warn(ctx, "unexpected tenant claim value", logger.F("token", token)) + http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) + + return + } + + r = r.WithContext(context.WithValue(ctx, "tenant", tenant)) + + next.ServeHTTP(w, r) + }) + } +} diff --git a/misc/client-sdk-testsuite/src/public/index.html b/misc/client-sdk-testsuite/src/public/index.html index 0511898..26c3dc9 100644 --- a/misc/client-sdk-testsuite/src/public/index.html +++ b/misc/client-sdk-testsuite/src/public/index.html @@ -24,6 +24,7 @@ mocha.checkLeaks(); + @@ -31,6 +32,7 @@ + \ No newline at end of file diff --git a/misc/client-sdk-testsuite/src/public/test/share-module.js b/misc/client-sdk-testsuite/src/public/test/share-module.js new file mode 100644 index 0000000..5b155f8 --- /dev/null +++ b/misc/client-sdk-testsuite/src/public/test/share-module.js @@ -0,0 +1,29 @@ +describe('Share Module', function() { + + before(() => { + return Edge.Client.connect(); + }); + + after(() => { + Edge.Client.disconnect(); + }); + + it('should create a new resource and find it', async () => { + const resource = await TestUtil.serverSideCall('share', 'upsertResource', 'my-resource', { name: "color", type: "text", value: "red" }); + chai.assert.isNotNull(resource); + chai.assert.equal(resource.origin, 'edge.sdk.client.test') + + + const results = await TestUtil.serverSideCall('share', 'findResources', 'color', 'text'); + chai.assert.isAbove(results.length, 0); + + const createdResource = results.find(res => { + return res.origin === 'edge.sdk.client.test' && + res.attributes.find(attr => attr.name === 'color' && attr.type === 'text') + }) + + chai.assert.isNotNull(createdResource) + + console.log(createdResource) + }); +}); \ No newline at end of file diff --git a/misc/client-sdk-testsuite/src/public/test/util.js b/misc/client-sdk-testsuite/src/public/test/util.js new file mode 100644 index 0000000..2c70867 --- /dev/null +++ b/misc/client-sdk-testsuite/src/public/test/util.js @@ -0,0 +1,7 @@ +(function(TestUtil) { + TestUtil.serverSideCall = (module, func, ...args) => { + return Edge.Client.rpc('serverSideCall', { module, func, args }) + } + console.log(TestUtil) + +}(globalThis.TestUtil = globalThis.TestUtil || {})); \ No newline at end of file diff --git a/misc/client-sdk-testsuite/src/server/main.js b/misc/client-sdk-testsuite/src/server/main.js index df6f82a..d492dec 100644 --- a/misc/client-sdk-testsuite/src/server/main.js +++ b/misc/client-sdk-testsuite/src/server/main.js @@ -15,6 +15,8 @@ function onInit() { rpc.register("listApps"); rpc.register("getApp"); rpc.register("getAppUrl"); + + rpc.register("serverSideCall", serverSideCall) } // Called for each client message @@ -103,4 +105,9 @@ function getAppUrl(ctx, params) { function onClientFetch(ctx, url, remoteAddr) { return { allow: url === 'http://example.com' }; +} + +function serverSideCall(ctx, params) { + console.log("Calling %s.%s(args...)", params.module, params.func) + return globalThis[params.module][params.func].call(null, ctx, ...params.args); } \ No newline at end of file diff --git a/modd.conf b/modd.conf index 5f1a776..ec80e1c 100644 --- a/modd.conf +++ b/modd.conf @@ -2,16 +2,19 @@ **/*.tmpl pkg/sdk/client/src/**/*.js pkg/sdk/client/src/**/*.ts -misc/client-sdk-testsuite/src/**/* +misc/client-sdk-testsuite/dist/server/*.js modd.conf { - prep: make build-sdk - prep: make build-client-sdk-test-app - prep: make build + prep: make build-sdk build-cli build-storage-server daemon: make run-app daemon: make run-storage-server } +misc/client-sdk-testsuite/src/**/* +{ + prep: make build-client-sdk-test-app +} + **/*.go { prep: make GOTEST_ARGS="-short" test } \ No newline at end of file diff --git a/pkg/module/auth/error.go b/pkg/jwtutil/error.go similarity index 55% rename from pkg/module/auth/error.go rename to pkg/jwtutil/error.go index 96618f9..9a789f8 100644 --- a/pkg/module/auth/error.go +++ b/pkg/jwtutil/error.go @@ -1,5 +1,6 @@ -package auth +package jwtutil import "errors" var ErrUnauthenticated = errors.New("unauthenticated") +var ErrNoKeySet = errors.New("no keyset") diff --git a/pkg/jwtutil/io.go b/pkg/jwtutil/io.go new file mode 100644 index 0000000..213b5f1 --- /dev/null +++ b/pkg/jwtutil/io.go @@ -0,0 +1,71 @@ +package jwtutil + +import ( + "crypto/rand" + "crypto/rsa" + "os" + + "github.com/lestrrat-go/jwx/v2/jwk" + "github.com/pkg/errors" +) + +func LoadOrGenerateKey(path string, defaultKeySize int) (jwk.Key, error) { + key, err := LoadKey(path) + if err != nil { + if !errors.Is(err, os.ErrNotExist) { + return nil, errors.WithStack(err) + } + + key, err = GenerateKey(defaultKeySize) + if err != nil { + return nil, errors.WithStack(err) + } + + if err := SaveKey(path, key); err != nil { + return nil, errors.WithStack(err) + } + } + + return key, nil +} + +func LoadKey(path string) (jwk.Key, error) { + data, err := os.ReadFile(path) + if err != nil { + return nil, errors.WithStack(err) + } + + key, err := jwk.ParseKey(data, jwk.WithPEM(true)) + if err != nil { + return nil, errors.WithStack(err) + } + + return key, nil +} + +func SaveKey(path string, key jwk.Key) error { + data, err := jwk.Pem(key) + if err != nil { + return errors.WithStack(err) + } + + if err := os.WriteFile(path, data, os.FileMode(0600)); err != nil { + return errors.WithStack(err) + } + + return nil +} + +func GenerateKey(keySize int) (jwk.Key, error) { + rsaKey, err := rsa.GenerateKey(rand.Reader, keySize) + if err != nil { + return nil, errors.WithStack(err) + } + + key, err := jwk.FromRaw(rsaKey) + if err != nil { + return nil, errors.WithStack(err) + } + + return key, nil +} diff --git a/pkg/jwtutil/key.go b/pkg/jwtutil/key.go new file mode 100644 index 0000000..9a23627 --- /dev/null +++ b/pkg/jwtutil/key.go @@ -0,0 +1,52 @@ +package jwtutil + +import ( + "github.com/lestrrat-go/jwx/v2/jwa" + "github.com/lestrrat-go/jwx/v2/jwk" + "github.com/pkg/errors" +) + +func NewKeySet(keys ...jwk.Key) (jwk.Set, error) { + set := jwk.NewSet() + + for _, k := range keys { + if err := set.AddKey(k); err != nil { + return nil, errors.WithStack(err) + } + } + + return set, nil +} + +func NewSymmetricKey(secret []byte) (jwk.Key, error) { + key, err := jwk.FromRaw(secret) + if err != nil { + return nil, errors.WithStack(err) + } + + if err := key.Set(jwk.AlgorithmKey, jwa.HS256); err != nil { + return nil, errors.WithStack(err) + } + + return key, nil +} + +func NewSymmetricKeySet(secrets ...[]byte) (jwk.Set, error) { + keys := make([]jwk.Key, len(secrets)) + + for idx, sec := range secrets { + key, err := NewSymmetricKey(sec) + if err != nil { + return nil, errors.WithStack(err) + } + + keys[idx] = key + } + + keySet, err := NewKeySet(keys...) + if err != nil { + return nil, errors.WithStack(err) + } + + return keySet, nil +} diff --git a/pkg/jwtutil/request.go b/pkg/jwtutil/request.go new file mode 100644 index 0000000..0d8602f --- /dev/null +++ b/pkg/jwtutil/request.go @@ -0,0 +1,123 @@ +package jwtutil + +import ( + "net/http" + "strings" + + "github.com/lestrrat-go/jwx/v2/jwk" + "github.com/lestrrat-go/jwx/v2/jws" + "github.com/lestrrat-go/jwx/v2/jwt" + "github.com/pkg/errors" +) + +type TokenFinderFunc func(r *http.Request) (string, error) + +type FindTokenOptions struct { + Finders []TokenFinderFunc +} + +type FindTokenOptionFunc func(*FindTokenOptions) + +type GetKeySetFunc func() (jwk.Set, error) + +func WithFinders(finders ...TokenFinderFunc) FindTokenOptionFunc { + return func(opts *FindTokenOptions) { + opts.Finders = finders + } +} + +func NewFindTokenOptions(funcs ...FindTokenOptionFunc) *FindTokenOptions { + opts := &FindTokenOptions{ + Finders: []TokenFinderFunc{ + FindTokenFromAuthorizationHeader, + }, + } + + for _, fn := range funcs { + fn(opts) + } + + return opts +} + +func FindTokenFromAuthorizationHeader(r *http.Request) (string, error) { + authorization := r.Header.Get("Authorization") + + // Retrieve token from Authorization header + rawToken := strings.TrimPrefix(authorization, "Bearer ") + + return rawToken, nil +} + +func FindTokenFromQueryString(name string) TokenFinderFunc { + return func(r *http.Request) (string, error) { + return r.URL.Query().Get(name), nil + } +} + +func FindTokenFromCookie(cookieName string) TokenFinderFunc { + return func(r *http.Request) (string, error) { + cookie, err := r.Cookie(cookieName) + if err != nil && !errors.Is(err, http.ErrNoCookie) { + return "", errors.WithStack(err) + } + + if cookie == nil { + return "", nil + } + + return cookie.Value, nil + } +} + +func FindRawToken(r *http.Request, funcs ...FindTokenOptionFunc) (string, error) { + opts := NewFindTokenOptions(funcs...) + + var rawToken string + var err error + + for _, find := range opts.Finders { + rawToken, err = find(r) + if err != nil { + return "", errors.WithStack(err) + } + + if rawToken == "" { + continue + } + + break + } + + if rawToken == "" { + return "", errors.WithStack(ErrUnauthenticated) + } + + return rawToken, nil +} + +func FindToken(r *http.Request, getKeySet GetKeySetFunc, funcs ...FindTokenOptionFunc) (jwt.Token, error) { + rawToken, err := FindRawToken(r, funcs...) + if err != nil { + return nil, errors.WithStack(err) + } + + keySet, err := getKeySet() + if err != nil { + return nil, errors.WithStack(err) + } + + if keySet == nil { + return nil, errors.WithStack(ErrNoKeySet) + } + + token, err := jwt.Parse([]byte(rawToken), + jwt.WithKeySet(keySet, jws.WithRequireKid(false)), + jwt.WithValidate(true), + ) + if err != nil { + return nil, errors.WithStack(err) + } + + return token, nil +} diff --git a/pkg/module/auth/jwt/jwt.go b/pkg/jwtutil/token.go similarity index 68% rename from pkg/module/auth/jwt/jwt.go rename to pkg/jwtutil/token.go index 0d128b9..da86988 100644 --- a/pkg/module/auth/jwt/jwt.go +++ b/pkg/jwtutil/token.go @@ -1,4 +1,4 @@ -package jwt +package jwtutil import ( "time" @@ -9,7 +9,7 @@ import ( "github.com/pkg/errors" ) -func GenerateSignedToken(algo jwa.KeyAlgorithm, key jwk.Key, claims map[string]any) ([]byte, error) { +func SignedToken(key jwk.Key, signingAlgorithm jwa.SignatureAlgorithm, claims map[string]any) ([]byte, error) { token := jwt.New() if err := token.Set(jwt.NotBeforeKey, time.Now()); err != nil { @@ -22,11 +22,11 @@ func GenerateSignedToken(algo jwa.KeyAlgorithm, key jwk.Key, claims map[string]a } } - if err := token.Set(jwk.AlgorithmKey, jwa.HS256); err != nil { + if err := token.Set(jwk.AlgorithmKey, signingAlgorithm); err != nil { return nil, errors.WithStack(err) } - rawToken, err := jwt.Sign(token, jwt.WithKey(algo, key)) + rawToken, err := jwt.Sign(token, jwt.WithKey(signingAlgorithm, key)) if err != nil { return nil, errors.WithStack(err) } diff --git a/pkg/module/auth/http/local_handler.go b/pkg/module/auth/http/local_handler.go index 302da25..91999a4 100644 --- a/pkg/module/auth/http/local_handler.go +++ b/pkg/module/auth/http/local_handler.go @@ -7,9 +7,9 @@ import ( _ "embed" + "forge.cadoles.com/arcad/edge/pkg/jwtutil" "forge.cadoles.com/arcad/edge/pkg/module/auth" "forge.cadoles.com/arcad/edge/pkg/module/auth/http/passwd" - "forge.cadoles.com/arcad/edge/pkg/module/auth/jwt" "github.com/go-chi/chi/v5" "github.com/lestrrat-go/jwx/v2/jwa" "github.com/lestrrat-go/jwx/v2/jwk" @@ -31,12 +31,12 @@ func init() { } type LocalHandler struct { - router chi.Router - algo jwa.KeyAlgorithm - key jwk.Key - getCookieDomain GetCookieDomainFunc - cookieDuration time.Duration - accounts map[string]LocalAccount + router chi.Router + key jwk.Key + signingAlgorithm jwa.SignatureAlgorithm + getCookieDomain GetCookieDomainFunc + cookieDuration time.Duration + accounts map[string]LocalAccount } func (h *LocalHandler) initRouter(prefix string) { @@ -113,7 +113,7 @@ func (h *LocalHandler) handleForm(w http.ResponseWriter, r *http.Request) { account.Claims[auth.ClaimIssuer] = "local" - token, err := jwt.GenerateSignedToken(h.algo, h.key, account.Claims) + token, err := jwtutil.SignedToken(h.key, h.signingAlgorithm, account.Claims) if err != nil { logger.Error(ctx, "could not generate signed token", logger.E(errors.WithStack(err))) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) @@ -182,18 +182,18 @@ func (h *LocalHandler) authenticate(username, password string) (*LocalAccount, e return &account, nil } -func NewLocalHandler(algo jwa.KeyAlgorithm, key jwk.Key, funcs ...LocalHandlerOptionFunc) *LocalHandler { +func NewLocalHandler(key jwk.Key, signingAlgorithm jwa.SignatureAlgorithm, funcs ...LocalHandlerOptionFunc) *LocalHandler { opts := defaultLocalHandlerOptions() for _, fn := range funcs { fn(opts) } handler := &LocalHandler{ - algo: algo, - key: key, - accounts: toAccountsMap(opts.Accounts), - getCookieDomain: opts.GetCookieDomain, - cookieDuration: opts.CookieDuration, + key: key, + signingAlgorithm: signingAlgorithm, + accounts: toAccountsMap(opts.Accounts), + getCookieDomain: opts.GetCookieDomain, + cookieDuration: opts.CookieDuration, } handler.initRouter(opts.RoutePrefix) diff --git a/pkg/module/auth/jwt.go b/pkg/module/auth/jwt.go deleted file mode 100644 index b96075c..0000000 --- a/pkg/module/auth/jwt.go +++ /dev/null @@ -1,118 +0,0 @@ -package auth - -import ( - "context" - "net/http" - "strings" - - "github.com/lestrrat-go/jwx/v2/jwk" - "github.com/lestrrat-go/jwx/v2/jws" - "github.com/lestrrat-go/jwx/v2/jwt" - "github.com/pkg/errors" -) - -const ( - CookieName string = "edge-auth" -) - -type GetKeySetFunc func() (jwk.Set, error) - -func WithJWT(getKeySet GetKeySetFunc) OptionFunc { - return func(o *Option) { - o.GetClaims = func(ctx context.Context, r *http.Request, names ...string) ([]string, error) { - claim, err := getClaims[string](r, getKeySet, names...) - if err != nil { - return nil, errors.WithStack(err) - } - - return claim, nil - } - } -} - -func FindRawToken(r *http.Request) (string, error) { - authorization := r.Header.Get("Authorization") - - // Retrieve token from Authorization header - rawToken := strings.TrimPrefix(authorization, "Bearer ") - - // Retrieve token from ?edge-auth= - if rawToken == "" { - rawToken = r.URL.Query().Get(CookieName) - } - - if rawToken == "" { - cookie, err := r.Cookie(CookieName) - if err != nil && !errors.Is(err, http.ErrNoCookie) { - return "", errors.WithStack(err) - } - - if cookie != nil { - rawToken = cookie.Value - } - } - - if rawToken == "" { - return "", errors.WithStack(ErrUnauthenticated) - } - - return rawToken, nil -} - -func FindToken(r *http.Request, getKeySet GetKeySetFunc) (jwt.Token, error) { - rawToken, err := FindRawToken(r) - if err != nil { - return nil, errors.WithStack(err) - } - - keySet, err := getKeySet() - if err != nil { - return nil, errors.WithStack(err) - } - - if keySet == nil { - return nil, errors.New("no keyset") - } - - token, err := jwt.Parse([]byte(rawToken), - jwt.WithKeySet(keySet, jws.WithRequireKid(false)), - jwt.WithValidate(true), - ) - if err != nil { - return nil, errors.WithStack(err) - } - - return token, nil -} - -func getClaims[T any](r *http.Request, getKeySet GetKeySetFunc, names ...string) ([]T, error) { - token, err := FindToken(r, getKeySet) - if err != nil { - return nil, errors.WithStack(err) - } - - ctx := r.Context() - - mapClaims, err := token.AsMap(ctx) - if err != nil { - return nil, errors.WithStack(err) - } - - claims := make([]T, len(names)) - - for idx, n := range names { - rawClaim, exists := mapClaims[n] - if !exists { - continue - } - - claim, ok := rawClaim.(T) - if !ok { - return nil, errors.Errorf("unexpected claim '%s' to be of type '%T', got '%T'", n, new(T), rawClaim) - } - - claims[idx] = claim - } - - return claims, nil -} diff --git a/pkg/module/auth/middleware/anonymous_user.go b/pkg/module/auth/middleware/anonymous_user.go index c594b9a..234372c 100644 --- a/pkg/module/auth/middleware/anonymous_user.go +++ b/pkg/module/auth/middleware/anonymous_user.go @@ -7,8 +7,8 @@ import ( "net/http" "time" + "forge.cadoles.com/arcad/edge/pkg/jwtutil" "forge.cadoles.com/arcad/edge/pkg/module/auth" - "forge.cadoles.com/arcad/edge/pkg/module/auth/jwt" "github.com/google/uuid" "github.com/lestrrat-go/jwx/v2/jwa" "github.com/lestrrat-go/jwx/v2/jwk" @@ -18,7 +18,7 @@ import ( const AnonIssuer = "anon" -func AnonymousUser(algo jwa.KeyAlgorithm, key jwk.Key, funcs ...AnonymousUserOptionFunc) func(next http.Handler) http.Handler { +func AnonymousUser(key jwk.Key, signingAlgorithm jwa.SignatureAlgorithm, funcs ...AnonymousUserOptionFunc) func(next http.Handler) http.Handler { opts := defaultAnonymousUserOptions() for _, fn := range funcs { fn(opts) @@ -26,7 +26,11 @@ func AnonymousUser(algo jwa.KeyAlgorithm, key jwk.Key, funcs ...AnonymousUserOpt return func(next http.Handler) http.Handler { handler := func(w http.ResponseWriter, r *http.Request) { - rawToken, err := auth.FindRawToken(r) + rawToken, err := jwtutil.FindRawToken(r, jwtutil.WithFinders( + jwtutil.FindTokenFromAuthorizationHeader, + jwtutil.FindTokenFromQueryString(auth.CookieName), + jwtutil.FindTokenFromCookie(auth.CookieName), + )) // If request already has a raw token, we do nothing if rawToken != "" && err == nil { @@ -62,7 +66,7 @@ func AnonymousUser(algo jwa.KeyAlgorithm, key jwk.Key, funcs ...AnonymousUserOpt auth.ClaimEdgeTenant: opts.Tenant, } - token, err := jwt.GenerateSignedToken(algo, key, claims) + token, err := jwtutil.SignedToken(key, signingAlgorithm, claims) if err != nil { logger.Error(ctx, "could not generate signed token", logger.E(errors.WithStack(err))) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) diff --git a/pkg/module/auth/module.go b/pkg/module/auth/module.go index 495a729..90e2cb8 100644 --- a/pkg/module/auth/module.go +++ b/pkg/module/auth/module.go @@ -5,12 +5,17 @@ import ( "forge.cadoles.com/arcad/edge/pkg/app" edgeHTTP "forge.cadoles.com/arcad/edge/pkg/http" + "forge.cadoles.com/arcad/edge/pkg/jwtutil" "forge.cadoles.com/arcad/edge/pkg/module/util" "github.com/dop251/goja" "github.com/pkg/errors" "gitlab.com/wpetit/goweb/logger" ) +const ( + CookieName string = "edge-auth" +) + const ( ClaimSubject = "sub" ClaimIssuer = "iss" @@ -21,8 +26,8 @@ const ( ) type Module struct { - server *app.Server - getClaims GetClaimsFunc + server *app.Server + getClaimFn GetClaimFunc } func (m *Module) Name() string { @@ -68,9 +73,9 @@ func (m *Module) getClaim(call goja.FunctionCall, rt *goja.Runtime) goja.Value { panic(rt.ToValue(errors.New("could not find http request in context"))) } - claim, err := m.getClaims(ctx, req, claimName) + claim, err := m.getClaimFn(ctx, req, claimName) if err != nil { - if errors.Is(err, ErrUnauthenticated) { + if errors.Is(err, jwtutil.ErrUnauthenticated) { return nil } @@ -78,11 +83,7 @@ func (m *Module) getClaim(call goja.FunctionCall, rt *goja.Runtime) goja.Value { return nil } - if len(claim) == 0 || claim[0] == "" { - return nil - } - - return rt.ToValue(claim[0]) + return rt.ToValue(claim) } func ModuleFactory(funcs ...OptionFunc) app.ServerModuleFactory { @@ -93,8 +94,8 @@ func ModuleFactory(funcs ...OptionFunc) app.ServerModuleFactory { return func(server *app.Server) app.ServerModule { return &Module{ - server: server, - getClaims: opt.GetClaims, + server: server, + getClaimFn: opt.GetClaim, } } } diff --git a/pkg/module/auth/module_test.go b/pkg/module/auth/module_test.go index 17cf933..a7fb233 100644 --- a/pkg/module/auth/module_test.go +++ b/pkg/module/auth/module_test.go @@ -10,6 +10,7 @@ import ( "cdr.dev/slog" "forge.cadoles.com/arcad/edge/pkg/app" edgeHTTP "forge.cadoles.com/arcad/edge/pkg/http" + "forge.cadoles.com/arcad/edge/pkg/jwtutil" "forge.cadoles.com/arcad/edge/pkg/module" "github.com/lestrrat-go/jwx/v2/jwa" "github.com/lestrrat-go/jwx/v2/jwk" @@ -130,7 +131,7 @@ func getDummyKey() jwk.Key { return key } -func getDummyKeySet(key jwk.Key) GetKeySetFunc { +func getDummyKeySet(key jwk.Key) jwtutil.GetKeySetFunc { return func() (jwk.Set, error) { set := jwk.NewSet() diff --git a/pkg/module/auth/mount.go b/pkg/module/auth/mount.go index f193150..ea3cd3a 100644 --- a/pkg/module/auth/mount.go +++ b/pkg/module/auth/mount.go @@ -3,6 +3,7 @@ package auth import ( "net/http" + "forge.cadoles.com/arcad/edge/pkg/jwtutil" "github.com/go-chi/chi/v5" "github.com/pkg/errors" "gitlab.com/wpetit/goweb/api" @@ -12,39 +13,39 @@ import ( type MountFunc func(r chi.Router) type Handler struct { - getClaims GetClaimsFunc + getClaim GetClaimFunc profileClaims []string } func (h *Handler) serveProfile(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - claims, err := h.getClaims(ctx, r, h.profileClaims...) - if err != nil { - if errors.Is(err, ErrUnauthenticated) { + profile := make(map[string]any) + + for _, name := range h.profileClaims { + value, err := h.getClaim(ctx, r, name) + if err != nil { + if errors.Is(err, jwtutil.ErrUnauthenticated) { + api.ErrorResponse( + w, http.StatusUnauthorized, + api.ErrCodeUnauthorized, + nil, + ) + + return + } + + logger.Error(ctx, "could not retrieve claims", logger.E(errors.WithStack(err))) api.ErrorResponse( - w, http.StatusUnauthorized, - api.ErrCodeUnauthorized, + w, http.StatusInternalServerError, + api.ErrCodeUnknownError, nil, ) return } - logger.Error(ctx, "could not retrieve claims", logger.E(errors.WithStack(err))) - api.ErrorResponse( - w, http.StatusInternalServerError, - api.ErrCodeUnknownError, - nil, - ) - - return - } - - profile := make(map[string]any) - - for idx, cl := range h.profileClaims { - profile[cl] = claims[idx] + profile[name] = value } api.DataResponse(w, http.StatusOK, struct { @@ -62,7 +63,7 @@ func Mount(authHandler http.Handler, funcs ...OptionFunc) MountFunc { handler := &Handler{ profileClaims: opt.ProfileClaims, - getClaims: opt.GetClaims, + getClaim: opt.GetClaim, } return func(r chi.Router) { diff --git a/pkg/module/auth/option.go b/pkg/module/auth/option.go index a10bbdb..d527290 100644 --- a/pkg/module/auth/option.go +++ b/pkg/module/auth/option.go @@ -2,15 +2,17 @@ package auth import ( "context" + "fmt" "net/http" + "forge.cadoles.com/arcad/edge/pkg/jwtutil" "github.com/pkg/errors" ) -type GetClaimsFunc func(ctx context.Context, r *http.Request, claims ...string) ([]string, error) +type GetClaimFunc func(ctx context.Context, r *http.Request, name string) (string, error) type Option struct { - GetClaims GetClaimsFunc + GetClaim GetClaimFunc ProfileClaims []string } @@ -18,7 +20,7 @@ type OptionFunc func(*Option) func defaultOptions() *Option { return &Option{ - GetClaims: dummyGetClaims, + GetClaim: dummyGetClaim, ProfileClaims: []string{ ClaimSubject, ClaimIssuer, @@ -30,13 +32,13 @@ func defaultOptions() *Option { } } -func dummyGetClaims(ctx context.Context, r *http.Request, claims ...string) ([]string, error) { - return nil, errors.Errorf("dummy getclaim func cannot retrieve claims '%s'", claims) +func dummyGetClaim(ctx context.Context, r *http.Request, name string) (string, error) { + return "", errors.Errorf("dummy getclaim func cannot retrieve claim '%s'", name) } -func WithGetClaims(fn GetClaimsFunc) OptionFunc { +func WithGetClaims(fn GetClaimFunc) OptionFunc { return func(o *Option) { - o.GetClaims = fn + o.GetClaim = fn } } @@ -45,3 +47,34 @@ func WithProfileClaims(claims ...string) OptionFunc { o.ProfileClaims = claims } } + +func WithJWT(getKeySet jwtutil.GetKeySetFunc) OptionFunc { + funcs := []jwtutil.FindTokenOptionFunc{ + jwtutil.WithFinders( + jwtutil.FindTokenFromAuthorizationHeader, + jwtutil.FindTokenFromQueryString(CookieName), + jwtutil.FindTokenFromCookie(CookieName), + ), + } + + return func(o *Option) { + o.GetClaim = func(ctx context.Context, r *http.Request, name string) (string, error) { + token, err := jwtutil.FindToken(r, getKeySet, funcs...) + if err != nil { + return "", errors.WithStack(err) + } + + tokenMap, err := token.AsMap(ctx) + if err != nil { + return "", errors.WithStack(err) + } + + value, exists := tokenMap[name] + if !exists { + return "", nil + } + + return fmt.Sprintf("%v", value), nil + } + } +} diff --git a/pkg/storage/driver/rpc/driver.go b/pkg/storage/driver/rpc/driver.go index fc8d961..b814925 100644 --- a/pkg/storage/driver/rpc/driver.go +++ b/pkg/storage/driver/rpc/driver.go @@ -6,11 +6,13 @@ import ( "forge.cadoles.com/arcad/edge/pkg/storage" "forge.cadoles.com/arcad/edge/pkg/storage/driver" "forge.cadoles.com/arcad/edge/pkg/storage/driver/rpc/client" + "forge.cadoles.com/arcad/edge/pkg/storage/share" ) func init() { driver.RegisterDocumentStoreFactory("rpc", documentStoreFactory) driver.RegisterBlobStoreFactory("rpc", blobStoreFactory) + driver.RegisterShareStoreFactory("rpc", shareStoreFactory) } func documentStoreFactory(url *url.URL) (storage.DocumentStore, error) { @@ -20,3 +22,7 @@ func documentStoreFactory(url *url.URL) (storage.DocumentStore, error) { func blobStoreFactory(url *url.URL) (storage.BlobStore, error) { return client.NewBlobStore(url), nil } + +func shareStoreFactory(url *url.URL) (share.Store, error) { + return client.NewShareStore(url), nil +}