From 8789b85d921150a0c69d406158fa651bae2ca442 Mon Sep 17 00:00:00 2001 From: William Petit Date: Wed, 1 Mar 2023 13:04:40 +0100 Subject: [PATCH] feat(server): handle panic in runtime ref #2 --- pkg/app/server.go | 36 +++++++++++++++++++++++++++------ pkg/module/auth/module_test.go | 4 ++-- pkg/module/blob.go | 4 ++-- pkg/module/cast/module_test.go | 3 ++- pkg/module/lifecycle.go | 2 +- pkg/module/net/module.go | 2 +- pkg/module/rpc.go | 2 +- pkg/module/store/module_test.go | 3 ++- 8 files changed, 41 insertions(+), 15 deletions(-) diff --git a/pkg/app/server.go b/pkg/app/server.go index 78708ee..1c0f661 100644 --- a/pkg/app/server.go +++ b/pkg/app/server.go @@ -1,15 +1,20 @@ package app import ( + "context" "math/rand" "sync" "github.com/dop251/goja" "github.com/dop251/goja_nodejs/eventloop" "github.com/pkg/errors" + "gitlab.com/wpetit/goweb/logger" ) -var ErrFuncDoesNotExist = errors.New("function does not exist") +var ( + ErrFuncDoesNotExist = errors.New("function does not exist") + ErUnknownError = errors.New("unknown error") +) type Server struct { runtime *goja.Runtime @@ -26,16 +31,18 @@ func (s *Server) Load(name string, src string) error { return nil } -func (s *Server) ExecFuncByName(funcName string, args ...interface{}) (goja.Value, error) { +func (s *Server) ExecFuncByName(ctx context.Context, funcName string, args ...interface{}) (goja.Value, error) { + ctx = logger.With(ctx, logger.F("function", funcName), logger.F("args", args)) + callable, ok := goja.AssertFunction(s.runtime.Get(funcName)) if !ok { return nil, errors.WithStack(ErrFuncDoesNotExist) } - return s.Exec(callable, args...) + return s.Exec(ctx, callable, args...) } -func (s *Server) Exec(callable goja.Callable, args ...interface{}) (goja.Value, error) { +func (s *Server) Exec(ctx context.Context, callable goja.Callable, args ...interface{}) (goja.Value, error) { var ( wg sync.WaitGroup value goja.Value @@ -45,6 +52,25 @@ func (s *Server) Exec(callable goja.Callable, args ...interface{}) (goja.Value, wg.Add(1) s.loop.RunOnLoop(func(vm *goja.Runtime) { + logger.Debug(ctx, "executing callable") + + defer wg.Done() + + defer func() { + if recovered := recover(); recovered != nil { + revoveredErr, ok := recovered.(error) + if ok { + logger.Error(ctx, "recovered runtime error", logger.E(errors.WithStack(revoveredErr))) + + err = errors.WithStack(ErUnknownError) + + return + } + + panic(recovered) + } + }() + jsArgs := make([]goja.Value, 0, len(args)) for _, a := range args { jsArgs = append(jsArgs, vm.ToValue(a)) @@ -54,8 +80,6 @@ func (s *Server) Exec(callable goja.Callable, args ...interface{}) (goja.Value, if err != nil { err = errors.WithStack(err) } - - wg.Done() }) wg.Wait() diff --git a/pkg/module/auth/module_test.go b/pkg/module/auth/module_test.go index 0a7d908..cf085e4 100644 --- a/pkg/module/auth/module_test.go +++ b/pkg/module/auth/module_test.go @@ -65,7 +65,7 @@ func TestAuthModule(t *testing.T) { ctx := context.WithValue(context.Background(), edgeHTTP.ContextKeyOriginRequest, req) - if _, err := server.ExecFuncByName("testAuth", ctx); err != nil { + if _, err := server.ExecFuncByName(ctx, "testAuth", ctx); err != nil { t.Fatalf("%+v", errors.WithStack(err)) } } @@ -104,7 +104,7 @@ func TestAuthAnonymousModule(t *testing.T) { ctx := context.WithValue(context.Background(), edgeHTTP.ContextKeyOriginRequest, req) - if _, err := server.ExecFuncByName("testAuth", ctx); err != nil { + if _, err := server.ExecFuncByName(ctx, "testAuth", ctx); err != nil { t.Fatalf("%+v", errors.WithStack(err)) } } diff --git a/pkg/module/blob.go b/pkg/module/blob.go index b52cc90..b6264f6 100644 --- a/pkg/module/blob.go +++ b/pkg/module/blob.go @@ -88,7 +88,7 @@ func (m *BlobModule) handleUploadRequest(req *MessageUploadRequest) (*MessageUpl "contentType": req.FileHeader.Header.Get("Content-Type"), } - rawResult, err := m.server.ExecFuncByName("onBlobUpload", ctx, blobID, blobInfo, req.Metadata) + rawResult, err := m.server.ExecFuncByName(ctx, "onBlobUpload", ctx, blobID, blobInfo, req.Metadata) if err != nil { if errors.Is(err, app.ErrFuncDoesNotExist) { res.Allow = false @@ -193,7 +193,7 @@ func (m *BlobModule) saveBlob(ctx context.Context, bucketName string, blobID sto func (m *BlobModule) handleDownloadRequest(req *MessageDownloadRequest) (*MessageDownloadResponse, error) { res := NewMessageDownloadResponse(req.RequestID) - rawResult, err := m.server.ExecFuncByName("onBlobDownload", req.Context, req.Bucket, req.BlobID) + rawResult, err := m.server.ExecFuncByName(req.Context, "onBlobDownload", req.Context, req.Bucket, req.BlobID) if err != nil { if errors.Is(err, app.ErrFuncDoesNotExist) { res.Allow = false diff --git a/pkg/module/cast/module_test.go b/pkg/module/cast/module_test.go index 5e6f41a..01c0a1a 100644 --- a/pkg/module/cast/module_test.go +++ b/pkg/module/cast/module_test.go @@ -1,6 +1,7 @@ package cast import ( + "context" "io/ioutil" "os" "testing" @@ -79,7 +80,7 @@ func TestCastModuleRefreshDevices(t *testing.T) { defer server.Stop() - result, err := server.ExecFuncByName("refreshDevices") + result, err := server.ExecFuncByName(context.Background(), "refreshDevices") if err != nil { t.Error(errors.WithStack(err)) } diff --git a/pkg/module/lifecycle.go b/pkg/module/lifecycle.go index febd02b..2a23e4a 100644 --- a/pkg/module/lifecycle.go +++ b/pkg/module/lifecycle.go @@ -21,7 +21,7 @@ func (m *LifecycleModule) Export(export *goja.Object) { } func (m *LifecycleModule) OnInit() error { - if _, err := m.server.ExecFuncByName("onInit"); err != nil { + if _, err := m.server.ExecFuncByName(context.Background(), "onInit"); err != nil { if errors.Is(err, app.ErrFuncDoesNotExist) { logger.Warn(context.Background(), "could not find onInit() function", logger.E(errors.WithStack(err))) diff --git a/pkg/module/net/module.go b/pkg/module/net/module.go index d1f8a03..29eeb80 100644 --- a/pkg/module/net/module.go +++ b/pkg/module/net/module.go @@ -129,7 +129,7 @@ func (m *Module) handleClientMessages() { logger.F("message", clientMessage), ) - if _, err := m.server.ExecFuncByName("onClientMessage", clientMessage.Context, clientMessage.Data); err != nil { + if _, err := m.server.ExecFuncByName(clientMessage.Context, "onClientMessage", clientMessage.Context, clientMessage.Data); err != nil { if errors.Is(err, app.ErrFuncDoesNotExist) { continue } diff --git a/pkg/module/rpc.go b/pkg/module/rpc.go index dd3108e..732d9af 100644 --- a/pkg/module/rpc.go +++ b/pkg/module/rpc.go @@ -161,7 +161,7 @@ func (m *RPCModule) handleMessages() { continue } - result, err := m.server.Exec(callable, clientMessage.Context, req.Params) + result, err := m.server.Exec(clientMessage.Context, callable, clientMessage.Context, req.Params) if err != nil { logger.Error( ctx, "rpc call error", diff --git a/pkg/module/store/module_test.go b/pkg/module/store/module_test.go index 0a1f204..ebc4801 100644 --- a/pkg/module/store/module_test.go +++ b/pkg/module/store/module_test.go @@ -1,6 +1,7 @@ package store import ( + "context" "io/ioutil" "testing" @@ -34,7 +35,7 @@ func TestStoreModule(t *testing.T) { t.Fatalf("%+v", errors.WithStack(err)) } - if _, err := server.ExecFuncByName("testStore"); err != nil { + if _, err := server.ExecFuncByName(context.Background(), "testStore"); err != nil { t.Fatalf("%+v", errors.WithStack(err)) }