diff --git a/pkg/bus/memory/bus.go b/pkg/bus/memory/bus.go index 9ac7f3a..d6b3088 100644 --- a/pkg/bus/memory/bus.go +++ b/pkg/bus/memory/bus.go @@ -22,13 +22,13 @@ func (b *Bus) Subscribe(ctx context.Context, ns bus.MessageNamespace) (<-chan bu ) dispatchers := b.getDispatchers(ns) - d := newEventDispatcher(b.opt.BufferSize) + disp := newEventDispatcher(b.opt.BufferSize) - go d.Run() + go disp.Run(ctx) - dispatchers.Add(d) + dispatchers.Add(disp) - return d.Out(), nil + return disp.Out(), nil } func (b *Bus) Unsubscribe(ctx context.Context, ns bus.MessageNamespace, ch <-chan bus.Message) { @@ -52,6 +52,12 @@ func (b *Bus) Publish(ctx context.Context, msg bus.Message) error { ) for _, d := range dispatchersList { + if d.Closed() { + dispatchers.Remove(d) + + continue + } + if err := d.In(msg); err != nil { return errors.WithStack(err) } diff --git a/pkg/bus/memory/event_dispatcher.go b/pkg/bus/memory/event_dispatcher.go index 5fc028d..a424da3 100644 --- a/pkg/bus/memory/event_dispatcher.go +++ b/pkg/bus/memory/event_dispatcher.go @@ -1,9 +1,13 @@ package memory import ( + "context" "sync" + "time" "forge.cadoles.com/arcad/edge/pkg/bus" + "github.com/pkg/errors" + "gitlab.com/wpetit/goweb/logger" ) type eventDispatcherSet struct { @@ -18,13 +22,21 @@ func (s *eventDispatcherSet) Add(d *eventDispatcher) { s.items[d] = struct{}{} } +func (s *eventDispatcherSet) Remove(d *eventDispatcher) { + s.mutex.Lock() + defer s.mutex.Unlock() + + d.close() + delete(s.items, d) +} + func (s *eventDispatcherSet) RemoveByOutChannel(out <-chan bus.Message) { s.mutex.Lock() defer s.mutex.Unlock() for d := range s.items { if d.IsOut(out) { - d.Close() + d.close() delete(s.items, d) } } @@ -56,10 +68,21 @@ type eventDispatcher struct { closed bool } +func (d *eventDispatcher) Closed() bool { + d.mutex.RLock() + defer d.mutex.RUnlock() + + return d.closed +} + func (d *eventDispatcher) Close() { d.mutex.Lock() defer d.mutex.Unlock() + d.close() +} + +func (d *eventDispatcher) close() { d.closed = true close(d.in) } @@ -85,16 +108,52 @@ func (d *eventDispatcher) IsOut(out <-chan bus.Message) bool { return d.out == out } -func (d *eventDispatcher) Run() { +func (d *eventDispatcher) Run(ctx context.Context) { + defer func() { + for { + logger.Debug(ctx, "closing dispatcher, flushing out incoming messages") + + close(d.out) + + // Flush all incoming messages + for { + _, ok := <-d.in + if !ok { + return + } + } + } + }() + for { msg, ok := <-d.in if !ok { - close(d.out) - return } - d.out <- msg + timeout := time.After(time.Second) + + select { + case d.out <- msg: + case <-timeout: + logger.Error( + ctx, + "out message channel timeout", + logger.F("message", msg), + ) + + return + + case <-ctx.Done(): + logger.Error( + ctx, + "message subscription context canceled", + logger.F("message", msg), + logger.E(errors.WithStack(ctx.Err())), + ) + + return + } } } diff --git a/pkg/module/rpc.go b/pkg/module/rpc.go index 1ced983..c88741e 100644 --- a/pkg/module/rpc.go +++ b/pkg/module/rpc.go @@ -123,79 +123,83 @@ func (m *RPCModule) handleMessages() { } for msg := range clientMessages { - clientMessage, ok := msg.(*ClientMessage) - if !ok { - logger.Warn(ctx, "unexpected bus message", logger.F("message", msg)) + go m.handleMessage(ctx, msg, sendRes) + } +} - continue - } +func (m *RPCModule) handleMessage(ctx context.Context, msg bus.Message, sendRes func(ctx context.Context, req *RPCRequest, result goja.Value)) { + clientMessage, ok := msg.(*ClientMessage) + if !ok { + logger.Warn(ctx, "unexpected bus message", logger.F("message", msg)) - ok, req := m.isRPCRequest(clientMessage) - if !ok { - continue - } + return + } - logger.Debug(ctx, "received rpc request", logger.F("request", req)) + ok, req := m.isRPCRequest(clientMessage) + if !ok { + return + } - rawCallable, exists := m.callbacks.Load(req.Method) - if !exists { - logger.Debug(ctx, "method not found", logger.F("req", req)) + logger.Debug(ctx, "received rpc request", logger.F("request", req)) - if err := m.sendMethodNotFoundResponse(clientMessage.Context, req); err != nil { - logger.Error( - ctx, "could not send method not found response", - logger.E(errors.WithStack(err)), - logger.F("request", req), - ) - } + rawCallable, exists := m.callbacks.Load(req.Method) + if !exists { + logger.Debug(ctx, "method not found", logger.F("req", req)) - continue - } - - callable, ok := rawCallable.(goja.Callable) - if !ok { - logger.Debug(ctx, "invalid method", logger.F("req", req)) - - if err := m.sendMethodNotFoundResponse(clientMessage.Context, req); err != nil { - logger.Error( - ctx, "could not send method not found response", - logger.E(errors.WithStack(err)), - logger.F("request", req), - ) - } - - continue - } - - result, err := m.server.Exec(clientMessage.Context, callable, clientMessage.Context, req.Params) - if err != nil { + if err := m.sendMethodNotFoundResponse(clientMessage.Context, req); err != nil { logger.Error( - ctx, "rpc call error", + ctx, "could not send method not found response", logger.E(errors.WithStack(err)), logger.F("request", req), ) - - if err := m.sendErrorResponse(clientMessage.Context, req, err); err != nil { - logger.Error( - ctx, "could not send error response", - logger.E(errors.WithStack(err)), - logger.F("originalError", err), - logger.F("request", req), - ) - } - - continue } - promise, ok := app.IsPromise(result) - if ok { - go func(ctx context.Context, req *RPCRequest, promise *goja.Promise) { - result := m.server.WaitForPromise(promise) - sendRes(ctx, req, result) - }(clientMessage.Context, req, promise) - } else { - sendRes(clientMessage.Context, req, result) + return + } + + callable, ok := rawCallable.(goja.Callable) + if !ok { + logger.Debug(ctx, "invalid method", logger.F("req", req)) + + if err := m.sendMethodNotFoundResponse(clientMessage.Context, req); err != nil { + logger.Error( + ctx, "could not send method not found response", + logger.E(errors.WithStack(err)), + logger.F("request", req), + ) } + + return + } + + result, err := m.server.Exec(clientMessage.Context, callable, clientMessage.Context, req.Params) + if err != nil { + logger.Error( + ctx, "rpc call error", + logger.E(errors.WithStack(err)), + logger.F("request", req), + ) + + if err := m.sendErrorResponse(clientMessage.Context, req, err); err != nil { + logger.Error( + ctx, "could not send error response", + logger.E(errors.WithStack(err)), + logger.F("originalError", err), + logger.F("request", req), + ) + } + + return + } + + promise, ok := app.IsPromise(result) + if ok { + go func(ctx context.Context, req *RPCRequest, promise *goja.Promise) { + result := m.server.WaitForPromise(promise) + sendRes(ctx, req, result) + }(clientMessage.Context, req, promise) + } else { + sendRes(clientMessage.Context, req, result) } }