feat: better proxy handling

This commit is contained in:
wpetit 2020-10-23 17:08:42 +02:00
parent 30564efd85
commit a209260778
6 changed files with 146 additions and 24 deletions

View File

@ -5,10 +5,13 @@ import (
"io" "io"
"net" "net"
"net/http" "net/http"
"os"
"strconv"
"gitlab.com/wpetit/goweb/logger" "gitlab.com/wpetit/goweb/logger"
"forge.cadoles.com/wpetit/go-tunnel/control" "forge.cadoles.com/wpetit/go-tunnel/control"
cmap "github.com/orcaman/concurrent-map"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/xtaci/kcp-go/v5" "github.com/xtaci/kcp-go/v5"
"github.com/xtaci/smux" "github.com/xtaci/smux"
@ -20,6 +23,7 @@ type Client struct {
sess *smux.Session sess *smux.Session
control *control.Control control *control.Control
http *http.Client http *http.Client
proxies cmap.ConcurrentMap
} }
func (c *Client) Connect(ctx context.Context) error { func (c *Client) Connect(ctx context.Context) error {
@ -70,6 +74,7 @@ func (c *Client) Listen(ctx context.Context) error {
err := c.control.Listen(ctx, control.Handlers{ err := c.control.Listen(ctx, control.Handlers{
control.TypeProxyRequest: c.handleProxyRequest, control.TypeProxyRequest: c.handleProxyRequest,
control.TypeCloseProxy: c.handleCloseProxy,
}) })
if errors.Is(err, io.ErrClosedPipe) { if errors.Is(err, io.ErrClosedPipe) {
@ -99,27 +104,94 @@ func (c *Client) Close() error {
return nil return nil
} }
func (c *Client) handleCloseProxy(ctx context.Context, m *control.Message) (*control.Message, error) {
closeProxyPayload, ok := m.Payload.(*control.CloseProxyPayload)
if !ok {
return nil, errors.WithStack(ErrUnexpectedMessage)
}
requestID := strconv.FormatInt(closeProxyPayload.RequestID, 10)
rawCloseChan, exists := c.proxies.Get(requestID)
if !exists {
return nil, nil
}
closeChan, ok := rawCloseChan.(chan struct{})
if !ok {
return nil, nil
}
closeChan <- struct{}{}
return nil, nil
}
func (c *Client) handleProxyRequest(ctx context.Context, m *control.Message) (*control.Message, error) { func (c *Client) handleProxyRequest(ctx context.Context, m *control.Message) (*control.Message, error) {
proxyReqPayload, ok := m.Payload.(*control.ProxyRequestPayload) proxyReqPayload, ok := m.Payload.(*control.ProxyRequestPayload)
if !ok { if !ok {
return nil, errors.WithStack(ErrUnexpectedMessage) return nil, errors.WithStack(ErrUnexpectedMessage)
} }
requestID := strconv.FormatInt(proxyReqPayload.RequestID, 10)
ctx = logger.With(ctx, logger.F("requestID", requestID))
logger.Debug(
ctx, "handling proxy request",
logger.F("network", proxyReqPayload.Network),
logger.F("address", proxyReqPayload.Address),
)
stream, err := c.sess.OpenStream() stream, err := c.sess.OpenStream()
if err != nil { if err != nil {
return nil, errors.WithStack(err) return nil, errors.WithStack(err)
} }
defer stream.Close() closeChan := make(chan struct{})
net, err := net.Dial(proxyReqPayload.Network, proxyReqPayload.Address) go func() {
if err != nil { defer func() {
return nil, errors.WithStack(err) stream.Close()
} logger.Debug(ctx, "proxy stream closed")
}()
if err := pipe(stream, net); err != nil { proxy := func() error {
return nil, errors.WithStack(err) net, err := net.Dial(proxyReqPayload.Network, proxyReqPayload.Address)
} if err != nil {
return errors.WithStack(err)
}
defer net.Close()
err = pipe(ctx, stream, net)
if errors.Is(err, os.ErrClosed) {
return nil
}
if err != nil {
return errors.WithStack(err)
}
return nil
}
for {
select {
case <-closeChan:
return
default:
if err := proxy(); err != nil {
logger.Error(ctx, "error while proxying", logger.E(err))
continue
}
return
}
}
}()
c.proxies.Set(requestID, closeChan)
return nil, nil return nil, nil
} }
@ -132,7 +204,8 @@ func NewClient(funcs ...ClientConfigFunc) *Client {
} }
return &Client{ return &Client{
conf: conf, conf: conf,
http: &http.Client{}, http: &http.Client{},
proxies: cmap.New(),
} }
} }

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"encoding/json" "encoding/json"
"net" "net"
"sync/atomic"
"time" "time"
"github.com/pkg/errors" "github.com/pkg/errors"
@ -12,10 +13,11 @@ import (
) )
type Control struct { type Control struct {
encoder *json.Encoder encoder *json.Encoder
decoder *json.Decoder decoder *json.Decoder
stream *smux.Stream stream *smux.Stream
sess *smux.Session sess *smux.Session
proxyClock int64
} }
func (c *Control) AuthRequest(credentials interface{}) (bool, error) { func (c *Control) AuthRequest(credentials interface{}) (bool, error) {
@ -37,17 +39,18 @@ func (c *Control) AuthRequest(credentials interface{}) (bool, error) {
return authResPayload.Success, nil return authResPayload.Success, nil
} }
type CloseStream func()
func (c *Control) Proxy(ctx context.Context, network, address string) (net.Conn, error) { func (c *Control) Proxy(ctx context.Context, network, address string) (net.Conn, error) {
var ( var (
stream *smux.Stream stream *smux.Stream
err error err error
) )
requestID := atomic.AddInt64(&c.proxyClock, 1)
req := NewMessage(TypeProxyRequest, &ProxyRequestPayload{ req := NewMessage(TypeProxyRequest, &ProxyRequestPayload{
Network: network, RequestID: requestID,
Address: address, Network: network,
Address: address,
}) })
ctx = logger.With(ctx, logger.F("network", network), logger.F("address", address)) ctx = logger.With(ctx, logger.F("network", network), logger.F("address", address))
@ -65,6 +68,21 @@ func (c *Control) Proxy(ctx context.Context, network, address string) (net.Conn,
return nil, errors.WithStack(err) return nil, errors.WithStack(err)
} }
go func() {
<-ctx.Done()
req := NewMessage(TypeCloseProxy, &CloseProxyPayload{
RequestID: requestID,
})
if err := c.Write(req); err != nil {
logger.Error(ctx, "error while closing proxy", logger.E(err))
}
logger.Debug(ctx, "closing proxy conn")
stream.Close()
}()
return stream, nil return stream, nil
} }

View File

@ -10,6 +10,7 @@ const (
TypeAuthRequest MessageType = "auth-req" TypeAuthRequest MessageType = "auth-req"
TypeAuthResponse MessageType = "auth-res" TypeAuthResponse MessageType = "auth-res"
TypeProxyRequest MessageType = "proxy-req" TypeProxyRequest MessageType = "proxy-req"
TypeCloseProxy MessageType = "close-proxy"
) )
type MessageType string type MessageType string
@ -61,6 +62,8 @@ func unmarshalPayload(mType MessageType, data []byte) (interface{}, error) {
payload = &AuthResponsePayload{} payload = &AuthResponsePayload{}
case TypeProxyRequest: case TypeProxyRequest:
payload = &ProxyRequestPayload{} payload = &ProxyRequestPayload{}
case TypeCloseProxy:
payload = &CloseProxyPayload{}
default: default:
return nil, errors.Wrapf(ErrUnexpectedMessage, "unexpected message type '%s'", mType) return nil, errors.Wrapf(ErrUnexpectedMessage, "unexpected message type '%s'", mType)
} }

View File

@ -1,6 +1,11 @@
package control package control
type ProxyRequestPayload struct { type ProxyRequestPayload struct {
Network string `json:"n"` RequestID int64 `json:"i"`
Address string `json:"a"` Network string `json:"n"`
Address string `json:"a"`
}
type CloseProxyPayload struct {
RequestID int64 `json:"i"`
} }

View File

@ -1,18 +1,30 @@
package tunnel package tunnel
import ( import (
"context"
"io" "io"
"net" "net"
"github.com/pkg/errors"
"gitlab.com/wpetit/goweb/logger"
) )
func pipe(client net.Conn, server net.Conn) (err error) { func pipe(ctx context.Context, client net.Conn, server net.Conn) (err error) {
stop := make(chan bool) stop := make(chan bool)
go func() { go func() {
err = relay(client, server, stop) err = relay(client, server, stop)
if err != nil {
err = errors.WithStack(err)
logger.Debug(ctx, "client->server error", logger.E(err))
}
}() }()
go func() { go func() {
err = relay(server, client, stop) err = relay(server, client, stop)
if err != nil {
err = errors.WithStack(err)
logger.Debug(ctx, "server->client error", logger.E(err))
}
}() }()
select { select {
@ -21,11 +33,15 @@ func pipe(client net.Conn, server net.Conn) (err error) {
} }
} }
func relay(src io.ReadCloser, dst io.WriteCloser, stop chan bool) (err error) { func relay(src net.Conn, dst net.Conn, stop chan bool) (err error) {
_, err = io.Copy(dst, src) _, err = io.Copy(dst, src)
if errors.Is(err, io.EOF) {
err = nil
}
dst.Close() if err != nil {
src.Close() err = errors.WithStack(err)
}
stop <- true stop <- true

View File

@ -4,6 +4,7 @@ import (
"context" "context"
"io" "io"
"net" "net"
"time"
"forge.cadoles.com/wpetit/go-tunnel/control" "forge.cadoles.com/wpetit/go-tunnel/control"
@ -41,6 +42,10 @@ func (c *RemoteClient) Accept(ctx context.Context, conn *kcp.UDPSession) error {
logger.Debug(ctx, "accepting control stream") logger.Debug(ctx, "accepting control stream")
if err := sess.SetDeadline(time.Now().Add(30 * time.Second)); err != nil {
return errors.WithStack(err)
}
controlStream, err := sess.AcceptStream() controlStream, err := sess.AcceptStream()
if err != nil { if err != nil {
return errors.WithStack(err) return errors.WithStack(err)
@ -92,6 +97,8 @@ func (c *RemoteClient) handleAuthRequest(ctx context.Context, m *control.Message
} }
} }
logger.Debug(ctx, "auth succeeded", logger.F("credentials", authReqPayload.Credentials))
res := control.NewMessage(control.TypeAuthResponse, &control.AuthResponsePayload{ res := control.NewMessage(control.TypeAuthResponse, &control.AuthResponsePayload{
Success: success, Success: success,
}) })