fix: enhance proxy stability
This commit is contained in:
parent
6994ab23ab
commit
536100da90
157
client.go
157
client.go
@ -5,25 +5,24 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gitlab.com/wpetit/goweb/logger"
|
||||
|
||||
"forge.cadoles.com/wpetit/go-tunnel/control"
|
||||
cmap "github.com/orcaman/concurrent-map"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/xtaci/kcp-go/v5"
|
||||
"github.com/xtaci/smux"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
conf *ClientConfig
|
||||
conn *kcp.UDPSession
|
||||
sess *smux.Session
|
||||
control *control.Control
|
||||
http *http.Client
|
||||
proxies cmap.ConcurrentMap
|
||||
conf *ClientConfig
|
||||
conn *kcp.UDPSession
|
||||
sess *smux.Session
|
||||
control *control.Control
|
||||
http *http.Client
|
||||
openStreamMutex sync.Mutex
|
||||
}
|
||||
|
||||
func (c *Client) Connect(ctx context.Context) error {
|
||||
@ -43,24 +42,22 @@ func (c *Client) Connect(ctx context.Context) error {
|
||||
|
||||
config := smux.DefaultConfig()
|
||||
config.Version = 2
|
||||
config.KeepAliveInterval = 10 * time.Second
|
||||
config.KeepAliveTimeout = 2 * config.KeepAliveInterval
|
||||
|
||||
sess, err := smux.Client(conn, config)
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
controlStream, err := sess.OpenStream()
|
||||
if err != nil {
|
||||
control := control.New()
|
||||
if err := control.Init(ctx, sess, false); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
c.conn = conn
|
||||
c.sess = sess
|
||||
c.control = control.New(sess, controlStream)
|
||||
|
||||
logger.Debug(ctx, "sending auth request")
|
||||
|
||||
success, err := c.control.AuthRequest(c.conf.Credentials)
|
||||
success, err := control.AuthRequest(c.conf.Credentials)
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
@ -70,15 +67,21 @@ func (c *Client) Connect(ctx context.Context) error {
|
||||
return errors.WithStack(ErrAuthFailed)
|
||||
}
|
||||
|
||||
c.control = control
|
||||
c.conn = conn
|
||||
c.sess = sess
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) Listen(ctx context.Context) error {
|
||||
logger.Debug(ctx, "listening for messages")
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
|
||||
err := c.control.Listen(ctx, control.Handlers{
|
||||
control.TypeProxyRequest: c.handleProxyRequest,
|
||||
control.TypeCloseProxy: c.handleCloseProxy,
|
||||
})
|
||||
|
||||
if errors.Is(err, io.ErrClosedPipe) {
|
||||
@ -99,107 +102,62 @@ func (c *Client) Close() error {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
if c.sess != nil && !c.sess.IsClosed() {
|
||||
if err := c.sess.Close(); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) handleCloseProxy(ctx context.Context, m *control.Message) (*control.Message, error) {
|
||||
closeProxyPayload, ok := m.Payload.(*control.CloseProxyPayload)
|
||||
if !ok {
|
||||
return nil, errors.WithStack(ErrUnexpectedMessage)
|
||||
}
|
||||
|
||||
requestID := strconv.FormatInt(closeProxyPayload.RequestID, 10)
|
||||
|
||||
rawCloseChan, exists := c.proxies.Get(requestID)
|
||||
if !exists {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
closeChan, ok := rawCloseChan.(chan struct{})
|
||||
if !ok {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
closeChan <- struct{}{}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (c *Client) handleProxyRequest(ctx context.Context, m *control.Message) (*control.Message, error) {
|
||||
proxyReqPayload, ok := m.Payload.(*control.ProxyRequestPayload)
|
||||
if !ok {
|
||||
return nil, errors.WithStack(ErrUnexpectedMessage)
|
||||
}
|
||||
|
||||
requestID := strconv.FormatInt(proxyReqPayload.RequestID, 10)
|
||||
|
||||
ctx = logger.With(ctx, logger.F("requestID", requestID))
|
||||
|
||||
logger.Debug(
|
||||
ctx, "handling proxy request",
|
||||
ctx = logger.With(ctx,
|
||||
logger.F("network", proxyReqPayload.Network),
|
||||
logger.F("address", proxyReqPayload.Address),
|
||||
)
|
||||
|
||||
stream, err := c.sess.OpenStream()
|
||||
logger.Debug(ctx, "handling proxy request")
|
||||
|
||||
out, err := net.Dial(proxyReqPayload.Network, proxyReqPayload.Address)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
closeChan := make(chan struct{})
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
stream.Close()
|
||||
logger.Debug(ctx, "proxy stream closed")
|
||||
}()
|
||||
|
||||
proxy := func() error {
|
||||
net, err := net.Dial(proxyReqPayload.Network, proxyReqPayload.Address)
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
defer net.Close()
|
||||
|
||||
err = pipe(ctx, stream, net)
|
||||
if errors.Is(err, os.ErrClosed) {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-closeChan:
|
||||
return
|
||||
default:
|
||||
if err := proxy(); err != nil {
|
||||
logger.Error(ctx, "error while proxying", logger.E(err))
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
c.proxies.Set(requestID, closeChan)
|
||||
go c.handleProxyStream(ctx, out)
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (c *Client) handleProxyStream(ctx context.Context, out net.Conn) {
|
||||
c.openStreamMutex.Lock()
|
||||
|
||||
in, err := c.sess.OpenStream()
|
||||
if err != nil {
|
||||
c.openStreamMutex.Unlock()
|
||||
logger.Error(ctx, "error while accepting proxy stream", logger.E(err))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
c.openStreamMutex.Unlock()
|
||||
|
||||
streamCopy := func(dst io.Writer, src io.ReadCloser) {
|
||||
if _, err := Copy(dst, src); err != nil {
|
||||
if errors.Is(err, smux.ErrInvalidProtocol) {
|
||||
logger.Error(ctx, "error while proxying", logger.E(errors.WithStack(err)))
|
||||
}
|
||||
}
|
||||
|
||||
logger.Debug(ctx, "closing proxy stream")
|
||||
|
||||
in.Close()
|
||||
out.Close()
|
||||
}
|
||||
|
||||
go streamCopy(in, out)
|
||||
streamCopy(out, in)
|
||||
}
|
||||
|
||||
func NewClient(funcs ...ClientConfigFunc) *Client {
|
||||
conf := DefaultClientConfig()
|
||||
|
||||
@ -208,8 +166,7 @@ func NewClient(funcs ...ClientConfigFunc) *Client {
|
||||
}
|
||||
|
||||
return &Client{
|
||||
conf: conf,
|
||||
http: &http.Client{},
|
||||
proxies: cmap.New(),
|
||||
conf: conf,
|
||||
http: &http.Client{},
|
||||
}
|
||||
}
|
||||
|
@ -60,7 +60,7 @@ func handleRequest(w http.ResponseWriter, r *http.Request) {
|
||||
return
|
||||
}
|
||||
|
||||
target, err := url.Parse("http://localhost:3000")
|
||||
target, err := url.Parse("https://arcad.games")
|
||||
if err != nil {
|
||||
logger.Fatal(r.Context(), "could not parse url", logger.E(err))
|
||||
}
|
||||
|
@ -3,9 +3,6 @@ package control
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/xtaci/smux"
|
||||
@ -13,11 +10,39 @@ import (
|
||||
)
|
||||
|
||||
type Control struct {
|
||||
encoder *json.Encoder
|
||||
decoder *json.Decoder
|
||||
stream *smux.Stream
|
||||
sess *smux.Session
|
||||
proxyClock int64
|
||||
encoder *json.Encoder
|
||||
decoder *json.Decoder
|
||||
stream *smux.Stream
|
||||
}
|
||||
|
||||
func (c *Control) Init(ctx context.Context, sess *smux.Session, serverMode bool) error {
|
||||
config := smux.DefaultConfig()
|
||||
config.Version = 2
|
||||
|
||||
logger.Debug(ctx, "creating control stream")
|
||||
|
||||
var (
|
||||
stream *smux.Stream
|
||||
err error
|
||||
)
|
||||
|
||||
if serverMode {
|
||||
stream, err = sess.AcceptStream()
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
} else {
|
||||
stream, err = sess.OpenStream()
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
}
|
||||
|
||||
c.stream = stream
|
||||
c.decoder = json.NewDecoder(stream)
|
||||
c.encoder = json.NewEncoder(stream)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Control) AuthRequest(credentials interface{}) (bool, error) {
|
||||
@ -39,89 +64,82 @@ func (c *Control) AuthRequest(credentials interface{}) (bool, error) {
|
||||
return authResPayload.Success, nil
|
||||
}
|
||||
|
||||
func (c *Control) Proxy(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
var (
|
||||
stream *smux.Stream
|
||||
err error
|
||||
)
|
||||
|
||||
requestID := atomic.AddInt64(&c.proxyClock, 1)
|
||||
|
||||
func (c *Control) ProxyReq(ctx context.Context, network, address string) error {
|
||||
req := NewMessage(TypeProxyRequest, &ProxyRequestPayload{
|
||||
RequestID: requestID,
|
||||
Network: network,
|
||||
Address: address,
|
||||
Network: network,
|
||||
Address: address,
|
||||
})
|
||||
|
||||
ctx = logger.With(ctx, logger.F("network", network), logger.F("address", address))
|
||||
|
||||
logger.Debug(ctx, "proxying")
|
||||
|
||||
if err := c.Write(req); err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
logger.Debug(ctx, "opening stream")
|
||||
|
||||
stream, err = c.sess.AcceptStream()
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
|
||||
req := NewMessage(TypeCloseProxy, &CloseProxyPayload{
|
||||
RequestID: requestID,
|
||||
})
|
||||
|
||||
if err := c.Write(req); err != nil {
|
||||
logger.Error(ctx, "error while closing proxy", logger.E(err))
|
||||
}
|
||||
|
||||
logger.Debug(ctx, "closing proxy conn")
|
||||
stream.Close()
|
||||
}()
|
||||
|
||||
return stream, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Control) Listen(ctx context.Context, handlers Handlers) error {
|
||||
for {
|
||||
logger.Debug(ctx, "reading next message")
|
||||
errChan := make(chan error)
|
||||
msgChan := make(chan *Message)
|
||||
dieChan := c.stream.GetDieCh()
|
||||
|
||||
req, err := c.Read()
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
go func(msgChan chan *Message, errChan chan error) {
|
||||
for {
|
||||
logger.Debug(ctx, "reading next message")
|
||||
|
||||
go func() {
|
||||
subCtx := logger.With(ctx, logger.F("messageType", req.Type))
|
||||
|
||||
handler, exists := handlers[req.Type]
|
||||
if !exists {
|
||||
logger.Error(subCtx, "no message handler registered")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(subCtx, req)
|
||||
msg, err := c.Read()
|
||||
if err != nil {
|
||||
logger.Error(subCtx, "error while handling message", logger.E(err))
|
||||
errChan <- errors.WithStack(err)
|
||||
|
||||
close(errChan)
|
||||
close(msgChan)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if res == nil {
|
||||
return
|
||||
}
|
||||
msgChan <- msg
|
||||
}
|
||||
}(msgChan, errChan)
|
||||
|
||||
if err := c.Write(res); err != nil {
|
||||
logger.Error(subCtx, "error while write message response", logger.E(err))
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
|
||||
return
|
||||
}
|
||||
}()
|
||||
case <-dieChan:
|
||||
return errors.WithStack(ErrStreamClosed)
|
||||
|
||||
case err := <-errChan:
|
||||
return errors.WithStack(err)
|
||||
|
||||
case msg := <-msgChan:
|
||||
go func() {
|
||||
subCtx := logger.With(ctx, logger.F("messageType", msg.Type))
|
||||
|
||||
handler, exists := handlers[msg.Type]
|
||||
if !exists {
|
||||
logger.Error(subCtx, "no message handler registered")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(subCtx, msg)
|
||||
if err != nil {
|
||||
logger.Error(subCtx, "error while handling message", logger.E(err))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if res == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.Write(res); err != nil {
|
||||
logger.Error(subCtx, "error while write message response", logger.E(err))
|
||||
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -164,10 +182,6 @@ func (c *Control) read(m *Message) error {
|
||||
}
|
||||
|
||||
func (c *Control) write(m *Message) error {
|
||||
if err := c.stream.SetWriteDeadline(time.Now().Add(time.Second)); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
if err := c.encoder.Encode(m); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
@ -175,11 +189,6 @@ func (c *Control) write(m *Message) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func New(sess *smux.Session, controlStream *smux.Stream) *Control {
|
||||
return &Control{
|
||||
encoder: json.NewEncoder(controlStream),
|
||||
decoder: json.NewDecoder(controlStream),
|
||||
sess: sess,
|
||||
stream: controlStream,
|
||||
}
|
||||
func New() *Control {
|
||||
return &Control{}
|
||||
}
|
||||
|
@ -3,5 +3,6 @@ package control
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrStreamClosed = errors.New("stream closed")
|
||||
ErrUnexpectedMessage = errors.New("unexpected message")
|
||||
)
|
||||
|
@ -1,9 +1,8 @@
|
||||
package control
|
||||
|
||||
type ProxyRequestPayload struct {
|
||||
RequestID int64 `json:"i"`
|
||||
Network string `json:"n"`
|
||||
Address string `json:"a"`
|
||||
Network string `json:"n"`
|
||||
Address string `json:"a"`
|
||||
}
|
||||
|
||||
type CloseProxyPayload struct {
|
||||
|
50
helper.go
50
helper.go
@ -1,52 +1,30 @@
|
||||
package tunnel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/xtaci/kcp-go/v5"
|
||||
"gitlab.com/wpetit/goweb/logger"
|
||||
)
|
||||
|
||||
func pipe(ctx context.Context, client net.Conn, server net.Conn) (err error) {
|
||||
stop := make(chan bool)
|
||||
const bufSize = 4096
|
||||
|
||||
go func() {
|
||||
err = relay(client, server, stop)
|
||||
if err != nil {
|
||||
err = errors.WithStack(err)
|
||||
logger.Debug(ctx, "client->server error", logger.E(err))
|
||||
}
|
||||
}()
|
||||
go func() {
|
||||
err = relay(server, client, stop)
|
||||
if err != nil {
|
||||
err = errors.WithStack(err)
|
||||
logger.Debug(ctx, "server->client error", logger.E(err))
|
||||
}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-stop:
|
||||
return err
|
||||
// From https://github.com/xtaci/kcptun/blob/master/generic/copy.go
|
||||
// Copyright https://github.com/xtaci
|
||||
func Copy(dst io.Writer, src io.Reader) (written int64, err error) {
|
||||
// If the reader has a WriteTo method, use it to do the copy.
|
||||
// Avoids an allocation and a copy.
|
||||
if wt, ok := src.(io.WriterTo); ok {
|
||||
return wt.WriteTo(dst)
|
||||
}
|
||||
}
|
||||
|
||||
func relay(src net.Conn, dst net.Conn, stop chan bool) (err error) {
|
||||
_, err = io.Copy(dst, src)
|
||||
if errors.Is(err, io.EOF) {
|
||||
err = nil
|
||||
// Similarly, if the writer has a ReadFrom method, use it to do the copy.
|
||||
if rt, ok := dst.(io.ReaderFrom); ok {
|
||||
return rt.ReadFrom(src)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
err = errors.WithStack(err)
|
||||
}
|
||||
|
||||
stop <- true
|
||||
|
||||
return
|
||||
// fallback to standard io.CopyBuffer
|
||||
buf := make([]byte, bufSize)
|
||||
return io.CopyBuffer(dst, src, buf)
|
||||
}
|
||||
|
||||
func createBlockCrypt(algorithm string, pass []byte) (kcp.BlockCrypt, error) {
|
||||
|
155
remote_client.go
155
remote_client.go
@ -2,11 +2,13 @@ package tunnel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"forge.cadoles.com/wpetit/go-tunnel/control"
|
||||
|
||||
cmap "github.com/orcaman/concurrent-map"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/xtaci/kcp-go/v5"
|
||||
"github.com/xtaci/smux"
|
||||
@ -20,10 +22,32 @@ type RemoteClient struct {
|
||||
sess *smux.Session
|
||||
control *control.Control
|
||||
remoteAddr net.Addr
|
||||
proxies cmap.ConcurrentMap
|
||||
acceptStreamMutex sync.Mutex
|
||||
}
|
||||
|
||||
func (c *RemoteClient) Accept(ctx context.Context, conn *kcp.UDPSession) error {
|
||||
config := smux.DefaultConfig()
|
||||
config.Version = 2
|
||||
config.KeepAliveInterval = 10 * time.Second
|
||||
config.KeepAliveTimeout = 2 * config.KeepAliveInterval
|
||||
|
||||
logger.Debug(ctx, "creating server session")
|
||||
|
||||
sess, err := smux.Server(conn, config)
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
ctrl := control.New()
|
||||
|
||||
if err := ctrl.Init(ctx, sess, true); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
c.sess = sess
|
||||
c.remoteAddr = conn.RemoteAddr()
|
||||
c.control = ctrl
|
||||
|
||||
if c.onClientConnectHook != nil {
|
||||
if err := c.onClientConnectHook.OnClientConnect(ctx, c); err != nil {
|
||||
@ -31,45 +55,82 @@ func (c *RemoteClient) Accept(ctx context.Context, conn *kcp.UDPSession) error {
|
||||
}
|
||||
}
|
||||
|
||||
config := smux.DefaultConfig()
|
||||
config.Version = 2
|
||||
|
||||
sess, err := smux.Server(conn, config)
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
logger.Debug(ctx, "accepting control stream")
|
||||
|
||||
controlStream, err := sess.AcceptStream()
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
c.sess = sess
|
||||
c.control = control.New(sess, controlStream)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *RemoteClient) Listen(ctx context.Context) error {
|
||||
logger.Debug(ctx, "listening for messages")
|
||||
|
||||
err := c.control.Listen(ctx, control.Handlers{
|
||||
control.TypeAuthRequest: c.handleAuthRequest,
|
||||
})
|
||||
|
||||
if errors.Is(err, io.ErrClosedPipe) {
|
||||
defer func() {
|
||||
if c.onClientDisconnectHook != nil {
|
||||
if err := c.onClientDisconnectHook.OnClientDisconnect(ctx, c); err != nil {
|
||||
logger.Error(ctx, "client disconnect hook error", logger.E(err))
|
||||
logger.Error(ctx, "client disconnect hook error", logger.E(errors.WithStack(err)))
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return errors.WithStack(ErrConnectionClosed)
|
||||
logger.Debug(ctx, "listening for messages")
|
||||
|
||||
return c.control.Listen(ctx, control.Handlers{
|
||||
control.TypeAuthRequest: c.handleAuthRequest,
|
||||
})
|
||||
}
|
||||
|
||||
func (c *RemoteClient) ConfigureHooks(hooks interface{}) {
|
||||
if hooks == nil {
|
||||
return
|
||||
}
|
||||
|
||||
return err
|
||||
if onClientAuthHook, ok := hooks.(OnClientAuthHook); ok {
|
||||
c.onClientAuthHook = onClientAuthHook
|
||||
}
|
||||
|
||||
if OnClientConnectHook, ok := hooks.(OnClientConnectHook); ok {
|
||||
c.onClientConnectHook = OnClientConnectHook
|
||||
}
|
||||
|
||||
if OnClientDisconnectHook, ok := hooks.(OnClientDisconnectHook); ok {
|
||||
c.onClientDisconnectHook = OnClientDisconnectHook
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RemoteClient) RemoteAddr() net.Addr {
|
||||
return c.remoteAddr
|
||||
}
|
||||
|
||||
func (c *RemoteClient) Close() {
|
||||
if c.sess != nil {
|
||||
c.sess.Close()
|
||||
}
|
||||
|
||||
c.sess = nil
|
||||
c.control = nil
|
||||
}
|
||||
|
||||
func (c *RemoteClient) Proxy(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
ctx = logger.With(ctx, logger.F("network", network), logger.F("address", address))
|
||||
|
||||
if err := c.control.ProxyReq(ctx, network, address); err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
logger.Debug(ctx, "opening proxy stream")
|
||||
|
||||
c.acceptStreamMutex.Lock()
|
||||
|
||||
stream, err := c.sess.AcceptStream()
|
||||
if err != nil {
|
||||
c.acceptStreamMutex.Unlock()
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
c.acceptStreamMutex.Unlock()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
logger.Debug(ctx, "closing proxy stream")
|
||||
stream.Close()
|
||||
}()
|
||||
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
func (c *RemoteClient) handleAuthRequest(ctx context.Context, m *control.Message) (*control.Message, error) {
|
||||
@ -101,38 +162,8 @@ func (c *RemoteClient) handleAuthRequest(ctx context.Context, m *control.Message
|
||||
return res, nil
|
||||
}
|
||||
|
||||
func (c *RemoteClient) ConfigureHooks(hooks interface{}) {
|
||||
if hooks == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if onClientAuthHook, ok := hooks.(OnClientAuthHook); ok {
|
||||
c.onClientAuthHook = onClientAuthHook
|
||||
}
|
||||
|
||||
if OnClientConnectHook, ok := hooks.(OnClientConnectHook); ok {
|
||||
c.onClientConnectHook = OnClientConnectHook
|
||||
}
|
||||
|
||||
if OnClientDisconnectHook, ok := hooks.(OnClientDisconnectHook); ok {
|
||||
c.onClientDisconnectHook = OnClientDisconnectHook
|
||||
}
|
||||
}
|
||||
|
||||
func (c *RemoteClient) RemoteAddr() net.Addr {
|
||||
return c.remoteAddr
|
||||
}
|
||||
|
||||
func (c *RemoteClient) Proxy(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return c.control.Proxy(ctx, network, address)
|
||||
}
|
||||
|
||||
func (c *RemoteClient) Close() {
|
||||
if c.sess != nil && !c.sess.IsClosed() {
|
||||
c.sess.Close()
|
||||
}
|
||||
}
|
||||
|
||||
func NewRemoteClient() *RemoteClient {
|
||||
return &RemoteClient{}
|
||||
return &RemoteClient{
|
||||
proxies: cmap.New(),
|
||||
}
|
||||
}
|
||||
|
@ -37,18 +37,20 @@ func (s *Server) handleNewConn(ctx context.Context, conn *kcp.UDPSession) {
|
||||
ctx = logger.With(ctx, logger.F("remoteAddr", conn.RemoteAddr().String()))
|
||||
|
||||
remoteClient := NewRemoteClient()
|
||||
|
||||
defer remoteClient.Close()
|
||||
defer conn.Close()
|
||||
|
||||
remoteClient.ConfigureHooks(s.conf.Hooks)
|
||||
|
||||
if err := remoteClient.Accept(ctx, conn); err != nil {
|
||||
logger.Error(ctx, "remote client error", logger.E(err))
|
||||
logger.Error(ctx, "remote client error", logger.E(errors.WithStack(err)))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if err := remoteClient.Listen(ctx); err != nil {
|
||||
logger.Error(ctx, "remote client error", logger.E(err))
|
||||
logger.Error(ctx, "remote client error", logger.E(errors.WithStack(err)))
|
||||
}
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user