package memory import ( "context" "strconv" "sync/atomic" "forge.cadoles.com/arcad/edge/pkg/bus" "github.com/pkg/errors" "gitlab.com/wpetit/goweb/logger" ) const ( AddressRequest bus.Address = "bus/memory/request" AddressReply bus.Address = "bus/memory/reply" ) type RequestEnvelope struct { requestID uint64 wrapped bus.Envelope } func (e *RequestEnvelope) Address() bus.Address { return getRequestAddress(e.wrapped.Address()) } func (e *RequestEnvelope) Message() any { return e.wrapped.Message() } func (e *RequestEnvelope) RequestID() uint64 { return e.requestID } func (e *RequestEnvelope) Unwrap() bus.Envelope { return e.wrapped } type ReplyEnvelope struct { requestID uint64 wrapped bus.Envelope err error } func (e *ReplyEnvelope) Address() bus.Address { return getReplyAddress(e.wrapped.Address(), e.requestID) } func (e *ReplyEnvelope) Message() any { return e.wrapped.Message() } func (e *ReplyEnvelope) Err() error { return e.err } func (e *ReplyEnvelope) Unwrap() bus.Envelope { return e.wrapped } func (b *Bus) Request(ctx context.Context, env bus.Envelope) (bus.Envelope, error) { requestID := atomic.AddUint64(&b.nextRequestID, 1) req := &RequestEnvelope{ requestID: requestID, wrapped: env, } replyAddress := getReplyAddress(env.Address(), requestID) subCtx, cancel := context.WithCancel(ctx) defer cancel() replies, err := b.Subscribe(subCtx, replyAddress) if err != nil { return nil, errors.WithStack(err) } defer func() { b.Unsubscribe(replyAddress, replies) }() logger.Debug(ctx, "publishing request", logger.F("request", req)) if err := b.Publish(req); err != nil { return nil, errors.WithStack(err) } for { select { case <-ctx.Done(): return nil, errors.WithStack(ctx.Err()) case env, ok := <-replies: if !ok { return nil, errors.WithStack(bus.ErrNoResponse) } reply, ok := env.(*ReplyEnvelope) if !ok { return nil, errors.WithStack(bus.ErrUnexpectedMessage) } if err := reply.Err(); err != nil { return nil, errors.WithStack(err) } return reply.Unwrap(), nil } } } func (b *Bus) Reply(ctx context.Context, address bus.Address, handler bus.RequestHandler) chan error { requestAddress := getRequestAddress(address) errs := make(chan error) requests, err := b.Subscribe(ctx, requestAddress) if err != nil { go func() { errs <- errors.WithStack(err) close(errs) }() return errs } go func() { defer func() { b.Unsubscribe(requestAddress, requests) close(errs) }() for { select { case <-ctx.Done(): errs <- errors.WithStack(ctx.Err()) return case env, ok := <-requests: if !ok { return } request, ok := env.(*RequestEnvelope) if !ok { errs <- errors.WithStack(bus.ErrUnexpectedMessage) continue } logger.Debug(ctx, "handling request", logger.F("request", request)) msg, err := handler(request.Unwrap()) reply := &ReplyEnvelope{ requestID: request.RequestID(), wrapped: bus.NewEnvelope(request.Unwrap().Address(), msg), } if err != nil { reply.err = errors.WithStack(err) } logger.Debug(ctx, "publishing reply", logger.F("reply", reply)) if err := b.Publish(reply); err != nil { errs <- errors.WithStack(err) continue } } } }() return errs } func getRequestAddress(addr bus.Address) bus.Address { return AddressRequest + "/" + addr } func getReplyAddress(addr bus.Address, requestID uint64) bus.Address { return AddressReply + "/" + addr + "/" + bus.Address(strconv.FormatUint(requestID, 10)) }