feat: rewrite bus to prevent deadlocks
This commit is contained in:
@ -3,11 +3,11 @@ package bus
|
||||
import "context"
|
||||
|
||||
type Bus interface {
|
||||
Subscribe(ctx context.Context, ns MessageNamespace) (<-chan Message, error)
|
||||
Unsubscribe(ctx context.Context, ns MessageNamespace, ch <-chan Message)
|
||||
Publish(ctx context.Context, msg Message) error
|
||||
Request(ctx context.Context, msg Message) (Message, error)
|
||||
Reply(ctx context.Context, ns MessageNamespace, h RequestHandler) error
|
||||
Subscribe(ctx context.Context, addr Address) (<-chan Envelope, error)
|
||||
Unsubscribe(addr Address, ch <-chan Envelope)
|
||||
Publish(env Envelope) error
|
||||
Request(ctx context.Context, env Envelope) (Envelope, error)
|
||||
Reply(ctx context.Context, addr Address, h RequestHandler) chan error
|
||||
}
|
||||
|
||||
type RequestHandler func(msg Message) (Message, error)
|
||||
type RequestHandler func(env Envelope) (any, error)
|
||||
|
32
pkg/bus/envelope.go
Normal file
32
pkg/bus/envelope.go
Normal file
@ -0,0 +1,32 @@
|
||||
package bus
|
||||
|
||||
type Address string
|
||||
|
||||
type Envelope interface {
|
||||
Message() any
|
||||
Address() Address
|
||||
}
|
||||
|
||||
type BaseEnvelope struct {
|
||||
msg any
|
||||
addr Address
|
||||
}
|
||||
|
||||
// Address implements Envelope.
|
||||
func (e *BaseEnvelope) Address() Address {
|
||||
return e.addr
|
||||
}
|
||||
|
||||
// Message implements Envelope.
|
||||
func (e *BaseEnvelope) Message() any {
|
||||
return e.msg
|
||||
}
|
||||
|
||||
func NewEnvelope(addr Address, msg any) *BaseEnvelope {
|
||||
return &BaseEnvelope{
|
||||
addr: addr,
|
||||
msg: msg,
|
||||
}
|
||||
}
|
||||
|
||||
var _ Envelope = &BaseEnvelope{}
|
@ -15,13 +15,13 @@ type Bus struct {
|
||||
nextRequestID uint64
|
||||
}
|
||||
|
||||
func (b *Bus) Subscribe(ctx context.Context, ns bus.MessageNamespace) (<-chan bus.Message, error) {
|
||||
func (b *Bus) Subscribe(ctx context.Context, address bus.Address) (<-chan bus.Envelope, error) {
|
||||
logger.Debug(
|
||||
ctx, "subscribing to messages",
|
||||
logger.F("messageNamespace", ns),
|
||||
ctx, "subscribing",
|
||||
logger.F("address", address),
|
||||
)
|
||||
|
||||
dispatchers := b.getDispatchers(ns)
|
||||
dispatchers := b.getDispatchers(address)
|
||||
disp := newEventDispatcher(b.opt.BufferSize)
|
||||
|
||||
go disp.Run(ctx)
|
||||
@ -31,50 +31,41 @@ func (b *Bus) Subscribe(ctx context.Context, ns bus.MessageNamespace) (<-chan bu
|
||||
return disp.Out(), nil
|
||||
}
|
||||
|
||||
func (b *Bus) Unsubscribe(ctx context.Context, ns bus.MessageNamespace, ch <-chan bus.Message) {
|
||||
func (b *Bus) Unsubscribe(address bus.Address, ch <-chan bus.Envelope) {
|
||||
logger.Debug(
|
||||
ctx, "unsubscribing from messages",
|
||||
logger.F("messageNamespace", ns),
|
||||
context.Background(), "unsubscribing",
|
||||
logger.F("address", address),
|
||||
)
|
||||
|
||||
dispatchers := b.getDispatchers(ns)
|
||||
dispatchers := b.getDispatchers(address)
|
||||
dispatchers.RemoveByOutChannel(ch)
|
||||
}
|
||||
|
||||
func (b *Bus) Publish(ctx context.Context, msg bus.Message) error {
|
||||
dispatchers := b.getDispatchers(msg.MessageNamespace())
|
||||
dispatchersList := dispatchers.List()
|
||||
|
||||
func (b *Bus) Publish(env bus.Envelope) error {
|
||||
dispatchers := b.getDispatchers(env.Address())
|
||||
logger.Debug(
|
||||
ctx, "publishing message",
|
||||
logger.F("dispatchers", len(dispatchersList)),
|
||||
logger.F("messageNamespace", msg.MessageNamespace()),
|
||||
context.Background(), "publish",
|
||||
logger.F("address", env.Address()),
|
||||
)
|
||||
|
||||
for _, d := range dispatchersList {
|
||||
if d.Closed() {
|
||||
dispatchers.Remove(d)
|
||||
|
||||
continue
|
||||
dispatchers.Range(func(d *eventDispatcher) {
|
||||
if err := d.In(env); err != nil {
|
||||
logger.Error(context.Background(), "could not publish message", logger.CapturedE(errors.WithStack(err)))
|
||||
}
|
||||
|
||||
if err := d.In(msg); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (b *Bus) getDispatchers(namespace bus.MessageNamespace) *eventDispatcherSet {
|
||||
strNamespace := string(namespace)
|
||||
func (b *Bus) getDispatchers(address bus.Address) *eventDispatcherSet {
|
||||
rawAddress := string(address)
|
||||
|
||||
rawDispatchers, exists := b.dispatchers.Get(strNamespace)
|
||||
rawDispatchers, exists := b.dispatchers.Get(rawAddress)
|
||||
dispatchers, ok := rawDispatchers.(*eventDispatcherSet)
|
||||
|
||||
if !exists || !ok {
|
||||
dispatchers = newEventDispatcherSet()
|
||||
b.dispatchers.Set(strNamespace, dispatchers)
|
||||
b.dispatchers.Set(rawAddress, dispatchers)
|
||||
}
|
||||
|
||||
return dispatchers
|
||||
|
@ -4,13 +4,23 @@ import (
|
||||
"testing"
|
||||
|
||||
busTesting "forge.cadoles.com/arcad/edge/pkg/bus/testing"
|
||||
"gitlab.com/wpetit/goweb/logger"
|
||||
"go.uber.org/goleak"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
goleak.VerifyTestMain(m)
|
||||
}
|
||||
|
||||
func TestMemoryBus(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Test disabled when -short flag is set")
|
||||
}
|
||||
|
||||
if testing.Verbose() {
|
||||
logger.SetLevel(logger.LevelDebug)
|
||||
}
|
||||
|
||||
t.Parallel()
|
||||
|
||||
t.Run("PublishSubscribe", func(t *testing.T) {
|
||||
@ -26,4 +36,11 @@ func TestMemoryBus(t *testing.T) {
|
||||
b := NewBus()
|
||||
busTesting.TestRequestReply(t, b)
|
||||
})
|
||||
|
||||
t.Run("CanceledRequestReply", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
b := NewBus()
|
||||
busTesting.TestCanceledRequest(t, b)
|
||||
})
|
||||
}
|
||||
|
@ -3,7 +3,6 @@ package memory
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"forge.cadoles.com/arcad/edge/pkg/bus"
|
||||
"github.com/pkg/errors"
|
||||
@ -30,7 +29,7 @@ func (s *eventDispatcherSet) Remove(d *eventDispatcher) {
|
||||
delete(s.items, d)
|
||||
}
|
||||
|
||||
func (s *eventDispatcherSet) RemoveByOutChannel(out <-chan bus.Message) {
|
||||
func (s *eventDispatcherSet) RemoveByOutChannel(out <-chan bus.Envelope) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
@ -42,17 +41,18 @@ func (s *eventDispatcherSet) RemoveByOutChannel(out <-chan bus.Message) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *eventDispatcherSet) List() []*eventDispatcher {
|
||||
func (s *eventDispatcherSet) Range(fn func(d *eventDispatcher)) {
|
||||
s.mutex.Lock()
|
||||
defer s.mutex.Unlock()
|
||||
|
||||
dispatchers := make([]*eventDispatcher, 0, len(s.items))
|
||||
|
||||
for d := range s.items {
|
||||
dispatchers = append(dispatchers, d)
|
||||
}
|
||||
if d.Closed() {
|
||||
s.Remove(d)
|
||||
continue
|
||||
}
|
||||
|
||||
return dispatchers
|
||||
fn(d)
|
||||
}
|
||||
}
|
||||
|
||||
func newEventDispatcherSet() *eventDispatcherSet {
|
||||
@ -62,8 +62,8 @@ func newEventDispatcherSet() *eventDispatcherSet {
|
||||
}
|
||||
|
||||
type eventDispatcher struct {
|
||||
in chan bus.Message
|
||||
out chan bus.Message
|
||||
in chan bus.Envelope
|
||||
out chan bus.Envelope
|
||||
mutex sync.RWMutex
|
||||
closed bool
|
||||
}
|
||||
@ -91,7 +91,7 @@ func (d *eventDispatcher) close() {
|
||||
d.closed = true
|
||||
}
|
||||
|
||||
func (d *eventDispatcher) In(msg bus.Message) (err error) {
|
||||
func (d *eventDispatcher) In(msg bus.Envelope) (err error) {
|
||||
d.mutex.RLock()
|
||||
defer d.mutex.RUnlock()
|
||||
|
||||
@ -104,67 +104,52 @@ func (d *eventDispatcher) In(msg bus.Message) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (d *eventDispatcher) Out() <-chan bus.Message {
|
||||
func (d *eventDispatcher) Out() <-chan bus.Envelope {
|
||||
return d.out
|
||||
}
|
||||
|
||||
func (d *eventDispatcher) IsOut(out <-chan bus.Message) bool {
|
||||
func (d *eventDispatcher) IsOut(out <-chan bus.Envelope) bool {
|
||||
return d.out == out
|
||||
}
|
||||
|
||||
func (d *eventDispatcher) Run(ctx context.Context) {
|
||||
defer func() {
|
||||
for {
|
||||
logger.Debug(ctx, "closing dispatcher, flushing out incoming messages")
|
||||
logger.Debug(ctx, "closing dispatcher, flushing out incoming messages")
|
||||
|
||||
close(d.out)
|
||||
close(d.out)
|
||||
|
||||
for range d.in {
|
||||
// Flush all incoming messages
|
||||
for {
|
||||
_, ok := <-d.in
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
msg, ok := <-d.in
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
timeout := time.After(time.Second)
|
||||
|
||||
select {
|
||||
case d.out <- msg:
|
||||
case <-timeout:
|
||||
logger.Error(
|
||||
ctx,
|
||||
"out message channel timeout",
|
||||
logger.F("message", msg),
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
case <-ctx.Done():
|
||||
logger.Error(
|
||||
ctx,
|
||||
"message subscription context canceled",
|
||||
logger.F("message", msg),
|
||||
logger.CapturedE(errors.WithStack(ctx.Err())),
|
||||
)
|
||||
if err := ctx.Err(); !errors.Is(err, context.Canceled) {
|
||||
logger.Error(
|
||||
ctx,
|
||||
"message subscription context canceled",
|
||||
logger.CapturedE(errors.WithStack(err)),
|
||||
)
|
||||
}
|
||||
|
||||
return
|
||||
|
||||
case msg, ok := <-d.in:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
d.out <- msg
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func newEventDispatcher(bufferSize int64) *eventDispatcher {
|
||||
return &eventDispatcher{
|
||||
in: make(chan bus.Message, bufferSize),
|
||||
out: make(chan bus.Message, bufferSize),
|
||||
in: make(chan bus.Envelope, bufferSize),
|
||||
out: make(chan bus.Envelope, bufferSize),
|
||||
closed: false,
|
||||
}
|
||||
}
|
||||
|
@ -11,57 +11,78 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
MessageNamespaceRequest bus.MessageNamespace = "reqrep/request"
|
||||
MessageNamespaceReply bus.MessageNamespace = "reqrep/reply"
|
||||
AddressRequest bus.Address = "bus/memory/request"
|
||||
AddressReply bus.Address = "bus/memory/reply"
|
||||
)
|
||||
|
||||
type RequestMessage struct {
|
||||
RequestID uint64
|
||||
|
||||
Message bus.Message
|
||||
|
||||
ns bus.MessageNamespace
|
||||
type RequestEnvelope struct {
|
||||
requestID uint64
|
||||
wrapped bus.Envelope
|
||||
}
|
||||
|
||||
func (m *RequestMessage) MessageNamespace() bus.MessageNamespace {
|
||||
return m.ns
|
||||
func (e *RequestEnvelope) Address() bus.Address {
|
||||
return getRequestAddress(e.wrapped.Address())
|
||||
}
|
||||
|
||||
type ReplyMessage struct {
|
||||
RequestID uint64
|
||||
Message bus.Message
|
||||
Error error
|
||||
|
||||
ns bus.MessageNamespace
|
||||
func (e *RequestEnvelope) Message() any {
|
||||
return e.wrapped.Message()
|
||||
}
|
||||
|
||||
func (m *ReplyMessage) MessageNamespace() bus.MessageNamespace {
|
||||
return m.ns
|
||||
func (e *RequestEnvelope) RequestID() uint64 {
|
||||
return e.requestID
|
||||
}
|
||||
|
||||
func (b *Bus) Request(ctx context.Context, msg bus.Message) (bus.Message, error) {
|
||||
func (e *RequestEnvelope) Unwrap() bus.Envelope {
|
||||
return e.wrapped
|
||||
}
|
||||
|
||||
type ReplyEnvelope struct {
|
||||
requestID uint64
|
||||
wrapped bus.Envelope
|
||||
err error
|
||||
}
|
||||
|
||||
func (e *ReplyEnvelope) Address() bus.Address {
|
||||
return getReplyAddress(e.wrapped.Address(), e.requestID)
|
||||
}
|
||||
|
||||
func (e *ReplyEnvelope) Message() any {
|
||||
return e.wrapped.Message()
|
||||
}
|
||||
|
||||
func (e *ReplyEnvelope) Err() error {
|
||||
return e.err
|
||||
}
|
||||
|
||||
func (e *ReplyEnvelope) Unwrap() bus.Envelope {
|
||||
return e.wrapped
|
||||
}
|
||||
|
||||
func (b *Bus) Request(ctx context.Context, env bus.Envelope) (bus.Envelope, error) {
|
||||
requestID := atomic.AddUint64(&b.nextRequestID, 1)
|
||||
|
||||
req := &RequestMessage{
|
||||
RequestID: requestID,
|
||||
Message: msg,
|
||||
ns: msg.MessageNamespace(),
|
||||
req := &RequestEnvelope{
|
||||
requestID: requestID,
|
||||
wrapped: env,
|
||||
}
|
||||
|
||||
replyNamespace := createReplyNamespace(requestID)
|
||||
replyAddress := getReplyAddress(env.Address(), requestID)
|
||||
|
||||
replies, err := b.Subscribe(ctx, replyNamespace)
|
||||
subCtx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
replies, err := b.Subscribe(subCtx, replyAddress)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
b.Unsubscribe(ctx, replyNamespace, replies)
|
||||
b.Unsubscribe(replyAddress, replies)
|
||||
}()
|
||||
|
||||
logger.Debug(ctx, "publishing request", logger.F("request", req))
|
||||
|
||||
if err := b.Publish(ctx, req); err != nil {
|
||||
if err := b.Publish(req); err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
@ -70,82 +91,93 @@ func (b *Bus) Request(ctx context.Context, msg bus.Message) (bus.Message, error)
|
||||
case <-ctx.Done():
|
||||
return nil, errors.WithStack(ctx.Err())
|
||||
|
||||
case msg, ok := <-replies:
|
||||
case env, ok := <-replies:
|
||||
if !ok {
|
||||
return nil, errors.WithStack(bus.ErrNoResponse)
|
||||
}
|
||||
|
||||
reply, ok := msg.(*ReplyMessage)
|
||||
reply, ok := env.(*ReplyEnvelope)
|
||||
if !ok {
|
||||
return nil, errors.WithStack(bus.ErrUnexpectedMessage)
|
||||
}
|
||||
|
||||
if reply.Error != nil {
|
||||
if err := reply.Err(); err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return reply.Message, nil
|
||||
return reply.Unwrap(), nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
type RequestHandler func(evt bus.Message) (bus.Message, error)
|
||||
func (b *Bus) Reply(ctx context.Context, address bus.Address, handler bus.RequestHandler) chan error {
|
||||
requestAddress := getRequestAddress(address)
|
||||
|
||||
func (b *Bus) Reply(ctx context.Context, msgNamespace bus.MessageNamespace, h bus.RequestHandler) error {
|
||||
requests, err := b.Subscribe(ctx, msgNamespace)
|
||||
errs := make(chan error)
|
||||
|
||||
requests, err := b.Subscribe(ctx, requestAddress)
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
go func() {
|
||||
errs <- errors.WithStack(err)
|
||||
close(errs)
|
||||
}()
|
||||
|
||||
return errs
|
||||
}
|
||||
|
||||
defer func() {
|
||||
b.Unsubscribe(ctx, msgNamespace, requests)
|
||||
go func() {
|
||||
defer func() {
|
||||
b.Unsubscribe(requestAddress, requests)
|
||||
close(errs)
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
errs <- errors.WithStack(ctx.Err())
|
||||
return
|
||||
|
||||
case env, ok := <-requests:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
|
||||
request, ok := env.(*RequestEnvelope)
|
||||
if !ok {
|
||||
errs <- errors.WithStack(bus.ErrUnexpectedMessage)
|
||||
continue
|
||||
}
|
||||
|
||||
logger.Debug(ctx, "handling request", logger.F("request", request))
|
||||
|
||||
msg, err := handler(request.Unwrap())
|
||||
|
||||
reply := &ReplyEnvelope{
|
||||
requestID: request.RequestID(),
|
||||
wrapped: bus.NewEnvelope(request.Unwrap().Address(), msg),
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
reply.err = errors.WithStack(err)
|
||||
}
|
||||
|
||||
logger.Debug(ctx, "publishing reply", logger.F("reply", reply))
|
||||
|
||||
if err := b.Publish(reply); err != nil {
|
||||
errs <- errors.WithStack(err)
|
||||
continue
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return errors.WithStack(ctx.Err())
|
||||
|
||||
case msg, ok := <-requests:
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
|
||||
request, ok := msg.(*RequestMessage)
|
||||
if !ok {
|
||||
return errors.WithStack(bus.ErrUnexpectedMessage)
|
||||
}
|
||||
|
||||
logger.Debug(ctx, "handling request", logger.F("request", request))
|
||||
|
||||
msg, err := h(request.Message)
|
||||
|
||||
reply := &ReplyMessage{
|
||||
RequestID: request.RequestID,
|
||||
Message: nil,
|
||||
Error: nil,
|
||||
|
||||
ns: createReplyNamespace(request.RequestID),
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
reply.Error = errors.WithStack(err)
|
||||
} else {
|
||||
reply.Message = msg
|
||||
}
|
||||
|
||||
logger.Debug(ctx, "publishing reply", logger.F("reply", reply))
|
||||
|
||||
if err := b.Publish(ctx, reply); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
}
|
||||
}
|
||||
return errs
|
||||
}
|
||||
|
||||
func createReplyNamespace(requestID uint64) bus.MessageNamespace {
|
||||
return bus.NewMessageNamespace(
|
||||
MessageNamespaceReply,
|
||||
bus.MessageNamespace(strconv.FormatUint(requestID, 10)),
|
||||
)
|
||||
func getRequestAddress(addr bus.Address) bus.Address {
|
||||
return AddressRequest + "/" + addr
|
||||
}
|
||||
|
||||
func getReplyAddress(addr bus.Address, requestID uint64) bus.Address {
|
||||
return AddressReply + "/" + addr + "/" + bus.Address(strconv.FormatUint(requestID, 10))
|
||||
}
|
||||
|
@ -1,33 +0,0 @@
|
||||
package bus
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type (
|
||||
MessageNamespace string
|
||||
)
|
||||
|
||||
type Message interface {
|
||||
MessageNamespace() MessageNamespace
|
||||
}
|
||||
|
||||
func NewMessageNamespace(namespaces ...MessageNamespace) MessageNamespace {
|
||||
var sb strings.Builder
|
||||
|
||||
for i, ns := range namespaces {
|
||||
if i != 0 {
|
||||
if _, err := sb.WriteString(":"); err != nil {
|
||||
panic(errors.Wrap(err, "could not build new message namespace"))
|
||||
}
|
||||
}
|
||||
|
||||
if _, err := sb.WriteString(string(ns)); err != nil {
|
||||
panic(errors.Wrap(err, "could not build new message namespace"))
|
||||
}
|
||||
}
|
||||
|
||||
return MessageNamespace(sb.String())
|
||||
}
|
@ -2,6 +2,7 @@ package testing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"testing"
|
||||
@ -12,74 +13,52 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
testNamespace bus.MessageNamespace = "testNamespace"
|
||||
testAddress bus.Address = "testAddress"
|
||||
)
|
||||
|
||||
type testMessage struct{}
|
||||
|
||||
func (e *testMessage) MessageNamespace() bus.MessageNamespace {
|
||||
return testNamespace
|
||||
}
|
||||
|
||||
func TestPublishSubscribe(t *testing.T, b bus.Bus) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Minute)
|
||||
defer cancel()
|
||||
|
||||
t.Log("subscribe")
|
||||
|
||||
messages, err := b.Subscribe(ctx, testNamespace)
|
||||
envelopes, err := b.Subscribe(ctx, testAddress)
|
||||
if err != nil {
|
||||
t.Fatal(errors.WithStack(err))
|
||||
}
|
||||
|
||||
expectedTotal := 5
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
wg.Add(5)
|
||||
wg.Add(expectedTotal)
|
||||
|
||||
go func() {
|
||||
// 5 events should be received
|
||||
t.Log("publish 0")
|
||||
|
||||
if err := b.Publish(ctx, &testMessage{}); err != nil {
|
||||
t.Error(errors.WithStack(err))
|
||||
}
|
||||
count := expectedTotal
|
||||
|
||||
t.Log("publish 1")
|
||||
for i := 0; i < count; i++ {
|
||||
env := bus.NewEnvelope(testAddress, fmt.Sprintf("message %d", i))
|
||||
|
||||
if err := b.Publish(ctx, &testMessage{}); err != nil {
|
||||
t.Error(errors.WithStack(err))
|
||||
}
|
||||
if err := b.Publish(env); err != nil {
|
||||
t.Error(errors.WithStack(err))
|
||||
}
|
||||
|
||||
t.Log("publish 2")
|
||||
|
||||
if err := b.Publish(ctx, &testMessage{}); err != nil {
|
||||
t.Error(errors.WithStack(err))
|
||||
}
|
||||
|
||||
t.Log("publish 3")
|
||||
|
||||
if err := b.Publish(ctx, &testMessage{}); err != nil {
|
||||
t.Error(errors.WithStack(err))
|
||||
}
|
||||
|
||||
t.Log("publish 4")
|
||||
|
||||
if err := b.Publish(ctx, &testMessage{}); err != nil {
|
||||
t.Error(errors.WithStack(err))
|
||||
t.Logf("published %d", i)
|
||||
}
|
||||
}()
|
||||
|
||||
var count int32 = 0
|
||||
|
||||
go func() {
|
||||
t.Log("range for events")
|
||||
t.Log("range for received envelopes")
|
||||
|
||||
for msg := range messages {
|
||||
for env := range envelopes {
|
||||
t.Logf("received msg %d", atomic.LoadInt32(&count))
|
||||
atomic.AddInt32(&count, 1)
|
||||
|
||||
if e, g := testNamespace, msg.MessageNamespace(); e != g {
|
||||
t.Errorf("evt.MessageNamespace(): expected '%v', got '%v'", e, g)
|
||||
if e, g := testAddress, env.Address(); e != g {
|
||||
t.Errorf("env.Address(): expected '%v', got '%v'", e, g)
|
||||
}
|
||||
|
||||
wg.Done()
|
||||
@ -88,9 +67,9 @@ func TestPublishSubscribe(t *testing.T, b bus.Bus) {
|
||||
|
||||
wg.Wait()
|
||||
|
||||
b.Unsubscribe(ctx, testNamespace, messages)
|
||||
b.Unsubscribe(testAddress, envelopes)
|
||||
|
||||
if e, g := int32(5), count; e != g {
|
||||
t.Errorf("message received count: expected '%v', got '%v'", e, g)
|
||||
if e, g := int32(expectedTotal), count; e != g {
|
||||
t.Errorf("envelopes received count: expected '%v', got '%v'", e, g)
|
||||
}
|
||||
}
|
||||
|
@ -11,58 +11,42 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
testTypeReqRes bus.MessageNamespace = "testNamspaceReqRes"
|
||||
testTypeReqResAddress bus.Address = "testTypeReqResAddress"
|
||||
)
|
||||
|
||||
type testReqResMessage struct {
|
||||
i int
|
||||
}
|
||||
|
||||
func (m *testReqResMessage) MessageNamespace() bus.MessageNamespace {
|
||||
return testNamespace
|
||||
}
|
||||
|
||||
func TestRequestReply(t *testing.T, b bus.Bus) {
|
||||
expectedRoundTrips := 256
|
||||
timeout := time.Now().Add(time.Duration(expectedRoundTrips) * time.Second)
|
||||
|
||||
var (
|
||||
initWaitGroup sync.WaitGroup
|
||||
resWaitGroup sync.WaitGroup
|
||||
)
|
||||
replyCtx, cancelReply := context.WithDeadline(context.Background(), timeout)
|
||||
defer cancelReply()
|
||||
|
||||
initWaitGroup.Add(1)
|
||||
var resWaitGroup sync.WaitGroup
|
||||
|
||||
replyErrs := b.Reply(replyCtx, testTypeReqResAddress, func(env bus.Envelope) (any, error) {
|
||||
defer resWaitGroup.Done()
|
||||
|
||||
req, ok := env.Message().(int)
|
||||
if !ok {
|
||||
return nil, errors.WithStack(bus.ErrUnexpectedMessage)
|
||||
}
|
||||
|
||||
// Simulate random work
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
|
||||
t.Logf("[RES] sending res #%d", req)
|
||||
|
||||
return req, nil
|
||||
})
|
||||
|
||||
go func() {
|
||||
repondCtx, cancelRespond := context.WithDeadline(context.Background(), timeout)
|
||||
defer cancelRespond()
|
||||
|
||||
initWaitGroup.Done()
|
||||
|
||||
err := b.Reply(repondCtx, testNamespace, func(msg bus.Message) (bus.Message, error) {
|
||||
defer resWaitGroup.Done()
|
||||
|
||||
req, ok := msg.(*testReqResMessage)
|
||||
if !ok {
|
||||
return nil, errors.WithStack(bus.ErrUnexpectedMessage)
|
||||
for err := range replyErrs {
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Errorf("%+v", errors.WithStack(err))
|
||||
}
|
||||
|
||||
result := &testReqResMessage{req.i}
|
||||
|
||||
// Simulate random work
|
||||
time.Sleep(time.Millisecond * 100)
|
||||
|
||||
t.Logf("[RES] sending res #%d", req.i)
|
||||
|
||||
return result, nil
|
||||
})
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
}()
|
||||
|
||||
initWaitGroup.Wait()
|
||||
|
||||
var reqWaitGroup sync.WaitGroup
|
||||
|
||||
for i := 0; i < expectedRoundTrips; i++ {
|
||||
@ -75,32 +59,30 @@ func TestRequestReply(t *testing.T, b bus.Bus) {
|
||||
requestCtx, cancelRequest := context.WithDeadline(context.Background(), timeout)
|
||||
defer cancelRequest()
|
||||
|
||||
req := &testReqResMessage{i}
|
||||
|
||||
t.Logf("[REQ] sending req #%d", i)
|
||||
|
||||
result, err := b.Request(requestCtx, req)
|
||||
response, err := b.Request(requestCtx, bus.NewEnvelope(testTypeReqResAddress, i))
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
|
||||
t.Logf("[REQ] received req #%d reply", i)
|
||||
|
||||
if result == nil {
|
||||
t.Error("result should not be nil")
|
||||
if response == nil {
|
||||
t.Error("response should not be nil")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
res, ok := result.(*testReqResMessage)
|
||||
result, ok := response.Message().(int)
|
||||
if !ok {
|
||||
t.Error(errors.WithStack(bus.ErrUnexpectedMessage))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if e, g := req.i, res.i; e != g {
|
||||
t.Errorf("res.i: expected '%v', got '%v'", e, g)
|
||||
if e, g := i, result; e != g {
|
||||
t.Errorf("response.Message(): expected '%v', got '%v'", e, g)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
@ -108,3 +90,77 @@ func TestRequestReply(t *testing.T, b bus.Bus) {
|
||||
reqWaitGroup.Wait()
|
||||
resWaitGroup.Wait()
|
||||
}
|
||||
|
||||
func TestCanceledRequest(t *testing.T, b bus.Bus) {
|
||||
replyCtx, cancelReply := context.WithCancel(context.Background())
|
||||
defer cancelReply()
|
||||
|
||||
errs := b.Reply(replyCtx, testTypeReqResAddress, func(env bus.Envelope) (any, error) {
|
||||
return env.Message(), nil
|
||||
})
|
||||
|
||||
go func() {
|
||||
for err := range errs {
|
||||
if !errors.Is(err, context.Canceled) {
|
||||
t.Errorf("%+v", errors.WithStack(err))
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
|
||||
count := 100
|
||||
|
||||
wg.Add(count)
|
||||
|
||||
for i := 0; i < count; i++ {
|
||||
go func(i int) {
|
||||
defer wg.Done()
|
||||
|
||||
t.Logf("calling %d", i)
|
||||
|
||||
isCanceled := i%2 == 0
|
||||
|
||||
var ctx context.Context
|
||||
if isCanceled {
|
||||
canceledCtx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
ctx = canceledCtx
|
||||
} else {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
t.Logf("publishing envelope #%d", i)
|
||||
|
||||
reply, err := b.Request(ctx, bus.NewEnvelope(testTypeReqResAddress, int64(i)))
|
||||
if err != nil {
|
||||
if errors.Is(err, context.Canceled) && isCanceled {
|
||||
return
|
||||
}
|
||||
|
||||
if errors.Is(err, bus.ErrNoResponse) && isCanceled {
|
||||
return
|
||||
}
|
||||
|
||||
t.Errorf("%+v", errors.WithStack(err))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
result, ok := reply.Message().(int64)
|
||||
if !ok {
|
||||
t.Errorf("response.Result: expected type '%T', got '%T'", int64(0), reply.Message())
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if e, g := i, int(result); e != g {
|
||||
t.Errorf("response.Result: expected '%v', got '%v'", e, g)
|
||||
|
||||
return
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
}
|
||||
|
Reference in New Issue
Block a user