package tunnel import ( "context" "io" "net" "net/http" "sync" "time" "gitlab.com/wpetit/goweb/logger" "forge.cadoles.com/wpetit/go-tunnel/control" "github.com/pkg/errors" "github.com/xtaci/kcp-go/v5" "github.com/xtaci/smux" ) type Client struct { conf *ClientConfig conn *kcp.UDPSession sess *smux.Session control *control.Control http *http.Client openStreamMutex sync.Mutex } func (c *Client) Connect(ctx context.Context) error { conn, err := kcp.DialWithOptions( c.conf.ServerAddress, c.conf.BlockCrypt, c.conf.DataShards, c.conf.ParityShards, ) if err != nil { return errors.WithStack(err) } if c.conf.ConfigureConn != nil { if err := c.conf.ConfigureConn(conn); err != nil { return errors.WithStack(err) } } config := smux.DefaultConfig() config.Version = 2 config.KeepAliveInterval = 10 * time.Second config.KeepAliveTimeout = 2 * config.KeepAliveInterval sess, err := smux.Client(conn, config) if err != nil { return errors.WithStack(err) } control := control.New() if err := control.Init(ctx, sess, false); err != nil { return errors.WithStack(err) } logger.Debug(ctx, "sending auth request") success, err := control.AuthRequest(c.conf.Credentials) if err != nil { return errors.WithStack(err) } if !success { defer c.Close() return errors.WithStack(ErrAuthFailed) } c.control = control c.conn = conn c.sess = sess return nil } func (c *Client) Listen(ctx context.Context) error { logger.Debug(ctx, "listening for messages") ctx, cancel := context.WithCancel(ctx) defer cancel() err := c.control.Listen(ctx, control.Handlers{ control.TypeProxyRequest: c.handleProxyRequest, }) if errors.Is(err, io.ErrClosedPipe) { logger.Debug(ctx, "client connection closed") return errors.WithStack(ErrConnectionClosed) } return err } func (c *Client) Close() error { if c.conn == nil { return errors.WithStack(ErrNotConnected) } if err := c.conn.Close(); err != nil { return errors.WithStack(err) } return nil } func (c *Client) handleProxyRequest(ctx context.Context, m *control.Message) (*control.Message, error) { proxyReqPayload, ok := m.Payload.(*control.ProxyRequestPayload) if !ok { return nil, errors.WithStack(ErrUnexpectedMessage) } ctx = logger.With(ctx, logger.F("network", proxyReqPayload.Network), logger.F("address", proxyReqPayload.Address), ) logger.Debug(ctx, "handling proxy request") out, err := net.Dial(proxyReqPayload.Network, proxyReqPayload.Address) if err != nil { return nil, errors.WithStack(err) } go c.handleProxyStream(ctx, out) return nil, nil } func (c *Client) handleProxyStream(ctx context.Context, out net.Conn) { c.openStreamMutex.Lock() in, err := c.sess.OpenStream() if err != nil { c.openStreamMutex.Unlock() logger.Error(ctx, "error while accepting proxy stream", logger.E(err)) return } c.openStreamMutex.Unlock() streamCopy := func(dst io.Writer, src io.ReadCloser) { if _, err := Copy(dst, src); err != nil { if errors.Is(err, smux.ErrInvalidProtocol) { logger.Error(ctx, "error while proxying", logger.E(errors.WithStack(err))) } } logger.Debug(ctx, "closing proxy stream") in.Close() out.Close() } go streamCopy(in, out) streamCopy(out, in) } func NewClient(funcs ...ClientConfigFunc) *Client { conf := DefaultClientConfig() for _, fn := range funcs { fn(conf) } return &Client{ conf: conf, http: &http.Client{}, } }