From 536100da90dc413976d32fdb19de635f1ffac1eb Mon Sep 17 00:00:00 2001 From: William Petit Date: Sat, 24 Oct 2020 13:35:27 +0200 Subject: [PATCH] fix: enhance proxy stability --- client.go | 157 +++++++++++++++------------------------- cmd/server/main.go | 2 +- control/control.go | 177 ++++++++++++++++++++++++--------------------- control/error.go | 1 + control/proxy.go | 5 +- helper.go | 50 ++++--------- remote_client.go | 155 +++++++++++++++++++++++---------------- server.go | 6 +- 8 files changed, 265 insertions(+), 288 deletions(-) diff --git a/client.go b/client.go index a69ff68..b138ef7 100644 --- a/client.go +++ b/client.go @@ -5,25 +5,24 @@ import ( "io" "net" "net/http" - "os" - "strconv" + "sync" + "time" "gitlab.com/wpetit/goweb/logger" "forge.cadoles.com/wpetit/go-tunnel/control" - cmap "github.com/orcaman/concurrent-map" "github.com/pkg/errors" "github.com/xtaci/kcp-go/v5" "github.com/xtaci/smux" ) type Client struct { - conf *ClientConfig - conn *kcp.UDPSession - sess *smux.Session - control *control.Control - http *http.Client - proxies cmap.ConcurrentMap + conf *ClientConfig + conn *kcp.UDPSession + sess *smux.Session + control *control.Control + http *http.Client + openStreamMutex sync.Mutex } func (c *Client) Connect(ctx context.Context) error { @@ -43,24 +42,22 @@ func (c *Client) Connect(ctx context.Context) error { config := smux.DefaultConfig() config.Version = 2 + config.KeepAliveInterval = 10 * time.Second + config.KeepAliveTimeout = 2 * config.KeepAliveInterval sess, err := smux.Client(conn, config) if err != nil { return errors.WithStack(err) } - controlStream, err := sess.OpenStream() - if err != nil { + control := control.New() + if err := control.Init(ctx, sess, false); err != nil { return errors.WithStack(err) } - c.conn = conn - c.sess = sess - c.control = control.New(sess, controlStream) - logger.Debug(ctx, "sending auth request") - success, err := c.control.AuthRequest(c.conf.Credentials) + success, err := control.AuthRequest(c.conf.Credentials) if err != nil { return errors.WithStack(err) } @@ -70,15 +67,21 @@ func (c *Client) Connect(ctx context.Context) error { return errors.WithStack(ErrAuthFailed) } + c.control = control + c.conn = conn + c.sess = sess + return nil } func (c *Client) Listen(ctx context.Context) error { logger.Debug(ctx, "listening for messages") + ctx, cancel := context.WithCancel(ctx) + defer cancel() + err := c.control.Listen(ctx, control.Handlers{ control.TypeProxyRequest: c.handleProxyRequest, - control.TypeCloseProxy: c.handleCloseProxy, }) if errors.Is(err, io.ErrClosedPipe) { @@ -99,107 +102,62 @@ func (c *Client) Close() error { return errors.WithStack(err) } - if c.sess != nil && !c.sess.IsClosed() { - if err := c.sess.Close(); err != nil { - return errors.WithStack(err) - } - } - return nil } -func (c *Client) handleCloseProxy(ctx context.Context, m *control.Message) (*control.Message, error) { - closeProxyPayload, ok := m.Payload.(*control.CloseProxyPayload) - if !ok { - return nil, errors.WithStack(ErrUnexpectedMessage) - } - - requestID := strconv.FormatInt(closeProxyPayload.RequestID, 10) - - rawCloseChan, exists := c.proxies.Get(requestID) - if !exists { - return nil, nil - } - - closeChan, ok := rawCloseChan.(chan struct{}) - if !ok { - return nil, nil - } - - closeChan <- struct{}{} - - return nil, nil -} - func (c *Client) handleProxyRequest(ctx context.Context, m *control.Message) (*control.Message, error) { proxyReqPayload, ok := m.Payload.(*control.ProxyRequestPayload) if !ok { return nil, errors.WithStack(ErrUnexpectedMessage) } - requestID := strconv.FormatInt(proxyReqPayload.RequestID, 10) - - ctx = logger.With(ctx, logger.F("requestID", requestID)) - - logger.Debug( - ctx, "handling proxy request", + ctx = logger.With(ctx, logger.F("network", proxyReqPayload.Network), logger.F("address", proxyReqPayload.Address), ) - stream, err := c.sess.OpenStream() + logger.Debug(ctx, "handling proxy request") + + out, err := net.Dial(proxyReqPayload.Network, proxyReqPayload.Address) if err != nil { return nil, errors.WithStack(err) } - closeChan := make(chan struct{}) - - go func() { - defer func() { - stream.Close() - logger.Debug(ctx, "proxy stream closed") - }() - - proxy := func() error { - net, err := net.Dial(proxyReqPayload.Network, proxyReqPayload.Address) - if err != nil { - return errors.WithStack(err) - } - defer net.Close() - - err = pipe(ctx, stream, net) - if errors.Is(err, os.ErrClosed) { - return nil - } - - if err != nil { - return errors.WithStack(err) - } - - return nil - } - - for { - select { - case <-closeChan: - return - default: - if err := proxy(); err != nil { - logger.Error(ctx, "error while proxying", logger.E(err)) - - continue - } - - return - } - } - }() - - c.proxies.Set(requestID, closeChan) + go c.handleProxyStream(ctx, out) return nil, nil } +func (c *Client) handleProxyStream(ctx context.Context, out net.Conn) { + c.openStreamMutex.Lock() + + in, err := c.sess.OpenStream() + if err != nil { + c.openStreamMutex.Unlock() + logger.Error(ctx, "error while accepting proxy stream", logger.E(err)) + + return + } + + c.openStreamMutex.Unlock() + + streamCopy := func(dst io.Writer, src io.ReadCloser) { + if _, err := Copy(dst, src); err != nil { + if errors.Is(err, smux.ErrInvalidProtocol) { + logger.Error(ctx, "error while proxying", logger.E(errors.WithStack(err))) + } + } + + logger.Debug(ctx, "closing proxy stream") + + in.Close() + out.Close() + } + + go streamCopy(in, out) + streamCopy(out, in) +} + func NewClient(funcs ...ClientConfigFunc) *Client { conf := DefaultClientConfig() @@ -208,8 +166,7 @@ func NewClient(funcs ...ClientConfigFunc) *Client { } return &Client{ - conf: conf, - http: &http.Client{}, - proxies: cmap.New(), + conf: conf, + http: &http.Client{}, } } diff --git a/cmd/server/main.go b/cmd/server/main.go index 28aaeef..22b60fd 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -60,7 +60,7 @@ func handleRequest(w http.ResponseWriter, r *http.Request) { return } - target, err := url.Parse("http://localhost:3000") + target, err := url.Parse("https://arcad.games") if err != nil { logger.Fatal(r.Context(), "could not parse url", logger.E(err)) } diff --git a/control/control.go b/control/control.go index 820342c..6e56806 100644 --- a/control/control.go +++ b/control/control.go @@ -3,9 +3,6 @@ package control import ( "context" "encoding/json" - "net" - "sync/atomic" - "time" "github.com/pkg/errors" "github.com/xtaci/smux" @@ -13,11 +10,39 @@ import ( ) type Control struct { - encoder *json.Encoder - decoder *json.Decoder - stream *smux.Stream - sess *smux.Session - proxyClock int64 + encoder *json.Encoder + decoder *json.Decoder + stream *smux.Stream +} + +func (c *Control) Init(ctx context.Context, sess *smux.Session, serverMode bool) error { + config := smux.DefaultConfig() + config.Version = 2 + + logger.Debug(ctx, "creating control stream") + + var ( + stream *smux.Stream + err error + ) + + if serverMode { + stream, err = sess.AcceptStream() + if err != nil { + return errors.WithStack(err) + } + } else { + stream, err = sess.OpenStream() + if err != nil { + return errors.WithStack(err) + } + } + + c.stream = stream + c.decoder = json.NewDecoder(stream) + c.encoder = json.NewEncoder(stream) + + return nil } func (c *Control) AuthRequest(credentials interface{}) (bool, error) { @@ -39,89 +64,82 @@ func (c *Control) AuthRequest(credentials interface{}) (bool, error) { return authResPayload.Success, nil } -func (c *Control) Proxy(ctx context.Context, network, address string) (net.Conn, error) { - var ( - stream *smux.Stream - err error - ) - - requestID := atomic.AddInt64(&c.proxyClock, 1) - +func (c *Control) ProxyReq(ctx context.Context, network, address string) error { req := NewMessage(TypeProxyRequest, &ProxyRequestPayload{ - RequestID: requestID, - Network: network, - Address: address, + Network: network, + Address: address, }) - ctx = logger.With(ctx, logger.F("network", network), logger.F("address", address)) - - logger.Debug(ctx, "proxying") - if err := c.Write(req); err != nil { - return nil, errors.WithStack(err) + return errors.WithStack(err) } - logger.Debug(ctx, "opening stream") - - stream, err = c.sess.AcceptStream() - if err != nil { - return nil, errors.WithStack(err) - } - - go func() { - <-ctx.Done() - - req := NewMessage(TypeCloseProxy, &CloseProxyPayload{ - RequestID: requestID, - }) - - if err := c.Write(req); err != nil { - logger.Error(ctx, "error while closing proxy", logger.E(err)) - } - - logger.Debug(ctx, "closing proxy conn") - stream.Close() - }() - - return stream, nil + return nil } func (c *Control) Listen(ctx context.Context, handlers Handlers) error { - for { - logger.Debug(ctx, "reading next message") + errChan := make(chan error) + msgChan := make(chan *Message) + dieChan := c.stream.GetDieCh() - req, err := c.Read() - if err != nil { - return errors.WithStack(err) - } + go func(msgChan chan *Message, errChan chan error) { + for { + logger.Debug(ctx, "reading next message") - go func() { - subCtx := logger.With(ctx, logger.F("messageType", req.Type)) - - handler, exists := handlers[req.Type] - if !exists { - logger.Error(subCtx, "no message handler registered") - - return - } - - res, err := handler(subCtx, req) + msg, err := c.Read() if err != nil { - logger.Error(subCtx, "error while handling message", logger.E(err)) + errChan <- errors.WithStack(err) + + close(errChan) + close(msgChan) return } - if res == nil { - return - } + msgChan <- msg + } + }(msgChan, errChan) - if err := c.Write(res); err != nil { - logger.Error(subCtx, "error while write message response", logger.E(err)) + for { + select { + case <-ctx.Done(): + return nil - return - } - }() + case <-dieChan: + return errors.WithStack(ErrStreamClosed) + + case err := <-errChan: + return errors.WithStack(err) + + case msg := <-msgChan: + go func() { + subCtx := logger.With(ctx, logger.F("messageType", msg.Type)) + + handler, exists := handlers[msg.Type] + if !exists { + logger.Error(subCtx, "no message handler registered") + + return + } + + res, err := handler(subCtx, msg) + if err != nil { + logger.Error(subCtx, "error while handling message", logger.E(err)) + + return + } + + if res == nil { + return + } + + if err := c.Write(res); err != nil { + logger.Error(subCtx, "error while write message response", logger.E(err)) + + return + } + }() + } } } @@ -164,10 +182,6 @@ func (c *Control) read(m *Message) error { } func (c *Control) write(m *Message) error { - if err := c.stream.SetWriteDeadline(time.Now().Add(time.Second)); err != nil { - return errors.WithStack(err) - } - if err := c.encoder.Encode(m); err != nil { return errors.WithStack(err) } @@ -175,11 +189,6 @@ func (c *Control) write(m *Message) error { return nil } -func New(sess *smux.Session, controlStream *smux.Stream) *Control { - return &Control{ - encoder: json.NewEncoder(controlStream), - decoder: json.NewDecoder(controlStream), - sess: sess, - stream: controlStream, - } +func New() *Control { + return &Control{} } diff --git a/control/error.go b/control/error.go index 72c87c8..fc6bee1 100644 --- a/control/error.go +++ b/control/error.go @@ -3,5 +3,6 @@ package control import "errors" var ( + ErrStreamClosed = errors.New("stream closed") ErrUnexpectedMessage = errors.New("unexpected message") ) diff --git a/control/proxy.go b/control/proxy.go index cda1f15..9685d6b 100644 --- a/control/proxy.go +++ b/control/proxy.go @@ -1,9 +1,8 @@ package control type ProxyRequestPayload struct { - RequestID int64 `json:"i"` - Network string `json:"n"` - Address string `json:"a"` + Network string `json:"n"` + Address string `json:"a"` } type CloseProxyPayload struct { diff --git a/helper.go b/helper.go index 2e822da..845850d 100644 --- a/helper.go +++ b/helper.go @@ -1,52 +1,30 @@ package tunnel import ( - "context" "io" - "net" "github.com/pkg/errors" "github.com/xtaci/kcp-go/v5" - "gitlab.com/wpetit/goweb/logger" ) -func pipe(ctx context.Context, client net.Conn, server net.Conn) (err error) { - stop := make(chan bool) +const bufSize = 4096 - go func() { - err = relay(client, server, stop) - if err != nil { - err = errors.WithStack(err) - logger.Debug(ctx, "client->server error", logger.E(err)) - } - }() - go func() { - err = relay(server, client, stop) - if err != nil { - err = errors.WithStack(err) - logger.Debug(ctx, "server->client error", logger.E(err)) - } - }() - - select { - case <-stop: - return err +// From https://github.com/xtaci/kcptun/blob/master/generic/copy.go +// Copyright https://github.com/xtaci +func Copy(dst io.Writer, src io.Reader) (written int64, err error) { + // If the reader has a WriteTo method, use it to do the copy. + // Avoids an allocation and a copy. + if wt, ok := src.(io.WriterTo); ok { + return wt.WriteTo(dst) } -} - -func relay(src net.Conn, dst net.Conn, stop chan bool) (err error) { - _, err = io.Copy(dst, src) - if errors.Is(err, io.EOF) { - err = nil + // Similarly, if the writer has a ReadFrom method, use it to do the copy. + if rt, ok := dst.(io.ReaderFrom); ok { + return rt.ReadFrom(src) } - if err != nil { - err = errors.WithStack(err) - } - - stop <- true - - return + // fallback to standard io.CopyBuffer + buf := make([]byte, bufSize) + return io.CopyBuffer(dst, src, buf) } func createBlockCrypt(algorithm string, pass []byte) (kcp.BlockCrypt, error) { diff --git a/remote_client.go b/remote_client.go index f8a4cff..ecfe9fa 100644 --- a/remote_client.go +++ b/remote_client.go @@ -2,11 +2,13 @@ package tunnel import ( "context" - "io" "net" + "sync" + "time" "forge.cadoles.com/wpetit/go-tunnel/control" + cmap "github.com/orcaman/concurrent-map" "github.com/pkg/errors" "github.com/xtaci/kcp-go/v5" "github.com/xtaci/smux" @@ -20,10 +22,32 @@ type RemoteClient struct { sess *smux.Session control *control.Control remoteAddr net.Addr + proxies cmap.ConcurrentMap + acceptStreamMutex sync.Mutex } func (c *RemoteClient) Accept(ctx context.Context, conn *kcp.UDPSession) error { + config := smux.DefaultConfig() + config.Version = 2 + config.KeepAliveInterval = 10 * time.Second + config.KeepAliveTimeout = 2 * config.KeepAliveInterval + + logger.Debug(ctx, "creating server session") + + sess, err := smux.Server(conn, config) + if err != nil { + return errors.WithStack(err) + } + + ctrl := control.New() + + if err := ctrl.Init(ctx, sess, true); err != nil { + return errors.WithStack(err) + } + + c.sess = sess c.remoteAddr = conn.RemoteAddr() + c.control = ctrl if c.onClientConnectHook != nil { if err := c.onClientConnectHook.OnClientConnect(ctx, c); err != nil { @@ -31,45 +55,82 @@ func (c *RemoteClient) Accept(ctx context.Context, conn *kcp.UDPSession) error { } } - config := smux.DefaultConfig() - config.Version = 2 - - sess, err := smux.Server(conn, config) - if err != nil { - return errors.WithStack(err) - } - - logger.Debug(ctx, "accepting control stream") - - controlStream, err := sess.AcceptStream() - if err != nil { - return errors.WithStack(err) - } - - c.sess = sess - c.control = control.New(sess, controlStream) - return nil } func (c *RemoteClient) Listen(ctx context.Context) error { - logger.Debug(ctx, "listening for messages") - - err := c.control.Listen(ctx, control.Handlers{ - control.TypeAuthRequest: c.handleAuthRequest, - }) - - if errors.Is(err, io.ErrClosedPipe) { + defer func() { if c.onClientDisconnectHook != nil { if err := c.onClientDisconnectHook.OnClientDisconnect(ctx, c); err != nil { - logger.Error(ctx, "client disconnect hook error", logger.E(err)) + logger.Error(ctx, "client disconnect hook error", logger.E(errors.WithStack(err))) } } + }() - return errors.WithStack(ErrConnectionClosed) + logger.Debug(ctx, "listening for messages") + + return c.control.Listen(ctx, control.Handlers{ + control.TypeAuthRequest: c.handleAuthRequest, + }) +} + +func (c *RemoteClient) ConfigureHooks(hooks interface{}) { + if hooks == nil { + return } - return err + if onClientAuthHook, ok := hooks.(OnClientAuthHook); ok { + c.onClientAuthHook = onClientAuthHook + } + + if OnClientConnectHook, ok := hooks.(OnClientConnectHook); ok { + c.onClientConnectHook = OnClientConnectHook + } + + if OnClientDisconnectHook, ok := hooks.(OnClientDisconnectHook); ok { + c.onClientDisconnectHook = OnClientDisconnectHook + } +} + +func (c *RemoteClient) RemoteAddr() net.Addr { + return c.remoteAddr +} + +func (c *RemoteClient) Close() { + if c.sess != nil { + c.sess.Close() + } + + c.sess = nil + c.control = nil +} + +func (c *RemoteClient) Proxy(ctx context.Context, network, address string) (net.Conn, error) { + ctx = logger.With(ctx, logger.F("network", network), logger.F("address", address)) + + if err := c.control.ProxyReq(ctx, network, address); err != nil { + return nil, errors.WithStack(err) + } + + logger.Debug(ctx, "opening proxy stream") + + c.acceptStreamMutex.Lock() + + stream, err := c.sess.AcceptStream() + if err != nil { + c.acceptStreamMutex.Unlock() + return nil, errors.WithStack(err) + } + + c.acceptStreamMutex.Unlock() + + go func() { + <-ctx.Done() + logger.Debug(ctx, "closing proxy stream") + stream.Close() + }() + + return stream, nil } func (c *RemoteClient) handleAuthRequest(ctx context.Context, m *control.Message) (*control.Message, error) { @@ -101,38 +162,8 @@ func (c *RemoteClient) handleAuthRequest(ctx context.Context, m *control.Message return res, nil } -func (c *RemoteClient) ConfigureHooks(hooks interface{}) { - if hooks == nil { - return - } - - if onClientAuthHook, ok := hooks.(OnClientAuthHook); ok { - c.onClientAuthHook = onClientAuthHook - } - - if OnClientConnectHook, ok := hooks.(OnClientConnectHook); ok { - c.onClientConnectHook = OnClientConnectHook - } - - if OnClientDisconnectHook, ok := hooks.(OnClientDisconnectHook); ok { - c.onClientDisconnectHook = OnClientDisconnectHook - } -} - -func (c *RemoteClient) RemoteAddr() net.Addr { - return c.remoteAddr -} - -func (c *RemoteClient) Proxy(ctx context.Context, network, address string) (net.Conn, error) { - return c.control.Proxy(ctx, network, address) -} - -func (c *RemoteClient) Close() { - if c.sess != nil && !c.sess.IsClosed() { - c.sess.Close() - } -} - func NewRemoteClient() *RemoteClient { - return &RemoteClient{} + return &RemoteClient{ + proxies: cmap.New(), + } } diff --git a/server.go b/server.go index 6a13ded..7af7635 100644 --- a/server.go +++ b/server.go @@ -37,18 +37,20 @@ func (s *Server) handleNewConn(ctx context.Context, conn *kcp.UDPSession) { ctx = logger.With(ctx, logger.F("remoteAddr", conn.RemoteAddr().String())) remoteClient := NewRemoteClient() + defer remoteClient.Close() + defer conn.Close() remoteClient.ConfigureHooks(s.conf.Hooks) if err := remoteClient.Accept(ctx, conn); err != nil { - logger.Error(ctx, "remote client error", logger.E(err)) + logger.Error(ctx, "remote client error", logger.E(errors.WithStack(err))) return } if err := remoteClient.Listen(ctx); err != nil { - logger.Error(ctx, "remote client error", logger.E(err)) + logger.Error(ctx, "remote client error", logger.E(errors.WithStack(err))) } }