feat: better proxy handling
This commit is contained in:
parent
30564efd85
commit
a209260778
93
client.go
93
client.go
@ -5,10 +5,13 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
|
||||
"gitlab.com/wpetit/goweb/logger"
|
||||
|
||||
"forge.cadoles.com/wpetit/go-tunnel/control"
|
||||
cmap "github.com/orcaman/concurrent-map"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/xtaci/kcp-go/v5"
|
||||
"github.com/xtaci/smux"
|
||||
@ -20,6 +23,7 @@ type Client struct {
|
||||
sess *smux.Session
|
||||
control *control.Control
|
||||
http *http.Client
|
||||
proxies cmap.ConcurrentMap
|
||||
}
|
||||
|
||||
func (c *Client) Connect(ctx context.Context) error {
|
||||
@ -70,6 +74,7 @@ func (c *Client) Listen(ctx context.Context) error {
|
||||
|
||||
err := c.control.Listen(ctx, control.Handlers{
|
||||
control.TypeProxyRequest: c.handleProxyRequest,
|
||||
control.TypeCloseProxy: c.handleCloseProxy,
|
||||
})
|
||||
|
||||
if errors.Is(err, io.ErrClosedPipe) {
|
||||
@ -99,27 +104,94 @@ func (c *Client) Close() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) handleCloseProxy(ctx context.Context, m *control.Message) (*control.Message, error) {
|
||||
closeProxyPayload, ok := m.Payload.(*control.CloseProxyPayload)
|
||||
if !ok {
|
||||
return nil, errors.WithStack(ErrUnexpectedMessage)
|
||||
}
|
||||
|
||||
requestID := strconv.FormatInt(closeProxyPayload.RequestID, 10)
|
||||
|
||||
rawCloseChan, exists := c.proxies.Get(requestID)
|
||||
if !exists {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
closeChan, ok := rawCloseChan.(chan struct{})
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
closeChan <- struct{}{}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (c *Client) handleProxyRequest(ctx context.Context, m *control.Message) (*control.Message, error) {
|
||||
proxyReqPayload, ok := m.Payload.(*control.ProxyRequestPayload)
|
||||
if !ok {
|
||||
return nil, errors.WithStack(ErrUnexpectedMessage)
|
||||
}
|
||||
|
||||
requestID := strconv.FormatInt(proxyReqPayload.RequestID, 10)
|
||||
|
||||
ctx = logger.With(ctx, logger.F("requestID", requestID))
|
||||
|
||||
logger.Debug(
|
||||
ctx, "handling proxy request",
|
||||
logger.F("network", proxyReqPayload.Network),
|
||||
logger.F("address", proxyReqPayload.Address),
|
||||
)
|
||||
|
||||
stream, err := c.sess.OpenStream()
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
defer stream.Close()
|
||||
closeChan := make(chan struct{})
|
||||
|
||||
net, err := net.Dial(proxyReqPayload.Network, proxyReqPayload.Address)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
go func() {
|
||||
defer func() {
|
||||
stream.Close()
|
||||
logger.Debug(ctx, "proxy stream closed")
|
||||
}()
|
||||
|
||||
if err := pipe(stream, net); err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
proxy := func() error {
|
||||
net, err := net.Dial(proxyReqPayload.Network, proxyReqPayload.Address)
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
defer net.Close()
|
||||
|
||||
err = pipe(ctx, stream, net)
|
||||
if errors.Is(err, os.ErrClosed) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-closeChan:
|
||||
return
|
||||
default:
|
||||
if err := proxy(); err != nil {
|
||||
logger.Error(ctx, "error while proxying", logger.E(err))
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
c.proxies.Set(requestID, closeChan)
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
@ -132,7 +204,8 @@ func NewClient(funcs ...ClientConfigFunc) *Client {
|
||||
}
|
||||
|
||||
return &Client{
|
||||
conf: conf,
|
||||
http: &http.Client{},
|
||||
conf: conf,
|
||||
http: &http.Client{},
|
||||
proxies: cmap.New(),
|
||||
}
|
||||
}
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
@ -12,10 +13,11 @@ import (
|
||||
)
|
||||
|
||||
type Control struct {
|
||||
encoder *json.Encoder
|
||||
decoder *json.Decoder
|
||||
stream *smux.Stream
|
||||
sess *smux.Session
|
||||
encoder *json.Encoder
|
||||
decoder *json.Decoder
|
||||
stream *smux.Stream
|
||||
sess *smux.Session
|
||||
proxyClock int64
|
||||
}
|
||||
|
||||
func (c *Control) AuthRequest(credentials interface{}) (bool, error) {
|
||||
@ -37,17 +39,18 @@ func (c *Control) AuthRequest(credentials interface{}) (bool, error) {
|
||||
return authResPayload.Success, nil
|
||||
}
|
||||
|
||||
type CloseStream func()
|
||||
|
||||
func (c *Control) Proxy(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
var (
|
||||
stream *smux.Stream
|
||||
err error
|
||||
)
|
||||
|
||||
requestID := atomic.AddInt64(&c.proxyClock, 1)
|
||||
|
||||
req := NewMessage(TypeProxyRequest, &ProxyRequestPayload{
|
||||
Network: network,
|
||||
Address: address,
|
||||
RequestID: requestID,
|
||||
Network: network,
|
||||
Address: address,
|
||||
})
|
||||
|
||||
ctx = logger.With(ctx, logger.F("network", network), logger.F("address", address))
|
||||
@ -65,6 +68,21 @@ func (c *Control) Proxy(ctx context.Context, network, address string) (net.Conn,
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
|
||||
req := NewMessage(TypeCloseProxy, &CloseProxyPayload{
|
||||
RequestID: requestID,
|
||||
})
|
||||
|
||||
if err := c.Write(req); err != nil {
|
||||
logger.Error(ctx, "error while closing proxy", logger.E(err))
|
||||
}
|
||||
|
||||
logger.Debug(ctx, "closing proxy conn")
|
||||
stream.Close()
|
||||
}()
|
||||
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
|
@ -10,6 +10,7 @@ const (
|
||||
TypeAuthRequest MessageType = "auth-req"
|
||||
TypeAuthResponse MessageType = "auth-res"
|
||||
TypeProxyRequest MessageType = "proxy-req"
|
||||
TypeCloseProxy MessageType = "close-proxy"
|
||||
)
|
||||
|
||||
type MessageType string
|
||||
@ -61,6 +62,8 @@ func unmarshalPayload(mType MessageType, data []byte) (interface{}, error) {
|
||||
payload = &AuthResponsePayload{}
|
||||
case TypeProxyRequest:
|
||||
payload = &ProxyRequestPayload{}
|
||||
case TypeCloseProxy:
|
||||
payload = &CloseProxyPayload{}
|
||||
default:
|
||||
return nil, errors.Wrapf(ErrUnexpectedMessage, "unexpected message type '%s'", mType)
|
||||
}
|
||||
|
@ -1,6 +1,11 @@
|
||||
package control
|
||||
|
||||
type ProxyRequestPayload struct {
|
||||
Network string `json:"n"`
|
||||
Address string `json:"a"`
|
||||
RequestID int64 `json:"i"`
|
||||
Network string `json:"n"`
|
||||
Address string `json:"a"`
|
||||
}
|
||||
|
||||
type CloseProxyPayload struct {
|
||||
RequestID int64 `json:"i"`
|
||||
}
|
||||
|
24
helper.go
24
helper.go
@ -1,18 +1,30 @@
|
||||
package tunnel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"gitlab.com/wpetit/goweb/logger"
|
||||
)
|
||||
|
||||
func pipe(client net.Conn, server net.Conn) (err error) {
|
||||
func pipe(ctx context.Context, client net.Conn, server net.Conn) (err error) {
|
||||
stop := make(chan bool)
|
||||
|
||||
go func() {
|
||||
err = relay(client, server, stop)
|
||||
if err != nil {
|
||||
err = errors.WithStack(err)
|
||||
logger.Debug(ctx, "client->server error", logger.E(err))
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
err = relay(server, client, stop)
|
||||
if err != nil {
|
||||
err = errors.WithStack(err)
|
||||
logger.Debug(ctx, "server->client error", logger.E(err))
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
@ -21,11 +33,15 @@ func pipe(client net.Conn, server net.Conn) (err error) {
|
||||
}
|
||||
}
|
||||
|
||||
func relay(src io.ReadCloser, dst io.WriteCloser, stop chan bool) (err error) {
|
||||
func relay(src net.Conn, dst net.Conn, stop chan bool) (err error) {
|
||||
_, err = io.Copy(dst, src)
|
||||
if errors.Is(err, io.EOF) {
|
||||
err = nil
|
||||
}
|
||||
|
||||
dst.Close()
|
||||
src.Close()
|
||||
if err != nil {
|
||||
err = errors.WithStack(err)
|
||||
}
|
||||
|
||||
stop <- true
|
||||
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
"forge.cadoles.com/wpetit/go-tunnel/control"
|
||||
|
||||
@ -41,6 +42,10 @@ func (c *RemoteClient) Accept(ctx context.Context, conn *kcp.UDPSession) error {
|
||||
|
||||
logger.Debug(ctx, "accepting control stream")
|
||||
|
||||
if err := sess.SetDeadline(time.Now().Add(30 * time.Second)); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
controlStream, err := sess.AcceptStream()
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
@ -92,6 +97,8 @@ func (c *RemoteClient) handleAuthRequest(ctx context.Context, m *control.Message
|
||||
}
|
||||
}
|
||||
|
||||
logger.Debug(ctx, "auth succeeded", logger.F("credentials", authReqPayload.Credentials))
|
||||
|
||||
res := control.NewMessage(control.TypeAuthResponse, &control.AuthResponsePayload{
|
||||
Success: success,
|
||||
})
|
||||
|
Loading…
Reference in New Issue
Block a user