feat(storage-server): jwt based authentication
All checks were successful
arcad/edge/pipeline/pr-master This commit looks good

This commit is contained in:
2023-09-28 23:41:01 -06:00
parent c63af872ea
commit d2472623f2
25 changed files with 646 additions and 235 deletions

View File

@ -0,0 +1,53 @@
package auth
import (
"fmt"
"forge.cadoles.com/arcad/edge/cmd/storage-server/command/flag"
"forge.cadoles.com/arcad/edge/pkg/jwtutil"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/pkg/errors"
"github.com/urfave/cli/v2"
)
func NewToken() *cli.Command {
return &cli.Command{
Name: "new-token",
Usage: "Generate new authentication token",
Flags: []cli.Flag{
&cli.StringFlag{
Name: "tenant",
},
flag.PrivateKey,
flag.PrivateKeySigningAlgorithm,
flag.PrivateKeyDefaultSize,
},
Action: func(ctx *cli.Context) error {
privateKeyFile := flag.GetPrivateKey(ctx)
signingAlgorithm := flag.GetSigningAlgorithm(ctx)
privateKeyDefaultSize := flag.GetPrivateKeyDefaultSize(ctx)
tenant := ctx.String("tenant")
privateKey, err := jwtutil.LoadOrGenerateKey(
privateKeyFile,
privateKeyDefaultSize,
)
if err != nil {
return errors.WithStack(err)
}
claims := map[string]any{
"tenant": tenant,
}
token, err := jwtutil.SignedToken(privateKey, jwa.SignatureAlgorithm(signingAlgorithm), claims)
if err != nil {
return errors.Wrap(err, "could not generate signed token")
}
fmt.Println(string(token))
return nil
},
}
}

View File

@ -6,8 +6,10 @@ import (
func Root() *cli.Command {
return &cli.Command{
Name: "auth",
Usage: "Auth related command",
Subcommands: []*cli.Command{},
Name: "auth",
Usage: "Auth related command",
Subcommands: []*cli.Command{
NewToken(),
},
}
}

View File

@ -0,0 +1,43 @@
package flag
import (
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/urfave/cli/v2"
)
const PrivateKeyFlagName = "private-key"
var PrivateKey = &cli.StringFlag{
Name: PrivateKeyFlagName,
EnvVars: []string{"STORAGE_SERVER_PRIVATE_KEY"},
Value: "storage-server.key",
TakesFile: true,
}
func GetPrivateKey(ctx *cli.Context) string {
return ctx.String(PrivateKeyFlagName)
}
const SigningAlgorithmFlagName = "signing-algorithm"
var PrivateKeySigningAlgorithm = &cli.StringFlag{
Name: SigningAlgorithmFlagName,
EnvVars: []string{"STORAGE_SERVER_PRIVATE_KEY_SIGNING_ALGORITHM"},
Value: jwa.RS256.String(),
}
func GetSigningAlgorithm(ctx *cli.Context) string {
return ctx.String(SigningAlgorithmFlagName)
}
const PrivateKeyDefaultSizeFlagName = "private-key-default-size"
var PrivateKeyDefaultSize = &cli.IntFlag{
Name: PrivateKeyDefaultSizeFlagName,
EnvVars: []string{"STORAGE_SERVER_PRIVATE_KEY_DEFAULT_SIZE"},
Value: 2048,
}
func GetPrivateKeyDefaultSize(ctx *cli.Context) int {
return ctx.Int(PrivateKeyDefaultSizeFlagName)
}

View File

@ -1,6 +1,7 @@
package command
import (
"context"
"fmt"
"net/http"
"strings"
@ -9,6 +10,8 @@ import (
"github.com/hashicorp/golang-lru/v2/expirable"
"github.com/keegancsmith/rpc"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwk"
"gitlab.com/wpetit/goweb/logger"
"github.com/go-chi/chi/v5"
@ -17,6 +20,8 @@ import (
"github.com/urfave/cli/v2"
// Register storage drivers
"forge.cadoles.com/arcad/edge/cmd/storage-server/command/flag"
"forge.cadoles.com/arcad/edge/pkg/jwtutil"
"forge.cadoles.com/arcad/edge/pkg/storage"
"forge.cadoles.com/arcad/edge/pkg/storage/driver"
_ "forge.cadoles.com/arcad/edge/pkg/storage/driver/rpc"
@ -50,6 +55,9 @@ func Run() *cli.Command {
EnvVars: []string{"STORAGE_SERVER_SHARESTORE_DSN_PATTERN"},
Value: fmt.Sprintf("sqlite://data/%%TENANT%%/sharestore.sqlite?_pragma=foreign_keys(1)&_pragma=busy_timeout=%d", (60 * time.Second).Milliseconds()),
},
flag.PrivateKey,
flag.PrivateKeySigningAlgorithm,
flag.PrivateKeyDefaultSize,
&cli.DurationFlag{
Name: "cache-ttl",
EnvVars: []string{"STORAGE_SERVER_CACHE_TTL"},
@ -68,9 +76,25 @@ func Run() *cli.Command {
shareStoreDSNPattern := ctx.String("sharestore-dsn-pattern")
cacheSize := ctx.Int("cache-size")
cacheTTL := ctx.Duration("cache-ttl")
privateKeyFile := flag.GetPrivateKey(ctx)
signingAlgorithm := flag.GetSigningAlgorithm(ctx)
privateKeyDefaultSize := flag.GetPrivateKeyDefaultSize(ctx)
router := chi.NewRouter()
privateKey, err := jwtutil.LoadOrGenerateKey(
privateKeyFile,
privateKeyDefaultSize,
)
if err != nil {
return errors.WithStack(err)
}
publicKey, err := privateKey.PublicKey()
if err != nil {
return errors.WithStack(err)
}
getBlobStoreServer := createGetCachedStoreServer(
func(dsn string) (storage.BlobStore, error) {
return driver.NewBlobStore(dsn)
@ -100,10 +124,11 @@ func Run() *cli.Command {
router.Use(middleware.RealIP)
router.Use(middleware.Logger)
router.Use(authenticate(publicKey, jwa.SignatureAlgorithm(signingAlgorithm)))
router.Handle("/blobstore", createStoreHandler(getBlobStoreServer, blobStoreDSNPattern, cacheSize, cacheTTL))
router.Handle("/documentstore", createStoreHandler(getDocumentStoreServer, documentStoreDSNPattern, cacheSize, cacheTTL))
router.Handle("/sharestore", createStoreHandler(getShareStoreServer, shareStoreDSNPattern, cacheSize, cacheTTL))
router.Handle("/blobstore", createStoreHandler(getBlobStoreServer, blobStoreDSNPattern, true, cacheSize, cacheTTL))
router.Handle("/documentstore", createStoreHandler(getDocumentStoreServer, documentStoreDSNPattern, true, cacheSize, cacheTTL))
router.Handle("/sharestore", createStoreHandler(getShareStoreServer, shareStoreDSNPattern, false, cacheSize, cacheTTL))
if err := http.ListenAndServe(addr, router); err != nil {
return errors.WithStack(err)
@ -150,17 +175,19 @@ func createGetCachedStoreServer[T any](storeFactory func(dsn string) (T, error),
}
}
func createStoreHandler(getStoreServer getRPCServerFunc, dsnPattern string, cacheSize int, cacheTTL time.Duration) http.Handler {
func createStoreHandler(getStoreServer getRPCServerFunc, dsnPattern string, appIDRequired bool, cacheSize int, cacheTTL time.Duration) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
tenant := r.URL.Query().Get("tenant")
if tenant == "" {
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
ctx := r.Context()
tenant, ok := ctx.Value("tenant").(string)
if !ok || tenant == "" {
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}
appID := r.URL.Query().Get("appId")
if tenant == "" {
if appIDRequired && appID == "" {
http.Error(w, http.StatusText(http.StatusBadRequest), http.StatusBadRequest)
return
@ -177,3 +204,79 @@ func createStoreHandler(getStoreServer getRPCServerFunc, dsnPattern string, cach
server.ServeHTTP(w, r)
})
}
func authenticate(privateKey jwk.Key, signingAlgorithm jwa.SignatureAlgorithm) func(http.Handler) http.Handler {
var (
createKeySet sync.Once
err error
getKeySet jwtutil.GetKeySetFunc
)
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
createKeySet.Do(func() {
err = privateKey.Set(jwk.AlgorithmKey, signingAlgorithm)
if err != nil {
return
}
var keySet jwk.Set
keySet, err = jwtutil.NewKeySet(privateKey)
if err != nil {
return
}
getKeySet = func() (jwk.Set, error) {
return keySet, nil
}
})
if err != nil {
logger.Error(ctx, "could not create keyset accessor", logger.E(errors.WithStack(err)))
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
token, err := jwtutil.FindToken(r, getKeySet, jwtutil.WithFinders(
jwtutil.FindTokenFromQueryString("token"),
))
if err != nil {
logger.Error(ctx, "could not find jwt token", logger.E(errors.WithStack(err)))
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}
tokenMap, err := token.AsMap(ctx)
if err != nil {
logger.Error(ctx, "could not transform token to map", logger.E(errors.WithStack(err)))
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}
rawTenant, exists := tokenMap["tenant"]
if !exists {
logger.Warn(ctx, "could not find tenant claim", logger.F("token", token))
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}
tenant, ok := rawTenant.(string)
if !ok {
logger.Warn(ctx, "unexpected tenant claim value", logger.F("token", token))
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}
r = r.WithContext(context.WithValue(ctx, "tenant", tenant))
next.ServeHTTP(w, r)
})
}
}