package tunnel import ( "context" "io" "net" "sync" "time" "forge.cadoles.com/wpetit/go-tunnel/control" cmap "github.com/orcaman/concurrent-map" "github.com/pkg/errors" "github.com/xtaci/kcp-go/v5" "github.com/xtaci/smux" "gitlab.com/wpetit/goweb/logger" ) type RemoteClient struct { onClientAuthHook OnClientAuthHook onClientConnectHook OnClientConnectHook onClientDisconnectHook OnClientDisconnectHook conn net.Conn sess *smux.Session control *control.Control remoteAddr net.Addr proxies cmap.ConcurrentMap acceptStreamMutex sync.Mutex } func (c *RemoteClient) Accept(ctx context.Context, conn *kcp.UDPSession) error { config := smux.DefaultConfig() config.Version = 2 config.KeepAliveInterval = 10 * time.Second config.KeepAliveTimeout = 2 * config.KeepAliveInterval sess, err := smux.Server(conn, config) if err != nil { return errors.WithStack(err) } control := control.New() if err := control.Init(ctx, sess, true); err != nil { return errors.WithStack(err) } c.sess = sess c.remoteAddr = conn.RemoteAddr() c.control = control c.conn = conn if c.onClientConnectHook != nil { if err := c.onClientConnectHook.OnClientConnect(ctx, c); err != nil { return errors.WithStack(err) } } return nil } func (c *RemoteClient) Listen(ctx context.Context) error { logger.Debug(ctx, "listening for messages") err := c.control.Listen(ctx, control.Handlers{ control.TypeAuthRequest: c.handleAuthRequest, }) if errors.Is(err, io.ErrClosedPipe) { if c.onClientDisconnectHook != nil { if err := c.onClientDisconnectHook.OnClientDisconnect(ctx, c); err != nil { logger.Error(ctx, "client disconnect hook error", logger.E(errors.WithStack(err))) } } return errors.WithStack(ErrConnectionClosed) } return errors.WithStack(err) } func (c *RemoteClient) ConfigureHooks(hooks interface{}) { if hooks == nil { return } if onClientAuthHook, ok := hooks.(OnClientAuthHook); ok { c.onClientAuthHook = onClientAuthHook } if OnClientConnectHook, ok := hooks.(OnClientConnectHook); ok { c.onClientConnectHook = OnClientConnectHook } if OnClientDisconnectHook, ok := hooks.(OnClientDisconnectHook); ok { c.onClientDisconnectHook = OnClientDisconnectHook } } func (c *RemoteClient) RemoteAddr() net.Addr { return c.remoteAddr } func (c *RemoteClient) Close() { if c.conn != nil { c.conn.Close() } } func (c *RemoteClient) Proxy(ctx context.Context, network, address string) (net.Conn, error) { ctx = logger.With(ctx, logger.F("network", network), logger.F("address", address)) if err := c.control.ProxyReq(ctx, network, address); err != nil { return nil, errors.WithStack(err) } logger.Debug(ctx, "opening proxy stream") c.acceptStreamMutex.Lock() stream, err := c.sess.AcceptStream() if err != nil { c.acceptStreamMutex.Unlock() return nil, errors.WithStack(err) } c.acceptStreamMutex.Unlock() go func() { <-ctx.Done() logger.Debug(ctx, "closing proxy stream") stream.Close() }() return stream, nil } func (c *RemoteClient) handleAuthRequest(ctx context.Context, m *control.Message) (*control.Message, error) { authReqPayload, ok := m.Payload.(*control.AuthRequestPayload) if !ok { return nil, errors.WithStack(ErrUnexpectedMessage) } logger.Debug(ctx, "handling auth request", logger.F("credentials", authReqPayload.Credentials)) var ( success bool err error ) if c.onClientAuthHook != nil { success, err = c.onClientAuthHook.OnClientAuth(ctx, c, authReqPayload.Credentials) if err != nil { return nil, errors.WithStack(err) } } logger.Debug(ctx, "auth succeeded", logger.F("credentials", authReqPayload.Credentials)) res := control.NewMessage(control.TypeAuthResponse, &control.AuthResponsePayload{ Success: success, }) return res, nil } func NewRemoteClient() *RemoteClient { return &RemoteClient{ proxies: cmap.New(), } }