feat(storage-server): jwt based authentication
arcad/edge/pipeline/pr-master This commit looks good Details

This commit is contained in:
wpetit 2023-09-28 23:41:01 -06:00
parent c63af872ea
commit 09da1c6ce9
17 changed files with 474 additions and 215 deletions

View File

@ -1,3 +1,4 @@
RUN_APP_ARGS="" RUN_APP_ARGS=""
#EDGE_DOCUMENTSTORE_DSN="rpc://localhost:3001/documentstore?tenant=local&appId=%APPID%" #EDGE_DOCUMENTSTORE_DSN="rpc://localhost:3001/documentstore?tenant=local&appId=%APPID%"
#EDGE_BLOBSTORE_DSN="rpc://localhost:3001/blobstore?tenant=local&appId=%APPID%" #EDGE_BLOBSTORE_DSN="rpc://localhost:3001/blobstore?tenant=local&appId=%APPID%"
#EDGE_SHARESTORE_DSN="rpc://localhost:3001/sharestore?tenant=local"

1
.gitignore vendored
View File

@ -10,3 +10,4 @@
/dist /dist
/.chglog /.chglog
/CHANGELOG.md /CHANGELOG.md
/storage-server.key

View File

@ -16,6 +16,7 @@ import (
"forge.cadoles.com/arcad/edge/pkg/bus" "forge.cadoles.com/arcad/edge/pkg/bus"
"forge.cadoles.com/arcad/edge/pkg/bus/memory" "forge.cadoles.com/arcad/edge/pkg/bus/memory"
appHTTP "forge.cadoles.com/arcad/edge/pkg/http" appHTTP "forge.cadoles.com/arcad/edge/pkg/http"
"forge.cadoles.com/arcad/edge/pkg/jwtutil"
"forge.cadoles.com/arcad/edge/pkg/module" "forge.cadoles.com/arcad/edge/pkg/module"
appModule "forge.cadoles.com/arcad/edge/pkg/module/app" appModule "forge.cadoles.com/arcad/edge/pkg/module/app"
appModuleMemory "forge.cadoles.com/arcad/edge/pkg/module/app/memory" appModuleMemory "forge.cadoles.com/arcad/edge/pkg/module/app/memory"
@ -33,7 +34,6 @@ import (
"forge.cadoles.com/arcad/edge/pkg/bundle" "forge.cadoles.com/arcad/edge/pkg/bundle"
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
"github.com/go-chi/chi/v5/middleware" "github.com/go-chi/chi/v5/middleware"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jwk"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/urfave/cli/v2" "github.com/urfave/cli/v2"
@ -50,6 +50,8 @@ import (
"forge.cadoles.com/arcad/edge/pkg/storage/share" "forge.cadoles.com/arcad/edge/pkg/storage/share"
) )
var dummySecret = []byte("not_so_secret")
func RunCommand() *cli.Command { func RunCommand() *cli.Command {
return &cli.Command{ return &cli.Command{
Name: "run", Name: "run",
@ -194,7 +196,7 @@ func runApp(ctx context.Context, path, address, documentStoreDSN, blobStoreDSN,
ctx = logger.With(ctx, logger.F("appID", manifest.ID)) ctx = logger.With(ctx, logger.F("appID", manifest.ID))
// Add auth handler // Add auth handler
key, err := dummyKey() key, err := jwtutil.NewSymmetricKey(dummySecret)
if err != nil { if err != nil {
return errors.WithStack(err) return errors.WithStack(err)
} }
@ -220,17 +222,17 @@ func runApp(ctx context.Context, path, address, documentStoreDSN, blobStoreDSN,
appModule.Mount(appRepository), appModule.Mount(appRepository),
authModule.Mount( authModule.Mount(
authHTTP.NewLocalHandler( authHTTP.NewLocalHandler(
jwa.HS256, key, key,
authHTTP.WithRoutePrefix("/auth"), authHTTP.WithRoutePrefix("/auth"),
authHTTP.WithAccounts(deps.Accounts...), authHTTP.WithAccounts(deps.Accounts...),
), ),
authModule.WithJWT(dummyKeySet), authModule.WithJWT(func() (jwk.Set, error) {
return jwtutil.NewSymmetricKeySet(dummySecret)
}),
), ),
), ),
appHTTP.WithHTTPMiddlewares( appHTTP.WithHTTPMiddlewares(
authModuleMiddleware.AnonymousUser( authModuleMiddleware.AnonymousUser(key),
jwa.HS256, key,
),
), ),
) )
if err := handler.Load(bundle); err != nil { if err := handler.Load(bundle); err != nil {
@ -276,7 +278,9 @@ func getServerModules(deps *moduleDeps) []app.ServerModuleFactory {
module.StoreModuleFactory(deps.DocumentStore), module.StoreModuleFactory(deps.DocumentStore),
blob.ModuleFactory(deps.Bus, deps.BlobStore), blob.ModuleFactory(deps.Bus, deps.BlobStore),
authModule.ModuleFactory( authModule.ModuleFactory(
authModule.WithJWT(dummyKeySet), authModule.WithJWT(func() (jwk.Set, error) {
return jwtutil.NewSymmetricKeySet(dummySecret)
}),
), ),
appModule.ModuleFactory(deps.AppRepository), appModule.ModuleFactory(deps.AppRepository),
fetch.ModuleFactory(deps.Bus), fetch.ModuleFactory(deps.Bus),
@ -284,36 +288,6 @@ func getServerModules(deps *moduleDeps) []app.ServerModuleFactory {
} }
} }
var dummySecret = []byte("not_so_secret")
func dummyKey() (jwk.Key, error) {
key, err := jwk.FromRaw(dummySecret)
if err != nil {
return nil, errors.WithStack(err)
}
return key, nil
}
func dummyKeySet() (jwk.Set, error) {
key, err := dummyKey()
if err != nil {
return nil, errors.WithStack(err)
}
if err := key.Set(jwk.AlgorithmKey, jwa.HS256); err != nil {
return nil, errors.WithStack(err)
}
set := jwk.NewSet()
if err := set.AddKey(key); err != nil {
return nil, errors.WithStack(err)
}
return set, nil
}
func ensureDir(path string) error { func ensureDir(path string) error {
if err := os.MkdirAll(filepath.Dir(path), os.ModePerm); err != nil { if err := os.MkdirAll(filepath.Dir(path), os.ModePerm); err != nil {
return errors.WithStack(err) return errors.WithStack(err)

View File

@ -1,6 +1,7 @@
package command package command
import ( import (
"context"
"fmt" "fmt"
"net/http" "net/http"
"strings" "strings"
@ -9,6 +10,7 @@ import (
"github.com/hashicorp/golang-lru/v2/expirable" "github.com/hashicorp/golang-lru/v2/expirable"
"github.com/keegancsmith/rpc" "github.com/keegancsmith/rpc"
"github.com/lestrrat-go/jwx/v2/jwk"
"gitlab.com/wpetit/goweb/logger" "gitlab.com/wpetit/goweb/logger"
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
@ -17,6 +19,7 @@ import (
"github.com/urfave/cli/v2" "github.com/urfave/cli/v2"
// Register storage drivers // Register storage drivers
"forge.cadoles.com/arcad/edge/pkg/jwtutil"
"forge.cadoles.com/arcad/edge/pkg/storage" "forge.cadoles.com/arcad/edge/pkg/storage"
"forge.cadoles.com/arcad/edge/pkg/storage/driver" "forge.cadoles.com/arcad/edge/pkg/storage/driver"
_ "forge.cadoles.com/arcad/edge/pkg/storage/driver/rpc" _ "forge.cadoles.com/arcad/edge/pkg/storage/driver/rpc"
@ -50,6 +53,12 @@ func Run() *cli.Command {
EnvVars: []string{"STORAGE_SERVER_SHARESTORE_DSN_PATTERN"}, EnvVars: []string{"STORAGE_SERVER_SHARESTORE_DSN_PATTERN"},
Value: fmt.Sprintf("sqlite://data/%%TENANT%%/sharestore.sqlite?_pragma=foreign_keys(1)&_pragma=busy_timeout=%d", (60 * time.Second).Milliseconds()), Value: fmt.Sprintf("sqlite://data/%%TENANT%%/sharestore.sqlite?_pragma=foreign_keys(1)&_pragma=busy_timeout=%d", (60 * time.Second).Milliseconds()),
}, },
&cli.StringFlag{
Name: "private-key",
EnvVars: []string{"STORAGE_SERVER_PRIVATE_KEY"},
Value: "storage-server.key",
TakesFile: true,
},
&cli.DurationFlag{ &cli.DurationFlag{
Name: "cache-ttl", Name: "cache-ttl",
EnvVars: []string{"STORAGE_SERVER_CACHE_TTL"}, EnvVars: []string{"STORAGE_SERVER_CACHE_TTL"},
@ -68,9 +77,20 @@ func Run() *cli.Command {
shareStoreDSNPattern := ctx.String("sharestore-dsn-pattern") shareStoreDSNPattern := ctx.String("sharestore-dsn-pattern")
cacheSize := ctx.Int("cache-size") cacheSize := ctx.Int("cache-size")
cacheTTL := ctx.Duration("cache-ttl") cacheTTL := ctx.Duration("cache-ttl")
privateKeyFile := ctx.String("private-key")
router := chi.NewRouter() router := chi.NewRouter()
rsaKey, err := jwtutil.LoadOrGenerateRSAKey(privateKeyFile, 2048)
if err != nil {
return errors.WithStack(err)
}
privateKey, err := jwtutil.FromRSA(rsaKey)
if err != nil {
return errors.WithStack(err)
}
getBlobStoreServer := createGetCachedStoreServer( getBlobStoreServer := createGetCachedStoreServer(
func(dsn string) (storage.BlobStore, error) { func(dsn string) (storage.BlobStore, error) {
return driver.NewBlobStore(dsn) return driver.NewBlobStore(dsn)
@ -100,6 +120,7 @@ func Run() *cli.Command {
router.Use(middleware.RealIP) router.Use(middleware.RealIP)
router.Use(middleware.Logger) router.Use(middleware.Logger)
router.Use(authenticate(privateKey))
router.Handle("/blobstore", createStoreHandler(getBlobStoreServer, blobStoreDSNPattern, cacheSize, cacheTTL)) router.Handle("/blobstore", createStoreHandler(getBlobStoreServer, blobStoreDSNPattern, cacheSize, cacheTTL))
router.Handle("/documentstore", createStoreHandler(getDocumentStoreServer, documentStoreDSNPattern, cacheSize, cacheTTL)) router.Handle("/documentstore", createStoreHandler(getDocumentStoreServer, documentStoreDSNPattern, cacheSize, cacheTTL))
@ -177,3 +198,76 @@ func createStoreHandler(getStoreServer getRPCServerFunc, dsnPattern string, cach
server.ServeHTTP(w, r) server.ServeHTTP(w, r)
}) })
} }
func authenticate(privateKey jwk.Key) func(http.Handler) http.Handler {
keySet, err := jwtutil.NewKeySet(privateKey)
getKeySet := func() (jwk.Set, error) {
if err != nil {
return nil, errors.WithStack(err)
}
return keySet, nil
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
token, err := jwtutil.FindToken(r, getKeySet, jwtutil.WithFinders(
jwtutil.FindTokenFromQueryString("token"),
))
if err != nil {
logger.Error(ctx, "could not find jwt token", logger.E(errors.WithStack(err)))
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}
tokenMap, err := token.AsMap(ctx)
if err != nil {
logger.Error(ctx, "could not transform token to map", logger.E(errors.WithStack(err)))
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}
rawTenant, exists := tokenMap["tenant"]
if !exists {
logger.Warn(ctx, "could not find tenant claim", logger.F("token", token))
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}
tenant, ok := rawTenant.(string)
if !ok {
logger.Warn(ctx, "unexpected tenant claim value", logger.F("token", token))
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}
r = r.WithContext(context.WithValue(ctx, "tenant", tenant))
rawAppId, exists := tokenMap["appId"]
if !exists {
logger.Warn(ctx, "could not find appId claim", logger.F("token", token))
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}
appId, ok := rawAppId.(string)
if !ok {
logger.Warn(ctx, "unexpected appId claim value", logger.F("token", token))
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
return
}
r = r.WithContext(context.WithValue(ctx, "appId", appId))
next.ServeHTTP(w, r)
})
}
}

View File

@ -1,5 +1,6 @@
package auth package jwtutil
import "errors" import "errors"
var ErrUnauthenticated = errors.New("unauthenticated") var ErrUnauthenticated = errors.New("unauthenticated")
var ErrNoKeySet = errors.New("no keyset")

123
pkg/jwtutil/jwt.go Normal file
View File

@ -0,0 +1,123 @@
package jwtutil
import (
"net/http"
"strings"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jws"
"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/pkg/errors"
)
type TokenFinderFunc func(r *http.Request) (string, error)
type FindTokenOptions struct {
Finders []TokenFinderFunc
}
type FindTokenOptionFunc func(*FindTokenOptions)
type GetKeySetFunc func() (jwk.Set, error)
func WithFinders(finders ...TokenFinderFunc) FindTokenOptionFunc {
return func(opts *FindTokenOptions) {
opts.Finders = finders
}
}
func NewFindTokenOptions(funcs ...FindTokenOptionFunc) *FindTokenOptions {
opts := &FindTokenOptions{
Finders: []TokenFinderFunc{
FindTokenFromAuthorizationHeader,
},
}
for _, fn := range funcs {
fn(opts)
}
return opts
}
func FindTokenFromAuthorizationHeader(r *http.Request) (string, error) {
authorization := r.Header.Get("Authorization")
// Retrieve token from Authorization header
rawToken := strings.TrimPrefix(authorization, "Bearer ")
return rawToken, nil
}
func FindTokenFromQueryString(name string) TokenFinderFunc {
return func(r *http.Request) (string, error) {
return r.URL.Query().Get(name), nil
}
}
func FindTokenFromCookie(cookieName string) TokenFinderFunc {
return func(r *http.Request) (string, error) {
cookie, err := r.Cookie(cookieName)
if err != nil && !errors.Is(err, http.ErrNoCookie) {
return "", errors.WithStack(err)
}
if cookie == nil {
return "", nil
}
return cookie.Value, nil
}
}
func FindRawToken(r *http.Request, funcs ...FindTokenOptionFunc) (string, error) {
opts := NewFindTokenOptions(funcs...)
var rawToken string
var err error
for _, find := range opts.Finders {
rawToken, err = find(r)
if err != nil {
return "", errors.WithStack(err)
}
if rawToken == "" {
continue
}
break
}
if rawToken == "" {
return "", errors.WithStack(ErrUnauthenticated)
}
return rawToken, nil
}
func FindToken(r *http.Request, getKeySet GetKeySetFunc, funcs ...FindTokenOptionFunc) (jwt.Token, error) {
rawToken, err := FindRawToken(r, funcs...)
if err != nil {
return nil, errors.WithStack(err)
}
keySet, err := getKeySet()
if err != nil {
return nil, errors.WithStack(err)
}
if keySet == nil {
return nil, errors.WithStack(ErrNoKeySet)
}
token, err := jwt.Parse([]byte(rawToken),
jwt.WithKeySet(keySet, jws.WithRequireKid(false)),
jwt.WithValidate(true),
)
if err != nil {
return nil, errors.WithStack(err)
}
return token, nil
}

52
pkg/jwtutil/key.go Normal file
View File

@ -0,0 +1,52 @@
package jwtutil
import (
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/pkg/errors"
)
func NewKeySet(keys ...jwk.Key) (jwk.Set, error) {
set := jwk.NewSet()
for _, k := range keys {
if err := set.AddKey(k); err != nil {
return nil, errors.WithStack(err)
}
}
return set, nil
}
func NewSymmetricKey(secret []byte) (jwk.Key, error) {
key, err := jwk.FromRaw(secret)
if err != nil {
return nil, errors.WithStack(err)
}
if err := key.Set(jwk.AlgorithmKey, jwa.HS256); err != nil {
return nil, errors.WithStack(err)
}
return key, nil
}
func NewSymmetricKeySet(secrets ...[]byte) (jwk.Set, error) {
keys := make([]jwk.Key, len(secrets))
for idx, sec := range secrets {
key, err := NewSymmetricKey(sec)
if err != nil {
return nil, errors.WithStack(err)
}
keys[idx] = key
}
keySet, err := NewKeySet(keys...)
if err != nil {
return nil, errors.WithStack(err)
}
return keySet, nil
}

90
pkg/jwtutil/rsa.go Normal file
View File

@ -0,0 +1,90 @@
package jwtutil
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"os"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/pkg/errors"
)
func FromRSA(privateKey *rsa.PrivateKey) (jwk.Key, error) {
key, err := jwk.FromRaw(privateKey)
if err != nil {
return nil, errors.WithStack(err)
}
if err := key.Set(jwk.AlgorithmKey, jwa.RS256); err != nil {
return nil, errors.WithStack(err)
}
return key, nil
}
func LoadOrGenerateRSAKey(path string, size int) (*rsa.PrivateKey, error) {
privateKey, err := LoadRSAKey(path)
if err != nil {
if !errors.Is(err, os.ErrNotExist) {
return nil, errors.WithStack(err)
}
privateKey, err = GenerateRSAKey(size)
if err != nil {
return nil, errors.WithStack(err)
}
if err := SaveRSAKey(path, privateKey); err != nil {
return nil, errors.WithStack(err)
}
}
return privateKey, nil
}
func LoadRSAKey(path string) (*rsa.PrivateKey, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, errors.WithStack(err)
}
block, _ := pem.Decode(data)
if block == nil {
return nil, errors.New("failed to parse pem block")
}
privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil {
return nil, errors.WithStack(err)
}
return privateKey, nil
}
func SaveRSAKey(path string, privateKey *rsa.PrivateKey) error {
data := x509.MarshalPKCS1PrivateKey(privateKey)
pem := pem.EncodeToMemory(
&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: data,
},
)
if err := os.WriteFile(path, pem, os.FileMode(0600)); err != nil {
return errors.WithStack(err)
}
return nil
}
func GenerateRSAKey(size int) (*rsa.PrivateKey, error) {
privateKey, err := rsa.GenerateKey(rand.Reader, size)
if err != nil {
return nil, errors.WithStack(err)
}
return privateKey, nil
}

View File

@ -1,15 +1,14 @@
package jwt package jwtutil
import ( import (
"time" "time"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jwt" "github.com/lestrrat-go/jwx/v2/jwt"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
func GenerateSignedToken(algo jwa.KeyAlgorithm, key jwk.Key, claims map[string]any) ([]byte, error) { func GenerateSignedToken(key jwk.Key, claims map[string]any) ([]byte, error) {
token := jwt.New() token := jwt.New()
if err := token.Set(jwt.NotBeforeKey, time.Now()); err != nil { if err := token.Set(jwt.NotBeforeKey, time.Now()); err != nil {
@ -22,11 +21,11 @@ func GenerateSignedToken(algo jwa.KeyAlgorithm, key jwk.Key, claims map[string]a
} }
} }
if err := token.Set(jwk.AlgorithmKey, jwa.HS256); err != nil { if err := token.Set(jwk.AlgorithmKey, key.Algorithm()); err != nil {
return nil, errors.WithStack(err) return nil, errors.WithStack(err)
} }
rawToken, err := jwt.Sign(token, jwt.WithKey(algo, key)) rawToken, err := jwt.Sign(token, jwt.WithKey(key.Algorithm(), key))
if err != nil { if err != nil {
return nil, errors.WithStack(err) return nil, errors.WithStack(err)
} }

View File

@ -7,11 +7,10 @@ import (
_ "embed" _ "embed"
"forge.cadoles.com/arcad/edge/pkg/jwtutil"
"forge.cadoles.com/arcad/edge/pkg/module/auth" "forge.cadoles.com/arcad/edge/pkg/module/auth"
"forge.cadoles.com/arcad/edge/pkg/module/auth/http/passwd" "forge.cadoles.com/arcad/edge/pkg/module/auth/http/passwd"
"forge.cadoles.com/arcad/edge/pkg/module/auth/jwt"
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jwk"
"github.com/pkg/errors" "github.com/pkg/errors"
"gitlab.com/wpetit/goweb/logger" "gitlab.com/wpetit/goweb/logger"
@ -32,7 +31,6 @@ func init() {
type LocalHandler struct { type LocalHandler struct {
router chi.Router router chi.Router
algo jwa.KeyAlgorithm
key jwk.Key key jwk.Key
getCookieDomain GetCookieDomainFunc getCookieDomain GetCookieDomainFunc
cookieDuration time.Duration cookieDuration time.Duration
@ -113,7 +111,7 @@ func (h *LocalHandler) handleForm(w http.ResponseWriter, r *http.Request) {
account.Claims[auth.ClaimIssuer] = "local" account.Claims[auth.ClaimIssuer] = "local"
token, err := jwt.GenerateSignedToken(h.algo, h.key, account.Claims) token, err := jwtutil.GenerateSignedToken(h.key, account.Claims)
if err != nil { if err != nil {
logger.Error(ctx, "could not generate signed token", logger.E(errors.WithStack(err))) logger.Error(ctx, "could not generate signed token", logger.E(errors.WithStack(err)))
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
@ -182,14 +180,13 @@ func (h *LocalHandler) authenticate(username, password string) (*LocalAccount, e
return &account, nil return &account, nil
} }
func NewLocalHandler(algo jwa.KeyAlgorithm, key jwk.Key, funcs ...LocalHandlerOptionFunc) *LocalHandler { func NewLocalHandler(key jwk.Key, funcs ...LocalHandlerOptionFunc) *LocalHandler {
opts := defaultLocalHandlerOptions() opts := defaultLocalHandlerOptions()
for _, fn := range funcs { for _, fn := range funcs {
fn(opts) fn(opts)
} }
handler := &LocalHandler{ handler := &LocalHandler{
algo: algo,
key: key, key: key,
accounts: toAccountsMap(opts.Accounts), accounts: toAccountsMap(opts.Accounts),
getCookieDomain: opts.GetCookieDomain, getCookieDomain: opts.GetCookieDomain,

View File

@ -1,118 +0,0 @@
package auth
import (
"context"
"net/http"
"strings"
"github.com/lestrrat-go/jwx/v2/jwk"
"github.com/lestrrat-go/jwx/v2/jws"
"github.com/lestrrat-go/jwx/v2/jwt"
"github.com/pkg/errors"
)
const (
CookieName string = "edge-auth"
)
type GetKeySetFunc func() (jwk.Set, error)
func WithJWT(getKeySet GetKeySetFunc) OptionFunc {
return func(o *Option) {
o.GetClaims = func(ctx context.Context, r *http.Request, names ...string) ([]string, error) {
claim, err := getClaims[string](r, getKeySet, names...)
if err != nil {
return nil, errors.WithStack(err)
}
return claim, nil
}
}
}
func FindRawToken(r *http.Request) (string, error) {
authorization := r.Header.Get("Authorization")
// Retrieve token from Authorization header
rawToken := strings.TrimPrefix(authorization, "Bearer ")
// Retrieve token from ?edge-auth=<value>
if rawToken == "" {
rawToken = r.URL.Query().Get(CookieName)
}
if rawToken == "" {
cookie, err := r.Cookie(CookieName)
if err != nil && !errors.Is(err, http.ErrNoCookie) {
return "", errors.WithStack(err)
}
if cookie != nil {
rawToken = cookie.Value
}
}
if rawToken == "" {
return "", errors.WithStack(ErrUnauthenticated)
}
return rawToken, nil
}
func FindToken(r *http.Request, getKeySet GetKeySetFunc) (jwt.Token, error) {
rawToken, err := FindRawToken(r)
if err != nil {
return nil, errors.WithStack(err)
}
keySet, err := getKeySet()
if err != nil {
return nil, errors.WithStack(err)
}
if keySet == nil {
return nil, errors.New("no keyset")
}
token, err := jwt.Parse([]byte(rawToken),
jwt.WithKeySet(keySet, jws.WithRequireKid(false)),
jwt.WithValidate(true),
)
if err != nil {
return nil, errors.WithStack(err)
}
return token, nil
}
func getClaims[T any](r *http.Request, getKeySet GetKeySetFunc, names ...string) ([]T, error) {
token, err := FindToken(r, getKeySet)
if err != nil {
return nil, errors.WithStack(err)
}
ctx := r.Context()
mapClaims, err := token.AsMap(ctx)
if err != nil {
return nil, errors.WithStack(err)
}
claims := make([]T, len(names))
for idx, n := range names {
rawClaim, exists := mapClaims[n]
if !exists {
continue
}
claim, ok := rawClaim.(T)
if !ok {
return nil, errors.Errorf("unexpected claim '%s' to be of type '%T', got '%T'", n, new(T), rawClaim)
}
claims[idx] = claim
}
return claims, nil
}

View File

@ -7,10 +7,9 @@ import (
"net/http" "net/http"
"time" "time"
"forge.cadoles.com/arcad/edge/pkg/jwtutil"
"forge.cadoles.com/arcad/edge/pkg/module/auth" "forge.cadoles.com/arcad/edge/pkg/module/auth"
"forge.cadoles.com/arcad/edge/pkg/module/auth/jwt"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jwk"
"github.com/pkg/errors" "github.com/pkg/errors"
"gitlab.com/wpetit/goweb/logger" "gitlab.com/wpetit/goweb/logger"
@ -18,7 +17,7 @@ import (
const AnonIssuer = "anon" const AnonIssuer = "anon"
func AnonymousUser(algo jwa.KeyAlgorithm, key jwk.Key, funcs ...AnonymousUserOptionFunc) func(next http.Handler) http.Handler { func AnonymousUser(key jwk.Key, funcs ...AnonymousUserOptionFunc) func(next http.Handler) http.Handler {
opts := defaultAnonymousUserOptions() opts := defaultAnonymousUserOptions()
for _, fn := range funcs { for _, fn := range funcs {
fn(opts) fn(opts)
@ -26,7 +25,11 @@ func AnonymousUser(algo jwa.KeyAlgorithm, key jwk.Key, funcs ...AnonymousUserOpt
return func(next http.Handler) http.Handler { return func(next http.Handler) http.Handler {
handler := func(w http.ResponseWriter, r *http.Request) { handler := func(w http.ResponseWriter, r *http.Request) {
rawToken, err := auth.FindRawToken(r) rawToken, err := jwtutil.FindRawToken(r, jwtutil.WithFinders(
jwtutil.FindTokenFromAuthorizationHeader,
jwtutil.FindTokenFromQueryString(auth.CookieName),
jwtutil.FindTokenFromCookie(auth.CookieName),
))
// If request already has a raw token, we do nothing // If request already has a raw token, we do nothing
if rawToken != "" && err == nil { if rawToken != "" && err == nil {
@ -62,7 +65,7 @@ func AnonymousUser(algo jwa.KeyAlgorithm, key jwk.Key, funcs ...AnonymousUserOpt
auth.ClaimEdgeTenant: opts.Tenant, auth.ClaimEdgeTenant: opts.Tenant,
} }
token, err := jwt.GenerateSignedToken(algo, key, claims) token, err := jwtutil.GenerateSignedToken(key, claims)
if err != nil { if err != nil {
logger.Error(ctx, "could not generate signed token", logger.E(errors.WithStack(err))) logger.Error(ctx, "could not generate signed token", logger.E(errors.WithStack(err)))
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)

View File

@ -5,12 +5,17 @@ import (
"forge.cadoles.com/arcad/edge/pkg/app" "forge.cadoles.com/arcad/edge/pkg/app"
edgeHTTP "forge.cadoles.com/arcad/edge/pkg/http" edgeHTTP "forge.cadoles.com/arcad/edge/pkg/http"
"forge.cadoles.com/arcad/edge/pkg/jwtutil"
"forge.cadoles.com/arcad/edge/pkg/module/util" "forge.cadoles.com/arcad/edge/pkg/module/util"
"github.com/dop251/goja" "github.com/dop251/goja"
"github.com/pkg/errors" "github.com/pkg/errors"
"gitlab.com/wpetit/goweb/logger" "gitlab.com/wpetit/goweb/logger"
) )
const (
CookieName string = "edge-auth"
)
const ( const (
ClaimSubject = "sub" ClaimSubject = "sub"
ClaimIssuer = "iss" ClaimIssuer = "iss"
@ -22,7 +27,7 @@ const (
type Module struct { type Module struct {
server *app.Server server *app.Server
getClaims GetClaimsFunc getClaimFn GetClaimFunc
} }
func (m *Module) Name() string { func (m *Module) Name() string {
@ -68,9 +73,9 @@ func (m *Module) getClaim(call goja.FunctionCall, rt *goja.Runtime) goja.Value {
panic(rt.ToValue(errors.New("could not find http request in context"))) panic(rt.ToValue(errors.New("could not find http request in context")))
} }
claim, err := m.getClaims(ctx, req, claimName) claim, err := m.getClaimFn(ctx, req, claimName)
if err != nil { if err != nil {
if errors.Is(err, ErrUnauthenticated) { if errors.Is(err, jwtutil.ErrUnauthenticated) {
return nil return nil
} }
@ -78,11 +83,7 @@ func (m *Module) getClaim(call goja.FunctionCall, rt *goja.Runtime) goja.Value {
return nil return nil
} }
if len(claim) == 0 || claim[0] == "" { return rt.ToValue(claim)
return nil
}
return rt.ToValue(claim[0])
} }
func ModuleFactory(funcs ...OptionFunc) app.ServerModuleFactory { func ModuleFactory(funcs ...OptionFunc) app.ServerModuleFactory {
@ -94,7 +95,7 @@ func ModuleFactory(funcs ...OptionFunc) app.ServerModuleFactory {
return func(server *app.Server) app.ServerModule { return func(server *app.Server) app.ServerModule {
return &Module{ return &Module{
server: server, server: server,
getClaims: opt.GetClaims, getClaimFn: opt.GetClaim,
} }
} }
} }

View File

@ -10,6 +10,7 @@ import (
"cdr.dev/slog" "cdr.dev/slog"
"forge.cadoles.com/arcad/edge/pkg/app" "forge.cadoles.com/arcad/edge/pkg/app"
edgeHTTP "forge.cadoles.com/arcad/edge/pkg/http" edgeHTTP "forge.cadoles.com/arcad/edge/pkg/http"
"forge.cadoles.com/arcad/edge/pkg/jwtutil"
"forge.cadoles.com/arcad/edge/pkg/module" "forge.cadoles.com/arcad/edge/pkg/module"
"github.com/lestrrat-go/jwx/v2/jwa" "github.com/lestrrat-go/jwx/v2/jwa"
"github.com/lestrrat-go/jwx/v2/jwk" "github.com/lestrrat-go/jwx/v2/jwk"
@ -130,7 +131,7 @@ func getDummyKey() jwk.Key {
return key return key
} }
func getDummyKeySet(key jwk.Key) GetKeySetFunc { func getDummyKeySet(key jwk.Key) jwtutil.GetKeySetFunc {
return func() (jwk.Set, error) { return func() (jwk.Set, error) {
set := jwk.NewSet() set := jwk.NewSet()

View File

@ -3,6 +3,7 @@ package auth
import ( import (
"net/http" "net/http"
"forge.cadoles.com/arcad/edge/pkg/jwtutil"
"github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5"
"github.com/pkg/errors" "github.com/pkg/errors"
"gitlab.com/wpetit/goweb/api" "gitlab.com/wpetit/goweb/api"
@ -12,16 +13,19 @@ import (
type MountFunc func(r chi.Router) type MountFunc func(r chi.Router)
type Handler struct { type Handler struct {
getClaims GetClaimsFunc getClaim GetClaimFunc
profileClaims []string profileClaims []string
} }
func (h *Handler) serveProfile(w http.ResponseWriter, r *http.Request) { func (h *Handler) serveProfile(w http.ResponseWriter, r *http.Request) {
ctx := r.Context() ctx := r.Context()
claims, err := h.getClaims(ctx, r, h.profileClaims...) profile := make(map[string]any)
for _, name := range h.profileClaims {
value, err := h.getClaim(ctx, r, name)
if err != nil { if err != nil {
if errors.Is(err, ErrUnauthenticated) { if errors.Is(err, jwtutil.ErrUnauthenticated) {
api.ErrorResponse( api.ErrorResponse(
w, http.StatusUnauthorized, w, http.StatusUnauthorized,
api.ErrCodeUnauthorized, api.ErrCodeUnauthorized,
@ -41,10 +45,7 @@ func (h *Handler) serveProfile(w http.ResponseWriter, r *http.Request) {
return return
} }
profile := make(map[string]any) profile[name] = value
for idx, cl := range h.profileClaims {
profile[cl] = claims[idx]
} }
api.DataResponse(w, http.StatusOK, struct { api.DataResponse(w, http.StatusOK, struct {
@ -62,7 +63,7 @@ func Mount(authHandler http.Handler, funcs ...OptionFunc) MountFunc {
handler := &Handler{ handler := &Handler{
profileClaims: opt.ProfileClaims, profileClaims: opt.ProfileClaims,
getClaims: opt.GetClaims, getClaim: opt.GetClaim,
} }
return func(r chi.Router) { return func(r chi.Router) {

View File

@ -2,15 +2,17 @@ package auth
import ( import (
"context" "context"
"fmt"
"net/http" "net/http"
"forge.cadoles.com/arcad/edge/pkg/jwtutil"
"github.com/pkg/errors" "github.com/pkg/errors"
) )
type GetClaimsFunc func(ctx context.Context, r *http.Request, claims ...string) ([]string, error) type GetClaimFunc func(ctx context.Context, r *http.Request, name string) (string, error)
type Option struct { type Option struct {
GetClaims GetClaimsFunc GetClaim GetClaimFunc
ProfileClaims []string ProfileClaims []string
} }
@ -18,7 +20,7 @@ type OptionFunc func(*Option)
func defaultOptions() *Option { func defaultOptions() *Option {
return &Option{ return &Option{
GetClaims: dummyGetClaims, GetClaim: dummyGetClaim,
ProfileClaims: []string{ ProfileClaims: []string{
ClaimSubject, ClaimSubject,
ClaimIssuer, ClaimIssuer,
@ -30,13 +32,13 @@ func defaultOptions() *Option {
} }
} }
func dummyGetClaims(ctx context.Context, r *http.Request, claims ...string) ([]string, error) { func dummyGetClaim(ctx context.Context, r *http.Request, name string) (string, error) {
return nil, errors.Errorf("dummy getclaim func cannot retrieve claims '%s'", claims) return "", errors.Errorf("dummy getclaim func cannot retrieve claim '%s'", name)
} }
func WithGetClaims(fn GetClaimsFunc) OptionFunc { func WithGetClaims(fn GetClaimFunc) OptionFunc {
return func(o *Option) { return func(o *Option) {
o.GetClaims = fn o.GetClaim = fn
} }
} }
@ -45,3 +47,34 @@ func WithProfileClaims(claims ...string) OptionFunc {
o.ProfileClaims = claims o.ProfileClaims = claims
} }
} }
func WithJWT(getKeySet jwtutil.GetKeySetFunc) OptionFunc {
funcs := []jwtutil.FindTokenOptionFunc{
jwtutil.WithFinders(
jwtutil.FindTokenFromAuthorizationHeader,
jwtutil.FindTokenFromQueryString(CookieName),
jwtutil.FindTokenFromCookie(CookieName),
),
}
return func(o *Option) {
o.GetClaim = func(ctx context.Context, r *http.Request, name string) (string, error) {
token, err := jwtutil.FindToken(r, getKeySet, funcs...)
if err != nil {
return "", errors.WithStack(err)
}
tokenMap, err := token.AsMap(ctx)
if err != nil {
return "", errors.WithStack(err)
}
value, exists := tokenMap[name]
if !exists {
return "", nil
}
return fmt.Sprintf("%v", value), nil
}
}
}

View File

@ -6,11 +6,13 @@ import (
"forge.cadoles.com/arcad/edge/pkg/storage" "forge.cadoles.com/arcad/edge/pkg/storage"
"forge.cadoles.com/arcad/edge/pkg/storage/driver" "forge.cadoles.com/arcad/edge/pkg/storage/driver"
"forge.cadoles.com/arcad/edge/pkg/storage/driver/rpc/client" "forge.cadoles.com/arcad/edge/pkg/storage/driver/rpc/client"
"forge.cadoles.com/arcad/edge/pkg/storage/share"
) )
func init() { func init() {
driver.RegisterDocumentStoreFactory("rpc", documentStoreFactory) driver.RegisterDocumentStoreFactory("rpc", documentStoreFactory)
driver.RegisterBlobStoreFactory("rpc", blobStoreFactory) driver.RegisterBlobStoreFactory("rpc", blobStoreFactory)
driver.RegisterShareStoreFactory("rpc", shareStoreFactory)
} }
func documentStoreFactory(url *url.URL) (storage.DocumentStore, error) { func documentStoreFactory(url *url.URL) (storage.DocumentStore, error) {
@ -20,3 +22,7 @@ func documentStoreFactory(url *url.URL) (storage.DocumentStore, error) {
func blobStoreFactory(url *url.URL) (storage.BlobStore, error) { func blobStoreFactory(url *url.URL) (storage.BlobStore, error) {
return client.NewBlobStore(url), nil return client.NewBlobStore(url), nil
} }
func shareStoreFactory(url *url.URL) (share.Store, error) {
return client.NewShareStore(url), nil
}