feat(module,auth): authentication module with arbitrary claims support
This commit is contained in:
8
pkg/module/auth/error.go
Normal file
8
pkg/module/auth/error.go
Normal file
@ -0,0 +1,8 @@
|
||||
package auth
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrUnauthenticated = errors.New("unauthenticated")
|
||||
ErrClaimNotFound = errors.New("claim not found")
|
||||
)
|
60
pkg/module/auth/jwt.go
Normal file
60
pkg/module/auth/jwt.go
Normal file
@ -0,0 +1,60 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func WithJWT(keyFunc jwt.Keyfunc) OptionFunc {
|
||||
return func(o *Option) {
|
||||
o.GetClaim = func(ctx context.Context, r *http.Request, claimName string) (string, error) {
|
||||
claim, err := getClaim[string](r, claimName, keyFunc)
|
||||
if err != nil {
|
||||
return "", errors.WithStack(err)
|
||||
}
|
||||
|
||||
return claim, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func getClaim[T any](r *http.Request, claimAttr string, keyFunc jwt.Keyfunc) (T, error) {
|
||||
rawToken := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ")
|
||||
if rawToken == "" {
|
||||
rawToken = r.URL.Query().Get("token")
|
||||
}
|
||||
|
||||
if rawToken == "" {
|
||||
return *new(T), errors.WithStack(ErrUnauthenticated)
|
||||
}
|
||||
|
||||
token, err := jwt.Parse(rawToken, keyFunc)
|
||||
if err != nil {
|
||||
return *new(T), errors.WithStack(err)
|
||||
}
|
||||
|
||||
if !token.Valid {
|
||||
return *new(T), errors.Errorf("invalid jwt token: '%v'", token.Raw)
|
||||
}
|
||||
|
||||
mapClaims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
return *new(T), errors.Errorf("unexpected claims type '%T'", token.Claims)
|
||||
}
|
||||
|
||||
rawClaim, exists := mapClaims[claimAttr]
|
||||
if !exists {
|
||||
return *new(T), errors.WithStack(ErrClaimNotFound)
|
||||
}
|
||||
|
||||
claim, ok := rawClaim.(T)
|
||||
if !ok {
|
||||
return *new(T), errors.Errorf("unexpected claim '%s' to be of type '%T', got '%T'", claimAttr, new(T), rawClaim)
|
||||
}
|
||||
|
||||
return claim, nil
|
||||
}
|
@ -2,23 +2,21 @@ package auth
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"forge.cadoles.com/arcad/edge/pkg/app"
|
||||
"forge.cadoles.com/arcad/edge/pkg/module"
|
||||
edgeHTTP "forge.cadoles.com/arcad/edge/pkg/http"
|
||||
"forge.cadoles.com/arcad/edge/pkg/module/util"
|
||||
"github.com/dop251/goja"
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
AnonymousSubject = "anonymous"
|
||||
ClaimSubject = "sub"
|
||||
)
|
||||
|
||||
type Module struct {
|
||||
server *app.Server
|
||||
keyFunc jwt.Keyfunc
|
||||
server *app.Server
|
||||
getClaimFunc GetClaimFunc
|
||||
}
|
||||
|
||||
func (m *Module) Name() string {
|
||||
@ -26,59 +24,46 @@ func (m *Module) Name() string {
|
||||
}
|
||||
|
||||
func (m *Module) Export(export *goja.Object) {
|
||||
if err := export.Set("getSubject", m.getSubject); err != nil {
|
||||
panic(errors.Wrap(err, "could not set 'getSubject' function"))
|
||||
if err := export.Set("getClaim", m.getClaim); err != nil {
|
||||
panic(errors.Wrap(err, "could not set 'getClaim' function"))
|
||||
}
|
||||
|
||||
if err := export.Set("ANONYMOUS", AnonymousSubject); err != nil {
|
||||
panic(errors.Wrap(err, "could not set 'ANONYMOUS_USER' property"))
|
||||
if err := export.Set("CLAIM_SUBJECT", ClaimSubject); err != nil {
|
||||
panic(errors.Wrap(err, "could not set 'CLAIM_SUBJECT' property"))
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Module) getSubject(call goja.FunctionCall, rt *goja.Runtime) goja.Value {
|
||||
func (m *Module) getClaim(call goja.FunctionCall, rt *goja.Runtime) goja.Value {
|
||||
ctx := util.AssertContext(call.Argument(0), rt)
|
||||
claimName := util.AssertString(call.Argument(1), rt)
|
||||
|
||||
req, ok := ctx.Value(module.ContextKeyOriginRequest).(*http.Request)
|
||||
req, ok := ctx.Value(edgeHTTP.ContextKeyOriginRequest).(*http.Request)
|
||||
if !ok {
|
||||
panic(errors.New("could not find http request in context"))
|
||||
panic(rt.ToValue(errors.New("could not find http request in context")))
|
||||
}
|
||||
|
||||
rawToken := strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ")
|
||||
if rawToken == "" {
|
||||
rawToken = req.URL.Query().Get("token")
|
||||
}
|
||||
|
||||
if rawToken == "" {
|
||||
return rt.ToValue(AnonymousSubject)
|
||||
}
|
||||
|
||||
token, err := jwt.Parse(rawToken, m.keyFunc)
|
||||
claim, err := m.getClaimFunc(ctx, req, claimName)
|
||||
if err != nil {
|
||||
panic(errors.WithStack(err))
|
||||
if errors.Is(err, ErrUnauthenticated) || errors.Is(err, ErrClaimNotFound) {
|
||||
return nil
|
||||
}
|
||||
|
||||
panic(rt.ToValue(errors.WithStack(err)))
|
||||
}
|
||||
|
||||
if !token.Valid {
|
||||
panic(errors.Errorf("invalid jwt token: '%v'", token.Raw))
|
||||
}
|
||||
|
||||
mapClaims, ok := token.Claims.(jwt.MapClaims)
|
||||
if !ok {
|
||||
panic(errors.Errorf("unexpected claims type '%T'", token.Claims))
|
||||
}
|
||||
|
||||
subject, exists := mapClaims["sub"]
|
||||
if !exists {
|
||||
return rt.ToValue(AnonymousSubject)
|
||||
}
|
||||
|
||||
return rt.ToValue(subject)
|
||||
return rt.ToValue(claim)
|
||||
}
|
||||
|
||||
func ModuleFactory(keyFunc jwt.Keyfunc) app.ServerModuleFactory {
|
||||
func ModuleFactory(funcs ...OptionFunc) app.ServerModuleFactory {
|
||||
opt := &Option{}
|
||||
for _, fn := range funcs {
|
||||
fn(opt)
|
||||
}
|
||||
|
||||
return func(server *app.Server) app.ServerModule {
|
||||
return &Module{
|
||||
server: server,
|
||||
keyFunc: keyFunc,
|
||||
server: server,
|
||||
getClaimFunc: opt.GetClaim,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
@ -10,6 +10,7 @@ import (
|
||||
|
||||
"cdr.dev/slog"
|
||||
"forge.cadoles.com/arcad/edge/pkg/app"
|
||||
edgeHTTP "forge.cadoles.com/arcad/edge/pkg/http"
|
||||
"forge.cadoles.com/arcad/edge/pkg/module"
|
||||
"github.com/golang-jwt/jwt"
|
||||
"github.com/pkg/errors"
|
||||
@ -25,7 +26,9 @@ func TestAuthModule(t *testing.T) {
|
||||
|
||||
server := app.NewServer(
|
||||
module.ConsoleModuleFactory(),
|
||||
ModuleFactory(keyFunc),
|
||||
ModuleFactory(
|
||||
WithJWT(keyFunc),
|
||||
),
|
||||
)
|
||||
|
||||
data, err := ioutil.ReadFile("testdata/auth.js")
|
||||
@ -60,7 +63,7 @@ func TestAuthModule(t *testing.T) {
|
||||
|
||||
req.Header.Add("Authorization", "Bearer "+rawToken)
|
||||
|
||||
ctx := context.WithValue(context.Background(), module.ContextKeyOriginRequest, req)
|
||||
ctx := context.WithValue(context.Background(), edgeHTTP.ContextKeyOriginRequest, req)
|
||||
|
||||
if _, err := server.ExecFuncByName("testAuth", ctx); err != nil {
|
||||
t.Fatalf("%+v", errors.WithStack(err))
|
||||
@ -76,7 +79,7 @@ func TestAuthAnonymousModule(t *testing.T) {
|
||||
|
||||
server := app.NewServer(
|
||||
module.ConsoleModuleFactory(),
|
||||
ModuleFactory(keyFunc),
|
||||
ModuleFactory(WithJWT(keyFunc)),
|
||||
)
|
||||
|
||||
data, err := ioutil.ReadFile("testdata/auth_anonymous.js")
|
||||
@ -99,7 +102,7 @@ func TestAuthAnonymousModule(t *testing.T) {
|
||||
t.Fatalf("%+v", errors.WithStack(err))
|
||||
}
|
||||
|
||||
ctx := context.WithValue(context.Background(), module.ContextKeyOriginRequest, req)
|
||||
ctx := context.WithValue(context.Background(), edgeHTTP.ContextKeyOriginRequest, req)
|
||||
|
||||
if _, err := server.ExecFuncByName("testAuth", ctx); err != nil {
|
||||
t.Fatalf("%+v", errors.WithStack(err))
|
||||
|
20
pkg/module/auth/option.go
Normal file
20
pkg/module/auth/option.go
Normal file
@ -0,0 +1,20 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
type GetClaimFunc func(ctx context.Context, r *http.Request, claimName string) (string, error)
|
||||
|
||||
type Option struct {
|
||||
GetClaim GetClaimFunc
|
||||
}
|
||||
|
||||
type OptionFunc func(*Option)
|
||||
|
||||
func WithGetClaim(fn GetClaimFunc) OptionFunc {
|
||||
return func(o *Option) {
|
||||
o.GetClaim = fn
|
||||
}
|
||||
}
|
2
pkg/module/auth/testdata/auth.js
vendored
2
pkg/module/auth/testdata/auth.js
vendored
@ -1,7 +1,7 @@
|
||||
|
||||
|
||||
function testAuth(ctx) {
|
||||
var subject = auth.getSubject(ctx);
|
||||
var subject = auth.getClaim(ctx, auth.CLAIM_SUBJECT);
|
||||
|
||||
if (subject !== "jdoe") {
|
||||
throw new Error("subject: expected 'jdoe', got '"+subject+"'");
|
||||
|
6
pkg/module/auth/testdata/auth_anonymous.js
vendored
6
pkg/module/auth/testdata/auth_anonymous.js
vendored
@ -1,9 +1,9 @@
|
||||
|
||||
|
||||
function testAuth(ctx) {
|
||||
var subject = auth.getSubject(ctx);
|
||||
var subject = auth.getClaim(ctx, auth.CLAIM_SUBJECT);
|
||||
|
||||
if (subject !== auth.ANONYMOUS) {
|
||||
throw new Error("subject: expected '"+auth.ANONYMOUS+"', got '"+subject+"'");
|
||||
if (subject !== undefined) {
|
||||
throw new Error("subject: expected undefined, got '"+subject+"'");
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user