feat: rewrite bus to prevent deadlocks
Some checks failed
arcad/edge/pipeline/head There was a failure building this commit

This commit is contained in:
wpetit 2023-11-28 16:35:49 +01:00
parent 02c74b6f8d
commit 17e2418af6
12 changed files with 276 additions and 75 deletions

View File

@ -47,6 +47,10 @@ func NewPromiseProxyFrom(rt *goja.Runtime) *PromiseProxy {
} }
func IsPromise(v goja.Value) (*goja.Promise, bool) { func IsPromise(v goja.Value) (*goja.Promise, bool) {
if v == nil {
return nil, false
}
promise, ok := v.Export().(*goja.Promise) promise, ok := v.Export().(*goja.Promise)
return promise, ok return promise, ok
} }

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"math/rand" "math/rand"
"sync" "sync"
"time"
"github.com/dop251/goja" "github.com/dop251/goja"
"github.com/dop251/goja_nodejs/eventloop" "github.com/dop251/goja_nodejs/eventloop"
@ -51,14 +52,15 @@ func (s *Server) ExecFuncByName(ctx context.Context, funcName string, args ...in
func (s *Server) Exec(ctx context.Context, callableOrFuncname any, args ...interface{}) (goja.Value, error) { func (s *Server) Exec(ctx context.Context, callableOrFuncname any, args ...interface{}) (goja.Value, error) {
var ( var (
wg sync.WaitGroup
value goja.Value value goja.Value
err error err error
) )
wg.Add(1) done := make(chan struct{})
s.loop.RunOnLoop(func(rt *goja.Runtime) { s.loop.RunOnLoop(func(rt *goja.Runtime) {
defer close(done)
var callable goja.Callable var callable goja.Callable
switch typ := callableOrFuncname.(type) { switch typ := callableOrFuncname.(type) {
case goja.Callable: case goja.Callable:
@ -80,23 +82,18 @@ func (s *Server) Exec(ctx context.Context, callableOrFuncname any, args ...inter
return return
} }
logger.Debug(ctx, "executing callable")
defer wg.Done()
defer func() { defer func() {
if recovered := recover(); recovered != nil { recovered := recover()
revoveredErr, ok := recovered.(error) if recovered == nil {
if ok { return
logger.Error(ctx, "recovered runtime error", logger.CapturedE(errors.WithStack(revoveredErr))) }
err = errors.WithStack(ErrUnknownError)
return
}
recoveredErr, ok := recovered.(error)
if !ok {
panic(recovered) panic(recovered)
} }
err = recoveredErr
}() }()
jsArgs := make([]goja.Value, 0, len(args)) jsArgs := make([]goja.Value, 0, len(args))
@ -104,16 +101,30 @@ func (s *Server) Exec(ctx context.Context, callableOrFuncname any, args ...inter
jsArgs = append(jsArgs, rt.ToValue(a)) jsArgs = append(jsArgs, rt.ToValue(a))
} }
logger.Debug(ctx, "executing callable", logger.F("callable", callableOrFuncname))
start := time.Now()
value, err = callable(nil, jsArgs...) value, err = callable(nil, jsArgs...)
if err != nil { if err != nil {
err = errors.WithStack(err) err = errors.WithStack(err)
} }
logger.Debug(ctx, "executed callable", logger.F("callable", callableOrFuncname), logger.F("duration", time.Since(start).String()))
}) })
wg.Wait() select {
case <-ctx.Done():
if err := ctx.Err(); err != nil {
err = errors.WithStack(err)
return nil, err
}
if err != nil { return nil, nil
return nil, errors.WithStack(err)
case <-done:
if err != nil {
return nil, errors.WithStack(err)
}
} }
return value, nil return value, nil

View File

@ -41,12 +41,12 @@ func (b *Bus) Unsubscribe(ctx context.Context, ns bus.MessageNamespace, ch <-cha
dispatchers.RemoveByOutChannel(ch) dispatchers.RemoveByOutChannel(ch)
} }
func (b *Bus) Publish(ctx context.Context, msg bus.Message) error { func (b *Bus) Publish(msg bus.Message) error {
dispatchers := b.getDispatchers(msg.MessageNamespace()) dispatchers := b.getDispatchers(msg.MessageNamespace())
dispatchersList := dispatchers.List() dispatchersList := dispatchers.List()
logger.Debug( logger.Debug(
ctx, "publishing message", msg.Context(), "publishing message",
logger.F("dispatchers", len(dispatchersList)), logger.F("dispatchers", len(dispatchersList)),
logger.F("messageNamespace", msg.MessageNamespace()), logger.F("messageNamespace", msg.MessageNamespace()),
) )

View File

@ -3,7 +3,6 @@ package memory
import ( import (
"context" "context"
"sync" "sync"
"time"
"forge.cadoles.com/arcad/edge/pkg/bus" "forge.cadoles.com/arcad/edge/pkg/bus"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -99,7 +98,17 @@ func (d *eventDispatcher) In(msg bus.Message) (err error) {
return return
} }
d.in <- msg ctx := msg.Context()
select {
case <-ctx.Done():
if err := ctx.Err(); err != nil {
return errors.WithStack(err)
}
return nil
case d.in <- msg:
}
return nil return nil
} }
@ -130,29 +139,29 @@ func (d *eventDispatcher) Run(ctx context.Context) {
}() }()
for { for {
msg, ok := <-d.in
if !ok {
return
}
timeout := time.After(time.Second)
select { select {
case d.out <- msg: case msg, ok := <-d.in:
case <-timeout: if !ok {
logger.Error( return
ctx, }
"out message channel timeout",
logger.F("message", msg),
)
return select {
case d.out <- msg:
case <-ctx.Done():
logger.Error(
ctx,
"message subscription context canceled",
logger.F("message", msg),
logger.CapturedE(errors.WithStack(ctx.Err())),
)
return
}
case <-ctx.Done(): case <-ctx.Done():
logger.Error( logger.Error(
ctx, ctx,
"message subscription context canceled", "message subscription context canceled",
logger.F("message", msg),
logger.CapturedE(errors.WithStack(ctx.Err())), logger.CapturedE(errors.WithStack(ctx.Err())),
) )

View File

@ -20,25 +20,35 @@ type RequestMessage struct {
Message bus.Message Message bus.Message
ns bus.MessageNamespace ns bus.MessageNamespace
ctx context.Context
} }
func (m *RequestMessage) MessageNamespace() bus.MessageNamespace { func (m *RequestMessage) MessageNamespace() bus.MessageNamespace {
return m.ns return m.ns
} }
func (m *RequestMessage) Context() context.Context {
return m.ctx
}
type ReplyMessage struct { type ReplyMessage struct {
RequestID uint64 RequestID uint64
Message bus.Message Message bus.Message
Error error Error error
ns bus.MessageNamespace ns bus.MessageNamespace
ctx context.Context
} }
func (m *ReplyMessage) MessageNamespace() bus.MessageNamespace { func (m *ReplyMessage) MessageNamespace() bus.MessageNamespace {
return m.ns return m.ns
} }
func (m *ReplyMessage) Context() context.Context {
return m.ctx
}
func (b *Bus) Request(ctx context.Context, msg bus.Message) (bus.Message, error) { func (b *Bus) Request(ctx context.Context, msg bus.Message) (bus.Message, error) {
requestID := atomic.AddUint64(&b.nextRequestID, 1) requestID := atomic.AddUint64(&b.nextRequestID, 1)
@ -46,6 +56,7 @@ func (b *Bus) Request(ctx context.Context, msg bus.Message) (bus.Message, error)
RequestID: requestID, RequestID: requestID,
Message: msg, Message: msg,
ns: msg.MessageNamespace(), ns: msg.MessageNamespace(),
ctx: ctx,
} }
replyNamespace := createReplyNamespace(requestID) replyNamespace := createReplyNamespace(requestID)
@ -61,7 +72,7 @@ func (b *Bus) Request(ctx context.Context, msg bus.Message) (bus.Message, error)
logger.Debug(ctx, "publishing request", logger.F("request", req)) logger.Debug(ctx, "publishing request", logger.F("request", req))
if err := b.Publish(ctx, req); err != nil { if err := b.Publish(req); err != nil {
return nil, errors.WithStack(err) return nil, errors.WithStack(err)
} }
@ -125,7 +136,8 @@ func (b *Bus) Reply(ctx context.Context, msgNamespace bus.MessageNamespace, h bu
Message: nil, Message: nil,
Error: nil, Error: nil,
ns: createReplyNamespace(request.RequestID), ns: createReplyNamespace(request.RequestID),
ctx: ctx,
} }
if err != nil { if err != nil {
@ -136,7 +148,7 @@ func (b *Bus) Reply(ctx context.Context, msgNamespace bus.MessageNamespace, h bu
logger.Debug(ctx, "publishing reply", logger.F("reply", reply)) logger.Debug(ctx, "publishing reply", logger.F("reply", reply))
if err := b.Publish(ctx, reply); err != nil { if err := b.Publish(reply); err != nil {
return errors.WithStack(err) return errors.WithStack(err)
} }
} }

View File

@ -1,6 +1,7 @@
package bus package bus
import ( import (
"context"
"strings" "strings"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -8,10 +9,28 @@ import (
type ( type (
MessageNamespace string MessageNamespace string
Address string
Message any
) )
type Message interface { type Envelope interface {
MessageNamespace() MessageNamespace Message() any
Context() context.Context
Address() Address
}
type BaseEnvelope struct {
msg Message
ctx context.Context
addr Address
}
func NewEnvelope(ctx context.Context, addr Address, msg Message) *BaseEnvelope {
return &BaseEnvelope{
ctx: ctx,
addr: addr,
msg: msg,
}
} }
func NewMessageNamespace(namespaces ...MessageNamespace) MessageNamespace { func NewMessageNamespace(namespaces ...MessageNamespace) MessageNamespace {

View File

@ -15,12 +15,18 @@ const (
testNamespace bus.MessageNamespace = "testNamespace" testNamespace bus.MessageNamespace = "testNamespace"
) )
type testMessage struct{} type testMessage struct {
ctx context.Context
}
func (e *testMessage) MessageNamespace() bus.MessageNamespace { func (m *testMessage) MessageNamespace() bus.MessageNamespace {
return testNamespace return testNamespace
} }
func (m *testMessage) Context() context.Context {
return m.ctx
}
func TestPublishSubscribe(t *testing.T, b bus.Bus) { func TestPublishSubscribe(t *testing.T, b bus.Bus) {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute) ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel() defer cancel()

View File

@ -15,13 +15,18 @@ const (
) )
type testReqResMessage struct { type testReqResMessage struct {
i int i int
ctx context.Context
} }
func (m *testReqResMessage) MessageNamespace() bus.MessageNamespace { func (m *testReqResMessage) MessageNamespace() bus.MessageNamespace {
return testNamespace return testNamespace
} }
func (m *testReqResMessage) Context() context.Context {
return m.ctx
}
func TestRequestReply(t *testing.T, b bus.Bus) { func TestRequestReply(t *testing.T, b bus.Bus) {
expectedRoundTrips := 256 expectedRoundTrips := 256
timeout := time.Now().Add(time.Duration(expectedRoundTrips) * time.Second) timeout := time.Now().Add(time.Duration(expectedRoundTrips) * time.Second)
@ -47,7 +52,7 @@ func TestRequestReply(t *testing.T, b bus.Bus) {
return nil, errors.WithStack(bus.ErrUnexpectedMessage) return nil, errors.WithStack(bus.ErrUnexpectedMessage)
} }
result := &testReqResMessage{req.i} result := &testReqResMessage{req.i, context.Background()}
// Simulate random work // Simulate random work
time.Sleep(time.Millisecond * 100) time.Sleep(time.Millisecond * 100)
@ -75,7 +80,7 @@ func TestRequestReply(t *testing.T, b bus.Bus) {
requestCtx, cancelRequest := context.WithDeadline(context.Background(), timeout) requestCtx, cancelRequest := context.WithDeadline(context.Background(), timeout)
defer cancelRequest() defer cancelRequest()
req := &testReqResMessage{i} req := &testReqResMessage{i, context.Background()}
t.Logf("[REQ] sending req #%d", i) t.Logf("[REQ] sending req #%d", i)

View File

@ -5,7 +5,6 @@ import (
"forge.cadoles.com/arcad/edge/pkg/app" "forge.cadoles.com/arcad/edge/pkg/app"
"github.com/dop251/goja" "github.com/dop251/goja"
"github.com/pkg/errors"
"gitlab.com/wpetit/goweb/logger" "gitlab.com/wpetit/goweb/logger"
) )
@ -19,16 +18,28 @@ func (m *LifecycleModule) Export(export *goja.Object) {
} }
func (m *LifecycleModule) OnInit(ctx context.Context, rt *goja.Runtime) (err error) { func (m *LifecycleModule) OnInit(ctx context.Context, rt *goja.Runtime) (err error) {
_, ok := goja.AssertFunction(rt.Get("onInit")) call, ok := goja.AssertFunction(rt.Get("onInit"))
if !ok { if !ok {
logger.Warn(ctx, "could not find onInit() function") logger.Warn(ctx, "could not find onInit() function")
return nil return nil
} }
if _, err := rt.RunString("setTimeout(onInit, 0)"); err != nil { defer func() {
return errors.WithStack(err) recovered := recover()
} if recovered == nil {
return
}
recoveredErr, ok := recovered.(error)
if !ok {
panic(recovered)
}
err = recoveredErr
}()
call(nil, rt.ToValue(ctx))
return nil return nil
} }

View File

@ -2,7 +2,6 @@ package module
import ( import (
"context" "context"
"fmt"
"sync" "sync"
"forge.cadoles.com/arcad/edge/pkg/app" "forge.cadoles.com/arcad/edge/pkg/app"
@ -52,7 +51,12 @@ func (m *RPCModule) Export(export *goja.Object) {
} }
func (m *RPCModule) OnInit(ctx context.Context, rt *goja.Runtime) error { func (m *RPCModule) OnInit(ctx context.Context, rt *goja.Runtime) error {
go m.handleMessages(ctx) clientMessages, err := m.bus.Subscribe(ctx, MessageNamespaceClient)
if err != nil {
return errors.WithStack(err)
}
go m.handleMessages(ctx, clientMessages)
return nil return nil
} }
@ -92,25 +96,25 @@ func (m *RPCModule) unregister(call goja.FunctionCall, rt *goja.Runtime) goja.Va
return nil return nil
} }
func (m *RPCModule) handleMessages(ctx context.Context) { func (m *RPCModule) handleMessages(ctx context.Context, clientMessages <-chan bus.Message) {
clientMessages, err := m.bus.Subscribe(ctx, MessageNamespaceClient)
if err != nil {
panic(errors.WithStack(err))
}
defer func() { defer func() {
m.bus.Unsubscribe(ctx, MessageNamespaceClient, clientMessages) m.bus.Unsubscribe(ctx, MessageNamespaceClient, clientMessages)
}() }()
sendRes := func(ctx context.Context, req *RPCRequest, result goja.Value) { sendRes := func(ctx context.Context, req *RPCRequest, result goja.Value) {
var rawResult any
if result != nil {
rawResult = result.Export()
}
res := &RPCResponse{ res := &RPCResponse{
ID: req.ID, ID: req.ID,
Result: result.Export(), Result: rawResult,
} }
logger.Debug(ctx, "sending rpc response", logger.F("response", res)) logger.Debug(ctx, "sending rpc response", logger.F("response", res))
if err := m.sendResponse(ctx, res); err != nil { if err := m.sendResponse(res); err != nil {
logger.Error( logger.Error(
ctx, "could not send response", ctx, "could not send response",
logger.CapturedE(errors.WithStack(err)), logger.CapturedE(errors.WithStack(err)),
@ -144,7 +148,7 @@ func (m *RPCModule) handleMessage(ctx context.Context, msg bus.Message, sendRes
if !exists { if !exists {
logger.Debug(ctx, "method not found", logger.F("req", req)) logger.Debug(ctx, "method not found", logger.F("req", req))
if err := m.sendMethodNotFoundResponse(clientMessage.Context, req); err != nil { if err := m.sendMethodNotFoundResponse(req); err != nil {
logger.Error( logger.Error(
ctx, "could not send method not found response", ctx, "could not send method not found response",
logger.CapturedE(errors.WithStack(err)), logger.CapturedE(errors.WithStack(err)),
@ -159,7 +163,7 @@ func (m *RPCModule) handleMessage(ctx context.Context, msg bus.Message, sendRes
if !ok { if !ok {
logger.Debug(ctx, "invalid method", logger.F("req", req)) logger.Debug(ctx, "invalid method", logger.F("req", req))
if err := m.sendMethodNotFoundResponse(clientMessage.Context, req); err != nil { if err := m.sendMethodNotFoundResponse(req); err != nil {
logger.Error( logger.Error(
ctx, "could not send method not found response", ctx, "could not send method not found response",
logger.CapturedE(errors.WithStack(err)), logger.CapturedE(errors.WithStack(err)),
@ -178,7 +182,7 @@ func (m *RPCModule) handleMessage(ctx context.Context, msg bus.Message, sendRes
logger.F("request", req), logger.F("request", req),
) )
if err := m.sendErrorResponse(clientMessage.Context, req, err); err != nil { if err := m.sendErrorResponse(context.Background(), req, err); err != nil {
logger.Error( logger.Error(
ctx, "could not send error response", ctx, "could not send error response",
logger.CapturedE(errors.WithStack(err)), logger.CapturedE(errors.WithStack(err)),
@ -202,7 +206,7 @@ func (m *RPCModule) handleMessage(ctx context.Context, msg bus.Message, sendRes
} }
func (m *RPCModule) sendErrorResponse(ctx context.Context, req *RPCRequest, err error) error { func (m *RPCModule) sendErrorResponse(ctx context.Context, req *RPCRequest, err error) error {
return m.sendResponse(ctx, &RPCResponse{ return m.sendResponse(&RPCResponse{
ID: req.ID, ID: req.ID,
Result: nil, Result: nil,
Error: &RPCError{ Error: &RPCError{
@ -212,26 +216,26 @@ func (m *RPCModule) sendErrorResponse(ctx context.Context, req *RPCRequest, err
}) })
} }
func (m *RPCModule) sendMethodNotFoundResponse(ctx context.Context, req *RPCRequest) error { func (m *RPCModule) sendMethodNotFoundResponse(req *RPCRequest) error {
return m.sendResponse(ctx, &RPCResponse{ return m.sendResponse(&RPCResponse{
ID: req.ID, ID: req.ID,
Result: nil, Result: nil,
Error: &RPCError{ Error: &RPCError{
Code: -32601, Code: -32601,
Message: fmt.Sprintf("method not found"), Message: "method not found",
}, },
}) })
} }
func (m *RPCModule) sendResponse(ctx context.Context, res *RPCResponse) error { func (m *RPCModule) sendResponse(res *RPCResponse) error {
msg := NewServerMessage(ctx, map[string]interface{}{ msg := NewServerMessage(context.Background(), map[string]interface{}{
"jsonrpc": "2.0", "jsonrpc": "2.0",
"id": res.ID, "id": res.ID,
"error": res.Error, "error": res.Error,
"result": res.Result, "result": res.Result,
}) })
if err := m.bus.Publish(ctx, msg); err != nil { if err := m.bus.Publish(context.Background(), msg); err != nil {
return errors.WithStack(err) return errors.WithStack(err)
} }

106
pkg/module/rpc_test.go Normal file
View File

@ -0,0 +1,106 @@
package module
import (
"context"
"fmt"
"os"
"sync"
"testing"
"time"
"forge.cadoles.com/arcad/edge/pkg/app"
"forge.cadoles.com/arcad/edge/pkg/bus/memory"
"github.com/pkg/errors"
"gitlab.com/wpetit/goweb/logger"
)
func TestServerExecDeadlock(t *testing.T) {
if testing.Verbose() {
logger.SetLevel(logger.LevelDebug)
}
bus := memory.NewBus(memory.WithBufferSize(1))
server := app.NewServer(
ConsoleModuleFactory(),
RPCModuleFactory(bus),
LifecycleModuleFactory(),
)
data, err := os.ReadFile("testdata/deadlock.js")
if err != nil {
t.Fatalf("%+v", errors.WithStack(err))
}
if err := server.Load("deadlock.js", string(data)); err != nil {
t.Fatalf("%+v", errors.WithStack(err))
}
ctx := context.Background()
if err := server.Start(ctx); err != nil {
t.Fatalf("%+v", errors.WithStack(err))
}
defer server.Stop()
messages, err := bus.Subscribe(ctx, MessageNamespaceServer)
if err != nil {
t.Fatalf("%+v", errors.WithStack(err))
}
count := 100
delay := 500
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
received := 0
for range messages {
received++
t.Logf("received %d", received)
if received == count {
bus.Unsubscribe(ctx, MessageNamespaceServer, messages)
return
}
}
}()
wg.Add(count)
for i := 0; i < count; i++ {
go func(i int) {
defer wg.Done()
timeout := time.Duration(delay*5) * time.Millisecond
t.Logf("sending message %d with timeout %s", i, timeout.String())
// ctx, cancel := context.WithTimeout(context.Background(), timeout)
// defer cancel()
msg := NewClientMessage(context.Background(), map[string]any{
"jsonrpc": "2.0",
"method": "doSomethingLong",
"id": fmt.Sprintf("msg-%d", i),
"params": map[string]any{
"i": i,
"delay": delay,
},
})
t.Logf("publishing message #%d", i)
if err := bus.Publish(ctx, msg); err != nil {
t.Errorf("%+v", errors.WithStack(err))
}
}(i)
}
wg.Wait()
}

14
pkg/module/testdata/deadlock.js vendored Normal file
View File

@ -0,0 +1,14 @@
function onInit() {
rpc.register("doSomethingLong", doSomethingLong)
}
function doSomethingLong(ctx, params) {
var start = Date.now()
while (true) {
var now = Date.now()
if (now - start >= params.delay) break
}
return params.i;
}