feat: rewrite bus to prevent deadlocks
All checks were successful
arcad/edge/pipeline/head This commit looks good
arcad/edge/pipeline/pr-master This commit looks good

This commit is contained in:
2023-11-28 16:35:49 +01:00
parent f4a7366aad
commit ad49c1718c
50 changed files with 1621 additions and 1336 deletions

View File

@ -0,0 +1,21 @@
package rpc
import (
"context"
"forge.cadoles.com/arcad/edge/pkg/bus"
)
const (
Address bus.Address = "module/rpc"
)
type Request struct {
Context context.Context
Method string
Params any
}
func NewRequestEnvelope(ctx context.Context, method string, params any) bus.Envelope {
return bus.NewEnvelope(Address, &Request{ctx, method, params})
}

7
pkg/module/rpc/error.go Normal file
View File

@ -0,0 +1,7 @@
package rpc
import "errors"
var (
ErrMethodNotFound = errors.New("method not found")
)

19
pkg/module/rpc/jsonrpc.go Normal file
View File

@ -0,0 +1,19 @@
package rpc
import "fmt"
type JSONRPCRequest struct {
ID any
Method string
Params any
}
type JSONRPCError struct {
Code int `json:"code"`
Message string `json:"message"`
Data interface{} `json:"data"`
}
func (e *JSONRPCError) Error() string {
return fmt.Sprintf("json-rpc error: %d - %s", e.Code, e.Message)
}

256
pkg/module/rpc/module.go Normal file
View File

@ -0,0 +1,256 @@
package rpc
import (
"context"
"sync"
"forge.cadoles.com/arcad/edge/pkg/app"
"forge.cadoles.com/arcad/edge/pkg/bus"
edgehttp "forge.cadoles.com/arcad/edge/pkg/http"
"forge.cadoles.com/arcad/edge/pkg/module"
"forge.cadoles.com/arcad/edge/pkg/module/util"
"github.com/dop251/goja"
"github.com/pkg/errors"
"gitlab.com/wpetit/goweb/logger"
)
type Module struct {
server *app.Server
bus bus.Bus
callbacks sync.Map
}
func (m *Module) Name() string {
return "rpc"
}
func (m *Module) Export(export *goja.Object) {
if err := export.Set("register", m.register); err != nil {
panic(errors.Wrap(err, "could not set 'register' function"))
}
if err := export.Set("unregister", m.unregister); err != nil {
panic(errors.Wrap(err, "could not set 'unregister' function"))
}
}
func (m *Module) OnInit(ctx context.Context, rt *goja.Runtime) error {
requestErrs := m.bus.Reply(ctx, Address, m.handleRequest)
go func() {
for err := range requestErrs {
logger.Error(ctx, "error while replying to rpc requests", logger.CapturedE(errors.WithStack(err)))
}
}()
httpIncomingMessages, err := m.bus.Subscribe(ctx, edgehttp.AddressIncomingMessage)
if err != nil {
return errors.WithStack(err)
}
go m.handleIncomingHTTPMessages(ctx, httpIncomingMessages)
return nil
}
func (m *Module) register(call goja.FunctionCall, rt *goja.Runtime) goja.Value {
fnName := util.AssertString(call.Argument(0), rt)
var (
callable goja.Callable
ok bool
)
if len(call.Arguments) > 1 {
callable, ok = goja.AssertFunction(call.Argument(1))
} else {
callable, ok = goja.AssertFunction(rt.Get(fnName))
}
if !ok {
panic(rt.NewTypeError("method should be a valid function"))
}
ctx := context.Background()
logger.Debug(ctx, "registering method", logger.F("method", fnName))
m.callbacks.Store(fnName, callable)
return nil
}
func (m *Module) unregister(call goja.FunctionCall, rt *goja.Runtime) goja.Value {
fnName := util.AssertString(call.Argument(0), rt)
m.callbacks.Delete(fnName)
return nil
}
func (m *Module) handleRequest(env bus.Envelope) (any, error) {
request, ok := env.Message().(*Request)
if !ok {
logger.Warn(context.Background(), "unexpected bus message", logger.F("message", env.Message()))
return nil, errors.WithStack(bus.ErrUnexpectedMessage)
}
ctx := logger.With(request.Context, logger.F("request", request))
logger.Debug(ctx, "received rpc request")
rawCallable, exists := m.callbacks.Load(request.Method)
if !exists {
logger.Debug(ctx, "method not found")
return nil, errors.WithStack(ErrMethodNotFound)
}
callable, ok := rawCallable.(goja.Callable)
if !ok {
logger.Debug(ctx, "invalid method")
return nil, errors.WithStack(ErrMethodNotFound)
}
result, err := m.server.Exec(ctx, callable, request.Context, request.Params)
if err != nil {
logger.Error(
ctx, "rpc call error",
logger.CapturedE(errors.WithStack(err)),
)
return nil, errors.WithStack(err)
}
return result, nil
}
func (m *Module) handleIncomingHTTPMessages(ctx context.Context, incoming <-chan bus.Envelope) {
defer func() {
m.bus.Unsubscribe(edgehttp.AddressIncomingMessage, incoming)
}()
for env := range incoming {
msg, ok := env.Message().(*edgehttp.IncomingMessage)
if !ok {
logger.Error(ctx, "unexpected incoming http message type", logger.F("message", env.Message()))
continue
}
jsonReq, ok := m.isRPCRequest(msg.Payload)
if !ok {
continue
}
requestCtx := logger.With(msg.Context, logger.F("rpcRequestMethod", jsonReq.Method), logger.F("rpcRequestID", jsonReq.ID))
request := NewRequestEnvelope(msg.Context, jsonReq.Method, jsonReq.Params)
sessionID := module.ContextValue[string](msg.Context, edgehttp.ContextKeySessionID)
reply, err := m.bus.Request(requestCtx, request)
if err != nil {
err = errors.WithStack(err)
logger.Error(
ctx, "could not execute rpc request",
logger.CapturedE(err),
)
if errors.Is(err, ErrMethodNotFound) {
if err := m.sendMethodNotFoundResponse(sessionID, jsonReq.ID); err != nil {
logger.Error(
ctx, "could not send json rpc error response",
logger.CapturedE(errors.WithStack(err)),
)
}
continue
}
if err := m.sendErrorResponse(sessionID, jsonReq.ID, err); err != nil {
logger.Error(
ctx, "could not send json rpc error response",
logger.CapturedE(errors.WithStack(err)),
)
}
continue
}
if err := m.sendResponse(sessionID, jsonReq.ID, reply.Message(), nil); err != nil {
logger.Error(
ctx, "could not send json rpc result response",
logger.CapturedE(err),
)
}
}
}
func (m *Module) sendErrorResponse(sessionID string, requestID any, err error) error {
return m.sendResponse(sessionID, requestID, nil, &JSONRPCError{
Code: -32603,
Message: err.Error(),
})
}
func (m *Module) sendMethodNotFoundResponse(sessionID string, requestID any) error {
return m.sendResponse(sessionID, requestID, nil, &JSONRPCError{
Code: -32601,
Message: "method not found",
})
}
func (m *Module) sendResponse(sessionID string, requestID any, result any, err error) error {
env := edgehttp.NewOutgoingMessageEnvelope(sessionID, map[string]interface{}{
"jsonrpc": "2.0",
"id": requestID,
"error": err,
"result": result,
})
if err := m.bus.Publish(env); err != nil {
return errors.WithStack(err)
}
return nil
}
func (m *Module) isRPCRequest(payload map[string]any) (*JSONRPCRequest, bool) {
jsonRPC, exists := payload["jsonrpc"]
if !exists || jsonRPC != "2.0" {
return nil, false
}
rawMethod, exists := payload["method"]
if !exists {
return nil, false
}
method, ok := rawMethod.(string)
if !ok {
return nil, false
}
id := payload["id"]
params := payload["params"]
return &JSONRPCRequest{
ID: id,
Method: method,
Params: params,
}, true
}
func ModuleFactory(bus bus.Bus) app.ServerModuleFactory {
return func(server *app.Server) app.ServerModule {
mod := &Module{
server: server,
bus: bus,
}
return mod
}
}
var _ app.InitializableModule = &Module{}

View File

@ -0,0 +1,109 @@
package rpc
import (
"context"
"os"
"sync"
"testing"
"forge.cadoles.com/arcad/edge/pkg/app"
"forge.cadoles.com/arcad/edge/pkg/bus"
"forge.cadoles.com/arcad/edge/pkg/bus/memory"
"forge.cadoles.com/arcad/edge/pkg/module"
"github.com/pkg/errors"
"gitlab.com/wpetit/goweb/logger"
)
func TestServerExecDeadlock(t *testing.T) {
if testing.Verbose() {
logger.SetLevel(logger.LevelDebug)
}
b := memory.NewBus(memory.WithBufferSize(1))
server := app.NewServer(
module.ConsoleModuleFactory(),
ModuleFactory(b),
module.LifecycleModuleFactory(),
)
data, err := os.ReadFile("testdata/deadlock.js")
if err != nil {
t.Fatalf("%+v", errors.WithStack(err))
}
ctx := context.Background()
t.Log("starting server")
if err := server.Start(ctx, "deadlock.js", string(data)); err != nil {
t.Fatalf("%+v", errors.WithStack(err))
}
defer server.Stop()
t.Log("server started")
count := 100
delay := 100
var wg sync.WaitGroup
wg.Add(count)
for i := 0; i < count; i++ {
go func(i int) {
defer wg.Done()
t.Logf("calling %d", i)
isCanceled := i%2 == 0
var ctx context.Context
if isCanceled {
canceledCtx, cancel := context.WithCancel(context.Background())
cancel()
ctx = canceledCtx
} else {
ctx = context.Background()
}
env := NewRequestEnvelope(ctx, "doSomethingLong", map[string]any{
"i": i,
"delay": delay,
})
t.Logf("publishing envelope #%d", i)
reply, err := b.Request(ctx, env)
if err != nil {
if errors.Is(err, context.Canceled) && isCanceled {
return
}
if errors.Is(err, bus.ErrNoResponse) && isCanceled {
return
}
t.Errorf("%+v", errors.WithStack(err))
return
}
result, ok := reply.Message().(int64)
if !ok {
t.Errorf("response.Result: expected type '%T', got '%T'", int64(0), reply.Message())
return
}
if e, g := i, int(result); e != g {
t.Errorf("response.Result: expected '%v', got '%v'", e, g)
return
}
}(i)
}
wg.Wait()
}

14
pkg/module/rpc/testdata/deadlock.js vendored Normal file
View File

@ -0,0 +1,14 @@
function onInit() {
rpc.register("doSomethingLong", doSomethingLong)
}
function doSomethingLong(ctx, params) {
var start = Date.now()
while (true) {
var now = Date.now()
if (now - start >= params.delay) break
}
return params.i;
}