fix: enhance proxy stability
This commit is contained in:
parent
6994ab23ab
commit
536100da90
143
client.go
143
client.go
|
@ -5,13 +5,12 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"os"
|
"sync"
|
||||||
"strconv"
|
"time"
|
||||||
|
|
||||||
"gitlab.com/wpetit/goweb/logger"
|
"gitlab.com/wpetit/goweb/logger"
|
||||||
|
|
||||||
"forge.cadoles.com/wpetit/go-tunnel/control"
|
"forge.cadoles.com/wpetit/go-tunnel/control"
|
||||||
cmap "github.com/orcaman/concurrent-map"
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/xtaci/kcp-go/v5"
|
"github.com/xtaci/kcp-go/v5"
|
||||||
"github.com/xtaci/smux"
|
"github.com/xtaci/smux"
|
||||||
|
@ -23,7 +22,7 @@ type Client struct {
|
||||||
sess *smux.Session
|
sess *smux.Session
|
||||||
control *control.Control
|
control *control.Control
|
||||||
http *http.Client
|
http *http.Client
|
||||||
proxies cmap.ConcurrentMap
|
openStreamMutex sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Connect(ctx context.Context) error {
|
func (c *Client) Connect(ctx context.Context) error {
|
||||||
|
@ -43,24 +42,22 @@ func (c *Client) Connect(ctx context.Context) error {
|
||||||
|
|
||||||
config := smux.DefaultConfig()
|
config := smux.DefaultConfig()
|
||||||
config.Version = 2
|
config.Version = 2
|
||||||
|
config.KeepAliveInterval = 10 * time.Second
|
||||||
|
config.KeepAliveTimeout = 2 * config.KeepAliveInterval
|
||||||
|
|
||||||
sess, err := smux.Client(conn, config)
|
sess, err := smux.Client(conn, config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.WithStack(err)
|
return errors.WithStack(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
controlStream, err := sess.OpenStream()
|
control := control.New()
|
||||||
if err != nil {
|
if err := control.Init(ctx, sess, false); err != nil {
|
||||||
return errors.WithStack(err)
|
return errors.WithStack(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.conn = conn
|
|
||||||
c.sess = sess
|
|
||||||
c.control = control.New(sess, controlStream)
|
|
||||||
|
|
||||||
logger.Debug(ctx, "sending auth request")
|
logger.Debug(ctx, "sending auth request")
|
||||||
|
|
||||||
success, err := c.control.AuthRequest(c.conf.Credentials)
|
success, err := control.AuthRequest(c.conf.Credentials)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.WithStack(err)
|
return errors.WithStack(err)
|
||||||
}
|
}
|
||||||
|
@ -70,15 +67,21 @@ func (c *Client) Connect(ctx context.Context) error {
|
||||||
return errors.WithStack(ErrAuthFailed)
|
return errors.WithStack(ErrAuthFailed)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
c.control = control
|
||||||
|
c.conn = conn
|
||||||
|
c.sess = sess
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Listen(ctx context.Context) error {
|
func (c *Client) Listen(ctx context.Context) error {
|
||||||
logger.Debug(ctx, "listening for messages")
|
logger.Debug(ctx, "listening for messages")
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
err := c.control.Listen(ctx, control.Handlers{
|
err := c.control.Listen(ctx, control.Handlers{
|
||||||
control.TypeProxyRequest: c.handleProxyRequest,
|
control.TypeProxyRequest: c.handleProxyRequest,
|
||||||
control.TypeCloseProxy: c.handleCloseProxy,
|
|
||||||
})
|
})
|
||||||
|
|
||||||
if errors.Is(err, io.ErrClosedPipe) {
|
if errors.Is(err, io.ErrClosedPipe) {
|
||||||
|
@ -99,107 +102,62 @@ func (c *Client) Close() error {
|
||||||
return errors.WithStack(err)
|
return errors.WithStack(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.sess != nil && !c.sess.IsClosed() {
|
|
||||||
if err := c.sess.Close(); err != nil {
|
|
||||||
return errors.WithStack(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) handleCloseProxy(ctx context.Context, m *control.Message) (*control.Message, error) {
|
|
||||||
closeProxyPayload, ok := m.Payload.(*control.CloseProxyPayload)
|
|
||||||
if !ok {
|
|
||||||
return nil, errors.WithStack(ErrUnexpectedMessage)
|
|
||||||
}
|
|
||||||
|
|
||||||
requestID := strconv.FormatInt(closeProxyPayload.RequestID, 10)
|
|
||||||
|
|
||||||
rawCloseChan, exists := c.proxies.Get(requestID)
|
|
||||||
if !exists {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
closeChan, ok := rawCloseChan.(chan struct{})
|
|
||||||
if !ok {
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
closeChan <- struct{}{}
|
|
||||||
|
|
||||||
return nil, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Client) handleProxyRequest(ctx context.Context, m *control.Message) (*control.Message, error) {
|
func (c *Client) handleProxyRequest(ctx context.Context, m *control.Message) (*control.Message, error) {
|
||||||
proxyReqPayload, ok := m.Payload.(*control.ProxyRequestPayload)
|
proxyReqPayload, ok := m.Payload.(*control.ProxyRequestPayload)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.WithStack(ErrUnexpectedMessage)
|
return nil, errors.WithStack(ErrUnexpectedMessage)
|
||||||
}
|
}
|
||||||
|
|
||||||
requestID := strconv.FormatInt(proxyReqPayload.RequestID, 10)
|
ctx = logger.With(ctx,
|
||||||
|
|
||||||
ctx = logger.With(ctx, logger.F("requestID", requestID))
|
|
||||||
|
|
||||||
logger.Debug(
|
|
||||||
ctx, "handling proxy request",
|
|
||||||
logger.F("network", proxyReqPayload.Network),
|
logger.F("network", proxyReqPayload.Network),
|
||||||
logger.F("address", proxyReqPayload.Address),
|
logger.F("address", proxyReqPayload.Address),
|
||||||
)
|
)
|
||||||
|
|
||||||
stream, err := c.sess.OpenStream()
|
logger.Debug(ctx, "handling proxy request")
|
||||||
|
|
||||||
|
out, err := net.Dial(proxyReqPayload.Network, proxyReqPayload.Address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.WithStack(err)
|
return nil, errors.WithStack(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
closeChan := make(chan struct{})
|
go c.handleProxyStream(ctx, out)
|
||||||
|
|
||||||
go func() {
|
|
||||||
defer func() {
|
|
||||||
stream.Close()
|
|
||||||
logger.Debug(ctx, "proxy stream closed")
|
|
||||||
}()
|
|
||||||
|
|
||||||
proxy := func() error {
|
|
||||||
net, err := net.Dial(proxyReqPayload.Network, proxyReqPayload.Address)
|
|
||||||
if err != nil {
|
|
||||||
return errors.WithStack(err)
|
|
||||||
}
|
|
||||||
defer net.Close()
|
|
||||||
|
|
||||||
err = pipe(ctx, stream, net)
|
|
||||||
if errors.Is(err, os.ErrClosed) {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return errors.WithStack(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-closeChan:
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
if err := proxy(); err != nil {
|
|
||||||
logger.Error(ctx, "error while proxying", logger.E(err))
|
|
||||||
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
c.proxies.Set(requestID, closeChan)
|
|
||||||
|
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) handleProxyStream(ctx context.Context, out net.Conn) {
|
||||||
|
c.openStreamMutex.Lock()
|
||||||
|
|
||||||
|
in, err := c.sess.OpenStream()
|
||||||
|
if err != nil {
|
||||||
|
c.openStreamMutex.Unlock()
|
||||||
|
logger.Error(ctx, "error while accepting proxy stream", logger.E(err))
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
c.openStreamMutex.Unlock()
|
||||||
|
|
||||||
|
streamCopy := func(dst io.Writer, src io.ReadCloser) {
|
||||||
|
if _, err := Copy(dst, src); err != nil {
|
||||||
|
if errors.Is(err, smux.ErrInvalidProtocol) {
|
||||||
|
logger.Error(ctx, "error while proxying", logger.E(errors.WithStack(err)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug(ctx, "closing proxy stream")
|
||||||
|
|
||||||
|
in.Close()
|
||||||
|
out.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
go streamCopy(in, out)
|
||||||
|
streamCopy(out, in)
|
||||||
|
}
|
||||||
|
|
||||||
func NewClient(funcs ...ClientConfigFunc) *Client {
|
func NewClient(funcs ...ClientConfigFunc) *Client {
|
||||||
conf := DefaultClientConfig()
|
conf := DefaultClientConfig()
|
||||||
|
|
||||||
|
@ -210,6 +168,5 @@ func NewClient(funcs ...ClientConfigFunc) *Client {
|
||||||
return &Client{
|
return &Client{
|
||||||
conf: conf,
|
conf: conf,
|
||||||
http: &http.Client{},
|
http: &http.Client{},
|
||||||
proxies: cmap.New(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -60,7 +60,7 @@ func handleRequest(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
target, err := url.Parse("http://localhost:3000")
|
target, err := url.Parse("https://arcad.games")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Fatal(r.Context(), "could not parse url", logger.E(err))
|
logger.Fatal(r.Context(), "could not parse url", logger.E(err))
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,9 +3,6 @@ package control
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net"
|
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/xtaci/smux"
|
"github.com/xtaci/smux"
|
||||||
|
@ -16,8 +13,36 @@ type Control struct {
|
||||||
encoder *json.Encoder
|
encoder *json.Encoder
|
||||||
decoder *json.Decoder
|
decoder *json.Decoder
|
||||||
stream *smux.Stream
|
stream *smux.Stream
|
||||||
sess *smux.Session
|
}
|
||||||
proxyClock int64
|
|
||||||
|
func (c *Control) Init(ctx context.Context, sess *smux.Session, serverMode bool) error {
|
||||||
|
config := smux.DefaultConfig()
|
||||||
|
config.Version = 2
|
||||||
|
|
||||||
|
logger.Debug(ctx, "creating control stream")
|
||||||
|
|
||||||
|
var (
|
||||||
|
stream *smux.Stream
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
if serverMode {
|
||||||
|
stream, err = sess.AcceptStream()
|
||||||
|
if err != nil {
|
||||||
|
return errors.WithStack(err)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
stream, err = sess.OpenStream()
|
||||||
|
if err != nil {
|
||||||
|
return errors.WithStack(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.stream = stream
|
||||||
|
c.decoder = json.NewDecoder(stream)
|
||||||
|
c.encoder = json.NewEncoder(stream)
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Control) AuthRequest(credentials interface{}) (bool, error) {
|
func (c *Control) AuthRequest(credentials interface{}) (bool, error) {
|
||||||
|
@ -39,73 +64,65 @@ func (c *Control) AuthRequest(credentials interface{}) (bool, error) {
|
||||||
return authResPayload.Success, nil
|
return authResPayload.Success, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Control) Proxy(ctx context.Context, network, address string) (net.Conn, error) {
|
func (c *Control) ProxyReq(ctx context.Context, network, address string) error {
|
||||||
var (
|
|
||||||
stream *smux.Stream
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
|
|
||||||
requestID := atomic.AddInt64(&c.proxyClock, 1)
|
|
||||||
|
|
||||||
req := NewMessage(TypeProxyRequest, &ProxyRequestPayload{
|
req := NewMessage(TypeProxyRequest, &ProxyRequestPayload{
|
||||||
RequestID: requestID,
|
|
||||||
Network: network,
|
Network: network,
|
||||||
Address: address,
|
Address: address,
|
||||||
})
|
})
|
||||||
|
|
||||||
ctx = logger.With(ctx, logger.F("network", network), logger.F("address", address))
|
|
||||||
|
|
||||||
logger.Debug(ctx, "proxying")
|
|
||||||
|
|
||||||
if err := c.Write(req); err != nil {
|
if err := c.Write(req); err != nil {
|
||||||
return nil, errors.WithStack(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Debug(ctx, "opening stream")
|
|
||||||
|
|
||||||
stream, err = c.sess.AcceptStream()
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.WithStack(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
<-ctx.Done()
|
|
||||||
|
|
||||||
req := NewMessage(TypeCloseProxy, &CloseProxyPayload{
|
|
||||||
RequestID: requestID,
|
|
||||||
})
|
|
||||||
|
|
||||||
if err := c.Write(req); err != nil {
|
|
||||||
logger.Error(ctx, "error while closing proxy", logger.E(err))
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Debug(ctx, "closing proxy conn")
|
|
||||||
stream.Close()
|
|
||||||
}()
|
|
||||||
|
|
||||||
return stream, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Control) Listen(ctx context.Context, handlers Handlers) error {
|
|
||||||
for {
|
|
||||||
logger.Debug(ctx, "reading next message")
|
|
||||||
|
|
||||||
req, err := c.Read()
|
|
||||||
if err != nil {
|
|
||||||
return errors.WithStack(err)
|
return errors.WithStack(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
go func() {
|
return nil
|
||||||
subCtx := logger.With(ctx, logger.F("messageType", req.Type))
|
}
|
||||||
|
|
||||||
handler, exists := handlers[req.Type]
|
func (c *Control) Listen(ctx context.Context, handlers Handlers) error {
|
||||||
|
errChan := make(chan error)
|
||||||
|
msgChan := make(chan *Message)
|
||||||
|
dieChan := c.stream.GetDieCh()
|
||||||
|
|
||||||
|
go func(msgChan chan *Message, errChan chan error) {
|
||||||
|
for {
|
||||||
|
logger.Debug(ctx, "reading next message")
|
||||||
|
|
||||||
|
msg, err := c.Read()
|
||||||
|
if err != nil {
|
||||||
|
errChan <- errors.WithStack(err)
|
||||||
|
|
||||||
|
close(errChan)
|
||||||
|
close(msgChan)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
msgChan <- msg
|
||||||
|
}
|
||||||
|
}(msgChan, errChan)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil
|
||||||
|
|
||||||
|
case <-dieChan:
|
||||||
|
return errors.WithStack(ErrStreamClosed)
|
||||||
|
|
||||||
|
case err := <-errChan:
|
||||||
|
return errors.WithStack(err)
|
||||||
|
|
||||||
|
case msg := <-msgChan:
|
||||||
|
go func() {
|
||||||
|
subCtx := logger.With(ctx, logger.F("messageType", msg.Type))
|
||||||
|
|
||||||
|
handler, exists := handlers[msg.Type]
|
||||||
if !exists {
|
if !exists {
|
||||||
logger.Error(subCtx, "no message handler registered")
|
logger.Error(subCtx, "no message handler registered")
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
res, err := handler(subCtx, req)
|
res, err := handler(subCtx, msg)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
logger.Error(subCtx, "error while handling message", logger.E(err))
|
logger.Error(subCtx, "error while handling message", logger.E(err))
|
||||||
|
|
||||||
|
@ -124,6 +141,7 @@ func (c *Control) Listen(ctx context.Context, handlers Handlers) error {
|
||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (c *Control) Read() (*Message, error) {
|
func (c *Control) Read() (*Message, error) {
|
||||||
message := &Message{}
|
message := &Message{}
|
||||||
|
@ -164,10 +182,6 @@ func (c *Control) read(m *Message) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Control) write(m *Message) error {
|
func (c *Control) write(m *Message) error {
|
||||||
if err := c.stream.SetWriteDeadline(time.Now().Add(time.Second)); err != nil {
|
|
||||||
return errors.WithStack(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := c.encoder.Encode(m); err != nil {
|
if err := c.encoder.Encode(m); err != nil {
|
||||||
return errors.WithStack(err)
|
return errors.WithStack(err)
|
||||||
}
|
}
|
||||||
|
@ -175,11 +189,6 @@ func (c *Control) write(m *Message) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(sess *smux.Session, controlStream *smux.Stream) *Control {
|
func New() *Control {
|
||||||
return &Control{
|
return &Control{}
|
||||||
encoder: json.NewEncoder(controlStream),
|
|
||||||
decoder: json.NewDecoder(controlStream),
|
|
||||||
sess: sess,
|
|
||||||
stream: controlStream,
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -3,5 +3,6 @@ package control
|
||||||
import "errors"
|
import "errors"
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
ErrStreamClosed = errors.New("stream closed")
|
||||||
ErrUnexpectedMessage = errors.New("unexpected message")
|
ErrUnexpectedMessage = errors.New("unexpected message")
|
||||||
)
|
)
|
||||||
|
|
|
@ -1,7 +1,6 @@
|
||||||
package control
|
package control
|
||||||
|
|
||||||
type ProxyRequestPayload struct {
|
type ProxyRequestPayload struct {
|
||||||
RequestID int64 `json:"i"`
|
|
||||||
Network string `json:"n"`
|
Network string `json:"n"`
|
||||||
Address string `json:"a"`
|
Address string `json:"a"`
|
||||||
}
|
}
|
||||||
|
|
50
helper.go
50
helper.go
|
@ -1,52 +1,30 @@
|
||||||
package tunnel
|
package tunnel
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"io"
|
"io"
|
||||||
"net"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/xtaci/kcp-go/v5"
|
"github.com/xtaci/kcp-go/v5"
|
||||||
"gitlab.com/wpetit/goweb/logger"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func pipe(ctx context.Context, client net.Conn, server net.Conn) (err error) {
|
const bufSize = 4096
|
||||||
stop := make(chan bool)
|
|
||||||
|
|
||||||
go func() {
|
// From https://github.com/xtaci/kcptun/blob/master/generic/copy.go
|
||||||
err = relay(client, server, stop)
|
// Copyright https://github.com/xtaci
|
||||||
if err != nil {
|
func Copy(dst io.Writer, src io.Reader) (written int64, err error) {
|
||||||
err = errors.WithStack(err)
|
// If the reader has a WriteTo method, use it to do the copy.
|
||||||
logger.Debug(ctx, "client->server error", logger.E(err))
|
// Avoids an allocation and a copy.
|
||||||
}
|
if wt, ok := src.(io.WriterTo); ok {
|
||||||
}()
|
return wt.WriteTo(dst)
|
||||||
go func() {
|
|
||||||
err = relay(server, client, stop)
|
|
||||||
if err != nil {
|
|
||||||
err = errors.WithStack(err)
|
|
||||||
logger.Debug(ctx, "server->client error", logger.E(err))
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-stop:
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
// Similarly, if the writer has a ReadFrom method, use it to do the copy.
|
||||||
|
if rt, ok := dst.(io.ReaderFrom); ok {
|
||||||
|
return rt.ReadFrom(src)
|
||||||
}
|
}
|
||||||
|
|
||||||
func relay(src net.Conn, dst net.Conn, stop chan bool) (err error) {
|
// fallback to standard io.CopyBuffer
|
||||||
_, err = io.Copy(dst, src)
|
buf := make([]byte, bufSize)
|
||||||
if errors.Is(err, io.EOF) {
|
return io.CopyBuffer(dst, src, buf)
|
||||||
err = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
err = errors.WithStack(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
stop <- true
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func createBlockCrypt(algorithm string, pass []byte) (kcp.BlockCrypt, error) {
|
func createBlockCrypt(algorithm string, pass []byte) (kcp.BlockCrypt, error) {
|
||||||
|
|
155
remote_client.go
155
remote_client.go
|
@ -2,11 +2,13 @@ package tunnel
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"io"
|
|
||||||
"net"
|
"net"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"forge.cadoles.com/wpetit/go-tunnel/control"
|
"forge.cadoles.com/wpetit/go-tunnel/control"
|
||||||
|
|
||||||
|
cmap "github.com/orcaman/concurrent-map"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/xtaci/kcp-go/v5"
|
"github.com/xtaci/kcp-go/v5"
|
||||||
"github.com/xtaci/smux"
|
"github.com/xtaci/smux"
|
||||||
|
@ -20,10 +22,32 @@ type RemoteClient struct {
|
||||||
sess *smux.Session
|
sess *smux.Session
|
||||||
control *control.Control
|
control *control.Control
|
||||||
remoteAddr net.Addr
|
remoteAddr net.Addr
|
||||||
|
proxies cmap.ConcurrentMap
|
||||||
|
acceptStreamMutex sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *RemoteClient) Accept(ctx context.Context, conn *kcp.UDPSession) error {
|
func (c *RemoteClient) Accept(ctx context.Context, conn *kcp.UDPSession) error {
|
||||||
|
config := smux.DefaultConfig()
|
||||||
|
config.Version = 2
|
||||||
|
config.KeepAliveInterval = 10 * time.Second
|
||||||
|
config.KeepAliveTimeout = 2 * config.KeepAliveInterval
|
||||||
|
|
||||||
|
logger.Debug(ctx, "creating server session")
|
||||||
|
|
||||||
|
sess, err := smux.Server(conn, config)
|
||||||
|
if err != nil {
|
||||||
|
return errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctrl := control.New()
|
||||||
|
|
||||||
|
if err := ctrl.Init(ctx, sess, true); err != nil {
|
||||||
|
return errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.sess = sess
|
||||||
c.remoteAddr = conn.RemoteAddr()
|
c.remoteAddr = conn.RemoteAddr()
|
||||||
|
c.control = ctrl
|
||||||
|
|
||||||
if c.onClientConnectHook != nil {
|
if c.onClientConnectHook != nil {
|
||||||
if err := c.onClientConnectHook.OnClientConnect(ctx, c); err != nil {
|
if err := c.onClientConnectHook.OnClientConnect(ctx, c); err != nil {
|
||||||
|
@ -31,45 +55,82 @@ func (c *RemoteClient) Accept(ctx context.Context, conn *kcp.UDPSession) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
config := smux.DefaultConfig()
|
|
||||||
config.Version = 2
|
|
||||||
|
|
||||||
sess, err := smux.Server(conn, config)
|
|
||||||
if err != nil {
|
|
||||||
return errors.WithStack(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Debug(ctx, "accepting control stream")
|
|
||||||
|
|
||||||
controlStream, err := sess.AcceptStream()
|
|
||||||
if err != nil {
|
|
||||||
return errors.WithStack(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
c.sess = sess
|
|
||||||
c.control = control.New(sess, controlStream)
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *RemoteClient) Listen(ctx context.Context) error {
|
func (c *RemoteClient) Listen(ctx context.Context) error {
|
||||||
logger.Debug(ctx, "listening for messages")
|
defer func() {
|
||||||
|
|
||||||
err := c.control.Listen(ctx, control.Handlers{
|
|
||||||
control.TypeAuthRequest: c.handleAuthRequest,
|
|
||||||
})
|
|
||||||
|
|
||||||
if errors.Is(err, io.ErrClosedPipe) {
|
|
||||||
if c.onClientDisconnectHook != nil {
|
if c.onClientDisconnectHook != nil {
|
||||||
if err := c.onClientDisconnectHook.OnClientDisconnect(ctx, c); err != nil {
|
if err := c.onClientDisconnectHook.OnClientDisconnect(ctx, c); err != nil {
|
||||||
logger.Error(ctx, "client disconnect hook error", logger.E(err))
|
logger.Error(ctx, "client disconnect hook error", logger.E(errors.WithStack(err)))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
logger.Debug(ctx, "listening for messages")
|
||||||
|
|
||||||
|
return c.control.Listen(ctx, control.Handlers{
|
||||||
|
control.TypeAuthRequest: c.handleAuthRequest,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RemoteClient) ConfigureHooks(hooks interface{}) {
|
||||||
|
if hooks == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if onClientAuthHook, ok := hooks.(OnClientAuthHook); ok {
|
||||||
|
c.onClientAuthHook = onClientAuthHook
|
||||||
|
}
|
||||||
|
|
||||||
|
if OnClientConnectHook, ok := hooks.(OnClientConnectHook); ok {
|
||||||
|
c.onClientConnectHook = OnClientConnectHook
|
||||||
|
}
|
||||||
|
|
||||||
|
if OnClientDisconnectHook, ok := hooks.(OnClientDisconnectHook); ok {
|
||||||
|
c.onClientDisconnectHook = OnClientDisconnectHook
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return errors.WithStack(ErrConnectionClosed)
|
func (c *RemoteClient) RemoteAddr() net.Addr {
|
||||||
|
return c.remoteAddr
|
||||||
}
|
}
|
||||||
|
|
||||||
return err
|
func (c *RemoteClient) Close() {
|
||||||
|
if c.sess != nil {
|
||||||
|
c.sess.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
c.sess = nil
|
||||||
|
c.control = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RemoteClient) Proxy(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
ctx = logger.With(ctx, logger.F("network", network), logger.F("address", address))
|
||||||
|
|
||||||
|
if err := c.control.ProxyReq(ctx, network, address); err != nil {
|
||||||
|
return nil, errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug(ctx, "opening proxy stream")
|
||||||
|
|
||||||
|
c.acceptStreamMutex.Lock()
|
||||||
|
|
||||||
|
stream, err := c.sess.AcceptStream()
|
||||||
|
if err != nil {
|
||||||
|
c.acceptStreamMutex.Unlock()
|
||||||
|
return nil, errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.acceptStreamMutex.Unlock()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
logger.Debug(ctx, "closing proxy stream")
|
||||||
|
stream.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
return stream, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *RemoteClient) handleAuthRequest(ctx context.Context, m *control.Message) (*control.Message, error) {
|
func (c *RemoteClient) handleAuthRequest(ctx context.Context, m *control.Message) (*control.Message, error) {
|
||||||
|
@ -101,38 +162,8 @@ func (c *RemoteClient) handleAuthRequest(ctx context.Context, m *control.Message
|
||||||
return res, nil
|
return res, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *RemoteClient) ConfigureHooks(hooks interface{}) {
|
|
||||||
if hooks == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if onClientAuthHook, ok := hooks.(OnClientAuthHook); ok {
|
|
||||||
c.onClientAuthHook = onClientAuthHook
|
|
||||||
}
|
|
||||||
|
|
||||||
if OnClientConnectHook, ok := hooks.(OnClientConnectHook); ok {
|
|
||||||
c.onClientConnectHook = OnClientConnectHook
|
|
||||||
}
|
|
||||||
|
|
||||||
if OnClientDisconnectHook, ok := hooks.(OnClientDisconnectHook); ok {
|
|
||||||
c.onClientDisconnectHook = OnClientDisconnectHook
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *RemoteClient) RemoteAddr() net.Addr {
|
|
||||||
return c.remoteAddr
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *RemoteClient) Proxy(ctx context.Context, network, address string) (net.Conn, error) {
|
|
||||||
return c.control.Proxy(ctx, network, address)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *RemoteClient) Close() {
|
|
||||||
if c.sess != nil && !c.sess.IsClosed() {
|
|
||||||
c.sess.Close()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewRemoteClient() *RemoteClient {
|
func NewRemoteClient() *RemoteClient {
|
||||||
return &RemoteClient{}
|
return &RemoteClient{
|
||||||
|
proxies: cmap.New(),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -37,18 +37,20 @@ func (s *Server) handleNewConn(ctx context.Context, conn *kcp.UDPSession) {
|
||||||
ctx = logger.With(ctx, logger.F("remoteAddr", conn.RemoteAddr().String()))
|
ctx = logger.With(ctx, logger.F("remoteAddr", conn.RemoteAddr().String()))
|
||||||
|
|
||||||
remoteClient := NewRemoteClient()
|
remoteClient := NewRemoteClient()
|
||||||
|
|
||||||
defer remoteClient.Close()
|
defer remoteClient.Close()
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
remoteClient.ConfigureHooks(s.conf.Hooks)
|
remoteClient.ConfigureHooks(s.conf.Hooks)
|
||||||
|
|
||||||
if err := remoteClient.Accept(ctx, conn); err != nil {
|
if err := remoteClient.Accept(ctx, conn); err != nil {
|
||||||
logger.Error(ctx, "remote client error", logger.E(err))
|
logger.Error(ctx, "remote client error", logger.E(errors.WithStack(err)))
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := remoteClient.Listen(ctx); err != nil {
|
if err := remoteClient.Listen(ctx); err != nil {
|
||||||
logger.Error(ctx, "remote client error", logger.E(err))
|
logger.Error(ctx, "remote client error", logger.E(errors.WithStack(err)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue