From c4b2db4b7f2d1b616ab4d3db81a411ce7784b421 Mon Sep 17 00:00:00 2001 From: William Petit Date: Tue, 28 Nov 2023 16:35:49 +0100 Subject: [PATCH] feat: rewrite bus to prevent deadlocks --- pkg/app/promise_proxy.go | 4 + pkg/app/server.go | 132 +++++---- pkg/bus/bus.go | 12 +- pkg/bus/envelope.go | 32 ++ pkg/bus/memory/bus.go | 40 +-- pkg/bus/memory/event_dispatcher.go | 49 ++- pkg/bus/memory/request_reply.go | 120 ++++---- pkg/bus/message.go | 33 --- pkg/bus/testing/publish_subscribe.go | 61 ++-- pkg/bus/testing/request_reply.go | 34 +-- pkg/http/client.go | 4 +- pkg/http/context.go | 55 ++++ pkg/http/envelope.go | 30 ++ pkg/http/handler.go | 13 +- pkg/http/sockjs.go | 41 ++- pkg/http/util.go | 82 ++++++ pkg/module/app/memory/module_test.go | 14 +- pkg/module/auth/module_test.go | 22 +- .../blob/{blob_message.go => envelope.go} | 43 +-- pkg/{http/blob.go => module/blob/http.go} | 128 ++------ pkg/module/blob/module_test.go | 14 +- pkg/module/cast/module_test.go | 21 +- pkg/module/fetch/envelope.go | 38 +++ pkg/module/fetch/fetch_message.go | 49 --- pkg/{http/fetch.go => module/fetch/http.go} | 41 +-- pkg/module/fetch/module.go | 10 +- pkg/module/fetch/module_test.go | 30 +- pkg/module/lifecycle.go | 21 +- pkg/module/message.go | 38 --- pkg/module/net/envelope.go | 38 +++ pkg/module/net/module.go | 35 ++- pkg/module/rpc.go | 278 ------------------ pkg/module/rpc/envelope.go | 21 ++ pkg/module/rpc/error.go | 7 + pkg/module/rpc/rpc.go | 199 +++++++++++++ pkg/module/rpc/rpc_test.go | 109 +++++++ pkg/module/rpc/testdata/deadlock.js | 14 + pkg/module/share/module_test.go | 10 +- pkg/module/store/module_test.go | 10 +- 39 files changed, 1029 insertions(+), 903 deletions(-) create mode 100644 pkg/bus/envelope.go delete mode 100644 pkg/bus/message.go create mode 100644 pkg/http/context.go create mode 100644 pkg/http/envelope.go create mode 100644 pkg/http/util.go rename pkg/module/blob/{blob_message.go => envelope.go} (52%) rename pkg/{http/blob.go => module/blob/http.go} (53%) create mode 100644 pkg/module/fetch/envelope.go delete mode 100644 pkg/module/fetch/fetch_message.go rename pkg/{http/fetch.go => module/fetch/http.go} (60%) delete mode 100644 pkg/module/message.go create mode 100644 pkg/module/net/envelope.go delete mode 100644 pkg/module/rpc.go create mode 100644 pkg/module/rpc/envelope.go create mode 100644 pkg/module/rpc/error.go create mode 100644 pkg/module/rpc/rpc.go create mode 100644 pkg/module/rpc/rpc_test.go create mode 100644 pkg/module/rpc/testdata/deadlock.js diff --git a/pkg/app/promise_proxy.go b/pkg/app/promise_proxy.go index 2031be1..dfdb4b4 100644 --- a/pkg/app/promise_proxy.go +++ b/pkg/app/promise_proxy.go @@ -47,6 +47,10 @@ func NewPromiseProxyFrom(rt *goja.Runtime) *PromiseProxy { } func IsPromise(v goja.Value) (*goja.Promise, bool) { + if v == nil { + return nil, false + } + promise, ok := v.Export().(*goja.Promise) return promise, ok } diff --git a/pkg/app/server.go b/pkg/app/server.go index 7f2c08e..6820117 100644 --- a/pkg/app/server.go +++ b/pkg/app/server.go @@ -4,6 +4,7 @@ import ( "context" "math/rand" "sync" + "time" "github.com/dop251/goja" "github.com/dop251/goja_nodejs/eventloop" @@ -22,22 +23,6 @@ type Server struct { modules []ServerModule } -func (s *Server) Load(name string, src string) error { - var err error - - s.loop.RunOnLoop(func(rt *goja.Runtime) { - _, err = rt.RunScript(name, src) - if err != nil { - err = errors.Wrap(err, "could not run js script") - } - }) - if err != nil { - return errors.WithStack(err) - } - - return nil -} - func (s *Server) ExecFuncByName(ctx context.Context, funcName string, args ...interface{}) (goja.Value, error) { ctx = logger.With(ctx, logger.F("function", funcName), logger.F("args", args)) @@ -50,15 +35,16 @@ func (s *Server) ExecFuncByName(ctx context.Context, funcName string, args ...in } func (s *Server) Exec(ctx context.Context, callableOrFuncname any, args ...interface{}) (goja.Value, error) { - var ( - wg sync.WaitGroup + type result struct { value goja.Value err error - ) + } - wg.Add(1) + done := make(chan result) s.loop.RunOnLoop(func(rt *goja.Runtime) { + defer close(done) + var callable goja.Callable switch typ := callableOrFuncname.(type) { case goja.Callable: @@ -67,7 +53,9 @@ func (s *Server) Exec(ctx context.Context, callableOrFuncname any, args ...inter case string: call, ok := goja.AssertFunction(rt.Get(typ)) if !ok { - err = errors.WithStack(ErrFuncDoesNotExist) + done <- result{ + err: errors.WithStack(ErrFuncDoesNotExist), + } return } @@ -75,28 +63,27 @@ func (s *Server) Exec(ctx context.Context, callableOrFuncname any, args ...inter callable = call default: - err = errors.Errorf("callableOrFuncname: expected callable or function name, got '%T'", callableOrFuncname) + done <- result{ + err: errors.Errorf("callableOrFuncname: expected callable or function name, got '%T'", callableOrFuncname), + } return } - logger.Debug(ctx, "executing callable") - - defer wg.Done() - defer func() { - if recovered := recover(); recovered != nil { - revoveredErr, ok := recovered.(error) - if ok { - logger.Error(ctx, "recovered runtime error", logger.CapturedE(errors.WithStack(revoveredErr))) - - err = errors.WithStack(ErrUnknownError) - - return - } + recovered := recover() + if recovered == nil { + return + } + recoveredErr, ok := recovered.(error) + if !ok { panic(recovered) } + + done <- result{ + err: recoveredErr, + } }() jsArgs := make([]goja.Value, 0, len(args)) @@ -104,19 +91,40 @@ func (s *Server) Exec(ctx context.Context, callableOrFuncname any, args ...inter jsArgs = append(jsArgs, rt.ToValue(a)) } - value, err = callable(nil, jsArgs...) + logger.Debug(ctx, "executing callable", logger.F("callable", callableOrFuncname)) + + start := time.Now() + value, err := callable(nil, jsArgs...) if err != nil { - err = errors.WithStack(err) + done <- result{ + err: errors.WithStack(err), + } + + return } + + done <- result{ + value: value, + } + + logger.Debug(ctx, "executed callable", logger.F("callable", callableOrFuncname), logger.F("duration", time.Since(start).String())) }) - wg.Wait() + select { + case <-ctx.Done(): + if err := ctx.Err(); err != nil { + return nil, errors.WithStack(err) + } - if err != nil { - return nil, errors.WithStack(err) + return nil, nil + + case result := <-done: + if result.err != nil { + return nil, errors.WithStack(result.err) + } + + return result.value, nil } - - return value, nil } func (s *Server) WaitForPromise(promise *goja.Promise) goja.Value { @@ -162,20 +170,40 @@ func (s *Server) WaitForPromise(promise *goja.Promise) goja.Value { return value } -func (s *Server) Start(ctx context.Context) error { +func (s *Server) Start(ctx context.Context, name string, src string) error { s.loop.Start() - var err error + done := make(chan error) s.loop.RunOnLoop(func(rt *goja.Runtime) { + defer close(done) + rt.SetFieldNameMapper(goja.TagFieldNameMapper("goja", true)) rt.SetRandSource(createRandomSource()) - if err = s.initModules(ctx, rt); err != nil { + if err := s.loadModules(ctx, rt); err != nil { err = errors.WithStack(err) + done <- err + + return } + + if _, err := rt.RunScript(name, src); err != nil { + done <- errors.Wrap(err, "could not run js script") + return + } + + if err := s.initModules(ctx, rt); err != nil { + err = errors.WithStack(err) + done <- err + + return + } + + done <- nil }) - if err != nil { + + if err := <-done; err != nil { return errors.WithStack(err) } @@ -186,7 +214,7 @@ func (s *Server) Stop() { s.loop.Stop() } -func (s *Server) initModules(ctx context.Context, rt *goja.Runtime) error { +func (s *Server) loadModules(ctx context.Context, rt *goja.Runtime) error { modules := make([]ServerModule, 0, len(s.factories)) for _, moduleFactory := range s.factories { @@ -200,7 +228,13 @@ func (s *Server) initModules(ctx context.Context, rt *goja.Runtime) error { modules = append(modules, mod) } - for _, mod := range modules { + s.modules = modules + + return nil +} + +func (s *Server) initModules(ctx context.Context, rt *goja.Runtime) error { + for _, mod := range s.modules { initMod, ok := mod.(InitializableModule) if !ok { continue @@ -213,8 +247,6 @@ func (s *Server) initModules(ctx context.Context, rt *goja.Runtime) error { } } - s.modules = modules - return nil } diff --git a/pkg/bus/bus.go b/pkg/bus/bus.go index a02d437..452333a 100644 --- a/pkg/bus/bus.go +++ b/pkg/bus/bus.go @@ -3,11 +3,11 @@ package bus import "context" type Bus interface { - Subscribe(ctx context.Context, ns MessageNamespace) (<-chan Message, error) - Unsubscribe(ctx context.Context, ns MessageNamespace, ch <-chan Message) - Publish(ctx context.Context, msg Message) error - Request(ctx context.Context, msg Message) (Message, error) - Reply(ctx context.Context, ns MessageNamespace, h RequestHandler) error + Subscribe(ctx context.Context, addr Address) (<-chan Envelope, error) + Unsubscribe(addr Address, ch <-chan Envelope) + Publish(env Envelope) error + Request(ctx context.Context, env Envelope) (Envelope, error) + Reply(ctx context.Context, addr Address, h RequestHandler) error } -type RequestHandler func(msg Message) (Message, error) +type RequestHandler func(env Envelope) (any, error) diff --git a/pkg/bus/envelope.go b/pkg/bus/envelope.go new file mode 100644 index 0000000..23a7e7d --- /dev/null +++ b/pkg/bus/envelope.go @@ -0,0 +1,32 @@ +package bus + +type Address string + +type Envelope interface { + Message() any + Address() Address +} + +type BaseEnvelope struct { + msg any + addr Address +} + +// Address implements Envelope. +func (e *BaseEnvelope) Address() Address { + return e.addr +} + +// Message implements Envelope. +func (e *BaseEnvelope) Message() any { + return e.msg +} + +func NewEnvelope(addr Address, msg any) *BaseEnvelope { + return &BaseEnvelope{ + addr: addr, + msg: msg, + } +} + +var _ Envelope = &BaseEnvelope{} diff --git a/pkg/bus/memory/bus.go b/pkg/bus/memory/bus.go index d6b3088..98cd979 100644 --- a/pkg/bus/memory/bus.go +++ b/pkg/bus/memory/bus.go @@ -15,13 +15,13 @@ type Bus struct { nextRequestID uint64 } -func (b *Bus) Subscribe(ctx context.Context, ns bus.MessageNamespace) (<-chan bus.Message, error) { +func (b *Bus) Subscribe(ctx context.Context, address bus.Address) (<-chan bus.Envelope, error) { logger.Debug( - ctx, "subscribing to messages", - logger.F("messageNamespace", ns), + ctx, "subscribing", + logger.F("address", address), ) - dispatchers := b.getDispatchers(ns) + dispatchers := b.getDispatchers(address) disp := newEventDispatcher(b.opt.BufferSize) go disp.Run(ctx) @@ -31,24 +31,24 @@ func (b *Bus) Subscribe(ctx context.Context, ns bus.MessageNamespace) (<-chan bu return disp.Out(), nil } -func (b *Bus) Unsubscribe(ctx context.Context, ns bus.MessageNamespace, ch <-chan bus.Message) { +func (b *Bus) Unsubscribe(address bus.Address, ch <-chan bus.Envelope) { logger.Debug( - ctx, "unsubscribing from messages", - logger.F("messageNamespace", ns), + context.Background(), "unsubscribing", + logger.F("address", address), ) - dispatchers := b.getDispatchers(ns) + dispatchers := b.getDispatchers(address) dispatchers.RemoveByOutChannel(ch) } -func (b *Bus) Publish(ctx context.Context, msg bus.Message) error { - dispatchers := b.getDispatchers(msg.MessageNamespace()) +func (b *Bus) Publish(env bus.Envelope) error { + dispatchers := b.getDispatchers(env.Address()) dispatchersList := dispatchers.List() logger.Debug( - ctx, "publishing message", + context.Background(), "publish", logger.F("dispatchers", len(dispatchersList)), - logger.F("messageNamespace", msg.MessageNamespace()), + logger.F("address", env.Address()), ) for _, d := range dispatchersList { @@ -58,23 +58,25 @@ func (b *Bus) Publish(ctx context.Context, msg bus.Message) error { continue } - if err := d.In(msg); err != nil { - return errors.WithStack(err) - } + go func(d *eventDispatcher) { + if err := d.In(env); err != nil { + logger.Error(context.Background(), "could not publish message", logger.CapturedE(errors.WithStack(err))) + } + }(d) } return nil } -func (b *Bus) getDispatchers(namespace bus.MessageNamespace) *eventDispatcherSet { - strNamespace := string(namespace) +func (b *Bus) getDispatchers(address bus.Address) *eventDispatcherSet { + rawAddress := string(address) - rawDispatchers, exists := b.dispatchers.Get(strNamespace) + rawDispatchers, exists := b.dispatchers.Get(rawAddress) dispatchers, ok := rawDispatchers.(*eventDispatcherSet) if !exists || !ok { dispatchers = newEventDispatcherSet() - b.dispatchers.Set(strNamespace, dispatchers) + b.dispatchers.Set(rawAddress, dispatchers) } return dispatchers diff --git a/pkg/bus/memory/event_dispatcher.go b/pkg/bus/memory/event_dispatcher.go index a078939..3bb7253 100644 --- a/pkg/bus/memory/event_dispatcher.go +++ b/pkg/bus/memory/event_dispatcher.go @@ -3,7 +3,6 @@ package memory import ( "context" "sync" - "time" "forge.cadoles.com/arcad/edge/pkg/bus" "github.com/pkg/errors" @@ -30,7 +29,7 @@ func (s *eventDispatcherSet) Remove(d *eventDispatcher) { delete(s.items, d) } -func (s *eventDispatcherSet) RemoveByOutChannel(out <-chan bus.Message) { +func (s *eventDispatcherSet) RemoveByOutChannel(out <-chan bus.Envelope) { s.mutex.Lock() defer s.mutex.Unlock() @@ -62,8 +61,8 @@ func newEventDispatcherSet() *eventDispatcherSet { } type eventDispatcher struct { - in chan bus.Message - out chan bus.Message + in chan bus.Envelope + out chan bus.Envelope mutex sync.RWMutex closed bool } @@ -91,7 +90,7 @@ func (d *eventDispatcher) close() { d.closed = true } -func (d *eventDispatcher) In(msg bus.Message) (err error) { +func (d *eventDispatcher) In(msg bus.Envelope) (err error) { d.mutex.RLock() defer d.mutex.RUnlock() @@ -104,11 +103,11 @@ func (d *eventDispatcher) In(msg bus.Message) (err error) { return nil } -func (d *eventDispatcher) Out() <-chan bus.Message { +func (d *eventDispatcher) Out() <-chan bus.Envelope { return d.out } -func (d *eventDispatcher) IsOut(out <-chan bus.Message) bool { +func (d *eventDispatcher) IsOut(out <-chan bus.Envelope) bool { return d.out == out } @@ -130,29 +129,29 @@ func (d *eventDispatcher) Run(ctx context.Context) { }() for { - msg, ok := <-d.in - if !ok { - return - } - - timeout := time.After(time.Second) - select { - case d.out <- msg: - case <-timeout: - logger.Error( - ctx, - "out message channel timeout", - logger.F("message", msg), - ) + case msg, ok := <-d.in: + if !ok { + return + } - return + select { + case d.out <- msg: + case <-ctx.Done(): + logger.Error( + ctx, + "message subscription context canceled", + logger.F("message", msg), + logger.CapturedE(errors.WithStack(ctx.Err())), + ) + + return + } case <-ctx.Done(): logger.Error( ctx, "message subscription context canceled", - logger.F("message", msg), logger.CapturedE(errors.WithStack(ctx.Err())), ) @@ -163,8 +162,8 @@ func (d *eventDispatcher) Run(ctx context.Context) { func newEventDispatcher(bufferSize int64) *eventDispatcher { return &eventDispatcher{ - in: make(chan bus.Message, bufferSize), - out: make(chan bus.Message, bufferSize), + in: make(chan bus.Envelope, bufferSize), + out: make(chan bus.Envelope, bufferSize), closed: false, } } diff --git a/pkg/bus/memory/request_reply.go b/pkg/bus/memory/request_reply.go index aaa9390..80e6ed1 100644 --- a/pkg/bus/memory/request_reply.go +++ b/pkg/bus/memory/request_reply.go @@ -11,57 +11,75 @@ import ( ) const ( - MessageNamespaceRequest bus.MessageNamespace = "reqrep/request" - MessageNamespaceReply bus.MessageNamespace = "reqrep/reply" + AddressRequest bus.Address = "bus/memory/request" + AddressReply bus.Address = "bus/memory/reply" ) -type RequestMessage struct { - RequestID uint64 - - Message bus.Message - - ns bus.MessageNamespace +type RequestEnvelope struct { + requestID uint64 + wrapped bus.Envelope } -func (m *RequestMessage) MessageNamespace() bus.MessageNamespace { - return m.ns +func (e *RequestEnvelope) Address() bus.Address { + return getRequestAddress(e.wrapped.Address()) } -type ReplyMessage struct { - RequestID uint64 - Message bus.Message - Error error - - ns bus.MessageNamespace +func (e *RequestEnvelope) Message() any { + return e.wrapped.Message() } -func (m *ReplyMessage) MessageNamespace() bus.MessageNamespace { - return m.ns +func (e *RequestEnvelope) RequestID() uint64 { + return e.requestID } -func (b *Bus) Request(ctx context.Context, msg bus.Message) (bus.Message, error) { +func (e *RequestEnvelope) Unwrap() bus.Envelope { + return e.wrapped +} + +type ReplyEnvelope struct { + requestID uint64 + wrapped bus.Envelope + err error +} + +func (e *ReplyEnvelope) Address() bus.Address { + return getReplyAddress(e.wrapped.Address(), e.requestID) +} + +func (e *ReplyEnvelope) Message() any { + return e.wrapped.Message() +} + +func (e *ReplyEnvelope) Err() error { + return e.err +} + +func (e *ReplyEnvelope) Unwrap() bus.Envelope { + return e.wrapped +} + +func (b *Bus) Request(ctx context.Context, env bus.Envelope) (bus.Envelope, error) { requestID := atomic.AddUint64(&b.nextRequestID, 1) - req := &RequestMessage{ - RequestID: requestID, - Message: msg, - ns: msg.MessageNamespace(), + req := &RequestEnvelope{ + requestID: requestID, + wrapped: env, } - replyNamespace := createReplyNamespace(requestID) + replyAddress := getReplyAddress(env.Address(), requestID) - replies, err := b.Subscribe(ctx, replyNamespace) + replies, err := b.Subscribe(ctx, replyAddress) if err != nil { return nil, errors.WithStack(err) } defer func() { - b.Unsubscribe(ctx, replyNamespace, replies) + b.Unsubscribe(replyAddress, replies) }() logger.Debug(ctx, "publishing request", logger.F("request", req)) - if err := b.Publish(ctx, req); err != nil { + if err := b.Publish(req); err != nil { return nil, errors.WithStack(err) } @@ -70,35 +88,35 @@ func (b *Bus) Request(ctx context.Context, msg bus.Message) (bus.Message, error) case <-ctx.Done(): return nil, errors.WithStack(ctx.Err()) - case msg, ok := <-replies: + case env, ok := <-replies: if !ok { return nil, errors.WithStack(bus.ErrNoResponse) } - reply, ok := msg.(*ReplyMessage) + reply, ok := env.(*ReplyEnvelope) if !ok { return nil, errors.WithStack(bus.ErrUnexpectedMessage) } - if reply.Error != nil { + if err := reply.Err(); err != nil { return nil, errors.WithStack(err) } - return reply.Message, nil + return reply.Unwrap(), nil } } } -type RequestHandler func(evt bus.Message) (bus.Message, error) +func (b *Bus) Reply(ctx context.Context, address bus.Address, handler bus.RequestHandler) error { + requestAddress := getRequestAddress(address) -func (b *Bus) Reply(ctx context.Context, msgNamespace bus.MessageNamespace, h bus.RequestHandler) error { - requests, err := b.Subscribe(ctx, msgNamespace) + requests, err := b.Subscribe(ctx, requestAddress) if err != nil { return errors.WithStack(err) } defer func() { - b.Unsubscribe(ctx, msgNamespace, requests) + b.Unsubscribe(requestAddress, requests) }() for { @@ -106,46 +124,42 @@ func (b *Bus) Reply(ctx context.Context, msgNamespace bus.MessageNamespace, h bu case <-ctx.Done(): return errors.WithStack(ctx.Err()) - case msg, ok := <-requests: + case env, ok := <-requests: if !ok { return nil } - request, ok := msg.(*RequestMessage) + request, ok := env.(*RequestEnvelope) if !ok { return errors.WithStack(bus.ErrUnexpectedMessage) } logger.Debug(ctx, "handling request", logger.F("request", request)) - msg, err := h(request.Message) + msg, err := handler(request.Unwrap()) - reply := &ReplyMessage{ - RequestID: request.RequestID, - Message: nil, - Error: nil, - - ns: createReplyNamespace(request.RequestID), + reply := &ReplyEnvelope{ + requestID: request.RequestID(), + wrapped: bus.NewEnvelope(request.Unwrap().Address(), msg), } if err != nil { - reply.Error = errors.WithStack(err) - } else { - reply.Message = msg + reply.err = errors.WithStack(err) } logger.Debug(ctx, "publishing reply", logger.F("reply", reply)) - if err := b.Publish(ctx, reply); err != nil { + if err := b.Publish(reply); err != nil { return errors.WithStack(err) } } } } -func createReplyNamespace(requestID uint64) bus.MessageNamespace { - return bus.NewMessageNamespace( - MessageNamespaceReply, - bus.MessageNamespace(strconv.FormatUint(requestID, 10)), - ) +func getRequestAddress(addr bus.Address) bus.Address { + return AddressRequest + "/" + addr +} + +func getReplyAddress(addr bus.Address, requestID uint64) bus.Address { + return AddressReply + "/" + addr + "/" + bus.Address(strconv.FormatUint(requestID, 10)) } diff --git a/pkg/bus/message.go b/pkg/bus/message.go deleted file mode 100644 index 3a470d1..0000000 --- a/pkg/bus/message.go +++ /dev/null @@ -1,33 +0,0 @@ -package bus - -import ( - "strings" - - "github.com/pkg/errors" -) - -type ( - MessageNamespace string -) - -type Message interface { - MessageNamespace() MessageNamespace -} - -func NewMessageNamespace(namespaces ...MessageNamespace) MessageNamespace { - var sb strings.Builder - - for i, ns := range namespaces { - if i != 0 { - if _, err := sb.WriteString(":"); err != nil { - panic(errors.Wrap(err, "could not build new message namespace")) - } - } - - if _, err := sb.WriteString(string(ns)); err != nil { - panic(errors.Wrap(err, "could not build new message namespace")) - } - } - - return MessageNamespace(sb.String()) -} diff --git a/pkg/bus/testing/publish_subscribe.go b/pkg/bus/testing/publish_subscribe.go index 6db69e3..2dd917b 100644 --- a/pkg/bus/testing/publish_subscribe.go +++ b/pkg/bus/testing/publish_subscribe.go @@ -2,6 +2,7 @@ package testing import ( "context" + "fmt" "sync" "sync/atomic" "testing" @@ -12,74 +13,52 @@ import ( ) const ( - testNamespace bus.MessageNamespace = "testNamespace" + testAddress bus.Address = "testAddress" ) -type testMessage struct{} - -func (e *testMessage) MessageNamespace() bus.MessageNamespace { - return testNamespace -} - func TestPublishSubscribe(t *testing.T, b bus.Bus) { ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() t.Log("subscribe") - messages, err := b.Subscribe(ctx, testNamespace) + envelopes, err := b.Subscribe(ctx, testAddress) if err != nil { t.Fatal(errors.WithStack(err)) } + expectedTotal := 5 + var wg sync.WaitGroup - wg.Add(5) + wg.Add(expectedTotal) go func() { - // 5 events should be received - t.Log("publish 0") - if err := b.Publish(ctx, &testMessage{}); err != nil { - t.Error(errors.WithStack(err)) - } + count := expectedTotal - t.Log("publish 1") + for i := 0; i < count; i++ { + env := bus.NewEnvelope(testAddress, fmt.Sprintf("message %d", i)) - if err := b.Publish(ctx, &testMessage{}); err != nil { - t.Error(errors.WithStack(err)) - } + if err := b.Publish(env); err != nil { + t.Error(errors.WithStack(err)) + } - t.Log("publish 2") - - if err := b.Publish(ctx, &testMessage{}); err != nil { - t.Error(errors.WithStack(err)) - } - - t.Log("publish 3") - - if err := b.Publish(ctx, &testMessage{}); err != nil { - t.Error(errors.WithStack(err)) - } - - t.Log("publish 4") - - if err := b.Publish(ctx, &testMessage{}); err != nil { - t.Error(errors.WithStack(err)) + t.Logf("published %d", i) } }() var count int32 = 0 go func() { - t.Log("range for events") + t.Log("range for received envelopes") - for msg := range messages { + for env := range envelopes { t.Logf("received msg %d", atomic.LoadInt32(&count)) atomic.AddInt32(&count, 1) - if e, g := testNamespace, msg.MessageNamespace(); e != g { - t.Errorf("evt.MessageNamespace(): expected '%v', got '%v'", e, g) + if e, g := testAddress, env.Address(); e != g { + t.Errorf("env.Address(): expected '%v', got '%v'", e, g) } wg.Done() @@ -88,9 +67,9 @@ func TestPublishSubscribe(t *testing.T, b bus.Bus) { wg.Wait() - b.Unsubscribe(ctx, testNamespace, messages) + b.Unsubscribe(testAddress, envelopes) - if e, g := int32(5), count; e != g { - t.Errorf("message received count: expected '%v', got '%v'", e, g) + if e, g := int32(expectedTotal), count; e != g { + t.Errorf("envelopes received count: expected '%v', got '%v'", e, g) } } diff --git a/pkg/bus/testing/request_reply.go b/pkg/bus/testing/request_reply.go index 22ceddd..077cc4c 100644 --- a/pkg/bus/testing/request_reply.go +++ b/pkg/bus/testing/request_reply.go @@ -11,17 +11,9 @@ import ( ) const ( - testTypeReqRes bus.MessageNamespace = "testNamspaceReqRes" + testTypeReqResAddress bus.Address = "testTypeReqResAddress" ) -type testReqResMessage struct { - i int -} - -func (m *testReqResMessage) MessageNamespace() bus.MessageNamespace { - return testNamespace -} - func TestRequestReply(t *testing.T, b bus.Bus) { expectedRoundTrips := 256 timeout := time.Now().Add(time.Duration(expectedRoundTrips) * time.Second) @@ -39,22 +31,20 @@ func TestRequestReply(t *testing.T, b bus.Bus) { initWaitGroup.Done() - err := b.Reply(repondCtx, testNamespace, func(msg bus.Message) (bus.Message, error) { + err := b.Reply(repondCtx, testTypeReqResAddress, func(env bus.Envelope) (any, error) { defer resWaitGroup.Done() - req, ok := msg.(*testReqResMessage) + req, ok := env.Message().(int) if !ok { return nil, errors.WithStack(bus.ErrUnexpectedMessage) } - result := &testReqResMessage{req.i} - // Simulate random work time.Sleep(time.Millisecond * 100) - t.Logf("[RES] sending res #%d", req.i) + t.Logf("[RES] sending res #%d", req) - return result, nil + return req, nil }) if err != nil { t.Error(err) @@ -75,32 +65,30 @@ func TestRequestReply(t *testing.T, b bus.Bus) { requestCtx, cancelRequest := context.WithDeadline(context.Background(), timeout) defer cancelRequest() - req := &testReqResMessage{i} - t.Logf("[REQ] sending req #%d", i) - result, err := b.Request(requestCtx, req) + response, err := b.Request(requestCtx, bus.NewEnvelope(testTypeReqResAddress, i)) if err != nil { t.Error(err) } t.Logf("[REQ] received req #%d reply", i) - if result == nil { - t.Error("result should not be nil") + if response == nil { + t.Error("response should not be nil") return } - res, ok := result.(*testReqResMessage) + result, ok := response.Message().(int) if !ok { t.Error(errors.WithStack(bus.ErrUnexpectedMessage)) return } - if e, g := req.i, res.i; e != g { - t.Errorf("res.i: expected '%v', got '%v'", e, g) + if e, g := i, result; e != g { + t.Errorf("response.Message(): expected '%v', got '%v'", e, g) } }(i) } diff --git a/pkg/http/client.go b/pkg/http/client.go index 96dd0b5..e8383ca 100644 --- a/pkg/http/client.go +++ b/pkg/http/client.go @@ -7,11 +7,11 @@ import ( ) func (h *Handler) handleSDKClient(w http.ResponseWriter, r *http.Request) { - serveFile(w, r, &sdk.FS, "client/dist/client.js") + ServeFile(w, r, &sdk.FS, "client/dist/client.js") } func (h *Handler) handleSDKClientMap(w http.ResponseWriter, r *http.Request) { - serveFile(w, r, &sdk.FS, "client/dist/client.js.map") + ServeFile(w, r, &sdk.FS, "client/dist/client.js.map") } func (h *Handler) handleAppFiles(w http.ResponseWriter, r *http.Request) { diff --git a/pkg/http/context.go b/pkg/http/context.go new file mode 100644 index 0000000..e492dd3 --- /dev/null +++ b/pkg/http/context.go @@ -0,0 +1,55 @@ +package http + +import ( + "context" + + "net/http" + + "forge.cadoles.com/arcad/edge/pkg/bus" + "github.com/pkg/errors" +) + +type contextKey string + +var ( + contextKeyBus contextKey = "bus" + contextKeyHTTPRequest contextKey = "httpRequest" + contextKeyHTTPClient contextKey = "httpClient" +) + +func (h *Handler) contextMiddleware(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + ctx = context.WithValue(ctx, contextKeyBus, h.bus) + ctx = context.WithValue(ctx, contextKeyHTTPRequest, r) + ctx = context.WithValue(ctx, contextKeyHTTPClient, h.httpClient) + + r = r.WithContext(ctx) + + next.ServeHTTP(w, r) + } + + return http.HandlerFunc(fn) +} + +func ContextBus(ctx context.Context) bus.Bus { + return contextValue[bus.Bus](ctx, contextKeyBus) +} + +func ContextHTTPRequest(ctx context.Context) *http.Request { + return contextValue[*http.Request](ctx, contextKeyHTTPRequest) +} + +func ContextHTTPClient(ctx context.Context) *http.Client { + return contextValue[*http.Client](ctx, contextKeyHTTPClient) +} + +func contextValue[T any](ctx context.Context, key any) T { + value, ok := ctx.Value(key).(T) + if !ok { + panic(errors.Errorf("could not find key '%v' on context", key)) + } + + return value +} diff --git a/pkg/http/envelope.go b/pkg/http/envelope.go new file mode 100644 index 0000000..4826377 --- /dev/null +++ b/pkg/http/envelope.go @@ -0,0 +1,30 @@ +package http + +import ( + "context" + + "forge.cadoles.com/arcad/edge/pkg/bus" +) + +var ( + AddressIncomingMessage bus.Address = "http/incoming-message" + AddressOutgoingMessage bus.Address = "http/outgoing-message" +) + +type IncomingMessage struct { + Context context.Context + Payload map[string]any +} + +func NewIncomingMessageEnvelope(ctx context.Context, payload map[string]any) bus.Envelope { + return bus.NewEnvelope(AddressIncomingMessage, &IncomingMessage{ctx, payload}) +} + +type OutgoingMessage struct { + SessionID string + Data any +} + +func NewOutgoingMessageEnvelope(sessionID string, data any) bus.Envelope { + return bus.NewEnvelope(AddressOutgoingMessage, &OutgoingMessage{sessionID, data}) +} diff --git a/pkg/http/handler.go b/pkg/http/handler.go index d486d75..c54a3f1 100644 --- a/pkg/http/handler.go +++ b/pkg/http/handler.go @@ -57,10 +57,6 @@ func (h *Handler) Load(ctx context.Context, bdle bundle.Bundle) error { server := app.NewServer(h.serverModuleFactories...) - if err := server.Load(serverMainScript, string(mainScript)); err != nil { - return errors.WithStack(err) - } - fs := bundle.NewFileSystem("public", bdle) public := HTML5Fileserver(fs) sockjs := sockjs.NewHandler(sockJSPathPrefix, h.sockjsOpts, h.handleSockJSSession) @@ -69,7 +65,7 @@ func (h *Handler) Load(ctx context.Context, bdle bundle.Bundle) error { h.server.Stop() } - if err := server.Start(ctx); err != nil { + if err := server.Start(ctx, serverMainScript, string(mainScript)); err != nil { return errors.WithStack(err) } @@ -108,12 +104,7 @@ func NewHandler(funcs ...HandlerOptionFunc) *Handler { r.Get("/client.js.map", handler.handleSDKClientMap) }) - r.Route("/api", func(r chi.Router) { - r.Post("/v1/upload", handler.handleAppUpload) - r.Get("/v1/download/{bucket}/{blobID}", handler.handleAppDownload) - - r.Get("/v1/fetch", handler.handleAppFetch) - }) + r.Use(handler.contextMiddleware) for _, fn := range opts.HTTPMounts { r.Group(func(r chi.Router) { diff --git a/pkg/http/sockjs.go b/pkg/http/sockjs.go index 57020e9..ddae58c 100644 --- a/pkg/http/sockjs.go +++ b/pkg/http/sockjs.go @@ -42,19 +42,18 @@ func (h *Handler) handleSockJSSession(sess sockjs.Session) { } }() - go h.handleServerMessages(ctx, sess) - h.handleClientMessages(ctx, sess) + go h.handleOutgoingMessages(ctx, sess) + h.handleIncomingMessages(ctx, sess) } -func (h *Handler) handleServerMessages(ctx context.Context, sess sockjs.Session) { - messages, err := h.bus.Subscribe(ctx, module.MessageNamespaceServer) +func (h *Handler) handleOutgoingMessages(ctx context.Context, sess sockjs.Session) { + envelopes, err := h.bus.Subscribe(ctx, AddressOutgoingMessage) if err != nil { panic(errors.WithStack(err)) } defer func() { - // Close messages subscriber - h.bus.Unsubscribe(ctx, module.MessageNamespaceServer, messages) + h.bus.Unsubscribe(AddressOutgoingMessage, envelopes) logger.Debug(ctx, "unsubscribed") @@ -72,26 +71,22 @@ func (h *Handler) handleServerMessages(ctx context.Context, sess sockjs.Session) case <-ctx.Done(): return - case msg := <-messages: - serverMessage, ok := msg.(*module.ServerMessage) + case env := <-envelopes: + outgoingMessage, ok := env.Message().(*OutgoingMessage) if !ok { logger.Error( ctx, - "unexpected server message", - logger.F("message", msg), + "unexpected outgoing message", + logger.F("message", env.Message()), ) - - continue } - sessionID := module.ContextValue[string](serverMessage.Context, ContextKeySessionID) - - isDest := sessionID == "" || sessionID == sess.ID() + isDest := outgoingMessage.SessionID == "" || outgoingMessage.SessionID == sess.ID() if !isDest { continue } - payload, err := json.Marshal(serverMessage.Data) + payload, err := json.Marshal(outgoingMessage.Data) if err != nil { logger.Error( ctx, @@ -132,7 +127,7 @@ func (h *Handler) handleServerMessages(ctx context.Context, sess sockjs.Session) } } -func (h *Handler) handleClientMessages(ctx context.Context, sess sockjs.Session) { +func (h *Handler) handleIncomingMessages(ctx context.Context, sess sockjs.Session) { for { select { case <-ctx.Done(): @@ -174,7 +169,7 @@ func (h *Handler) handleClientMessages(ctx context.Context, sess sockjs.Session) switch { case message.Type == WebsocketMessageTypeMessage: - var payload map[string]interface{} + var payload map[string]any if err := json.Unmarshal(message.Payload, &payload); err != nil { logger.Error( ctx, @@ -191,21 +186,19 @@ func (h *Handler) handleClientMessages(ctx context.Context, sess sockjs.Session) ContextKeyOriginRequest: sess.Request(), }) - clientMessage := module.NewClientMessage(ctx, payload) + incomingMessage := NewIncomingMessageEnvelope(ctx, payload) - logger.Debug(ctx, "publishing new client message", logger.F("message", clientMessage)) + logger.Debug(ctx, "publishing new incoming message", logger.F("message", incomingMessage)) - if err := h.bus.Publish(ctx, clientMessage); err != nil { + if err := h.bus.Publish(incomingMessage); err != nil { logger.Error(ctx, "could not publish message", logger.CapturedE(errors.WithStack(err)), - logger.F("message", clientMessage), + logger.F("message", incomingMessage), ) return } - logger.Debug(ctx, "new client message published", logger.F("message", clientMessage)) - default: logger.Error( ctx, diff --git a/pkg/http/util.go b/pkg/http/util.go new file mode 100644 index 0000000..4a7b328 --- /dev/null +++ b/pkg/http/util.go @@ -0,0 +1,82 @@ +package http + +import ( + "encoding/json" + "io" + "io/fs" + "net/http" + "os" + + "github.com/pkg/errors" + "gitlab.com/wpetit/goweb/logger" +) + +const ( + ErrCodeForbidden = "forbidden" + ErrCodeInternalError = "internal-error" + ErrCodeBadRequest = "bad-request" + ErrCodeNotFound = "not-found" +) + +type jsonErrorResponse struct { + Error jsonErr `json:"error"` +} + +type jsonErr struct { + Code string `json:"code"` +} + +func JSONError(w http.ResponseWriter, status int, code string) { + w.Header().Add("Content-Type", "application/json") + w.WriteHeader(status) + + encoder := json.NewEncoder(w) + response := jsonErrorResponse{ + Error: jsonErr{ + Code: code, + }, + } + + if err := encoder.Encode(response); err != nil { + panic(errors.WithStack(err)) + } +} + +func ServeFile(w http.ResponseWriter, r *http.Request, fs fs.FS, path string) { + ctx := logger.With(r.Context(), logger.F("path", path)) + + file, err := fs.Open(path) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) + + return + } + + logger.Error(ctx, "error while opening fs file", logger.CapturedE(errors.WithStack(err))) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + + return + } + + defer func() { + if err := file.Close(); err != nil { + logger.Error(ctx, "error while closing fs file", logger.CapturedE(errors.WithStack(err))) + } + }() + + info, err := file.Stat() + if err != nil { + logger.Error(ctx, "error while retrieving fs file stat", logger.CapturedE(errors.WithStack(err))) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + + return + } + + reader, ok := file.(io.ReadSeeker) + if !ok { + return + } + + http.ServeContent(w, r, path, info.ModTime(), reader) +} diff --git a/pkg/module/app/memory/module_test.go b/pkg/module/app/memory/module_test.go index 100a1c8..064f26a 100644 --- a/pkg/module/app/memory/module_test.go +++ b/pkg/module/app/memory/module_test.go @@ -39,21 +39,17 @@ func TestAppModuleWithMemoryRepository(t *testing.T) { )), ) - file := "testdata/app.js" + script := "testdata/app.js" - data, err := os.ReadFile(file) + data, err := os.ReadFile(script) if err != nil { t.Fatal(err) } - if err := server.Load(file, string(data)); err != nil { - t.Fatal(err) + ctx := context.Background() + if err := server.Start(ctx, script, string(data)); err != nil { + t.Fatalf("%+v", errors.WithStack(err)) } defer server.Stop() - - ctx := context.Background() - if err := server.Start(ctx); err != nil { - t.Fatalf("%+v", errors.WithStack(err)) - } } diff --git a/pkg/module/auth/module_test.go b/pkg/module/auth/module_test.go index e127add..0869a04 100644 --- a/pkg/module/auth/module_test.go +++ b/pkg/module/auth/module_test.go @@ -2,8 +2,8 @@ package auth import ( "context" - "io/ioutil" "net/http" + "os" "testing" "time" @@ -33,17 +33,15 @@ func TestAuthModule(t *testing.T) { ), ) - data, err := ioutil.ReadFile("testdata/auth.js") + script := "testdata/auth.js" + + data, err := os.ReadFile(script) if err != nil { t.Fatal(err) } - if err := server.Load("testdata/auth.js", string(data)); err != nil { - t.Fatal(err) - } - ctx := context.Background() - if err := server.Start(ctx); err != nil { + if err := server.Start(ctx, script, string(data)); err != nil { t.Fatalf("%+v", errors.WithStack(err)) } @@ -90,17 +88,15 @@ func TestAuthAnonymousModule(t *testing.T) { ModuleFactory(WithJWT(getDummyKeySet(key))), ) - data, err := ioutil.ReadFile("testdata/auth_anonymous.js") + script := "testdata/auth_anonymous.js" + + data, err := os.ReadFile("testdata/auth_anonymous.js") if err != nil { t.Fatal(err) } - if err := server.Load("testdata/auth_anonymous.js", string(data)); err != nil { - t.Fatal(err) - } - ctx := context.Background() - if err := server.Start(ctx); err != nil { + if err := server.Start(ctx, script, string(data)); err != nil { t.Fatalf("%+v", errors.WithStack(err)) } diff --git a/pkg/module/blob/blob_message.go b/pkg/module/blob/envelope.go similarity index 52% rename from pkg/module/blob/blob_message.go rename to pkg/module/blob/envelope.go index d355a90..b9ba465 100644 --- a/pkg/module/blob/blob_message.go +++ b/pkg/module/blob/envelope.go @@ -11,50 +11,37 @@ import ( ) const ( - MessageNamespaceUploadRequest bus.MessageNamespace = "uploadRequest" - MessageNamespaceUploadResponse bus.MessageNamespace = "uploadResponse" - MessageNamespaceDownloadRequest bus.MessageNamespace = "downloadRequest" - MessageNamespaceDownloadResponse bus.MessageNamespace = "downloadResponse" + AddressUploadRequest bus.Address = "module/blob/uploadRequest" + AddressUploadResponse bus.Address = "module/blob/uploadResponse" + AddressDownloadRequest bus.Address = "module/blob/downloadRequest" + AddressDownloadResponse bus.Address = "module/blob/downloadResponse" ) -type MessageUploadRequest struct { +type UploadRequest struct { Context context.Context - RequestID string FileHeader *multipart.FileHeader Metadata map[string]interface{} } -func (m *MessageUploadRequest) MessageNamespace() bus.MessageNamespace { - return MessageNamespaceUploadRequest -} - -func NewMessageUploadRequest(ctx context.Context, fileHeader *multipart.FileHeader, metadata map[string]interface{}) *MessageUploadRequest { - return &MessageUploadRequest{ +func NewUploadRequestEnvelope(ctx context.Context, fileHeader *multipart.FileHeader, metadata map[string]interface{}) bus.Envelope { + return bus.NewEnvelope(AddressUploadRequest, &UploadRequest{ Context: ctx, - RequestID: ulid.Make().String(), FileHeader: fileHeader, Metadata: metadata, - } + }) } -type MessageUploadResponse struct { - RequestID string - BlobID storage.BlobID - Bucket string - Allow bool +type UploadResponse struct { + Allow bool } -func (m *MessageUploadResponse) MessageNamespace() bus.MessageNamespace { - return MessageNamespaceDownloadResponse +func NewUploadResponseEnvelope(allow bool) bus.Envelope { + return bus.NewEnvelope(AddressUploadResponse, &UploadResponse{ + Allow: allow, + }) } -func NewMessageUploadResponse(requestID string) *MessageUploadResponse { - return &MessageUploadResponse{ - RequestID: requestID, - } -} - -type MessageDownloadRequest struct { +type DownloadRequest struct { Context context.Context RequestID string Bucket string diff --git a/pkg/http/blob.go b/pkg/module/blob/http.go similarity index 53% rename from pkg/http/blob.go rename to pkg/module/blob/http.go index d549393..20fe3a2 100644 --- a/pkg/http/blob.go +++ b/pkg/module/blob/http.go @@ -1,8 +1,7 @@ -package http +package blob import ( "encoding/json" - "io" "io/fs" "mime/multipart" "net/http" @@ -10,37 +9,31 @@ import ( "time" "forge.cadoles.com/arcad/edge/pkg/bus" - "forge.cadoles.com/arcad/edge/pkg/module" - "forge.cadoles.com/arcad/edge/pkg/module/blob" + edgehttp "forge.cadoles.com/arcad/edge/pkg/http" "forge.cadoles.com/arcad/edge/pkg/storage" "github.com/go-chi/chi/v5" "github.com/pkg/errors" "gitlab.com/wpetit/goweb/logger" ) -const ( - errorCodeForbidden = "forbidden" - errorCodeInternalError = "internal-error" - errorCodeBadRequest = "bad-request" - errorCodeNotFound = "not-found" -) - type uploadResponse struct { Bucket string `json:"bucket"` BlobID storage.BlobID `json:"blobId"` } -func (h *Handler) handleAppUpload(w http.ResponseWriter, r *http.Request) { - h.mutex.RLock() - defer h.mutex.RUnlock() +func Mount(r chi.Router) { + r.Post("/api/v1/upload", handleAppUpload) + r.Get("/api/v1/download/{bucket}/{blobID}", handleAppDownload) +} +func handleAppUpload(w http.ResponseWriter, r *http.Request) { ctx := r.Context() r.Body = http.MaxBytesReader(w, r.Body, h.uploadMaxFileSize) if err := r.ParseMultipartForm(h.uploadMaxFileSize); err != nil { logger.Error(ctx, "could not parse multipart form", logger.CapturedE(errors.WithStack(err))) - jsonError(w, http.StatusBadRequest, errorCodeBadRequest) + edgehttp.JSONError(w, http.StatusBadRequest, edgehttp.ErrCodeBadRequest) return } @@ -48,7 +41,7 @@ func (h *Handler) handleAppUpload(w http.ResponseWriter, r *http.Request) { _, fileHeader, err := r.FormFile("file") if err != nil { logger.Error(ctx, "could not read form file", logger.CapturedE(errors.WithStack(err))) - jsonError(w, http.StatusBadRequest, errorCodeBadRequest) + edgehttp.JSONError(w, http.StatusBadRequest, edgehttp.ErrCodeBadRequest) return } @@ -59,41 +52,39 @@ func (h *Handler) handleAppUpload(w http.ResponseWriter, r *http.Request) { if rawMetadata != "" { if err := json.Unmarshal([]byte(rawMetadata), &metadata); err != nil { logger.Error(ctx, "could not parse metadata", logger.CapturedE(errors.WithStack(err))) - jsonError(w, http.StatusBadRequest, errorCodeBadRequest) + edgehttp.JSONError(w, http.StatusBadRequest, edgehttp.ErrCodeBadRequest) return } } - ctx = module.WithContext(ctx, map[module.ContextKey]any{ - ContextKeyOriginRequest: r, - }) + requestMsg := NewMessageUploadRequest(ctx, fileHeader, metadata) - requestMsg := blob.NewMessageUploadRequest(ctx, fileHeader, metadata) + bus := edgehttp.ContextBus(ctx) - reply, err := h.bus.Request(ctx, requestMsg) + reply, err := bus.Request(ctx, requestMsg) if err != nil { logger.Error(ctx, "could not retrieve file", logger.CapturedE(errors.WithStack(err))) - jsonError(w, http.StatusInternalServerError, errorCodeInternalError) + edgehttp.JSONError(w, http.StatusInternalServerError, edgehttp.ErrCodeInternalError) return } logger.Debug(ctx, "upload reply", logger.F("reply", reply)) - responseMsg, ok := reply.(*blob.MessageUploadResponse) + responseMsg, ok := reply.(*MessageUploadResponse) if !ok { logger.Error( ctx, "unexpected upload response message", logger.F("message", reply), ) - jsonError(w, http.StatusInternalServerError, errorCodeInternalError) + edgehttp.JSONError(w, http.StatusInternalServerError, edgehttp.ErrCodeInternalError) return } if !responseMsg.Allow { - jsonError(w, http.StatusForbidden, errorCodeForbidden) + edgehttp.JSONError(w, http.StatusForbidden, edgehttp.ErrCodeForbidden) return } @@ -109,21 +100,17 @@ func (h *Handler) handleAppUpload(w http.ResponseWriter, r *http.Request) { } } -func (h *Handler) handleAppDownload(w http.ResponseWriter, r *http.Request) { - h.mutex.RLock() - defer h.mutex.RUnlock() - +func handleAppDownload(w http.ResponseWriter, r *http.Request) { bucket := chi.URLParam(r, "bucket") blobID := chi.URLParam(r, "blobID") ctx := logger.With(r.Context(), logger.F("blobID", blobID), logger.F("bucket", bucket)) - ctx = module.WithContext(ctx, map[module.ContextKey]any{ - ContextKeyOriginRequest: r, - }) - requestMsg := blob.NewMessageDownloadRequest(ctx, bucket, storage.BlobID(blobID)) + requestMsg := NewMessageDownloadRequest(ctx, bucket, storage.BlobID(blobID)) - reply, err := h.bus.Request(ctx, requestMsg) + bs := edgehttp.ContextBus(ctx) + + reply, err := bs.Request(ctx, requestMsg) if err != nil { logger.Error(ctx, "could not retrieve file", logger.CapturedE(errors.WithStack(err))) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) @@ -131,26 +118,26 @@ func (h *Handler) handleAppDownload(w http.ResponseWriter, r *http.Request) { return } - replyMsg, ok := reply.(*blob.MessageDownloadResponse) + replyMsg, ok := reply.(*MessageDownloadResponse) if !ok { logger.Error( ctx, "unexpected download response message", logger.CapturedE(errors.WithStack(bus.ErrUnexpectedMessage)), logger.F("message", reply), ) - jsonError(w, http.StatusInternalServerError, errorCodeInternalError) + edgehttp.JSONError(w, http.StatusInternalServerError, edgehttp.ErrCodeInternalError) return } if !replyMsg.Allow { - jsonError(w, http.StatusForbidden, errorCodeForbidden) + edgehttp.JSONError(w, http.StatusForbidden, edgehttp.ErrCodeForbidden) return } if replyMsg.Blob == nil { - jsonError(w, http.StatusNotFound, errorCodeNotFound) + edgehttp.JSONError(w, http.StatusNotFound, edgehttp.ErrCodeNotFound) return } @@ -164,69 +151,6 @@ func (h *Handler) handleAppDownload(w http.ResponseWriter, r *http.Request) { http.ServeContent(w, r, string(replyMsg.BlobInfo.ID()), replyMsg.BlobInfo.ModTime(), replyMsg.Blob) } -func serveFile(w http.ResponseWriter, r *http.Request, fs fs.FS, path string) { - ctx := logger.With(r.Context(), logger.F("path", path)) - - file, err := fs.Open(path) - if err != nil { - if errors.Is(err, os.ErrNotExist) { - http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) - - return - } - - logger.Error(ctx, "error while opening fs file", logger.CapturedE(errors.WithStack(err))) - http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) - - return - } - - defer func() { - if err := file.Close(); err != nil { - logger.Error(ctx, "error while closing fs file", logger.CapturedE(errors.WithStack(err))) - } - }() - - info, err := file.Stat() - if err != nil { - logger.Error(ctx, "error while retrieving fs file stat", logger.CapturedE(errors.WithStack(err))) - http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) - - return - } - - reader, ok := file.(io.ReadSeeker) - if !ok { - return - } - - http.ServeContent(w, r, path, info.ModTime(), reader) -} - -type jsonErrorResponse struct { - Error jsonErr `json:"error"` -} - -type jsonErr struct { - Code string `json:"code"` -} - -func jsonError(w http.ResponseWriter, status int, code string) { - w.Header().Add("Content-Type", "application/json") - w.WriteHeader(status) - - encoder := json.NewEncoder(w) - response := jsonErrorResponse{ - Error: jsonErr{ - Code: code, - }, - } - - if err := encoder.Encode(response); err != nil { - panic(errors.WithStack(err)) - } -} - type uploadedFile struct { multipart.File header *multipart.FileHeader diff --git a/pkg/module/blob/module_test.go b/pkg/module/blob/module_test.go index 7c89bb3..20a9c03 100644 --- a/pkg/module/blob/module_test.go +++ b/pkg/module/blob/module_test.go @@ -28,19 +28,17 @@ func TestBlobModule(t *testing.T) { ModuleFactory(bus, store), ) - data, err := os.ReadFile("testdata/blob.js") + script := "testdata/blob.js" + + data, err := os.ReadFile(script) if err != nil { t.Fatal(err) } - if err := server.Load("testdata/blob.js", string(data)); err != nil { - t.Fatal(err) + ctx := context.Background() + if err := server.Start(ctx, script, string(data)); err != nil { + t.Fatalf("%+v", errors.WithStack(err)) } defer server.Stop() - - ctx := context.Background() - if err := server.Start(ctx); err != nil { - t.Fatalf("%+v", errors.WithStack(err)) - } } diff --git a/pkg/module/cast/module_test.go b/pkg/module/cast/module_test.go index 37c298f..d38b970 100644 --- a/pkg/module/cast/module_test.go +++ b/pkg/module/cast/module_test.go @@ -2,7 +2,6 @@ package cast import ( "context" - "io/ioutil" "os" "testing" "time" @@ -31,17 +30,15 @@ func TestCastModule(t *testing.T) { CastModuleFactory(), ) - data, err := ioutil.ReadFile("testdata/cast.js") + script := "testdata/cast.js" + + data, err := os.ReadFile(script) if err != nil { t.Fatal(err) } - if err := server.Load("testdata/cast.js", string(data)); err != nil { - t.Fatal(err) - } - ctx := context.Background() - if err := server.Start(ctx); err != nil { + if err := server.Start(ctx, script, string(data)); err != nil { t.Fatalf("%+v", errors.WithStack(err)) } @@ -66,17 +63,15 @@ func TestCastModuleRefreshDevices(t *testing.T) { CastModuleFactory(), ) - data, err := ioutil.ReadFile("testdata/refresh_devices.js") + script := "testdata/refresh_devices.js" + + data, err := os.ReadFile(script) if err != nil { t.Fatal(err) } - if err := server.Load("testdata/refresh_devices.js", string(data)); err != nil { - t.Fatal(err) - } - ctx := context.Background() - if err := server.Start(ctx); err != nil { + if err := server.Start(ctx, script, string(data)); err != nil { t.Fatalf("%+v", errors.WithStack(err)) } diff --git a/pkg/module/fetch/envelope.go b/pkg/module/fetch/envelope.go new file mode 100644 index 0000000..2af30f7 --- /dev/null +++ b/pkg/module/fetch/envelope.go @@ -0,0 +1,38 @@ +package fetch + +import ( + "context" + "net/url" + + "forge.cadoles.com/arcad/edge/pkg/bus" +) + +const ( + AddressFetchRequest bus.Address = "module/fetch/request" + AddressFetchResponse bus.Address = "module/fetch/response" +) + +type FetchRequest struct { + Context context.Context + RequestID string + URL *url.URL + RemoteAddr string +} + +func NewFetchRequestEnvelope(ctx context.Context, remoteAddr string, url *url.URL) bus.Envelope { + return bus.NewEnvelope(AddressFetchRequest, &FetchRequest{ + Context: ctx, + URL: url, + RemoteAddr: remoteAddr, + }) +} + +type FetchResponse struct { + Allow bool +} + +func NewFetchResponseEnvelope(allow bool) bus.Envelope { + return bus.NewEnvelope(AddressFetchResponse, &FetchResponse{ + Allow: allow, + }) +} diff --git a/pkg/module/fetch/fetch_message.go b/pkg/module/fetch/fetch_message.go deleted file mode 100644 index f2493e1..0000000 --- a/pkg/module/fetch/fetch_message.go +++ /dev/null @@ -1,49 +0,0 @@ -package fetch - -import ( - "context" - "net/url" - - "forge.cadoles.com/arcad/edge/pkg/bus" - "github.com/oklog/ulid/v2" -) - -const ( - MessageNamespaceFetchRequest bus.MessageNamespace = "fetchRequest" - MessageNamespaceFetchResponse bus.MessageNamespace = "fetchResponse" -) - -type MessageFetchRequest struct { - Context context.Context - RequestID string - URL *url.URL - RemoteAddr string -} - -func (m *MessageFetchRequest) MessageNamespace() bus.MessageNamespace { - return MessageNamespaceFetchRequest -} - -func NewMessageFetchRequest(ctx context.Context, remoteAddr string, url *url.URL) *MessageFetchRequest { - return &MessageFetchRequest{ - Context: ctx, - RequestID: ulid.Make().String(), - RemoteAddr: remoteAddr, - URL: url, - } -} - -type MessageFetchResponse struct { - RequestID string - Allow bool -} - -func (m *MessageFetchResponse) MessageNamespace() bus.MessageNamespace { - return MessageNamespaceFetchResponse -} - -func NewMessageFetchResponse(requestID string) *MessageFetchResponse { - return &MessageFetchResponse{ - RequestID: requestID, - } -} diff --git a/pkg/http/fetch.go b/pkg/module/fetch/http.go similarity index 60% rename from pkg/http/fetch.go rename to pkg/module/fetch/http.go index 87c71e4..be8e946 100644 --- a/pkg/http/fetch.go +++ b/pkg/module/fetch/http.go @@ -1,60 +1,59 @@ -package http +package fetch import ( "io" "net/http" "net/url" - "forge.cadoles.com/arcad/edge/pkg/module" - "forge.cadoles.com/arcad/edge/pkg/module/fetch" + edgehttp "forge.cadoles.com/arcad/edge/pkg/http" + "github.com/go-chi/chi/v5" "github.com/pkg/errors" "gitlab.com/wpetit/goweb/logger" ) -func (h *Handler) handleAppFetch(w http.ResponseWriter, r *http.Request) { - h.mutex.RLock() - defer h.mutex.RUnlock() +func Mount(r chi.Router) { + r.Get("/api/v1/fetch", handleAppFetch) +} +func handleAppFetch(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - ctx = module.WithContext(ctx, map[module.ContextKey]any{ - ContextKeyOriginRequest: r, - }) - rawURL := r.URL.Query().Get("url") url, err := url.Parse(rawURL) if err != nil { - jsonError(w, http.StatusBadRequest, errorCodeBadRequest) + edgehttp.JSONError(w, http.StatusBadRequest, edgehttp.ErrCodeBadRequest) return } - requestMsg := fetch.NewMessageFetchRequest(ctx, r.RemoteAddr, url) + requestMsg := NewFetchRequestEnvelope(ctx, r.RemoteAddr, url) - reply, err := h.bus.Request(ctx, requestMsg) + bus := edgehttp.ContextBus(ctx) + + reply, err := bus.Request(ctx, requestMsg) if err != nil { logger.Error(ctx, "could not retrieve fetch request reply", logger.CapturedE(errors.WithStack(err))) - jsonError(w, http.StatusInternalServerError, errorCodeInternalError) + edgehttp.JSONError(w, http.StatusInternalServerError, edgehttp.ErrCodeInternalError) return } logger.Debug(ctx, "fetch reply", logger.F("reply", reply)) - responseMsg, ok := reply.(*fetch.MessageFetchResponse) + responseMsg, ok := reply.Message().(*FetchResponse) if !ok { logger.Error( ctx, "unexpected fetch response message", logger.F("message", reply), ) - jsonError(w, http.StatusInternalServerError, errorCodeInternalError) + edgehttp.JSONError(w, http.StatusInternalServerError, edgehttp.ErrCodeInternalError) return } if !responseMsg.Allow { - jsonError(w, http.StatusForbidden, errorCodeForbidden) + edgehttp.JSONError(w, http.StatusForbidden, edgehttp.ErrCodeForbidden) return } @@ -65,7 +64,7 @@ func (h *Handler) handleAppFetch(w http.ResponseWriter, r *http.Request) { ctx, "could not create proxy request", logger.CapturedE(errors.WithStack(err)), ) - jsonError(w, http.StatusInternalServerError, errorCodeInternalError) + edgehttp.JSONError(w, http.StatusInternalServerError, edgehttp.ErrCodeInternalError) return } @@ -78,13 +77,15 @@ func (h *Handler) handleAppFetch(w http.ResponseWriter, r *http.Request) { proxyReq.Header.Add("X-Forwarded-From", r.RemoteAddr) - res, err := h.httpClient.Do(proxyReq) + httpClient := edgehttp.ContextHTTPClient(ctx) + + res, err := httpClient.Do(proxyReq) if err != nil { logger.Error( ctx, "could not execute proxy request", logger.CapturedE(errors.WithStack(err)), ) - jsonError(w, http.StatusInternalServerError, errorCodeInternalError) + edgehttp.JSONError(w, http.StatusInternalServerError, edgehttp.ErrCodeInternalError) return } diff --git a/pkg/module/fetch/module.go b/pkg/module/fetch/module.go index fdc3930..4bf2797 100644 --- a/pkg/module/fetch/module.go +++ b/pkg/module/fetch/module.go @@ -40,10 +40,10 @@ func (m *Module) get(call goja.FunctionCall, rt *goja.Runtime) goja.Value { func (m *Module) handleMessages() { ctx := context.Background() - err := m.bus.Reply(ctx, MessageNamespaceFetchRequest, func(msg bus.Message) (bus.Message, error) { - fetchRequest, ok := msg.(*MessageFetchRequest) + err := m.bus.Reply(ctx, AddressFetchRequest, func(env bus.Envelope) (any, error) { + fetchRequest, ok := env.Message().(*FetchRequest) if !ok { - return nil, errors.Wrapf(bus.ErrUnexpectedMessage, "expected message fetch request, got '%T'", msg) + return nil, errors.Wrapf(bus.ErrUnexpectedMessage, "expected fetch request, got '%T'", env.Message()) } res, err := m.handleFetchRequest(fetchRequest) @@ -62,8 +62,8 @@ func (m *Module) handleMessages() { } } -func (m *Module) handleFetchRequest(req *MessageFetchRequest) (*MessageFetchResponse, error) { - res := NewMessageFetchResponse(req.RequestID) +func (m *Module) handleFetchRequest(req *FetchRequest) (*FetchResponse, error) { + res := &FetchResponse{} ctx := logger.With( req.Context, diff --git a/pkg/module/fetch/module_test.go b/pkg/module/fetch/module_test.go index a9e8c49..19cb6af 100644 --- a/pkg/module/fetch/module_test.go +++ b/pkg/module/fetch/module_test.go @@ -2,8 +2,8 @@ package fetch import ( "context" - "io/ioutil" "net/url" + "os" "testing" "time" @@ -28,22 +28,20 @@ func TestFetchModule(t *testing.T) { ModuleFactory(bus), ) - data, err := ioutil.ReadFile("testdata/fetch.js") + path := "testdata/fetch.js" + + data, err := os.ReadFile(path) if err != nil { t.Fatalf("%+v", errors.WithStack(err)) } - if err := server.Load("testdata/fetch.js", string(data)); err != nil { + ctx := context.Background() + if err := server.Start(ctx, path, string(data)); err != nil { t.Fatalf("%+v", errors.WithStack(err)) } defer server.Stop() - ctx := context.Background() - if err := server.Start(ctx); err != nil { - t.Fatalf("%+v", errors.WithStack(err)) - } - // Wait for module to startup time.Sleep(1 * time.Second) @@ -53,33 +51,33 @@ func TestFetchModule(t *testing.T) { remoteAddr := "127.0.0.1" url, _ := url.Parse("http://example.com") - rawReply, err := bus.Request(ctx, NewMessageFetchRequest(ctx, remoteAddr, url)) + reply, err := bus.Request(ctx, NewFetchRequestEnvelope(ctx, remoteAddr, url)) if err != nil { t.Fatalf("%+v", errors.WithStack(err)) } - reply, ok := rawReply.(*MessageFetchResponse) + response, ok := reply.Message().(*FetchResponse) if !ok { - t.Fatalf("unexpected reply type '%T'", rawReply) + t.Fatalf("unexpected reply message type '%T'", reply.Message()) } - if e, g := true, reply.Allow; e != g { + if e, g := true, response.Allow; e != g { t.Errorf("reply.Allow: expected '%v', got '%v'", e, g) } url, _ = url.Parse("https://google.com") - rawReply, err = bus.Request(ctx, NewMessageFetchRequest(ctx, remoteAddr, url)) + reply, err = bus.Request(ctx, NewFetchRequestEnvelope(ctx, remoteAddr, url)) if err != nil { t.Fatalf("%+v", errors.WithStack(err)) } - reply, ok = rawReply.(*MessageFetchResponse) + response, ok = reply.Message().(*FetchResponse) if !ok { - t.Fatalf("unexpected reply type '%T'", rawReply) + t.Fatalf("unexpected reply message type '%T'", reply.Message()) } - if e, g := false, reply.Allow; e != g { + if e, g := false, response.Allow; e != g { t.Errorf("reply.Allow: expected '%v', got '%v'", e, g) } } diff --git a/pkg/module/lifecycle.go b/pkg/module/lifecycle.go index 983617f..af98219 100644 --- a/pkg/module/lifecycle.go +++ b/pkg/module/lifecycle.go @@ -5,7 +5,6 @@ import ( "forge.cadoles.com/arcad/edge/pkg/app" "github.com/dop251/goja" - "github.com/pkg/errors" "gitlab.com/wpetit/goweb/logger" ) @@ -19,16 +18,28 @@ func (m *LifecycleModule) Export(export *goja.Object) { } func (m *LifecycleModule) OnInit(ctx context.Context, rt *goja.Runtime) (err error) { - _, ok := goja.AssertFunction(rt.Get("onInit")) + call, ok := goja.AssertFunction(rt.Get("onInit")) if !ok { logger.Warn(ctx, "could not find onInit() function") return nil } - if _, err := rt.RunString("setTimeout(onInit, 0)"); err != nil { - return errors.WithStack(err) - } + defer func() { + recovered := recover() + if recovered == nil { + return + } + + recoveredErr, ok := recovered.(error) + if !ok { + panic(recovered) + } + + err = recoveredErr + }() + + call(nil, rt.ToValue(ctx)) return nil } diff --git a/pkg/module/message.go b/pkg/module/message.go deleted file mode 100644 index 728d03d..0000000 --- a/pkg/module/message.go +++ /dev/null @@ -1,38 +0,0 @@ -package module - -import ( - "context" - - "forge.cadoles.com/arcad/edge/pkg/bus" -) - -const ( - MessageNamespaceClient bus.MessageNamespace = "client" - MessageNamespaceServer bus.MessageNamespace = "server" -) - -type ServerMessage struct { - Context context.Context - Data interface{} -} - -func (m *ServerMessage) MessageNamespace() bus.MessageNamespace { - return MessageNamespaceServer -} - -func NewServerMessage(ctx context.Context, data interface{}) *ServerMessage { - return &ServerMessage{ctx, data} -} - -type ClientMessage struct { - Context context.Context - Data map[string]interface{} -} - -func (m *ClientMessage) MessageNamespace() bus.MessageNamespace { - return MessageNamespaceClient -} - -func NewClientMessage(ctx context.Context, data map[string]interface{}) *ClientMessage { - return &ClientMessage{ctx, data} -} diff --git a/pkg/module/net/envelope.go b/pkg/module/net/envelope.go new file mode 100644 index 0000000..2a7c657 --- /dev/null +++ b/pkg/module/net/envelope.go @@ -0,0 +1,38 @@ +package net + +import ( + "context" + + "forge.cadoles.com/arcad/edge/pkg/bus" +) + +const ( + AddressIncoming bus.Address = "module/net/incoming" + AddressOutgoing bus.Address = "module/net/outgoing" +) + +type IncomingMessage struct { + Context context.Context + Data any +} + +func NewIncomingMessageEnvelope(ctx context.Context, data any) bus.Envelope { + return bus.NewEnvelope(AddressIncoming, &IncomingMessage{ctx, data}) +} + +type OutgoingBroadcastMessage struct { + Data any +} + +func NewOutgoingBroadcastMessageEnvelope(data any) bus.Envelope { + return bus.NewEnvelope(AddressIncoming, &OutgoingBroadcastMessage{data}) +} + +type OutgoingMessage struct { + Context context.Context + Data any +} + +func NewOutgoingMessageEnvelope(ctx context.Context, data any) bus.Envelope { + return bus.NewEnvelope(AddressOutgoing, &OutgoingMessage{ctx, data}) +} diff --git a/pkg/module/net/module.go b/pkg/module/net/module.go index d26a6b3..98dcc40 100644 --- a/pkg/module/net/module.go +++ b/pkg/module/net/module.go @@ -38,10 +38,9 @@ func (m *Module) broadcast(call goja.FunctionCall, rt *goja.Runtime) goja.Value } data := call.Argument(0).Export() - ctx := context.Background() - msg := module.NewServerMessage(ctx, data) - if err := m.bus.Publish(ctx, msg); err != nil { + env := NewOutgoingBroadcastMessageEnvelope(data) + if err := m.bus.Publish(env); err != nil { panic(rt.ToValue(errors.WithStack(err))) } @@ -68,23 +67,23 @@ func (m *Module) send(call goja.FunctionCall, rt *goja.Runtime) goja.Value { data := call.Argument(1).Export() - msg := module.NewServerMessage(ctx, data) - if err := m.bus.Publish(ctx, msg); err != nil { + env := NewOutgoingMessageEnvelope(ctx, data) + if err := m.bus.Publish(env); err != nil { panic(rt.ToValue(errors.WithStack(err))) } return nil } -func (m *Module) handleClientMessages() { +func (m *Module) handleIncomingMessages() { ctx := context.Background() logger.Debug( ctx, - "subscribing to bus messages", + "subscribing to bus envelopes", ) - clientMessages, err := m.bus.Subscribe(ctx, module.MessageNamespaceClient) + envelopes, err := m.bus.Subscribe(ctx, AddressIncoming) if err != nil { panic(errors.WithStack(err)) } @@ -92,16 +91,16 @@ func (m *Module) handleClientMessages() { defer func() { logger.Debug( ctx, - "unsubscribing from bus messages", + "unsubscribing from bus envelopes", ) - m.bus.Unsubscribe(ctx, module.MessageNamespaceClient, clientMessages) + m.bus.Unsubscribe(AddressIncoming, envelopes) }() for { logger.Debug( ctx, - "waiting for next message", + "waiting for next envelope", ) select { case <-ctx.Done(): @@ -112,13 +111,13 @@ func (m *Module) handleClientMessages() { return - case msg := <-clientMessages: - clientMessage, ok := msg.(*module.ClientMessage) + case env := <-envelopes: + incomingMessage, ok := env.Message().(*IncomingMessage) if !ok { logger.Warn( ctx, "unexpected message type", - logger.F("message", msg), + logger.F("message", env.Message()), ) continue @@ -126,11 +125,11 @@ func (m *Module) handleClientMessages() { logger.Debug( ctx, - "received client message", - logger.F("message", clientMessage), + "received incoming message", + logger.F("message", incomingMessage), ) - if _, err := m.server.ExecFuncByName(clientMessage.Context, "onClientMessage", clientMessage.Context, clientMessage.Data); err != nil { + if _, err := m.server.ExecFuncByName(incomingMessage.Context, "onClientMessage", incomingMessage.Context, incomingMessage.Data); err != nil { if errors.Is(err, app.ErrFuncDoesNotExist) { continue } @@ -152,7 +151,7 @@ func ModuleFactory(bus bus.Bus) app.ServerModuleFactory { bus: bus, } - go module.handleClientMessages() + go module.handleIncomingMessages() return module } diff --git a/pkg/module/rpc.go b/pkg/module/rpc.go deleted file mode 100644 index 5877ffb..0000000 --- a/pkg/module/rpc.go +++ /dev/null @@ -1,278 +0,0 @@ -package module - -import ( - "context" - "fmt" - "sync" - - "forge.cadoles.com/arcad/edge/pkg/app" - "forge.cadoles.com/arcad/edge/pkg/bus" - "forge.cadoles.com/arcad/edge/pkg/module/util" - "github.com/dop251/goja" - "github.com/pkg/errors" - "gitlab.com/wpetit/goweb/logger" -) - -type RPCRequest struct { - Method string - Params interface{} - ID interface{} -} - -type RPCError struct { - Code int `json:"code"` - Message string `json:"message"` - Data interface{} `json:"data"` -} - -type RPCResponse struct { - Result interface{} - Error *RPCError - ID interface{} -} - -type RPCModule struct { - server *app.Server - bus bus.Bus - callbacks sync.Map -} - -func (m *RPCModule) Name() string { - return "rpc" -} - -func (m *RPCModule) Export(export *goja.Object) { - if err := export.Set("register", m.register); err != nil { - panic(errors.Wrap(err, "could not set 'register' function")) - } - - if err := export.Set("unregister", m.unregister); err != nil { - panic(errors.Wrap(err, "could not set 'unregister' function")) - } -} - -func (m *RPCModule) OnInit(ctx context.Context, rt *goja.Runtime) error { - go m.handleMessages(ctx) - - return nil -} - -func (m *RPCModule) register(call goja.FunctionCall, rt *goja.Runtime) goja.Value { - fnName := util.AssertString(call.Argument(0), rt) - - var ( - callable goja.Callable - ok bool - ) - - if len(call.Arguments) > 1 { - callable, ok = goja.AssertFunction(call.Argument(1)) - } else { - callable, ok = goja.AssertFunction(rt.Get(fnName)) - } - - if !ok { - panic(rt.NewTypeError("method should be a valid function")) - } - - ctx := context.Background() - - logger.Debug(ctx, "registering method", logger.F("method", fnName)) - - m.callbacks.Store(fnName, callable) - - return nil -} - -func (m *RPCModule) unregister(call goja.FunctionCall, rt *goja.Runtime) goja.Value { - fnName := util.AssertString(call.Argument(0), rt) - - m.callbacks.Delete(fnName) - - return nil -} - -func (m *RPCModule) handleMessages(ctx context.Context) { - clientMessages, err := m.bus.Subscribe(ctx, MessageNamespaceClient) - if err != nil { - panic(errors.WithStack(err)) - } - - defer func() { - m.bus.Unsubscribe(ctx, MessageNamespaceClient, clientMessages) - }() - - sendRes := func(ctx context.Context, req *RPCRequest, result goja.Value) { - res := &RPCResponse{ - ID: req.ID, - Result: result.Export(), - } - - logger.Debug(ctx, "sending rpc response", logger.F("response", res)) - - if err := m.sendResponse(ctx, res); err != nil { - logger.Error( - ctx, "could not send response", - logger.CapturedE(errors.WithStack(err)), - logger.F("response", res), - logger.F("request", req), - ) - } - } - - for msg := range clientMessages { - go m.handleMessage(ctx, msg, sendRes) - } -} - -func (m *RPCModule) handleMessage(ctx context.Context, msg bus.Message, sendRes func(ctx context.Context, req *RPCRequest, result goja.Value)) { - clientMessage, ok := msg.(*ClientMessage) - if !ok { - logger.Warn(ctx, "unexpected bus message", logger.F("message", msg)) - - return - } - - ok, req := m.isRPCRequest(clientMessage) - if !ok { - return - } - - logger.Debug(ctx, "received rpc request", logger.F("request", req)) - - rawCallable, exists := m.callbacks.Load(req.Method) - if !exists { - logger.Debug(ctx, "method not found", logger.F("req", req)) - - if err := m.sendMethodNotFoundResponse(clientMessage.Context, req); err != nil { - logger.Error( - ctx, "could not send method not found response", - logger.CapturedE(errors.WithStack(err)), - logger.F("request", req), - ) - } - - return - } - - callable, ok := rawCallable.(goja.Callable) - if !ok { - logger.Debug(ctx, "invalid method", logger.F("req", req)) - - if err := m.sendMethodNotFoundResponse(clientMessage.Context, req); err != nil { - logger.Error( - ctx, "could not send method not found response", - logger.CapturedE(errors.WithStack(err)), - logger.F("request", req), - ) - } - - return - } - - result, err := m.server.Exec(clientMessage.Context, callable, clientMessage.Context, req.Params) - if err != nil { - logger.Error( - ctx, "rpc call error", - logger.CapturedE(errors.WithStack(err)), - logger.F("request", req), - ) - - if err := m.sendErrorResponse(clientMessage.Context, req, err); err != nil { - logger.Error( - ctx, "could not send error response", - logger.CapturedE(errors.WithStack(err)), - logger.F("originalError", err), - logger.F("request", req), - ) - } - - return - } - - promise, ok := app.IsPromise(result) - if ok { - go func(ctx context.Context, req *RPCRequest, promise *goja.Promise) { - result := m.server.WaitForPromise(promise) - sendRes(ctx, req, result) - }(clientMessage.Context, req, promise) - } else { - sendRes(clientMessage.Context, req, result) - } -} - -func (m *RPCModule) sendErrorResponse(ctx context.Context, req *RPCRequest, err error) error { - return m.sendResponse(ctx, &RPCResponse{ - ID: req.ID, - Result: nil, - Error: &RPCError{ - Code: -32603, - Message: err.Error(), - }, - }) -} - -func (m *RPCModule) sendMethodNotFoundResponse(ctx context.Context, req *RPCRequest) error { - return m.sendResponse(ctx, &RPCResponse{ - ID: req.ID, - Result: nil, - Error: &RPCError{ - Code: -32601, - Message: fmt.Sprintf("method not found"), - }, - }) -} - -func (m *RPCModule) sendResponse(ctx context.Context, res *RPCResponse) error { - msg := NewServerMessage(ctx, map[string]interface{}{ - "jsonrpc": "2.0", - "id": res.ID, - "error": res.Error, - "result": res.Result, - }) - - if err := m.bus.Publish(ctx, msg); err != nil { - return errors.WithStack(err) - } - - return nil -} - -func (m *RPCModule) isRPCRequest(msg *ClientMessage) (bool, *RPCRequest) { - jsonRPC, exists := msg.Data["jsonrpc"] - if !exists || jsonRPC != "2.0" { - return false, nil - } - - rawMethod, exists := msg.Data["method"] - if !exists { - return false, nil - } - - method, ok := rawMethod.(string) - if !ok { - return false, nil - } - - id := msg.Data["id"] - params := msg.Data["params"] - - return true, &RPCRequest{ - ID: id, - Method: method, - Params: params, - } -} - -func RPCModuleFactory(bus bus.Bus) app.ServerModuleFactory { - return func(server *app.Server) app.ServerModule { - mod := &RPCModule{ - server: server, - bus: bus, - } - - return mod - } -} - -var _ app.InitializableModule = &RPCModule{} diff --git a/pkg/module/rpc/envelope.go b/pkg/module/rpc/envelope.go new file mode 100644 index 0000000..b7dd44a --- /dev/null +++ b/pkg/module/rpc/envelope.go @@ -0,0 +1,21 @@ +package rpc + +import ( + "context" + + "forge.cadoles.com/arcad/edge/pkg/bus" +) + +const ( + Address bus.Address = "module/rpc" +) + +type Request struct { + Context context.Context + Method string + Params map[string]any +} + +func NewRequestEnvelope(ctx context.Context, method string, params map[string]any) bus.Envelope { + return bus.NewEnvelope(Address, &Request{ctx, method, params}) +} diff --git a/pkg/module/rpc/error.go b/pkg/module/rpc/error.go new file mode 100644 index 0000000..0b0b7d0 --- /dev/null +++ b/pkg/module/rpc/error.go @@ -0,0 +1,7 @@ +package rpc + +import "errors" + +var ( + ErrMethodNotFound = errors.New("method not found") +) diff --git a/pkg/module/rpc/rpc.go b/pkg/module/rpc/rpc.go new file mode 100644 index 0000000..f495cde --- /dev/null +++ b/pkg/module/rpc/rpc.go @@ -0,0 +1,199 @@ +package rpc + +import ( + "context" + "sync" + + "forge.cadoles.com/arcad/edge/pkg/app" + "forge.cadoles.com/arcad/edge/pkg/bus" + "forge.cadoles.com/arcad/edge/pkg/module/util" + "github.com/dop251/goja" + "github.com/pkg/errors" + "gitlab.com/wpetit/goweb/logger" +) + +type RPCModule struct { + server *app.Server + bus bus.Bus + callbacks sync.Map +} + +func (m *RPCModule) Name() string { + return "rpc" +} + +func (m *RPCModule) Export(export *goja.Object) { + if err := export.Set("register", m.register); err != nil { + panic(errors.Wrap(err, "could not set 'register' function")) + } + + if err := export.Set("unregister", m.unregister); err != nil { + panic(errors.Wrap(err, "could not set 'unregister' function")) + } +} + +func (m *RPCModule) OnInit(ctx context.Context, rt *goja.Runtime) error { + go func() { + err := m.bus.Reply(ctx, Address, m.handleRequestEnvelope) + if err != nil { + logger.Error(ctx, "could not setup reply hander", logger.CapturedE(errors.WithStack(err))) + } + }() + + return nil +} + +func (m *RPCModule) register(call goja.FunctionCall, rt *goja.Runtime) goja.Value { + fnName := util.AssertString(call.Argument(0), rt) + + var ( + callable goja.Callable + ok bool + ) + + if len(call.Arguments) > 1 { + callable, ok = goja.AssertFunction(call.Argument(1)) + } else { + callable, ok = goja.AssertFunction(rt.Get(fnName)) + } + + if !ok { + panic(rt.NewTypeError("method should be a valid function")) + } + + ctx := context.Background() + + logger.Debug(ctx, "registering method", logger.F("method", fnName)) + + m.callbacks.Store(fnName, callable) + + return nil +} + +func (m *RPCModule) unregister(call goja.FunctionCall, rt *goja.Runtime) goja.Value { + fnName := util.AssertString(call.Argument(0), rt) + + m.callbacks.Delete(fnName) + + return nil +} + +func (m *RPCModule) handleRequestEnvelope(env bus.Envelope) (any, error) { + request, ok := env.Message().(*Request) + if !ok { + logger.Warn(context.Background(), "unexpected bus message", logger.F("message", env.Message())) + + return nil, errors.WithStack(bus.ErrUnexpectedMessage) + } + + ctx := logger.With(request.Context, logger.F("request", request)) + + logger.Debug(ctx, "received rpc request") + + rawCallable, exists := m.callbacks.Load(request.Method) + if !exists { + logger.Debug(ctx, "method not found") + + return nil, errors.WithStack(ErrMethodNotFound) + } + + callable, ok := rawCallable.(goja.Callable) + if !ok { + logger.Debug(ctx, "invalid method") + + return nil, errors.WithStack(ErrMethodNotFound) + } + + result, err := m.server.Exec(ctx, callable, request, request.Params) + if err != nil { + logger.Error( + ctx, "rpc call error", + logger.CapturedE(errors.WithStack(err)), + ) + + return nil, errors.WithStack(err) + } + + promise, ok := app.IsPromise(result) + if ok { + result = m.server.WaitForPromise(promise) + } + + return result.Export(), nil +} + +// func (m *RPCModule) sendErrorResponse(ctx context.Context, req *RPCRequest, err error) error { +// return m.sendResponse(&RPCResponse{ +// ID: req.ID, +// Result: nil, +// Error: &RPCError{ +// Code: -32603, +// Message: err.Error(), +// }, +// }) +// } + +// func (m *RPCModule) sendMethodNotFoundResponse(req *RPCRequest) error { +// return m.sendResponse(&RPCResponse{ +// ID: req.ID, +// Result: nil, +// Error: &RPCError{ +// Code: -32601, +// Message: "method not found", +// }, +// }) +// } + +// func (m *RPCModule) sendResponse(res *RPCResponse) error { +// env := NewServerEnvelope(context.Background(), map[string]interface{}{ +// "jsonrpc": "2.0", +// "id": res.ID, +// "error": res.Error, +// "result": res.Result, +// }) + +// if err := m.bus.Publish(env); err != nil { +// return errors.WithStack(err) +// } + +// return nil +// } + +// func (m *RPCModule) isRPCRequest(msg *ClientMessage) (bool, *RPCRequest) { +// jsonRPC, exists := msg.Data["jsonrpc"] +// if !exists || jsonRPC != "2.0" { +// return false, nil +// } + +// rawMethod, exists := msg.Data["method"] +// if !exists { +// return false, nil +// } + +// method, ok := rawMethod.(string) +// if !ok { +// return false, nil +// } + +// id := msg.Data["id"] +// params := msg.Data["params"] + +// return true, &RPCRequest{ +// ID: id, +// Method: method, +// Params: params, +// } +// } + +func RPCModuleFactory(bus bus.Bus) app.ServerModuleFactory { + return func(server *app.Server) app.ServerModule { + mod := &RPCModule{ + server: server, + bus: bus, + } + + return mod + } +} + +var _ app.InitializableModule = &RPCModule{} diff --git a/pkg/module/rpc/rpc_test.go b/pkg/module/rpc/rpc_test.go new file mode 100644 index 0000000..e4488bb --- /dev/null +++ b/pkg/module/rpc/rpc_test.go @@ -0,0 +1,109 @@ +package rpc + +import ( + "context" + "os" + "sync" + "testing" + + "forge.cadoles.com/arcad/edge/pkg/app" + "forge.cadoles.com/arcad/edge/pkg/bus" + "forge.cadoles.com/arcad/edge/pkg/bus/memory" + "forge.cadoles.com/arcad/edge/pkg/module" + "github.com/pkg/errors" + "gitlab.com/wpetit/goweb/logger" +) + +func TestServerExecDeadlock(t *testing.T) { + if testing.Verbose() { + logger.SetLevel(logger.LevelDebug) + } + + b := memory.NewBus(memory.WithBufferSize(1)) + + server := app.NewServer( + module.ConsoleModuleFactory(), + RPCModuleFactory(b), + module.LifecycleModuleFactory(), + ) + + data, err := os.ReadFile("testdata/deadlock.js") + if err != nil { + t.Fatalf("%+v", errors.WithStack(err)) + } + + ctx := context.Background() + + t.Log("starting server") + + if err := server.Start(ctx, "deadlock.js", string(data)); err != nil { + t.Fatalf("%+v", errors.WithStack(err)) + } + + defer server.Stop() + + t.Log("server started") + + count := 100 + delay := 500 + + var wg sync.WaitGroup + + wg.Add(count) + + for i := 0; i < count; i++ { + go func(i int) { + defer wg.Done() + + t.Logf("calling %d", i) + + isCanceled := i%5 == 0 + + var ctx context.Context + if isCanceled { + canceledCtx, cancel := context.WithCancel(context.Background()) + cancel() + ctx = canceledCtx + } else { + ctx = context.Background() + } + + env := NewRequestEnvelope(ctx, "doSomethingLong", map[string]any{ + "i": i, + "delay": delay, + }) + + t.Logf("publishing envelope #%d", i) + + reply, err := b.Request(ctx, env) + if err != nil { + if errors.Is(err, context.Canceled) && isCanceled { + return + } + + if errors.Is(err, bus.ErrNoResponse) && isCanceled { + return + } + + t.Errorf("%+v", errors.WithStack(err)) + + return + } + + result, ok := reply.Message().(int64) + if !ok { + t.Errorf("response.Result: expected type '%T', got '%T'", int64(0), reply.Message()) + + return + } + + if e, g := i, int(result); e != g { + t.Errorf("response.Result: expected '%v', got '%v'", e, g) + + return + } + }(i) + } + + wg.Wait() +} diff --git a/pkg/module/rpc/testdata/deadlock.js b/pkg/module/rpc/testdata/deadlock.js new file mode 100644 index 0000000..8a667a6 --- /dev/null +++ b/pkg/module/rpc/testdata/deadlock.js @@ -0,0 +1,14 @@ +function onInit() { + rpc.register("doSomethingLong", doSomethingLong) +} + +function doSomethingLong(ctx, params) { + var start = Date.now() + + while (true) { + var now = Date.now() + if (now - start >= params.delay) break + } + + return params.i; +} diff --git a/pkg/module/share/module_test.go b/pkg/module/share/module_test.go index 587d03e..5ce367d 100644 --- a/pkg/module/share/module_test.go +++ b/pkg/module/share/module_test.go @@ -33,18 +33,14 @@ func TestModule(t *testing.T) { t.Fatalf("%+v", errors.WithStack(err)) } - if err := server.Load("testdata/share.js", string(data)); err != nil { + ctx := context.Background() + if err := server.Start(ctx, "testdata/share.js", string(data)); err != nil { t.Fatalf("%+v", errors.WithStack(err)) } - ctx := context.Background() - if err := server.Start(ctx); err != nil { - t.Fatalf("%+v", errors.WithStack(err)) - } + defer server.Stop() if _, err := server.ExecFuncByName(context.Background(), "testModule"); err != nil { t.Fatalf("%+v", errors.WithStack(err)) } - - server.Stop() } diff --git a/pkg/module/store/module_test.go b/pkg/module/store/module_test.go index 080c9f2..1b70610 100644 --- a/pkg/module/store/module_test.go +++ b/pkg/module/store/module_test.go @@ -27,18 +27,14 @@ func TestStoreModule(t *testing.T) { t.Fatalf("%+v", errors.WithStack(err)) } - if err := server.Load("testdata/store.js", string(data)); err != nil { + ctx := context.Background() + if err := server.Start(ctx, "testdata/store.js", string(data)); err != nil { t.Fatalf("%+v", errors.WithStack(err)) } - ctx := context.Background() - if err := server.Start(ctx); err != nil { - t.Fatalf("%+v", errors.WithStack(err)) - } + defer server.Stop() if _, err := server.ExecFuncByName(context.Background(), "testStore"); err != nil { t.Fatalf("%+v", errors.WithStack(err)) } - - server.Stop() }