feat: rewrite bus to prevent deadlocks
All checks were successful
arcad/edge/pipeline/head This commit looks good
arcad/edge/pipeline/pr-master This commit looks good

This commit is contained in:
2023-11-28 16:35:49 +01:00
parent f4a7366aad
commit ad49c1718c
50 changed files with 1621 additions and 1336 deletions

View File

@ -3,11 +3,11 @@ package bus
import "context"
type Bus interface {
Subscribe(ctx context.Context, ns MessageNamespace) (<-chan Message, error)
Unsubscribe(ctx context.Context, ns MessageNamespace, ch <-chan Message)
Publish(ctx context.Context, msg Message) error
Request(ctx context.Context, msg Message) (Message, error)
Reply(ctx context.Context, ns MessageNamespace, h RequestHandler) error
Subscribe(ctx context.Context, addr Address) (<-chan Envelope, error)
Unsubscribe(addr Address, ch <-chan Envelope)
Publish(env Envelope) error
Request(ctx context.Context, env Envelope) (Envelope, error)
Reply(ctx context.Context, addr Address, h RequestHandler) chan error
}
type RequestHandler func(msg Message) (Message, error)
type RequestHandler func(env Envelope) (any, error)

32
pkg/bus/envelope.go Normal file
View File

@ -0,0 +1,32 @@
package bus
type Address string
type Envelope interface {
Message() any
Address() Address
}
type BaseEnvelope struct {
msg any
addr Address
}
// Address implements Envelope.
func (e *BaseEnvelope) Address() Address {
return e.addr
}
// Message implements Envelope.
func (e *BaseEnvelope) Message() any {
return e.msg
}
func NewEnvelope(addr Address, msg any) *BaseEnvelope {
return &BaseEnvelope{
addr: addr,
msg: msg,
}
}
var _ Envelope = &BaseEnvelope{}

View File

@ -15,13 +15,13 @@ type Bus struct {
nextRequestID uint64
}
func (b *Bus) Subscribe(ctx context.Context, ns bus.MessageNamespace) (<-chan bus.Message, error) {
func (b *Bus) Subscribe(ctx context.Context, address bus.Address) (<-chan bus.Envelope, error) {
logger.Debug(
ctx, "subscribing to messages",
logger.F("messageNamespace", ns),
ctx, "subscribing",
logger.F("address", address),
)
dispatchers := b.getDispatchers(ns)
dispatchers := b.getDispatchers(address)
disp := newEventDispatcher(b.opt.BufferSize)
go disp.Run(ctx)
@ -31,50 +31,41 @@ func (b *Bus) Subscribe(ctx context.Context, ns bus.MessageNamespace) (<-chan bu
return disp.Out(), nil
}
func (b *Bus) Unsubscribe(ctx context.Context, ns bus.MessageNamespace, ch <-chan bus.Message) {
func (b *Bus) Unsubscribe(address bus.Address, ch <-chan bus.Envelope) {
logger.Debug(
ctx, "unsubscribing from messages",
logger.F("messageNamespace", ns),
context.Background(), "unsubscribing",
logger.F("address", address),
)
dispatchers := b.getDispatchers(ns)
dispatchers := b.getDispatchers(address)
dispatchers.RemoveByOutChannel(ch)
}
func (b *Bus) Publish(ctx context.Context, msg bus.Message) error {
dispatchers := b.getDispatchers(msg.MessageNamespace())
dispatchersList := dispatchers.List()
func (b *Bus) Publish(env bus.Envelope) error {
dispatchers := b.getDispatchers(env.Address())
logger.Debug(
ctx, "publishing message",
logger.F("dispatchers", len(dispatchersList)),
logger.F("messageNamespace", msg.MessageNamespace()),
context.Background(), "publish",
logger.F("address", env.Address()),
)
for _, d := range dispatchersList {
if d.Closed() {
dispatchers.Remove(d)
continue
dispatchers.Range(func(d *eventDispatcher) {
if err := d.In(env); err != nil {
logger.Error(context.Background(), "could not publish message", logger.CapturedE(errors.WithStack(err)))
}
if err := d.In(msg); err != nil {
return errors.WithStack(err)
}
}
})
return nil
}
func (b *Bus) getDispatchers(namespace bus.MessageNamespace) *eventDispatcherSet {
strNamespace := string(namespace)
func (b *Bus) getDispatchers(address bus.Address) *eventDispatcherSet {
rawAddress := string(address)
rawDispatchers, exists := b.dispatchers.Get(strNamespace)
rawDispatchers, exists := b.dispatchers.Get(rawAddress)
dispatchers, ok := rawDispatchers.(*eventDispatcherSet)
if !exists || !ok {
dispatchers = newEventDispatcherSet()
b.dispatchers.Set(strNamespace, dispatchers)
b.dispatchers.Set(rawAddress, dispatchers)
}
return dispatchers

View File

@ -4,13 +4,23 @@ import (
"testing"
busTesting "forge.cadoles.com/arcad/edge/pkg/bus/testing"
"gitlab.com/wpetit/goweb/logger"
"go.uber.org/goleak"
)
func TestMain(m *testing.M) {
goleak.VerifyTestMain(m)
}
func TestMemoryBus(t *testing.T) {
if testing.Short() {
t.Skip("Test disabled when -short flag is set")
}
if testing.Verbose() {
logger.SetLevel(logger.LevelDebug)
}
t.Parallel()
t.Run("PublishSubscribe", func(t *testing.T) {
@ -26,4 +36,11 @@ func TestMemoryBus(t *testing.T) {
b := NewBus()
busTesting.TestRequestReply(t, b)
})
t.Run("CanceledRequestReply", func(t *testing.T) {
t.Parallel()
b := NewBus()
busTesting.TestCanceledRequest(t, b)
})
}

View File

@ -3,7 +3,6 @@ package memory
import (
"context"
"sync"
"time"
"forge.cadoles.com/arcad/edge/pkg/bus"
"github.com/pkg/errors"
@ -30,7 +29,7 @@ func (s *eventDispatcherSet) Remove(d *eventDispatcher) {
delete(s.items, d)
}
func (s *eventDispatcherSet) RemoveByOutChannel(out <-chan bus.Message) {
func (s *eventDispatcherSet) RemoveByOutChannel(out <-chan bus.Envelope) {
s.mutex.Lock()
defer s.mutex.Unlock()
@ -42,17 +41,18 @@ func (s *eventDispatcherSet) RemoveByOutChannel(out <-chan bus.Message) {
}
}
func (s *eventDispatcherSet) List() []*eventDispatcher {
func (s *eventDispatcherSet) Range(fn func(d *eventDispatcher)) {
s.mutex.Lock()
defer s.mutex.Unlock()
dispatchers := make([]*eventDispatcher, 0, len(s.items))
for d := range s.items {
dispatchers = append(dispatchers, d)
}
if d.Closed() {
s.Remove(d)
continue
}
return dispatchers
fn(d)
}
}
func newEventDispatcherSet() *eventDispatcherSet {
@ -62,8 +62,8 @@ func newEventDispatcherSet() *eventDispatcherSet {
}
type eventDispatcher struct {
in chan bus.Message
out chan bus.Message
in chan bus.Envelope
out chan bus.Envelope
mutex sync.RWMutex
closed bool
}
@ -91,7 +91,7 @@ func (d *eventDispatcher) close() {
d.closed = true
}
func (d *eventDispatcher) In(msg bus.Message) (err error) {
func (d *eventDispatcher) In(msg bus.Envelope) (err error) {
d.mutex.RLock()
defer d.mutex.RUnlock()
@ -104,67 +104,52 @@ func (d *eventDispatcher) In(msg bus.Message) (err error) {
return nil
}
func (d *eventDispatcher) Out() <-chan bus.Message {
func (d *eventDispatcher) Out() <-chan bus.Envelope {
return d.out
}
func (d *eventDispatcher) IsOut(out <-chan bus.Message) bool {
func (d *eventDispatcher) IsOut(out <-chan bus.Envelope) bool {
return d.out == out
}
func (d *eventDispatcher) Run(ctx context.Context) {
defer func() {
for {
logger.Debug(ctx, "closing dispatcher, flushing out incoming messages")
logger.Debug(ctx, "closing dispatcher, flushing out incoming messages")
close(d.out)
close(d.out)
for range d.in {
// Flush all incoming messages
for {
_, ok := <-d.in
if !ok {
return
}
}
}
}()
for {
msg, ok := <-d.in
if !ok {
return
}
timeout := time.After(time.Second)
select {
case d.out <- msg:
case <-timeout:
logger.Error(
ctx,
"out message channel timeout",
logger.F("message", msg),
)
return
case <-ctx.Done():
logger.Error(
ctx,
"message subscription context canceled",
logger.F("message", msg),
logger.CapturedE(errors.WithStack(ctx.Err())),
)
if err := ctx.Err(); !errors.Is(err, context.Canceled) {
logger.Error(
ctx,
"message subscription context canceled",
logger.CapturedE(errors.WithStack(err)),
)
}
return
case msg, ok := <-d.in:
if !ok {
return
}
d.out <- msg
}
}
}
func newEventDispatcher(bufferSize int64) *eventDispatcher {
return &eventDispatcher{
in: make(chan bus.Message, bufferSize),
out: make(chan bus.Message, bufferSize),
in: make(chan bus.Envelope, bufferSize),
out: make(chan bus.Envelope, bufferSize),
closed: false,
}
}

View File

@ -11,57 +11,78 @@ import (
)
const (
MessageNamespaceRequest bus.MessageNamespace = "reqrep/request"
MessageNamespaceReply bus.MessageNamespace = "reqrep/reply"
AddressRequest bus.Address = "bus/memory/request"
AddressReply bus.Address = "bus/memory/reply"
)
type RequestMessage struct {
RequestID uint64
Message bus.Message
ns bus.MessageNamespace
type RequestEnvelope struct {
requestID uint64
wrapped bus.Envelope
}
func (m *RequestMessage) MessageNamespace() bus.MessageNamespace {
return m.ns
func (e *RequestEnvelope) Address() bus.Address {
return getRequestAddress(e.wrapped.Address())
}
type ReplyMessage struct {
RequestID uint64
Message bus.Message
Error error
ns bus.MessageNamespace
func (e *RequestEnvelope) Message() any {
return e.wrapped.Message()
}
func (m *ReplyMessage) MessageNamespace() bus.MessageNamespace {
return m.ns
func (e *RequestEnvelope) RequestID() uint64 {
return e.requestID
}
func (b *Bus) Request(ctx context.Context, msg bus.Message) (bus.Message, error) {
func (e *RequestEnvelope) Unwrap() bus.Envelope {
return e.wrapped
}
type ReplyEnvelope struct {
requestID uint64
wrapped bus.Envelope
err error
}
func (e *ReplyEnvelope) Address() bus.Address {
return getReplyAddress(e.wrapped.Address(), e.requestID)
}
func (e *ReplyEnvelope) Message() any {
return e.wrapped.Message()
}
func (e *ReplyEnvelope) Err() error {
return e.err
}
func (e *ReplyEnvelope) Unwrap() bus.Envelope {
return e.wrapped
}
func (b *Bus) Request(ctx context.Context, env bus.Envelope) (bus.Envelope, error) {
requestID := atomic.AddUint64(&b.nextRequestID, 1)
req := &RequestMessage{
RequestID: requestID,
Message: msg,
ns: msg.MessageNamespace(),
req := &RequestEnvelope{
requestID: requestID,
wrapped: env,
}
replyNamespace := createReplyNamespace(requestID)
replyAddress := getReplyAddress(env.Address(), requestID)
replies, err := b.Subscribe(ctx, replyNamespace)
subCtx, cancel := context.WithCancel(ctx)
defer cancel()
replies, err := b.Subscribe(subCtx, replyAddress)
if err != nil {
return nil, errors.WithStack(err)
}
defer func() {
b.Unsubscribe(ctx, replyNamespace, replies)
b.Unsubscribe(replyAddress, replies)
}()
logger.Debug(ctx, "publishing request", logger.F("request", req))
if err := b.Publish(ctx, req); err != nil {
if err := b.Publish(req); err != nil {
return nil, errors.WithStack(err)
}
@ -70,82 +91,93 @@ func (b *Bus) Request(ctx context.Context, msg bus.Message) (bus.Message, error)
case <-ctx.Done():
return nil, errors.WithStack(ctx.Err())
case msg, ok := <-replies:
case env, ok := <-replies:
if !ok {
return nil, errors.WithStack(bus.ErrNoResponse)
}
reply, ok := msg.(*ReplyMessage)
reply, ok := env.(*ReplyEnvelope)
if !ok {
return nil, errors.WithStack(bus.ErrUnexpectedMessage)
}
if reply.Error != nil {
if err := reply.Err(); err != nil {
return nil, errors.WithStack(err)
}
return reply.Message, nil
return reply.Unwrap(), nil
}
}
}
type RequestHandler func(evt bus.Message) (bus.Message, error)
func (b *Bus) Reply(ctx context.Context, address bus.Address, handler bus.RequestHandler) chan error {
requestAddress := getRequestAddress(address)
func (b *Bus) Reply(ctx context.Context, msgNamespace bus.MessageNamespace, h bus.RequestHandler) error {
requests, err := b.Subscribe(ctx, msgNamespace)
errs := make(chan error)
requests, err := b.Subscribe(ctx, requestAddress)
if err != nil {
return errors.WithStack(err)
go func() {
errs <- errors.WithStack(err)
close(errs)
}()
return errs
}
defer func() {
b.Unsubscribe(ctx, msgNamespace, requests)
go func() {
defer func() {
b.Unsubscribe(requestAddress, requests)
close(errs)
}()
for {
select {
case <-ctx.Done():
errs <- errors.WithStack(ctx.Err())
return
case env, ok := <-requests:
if !ok {
return
}
request, ok := env.(*RequestEnvelope)
if !ok {
errs <- errors.WithStack(bus.ErrUnexpectedMessage)
continue
}
logger.Debug(ctx, "handling request", logger.F("request", request))
msg, err := handler(request.Unwrap())
reply := &ReplyEnvelope{
requestID: request.RequestID(),
wrapped: bus.NewEnvelope(request.Unwrap().Address(), msg),
}
if err != nil {
reply.err = errors.WithStack(err)
}
logger.Debug(ctx, "publishing reply", logger.F("reply", reply))
if err := b.Publish(reply); err != nil {
errs <- errors.WithStack(err)
continue
}
}
}
}()
for {
select {
case <-ctx.Done():
return errors.WithStack(ctx.Err())
case msg, ok := <-requests:
if !ok {
return nil
}
request, ok := msg.(*RequestMessage)
if !ok {
return errors.WithStack(bus.ErrUnexpectedMessage)
}
logger.Debug(ctx, "handling request", logger.F("request", request))
msg, err := h(request.Message)
reply := &ReplyMessage{
RequestID: request.RequestID,
Message: nil,
Error: nil,
ns: createReplyNamespace(request.RequestID),
}
if err != nil {
reply.Error = errors.WithStack(err)
} else {
reply.Message = msg
}
logger.Debug(ctx, "publishing reply", logger.F("reply", reply))
if err := b.Publish(ctx, reply); err != nil {
return errors.WithStack(err)
}
}
}
return errs
}
func createReplyNamespace(requestID uint64) bus.MessageNamespace {
return bus.NewMessageNamespace(
MessageNamespaceReply,
bus.MessageNamespace(strconv.FormatUint(requestID, 10)),
)
func getRequestAddress(addr bus.Address) bus.Address {
return AddressRequest + "/" + addr
}
func getReplyAddress(addr bus.Address, requestID uint64) bus.Address {
return AddressReply + "/" + addr + "/" + bus.Address(strconv.FormatUint(requestID, 10))
}

View File

@ -1,33 +0,0 @@
package bus
import (
"strings"
"github.com/pkg/errors"
)
type (
MessageNamespace string
)
type Message interface {
MessageNamespace() MessageNamespace
}
func NewMessageNamespace(namespaces ...MessageNamespace) MessageNamespace {
var sb strings.Builder
for i, ns := range namespaces {
if i != 0 {
if _, err := sb.WriteString(":"); err != nil {
panic(errors.Wrap(err, "could not build new message namespace"))
}
}
if _, err := sb.WriteString(string(ns)); err != nil {
panic(errors.Wrap(err, "could not build new message namespace"))
}
}
return MessageNamespace(sb.String())
}

View File

@ -2,6 +2,7 @@ package testing
import (
"context"
"fmt"
"sync"
"sync/atomic"
"testing"
@ -12,74 +13,52 @@ import (
)
const (
testNamespace bus.MessageNamespace = "testNamespace"
testAddress bus.Address = "testAddress"
)
type testMessage struct{}
func (e *testMessage) MessageNamespace() bus.MessageNamespace {
return testNamespace
}
func TestPublishSubscribe(t *testing.T, b bus.Bus) {
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
defer cancel()
t.Log("subscribe")
messages, err := b.Subscribe(ctx, testNamespace)
envelopes, err := b.Subscribe(ctx, testAddress)
if err != nil {
t.Fatal(errors.WithStack(err))
}
expectedTotal := 5
var wg sync.WaitGroup
wg.Add(5)
wg.Add(expectedTotal)
go func() {
// 5 events should be received
t.Log("publish 0")
if err := b.Publish(ctx, &testMessage{}); err != nil {
t.Error(errors.WithStack(err))
}
count := expectedTotal
t.Log("publish 1")
for i := 0; i < count; i++ {
env := bus.NewEnvelope(testAddress, fmt.Sprintf("message %d", i))
if err := b.Publish(ctx, &testMessage{}); err != nil {
t.Error(errors.WithStack(err))
}
if err := b.Publish(env); err != nil {
t.Error(errors.WithStack(err))
}
t.Log("publish 2")
if err := b.Publish(ctx, &testMessage{}); err != nil {
t.Error(errors.WithStack(err))
}
t.Log("publish 3")
if err := b.Publish(ctx, &testMessage{}); err != nil {
t.Error(errors.WithStack(err))
}
t.Log("publish 4")
if err := b.Publish(ctx, &testMessage{}); err != nil {
t.Error(errors.WithStack(err))
t.Logf("published %d", i)
}
}()
var count int32 = 0
go func() {
t.Log("range for events")
t.Log("range for received envelopes")
for msg := range messages {
for env := range envelopes {
t.Logf("received msg %d", atomic.LoadInt32(&count))
atomic.AddInt32(&count, 1)
if e, g := testNamespace, msg.MessageNamespace(); e != g {
t.Errorf("evt.MessageNamespace(): expected '%v', got '%v'", e, g)
if e, g := testAddress, env.Address(); e != g {
t.Errorf("env.Address(): expected '%v', got '%v'", e, g)
}
wg.Done()
@ -88,9 +67,9 @@ func TestPublishSubscribe(t *testing.T, b bus.Bus) {
wg.Wait()
b.Unsubscribe(ctx, testNamespace, messages)
b.Unsubscribe(testAddress, envelopes)
if e, g := int32(5), count; e != g {
t.Errorf("message received count: expected '%v', got '%v'", e, g)
if e, g := int32(expectedTotal), count; e != g {
t.Errorf("envelopes received count: expected '%v', got '%v'", e, g)
}
}

View File

@ -11,58 +11,42 @@ import (
)
const (
testTypeReqRes bus.MessageNamespace = "testNamspaceReqRes"
testTypeReqResAddress bus.Address = "testTypeReqResAddress"
)
type testReqResMessage struct {
i int
}
func (m *testReqResMessage) MessageNamespace() bus.MessageNamespace {
return testNamespace
}
func TestRequestReply(t *testing.T, b bus.Bus) {
expectedRoundTrips := 256
timeout := time.Now().Add(time.Duration(expectedRoundTrips) * time.Second)
var (
initWaitGroup sync.WaitGroup
resWaitGroup sync.WaitGroup
)
replyCtx, cancelReply := context.WithDeadline(context.Background(), timeout)
defer cancelReply()
initWaitGroup.Add(1)
var resWaitGroup sync.WaitGroup
replyErrs := b.Reply(replyCtx, testTypeReqResAddress, func(env bus.Envelope) (any, error) {
defer resWaitGroup.Done()
req, ok := env.Message().(int)
if !ok {
return nil, errors.WithStack(bus.ErrUnexpectedMessage)
}
// Simulate random work
time.Sleep(time.Millisecond * 100)
t.Logf("[RES] sending res #%d", req)
return req, nil
})
go func() {
repondCtx, cancelRespond := context.WithDeadline(context.Background(), timeout)
defer cancelRespond()
initWaitGroup.Done()
err := b.Reply(repondCtx, testNamespace, func(msg bus.Message) (bus.Message, error) {
defer resWaitGroup.Done()
req, ok := msg.(*testReqResMessage)
if !ok {
return nil, errors.WithStack(bus.ErrUnexpectedMessage)
for err := range replyErrs {
if !errors.Is(err, context.Canceled) {
t.Errorf("%+v", errors.WithStack(err))
}
result := &testReqResMessage{req.i}
// Simulate random work
time.Sleep(time.Millisecond * 100)
t.Logf("[RES] sending res #%d", req.i)
return result, nil
})
if err != nil {
t.Error(err)
}
}()
initWaitGroup.Wait()
var reqWaitGroup sync.WaitGroup
for i := 0; i < expectedRoundTrips; i++ {
@ -75,32 +59,30 @@ func TestRequestReply(t *testing.T, b bus.Bus) {
requestCtx, cancelRequest := context.WithDeadline(context.Background(), timeout)
defer cancelRequest()
req := &testReqResMessage{i}
t.Logf("[REQ] sending req #%d", i)
result, err := b.Request(requestCtx, req)
response, err := b.Request(requestCtx, bus.NewEnvelope(testTypeReqResAddress, i))
if err != nil {
t.Error(err)
}
t.Logf("[REQ] received req #%d reply", i)
if result == nil {
t.Error("result should not be nil")
if response == nil {
t.Error("response should not be nil")
return
}
res, ok := result.(*testReqResMessage)
result, ok := response.Message().(int)
if !ok {
t.Error(errors.WithStack(bus.ErrUnexpectedMessage))
return
}
if e, g := req.i, res.i; e != g {
t.Errorf("res.i: expected '%v', got '%v'", e, g)
if e, g := i, result; e != g {
t.Errorf("response.Message(): expected '%v', got '%v'", e, g)
}
}(i)
}
@ -108,3 +90,77 @@ func TestRequestReply(t *testing.T, b bus.Bus) {
reqWaitGroup.Wait()
resWaitGroup.Wait()
}
func TestCanceledRequest(t *testing.T, b bus.Bus) {
replyCtx, cancelReply := context.WithCancel(context.Background())
defer cancelReply()
errs := b.Reply(replyCtx, testTypeReqResAddress, func(env bus.Envelope) (any, error) {
return env.Message(), nil
})
go func() {
for err := range errs {
if !errors.Is(err, context.Canceled) {
t.Errorf("%+v", errors.WithStack(err))
}
}
}()
var wg sync.WaitGroup
count := 100
wg.Add(count)
for i := 0; i < count; i++ {
go func(i int) {
defer wg.Done()
t.Logf("calling %d", i)
isCanceled := i%2 == 0
var ctx context.Context
if isCanceled {
canceledCtx, cancel := context.WithCancel(context.Background())
cancel()
ctx = canceledCtx
} else {
ctx = context.Background()
}
t.Logf("publishing envelope #%d", i)
reply, err := b.Request(ctx, bus.NewEnvelope(testTypeReqResAddress, int64(i)))
if err != nil {
if errors.Is(err, context.Canceled) && isCanceled {
return
}
if errors.Is(err, bus.ErrNoResponse) && isCanceled {
return
}
t.Errorf("%+v", errors.WithStack(err))
return
}
result, ok := reply.Message().(int64)
if !ok {
t.Errorf("response.Result: expected type '%T', got '%T'", int64(0), reply.Message())
return
}
if e, g := i, int(result); e != g {
t.Errorf("response.Result: expected '%v', got '%v'", e, g)
return
}
}(i)
}
wg.Wait()
}