package memory import ( "context" "strconv" "sync/atomic" "forge.cadoles.com/arcad/edge/pkg/bus" "github.com/pkg/errors" "gitlab.com/wpetit/goweb/logger" ) const ( MessageNamespaceRequest bus.MessageNamespace = "reqrep/request" MessageNamespaceReply bus.MessageNamespace = "reqrep/reply" ) type RequestMessage struct { RequestID uint64 Message bus.Message ns bus.MessageNamespace } func (m *RequestMessage) MessageNamespace() bus.MessageNamespace { return m.ns } type ReplyMessage struct { RequestID uint64 Message bus.Message Error error ns bus.MessageNamespace } func (m *ReplyMessage) MessageNamespace() bus.MessageNamespace { return m.ns } func (b *Bus) Request(ctx context.Context, msg bus.Message) (bus.Message, error) { requestID := atomic.AddUint64(&b.nextRequestID, 1) req := &RequestMessage{ RequestID: requestID, Message: msg, ns: msg.MessageNamespace(), } replyNamespace := createReplyNamespace(requestID) replies, err := b.Subscribe(ctx, replyNamespace) if err != nil { return nil, errors.WithStack(err) } defer func() { b.Unsubscribe(ctx, replyNamespace, replies) }() logger.Debug(ctx, "publishing request", logger.F("request", req)) if err := b.Publish(ctx, req); err != nil { return nil, errors.WithStack(err) } for { select { case <-ctx.Done(): return nil, errors.WithStack(ctx.Err()) case msg, ok := <-replies: if !ok { return nil, errors.WithStack(bus.ErrNoResponse) } reply, ok := msg.(*ReplyMessage) if !ok { return nil, errors.WithStack(bus.ErrUnexpectedMessage) } if reply.Error != nil { return nil, errors.WithStack(err) } return reply.Message, nil } } } type RequestHandler func(evt bus.Message) (bus.Message, error) func (b *Bus) Reply(ctx context.Context, msgNamespace bus.MessageNamespace, h bus.RequestHandler) error { requests, err := b.Subscribe(ctx, msgNamespace) if err != nil { return errors.WithStack(err) } defer func() { b.Unsubscribe(ctx, msgNamespace, requests) }() for { select { case <-ctx.Done(): return errors.WithStack(ctx.Err()) case msg, ok := <-requests: if !ok { return nil } request, ok := msg.(*RequestMessage) if !ok { return errors.WithStack(bus.ErrUnexpectedMessage) } logger.Debug(ctx, "handling request", logger.F("request", request)) msg, err := h(request.Message) reply := &ReplyMessage{ RequestID: request.RequestID, Message: nil, Error: nil, ns: createReplyNamespace(request.RequestID), } if err != nil { reply.Error = errors.WithStack(err) } else { reply.Message = msg } logger.Debug(ctx, "publishing reply", logger.F("reply", reply)) if err := b.Publish(ctx, reply); err != nil { return errors.WithStack(err) } } } } func createReplyNamespace(requestID uint64) bus.MessageNamespace { return bus.NewMessageNamespace( MessageNamespaceReply, bus.MessageNamespace(strconv.FormatUint(requestID, 10)), ) }