diff --git a/cmd/cli/command/app/run.go b/cmd/cli/command/app/run.go index 59e9f11..89eb0a6 100644 --- a/cmd/cli/command/app/run.go +++ b/cmd/cli/command/app/run.go @@ -2,19 +2,28 @@ package app import ( "database/sql" + "fmt" "net/http" "path/filepath" + "time" + "forge.cadoles.com/arcad/edge/pkg/app" + "forge.cadoles.com/arcad/edge/pkg/bus" "forge.cadoles.com/arcad/edge/pkg/bus/memory" appHTTP "forge.cadoles.com/arcad/edge/pkg/http" "forge.cadoles.com/arcad/edge/pkg/module" + "forge.cadoles.com/arcad/edge/pkg/module/auth" "forge.cadoles.com/arcad/edge/pkg/module/cast" + "forge.cadoles.com/arcad/edge/pkg/module/net" + "forge.cadoles.com/arcad/edge/pkg/storage" "forge.cadoles.com/arcad/edge/pkg/storage/sqlite" "gitlab.com/wpetit/goweb/logger" "forge.cadoles.com/arcad/edge/pkg/bundle" + "github.com/dop251/goja" "github.com/go-chi/chi/v5" "github.com/go-chi/chi/v5/middleware" + "github.com/golang-jwt/jwt" "github.com/pkg/errors" "github.com/urfave/cli/v2" @@ -53,14 +62,35 @@ func RunCommand() *cli.Command { Usage: "use `FILE` for SQLite storage database", Value: "data.sqlite", }, + &cli.StringFlag{ + Name: "auth-subject", + Usage: "set the `SUBJECT` associated with the simulated connected user", + Value: "jdoe", + }, + &cli.StringFlag{ + Name: "auth-role", + Usage: "set the `ROLE` associated with the simulated connected user", + Value: "user", + }, + &cli.StringFlag{ + Name: "auth-preferred-username", + Usage: "set the `PREFERRED_USERNAME` associated with the simulated connected user", + Value: "Jane Doe", + }, }, Action: func(ctx *cli.Context) error { address := ctx.String("address") path := ctx.String("path") + logFormat := ctx.String("log-format") logLevel := ctx.Int("log-level") + storageFile := ctx.String("storage-file") + authSubject := ctx.String("auth-subject") + authRole := ctx.String("auth-role") + authPreferredUsername := ctx.String("auth-preferred-username") + logger.SetFormat(logger.Format(logFormat)) logger.SetLevel(logger.Level(logLevel)) @@ -81,6 +111,7 @@ func RunCommand() *cli.Command { mux := chi.NewMux() mux.Use(middleware.Logger) + mux.Use(dummyAuthMiddleware(authSubject, authRole, authPreferredUsername)) bus := memory.NewBus() @@ -89,21 +120,12 @@ func RunCommand() *cli.Command { return errors.Wrapf(err, "could not open database with path '%s'", storageFile) } - documentStore := sqlite.NewDocumentStoreWithDB(db) - blobStore := sqlite.NewBlobStoreWithDB(db) + ds := sqlite.NewDocumentStoreWithDB(db) + bs := sqlite.NewBlobStoreWithDB(db) handler := appHTTP.NewHandler( appHTTP.WithBus(bus), - appHTTP.WithServerModules( - module.ContextModuleFactory(), - module.ConsoleModuleFactory(), - cast.CastModuleFactory(), - module.LifecycleModuleFactory(bus), - module.NetModuleFactory(bus), - module.RPCModuleFactory(bus), - module.StoreModuleFactory(documentStore), - module.BlobModuleFactory(bus, blobStore), - ), + appHTTP.WithServerModules(getServerModules(bus, ds, bs)...), ) if err := handler.Load(bundle); err != nil { return errors.Wrap(err, "could not load app bundle") @@ -121,3 +143,89 @@ func RunCommand() *cli.Command { }, } } + +func getServerModules(bus bus.Bus, ds storage.DocumentStore, bs storage.BlobStore) []app.ServerModuleFactory { + return []app.ServerModuleFactory{ + module.ContextModuleFactory(), + module.ConsoleModuleFactory(), + cast.CastModuleFactory(), + module.LifecycleModuleFactory(), + net.ModuleFactory(bus), + module.RPCModuleFactory(bus), + module.StoreModuleFactory(ds), + module.BlobModuleFactory(bus, bs), + module.Extends( + auth.ModuleFactory( + auth.WithJWT(dummyKeyFunc), + ), + func(o *goja.Object) { + if err := o.Set("CLAIM_ROLE", "role"); err != nil { + panic(errors.New("could not set 'CLAIM_ROLE' property")) + } + + if err := o.Set("CLAIM_PREFERRED_USERNAME", "preferred_username"); err != nil { + panic(errors.New("could not set 'CLAIM_PREFERRED_USERNAME' property")) + } + }, + ), + } +} + +var dummySecret = []byte("not_so_secret") + +func dummyKeyFunc(t *jwt.Token) (interface{}, error) { + if _, ok := t.Method.(*jwt.SigningMethodHMAC); !ok { + return nil, fmt.Errorf("Unexpected signing method: %v", t.Header["alg"]) + } + + return dummySecret, nil +} + +func dummyAuthMiddleware(subject, role, username string) func(http.Handler) http.Handler { + return func(h http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + unauthenticated := subject == "" && role == "" && username == "" + + if unauthenticated { + h.ServeHTTP(w, r) + + return + } + + claims := jwt.MapClaims{ + "nbf": time.Now().UTC().Unix(), + } + + if subject != "" { + claims["sub"] = subject + } + + if role != "" { + claims["role"] = role + } + + if username != "" { + claims["preferred_username"] = username + } + + token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims) + + ctx := r.Context() + + rawToken, err := token.SignedString(dummySecret) + if err != nil { + logger.Error(ctx, "could not sign token", logger.E(errors.WithStack(err))) + + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + + return + } + + r.Header.Add("Authorization", "Bearer "+rawToken) + + h.ServeHTTP(w, r) + } + + return http.HandlerFunc(fn) + } +} diff --git a/doc/apps/server-api/README.md b/doc/apps/server-api/README.md index febff10..5c986c7 100644 --- a/doc/apps/server-api/README.md +++ b/doc/apps/server-api/README.md @@ -20,10 +20,11 @@ function onInit() { Listes des modules disponibles côté serveur. +- [`auth`](./auth.md) +- [`blob`](./blob.md) +- [`cast`](./cast.md) - [`console`](./console.md) - [`context`](./context.md) - [`net`](./net.md) - [`rpc`](./rpc.md) - [`store`](./store.md) -- [`blob`](./blob.md) -- [`cast`](./cast.md) \ No newline at end of file diff --git a/doc/apps/server-api/auth.md b/doc/apps/server-api/auth.md new file mode 100644 index 0000000..fbfe9b0 --- /dev/null +++ b/doc/apps/server-api/auth.md @@ -0,0 +1,41 @@ +# Module `auth` + +Ce module permet de récupérer des informations concernant l'utilisateur connecté et ses attributs. + +## Méthodes + +### `auth.getClaim(ctx: Context, name: string): string` + +Récupère un attribut associé à l'utilisateur. + +#### Arguments + +- `ctx` **Context** Le contexte d'exécution. Voir la documentation du module [`context`](./context.md) +- `name` **string** Nom de l'attribut à retrouver + +#### Valeur de retour + +Valeur de l'attribut associé ou vide si la requête est non authentifiée ou que l'attribut n'a pas été trouvé. + +#### Usage + +```js +function onClientMessage(ctx, message) { + var subject = auth.getClaim(ctx, auth.CLAIM_SUBJECT); + console.log("Connected user is", subject); +} +``` + +## Propriétés + +### `auth.CLAIM_SUBJECT` + +Cette propriété identifie l'utilisateur connecté. Si la valeur retournée par la méthode `getClaim()` est vide, alors l'utilisateur n'est pas connecté. + +### `auth.CLAIM_ROLE` + +Cette propriété retourne le rôle de l'utilisateur connecté au sein du "tenant" courant. Si la valeur retournée par la méthode `getClaim()` est vide, alors l'utilisateur n'est pas connecté. + +### `auth.PREFERRED_USERNAME` + +Cette propriété retourne le nom "préféré pour l'affichage" de l'utilisateur connecté. Si la valeur retournée par la méthode `getClaim()` est vide, alors l'utilisateur n'est pas connecté ou l'utilisateur n'a pas défini de nom d'utilisateur particulier. \ No newline at end of file diff --git a/misc/client-sdk-testsuite/src/public/index.html b/misc/client-sdk-testsuite/src/public/index.html index f4e87d4..38034f9 100644 --- a/misc/client-sdk-testsuite/src/public/index.html +++ b/misc/client-sdk-testsuite/src/public/index.html @@ -21,6 +21,7 @@ + diff --git a/misc/client-sdk-testsuite/src/public/test/auth-module.js b/misc/client-sdk-testsuite/src/public/test/auth-module.js new file mode 100644 index 0000000..856d748 --- /dev/null +++ b/misc/client-sdk-testsuite/src/public/test/auth-module.js @@ -0,0 +1,20 @@ +describe('Auth Module', function() { + + before(() => { + return Edge.connect(); + }); + + after(() => { + Edge.disconnect(); + }); + + it('should retrieve user informations', function() { + return Edge.rpc("getUserInfo") + .then(userInfo => { + chai.assert.isNotNull(userInfo.subject); + chai.assert.isNotNull(userInfo.role); + chai.assert.isNotNull(userInfo.preferredUsername); + }) + }); + +}); \ No newline at end of file diff --git a/misc/client-sdk-testsuite/src/server/main.js b/misc/client-sdk-testsuite/src/server/main.js index 5e4e235..b7e8a22 100644 --- a/misc/client-sdk-testsuite/src/server/main.js +++ b/misc/client-sdk-testsuite/src/server/main.js @@ -10,12 +10,12 @@ function onInit() { rpc.register("add", add); rpc.register("reset", reset); rpc.register("total", total); + rpc.register("getUserInfo", getUserInfo); } // Called for each client message function onClientMessage(ctx, data) { - var sessionId = context.get(ctx, context.SESSION_ID); - console.log("onClientMessage", sessionId, data.now); + console.log("onClientMessage", data.now); net.send(ctx, { now: data.now }); } @@ -60,4 +60,16 @@ function reset(ctx, params) { function total(ctx, params) { return count; +} + +function getUserInfo(ctx, params) { + var subject = auth.getClaim(ctx, auth.CLAIM_SUBJECT); + var role = auth.getClaim(ctx, auth.CLAIM_ROLE); + var preferredUsername = auth.getClaim(ctx, auth.CLAIM_PREFERRED_USERNAME); + + return { + subject: subject, + role: role, + preferredUsername: preferredUsername, + }; } \ No newline at end of file diff --git a/pkg/http/blob.go b/pkg/http/blob.go index ecf31dd..5e5bb7b 100644 --- a/pkg/http/blob.go +++ b/pkg/http/blob.go @@ -65,7 +65,7 @@ func (h *Handler) handleAppUpload(w http.ResponseWriter, r *http.Request) { } ctx = module.WithContext(ctx, map[module.ContextKey]any{ - module.ContextKeyOriginRequest: r, + ContextKeyOriginRequest: r, }) requestMsg := module.NewMessageUploadRequest(ctx, fileHeader, metadata) @@ -117,7 +117,7 @@ func (h *Handler) handleAppDownload(w http.ResponseWriter, r *http.Request) { ctx := logger.With(r.Context(), logger.F("blobID", blobID), logger.F("bucket", bucket)) ctx = module.WithContext(ctx, map[module.ContextKey]any{ - module.ContextKeyOriginRequest: r, + ContextKeyOriginRequest: r, }) requestMsg := module.NewMessageDownloadRequest(ctx, bucket, storage.BlobID(blobID)) diff --git a/pkg/http/sockjs.go b/pkg/http/sockjs.go index 3a29d35..4ce2bbf 100644 --- a/pkg/http/sockjs.go +++ b/pkg/http/sockjs.go @@ -15,6 +15,11 @@ const ( statusChannelClosed = iota ) +const ( + ContextKeySessionID module.ContextKey = "sessionId" + ContextKeyOriginRequest module.ContextKey = "originRequest" +) + func (h *Handler) handleSockJS(w http.ResponseWriter, r *http.Request) { h.mutex.RLock() defer h.mutex.RUnlock() @@ -79,7 +84,7 @@ func (h *Handler) handleServerMessages(ctx context.Context, sess sockjs.Session) continue } - sessionID := module.ContextValue[string](serverMessage.Context, module.ContextKeySessionID) + sessionID := module.ContextValue[string](serverMessage.Context, ContextKeySessionID) isDest := sessionID == "" || sessionID == sess.ID() if !isDest { @@ -182,8 +187,8 @@ func (h *Handler) handleClientMessages(ctx context.Context, sess sockjs.Session) ctx := logger.With(ctx, logger.F("payload", payload)) ctx = module.WithContext(ctx, map[module.ContextKey]any{ - module.ContextKeySessionID: sess.ID(), - module.ContextKeyOriginRequest: sess.Request(), + ContextKeySessionID: sess.ID(), + ContextKeyOriginRequest: sess.Request(), }) clientMessage := module.NewClientMessage(ctx, payload) diff --git a/pkg/module/auth/error.go b/pkg/module/auth/error.go new file mode 100644 index 0000000..046dd02 --- /dev/null +++ b/pkg/module/auth/error.go @@ -0,0 +1,8 @@ +package auth + +import "errors" + +var ( + ErrUnauthenticated = errors.New("unauthenticated") + ErrClaimNotFound = errors.New("claim not found") +) diff --git a/pkg/module/auth/jwt.go b/pkg/module/auth/jwt.go new file mode 100644 index 0000000..ab01af7 --- /dev/null +++ b/pkg/module/auth/jwt.go @@ -0,0 +1,60 @@ +package auth + +import ( + "context" + "net/http" + "strings" + + "github.com/golang-jwt/jwt" + "github.com/pkg/errors" +) + +func WithJWT(keyFunc jwt.Keyfunc) OptionFunc { + return func(o *Option) { + o.GetClaim = func(ctx context.Context, r *http.Request, claimName string) (string, error) { + claim, err := getClaim[string](r, claimName, keyFunc) + if err != nil { + return "", errors.WithStack(err) + } + + return claim, nil + } + } +} + +func getClaim[T any](r *http.Request, claimAttr string, keyFunc jwt.Keyfunc) (T, error) { + rawToken := strings.TrimPrefix(r.Header.Get("Authorization"), "Bearer ") + if rawToken == "" { + rawToken = r.URL.Query().Get("token") + } + + if rawToken == "" { + return *new(T), errors.WithStack(ErrUnauthenticated) + } + + token, err := jwt.Parse(rawToken, keyFunc) + if err != nil { + return *new(T), errors.WithStack(err) + } + + if !token.Valid { + return *new(T), errors.Errorf("invalid jwt token: '%v'", token.Raw) + } + + mapClaims, ok := token.Claims.(jwt.MapClaims) + if !ok { + return *new(T), errors.Errorf("unexpected claims type '%T'", token.Claims) + } + + rawClaim, exists := mapClaims[claimAttr] + if !exists { + return *new(T), errors.WithStack(ErrClaimNotFound) + } + + claim, ok := rawClaim.(T) + if !ok { + return *new(T), errors.Errorf("unexpected claim '%s' to be of type '%T', got '%T'", claimAttr, new(T), rawClaim) + } + + return claim, nil +} diff --git a/pkg/module/auth/module.go b/pkg/module/auth/module.go index 98bc6da..8705f93 100644 --- a/pkg/module/auth/module.go +++ b/pkg/module/auth/module.go @@ -2,23 +2,21 @@ package auth import ( "net/http" - "strings" "forge.cadoles.com/arcad/edge/pkg/app" - "forge.cadoles.com/arcad/edge/pkg/module" + edgeHTTP "forge.cadoles.com/arcad/edge/pkg/http" "forge.cadoles.com/arcad/edge/pkg/module/util" "github.com/dop251/goja" - "github.com/golang-jwt/jwt" "github.com/pkg/errors" ) const ( - AnonymousSubject = "anonymous" + ClaimSubject = "sub" ) type Module struct { - server *app.Server - keyFunc jwt.Keyfunc + server *app.Server + getClaimFunc GetClaimFunc } func (m *Module) Name() string { @@ -26,59 +24,46 @@ func (m *Module) Name() string { } func (m *Module) Export(export *goja.Object) { - if err := export.Set("getSubject", m.getSubject); err != nil { - panic(errors.Wrap(err, "could not set 'getSubject' function")) + if err := export.Set("getClaim", m.getClaim); err != nil { + panic(errors.Wrap(err, "could not set 'getClaim' function")) } - if err := export.Set("ANONYMOUS", AnonymousSubject); err != nil { - panic(errors.Wrap(err, "could not set 'ANONYMOUS_USER' property")) + if err := export.Set("CLAIM_SUBJECT", ClaimSubject); err != nil { + panic(errors.Wrap(err, "could not set 'CLAIM_SUBJECT' property")) } } -func (m *Module) getSubject(call goja.FunctionCall, rt *goja.Runtime) goja.Value { +func (m *Module) getClaim(call goja.FunctionCall, rt *goja.Runtime) goja.Value { ctx := util.AssertContext(call.Argument(0), rt) + claimName := util.AssertString(call.Argument(1), rt) - req, ok := ctx.Value(module.ContextKeyOriginRequest).(*http.Request) + req, ok := ctx.Value(edgeHTTP.ContextKeyOriginRequest).(*http.Request) if !ok { - panic(errors.New("could not find http request in context")) + panic(rt.ToValue(errors.New("could not find http request in context"))) } - rawToken := strings.TrimPrefix(req.Header.Get("Authorization"), "Bearer ") - if rawToken == "" { - rawToken = req.URL.Query().Get("token") - } - - if rawToken == "" { - return rt.ToValue(AnonymousSubject) - } - - token, err := jwt.Parse(rawToken, m.keyFunc) + claim, err := m.getClaimFunc(ctx, req, claimName) if err != nil { - panic(errors.WithStack(err)) + if errors.Is(err, ErrUnauthenticated) || errors.Is(err, ErrClaimNotFound) { + return nil + } + + panic(rt.ToValue(errors.WithStack(err))) } - if !token.Valid { - panic(errors.Errorf("invalid jwt token: '%v'", token.Raw)) - } - - mapClaims, ok := token.Claims.(jwt.MapClaims) - if !ok { - panic(errors.Errorf("unexpected claims type '%T'", token.Claims)) - } - - subject, exists := mapClaims["sub"] - if !exists { - return rt.ToValue(AnonymousSubject) - } - - return rt.ToValue(subject) + return rt.ToValue(claim) } -func ModuleFactory(keyFunc jwt.Keyfunc) app.ServerModuleFactory { +func ModuleFactory(funcs ...OptionFunc) app.ServerModuleFactory { + opt := &Option{} + for _, fn := range funcs { + fn(opt) + } + return func(server *app.Server) app.ServerModule { return &Module{ - server: server, - keyFunc: keyFunc, + server: server, + getClaimFunc: opt.GetClaim, } } } diff --git a/pkg/module/auth/module_test.go b/pkg/module/auth/module_test.go index f138807..0a7d908 100644 --- a/pkg/module/auth/module_test.go +++ b/pkg/module/auth/module_test.go @@ -10,6 +10,7 @@ import ( "cdr.dev/slog" "forge.cadoles.com/arcad/edge/pkg/app" + edgeHTTP "forge.cadoles.com/arcad/edge/pkg/http" "forge.cadoles.com/arcad/edge/pkg/module" "github.com/golang-jwt/jwt" "github.com/pkg/errors" @@ -25,7 +26,9 @@ func TestAuthModule(t *testing.T) { server := app.NewServer( module.ConsoleModuleFactory(), - ModuleFactory(keyFunc), + ModuleFactory( + WithJWT(keyFunc), + ), ) data, err := ioutil.ReadFile("testdata/auth.js") @@ -60,7 +63,7 @@ func TestAuthModule(t *testing.T) { req.Header.Add("Authorization", "Bearer "+rawToken) - ctx := context.WithValue(context.Background(), module.ContextKeyOriginRequest, req) + ctx := context.WithValue(context.Background(), edgeHTTP.ContextKeyOriginRequest, req) if _, err := server.ExecFuncByName("testAuth", ctx); err != nil { t.Fatalf("%+v", errors.WithStack(err)) @@ -76,7 +79,7 @@ func TestAuthAnonymousModule(t *testing.T) { server := app.NewServer( module.ConsoleModuleFactory(), - ModuleFactory(keyFunc), + ModuleFactory(WithJWT(keyFunc)), ) data, err := ioutil.ReadFile("testdata/auth_anonymous.js") @@ -99,7 +102,7 @@ func TestAuthAnonymousModule(t *testing.T) { t.Fatalf("%+v", errors.WithStack(err)) } - ctx := context.WithValue(context.Background(), module.ContextKeyOriginRequest, req) + ctx := context.WithValue(context.Background(), edgeHTTP.ContextKeyOriginRequest, req) if _, err := server.ExecFuncByName("testAuth", ctx); err != nil { t.Fatalf("%+v", errors.WithStack(err)) diff --git a/pkg/module/auth/option.go b/pkg/module/auth/option.go new file mode 100644 index 0000000..fa2a375 --- /dev/null +++ b/pkg/module/auth/option.go @@ -0,0 +1,20 @@ +package auth + +import ( + "context" + "net/http" +) + +type GetClaimFunc func(ctx context.Context, r *http.Request, claimName string) (string, error) + +type Option struct { + GetClaim GetClaimFunc +} + +type OptionFunc func(*Option) + +func WithGetClaim(fn GetClaimFunc) OptionFunc { + return func(o *Option) { + o.GetClaim = fn + } +} diff --git a/pkg/module/auth/testdata/auth.js b/pkg/module/auth/testdata/auth.js index 65e6d5f..fb49306 100644 --- a/pkg/module/auth/testdata/auth.js +++ b/pkg/module/auth/testdata/auth.js @@ -1,7 +1,7 @@ function testAuth(ctx) { - var subject = auth.getSubject(ctx); + var subject = auth.getClaim(ctx, auth.CLAIM_SUBJECT); if (subject !== "jdoe") { throw new Error("subject: expected 'jdoe', got '"+subject+"'"); diff --git a/pkg/module/auth/testdata/auth_anonymous.js b/pkg/module/auth/testdata/auth_anonymous.js index edda6e7..ff705da 100644 --- a/pkg/module/auth/testdata/auth_anonymous.js +++ b/pkg/module/auth/testdata/auth_anonymous.js @@ -1,9 +1,9 @@ function testAuth(ctx) { - var subject = auth.getSubject(ctx); + var subject = auth.getClaim(ctx, auth.CLAIM_SUBJECT); - if (subject !== auth.ANONYMOUS) { - throw new Error("subject: expected '"+auth.ANONYMOUS+"', got '"+subject+"'"); + if (subject !== undefined) { + throw new Error("subject: expected undefined, got '"+subject+"'"); } } \ No newline at end of file diff --git a/pkg/module/cast/module.go b/pkg/module/cast/module.go index a23c4f1..8707499 100644 --- a/pkg/module/cast/module.go +++ b/pkg/module/cast/module.go @@ -60,7 +60,7 @@ func (m *Module) refreshDevices(call goja.FunctionCall, rt *goja.Runtime) goja.V timeout, err := m.parseTimeout(rawTimeout) if err != nil { - panic(errors.WithStack(err)) + panic(rt.ToValue(errors.WithStack(err))) } promise := m.server.NewPromise() @@ -106,7 +106,7 @@ func (m *Module) getDevices(call goja.FunctionCall, rt *goja.Runtime) goja.Value func (m *Module) loadUrl(call goja.FunctionCall, rt *goja.Runtime) goja.Value { if len(call.Arguments) < 2 { - panic(errors.WithStack(module.ErrUnexpectedArgumentsNumber)) + panic(rt.ToValue(errors.WithStack(module.ErrUnexpectedArgumentsNumber))) } deviceUUID := call.Argument(0).String() @@ -116,7 +116,7 @@ func (m *Module) loadUrl(call goja.FunctionCall, rt *goja.Runtime) goja.Value { timeout, err := m.parseTimeout(rawTimeout) if err != nil { - panic(errors.WithStack(err)) + panic(rt.ToValue(errors.WithStack(err))) } promise := m.server.NewPromise() @@ -144,9 +144,9 @@ func (m *Module) loadUrl(call goja.FunctionCall, rt *goja.Runtime) goja.Value { return m.server.ToValue(promise) } -func (m *Module) stopCast(call goja.FunctionCall) goja.Value { +func (m *Module) stopCast(call goja.FunctionCall, rt *goja.Runtime) goja.Value { if len(call.Arguments) < 1 { - panic(errors.WithStack(module.ErrUnexpectedArgumentsNumber)) + panic(rt.ToValue(errors.WithStack(module.ErrUnexpectedArgumentsNumber))) } deviceUUID := call.Argument(0).String() @@ -154,7 +154,7 @@ func (m *Module) stopCast(call goja.FunctionCall) goja.Value { timeout, err := m.parseTimeout(rawTimeout) if err != nil { - panic(errors.WithStack(err)) + panic(rt.ToValue(errors.WithStack(err))) } promise := m.server.NewPromise() @@ -182,9 +182,9 @@ func (m *Module) stopCast(call goja.FunctionCall) goja.Value { return m.server.ToValue(promise) } -func (m *Module) getStatus(call goja.FunctionCall) goja.Value { +func (m *Module) getStatus(call goja.FunctionCall, rt *goja.Runtime) goja.Value { if len(call.Arguments) < 1 { - panic(errors.WithStack(module.ErrUnexpectedArgumentsNumber)) + panic(rt.ToValue(errors.WithStack(module.ErrUnexpectedArgumentsNumber))) } deviceUUID := call.Argument(0).String() @@ -192,7 +192,7 @@ func (m *Module) getStatus(call goja.FunctionCall) goja.Value { timeout, err := m.parseTimeout(rawTimeout) if err != nil { - panic(errors.WithStack(err)) + panic(rt.ToValue(errors.WithStack(err))) } promise := m.server.NewPromise() diff --git a/pkg/module/context.go b/pkg/module/context.go index 67f8a57..1db19e6 100644 --- a/pkg/module/context.go +++ b/pkg/module/context.go @@ -11,11 +11,6 @@ import ( type ContextKey string -const ( - ContextKeySessionID ContextKey = "sessionId" - ContextKeyOriginRequest ContextKey = "originRequest" -) - type ContextModule struct{} func (m *ContextModule) Name() string { @@ -61,14 +56,6 @@ func (m *ContextModule) Export(export *goja.Object) { if err := export.Set("with", m.with); err != nil { panic(errors.Wrap(err, "could not set 'with' function")) } - - if err := export.Set("ORIGIN_REQUEST", string(ContextKeyOriginRequest)); err != nil { - panic(errors.Wrap(err, "could not set 'ORIGIN_REQUEST' property")) - } - - if err := export.Set("SESSION_ID", string(ContextKeySessionID)); err != nil { - panic(errors.Wrap(err, "could not set 'SESSION_ID' property")) - } } func ContextModuleFactory() app.ServerModuleFactory { diff --git a/pkg/module/extension.go b/pkg/module/extension.go new file mode 100644 index 0000000..1fbe53a --- /dev/null +++ b/pkg/module/extension.go @@ -0,0 +1,40 @@ +package module + +import ( + "forge.cadoles.com/arcad/edge/pkg/app" + "github.com/dop251/goja" +) + +type ExtensionFunc func(*goja.Object) + +type ExtendedModule struct { + module app.ServerModule + extensions []ExtensionFunc +} + +// Export implements app.ServerModule. +func (m *ExtendedModule) Export(exports *goja.Object) { + m.module.Export(exports) + + for _, ext := range m.extensions { + ext(exports) + } +} + +// Name implements app.ServerModule. +func (m *ExtendedModule) Name() string { + return m.module.Name() +} + +func Extends(factory app.ServerModuleFactory, extensions ...ExtensionFunc) app.ServerModuleFactory { + return func(s *app.Server) app.ServerModule { + module := factory(s) + + return &ExtendedModule{ + module: module, + extensions: extensions, + } + } +} + +var _ app.ServerModule = &ExtendedModule{} diff --git a/pkg/module/lifecycle.go b/pkg/module/lifecycle.go index 0e9da6c..febd02b 100644 --- a/pkg/module/lifecycle.go +++ b/pkg/module/lifecycle.go @@ -4,7 +4,6 @@ import ( "context" "forge.cadoles.com/arcad/edge/pkg/app" - "forge.cadoles.com/arcad/edge/pkg/bus" "github.com/dop251/goja" "github.com/pkg/errors" "gitlab.com/wpetit/goweb/logger" @@ -12,7 +11,6 @@ import ( type LifecycleModule struct { server *app.Server - bus bus.Bus } func (m *LifecycleModule) Name() string { @@ -36,84 +34,12 @@ func (m *LifecycleModule) OnInit() error { return nil } -func (m *LifecycleModule) handleMessages() { - ctx := context.Background() - - logger.Debug( - ctx, - "subscribing to bus messages", - ) - - clientMessages, err := m.bus.Subscribe(ctx, MessageNamespaceClient) - if err != nil { - panic(errors.WithStack(err)) - } - - defer func() { - logger.Debug( - ctx, - "unsubscribing from bus messages", - ) - - m.bus.Unsubscribe(ctx, MessageNamespaceClient, clientMessages) - }() - - for { - logger.Debug( - ctx, - "waiting for next message", - ) - select { - case <-ctx.Done(): - logger.Debug( - ctx, - "context done", - ) - - return - - case msg := <-clientMessages: - clientMessage, ok := msg.(*ClientMessage) - if !ok { - logger.Error( - ctx, - "unexpected message type", - logger.F("message", msg), - ) - - continue - } - - logger.Debug( - ctx, - "received client message", - logger.F("message", clientMessage), - ) - - if _, err := m.server.ExecFuncByName("onClientMessage", clientMessage.Context, clientMessage.Data); err != nil { - if errors.Is(err, app.ErrFuncDoesNotExist) { - continue - } - - logger.Error( - ctx, - "on client message error", - logger.E(err), - ) - } - } - } -} - -func LifecycleModuleFactory(bus bus.Bus) app.ServerModuleFactory { +func LifecycleModuleFactory() app.ServerModuleFactory { return func(server *app.Server) app.ServerModule { module := &LifecycleModule{ server: server, - bus: bus, } - go module.handleMessages() - return module } } diff --git a/pkg/module/net.go b/pkg/module/net.go deleted file mode 100644 index be38efa..0000000 --- a/pkg/module/net.go +++ /dev/null @@ -1,82 +0,0 @@ -package module - -import ( - "context" - - "forge.cadoles.com/arcad/edge/pkg/app" - "forge.cadoles.com/arcad/edge/pkg/bus" - "forge.cadoles.com/arcad/edge/pkg/module/util" - "github.com/dop251/goja" - "github.com/pkg/errors" -) - -type NetModule struct { - server *app.Server - bus bus.Bus -} - -func (m *NetModule) Name() string { - return "net" -} - -func (m *NetModule) Export(export *goja.Object) { - if err := export.Set("broadcast", m.broadcast); err != nil { - panic(errors.Wrap(err, "could not set 'broadcast' function")) - } - - if err := export.Set("send", m.send); err != nil { - panic(errors.Wrap(err, "could not set 'send' function")) - } -} - -func (m *NetModule) broadcast(call goja.FunctionCall) goja.Value { - if len(call.Arguments) < 1 { - panic(m.server.ToValue("invalid number of argument")) - } - - data := call.Argument(0).Export() - - msg := NewServerMessage(nil, data) - if err := m.bus.Publish(context.Background(), msg); err != nil { - panic(errors.WithStack(err)) - } - - return nil -} - -func (m *NetModule) send(call goja.FunctionCall, rt *goja.Runtime) goja.Value { - if len(call.Arguments) < 2 { - panic(m.server.ToValue("invalid number of argument")) - } - - var ctx context.Context - - firstArg := call.Argument(0) - - sessionID, ok := firstArg.Export().(string) - if ok { - ctx = WithContext(context.Background(), map[ContextKey]any{ - ContextKeySessionID: sessionID, - }) - } else { - ctx = util.AssertContext(firstArg, rt) - } - - data := call.Argument(1).Export() - - msg := NewServerMessage(ctx, data) - if err := m.bus.Publish(ctx, msg); err != nil { - panic(errors.WithStack(err)) - } - - return nil -} - -func NetModuleFactory(bus bus.Bus) app.ServerModuleFactory { - return func(server *app.Server) app.ServerModule { - return &NetModule{ - server: server, - bus: bus, - } - } -} diff --git a/pkg/module/net/module.go b/pkg/module/net/module.go new file mode 100644 index 0000000..d1f8a03 --- /dev/null +++ b/pkg/module/net/module.go @@ -0,0 +1,158 @@ +package net + +import ( + "context" + + "forge.cadoles.com/arcad/edge/pkg/app" + "forge.cadoles.com/arcad/edge/pkg/bus" + edgeHTTP "forge.cadoles.com/arcad/edge/pkg/http" + "forge.cadoles.com/arcad/edge/pkg/module" + "forge.cadoles.com/arcad/edge/pkg/module/util" + "github.com/dop251/goja" + "github.com/pkg/errors" + "gitlab.com/wpetit/goweb/logger" +) + +type Module struct { + server *app.Server + bus bus.Bus +} + +func (m *Module) Name() string { + return "net" +} + +func (m *Module) Export(export *goja.Object) { + if err := export.Set("broadcast", m.broadcast); err != nil { + panic(errors.Wrap(err, "could not set 'broadcast' function")) + } + + if err := export.Set("send", m.send); err != nil { + panic(errors.Wrap(err, "could not set 'send' function")) + } +} + +func (m *Module) broadcast(call goja.FunctionCall, rt *goja.Runtime) goja.Value { + if len(call.Arguments) < 1 { + panic(rt.ToValue(errors.New("invalid number of argument"))) + } + + data := call.Argument(0).Export() + + msg := module.NewServerMessage(nil, data) + if err := m.bus.Publish(context.Background(), msg); err != nil { + panic(rt.ToValue(errors.WithStack(err))) + } + + return nil +} + +func (m *Module) send(call goja.FunctionCall, rt *goja.Runtime) goja.Value { + if len(call.Arguments) < 2 { + panic(rt.ToValue(errors.New("invalid number of argument"))) + } + + var ctx context.Context + + firstArg := call.Argument(0) + + sessionID, ok := firstArg.Export().(string) + if ok { + ctx = module.WithContext(context.Background(), map[module.ContextKey]any{ + edgeHTTP.ContextKeySessionID: sessionID, + }) + } else { + ctx = util.AssertContext(firstArg, rt) + } + + data := call.Argument(1).Export() + + msg := module.NewServerMessage(ctx, data) + if err := m.bus.Publish(ctx, msg); err != nil { + panic(rt.ToValue(errors.WithStack(err))) + } + + return nil +} + +func (m *Module) handleClientMessages() { + ctx := context.Background() + + logger.Debug( + ctx, + "subscribing to bus messages", + ) + + clientMessages, err := m.bus.Subscribe(ctx, module.MessageNamespaceClient) + if err != nil { + panic(errors.WithStack(err)) + } + + defer func() { + logger.Debug( + ctx, + "unsubscribing from bus messages", + ) + + m.bus.Unsubscribe(ctx, module.MessageNamespaceClient, clientMessages) + }() + + for { + logger.Debug( + ctx, + "waiting for next message", + ) + select { + case <-ctx.Done(): + logger.Debug( + ctx, + "context done", + ) + + return + + case msg := <-clientMessages: + clientMessage, ok := msg.(*module.ClientMessage) + if !ok { + logger.Error( + ctx, + "unexpected message type", + logger.F("message", msg), + ) + + continue + } + + logger.Debug( + ctx, + "received client message", + logger.F("message", clientMessage), + ) + + if _, err := m.server.ExecFuncByName("onClientMessage", clientMessage.Context, clientMessage.Data); err != nil { + if errors.Is(err, app.ErrFuncDoesNotExist) { + continue + } + + logger.Error( + ctx, + "on client message error", + logger.E(err), + ) + } + } + } +} + +func ModuleFactory(bus bus.Bus) app.ServerModuleFactory { + return func(server *app.Server) app.ServerModule { + module := &Module{ + server: server, + bus: bus, + } + + go module.handleClientMessages() + + return module + } +} diff --git a/pkg/module/rpc.go b/pkg/module/rpc.go index d309581..dd3108e 100644 --- a/pkg/module/rpc.go +++ b/pkg/module/rpc.go @@ -161,8 +161,14 @@ func (m *RPCModule) handleMessages() { continue } - result, err := m.server.Exec(callable, ctx, req.Params) + result, err := m.server.Exec(callable, clientMessage.Context, req.Params) if err != nil { + logger.Error( + ctx, "rpc call error", + logger.E(errors.WithStack(err)), + logger.F("request", req), + ) + if err := m.sendErrorResponse(clientMessage.Context, req, err); err != nil { logger.Error( ctx, "could not send error response", diff --git a/pkg/module/store.go b/pkg/module/store.go index 0a6df63..1fa5c51 100644 --- a/pkg/module/store.go +++ b/pkg/module/store.go @@ -54,7 +54,7 @@ func (m *StoreModule) upsert(call goja.FunctionCall, rt *goja.Runtime) goja.Valu document, err := m.store.Upsert(ctx, collection, document) if err != nil { - panic(errors.Wrapf(err, "error while upserting document in collection '%s'", collection)) + panic(rt.ToValue(errors.Wrapf(err, "error while upserting document in collection '%s'", collection))) } return rt.ToValue(map[string]interface{}(document)) @@ -71,7 +71,7 @@ func (m *StoreModule) get(call goja.FunctionCall, rt *goja.Runtime) goja.Value { return nil } - panic(errors.Wrapf(err, "error while getting document '%s' in collection '%s'", documentID, collection)) + panic(rt.ToValue(errors.Wrapf(err, "error while getting document '%s' in collection '%s'", documentID, collection))) } return rt.ToValue(map[string]interface{}(document)) @@ -114,7 +114,7 @@ func (m *StoreModule) query(call goja.FunctionCall, rt *goja.Runtime) goja.Value documents, err := m.store.Query(ctx, collection, filter, queryOptionsFuncs...) if err != nil { - panic(errors.Wrapf(err, "error while querying documents in collection '%s'", collection)) + panic(rt.ToValue(errors.Wrapf(err, "error while querying documents in collection '%s'", collection))) } rawDocuments := make([]map[string]interface{}, len(documents)) @@ -131,7 +131,7 @@ func (m *StoreModule) delete(call goja.FunctionCall, rt *goja.Runtime) goja.Valu documentID := m.assertDocumentID(call.Argument(2), rt) if err := m.store.Delete(ctx, collection, documentID); err != nil { - panic(errors.Wrapf(err, "error while deleting document '%s' in collection '%s'", documentID, collection)) + panic(rt.ToValue(errors.Wrapf(err, "error while deleting document '%s' in collection '%s'", documentID, collection))) } return nil @@ -158,7 +158,7 @@ func (m *StoreModule) assertFilter(value goja.Value, rt *goja.Runtime) *filter.F filter, err := filter.NewFrom(rawFilter) if err != nil { - panic(errors.Wrap(err, "could not convert object to filter")) + panic(rt.ToValue(errors.Wrap(err, "could not convert object to filter"))) } return filter @@ -191,7 +191,7 @@ func (m *StoreModule) assertQueryOptions(value goja.Value, rt *goja.Runtime) *qu queryOptions := &queryOptions{} if err := mapstructure.Decode(rawQueryOptions, queryOptions); err != nil { - panic(errors.Wrap(err, "could not convert object to query options")) + panic(rt.ToValue(errors.Wrap(err, "could not convert object to query options"))) } return queryOptions diff --git a/pkg/module/util/assert.go b/pkg/module/util/assert.go index ffe7e69..a7f6fd9 100644 --- a/pkg/module/util/assert.go +++ b/pkg/module/util/assert.go @@ -2,9 +2,9 @@ package util import ( "context" - "fmt" "github.com/dop251/goja" + "github.com/pkg/errors" ) func AssertType[T any](v goja.Value, rt *goja.Runtime) T { @@ -12,7 +12,7 @@ func AssertType[T any](v goja.Value, rt *goja.Runtime) T { return c } - panic(rt.NewTypeError(fmt.Sprintf("expected value to be a '%T', got '%T'", new(T), v.Export()))) + panic(rt.ToValue(errors.Errorf("expected value to be a '%T', got '%T'", *new(T), v.Export()))) } func AssertContext(v goja.Value, r *goja.Runtime) context.Context { diff --git a/pkg/sdk/client/src/client.ts b/pkg/sdk/client/src/client.ts index 01bd7a3..0be1b58 100644 --- a/pkg/sdk/client/src/client.ts +++ b/pkg/sdk/client/src/client.ts @@ -90,8 +90,6 @@ export class Client extends EventTarget { } _handleRPCResponse(evt) { - console.log(evt); - const { jsonrpc, id, error, result } = evt.detail; if (jsonrpc !== '2.0' || id === undefined) return;