From a2092607785023e09af5451cbf12ae1cfcdd53cd Mon Sep 17 00:00:00 2001 From: William Petit Date: Fri, 23 Oct 2020 17:08:42 +0200 Subject: [PATCH] feat: better proxy handling --- client.go | 93 +++++++++++++++++++++++++++++++++++++++++----- control/control.go | 34 +++++++++++++---- control/message.go | 3 ++ control/proxy.go | 9 ++++- helper.go | 24 ++++++++++-- remote_client.go | 7 ++++ 6 files changed, 146 insertions(+), 24 deletions(-) diff --git a/client.go b/client.go index f787f8f..b197a0d 100644 --- a/client.go +++ b/client.go @@ -5,10 +5,13 @@ import ( "io" "net" "net/http" + "os" + "strconv" "gitlab.com/wpetit/goweb/logger" "forge.cadoles.com/wpetit/go-tunnel/control" + cmap "github.com/orcaman/concurrent-map" "github.com/pkg/errors" "github.com/xtaci/kcp-go/v5" "github.com/xtaci/smux" @@ -20,6 +23,7 @@ type Client struct { sess *smux.Session control *control.Control http *http.Client + proxies cmap.ConcurrentMap } func (c *Client) Connect(ctx context.Context) error { @@ -70,6 +74,7 @@ func (c *Client) Listen(ctx context.Context) error { err := c.control.Listen(ctx, control.Handlers{ control.TypeProxyRequest: c.handleProxyRequest, + control.TypeCloseProxy: c.handleCloseProxy, }) if errors.Is(err, io.ErrClosedPipe) { @@ -99,27 +104,94 @@ func (c *Client) Close() error { return nil } +func (c *Client) handleCloseProxy(ctx context.Context, m *control.Message) (*control.Message, error) { + closeProxyPayload, ok := m.Payload.(*control.CloseProxyPayload) + if !ok { + return nil, errors.WithStack(ErrUnexpectedMessage) + } + + requestID := strconv.FormatInt(closeProxyPayload.RequestID, 10) + + rawCloseChan, exists := c.proxies.Get(requestID) + if !exists { + return nil, nil + } + + closeChan, ok := rawCloseChan.(chan struct{}) + if !ok { + return nil, nil + } + + closeChan <- struct{}{} + + return nil, nil +} + func (c *Client) handleProxyRequest(ctx context.Context, m *control.Message) (*control.Message, error) { proxyReqPayload, ok := m.Payload.(*control.ProxyRequestPayload) if !ok { return nil, errors.WithStack(ErrUnexpectedMessage) } + requestID := strconv.FormatInt(proxyReqPayload.RequestID, 10) + + ctx = logger.With(ctx, logger.F("requestID", requestID)) + + logger.Debug( + ctx, "handling proxy request", + logger.F("network", proxyReqPayload.Network), + logger.F("address", proxyReqPayload.Address), + ) + stream, err := c.sess.OpenStream() if err != nil { return nil, errors.WithStack(err) } - defer stream.Close() + closeChan := make(chan struct{}) - net, err := net.Dial(proxyReqPayload.Network, proxyReqPayload.Address) - if err != nil { - return nil, errors.WithStack(err) - } + go func() { + defer func() { + stream.Close() + logger.Debug(ctx, "proxy stream closed") + }() - if err := pipe(stream, net); err != nil { - return nil, errors.WithStack(err) - } + proxy := func() error { + net, err := net.Dial(proxyReqPayload.Network, proxyReqPayload.Address) + if err != nil { + return errors.WithStack(err) + } + defer net.Close() + + err = pipe(ctx, stream, net) + if errors.Is(err, os.ErrClosed) { + return nil + } + + if err != nil { + return errors.WithStack(err) + } + + return nil + } + + for { + select { + case <-closeChan: + return + default: + if err := proxy(); err != nil { + logger.Error(ctx, "error while proxying", logger.E(err)) + + continue + } + + return + } + } + }() + + c.proxies.Set(requestID, closeChan) return nil, nil } @@ -132,7 +204,8 @@ func NewClient(funcs ...ClientConfigFunc) *Client { } return &Client{ - conf: conf, - http: &http.Client{}, + conf: conf, + http: &http.Client{}, + proxies: cmap.New(), } } diff --git a/control/control.go b/control/control.go index 4dcd07d..820342c 100644 --- a/control/control.go +++ b/control/control.go @@ -4,6 +4,7 @@ import ( "context" "encoding/json" "net" + "sync/atomic" "time" "github.com/pkg/errors" @@ -12,10 +13,11 @@ import ( ) type Control struct { - encoder *json.Encoder - decoder *json.Decoder - stream *smux.Stream - sess *smux.Session + encoder *json.Encoder + decoder *json.Decoder + stream *smux.Stream + sess *smux.Session + proxyClock int64 } func (c *Control) AuthRequest(credentials interface{}) (bool, error) { @@ -37,17 +39,18 @@ func (c *Control) AuthRequest(credentials interface{}) (bool, error) { return authResPayload.Success, nil } -type CloseStream func() - func (c *Control) Proxy(ctx context.Context, network, address string) (net.Conn, error) { var ( stream *smux.Stream err error ) + requestID := atomic.AddInt64(&c.proxyClock, 1) + req := NewMessage(TypeProxyRequest, &ProxyRequestPayload{ - Network: network, - Address: address, + RequestID: requestID, + Network: network, + Address: address, }) ctx = logger.With(ctx, logger.F("network", network), logger.F("address", address)) @@ -65,6 +68,21 @@ func (c *Control) Proxy(ctx context.Context, network, address string) (net.Conn, return nil, errors.WithStack(err) } + go func() { + <-ctx.Done() + + req := NewMessage(TypeCloseProxy, &CloseProxyPayload{ + RequestID: requestID, + }) + + if err := c.Write(req); err != nil { + logger.Error(ctx, "error while closing proxy", logger.E(err)) + } + + logger.Debug(ctx, "closing proxy conn") + stream.Close() + }() + return stream, nil } diff --git a/control/message.go b/control/message.go index fd8e05a..008f175 100644 --- a/control/message.go +++ b/control/message.go @@ -10,6 +10,7 @@ const ( TypeAuthRequest MessageType = "auth-req" TypeAuthResponse MessageType = "auth-res" TypeProxyRequest MessageType = "proxy-req" + TypeCloseProxy MessageType = "close-proxy" ) type MessageType string @@ -61,6 +62,8 @@ func unmarshalPayload(mType MessageType, data []byte) (interface{}, error) { payload = &AuthResponsePayload{} case TypeProxyRequest: payload = &ProxyRequestPayload{} + case TypeCloseProxy: + payload = &CloseProxyPayload{} default: return nil, errors.Wrapf(ErrUnexpectedMessage, "unexpected message type '%s'", mType) } diff --git a/control/proxy.go b/control/proxy.go index 1863325..cda1f15 100644 --- a/control/proxy.go +++ b/control/proxy.go @@ -1,6 +1,11 @@ package control type ProxyRequestPayload struct { - Network string `json:"n"` - Address string `json:"a"` + RequestID int64 `json:"i"` + Network string `json:"n"` + Address string `json:"a"` +} + +type CloseProxyPayload struct { + RequestID int64 `json:"i"` } diff --git a/helper.go b/helper.go index 7493971..1ea5957 100644 --- a/helper.go +++ b/helper.go @@ -1,18 +1,30 @@ package tunnel import ( + "context" "io" "net" + + "github.com/pkg/errors" + "gitlab.com/wpetit/goweb/logger" ) -func pipe(client net.Conn, server net.Conn) (err error) { +func pipe(ctx context.Context, client net.Conn, server net.Conn) (err error) { stop := make(chan bool) go func() { err = relay(client, server, stop) + if err != nil { + err = errors.WithStack(err) + logger.Debug(ctx, "client->server error", logger.E(err)) + } }() go func() { err = relay(server, client, stop) + if err != nil { + err = errors.WithStack(err) + logger.Debug(ctx, "server->client error", logger.E(err)) + } }() select { @@ -21,11 +33,15 @@ func pipe(client net.Conn, server net.Conn) (err error) { } } -func relay(src io.ReadCloser, dst io.WriteCloser, stop chan bool) (err error) { +func relay(src net.Conn, dst net.Conn, stop chan bool) (err error) { _, err = io.Copy(dst, src) + if errors.Is(err, io.EOF) { + err = nil + } - dst.Close() - src.Close() + if err != nil { + err = errors.WithStack(err) + } stop <- true diff --git a/remote_client.go b/remote_client.go index f1a7c04..0fb1b66 100644 --- a/remote_client.go +++ b/remote_client.go @@ -4,6 +4,7 @@ import ( "context" "io" "net" + "time" "forge.cadoles.com/wpetit/go-tunnel/control" @@ -41,6 +42,10 @@ func (c *RemoteClient) Accept(ctx context.Context, conn *kcp.UDPSession) error { logger.Debug(ctx, "accepting control stream") + if err := sess.SetDeadline(time.Now().Add(30 * time.Second)); err != nil { + return errors.WithStack(err) + } + controlStream, err := sess.AcceptStream() if err != nil { return errors.WithStack(err) @@ -92,6 +97,8 @@ func (c *RemoteClient) handleAuthRequest(ctx context.Context, m *control.Message } } + logger.Debug(ctx, "auth succeeded", logger.F("credentials", authReqPayload.Credentials)) + res := control.NewMessage(control.TypeAuthResponse, &control.AuthResponsePayload{ Success: success, })