diff --git a/cmd/cli/command/app/run.go b/cmd/cli/command/app/run.go index a11343a..d6169ec 100644 --- a/cmd/cli/command/app/run.go +++ b/cmd/cli/command/app/run.go @@ -23,10 +23,11 @@ import ( authModule "forge.cadoles.com/arcad/edge/pkg/module/auth" authHTTP "forge.cadoles.com/arcad/edge/pkg/module/auth/http" authModuleMiddleware "forge.cadoles.com/arcad/edge/pkg/module/auth/middleware" - "forge.cadoles.com/arcad/edge/pkg/module/blob" - "forge.cadoles.com/arcad/edge/pkg/module/cast" - "forge.cadoles.com/arcad/edge/pkg/module/fetch" + blobModule "forge.cadoles.com/arcad/edge/pkg/module/blob" + castModule "forge.cadoles.com/arcad/edge/pkg/module/cast" + fetchModule "forge.cadoles.com/arcad/edge/pkg/module/fetch" netModule "forge.cadoles.com/arcad/edge/pkg/module/net" + rpcModule "forge.cadoles.com/arcad/edge/pkg/module/rpc" shareModule "forge.cadoles.com/arcad/edge/pkg/module/share" "forge.cadoles.com/arcad/edge/pkg/storage" "gitlab.com/wpetit/goweb/logger" @@ -106,6 +107,11 @@ func RunCommand() *cli.Command { Usage: "use `FILE` as local accounts", Value: ".edge/%APPID%/accounts.json", }, + &cli.IntFlag{ + Name: "max-upload-size", + Usage: "use `MAX-UPLOAD-SIZE` as blob max upload size", + Value: 10 << (10 * 2), // 10Mb + }, }, Action: func(ctx *cli.Context) error { address := ctx.String("address") @@ -117,6 +123,7 @@ func RunCommand() *cli.Command { documentstoreDSN := ctx.String("documentstore-dsn") shareStoreDSN := ctx.String("sharestore-dsn") accountsFile := ctx.String("accounts-file") + maxUploadSize := ctx.Int("max-upload-size") logger.SetFormat(logger.Format(logFormat)) logger.SetLevel(logger.Level(logLevel)) @@ -162,7 +169,7 @@ func RunCommand() *cli.Command { appCtx := logger.With(cmdCtx, logger.F("address", address)) - if err := runApp(appCtx, path, address, documentstoreDSN, blobstoreDSN, shareStoreDSN, accountsFile, appsRepository); err != nil { + if err := runApp(appCtx, path, address, documentstoreDSN, blobstoreDSN, shareStoreDSN, accountsFile, appsRepository, maxUploadSize); err != nil { logger.Error(appCtx, "could not run app", logger.CapturedE(errors.WithStack(err))) } }(p, port, idx) @@ -175,7 +182,7 @@ func RunCommand() *cli.Command { } } -func runApp(ctx context.Context, path, address, documentStoreDSN, blobStoreDSN, shareStoreDSN, accountsFile string, appRepository appModule.Repository) error { +func runApp(ctx context.Context, path, address, documentStoreDSN, blobStoreDSN, shareStoreDSN, accountsFile string, appRepository appModule.Repository, maxUploadSize int) error { absPath, err := filepath.Abs(path) if err != nil { return errors.Wrapf(err, "could not resolve path '%s'", path) @@ -236,6 +243,8 @@ func runApp(ctx context.Context, path, address, documentStoreDSN, blobStoreDSN, return jwtutil.NewSymmetricKeySet(dummySecret) }), ), + blobModule.Mount(maxUploadSize), // 10Mb, + fetchModule.Mount(), ), appHTTP.WithHTTPMiddlewares( authModuleMiddleware.AnonymousUser(key, jwa.HS256), @@ -278,18 +287,18 @@ func getServerModules(deps *moduleDeps) []app.ServerModuleFactory { module.LifecycleModuleFactory(), module.ContextModuleFactory(), module.ConsoleModuleFactory(), - cast.CastModuleFactory(), + castModule.CastModuleFactory(), netModule.ModuleFactory(deps.Bus), - module.RPCModuleFactory(deps.Bus), + rpcModule.ModuleFactory(deps.Bus), module.StoreModuleFactory(deps.DocumentStore), - blob.ModuleFactory(deps.Bus, deps.BlobStore), + blobModule.ModuleFactory(deps.Bus, deps.BlobStore), authModule.ModuleFactory( authModule.WithJWT(func() (jwk.Set, error) { return jwtutil.NewSymmetricKeySet(dummySecret) }), ), appModule.ModuleFactory(deps.AppRepository), - fetch.ModuleFactory(deps.Bus), + fetchModule.ModuleFactory(deps.Bus), shareModule.ModuleFactory(deps.AppID, deps.ShareStore), } } diff --git a/go.mod b/go.mod index 143ff60..4549a4e 100644 --- a/go.mod +++ b/go.mod @@ -29,6 +29,7 @@ require ( github.com/lestrrat-go/iter v1.0.2 // indirect github.com/lestrrat-go/option v1.0.0 // indirect github.com/miekg/dns v1.1.53 // indirect + go.uber.org/goleak v1.3.0 // indirect golang.org/x/sync v0.1.0 // indirect google.golang.org/genproto v0.0.0-20210226172003-ab064af71705 // indirect google.golang.org/grpc v1.35.0 // indirect diff --git a/go.sum b/go.sum index a4f017c..b7aebf3 100644 --- a/go.sum +++ b/go.sum @@ -326,6 +326,8 @@ go.opencensus.io v0.22.3/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.4/go.mod h1:yxeiOL68Rb0Xd1ddK5vPZ/oVn4vY4Ynel7k9FzqtOIw= go.opencensus.io v0.22.5 h1:dntmOdLpSpHlVqbW5Eay97DelsZHe+55D+xC6i0dDS0= go.opencensus.io v0.22.5/go.mod h1:5pWMHQbX5EPX2/62yrJeAkowc+lfs/XD7Uxpq3pI6kk= +go.uber.org/goleak v1.3.0 h1:2K3zAYmnTNqV73imy9J1T3WC+gmCePx2hEGkimedGto= +go.uber.org/goleak v1.3.0/go.mod h1:CoHD4mav9JJNrW/WLlf7HGZPjdw8EucARQHekz1X6bE= golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w= golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI= diff --git a/pkg/app/promise_proxy.go b/pkg/app/promise_proxy.go index 2031be1..dfdb4b4 100644 --- a/pkg/app/promise_proxy.go +++ b/pkg/app/promise_proxy.go @@ -47,6 +47,10 @@ func NewPromiseProxyFrom(rt *goja.Runtime) *PromiseProxy { } func IsPromise(v goja.Value) (*goja.Promise, bool) { + if v == nil { + return nil, false + } + promise, ok := v.Export().(*goja.Promise) return promise, ok } diff --git a/pkg/app/server.go b/pkg/app/server.go index 7f2c08e..6006ebf 100644 --- a/pkg/app/server.go +++ b/pkg/app/server.go @@ -4,6 +4,7 @@ import ( "context" "math/rand" "sync" + "time" "github.com/dop251/goja" "github.com/dop251/goja_nodejs/eventloop" @@ -22,23 +23,7 @@ type Server struct { modules []ServerModule } -func (s *Server) Load(name string, src string) error { - var err error - - s.loop.RunOnLoop(func(rt *goja.Runtime) { - _, err = rt.RunScript(name, src) - if err != nil { - err = errors.Wrap(err, "could not run js script") - } - }) - if err != nil { - return errors.WithStack(err) - } - - return nil -} - -func (s *Server) ExecFuncByName(ctx context.Context, funcName string, args ...interface{}) (goja.Value, error) { +func (s *Server) ExecFuncByName(ctx context.Context, funcName string, args ...interface{}) (any, error) { ctx = logger.With(ctx, logger.F("function", funcName), logger.F("args", args)) ret, err := s.Exec(ctx, funcName, args...) @@ -49,16 +34,23 @@ func (s *Server) ExecFuncByName(ctx context.Context, funcName string, args ...in return ret, nil } -func (s *Server) Exec(ctx context.Context, callableOrFuncname any, args ...interface{}) (goja.Value, error) { - var ( - wg sync.WaitGroup +func (s *Server) Exec(ctx context.Context, callableOrFuncname any, args ...interface{}) (any, error) { + type result struct { value goja.Value err error - ) + } - wg.Add(1) + done := make(chan result) + + defer func() { + // Drain done channel + for range done { + } + }() s.loop.RunOnLoop(func(rt *goja.Runtime) { + defer close(done) + var callable goja.Callable switch typ := callableOrFuncname.(type) { case goja.Callable: @@ -67,7 +59,9 @@ func (s *Server) Exec(ctx context.Context, callableOrFuncname any, args ...inter case string: call, ok := goja.AssertFunction(rt.Get(typ)) if !ok { - err = errors.WithStack(ErrFuncDoesNotExist) + done <- result{ + err: errors.WithStack(ErrFuncDoesNotExist), + } return } @@ -75,28 +69,27 @@ func (s *Server) Exec(ctx context.Context, callableOrFuncname any, args ...inter callable = call default: - err = errors.Errorf("callableOrFuncname: expected callable or function name, got '%T'", callableOrFuncname) + done <- result{ + err: errors.Errorf("callableOrFuncname: expected callable or function name, got '%T'", callableOrFuncname), + } return } - logger.Debug(ctx, "executing callable") - - defer wg.Done() - defer func() { - if recovered := recover(); recovered != nil { - revoveredErr, ok := recovered.(error) - if ok { - logger.Error(ctx, "recovered runtime error", logger.CapturedE(errors.WithStack(revoveredErr))) - - err = errors.WithStack(ErrUnknownError) - - return - } + recovered := recover() + if recovered == nil { + return + } + recoveredErr, ok := recovered.(error) + if !ok { panic(recovered) } + + done <- result{ + err: recoveredErr, + } }() jsArgs := make([]goja.Value, 0, len(args)) @@ -104,22 +97,49 @@ func (s *Server) Exec(ctx context.Context, callableOrFuncname any, args ...inter jsArgs = append(jsArgs, rt.ToValue(a)) } - value, err = callable(nil, jsArgs...) + logger.Debug(ctx, "executing callable", logger.F("callable", callableOrFuncname)) + + start := time.Now() + value, err := callable(nil, jsArgs...) if err != nil { - err = errors.WithStack(err) + done <- result{ + err: errors.WithStack(err), + } + + return } + + done <- result{ + value: value, + } + + logger.Debug(ctx, "executed callable", logger.F("callable", callableOrFuncname), logger.F("duration", time.Since(start).String())) }) - wg.Wait() + select { + case <-ctx.Done(): + if err := ctx.Err(); err != nil { + return nil, errors.WithStack(err) + } - if err != nil { - return nil, errors.WithStack(err) + return nil, nil + + case result := <-done: + if result.err != nil { + return nil, errors.WithStack(result.err) + } + + value := result.value + + if promise, ok := IsPromise(value); ok { + value = s.waitForPromise(promise) + } + + return value.Export(), nil } - - return value, nil } -func (s *Server) WaitForPromise(promise *goja.Promise) goja.Value { +func (s *Server) waitForPromise(promise *goja.Promise) goja.Value { var ( wg sync.WaitGroup value goja.Value @@ -162,20 +182,40 @@ func (s *Server) WaitForPromise(promise *goja.Promise) goja.Value { return value } -func (s *Server) Start(ctx context.Context) error { +func (s *Server) Start(ctx context.Context, name string, src string) error { s.loop.Start() - var err error + done := make(chan error) s.loop.RunOnLoop(func(rt *goja.Runtime) { + defer close(done) + rt.SetFieldNameMapper(goja.TagFieldNameMapper("goja", true)) rt.SetRandSource(createRandomSource()) - if err = s.initModules(ctx, rt); err != nil { + if err := s.loadModules(ctx, rt); err != nil { err = errors.WithStack(err) + done <- err + + return } + + if _, err := rt.RunScript(name, src); err != nil { + done <- errors.Wrap(err, "could not run js script") + return + } + + if err := s.initModules(ctx, rt); err != nil { + err = errors.WithStack(err) + done <- err + + return + } + + done <- nil }) - if err != nil { + + if err := <-done; err != nil { return errors.WithStack(err) } @@ -186,7 +226,7 @@ func (s *Server) Stop() { s.loop.Stop() } -func (s *Server) initModules(ctx context.Context, rt *goja.Runtime) error { +func (s *Server) loadModules(ctx context.Context, rt *goja.Runtime) error { modules := make([]ServerModule, 0, len(s.factories)) for _, moduleFactory := range s.factories { @@ -200,7 +240,13 @@ func (s *Server) initModules(ctx context.Context, rt *goja.Runtime) error { modules = append(modules, mod) } - for _, mod := range modules { + s.modules = modules + + return nil +} + +func (s *Server) initModules(ctx context.Context, rt *goja.Runtime) error { + for _, mod := range s.modules { initMod, ok := mod.(InitializableModule) if !ok { continue @@ -213,8 +259,6 @@ func (s *Server) initModules(ctx context.Context, rt *goja.Runtime) error { } } - s.modules = modules - return nil } diff --git a/pkg/bus/bus.go b/pkg/bus/bus.go index a02d437..0d0710b 100644 --- a/pkg/bus/bus.go +++ b/pkg/bus/bus.go @@ -3,11 +3,11 @@ package bus import "context" type Bus interface { - Subscribe(ctx context.Context, ns MessageNamespace) (<-chan Message, error) - Unsubscribe(ctx context.Context, ns MessageNamespace, ch <-chan Message) - Publish(ctx context.Context, msg Message) error - Request(ctx context.Context, msg Message) (Message, error) - Reply(ctx context.Context, ns MessageNamespace, h RequestHandler) error + Subscribe(ctx context.Context, addr Address) (<-chan Envelope, error) + Unsubscribe(addr Address, ch <-chan Envelope) + Publish(env Envelope) error + Request(ctx context.Context, env Envelope) (Envelope, error) + Reply(ctx context.Context, addr Address, h RequestHandler) chan error } -type RequestHandler func(msg Message) (Message, error) +type RequestHandler func(env Envelope) (any, error) diff --git a/pkg/bus/envelope.go b/pkg/bus/envelope.go new file mode 100644 index 0000000..23a7e7d --- /dev/null +++ b/pkg/bus/envelope.go @@ -0,0 +1,32 @@ +package bus + +type Address string + +type Envelope interface { + Message() any + Address() Address +} + +type BaseEnvelope struct { + msg any + addr Address +} + +// Address implements Envelope. +func (e *BaseEnvelope) Address() Address { + return e.addr +} + +// Message implements Envelope. +func (e *BaseEnvelope) Message() any { + return e.msg +} + +func NewEnvelope(addr Address, msg any) *BaseEnvelope { + return &BaseEnvelope{ + addr: addr, + msg: msg, + } +} + +var _ Envelope = &BaseEnvelope{} diff --git a/pkg/bus/memory/bus.go b/pkg/bus/memory/bus.go index d6b3088..be7b776 100644 --- a/pkg/bus/memory/bus.go +++ b/pkg/bus/memory/bus.go @@ -15,13 +15,13 @@ type Bus struct { nextRequestID uint64 } -func (b *Bus) Subscribe(ctx context.Context, ns bus.MessageNamespace) (<-chan bus.Message, error) { +func (b *Bus) Subscribe(ctx context.Context, address bus.Address) (<-chan bus.Envelope, error) { logger.Debug( - ctx, "subscribing to messages", - logger.F("messageNamespace", ns), + ctx, "subscribing", + logger.F("address", address), ) - dispatchers := b.getDispatchers(ns) + dispatchers := b.getDispatchers(address) disp := newEventDispatcher(b.opt.BufferSize) go disp.Run(ctx) @@ -31,50 +31,41 @@ func (b *Bus) Subscribe(ctx context.Context, ns bus.MessageNamespace) (<-chan bu return disp.Out(), nil } -func (b *Bus) Unsubscribe(ctx context.Context, ns bus.MessageNamespace, ch <-chan bus.Message) { +func (b *Bus) Unsubscribe(address bus.Address, ch <-chan bus.Envelope) { logger.Debug( - ctx, "unsubscribing from messages", - logger.F("messageNamespace", ns), + context.Background(), "unsubscribing", + logger.F("address", address), ) - dispatchers := b.getDispatchers(ns) + dispatchers := b.getDispatchers(address) dispatchers.RemoveByOutChannel(ch) } -func (b *Bus) Publish(ctx context.Context, msg bus.Message) error { - dispatchers := b.getDispatchers(msg.MessageNamespace()) - dispatchersList := dispatchers.List() - +func (b *Bus) Publish(env bus.Envelope) error { + dispatchers := b.getDispatchers(env.Address()) logger.Debug( - ctx, "publishing message", - logger.F("dispatchers", len(dispatchersList)), - logger.F("messageNamespace", msg.MessageNamespace()), + context.Background(), "publish", + logger.F("address", env.Address()), ) - for _, d := range dispatchersList { - if d.Closed() { - dispatchers.Remove(d) - - continue + dispatchers.Range(func(d *eventDispatcher) { + if err := d.In(env); err != nil { + logger.Error(context.Background(), "could not publish message", logger.CapturedE(errors.WithStack(err))) } - - if err := d.In(msg); err != nil { - return errors.WithStack(err) - } - } + }) return nil } -func (b *Bus) getDispatchers(namespace bus.MessageNamespace) *eventDispatcherSet { - strNamespace := string(namespace) +func (b *Bus) getDispatchers(address bus.Address) *eventDispatcherSet { + rawAddress := string(address) - rawDispatchers, exists := b.dispatchers.Get(strNamespace) + rawDispatchers, exists := b.dispatchers.Get(rawAddress) dispatchers, ok := rawDispatchers.(*eventDispatcherSet) if !exists || !ok { dispatchers = newEventDispatcherSet() - b.dispatchers.Set(strNamespace, dispatchers) + b.dispatchers.Set(rawAddress, dispatchers) } return dispatchers diff --git a/pkg/bus/memory/bus_test.go b/pkg/bus/memory/bus_test.go index ac79f69..8efe19d 100644 --- a/pkg/bus/memory/bus_test.go +++ b/pkg/bus/memory/bus_test.go @@ -4,13 +4,23 @@ import ( "testing" busTesting "forge.cadoles.com/arcad/edge/pkg/bus/testing" + "gitlab.com/wpetit/goweb/logger" + "go.uber.org/goleak" ) +func TestMain(m *testing.M) { + goleak.VerifyTestMain(m) +} + func TestMemoryBus(t *testing.T) { if testing.Short() { t.Skip("Test disabled when -short flag is set") } + if testing.Verbose() { + logger.SetLevel(logger.LevelDebug) + } + t.Parallel() t.Run("PublishSubscribe", func(t *testing.T) { @@ -26,4 +36,11 @@ func TestMemoryBus(t *testing.T) { b := NewBus() busTesting.TestRequestReply(t, b) }) + + t.Run("CanceledRequestReply", func(t *testing.T) { + t.Parallel() + + b := NewBus() + busTesting.TestCanceledRequest(t, b) + }) } diff --git a/pkg/bus/memory/event_dispatcher.go b/pkg/bus/memory/event_dispatcher.go index a078939..c9f3941 100644 --- a/pkg/bus/memory/event_dispatcher.go +++ b/pkg/bus/memory/event_dispatcher.go @@ -3,7 +3,6 @@ package memory import ( "context" "sync" - "time" "forge.cadoles.com/arcad/edge/pkg/bus" "github.com/pkg/errors" @@ -30,7 +29,7 @@ func (s *eventDispatcherSet) Remove(d *eventDispatcher) { delete(s.items, d) } -func (s *eventDispatcherSet) RemoveByOutChannel(out <-chan bus.Message) { +func (s *eventDispatcherSet) RemoveByOutChannel(out <-chan bus.Envelope) { s.mutex.Lock() defer s.mutex.Unlock() @@ -42,17 +41,18 @@ func (s *eventDispatcherSet) RemoveByOutChannel(out <-chan bus.Message) { } } -func (s *eventDispatcherSet) List() []*eventDispatcher { +func (s *eventDispatcherSet) Range(fn func(d *eventDispatcher)) { s.mutex.Lock() defer s.mutex.Unlock() - dispatchers := make([]*eventDispatcher, 0, len(s.items)) - for d := range s.items { - dispatchers = append(dispatchers, d) - } + if d.Closed() { + s.Remove(d) + continue + } - return dispatchers + fn(d) + } } func newEventDispatcherSet() *eventDispatcherSet { @@ -62,8 +62,8 @@ func newEventDispatcherSet() *eventDispatcherSet { } type eventDispatcher struct { - in chan bus.Message - out chan bus.Message + in chan bus.Envelope + out chan bus.Envelope mutex sync.RWMutex closed bool } @@ -91,7 +91,7 @@ func (d *eventDispatcher) close() { d.closed = true } -func (d *eventDispatcher) In(msg bus.Message) (err error) { +func (d *eventDispatcher) In(msg bus.Envelope) (err error) { d.mutex.RLock() defer d.mutex.RUnlock() @@ -104,67 +104,52 @@ func (d *eventDispatcher) In(msg bus.Message) (err error) { return nil } -func (d *eventDispatcher) Out() <-chan bus.Message { +func (d *eventDispatcher) Out() <-chan bus.Envelope { return d.out } -func (d *eventDispatcher) IsOut(out <-chan bus.Message) bool { +func (d *eventDispatcher) IsOut(out <-chan bus.Envelope) bool { return d.out == out } func (d *eventDispatcher) Run(ctx context.Context) { defer func() { - for { - logger.Debug(ctx, "closing dispatcher, flushing out incoming messages") + logger.Debug(ctx, "closing dispatcher, flushing out incoming messages") - close(d.out) + close(d.out) + for range d.in { // Flush all incoming messages - for { - _, ok := <-d.in - if !ok { - return - } - } } }() for { - msg, ok := <-d.in - if !ok { - return - } - - timeout := time.After(time.Second) - select { - case d.out <- msg: - case <-timeout: - logger.Error( - ctx, - "out message channel timeout", - logger.F("message", msg), - ) - - return - case <-ctx.Done(): - logger.Error( - ctx, - "message subscription context canceled", - logger.F("message", msg), - logger.CapturedE(errors.WithStack(ctx.Err())), - ) + if err := ctx.Err(); !errors.Is(err, context.Canceled) { + logger.Error( + ctx, + "message subscription context canceled", + logger.CapturedE(errors.WithStack(err)), + ) + } return + + case msg, ok := <-d.in: + if !ok { + return + } + + d.out <- msg } } } func newEventDispatcher(bufferSize int64) *eventDispatcher { return &eventDispatcher{ - in: make(chan bus.Message, bufferSize), - out: make(chan bus.Message, bufferSize), + in: make(chan bus.Envelope, bufferSize), + out: make(chan bus.Envelope, bufferSize), closed: false, } } diff --git a/pkg/bus/memory/request_reply.go b/pkg/bus/memory/request_reply.go index aaa9390..730ead5 100644 --- a/pkg/bus/memory/request_reply.go +++ b/pkg/bus/memory/request_reply.go @@ -11,57 +11,78 @@ import ( ) const ( - MessageNamespaceRequest bus.MessageNamespace = "reqrep/request" - MessageNamespaceReply bus.MessageNamespace = "reqrep/reply" + AddressRequest bus.Address = "bus/memory/request" + AddressReply bus.Address = "bus/memory/reply" ) -type RequestMessage struct { - RequestID uint64 - - Message bus.Message - - ns bus.MessageNamespace +type RequestEnvelope struct { + requestID uint64 + wrapped bus.Envelope } -func (m *RequestMessage) MessageNamespace() bus.MessageNamespace { - return m.ns +func (e *RequestEnvelope) Address() bus.Address { + return getRequestAddress(e.wrapped.Address()) } -type ReplyMessage struct { - RequestID uint64 - Message bus.Message - Error error - - ns bus.MessageNamespace +func (e *RequestEnvelope) Message() any { + return e.wrapped.Message() } -func (m *ReplyMessage) MessageNamespace() bus.MessageNamespace { - return m.ns +func (e *RequestEnvelope) RequestID() uint64 { + return e.requestID } -func (b *Bus) Request(ctx context.Context, msg bus.Message) (bus.Message, error) { +func (e *RequestEnvelope) Unwrap() bus.Envelope { + return e.wrapped +} + +type ReplyEnvelope struct { + requestID uint64 + wrapped bus.Envelope + err error +} + +func (e *ReplyEnvelope) Address() bus.Address { + return getReplyAddress(e.wrapped.Address(), e.requestID) +} + +func (e *ReplyEnvelope) Message() any { + return e.wrapped.Message() +} + +func (e *ReplyEnvelope) Err() error { + return e.err +} + +func (e *ReplyEnvelope) Unwrap() bus.Envelope { + return e.wrapped +} + +func (b *Bus) Request(ctx context.Context, env bus.Envelope) (bus.Envelope, error) { requestID := atomic.AddUint64(&b.nextRequestID, 1) - req := &RequestMessage{ - RequestID: requestID, - Message: msg, - ns: msg.MessageNamespace(), + req := &RequestEnvelope{ + requestID: requestID, + wrapped: env, } - replyNamespace := createReplyNamespace(requestID) + replyAddress := getReplyAddress(env.Address(), requestID) - replies, err := b.Subscribe(ctx, replyNamespace) + subCtx, cancel := context.WithCancel(ctx) + defer cancel() + + replies, err := b.Subscribe(subCtx, replyAddress) if err != nil { return nil, errors.WithStack(err) } defer func() { - b.Unsubscribe(ctx, replyNamespace, replies) + b.Unsubscribe(replyAddress, replies) }() logger.Debug(ctx, "publishing request", logger.F("request", req)) - if err := b.Publish(ctx, req); err != nil { + if err := b.Publish(req); err != nil { return nil, errors.WithStack(err) } @@ -70,82 +91,93 @@ func (b *Bus) Request(ctx context.Context, msg bus.Message) (bus.Message, error) case <-ctx.Done(): return nil, errors.WithStack(ctx.Err()) - case msg, ok := <-replies: + case env, ok := <-replies: if !ok { return nil, errors.WithStack(bus.ErrNoResponse) } - reply, ok := msg.(*ReplyMessage) + reply, ok := env.(*ReplyEnvelope) if !ok { return nil, errors.WithStack(bus.ErrUnexpectedMessage) } - if reply.Error != nil { + if err := reply.Err(); err != nil { return nil, errors.WithStack(err) } - return reply.Message, nil + return reply.Unwrap(), nil } } } -type RequestHandler func(evt bus.Message) (bus.Message, error) +func (b *Bus) Reply(ctx context.Context, address bus.Address, handler bus.RequestHandler) chan error { + requestAddress := getRequestAddress(address) -func (b *Bus) Reply(ctx context.Context, msgNamespace bus.MessageNamespace, h bus.RequestHandler) error { - requests, err := b.Subscribe(ctx, msgNamespace) + errs := make(chan error) + + requests, err := b.Subscribe(ctx, requestAddress) if err != nil { - return errors.WithStack(err) + go func() { + errs <- errors.WithStack(err) + close(errs) + }() + + return errs } - defer func() { - b.Unsubscribe(ctx, msgNamespace, requests) + go func() { + defer func() { + b.Unsubscribe(requestAddress, requests) + close(errs) + }() + + for { + select { + case <-ctx.Done(): + errs <- errors.WithStack(ctx.Err()) + return + + case env, ok := <-requests: + if !ok { + return + } + + request, ok := env.(*RequestEnvelope) + if !ok { + errs <- errors.WithStack(bus.ErrUnexpectedMessage) + continue + } + + logger.Debug(ctx, "handling request", logger.F("request", request)) + + msg, err := handler(request.Unwrap()) + + reply := &ReplyEnvelope{ + requestID: request.RequestID(), + wrapped: bus.NewEnvelope(request.Unwrap().Address(), msg), + } + + if err != nil { + reply.err = errors.WithStack(err) + } + + logger.Debug(ctx, "publishing reply", logger.F("reply", reply)) + + if err := b.Publish(reply); err != nil { + errs <- errors.WithStack(err) + continue + } + } + } }() - for { - select { - case <-ctx.Done(): - return errors.WithStack(ctx.Err()) - - case msg, ok := <-requests: - if !ok { - return nil - } - - request, ok := msg.(*RequestMessage) - if !ok { - return errors.WithStack(bus.ErrUnexpectedMessage) - } - - logger.Debug(ctx, "handling request", logger.F("request", request)) - - msg, err := h(request.Message) - - reply := &ReplyMessage{ - RequestID: request.RequestID, - Message: nil, - Error: nil, - - ns: createReplyNamespace(request.RequestID), - } - - if err != nil { - reply.Error = errors.WithStack(err) - } else { - reply.Message = msg - } - - logger.Debug(ctx, "publishing reply", logger.F("reply", reply)) - - if err := b.Publish(ctx, reply); err != nil { - return errors.WithStack(err) - } - } - } + return errs } -func createReplyNamespace(requestID uint64) bus.MessageNamespace { - return bus.NewMessageNamespace( - MessageNamespaceReply, - bus.MessageNamespace(strconv.FormatUint(requestID, 10)), - ) +func getRequestAddress(addr bus.Address) bus.Address { + return AddressRequest + "/" + addr +} + +func getReplyAddress(addr bus.Address, requestID uint64) bus.Address { + return AddressReply + "/" + addr + "/" + bus.Address(strconv.FormatUint(requestID, 10)) } diff --git a/pkg/bus/message.go b/pkg/bus/message.go deleted file mode 100644 index 3a470d1..0000000 --- a/pkg/bus/message.go +++ /dev/null @@ -1,33 +0,0 @@ -package bus - -import ( - "strings" - - "github.com/pkg/errors" -) - -type ( - MessageNamespace string -) - -type Message interface { - MessageNamespace() MessageNamespace -} - -func NewMessageNamespace(namespaces ...MessageNamespace) MessageNamespace { - var sb strings.Builder - - for i, ns := range namespaces { - if i != 0 { - if _, err := sb.WriteString(":"); err != nil { - panic(errors.Wrap(err, "could not build new message namespace")) - } - } - - if _, err := sb.WriteString(string(ns)); err != nil { - panic(errors.Wrap(err, "could not build new message namespace")) - } - } - - return MessageNamespace(sb.String()) -} diff --git a/pkg/bus/testing/publish_subscribe.go b/pkg/bus/testing/publish_subscribe.go index 6db69e3..2dd917b 100644 --- a/pkg/bus/testing/publish_subscribe.go +++ b/pkg/bus/testing/publish_subscribe.go @@ -2,6 +2,7 @@ package testing import ( "context" + "fmt" "sync" "sync/atomic" "testing" @@ -12,74 +13,52 @@ import ( ) const ( - testNamespace bus.MessageNamespace = "testNamespace" + testAddress bus.Address = "testAddress" ) -type testMessage struct{} - -func (e *testMessage) MessageNamespace() bus.MessageNamespace { - return testNamespace -} - func TestPublishSubscribe(t *testing.T, b bus.Bus) { ctx, cancel := context.WithTimeout(context.Background(), time.Minute) defer cancel() t.Log("subscribe") - messages, err := b.Subscribe(ctx, testNamespace) + envelopes, err := b.Subscribe(ctx, testAddress) if err != nil { t.Fatal(errors.WithStack(err)) } + expectedTotal := 5 + var wg sync.WaitGroup - wg.Add(5) + wg.Add(expectedTotal) go func() { - // 5 events should be received - t.Log("publish 0") - if err := b.Publish(ctx, &testMessage{}); err != nil { - t.Error(errors.WithStack(err)) - } + count := expectedTotal - t.Log("publish 1") + for i := 0; i < count; i++ { + env := bus.NewEnvelope(testAddress, fmt.Sprintf("message %d", i)) - if err := b.Publish(ctx, &testMessage{}); err != nil { - t.Error(errors.WithStack(err)) - } + if err := b.Publish(env); err != nil { + t.Error(errors.WithStack(err)) + } - t.Log("publish 2") - - if err := b.Publish(ctx, &testMessage{}); err != nil { - t.Error(errors.WithStack(err)) - } - - t.Log("publish 3") - - if err := b.Publish(ctx, &testMessage{}); err != nil { - t.Error(errors.WithStack(err)) - } - - t.Log("publish 4") - - if err := b.Publish(ctx, &testMessage{}); err != nil { - t.Error(errors.WithStack(err)) + t.Logf("published %d", i) } }() var count int32 = 0 go func() { - t.Log("range for events") + t.Log("range for received envelopes") - for msg := range messages { + for env := range envelopes { t.Logf("received msg %d", atomic.LoadInt32(&count)) atomic.AddInt32(&count, 1) - if e, g := testNamespace, msg.MessageNamespace(); e != g { - t.Errorf("evt.MessageNamespace(): expected '%v', got '%v'", e, g) + if e, g := testAddress, env.Address(); e != g { + t.Errorf("env.Address(): expected '%v', got '%v'", e, g) } wg.Done() @@ -88,9 +67,9 @@ func TestPublishSubscribe(t *testing.T, b bus.Bus) { wg.Wait() - b.Unsubscribe(ctx, testNamespace, messages) + b.Unsubscribe(testAddress, envelopes) - if e, g := int32(5), count; e != g { - t.Errorf("message received count: expected '%v', got '%v'", e, g) + if e, g := int32(expectedTotal), count; e != g { + t.Errorf("envelopes received count: expected '%v', got '%v'", e, g) } } diff --git a/pkg/bus/testing/request_reply.go b/pkg/bus/testing/request_reply.go index 22ceddd..905afe1 100644 --- a/pkg/bus/testing/request_reply.go +++ b/pkg/bus/testing/request_reply.go @@ -11,58 +11,42 @@ import ( ) const ( - testTypeReqRes bus.MessageNamespace = "testNamspaceReqRes" + testTypeReqResAddress bus.Address = "testTypeReqResAddress" ) -type testReqResMessage struct { - i int -} - -func (m *testReqResMessage) MessageNamespace() bus.MessageNamespace { - return testNamespace -} - func TestRequestReply(t *testing.T, b bus.Bus) { expectedRoundTrips := 256 timeout := time.Now().Add(time.Duration(expectedRoundTrips) * time.Second) - var ( - initWaitGroup sync.WaitGroup - resWaitGroup sync.WaitGroup - ) + replyCtx, cancelReply := context.WithDeadline(context.Background(), timeout) + defer cancelReply() - initWaitGroup.Add(1) + var resWaitGroup sync.WaitGroup + + replyErrs := b.Reply(replyCtx, testTypeReqResAddress, func(env bus.Envelope) (any, error) { + defer resWaitGroup.Done() + + req, ok := env.Message().(int) + if !ok { + return nil, errors.WithStack(bus.ErrUnexpectedMessage) + } + + // Simulate random work + time.Sleep(time.Millisecond * 100) + + t.Logf("[RES] sending res #%d", req) + + return req, nil + }) go func() { - repondCtx, cancelRespond := context.WithDeadline(context.Background(), timeout) - defer cancelRespond() - - initWaitGroup.Done() - - err := b.Reply(repondCtx, testNamespace, func(msg bus.Message) (bus.Message, error) { - defer resWaitGroup.Done() - - req, ok := msg.(*testReqResMessage) - if !ok { - return nil, errors.WithStack(bus.ErrUnexpectedMessage) + for err := range replyErrs { + if !errors.Is(err, context.Canceled) { + t.Errorf("%+v", errors.WithStack(err)) } - - result := &testReqResMessage{req.i} - - // Simulate random work - time.Sleep(time.Millisecond * 100) - - t.Logf("[RES] sending res #%d", req.i) - - return result, nil - }) - if err != nil { - t.Error(err) } }() - initWaitGroup.Wait() - var reqWaitGroup sync.WaitGroup for i := 0; i < expectedRoundTrips; i++ { @@ -75,32 +59,30 @@ func TestRequestReply(t *testing.T, b bus.Bus) { requestCtx, cancelRequest := context.WithDeadline(context.Background(), timeout) defer cancelRequest() - req := &testReqResMessage{i} - t.Logf("[REQ] sending req #%d", i) - result, err := b.Request(requestCtx, req) + response, err := b.Request(requestCtx, bus.NewEnvelope(testTypeReqResAddress, i)) if err != nil { t.Error(err) } t.Logf("[REQ] received req #%d reply", i) - if result == nil { - t.Error("result should not be nil") + if response == nil { + t.Error("response should not be nil") return } - res, ok := result.(*testReqResMessage) + result, ok := response.Message().(int) if !ok { t.Error(errors.WithStack(bus.ErrUnexpectedMessage)) return } - if e, g := req.i, res.i; e != g { - t.Errorf("res.i: expected '%v', got '%v'", e, g) + if e, g := i, result; e != g { + t.Errorf("response.Message(): expected '%v', got '%v'", e, g) } }(i) } @@ -108,3 +90,77 @@ func TestRequestReply(t *testing.T, b bus.Bus) { reqWaitGroup.Wait() resWaitGroup.Wait() } + +func TestCanceledRequest(t *testing.T, b bus.Bus) { + replyCtx, cancelReply := context.WithCancel(context.Background()) + defer cancelReply() + + errs := b.Reply(replyCtx, testTypeReqResAddress, func(env bus.Envelope) (any, error) { + return env.Message(), nil + }) + + go func() { + for err := range errs { + if !errors.Is(err, context.Canceled) { + t.Errorf("%+v", errors.WithStack(err)) + } + } + }() + + var wg sync.WaitGroup + + count := 100 + + wg.Add(count) + + for i := 0; i < count; i++ { + go func(i int) { + defer wg.Done() + + t.Logf("calling %d", i) + + isCanceled := i%2 == 0 + + var ctx context.Context + if isCanceled { + canceledCtx, cancel := context.WithCancel(context.Background()) + cancel() + ctx = canceledCtx + } else { + ctx = context.Background() + } + + t.Logf("publishing envelope #%d", i) + + reply, err := b.Request(ctx, bus.NewEnvelope(testTypeReqResAddress, int64(i))) + if err != nil { + if errors.Is(err, context.Canceled) && isCanceled { + return + } + + if errors.Is(err, bus.ErrNoResponse) && isCanceled { + return + } + + t.Errorf("%+v", errors.WithStack(err)) + + return + } + + result, ok := reply.Message().(int64) + if !ok { + t.Errorf("response.Result: expected type '%T', got '%T'", int64(0), reply.Message()) + + return + } + + if e, g := i, int(result); e != g { + t.Errorf("response.Result: expected '%v', got '%v'", e, g) + + return + } + }(i) + } + + wg.Wait() +} diff --git a/pkg/http/blob.go b/pkg/http/blob.go deleted file mode 100644 index d549393..0000000 --- a/pkg/http/blob.go +++ /dev/null @@ -1,282 +0,0 @@ -package http - -import ( - "encoding/json" - "io" - "io/fs" - "mime/multipart" - "net/http" - "os" - "time" - - "forge.cadoles.com/arcad/edge/pkg/bus" - "forge.cadoles.com/arcad/edge/pkg/module" - "forge.cadoles.com/arcad/edge/pkg/module/blob" - "forge.cadoles.com/arcad/edge/pkg/storage" - "github.com/go-chi/chi/v5" - "github.com/pkg/errors" - "gitlab.com/wpetit/goweb/logger" -) - -const ( - errorCodeForbidden = "forbidden" - errorCodeInternalError = "internal-error" - errorCodeBadRequest = "bad-request" - errorCodeNotFound = "not-found" -) - -type uploadResponse struct { - Bucket string `json:"bucket"` - BlobID storage.BlobID `json:"blobId"` -} - -func (h *Handler) handleAppUpload(w http.ResponseWriter, r *http.Request) { - h.mutex.RLock() - defer h.mutex.RUnlock() - - ctx := r.Context() - - r.Body = http.MaxBytesReader(w, r.Body, h.uploadMaxFileSize) - - if err := r.ParseMultipartForm(h.uploadMaxFileSize); err != nil { - logger.Error(ctx, "could not parse multipart form", logger.CapturedE(errors.WithStack(err))) - jsonError(w, http.StatusBadRequest, errorCodeBadRequest) - - return - } - - _, fileHeader, err := r.FormFile("file") - if err != nil { - logger.Error(ctx, "could not read form file", logger.CapturedE(errors.WithStack(err))) - jsonError(w, http.StatusBadRequest, errorCodeBadRequest) - - return - } - - var metadata map[string]any - - rawMetadata := r.Form.Get("metadata") - if rawMetadata != "" { - if err := json.Unmarshal([]byte(rawMetadata), &metadata); err != nil { - logger.Error(ctx, "could not parse metadata", logger.CapturedE(errors.WithStack(err))) - jsonError(w, http.StatusBadRequest, errorCodeBadRequest) - - return - } - } - - ctx = module.WithContext(ctx, map[module.ContextKey]any{ - ContextKeyOriginRequest: r, - }) - - requestMsg := blob.NewMessageUploadRequest(ctx, fileHeader, metadata) - - reply, err := h.bus.Request(ctx, requestMsg) - if err != nil { - logger.Error(ctx, "could not retrieve file", logger.CapturedE(errors.WithStack(err))) - jsonError(w, http.StatusInternalServerError, errorCodeInternalError) - - return - } - - logger.Debug(ctx, "upload reply", logger.F("reply", reply)) - - responseMsg, ok := reply.(*blob.MessageUploadResponse) - if !ok { - logger.Error( - ctx, "unexpected upload response message", - logger.F("message", reply), - ) - jsonError(w, http.StatusInternalServerError, errorCodeInternalError) - - return - } - - if !responseMsg.Allow { - jsonError(w, http.StatusForbidden, errorCodeForbidden) - - return - } - - encoder := json.NewEncoder(w) - res := &uploadResponse{ - Bucket: responseMsg.Bucket, - BlobID: responseMsg.BlobID, - } - - if err := encoder.Encode(res); err != nil { - panic(errors.Wrap(err, "could not encode upload response")) - } -} - -func (h *Handler) handleAppDownload(w http.ResponseWriter, r *http.Request) { - h.mutex.RLock() - defer h.mutex.RUnlock() - - bucket := chi.URLParam(r, "bucket") - blobID := chi.URLParam(r, "blobID") - - ctx := logger.With(r.Context(), logger.F("blobID", blobID), logger.F("bucket", bucket)) - ctx = module.WithContext(ctx, map[module.ContextKey]any{ - ContextKeyOriginRequest: r, - }) - - requestMsg := blob.NewMessageDownloadRequest(ctx, bucket, storage.BlobID(blobID)) - - reply, err := h.bus.Request(ctx, requestMsg) - if err != nil { - logger.Error(ctx, "could not retrieve file", logger.CapturedE(errors.WithStack(err))) - http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) - - return - } - - replyMsg, ok := reply.(*blob.MessageDownloadResponse) - if !ok { - logger.Error( - ctx, "unexpected download response message", - logger.CapturedE(errors.WithStack(bus.ErrUnexpectedMessage)), - logger.F("message", reply), - ) - jsonError(w, http.StatusInternalServerError, errorCodeInternalError) - - return - } - - if !replyMsg.Allow { - jsonError(w, http.StatusForbidden, errorCodeForbidden) - - return - } - - if replyMsg.Blob == nil { - jsonError(w, http.StatusNotFound, errorCodeNotFound) - - return - } - - defer func() { - if err := replyMsg.Blob.Close(); err != nil { - logger.Error(ctx, "could not close blob", logger.CapturedE(errors.WithStack(err))) - } - }() - - http.ServeContent(w, r, string(replyMsg.BlobInfo.ID()), replyMsg.BlobInfo.ModTime(), replyMsg.Blob) -} - -func serveFile(w http.ResponseWriter, r *http.Request, fs fs.FS, path string) { - ctx := logger.With(r.Context(), logger.F("path", path)) - - file, err := fs.Open(path) - if err != nil { - if errors.Is(err, os.ErrNotExist) { - http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) - - return - } - - logger.Error(ctx, "error while opening fs file", logger.CapturedE(errors.WithStack(err))) - http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) - - return - } - - defer func() { - if err := file.Close(); err != nil { - logger.Error(ctx, "error while closing fs file", logger.CapturedE(errors.WithStack(err))) - } - }() - - info, err := file.Stat() - if err != nil { - logger.Error(ctx, "error while retrieving fs file stat", logger.CapturedE(errors.WithStack(err))) - http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) - - return - } - - reader, ok := file.(io.ReadSeeker) - if !ok { - return - } - - http.ServeContent(w, r, path, info.ModTime(), reader) -} - -type jsonErrorResponse struct { - Error jsonErr `json:"error"` -} - -type jsonErr struct { - Code string `json:"code"` -} - -func jsonError(w http.ResponseWriter, status int, code string) { - w.Header().Add("Content-Type", "application/json") - w.WriteHeader(status) - - encoder := json.NewEncoder(w) - response := jsonErrorResponse{ - Error: jsonErr{ - Code: code, - }, - } - - if err := encoder.Encode(response); err != nil { - panic(errors.WithStack(err)) - } -} - -type uploadedFile struct { - multipart.File - header *multipart.FileHeader - modTime time.Time -} - -// Stat implements fs.File -func (f *uploadedFile) Stat() (fs.FileInfo, error) { - return &uploadedFileInfo{ - header: f.header, - modTime: f.modTime, - }, nil -} - -type uploadedFileInfo struct { - header *multipart.FileHeader - modTime time.Time -} - -// IsDir implements fs.FileInfo -func (i *uploadedFileInfo) IsDir() bool { - return false -} - -// ModTime implements fs.FileInfo -func (i *uploadedFileInfo) ModTime() time.Time { - return i.modTime -} - -// Mode implements fs.FileInfo -func (i *uploadedFileInfo) Mode() fs.FileMode { - return os.ModePerm -} - -// Name implements fs.FileInfo -func (i *uploadedFileInfo) Name() string { - return i.header.Filename -} - -// Size implements fs.FileInfo -func (i *uploadedFileInfo) Size() int64 { - return i.header.Size -} - -// Sys implements fs.FileInfo -func (i *uploadedFileInfo) Sys() any { - return nil -} - -var ( - _ fs.File = &uploadedFile{} - _ fs.FileInfo = &uploadedFileInfo{} -) diff --git a/pkg/http/client.go b/pkg/http/client.go index 96dd0b5..e8383ca 100644 --- a/pkg/http/client.go +++ b/pkg/http/client.go @@ -7,11 +7,11 @@ import ( ) func (h *Handler) handleSDKClient(w http.ResponseWriter, r *http.Request) { - serveFile(w, r, &sdk.FS, "client/dist/client.js") + ServeFile(w, r, &sdk.FS, "client/dist/client.js") } func (h *Handler) handleSDKClientMap(w http.ResponseWriter, r *http.Request) { - serveFile(w, r, &sdk.FS, "client/dist/client.js.map") + ServeFile(w, r, &sdk.FS, "client/dist/client.js.map") } func (h *Handler) handleAppFiles(w http.ResponseWriter, r *http.Request) { diff --git a/pkg/http/context.go b/pkg/http/context.go new file mode 100644 index 0000000..e492dd3 --- /dev/null +++ b/pkg/http/context.go @@ -0,0 +1,55 @@ +package http + +import ( + "context" + + "net/http" + + "forge.cadoles.com/arcad/edge/pkg/bus" + "github.com/pkg/errors" +) + +type contextKey string + +var ( + contextKeyBus contextKey = "bus" + contextKeyHTTPRequest contextKey = "httpRequest" + contextKeyHTTPClient contextKey = "httpClient" +) + +func (h *Handler) contextMiddleware(next http.Handler) http.Handler { + fn := func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + ctx = context.WithValue(ctx, contextKeyBus, h.bus) + ctx = context.WithValue(ctx, contextKeyHTTPRequest, r) + ctx = context.WithValue(ctx, contextKeyHTTPClient, h.httpClient) + + r = r.WithContext(ctx) + + next.ServeHTTP(w, r) + } + + return http.HandlerFunc(fn) +} + +func ContextBus(ctx context.Context) bus.Bus { + return contextValue[bus.Bus](ctx, contextKeyBus) +} + +func ContextHTTPRequest(ctx context.Context) *http.Request { + return contextValue[*http.Request](ctx, contextKeyHTTPRequest) +} + +func ContextHTTPClient(ctx context.Context) *http.Client { + return contextValue[*http.Client](ctx, contextKeyHTTPClient) +} + +func contextValue[T any](ctx context.Context, key any) T { + value, ok := ctx.Value(key).(T) + if !ok { + panic(errors.Errorf("could not find key '%v' on context", key)) + } + + return value +} diff --git a/pkg/http/envelope.go b/pkg/http/envelope.go new file mode 100644 index 0000000..4826377 --- /dev/null +++ b/pkg/http/envelope.go @@ -0,0 +1,30 @@ +package http + +import ( + "context" + + "forge.cadoles.com/arcad/edge/pkg/bus" +) + +var ( + AddressIncomingMessage bus.Address = "http/incoming-message" + AddressOutgoingMessage bus.Address = "http/outgoing-message" +) + +type IncomingMessage struct { + Context context.Context + Payload map[string]any +} + +func NewIncomingMessageEnvelope(ctx context.Context, payload map[string]any) bus.Envelope { + return bus.NewEnvelope(AddressIncomingMessage, &IncomingMessage{ctx, payload}) +} + +type OutgoingMessage struct { + SessionID string + Data any +} + +func NewOutgoingMessageEnvelope(sessionID string, data any) bus.Envelope { + return bus.NewEnvelope(AddressOutgoingMessage, &OutgoingMessage{sessionID, data}) +} diff --git a/pkg/http/handler.go b/pkg/http/handler.go index d486d75..c9cf99e 100644 --- a/pkg/http/handler.go +++ b/pkg/http/handler.go @@ -24,10 +24,9 @@ type Handler struct { public http.Handler router chi.Router - sockjs http.Handler - bus bus.Bus - sockjsOpts sockjs.Options - uploadMaxFileSize int64 + sockjs http.Handler + bus bus.Bus + sockjsOpts sockjs.Options server *app.Server serverModuleFactories []app.ServerModuleFactory @@ -57,10 +56,6 @@ func (h *Handler) Load(ctx context.Context, bdle bundle.Bundle) error { server := app.NewServer(h.serverModuleFactories...) - if err := server.Load(serverMainScript, string(mainScript)); err != nil { - return errors.WithStack(err) - } - fs := bundle.NewFileSystem("public", bdle) public := HTML5Fileserver(fs) sockjs := sockjs.NewHandler(sockJSPathPrefix, h.sockjsOpts, h.handleSockJSSession) @@ -69,7 +64,7 @@ func (h *Handler) Load(ctx context.Context, bdle bundle.Bundle) error { h.server.Stop() } - if err := server.Start(ctx); err != nil { + if err := server.Start(ctx, serverMainScript, string(mainScript)); err != nil { return errors.WithStack(err) } @@ -90,7 +85,6 @@ func NewHandler(funcs ...HandlerOptionFunc) *Handler { router := chi.NewRouter() handler := &Handler{ - uploadMaxFileSize: opts.UploadMaxFileSize, sockjsOpts: opts.SockJS, router: router, serverModuleFactories: opts.ServerModuleFactories, @@ -108,19 +102,15 @@ func NewHandler(funcs ...HandlerOptionFunc) *Handler { r.Get("/client.js.map", handler.handleSDKClientMap) }) - r.Route("/api", func(r chi.Router) { - r.Post("/v1/upload", handler.handleAppUpload) - r.Get("/v1/download/{bucket}/{blobID}", handler.handleAppDownload) - - r.Get("/v1/fetch", handler.handleAppFetch) + r.Group(func(r chi.Router) { + r.Use(handler.contextMiddleware) + for _, fn := range opts.HTTPMounts { + r.Group(func(r chi.Router) { + fn(r) + }) + } }) - for _, fn := range opts.HTTPMounts { - r.Group(func(r chi.Router) { - fn(r) - }) - } - r.HandleFunc("/sock/*", handler.handleSockJS) }) diff --git a/pkg/http/options.go b/pkg/http/options.go index 5b11dad..9c980f4 100644 --- a/pkg/http/options.go +++ b/pkg/http/options.go @@ -15,7 +15,6 @@ type HandlerOptions struct { Bus bus.Bus SockJS sockjs.Options ServerModuleFactories []app.ServerModuleFactory - UploadMaxFileSize int64 HTTPClient *http.Client HTTPMounts []func(r chi.Router) HTTPMiddlewares []func(next http.Handler) http.Handler @@ -31,7 +30,6 @@ func defaultHandlerOptions() *HandlerOptions { Bus: memory.NewBus(), SockJS: sockjsOptions, ServerModuleFactories: make([]app.ServerModuleFactory, 0), - UploadMaxFileSize: 10 << (10 * 2), // 10Mb HTTPClient: &http.Client{ Timeout: time.Second * 30, }, @@ -60,12 +58,6 @@ func WithBus(bus bus.Bus) HandlerOptionFunc { } } -func WithUploadMaxFileSize(size int64) HandlerOptionFunc { - return func(opts *HandlerOptions) { - opts.UploadMaxFileSize = size - } -} - func WithHTTPClient(client *http.Client) HandlerOptionFunc { return func(opts *HandlerOptions) { opts.HTTPClient = client diff --git a/pkg/http/sockjs.go b/pkg/http/sockjs.go index 57020e9..9b7b3c6 100644 --- a/pkg/http/sockjs.go +++ b/pkg/http/sockjs.go @@ -42,19 +42,18 @@ func (h *Handler) handleSockJSSession(sess sockjs.Session) { } }() - go h.handleServerMessages(ctx, sess) - h.handleClientMessages(ctx, sess) + go h.handleOutgoingMessages(ctx, sess) + h.handleIncomingMessages(ctx, sess) } -func (h *Handler) handleServerMessages(ctx context.Context, sess sockjs.Session) { - messages, err := h.bus.Subscribe(ctx, module.MessageNamespaceServer) +func (h *Handler) handleOutgoingMessages(ctx context.Context, sess sockjs.Session) { + envelopes, err := h.bus.Subscribe(ctx, AddressOutgoingMessage) if err != nil { panic(errors.WithStack(err)) } defer func() { - // Close messages subscriber - h.bus.Unsubscribe(ctx, module.MessageNamespaceServer, messages) + h.bus.Unsubscribe(AddressOutgoingMessage, envelopes) logger.Debug(ctx, "unsubscribed") @@ -72,26 +71,22 @@ func (h *Handler) handleServerMessages(ctx context.Context, sess sockjs.Session) case <-ctx.Done(): return - case msg := <-messages: - serverMessage, ok := msg.(*module.ServerMessage) + case env := <-envelopes: + outgoingMessage, ok := env.Message().(*OutgoingMessage) if !ok { logger.Error( ctx, - "unexpected server message", - logger.F("message", msg), + "unexpected outgoing message", + logger.F("message", env.Message()), ) - - continue } - sessionID := module.ContextValue[string](serverMessage.Context, ContextKeySessionID) - - isDest := sessionID == "" || sessionID == sess.ID() + isDest := outgoingMessage.SessionID == "" || outgoingMessage.SessionID == sess.ID() if !isDest { continue } - payload, err := json.Marshal(serverMessage.Data) + payload, err := json.Marshal(outgoingMessage.Data) if err != nil { logger.Error( ctx, @@ -132,7 +127,7 @@ func (h *Handler) handleServerMessages(ctx context.Context, sess sockjs.Session) } } -func (h *Handler) handleClientMessages(ctx context.Context, sess sockjs.Session) { +func (h *Handler) handleIncomingMessages(ctx context.Context, sess sockjs.Session) { for { select { case <-ctx.Done(): @@ -145,7 +140,7 @@ func (h *Handler) handleClientMessages(ctx context.Context, sess sockjs.Session) data, err := sess.RecvCtx(ctx) if err != nil { - if errors.Is(err, sockjs.ErrSessionNotOpen) { + if errors.Is(err, sockjs.ErrSessionNotOpen) || errors.Is(err, context.Canceled) { break } @@ -174,7 +169,7 @@ func (h *Handler) handleClientMessages(ctx context.Context, sess sockjs.Session) switch { case message.Type == WebsocketMessageTypeMessage: - var payload map[string]interface{} + var payload map[string]any if err := json.Unmarshal(message.Payload, &payload); err != nil { logger.Error( ctx, @@ -191,21 +186,19 @@ func (h *Handler) handleClientMessages(ctx context.Context, sess sockjs.Session) ContextKeyOriginRequest: sess.Request(), }) - clientMessage := module.NewClientMessage(ctx, payload) + incomingMessage := NewIncomingMessageEnvelope(ctx, payload) - logger.Debug(ctx, "publishing new client message", logger.F("message", clientMessage)) + logger.Debug(ctx, "publishing new incoming message", logger.F("message", incomingMessage)) - if err := h.bus.Publish(ctx, clientMessage); err != nil { + if err := h.bus.Publish(incomingMessage); err != nil { logger.Error(ctx, "could not publish message", logger.CapturedE(errors.WithStack(err)), - logger.F("message", clientMessage), + logger.F("message", incomingMessage), ) return } - logger.Debug(ctx, "new client message published", logger.F("message", clientMessage)) - default: logger.Error( ctx, diff --git a/pkg/http/util.go b/pkg/http/util.go new file mode 100644 index 0000000..4a7b328 --- /dev/null +++ b/pkg/http/util.go @@ -0,0 +1,82 @@ +package http + +import ( + "encoding/json" + "io" + "io/fs" + "net/http" + "os" + + "github.com/pkg/errors" + "gitlab.com/wpetit/goweb/logger" +) + +const ( + ErrCodeForbidden = "forbidden" + ErrCodeInternalError = "internal-error" + ErrCodeBadRequest = "bad-request" + ErrCodeNotFound = "not-found" +) + +type jsonErrorResponse struct { + Error jsonErr `json:"error"` +} + +type jsonErr struct { + Code string `json:"code"` +} + +func JSONError(w http.ResponseWriter, status int, code string) { + w.Header().Add("Content-Type", "application/json") + w.WriteHeader(status) + + encoder := json.NewEncoder(w) + response := jsonErrorResponse{ + Error: jsonErr{ + Code: code, + }, + } + + if err := encoder.Encode(response); err != nil { + panic(errors.WithStack(err)) + } +} + +func ServeFile(w http.ResponseWriter, r *http.Request, fs fs.FS, path string) { + ctx := logger.With(r.Context(), logger.F("path", path)) + + file, err := fs.Open(path) + if err != nil { + if errors.Is(err, os.ErrNotExist) { + http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) + + return + } + + logger.Error(ctx, "error while opening fs file", logger.CapturedE(errors.WithStack(err))) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + + return + } + + defer func() { + if err := file.Close(); err != nil { + logger.Error(ctx, "error while closing fs file", logger.CapturedE(errors.WithStack(err))) + } + }() + + info, err := file.Stat() + if err != nil { + logger.Error(ctx, "error while retrieving fs file stat", logger.CapturedE(errors.WithStack(err))) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + + return + } + + reader, ok := file.(io.ReadSeeker) + if !ok { + return + } + + http.ServeContent(w, r, path, info.ModTime(), reader) +} diff --git a/pkg/module/app/memory/module_test.go b/pkg/module/app/memory/module_test.go index 100a1c8..064f26a 100644 --- a/pkg/module/app/memory/module_test.go +++ b/pkg/module/app/memory/module_test.go @@ -39,21 +39,17 @@ func TestAppModuleWithMemoryRepository(t *testing.T) { )), ) - file := "testdata/app.js" + script := "testdata/app.js" - data, err := os.ReadFile(file) + data, err := os.ReadFile(script) if err != nil { t.Fatal(err) } - if err := server.Load(file, string(data)); err != nil { - t.Fatal(err) + ctx := context.Background() + if err := server.Start(ctx, script, string(data)); err != nil { + t.Fatalf("%+v", errors.WithStack(err)) } defer server.Stop() - - ctx := context.Background() - if err := server.Start(ctx); err != nil { - t.Fatalf("%+v", errors.WithStack(err)) - } } diff --git a/pkg/module/auth/module_test.go b/pkg/module/auth/module_test.go index e127add..de14312 100644 --- a/pkg/module/auth/module_test.go +++ b/pkg/module/auth/module_test.go @@ -2,8 +2,8 @@ package auth import ( "context" - "io/ioutil" "net/http" + "os" "testing" "time" @@ -22,7 +22,9 @@ import ( func TestAuthModule(t *testing.T) { t.Parallel() - logger.SetLevel(slog.LevelDebug) + if testing.Verbose() { + logger.SetLevel(slog.LevelDebug) + } key := getDummyKey() @@ -33,17 +35,15 @@ func TestAuthModule(t *testing.T) { ), ) - data, err := ioutil.ReadFile("testdata/auth.js") + script := "testdata/auth.js" + + data, err := os.ReadFile(script) if err != nil { t.Fatal(err) } - if err := server.Load("testdata/auth.js", string(data)); err != nil { - t.Fatal(err) - } - ctx := context.Background() - if err := server.Start(ctx); err != nil { + if err := server.Start(ctx, script, string(data)); err != nil { t.Fatalf("%+v", errors.WithStack(err)) } @@ -81,7 +81,9 @@ func TestAuthModule(t *testing.T) { func TestAuthAnonymousModule(t *testing.T) { t.Parallel() - logger.SetLevel(slog.LevelDebug) + if testing.Verbose() { + logger.SetLevel(slog.LevelDebug) + } key := getDummyKey() @@ -90,17 +92,15 @@ func TestAuthAnonymousModule(t *testing.T) { ModuleFactory(WithJWT(getDummyKeySet(key))), ) - data, err := ioutil.ReadFile("testdata/auth_anonymous.js") + script := "testdata/auth_anonymous.js" + + data, err := os.ReadFile("testdata/auth_anonymous.js") if err != nil { t.Fatal(err) } - if err := server.Load("testdata/auth_anonymous.js", string(data)); err != nil { - t.Fatal(err) - } - ctx := context.Background() - if err := server.Start(ctx); err != nil { + if err := server.Start(ctx, script, string(data)); err != nil { t.Fatalf("%+v", errors.WithStack(err)) } diff --git a/pkg/module/blob/blob_message.go b/pkg/module/blob/blob_message.go deleted file mode 100644 index d355a90..0000000 --- a/pkg/module/blob/blob_message.go +++ /dev/null @@ -1,92 +0,0 @@ -package blob - -import ( - "context" - "io" - "mime/multipart" - - "forge.cadoles.com/arcad/edge/pkg/bus" - "forge.cadoles.com/arcad/edge/pkg/storage" - "github.com/oklog/ulid/v2" -) - -const ( - MessageNamespaceUploadRequest bus.MessageNamespace = "uploadRequest" - MessageNamespaceUploadResponse bus.MessageNamespace = "uploadResponse" - MessageNamespaceDownloadRequest bus.MessageNamespace = "downloadRequest" - MessageNamespaceDownloadResponse bus.MessageNamespace = "downloadResponse" -) - -type MessageUploadRequest struct { - Context context.Context - RequestID string - FileHeader *multipart.FileHeader - Metadata map[string]interface{} -} - -func (m *MessageUploadRequest) MessageNamespace() bus.MessageNamespace { - return MessageNamespaceUploadRequest -} - -func NewMessageUploadRequest(ctx context.Context, fileHeader *multipart.FileHeader, metadata map[string]interface{}) *MessageUploadRequest { - return &MessageUploadRequest{ - Context: ctx, - RequestID: ulid.Make().String(), - FileHeader: fileHeader, - Metadata: metadata, - } -} - -type MessageUploadResponse struct { - RequestID string - BlobID storage.BlobID - Bucket string - Allow bool -} - -func (m *MessageUploadResponse) MessageNamespace() bus.MessageNamespace { - return MessageNamespaceDownloadResponse -} - -func NewMessageUploadResponse(requestID string) *MessageUploadResponse { - return &MessageUploadResponse{ - RequestID: requestID, - } -} - -type MessageDownloadRequest struct { - Context context.Context - RequestID string - Bucket string - BlobID storage.BlobID -} - -func (m *MessageDownloadRequest) MessageNamespace() bus.MessageNamespace { - return MessageNamespaceDownloadRequest -} - -func NewMessageDownloadRequest(ctx context.Context, bucket string, blobID storage.BlobID) *MessageDownloadRequest { - return &MessageDownloadRequest{ - Context: ctx, - RequestID: ulid.Make().String(), - Bucket: bucket, - BlobID: blobID, - } -} - -type MessageDownloadResponse struct { - RequestID string - Allow bool - BlobInfo storage.BlobInfo - Blob io.ReadSeekCloser -} - -func (m *MessageDownloadResponse) MessageNamespace() bus.MessageNamespace { - return MessageNamespaceDownloadResponse -} - -func NewMessageDownloadResponse(requestID string) *MessageDownloadResponse { - return &MessageDownloadResponse{ - RequestID: requestID, - } -} diff --git a/pkg/module/blob/envelope.go b/pkg/module/blob/envelope.go new file mode 100644 index 0000000..ea57ba5 --- /dev/null +++ b/pkg/module/blob/envelope.go @@ -0,0 +1,55 @@ +package blob + +import ( + "context" + "io" + "mime/multipart" + + "forge.cadoles.com/arcad/edge/pkg/bus" + "forge.cadoles.com/arcad/edge/pkg/storage" +) + +const ( + AddressUpload bus.Address = "module/blob/upload" + AddressDownload bus.Address = "module/blob/download" +) + +type UploadRequest struct { + Context context.Context + FileHeader *multipart.FileHeader + Metadata map[string]interface{} +} + +func NewUploadRequestEnvelope(ctx context.Context, fileHeader *multipart.FileHeader, metadata map[string]interface{}) bus.Envelope { + return bus.NewEnvelope(AddressUpload, &UploadRequest{ + Context: ctx, + FileHeader: fileHeader, + Metadata: metadata, + }) +} + +type UploadResponse struct { + Allow bool + Bucket string + BlobID storage.BlobID +} + +type DownloadRequest struct { + Context context.Context + Bucket string + BlobID storage.BlobID +} + +func NewDownloadRequestEnvelope(ctx context.Context, bucket string, blobID storage.BlobID) bus.Envelope { + return bus.NewEnvelope(AddressDownload, &DownloadRequest{ + Context: ctx, + Bucket: bucket, + BlobID: blobID, + }) +} + +type DownloadResponse struct { + Allow bool + Blob io.ReadSeekCloser + BlobInfo storage.BlobInfo +} diff --git a/pkg/module/blob/http.go b/pkg/module/blob/http.go new file mode 100644 index 0000000..fc4edcb --- /dev/null +++ b/pkg/module/blob/http.go @@ -0,0 +1,212 @@ +package blob + +import ( + "encoding/json" + "io/fs" + "mime/multipart" + "net/http" + "os" + "time" + + "forge.cadoles.com/arcad/edge/pkg/bus" + edgehttp "forge.cadoles.com/arcad/edge/pkg/http" + "forge.cadoles.com/arcad/edge/pkg/storage" + "github.com/go-chi/chi/v5" + "github.com/pkg/errors" + "gitlab.com/wpetit/goweb/logger" +) + +type uploadResponse struct { + Bucket string `json:"bucket"` + BlobID storage.BlobID `json:"blobId"` +} + +func Mount(uploadMaxFileSize int) func(r chi.Router) { + return func(r chi.Router) { + r.Post("/api/v1/upload", getAppUploadHandler(uploadMaxFileSize)) + r.Get("/api/v1/download/{bucket}/{blobID}", handleAppDownload) + } +} + +func getAppUploadHandler(fileMaxUpload int) func(w http.ResponseWriter, r *http.Request) { + return func(w http.ResponseWriter, r *http.Request) { + ctx := r.Context() + + uploadMaxFileSize := int64(8000) + + r.Body = http.MaxBytesReader(w, r.Body, uploadMaxFileSize) + + if err := r.ParseMultipartForm(uploadMaxFileSize); err != nil { + logger.Error(ctx, "could not parse multipart form", logger.CapturedE(errors.WithStack(err))) + edgehttp.JSONError(w, http.StatusBadRequest, edgehttp.ErrCodeBadRequest) + + return + } + + _, fileHeader, err := r.FormFile("file") + if err != nil { + logger.Error(ctx, "could not read form file", logger.CapturedE(errors.WithStack(err))) + edgehttp.JSONError(w, http.StatusBadRequest, edgehttp.ErrCodeBadRequest) + + return + } + + var metadata map[string]any + + rawMetadata := r.Form.Get("metadata") + if rawMetadata != "" { + if err := json.Unmarshal([]byte(rawMetadata), &metadata); err != nil { + logger.Error(ctx, "could not parse metadata", logger.CapturedE(errors.WithStack(err))) + edgehttp.JSONError(w, http.StatusBadRequest, edgehttp.ErrCodeBadRequest) + + return + } + } + + requestEnv := NewUploadRequestEnvelope(ctx, fileHeader, metadata) + + bus := edgehttp.ContextBus(ctx) + + reply, err := bus.Request(ctx, requestEnv) + if err != nil { + logger.Error(ctx, "could not retrieve file", logger.CapturedE(errors.WithStack(err))) + edgehttp.JSONError(w, http.StatusInternalServerError, edgehttp.ErrCodeInternalError) + + return + } + + logger.Debug(ctx, "upload reply", logger.F("reply", reply)) + + replyMessage, ok := reply.Message().(*UploadResponse) + if !ok { + logger.Error( + ctx, "unexpected upload response message", + logger.F("message", reply.Message()), + ) + edgehttp.JSONError(w, http.StatusInternalServerError, edgehttp.ErrCodeInternalError) + + return + } + + if !replyMessage.Allow { + edgehttp.JSONError(w, http.StatusForbidden, edgehttp.ErrCodeForbidden) + + return + } + + encoder := json.NewEncoder(w) + res := &uploadResponse{ + Bucket: replyMessage.Bucket, + BlobID: replyMessage.BlobID, + } + + if err := encoder.Encode(res); err != nil { + panic(errors.Wrap(err, "could not encode upload response")) + } + } +} + +func handleAppDownload(w http.ResponseWriter, r *http.Request) { + bucket := chi.URLParam(r, "bucket") + blobID := chi.URLParam(r, "blobID") + + ctx := logger.With(r.Context(), logger.F("blobID", blobID), logger.F("bucket", bucket)) + + requestMsg := NewDownloadRequestEnvelope(ctx, bucket, storage.BlobID(blobID)) + + bs := edgehttp.ContextBus(ctx) + + reply, err := bs.Request(ctx, requestMsg) + if err != nil { + logger.Error(ctx, "could not retrieve file", logger.CapturedE(errors.WithStack(err))) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + + return + } + + replyMessage, ok := reply.Message().(*DownloadResponse) + if !ok { + logger.Error( + ctx, "unexpected download response message", + logger.CapturedE(errors.WithStack(bus.ErrUnexpectedMessage)), + logger.F("message", reply), + ) + edgehttp.JSONError(w, http.StatusInternalServerError, edgehttp.ErrCodeInternalError) + + return + } + + if !replyMessage.Allow { + edgehttp.JSONError(w, http.StatusForbidden, edgehttp.ErrCodeForbidden) + + return + } + + if replyMessage.Blob == nil { + edgehttp.JSONError(w, http.StatusNotFound, edgehttp.ErrCodeNotFound) + + return + } + + defer func() { + if err := replyMessage.Blob.Close(); err != nil { + logger.Error(ctx, "could not close blob", logger.CapturedE(errors.WithStack(err))) + } + }() + + http.ServeContent(w, r, string(replyMessage.BlobInfo.ID()), replyMessage.BlobInfo.ModTime(), replyMessage.Blob) +} + +type uploadedFile struct { + multipart.File + header *multipart.FileHeader + modTime time.Time +} + +// Stat implements fs.File +func (f *uploadedFile) Stat() (fs.FileInfo, error) { + return &uploadedFileInfo{ + header: f.header, + modTime: f.modTime, + }, nil +} + +type uploadedFileInfo struct { + header *multipart.FileHeader + modTime time.Time +} + +// IsDir implements fs.FileInfo +func (i *uploadedFileInfo) IsDir() bool { + return false +} + +// ModTime implements fs.FileInfo +func (i *uploadedFileInfo) ModTime() time.Time { + return i.modTime +} + +// Mode implements fs.FileInfo +func (i *uploadedFileInfo) Mode() fs.FileMode { + return os.ModePerm +} + +// Name implements fs.FileInfo +func (i *uploadedFileInfo) Name() string { + return i.header.Filename +} + +// Size implements fs.FileInfo +func (i *uploadedFileInfo) Size() int64 { + return i.header.Size +} + +// Sys implements fs.FileInfo +func (i *uploadedFileInfo) Sys() any { + return nil +} + +var ( + _ fs.File = &uploadedFile{} + _ fs.FileInfo = &uploadedFileInfo{} +) diff --git a/pkg/module/blob/module.go b/pkg/module/blob/module.go index dedc245..7d34023 100644 --- a/pkg/module/blob/module.go +++ b/pkg/module/blob/module.go @@ -236,33 +236,34 @@ func (m *Module) getBucketSize(call goja.FunctionCall, rt *goja.Runtime) goja.Va func (m *Module) handleMessages() { ctx := context.Background() - go func() { - err := m.bus.Reply(ctx, MessageNamespaceUploadRequest, func(msg bus.Message) (bus.Message, error) { - uploadRequest, ok := msg.(*MessageUploadRequest) - if !ok { - return nil, errors.Wrapf(bus.ErrUnexpectedMessage, "expected message upload request, got '%T'", msg) - } + uploadRequestErrs := m.bus.Reply(ctx, AddressUpload, func(env bus.Envelope) (any, error) { + uploadRequest, ok := env.Message().(*UploadRequest) + if !ok { + return nil, errors.Wrapf(bus.ErrUnexpectedMessage, "expected message upload request, got '%T'", env.Message()) + } - res, err := m.handleUploadRequest(uploadRequest) - if err != nil { - logger.Error(ctx, "could not handle upload request", logger.CapturedE(errors.WithStack(err))) - - return nil, errors.WithStack(err) - } - - logger.Debug(ctx, "upload request response", logger.F("response", res)) - - return res, nil - }) + res, err := m.handleUploadRequest(uploadRequest) if err != nil { - panic(errors.WithStack(err)) + logger.Error(ctx, "could not handle upload request", logger.CapturedE(errors.WithStack(err))) + + return nil, errors.WithStack(err) + } + + logger.Debug(ctx, "upload request response", logger.F("response", res)) + + return res, nil + }) + + go func() { + for err := range uploadRequestErrs { + logger.Error(ctx, "error while replying to upload requests", logger.CapturedE(errors.WithStack(err))) } }() - err := m.bus.Reply(ctx, MessageNamespaceDownloadRequest, func(msg bus.Message) (bus.Message, error) { - downloadRequest, ok := msg.(*MessageDownloadRequest) + downloadRequestErrs := m.bus.Reply(ctx, AddressDownload, func(env bus.Envelope) (any, error) { + downloadRequest, ok := env.Message().(*DownloadRequest) if !ok { - return nil, errors.Wrapf(bus.ErrUnexpectedMessage, "expected message download request, got '%T'", msg) + return nil, errors.Wrapf(bus.ErrUnexpectedMessage, "expected message download request, got '%T'", env.Message()) } res, err := m.handleDownloadRequest(downloadRequest) @@ -274,14 +275,15 @@ func (m *Module) handleMessages() { return res, nil }) - if err != nil { - panic(errors.WithStack(err)) + + for err := range downloadRequestErrs { + logger.Fatal(ctx, "error while replying to download requests", logger.CapturedE(errors.WithStack(err))) } } -func (m *Module) handleUploadRequest(req *MessageUploadRequest) (*MessageUploadResponse, error) { +func (m *Module) handleUploadRequest(req *UploadRequest) (*UploadResponse, error) { blobID := storage.NewBlobID() - res := NewMessageUploadResponse(req.RequestID) + res := &UploadResponse{} ctx := logger.With(req.Context, logger.F("blobID", blobID)) @@ -302,11 +304,11 @@ func (m *Module) handleUploadRequest(req *MessageUploadRequest) (*MessageUploadR return nil, errors.WithStack(err) } - result, ok := rawResult.Export().(map[string]interface{}) + result, ok := rawResult.(map[string]interface{}) if !ok { return nil, errors.Errorf( "unexpected onBlobUpload result: expected 'map[string]interface{}', got '%T'", - rawResult.Export(), + rawResult, ) } @@ -393,8 +395,8 @@ func (m *Module) saveBlob(ctx context.Context, bucketName string, blobID storage return nil } -func (m *Module) handleDownloadRequest(req *MessageDownloadRequest) (*MessageDownloadResponse, error) { - res := NewMessageDownloadResponse(req.RequestID) +func (m *Module) handleDownloadRequest(req *DownloadRequest) (*DownloadResponse, error) { + res := &DownloadResponse{} rawResult, err := m.server.ExecFuncByName(req.Context, "onBlobDownload", req.Context, req.Bucket, req.BlobID) if err != nil { @@ -407,11 +409,11 @@ func (m *Module) handleDownloadRequest(req *MessageDownloadRequest) (*MessageDow return nil, errors.WithStack(err) } - result, ok := rawResult.Export().(map[string]interface{}) + result, ok := rawResult.(map[string]interface{}) if !ok { return nil, errors.Errorf( "unexpected onBlobDownload result: expected 'map[string]interface{}', got '%T'", - rawResult.Export(), + rawResult, ) } diff --git a/pkg/module/blob/module_test.go b/pkg/module/blob/module_test.go index 7c89bb3..bbaadcf 100644 --- a/pkg/module/blob/module_test.go +++ b/pkg/module/blob/module_test.go @@ -17,7 +17,9 @@ import ( func TestBlobModule(t *testing.T) { t.Parallel() - logger.SetLevel(slog.LevelDebug) + if testing.Verbose() { + logger.SetLevel(slog.LevelDebug) + } bus := memory.NewBus() store := sqlite.NewBlobStore(":memory:?_pragma=foreign_keys(1)&_pragma=busy_timeout=60000") @@ -28,19 +30,17 @@ func TestBlobModule(t *testing.T) { ModuleFactory(bus, store), ) - data, err := os.ReadFile("testdata/blob.js") + script := "testdata/blob.js" + + data, err := os.ReadFile(script) if err != nil { t.Fatal(err) } - if err := server.Load("testdata/blob.js", string(data)); err != nil { - t.Fatal(err) + ctx := context.Background() + if err := server.Start(ctx, script, string(data)); err != nil { + t.Fatalf("%+v", errors.WithStack(err)) } defer server.Stop() - - ctx := context.Background() - if err := server.Start(ctx); err != nil { - t.Fatalf("%+v", errors.WithStack(err)) - } } diff --git a/pkg/module/cast/cast_test.go b/pkg/module/cast/cast_test.go index ea3dd06..3170d09 100644 --- a/pkg/module/cast/cast_test.go +++ b/pkg/module/cast/cast_test.go @@ -21,7 +21,9 @@ func TestCastLoadURL(t *testing.T) { return } - logger.SetLevel(slog.LevelDebug) + if testing.Verbose() { + logger.SetLevel(slog.LevelDebug) + } ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() diff --git a/pkg/module/cast/module_test.go b/pkg/module/cast/module_test.go index 37c298f..f5bed87 100644 --- a/pkg/module/cast/module_test.go +++ b/pkg/module/cast/module_test.go @@ -2,7 +2,6 @@ package cast import ( "context" - "io/ioutil" "os" "testing" "time" @@ -24,24 +23,24 @@ func TestCastModule(t *testing.T) { return } - logger.SetLevel(slog.LevelDebug) + if testing.Verbose() { + logger.SetLevel(slog.LevelDebug) + } server := app.NewServer( module.ConsoleModuleFactory(), CastModuleFactory(), ) - data, err := ioutil.ReadFile("testdata/cast.js") + script := "testdata/cast.js" + + data, err := os.ReadFile(script) if err != nil { t.Fatal(err) } - if err := server.Load("testdata/cast.js", string(data)); err != nil { - t.Fatal(err) - } - ctx := context.Background() - if err := server.Start(ctx); err != nil { + if err := server.Start(ctx, script, string(data)); err != nil { t.Fatalf("%+v", errors.WithStack(err)) } @@ -59,24 +58,24 @@ func TestCastModuleRefreshDevices(t *testing.T) { return } - logger.SetLevel(slog.LevelDebug) + if testing.Verbose() { + logger.SetLevel(slog.LevelDebug) + } server := app.NewServer( module.ConsoleModuleFactory(), CastModuleFactory(), ) - data, err := ioutil.ReadFile("testdata/refresh_devices.js") + script := "testdata/refresh_devices.js" + + data, err := os.ReadFile(script) if err != nil { t.Fatal(err) } - if err := server.Load("testdata/refresh_devices.js", string(data)); err != nil { - t.Fatal(err) - } - ctx := context.Background() - if err := server.Start(ctx); err != nil { + if err := server.Start(ctx, script, string(data)); err != nil { t.Fatalf("%+v", errors.WithStack(err)) } @@ -87,12 +86,5 @@ func TestCastModuleRefreshDevices(t *testing.T) { t.Error(errors.WithStack(err)) } - promise, ok := app.IsPromise(result) - if !ok { - t.Fatal("expected promise") - } - - value := server.WaitForPromise(promise) - - spew.Dump(value.Export()) + spew.Dump(result) } diff --git a/pkg/module/fetch/envelope.go b/pkg/module/fetch/envelope.go new file mode 100644 index 0000000..2af30f7 --- /dev/null +++ b/pkg/module/fetch/envelope.go @@ -0,0 +1,38 @@ +package fetch + +import ( + "context" + "net/url" + + "forge.cadoles.com/arcad/edge/pkg/bus" +) + +const ( + AddressFetchRequest bus.Address = "module/fetch/request" + AddressFetchResponse bus.Address = "module/fetch/response" +) + +type FetchRequest struct { + Context context.Context + RequestID string + URL *url.URL + RemoteAddr string +} + +func NewFetchRequestEnvelope(ctx context.Context, remoteAddr string, url *url.URL) bus.Envelope { + return bus.NewEnvelope(AddressFetchRequest, &FetchRequest{ + Context: ctx, + URL: url, + RemoteAddr: remoteAddr, + }) +} + +type FetchResponse struct { + Allow bool +} + +func NewFetchResponseEnvelope(allow bool) bus.Envelope { + return bus.NewEnvelope(AddressFetchResponse, &FetchResponse{ + Allow: allow, + }) +} diff --git a/pkg/module/fetch/fetch_message.go b/pkg/module/fetch/fetch_message.go deleted file mode 100644 index f2493e1..0000000 --- a/pkg/module/fetch/fetch_message.go +++ /dev/null @@ -1,49 +0,0 @@ -package fetch - -import ( - "context" - "net/url" - - "forge.cadoles.com/arcad/edge/pkg/bus" - "github.com/oklog/ulid/v2" -) - -const ( - MessageNamespaceFetchRequest bus.MessageNamespace = "fetchRequest" - MessageNamespaceFetchResponse bus.MessageNamespace = "fetchResponse" -) - -type MessageFetchRequest struct { - Context context.Context - RequestID string - URL *url.URL - RemoteAddr string -} - -func (m *MessageFetchRequest) MessageNamespace() bus.MessageNamespace { - return MessageNamespaceFetchRequest -} - -func NewMessageFetchRequest(ctx context.Context, remoteAddr string, url *url.URL) *MessageFetchRequest { - return &MessageFetchRequest{ - Context: ctx, - RequestID: ulid.Make().String(), - RemoteAddr: remoteAddr, - URL: url, - } -} - -type MessageFetchResponse struct { - RequestID string - Allow bool -} - -func (m *MessageFetchResponse) MessageNamespace() bus.MessageNamespace { - return MessageNamespaceFetchResponse -} - -func NewMessageFetchResponse(requestID string) *MessageFetchResponse { - return &MessageFetchResponse{ - RequestID: requestID, - } -} diff --git a/pkg/http/fetch.go b/pkg/module/fetch/http.go similarity index 59% rename from pkg/http/fetch.go rename to pkg/module/fetch/http.go index 87c71e4..e30b975 100644 --- a/pkg/http/fetch.go +++ b/pkg/module/fetch/http.go @@ -1,60 +1,61 @@ -package http +package fetch import ( "io" "net/http" "net/url" - "forge.cadoles.com/arcad/edge/pkg/module" - "forge.cadoles.com/arcad/edge/pkg/module/fetch" + edgehttp "forge.cadoles.com/arcad/edge/pkg/http" + "github.com/go-chi/chi/v5" "github.com/pkg/errors" "gitlab.com/wpetit/goweb/logger" ) -func (h *Handler) handleAppFetch(w http.ResponseWriter, r *http.Request) { - h.mutex.RLock() - defer h.mutex.RUnlock() +func Mount() func(r chi.Router) { + return func(r chi.Router) { + r.Get("/api/v1/fetch", handleAppFetch) + } +} +func handleAppFetch(w http.ResponseWriter, r *http.Request) { ctx := r.Context() - ctx = module.WithContext(ctx, map[module.ContextKey]any{ - ContextKeyOriginRequest: r, - }) - rawURL := r.URL.Query().Get("url") url, err := url.Parse(rawURL) if err != nil { - jsonError(w, http.StatusBadRequest, errorCodeBadRequest) + edgehttp.JSONError(w, http.StatusBadRequest, edgehttp.ErrCodeBadRequest) return } - requestMsg := fetch.NewMessageFetchRequest(ctx, r.RemoteAddr, url) + requestMsg := NewFetchRequestEnvelope(ctx, r.RemoteAddr, url) - reply, err := h.bus.Request(ctx, requestMsg) + bus := edgehttp.ContextBus(ctx) + + reply, err := bus.Request(ctx, requestMsg) if err != nil { logger.Error(ctx, "could not retrieve fetch request reply", logger.CapturedE(errors.WithStack(err))) - jsonError(w, http.StatusInternalServerError, errorCodeInternalError) + edgehttp.JSONError(w, http.StatusInternalServerError, edgehttp.ErrCodeInternalError) return } logger.Debug(ctx, "fetch reply", logger.F("reply", reply)) - responseMsg, ok := reply.(*fetch.MessageFetchResponse) + responseMsg, ok := reply.Message().(*FetchResponse) if !ok { logger.Error( ctx, "unexpected fetch response message", logger.F("message", reply), ) - jsonError(w, http.StatusInternalServerError, errorCodeInternalError) + edgehttp.JSONError(w, http.StatusInternalServerError, edgehttp.ErrCodeInternalError) return } if !responseMsg.Allow { - jsonError(w, http.StatusForbidden, errorCodeForbidden) + edgehttp.JSONError(w, http.StatusForbidden, edgehttp.ErrCodeForbidden) return } @@ -65,7 +66,7 @@ func (h *Handler) handleAppFetch(w http.ResponseWriter, r *http.Request) { ctx, "could not create proxy request", logger.CapturedE(errors.WithStack(err)), ) - jsonError(w, http.StatusInternalServerError, errorCodeInternalError) + edgehttp.JSONError(w, http.StatusInternalServerError, edgehttp.ErrCodeInternalError) return } @@ -78,13 +79,15 @@ func (h *Handler) handleAppFetch(w http.ResponseWriter, r *http.Request) { proxyReq.Header.Add("X-Forwarded-From", r.RemoteAddr) - res, err := h.httpClient.Do(proxyReq) + httpClient := edgehttp.ContextHTTPClient(ctx) + + res, err := httpClient.Do(proxyReq) if err != nil { logger.Error( ctx, "could not execute proxy request", logger.CapturedE(errors.WithStack(err)), ) - jsonError(w, http.StatusInternalServerError, errorCodeInternalError) + edgehttp.JSONError(w, http.StatusInternalServerError, edgehttp.ErrCodeInternalError) return } diff --git a/pkg/module/fetch/module.go b/pkg/module/fetch/module.go index fdc3930..9da5a16 100644 --- a/pkg/module/fetch/module.go +++ b/pkg/module/fetch/module.go @@ -40,10 +40,10 @@ func (m *Module) get(call goja.FunctionCall, rt *goja.Runtime) goja.Value { func (m *Module) handleMessages() { ctx := context.Background() - err := m.bus.Reply(ctx, MessageNamespaceFetchRequest, func(msg bus.Message) (bus.Message, error) { - fetchRequest, ok := msg.(*MessageFetchRequest) + fetchErrs := m.bus.Reply(ctx, AddressFetchRequest, func(env bus.Envelope) (any, error) { + fetchRequest, ok := env.Message().(*FetchRequest) if !ok { - return nil, errors.Wrapf(bus.ErrUnexpectedMessage, "expected message fetch request, got '%T'", msg) + return nil, errors.Wrapf(bus.ErrUnexpectedMessage, "expected fetch request, got '%T'", env.Message()) } res, err := m.handleFetchRequest(fetchRequest) @@ -57,13 +57,14 @@ func (m *Module) handleMessages() { return res, nil }) - if err != nil { - panic(errors.WithStack(err)) + + for err := range fetchErrs { + logger.Fatal(ctx, "error while replying to fetch requests", logger.CapturedE(errors.WithStack(err))) } } -func (m *Module) handleFetchRequest(req *MessageFetchRequest) (*MessageFetchResponse, error) { - res := NewMessageFetchResponse(req.RequestID) +func (m *Module) handleFetchRequest(req *FetchRequest) (*FetchResponse, error) { + res := &FetchResponse{} ctx := logger.With( req.Context, @@ -83,11 +84,11 @@ func (m *Module) handleFetchRequest(req *MessageFetchRequest) (*MessageFetchResp return nil, errors.WithStack(err) } - result, ok := rawResult.Export().(map[string]interface{}) + result, ok := rawResult.(map[string]interface{}) if !ok { return nil, errors.Errorf( "unexpected onClientFetch result: expected 'map[string]interface{}', got '%T'", - rawResult.Export(), + rawResult, ) } diff --git a/pkg/module/fetch/module_test.go b/pkg/module/fetch/module_test.go index a9e8c49..88ba88b 100644 --- a/pkg/module/fetch/module_test.go +++ b/pkg/module/fetch/module_test.go @@ -2,8 +2,8 @@ package fetch import ( "context" - "io/ioutil" "net/url" + "os" "testing" "time" @@ -18,7 +18,9 @@ import ( func TestFetchModule(t *testing.T) { t.Parallel() - logger.SetLevel(slog.LevelDebug) + if testing.Verbose() { + logger.SetLevel(slog.LevelDebug) + } bus := memory.NewBus() @@ -28,22 +30,20 @@ func TestFetchModule(t *testing.T) { ModuleFactory(bus), ) - data, err := ioutil.ReadFile("testdata/fetch.js") + path := "testdata/fetch.js" + + data, err := os.ReadFile(path) if err != nil { t.Fatalf("%+v", errors.WithStack(err)) } - if err := server.Load("testdata/fetch.js", string(data)); err != nil { + ctx := context.Background() + if err := server.Start(ctx, path, string(data)); err != nil { t.Fatalf("%+v", errors.WithStack(err)) } defer server.Stop() - ctx := context.Background() - if err := server.Start(ctx); err != nil { - t.Fatalf("%+v", errors.WithStack(err)) - } - // Wait for module to startup time.Sleep(1 * time.Second) @@ -53,33 +53,33 @@ func TestFetchModule(t *testing.T) { remoteAddr := "127.0.0.1" url, _ := url.Parse("http://example.com") - rawReply, err := bus.Request(ctx, NewMessageFetchRequest(ctx, remoteAddr, url)) + reply, err := bus.Request(ctx, NewFetchRequestEnvelope(ctx, remoteAddr, url)) if err != nil { t.Fatalf("%+v", errors.WithStack(err)) } - reply, ok := rawReply.(*MessageFetchResponse) + response, ok := reply.Message().(*FetchResponse) if !ok { - t.Fatalf("unexpected reply type '%T'", rawReply) + t.Fatalf("unexpected reply message type '%T'", reply.Message()) } - if e, g := true, reply.Allow; e != g { + if e, g := true, response.Allow; e != g { t.Errorf("reply.Allow: expected '%v', got '%v'", e, g) } url, _ = url.Parse("https://google.com") - rawReply, err = bus.Request(ctx, NewMessageFetchRequest(ctx, remoteAddr, url)) + reply, err = bus.Request(ctx, NewFetchRequestEnvelope(ctx, remoteAddr, url)) if err != nil { t.Fatalf("%+v", errors.WithStack(err)) } - reply, ok = rawReply.(*MessageFetchResponse) + response, ok = reply.Message().(*FetchResponse) if !ok { - t.Fatalf("unexpected reply type '%T'", rawReply) + t.Fatalf("unexpected reply message type '%T'", reply.Message()) } - if e, g := false, reply.Allow; e != g { + if e, g := false, response.Allow; e != g { t.Errorf("reply.Allow: expected '%v', got '%v'", e, g) } } diff --git a/pkg/module/lifecycle.go b/pkg/module/lifecycle.go index 983617f..af98219 100644 --- a/pkg/module/lifecycle.go +++ b/pkg/module/lifecycle.go @@ -5,7 +5,6 @@ import ( "forge.cadoles.com/arcad/edge/pkg/app" "github.com/dop251/goja" - "github.com/pkg/errors" "gitlab.com/wpetit/goweb/logger" ) @@ -19,16 +18,28 @@ func (m *LifecycleModule) Export(export *goja.Object) { } func (m *LifecycleModule) OnInit(ctx context.Context, rt *goja.Runtime) (err error) { - _, ok := goja.AssertFunction(rt.Get("onInit")) + call, ok := goja.AssertFunction(rt.Get("onInit")) if !ok { logger.Warn(ctx, "could not find onInit() function") return nil } - if _, err := rt.RunString("setTimeout(onInit, 0)"); err != nil { - return errors.WithStack(err) - } + defer func() { + recovered := recover() + if recovered == nil { + return + } + + recoveredErr, ok := recovered.(error) + if !ok { + panic(recovered) + } + + err = recoveredErr + }() + + call(nil, rt.ToValue(ctx)) return nil } diff --git a/pkg/module/message.go b/pkg/module/message.go deleted file mode 100644 index 728d03d..0000000 --- a/pkg/module/message.go +++ /dev/null @@ -1,38 +0,0 @@ -package module - -import ( - "context" - - "forge.cadoles.com/arcad/edge/pkg/bus" -) - -const ( - MessageNamespaceClient bus.MessageNamespace = "client" - MessageNamespaceServer bus.MessageNamespace = "server" -) - -type ServerMessage struct { - Context context.Context - Data interface{} -} - -func (m *ServerMessage) MessageNamespace() bus.MessageNamespace { - return MessageNamespaceServer -} - -func NewServerMessage(ctx context.Context, data interface{}) *ServerMessage { - return &ServerMessage{ctx, data} -} - -type ClientMessage struct { - Context context.Context - Data map[string]interface{} -} - -func (m *ClientMessage) MessageNamespace() bus.MessageNamespace { - return MessageNamespaceClient -} - -func NewClientMessage(ctx context.Context, data map[string]interface{}) *ClientMessage { - return &ClientMessage{ctx, data} -} diff --git a/pkg/module/net/module.go b/pkg/module/net/module.go index d26a6b3..9e4c44e 100644 --- a/pkg/module/net/module.go +++ b/pkg/module/net/module.go @@ -5,7 +5,7 @@ import ( "forge.cadoles.com/arcad/edge/pkg/app" "forge.cadoles.com/arcad/edge/pkg/bus" - edgeHTTP "forge.cadoles.com/arcad/edge/pkg/http" + edgehttp "forge.cadoles.com/arcad/edge/pkg/http" "forge.cadoles.com/arcad/edge/pkg/module" "forge.cadoles.com/arcad/edge/pkg/module/util" "github.com/dop251/goja" @@ -38,10 +38,9 @@ func (m *Module) broadcast(call goja.FunctionCall, rt *goja.Runtime) goja.Value } data := call.Argument(0).Export() - ctx := context.Background() - msg := module.NewServerMessage(ctx, data) - if err := m.bus.Publish(ctx, msg); err != nil { + env := edgehttp.NewOutgoingMessageEnvelope("", data) + if err := m.bus.Publish(env); err != nil { panic(rt.ToValue(errors.WithStack(err))) } @@ -53,38 +52,33 @@ func (m *Module) send(call goja.FunctionCall, rt *goja.Runtime) goja.Value { panic(rt.ToValue(errors.New("invalid number of argument"))) } - var ctx context.Context - firstArg := call.Argument(0) sessionID, ok := firstArg.Export().(string) - if ok { - ctx = module.WithContext(context.Background(), map[module.ContextKey]any{ - edgeHTTP.ContextKeySessionID: sessionID, - }) - } else { - ctx = util.AssertContext(firstArg, rt) + if !ok { + ctx := util.AssertContext(firstArg, rt) + sessionID = module.ContextValue[string](ctx, edgehttp.ContextKeySessionID) } data := call.Argument(1).Export() - msg := module.NewServerMessage(ctx, data) - if err := m.bus.Publish(ctx, msg); err != nil { + env := edgehttp.NewOutgoingMessageEnvelope(sessionID, data) + if err := m.bus.Publish(env); err != nil { panic(rt.ToValue(errors.WithStack(err))) } return nil } -func (m *Module) handleClientMessages() { +func (m *Module) handleIncomingMessages() { ctx := context.Background() logger.Debug( ctx, - "subscribing to bus messages", + "subscribing to bus envelopes", ) - clientMessages, err := m.bus.Subscribe(ctx, module.MessageNamespaceClient) + envelopes, err := m.bus.Subscribe(ctx, edgehttp.AddressIncomingMessage) if err != nil { panic(errors.WithStack(err)) } @@ -92,16 +86,16 @@ func (m *Module) handleClientMessages() { defer func() { logger.Debug( ctx, - "unsubscribing from bus messages", + "unsubscribing from bus envelopes", ) - m.bus.Unsubscribe(ctx, module.MessageNamespaceClient, clientMessages) + m.bus.Unsubscribe(edgehttp.AddressIncomingMessage, envelopes) }() for { logger.Debug( ctx, - "waiting for next message", + "waiting for next envelope", ) select { case <-ctx.Done(): @@ -112,13 +106,13 @@ func (m *Module) handleClientMessages() { return - case msg := <-clientMessages: - clientMessage, ok := msg.(*module.ClientMessage) + case env := <-envelopes: + incomingMessage, ok := env.Message().(*edgehttp.IncomingMessage) if !ok { logger.Warn( ctx, "unexpected message type", - logger.F("message", msg), + logger.F("message", env.Message()), ) continue @@ -126,11 +120,11 @@ func (m *Module) handleClientMessages() { logger.Debug( ctx, - "received client message", - logger.F("message", clientMessage), + "received incoming message", + logger.F("message", incomingMessage), ) - if _, err := m.server.ExecFuncByName(clientMessage.Context, "onClientMessage", clientMessage.Context, clientMessage.Data); err != nil { + if _, err := m.server.ExecFuncByName(incomingMessage.Context, "onClientMessage", incomingMessage.Context, incomingMessage.Payload); err != nil { if errors.Is(err, app.ErrFuncDoesNotExist) { continue } @@ -152,7 +146,7 @@ func ModuleFactory(bus bus.Bus) app.ServerModuleFactory { bus: bus, } - go module.handleClientMessages() + go module.handleIncomingMessages() return module } diff --git a/pkg/module/rpc.go b/pkg/module/rpc.go deleted file mode 100644 index 5877ffb..0000000 --- a/pkg/module/rpc.go +++ /dev/null @@ -1,278 +0,0 @@ -package module - -import ( - "context" - "fmt" - "sync" - - "forge.cadoles.com/arcad/edge/pkg/app" - "forge.cadoles.com/arcad/edge/pkg/bus" - "forge.cadoles.com/arcad/edge/pkg/module/util" - "github.com/dop251/goja" - "github.com/pkg/errors" - "gitlab.com/wpetit/goweb/logger" -) - -type RPCRequest struct { - Method string - Params interface{} - ID interface{} -} - -type RPCError struct { - Code int `json:"code"` - Message string `json:"message"` - Data interface{} `json:"data"` -} - -type RPCResponse struct { - Result interface{} - Error *RPCError - ID interface{} -} - -type RPCModule struct { - server *app.Server - bus bus.Bus - callbacks sync.Map -} - -func (m *RPCModule) Name() string { - return "rpc" -} - -func (m *RPCModule) Export(export *goja.Object) { - if err := export.Set("register", m.register); err != nil { - panic(errors.Wrap(err, "could not set 'register' function")) - } - - if err := export.Set("unregister", m.unregister); err != nil { - panic(errors.Wrap(err, "could not set 'unregister' function")) - } -} - -func (m *RPCModule) OnInit(ctx context.Context, rt *goja.Runtime) error { - go m.handleMessages(ctx) - - return nil -} - -func (m *RPCModule) register(call goja.FunctionCall, rt *goja.Runtime) goja.Value { - fnName := util.AssertString(call.Argument(0), rt) - - var ( - callable goja.Callable - ok bool - ) - - if len(call.Arguments) > 1 { - callable, ok = goja.AssertFunction(call.Argument(1)) - } else { - callable, ok = goja.AssertFunction(rt.Get(fnName)) - } - - if !ok { - panic(rt.NewTypeError("method should be a valid function")) - } - - ctx := context.Background() - - logger.Debug(ctx, "registering method", logger.F("method", fnName)) - - m.callbacks.Store(fnName, callable) - - return nil -} - -func (m *RPCModule) unregister(call goja.FunctionCall, rt *goja.Runtime) goja.Value { - fnName := util.AssertString(call.Argument(0), rt) - - m.callbacks.Delete(fnName) - - return nil -} - -func (m *RPCModule) handleMessages(ctx context.Context) { - clientMessages, err := m.bus.Subscribe(ctx, MessageNamespaceClient) - if err != nil { - panic(errors.WithStack(err)) - } - - defer func() { - m.bus.Unsubscribe(ctx, MessageNamespaceClient, clientMessages) - }() - - sendRes := func(ctx context.Context, req *RPCRequest, result goja.Value) { - res := &RPCResponse{ - ID: req.ID, - Result: result.Export(), - } - - logger.Debug(ctx, "sending rpc response", logger.F("response", res)) - - if err := m.sendResponse(ctx, res); err != nil { - logger.Error( - ctx, "could not send response", - logger.CapturedE(errors.WithStack(err)), - logger.F("response", res), - logger.F("request", req), - ) - } - } - - for msg := range clientMessages { - go m.handleMessage(ctx, msg, sendRes) - } -} - -func (m *RPCModule) handleMessage(ctx context.Context, msg bus.Message, sendRes func(ctx context.Context, req *RPCRequest, result goja.Value)) { - clientMessage, ok := msg.(*ClientMessage) - if !ok { - logger.Warn(ctx, "unexpected bus message", logger.F("message", msg)) - - return - } - - ok, req := m.isRPCRequest(clientMessage) - if !ok { - return - } - - logger.Debug(ctx, "received rpc request", logger.F("request", req)) - - rawCallable, exists := m.callbacks.Load(req.Method) - if !exists { - logger.Debug(ctx, "method not found", logger.F("req", req)) - - if err := m.sendMethodNotFoundResponse(clientMessage.Context, req); err != nil { - logger.Error( - ctx, "could not send method not found response", - logger.CapturedE(errors.WithStack(err)), - logger.F("request", req), - ) - } - - return - } - - callable, ok := rawCallable.(goja.Callable) - if !ok { - logger.Debug(ctx, "invalid method", logger.F("req", req)) - - if err := m.sendMethodNotFoundResponse(clientMessage.Context, req); err != nil { - logger.Error( - ctx, "could not send method not found response", - logger.CapturedE(errors.WithStack(err)), - logger.F("request", req), - ) - } - - return - } - - result, err := m.server.Exec(clientMessage.Context, callable, clientMessage.Context, req.Params) - if err != nil { - logger.Error( - ctx, "rpc call error", - logger.CapturedE(errors.WithStack(err)), - logger.F("request", req), - ) - - if err := m.sendErrorResponse(clientMessage.Context, req, err); err != nil { - logger.Error( - ctx, "could not send error response", - logger.CapturedE(errors.WithStack(err)), - logger.F("originalError", err), - logger.F("request", req), - ) - } - - return - } - - promise, ok := app.IsPromise(result) - if ok { - go func(ctx context.Context, req *RPCRequest, promise *goja.Promise) { - result := m.server.WaitForPromise(promise) - sendRes(ctx, req, result) - }(clientMessage.Context, req, promise) - } else { - sendRes(clientMessage.Context, req, result) - } -} - -func (m *RPCModule) sendErrorResponse(ctx context.Context, req *RPCRequest, err error) error { - return m.sendResponse(ctx, &RPCResponse{ - ID: req.ID, - Result: nil, - Error: &RPCError{ - Code: -32603, - Message: err.Error(), - }, - }) -} - -func (m *RPCModule) sendMethodNotFoundResponse(ctx context.Context, req *RPCRequest) error { - return m.sendResponse(ctx, &RPCResponse{ - ID: req.ID, - Result: nil, - Error: &RPCError{ - Code: -32601, - Message: fmt.Sprintf("method not found"), - }, - }) -} - -func (m *RPCModule) sendResponse(ctx context.Context, res *RPCResponse) error { - msg := NewServerMessage(ctx, map[string]interface{}{ - "jsonrpc": "2.0", - "id": res.ID, - "error": res.Error, - "result": res.Result, - }) - - if err := m.bus.Publish(ctx, msg); err != nil { - return errors.WithStack(err) - } - - return nil -} - -func (m *RPCModule) isRPCRequest(msg *ClientMessage) (bool, *RPCRequest) { - jsonRPC, exists := msg.Data["jsonrpc"] - if !exists || jsonRPC != "2.0" { - return false, nil - } - - rawMethod, exists := msg.Data["method"] - if !exists { - return false, nil - } - - method, ok := rawMethod.(string) - if !ok { - return false, nil - } - - id := msg.Data["id"] - params := msg.Data["params"] - - return true, &RPCRequest{ - ID: id, - Method: method, - Params: params, - } -} - -func RPCModuleFactory(bus bus.Bus) app.ServerModuleFactory { - return func(server *app.Server) app.ServerModule { - mod := &RPCModule{ - server: server, - bus: bus, - } - - return mod - } -} - -var _ app.InitializableModule = &RPCModule{} diff --git a/pkg/module/rpc/envelope.go b/pkg/module/rpc/envelope.go new file mode 100644 index 0000000..2654795 --- /dev/null +++ b/pkg/module/rpc/envelope.go @@ -0,0 +1,21 @@ +package rpc + +import ( + "context" + + "forge.cadoles.com/arcad/edge/pkg/bus" +) + +const ( + Address bus.Address = "module/rpc" +) + +type Request struct { + Context context.Context + Method string + Params any +} + +func NewRequestEnvelope(ctx context.Context, method string, params any) bus.Envelope { + return bus.NewEnvelope(Address, &Request{ctx, method, params}) +} diff --git a/pkg/module/rpc/error.go b/pkg/module/rpc/error.go new file mode 100644 index 0000000..0b0b7d0 --- /dev/null +++ b/pkg/module/rpc/error.go @@ -0,0 +1,7 @@ +package rpc + +import "errors" + +var ( + ErrMethodNotFound = errors.New("method not found") +) diff --git a/pkg/module/rpc/jsonrpc.go b/pkg/module/rpc/jsonrpc.go new file mode 100644 index 0000000..08c7bc0 --- /dev/null +++ b/pkg/module/rpc/jsonrpc.go @@ -0,0 +1,19 @@ +package rpc + +import "fmt" + +type JSONRPCRequest struct { + ID any + Method string + Params any +} + +type JSONRPCError struct { + Code int `json:"code"` + Message string `json:"message"` + Data interface{} `json:"data"` +} + +func (e *JSONRPCError) Error() string { + return fmt.Sprintf("json-rpc error: %d - %s", e.Code, e.Message) +} diff --git a/pkg/module/rpc/module.go b/pkg/module/rpc/module.go new file mode 100644 index 0000000..f134677 --- /dev/null +++ b/pkg/module/rpc/module.go @@ -0,0 +1,256 @@ +package rpc + +import ( + "context" + "sync" + + "forge.cadoles.com/arcad/edge/pkg/app" + "forge.cadoles.com/arcad/edge/pkg/bus" + edgehttp "forge.cadoles.com/arcad/edge/pkg/http" + "forge.cadoles.com/arcad/edge/pkg/module" + "forge.cadoles.com/arcad/edge/pkg/module/util" + "github.com/dop251/goja" + "github.com/pkg/errors" + "gitlab.com/wpetit/goweb/logger" +) + +type Module struct { + server *app.Server + bus bus.Bus + callbacks sync.Map +} + +func (m *Module) Name() string { + return "rpc" +} + +func (m *Module) Export(export *goja.Object) { + if err := export.Set("register", m.register); err != nil { + panic(errors.Wrap(err, "could not set 'register' function")) + } + + if err := export.Set("unregister", m.unregister); err != nil { + panic(errors.Wrap(err, "could not set 'unregister' function")) + } +} + +func (m *Module) OnInit(ctx context.Context, rt *goja.Runtime) error { + requestErrs := m.bus.Reply(ctx, Address, m.handleRequest) + go func() { + for err := range requestErrs { + logger.Error(ctx, "error while replying to rpc requests", logger.CapturedE(errors.WithStack(err))) + } + }() + + httpIncomingMessages, err := m.bus.Subscribe(ctx, edgehttp.AddressIncomingMessage) + if err != nil { + return errors.WithStack(err) + } + + go m.handleIncomingHTTPMessages(ctx, httpIncomingMessages) + + return nil +} + +func (m *Module) register(call goja.FunctionCall, rt *goja.Runtime) goja.Value { + fnName := util.AssertString(call.Argument(0), rt) + + var ( + callable goja.Callable + ok bool + ) + + if len(call.Arguments) > 1 { + callable, ok = goja.AssertFunction(call.Argument(1)) + } else { + callable, ok = goja.AssertFunction(rt.Get(fnName)) + } + + if !ok { + panic(rt.NewTypeError("method should be a valid function")) + } + + ctx := context.Background() + + logger.Debug(ctx, "registering method", logger.F("method", fnName)) + + m.callbacks.Store(fnName, callable) + + return nil +} + +func (m *Module) unregister(call goja.FunctionCall, rt *goja.Runtime) goja.Value { + fnName := util.AssertString(call.Argument(0), rt) + + m.callbacks.Delete(fnName) + + return nil +} + +func (m *Module) handleRequest(env bus.Envelope) (any, error) { + request, ok := env.Message().(*Request) + if !ok { + logger.Warn(context.Background(), "unexpected bus message", logger.F("message", env.Message())) + + return nil, errors.WithStack(bus.ErrUnexpectedMessage) + } + + ctx := logger.With(request.Context, logger.F("request", request)) + + logger.Debug(ctx, "received rpc request") + + rawCallable, exists := m.callbacks.Load(request.Method) + if !exists { + logger.Debug(ctx, "method not found") + + return nil, errors.WithStack(ErrMethodNotFound) + } + + callable, ok := rawCallable.(goja.Callable) + if !ok { + logger.Debug(ctx, "invalid method") + + return nil, errors.WithStack(ErrMethodNotFound) + } + + result, err := m.server.Exec(ctx, callable, request.Context, request.Params) + if err != nil { + logger.Error( + ctx, "rpc call error", + logger.CapturedE(errors.WithStack(err)), + ) + + return nil, errors.WithStack(err) + } + + return result, nil +} + +func (m *Module) handleIncomingHTTPMessages(ctx context.Context, incoming <-chan bus.Envelope) { + defer func() { + m.bus.Unsubscribe(edgehttp.AddressIncomingMessage, incoming) + }() + + for env := range incoming { + msg, ok := env.Message().(*edgehttp.IncomingMessage) + if !ok { + logger.Error(ctx, "unexpected incoming http message type", logger.F("message", env.Message())) + continue + } + + jsonReq, ok := m.isRPCRequest(msg.Payload) + if !ok { + continue + } + + requestCtx := logger.With(msg.Context, logger.F("rpcRequestMethod", jsonReq.Method), logger.F("rpcRequestID", jsonReq.ID)) + + request := NewRequestEnvelope(msg.Context, jsonReq.Method, jsonReq.Params) + sessionID := module.ContextValue[string](msg.Context, edgehttp.ContextKeySessionID) + + reply, err := m.bus.Request(requestCtx, request) + if err != nil { + err = errors.WithStack(err) + + logger.Error( + ctx, "could not execute rpc request", + logger.CapturedE(err), + ) + + if errors.Is(err, ErrMethodNotFound) { + if err := m.sendMethodNotFoundResponse(sessionID, jsonReq.ID); err != nil { + logger.Error( + ctx, "could not send json rpc error response", + logger.CapturedE(errors.WithStack(err)), + ) + } + + continue + } + + if err := m.sendErrorResponse(sessionID, jsonReq.ID, err); err != nil { + logger.Error( + ctx, "could not send json rpc error response", + logger.CapturedE(errors.WithStack(err)), + ) + } + + continue + } + + if err := m.sendResponse(sessionID, jsonReq.ID, reply.Message(), nil); err != nil { + logger.Error( + ctx, "could not send json rpc result response", + logger.CapturedE(err), + ) + } + } +} + +func (m *Module) sendErrorResponse(sessionID string, requestID any, err error) error { + return m.sendResponse(sessionID, requestID, nil, &JSONRPCError{ + Code: -32603, + Message: err.Error(), + }) +} + +func (m *Module) sendMethodNotFoundResponse(sessionID string, requestID any) error { + return m.sendResponse(sessionID, requestID, nil, &JSONRPCError{ + Code: -32601, + Message: "method not found", + }) +} + +func (m *Module) sendResponse(sessionID string, requestID any, result any, err error) error { + env := edgehttp.NewOutgoingMessageEnvelope(sessionID, map[string]interface{}{ + "jsonrpc": "2.0", + "id": requestID, + "error": err, + "result": result, + }) + + if err := m.bus.Publish(env); err != nil { + return errors.WithStack(err) + } + + return nil +} + +func (m *Module) isRPCRequest(payload map[string]any) (*JSONRPCRequest, bool) { + jsonRPC, exists := payload["jsonrpc"] + if !exists || jsonRPC != "2.0" { + return nil, false + } + + rawMethod, exists := payload["method"] + if !exists { + return nil, false + } + + method, ok := rawMethod.(string) + if !ok { + return nil, false + } + + id := payload["id"] + params := payload["params"] + + return &JSONRPCRequest{ + ID: id, + Method: method, + Params: params, + }, true +} + +func ModuleFactory(bus bus.Bus) app.ServerModuleFactory { + return func(server *app.Server) app.ServerModule { + mod := &Module{ + server: server, + bus: bus, + } + + return mod + } +} + +var _ app.InitializableModule = &Module{} diff --git a/pkg/module/rpc/module_test.go b/pkg/module/rpc/module_test.go new file mode 100644 index 0000000..32af3f6 --- /dev/null +++ b/pkg/module/rpc/module_test.go @@ -0,0 +1,109 @@ +package rpc + +import ( + "context" + "os" + "sync" + "testing" + + "forge.cadoles.com/arcad/edge/pkg/app" + "forge.cadoles.com/arcad/edge/pkg/bus" + "forge.cadoles.com/arcad/edge/pkg/bus/memory" + "forge.cadoles.com/arcad/edge/pkg/module" + "github.com/pkg/errors" + "gitlab.com/wpetit/goweb/logger" +) + +func TestServerExecDeadlock(t *testing.T) { + if testing.Verbose() { + logger.SetLevel(logger.LevelDebug) + } + + b := memory.NewBus(memory.WithBufferSize(1)) + + server := app.NewServer( + module.ConsoleModuleFactory(), + ModuleFactory(b), + module.LifecycleModuleFactory(), + ) + + data, err := os.ReadFile("testdata/deadlock.js") + if err != nil { + t.Fatalf("%+v", errors.WithStack(err)) + } + + ctx := context.Background() + + t.Log("starting server") + + if err := server.Start(ctx, "deadlock.js", string(data)); err != nil { + t.Fatalf("%+v", errors.WithStack(err)) + } + + defer server.Stop() + + t.Log("server started") + + count := 100 + delay := 100 + + var wg sync.WaitGroup + + wg.Add(count) + + for i := 0; i < count; i++ { + go func(i int) { + defer wg.Done() + + t.Logf("calling %d", i) + + isCanceled := i%2 == 0 + + var ctx context.Context + if isCanceled { + canceledCtx, cancel := context.WithCancel(context.Background()) + cancel() + ctx = canceledCtx + } else { + ctx = context.Background() + } + + env := NewRequestEnvelope(ctx, "doSomethingLong", map[string]any{ + "i": i, + "delay": delay, + }) + + t.Logf("publishing envelope #%d", i) + + reply, err := b.Request(ctx, env) + if err != nil { + if errors.Is(err, context.Canceled) && isCanceled { + return + } + + if errors.Is(err, bus.ErrNoResponse) && isCanceled { + return + } + + t.Errorf("%+v", errors.WithStack(err)) + + return + } + + result, ok := reply.Message().(int64) + if !ok { + t.Errorf("response.Result: expected type '%T', got '%T'", int64(0), reply.Message()) + + return + } + + if e, g := i, int(result); e != g { + t.Errorf("response.Result: expected '%v', got '%v'", e, g) + + return + } + }(i) + } + + wg.Wait() +} diff --git a/pkg/module/rpc/testdata/deadlock.js b/pkg/module/rpc/testdata/deadlock.js new file mode 100644 index 0000000..8a667a6 --- /dev/null +++ b/pkg/module/rpc/testdata/deadlock.js @@ -0,0 +1,14 @@ +function onInit() { + rpc.register("doSomethingLong", doSomethingLong) +} + +function doSomethingLong(ctx, params) { + var start = Date.now() + + while (true) { + var now = Date.now() + if (now - start >= params.delay) break + } + + return params.i; +} diff --git a/pkg/module/share/module_test.go b/pkg/module/share/module_test.go index 587d03e..5ce367d 100644 --- a/pkg/module/share/module_test.go +++ b/pkg/module/share/module_test.go @@ -33,18 +33,14 @@ func TestModule(t *testing.T) { t.Fatalf("%+v", errors.WithStack(err)) } - if err := server.Load("testdata/share.js", string(data)); err != nil { + ctx := context.Background() + if err := server.Start(ctx, "testdata/share.js", string(data)); err != nil { t.Fatalf("%+v", errors.WithStack(err)) } - ctx := context.Background() - if err := server.Start(ctx); err != nil { - t.Fatalf("%+v", errors.WithStack(err)) - } + defer server.Stop() if _, err := server.ExecFuncByName(context.Background(), "testModule"); err != nil { t.Fatalf("%+v", errors.WithStack(err)) } - - server.Stop() } diff --git a/pkg/module/store/module_test.go b/pkg/module/store/module_test.go index 080c9f2..1b70610 100644 --- a/pkg/module/store/module_test.go +++ b/pkg/module/store/module_test.go @@ -27,18 +27,14 @@ func TestStoreModule(t *testing.T) { t.Fatalf("%+v", errors.WithStack(err)) } - if err := server.Load("testdata/store.js", string(data)); err != nil { + ctx := context.Background() + if err := server.Start(ctx, "testdata/store.js", string(data)); err != nil { t.Fatalf("%+v", errors.WithStack(err)) } - ctx := context.Background() - if err := server.Start(ctx); err != nil { - t.Fatalf("%+v", errors.WithStack(err)) - } + defer server.Stop() if _, err := server.ExecFuncByName(context.Background(), "testStore"); err != nil { t.Fatalf("%+v", errors.WithStack(err)) } - - server.Stop() } diff --git a/pkg/storage/driver/rpc/client/client_pool.go b/pkg/storage/driver/rpc/client/client_pool.go index 7646d54..6126911 100644 --- a/pkg/storage/driver/rpc/client/client_pool.go +++ b/pkg/storage/driver/rpc/client/client_pool.go @@ -5,6 +5,7 @@ import ( "net/url" "strconv" "sync" + "time" "github.com/jackc/puddle/v2" "github.com/keegancsmith/rpc" @@ -74,21 +75,59 @@ func WithPooledClient(serverURL *url.URL) WithClientFunc { return errors.WithStack(err) } - clientResource, err := pool.Acquire(ctx) - if err != nil { - return errors.WithStack(err) - } + attempts := 0 + max := 5 - if err := fn(ctx, clientResource.Value()); err != nil { - if errors.Is(err, rpc.ErrShutdown) { - clientResource.Destroy() + for { + if attempts >= max { + logger.Debug(ctx, "rpc client call retrying failed", logger.F("attempts", attempts)) + + return errors.Wrapf(err, "rpc client call failed after %d attempts", max) } - return errors.WithStack(err) + clientResource, err := pool.Acquire(ctx) + if err != nil { + return errors.WithStack(err) + } + + client := clientResource.Value() + + if err := fn(ctx, client); err != nil { + if errors.Is(err, rpc.ErrShutdown) { + clientResource.Destroy() + + wait := time.Duration(8<<(attempts+1)) * time.Millisecond + + logger.Warn( + ctx, "rpc client connection is shutdown, retrying", + logger.F("attempts", attempts), + logger.F("max", max), + logger.F("delay", wait), + ) + + timer := time.NewTimer(wait) + select { + case <-timer.C: + attempts++ + continue + + case <-ctx.Done(): + if err := ctx.Err(); err != nil { + return errors.WithStack(err) + } + + return nil + } + } + + clientResource.Release() + + return errors.WithStack(err) + } + + clientResource.Release() + + return nil } - - clientResource.Release() - - return nil } } diff --git a/pkg/storage/driver/rpc/server/blob/new_blob_reader.go b/pkg/storage/driver/rpc/server/blob/new_blob_reader.go index b1d1286..006bd43 100644 --- a/pkg/storage/driver/rpc/server/blob/new_blob_reader.go +++ b/pkg/storage/driver/rpc/server/blob/new_blob_reader.go @@ -45,12 +45,12 @@ func (s *Service) NewBlobReader(ctx context.Context, args *NewBlobReaderArgs, re func (s *Service) getOpenedReader(id ReaderID) (io.ReadSeekCloser, error) { raw, exists := s.readers.Load(id) if !exists { - return nil, errors.Errorf("could not find writer '%s'", id) + return nil, errors.Errorf("could not find reader '%s'", id) } reader, ok := raw.(io.ReadSeekCloser) if !ok { - return nil, errors.Errorf("unexpected type '%T' for writer", raw) + return nil, errors.Errorf("unexpected type '%T' for reader", raw) } return reader, nil