diff --git a/Makefile b/Makefile index 792bc00..bf80a30 100644 --- a/Makefile +++ b/Makefile @@ -4,10 +4,10 @@ watch: build: build-server build-client build-server: - go build -o ./bin/server ./cmd/server + CGO_ENABLED=0 go build -o ./bin/server ./cmd/server build-client: - go build -o ./bin/client ./cmd/client + CGO_ENABLED=0 go build -o ./bin/client ./cmd/client test: go test -v -race ./... \ No newline at end of file diff --git a/client.go b/client.go index b138ef7..8ff8e5a 100644 --- a/client.go +++ b/client.go @@ -2,30 +2,27 @@ package tunnel import ( "context" + "encoding/json" "io" "net" - "net/http" - "sync" "time" "gitlab.com/wpetit/goweb/logger" - "forge.cadoles.com/wpetit/go-tunnel/control" "github.com/pkg/errors" "github.com/xtaci/kcp-go/v5" "github.com/xtaci/smux" ) type Client struct { - conf *ClientConfig - conn *kcp.UDPSession - sess *smux.Session - control *control.Control - http *http.Client - openStreamMutex sync.Mutex + conf *ClientConfig + conn *kcp.UDPSession + sess *smux.Session } func (c *Client) Connect(ctx context.Context) error { + logger.Debug(ctx, "connecting", logger.F("serverAddr", c.conf.ServerAddress)) + conn, err := kcp.DialWithOptions( c.conf.ServerAddress, c.conf.BlockCrypt, c.conf.DataShards, c.conf.ParityShards, @@ -40,34 +37,29 @@ func (c *Client) Connect(ctx context.Context) error { } } - config := smux.DefaultConfig() - config.Version = 2 - config.KeepAliveInterval = 10 * time.Second - config.KeepAliveTimeout = 2 * config.KeepAliveInterval - - sess, err := smux.Client(conn, config) + sess, err := smux.Client(conn, c.conf.SmuxConfig) if err != nil { return errors.WithStack(err) } - control := control.New() - if err := control.Init(ctx, sess, false); err != nil { + stream, err := sess.OpenStream() + if err != nil { return errors.WithStack(err) } - logger.Debug(ctx, "sending auth request") + defer stream.Close() - success, err := control.AuthRequest(c.conf.Credentials) + success, err := c.authenticate(ctx, stream) if err != nil { return errors.WithStack(err) } if !success { - defer c.Close() - return errors.WithStack(ErrAuthFailed) + return errors.WithStack(ErrAuthenticationFailed) } - c.control = control + logger.Debug(ctx, "authentication success") + c.conn = conn c.sess = sess @@ -75,87 +67,138 @@ func (c *Client) Connect(ctx context.Context) error { } func (c *Client) Listen(ctx context.Context) error { - logger.Debug(ctx, "listening for messages") + logger.Debug(ctx, "listening for proxy requests") - ctx, cancel := context.WithCancel(ctx) - defer cancel() + for { + stream, err := c.sess.AcceptStream() + if err != nil { + return errors.WithStack(err) + } - err := c.control.Listen(ctx, control.Handlers{ - control.TypeProxyRequest: c.handleProxyRequest, - }) + subCtx := logger.With(ctx, + logger.F("remoteAddr", stream.RemoteAddr()), + logger.F("localAddr", stream.LocalAddr()), + ) - if errors.Is(err, io.ErrClosedPipe) { - logger.Debug(ctx, "client connection closed") + readDeadline := time.Now().Add(c.conf.ProxyRequestTimeout) + logger.Debug(subCtx, "waiting for proxy request", logger.F("deadline", readDeadline)) - return errors.WithStack(ErrConnectionClosed) + if err := stream.SetReadDeadline(readDeadline); err != nil { + stream.Close() + logger.Error(subCtx, "could not set read deadline", logger.E(errors.WithStack(err))) + + continue + } + + decoder := json.NewDecoder(stream) + proxyReq := &proxyRequest{} + + if err := decoder.Decode(proxyReq); err != nil { + stream.Close() + logger.Error(subCtx, "could not decode proxy request", logger.E(errors.WithStack(err))) + + continue + } + + if err := stream.SetReadDeadline(time.Time{}); err != nil { + stream.Close() + logger.Error(subCtx, "could not set read deadline", logger.E(errors.WithStack(err))) + + continue + } + + go c.handleProxyStream(subCtx, stream, proxyReq.Network, proxyReq.Address) } - - return err } func (c *Client) Close() error { - if c.conn == nil { - return errors.WithStack(ErrNotConnected) + if c.sess != nil && !c.sess.IsClosed() { + if err := c.sess.Close(); err != nil { + return errors.WithStack(err) + } } - if err := c.conn.Close(); err != nil { - return errors.WithStack(err) + if c.conn != nil { + if err := c.conn.Close(); err != nil { + return errors.WithStack(err) + } } + c.conn = nil + c.sess = nil + return nil } -func (c *Client) handleProxyRequest(ctx context.Context, m *control.Message) (*control.Message, error) { - proxyReqPayload, ok := m.Payload.(*control.ProxyRequestPayload) - if !ok { - return nil, errors.WithStack(ErrUnexpectedMessage) +func (c *Client) authenticate(ctx context.Context, stream *smux.Stream) (bool, error) { + encoder := json.NewEncoder(stream) + authReq := &authRequest{ + Credentials: c.conf.Credentials, } - ctx = logger.With(ctx, - logger.F("network", proxyReqPayload.Network), - logger.F("address", proxyReqPayload.Address), - ) + start := time.Now() + writeDeadline := start.Add(c.conf.AuthenticationTimeout) + logger.Debug(ctx, "sending auth request", logger.F("deadline", writeDeadline)) - logger.Debug(ctx, "handling proxy request") - - out, err := net.Dial(proxyReqPayload.Network, proxyReqPayload.Address) - if err != nil { - return nil, errors.WithStack(err) + if err := stream.SetWriteDeadline(writeDeadline); err != nil { + return false, errors.WithStack(err) } - go c.handleProxyStream(ctx, out) + if err := encoder.Encode(authReq); err != nil { + return false, errors.WithStack(err) + } - return nil, nil + decoder := json.NewDecoder(stream) + authRes := &authResponse{} + + readDeadline := time.Now().Add(c.conf.AuthenticationTimeout - time.Now().Sub(start)) + logger.Debug(ctx, "waiting for auth response", logger.F("deadline", readDeadline)) + + if err := stream.SetReadDeadline(readDeadline); err != nil { + return false, errors.WithStack(err) + } + + if err := decoder.Decode(authRes); err != nil && !errors.Is(err, io.EOF) { + return false, errors.WithStack(err) + } + + return authRes.Success, nil } -func (c *Client) handleProxyStream(ctx context.Context, out net.Conn) { - c.openStreamMutex.Lock() +func (c *Client) handleProxyStream(ctx context.Context, in *smux.Stream, network, address string) { + defer func(start time.Time) { + logger.Debug(ctx, "handleProxyStream duration", logger.F("duration", time.Since(start))) + }(time.Now()) - in, err := c.sess.OpenStream() + defer in.Close() + + logger.Debug( + ctx, "proxying", + logger.F("network", network), + logger.F("address", address), + ) + + out, err := net.Dial(network, address) if err != nil { - c.openStreamMutex.Unlock() - logger.Error(ctx, "error while accepting proxy stream", logger.E(err)) + logger.Error(ctx, "could not dial", logger.E(errors.WithStack(err))) return } - - c.openStreamMutex.Unlock() + defer out.Close() streamCopy := func(dst io.Writer, src io.ReadCloser) { if _, err := Copy(dst, src); err != nil { if errors.Is(err, smux.ErrInvalidProtocol) { - logger.Error(ctx, "error while proxying", logger.E(errors.WithStack(err))) + logger.Error(ctx, "could not proxy", logger.E(errors.WithStack(err))) } } - logger.Debug(ctx, "closing proxy stream") - in.Close() out.Close() } - go streamCopy(in, out) - streamCopy(out, in) + go streamCopy(out, in) + streamCopy(in, out) } func NewClient(funcs ...ClientConfigFunc) *Client { @@ -167,6 +210,5 @@ func NewClient(funcs ...ClientConfigFunc) *Client { return &Client{ conf: conf, - http: &http.Client{}, } } diff --git a/client_config.go b/client_config.go index 9f2eb00..3d7ba8d 100644 --- a/client_config.go +++ b/client_config.go @@ -2,33 +2,49 @@ package tunnel import ( "crypto/sha1" + "time" "github.com/pkg/errors" "github.com/xtaci/kcp-go/v5" + "github.com/xtaci/smux" "golang.org/x/crypto/pbkdf2" ) type ClientConfig struct { - ServerAddress string - BlockCrypt kcp.BlockCrypt - DataShards int - ParityShards int - Credentials interface{} - ConfigureConn ConfigureConnFunc + ServerAddress string + BlockCrypt kcp.BlockCrypt + DataShards int + ParityShards int + Credentials interface{} + ConfigureConn ConfigureConnFunc + AuthenticationTimeout time.Duration + ProxyRequestTimeout time.Duration + SmuxConfig *smux.Config } +// nolint: go-mnd func DefaultClientConfig() *ClientConfig { unencryptedBlock, err := kcp.NewNoneBlockCrypt(nil) if err != nil { // should never happen panic(errors.WithStack(err)) } + smuxConfig := smux.DefaultConfig() + smuxConfig.Version = 2 + smuxConfig.KeepAliveInterval = 10 * time.Second + smuxConfig.MaxReceiveBuffer = 4194304 + smuxConfig.MaxStreamBuffer = 2097152 + return &ClientConfig{ - ServerAddress: "127.0.0.1:36543", - BlockCrypt: unencryptedBlock, - DataShards: 3, - ParityShards: 10, - Credentials: nil, + ServerAddress: "127.0.0.1:36543", + BlockCrypt: unencryptedBlock, + DataShards: 3, + ParityShards: 10, + Credentials: nil, + ConfigureConn: DefaultClientConfigureConn, + AuthenticationTimeout: 30 * time.Second, + ProxyRequestTimeout: 5 * time.Second, + SmuxConfig: smuxConfig, } } @@ -64,3 +80,34 @@ func WithClientConfigureConn(fn ConfigureConnFunc) ClientConfigFunc { conf.ConfigureConn = fn } } + +func WithClientSmuxConfig(c *smux.Config) ClientConfigFunc { + return func(conf *ClientConfig) { + conf.SmuxConfig = c + } +} + +// nolint: go-mnd +func DefaultClientConfigureConn(conn *kcp.UDPSession) error { + // Based on kcptun default configuration, mode 'fast3' + conn.SetStreamMode(true) + conn.SetWriteDelay(false) + conn.SetNoDelay(1, 10, 2, 1) + conn.SetWindowSize(128, 512) + conn.SetMtu(1400) + conn.SetACKNoDelay(true) + + if err := conn.SetReadBuffer(16777217); err != nil { + return errors.WithStack(err) + } + + if err := conn.SetWriteBuffer(16777217); err != nil { + return errors.WithStack(err) + } + + if err := conn.SetDSCP(46); err != nil { + return errors.WithStack(err) + } + + return nil +} diff --git a/cmd/client/main.go b/cmd/client/main.go index fe4888f..645a994 100644 --- a/cmd/client/main.go +++ b/cmd/client/main.go @@ -3,6 +3,7 @@ package main import ( "context" "flag" + "fmt" "math/rand" "time" @@ -11,15 +12,18 @@ import ( "gitlab.com/wpetit/goweb/logger" ) -const sharedKey = "go-tunnel" const salt = "go-tunnel" func main() { var ( - clientID string + clientID = fmt.Sprintf("client-%d", time.Now().Unix()) + serverAddr = "127.0.0.1:36543" + sharedKey = "go-tunnel" ) - flag.StringVar(&clientID, "id", "", "Client ID") + flag.StringVar(&sharedKey, "shared-key", sharedKey, "shared key") + flag.StringVar(&clientID, "id", clientID, "Client ID") + flag.StringVar(&serverAddr, "server-addr", serverAddr, "server address") flag.Parse() ctx := context.Background() @@ -28,12 +32,12 @@ func main() { logger.SetLevel(slog.LevelDebug) client := tunnel.NewClient( + tunnel.WithClientServerAddress(serverAddr), tunnel.WithClientCredentials(clientID), - tunnel.WithClientAESBlockCrypt(sharedKey, salt), ) defer client.Close() - initialBackoff := time.Second * 10 + initialBackoff := time.Second * 2 backoff := initialBackoff sleep := func() { diff --git a/cmd/server/main.go b/cmd/server/main.go index 22b60fd..3dfa2d5 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -2,12 +2,9 @@ package main import ( "context" - "net" + "flag" "net/http" - "net/http/httputil" - "net/url" "strings" - "time" "cdr.dev/slog" "forge.cadoles.com/wpetit/go-tunnel" @@ -15,20 +12,61 @@ import ( "gitlab.com/wpetit/goweb/logger" ) -const sharedKey = "go-tunnel" const salt = "go-tunnel" var registry = NewRegistry() func main() { + var ( + serverAddr = ":36543" + httpAddr = ":3003" + sharedKey = "go-tunnel" + targetURL = "https://arcad.games" + ) + + flag.StringVar(&serverAddr, "server-addr", serverAddr, "server address") + flag.StringVar(&targetURL, "target-url", targetURL, "target url") + flag.StringVar(&httpAddr, "http-addr", httpAddr, "http server address") + flag.StringVar(&sharedKey, "shared-key", sharedKey, "shared key") + + flag.Parse() + ctx := context.Background() logger.SetLevel(slog.LevelDebug) server := tunnel.NewServer( - tunnel.WithServerAESBlockCrypt(sharedKey, salt), + tunnel.WithServerAddress(serverAddr), tunnel.WithServerOnClientAuth(registry.OnClientAuth), tunnel.WithServerOnClientDisconnect(registry.OnClientDisconnect), + tunnel.WithServerOnClientAuth(func(ctx context.Context, remoteClient *tunnel.RemoteClient, credentials interface{}) (bool, error) { + remoteAddr := remoteClient.RemoteAddr().String() + + ctx = logger.With(ctx, logger.F("remoteAddr", remoteAddr)) + + logger.Debug(ctx, "new client auth") + + clientID, ok := credentials.(string) + if !ok { + logger.Debug(ctx, "client auth failed") + + return false, nil + } + + registry.Add(clientID, remoteAddr, remoteClient) + + logger.Debug(ctx, "client auth success") + + return true, nil + }), + tunnel.WithServerOnClientDisconnect(func(ctx context.Context, remoteClient *tunnel.RemoteClient) error { + remoteAddr := remoteClient.RemoteAddr().String() + ctx = logger.With(ctx, logger.F("remoteAddr", remoteAddr)) + logger.Debug(ctx, "client disconnected") + registry.RemoveByRemoteAddr(remoteAddr) + + return nil + }), ) go func() { @@ -37,51 +75,28 @@ func main() { } }() - if err := http.ListenAndServe(":3003", http.HandlerFunc(handleRequest)); err != nil { + handler, err := createProxyHandler(targetURL) + if err != nil { + logger.Fatal(ctx, "could not create proxy handler", logger.E(errors.WithStack(err))) + } + + if err := http.ListenAndServe(httpAddr, handler); err != nil { logger.Fatal(ctx, "error while listening", logger.E(err)) } } -func handleRequest(w http.ResponseWriter, r *http.Request) { - subdomains := strings.SplitN(r.Host, ".", 2) +func createProxyHandler(targetURL string) (http.Handler, error) { + return tunnel.ProxyHandler(targetURL, func(w http.ResponseWriter, r *http.Request) (*tunnel.RemoteClient, error) { + subdomains := strings.SplitN(r.Host, ".", 2) - if len(subdomains) < 2 { - http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) + if len(subdomains) < 2 { + http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) - return - } + return nil, tunnel.ErrAbortProxy + } - clientID := subdomains[0] - remoteClient := registry.Get(clientID) + clientID := subdomains[0] - if remoteClient == nil { - http.Error(w, http.StatusText(http.StatusBadGateway), http.StatusBadGateway) - - return - } - - target, err := url.Parse("https://arcad.games") - if err != nil { - logger.Fatal(r.Context(), "could not parse url", logger.E(err)) - } - - reverse := httputil.NewSingleHostReverseProxy(target) - reverse.Transport = &http.Transport{ - Proxy: http.ProxyFromEnvironment, - DialContext: func(ctx context.Context, network string, addr string) (net.Conn, error) { - conn, err := remoteClient.Proxy(ctx, network, addr) - if err != nil { - return nil, errors.WithStack(err) - } - - return conn, nil - }, - ForceAttemptHTTP2: true, - MaxIdleConns: 100, - IdleConnTimeout: 90 * time.Second, - TLSHandshakeTimeout: 10 * time.Second, - ExpectContinueTimeout: 1 * time.Second, - } - - reverse.ServeHTTP(w, r) + return registry.Get(clientID), nil + }) } diff --git a/control/auth.go b/control/auth.go deleted file mode 100644 index b965617..0000000 --- a/control/auth.go +++ /dev/null @@ -1,9 +0,0 @@ -package control - -type AuthRequestPayload struct { - Credentials interface{} `json:"c"` -} - -type AuthResponsePayload struct { - Success bool `json:"s"` -} diff --git a/control/control.go b/control/control.go deleted file mode 100644 index 6e56806..0000000 --- a/control/control.go +++ /dev/null @@ -1,194 +0,0 @@ -package control - -import ( - "context" - "encoding/json" - - "github.com/pkg/errors" - "github.com/xtaci/smux" - "gitlab.com/wpetit/goweb/logger" -) - -type Control struct { - encoder *json.Encoder - decoder *json.Decoder - stream *smux.Stream -} - -func (c *Control) Init(ctx context.Context, sess *smux.Session, serverMode bool) error { - config := smux.DefaultConfig() - config.Version = 2 - - logger.Debug(ctx, "creating control stream") - - var ( - stream *smux.Stream - err error - ) - - if serverMode { - stream, err = sess.AcceptStream() - if err != nil { - return errors.WithStack(err) - } - } else { - stream, err = sess.OpenStream() - if err != nil { - return errors.WithStack(err) - } - } - - c.stream = stream - c.decoder = json.NewDecoder(stream) - c.encoder = json.NewEncoder(stream) - - return nil -} - -func (c *Control) AuthRequest(credentials interface{}) (bool, error) { - req := NewMessage(TypeAuthRequest, &AuthRequestPayload{ - Credentials: credentials, - }) - - res := NewMessage(TypeAuthResponse, nil) - - if err := c.reqRes(req, res); err != nil { - return false, errors.WithStack(err) - } - - authResPayload, ok := res.Payload.(*AuthResponsePayload) - if !ok { - return false, errors.WithStack(ErrUnexpectedMessage) - } - - return authResPayload.Success, nil -} - -func (c *Control) ProxyReq(ctx context.Context, network, address string) error { - req := NewMessage(TypeProxyRequest, &ProxyRequestPayload{ - Network: network, - Address: address, - }) - - if err := c.Write(req); err != nil { - return errors.WithStack(err) - } - - return nil -} - -func (c *Control) Listen(ctx context.Context, handlers Handlers) error { - errChan := make(chan error) - msgChan := make(chan *Message) - dieChan := c.stream.GetDieCh() - - go func(msgChan chan *Message, errChan chan error) { - for { - logger.Debug(ctx, "reading next message") - - msg, err := c.Read() - if err != nil { - errChan <- errors.WithStack(err) - - close(errChan) - close(msgChan) - - return - } - - msgChan <- msg - } - }(msgChan, errChan) - - for { - select { - case <-ctx.Done(): - return nil - - case <-dieChan: - return errors.WithStack(ErrStreamClosed) - - case err := <-errChan: - return errors.WithStack(err) - - case msg := <-msgChan: - go func() { - subCtx := logger.With(ctx, logger.F("messageType", msg.Type)) - - handler, exists := handlers[msg.Type] - if !exists { - logger.Error(subCtx, "no message handler registered") - - return - } - - res, err := handler(subCtx, msg) - if err != nil { - logger.Error(subCtx, "error while handling message", logger.E(err)) - - return - } - - if res == nil { - return - } - - if err := c.Write(res); err != nil { - logger.Error(subCtx, "error while write message response", logger.E(err)) - - return - } - }() - } - } -} - -func (c *Control) Read() (*Message, error) { - message := &Message{} - - if err := c.read(message); err != nil { - return nil, errors.WithStack(err) - } - - return message, nil -} - -func (c *Control) Write(m *Message) error { - if err := c.write(m); err != nil { - return errors.WithStack(err) - } - - return nil -} - -func (c *Control) reqRes(req *Message, res *Message) error { - if err := c.write(req); err != nil { - return errors.WithStack(err) - } - - if err := c.read(res); err != nil { - return errors.WithStack(err) - } - - return nil -} - -func (c *Control) read(m *Message) error { - if err := c.decoder.Decode(m); err != nil { - return errors.WithStack(err) - } - - return nil -} - -func (c *Control) write(m *Message) error { - if err := c.encoder.Encode(m); err != nil { - return errors.WithStack(err) - } - - return nil -} - -func New() *Control { - return &Control{} -} diff --git a/control/error.go b/control/error.go deleted file mode 100644 index fc6bee1..0000000 --- a/control/error.go +++ /dev/null @@ -1,8 +0,0 @@ -package control - -import "errors" - -var ( - ErrStreamClosed = errors.New("stream closed") - ErrUnexpectedMessage = errors.New("unexpected message") -) diff --git a/control/handler.go b/control/handler.go deleted file mode 100644 index 0028c0b..0000000 --- a/control/handler.go +++ /dev/null @@ -1,7 +0,0 @@ -package control - -import "context" - -type Handlers map[MessageType]MessageHandler - -type MessageHandler func(ctx context.Context, m *Message) (*Message, error) diff --git a/control/message.go b/control/message.go deleted file mode 100644 index 008f175..0000000 --- a/control/message.go +++ /dev/null @@ -1,76 +0,0 @@ -package control - -import ( - "encoding/json" - - "github.com/pkg/errors" -) - -const ( - TypeAuthRequest MessageType = "auth-req" - TypeAuthResponse MessageType = "auth-res" - TypeProxyRequest MessageType = "proxy-req" - TypeCloseProxy MessageType = "close-proxy" -) - -type MessageType string - -type BaseMessage struct { - Type MessageType `json:"t"` - RawPayload json.RawMessage `json:"p"` -} - -type Message struct { - BaseMessage - Payload interface{} `json:"p"` -} - -func (m *Message) UnmarshalJSON(data []byte) error { - base := &BaseMessage{} - - if err := json.Unmarshal(data, base); err != nil { - return errors.WithStack(err) - } - - payload, err := unmarshalPayload(base.Type, base.RawPayload) - if err != nil { - return errors.WithStack(err) - } - - m.Type = base.Type - m.Payload = payload - - return nil -} - -func NewMessage(mType MessageType, payload interface{}) *Message { - return &Message{ - BaseMessage: BaseMessage{ - Type: mType, - }, - Payload: payload, - } -} - -func unmarshalPayload(mType MessageType, data []byte) (interface{}, error) { - var payload interface{} - - switch mType { - case TypeAuthRequest: - payload = &AuthRequestPayload{} - case TypeAuthResponse: - payload = &AuthResponsePayload{} - case TypeProxyRequest: - payload = &ProxyRequestPayload{} - case TypeCloseProxy: - payload = &CloseProxyPayload{} - default: - return nil, errors.Wrapf(ErrUnexpectedMessage, "unexpected message type '%s'", mType) - } - - if err := json.Unmarshal(data, payload); err != nil { - return nil, errors.WithStack(err) - } - - return payload, nil -} diff --git a/control/proxy.go b/control/proxy.go deleted file mode 100644 index 9685d6b..0000000 --- a/control/proxy.go +++ /dev/null @@ -1,10 +0,0 @@ -package control - -type ProxyRequestPayload struct { - Network string `json:"n"` - Address string `json:"a"` -} - -type CloseProxyPayload struct { - RequestID int64 `json:"i"` -} diff --git a/error.go b/error.go index 31af0a7..75f0e44 100644 --- a/error.go +++ b/error.go @@ -3,10 +3,10 @@ package tunnel import "errors" var ( - ErrNotConnected = errors.New("not connected") - ErrCouldNotConnect = errors.New("could not connect") - ErrConnectionClosed = errors.New("connection closed") - ErrAuthFailed = errors.New("auth failed") - ErrUnexpectedMessage = errors.New("unexpected message") - ErrUnexpectedResponse = errors.New("unexpected response") + ErrNotConnected = errors.New("not connected") + ErrCouldNotConnect = errors.New("could not connect") + ErrConnectionClosed = errors.New("connection closed") + ErrAuthenticationFailed = errors.New("authentication failed") + ErrUnexpectedMessage = errors.New("unexpected message") + ErrUnexpectedResponse = errors.New("unexpected response") ) diff --git a/go.mod b/go.mod index dd0924e..428fbe3 100644 --- a/go.mod +++ b/go.mod @@ -4,6 +4,7 @@ go 1.15 require ( cdr.dev/slog v1.3.0 + github.com/davecgh/go-spew v1.1.1 github.com/orcaman/concurrent-map v0.0.0-20190826125027-8c72a8bb44f6 github.com/pkg/errors v0.9.1 github.com/streamrail/concurrent-map v0.0.0-20160823150647-8bf1e9bacbf6 diff --git a/http.go b/http.go new file mode 100644 index 0000000..96f4825 --- /dev/null +++ b/http.go @@ -0,0 +1,102 @@ +package tunnel + +import ( + "context" + "net" + "net/http" + "net/http/httputil" + "net/url" + "time" + + "github.com/pkg/errors" + "gitlab.com/wpetit/goweb/logger" +) + +type contextKey string + +const remoteClientKey contextKey = "go-tunnel.remoteclient" + +var ( + ErrAbortProxy = errors.New("proxy aborted") +) + +type MatchRequestFunc func(w http.ResponseWriter, r *http.Request) (*RemoteClient, error) + +func ProxyHandler(targetURL string, match MatchRequestFunc, funcs ...ProxyConfigFunc) (http.Handler, error) { + conf := DefaultProxyConfig() + + for _, fn := range funcs { + fn(conf) + } + + target, err := url.Parse(targetURL) + if err != nil { + return nil, errors.WithStack(err) + } + + reverse := createReverseProxy(target) + + if conf.ConfigureReverseProxy != nil { + if err := conf.ConfigureReverseProxy(reverse); err != nil { + return nil, errors.WithStack(err) + } + } + + fn := func(w http.ResponseWriter, r *http.Request) { + remoteClient, err := match(w, r) + if errors.Is(err, ErrAbortProxy) { + return + } + + if err != nil { + logger.Error(r.Context(), "could not match proxy request", logger.E(errors.WithStack(err))) + http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) + + return + } + + if remoteClient == nil { + http.Error(w, http.StatusText(http.StatusBadGateway), http.StatusBadGateway) + + return + } + + ctx := context.WithValue(r.Context(), remoteClientKey, remoteClient) + r = r.WithContext(ctx) + + reverse.ServeHTTP(w, r) + } + + return http.HandlerFunc(fn), nil +} + +func createReverseProxy(target *url.URL) *httputil.ReverseProxy { + reverse := httputil.NewSingleHostReverseProxy(target) + + // nolint: go-mnd + reverse.Transport = &http.Transport{ + Proxy: http.ProxyFromEnvironment, + DialContext: func(ctx context.Context, network string, addr string) (net.Conn, error) { + remoteClient, ok := ctx.Value(remoteClientKey).(*RemoteClient) + if !ok { + return nil, errors.New("could not retrieve remote client") + } + + conn, err := remoteClient.Proxy(ctx, network, addr) + if err != nil { + return nil, errors.WithStack(err) + } + + return conn, nil + }, + ForceAttemptHTTP2: true, + MaxIdleConns: 100, + IdleConnTimeout: 90 * time.Second, + TLSHandshakeTimeout: 10 * time.Second, + ExpectContinueTimeout: 1 * time.Second, + } + + reverse.FlushInterval = 0 + + return reverse +} diff --git a/http_config.go b/http_config.go new file mode 100644 index 0000000..ca7deeb --- /dev/null +++ b/http_config.go @@ -0,0 +1,21 @@ +package tunnel + +import "net/http/httputil" + +type ConfigureReverseProxyFunc func(*httputil.ReverseProxy) error + +type ProxyConfig struct { + ConfigureReverseProxy ConfigureReverseProxyFunc +} + +func DefaultProxyConfig() *ProxyConfig { + return &ProxyConfig{} +} + +type ProxyConfigFunc func(c *ProxyConfig) + +func WithProxyConfigure(fn ConfigureReverseProxyFunc) ProxyConfigFunc { + return func(c *ProxyConfig) { + c.ConfigureReverseProxy = fn + } +} diff --git a/modd.conf b/modd.conf index 2991f97..d9af6fd 100644 --- a/modd.conf +++ b/modd.conf @@ -1,7 +1,7 @@ -**/*.go { +**/*.go +modd.conf { prep: make test prep: make build - daemon: ./bin/server + daemon: ./bin/server -target-url http://127.0.0.1:3000 daemon: ./bin/client -id client1 - daemon: ./bin/client -id client2 } \ No newline at end of file diff --git a/protocol.go b/protocol.go new file mode 100644 index 0000000..7fac5eb --- /dev/null +++ b/protocol.go @@ -0,0 +1,14 @@ +package tunnel + +type authRequest struct { + Credentials interface{} `json:"c"` +} + +type authResponse struct { + Success bool `json:"b"` +} + +type proxyRequest struct { + Network string `json:"n"` + Address string `json:"a"` +} diff --git a/remote_client.go b/remote_client.go index ecfe9fa..52b9944 100644 --- a/remote_client.go +++ b/remote_client.go @@ -2,13 +2,11 @@ package tunnel import ( "context" + "encoding/json" "net" "sync" "time" - "forge.cadoles.com/wpetit/go-tunnel/control" - - cmap "github.com/orcaman/concurrent-map" "github.com/pkg/errors" "github.com/xtaci/kcp-go/v5" "github.com/xtaci/smux" @@ -19,61 +17,45 @@ type RemoteClient struct { onClientAuthHook OnClientAuthHook onClientConnectHook OnClientConnectHook onClientDisconnectHook OnClientDisconnectHook + conn *kcp.UDPSession sess *smux.Session - control *control.Control remoteAddr net.Addr - proxies cmap.ConcurrentMap - acceptStreamMutex sync.Mutex + authenticationTimeout time.Duration + proxyRequestTimeout time.Duration + connMutex sync.RWMutex + smuxConfig *smux.Config } func (c *RemoteClient) Accept(ctx context.Context, conn *kcp.UDPSession) error { - config := smux.DefaultConfig() - config.Version = 2 - config.KeepAliveInterval = 10 * time.Second - config.KeepAliveTimeout = 2 * config.KeepAliveInterval + c.connMutex.Lock() + defer c.connMutex.Unlock() - logger.Debug(ctx, "creating server session") + if err := c.Close(); err != nil { + return errors.WithStack(err) + } - sess, err := smux.Server(conn, config) + sess, err := c.acceptSession(ctx, conn) if err != nil { return errors.WithStack(err) } - ctrl := control.New() + stream, err := sess.AcceptStream() + if err != nil { + return errors.WithStack(err) + } - if err := ctrl.Init(ctx, sess, true); err != nil { + defer stream.Close() + + if err := c.authenticate(ctx, stream); err != nil { return errors.WithStack(err) } c.sess = sess - c.remoteAddr = conn.RemoteAddr() - c.control = ctrl - - if c.onClientConnectHook != nil { - if err := c.onClientConnectHook.OnClientConnect(ctx, c); err != nil { - return errors.WithStack(err) - } - } + c.conn = conn return nil } -func (c *RemoteClient) Listen(ctx context.Context) error { - defer func() { - if c.onClientDisconnectHook != nil { - if err := c.onClientDisconnectHook.OnClientDisconnect(ctx, c); err != nil { - logger.Error(ctx, "client disconnect hook error", logger.E(errors.WithStack(err))) - } - } - }() - - logger.Debug(ctx, "listening for messages") - - return c.control.Listen(ctx, control.Handlers{ - control.TypeAuthRequest: c.handleAuthRequest, - }) -} - func (c *RemoteClient) ConfigureHooks(hooks interface{}) { if hooks == nil { return @@ -96,74 +78,164 @@ func (c *RemoteClient) RemoteAddr() net.Addr { return c.remoteAddr } -func (c *RemoteClient) Close() { +func (c *RemoteClient) Close() error { if c.sess != nil { - c.sess.Close() + if err := c.sess.Close(); err != nil { + return errors.WithStack(err) + } + } + + if c.conn != nil { + if err := c.conn.Close(); err != nil { + return errors.WithStack(err) + } } c.sess = nil - c.control = nil + c.conn = nil + + return nil +} + +func (c *RemoteClient) SwitchConn(ctx context.Context, conn *kcp.UDPSession) error { + c.connMutex.Lock() + defer c.connMutex.Unlock() + + if err := c.Close(); err != nil { + return errors.WithStack(err) + } + + sess, err := c.acceptSession(ctx, conn) + if err != nil { + return errors.WithStack(err) + } + + c.sess = sess + c.conn = conn + + return nil } func (c *RemoteClient) Proxy(ctx context.Context, network, address string) (net.Conn, error) { - ctx = logger.With(ctx, logger.F("network", network), logger.F("address", address)) + c.connMutex.RLock() + defer c.connMutex.RUnlock() - if err := c.control.ProxyReq(ctx, network, address); err != nil { - return nil, errors.WithStack(err) - } + ctx = logger.With(ctx, logger.F("network", network), logger.F("address", address)) logger.Debug(ctx, "opening proxy stream") - c.acceptStreamMutex.Lock() - - stream, err := c.sess.AcceptStream() + stream, err := c.sess.OpenStream() if err != nil { - c.acceptStreamMutex.Unlock() return nil, errors.WithStack(err) } - c.acceptStreamMutex.Unlock() + proxyReq := &proxyRequest{ + Network: network, + Address: address, + } + encoder := json.NewEncoder(stream) - go func() { - <-ctx.Done() - logger.Debug(ctx, "closing proxy stream") + writeDeadline := time.Now().Add(c.proxyRequestTimeout) + logger.Debug(ctx, "sending proxy req", logger.F("deadline", writeDeadline)) + + if err := stream.SetWriteDeadline(writeDeadline); err != nil { stream.Close() - }() + + return nil, errors.WithStack(err) + } + + if err := encoder.Encode(proxyReq); err != nil { + stream.Close() + + return nil, errors.WithStack(err) + } + + if err := stream.SetWriteDeadline(time.Time{}); err != nil { + stream.Close() + + return nil, errors.WithStack(err) + } return stream, nil } -func (c *RemoteClient) handleAuthRequest(ctx context.Context, m *control.Message) (*control.Message, error) { - authReqPayload, ok := m.Payload.(*control.AuthRequestPayload) - if !ok { - return nil, errors.WithStack(ErrUnexpectedMessage) +func (c *RemoteClient) acceptSession(ctx context.Context, conn *kcp.UDPSession) (*smux.Session, error) { + logger.Debug(ctx, "accepting client session") + + sess, err := smux.Server(conn, c.smuxConfig) + if err != nil { + return nil, errors.WithStack(err) } - logger.Debug(ctx, "handling auth request", logger.F("credentials", authReqPayload.Credentials)) + c.remoteAddr = conn.RemoteAddr() + + if c.onClientConnectHook != nil { + if err := c.onClientConnectHook.OnClientConnect(ctx, c); err != nil { + return nil, errors.WithStack(err) + } + } + + return sess, nil +} + +func (c *RemoteClient) authenticate(ctx context.Context, stream *smux.Stream) error { + start := time.Now() + + readDeadline := time.Now().Add(c.authenticationTimeout) + logger.Debug(ctx, "waiting for auth request", logger.F("deadline", readDeadline)) + + if err := stream.SetReadDeadline(readDeadline); err != nil { + return errors.WithStack(err) + } + + decoder := json.NewDecoder(stream) + authReq := &authRequest{} + + if err := decoder.Decode(authReq); err != nil { + return errors.WithStack(err) + } var ( success bool err error ) + logger.Debug(ctx, "received client credentials", logger.F("credentials", authReq.Credentials)) + if c.onClientAuthHook != nil { - success, err = c.onClientAuthHook.OnClientAuth(ctx, c, authReqPayload.Credentials) + success, err = c.onClientAuthHook.OnClientAuth(ctx, c, authReq.Credentials) if err != nil { - return nil, errors.WithStack(err) + return errors.WithStack(err) } } - logger.Debug(ctx, "auth succeeded", logger.F("credentials", authReqPayload.Credentials)) - - res := control.NewMessage(control.TypeAuthResponse, &control.AuthResponsePayload{ + authRes := &authResponse{ Success: success, - }) + } + encoder := json.NewEncoder(stream) - return res, nil + writeDeadline := time.Now().Add(c.authenticationTimeout - time.Since(start)) + logger.Debug(ctx, "sending auth response", logger.F("deadline", writeDeadline)) + + if err := stream.SetWriteDeadline(writeDeadline); err != nil { + return errors.WithStack(err) + } + + if err := encoder.Encode(authRes); err != nil { + return errors.WithStack(err) + } + + if !success { + return errors.WithStack(ErrAuthenticationFailed) + } + + return nil } -func NewRemoteClient() *RemoteClient { +func NewRemoteClient(smuxConfig *smux.Config, authenticationTimeout, proxyRequestTimeout time.Duration) *RemoteClient { return &RemoteClient{ - proxies: cmap.New(), + smuxConfig: smuxConfig, + authenticationTimeout: authenticationTimeout, + proxyRequestTimeout: proxyRequestTimeout, } } diff --git a/server.go b/server.go index 7af7635..25e3eca 100644 --- a/server.go +++ b/server.go @@ -3,8 +3,8 @@ package tunnel import ( "context" - cmap "github.com/orcaman/concurrent-map" "github.com/pkg/errors" + cmap "github.com/streamrail/concurrent-map" "github.com/xtaci/kcp-go/v5" "gitlab.com/wpetit/goweb/logger" ) @@ -23,6 +23,14 @@ func (s *Server) Listen(ctx context.Context) error { return errors.WithStack(err) } + if s.conf.ConfigureListener != nil { + if err := s.conf.ConfigureListener(listener); err != nil { + return errors.WithStack(err) + } + } + + logger.Debug(ctx, "accepting connections", logger.F("address", s.conf.Address)) + for { conn, err := listener.AcceptKCP() if err != nil { @@ -34,12 +42,31 @@ func (s *Server) Listen(ctx context.Context) error { } func (s *Server) handleNewConn(ctx context.Context, conn *kcp.UDPSession) { - ctx = logger.With(ctx, logger.F("remoteAddr", conn.RemoteAddr().String())) + var remoteClient *RemoteClient - remoteClient := NewRemoteClient() + remoteAddr := conn.RemoteAddr().String() + ctx = logger.With(ctx, logger.F("remoteAddr", remoteAddr)) - defer remoteClient.Close() - defer conn.Close() + rawExistingClient, exists := s.clients.Get(remoteAddr) + if exists { + logger.Debug(ctx, "remote client already exists") + + remoteClient, _ = rawExistingClient.(*RemoteClient) + + if err := remoteClient.SwitchConn(ctx, conn); err != nil { + logger.Error(ctx, "remote client error", logger.E(errors.WithStack(err))) + + s.clients.Remove(remoteAddr) + + return + } + } + + remoteClient = NewRemoteClient( + s.conf.SmuxConfig, + s.conf.AuthenticationTimeout, + s.conf.ProxyRequestTimeout, + ) remoteClient.ConfigureHooks(s.conf.Hooks) @@ -49,9 +76,7 @@ func (s *Server) handleNewConn(ctx context.Context, conn *kcp.UDPSession) { return } - if err := remoteClient.Listen(ctx); err != nil { - logger.Error(ctx, "remote client error", logger.E(errors.WithStack(err))) - } + s.clients.Set(remoteAddr, remoteClient) } func NewServer(funcs ...ServerConfigFunc) *Server { diff --git a/server_config.go b/server_config.go index 8cf3450..cde9294 100644 --- a/server_config.go +++ b/server_config.go @@ -2,29 +2,43 @@ package tunnel import ( "crypto/sha1" + "time" "github.com/pkg/errors" "github.com/xtaci/kcp-go/v5" + "github.com/xtaci/smux" "golang.org/x/crypto/pbkdf2" ) type ConfigureConnFunc func(conn *kcp.UDPSession) error +type ConfigureListenerFunc func(listener *kcp.Listener) error type ServerConfig struct { - Address string - BlockCrypt kcp.BlockCrypt - DataShards int - ParityShards int - Hooks *ServerHooks - ConfigureConn ConfigureConnFunc + Address string + BlockCrypt kcp.BlockCrypt + DataShards int + ParityShards int + Hooks *ServerHooks + ConfigureConn ConfigureConnFunc + ConfigureListener ConfigureListenerFunc + AuthenticationTimeout time.Duration + ProxyRequestTimeout time.Duration + SmuxConfig *smux.Config } +// nolint: go-mnd func DefaultServerConfig() *ServerConfig { unencryptedBlock, err := kcp.NewNoneBlockCrypt(nil) if err != nil { // should never happen panic(errors.WithStack(err)) } + smuxConfig := smux.DefaultConfig() + smuxConfig.Version = 2 + smuxConfig.KeepAliveInterval = 10 * time.Second + smuxConfig.MaxReceiveBuffer = 4194304 + smuxConfig.MaxStreamBuffer = 2097152 + return &ServerConfig{ Address: ":36543", BlockCrypt: unencryptedBlock, @@ -35,6 +49,11 @@ func DefaultServerConfig() *ServerConfig { onClientDisconnect: DefaultOnClientDisconnect, onClientAuth: DefaultOnClientAuth, }, + ConfigureConn: DefaultServerConfigureConn, + ConfigureListener: DefaultServerConfigureListener, + AuthenticationTimeout: 30 * time.Second, + ProxyRequestTimeout: 5 * time.Second, + SmuxConfig: smuxConfig, } } @@ -82,3 +101,50 @@ func WithServerConfigureConn(fn ConfigureConnFunc) ServerConfigFunc { conf.ConfigureConn = fn } } + +func WithServerConfigureListener(fn ConfigureListenerFunc) ServerConfigFunc { + return func(conf *ServerConfig) { + conf.ConfigureListener = fn + } +} + +func WithServerSmuxConfig(c *smux.Config) ServerConfigFunc { + return func(conf *ServerConfig) { + conf.SmuxConfig = c + } +} + +// nolint: go-mnd +func DefaultServerConfigureConn(conn *kcp.UDPSession) error { + // Based on kcptun default configuration, mode 'fast3' + conn.SetStreamMode(true) + conn.SetWriteDelay(false) + conn.SetNoDelay(1, 10, 2, 1) + conn.SetWindowSize(128, 512) + conn.SetMtu(1400) + conn.SetACKNoDelay(true) + + if err := conn.SetDSCP(46); err != nil { + return errors.WithStack(err) + } + + return nil +} + +// nolint: go-mnd +func DefaultServerConfigureListener(listener *kcp.Listener) error { + // Based on kcptun default configuration, mode 'fast3' + if err := listener.SetReadBuffer(16777217); err != nil { + return errors.WithStack(err) + } + + if err := listener.SetWriteBuffer(16777217); err != nil { + return errors.WithStack(err) + } + + if err := listener.SetDSCP(46); err != nil { + return errors.WithStack(err) + } + + return nil +}