feat(storage-server): jwt based authentication
arcad/edge/pipeline/pr-master This commit looks good
Details
arcad/edge/pipeline/pr-master This commit looks good
Details
This commit is contained in:
parent
c63af872ea
commit
09da1c6ce9
|
@ -1,3 +1,4 @@
|
||||||
RUN_APP_ARGS=""
|
RUN_APP_ARGS=""
|
||||||
#EDGE_DOCUMENTSTORE_DSN="rpc://localhost:3001/documentstore?tenant=local&appId=%APPID%"
|
#EDGE_DOCUMENTSTORE_DSN="rpc://localhost:3001/documentstore?tenant=local&appId=%APPID%"
|
||||||
#EDGE_BLOBSTORE_DSN="rpc://localhost:3001/blobstore?tenant=local&appId=%APPID%"
|
#EDGE_BLOBSTORE_DSN="rpc://localhost:3001/blobstore?tenant=local&appId=%APPID%"
|
||||||
|
#EDGE_SHARESTORE_DSN="rpc://localhost:3001/sharestore?tenant=local"
|
|
@ -10,3 +10,4 @@
|
||||||
/dist
|
/dist
|
||||||
/.chglog
|
/.chglog
|
||||||
/CHANGELOG.md
|
/CHANGELOG.md
|
||||||
|
/storage-server.key
|
|
@ -16,6 +16,7 @@ import (
|
||||||
"forge.cadoles.com/arcad/edge/pkg/bus"
|
"forge.cadoles.com/arcad/edge/pkg/bus"
|
||||||
"forge.cadoles.com/arcad/edge/pkg/bus/memory"
|
"forge.cadoles.com/arcad/edge/pkg/bus/memory"
|
||||||
appHTTP "forge.cadoles.com/arcad/edge/pkg/http"
|
appHTTP "forge.cadoles.com/arcad/edge/pkg/http"
|
||||||
|
"forge.cadoles.com/arcad/edge/pkg/jwtutil"
|
||||||
"forge.cadoles.com/arcad/edge/pkg/module"
|
"forge.cadoles.com/arcad/edge/pkg/module"
|
||||||
appModule "forge.cadoles.com/arcad/edge/pkg/module/app"
|
appModule "forge.cadoles.com/arcad/edge/pkg/module/app"
|
||||||
appModuleMemory "forge.cadoles.com/arcad/edge/pkg/module/app/memory"
|
appModuleMemory "forge.cadoles.com/arcad/edge/pkg/module/app/memory"
|
||||||
|
@ -33,7 +34,6 @@ import (
|
||||||
"forge.cadoles.com/arcad/edge/pkg/bundle"
|
"forge.cadoles.com/arcad/edge/pkg/bundle"
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
"github.com/go-chi/chi/v5/middleware"
|
"github.com/go-chi/chi/v5/middleware"
|
||||||
"github.com/lestrrat-go/jwx/v2/jwa"
|
|
||||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/urfave/cli/v2"
|
"github.com/urfave/cli/v2"
|
||||||
|
@ -50,6 +50,8 @@ import (
|
||||||
"forge.cadoles.com/arcad/edge/pkg/storage/share"
|
"forge.cadoles.com/arcad/edge/pkg/storage/share"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var dummySecret = []byte("not_so_secret")
|
||||||
|
|
||||||
func RunCommand() *cli.Command {
|
func RunCommand() *cli.Command {
|
||||||
return &cli.Command{
|
return &cli.Command{
|
||||||
Name: "run",
|
Name: "run",
|
||||||
|
@ -194,7 +196,7 @@ func runApp(ctx context.Context, path, address, documentStoreDSN, blobStoreDSN,
|
||||||
ctx = logger.With(ctx, logger.F("appID", manifest.ID))
|
ctx = logger.With(ctx, logger.F("appID", manifest.ID))
|
||||||
|
|
||||||
// Add auth handler
|
// Add auth handler
|
||||||
key, err := dummyKey()
|
key, err := jwtutil.NewSymmetricKey(dummySecret)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.WithStack(err)
|
return errors.WithStack(err)
|
||||||
}
|
}
|
||||||
|
@ -220,17 +222,17 @@ func runApp(ctx context.Context, path, address, documentStoreDSN, blobStoreDSN,
|
||||||
appModule.Mount(appRepository),
|
appModule.Mount(appRepository),
|
||||||
authModule.Mount(
|
authModule.Mount(
|
||||||
authHTTP.NewLocalHandler(
|
authHTTP.NewLocalHandler(
|
||||||
jwa.HS256, key,
|
key,
|
||||||
authHTTP.WithRoutePrefix("/auth"),
|
authHTTP.WithRoutePrefix("/auth"),
|
||||||
authHTTP.WithAccounts(deps.Accounts...),
|
authHTTP.WithAccounts(deps.Accounts...),
|
||||||
),
|
),
|
||||||
authModule.WithJWT(dummyKeySet),
|
authModule.WithJWT(func() (jwk.Set, error) {
|
||||||
|
return jwtutil.NewSymmetricKeySet(dummySecret)
|
||||||
|
}),
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
appHTTP.WithHTTPMiddlewares(
|
appHTTP.WithHTTPMiddlewares(
|
||||||
authModuleMiddleware.AnonymousUser(
|
authModuleMiddleware.AnonymousUser(key),
|
||||||
jwa.HS256, key,
|
|
||||||
),
|
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
if err := handler.Load(bundle); err != nil {
|
if err := handler.Load(bundle); err != nil {
|
||||||
|
@ -276,7 +278,9 @@ func getServerModules(deps *moduleDeps) []app.ServerModuleFactory {
|
||||||
module.StoreModuleFactory(deps.DocumentStore),
|
module.StoreModuleFactory(deps.DocumentStore),
|
||||||
blob.ModuleFactory(deps.Bus, deps.BlobStore),
|
blob.ModuleFactory(deps.Bus, deps.BlobStore),
|
||||||
authModule.ModuleFactory(
|
authModule.ModuleFactory(
|
||||||
authModule.WithJWT(dummyKeySet),
|
authModule.WithJWT(func() (jwk.Set, error) {
|
||||||
|
return jwtutil.NewSymmetricKeySet(dummySecret)
|
||||||
|
}),
|
||||||
),
|
),
|
||||||
appModule.ModuleFactory(deps.AppRepository),
|
appModule.ModuleFactory(deps.AppRepository),
|
||||||
fetch.ModuleFactory(deps.Bus),
|
fetch.ModuleFactory(deps.Bus),
|
||||||
|
@ -284,36 +288,6 @@ func getServerModules(deps *moduleDeps) []app.ServerModuleFactory {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
var dummySecret = []byte("not_so_secret")
|
|
||||||
|
|
||||||
func dummyKey() (jwk.Key, error) {
|
|
||||||
key, err := jwk.FromRaw(dummySecret)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.WithStack(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return key, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func dummyKeySet() (jwk.Set, error) {
|
|
||||||
key, err := dummyKey()
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.WithStack(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := key.Set(jwk.AlgorithmKey, jwa.HS256); err != nil {
|
|
||||||
return nil, errors.WithStack(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
set := jwk.NewSet()
|
|
||||||
|
|
||||||
if err := set.AddKey(key); err != nil {
|
|
||||||
return nil, errors.WithStack(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return set, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func ensureDir(path string) error {
|
func ensureDir(path string) error {
|
||||||
if err := os.MkdirAll(filepath.Dir(path), os.ModePerm); err != nil {
|
if err := os.MkdirAll(filepath.Dir(path), os.ModePerm); err != nil {
|
||||||
return errors.WithStack(err)
|
return errors.WithStack(err)
|
||||||
|
|
|
@ -1,6 +1,7 @@
|
||||||
package command
|
package command
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -9,6 +10,7 @@ import (
|
||||||
|
|
||||||
"github.com/hashicorp/golang-lru/v2/expirable"
|
"github.com/hashicorp/golang-lru/v2/expirable"
|
||||||
"github.com/keegancsmith/rpc"
|
"github.com/keegancsmith/rpc"
|
||||||
|
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||||
"gitlab.com/wpetit/goweb/logger"
|
"gitlab.com/wpetit/goweb/logger"
|
||||||
|
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
|
@ -17,6 +19,7 @@ import (
|
||||||
"github.com/urfave/cli/v2"
|
"github.com/urfave/cli/v2"
|
||||||
|
|
||||||
// Register storage drivers
|
// Register storage drivers
|
||||||
|
"forge.cadoles.com/arcad/edge/pkg/jwtutil"
|
||||||
"forge.cadoles.com/arcad/edge/pkg/storage"
|
"forge.cadoles.com/arcad/edge/pkg/storage"
|
||||||
"forge.cadoles.com/arcad/edge/pkg/storage/driver"
|
"forge.cadoles.com/arcad/edge/pkg/storage/driver"
|
||||||
_ "forge.cadoles.com/arcad/edge/pkg/storage/driver/rpc"
|
_ "forge.cadoles.com/arcad/edge/pkg/storage/driver/rpc"
|
||||||
|
@ -50,6 +53,12 @@ func Run() *cli.Command {
|
||||||
EnvVars: []string{"STORAGE_SERVER_SHARESTORE_DSN_PATTERN"},
|
EnvVars: []string{"STORAGE_SERVER_SHARESTORE_DSN_PATTERN"},
|
||||||
Value: fmt.Sprintf("sqlite://data/%%TENANT%%/sharestore.sqlite?_pragma=foreign_keys(1)&_pragma=busy_timeout=%d", (60 * time.Second).Milliseconds()),
|
Value: fmt.Sprintf("sqlite://data/%%TENANT%%/sharestore.sqlite?_pragma=foreign_keys(1)&_pragma=busy_timeout=%d", (60 * time.Second).Milliseconds()),
|
||||||
},
|
},
|
||||||
|
&cli.StringFlag{
|
||||||
|
Name: "private-key",
|
||||||
|
EnvVars: []string{"STORAGE_SERVER_PRIVATE_KEY"},
|
||||||
|
Value: "storage-server.key",
|
||||||
|
TakesFile: true,
|
||||||
|
},
|
||||||
&cli.DurationFlag{
|
&cli.DurationFlag{
|
||||||
Name: "cache-ttl",
|
Name: "cache-ttl",
|
||||||
EnvVars: []string{"STORAGE_SERVER_CACHE_TTL"},
|
EnvVars: []string{"STORAGE_SERVER_CACHE_TTL"},
|
||||||
|
@ -68,9 +77,20 @@ func Run() *cli.Command {
|
||||||
shareStoreDSNPattern := ctx.String("sharestore-dsn-pattern")
|
shareStoreDSNPattern := ctx.String("sharestore-dsn-pattern")
|
||||||
cacheSize := ctx.Int("cache-size")
|
cacheSize := ctx.Int("cache-size")
|
||||||
cacheTTL := ctx.Duration("cache-ttl")
|
cacheTTL := ctx.Duration("cache-ttl")
|
||||||
|
privateKeyFile := ctx.String("private-key")
|
||||||
|
|
||||||
router := chi.NewRouter()
|
router := chi.NewRouter()
|
||||||
|
|
||||||
|
rsaKey, err := jwtutil.LoadOrGenerateRSAKey(privateKeyFile, 2048)
|
||||||
|
if err != nil {
|
||||||
|
return errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
privateKey, err := jwtutil.FromRSA(rsaKey)
|
||||||
|
if err != nil {
|
||||||
|
return errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
getBlobStoreServer := createGetCachedStoreServer(
|
getBlobStoreServer := createGetCachedStoreServer(
|
||||||
func(dsn string) (storage.BlobStore, error) {
|
func(dsn string) (storage.BlobStore, error) {
|
||||||
return driver.NewBlobStore(dsn)
|
return driver.NewBlobStore(dsn)
|
||||||
|
@ -100,6 +120,7 @@ func Run() *cli.Command {
|
||||||
|
|
||||||
router.Use(middleware.RealIP)
|
router.Use(middleware.RealIP)
|
||||||
router.Use(middleware.Logger)
|
router.Use(middleware.Logger)
|
||||||
|
router.Use(authenticate(privateKey))
|
||||||
|
|
||||||
router.Handle("/blobstore", createStoreHandler(getBlobStoreServer, blobStoreDSNPattern, cacheSize, cacheTTL))
|
router.Handle("/blobstore", createStoreHandler(getBlobStoreServer, blobStoreDSNPattern, cacheSize, cacheTTL))
|
||||||
router.Handle("/documentstore", createStoreHandler(getDocumentStoreServer, documentStoreDSNPattern, cacheSize, cacheTTL))
|
router.Handle("/documentstore", createStoreHandler(getDocumentStoreServer, documentStoreDSNPattern, cacheSize, cacheTTL))
|
||||||
|
@ -177,3 +198,76 @@ func createStoreHandler(getStoreServer getRPCServerFunc, dsnPattern string, cach
|
||||||
server.ServeHTTP(w, r)
|
server.ServeHTTP(w, r)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func authenticate(privateKey jwk.Key) func(http.Handler) http.Handler {
|
||||||
|
keySet, err := jwtutil.NewKeySet(privateKey)
|
||||||
|
getKeySet := func() (jwk.Set, error) {
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return keySet, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(next http.Handler) http.Handler {
|
||||||
|
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
ctx := r.Context()
|
||||||
|
|
||||||
|
token, err := jwtutil.FindToken(r, getKeySet, jwtutil.WithFinders(
|
||||||
|
jwtutil.FindTokenFromQueryString("token"),
|
||||||
|
))
|
||||||
|
if err != nil {
|
||||||
|
logger.Error(ctx, "could not find jwt token", logger.E(errors.WithStack(err)))
|
||||||
|
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenMap, err := token.AsMap(ctx)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error(ctx, "could not transform token to map", logger.E(errors.WithStack(err)))
|
||||||
|
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
rawTenant, exists := tokenMap["tenant"]
|
||||||
|
if !exists {
|
||||||
|
logger.Warn(ctx, "could not find tenant claim", logger.F("token", token))
|
||||||
|
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
tenant, ok := rawTenant.(string)
|
||||||
|
if !ok {
|
||||||
|
logger.Warn(ctx, "unexpected tenant claim value", logger.F("token", token))
|
||||||
|
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
r = r.WithContext(context.WithValue(ctx, "tenant", tenant))
|
||||||
|
|
||||||
|
rawAppId, exists := tokenMap["appId"]
|
||||||
|
if !exists {
|
||||||
|
logger.Warn(ctx, "could not find appId claim", logger.F("token", token))
|
||||||
|
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
appId, ok := rawAppId.(string)
|
||||||
|
if !ok {
|
||||||
|
logger.Warn(ctx, "unexpected appId claim value", logger.F("token", token))
|
||||||
|
http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
r = r.WithContext(context.WithValue(ctx, "appId", appId))
|
||||||
|
|
||||||
|
next.ServeHTTP(w, r)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -1,5 +1,6 @@
|
||||||
package auth
|
package jwtutil
|
||||||
|
|
||||||
import "errors"
|
import "errors"
|
||||||
|
|
||||||
var ErrUnauthenticated = errors.New("unauthenticated")
|
var ErrUnauthenticated = errors.New("unauthenticated")
|
||||||
|
var ErrNoKeySet = errors.New("no keyset")
|
|
@ -0,0 +1,123 @@
|
||||||
|
package jwtutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||||
|
"github.com/lestrrat-go/jwx/v2/jws"
|
||||||
|
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
type TokenFinderFunc func(r *http.Request) (string, error)
|
||||||
|
|
||||||
|
type FindTokenOptions struct {
|
||||||
|
Finders []TokenFinderFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
type FindTokenOptionFunc func(*FindTokenOptions)
|
||||||
|
|
||||||
|
type GetKeySetFunc func() (jwk.Set, error)
|
||||||
|
|
||||||
|
func WithFinders(finders ...TokenFinderFunc) FindTokenOptionFunc {
|
||||||
|
return func(opts *FindTokenOptions) {
|
||||||
|
opts.Finders = finders
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewFindTokenOptions(funcs ...FindTokenOptionFunc) *FindTokenOptions {
|
||||||
|
opts := &FindTokenOptions{
|
||||||
|
Finders: []TokenFinderFunc{
|
||||||
|
FindTokenFromAuthorizationHeader,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, fn := range funcs {
|
||||||
|
fn(opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
return opts
|
||||||
|
}
|
||||||
|
|
||||||
|
func FindTokenFromAuthorizationHeader(r *http.Request) (string, error) {
|
||||||
|
authorization := r.Header.Get("Authorization")
|
||||||
|
|
||||||
|
// Retrieve token from Authorization header
|
||||||
|
rawToken := strings.TrimPrefix(authorization, "Bearer ")
|
||||||
|
|
||||||
|
return rawToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func FindTokenFromQueryString(name string) TokenFinderFunc {
|
||||||
|
return func(r *http.Request) (string, error) {
|
||||||
|
return r.URL.Query().Get(name), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func FindTokenFromCookie(cookieName string) TokenFinderFunc {
|
||||||
|
return func(r *http.Request) (string, error) {
|
||||||
|
cookie, err := r.Cookie(cookieName)
|
||||||
|
if err != nil && !errors.Is(err, http.ErrNoCookie) {
|
||||||
|
return "", errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if cookie == nil {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return cookie.Value, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func FindRawToken(r *http.Request, funcs ...FindTokenOptionFunc) (string, error) {
|
||||||
|
opts := NewFindTokenOptions(funcs...)
|
||||||
|
|
||||||
|
var rawToken string
|
||||||
|
var err error
|
||||||
|
|
||||||
|
for _, find := range opts.Finders {
|
||||||
|
rawToken, err = find(r)
|
||||||
|
if err != nil {
|
||||||
|
return "", errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rawToken == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
break
|
||||||
|
}
|
||||||
|
|
||||||
|
if rawToken == "" {
|
||||||
|
return "", errors.WithStack(ErrUnauthenticated)
|
||||||
|
}
|
||||||
|
|
||||||
|
return rawToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func FindToken(r *http.Request, getKeySet GetKeySetFunc, funcs ...FindTokenOptionFunc) (jwt.Token, error) {
|
||||||
|
rawToken, err := FindRawToken(r, funcs...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
keySet, err := getKeySet()
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if keySet == nil {
|
||||||
|
return nil, errors.WithStack(ErrNoKeySet)
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := jwt.Parse([]byte(rawToken),
|
||||||
|
jwt.WithKeySet(keySet, jws.WithRequireKid(false)),
|
||||||
|
jwt.WithValidate(true),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return token, nil
|
||||||
|
}
|
|
@ -0,0 +1,52 @@
|
||||||
|
package jwtutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"github.com/lestrrat-go/jwx/v2/jwa"
|
||||||
|
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
func NewKeySet(keys ...jwk.Key) (jwk.Set, error) {
|
||||||
|
set := jwk.NewSet()
|
||||||
|
|
||||||
|
for _, k := range keys {
|
||||||
|
if err := set.AddKey(k); err != nil {
|
||||||
|
return nil, errors.WithStack(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return set, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSymmetricKey(secret []byte) (jwk.Key, error) {
|
||||||
|
key, err := jwk.FromRaw(secret)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := key.Set(jwk.AlgorithmKey, jwa.HS256); err != nil {
|
||||||
|
return nil, errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return key, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewSymmetricKeySet(secrets ...[]byte) (jwk.Set, error) {
|
||||||
|
keys := make([]jwk.Key, len(secrets))
|
||||||
|
|
||||||
|
for idx, sec := range secrets {
|
||||||
|
key, err := NewSymmetricKey(sec)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
keys[idx] = key
|
||||||
|
}
|
||||||
|
|
||||||
|
keySet, err := NewKeySet(keys...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return keySet, nil
|
||||||
|
}
|
|
@ -0,0 +1,90 @@
|
||||||
|
package jwtutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/pem"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
"github.com/lestrrat-go/jwx/v2/jwa"
|
||||||
|
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
)
|
||||||
|
|
||||||
|
func FromRSA(privateKey *rsa.PrivateKey) (jwk.Key, error) {
|
||||||
|
key, err := jwk.FromRaw(privateKey)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := key.Set(jwk.AlgorithmKey, jwa.RS256); err != nil {
|
||||||
|
return nil, errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return key, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadOrGenerateRSAKey(path string, size int) (*rsa.PrivateKey, error) {
|
||||||
|
privateKey, err := LoadRSAKey(path)
|
||||||
|
if err != nil {
|
||||||
|
if !errors.Is(err, os.ErrNotExist) {
|
||||||
|
return nil, errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
privateKey, err = GenerateRSAKey(size)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := SaveRSAKey(path, privateKey); err != nil {
|
||||||
|
return nil, errors.WithStack(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return privateKey, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func LoadRSAKey(path string) (*rsa.PrivateKey, error) {
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
block, _ := pem.Decode(data)
|
||||||
|
if block == nil {
|
||||||
|
return nil, errors.New("failed to parse pem block")
|
||||||
|
}
|
||||||
|
|
||||||
|
privateKey, err := x509.ParsePKCS1PrivateKey(block.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return privateKey, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func SaveRSAKey(path string, privateKey *rsa.PrivateKey) error {
|
||||||
|
data := x509.MarshalPKCS1PrivateKey(privateKey)
|
||||||
|
pem := pem.EncodeToMemory(
|
||||||
|
&pem.Block{
|
||||||
|
Type: "RSA PRIVATE KEY",
|
||||||
|
Bytes: data,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
if err := os.WriteFile(path, pem, os.FileMode(0600)); err != nil {
|
||||||
|
return errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func GenerateRSAKey(size int) (*rsa.PrivateKey, error) {
|
||||||
|
privateKey, err := rsa.GenerateKey(rand.Reader, size)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return privateKey, nil
|
||||||
|
}
|
|
@ -1,15 +1,14 @@
|
||||||
package jwt
|
package jwtutil
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/lestrrat-go/jwx/v2/jwa"
|
|
||||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
func GenerateSignedToken(algo jwa.KeyAlgorithm, key jwk.Key, claims map[string]any) ([]byte, error) {
|
func GenerateSignedToken(key jwk.Key, claims map[string]any) ([]byte, error) {
|
||||||
token := jwt.New()
|
token := jwt.New()
|
||||||
|
|
||||||
if err := token.Set(jwt.NotBeforeKey, time.Now()); err != nil {
|
if err := token.Set(jwt.NotBeforeKey, time.Now()); err != nil {
|
||||||
|
@ -22,11 +21,11 @@ func GenerateSignedToken(algo jwa.KeyAlgorithm, key jwk.Key, claims map[string]a
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := token.Set(jwk.AlgorithmKey, jwa.HS256); err != nil {
|
if err := token.Set(jwk.AlgorithmKey, key.Algorithm()); err != nil {
|
||||||
return nil, errors.WithStack(err)
|
return nil, errors.WithStack(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
rawToken, err := jwt.Sign(token, jwt.WithKey(algo, key))
|
rawToken, err := jwt.Sign(token, jwt.WithKey(key.Algorithm(), key))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.WithStack(err)
|
return nil, errors.WithStack(err)
|
||||||
}
|
}
|
|
@ -7,11 +7,10 @@ import (
|
||||||
|
|
||||||
_ "embed"
|
_ "embed"
|
||||||
|
|
||||||
|
"forge.cadoles.com/arcad/edge/pkg/jwtutil"
|
||||||
"forge.cadoles.com/arcad/edge/pkg/module/auth"
|
"forge.cadoles.com/arcad/edge/pkg/module/auth"
|
||||||
"forge.cadoles.com/arcad/edge/pkg/module/auth/http/passwd"
|
"forge.cadoles.com/arcad/edge/pkg/module/auth/http/passwd"
|
||||||
"forge.cadoles.com/arcad/edge/pkg/module/auth/jwt"
|
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
"github.com/lestrrat-go/jwx/v2/jwa"
|
|
||||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"gitlab.com/wpetit/goweb/logger"
|
"gitlab.com/wpetit/goweb/logger"
|
||||||
|
@ -32,7 +31,6 @@ func init() {
|
||||||
|
|
||||||
type LocalHandler struct {
|
type LocalHandler struct {
|
||||||
router chi.Router
|
router chi.Router
|
||||||
algo jwa.KeyAlgorithm
|
|
||||||
key jwk.Key
|
key jwk.Key
|
||||||
getCookieDomain GetCookieDomainFunc
|
getCookieDomain GetCookieDomainFunc
|
||||||
cookieDuration time.Duration
|
cookieDuration time.Duration
|
||||||
|
@ -113,7 +111,7 @@ func (h *LocalHandler) handleForm(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
account.Claims[auth.ClaimIssuer] = "local"
|
account.Claims[auth.ClaimIssuer] = "local"
|
||||||
|
|
||||||
token, err := jwt.GenerateSignedToken(h.algo, h.key, account.Claims)
|
token, err := jwtutil.GenerateSignedToken(h.key, account.Claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error(ctx, "could not generate signed token", logger.E(errors.WithStack(err)))
|
logger.Error(ctx, "could not generate signed token", logger.E(errors.WithStack(err)))
|
||||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||||
|
@ -182,14 +180,13 @@ func (h *LocalHandler) authenticate(username, password string) (*LocalAccount, e
|
||||||
return &account, nil
|
return &account, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewLocalHandler(algo jwa.KeyAlgorithm, key jwk.Key, funcs ...LocalHandlerOptionFunc) *LocalHandler {
|
func NewLocalHandler(key jwk.Key, funcs ...LocalHandlerOptionFunc) *LocalHandler {
|
||||||
opts := defaultLocalHandlerOptions()
|
opts := defaultLocalHandlerOptions()
|
||||||
for _, fn := range funcs {
|
for _, fn := range funcs {
|
||||||
fn(opts)
|
fn(opts)
|
||||||
}
|
}
|
||||||
|
|
||||||
handler := &LocalHandler{
|
handler := &LocalHandler{
|
||||||
algo: algo,
|
|
||||||
key: key,
|
key: key,
|
||||||
accounts: toAccountsMap(opts.Accounts),
|
accounts: toAccountsMap(opts.Accounts),
|
||||||
getCookieDomain: opts.GetCookieDomain,
|
getCookieDomain: opts.GetCookieDomain,
|
||||||
|
|
|
@ -1,118 +0,0 @@
|
||||||
package auth
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"net/http"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
|
||||||
"github.com/lestrrat-go/jwx/v2/jws"
|
|
||||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
CookieName string = "edge-auth"
|
|
||||||
)
|
|
||||||
|
|
||||||
type GetKeySetFunc func() (jwk.Set, error)
|
|
||||||
|
|
||||||
func WithJWT(getKeySet GetKeySetFunc) OptionFunc {
|
|
||||||
return func(o *Option) {
|
|
||||||
o.GetClaims = func(ctx context.Context, r *http.Request, names ...string) ([]string, error) {
|
|
||||||
claim, err := getClaims[string](r, getKeySet, names...)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.WithStack(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return claim, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func FindRawToken(r *http.Request) (string, error) {
|
|
||||||
authorization := r.Header.Get("Authorization")
|
|
||||||
|
|
||||||
// Retrieve token from Authorization header
|
|
||||||
rawToken := strings.TrimPrefix(authorization, "Bearer ")
|
|
||||||
|
|
||||||
// Retrieve token from ?edge-auth=<value>
|
|
||||||
if rawToken == "" {
|
|
||||||
rawToken = r.URL.Query().Get(CookieName)
|
|
||||||
}
|
|
||||||
|
|
||||||
if rawToken == "" {
|
|
||||||
cookie, err := r.Cookie(CookieName)
|
|
||||||
if err != nil && !errors.Is(err, http.ErrNoCookie) {
|
|
||||||
return "", errors.WithStack(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if cookie != nil {
|
|
||||||
rawToken = cookie.Value
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if rawToken == "" {
|
|
||||||
return "", errors.WithStack(ErrUnauthenticated)
|
|
||||||
}
|
|
||||||
|
|
||||||
return rawToken, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func FindToken(r *http.Request, getKeySet GetKeySetFunc) (jwt.Token, error) {
|
|
||||||
rawToken, err := FindRawToken(r)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.WithStack(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
keySet, err := getKeySet()
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.WithStack(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if keySet == nil {
|
|
||||||
return nil, errors.New("no keyset")
|
|
||||||
}
|
|
||||||
|
|
||||||
token, err := jwt.Parse([]byte(rawToken),
|
|
||||||
jwt.WithKeySet(keySet, jws.WithRequireKid(false)),
|
|
||||||
jwt.WithValidate(true),
|
|
||||||
)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.WithStack(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return token, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getClaims[T any](r *http.Request, getKeySet GetKeySetFunc, names ...string) ([]T, error) {
|
|
||||||
token, err := FindToken(r, getKeySet)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.WithStack(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := r.Context()
|
|
||||||
|
|
||||||
mapClaims, err := token.AsMap(ctx)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.WithStack(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
claims := make([]T, len(names))
|
|
||||||
|
|
||||||
for idx, n := range names {
|
|
||||||
rawClaim, exists := mapClaims[n]
|
|
||||||
if !exists {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
claim, ok := rawClaim.(T)
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.Errorf("unexpected claim '%s' to be of type '%T', got '%T'", n, new(T), rawClaim)
|
|
||||||
}
|
|
||||||
|
|
||||||
claims[idx] = claim
|
|
||||||
}
|
|
||||||
|
|
||||||
return claims, nil
|
|
||||||
}
|
|
|
@ -7,10 +7,9 @@ import (
|
||||||
"net/http"
|
"net/http"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"forge.cadoles.com/arcad/edge/pkg/jwtutil"
|
||||||
"forge.cadoles.com/arcad/edge/pkg/module/auth"
|
"forge.cadoles.com/arcad/edge/pkg/module/auth"
|
||||||
"forge.cadoles.com/arcad/edge/pkg/module/auth/jwt"
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/lestrrat-go/jwx/v2/jwa"
|
|
||||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"gitlab.com/wpetit/goweb/logger"
|
"gitlab.com/wpetit/goweb/logger"
|
||||||
|
@ -18,7 +17,7 @@ import (
|
||||||
|
|
||||||
const AnonIssuer = "anon"
|
const AnonIssuer = "anon"
|
||||||
|
|
||||||
func AnonymousUser(algo jwa.KeyAlgorithm, key jwk.Key, funcs ...AnonymousUserOptionFunc) func(next http.Handler) http.Handler {
|
func AnonymousUser(key jwk.Key, funcs ...AnonymousUserOptionFunc) func(next http.Handler) http.Handler {
|
||||||
opts := defaultAnonymousUserOptions()
|
opts := defaultAnonymousUserOptions()
|
||||||
for _, fn := range funcs {
|
for _, fn := range funcs {
|
||||||
fn(opts)
|
fn(opts)
|
||||||
|
@ -26,7 +25,11 @@ func AnonymousUser(algo jwa.KeyAlgorithm, key jwk.Key, funcs ...AnonymousUserOpt
|
||||||
|
|
||||||
return func(next http.Handler) http.Handler {
|
return func(next http.Handler) http.Handler {
|
||||||
handler := func(w http.ResponseWriter, r *http.Request) {
|
handler := func(w http.ResponseWriter, r *http.Request) {
|
||||||
rawToken, err := auth.FindRawToken(r)
|
rawToken, err := jwtutil.FindRawToken(r, jwtutil.WithFinders(
|
||||||
|
jwtutil.FindTokenFromAuthorizationHeader,
|
||||||
|
jwtutil.FindTokenFromQueryString(auth.CookieName),
|
||||||
|
jwtutil.FindTokenFromCookie(auth.CookieName),
|
||||||
|
))
|
||||||
|
|
||||||
// If request already has a raw token, we do nothing
|
// If request already has a raw token, we do nothing
|
||||||
if rawToken != "" && err == nil {
|
if rawToken != "" && err == nil {
|
||||||
|
@ -62,7 +65,7 @@ func AnonymousUser(algo jwa.KeyAlgorithm, key jwk.Key, funcs ...AnonymousUserOpt
|
||||||
auth.ClaimEdgeTenant: opts.Tenant,
|
auth.ClaimEdgeTenant: opts.Tenant,
|
||||||
}
|
}
|
||||||
|
|
||||||
token, err := jwt.GenerateSignedToken(algo, key, claims)
|
token, err := jwtutil.GenerateSignedToken(key, claims)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error(ctx, "could not generate signed token", logger.E(errors.WithStack(err)))
|
logger.Error(ctx, "could not generate signed token", logger.E(errors.WithStack(err)))
|
||||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||||
|
|
|
@ -5,12 +5,17 @@ import (
|
||||||
|
|
||||||
"forge.cadoles.com/arcad/edge/pkg/app"
|
"forge.cadoles.com/arcad/edge/pkg/app"
|
||||||
edgeHTTP "forge.cadoles.com/arcad/edge/pkg/http"
|
edgeHTTP "forge.cadoles.com/arcad/edge/pkg/http"
|
||||||
|
"forge.cadoles.com/arcad/edge/pkg/jwtutil"
|
||||||
"forge.cadoles.com/arcad/edge/pkg/module/util"
|
"forge.cadoles.com/arcad/edge/pkg/module/util"
|
||||||
"github.com/dop251/goja"
|
"github.com/dop251/goja"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"gitlab.com/wpetit/goweb/logger"
|
"gitlab.com/wpetit/goweb/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
CookieName string = "edge-auth"
|
||||||
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
ClaimSubject = "sub"
|
ClaimSubject = "sub"
|
||||||
ClaimIssuer = "iss"
|
ClaimIssuer = "iss"
|
||||||
|
@ -22,7 +27,7 @@ const (
|
||||||
|
|
||||||
type Module struct {
|
type Module struct {
|
||||||
server *app.Server
|
server *app.Server
|
||||||
getClaims GetClaimsFunc
|
getClaimFn GetClaimFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Module) Name() string {
|
func (m *Module) Name() string {
|
||||||
|
@ -68,9 +73,9 @@ func (m *Module) getClaim(call goja.FunctionCall, rt *goja.Runtime) goja.Value {
|
||||||
panic(rt.ToValue(errors.New("could not find http request in context")))
|
panic(rt.ToValue(errors.New("could not find http request in context")))
|
||||||
}
|
}
|
||||||
|
|
||||||
claim, err := m.getClaims(ctx, req, claimName)
|
claim, err := m.getClaimFn(ctx, req, claimName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, ErrUnauthenticated) {
|
if errors.Is(err, jwtutil.ErrUnauthenticated) {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -78,11 +83,7 @@ func (m *Module) getClaim(call goja.FunctionCall, rt *goja.Runtime) goja.Value {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(claim) == 0 || claim[0] == "" {
|
return rt.ToValue(claim)
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return rt.ToValue(claim[0])
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func ModuleFactory(funcs ...OptionFunc) app.ServerModuleFactory {
|
func ModuleFactory(funcs ...OptionFunc) app.ServerModuleFactory {
|
||||||
|
@ -94,7 +95,7 @@ func ModuleFactory(funcs ...OptionFunc) app.ServerModuleFactory {
|
||||||
return func(server *app.Server) app.ServerModule {
|
return func(server *app.Server) app.ServerModule {
|
||||||
return &Module{
|
return &Module{
|
||||||
server: server,
|
server: server,
|
||||||
getClaims: opt.GetClaims,
|
getClaimFn: opt.GetClaim,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -10,6 +10,7 @@ import (
|
||||||
"cdr.dev/slog"
|
"cdr.dev/slog"
|
||||||
"forge.cadoles.com/arcad/edge/pkg/app"
|
"forge.cadoles.com/arcad/edge/pkg/app"
|
||||||
edgeHTTP "forge.cadoles.com/arcad/edge/pkg/http"
|
edgeHTTP "forge.cadoles.com/arcad/edge/pkg/http"
|
||||||
|
"forge.cadoles.com/arcad/edge/pkg/jwtutil"
|
||||||
"forge.cadoles.com/arcad/edge/pkg/module"
|
"forge.cadoles.com/arcad/edge/pkg/module"
|
||||||
"github.com/lestrrat-go/jwx/v2/jwa"
|
"github.com/lestrrat-go/jwx/v2/jwa"
|
||||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||||
|
@ -130,7 +131,7 @@ func getDummyKey() jwk.Key {
|
||||||
return key
|
return key
|
||||||
}
|
}
|
||||||
|
|
||||||
func getDummyKeySet(key jwk.Key) GetKeySetFunc {
|
func getDummyKeySet(key jwk.Key) jwtutil.GetKeySetFunc {
|
||||||
return func() (jwk.Set, error) {
|
return func() (jwk.Set, error) {
|
||||||
set := jwk.NewSet()
|
set := jwk.NewSet()
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,7 @@ package auth
|
||||||
import (
|
import (
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"forge.cadoles.com/arcad/edge/pkg/jwtutil"
|
||||||
"github.com/go-chi/chi/v5"
|
"github.com/go-chi/chi/v5"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"gitlab.com/wpetit/goweb/api"
|
"gitlab.com/wpetit/goweb/api"
|
||||||
|
@ -12,16 +13,19 @@ import (
|
||||||
type MountFunc func(r chi.Router)
|
type MountFunc func(r chi.Router)
|
||||||
|
|
||||||
type Handler struct {
|
type Handler struct {
|
||||||
getClaims GetClaimsFunc
|
getClaim GetClaimFunc
|
||||||
profileClaims []string
|
profileClaims []string
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *Handler) serveProfile(w http.ResponseWriter, r *http.Request) {
|
func (h *Handler) serveProfile(w http.ResponseWriter, r *http.Request) {
|
||||||
ctx := r.Context()
|
ctx := r.Context()
|
||||||
|
|
||||||
claims, err := h.getClaims(ctx, r, h.profileClaims...)
|
profile := make(map[string]any)
|
||||||
|
|
||||||
|
for _, name := range h.profileClaims {
|
||||||
|
value, err := h.getClaim(ctx, r, name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, ErrUnauthenticated) {
|
if errors.Is(err, jwtutil.ErrUnauthenticated) {
|
||||||
api.ErrorResponse(
|
api.ErrorResponse(
|
||||||
w, http.StatusUnauthorized,
|
w, http.StatusUnauthorized,
|
||||||
api.ErrCodeUnauthorized,
|
api.ErrCodeUnauthorized,
|
||||||
|
@ -41,10 +45,7 @@ func (h *Handler) serveProfile(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
profile := make(map[string]any)
|
profile[name] = value
|
||||||
|
|
||||||
for idx, cl := range h.profileClaims {
|
|
||||||
profile[cl] = claims[idx]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
api.DataResponse(w, http.StatusOK, struct {
|
api.DataResponse(w, http.StatusOK, struct {
|
||||||
|
@ -62,7 +63,7 @@ func Mount(authHandler http.Handler, funcs ...OptionFunc) MountFunc {
|
||||||
|
|
||||||
handler := &Handler{
|
handler := &Handler{
|
||||||
profileClaims: opt.ProfileClaims,
|
profileClaims: opt.ProfileClaims,
|
||||||
getClaims: opt.GetClaims,
|
getClaim: opt.GetClaim,
|
||||||
}
|
}
|
||||||
|
|
||||||
return func(r chi.Router) {
|
return func(r chi.Router) {
|
||||||
|
|
|
@ -2,15 +2,17 @@ package auth
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
|
||||||
|
"forge.cadoles.com/arcad/edge/pkg/jwtutil"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
)
|
)
|
||||||
|
|
||||||
type GetClaimsFunc func(ctx context.Context, r *http.Request, claims ...string) ([]string, error)
|
type GetClaimFunc func(ctx context.Context, r *http.Request, name string) (string, error)
|
||||||
|
|
||||||
type Option struct {
|
type Option struct {
|
||||||
GetClaims GetClaimsFunc
|
GetClaim GetClaimFunc
|
||||||
ProfileClaims []string
|
ProfileClaims []string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -18,7 +20,7 @@ type OptionFunc func(*Option)
|
||||||
|
|
||||||
func defaultOptions() *Option {
|
func defaultOptions() *Option {
|
||||||
return &Option{
|
return &Option{
|
||||||
GetClaims: dummyGetClaims,
|
GetClaim: dummyGetClaim,
|
||||||
ProfileClaims: []string{
|
ProfileClaims: []string{
|
||||||
ClaimSubject,
|
ClaimSubject,
|
||||||
ClaimIssuer,
|
ClaimIssuer,
|
||||||
|
@ -30,13 +32,13 @@ func defaultOptions() *Option {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func dummyGetClaims(ctx context.Context, r *http.Request, claims ...string) ([]string, error) {
|
func dummyGetClaim(ctx context.Context, r *http.Request, name string) (string, error) {
|
||||||
return nil, errors.Errorf("dummy getclaim func cannot retrieve claims '%s'", claims)
|
return "", errors.Errorf("dummy getclaim func cannot retrieve claim '%s'", name)
|
||||||
}
|
}
|
||||||
|
|
||||||
func WithGetClaims(fn GetClaimsFunc) OptionFunc {
|
func WithGetClaims(fn GetClaimFunc) OptionFunc {
|
||||||
return func(o *Option) {
|
return func(o *Option) {
|
||||||
o.GetClaims = fn
|
o.GetClaim = fn
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -45,3 +47,34 @@ func WithProfileClaims(claims ...string) OptionFunc {
|
||||||
o.ProfileClaims = claims
|
o.ProfileClaims = claims
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithJWT(getKeySet jwtutil.GetKeySetFunc) OptionFunc {
|
||||||
|
funcs := []jwtutil.FindTokenOptionFunc{
|
||||||
|
jwtutil.WithFinders(
|
||||||
|
jwtutil.FindTokenFromAuthorizationHeader,
|
||||||
|
jwtutil.FindTokenFromQueryString(CookieName),
|
||||||
|
jwtutil.FindTokenFromCookie(CookieName),
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
return func(o *Option) {
|
||||||
|
o.GetClaim = func(ctx context.Context, r *http.Request, name string) (string, error) {
|
||||||
|
token, err := jwtutil.FindToken(r, getKeySet, funcs...)
|
||||||
|
if err != nil {
|
||||||
|
return "", errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenMap, err := token.AsMap(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return "", errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
value, exists := tokenMap[name]
|
||||||
|
if !exists {
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf("%v", value), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
|
@ -6,11 +6,13 @@ import (
|
||||||
"forge.cadoles.com/arcad/edge/pkg/storage"
|
"forge.cadoles.com/arcad/edge/pkg/storage"
|
||||||
"forge.cadoles.com/arcad/edge/pkg/storage/driver"
|
"forge.cadoles.com/arcad/edge/pkg/storage/driver"
|
||||||
"forge.cadoles.com/arcad/edge/pkg/storage/driver/rpc/client"
|
"forge.cadoles.com/arcad/edge/pkg/storage/driver/rpc/client"
|
||||||
|
"forge.cadoles.com/arcad/edge/pkg/storage/share"
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
driver.RegisterDocumentStoreFactory("rpc", documentStoreFactory)
|
driver.RegisterDocumentStoreFactory("rpc", documentStoreFactory)
|
||||||
driver.RegisterBlobStoreFactory("rpc", blobStoreFactory)
|
driver.RegisterBlobStoreFactory("rpc", blobStoreFactory)
|
||||||
|
driver.RegisterShareStoreFactory("rpc", shareStoreFactory)
|
||||||
}
|
}
|
||||||
|
|
||||||
func documentStoreFactory(url *url.URL) (storage.DocumentStore, error) {
|
func documentStoreFactory(url *url.URL) (storage.DocumentStore, error) {
|
||||||
|
@ -20,3 +22,7 @@ func documentStoreFactory(url *url.URL) (storage.DocumentStore, error) {
|
||||||
func blobStoreFactory(url *url.URL) (storage.BlobStore, error) {
|
func blobStoreFactory(url *url.URL) (storage.BlobStore, error) {
|
||||||
return client.NewBlobStore(url), nil
|
return client.NewBlobStore(url), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func shareStoreFactory(url *url.URL) (share.Store, error) {
|
||||||
|
return client.NewShareStore(url), nil
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue