feat: rewrite bus to prevent deadlocks
This commit is contained in:
38
pkg/module/fetch/envelope.go
Normal file
38
pkg/module/fetch/envelope.go
Normal file
@ -0,0 +1,38 @@
|
||||
package fetch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
|
||||
"forge.cadoles.com/arcad/edge/pkg/bus"
|
||||
)
|
||||
|
||||
const (
|
||||
AddressFetchRequest bus.Address = "module/fetch/request"
|
||||
AddressFetchResponse bus.Address = "module/fetch/response"
|
||||
)
|
||||
|
||||
type FetchRequest struct {
|
||||
Context context.Context
|
||||
RequestID string
|
||||
URL *url.URL
|
||||
RemoteAddr string
|
||||
}
|
||||
|
||||
func NewFetchRequestEnvelope(ctx context.Context, remoteAddr string, url *url.URL) bus.Envelope {
|
||||
return bus.NewEnvelope(AddressFetchRequest, &FetchRequest{
|
||||
Context: ctx,
|
||||
URL: url,
|
||||
RemoteAddr: remoteAddr,
|
||||
})
|
||||
}
|
||||
|
||||
type FetchResponse struct {
|
||||
Allow bool
|
||||
}
|
||||
|
||||
func NewFetchResponseEnvelope(allow bool) bus.Envelope {
|
||||
return bus.NewEnvelope(AddressFetchResponse, &FetchResponse{
|
||||
Allow: allow,
|
||||
})
|
||||
}
|
@ -1,49 +0,0 @@
|
||||
package fetch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/url"
|
||||
|
||||
"forge.cadoles.com/arcad/edge/pkg/bus"
|
||||
"github.com/oklog/ulid/v2"
|
||||
)
|
||||
|
||||
const (
|
||||
MessageNamespaceFetchRequest bus.MessageNamespace = "fetchRequest"
|
||||
MessageNamespaceFetchResponse bus.MessageNamespace = "fetchResponse"
|
||||
)
|
||||
|
||||
type MessageFetchRequest struct {
|
||||
Context context.Context
|
||||
RequestID string
|
||||
URL *url.URL
|
||||
RemoteAddr string
|
||||
}
|
||||
|
||||
func (m *MessageFetchRequest) MessageNamespace() bus.MessageNamespace {
|
||||
return MessageNamespaceFetchRequest
|
||||
}
|
||||
|
||||
func NewMessageFetchRequest(ctx context.Context, remoteAddr string, url *url.URL) *MessageFetchRequest {
|
||||
return &MessageFetchRequest{
|
||||
Context: ctx,
|
||||
RequestID: ulid.Make().String(),
|
||||
RemoteAddr: remoteAddr,
|
||||
URL: url,
|
||||
}
|
||||
}
|
||||
|
||||
type MessageFetchResponse struct {
|
||||
RequestID string
|
||||
Allow bool
|
||||
}
|
||||
|
||||
func (m *MessageFetchResponse) MessageNamespace() bus.MessageNamespace {
|
||||
return MessageNamespaceFetchResponse
|
||||
}
|
||||
|
||||
func NewMessageFetchResponse(requestID string) *MessageFetchResponse {
|
||||
return &MessageFetchResponse{
|
||||
RequestID: requestID,
|
||||
}
|
||||
}
|
115
pkg/module/fetch/http.go
Normal file
115
pkg/module/fetch/http.go
Normal file
@ -0,0 +1,115 @@
|
||||
package fetch
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
|
||||
edgehttp "forge.cadoles.com/arcad/edge/pkg/http"
|
||||
"github.com/go-chi/chi/v5"
|
||||
"github.com/pkg/errors"
|
||||
"gitlab.com/wpetit/goweb/logger"
|
||||
)
|
||||
|
||||
func Mount() func(r chi.Router) {
|
||||
return func(r chi.Router) {
|
||||
r.Get("/api/v1/fetch", handleAppFetch)
|
||||
}
|
||||
}
|
||||
|
||||
func handleAppFetch(w http.ResponseWriter, r *http.Request) {
|
||||
ctx := r.Context()
|
||||
|
||||
rawURL := r.URL.Query().Get("url")
|
||||
|
||||
url, err := url.Parse(rawURL)
|
||||
if err != nil {
|
||||
edgehttp.JSONError(w, http.StatusBadRequest, edgehttp.ErrCodeBadRequest)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
requestMsg := NewFetchRequestEnvelope(ctx, r.RemoteAddr, url)
|
||||
|
||||
bus := edgehttp.ContextBus(ctx)
|
||||
|
||||
reply, err := bus.Request(ctx, requestMsg)
|
||||
if err != nil {
|
||||
logger.Error(ctx, "could not retrieve fetch request reply", logger.CapturedE(errors.WithStack(err)))
|
||||
edgehttp.JSONError(w, http.StatusInternalServerError, edgehttp.ErrCodeInternalError)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
logger.Debug(ctx, "fetch reply", logger.F("reply", reply))
|
||||
|
||||
responseMsg, ok := reply.Message().(*FetchResponse)
|
||||
if !ok {
|
||||
logger.Error(
|
||||
ctx, "unexpected fetch response message",
|
||||
logger.F("message", reply),
|
||||
)
|
||||
edgehttp.JSONError(w, http.StatusInternalServerError, edgehttp.ErrCodeInternalError)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if !responseMsg.Allow {
|
||||
edgehttp.JSONError(w, http.StatusForbidden, edgehttp.ErrCodeForbidden)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
proxyReq, err := http.NewRequest(http.MethodGet, url.String(), nil)
|
||||
if err != nil {
|
||||
logger.Error(
|
||||
ctx, "could not create proxy request",
|
||||
logger.CapturedE(errors.WithStack(err)),
|
||||
)
|
||||
edgehttp.JSONError(w, http.StatusInternalServerError, edgehttp.ErrCodeInternalError)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
for header, values := range r.Header {
|
||||
for _, value := range values {
|
||||
proxyReq.Header.Add(header, value)
|
||||
}
|
||||
}
|
||||
|
||||
proxyReq.Header.Add("X-Forwarded-From", r.RemoteAddr)
|
||||
|
||||
httpClient := edgehttp.ContextHTTPClient(ctx)
|
||||
|
||||
res, err := httpClient.Do(proxyReq)
|
||||
if err != nil {
|
||||
logger.Error(
|
||||
ctx, "could not execute proxy request",
|
||||
logger.CapturedE(errors.WithStack(err)),
|
||||
)
|
||||
edgehttp.JSONError(w, http.StatusInternalServerError, edgehttp.ErrCodeInternalError)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := res.Body.Close(); err != nil {
|
||||
logger.Error(
|
||||
ctx, "could not close response body",
|
||||
logger.CapturedE(errors.WithStack(err)),
|
||||
)
|
||||
}
|
||||
}()
|
||||
|
||||
for header, values := range res.Header {
|
||||
for _, value := range values {
|
||||
w.Header().Add(header, value)
|
||||
}
|
||||
}
|
||||
|
||||
w.WriteHeader(res.StatusCode)
|
||||
|
||||
if _, err := io.Copy(w, res.Body); err != nil {
|
||||
panic(errors.WithStack(err))
|
||||
}
|
||||
}
|
@ -40,10 +40,10 @@ func (m *Module) get(call goja.FunctionCall, rt *goja.Runtime) goja.Value {
|
||||
func (m *Module) handleMessages() {
|
||||
ctx := context.Background()
|
||||
|
||||
err := m.bus.Reply(ctx, MessageNamespaceFetchRequest, func(msg bus.Message) (bus.Message, error) {
|
||||
fetchRequest, ok := msg.(*MessageFetchRequest)
|
||||
fetchErrs := m.bus.Reply(ctx, AddressFetchRequest, func(env bus.Envelope) (any, error) {
|
||||
fetchRequest, ok := env.Message().(*FetchRequest)
|
||||
if !ok {
|
||||
return nil, errors.Wrapf(bus.ErrUnexpectedMessage, "expected message fetch request, got '%T'", msg)
|
||||
return nil, errors.Wrapf(bus.ErrUnexpectedMessage, "expected fetch request, got '%T'", env.Message())
|
||||
}
|
||||
|
||||
res, err := m.handleFetchRequest(fetchRequest)
|
||||
@ -57,13 +57,14 @@ func (m *Module) handleMessages() {
|
||||
|
||||
return res, nil
|
||||
})
|
||||
if err != nil {
|
||||
panic(errors.WithStack(err))
|
||||
|
||||
for err := range fetchErrs {
|
||||
logger.Fatal(ctx, "error while replying to fetch requests", logger.CapturedE(errors.WithStack(err)))
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Module) handleFetchRequest(req *MessageFetchRequest) (*MessageFetchResponse, error) {
|
||||
res := NewMessageFetchResponse(req.RequestID)
|
||||
func (m *Module) handleFetchRequest(req *FetchRequest) (*FetchResponse, error) {
|
||||
res := &FetchResponse{}
|
||||
|
||||
ctx := logger.With(
|
||||
req.Context,
|
||||
@ -83,11 +84,11 @@ func (m *Module) handleFetchRequest(req *MessageFetchRequest) (*MessageFetchResp
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
result, ok := rawResult.Export().(map[string]interface{})
|
||||
result, ok := rawResult.(map[string]interface{})
|
||||
if !ok {
|
||||
return nil, errors.Errorf(
|
||||
"unexpected onClientFetch result: expected 'map[string]interface{}', got '%T'",
|
||||
rawResult.Export(),
|
||||
rawResult,
|
||||
)
|
||||
}
|
||||
|
||||
|
@ -2,8 +2,8 @@ package fetch
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io/ioutil"
|
||||
"net/url"
|
||||
"os"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
@ -18,7 +18,9 @@ import (
|
||||
func TestFetchModule(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
logger.SetLevel(slog.LevelDebug)
|
||||
if testing.Verbose() {
|
||||
logger.SetLevel(slog.LevelDebug)
|
||||
}
|
||||
|
||||
bus := memory.NewBus()
|
||||
|
||||
@ -28,22 +30,20 @@ func TestFetchModule(t *testing.T) {
|
||||
ModuleFactory(bus),
|
||||
)
|
||||
|
||||
data, err := ioutil.ReadFile("testdata/fetch.js")
|
||||
path := "testdata/fetch.js"
|
||||
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
t.Fatalf("%+v", errors.WithStack(err))
|
||||
}
|
||||
|
||||
if err := server.Load("testdata/fetch.js", string(data)); err != nil {
|
||||
ctx := context.Background()
|
||||
if err := server.Start(ctx, path, string(data)); err != nil {
|
||||
t.Fatalf("%+v", errors.WithStack(err))
|
||||
}
|
||||
|
||||
defer server.Stop()
|
||||
|
||||
ctx := context.Background()
|
||||
if err := server.Start(ctx); err != nil {
|
||||
t.Fatalf("%+v", errors.WithStack(err))
|
||||
}
|
||||
|
||||
// Wait for module to startup
|
||||
time.Sleep(1 * time.Second)
|
||||
|
||||
@ -53,33 +53,33 @@ func TestFetchModule(t *testing.T) {
|
||||
remoteAddr := "127.0.0.1"
|
||||
url, _ := url.Parse("http://example.com")
|
||||
|
||||
rawReply, err := bus.Request(ctx, NewMessageFetchRequest(ctx, remoteAddr, url))
|
||||
reply, err := bus.Request(ctx, NewFetchRequestEnvelope(ctx, remoteAddr, url))
|
||||
if err != nil {
|
||||
t.Fatalf("%+v", errors.WithStack(err))
|
||||
}
|
||||
|
||||
reply, ok := rawReply.(*MessageFetchResponse)
|
||||
response, ok := reply.Message().(*FetchResponse)
|
||||
if !ok {
|
||||
t.Fatalf("unexpected reply type '%T'", rawReply)
|
||||
t.Fatalf("unexpected reply message type '%T'", reply.Message())
|
||||
}
|
||||
|
||||
if e, g := true, reply.Allow; e != g {
|
||||
if e, g := true, response.Allow; e != g {
|
||||
t.Errorf("reply.Allow: expected '%v', got '%v'", e, g)
|
||||
}
|
||||
|
||||
url, _ = url.Parse("https://google.com")
|
||||
|
||||
rawReply, err = bus.Request(ctx, NewMessageFetchRequest(ctx, remoteAddr, url))
|
||||
reply, err = bus.Request(ctx, NewFetchRequestEnvelope(ctx, remoteAddr, url))
|
||||
if err != nil {
|
||||
t.Fatalf("%+v", errors.WithStack(err))
|
||||
}
|
||||
|
||||
reply, ok = rawReply.(*MessageFetchResponse)
|
||||
response, ok = reply.Message().(*FetchResponse)
|
||||
if !ok {
|
||||
t.Fatalf("unexpected reply type '%T'", rawReply)
|
||||
t.Fatalf("unexpected reply message type '%T'", reply.Message())
|
||||
}
|
||||
|
||||
if e, g := false, reply.Allow; e != g {
|
||||
if e, g := false, response.Allow; e != g {
|
||||
t.Errorf("reply.Allow: expected '%v', got '%v'", e, g)
|
||||
}
|
||||
}
|
||||
|
Reference in New Issue
Block a user