103 lines
2.3 KiB
Go
103 lines
2.3 KiB
Go
package tunnel
|
|
|
|
import (
|
|
"context"
|
|
"net"
|
|
"net/http"
|
|
"net/http/httputil"
|
|
"net/url"
|
|
"time"
|
|
|
|
"github.com/pkg/errors"
|
|
"gitlab.com/wpetit/goweb/logger"
|
|
)
|
|
|
|
type contextKey string
|
|
|
|
const remoteClientKey contextKey = "go-tunnel.remoteclient"
|
|
|
|
var (
|
|
ErrAbortProxy = errors.New("proxy aborted")
|
|
)
|
|
|
|
type MatchRequestFunc func(w http.ResponseWriter, r *http.Request) (*RemoteClient, error)
|
|
|
|
func ProxyHandler(targetURL string, match MatchRequestFunc, funcs ...ProxyConfigFunc) (http.Handler, error) {
|
|
conf := DefaultProxyConfig()
|
|
|
|
for _, fn := range funcs {
|
|
fn(conf)
|
|
}
|
|
|
|
target, err := url.Parse(targetURL)
|
|
if err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
|
|
reverse := createReverseProxy(target)
|
|
|
|
if conf.ConfigureReverseProxy != nil {
|
|
if err := conf.ConfigureReverseProxy(reverse); err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
}
|
|
|
|
fn := func(w http.ResponseWriter, r *http.Request) {
|
|
remoteClient, err := match(w, r)
|
|
if errors.Is(err, ErrAbortProxy) {
|
|
return
|
|
}
|
|
|
|
if err != nil {
|
|
logger.Error(r.Context(), "could not match proxy request", logger.E(errors.WithStack(err)))
|
|
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
|
|
|
return
|
|
}
|
|
|
|
if remoteClient == nil {
|
|
http.Error(w, http.StatusText(http.StatusBadGateway), http.StatusBadGateway)
|
|
|
|
return
|
|
}
|
|
|
|
ctx := context.WithValue(r.Context(), remoteClientKey, remoteClient)
|
|
r = r.WithContext(ctx)
|
|
|
|
reverse.ServeHTTP(w, r)
|
|
}
|
|
|
|
return http.HandlerFunc(fn), nil
|
|
}
|
|
|
|
func createReverseProxy(target *url.URL) *httputil.ReverseProxy {
|
|
reverse := httputil.NewSingleHostReverseProxy(target)
|
|
|
|
// nolint: go-mnd
|
|
reverse.Transport = &http.Transport{
|
|
Proxy: http.ProxyFromEnvironment,
|
|
DialContext: func(ctx context.Context, network string, addr string) (net.Conn, error) {
|
|
remoteClient, ok := ctx.Value(remoteClientKey).(*RemoteClient)
|
|
if !ok {
|
|
return nil, errors.New("could not retrieve remote client")
|
|
}
|
|
|
|
conn, err := remoteClient.Proxy(ctx, network, addr)
|
|
if err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
|
|
return conn, nil
|
|
},
|
|
ForceAttemptHTTP2: true,
|
|
MaxIdleConns: 100,
|
|
IdleConnTimeout: 90 * time.Second,
|
|
TLSHandshakeTimeout: 10 * time.Second,
|
|
ExpectContinueTimeout: 1 * time.Second,
|
|
}
|
|
|
|
reverse.FlushInterval = 0
|
|
|
|
return reverse
|
|
}
|