package command import ( "context" "fmt" "net/http" "strings" "sync" "time" "github.com/hashicorp/golang-lru/v2/expirable" "github.com/keegancsmith/rpc" "github.com/lestrrat-go/jwx/v2/jwk" "gitlab.com/wpetit/goweb/logger" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" "github.com/pkg/errors" "github.com/urfave/cli/v2" // Register storage drivers "forge.cadoles.com/arcad/edge/pkg/jwtutil" "forge.cadoles.com/arcad/edge/pkg/storage" "forge.cadoles.com/arcad/edge/pkg/storage/driver" _ "forge.cadoles.com/arcad/edge/pkg/storage/driver/rpc" "forge.cadoles.com/arcad/edge/pkg/storage/driver/rpc/server" _ "forge.cadoles.com/arcad/edge/pkg/storage/driver/sqlite" "forge.cadoles.com/arcad/edge/pkg/storage/share" ) func Run() *cli.Command { return &cli.Command{ Name: "run", Usage: "Run server", Flags: []cli.Flag{ &cli.StringFlag{ Name: "address", Aliases: []string{"addr"}, Value: ":3001", }, &cli.StringFlag{ Name: "blobstore-dsn-pattern", EnvVars: []string{"STORAGE_SERVER_BLOBSTORE_DSN_PATTERN"}, Value: fmt.Sprintf("sqlite://data/%%TENANT%%/%%APPID%%/blobstore.sqlite?_pragma=foreign_keys(1)&_pragma=busy_timeout=%d", (60 * time.Second).Milliseconds()), }, &cli.StringFlag{ Name: "documentstore-dsn-pattern", EnvVars: []string{"STORAGE_SERVER_DOCUMENTSTORE_DSN_PATTERN"}, Value: fmt.Sprintf("sqlite://data/%%TENANT%%/%%APPID%%/documentstore.sqlite?_pragma=foreign_keys(1)&_pragma=busy_timeout=%d", (60 * time.Second).Milliseconds()), }, &cli.StringFlag{ Name: "sharestore-dsn-pattern", EnvVars: []string{"STORAGE_SERVER_SHARESTORE_DSN_PATTERN"}, Value: fmt.Sprintf("sqlite://data/%%TENANT%%/sharestore.sqlite?_pragma=foreign_keys(1)&_pragma=busy_timeout=%d", (60 * time.Second).Milliseconds()), }, &cli.StringFlag{ Name: "private-key", EnvVars: []string{"STORAGE_SERVER_PRIVATE_KEY"}, Value: "storage-server.key", TakesFile: true, }, &cli.DurationFlag{ Name: "cache-ttl", EnvVars: []string{"STORAGE_SERVER_CACHE_TTL"}, Value: time.Hour, }, &cli.IntFlag{ Name: "cache-size", EnvVars: []string{"STORAGE_SERVER_CACHE_SIZE"}, Value: 32, }, }, Action: func(ctx *cli.Context) error { addr := ctx.String("address") blobStoreDSNPattern := ctx.String("blobstore-dsn-pattern") documentStoreDSNPattern := ctx.String("documentstore-dsn-pattern") shareStoreDSNPattern := ctx.String("sharestore-dsn-pattern") cacheSize := ctx.Int("cache-size") cacheTTL := ctx.Duration("cache-ttl") privateKeyFile := ctx.String("private-key") router := chi.NewRouter() rsaKey, err := jwtutil.LoadOrGenerateRSAKey(privateKeyFile, 2048) if err != nil { return errors.WithStack(err) } privateKey, err := jwtutil.FromRSA(rsaKey) if err != nil { return errors.WithStack(err) } getBlobStoreServer := createGetCachedStoreServer( func(dsn string) (storage.BlobStore, error) { return driver.NewBlobStore(dsn) }, func(store storage.BlobStore) *rpc.Server { return server.NewBlobStoreServer(store) }, ) getShareStoreServer := createGetCachedStoreServer( func(dsn string) (share.Store, error) { return driver.NewShareStore(dsn) }, func(store share.Store) *rpc.Server { return server.NewShareStoreServer(store) }, ) getDocumentStoreServer := createGetCachedStoreServer( func(dsn string) (storage.DocumentStore, error) { return driver.NewDocumentStore(dsn) }, func(store storage.DocumentStore) *rpc.Server { return server.NewDocumentStoreServer(store) }, ) router.Use(middleware.RealIP) router.Use(middleware.Logger) router.Use(authenticate(privateKey)) router.Handle("/blobstore", createStoreHandler(getBlobStoreServer, blobStoreDSNPattern, cacheSize, cacheTTL)) router.Handle("/documentstore", createStoreHandler(getDocumentStoreServer, documentStoreDSNPattern, cacheSize, cacheTTL)) router.Handle("/sharestore", createStoreHandler(getShareStoreServer, shareStoreDSNPattern, cacheSize, cacheTTL)) if err := http.ListenAndServe(addr, router); err != nil { return errors.WithStack(err) } return nil }, } } type getRPCServerFunc func(cacheSize int, cacheTTL time.Duration, tenant, appID, dsnPattern string) (*rpc.Server, error) func createGetCachedStoreServer[T any](storeFactory func(dsn string) (T, error), serverFactory func(store T) *rpc.Server) getRPCServerFunc { var ( cache *expirable.LRU[string, *rpc.Server] initCache sync.Once ) return func(cacheSize int, cacheTTL time.Duration, tenant, appID, dsnPattern string) (*rpc.Server, error) { initCache.Do(func() { cache = expirable.NewLRU[string, *rpc.Server](cacheSize, nil, cacheTTL) }) key := fmt.Sprintf("%s:%s", tenant, appID) storeServer, _ := cache.Get(key) if storeServer != nil { return storeServer, nil } dsn := strings.ReplaceAll(dsnPattern, "%TENANT%", tenant) dsn = strings.ReplaceAll(dsn, "%APPID%", appID) store, err := storeFactory(dsn) if err != nil { return nil, errors.WithStack(err) } storeServer = serverFactory(store) cache.Add(key, storeServer) return storeServer, nil } } func createStoreHandler(getStoreServer getRPCServerFunc, dsnPattern string, cacheSize int, cacheTTL time.Duration) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { tenant := r.URL.Query().Get("tenant") if tenant == "" { http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) return } appID := r.URL.Query().Get("appId") if tenant == "" { http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest) return } server, err := getStoreServer(cacheSize, cacheTTL, tenant, appID, dsnPattern) if err != nil { logger.Error(r.Context(), "could not retrieve store server", logger.E(errors.WithStack(err)), logger.F("tenant", tenant)) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } server.ServeHTTP(w, r) }) } func authenticate(privateKey jwk.Key) func(http.Handler) http.Handler { keySet, err := jwtutil.NewKeySet(privateKey) getKeySet := func() (jwk.Set, error) { if err != nil { return nil, errors.WithStack(err) } return keySet, nil } return func(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() token, err := jwtutil.FindToken(r, getKeySet, jwtutil.WithFinders( jwtutil.FindTokenFromQueryString("token"), )) if err != nil { logger.Error(ctx, "could not find jwt token", logger.E(errors.WithStack(err))) http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } tokenMap, err := token.AsMap(ctx) if err != nil { logger.Error(ctx, "could not transform token to map", logger.E(errors.WithStack(err))) http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } rawTenant, exists := tokenMap["tenant"] if !exists { logger.Warn(ctx, "could not find tenant claim", logger.F("token", token)) http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } tenant, ok := rawTenant.(string) if !ok { logger.Warn(ctx, "unexpected tenant claim value", logger.F("token", token)) http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } r = r.WithContext(context.WithValue(ctx, "tenant", tenant)) rawAppId, exists := tokenMap["appId"] if !exists { logger.Warn(ctx, "could not find appId claim", logger.F("token", token)) http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } appId, ok := rawAppId.(string) if !ok { logger.Warn(ctx, "unexpected appId claim value", logger.F("token", token)) http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized) return } r = r.WithContext(context.WithValue(ctx, "appId", appId)) next.ServeHTTP(w, r) }) } }