package command import ( "context" "fmt" "net/http" "strings" "sync" "time" "github.com/hashicorp/golang-lru/v2/expirable" "github.com/keegancsmith/rpc" "github.com/lestrrat-go/jwx/v2/jwa" "github.com/lestrrat-go/jwx/v2/jwk" "gitlab.com/wpetit/goweb/logger" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/pkg/errors" "github.com/urfave/cli/v2" // Register storage drivers "forge.cadoles.com/arcad/edge/cmd/storage-server/command/flag" "forge.cadoles.com/arcad/edge/pkg/jwtutil" "forge.cadoles.com/arcad/edge/pkg/storage" "forge.cadoles.com/arcad/edge/pkg/storage/driver" _ "forge.cadoles.com/arcad/edge/pkg/storage/driver/rpc" "forge.cadoles.com/arcad/edge/pkg/storage/driver/rpc/server" _ "forge.cadoles.com/arcad/edge/pkg/storage/driver/sqlite" "forge.cadoles.com/arcad/edge/pkg/storage/share" ) func Run() *cli.Command { return &cli.Command{ Name: "run", Usage: "Run server", Flags: []cli.Flag{ &cli.StringFlag{ Name: "address", Aliases: []string{"addr"}, Value: ":3001", }, &cli.StringFlag{ Name: "blobstore-dsn-pattern", EnvVars: []string{"STORAGE_SERVER_BLOBSTORE_DSN_PATTERN"}, Value: fmt.Sprintf("sqlite://data/%%TENANT%%/%%APPID%%/blobstore.sqlite?_pragma=foreign_keys(1)&_pragma=busy_timeout=%d", (60 * time.Second).Milliseconds()), }, &cli.StringFlag{ Name: "documentstore-dsn-pattern", EnvVars: []string{"STORAGE_SERVER_DOCUMENTSTORE_DSN_PATTERN"}, Value: fmt.Sprintf("sqlite://data/%%TENANT%%/%%APPID%%/documentstore.sqlite?_pragma=foreign_keys(1)&_pragma=busy_timeout=%d", (60 * time.Second).Milliseconds()), }, &cli.StringFlag{ Name: "sharestore-dsn-pattern", EnvVars: []string{"STORAGE_SERVER_SHARESTORE_DSN_PATTERN"}, Value: fmt.Sprintf("sqlite://data/%%TENANT%%/sharestore.sqlite?_pragma=foreign_keys(1)&_pragma=busy_timeout=%d", (60 * time.Second).Milliseconds()), }, flag.PrivateKey, flag.PrivateKeySigningAlgorithm, flag.PrivateKeyDefaultSize, &cli.DurationFlag{ Name: "cache-ttl", EnvVars: []string{"STORAGE_SERVER_CACHE_TTL"}, Value: time.Hour, }, &cli.IntFlag{ Name: "cache-size", EnvVars: []string{"STORAGE_SERVER_CACHE_SIZE"}, Value: 32, }, }, Action: func(ctx *cli.Context) error { addr := ctx.String("address") blobStoreDSNPattern := ctx.String("blobstore-dsn-pattern") documentStoreDSNPattern := ctx.String("documentstore-dsn-pattern") shareStoreDSNPattern := ctx.String("sharestore-dsn-pattern") cacheSize := ctx.Int("cache-size") cacheTTL := ctx.Duration("cache-ttl") privateKeyFile := flag.GetPrivateKey(ctx) signingAlgorithm := flag.GetSigningAlgorithm(ctx) privateKeyDefaultSize := flag.GetPrivateKeyDefaultSize(ctx) router := chi.NewRouter() privateKey, err := jwtutil.LoadOrGenerateKey( privateKeyFile, privateKeyDefaultSize, ) if err != nil { return errors.WithStack(err) } publicKey, err := privateKey.PublicKey() if err != nil { return errors.WithStack(err) } getBlobStoreServer := createGetCachedStoreServer( func(dsn string) (storage.BlobStore, error) { return driver.NewBlobStore(dsn) }, func(store storage.BlobStore) *rpc.Server { return server.NewBlobStoreServer(store) }, ) getShareStoreServer := createGetCachedStoreServer( func(dsn string) (share.Store, error) { return driver.NewShareStore(dsn) }, func(store share.Store) *rpc.Server { return server.NewShareStoreServer(store) }, ) getDocumentStoreServer := createGetCachedStoreServer( func(dsn string) (storage.DocumentStore, error) { return driver.NewDocumentStore(dsn) }, func(store storage.DocumentStore) *rpc.Server { return server.NewDocumentStoreServer(store) }, ) router.Use(middleware.RealIP) router.Use(middleware.Logger) router.Use(authenticate(publicKey, jwa.SignatureAlgorithm(signingAlgorithm))) router.Handle("/blobstore", createStoreHandler(getBlobStoreServer, blobStoreDSNPattern, true, cacheSize, cacheTTL)) router.Handle("/documentstore", createStoreHandler(getDocumentStoreServer, documentStoreDSNPattern, true, cacheSize, cacheTTL)) router.Handle("/sharestore", createStoreHandler(getShareStoreServer, shareStoreDSNPattern, false, cacheSize, cacheTTL)) if err := http.ListenAndServe(addr, router); err != nil { return errors.WithStack(err) } return nil }, } } type getRPCServerFunc func(cacheSize int, cacheTTL time.Duration, tenant, appID, dsnPattern string) (*rpc.Server, error) func createGetCachedStoreServer[T any](storeFactory func(dsn string) (T, error), serverFactory func(store T) *rpc.Server) getRPCServerFunc { var ( cache *expirable.LRU[string, *rpc.Server] initCache sync.Once ) return func(cacheSize int, cacheTTL time.Duration, tenant, appID, dsnPattern string) (*rpc.Server, error) { initCache.Do(func() { cache = expirable.NewLRU[string, *rpc.Server](cacheSize, nil, cacheTTL) }) key := fmt.Sprintf("%s:%s", tenant, appID) storeServer, _ := cache.Get(key) if storeServer != nil { return storeServer, nil } dsn := strings.ReplaceAll(dsnPattern, "%TENANT%", tenant) dsn = strings.ReplaceAll(dsn, "%APPID%", appID) store, err := storeFactory(dsn) if err != nil { return nil, errors.WithStack(err) } storeServer = serverFactory(store) cache.Add(key, storeServer) return storeServer, nil } } func createStoreHandler(getStoreServer getRPCServerFunc, dsnPattern string, appIDRequired bool, cacheSize int, cacheTTL time.Duration) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() tenant, ok := ctx.Value("tenant").(string) if !ok || tenant == "" { http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } appID := r.URL.Query().Get("appId") if appIDRequired && appID == "" { http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) return } server, err := getStoreServer(cacheSize, cacheTTL, tenant, appID, dsnPattern) if err != nil { logger.Error(r.Context(), "could not retrieve store server", logger.E(errors.WithStack(err)), logger.F("tenant", tenant)) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } server.ServeHTTP(w, r) }) } func authenticate(privateKey jwk.Key, signingAlgorithm jwa.SignatureAlgorithm) func(http.Handler) http.Handler { var ( createKeySet sync.Once err error getKeySet jwtutil.GetKeySetFunc ) return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() createKeySet.Do(func() { err = privateKey.Set(jwk.AlgorithmKey, signingAlgorithm) if err != nil { return } var keySet jwk.Set keySet, err = jwtutil.NewKeySet(privateKey) if err != nil { return } getKeySet = func() (jwk.Set, error) { return keySet, nil } }) if err != nil { logger.Error(ctx, "could not create keyset accessor", logger.E(errors.WithStack(err))) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } token, err := jwtutil.FindToken(r, getKeySet, jwtutil.WithFinders( jwtutil.FindTokenFromQueryString("token"), )) if err != nil { logger.Error(ctx, "could not find jwt token", logger.E(errors.WithStack(err))) http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } tokenMap, err := token.AsMap(ctx) if err != nil { logger.Error(ctx, "could not transform token to map", logger.E(errors.WithStack(err))) http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } rawTenant, exists := tokenMap["tenant"] if !exists { logger.Warn(ctx, "could not find tenant claim", logger.F("token", token)) http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } tenant, ok := rawTenant.(string) if !ok { logger.Warn(ctx, "unexpected tenant claim value", logger.F("token", token)) http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } r = r.WithContext(context.WithValue(ctx, "tenant", tenant)) next.ServeHTTP(w, r) }) } }