feat: better proxy handling
This commit is contained in:
parent
30564efd85
commit
a209260778
81
client.go
81
client.go
|
@ -5,10 +5,13 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
|
||||||
"gitlab.com/wpetit/goweb/logger"
|
"gitlab.com/wpetit/goweb/logger"
|
||||||
|
|
||||||
"forge.cadoles.com/wpetit/go-tunnel/control"
|
"forge.cadoles.com/wpetit/go-tunnel/control"
|
||||||
|
cmap "github.com/orcaman/concurrent-map"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/xtaci/kcp-go/v5"
|
"github.com/xtaci/kcp-go/v5"
|
||||||
"github.com/xtaci/smux"
|
"github.com/xtaci/smux"
|
||||||
|
@ -20,6 +23,7 @@ type Client struct {
|
||||||
sess *smux.Session
|
sess *smux.Session
|
||||||
control *control.Control
|
control *control.Control
|
||||||
http *http.Client
|
http *http.Client
|
||||||
|
proxies cmap.ConcurrentMap
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Connect(ctx context.Context) error {
|
func (c *Client) Connect(ctx context.Context) error {
|
||||||
|
@ -70,6 +74,7 @@ func (c *Client) Listen(ctx context.Context) error {
|
||||||
|
|
||||||
err := c.control.Listen(ctx, control.Handlers{
|
err := c.control.Listen(ctx, control.Handlers{
|
||||||
control.TypeProxyRequest: c.handleProxyRequest,
|
control.TypeProxyRequest: c.handleProxyRequest,
|
||||||
|
control.TypeCloseProxy: c.handleCloseProxy,
|
||||||
})
|
})
|
||||||
|
|
||||||
if errors.Is(err, io.ErrClosedPipe) {
|
if errors.Is(err, io.ErrClosedPipe) {
|
||||||
|
@ -99,28 +104,95 @@ func (c *Client) Close() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) handleCloseProxy(ctx context.Context, m *control.Message) (*control.Message, error) {
|
||||||
|
closeProxyPayload, ok := m.Payload.(*control.CloseProxyPayload)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.WithStack(ErrUnexpectedMessage)
|
||||||
|
}
|
||||||
|
|
||||||
|
requestID := strconv.FormatInt(closeProxyPayload.RequestID, 10)
|
||||||
|
|
||||||
|
rawCloseChan, exists := c.proxies.Get(requestID)
|
||||||
|
if !exists {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
closeChan, ok := rawCloseChan.(chan struct{})
|
||||||
|
if !ok {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
closeChan <- struct{}{}
|
||||||
|
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Client) handleProxyRequest(ctx context.Context, m *control.Message) (*control.Message, error) {
|
func (c *Client) handleProxyRequest(ctx context.Context, m *control.Message) (*control.Message, error) {
|
||||||
proxyReqPayload, ok := m.Payload.(*control.ProxyRequestPayload)
|
proxyReqPayload, ok := m.Payload.(*control.ProxyRequestPayload)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.WithStack(ErrUnexpectedMessage)
|
return nil, errors.WithStack(ErrUnexpectedMessage)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
requestID := strconv.FormatInt(proxyReqPayload.RequestID, 10)
|
||||||
|
|
||||||
|
ctx = logger.With(ctx, logger.F("requestID", requestID))
|
||||||
|
|
||||||
|
logger.Debug(
|
||||||
|
ctx, "handling proxy request",
|
||||||
|
logger.F("network", proxyReqPayload.Network),
|
||||||
|
logger.F("address", proxyReqPayload.Address),
|
||||||
|
)
|
||||||
|
|
||||||
stream, err := c.sess.OpenStream()
|
stream, err := c.sess.OpenStream()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.WithStack(err)
|
return nil, errors.WithStack(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
defer stream.Close()
|
closeChan := make(chan struct{})
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
defer func() {
|
||||||
|
stream.Close()
|
||||||
|
logger.Debug(ctx, "proxy stream closed")
|
||||||
|
}()
|
||||||
|
|
||||||
|
proxy := func() error {
|
||||||
net, err := net.Dial(proxyReqPayload.Network, proxyReqPayload.Address)
|
net, err := net.Dial(proxyReqPayload.Network, proxyReqPayload.Address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.WithStack(err)
|
return errors.WithStack(err)
|
||||||
|
}
|
||||||
|
defer net.Close()
|
||||||
|
|
||||||
|
err = pipe(ctx, stream, net)
|
||||||
|
if errors.Is(err, os.ErrClosed) {
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := pipe(stream, net); err != nil {
|
if err != nil {
|
||||||
return nil, errors.WithStack(err)
|
return errors.WithStack(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-closeChan:
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
if err := proxy(); err != nil {
|
||||||
|
logger.Error(ctx, "error while proxying", logger.E(err))
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
c.proxies.Set(requestID, closeChan)
|
||||||
|
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -134,5 +206,6 @@ func NewClient(funcs ...ClientConfigFunc) *Client {
|
||||||
return &Client{
|
return &Client{
|
||||||
conf: conf,
|
conf: conf,
|
||||||
http: &http.Client{},
|
http: &http.Client{},
|
||||||
|
proxies: cmap.New(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net"
|
"net"
|
||||||
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
@ -16,6 +17,7 @@ type Control struct {
|
||||||
decoder *json.Decoder
|
decoder *json.Decoder
|
||||||
stream *smux.Stream
|
stream *smux.Stream
|
||||||
sess *smux.Session
|
sess *smux.Session
|
||||||
|
proxyClock int64
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Control) AuthRequest(credentials interface{}) (bool, error) {
|
func (c *Control) AuthRequest(credentials interface{}) (bool, error) {
|
||||||
|
@ -37,15 +39,16 @@ func (c *Control) AuthRequest(credentials interface{}) (bool, error) {
|
||||||
return authResPayload.Success, nil
|
return authResPayload.Success, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type CloseStream func()
|
|
||||||
|
|
||||||
func (c *Control) Proxy(ctx context.Context, network, address string) (net.Conn, error) {
|
func (c *Control) Proxy(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
var (
|
var (
|
||||||
stream *smux.Stream
|
stream *smux.Stream
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
|
|
||||||
|
requestID := atomic.AddInt64(&c.proxyClock, 1)
|
||||||
|
|
||||||
req := NewMessage(TypeProxyRequest, &ProxyRequestPayload{
|
req := NewMessage(TypeProxyRequest, &ProxyRequestPayload{
|
||||||
|
RequestID: requestID,
|
||||||
Network: network,
|
Network: network,
|
||||||
Address: address,
|
Address: address,
|
||||||
})
|
})
|
||||||
|
@ -65,6 +68,21 @@ func (c *Control) Proxy(ctx context.Context, network, address string) (net.Conn,
|
||||||
return nil, errors.WithStack(err)
|
return nil, errors.WithStack(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
|
||||||
|
req := NewMessage(TypeCloseProxy, &CloseProxyPayload{
|
||||||
|
RequestID: requestID,
|
||||||
|
})
|
||||||
|
|
||||||
|
if err := c.Write(req); err != nil {
|
||||||
|
logger.Error(ctx, "error while closing proxy", logger.E(err))
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug(ctx, "closing proxy conn")
|
||||||
|
stream.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
return stream, nil
|
return stream, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -10,6 +10,7 @@ const (
|
||||||
TypeAuthRequest MessageType = "auth-req"
|
TypeAuthRequest MessageType = "auth-req"
|
||||||
TypeAuthResponse MessageType = "auth-res"
|
TypeAuthResponse MessageType = "auth-res"
|
||||||
TypeProxyRequest MessageType = "proxy-req"
|
TypeProxyRequest MessageType = "proxy-req"
|
||||||
|
TypeCloseProxy MessageType = "close-proxy"
|
||||||
)
|
)
|
||||||
|
|
||||||
type MessageType string
|
type MessageType string
|
||||||
|
@ -61,6 +62,8 @@ func unmarshalPayload(mType MessageType, data []byte) (interface{}, error) {
|
||||||
payload = &AuthResponsePayload{}
|
payload = &AuthResponsePayload{}
|
||||||
case TypeProxyRequest:
|
case TypeProxyRequest:
|
||||||
payload = &ProxyRequestPayload{}
|
payload = &ProxyRequestPayload{}
|
||||||
|
case TypeCloseProxy:
|
||||||
|
payload = &CloseProxyPayload{}
|
||||||
default:
|
default:
|
||||||
return nil, errors.Wrapf(ErrUnexpectedMessage, "unexpected message type '%s'", mType)
|
return nil, errors.Wrapf(ErrUnexpectedMessage, "unexpected message type '%s'", mType)
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,11 @@
|
||||||
package control
|
package control
|
||||||
|
|
||||||
type ProxyRequestPayload struct {
|
type ProxyRequestPayload struct {
|
||||||
|
RequestID int64 `json:"i"`
|
||||||
Network string `json:"n"`
|
Network string `json:"n"`
|
||||||
Address string `json:"a"`
|
Address string `json:"a"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type CloseProxyPayload struct {
|
||||||
|
RequestID int64 `json:"i"`
|
||||||
|
}
|
||||||
|
|
24
helper.go
24
helper.go
|
@ -1,18 +1,30 @@
|
||||||
package tunnel
|
package tunnel
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"gitlab.com/wpetit/goweb/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
func pipe(client net.Conn, server net.Conn) (err error) {
|
func pipe(ctx context.Context, client net.Conn, server net.Conn) (err error) {
|
||||||
stop := make(chan bool)
|
stop := make(chan bool)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
err = relay(client, server, stop)
|
err = relay(client, server, stop)
|
||||||
|
if err != nil {
|
||||||
|
err = errors.WithStack(err)
|
||||||
|
logger.Debug(ctx, "client->server error", logger.E(err))
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
go func() {
|
go func() {
|
||||||
err = relay(server, client, stop)
|
err = relay(server, client, stop)
|
||||||
|
if err != nil {
|
||||||
|
err = errors.WithStack(err)
|
||||||
|
logger.Debug(ctx, "server->client error", logger.E(err))
|
||||||
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
|
@ -21,11 +33,15 @@ func pipe(client net.Conn, server net.Conn) (err error) {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func relay(src io.ReadCloser, dst io.WriteCloser, stop chan bool) (err error) {
|
func relay(src net.Conn, dst net.Conn, stop chan bool) (err error) {
|
||||||
_, err = io.Copy(dst, src)
|
_, err = io.Copy(dst, src)
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
err = nil
|
||||||
|
}
|
||||||
|
|
||||||
dst.Close()
|
if err != nil {
|
||||||
src.Close()
|
err = errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
stop <- true
|
stop <- true
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,7 @@ import (
|
||||||
"context"
|
"context"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
"forge.cadoles.com/wpetit/go-tunnel/control"
|
"forge.cadoles.com/wpetit/go-tunnel/control"
|
||||||
|
|
||||||
|
@ -41,6 +42,10 @@ func (c *RemoteClient) Accept(ctx context.Context, conn *kcp.UDPSession) error {
|
||||||
|
|
||||||
logger.Debug(ctx, "accepting control stream")
|
logger.Debug(ctx, "accepting control stream")
|
||||||
|
|
||||||
|
if err := sess.SetDeadline(time.Now().Add(30 * time.Second)); err != nil {
|
||||||
|
return errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
controlStream, err := sess.AcceptStream()
|
controlStream, err := sess.AcceptStream()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.WithStack(err)
|
return errors.WithStack(err)
|
||||||
|
@ -92,6 +97,8 @@ func (c *RemoteClient) handleAuthRequest(ctx context.Context, m *control.Message
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
logger.Debug(ctx, "auth succeeded", logger.F("credentials", authReqPayload.Credentials))
|
||||||
|
|
||||||
res := control.NewMessage(control.TypeAuthResponse, &control.AuthResponsePayload{
|
res := control.NewMessage(control.TypeAuthResponse, &control.AuthResponsePayload{
|
||||||
Success: success,
|
Success: success,
|
||||||
})
|
})
|
||||||
|
|
Loading…
Reference in New Issue