edge/pkg/module/rpc.go

264 lines
5.6 KiB
Go
Raw Permalink Normal View History

2023-02-09 12:16:36 +01:00
package module
import (
"context"
"fmt"
"sync"
"forge.cadoles.com/arcad/edge/pkg/app"
"forge.cadoles.com/arcad/edge/pkg/bus"
"github.com/dop251/goja"
"github.com/pkg/errors"
"gitlab.com/wpetit/goweb/logger"
)
type RPCRequest struct {
Method string
Params interface{}
ID interface{}
}
type RPCError struct {
Code int `json:"code"`
Message string `json:"message"`
Data interface{} `json:"data"`
}
type RPCResponse struct {
Result interface{}
Error *RPCError
ID interface{}
}
type RPCModule struct {
server *app.Server
bus bus.Bus
callbacks sync.Map
}
func (m *RPCModule) Name() string {
return "rpc"
}
func (m *RPCModule) Export(export *goja.Object) {
if err := export.Set("register", m.register); err != nil {
panic(errors.Wrap(err, "could not set 'register' function"))
}
if err := export.Set("unregister", m.unregister); err != nil {
panic(errors.Wrap(err, "could not set 'unregister' function"))
}
}
func (m *RPCModule) register(call goja.FunctionCall, rt *goja.Runtime) goja.Value {
fnName := assertString(call.Argument(0), rt)
var (
callable goja.Callable
ok bool
)
if len(call.Arguments) > 1 {
callable, ok = goja.AssertFunction(call.Argument(1))
} else {
callable, ok = goja.AssertFunction(rt.Get(fnName))
}
if !ok {
panic(rt.NewTypeError("method should be a valid function"))
}
ctx := context.Background()
logger.Debug(ctx, "registering method", logger.F("method", fnName))
m.callbacks.Store(fnName, callable)
return nil
}
func (m *RPCModule) unregister(call goja.FunctionCall, rt *goja.Runtime) goja.Value {
fnName := assertString(call.Argument(0), rt)
m.callbacks.Delete(fnName)
return nil
}
func (m *RPCModule) handleMessages() {
ctx := context.Background()
clientMessages, err := m.bus.Subscribe(ctx, MessageNamespaceClient)
if err != nil {
panic(errors.WithStack(err))
}
defer func() {
m.bus.Unsubscribe(ctx, MessageNamespaceClient, clientMessages)
}()
sendRes := func(ctx context.Context, req *RPCRequest, result goja.Value) {
res := &RPCResponse{
ID: req.ID,
Result: result.Export(),
}
logger.Debug(ctx, "sending rpc response", logger.F("response", res))
if err := m.sendResponse(ctx, res); err != nil {
logger.Error(
ctx, "could not send response",
logger.E(errors.WithStack(err)),
logger.F("response", res),
logger.F("request", req),
)
}
}
for msg := range clientMessages {
clientMessage, ok := msg.(*ClientMessage)
if !ok {
logger.Warn(ctx, "unexpected bus message", logger.F("message", msg))
continue
}
ok, req := m.isRPCRequest(clientMessage)
if !ok {
continue
}
logger.Debug(ctx, "received rpc request", logger.F("request", req))
rawCallable, exists := m.callbacks.Load(req.Method)
if !exists {
logger.Debug(ctx, "method not found", logger.F("req", req))
if err := m.sendMethodNotFoundResponse(clientMessage.Context, req); err != nil {
logger.Error(
ctx, "could not send method not found response",
logger.E(errors.WithStack(err)),
logger.F("request", req),
)
}
continue
}
callable, ok := rawCallable.(goja.Callable)
if !ok {
logger.Debug(ctx, "invalid method", logger.F("req", req))
if err := m.sendMethodNotFoundResponse(clientMessage.Context, req); err != nil {
logger.Error(
ctx, "could not send method not found response",
logger.E(errors.WithStack(err)),
logger.F("request", req),
)
}
continue
}
result, err := m.server.Exec(callable, ctx, req.Params)
if err != nil {
if err := m.sendErrorResponse(clientMessage.Context, req, err); err != nil {
logger.Error(
ctx, "could not send error response",
logger.E(errors.WithStack(err)),
logger.F("originalError", err),
logger.F("request", req),
)
}
continue
}
promise, ok := m.server.IsPromise(result)
if ok {
go func(ctx context.Context, req *RPCRequest, promise *goja.Promise) {
result := m.server.WaitForPromise(promise)
sendRes(ctx, req, result)
}(clientMessage.Context, req, promise)
} else {
sendRes(clientMessage.Context, req, result)
}
}
}
func (m *RPCModule) sendErrorResponse(ctx context.Context, req *RPCRequest, err error) error {
return m.sendResponse(ctx, &RPCResponse{
ID: req.ID,
Result: nil,
Error: &RPCError{
Code: -32603,
Message: err.Error(),
},
})
}
func (m *RPCModule) sendMethodNotFoundResponse(ctx context.Context, req *RPCRequest) error {
return m.sendResponse(ctx, &RPCResponse{
ID: req.ID,
Result: nil,
Error: &RPCError{
Code: -32601,
Message: fmt.Sprintf("method not found"),
},
})
}
func (m *RPCModule) sendResponse(ctx context.Context, res *RPCResponse) error {
msg := NewServerMessage(ctx, map[string]interface{}{
"jsonrpc": "2.0",
"id": res.ID,
"error": res.Error,
"result": res.Result,
})
if err := m.bus.Publish(ctx, msg); err != nil {
return errors.WithStack(err)
}
return nil
}
func (m *RPCModule) isRPCRequest(msg *ClientMessage) (bool, *RPCRequest) {
jsonRPC, exists := msg.Data["jsonrpc"]
if !exists || jsonRPC != "2.0" {
return false, nil
}
rawMethod, exists := msg.Data["method"]
if !exists {
return false, nil
}
method, ok := rawMethod.(string)
if !ok {
return false, nil
}
id := msg.Data["id"]
params := msg.Data["params"]
return true, &RPCRequest{
ID: id,
Method: method,
Params: params,
}
}
func RPCModuleFactory(bus bus.Bus) app.ServerModuleFactory {
return func(server *app.Server) app.ServerModule {
mod := &RPCModule{
server: server,
bus: bus,
}
go mod.handleMessages()
return mod
}
}