feat: rewrite bus to prevent deadlocks
All checks were successful
arcad/edge/pipeline/head This commit looks good
arcad/edge/pipeline/pr-master This commit looks good

This commit is contained in:
2023-11-28 16:35:49 +01:00
parent f4a7366aad
commit ad49c1718c
50 changed files with 1621 additions and 1336 deletions

View File

@ -0,0 +1,38 @@
package fetch
import (
"context"
"net/url"
"forge.cadoles.com/arcad/edge/pkg/bus"
)
const (
AddressFetchRequest bus.Address = "module/fetch/request"
AddressFetchResponse bus.Address = "module/fetch/response"
)
type FetchRequest struct {
Context context.Context
RequestID string
URL *url.URL
RemoteAddr string
}
func NewFetchRequestEnvelope(ctx context.Context, remoteAddr string, url *url.URL) bus.Envelope {
return bus.NewEnvelope(AddressFetchRequest, &FetchRequest{
Context: ctx,
URL: url,
RemoteAddr: remoteAddr,
})
}
type FetchResponse struct {
Allow bool
}
func NewFetchResponseEnvelope(allow bool) bus.Envelope {
return bus.NewEnvelope(AddressFetchResponse, &FetchResponse{
Allow: allow,
})
}

View File

@ -1,49 +0,0 @@
package fetch
import (
"context"
"net/url"
"forge.cadoles.com/arcad/edge/pkg/bus"
"github.com/oklog/ulid/v2"
)
const (
MessageNamespaceFetchRequest bus.MessageNamespace = "fetchRequest"
MessageNamespaceFetchResponse bus.MessageNamespace = "fetchResponse"
)
type MessageFetchRequest struct {
Context context.Context
RequestID string
URL *url.URL
RemoteAddr string
}
func (m *MessageFetchRequest) MessageNamespace() bus.MessageNamespace {
return MessageNamespaceFetchRequest
}
func NewMessageFetchRequest(ctx context.Context, remoteAddr string, url *url.URL) *MessageFetchRequest {
return &MessageFetchRequest{
Context: ctx,
RequestID: ulid.Make().String(),
RemoteAddr: remoteAddr,
URL: url,
}
}
type MessageFetchResponse struct {
RequestID string
Allow bool
}
func (m *MessageFetchResponse) MessageNamespace() bus.MessageNamespace {
return MessageNamespaceFetchResponse
}
func NewMessageFetchResponse(requestID string) *MessageFetchResponse {
return &MessageFetchResponse{
RequestID: requestID,
}
}

115
pkg/module/fetch/http.go Normal file
View File

@ -0,0 +1,115 @@
package fetch
import (
"io"
"net/http"
"net/url"
edgehttp "forge.cadoles.com/arcad/edge/pkg/http"
"github.com/go-chi/chi/v5"
"github.com/pkg/errors"
"gitlab.com/wpetit/goweb/logger"
)
func Mount() func(r chi.Router) {
return func(r chi.Router) {
r.Get("/api/v1/fetch", handleAppFetch)
}
}
func handleAppFetch(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
rawURL := r.URL.Query().Get("url")
url, err := url.Parse(rawURL)
if err != nil {
edgehttp.JSONError(w, http.StatusBadRequest, edgehttp.ErrCodeBadRequest)
return
}
requestMsg := NewFetchRequestEnvelope(ctx, r.RemoteAddr, url)
bus := edgehttp.ContextBus(ctx)
reply, err := bus.Request(ctx, requestMsg)
if err != nil {
logger.Error(ctx, "could not retrieve fetch request reply", logger.CapturedE(errors.WithStack(err)))
edgehttp.JSONError(w, http.StatusInternalServerError, edgehttp.ErrCodeInternalError)
return
}
logger.Debug(ctx, "fetch reply", logger.F("reply", reply))
responseMsg, ok := reply.Message().(*FetchResponse)
if !ok {
logger.Error(
ctx, "unexpected fetch response message",
logger.F("message", reply),
)
edgehttp.JSONError(w, http.StatusInternalServerError, edgehttp.ErrCodeInternalError)
return
}
if !responseMsg.Allow {
edgehttp.JSONError(w, http.StatusForbidden, edgehttp.ErrCodeForbidden)
return
}
proxyReq, err := http.NewRequest(http.MethodGet, url.String(), nil)
if err != nil {
logger.Error(
ctx, "could not create proxy request",
logger.CapturedE(errors.WithStack(err)),
)
edgehttp.JSONError(w, http.StatusInternalServerError, edgehttp.ErrCodeInternalError)
return
}
for header, values := range r.Header {
for _, value := range values {
proxyReq.Header.Add(header, value)
}
}
proxyReq.Header.Add("X-Forwarded-From", r.RemoteAddr)
httpClient := edgehttp.ContextHTTPClient(ctx)
res, err := httpClient.Do(proxyReq)
if err != nil {
logger.Error(
ctx, "could not execute proxy request",
logger.CapturedE(errors.WithStack(err)),
)
edgehttp.JSONError(w, http.StatusInternalServerError, edgehttp.ErrCodeInternalError)
return
}
defer func() {
if err := res.Body.Close(); err != nil {
logger.Error(
ctx, "could not close response body",
logger.CapturedE(errors.WithStack(err)),
)
}
}()
for header, values := range res.Header {
for _, value := range values {
w.Header().Add(header, value)
}
}
w.WriteHeader(res.StatusCode)
if _, err := io.Copy(w, res.Body); err != nil {
panic(errors.WithStack(err))
}
}

View File

@ -40,10 +40,10 @@ func (m *Module) get(call goja.FunctionCall, rt *goja.Runtime) goja.Value {
func (m *Module) handleMessages() {
ctx := context.Background()
err := m.bus.Reply(ctx, MessageNamespaceFetchRequest, func(msg bus.Message) (bus.Message, error) {
fetchRequest, ok := msg.(*MessageFetchRequest)
fetchErrs := m.bus.Reply(ctx, AddressFetchRequest, func(env bus.Envelope) (any, error) {
fetchRequest, ok := env.Message().(*FetchRequest)
if !ok {
return nil, errors.Wrapf(bus.ErrUnexpectedMessage, "expected message fetch request, got '%T'", msg)
return nil, errors.Wrapf(bus.ErrUnexpectedMessage, "expected fetch request, got '%T'", env.Message())
}
res, err := m.handleFetchRequest(fetchRequest)
@ -57,13 +57,14 @@ func (m *Module) handleMessages() {
return res, nil
})
if err != nil {
panic(errors.WithStack(err))
for err := range fetchErrs {
logger.Fatal(ctx, "error while replying to fetch requests", logger.CapturedE(errors.WithStack(err)))
}
}
func (m *Module) handleFetchRequest(req *MessageFetchRequest) (*MessageFetchResponse, error) {
res := NewMessageFetchResponse(req.RequestID)
func (m *Module) handleFetchRequest(req *FetchRequest) (*FetchResponse, error) {
res := &FetchResponse{}
ctx := logger.With(
req.Context,
@ -83,11 +84,11 @@ func (m *Module) handleFetchRequest(req *MessageFetchRequest) (*MessageFetchResp
return nil, errors.WithStack(err)
}
result, ok := rawResult.Export().(map[string]interface{})
result, ok := rawResult.(map[string]interface{})
if !ok {
return nil, errors.Errorf(
"unexpected onClientFetch result: expected 'map[string]interface{}', got '%T'",
rawResult.Export(),
rawResult,
)
}

View File

@ -2,8 +2,8 @@ package fetch
import (
"context"
"io/ioutil"
"net/url"
"os"
"testing"
"time"
@ -18,7 +18,9 @@ import (
func TestFetchModule(t *testing.T) {
t.Parallel()
logger.SetLevel(slog.LevelDebug)
if testing.Verbose() {
logger.SetLevel(slog.LevelDebug)
}
bus := memory.NewBus()
@ -28,22 +30,20 @@ func TestFetchModule(t *testing.T) {
ModuleFactory(bus),
)
data, err := ioutil.ReadFile("testdata/fetch.js")
path := "testdata/fetch.js"
data, err := os.ReadFile(path)
if err != nil {
t.Fatalf("%+v", errors.WithStack(err))
}
if err := server.Load("testdata/fetch.js", string(data)); err != nil {
ctx := context.Background()
if err := server.Start(ctx, path, string(data)); err != nil {
t.Fatalf("%+v", errors.WithStack(err))
}
defer server.Stop()
ctx := context.Background()
if err := server.Start(ctx); err != nil {
t.Fatalf("%+v", errors.WithStack(err))
}
// Wait for module to startup
time.Sleep(1 * time.Second)
@ -53,33 +53,33 @@ func TestFetchModule(t *testing.T) {
remoteAddr := "127.0.0.1"
url, _ := url.Parse("http://example.com")
rawReply, err := bus.Request(ctx, NewMessageFetchRequest(ctx, remoteAddr, url))
reply, err := bus.Request(ctx, NewFetchRequestEnvelope(ctx, remoteAddr, url))
if err != nil {
t.Fatalf("%+v", errors.WithStack(err))
}
reply, ok := rawReply.(*MessageFetchResponse)
response, ok := reply.Message().(*FetchResponse)
if !ok {
t.Fatalf("unexpected reply type '%T'", rawReply)
t.Fatalf("unexpected reply message type '%T'", reply.Message())
}
if e, g := true, reply.Allow; e != g {
if e, g := true, response.Allow; e != g {
t.Errorf("reply.Allow: expected '%v', got '%v'", e, g)
}
url, _ = url.Parse("https://google.com")
rawReply, err = bus.Request(ctx, NewMessageFetchRequest(ctx, remoteAddr, url))
reply, err = bus.Request(ctx, NewFetchRequestEnvelope(ctx, remoteAddr, url))
if err != nil {
t.Fatalf("%+v", errors.WithStack(err))
}
reply, ok = rawReply.(*MessageFetchResponse)
response, ok = reply.Message().(*FetchResponse)
if !ok {
t.Fatalf("unexpected reply type '%T'", rawReply)
t.Fatalf("unexpected reply message type '%T'", reply.Message())
}
if e, g := false, reply.Allow; e != g {
if e, g := false, response.Allow; e != g {
t.Errorf("reply.Allow: expected '%v', got '%v'", e, g)
}
}