feat: rewrite bus to prevent deadlocks
This commit is contained in:
@ -2,6 +2,7 @@ package testing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
@ -12,74 +13,52 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
testNamespace bus.MessageNamespace = "testNamespace"
|
||||
testAddress bus.Address = "testAddress"
|
||||
)
|
||||
|
||||
type testMessage struct{}
|
||||
|
||||
func (e *testMessage) MessageNamespace() bus.MessageNamespace {
|
||||
return testNamespace
|
||||
}
|
||||
|
||||
func TestPublishSubscribe(t *testing.T, b bus.Bus) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
|
||||
t.Log("subscribe")
|
||||
|
||||
messages, err := b.Subscribe(ctx, testNamespace)
|
||||
envelopes, err := b.Subscribe(ctx, testAddress)
|
||||
if err != nil {
|
||||
t.Fatal(errors.WithStack(err))
|
||||
}
|
||||
|
||||
expectedTotal := 5
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
wg.Add(5)
|
||||
wg.Add(expectedTotal)
|
||||
|
||||
go func() {
|
||||
// 5 events should be received
|
||||
t.Log("publish 0")
|
||||
|
||||
if err := b.Publish(ctx, &testMessage{}); err != nil {
|
||||
t.Error(errors.WithStack(err))
|
||||
}
|
||||
count := expectedTotal
|
||||
|
||||
t.Log("publish 1")
|
||||
for i := 0; i < count; i++ {
|
||||
env := bus.NewEnvelope(testAddress, fmt.Sprintf("message %d", i))
|
||||
|
||||
if err := b.Publish(ctx, &testMessage{}); err != nil {
|
||||
t.Error(errors.WithStack(err))
|
||||
}
|
||||
if err := b.Publish(env); err != nil {
|
||||
t.Error(errors.WithStack(err))
|
||||
}
|
||||
|
||||
t.Log("publish 2")
|
||||
|
||||
if err := b.Publish(ctx, &testMessage{}); err != nil {
|
||||
t.Error(errors.WithStack(err))
|
||||
}
|
||||
|
||||
t.Log("publish 3")
|
||||
|
||||
if err := b.Publish(ctx, &testMessage{}); err != nil {
|
||||
t.Error(errors.WithStack(err))
|
||||
}
|
||||
|
||||
t.Log("publish 4")
|
||||
|
||||
if err := b.Publish(ctx, &testMessage{}); err != nil {
|
||||
t.Error(errors.WithStack(err))
|
||||
t.Logf("published %d", i)
|
||||
}
|
||||
}()
|
||||
|
||||
var count int32 = 0
|
||||
|
||||
go func() {
|
||||
t.Log("range for events")
|
||||
t.Log("range for received envelopes")
|
||||
|
||||
for msg := range messages {
|
||||
for env := range envelopes {
|
||||
t.Logf("received msg %d", atomic.LoadInt32(&count))
|
||||
atomic.AddInt32(&count, 1)
|
||||
|
||||
if e, g := testNamespace, msg.MessageNamespace(); e != g {
|
||||
t.Errorf("evt.MessageNamespace(): expected '%v', got '%v'", e, g)
|
||||
if e, g := testAddress, env.Address(); e != g {
|
||||
t.Errorf("env.Address(): expected '%v', got '%v'", e, g)
|
||||
}
|
||||
|
||||
wg.Done()
|
||||
@ -88,9 +67,9 @@ func TestPublishSubscribe(t *testing.T, b bus.Bus) {
|
||||
|
||||
wg.Wait()
|
||||
|
||||
b.Unsubscribe(ctx, testNamespace, messages)
|
||||
b.Unsubscribe(testAddress, envelopes)
|
||||
|
||||
if e, g := int32(5), count; e != g {
|
||||
t.Errorf("message received count: expected '%v', got '%v'", e, g)
|
||||
if e, g := int32(expectedTotal), count; e != g {
|
||||
t.Errorf("envelopes received count: expected '%v', got '%v'", e, g)
|
||||
}
|
||||
}
|
||||
|
@ -11,58 +11,42 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
testTypeReqRes bus.MessageNamespace = "testNamspaceReqRes"
|
||||
testTypeReqResAddress bus.Address = "testTypeReqResAddress"
|
||||
)
|
||||
|
||||
type testReqResMessage struct {
|
||||
i int
|
||||
}
|
||||
|
||||
func (m *testReqResMessage) MessageNamespace() bus.MessageNamespace {
|
||||
return testNamespace
|
||||
}
|
||||
|
||||
func TestRequestReply(t *testing.T, b bus.Bus) {
|
||||
expectedRoundTrips := 256
|
||||
timeout := time.Now().Add(time.Duration(expectedRoundTrips) * time.Second)
|
||||
|
||||
var (
|
||||
initWaitGroup sync.WaitGroup
|
||||
resWaitGroup sync.WaitGroup
|
||||
)
|
||||
replyCtx, cancelReply := context.WithDeadline(context.Background(), timeout)
|
||||
defer cancelReply()
|
||||
|
||||
initWaitGroup.Add(1)
|
||||
var resWaitGroup sync.WaitGroup
|
||||
|
||||
replyErrs := b.Reply(replyCtx, testTypeReqResAddress, func(env bus.Envelope) (any, error) {
|
||||
defer resWaitGroup.Done()
|
||||
|
||||
req, ok := env.Message().(int)
|
||||
if !ok {
|
||||
return nil, errors.WithStack(bus.ErrUnexpectedMessage)
|
||||
}
|
||||
|
||||
// Simulate random work
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
|
||||
t.Logf("[RES] sending res #%d", req)
|
||||
|
||||
return req, nil
|
||||
})
|
||||
|
||||
go func() {
|
||||
repondCtx, cancelRespond := context.WithDeadline(context.Background(), timeout)
|
||||
defer cancelRespond()
|
||||
|
||||
initWaitGroup.Done()
|
||||
|
||||
err := b.Reply(repondCtx, testNamespace, func(msg bus.Message) (bus.Message, error) {
|
||||
defer resWaitGroup.Done()
|
||||
|
||||
req, ok := msg.(*testReqResMessage)
|
||||
if !ok {
|
||||
return nil, errors.WithStack(bus.ErrUnexpectedMessage)
|
||||
for err := range replyErrs {
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Errorf("%+v", errors.WithStack(err))
|
||||
}
|
||||
|
||||
result := &testReqResMessage{req.i}
|
||||
|
||||
// Simulate random work
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
|
||||
t.Logf("[RES] sending res #%d", req.i)
|
||||
|
||||
return result, nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
initWaitGroup.Wait()
|
||||
|
||||
var reqWaitGroup sync.WaitGroup
|
||||
|
||||
for i := 0; i < expectedRoundTrips; i++ {
|
||||
@ -75,32 +59,30 @@ func TestRequestReply(t *testing.T, b bus.Bus) {
|
||||
requestCtx, cancelRequest := context.WithDeadline(context.Background(), timeout)
|
||||
defer cancelRequest()
|
||||
|
||||
req := &testReqResMessage{i}
|
||||
|
||||
t.Logf("[REQ] sending req #%d", i)
|
||||
|
||||
result, err := b.Request(requestCtx, req)
|
||||
response, err := b.Request(requestCtx, bus.NewEnvelope(testTypeReqResAddress, i))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
t.Logf("[REQ] received req #%d reply", i)
|
||||
|
||||
if result == nil {
|
||||
t.Error("result should not be nil")
|
||||
if response == nil {
|
||||
t.Error("response should not be nil")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
res, ok := result.(*testReqResMessage)
|
||||
result, ok := response.Message().(int)
|
||||
if !ok {
|
||||
t.Error(errors.WithStack(bus.ErrUnexpectedMessage))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if e, g := req.i, res.i; e != g {
|
||||
t.Errorf("res.i: expected '%v', got '%v'", e, g)
|
||||
if e, g := i, result; e != g {
|
||||
t.Errorf("response.Message(): expected '%v', got '%v'", e, g)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
@ -108,3 +90,77 @@ func TestRequestReply(t *testing.T, b bus.Bus) {
|
||||
reqWaitGroup.Wait()
|
||||
resWaitGroup.Wait()
|
||||
}
|
||||
|
||||
func TestCanceledRequest(t *testing.T, b bus.Bus) {
|
||||
replyCtx, cancelReply := context.WithCancel(context.Background())
|
||||
defer cancelReply()
|
||||
|
||||
errs := b.Reply(replyCtx, testTypeReqResAddress, func(env bus.Envelope) (any, error) {
|
||||
return env.Message(), nil
|
||||
})
|
||||
|
||||
go func() {
|
||||
for err := range errs {
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Errorf("%+v", errors.WithStack(err))
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
count := 100
|
||||
|
||||
wg.Add(count)
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
|
||||
t.Logf("calling %d", i)
|
||||
|
||||
isCanceled := i%2 == 0
|
||||
|
||||
var ctx context.Context
|
||||
if isCanceled {
|
||||
canceledCtx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
ctx = canceledCtx
|
||||
} else {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
t.Logf("publishing envelope #%d", i)
|
||||
|
||||
reply, err := b.Request(ctx, bus.NewEnvelope(testTypeReqResAddress, int64(i)))
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) && isCanceled {
|
||||
return
|
||||
}
|
||||
|
||||
if errors.Is(err, bus.ErrNoResponse) && isCanceled {
|
||||
return
|
||||
}
|
||||
|
||||
t.Errorf("%+v", errors.WithStack(err))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
result, ok := reply.Message().(int64)
|
||||
if !ok {
|
||||
t.Errorf("response.Result: expected type '%T', got '%T'", int64(0), reply.Message())
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if e, g := i, int(result); e != g {
|
||||
t.Errorf("response.Result: expected '%v', got '%v'", e, g)
|
||||
|
||||
return
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
Reference in New Issue
Block a user