feat: better proxy handling

This commit is contained in:
wpetit 2020-10-23 17:08:42 +02:00
parent 30564efd85
commit a209260778
6 changed files with 146 additions and 24 deletions

View File

@ -5,10 +5,13 @@ import (
"io"
"net"
"net/http"
"os"
"strconv"
"gitlab.com/wpetit/goweb/logger"
"forge.cadoles.com/wpetit/go-tunnel/control"
cmap "github.com/orcaman/concurrent-map"
"github.com/pkg/errors"
"github.com/xtaci/kcp-go/v5"
"github.com/xtaci/smux"
@ -20,6 +23,7 @@ type Client struct {
sess *smux.Session
control *control.Control
http *http.Client
proxies cmap.ConcurrentMap
}
func (c *Client) Connect(ctx context.Context) error {
@ -70,6 +74,7 @@ func (c *Client) Listen(ctx context.Context) error {
err := c.control.Listen(ctx, control.Handlers{
control.TypeProxyRequest: c.handleProxyRequest,
control.TypeCloseProxy: c.handleCloseProxy,
})
if errors.Is(err, io.ErrClosedPipe) {
@ -99,27 +104,94 @@ func (c *Client) Close() error {
return nil
}
func (c *Client) handleCloseProxy(ctx context.Context, m *control.Message) (*control.Message, error) {
closeProxyPayload, ok := m.Payload.(*control.CloseProxyPayload)
if !ok {
return nil, errors.WithStack(ErrUnexpectedMessage)
}
requestID := strconv.FormatInt(closeProxyPayload.RequestID, 10)
rawCloseChan, exists := c.proxies.Get(requestID)
if !exists {
return nil, nil
}
closeChan, ok := rawCloseChan.(chan struct{})
if !ok {
return nil, nil
}
closeChan <- struct{}{}
return nil, nil
}
func (c *Client) handleProxyRequest(ctx context.Context, m *control.Message) (*control.Message, error) {
proxyReqPayload, ok := m.Payload.(*control.ProxyRequestPayload)
if !ok {
return nil, errors.WithStack(ErrUnexpectedMessage)
}
requestID := strconv.FormatInt(proxyReqPayload.RequestID, 10)
ctx = logger.With(ctx, logger.F("requestID", requestID))
logger.Debug(
ctx, "handling proxy request",
logger.F("network", proxyReqPayload.Network),
logger.F("address", proxyReqPayload.Address),
)
stream, err := c.sess.OpenStream()
if err != nil {
return nil, errors.WithStack(err)
}
defer stream.Close()
closeChan := make(chan struct{})
net, err := net.Dial(proxyReqPayload.Network, proxyReqPayload.Address)
if err != nil {
return nil, errors.WithStack(err)
}
go func() {
defer func() {
stream.Close()
logger.Debug(ctx, "proxy stream closed")
}()
if err := pipe(stream, net); err != nil {
return nil, errors.WithStack(err)
}
proxy := func() error {
net, err := net.Dial(proxyReqPayload.Network, proxyReqPayload.Address)
if err != nil {
return errors.WithStack(err)
}
defer net.Close()
err = pipe(ctx, stream, net)
if errors.Is(err, os.ErrClosed) {
return nil
}
if err != nil {
return errors.WithStack(err)
}
return nil
}
for {
select {
case <-closeChan:
return
default:
if err := proxy(); err != nil {
logger.Error(ctx, "error while proxying", logger.E(err))
continue
}
return
}
}
}()
c.proxies.Set(requestID, closeChan)
return nil, nil
}
@ -132,7 +204,8 @@ func NewClient(funcs ...ClientConfigFunc) *Client {
}
return &Client{
conf: conf,
http: &http.Client{},
conf: conf,
http: &http.Client{},
proxies: cmap.New(),
}
}

View File

@ -4,6 +4,7 @@ import (
"context"
"encoding/json"
"net"
"sync/atomic"
"time"
"github.com/pkg/errors"
@ -12,10 +13,11 @@ import (
)
type Control struct {
encoder *json.Encoder
decoder *json.Decoder
stream *smux.Stream
sess *smux.Session
encoder *json.Encoder
decoder *json.Decoder
stream *smux.Stream
sess *smux.Session
proxyClock int64
}
func (c *Control) AuthRequest(credentials interface{}) (bool, error) {
@ -37,17 +39,18 @@ func (c *Control) AuthRequest(credentials interface{}) (bool, error) {
return authResPayload.Success, nil
}
type CloseStream func()
func (c *Control) Proxy(ctx context.Context, network, address string) (net.Conn, error) {
var (
stream *smux.Stream
err error
)
requestID := atomic.AddInt64(&c.proxyClock, 1)
req := NewMessage(TypeProxyRequest, &ProxyRequestPayload{
Network: network,
Address: address,
RequestID: requestID,
Network: network,
Address: address,
})
ctx = logger.With(ctx, logger.F("network", network), logger.F("address", address))
@ -65,6 +68,21 @@ func (c *Control) Proxy(ctx context.Context, network, address string) (net.Conn,
return nil, errors.WithStack(err)
}
go func() {
<-ctx.Done()
req := NewMessage(TypeCloseProxy, &CloseProxyPayload{
RequestID: requestID,
})
if err := c.Write(req); err != nil {
logger.Error(ctx, "error while closing proxy", logger.E(err))
}
logger.Debug(ctx, "closing proxy conn")
stream.Close()
}()
return stream, nil
}

View File

@ -10,6 +10,7 @@ const (
TypeAuthRequest MessageType = "auth-req"
TypeAuthResponse MessageType = "auth-res"
TypeProxyRequest MessageType = "proxy-req"
TypeCloseProxy MessageType = "close-proxy"
)
type MessageType string
@ -61,6 +62,8 @@ func unmarshalPayload(mType MessageType, data []byte) (interface{}, error) {
payload = &AuthResponsePayload{}
case TypeProxyRequest:
payload = &ProxyRequestPayload{}
case TypeCloseProxy:
payload = &CloseProxyPayload{}
default:
return nil, errors.Wrapf(ErrUnexpectedMessage, "unexpected message type '%s'", mType)
}

View File

@ -1,6 +1,11 @@
package control
type ProxyRequestPayload struct {
Network string `json:"n"`
Address string `json:"a"`
RequestID int64 `json:"i"`
Network string `json:"n"`
Address string `json:"a"`
}
type CloseProxyPayload struct {
RequestID int64 `json:"i"`
}

View File

@ -1,18 +1,30 @@
package tunnel
import (
"context"
"io"
"net"
"github.com/pkg/errors"
"gitlab.com/wpetit/goweb/logger"
)
func pipe(client net.Conn, server net.Conn) (err error) {
func pipe(ctx context.Context, client net.Conn, server net.Conn) (err error) {
stop := make(chan bool)
go func() {
err = relay(client, server, stop)
if err != nil {
err = errors.WithStack(err)
logger.Debug(ctx, "client->server error", logger.E(err))
}
}()
go func() {
err = relay(server, client, stop)
if err != nil {
err = errors.WithStack(err)
logger.Debug(ctx, "server->client error", logger.E(err))
}
}()
select {
@ -21,11 +33,15 @@ func pipe(client net.Conn, server net.Conn) (err error) {
}
}
func relay(src io.ReadCloser, dst io.WriteCloser, stop chan bool) (err error) {
func relay(src net.Conn, dst net.Conn, stop chan bool) (err error) {
_, err = io.Copy(dst, src)
if errors.Is(err, io.EOF) {
err = nil
}
dst.Close()
src.Close()
if err != nil {
err = errors.WithStack(err)
}
stop <- true

View File

@ -4,6 +4,7 @@ import (
"context"
"io"
"net"
"time"
"forge.cadoles.com/wpetit/go-tunnel/control"
@ -41,6 +42,10 @@ func (c *RemoteClient) Accept(ctx context.Context, conn *kcp.UDPSession) error {
logger.Debug(ctx, "accepting control stream")
if err := sess.SetDeadline(time.Now().Add(30 * time.Second)); err != nil {
return errors.WithStack(err)
}
controlStream, err := sess.AcceptStream()
if err != nil {
return errors.WithStack(err)
@ -92,6 +97,8 @@ func (c *RemoteClient) handleAuthRequest(ctx context.Context, m *control.Message
}
}
logger.Debug(ctx, "auth succeeded", logger.F("credentials", authReqPayload.Credentials))
res := control.NewMessage(control.TypeAuthResponse, &control.AuthResponsePayload{
Success: success,
})