diff --git a/client.go b/client.go index a69ff68..ddf886f 100644 --- a/client.go +++ b/client.go @@ -5,13 +5,10 @@ import ( "io" "net" "net/http" - "os" - "strconv" "gitlab.com/wpetit/goweb/logger" "forge.cadoles.com/wpetit/go-tunnel/control" - cmap "github.com/orcaman/concurrent-map" "github.com/pkg/errors" "github.com/xtaci/kcp-go/v5" "github.com/xtaci/smux" @@ -23,7 +20,6 @@ type Client struct { sess *smux.Session control *control.Control http *http.Client - proxies cmap.ConcurrentMap } func (c *Client) Connect(ctx context.Context) error { @@ -49,18 +45,14 @@ func (c *Client) Connect(ctx context.Context) error { return errors.WithStack(err) } - controlStream, err := sess.OpenStream() - if err != nil { + control := control.New() + if err := control.Init(ctx, sess, false); err != nil { return errors.WithStack(err) } - c.conn = conn - c.sess = sess - c.control = control.New(sess, controlStream) - logger.Debug(ctx, "sending auth request") - success, err := c.control.AuthRequest(c.conf.Credentials) + success, err := control.AuthRequest(c.conf.Credentials) if err != nil { return errors.WithStack(err) } @@ -70,6 +62,10 @@ func (c *Client) Connect(ctx context.Context) error { return errors.WithStack(ErrAuthFailed) } + c.control = control + c.conn = conn + c.sess = sess + return nil } @@ -78,7 +74,6 @@ func (c *Client) Listen(ctx context.Context) error { err := c.control.Listen(ctx, control.Handlers{ control.TypeProxyRequest: c.handleProxyRequest, - control.TypeCloseProxy: c.handleCloseProxy, }) if errors.Is(err, io.ErrClosedPipe) { @@ -99,107 +94,59 @@ func (c *Client) Close() error { return errors.WithStack(err) } - if c.sess != nil && !c.sess.IsClosed() { - if err := c.sess.Close(); err != nil { - return errors.WithStack(err) - } - } - return nil } -func (c *Client) handleCloseProxy(ctx context.Context, m *control.Message) (*control.Message, error) { - closeProxyPayload, ok := m.Payload.(*control.CloseProxyPayload) - if !ok { - return nil, errors.WithStack(ErrUnexpectedMessage) - } - - requestID := strconv.FormatInt(closeProxyPayload.RequestID, 10) - - rawCloseChan, exists := c.proxies.Get(requestID) - if !exists { - return nil, nil - } - - closeChan, ok := rawCloseChan.(chan struct{}) - if !ok { - return nil, nil - } - - closeChan <- struct{}{} - - return nil, nil -} - func (c *Client) handleProxyRequest(ctx context.Context, m *control.Message) (*control.Message, error) { proxyReqPayload, ok := m.Payload.(*control.ProxyRequestPayload) if !ok { return nil, errors.WithStack(ErrUnexpectedMessage) } - requestID := strconv.FormatInt(proxyReqPayload.RequestID, 10) - - ctx = logger.With(ctx, logger.F("requestID", requestID)) - - logger.Debug( - ctx, "handling proxy request", + ctx = logger.With(ctx, logger.F("network", proxyReqPayload.Network), logger.F("address", proxyReqPayload.Address), ) - stream, err := c.sess.OpenStream() + logger.Debug(ctx, "handling proxy request") + + out, err := net.Dial(proxyReqPayload.Network, proxyReqPayload.Address) if err != nil { return nil, errors.WithStack(err) } - closeChan := make(chan struct{}) - - go func() { - defer func() { - stream.Close() - logger.Debug(ctx, "proxy stream closed") - }() - - proxy := func() error { - net, err := net.Dial(proxyReqPayload.Network, proxyReqPayload.Address) - if err != nil { - return errors.WithStack(err) - } - defer net.Close() - - err = pipe(ctx, stream, net) - if errors.Is(err, os.ErrClosed) { - return nil - } - - if err != nil { - return errors.WithStack(err) - } - - return nil - } - - for { - select { - case <-closeChan: - return - default: - if err := proxy(); err != nil { - logger.Error(ctx, "error while proxying", logger.E(err)) - - continue - } - - return - } - } - }() - - c.proxies.Set(requestID, closeChan) + go c.handleProxyStream(ctx, out) return nil, nil } +func (c *Client) handleProxyStream(ctx context.Context, out net.Conn) { + in, err := c.sess.AcceptStream() + if err != nil { + logger.Error(ctx, "error while accepting proxy stream", logger.E(err)) + + return + } + + defer in.Close() + + streamCopy := func(dst io.Writer, src io.ReadCloser) { + if _, err := Copy(dst, src); err != nil { + if !errors.Is(err, io.ErrClosedPipe) && !errors.Is(err, io.EOF) { + logger.Error(ctx, "error while proxying", logger.E(err)) + } + } + + logger.Debug(ctx, "closing proxy stream") + + in.Close() + out.Close() + } + + go streamCopy(out, in) + streamCopy(in, out) +} + func NewClient(funcs ...ClientConfigFunc) *Client { conf := DefaultClientConfig() @@ -208,8 +155,7 @@ func NewClient(funcs ...ClientConfigFunc) *Client { } return &Client{ - conf: conf, - http: &http.Client{}, - proxies: cmap.New(), + conf: conf, + http: &http.Client{}, } } diff --git a/cmd/server/main.go b/cmd/server/main.go index 28aaeef..22b60fd 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -60,7 +60,7 @@ func handleRequest(w http.ResponseWriter, r *http.Request) { return } - target, err := url.Parse("http://localhost:3000") + target, err := url.Parse("https://arcad.games") if err != nil { logger.Fatal(r.Context(), "could not parse url", logger.E(err)) } diff --git a/control/control.go b/control/control.go index 820342c..7594e8a 100644 --- a/control/control.go +++ b/control/control.go @@ -3,9 +3,6 @@ package control import ( "context" "encoding/json" - "net" - "sync/atomic" - "time" "github.com/pkg/errors" "github.com/xtaci/smux" @@ -13,11 +10,37 @@ import ( ) type Control struct { - encoder *json.Encoder - decoder *json.Decoder - stream *smux.Stream - sess *smux.Session - proxyClock int64 + encoder *json.Encoder + decoder *json.Decoder +} + +func (c *Control) Init(ctx context.Context, sess *smux.Session, serverMode bool) error { + config := smux.DefaultConfig() + config.Version = 2 + + logger.Debug(ctx, "creating control stream") + + var ( + controlStream *smux.Stream + err error + ) + + if serverMode { + controlStream, err = sess.AcceptStream() + if err != nil { + return errors.WithStack(err) + } + } else { + controlStream, err = sess.OpenStream() + if err != nil { + return errors.WithStack(err) + } + } + + c.decoder = json.NewDecoder(controlStream) + c.encoder = json.NewEncoder(controlStream) + + return nil } func (c *Control) AuthRequest(credentials interface{}) (bool, error) { @@ -39,51 +62,17 @@ func (c *Control) AuthRequest(credentials interface{}) (bool, error) { return authResPayload.Success, nil } -func (c *Control) Proxy(ctx context.Context, network, address string) (net.Conn, error) { - var ( - stream *smux.Stream - err error - ) - - requestID := atomic.AddInt64(&c.proxyClock, 1) - +func (c *Control) ProxyReq(ctx context.Context, network, address string) error { req := NewMessage(TypeProxyRequest, &ProxyRequestPayload{ - RequestID: requestID, - Network: network, - Address: address, + Network: network, + Address: address, }) - ctx = logger.With(ctx, logger.F("network", network), logger.F("address", address)) - - logger.Debug(ctx, "proxying") - if err := c.Write(req); err != nil { - return nil, errors.WithStack(err) + return errors.WithStack(err) } - logger.Debug(ctx, "opening stream") - - stream, err = c.sess.AcceptStream() - if err != nil { - return nil, errors.WithStack(err) - } - - go func() { - <-ctx.Done() - - req := NewMessage(TypeCloseProxy, &CloseProxyPayload{ - RequestID: requestID, - }) - - if err := c.Write(req); err != nil { - logger.Error(ctx, "error while closing proxy", logger.E(err)) - } - - logger.Debug(ctx, "closing proxy conn") - stream.Close() - }() - - return stream, nil + return nil } func (c *Control) Listen(ctx context.Context, handlers Handlers) error { @@ -164,10 +153,6 @@ func (c *Control) read(m *Message) error { } func (c *Control) write(m *Message) error { - if err := c.stream.SetWriteDeadline(time.Now().Add(time.Second)); err != nil { - return errors.WithStack(err) - } - if err := c.encoder.Encode(m); err != nil { return errors.WithStack(err) } @@ -175,11 +160,6 @@ func (c *Control) write(m *Message) error { return nil } -func New(sess *smux.Session, controlStream *smux.Stream) *Control { - return &Control{ - encoder: json.NewEncoder(controlStream), - decoder: json.NewDecoder(controlStream), - sess: sess, - stream: controlStream, - } +func New() *Control { + return &Control{} } diff --git a/control/proxy.go b/control/proxy.go index cda1f15..9685d6b 100644 --- a/control/proxy.go +++ b/control/proxy.go @@ -1,9 +1,8 @@ package control type ProxyRequestPayload struct { - RequestID int64 `json:"i"` - Network string `json:"n"` - Address string `json:"a"` + Network string `json:"n"` + Address string `json:"a"` } type CloseProxyPayload struct { diff --git a/helper.go b/helper.go index 2e822da..845850d 100644 --- a/helper.go +++ b/helper.go @@ -1,52 +1,30 @@ package tunnel import ( - "context" "io" - "net" "github.com/pkg/errors" "github.com/xtaci/kcp-go/v5" - "gitlab.com/wpetit/goweb/logger" ) -func pipe(ctx context.Context, client net.Conn, server net.Conn) (err error) { - stop := make(chan bool) +const bufSize = 4096 - go func() { - err = relay(client, server, stop) - if err != nil { - err = errors.WithStack(err) - logger.Debug(ctx, "client->server error", logger.E(err)) - } - }() - go func() { - err = relay(server, client, stop) - if err != nil { - err = errors.WithStack(err) - logger.Debug(ctx, "server->client error", logger.E(err)) - } - }() - - select { - case <-stop: - return err +// From https://github.com/xtaci/kcptun/blob/master/generic/copy.go +// Copyright https://github.com/xtaci +func Copy(dst io.Writer, src io.Reader) (written int64, err error) { + // If the reader has a WriteTo method, use it to do the copy. + // Avoids an allocation and a copy. + if wt, ok := src.(io.WriterTo); ok { + return wt.WriteTo(dst) } -} - -func relay(src net.Conn, dst net.Conn, stop chan bool) (err error) { - _, err = io.Copy(dst, src) - if errors.Is(err, io.EOF) { - err = nil + // Similarly, if the writer has a ReadFrom method, use it to do the copy. + if rt, ok := dst.(io.ReaderFrom); ok { + return rt.ReadFrom(src) } - if err != nil { - err = errors.WithStack(err) - } - - stop <- true - - return + // fallback to standard io.CopyBuffer + buf := make([]byte, bufSize) + return io.CopyBuffer(dst, src, buf) } func createBlockCrypt(algorithm string, pass []byte) (kcp.BlockCrypt, error) { diff --git a/remote_client.go b/remote_client.go index f8a4cff..0c1132c 100644 --- a/remote_client.go +++ b/remote_client.go @@ -7,6 +7,7 @@ import ( "forge.cadoles.com/wpetit/go-tunnel/control" + cmap "github.com/orcaman/concurrent-map" "github.com/pkg/errors" "github.com/xtaci/kcp-go/v5" "github.com/xtaci/smux" @@ -17,20 +18,14 @@ type RemoteClient struct { onClientAuthHook OnClientAuthHook onClientConnectHook OnClientConnectHook onClientDisconnectHook OnClientDisconnectHook + conn net.Conn sess *smux.Session control *control.Control remoteAddr net.Addr + proxies cmap.ConcurrentMap } func (c *RemoteClient) Accept(ctx context.Context, conn *kcp.UDPSession) error { - c.remoteAddr = conn.RemoteAddr() - - if c.onClientConnectHook != nil { - if err := c.onClientConnectHook.OnClientConnect(ctx, c); err != nil { - return errors.WithStack(err) - } - } - config := smux.DefaultConfig() config.Version = 2 @@ -39,15 +34,22 @@ func (c *RemoteClient) Accept(ctx context.Context, conn *kcp.UDPSession) error { return errors.WithStack(err) } - logger.Debug(ctx, "accepting control stream") + control := control.New() - controlStream, err := sess.AcceptStream() - if err != nil { + if err := control.Init(ctx, sess, true); err != nil { return errors.WithStack(err) } c.sess = sess - c.control = control.New(sess, controlStream) + c.remoteAddr = conn.RemoteAddr() + c.control = control + c.conn = conn + + if c.onClientConnectHook != nil { + if err := c.onClientConnectHook.OnClientConnect(ctx, c); err != nil { + return errors.WithStack(err) + } + } return nil } @@ -72,6 +74,57 @@ func (c *RemoteClient) Listen(ctx context.Context) error { return err } +func (c *RemoteClient) ConfigureHooks(hooks interface{}) { + if hooks == nil { + return + } + + if onClientAuthHook, ok := hooks.(OnClientAuthHook); ok { + c.onClientAuthHook = onClientAuthHook + } + + if OnClientConnectHook, ok := hooks.(OnClientConnectHook); ok { + c.onClientConnectHook = OnClientConnectHook + } + + if OnClientDisconnectHook, ok := hooks.(OnClientDisconnectHook); ok { + c.onClientDisconnectHook = OnClientDisconnectHook + } +} + +func (c *RemoteClient) RemoteAddr() net.Addr { + return c.remoteAddr +} + +func (c *RemoteClient) Close() { + if c.conn != nil { + c.conn.Close() + } +} + +func (c *RemoteClient) Proxy(ctx context.Context, network, address string) (net.Conn, error) { + ctx = logger.With(ctx, logger.F("network", network), logger.F("address", address)) + + if err := c.control.ProxyReq(ctx, network, address); err != nil { + return nil, errors.WithStack(err) + } + + logger.Debug(ctx, "opening proxy stream") + + stream, err := c.sess.OpenStream() + if err != nil { + return nil, errors.WithStack(err) + } + + go func() { + <-ctx.Done() + logger.Debug(ctx, "closing proxy stream") + stream.Close() + }() + + return stream, nil +} + func (c *RemoteClient) handleAuthRequest(ctx context.Context, m *control.Message) (*control.Message, error) { authReqPayload, ok := m.Payload.(*control.AuthRequestPayload) if !ok { @@ -101,38 +154,8 @@ func (c *RemoteClient) handleAuthRequest(ctx context.Context, m *control.Message return res, nil } -func (c *RemoteClient) ConfigureHooks(hooks interface{}) { - if hooks == nil { - return - } - - if onClientAuthHook, ok := hooks.(OnClientAuthHook); ok { - c.onClientAuthHook = onClientAuthHook - } - - if OnClientConnectHook, ok := hooks.(OnClientConnectHook); ok { - c.onClientConnectHook = OnClientConnectHook - } - - if OnClientDisconnectHook, ok := hooks.(OnClientDisconnectHook); ok { - c.onClientDisconnectHook = OnClientDisconnectHook - } -} - -func (c *RemoteClient) RemoteAddr() net.Addr { - return c.remoteAddr -} - -func (c *RemoteClient) Proxy(ctx context.Context, network, address string) (net.Conn, error) { - return c.control.Proxy(ctx, network, address) -} - -func (c *RemoteClient) Close() { - if c.sess != nil && !c.sess.IsClosed() { - c.sess.Close() - } -} - func NewRemoteClient() *RemoteClient { - return &RemoteClient{} + return &RemoteClient{ + proxies: cmap.New(), + } }