package tunnel import ( "context" "net" "net/http" "net/http/httputil" "net/url" "time" "github.com/pkg/errors" "gitlab.com/wpetit/goweb/logger" ) type contextKey string const remoteClientKey contextKey = "go-tunnel.remoteclient" var ( ErrAbortProxy = errors.New("proxy aborted") ) type MatchRequestFunc func(w http.ResponseWriter, r *http.Request) (*RemoteClient, error) func ProxyHandler(targetURL string, match MatchRequestFunc, funcs ...ProxyConfigFunc) (http.Handler, error) { conf := DefaultProxyConfig() for _, fn := range funcs { fn(conf) } target, err := url.Parse(targetURL) if err != nil { return nil, errors.WithStack(err) } reverse := createReverseProxy(target) if conf.ConfigureReverseProxy != nil { if err := conf.ConfigureReverseProxy(reverse); err != nil { return nil, errors.WithStack(err) } } fn := func(w http.ResponseWriter, r *http.Request) { remoteClient, err := match(w, r) if errors.Is(err, ErrAbortProxy) { return } if err != nil { logger.Error(r.Context(), "could not match proxy request", logger.E(errors.WithStack(err))) http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) return } if remoteClient == nil { http.Error(w, http.StatusText(http.StatusBadGateway), http.StatusBadGateway) return } ctx := context.WithValue(r.Context(), remoteClientKey, remoteClient) r = r.WithContext(ctx) reverse.ServeHTTP(w, r) } return http.HandlerFunc(fn), nil } func createReverseProxy(target *url.URL) *httputil.ReverseProxy { reverse := httputil.NewSingleHostReverseProxy(target) // nolint: go-mnd reverse.Transport = &http.Transport{ Proxy: http.ProxyFromEnvironment, DialContext: func(ctx context.Context, network string, addr string) (net.Conn, error) { remoteClient, ok := ctx.Value(remoteClientKey).(*RemoteClient) if !ok { return nil, errors.New("could not retrieve remote client") } conn, err := remoteClient.Proxy(ctx, network, addr) if err != nil { return nil, errors.WithStack(err) } return conn, nil }, ForceAttemptHTTP2: true, MaxIdleConns: 100, IdleConnTimeout: 90 * time.Second, TLSHandshakeTimeout: 10 * time.Second, ExpectContinueTimeout: 1 * time.Second, } reverse.FlushInterval = 0 return reverse }