edge/pkg/bus/memory/request_reply.go
William Petit ad49c1718c
All checks were successful
arcad/edge/pipeline/head This commit looks good
arcad/edge/pipeline/pr-master This commit looks good
feat: rewrite bus to prevent deadlocks
2023-11-30 15:02:36 +01:00

184 lines
3.5 KiB
Go

package memory
import (
"context"
"strconv"
"sync/atomic"
"forge.cadoles.com/arcad/edge/pkg/bus"
"github.com/pkg/errors"
"gitlab.com/wpetit/goweb/logger"
)
const (
AddressRequest bus.Address = "bus/memory/request"
AddressReply bus.Address = "bus/memory/reply"
)
type RequestEnvelope struct {
requestID uint64
wrapped bus.Envelope
}
func (e *RequestEnvelope) Address() bus.Address {
return getRequestAddress(e.wrapped.Address())
}
func (e *RequestEnvelope) Message() any {
return e.wrapped.Message()
}
func (e *RequestEnvelope) RequestID() uint64 {
return e.requestID
}
func (e *RequestEnvelope) Unwrap() bus.Envelope {
return e.wrapped
}
type ReplyEnvelope struct {
requestID uint64
wrapped bus.Envelope
err error
}
func (e *ReplyEnvelope) Address() bus.Address {
return getReplyAddress(e.wrapped.Address(), e.requestID)
}
func (e *ReplyEnvelope) Message() any {
return e.wrapped.Message()
}
func (e *ReplyEnvelope) Err() error {
return e.err
}
func (e *ReplyEnvelope) Unwrap() bus.Envelope {
return e.wrapped
}
func (b *Bus) Request(ctx context.Context, env bus.Envelope) (bus.Envelope, error) {
requestID := atomic.AddUint64(&b.nextRequestID, 1)
req := &RequestEnvelope{
requestID: requestID,
wrapped: env,
}
replyAddress := getReplyAddress(env.Address(), requestID)
subCtx, cancel := context.WithCancel(ctx)
defer cancel()
replies, err := b.Subscribe(subCtx, replyAddress)
if err != nil {
return nil, errors.WithStack(err)
}
defer func() {
b.Unsubscribe(replyAddress, replies)
}()
logger.Debug(ctx, "publishing request", logger.F("request", req))
if err := b.Publish(req); err != nil {
return nil, errors.WithStack(err)
}
for {
select {
case <-ctx.Done():
return nil, errors.WithStack(ctx.Err())
case env, ok := <-replies:
if !ok {
return nil, errors.WithStack(bus.ErrNoResponse)
}
reply, ok := env.(*ReplyEnvelope)
if !ok {
return nil, errors.WithStack(bus.ErrUnexpectedMessage)
}
if err := reply.Err(); err != nil {
return nil, errors.WithStack(err)
}
return reply.Unwrap(), nil
}
}
}
func (b *Bus) Reply(ctx context.Context, address bus.Address, handler bus.RequestHandler) chan error {
requestAddress := getRequestAddress(address)
errs := make(chan error)
requests, err := b.Subscribe(ctx, requestAddress)
if err != nil {
go func() {
errs <- errors.WithStack(err)
close(errs)
}()
return errs
}
go func() {
defer func() {
b.Unsubscribe(requestAddress, requests)
close(errs)
}()
for {
select {
case <-ctx.Done():
errs <- errors.WithStack(ctx.Err())
return
case env, ok := <-requests:
if !ok {
return
}
request, ok := env.(*RequestEnvelope)
if !ok {
errs <- errors.WithStack(bus.ErrUnexpectedMessage)
continue
}
logger.Debug(ctx, "handling request", logger.F("request", request))
msg, err := handler(request.Unwrap())
reply := &ReplyEnvelope{
requestID: request.RequestID(),
wrapped: bus.NewEnvelope(request.Unwrap().Address(), msg),
}
if err != nil {
reply.err = errors.WithStack(err)
}
logger.Debug(ctx, "publishing reply", logger.F("reply", reply))
if err := b.Publish(reply); err != nil {
errs <- errors.WithStack(err)
continue
}
}
}
}()
return errs
}
func getRequestAddress(addr bus.Address) bus.Address {
return AddressRequest + "/" + addr
}
func getReplyAddress(addr bus.Address, requestID uint64) bus.Address {
return AddressReply + "/" + addr + "/" + bus.Address(strconv.FormatUint(requestID, 10))
}