package module import ( "context" "fmt" "sync" "forge.cadoles.com/arcad/edge/pkg/app" "forge.cadoles.com/arcad/edge/pkg/bus" "forge.cadoles.com/arcad/edge/pkg/module/util" "github.com/dop251/goja" "github.com/pkg/errors" "gitlab.com/wpetit/goweb/logger" ) type RPCRequest struct { Method string Params interface{} ID interface{} } type RPCError struct { Code int `json:"code"` Message string `json:"message"` Data interface{} `json:"data"` } type RPCResponse struct { Result interface{} Error *RPCError ID interface{} } type RPCModule struct { server *app.Server bus bus.Bus callbacks sync.Map } func (m *RPCModule) Name() string { return "rpc" } func (m *RPCModule) Export(export *goja.Object) { if err := export.Set("register", m.register); err != nil { panic(errors.Wrap(err, "could not set 'register' function")) } if err := export.Set("unregister", m.unregister); err != nil { panic(errors.Wrap(err, "could not set 'unregister' function")) } } func (m *RPCModule) register(call goja.FunctionCall, rt *goja.Runtime) goja.Value { fnName := util.AssertString(call.Argument(0), rt) var ( callable goja.Callable ok bool ) if len(call.Arguments) > 1 { callable, ok = goja.AssertFunction(call.Argument(1)) } else { callable, ok = goja.AssertFunction(rt.Get(fnName)) } if !ok { panic(rt.NewTypeError("method should be a valid function")) } ctx := context.Background() logger.Debug(ctx, "registering method", logger.F("method", fnName)) m.callbacks.Store(fnName, callable) return nil } func (m *RPCModule) unregister(call goja.FunctionCall, rt *goja.Runtime) goja.Value { fnName := util.AssertString(call.Argument(0), rt) m.callbacks.Delete(fnName) return nil } func (m *RPCModule) handleMessages() { ctx := context.Background() clientMessages, err := m.bus.Subscribe(ctx, MessageNamespaceClient) if err != nil { panic(errors.WithStack(err)) } defer func() { m.bus.Unsubscribe(ctx, MessageNamespaceClient, clientMessages) }() sendRes := func(ctx context.Context, req *RPCRequest, result goja.Value) { res := &RPCResponse{ ID: req.ID, Result: result.Export(), } logger.Debug(ctx, "sending rpc response", logger.F("response", res)) if err := m.sendResponse(ctx, res); err != nil { logger.Error( ctx, "could not send response", logger.E(errors.WithStack(err)), logger.F("response", res), logger.F("request", req), ) } } for msg := range clientMessages { clientMessage, ok := msg.(*ClientMessage) if !ok { logger.Warn(ctx, "unexpected bus message", logger.F("message", msg)) continue } ok, req := m.isRPCRequest(clientMessage) if !ok { continue } logger.Debug(ctx, "received rpc request", logger.F("request", req)) rawCallable, exists := m.callbacks.Load(req.Method) if !exists { logger.Debug(ctx, "method not found", logger.F("req", req)) if err := m.sendMethodNotFoundResponse(clientMessage.Context, req); err != nil { logger.Error( ctx, "could not send method not found response", logger.E(errors.WithStack(err)), logger.F("request", req), ) } continue } callable, ok := rawCallable.(goja.Callable) if !ok { logger.Debug(ctx, "invalid method", logger.F("req", req)) if err := m.sendMethodNotFoundResponse(clientMessage.Context, req); err != nil { logger.Error( ctx, "could not send method not found response", logger.E(errors.WithStack(err)), logger.F("request", req), ) } continue } result, err := m.server.Exec(clientMessage.Context, callable, clientMessage.Context, req.Params) if err != nil { logger.Error( ctx, "rpc call error", logger.E(errors.WithStack(err)), logger.F("request", req), ) if err := m.sendErrorResponse(clientMessage.Context, req, err); err != nil { logger.Error( ctx, "could not send error response", logger.E(errors.WithStack(err)), logger.F("originalError", err), logger.F("request", req), ) } continue } promise, ok := m.server.IsPromise(result) if ok { go func(ctx context.Context, req *RPCRequest, promise *goja.Promise) { result := m.server.WaitForPromise(promise) sendRes(ctx, req, result) }(clientMessage.Context, req, promise) } else { sendRes(clientMessage.Context, req, result) } } } func (m *RPCModule) sendErrorResponse(ctx context.Context, req *RPCRequest, err error) error { return m.sendResponse(ctx, &RPCResponse{ ID: req.ID, Result: nil, Error: &RPCError{ Code: -32603, Message: err.Error(), }, }) } func (m *RPCModule) sendMethodNotFoundResponse(ctx context.Context, req *RPCRequest) error { return m.sendResponse(ctx, &RPCResponse{ ID: req.ID, Result: nil, Error: &RPCError{ Code: -32601, Message: fmt.Sprintf("method not found"), }, }) } func (m *RPCModule) sendResponse(ctx context.Context, res *RPCResponse) error { msg := NewServerMessage(ctx, map[string]interface{}{ "jsonrpc": "2.0", "id": res.ID, "error": res.Error, "result": res.Result, }) if err := m.bus.Publish(ctx, msg); err != nil { return errors.WithStack(err) } return nil } func (m *RPCModule) isRPCRequest(msg *ClientMessage) (bool, *RPCRequest) { jsonRPC, exists := msg.Data["jsonrpc"] if !exists || jsonRPC != "2.0" { return false, nil } rawMethod, exists := msg.Data["method"] if !exists { return false, nil } method, ok := rawMethod.(string) if !ok { return false, nil } id := msg.Data["id"] params := msg.Data["params"] return true, &RPCRequest{ ID: id, Method: method, Params: params, } } func RPCModuleFactory(bus bus.Bus) app.ServerModuleFactory { return func(server *app.Server) app.ServerModule { mod := &RPCModule{ server: server, bus: bus, } go mod.handleMessages() return mod } }