feat(storage-server): jwt based authentication
All checks were successful
arcad/edge/pipeline/pr-master This commit looks good
All checks were successful
arcad/edge/pipeline/pr-master This commit looks good
This commit is contained in:
@ -1,5 +1,6 @@
|
||||
package auth
|
||||
package jwtutil
|
||||
|
||||
import "errors"
|
||||
|
||||
var ErrUnauthenticated = errors.New("unauthenticated")
|
||||
var ErrNoKeySet = errors.New("no keyset")
|
71
pkg/jwtutil/io.go
Normal file
71
pkg/jwtutil/io.go
Normal file
@ -0,0 +1,71 @@
|
||||
package jwtutil
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"os"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func LoadOrGenerateKey(path string, defaultKeySize int) (jwk.Key, error) {
|
||||
key, err := LoadKey(path)
|
||||
if err != nil {
|
||||
if !errors.Is(err, os.ErrNotExist) {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
key, err = GenerateKey(defaultKeySize)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
if err := SaveKey(path, key); err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func LoadKey(path string) (jwk.Key, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
key, err := jwk.ParseKey(data, jwk.WithPEM(true))
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func SaveKey(path string, key jwk.Key) error {
|
||||
data, err := jwk.Pem(key)
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
if err := os.WriteFile(path, data, os.FileMode(0600)); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func GenerateKey(keySize int) (jwk.Key, error) {
|
||||
rsaKey, err := rsa.GenerateKey(rand.Reader, keySize)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
key, err := jwk.FromRaw(rsaKey)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
52
pkg/jwtutil/key.go
Normal file
52
pkg/jwtutil/key.go
Normal file
@ -0,0 +1,52 @@
|
||||
package jwtutil
|
||||
|
||||
import (
|
||||
"github.com/lestrrat-go/jwx/v2/jwa"
|
||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func NewKeySet(keys ...jwk.Key) (jwk.Set, error) {
|
||||
set := jwk.NewSet()
|
||||
|
||||
for _, k := range keys {
|
||||
if err := set.AddKey(k); err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
}
|
||||
|
||||
return set, nil
|
||||
}
|
||||
|
||||
func NewSymmetricKey(secret []byte) (jwk.Key, error) {
|
||||
key, err := jwk.FromRaw(secret)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
if err := key.Set(jwk.AlgorithmKey, jwa.HS256); err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return key, nil
|
||||
}
|
||||
|
||||
func NewSymmetricKeySet(secrets ...[]byte) (jwk.Set, error) {
|
||||
keys := make([]jwk.Key, len(secrets))
|
||||
|
||||
for idx, sec := range secrets {
|
||||
key, err := NewSymmetricKey(sec)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
keys[idx] = key
|
||||
}
|
||||
|
||||
keySet, err := NewKeySet(keys...)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return keySet, nil
|
||||
}
|
123
pkg/jwtutil/request.go
Normal file
123
pkg/jwtutil/request.go
Normal file
@ -0,0 +1,123 @@
|
||||
package jwtutil
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||
"github.com/lestrrat-go/jwx/v2/jws"
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type TokenFinderFunc func(r *http.Request) (string, error)
|
||||
|
||||
type FindTokenOptions struct {
|
||||
Finders []TokenFinderFunc
|
||||
}
|
||||
|
||||
type FindTokenOptionFunc func(*FindTokenOptions)
|
||||
|
||||
type GetKeySetFunc func() (jwk.Set, error)
|
||||
|
||||
func WithFinders(finders ...TokenFinderFunc) FindTokenOptionFunc {
|
||||
return func(opts *FindTokenOptions) {
|
||||
opts.Finders = finders
|
||||
}
|
||||
}
|
||||
|
||||
func NewFindTokenOptions(funcs ...FindTokenOptionFunc) *FindTokenOptions {
|
||||
opts := &FindTokenOptions{
|
||||
Finders: []TokenFinderFunc{
|
||||
FindTokenFromAuthorizationHeader,
|
||||
},
|
||||
}
|
||||
|
||||
for _, fn := range funcs {
|
||||
fn(opts)
|
||||
}
|
||||
|
||||
return opts
|
||||
}
|
||||
|
||||
func FindTokenFromAuthorizationHeader(r *http.Request) (string, error) {
|
||||
authorization := r.Header.Get("Authorization")
|
||||
|
||||
// Retrieve token from Authorization header
|
||||
rawToken := strings.TrimPrefix(authorization, "Bearer ")
|
||||
|
||||
return rawToken, nil
|
||||
}
|
||||
|
||||
func FindTokenFromQueryString(name string) TokenFinderFunc {
|
||||
return func(r *http.Request) (string, error) {
|
||||
return r.URL.Query().Get(name), nil
|
||||
}
|
||||
}
|
||||
|
||||
func FindTokenFromCookie(cookieName string) TokenFinderFunc {
|
||||
return func(r *http.Request) (string, error) {
|
||||
cookie, err := r.Cookie(cookieName)
|
||||
if err != nil && !errors.Is(err, http.ErrNoCookie) {
|
||||
return "", errors.WithStack(err)
|
||||
}
|
||||
|
||||
if cookie == nil {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
return cookie.Value, nil
|
||||
}
|
||||
}
|
||||
|
||||
func FindRawToken(r *http.Request, funcs ...FindTokenOptionFunc) (string, error) {
|
||||
opts := NewFindTokenOptions(funcs...)
|
||||
|
||||
var rawToken string
|
||||
var err error
|
||||
|
||||
for _, find := range opts.Finders {
|
||||
rawToken, err = find(r)
|
||||
if err != nil {
|
||||
return "", errors.WithStack(err)
|
||||
}
|
||||
|
||||
if rawToken == "" {
|
||||
continue
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
if rawToken == "" {
|
||||
return "", errors.WithStack(ErrUnauthenticated)
|
||||
}
|
||||
|
||||
return rawToken, nil
|
||||
}
|
||||
|
||||
func FindToken(r *http.Request, getKeySet GetKeySetFunc, funcs ...FindTokenOptionFunc) (jwt.Token, error) {
|
||||
rawToken, err := FindRawToken(r, funcs...)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
keySet, err := getKeySet()
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
if keySet == nil {
|
||||
return nil, errors.WithStack(ErrNoKeySet)
|
||||
}
|
||||
|
||||
token, err := jwt.Parse([]byte(rawToken),
|
||||
jwt.WithKeySet(keySet, jws.WithRequireKid(false)),
|
||||
jwt.WithValidate(true),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
@ -1,4 +1,4 @@
|
||||
package jwt
|
||||
package jwtutil
|
||||
|
||||
import (
|
||||
"time"
|
||||
@ -9,7 +9,7 @@ import (
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func GenerateSignedToken(algo jwa.KeyAlgorithm, key jwk.Key, claims map[string]any) ([]byte, error) {
|
||||
func SignedToken(key jwk.Key, signingAlgorithm jwa.SignatureAlgorithm, claims map[string]any) ([]byte, error) {
|
||||
token := jwt.New()
|
||||
|
||||
if err := token.Set(jwt.NotBeforeKey, time.Now()); err != nil {
|
||||
@ -22,11 +22,11 @@ func GenerateSignedToken(algo jwa.KeyAlgorithm, key jwk.Key, claims map[string]a
|
||||
}
|
||||
}
|
||||
|
||||
if err := token.Set(jwk.AlgorithmKey, jwa.HS256); err != nil {
|
||||
if err := token.Set(jwk.AlgorithmKey, signingAlgorithm); err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
rawToken, err := jwt.Sign(token, jwt.WithKey(algo, key))
|
||||
rawToken, err := jwt.Sign(token, jwt.WithKey(signingAlgorithm, key))
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
@ -7,9 +7,9 @@ import (
|
||||
|
||||
_ "embed"
|
||||
|
||||
"forge.cadoles.com/arcad/edge/pkg/jwtutil"
|
||||
"forge.cadoles.com/arcad/edge/pkg/module/auth"
|
||||
"forge.cadoles.com/arcad/edge/pkg/module/auth/http/passwd"
|
||||
"forge.cadoles.com/arcad/edge/pkg/module/auth/jwt"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/lestrrat-go/jwx/v2/jwa"
|
||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||
@ -31,12 +31,12 @@ func init() {
|
||||
}
|
||||
|
||||
type LocalHandler struct {
|
||||
router chi.Router
|
||||
algo jwa.KeyAlgorithm
|
||||
key jwk.Key
|
||||
getCookieDomain GetCookieDomainFunc
|
||||
cookieDuration time.Duration
|
||||
accounts map[string]LocalAccount
|
||||
router chi.Router
|
||||
key jwk.Key
|
||||
signingAlgorithm jwa.SignatureAlgorithm
|
||||
getCookieDomain GetCookieDomainFunc
|
||||
cookieDuration time.Duration
|
||||
accounts map[string]LocalAccount
|
||||
}
|
||||
|
||||
func (h *LocalHandler) initRouter(prefix string) {
|
||||
@ -113,7 +113,7 @@ func (h *LocalHandler) handleForm(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
account.Claims[auth.ClaimIssuer] = "local"
|
||||
|
||||
token, err := jwt.GenerateSignedToken(h.algo, h.key, account.Claims)
|
||||
token, err := jwtutil.SignedToken(h.key, h.signingAlgorithm, account.Claims)
|
||||
if err != nil {
|
||||
logger.Error(ctx, "could not generate signed token", logger.E(errors.WithStack(err)))
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
@ -182,18 +182,18 @@ func (h *LocalHandler) authenticate(username, password string) (*LocalAccount, e
|
||||
return &account, nil
|
||||
}
|
||||
|
||||
func NewLocalHandler(algo jwa.KeyAlgorithm, key jwk.Key, funcs ...LocalHandlerOptionFunc) *LocalHandler {
|
||||
func NewLocalHandler(key jwk.Key, signingAlgorithm jwa.SignatureAlgorithm, funcs ...LocalHandlerOptionFunc) *LocalHandler {
|
||||
opts := defaultLocalHandlerOptions()
|
||||
for _, fn := range funcs {
|
||||
fn(opts)
|
||||
}
|
||||
|
||||
handler := &LocalHandler{
|
||||
algo: algo,
|
||||
key: key,
|
||||
accounts: toAccountsMap(opts.Accounts),
|
||||
getCookieDomain: opts.GetCookieDomain,
|
||||
cookieDuration: opts.CookieDuration,
|
||||
key: key,
|
||||
signingAlgorithm: signingAlgorithm,
|
||||
accounts: toAccountsMap(opts.Accounts),
|
||||
getCookieDomain: opts.GetCookieDomain,
|
||||
cookieDuration: opts.CookieDuration,
|
||||
}
|
||||
|
||||
handler.initRouter(opts.RoutePrefix)
|
||||
|
@ -1,118 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||
"github.com/lestrrat-go/jwx/v2/jws"
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
CookieName string = "edge-auth"
|
||||
)
|
||||
|
||||
type GetKeySetFunc func() (jwk.Set, error)
|
||||
|
||||
func WithJWT(getKeySet GetKeySetFunc) OptionFunc {
|
||||
return func(o *Option) {
|
||||
o.GetClaims = func(ctx context.Context, r *http.Request, names ...string) ([]string, error) {
|
||||
claim, err := getClaims[string](r, getKeySet, names...)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return claim, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func FindRawToken(r *http.Request) (string, error) {
|
||||
authorization := r.Header.Get("Authorization")
|
||||
|
||||
// Retrieve token from Authorization header
|
||||
rawToken := strings.TrimPrefix(authorization, "Bearer ")
|
||||
|
||||
// Retrieve token from ?edge-auth=<value>
|
||||
if rawToken == "" {
|
||||
rawToken = r.URL.Query().Get(CookieName)
|
||||
}
|
||||
|
||||
if rawToken == "" {
|
||||
cookie, err := r.Cookie(CookieName)
|
||||
if err != nil && !errors.Is(err, http.ErrNoCookie) {
|
||||
return "", errors.WithStack(err)
|
||||
}
|
||||
|
||||
if cookie != nil {
|
||||
rawToken = cookie.Value
|
||||
}
|
||||
}
|
||||
|
||||
if rawToken == "" {
|
||||
return "", errors.WithStack(ErrUnauthenticated)
|
||||
}
|
||||
|
||||
return rawToken, nil
|
||||
}
|
||||
|
||||
func FindToken(r *http.Request, getKeySet GetKeySetFunc) (jwt.Token, error) {
|
||||
rawToken, err := FindRawToken(r)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
keySet, err := getKeySet()
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
if keySet == nil {
|
||||
return nil, errors.New("no keyset")
|
||||
}
|
||||
|
||||
token, err := jwt.Parse([]byte(rawToken),
|
||||
jwt.WithKeySet(keySet, jws.WithRequireKid(false)),
|
||||
jwt.WithValidate(true),
|
||||
)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func getClaims[T any](r *http.Request, getKeySet GetKeySetFunc, names ...string) ([]T, error) {
|
||||
token, err := FindToken(r, getKeySet)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
ctx := r.Context()
|
||||
|
||||
mapClaims, err := token.AsMap(ctx)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
claims := make([]T, len(names))
|
||||
|
||||
for idx, n := range names {
|
||||
rawClaim, exists := mapClaims[n]
|
||||
if !exists {
|
||||
continue
|
||||
}
|
||||
|
||||
claim, ok := rawClaim.(T)
|
||||
if !ok {
|
||||
return nil, errors.Errorf("unexpected claim '%s' to be of type '%T', got '%T'", n, new(T), rawClaim)
|
||||
}
|
||||
|
||||
claims[idx] = claim
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
@ -7,8 +7,8 @@ import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"forge.cadoles.com/arcad/edge/pkg/jwtutil"
|
||||
"forge.cadoles.com/arcad/edge/pkg/module/auth"
|
||||
"forge.cadoles.com/arcad/edge/pkg/module/auth/jwt"
|
||||
"github.com/google/uuid"
|
||||
"github.com/lestrrat-go/jwx/v2/jwa"
|
||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||
@ -18,7 +18,7 @@ import (
|
||||
|
||||
const AnonIssuer = "anon"
|
||||
|
||||
func AnonymousUser(algo jwa.KeyAlgorithm, key jwk.Key, funcs ...AnonymousUserOptionFunc) func(next http.Handler) http.Handler {
|
||||
func AnonymousUser(key jwk.Key, signingAlgorithm jwa.SignatureAlgorithm, funcs ...AnonymousUserOptionFunc) func(next http.Handler) http.Handler {
|
||||
opts := defaultAnonymousUserOptions()
|
||||
for _, fn := range funcs {
|
||||
fn(opts)
|
||||
@ -26,7 +26,11 @@ func AnonymousUser(algo jwa.KeyAlgorithm, key jwk.Key, funcs ...AnonymousUserOpt
|
||||
|
||||
return func(next http.Handler) http.Handler {
|
||||
handler := func(w http.ResponseWriter, r *http.Request) {
|
||||
rawToken, err := auth.FindRawToken(r)
|
||||
rawToken, err := jwtutil.FindRawToken(r, jwtutil.WithFinders(
|
||||
jwtutil.FindTokenFromAuthorizationHeader,
|
||||
jwtutil.FindTokenFromQueryString(auth.CookieName),
|
||||
jwtutil.FindTokenFromCookie(auth.CookieName),
|
||||
))
|
||||
|
||||
// If request already has a raw token, we do nothing
|
||||
if rawToken != "" && err == nil {
|
||||
@ -62,7 +66,7 @@ func AnonymousUser(algo jwa.KeyAlgorithm, key jwk.Key, funcs ...AnonymousUserOpt
|
||||
auth.ClaimEdgeTenant: opts.Tenant,
|
||||
}
|
||||
|
||||
token, err := jwt.GenerateSignedToken(algo, key, claims)
|
||||
token, err := jwtutil.SignedToken(key, signingAlgorithm, claims)
|
||||
if err != nil {
|
||||
logger.Error(ctx, "could not generate signed token", logger.E(errors.WithStack(err)))
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
|
@ -5,12 +5,17 @@ import (
|
||||
|
||||
"forge.cadoles.com/arcad/edge/pkg/app"
|
||||
edgeHTTP "forge.cadoles.com/arcad/edge/pkg/http"
|
||||
"forge.cadoles.com/arcad/edge/pkg/jwtutil"
|
||||
"forge.cadoles.com/arcad/edge/pkg/module/util"
|
||||
"github.com/dop251/goja"
|
||||
"github.com/pkg/errors"
|
||||
"gitlab.com/wpetit/goweb/logger"
|
||||
)
|
||||
|
||||
const (
|
||||
CookieName string = "edge-auth"
|
||||
)
|
||||
|
||||
const (
|
||||
ClaimSubject = "sub"
|
||||
ClaimIssuer = "iss"
|
||||
@ -21,8 +26,8 @@ const (
|
||||
)
|
||||
|
||||
type Module struct {
|
||||
server *app.Server
|
||||
getClaims GetClaimsFunc
|
||||
server *app.Server
|
||||
getClaimFn GetClaimFunc
|
||||
}
|
||||
|
||||
func (m *Module) Name() string {
|
||||
@ -68,9 +73,9 @@ func (m *Module) getClaim(call goja.FunctionCall, rt *goja.Runtime) goja.Value {
|
||||
panic(rt.ToValue(errors.New("could not find http request in context")))
|
||||
}
|
||||
|
||||
claim, err := m.getClaims(ctx, req, claimName)
|
||||
claim, err := m.getClaimFn(ctx, req, claimName)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUnauthenticated) {
|
||||
if errors.Is(err, jwtutil.ErrUnauthenticated) {
|
||||
return nil
|
||||
}
|
||||
|
||||
@ -78,11 +83,7 @@ func (m *Module) getClaim(call goja.FunctionCall, rt *goja.Runtime) goja.Value {
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(claim) == 0 || claim[0] == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
return rt.ToValue(claim[0])
|
||||
return rt.ToValue(claim)
|
||||
}
|
||||
|
||||
func ModuleFactory(funcs ...OptionFunc) app.ServerModuleFactory {
|
||||
@ -93,8 +94,8 @@ func ModuleFactory(funcs ...OptionFunc) app.ServerModuleFactory {
|
||||
|
||||
return func(server *app.Server) app.ServerModule {
|
||||
return &Module{
|
||||
server: server,
|
||||
getClaims: opt.GetClaims,
|
||||
server: server,
|
||||
getClaimFn: opt.GetClaim,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -10,6 +10,7 @@ import (
|
||||
"cdr.dev/slog"
|
||||
"forge.cadoles.com/arcad/edge/pkg/app"
|
||||
edgeHTTP "forge.cadoles.com/arcad/edge/pkg/http"
|
||||
"forge.cadoles.com/arcad/edge/pkg/jwtutil"
|
||||
"forge.cadoles.com/arcad/edge/pkg/module"
|
||||
"github.com/lestrrat-go/jwx/v2/jwa"
|
||||
"github.com/lestrrat-go/jwx/v2/jwk"
|
||||
@ -130,7 +131,7 @@ func getDummyKey() jwk.Key {
|
||||
return key
|
||||
}
|
||||
|
||||
func getDummyKeySet(key jwk.Key) GetKeySetFunc {
|
||||
func getDummyKeySet(key jwk.Key) jwtutil.GetKeySetFunc {
|
||||
return func() (jwk.Set, error) {
|
||||
set := jwk.NewSet()
|
||||
|
||||
|
@ -3,6 +3,7 @@ package auth
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"forge.cadoles.com/arcad/edge/pkg/jwtutil"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/pkg/errors"
|
||||
"gitlab.com/wpetit/goweb/api"
|
||||
@ -12,39 +13,39 @@ import (
|
||||
type MountFunc func(r chi.Router)
|
||||
|
||||
type Handler struct {
|
||||
getClaims GetClaimsFunc
|
||||
getClaim GetClaimFunc
|
||||
profileClaims []string
|
||||
}
|
||||
|
||||
func (h *Handler) serveProfile(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
claims, err := h.getClaims(ctx, r, h.profileClaims...)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrUnauthenticated) {
|
||||
profile := make(map[string]any)
|
||||
|
||||
for _, name := range h.profileClaims {
|
||||
value, err := h.getClaim(ctx, r, name)
|
||||
if err != nil {
|
||||
if errors.Is(err, jwtutil.ErrUnauthenticated) {
|
||||
api.ErrorResponse(
|
||||
w, http.StatusUnauthorized,
|
||||
api.ErrCodeUnauthorized,
|
||||
nil,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
logger.Error(ctx, "could not retrieve claims", logger.E(errors.WithStack(err)))
|
||||
api.ErrorResponse(
|
||||
w, http.StatusUnauthorized,
|
||||
api.ErrCodeUnauthorized,
|
||||
w, http.StatusInternalServerError,
|
||||
api.ErrCodeUnknownError,
|
||||
nil,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
logger.Error(ctx, "could not retrieve claims", logger.E(errors.WithStack(err)))
|
||||
api.ErrorResponse(
|
||||
w, http.StatusInternalServerError,
|
||||
api.ErrCodeUnknownError,
|
||||
nil,
|
||||
)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
profile := make(map[string]any)
|
||||
|
||||
for idx, cl := range h.profileClaims {
|
||||
profile[cl] = claims[idx]
|
||||
profile[name] = value
|
||||
}
|
||||
|
||||
api.DataResponse(w, http.StatusOK, struct {
|
||||
@ -62,7 +63,7 @@ func Mount(authHandler http.Handler, funcs ...OptionFunc) MountFunc {
|
||||
|
||||
handler := &Handler{
|
||||
profileClaims: opt.ProfileClaims,
|
||||
getClaims: opt.GetClaims,
|
||||
getClaim: opt.GetClaim,
|
||||
}
|
||||
|
||||
return func(r chi.Router) {
|
||||
|
@ -2,15 +2,17 @@ package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"forge.cadoles.com/arcad/edge/pkg/jwtutil"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type GetClaimsFunc func(ctx context.Context, r *http.Request, claims ...string) ([]string, error)
|
||||
type GetClaimFunc func(ctx context.Context, r *http.Request, name string) (string, error)
|
||||
|
||||
type Option struct {
|
||||
GetClaims GetClaimsFunc
|
||||
GetClaim GetClaimFunc
|
||||
ProfileClaims []string
|
||||
}
|
||||
|
||||
@ -18,7 +20,7 @@ type OptionFunc func(*Option)
|
||||
|
||||
func defaultOptions() *Option {
|
||||
return &Option{
|
||||
GetClaims: dummyGetClaims,
|
||||
GetClaim: dummyGetClaim,
|
||||
ProfileClaims: []string{
|
||||
ClaimSubject,
|
||||
ClaimIssuer,
|
||||
@ -30,13 +32,13 @@ func defaultOptions() *Option {
|
||||
}
|
||||
}
|
||||
|
||||
func dummyGetClaims(ctx context.Context, r *http.Request, claims ...string) ([]string, error) {
|
||||
return nil, errors.Errorf("dummy getclaim func cannot retrieve claims '%s'", claims)
|
||||
func dummyGetClaim(ctx context.Context, r *http.Request, name string) (string, error) {
|
||||
return "", errors.Errorf("dummy getclaim func cannot retrieve claim '%s'", name)
|
||||
}
|
||||
|
||||
func WithGetClaims(fn GetClaimsFunc) OptionFunc {
|
||||
func WithGetClaims(fn GetClaimFunc) OptionFunc {
|
||||
return func(o *Option) {
|
||||
o.GetClaims = fn
|
||||
o.GetClaim = fn
|
||||
}
|
||||
}
|
||||
|
||||
@ -45,3 +47,34 @@ func WithProfileClaims(claims ...string) OptionFunc {
|
||||
o.ProfileClaims = claims
|
||||
}
|
||||
}
|
||||
|
||||
func WithJWT(getKeySet jwtutil.GetKeySetFunc) OptionFunc {
|
||||
funcs := []jwtutil.FindTokenOptionFunc{
|
||||
jwtutil.WithFinders(
|
||||
jwtutil.FindTokenFromAuthorizationHeader,
|
||||
jwtutil.FindTokenFromQueryString(CookieName),
|
||||
jwtutil.FindTokenFromCookie(CookieName),
|
||||
),
|
||||
}
|
||||
|
||||
return func(o *Option) {
|
||||
o.GetClaim = func(ctx context.Context, r *http.Request, name string) (string, error) {
|
||||
token, err := jwtutil.FindToken(r, getKeySet, funcs...)
|
||||
if err != nil {
|
||||
return "", errors.WithStack(err)
|
||||
}
|
||||
|
||||
tokenMap, err := token.AsMap(ctx)
|
||||
if err != nil {
|
||||
return "", errors.WithStack(err)
|
||||
}
|
||||
|
||||
value, exists := tokenMap[name]
|
||||
if !exists {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%v", value), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -6,11 +6,13 @@ import (
|
||||
"forge.cadoles.com/arcad/edge/pkg/storage"
|
||||
"forge.cadoles.com/arcad/edge/pkg/storage/driver"
|
||||
"forge.cadoles.com/arcad/edge/pkg/storage/driver/rpc/client"
|
||||
"forge.cadoles.com/arcad/edge/pkg/storage/share"
|
||||
)
|
||||
|
||||
func init() {
|
||||
driver.RegisterDocumentStoreFactory("rpc", documentStoreFactory)
|
||||
driver.RegisterBlobStoreFactory("rpc", blobStoreFactory)
|
||||
driver.RegisterShareStoreFactory("rpc", shareStoreFactory)
|
||||
}
|
||||
|
||||
func documentStoreFactory(url *url.URL) (storage.DocumentStore, error) {
|
||||
@ -20,3 +22,7 @@ func documentStoreFactory(url *url.URL) (storage.DocumentStore, error) {
|
||||
func blobStoreFactory(url *url.URL) (storage.BlobStore, error) {
|
||||
return client.NewBlobStore(url), nil
|
||||
}
|
||||
|
||||
func shareStoreFactory(url *url.URL) (share.Store, error) {
|
||||
return client.NewShareStore(url), nil
|
||||
}
|
||||
|
Reference in New Issue
Block a user