feat: rewrite bus to prevent deadlocks
All checks were successful
arcad/edge/pipeline/head This commit looks good
arcad/edge/pipeline/pr-master This commit looks good

This commit is contained in:
2023-11-28 16:35:49 +01:00
parent f4a7366aad
commit ad49c1718c
50 changed files with 1621 additions and 1336 deletions

View File

@ -15,13 +15,13 @@ type Bus struct {
nextRequestID uint64
}
func (b *Bus) Subscribe(ctx context.Context, ns bus.MessageNamespace) (<-chan bus.Message, error) {
func (b *Bus) Subscribe(ctx context.Context, address bus.Address) (<-chan bus.Envelope, error) {
logger.Debug(
ctx, "subscribing to messages",
logger.F("messageNamespace", ns),
ctx, "subscribing",
logger.F("address", address),
)
dispatchers := b.getDispatchers(ns)
dispatchers := b.getDispatchers(address)
disp := newEventDispatcher(b.opt.BufferSize)
go disp.Run(ctx)
@ -31,50 +31,41 @@ func (b *Bus) Subscribe(ctx context.Context, ns bus.MessageNamespace) (<-chan bu
return disp.Out(), nil
}
func (b *Bus) Unsubscribe(ctx context.Context, ns bus.MessageNamespace, ch <-chan bus.Message) {
func (b *Bus) Unsubscribe(address bus.Address, ch <-chan bus.Envelope) {
logger.Debug(
ctx, "unsubscribing from messages",
logger.F("messageNamespace", ns),
context.Background(), "unsubscribing",
logger.F("address", address),
)
dispatchers := b.getDispatchers(ns)
dispatchers := b.getDispatchers(address)
dispatchers.RemoveByOutChannel(ch)
}
func (b *Bus) Publish(ctx context.Context, msg bus.Message) error {
dispatchers := b.getDispatchers(msg.MessageNamespace())
dispatchersList := dispatchers.List()
func (b *Bus) Publish(env bus.Envelope) error {
dispatchers := b.getDispatchers(env.Address())
logger.Debug(
ctx, "publishing message",
logger.F("dispatchers", len(dispatchersList)),
logger.F("messageNamespace", msg.MessageNamespace()),
context.Background(), "publish",
logger.F("address", env.Address()),
)
for _, d := range dispatchersList {
if d.Closed() {
dispatchers.Remove(d)
continue
dispatchers.Range(func(d *eventDispatcher) {
if err := d.In(env); err != nil {
logger.Error(context.Background(), "could not publish message", logger.CapturedE(errors.WithStack(err)))
}
if err := d.In(msg); err != nil {
return errors.WithStack(err)
}
}
})
return nil
}
func (b *Bus) getDispatchers(namespace bus.MessageNamespace) *eventDispatcherSet {
strNamespace := string(namespace)
func (b *Bus) getDispatchers(address bus.Address) *eventDispatcherSet {
rawAddress := string(address)
rawDispatchers, exists := b.dispatchers.Get(strNamespace)
rawDispatchers, exists := b.dispatchers.Get(rawAddress)
dispatchers, ok := rawDispatchers.(*eventDispatcherSet)
if !exists || !ok {
dispatchers = newEventDispatcherSet()
b.dispatchers.Set(strNamespace, dispatchers)
b.dispatchers.Set(rawAddress, dispatchers)
}
return dispatchers

View File

@ -4,13 +4,23 @@ import (
"testing"
busTesting "forge.cadoles.com/arcad/edge/pkg/bus/testing"
"gitlab.com/wpetit/goweb/logger"
"go.uber.org/goleak"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestMemoryBus(t *testing.T) {
if testing.Short() {
t.Skip("Test disabled when -short flag is set")
}
if testing.Verbose() {
logger.SetLevel(logger.LevelDebug)
}
t.Parallel()
t.Run("PublishSubscribe", func(t *testing.T) {
@ -26,4 +36,11 @@ func TestMemoryBus(t *testing.T) {
b := NewBus()
busTesting.TestRequestReply(t, b)
})
t.Run("CanceledRequestReply", func(t *testing.T) {
t.Parallel()
b := NewBus()
busTesting.TestCanceledRequest(t, b)
})
}

View File

@ -3,7 +3,6 @@ package memory
import (
"context"
"sync"
"time"
"forge.cadoles.com/arcad/edge/pkg/bus"
"github.com/pkg/errors"
@ -30,7 +29,7 @@ func (s *eventDispatcherSet) Remove(d *eventDispatcher) {
delete(s.items, d)
}
func (s *eventDispatcherSet) RemoveByOutChannel(out <-chan bus.Message) {
func (s *eventDispatcherSet) RemoveByOutChannel(out <-chan bus.Envelope) {
s.mutex.Lock()
defer s.mutex.Unlock()
@ -42,17 +41,18 @@ func (s *eventDispatcherSet) RemoveByOutChannel(out <-chan bus.Message) {
}
}
func (s *eventDispatcherSet) List() []*eventDispatcher {
func (s *eventDispatcherSet) Range(fn func(d *eventDispatcher)) {
s.mutex.Lock()
defer s.mutex.Unlock()
dispatchers := make([]*eventDispatcher, 0, len(s.items))
for d := range s.items {
dispatchers = append(dispatchers, d)
}
if d.Closed() {
s.Remove(d)
continue
}
return dispatchers
fn(d)
}
}
func newEventDispatcherSet() *eventDispatcherSet {
@ -62,8 +62,8 @@ func newEventDispatcherSet() *eventDispatcherSet {
}
type eventDispatcher struct {
in chan bus.Message
out chan bus.Message
in chan bus.Envelope
out chan bus.Envelope
mutex sync.RWMutex
closed bool
}
@ -91,7 +91,7 @@ func (d *eventDispatcher) close() {
d.closed = true
}
func (d *eventDispatcher) In(msg bus.Message) (err error) {
func (d *eventDispatcher) In(msg bus.Envelope) (err error) {
d.mutex.RLock()
defer d.mutex.RUnlock()
@ -104,67 +104,52 @@ func (d *eventDispatcher) In(msg bus.Message) (err error) {
return nil
}
func (d *eventDispatcher) Out() <-chan bus.Message {
func (d *eventDispatcher) Out() <-chan bus.Envelope {
return d.out
}
func (d *eventDispatcher) IsOut(out <-chan bus.Message) bool {
func (d *eventDispatcher) IsOut(out <-chan bus.Envelope) bool {
return d.out == out
}
func (d *eventDispatcher) Run(ctx context.Context) {
defer func() {
for {
logger.Debug(ctx, "closing dispatcher, flushing out incoming messages")
logger.Debug(ctx, "closing dispatcher, flushing out incoming messages")
close(d.out)
close(d.out)
for range d.in {
// Flush all incoming messages
for {
_, ok := <-d.in
if !ok {
return
}
}
}
}()
for {
msg, ok := <-d.in
if !ok {
return
}
timeout := time.After(time.Second)
select {
case d.out <- msg:
case <-timeout:
logger.Error(
ctx,
"out message channel timeout",
logger.F("message", msg),
)
return
case <-ctx.Done():
logger.Error(
ctx,
"message subscription context canceled",
logger.F("message", msg),
logger.CapturedE(errors.WithStack(ctx.Err())),
)
if err := ctx.Err(); !errors.Is(err, context.Canceled) {
logger.Error(
ctx,
"message subscription context canceled",
logger.CapturedE(errors.WithStack(err)),
)
}
return
case msg, ok := <-d.in:
if !ok {
return
}
d.out <- msg
}
}
}
func newEventDispatcher(bufferSize int64) *eventDispatcher {
return &eventDispatcher{
in: make(chan bus.Message, bufferSize),
out: make(chan bus.Message, bufferSize),
in: make(chan bus.Envelope, bufferSize),
out: make(chan bus.Envelope, bufferSize),
closed: false,
}
}

View File

@ -11,57 +11,78 @@ import (
)
const (
MessageNamespaceRequest bus.MessageNamespace = "reqrep/request"
MessageNamespaceReply bus.MessageNamespace = "reqrep/reply"
AddressRequest bus.Address = "bus/memory/request"
AddressReply bus.Address = "bus/memory/reply"
)
type RequestMessage struct {
RequestID uint64
Message bus.Message
ns bus.MessageNamespace
type RequestEnvelope struct {
requestID uint64
wrapped bus.Envelope
}
func (m *RequestMessage) MessageNamespace() bus.MessageNamespace {
return m.ns
func (e *RequestEnvelope) Address() bus.Address {
return getRequestAddress(e.wrapped.Address())
}
type ReplyMessage struct {
RequestID uint64
Message bus.Message
Error error
ns bus.MessageNamespace
func (e *RequestEnvelope) Message() any {
return e.wrapped.Message()
}
func (m *ReplyMessage) MessageNamespace() bus.MessageNamespace {
return m.ns
func (e *RequestEnvelope) RequestID() uint64 {
return e.requestID
}
func (b *Bus) Request(ctx context.Context, msg bus.Message) (bus.Message, error) {
func (e *RequestEnvelope) Unwrap() bus.Envelope {
return e.wrapped
}
type ReplyEnvelope struct {
requestID uint64
wrapped bus.Envelope
err error
}
func (e *ReplyEnvelope) Address() bus.Address {
return getReplyAddress(e.wrapped.Address(), e.requestID)
}
func (e *ReplyEnvelope) Message() any {
return e.wrapped.Message()
}
func (e *ReplyEnvelope) Err() error {
return e.err
}
func (e *ReplyEnvelope) Unwrap() bus.Envelope {
return e.wrapped
}
func (b *Bus) Request(ctx context.Context, env bus.Envelope) (bus.Envelope, error) {
requestID := atomic.AddUint64(&b.nextRequestID, 1)
req := &RequestMessage{
RequestID: requestID,
Message: msg,
ns: msg.MessageNamespace(),
req := &RequestEnvelope{
requestID: requestID,
wrapped: env,
}
replyNamespace := createReplyNamespace(requestID)
replyAddress := getReplyAddress(env.Address(), requestID)
replies, err := b.Subscribe(ctx, replyNamespace)
subCtx, cancel := context.WithCancel(ctx)
defer cancel()
replies, err := b.Subscribe(subCtx, replyAddress)
if err != nil {
return nil, errors.WithStack(err)
}
defer func() {
b.Unsubscribe(ctx, replyNamespace, replies)
b.Unsubscribe(replyAddress, replies)
}()
logger.Debug(ctx, "publishing request", logger.F("request", req))
if err := b.Publish(ctx, req); err != nil {
if err := b.Publish(req); err != nil {
return nil, errors.WithStack(err)
}
@ -70,82 +91,93 @@ func (b *Bus) Request(ctx context.Context, msg bus.Message) (bus.Message, error)
case <-ctx.Done():
return nil, errors.WithStack(ctx.Err())
case msg, ok := <-replies:
case env, ok := <-replies:
if !ok {
return nil, errors.WithStack(bus.ErrNoResponse)
}
reply, ok := msg.(*ReplyMessage)
reply, ok := env.(*ReplyEnvelope)
if !ok {
return nil, errors.WithStack(bus.ErrUnexpectedMessage)
}
if reply.Error != nil {
if err := reply.Err(); err != nil {
return nil, errors.WithStack(err)
}
return reply.Message, nil
return reply.Unwrap(), nil
}
}
}
type RequestHandler func(evt bus.Message) (bus.Message, error)
func (b *Bus) Reply(ctx context.Context, address bus.Address, handler bus.RequestHandler) chan error {
requestAddress := getRequestAddress(address)
func (b *Bus) Reply(ctx context.Context, msgNamespace bus.MessageNamespace, h bus.RequestHandler) error {
requests, err := b.Subscribe(ctx, msgNamespace)
errs := make(chan error)
requests, err := b.Subscribe(ctx, requestAddress)
if err != nil {
return errors.WithStack(err)
go func() {
errs <- errors.WithStack(err)
close(errs)
}()
return errs
}
defer func() {
b.Unsubscribe(ctx, msgNamespace, requests)
go func() {
defer func() {
b.Unsubscribe(requestAddress, requests)
close(errs)
}()
for {
select {
case <-ctx.Done():
errs <- errors.WithStack(ctx.Err())
return
case env, ok := <-requests:
if !ok {
return
}
request, ok := env.(*RequestEnvelope)
if !ok {
errs <- errors.WithStack(bus.ErrUnexpectedMessage)
continue
}
logger.Debug(ctx, "handling request", logger.F("request", request))
msg, err := handler(request.Unwrap())
reply := &ReplyEnvelope{
requestID: request.RequestID(),
wrapped: bus.NewEnvelope(request.Unwrap().Address(), msg),
}
if err != nil {
reply.err = errors.WithStack(err)
}
logger.Debug(ctx, "publishing reply", logger.F("reply", reply))
if err := b.Publish(reply); err != nil {
errs <- errors.WithStack(err)
continue
}
}
}
}()
for {
select {
case <-ctx.Done():
return errors.WithStack(ctx.Err())
case msg, ok := <-requests:
if !ok {
return nil
}
request, ok := msg.(*RequestMessage)
if !ok {
return errors.WithStack(bus.ErrUnexpectedMessage)
}
logger.Debug(ctx, "handling request", logger.F("request", request))
msg, err := h(request.Message)
reply := &ReplyMessage{
RequestID: request.RequestID,
Message: nil,
Error: nil,
ns: createReplyNamespace(request.RequestID),
}
if err != nil {
reply.Error = errors.WithStack(err)
} else {
reply.Message = msg
}
logger.Debug(ctx, "publishing reply", logger.F("reply", reply))
if err := b.Publish(ctx, reply); err != nil {
return errors.WithStack(err)
}
}
}
return errs
}
func createReplyNamespace(requestID uint64) bus.MessageNamespace {
return bus.NewMessageNamespace(
MessageNamespaceReply,
bus.MessageNamespace(strconv.FormatUint(requestID, 10)),
)
func getRequestAddress(addr bus.Address) bus.Address {
return AddressRequest + "/" + addr
}
func getReplyAddress(addr bus.Address, requestID uint64) bus.Address {
return AddressReply + "/" + addr + "/" + bus.Address(strconv.FormatUint(requestID, 10))
}