edge/pkg/bus/memory/request_reply.go

152 lines
2.9 KiB
Go
Raw Permalink Normal View History

2023-02-09 12:16:36 +01:00
package memory
import (
"context"
"strconv"
"sync/atomic"
"forge.cadoles.com/arcad/edge/pkg/bus"
"github.com/pkg/errors"
"gitlab.com/wpetit/goweb/logger"
)
const (
MessageNamespaceRequest bus.MessageNamespace = "reqrep/request"
MessageNamespaceReply bus.MessageNamespace = "reqrep/reply"
)
type RequestMessage struct {
RequestID uint64
Message bus.Message
ns bus.MessageNamespace
}
func (m *RequestMessage) MessageNamespace() bus.MessageNamespace {
return m.ns
}
type ReplyMessage struct {
RequestID uint64
Message bus.Message
Error error
ns bus.MessageNamespace
}
func (m *ReplyMessage) MessageNamespace() bus.MessageNamespace {
return m.ns
}
func (b *Bus) Request(ctx context.Context, msg bus.Message) (bus.Message, error) {
requestID := atomic.AddUint64(&b.nextRequestID, 1)
req := &RequestMessage{
RequestID: requestID,
Message: msg,
ns: msg.MessageNamespace(),
}
replyNamespace := createReplyNamespace(requestID)
replies, err := b.Subscribe(ctx, replyNamespace)
if err != nil {
return nil, errors.WithStack(err)
}
defer func() {
b.Unsubscribe(ctx, replyNamespace, replies)
}()
logger.Debug(ctx, "publishing request", logger.F("request", req))
if err := b.Publish(ctx, req); err != nil {
return nil, errors.WithStack(err)
}
for {
select {
case <-ctx.Done():
return nil, errors.WithStack(ctx.Err())
case msg, ok := <-replies:
if !ok {
return nil, errors.WithStack(bus.ErrNoResponse)
}
reply, ok := msg.(*ReplyMessage)
if !ok {
return nil, errors.WithStack(bus.ErrUnexpectedMessage)
}
if reply.Error != nil {
return nil, errors.WithStack(err)
}
return reply.Message, nil
}
}
}
type RequestHandler func(evt bus.Message) (bus.Message, error)
func (b *Bus) Reply(ctx context.Context, msgNamespace bus.MessageNamespace, h bus.RequestHandler) error {
requests, err := b.Subscribe(ctx, msgNamespace)
if err != nil {
return errors.WithStack(err)
}
defer func() {
b.Unsubscribe(ctx, msgNamespace, requests)
}()
for {
select {
case <-ctx.Done():
return errors.WithStack(ctx.Err())
case msg, ok := <-requests:
if !ok {
return nil
}
request, ok := msg.(*RequestMessage)
if !ok {
return errors.WithStack(bus.ErrUnexpectedMessage)
}
logger.Debug(ctx, "handling request", logger.F("request", request))
msg, err := h(request.Message)
reply := &ReplyMessage{
RequestID: request.RequestID,
Message: nil,
Error: nil,
ns: createReplyNamespace(request.RequestID),
}
if err != nil {
reply.Error = errors.WithStack(err)
} else {
reply.Message = msg
}
logger.Debug(ctx, "publishing reply", logger.F("reply", reply))
if err := b.Publish(ctx, reply); err != nil {
return errors.WithStack(err)
}
}
}
}
func createReplyNamespace(requestID uint64) bus.MessageNamespace {
return bus.NewMessageNamespace(
MessageNamespaceReply,
bus.MessageNamespace(strconv.FormatUint(requestID, 10)),
)
}