go-tunnel/http.go

103 lines
2.3 KiB
Go

package tunnel
import (
"context"
"net"
"net/http"
"net/http/httputil"
"net/url"
"time"
"github.com/pkg/errors"
"gitlab.com/wpetit/goweb/logger"
)
type contextKey string
const remoteClientKey contextKey = "go-tunnel.remoteclient"
var (
ErrAbortProxy = errors.New("proxy aborted")
)
type MatchRequestFunc func(w http.ResponseWriter, r *http.Request) (*RemoteClient, error)
func ProxyHandler(targetURL string, match MatchRequestFunc, funcs ...ProxyConfigFunc) (http.Handler, error) {
conf := DefaultProxyConfig()
for _, fn := range funcs {
fn(conf)
}
target, err := url.Parse(targetURL)
if err != nil {
return nil, errors.WithStack(err)
}
reverse := createReverseProxy(target)
if conf.ConfigureReverseProxy != nil {
if err := conf.ConfigureReverseProxy(reverse); err != nil {
return nil, errors.WithStack(err)
}
}
fn := func(w http.ResponseWriter, r *http.Request) {
remoteClient, err := match(w, r)
if errors.Is(err, ErrAbortProxy) {
return
}
if err != nil {
logger.Error(r.Context(), "could not match proxy request", logger.E(errors.WithStack(err)))
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
if remoteClient == nil {
http.Error(w, http.StatusText(http.StatusBadGateway), http.StatusBadGateway)
return
}
ctx := context.WithValue(r.Context(), remoteClientKey, remoteClient)
r = r.WithContext(ctx)
reverse.ServeHTTP(w, r)
}
return http.HandlerFunc(fn), nil
}
func createReverseProxy(target *url.URL) *httputil.ReverseProxy {
reverse := httputil.NewSingleHostReverseProxy(target)
// nolint: go-mnd
reverse.Transport = &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: func(ctx context.Context, network string, addr string) (net.Conn, error) {
remoteClient, ok := ctx.Value(remoteClientKey).(*RemoteClient)
if !ok {
return nil, errors.New("could not retrieve remote client")
}
conn, err := remoteClient.Proxy(ctx, network, addr)
if err != nil {
return nil, errors.WithStack(err)
}
return conn, nil
},
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
reverse.FlushInterval = 0
return reverse
}