edge/pkg/http/fetch.go

113 lines
2.4 KiB
Go

package http
import (
"io"
"net/http"
"net/url"
"forge.cadoles.com/arcad/edge/pkg/module"
"forge.cadoles.com/arcad/edge/pkg/module/fetch"
"github.com/pkg/errors"
"gitlab.com/wpetit/goweb/logger"
)
func (h *Handler) handleAppFetch(w http.ResponseWriter, r *http.Request) {
h.mutex.RLock()
defer h.mutex.RUnlock()
ctx := r.Context()
ctx = module.WithContext(ctx, map[module.ContextKey]any{
ContextKeyOriginRequest: r,
})
rawURL := r.URL.Query().Get("url")
url, err := url.Parse(rawURL)
if err != nil {
jsonError(w, http.StatusBadRequest, errorCodeBadRequest)
return
}
requestMsg := fetch.NewMessageFetchRequest(ctx, r.RemoteAddr, url)
reply, err := h.bus.Request(ctx, requestMsg)
if err != nil {
logger.Error(ctx, "could not retrieve fetch request reply", logger.CapturedE(errors.WithStack(err)))
jsonError(w, http.StatusInternalServerError, errorCodeInternalError)
return
}
logger.Debug(ctx, "fetch reply", logger.F("reply", reply))
responseMsg, ok := reply.(*fetch.MessageFetchResponse)
if !ok {
logger.Error(
ctx, "unexpected fetch response message",
logger.F("message", reply),
)
jsonError(w, http.StatusInternalServerError, errorCodeInternalError)
return
}
if !responseMsg.Allow {
jsonError(w, http.StatusForbidden, errorCodeForbidden)
return
}
proxyReq, err := http.NewRequest(http.MethodGet, url.String(), nil)
if err != nil {
logger.Error(
ctx, "could not create proxy request",
logger.CapturedE(errors.WithStack(err)),
)
jsonError(w, http.StatusInternalServerError, errorCodeInternalError)
return
}
for header, values := range r.Header {
for _, value := range values {
proxyReq.Header.Add(header, value)
}
}
proxyReq.Header.Add("X-Forwarded-From", r.RemoteAddr)
res, err := h.httpClient.Do(proxyReq)
if err != nil {
logger.Error(
ctx, "could not execute proxy request",
logger.CapturedE(errors.WithStack(err)),
)
jsonError(w, http.StatusInternalServerError, errorCodeInternalError)
return
}
defer func() {
if err := res.Body.Close(); err != nil {
logger.Error(
ctx, "could not close response body",
logger.CapturedE(errors.WithStack(err)),
)
}
}()
for header, values := range res.Header {
for _, value := range values {
w.Header().Add(header, value)
}
}
w.WriteHeader(res.StatusCode)
if _, err := io.Copy(w, res.Body); err != nil {
panic(errors.WithStack(err))
}
}