332 lines
9.9 KiB
Go
332 lines
9.9 KiB
Go
package command
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net/http"
|
|
"os"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/getsentry/sentry-go"
|
|
"github.com/hashicorp/golang-lru/v2/expirable"
|
|
"github.com/keegancsmith/rpc"
|
|
"github.com/lestrrat-go/jwx/v2/jwa"
|
|
"github.com/lestrrat-go/jwx/v2/jwk"
|
|
"gitlab.com/wpetit/goweb/logger"
|
|
|
|
"github.com/go-chi/chi/v5"
|
|
"github.com/go-chi/chi/v5/middleware"
|
|
"github.com/pkg/errors"
|
|
"github.com/urfave/cli/v2"
|
|
|
|
// Register storage drivers
|
|
_ "forge.cadoles.com/arcad/edge/pkg/storage/driver/cache"
|
|
_ "forge.cadoles.com/arcad/edge/pkg/storage/driver/rpc"
|
|
_ "forge.cadoles.com/arcad/edge/pkg/storage/driver/sqlite"
|
|
|
|
"forge.cadoles.com/arcad/edge/cmd/storage-server/command/flag"
|
|
"forge.cadoles.com/arcad/edge/pkg/jwtutil"
|
|
"forge.cadoles.com/arcad/edge/pkg/storage"
|
|
"forge.cadoles.com/arcad/edge/pkg/storage/driver"
|
|
"forge.cadoles.com/arcad/edge/pkg/storage/driver/rpc/server"
|
|
"forge.cadoles.com/arcad/edge/pkg/storage/share"
|
|
)
|
|
|
|
func Run() *cli.Command {
|
|
return &cli.Command{
|
|
Name: "run",
|
|
Usage: "Run server",
|
|
Flags: []cli.Flag{
|
|
&cli.StringFlag{
|
|
Name: "address",
|
|
EnvVars: []string{"STORAGE_SERVER_ADDRESS"},
|
|
Aliases: []string{"addr"},
|
|
Value: ":3001",
|
|
},
|
|
&cli.IntFlag{
|
|
Name: "log-level",
|
|
EnvVars: []string{"STORAGE_SERVER_LOG_LEVEL"},
|
|
Value: int(logger.LevelError),
|
|
},
|
|
&cli.StringFlag{
|
|
Name: "blobstore-dsn-pattern",
|
|
EnvVars: []string{"STORAGE_SERVER_BLOBSTORE_DSN_PATTERN"},
|
|
Value: fmt.Sprintf("sqlite://data/%%TENANT%%/%%APPID%%/blobstore.sqlite?_pragma=foreign_keys(1)&_pragma=busy_timeout=%d&_pragma=journal_mode=wal", (60 * time.Second).Milliseconds()),
|
|
},
|
|
&cli.StringFlag{
|
|
Name: "documentstore-dsn-pattern",
|
|
EnvVars: []string{"STORAGE_SERVER_DOCUMENTSTORE_DSN_PATTERN"},
|
|
Value: fmt.Sprintf("sqlite://data/%%TENANT%%/%%APPID%%/documentstore.sqlite?_pragma=foreign_keys(1)&_pragma=busy_timeout=%d&_pragma=journal_mode=wal", (60 * time.Second).Milliseconds()),
|
|
},
|
|
&cli.StringFlag{
|
|
Name: "sharestore-dsn-pattern",
|
|
EnvVars: []string{"STORAGE_SERVER_SHARESTORE_DSN_PATTERN"},
|
|
Value: fmt.Sprintf("sqlite://data/%%TENANT%%/sharestore.sqlite?_pragma=foreign_keys(1)&_pragma=busy_timeout=%d&_pragma=journal_mode=wal", (60 * time.Second).Milliseconds()),
|
|
},
|
|
&cli.StringFlag{
|
|
Name: "sentry-dsn",
|
|
EnvVars: []string{"STORAGE_SERVER_SENTRY_DSN"},
|
|
Value: "",
|
|
},
|
|
&cli.StringFlag{
|
|
Name: "sentry-environment",
|
|
EnvVars: []string{"STORAGE_SERVER_SENTRY_ENVIRONMENT"},
|
|
Value: "",
|
|
},
|
|
flag.PrivateKey,
|
|
flag.PrivateKeySigningAlgorithm,
|
|
flag.PrivateKeyDefaultSize,
|
|
&cli.DurationFlag{
|
|
Name: "cache-ttl",
|
|
EnvVars: []string{"STORAGE_SERVER_CACHE_TTL"},
|
|
Value: time.Hour,
|
|
},
|
|
&cli.IntFlag{
|
|
Name: "cache-size",
|
|
EnvVars: []string{"STORAGE_SERVER_CACHE_SIZE"},
|
|
Value: 32,
|
|
},
|
|
},
|
|
Action: func(ctx *cli.Context) error {
|
|
addr := ctx.String("address")
|
|
blobStoreDSNPattern := ctx.String("blobstore-dsn-pattern")
|
|
documentStoreDSNPattern := ctx.String("documentstore-dsn-pattern")
|
|
shareStoreDSNPattern := ctx.String("sharestore-dsn-pattern")
|
|
cacheSize := ctx.Int("cache-size")
|
|
cacheTTL := ctx.Duration("cache-ttl")
|
|
privateKeyFile := flag.GetPrivateKey(ctx)
|
|
signingAlgorithm := flag.GetSigningAlgorithm(ctx)
|
|
privateKeyDefaultSize := flag.GetPrivateKeyDefaultSize(ctx)
|
|
logLevel := ctx.Int("log-level")
|
|
|
|
logger.SetLevel(logger.Level(logLevel))
|
|
|
|
sentryDSN := ctx.String("sentry-dsn")
|
|
sentryEnvironment := ctx.String("sentry-environment")
|
|
if sentryDSN != "" {
|
|
if sentryEnvironment == "" {
|
|
sentryEnvironment, _ = os.Hostname()
|
|
}
|
|
|
|
err := sentry.Init(sentry.ClientOptions{
|
|
Dsn: sentryDSN,
|
|
Debug: logLevel == int(logger.LevelDebug),
|
|
AttachStacktrace: true,
|
|
Environment: sentryEnvironment,
|
|
})
|
|
if err != nil {
|
|
logger.Error(ctx.Context, "could not initialize sentry", logger.CapturedE(errors.WithStack(err)))
|
|
}
|
|
|
|
logger.SetCaptureFunc(func(err error) {
|
|
sentry.CaptureException(err)
|
|
})
|
|
|
|
defer sentry.Flush(2 * time.Second)
|
|
}
|
|
|
|
router := chi.NewRouter()
|
|
|
|
privateKey, err := jwtutil.LoadOrGenerateKey(
|
|
privateKeyFile,
|
|
privateKeyDefaultSize,
|
|
)
|
|
if err != nil {
|
|
return errors.WithStack(err)
|
|
}
|
|
|
|
getBlobStoreServer := createGetCachedStoreServer(
|
|
func(dsn string) (storage.BlobStore, error) {
|
|
return driver.NewBlobStore(dsn)
|
|
},
|
|
func(store storage.BlobStore) *rpc.Server {
|
|
return server.NewBlobStoreServer(store)
|
|
},
|
|
)
|
|
|
|
getShareStoreServer := createGetCachedStoreServer(
|
|
func(dsn string) (share.Store, error) {
|
|
return driver.NewShareStore(dsn)
|
|
},
|
|
func(store share.Store) *rpc.Server {
|
|
return server.NewShareStoreServer(store)
|
|
},
|
|
)
|
|
|
|
getDocumentStoreServer := createGetCachedStoreServer(
|
|
func(dsn string) (storage.DocumentStore, error) {
|
|
return driver.NewDocumentStore(dsn)
|
|
},
|
|
func(store storage.DocumentStore) *rpc.Server {
|
|
return server.NewDocumentStoreServer(store)
|
|
},
|
|
)
|
|
|
|
router.Use(middleware.RealIP)
|
|
router.Use(middleware.Logger)
|
|
|
|
logger.Debug(ctx.Context, "using authentication", logger.F("privateKey", privateKeyFile), logger.F("signingAlgorithm", signingAlgorithm))
|
|
|
|
router.Use(authenticate(privateKey, jwa.SignatureAlgorithm(signingAlgorithm)))
|
|
|
|
router.Handle("/blobstore", createStoreHandler(getBlobStoreServer, blobStoreDSNPattern, true, cacheSize, cacheTTL))
|
|
router.Handle("/documentstore", createStoreHandler(getDocumentStoreServer, documentStoreDSNPattern, true, cacheSize, cacheTTL))
|
|
router.Handle("/sharestore", createStoreHandler(getShareStoreServer, shareStoreDSNPattern, false, cacheSize, cacheTTL))
|
|
|
|
logger.Info(ctx.Context, "listening", logger.F("addr", addr))
|
|
|
|
if err := http.ListenAndServe(addr, router); err != nil {
|
|
return errors.WithStack(err)
|
|
}
|
|
|
|
return nil
|
|
},
|
|
}
|
|
}
|
|
|
|
type getRPCServerFunc func(cacheSize int, cacheTTL time.Duration, tenant, appID, dsnPattern string) (*rpc.Server, error)
|
|
|
|
func createGetCachedStoreServer[T any](storeFactory func(dsn string) (T, error), serverFactory func(store T) *rpc.Server) getRPCServerFunc {
|
|
var (
|
|
cache *expirable.LRU[string, *rpc.Server]
|
|
initCache sync.Once
|
|
)
|
|
|
|
return func(cacheSize int, cacheTTL time.Duration, tenant, appID, dsnPattern string) (*rpc.Server, error) {
|
|
initCache.Do(func() {
|
|
cache = expirable.NewLRU[string, *rpc.Server](cacheSize, nil, cacheTTL)
|
|
})
|
|
|
|
key := fmt.Sprintf("%s:%s", tenant, appID)
|
|
|
|
storeServer, _ := cache.Get(key)
|
|
if storeServer != nil {
|
|
return storeServer, nil
|
|
}
|
|
|
|
dsn := strings.ReplaceAll(dsnPattern, "%TENANT%", tenant)
|
|
dsn = strings.ReplaceAll(dsn, "%APPID%", appID)
|
|
|
|
store, err := storeFactory(dsn)
|
|
if err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
|
|
storeServer = serverFactory(store)
|
|
|
|
cache.Add(key, storeServer)
|
|
|
|
return storeServer, nil
|
|
}
|
|
}
|
|
|
|
func createStoreHandler(getStoreServer getRPCServerFunc, dsnPattern string, appIDRequired bool, cacheSize int, cacheTTL time.Duration) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
|
|
tenant, ok := ctx.Value("tenant").(string)
|
|
if !ok || tenant == "" {
|
|
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
|
|
|
|
return
|
|
}
|
|
|
|
appID := r.URL.Query().Get("appId")
|
|
if appIDRequired && appID == "" {
|
|
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
|
|
|
|
return
|
|
}
|
|
|
|
server, err := getStoreServer(cacheSize, cacheTTL, tenant, appID, dsnPattern)
|
|
if err != nil {
|
|
logger.Error(r.Context(), "could not retrieve store server", logger.CapturedE(errors.WithStack(err)), logger.F("tenant", tenant))
|
|
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
|
|
|
return
|
|
}
|
|
|
|
server.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
|
|
func authenticate(privateKey jwk.Key, signingAlgorithm jwa.SignatureAlgorithm) func(http.Handler) http.Handler {
|
|
var (
|
|
createKeySet sync.Once
|
|
err error
|
|
getKeySet jwtutil.GetKeySetFunc
|
|
)
|
|
|
|
return func(next http.Handler) http.Handler {
|
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
|
ctx := r.Context()
|
|
|
|
createKeySet.Do(func() {
|
|
var keySet jwk.Set
|
|
|
|
keySet, err = jwtutil.NewKeySet()
|
|
if err != nil {
|
|
err = errors.WithStack(err)
|
|
return
|
|
}
|
|
|
|
err = jwtutil.AddKeyWithSigningAlgo(keySet, privateKey, signingAlgorithm)
|
|
if err != nil {
|
|
err = errors.WithStack(err)
|
|
return
|
|
}
|
|
|
|
getKeySet = func() (jwk.Set, error) {
|
|
return keySet, nil
|
|
}
|
|
})
|
|
if err != nil {
|
|
logger.Error(ctx, "could not create keyset accessor", logger.CapturedE(errors.WithStack(err)))
|
|
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
|
|
|
return
|
|
}
|
|
|
|
token, err := jwtutil.FindToken(r, getKeySet, jwtutil.WithFinders(
|
|
jwtutil.FindTokenFromQueryString("token"),
|
|
))
|
|
if err != nil {
|
|
logger.Error(ctx, "could not find jwt token", logger.CapturedE(errors.WithStack(err)))
|
|
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
|
|
|
|
return
|
|
}
|
|
|
|
tokenMap, err := token.AsMap(ctx)
|
|
if err != nil {
|
|
logger.Error(ctx, "could not transform token to map", logger.CapturedE(errors.WithStack(err)))
|
|
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
|
|
|
|
return
|
|
}
|
|
|
|
rawTenant, exists := tokenMap["tenant"]
|
|
if !exists {
|
|
logger.Warn(ctx, "could not find tenant claim", logger.F("token", token))
|
|
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
|
|
|
|
return
|
|
}
|
|
|
|
tenant, ok := rawTenant.(string)
|
|
if !ok {
|
|
logger.Warn(ctx, "unexpected tenant claim value", logger.F("token", token))
|
|
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
|
|
|
|
return
|
|
}
|
|
|
|
r = r.WithContext(context.WithValue(ctx, "tenant", tenant))
|
|
|
|
next.ServeHTTP(w, r)
|
|
})
|
|
}
|
|
}
|