From 17e2418af6752b25369e6730d8c54be74f5a5b59 Mon Sep 17 00:00:00 2001 From: William Petit Date: Tue, 28 Nov 2023 16:35:49 +0100 Subject: [PATCH] feat: rewrite bus to prevent deadlocks --- pkg/app/promise_proxy.go | 4 + pkg/app/server.go | 47 +++++++----- pkg/bus/memory/bus.go | 4 +- pkg/bus/memory/event_dispatcher.go | 45 +++++++----- pkg/bus/memory/request_reply.go | 22 ++++-- pkg/bus/message.go | 23 +++++- pkg/bus/testing/publish_subscribe.go | 10 ++- pkg/bus/testing/request_reply.go | 11 ++- pkg/module/lifecycle.go | 21 ++++-- pkg/module/rpc.go | 44 ++++++----- pkg/module/rpc_test.go | 106 +++++++++++++++++++++++++++ pkg/module/testdata/deadlock.js | 14 ++++ 12 files changed, 276 insertions(+), 75 deletions(-) create mode 100644 pkg/module/rpc_test.go create mode 100644 pkg/module/testdata/deadlock.js diff --git a/pkg/app/promise_proxy.go b/pkg/app/promise_proxy.go index 2031be1..dfdb4b4 100644 --- a/pkg/app/promise_proxy.go +++ b/pkg/app/promise_proxy.go @@ -47,6 +47,10 @@ func NewPromiseProxyFrom(rt *goja.Runtime) *PromiseProxy { } func IsPromise(v goja.Value) (*goja.Promise, bool) { + if v == nil { + return nil, false + } + promise, ok := v.Export().(*goja.Promise) return promise, ok } diff --git a/pkg/app/server.go b/pkg/app/server.go index 7f2c08e..9c96791 100644 --- a/pkg/app/server.go +++ b/pkg/app/server.go @@ -4,6 +4,7 @@ import ( "context" "math/rand" "sync" + "time" "github.com/dop251/goja" "github.com/dop251/goja_nodejs/eventloop" @@ -51,14 +52,15 @@ func (s *Server) ExecFuncByName(ctx context.Context, funcName string, args ...in func (s *Server) Exec(ctx context.Context, callableOrFuncname any, args ...interface{}) (goja.Value, error) { var ( - wg sync.WaitGroup value goja.Value err error ) - wg.Add(1) + done := make(chan struct{}) s.loop.RunOnLoop(func(rt *goja.Runtime) { + defer close(done) + var callable goja.Callable switch typ := callableOrFuncname.(type) { case goja.Callable: @@ -80,23 +82,18 @@ func (s *Server) Exec(ctx context.Context, callableOrFuncname any, args ...inter return } - logger.Debug(ctx, "executing callable") - - defer wg.Done() - defer func() { - if recovered := recover(); recovered != nil { - revoveredErr, ok := recovered.(error) - if ok { - logger.Error(ctx, "recovered runtime error", logger.CapturedE(errors.WithStack(revoveredErr))) - - err = errors.WithStack(ErrUnknownError) - - return - } + recovered := recover() + if recovered == nil { + return + } + recoveredErr, ok := recovered.(error) + if !ok { panic(recovered) } + + err = recoveredErr }() jsArgs := make([]goja.Value, 0, len(args)) @@ -104,16 +101,30 @@ func (s *Server) Exec(ctx context.Context, callableOrFuncname any, args ...inter jsArgs = append(jsArgs, rt.ToValue(a)) } + logger.Debug(ctx, "executing callable", logger.F("callable", callableOrFuncname)) + + start := time.Now() value, err = callable(nil, jsArgs...) if err != nil { err = errors.WithStack(err) } + + logger.Debug(ctx, "executed callable", logger.F("callable", callableOrFuncname), logger.F("duration", time.Since(start).String())) }) - wg.Wait() + select { + case <-ctx.Done(): + if err := ctx.Err(); err != nil { + err = errors.WithStack(err) + return nil, err + } - if err != nil { - return nil, errors.WithStack(err) + return nil, nil + + case <-done: + if err != nil { + return nil, errors.WithStack(err) + } } return value, nil diff --git a/pkg/bus/memory/bus.go b/pkg/bus/memory/bus.go index d6b3088..052a51c 100644 --- a/pkg/bus/memory/bus.go +++ b/pkg/bus/memory/bus.go @@ -41,12 +41,12 @@ func (b *Bus) Unsubscribe(ctx context.Context, ns bus.MessageNamespace, ch <-cha dispatchers.RemoveByOutChannel(ch) } -func (b *Bus) Publish(ctx context.Context, msg bus.Message) error { +func (b *Bus) Publish(msg bus.Message) error { dispatchers := b.getDispatchers(msg.MessageNamespace()) dispatchersList := dispatchers.List() logger.Debug( - ctx, "publishing message", + msg.Context(), "publishing message", logger.F("dispatchers", len(dispatchersList)), logger.F("messageNamespace", msg.MessageNamespace()), ) diff --git a/pkg/bus/memory/event_dispatcher.go b/pkg/bus/memory/event_dispatcher.go index a078939..7710211 100644 --- a/pkg/bus/memory/event_dispatcher.go +++ b/pkg/bus/memory/event_dispatcher.go @@ -3,7 +3,6 @@ package memory import ( "context" "sync" - "time" "forge.cadoles.com/arcad/edge/pkg/bus" "github.com/pkg/errors" @@ -99,7 +98,17 @@ func (d *eventDispatcher) In(msg bus.Message) (err error) { return } - d.in <- msg + ctx := msg.Context() + + select { + case <-ctx.Done(): + if err := ctx.Err(); err != nil { + return errors.WithStack(err) + } + + return nil + case d.in <- msg: + } return nil } @@ -130,29 +139,29 @@ func (d *eventDispatcher) Run(ctx context.Context) { }() for { - msg, ok := <-d.in - if !ok { - return - } - - timeout := time.After(time.Second) - select { - case d.out <- msg: - case <-timeout: - logger.Error( - ctx, - "out message channel timeout", - logger.F("message", msg), - ) + case msg, ok := <-d.in: + if !ok { + return + } - return + select { + case d.out <- msg: + case <-ctx.Done(): + logger.Error( + ctx, + "message subscription context canceled", + logger.F("message", msg), + logger.CapturedE(errors.WithStack(ctx.Err())), + ) + + return + } case <-ctx.Done(): logger.Error( ctx, "message subscription context canceled", - logger.F("message", msg), logger.CapturedE(errors.WithStack(ctx.Err())), ) diff --git a/pkg/bus/memory/request_reply.go b/pkg/bus/memory/request_reply.go index aaa9390..3a4cc40 100644 --- a/pkg/bus/memory/request_reply.go +++ b/pkg/bus/memory/request_reply.go @@ -20,25 +20,35 @@ type RequestMessage struct { Message bus.Message - ns bus.MessageNamespace + ns bus.MessageNamespace + ctx context.Context } func (m *RequestMessage) MessageNamespace() bus.MessageNamespace { return m.ns } +func (m *RequestMessage) Context() context.Context { + return m.ctx +} + type ReplyMessage struct { RequestID uint64 Message bus.Message Error error - ns bus.MessageNamespace + ns bus.MessageNamespace + ctx context.Context } func (m *ReplyMessage) MessageNamespace() bus.MessageNamespace { return m.ns } +func (m *ReplyMessage) Context() context.Context { + return m.ctx +} + func (b *Bus) Request(ctx context.Context, msg bus.Message) (bus.Message, error) { requestID := atomic.AddUint64(&b.nextRequestID, 1) @@ -46,6 +56,7 @@ func (b *Bus) Request(ctx context.Context, msg bus.Message) (bus.Message, error) RequestID: requestID, Message: msg, ns: msg.MessageNamespace(), + ctx: ctx, } replyNamespace := createReplyNamespace(requestID) @@ -61,7 +72,7 @@ func (b *Bus) Request(ctx context.Context, msg bus.Message) (bus.Message, error) logger.Debug(ctx, "publishing request", logger.F("request", req)) - if err := b.Publish(ctx, req); err != nil { + if err := b.Publish(req); err != nil { return nil, errors.WithStack(err) } @@ -125,7 +136,8 @@ func (b *Bus) Reply(ctx context.Context, msgNamespace bus.MessageNamespace, h bu Message: nil, Error: nil, - ns: createReplyNamespace(request.RequestID), + ns: createReplyNamespace(request.RequestID), + ctx: ctx, } if err != nil { @@ -136,7 +148,7 @@ func (b *Bus) Reply(ctx context.Context, msgNamespace bus.MessageNamespace, h bu logger.Debug(ctx, "publishing reply", logger.F("reply", reply)) - if err := b.Publish(ctx, reply); err != nil { + if err := b.Publish(reply); err != nil { return errors.WithStack(err) } } diff --git a/pkg/bus/message.go b/pkg/bus/message.go index 3a470d1..43034c2 100644 --- a/pkg/bus/message.go +++ b/pkg/bus/message.go @@ -1,6 +1,7 @@ package bus import ( + "context" "strings" "github.com/pkg/errors" @@ -8,10 +9,28 @@ import ( type ( MessageNamespace string + Address string + Message any ) -type Message interface { - MessageNamespace() MessageNamespace +type Envelope interface { + Message() any + Context() context.Context + Address() Address +} + +type BaseEnvelope struct { + msg Message + ctx context.Context + addr Address +} + +func NewEnvelope(ctx context.Context, addr Address, msg Message) *BaseEnvelope { + return &BaseEnvelope{ + ctx: ctx, + addr: addr, + msg: msg, + } } func NewMessageNamespace(namespaces ...MessageNamespace) MessageNamespace { diff --git a/pkg/bus/testing/publish_subscribe.go b/pkg/bus/testing/publish_subscribe.go index 6db69e3..12ae7f3 100644 --- a/pkg/bus/testing/publish_subscribe.go +++ b/pkg/bus/testing/publish_subscribe.go @@ -15,12 +15,18 @@ const ( testNamespace bus.MessageNamespace = "testNamespace" ) -type testMessage struct{} +type testMessage struct { + ctx context.Context +} -func (e *testMessage) MessageNamespace() bus.MessageNamespace { +func (m *testMessage) MessageNamespace() bus.MessageNamespace { return testNamespace } +func (m *testMessage) Context() context.Context { + return m.ctx +} + func TestPublishSubscribe(t *testing.T, b bus.Bus) { ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() diff --git a/pkg/bus/testing/request_reply.go b/pkg/bus/testing/request_reply.go index 22ceddd..39ce5c3 100644 --- a/pkg/bus/testing/request_reply.go +++ b/pkg/bus/testing/request_reply.go @@ -15,13 +15,18 @@ const ( ) type testReqResMessage struct { - i int + i int + ctx context.Context } func (m *testReqResMessage) MessageNamespace() bus.MessageNamespace { return testNamespace } +func (m *testReqResMessage) Context() context.Context { + return m.ctx +} + func TestRequestReply(t *testing.T, b bus.Bus) { expectedRoundTrips := 256 timeout := time.Now().Add(time.Duration(expectedRoundTrips) * time.Second) @@ -47,7 +52,7 @@ func TestRequestReply(t *testing.T, b bus.Bus) { return nil, errors.WithStack(bus.ErrUnexpectedMessage) } - result := &testReqResMessage{req.i} + result := &testReqResMessage{req.i, context.Background()} // Simulate random work time.Sleep(time.Millisecond * 100) @@ -75,7 +80,7 @@ func TestRequestReply(t *testing.T, b bus.Bus) { requestCtx, cancelRequest := context.WithDeadline(context.Background(), timeout) defer cancelRequest() - req := &testReqResMessage{i} + req := &testReqResMessage{i, context.Background()} t.Logf("[REQ] sending req #%d", i) diff --git a/pkg/module/lifecycle.go b/pkg/module/lifecycle.go index 983617f..af98219 100644 --- a/pkg/module/lifecycle.go +++ b/pkg/module/lifecycle.go @@ -5,7 +5,6 @@ import ( "forge.cadoles.com/arcad/edge/pkg/app" "github.com/dop251/goja" - "github.com/pkg/errors" "gitlab.com/wpetit/goweb/logger" ) @@ -19,16 +18,28 @@ func (m *LifecycleModule) Export(export *goja.Object) { } func (m *LifecycleModule) OnInit(ctx context.Context, rt *goja.Runtime) (err error) { - _, ok := goja.AssertFunction(rt.Get("onInit")) + call, ok := goja.AssertFunction(rt.Get("onInit")) if !ok { logger.Warn(ctx, "could not find onInit() function") return nil } - if _, err := rt.RunString("setTimeout(onInit, 0)"); err != nil { - return errors.WithStack(err) - } + defer func() { + recovered := recover() + if recovered == nil { + return + } + + recoveredErr, ok := recovered.(error) + if !ok { + panic(recovered) + } + + err = recoveredErr + }() + + call(nil, rt.ToValue(ctx)) return nil } diff --git a/pkg/module/rpc.go b/pkg/module/rpc.go index 5877ffb..ab09b87 100644 --- a/pkg/module/rpc.go +++ b/pkg/module/rpc.go @@ -2,7 +2,6 @@ package module import ( "context" - "fmt" "sync" "forge.cadoles.com/arcad/edge/pkg/app" @@ -52,7 +51,12 @@ func (m *RPCModule) Export(export *goja.Object) { } func (m *RPCModule) OnInit(ctx context.Context, rt *goja.Runtime) error { - go m.handleMessages(ctx) + clientMessages, err := m.bus.Subscribe(ctx, MessageNamespaceClient) + if err != nil { + return errors.WithStack(err) + } + + go m.handleMessages(ctx, clientMessages) return nil } @@ -92,25 +96,25 @@ func (m *RPCModule) unregister(call goja.FunctionCall, rt *goja.Runtime) goja.Va return nil } -func (m *RPCModule) handleMessages(ctx context.Context) { - clientMessages, err := m.bus.Subscribe(ctx, MessageNamespaceClient) - if err != nil { - panic(errors.WithStack(err)) - } - +func (m *RPCModule) handleMessages(ctx context.Context, clientMessages <-chan bus.Message) { defer func() { m.bus.Unsubscribe(ctx, MessageNamespaceClient, clientMessages) }() sendRes := func(ctx context.Context, req *RPCRequest, result goja.Value) { + var rawResult any + if result != nil { + rawResult = result.Export() + } + res := &RPCResponse{ ID: req.ID, - Result: result.Export(), + Result: rawResult, } logger.Debug(ctx, "sending rpc response", logger.F("response", res)) - if err := m.sendResponse(ctx, res); err != nil { + if err := m.sendResponse(res); err != nil { logger.Error( ctx, "could not send response", logger.CapturedE(errors.WithStack(err)), @@ -144,7 +148,7 @@ func (m *RPCModule) handleMessage(ctx context.Context, msg bus.Message, sendRes if !exists { logger.Debug(ctx, "method not found", logger.F("req", req)) - if err := m.sendMethodNotFoundResponse(clientMessage.Context, req); err != nil { + if err := m.sendMethodNotFoundResponse(req); err != nil { logger.Error( ctx, "could not send method not found response", logger.CapturedE(errors.WithStack(err)), @@ -159,7 +163,7 @@ func (m *RPCModule) handleMessage(ctx context.Context, msg bus.Message, sendRes if !ok { logger.Debug(ctx, "invalid method", logger.F("req", req)) - if err := m.sendMethodNotFoundResponse(clientMessage.Context, req); err != nil { + if err := m.sendMethodNotFoundResponse(req); err != nil { logger.Error( ctx, "could not send method not found response", logger.CapturedE(errors.WithStack(err)), @@ -178,7 +182,7 @@ func (m *RPCModule) handleMessage(ctx context.Context, msg bus.Message, sendRes logger.F("request", req), ) - if err := m.sendErrorResponse(clientMessage.Context, req, err); err != nil { + if err := m.sendErrorResponse(context.Background(), req, err); err != nil { logger.Error( ctx, "could not send error response", logger.CapturedE(errors.WithStack(err)), @@ -202,7 +206,7 @@ func (m *RPCModule) handleMessage(ctx context.Context, msg bus.Message, sendRes } func (m *RPCModule) sendErrorResponse(ctx context.Context, req *RPCRequest, err error) error { - return m.sendResponse(ctx, &RPCResponse{ + return m.sendResponse(&RPCResponse{ ID: req.ID, Result: nil, Error: &RPCError{ @@ -212,26 +216,26 @@ func (m *RPCModule) sendErrorResponse(ctx context.Context, req *RPCRequest, err }) } -func (m *RPCModule) sendMethodNotFoundResponse(ctx context.Context, req *RPCRequest) error { - return m.sendResponse(ctx, &RPCResponse{ +func (m *RPCModule) sendMethodNotFoundResponse(req *RPCRequest) error { + return m.sendResponse(&RPCResponse{ ID: req.ID, Result: nil, Error: &RPCError{ Code: -32601, - Message: fmt.Sprintf("method not found"), + Message: "method not found", }, }) } -func (m *RPCModule) sendResponse(ctx context.Context, res *RPCResponse) error { - msg := NewServerMessage(ctx, map[string]interface{}{ +func (m *RPCModule) sendResponse(res *RPCResponse) error { + msg := NewServerMessage(context.Background(), map[string]interface{}{ "jsonrpc": "2.0", "id": res.ID, "error": res.Error, "result": res.Result, }) - if err := m.bus.Publish(ctx, msg); err != nil { + if err := m.bus.Publish(context.Background(), msg); err != nil { return errors.WithStack(err) } diff --git a/pkg/module/rpc_test.go b/pkg/module/rpc_test.go new file mode 100644 index 0000000..8e8c3fb --- /dev/null +++ b/pkg/module/rpc_test.go @@ -0,0 +1,106 @@ +package module + +import ( + "context" + "fmt" + "os" + "sync" + "testing" + "time" + + "forge.cadoles.com/arcad/edge/pkg/app" + "forge.cadoles.com/arcad/edge/pkg/bus/memory" + "github.com/pkg/errors" + "gitlab.com/wpetit/goweb/logger" +) + +func TestServerExecDeadlock(t *testing.T) { + if testing.Verbose() { + logger.SetLevel(logger.LevelDebug) + } + + bus := memory.NewBus(memory.WithBufferSize(1)) + + server := app.NewServer( + ConsoleModuleFactory(), + RPCModuleFactory(bus), + LifecycleModuleFactory(), + ) + + data, err := os.ReadFile("testdata/deadlock.js") + if err != nil { + t.Fatalf("%+v", errors.WithStack(err)) + } + + if err := server.Load("deadlock.js", string(data)); err != nil { + t.Fatalf("%+v", errors.WithStack(err)) + } + + ctx := context.Background() + + if err := server.Start(ctx); err != nil { + t.Fatalf("%+v", errors.WithStack(err)) + } + + defer server.Stop() + + messages, err := bus.Subscribe(ctx, MessageNamespaceServer) + if err != nil { + t.Fatalf("%+v", errors.WithStack(err)) + } + + count := 100 + delay := 500 + + var wg sync.WaitGroup + + wg.Add(1) + + go func() { + defer wg.Done() + + received := 0 + for range messages { + received++ + t.Logf("received %d", received) + if received == count { + bus.Unsubscribe(ctx, MessageNamespaceServer, messages) + + return + } + } + }() + + wg.Add(count) + + for i := 0; i < count; i++ { + go func(i int) { + defer wg.Done() + + timeout := time.Duration(delay*5) * time.Millisecond + + t.Logf("sending message %d with timeout %s", i, timeout.String()) + + // ctx, cancel := context.WithTimeout(context.Background(), timeout) + // defer cancel() + + msg := NewClientMessage(context.Background(), map[string]any{ + "jsonrpc": "2.0", + "method": "doSomethingLong", + "id": fmt.Sprintf("msg-%d", i), + "params": map[string]any{ + "i": i, + "delay": delay, + }, + }) + + t.Logf("publishing message #%d", i) + + if err := bus.Publish(ctx, msg); err != nil { + t.Errorf("%+v", errors.WithStack(err)) + } + }(i) + } + + wg.Wait() +} diff --git a/pkg/module/testdata/deadlock.js b/pkg/module/testdata/deadlock.js new file mode 100644 index 0000000..8a667a6 --- /dev/null +++ b/pkg/module/testdata/deadlock.js @@ -0,0 +1,14 @@ +function onInit() { + rpc.register("doSomethingLong", doSomethingLong) +} + +function doSomethingLong(ctx, params) { + var start = Date.now() + + while (true) { + var now = Date.now() + if (now - start >= params.delay) break + } + + return params.i; +}