From f01b1ef3b24443fcb3aab270cedb5d5f0ff22e5b Mon Sep 17 00:00:00 2001 From: William Petit Date: Tue, 21 Feb 2023 12:14:29 +0100 Subject: [PATCH] feat(module,auth): auth based on jwt --- go.mod | 1 + go.sum | 2 + pkg/module/assert.go | 28 ----- pkg/module/auth/module.go | 84 ++++++++++++++ pkg/module/auth/module_test.go | 121 +++++++++++++++++++++ pkg/module/auth/testdata/auth.js | 9 ++ pkg/module/auth/testdata/auth_anonymous.js | 9 ++ pkg/module/context.go | 9 +- pkg/module/net.go | 3 +- pkg/module/rpc.go | 5 +- pkg/module/store.go | 9 +- pkg/module/util/assert.go | 28 +++++ 12 files changed, 269 insertions(+), 39 deletions(-) delete mode 100644 pkg/module/assert.go create mode 100644 pkg/module/auth/module.go create mode 100644 pkg/module/auth/module_test.go create mode 100644 pkg/module/auth/testdata/auth.js create mode 100644 pkg/module/auth/testdata/auth_anonymous.js create mode 100644 pkg/module/util/assert.go diff --git a/go.mod b/go.mod index a2ac329..01bcb42 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,7 @@ require modernc.org/sqlite v1.20.4 require ( github.com/gogo/protobuf v0.0.0-20161014173244-50d1bd39ce4e // indirect + github.com/golang-jwt/jwt v3.2.2+incompatible // indirect github.com/hashicorp/go.net v0.0.0-20151006203346-104dcad90073 // indirect github.com/hashicorp/mdns v0.0.0-20151206042412-9d85cf22f9f8 // indirect github.com/miekg/dns v0.0.0-20161006100029-fc4e1e2843d8 // indirect diff --git a/go.sum b/go.sum index 63ba143..5a34ffc 100644 --- a/go.sum +++ b/go.sum @@ -110,6 +110,8 @@ github.com/go-sourcemap/sourcemap v2.1.3+incompatible h1:W1iEw64niKVGogNgBN3ePyL github.com/go-sourcemap/sourcemap v2.1.3+incompatible/go.mod h1:F8jJfvm2KbVjc5NqelyYJmf/v5J0dwNLS2mL4sNA1Jg= github.com/gogo/protobuf v0.0.0-20161014173244-50d1bd39ce4e h1:eeyMpoxANuWNQ9O2auv4wXxJsrXzLUhdHaOmNWEGkRY= github.com/gogo/protobuf v0.0.0-20161014173244-50d1bd39ce4e/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= +github.com/golang-jwt/jwt v3.2.2+incompatible h1:IfV12K8xAKAnZqdXVzCZ+TOjboZ2keLg81eXfW3O+oY= +github.com/golang-jwt/jwt v3.2.2+incompatible/go.mod h1:8pz2t5EyA70fFQQSrl6XZXzqecmYZeUEB8OUGHkxJ+I= github.com/golang/glog v0.0.0-20160126235308-23def4e6c14b/go.mod h1:SBH7ygxi8pfUlaOkMMuAQtPIUF8ecWP5IEl/CR7VP2Q= github.com/golang/groupcache v0.0.0-20190702054246-869f871628b6/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= github.com/golang/groupcache v0.0.0-20191027212112-611e8accdfc9/go.mod h1:cIg4eruTrX1D+g88fzRXU5OdNfaM+9IcxsU14FzY7Hc= diff --git a/pkg/module/assert.go b/pkg/module/assert.go deleted file mode 100644 index 26236fe..0000000 --- a/pkg/module/assert.go +++ /dev/null @@ -1,28 +0,0 @@ -package module - -import ( - "context" - "fmt" - - "github.com/dop251/goja" -) - -func assertType[T any](v goja.Value, rt *goja.Runtime) T { - if c, ok := v.Export().(T); ok { - return c - } - - panic(rt.NewTypeError(fmt.Sprintf("expected value to be a '%T', got '%T'", new(T), v.Export()))) -} - -func assertContext(v goja.Value, r *goja.Runtime) context.Context { - return assertType[context.Context](v, r) -} - -func assertObject(v goja.Value, r *goja.Runtime) map[string]any { - return assertType[map[string]any](v, r) -} - -func assertString(v goja.Value, r *goja.Runtime) string { - return assertType[string](v, r) -} diff --git a/pkg/module/auth/module.go b/pkg/module/auth/module.go new file mode 100644 index 0000000..98bc6da --- /dev/null +++ b/pkg/module/auth/module.go @@ -0,0 +1,84 @@ +package auth + +import ( + "net/http" + "strings" + + "forge.cadoles.com/arcad/edge/pkg/app" + "forge.cadoles.com/arcad/edge/pkg/module" + "forge.cadoles.com/arcad/edge/pkg/module/util" + "github.com/dop251/goja" + "github.com/golang-jwt/jwt" + "github.com/pkg/errors" +) + +const ( + AnonymousSubject = "anonymous" +) + +type Module struct { + server *app.Server + keyFunc jwt.Keyfunc +} + +func (m *Module) Name() string { + return "auth" +} + +func (m *Module) Export(export *goja.Object) { + if err := export.Set("getSubject", m.getSubject); err != nil { + panic(errors.Wrap(err, "could not set 'getSubject' function")) + } + + if err := export.Set("ANONYMOUS", AnonymousSubject); err != nil { + panic(errors.Wrap(err, "could not set 'ANONYMOUS_USER' property")) + } +} + +func (m *Module) getSubject(call goja.FunctionCall, rt *goja.Runtime) goja.Value { + ctx := util.AssertContext(call.Argument(0), rt) + + req, ok := ctx.Value(module.ContextKeyOriginRequest).(*http.Request) + if !ok { + panic(errors.New("could not find http request in context")) + } + + rawToken := strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ") + if rawToken == "" { + rawToken = req.URL.Query().Get("token") + } + + if rawToken == "" { + return rt.ToValue(AnonymousSubject) + } + + token, err := jwt.Parse(rawToken, m.keyFunc) + if err != nil { + panic(errors.WithStack(err)) + } + + if !token.Valid { + panic(errors.Errorf("invalid jwt token: '%v'", token.Raw)) + } + + mapClaims, ok := token.Claims.(jwt.MapClaims) + if !ok { + panic(errors.Errorf("unexpected claims type '%T'", token.Claims)) + } + + subject, exists := mapClaims["sub"] + if !exists { + return rt.ToValue(AnonymousSubject) + } + + return rt.ToValue(subject) +} + +func ModuleFactory(keyFunc jwt.Keyfunc) app.ServerModuleFactory { + return func(server *app.Server) app.ServerModule { + return &Module{ + server: server, + keyFunc: keyFunc, + } + } +} diff --git a/pkg/module/auth/module_test.go b/pkg/module/auth/module_test.go new file mode 100644 index 0000000..f138807 --- /dev/null +++ b/pkg/module/auth/module_test.go @@ -0,0 +1,121 @@ +package auth + +import ( + "context" + "fmt" + "io/ioutil" + "net/http" + "testing" + "time" + + "cdr.dev/slog" + "forge.cadoles.com/arcad/edge/pkg/app" + "forge.cadoles.com/arcad/edge/pkg/module" + "github.com/golang-jwt/jwt" + "github.com/pkg/errors" + "gitlab.com/wpetit/goweb/logger" +) + +func TestAuthModule(t *testing.T) { + t.Parallel() + + logger.SetLevel(slog.LevelDebug) + + keyFunc, secret := getKeyFunc() + + server := app.NewServer( + module.ConsoleModuleFactory(), + ModuleFactory(keyFunc), + ) + + data, err := ioutil.ReadFile("testdata/auth.js") + if err != nil { + t.Fatal(err) + } + + if err := server.Load("testdata/auth.js", string(data)); err != nil { + t.Fatal(err) + } + + if err := server.Start(); err != nil { + t.Fatalf("%+v", errors.WithStack(err)) + } + + defer server.Stop() + + req, err := http.NewRequest("GET", "/foo", nil) + if err != nil { + t.Fatalf("%+v", errors.WithStack(err)) + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, jwt.MapClaims{ + "sub": "jdoe", + "nbf": time.Now().UTC().Unix(), + }) + + rawToken, err := token.SignedString(secret) + if err != nil { + t.Fatalf("%+v", errors.WithStack(err)) + } + + req.Header.Add("Authorization", "Bearer "+rawToken) + + ctx := context.WithValue(context.Background(), module.ContextKeyOriginRequest, req) + + if _, err := server.ExecFuncByName("testAuth", ctx); err != nil { + t.Fatalf("%+v", errors.WithStack(err)) + } +} + +func TestAuthAnonymousModule(t *testing.T) { + t.Parallel() + + logger.SetLevel(slog.LevelDebug) + + keyFunc, _ := getKeyFunc() + + server := app.NewServer( + module.ConsoleModuleFactory(), + ModuleFactory(keyFunc), + ) + + data, err := ioutil.ReadFile("testdata/auth_anonymous.js") + if err != nil { + t.Fatal(err) + } + + if err := server.Load("testdata/auth_anonymous.js", string(data)); err != nil { + t.Fatal(err) + } + + if err := server.Start(); err != nil { + t.Fatalf("%+v", errors.WithStack(err)) + } + + defer server.Stop() + + req, err := http.NewRequest("GET", "/foo", nil) + if err != nil { + t.Fatalf("%+v", errors.WithStack(err)) + } + + ctx := context.WithValue(context.Background(), module.ContextKeyOriginRequest, req) + + if _, err := server.ExecFuncByName("testAuth", ctx); err != nil { + t.Fatalf("%+v", errors.WithStack(err)) + } +} + +func getKeyFunc() (jwt.Keyfunc, []byte) { + secret := []byte("not_so_secret") + + keyFunc := func(t *jwt.Token) (interface{}, error) { + if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("Unexpected signing method: %v", t.Header["alg"]) + } + + return secret, nil + } + + return keyFunc, secret +} diff --git a/pkg/module/auth/testdata/auth.js b/pkg/module/auth/testdata/auth.js new file mode 100644 index 0000000..65e6d5f --- /dev/null +++ b/pkg/module/auth/testdata/auth.js @@ -0,0 +1,9 @@ + + +function testAuth(ctx) { + var subject = auth.getSubject(ctx); + + if (subject !== "jdoe") { + throw new Error("subject: expected 'jdoe', got '"+subject+"'"); + } +} \ No newline at end of file diff --git a/pkg/module/auth/testdata/auth_anonymous.js b/pkg/module/auth/testdata/auth_anonymous.js new file mode 100644 index 0000000..edda6e7 --- /dev/null +++ b/pkg/module/auth/testdata/auth_anonymous.js @@ -0,0 +1,9 @@ + + +function testAuth(ctx) { + var subject = auth.getSubject(ctx); + + if (subject !== auth.ANONYMOUS) { + throw new Error("subject: expected '"+auth.ANONYMOUS+"', got '"+subject+"'"); + } +} \ No newline at end of file diff --git a/pkg/module/context.go b/pkg/module/context.go index 52ee3d9..67f8a57 100644 --- a/pkg/module/context.go +++ b/pkg/module/context.go @@ -4,6 +4,7 @@ import ( "context" "forge.cadoles.com/arcad/edge/pkg/app" + "forge.cadoles.com/arcad/edge/pkg/module/util" "github.com/dop251/goja" "github.com/pkg/errors" ) @@ -26,8 +27,8 @@ func (m *ContextModule) new(call goja.FunctionCall, rt *goja.Runtime) goja.Value } func (m *ContextModule) with(call goja.FunctionCall, rt *goja.Runtime) goja.Value { - ctx := assertContext(call.Argument(0), rt) - rawValues := assertObject(call.Argument(1), rt) + ctx := util.AssertContext(call.Argument(0), rt) + rawValues := util.AssertObject(call.Argument(1), rt) values := make(map[ContextKey]any) for k, v := range rawValues { @@ -40,8 +41,8 @@ func (m *ContextModule) with(call goja.FunctionCall, rt *goja.Runtime) goja.Valu } func (m *ContextModule) get(call goja.FunctionCall, rt *goja.Runtime) goja.Value { - ctx := assertContext(call.Argument(0), rt) - rawKey := assertString(call.Argument(1), rt) + ctx := util.AssertContext(call.Argument(0), rt) + rawKey := util.AssertString(call.Argument(1), rt) value := ctx.Value(ContextKey(rawKey)) diff --git a/pkg/module/net.go b/pkg/module/net.go index 15510af..be38efa 100644 --- a/pkg/module/net.go +++ b/pkg/module/net.go @@ -5,6 +5,7 @@ import ( "forge.cadoles.com/arcad/edge/pkg/app" "forge.cadoles.com/arcad/edge/pkg/bus" + "forge.cadoles.com/arcad/edge/pkg/module/util" "github.com/dop251/goja" "github.com/pkg/errors" ) @@ -58,7 +59,7 @@ func (m *NetModule) send(call goja.FunctionCall, rt *goja.Runtime) goja.Value { ContextKeySessionID: sessionID, }) } else { - ctx = assertContext(firstArg, rt) + ctx = util.AssertContext(firstArg, rt) } data := call.Argument(1).Export() diff --git a/pkg/module/rpc.go b/pkg/module/rpc.go index 4033834..d309581 100644 --- a/pkg/module/rpc.go +++ b/pkg/module/rpc.go @@ -7,6 +7,7 @@ import ( "forge.cadoles.com/arcad/edge/pkg/app" "forge.cadoles.com/arcad/edge/pkg/bus" + "forge.cadoles.com/arcad/edge/pkg/module/util" "github.com/dop251/goja" "github.com/pkg/errors" "gitlab.com/wpetit/goweb/logger" @@ -51,7 +52,7 @@ func (m *RPCModule) Export(export *goja.Object) { } func (m *RPCModule) register(call goja.FunctionCall, rt *goja.Runtime) goja.Value { - fnName := assertString(call.Argument(0), rt) + fnName := util.AssertString(call.Argument(0), rt) var ( callable goja.Callable @@ -78,7 +79,7 @@ func (m *RPCModule) register(call goja.FunctionCall, rt *goja.Runtime) goja.Valu } func (m *RPCModule) unregister(call goja.FunctionCall, rt *goja.Runtime) goja.Value { - fnName := assertString(call.Argument(0), rt) + fnName := util.AssertString(call.Argument(0), rt) m.callbacks.Delete(fnName) diff --git a/pkg/module/store.go b/pkg/module/store.go index bb30735..0a6df63 100644 --- a/pkg/module/store.go +++ b/pkg/module/store.go @@ -4,6 +4,7 @@ import ( "fmt" "forge.cadoles.com/arcad/edge/pkg/app" + "forge.cadoles.com/arcad/edge/pkg/module/util" "forge.cadoles.com/arcad/edge/pkg/storage" "forge.cadoles.com/arcad/edge/pkg/storage/filter" "github.com/dop251/goja" @@ -47,7 +48,7 @@ func (m *StoreModule) Export(export *goja.Object) { } func (m *StoreModule) upsert(call goja.FunctionCall, rt *goja.Runtime) goja.Value { - ctx := assertContext(call.Argument(0), rt) + ctx := util.AssertContext(call.Argument(0), rt) collection := m.assertCollection(call.Argument(1), rt) document := m.assertDocument(call.Argument(2), rt) @@ -60,7 +61,7 @@ func (m *StoreModule) upsert(call goja.FunctionCall, rt *goja.Runtime) goja.Valu } func (m *StoreModule) get(call goja.FunctionCall, rt *goja.Runtime) goja.Value { - ctx := assertContext(call.Argument(0), rt) + ctx := util.AssertContext(call.Argument(0), rt) collection := m.assertCollection(call.Argument(1), rt) documentID := m.assertDocumentID(call.Argument(2), rt) @@ -84,7 +85,7 @@ type queryOptions struct { } func (m *StoreModule) query(call goja.FunctionCall, rt *goja.Runtime) goja.Value { - ctx := assertContext(call.Argument(0), rt) + ctx := util.AssertContext(call.Argument(0), rt) collection := m.assertCollection(call.Argument(1), rt) filter := m.assertFilter(call.Argument(2), rt) queryOptions := m.assertQueryOptions(call.Argument(3), rt) @@ -125,7 +126,7 @@ func (m *StoreModule) query(call goja.FunctionCall, rt *goja.Runtime) goja.Value } func (m *StoreModule) delete(call goja.FunctionCall, rt *goja.Runtime) goja.Value { - ctx := assertContext(call.Argument(0), rt) + ctx := util.AssertContext(call.Argument(0), rt) collection := m.assertCollection(call.Argument(1), rt) documentID := m.assertDocumentID(call.Argument(2), rt) diff --git a/pkg/module/util/assert.go b/pkg/module/util/assert.go new file mode 100644 index 0000000..ffe7e69 --- /dev/null +++ b/pkg/module/util/assert.go @@ -0,0 +1,28 @@ +package util + +import ( + "context" + "fmt" + + "github.com/dop251/goja" +) + +func AssertType[T any](v goja.Value, rt *goja.Runtime) T { + if c, ok := v.Export().(T); ok { + return c + } + + panic(rt.NewTypeError(fmt.Sprintf("expected value to be a '%T', got '%T'", new(T), v.Export()))) +} + +func AssertContext(v goja.Value, r *goja.Runtime) context.Context { + return AssertType[context.Context](v, r) +} + +func AssertObject(v goja.Value, r *goja.Runtime) map[string]any { + return AssertType[map[string]any](v, r) +} + +func AssertString(v goja.Value, r *goja.Runtime) string { + return AssertType[string](v, r) +}