edge/cmd/storage-server/command/run.go
William Petit 09da1c6ce9
All checks were successful
arcad/edge/pipeline/pr-master This commit looks good
feat(storage-server): jwt based authentication
2023-09-28 23:41:01 -06:00

274 lines
8.0 KiB
Go

package command
import (
"context"
"fmt"
"net/http"
"strings"
"sync"
"time"
"github.com/hashicorp/golang-lru/v2/expirable"
"github.com/keegancsmith/rpc"
"github.com/lestrrat-go/jwx/v2/jwk"
"gitlab.com/wpetit/goweb/logger"
"github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware"
"github.com/pkg/errors"
"github.com/urfave/cli/v2"
// Register storage drivers
"forge.cadoles.com/arcad/edge/pkg/jwtutil"
"forge.cadoles.com/arcad/edge/pkg/storage"
"forge.cadoles.com/arcad/edge/pkg/storage/driver"
_ "forge.cadoles.com/arcad/edge/pkg/storage/driver/rpc"
"forge.cadoles.com/arcad/edge/pkg/storage/driver/rpc/server"
_ "forge.cadoles.com/arcad/edge/pkg/storage/driver/sqlite"
"forge.cadoles.com/arcad/edge/pkg/storage/share"
)
func Run() *cli.Command {
return &cli.Command{
Name: "run",
Usage: "Run server",
Flags: []cli.Flag{
&cli.StringFlag{
Name: "address",
Aliases: []string{"addr"},
Value: ":3001",
},
&cli.StringFlag{
Name: "blobstore-dsn-pattern",
EnvVars: []string{"STORAGE_SERVER_BLOBSTORE_DSN_PATTERN"},
Value: fmt.Sprintf("sqlite://data/%%TENANT%%/%%APPID%%/blobstore.sqlite?_pragma=foreign_keys(1)&_pragma=busy_timeout=%d", (60 * time.Second).Milliseconds()),
},
&cli.StringFlag{
Name: "documentstore-dsn-pattern",
EnvVars: []string{"STORAGE_SERVER_DOCUMENTSTORE_DSN_PATTERN"},
Value: fmt.Sprintf("sqlite://data/%%TENANT%%/%%APPID%%/documentstore.sqlite?_pragma=foreign_keys(1)&_pragma=busy_timeout=%d", (60 * time.Second).Milliseconds()),
},
&cli.StringFlag{
Name: "sharestore-dsn-pattern",
EnvVars: []string{"STORAGE_SERVER_SHARESTORE_DSN_PATTERN"},
Value: fmt.Sprintf("sqlite://data/%%TENANT%%/sharestore.sqlite?_pragma=foreign_keys(1)&_pragma=busy_timeout=%d", (60 * time.Second).Milliseconds()),
},
&cli.StringFlag{
Name: "private-key",
EnvVars: []string{"STORAGE_SERVER_PRIVATE_KEY"},
Value: "storage-server.key",
TakesFile: true,
},
&cli.DurationFlag{
Name: "cache-ttl",
EnvVars: []string{"STORAGE_SERVER_CACHE_TTL"},
Value: time.Hour,
},
&cli.IntFlag{
Name: "cache-size",
EnvVars: []string{"STORAGE_SERVER_CACHE_SIZE"},
Value: 32,
},
},
Action: func(ctx *cli.Context) error {
addr := ctx.String("address")
blobStoreDSNPattern := ctx.String("blobstore-dsn-pattern")
documentStoreDSNPattern := ctx.String("documentstore-dsn-pattern")
shareStoreDSNPattern := ctx.String("sharestore-dsn-pattern")
cacheSize := ctx.Int("cache-size")
cacheTTL := ctx.Duration("cache-ttl")
privateKeyFile := ctx.String("private-key")
router := chi.NewRouter()
rsaKey, err := jwtutil.LoadOrGenerateRSAKey(privateKeyFile, 2048)
if err != nil {
return errors.WithStack(err)
}
privateKey, err := jwtutil.FromRSA(rsaKey)
if err != nil {
return errors.WithStack(err)
}
getBlobStoreServer := createGetCachedStoreServer(
func(dsn string) (storage.BlobStore, error) {
return driver.NewBlobStore(dsn)
},
func(store storage.BlobStore) *rpc.Server {
return server.NewBlobStoreServer(store)
},
)
getShareStoreServer := createGetCachedStoreServer(
func(dsn string) (share.Store, error) {
return driver.NewShareStore(dsn)
},
func(store share.Store) *rpc.Server {
return server.NewShareStoreServer(store)
},
)
getDocumentStoreServer := createGetCachedStoreServer(
func(dsn string) (storage.DocumentStore, error) {
return driver.NewDocumentStore(dsn)
},
func(store storage.DocumentStore) *rpc.Server {
return server.NewDocumentStoreServer(store)
},
)
router.Use(middleware.RealIP)
router.Use(middleware.Logger)
router.Use(authenticate(privateKey))
router.Handle("/blobstore", createStoreHandler(getBlobStoreServer, blobStoreDSNPattern, cacheSize, cacheTTL))
router.Handle("/documentstore", createStoreHandler(getDocumentStoreServer, documentStoreDSNPattern, cacheSize, cacheTTL))
router.Handle("/sharestore", createStoreHandler(getShareStoreServer, shareStoreDSNPattern, cacheSize, cacheTTL))
if err := http.ListenAndServe(addr, router); err != nil {
return errors.WithStack(err)
}
return nil
},
}
}
type getRPCServerFunc func(cacheSize int, cacheTTL time.Duration, tenant, appID, dsnPattern string) (*rpc.Server, error)
func createGetCachedStoreServer[T any](storeFactory func(dsn string) (T, error), serverFactory func(store T) *rpc.Server) getRPCServerFunc {
var (
cache *expirable.LRU[string, *rpc.Server]
initCache sync.Once
)
return func(cacheSize int, cacheTTL time.Duration, tenant, appID, dsnPattern string) (*rpc.Server, error) {
initCache.Do(func() {
cache = expirable.NewLRU[string, *rpc.Server](cacheSize, nil, cacheTTL)
})
key := fmt.Sprintf("%s:%s", tenant, appID)
storeServer, _ := cache.Get(key)
if storeServer != nil {
return storeServer, nil
}
dsn := strings.ReplaceAll(dsnPattern, "%TENANT%", tenant)
dsn = strings.ReplaceAll(dsn, "%APPID%", appID)
store, err := storeFactory(dsn)
if err != nil {
return nil, errors.WithStack(err)
}
storeServer = serverFactory(store)
cache.Add(key, storeServer)
return storeServer, nil
}
}
func createStoreHandler(getStoreServer getRPCServerFunc, dsnPattern string, cacheSize int, cacheTTL time.Duration) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tenant := r.URL.Query().Get("tenant")
if tenant == "" {
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
return
}
appID := r.URL.Query().Get("appId")
if tenant == "" {
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
return
}
server, err := getStoreServer(cacheSize, cacheTTL, tenant, appID, dsnPattern)
if err != nil {
logger.Error(r.Context(), "could not retrieve store server", logger.E(errors.WithStack(err)), logger.F("tenant", tenant))
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
server.ServeHTTP(w, r)
})
}
func authenticate(privateKey jwk.Key) func(http.Handler) http.Handler {
keySet, err := jwtutil.NewKeySet(privateKey)
getKeySet := func() (jwk.Set, error) {
if err != nil {
return nil, errors.WithStack(err)
}
return keySet, nil
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
token, err := jwtutil.FindToken(r, getKeySet, jwtutil.WithFinders(
jwtutil.FindTokenFromQueryString("token"),
))
if err != nil {
logger.Error(ctx, "could not find jwt token", logger.E(errors.WithStack(err)))
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}
tokenMap, err := token.AsMap(ctx)
if err != nil {
logger.Error(ctx, "could not transform token to map", logger.E(errors.WithStack(err)))
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}
rawTenant, exists := tokenMap["tenant"]
if !exists {
logger.Warn(ctx, "could not find tenant claim", logger.F("token", token))
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}
tenant, ok := rawTenant.(string)
if !ok {
logger.Warn(ctx, "unexpected tenant claim value", logger.F("token", token))
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}
r = r.WithContext(context.WithValue(ctx, "tenant", tenant))
rawAppId, exists := tokenMap["appId"]
if !exists {
logger.Warn(ctx, "could not find appId claim", logger.F("token", token))
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}
appId, ok := rawAppId.(string)
if !ok {
logger.Warn(ctx, "unexpected appId claim value", logger.F("token", token))
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}
r = r.WithContext(context.WithValue(ctx, "appId", appId))
next.ServeHTTP(w, r)
})
}
}