package tunnel import ( "context" "encoding/json" "net" "sync" "time" "github.com/pkg/errors" "github.com/xtaci/kcp-go/v5" "github.com/xtaci/smux" "gitlab.com/wpetit/goweb/logger" ) type RemoteClient struct { onClientAuthHook OnClientAuthHook onClientConnectHook OnClientConnectHook onClientDisconnectHook OnClientDisconnectHook conn *kcp.UDPSession sess *smux.Session remoteAddr net.Addr authenticationTimeout time.Duration proxyRequestTimeout time.Duration connMutex sync.RWMutex smuxConfig *smux.Config } func (c *RemoteClient) Accept(ctx context.Context, conn *kcp.UDPSession) error { c.connMutex.Lock() defer c.connMutex.Unlock() if err := c.Close(); err != nil { return errors.WithStack(err) } sess, err := c.acceptSession(ctx, conn) if err != nil { return errors.WithStack(err) } stream, err := sess.AcceptStream() if err != nil { return errors.WithStack(err) } defer stream.Close() if err := c.authenticate(ctx, stream); err != nil { return errors.WithStack(err) } c.sess = sess c.conn = conn return nil } func (c *RemoteClient) ConfigureHooks(hooks interface{}) { if hooks == nil { return } if onClientAuthHook, ok := hooks.(OnClientAuthHook); ok { c.onClientAuthHook = onClientAuthHook } if OnClientConnectHook, ok := hooks.(OnClientConnectHook); ok { c.onClientConnectHook = OnClientConnectHook } if OnClientDisconnectHook, ok := hooks.(OnClientDisconnectHook); ok { c.onClientDisconnectHook = OnClientDisconnectHook } } func (c *RemoteClient) RemoteAddr() net.Addr { return c.remoteAddr } func (c *RemoteClient) Close() error { if c.sess != nil { if err := c.sess.Close(); err != nil { return errors.WithStack(err) } } if c.conn != nil { if err := c.conn.Close(); err != nil { return errors.WithStack(err) } } c.sess = nil c.conn = nil return nil } func (c *RemoteClient) SwitchConn(ctx context.Context, conn *kcp.UDPSession) error { c.connMutex.Lock() defer c.connMutex.Unlock() if err := c.Close(); err != nil { return errors.WithStack(err) } sess, err := c.acceptSession(ctx, conn) if err != nil { return errors.WithStack(err) } c.sess = sess c.conn = conn return nil } func (c *RemoteClient) Proxy(ctx context.Context, network, address string) (net.Conn, error) { c.connMutex.RLock() defer c.connMutex.RUnlock() ctx = logger.With(ctx, logger.F("network", network), logger.F("address", address)) logger.Debug(ctx, "opening proxy stream") stream, err := c.sess.OpenStream() if err != nil { return nil, errors.WithStack(err) } proxyReq := &proxyRequest{ Network: network, Address: address, } encoder := json.NewEncoder(stream) writeDeadline := time.Now().Add(c.proxyRequestTimeout) logger.Debug(ctx, "sending proxy req", logger.F("deadline", writeDeadline)) if err := stream.SetWriteDeadline(writeDeadline); err != nil { stream.Close() return nil, errors.WithStack(err) } if err := encoder.Encode(proxyReq); err != nil { stream.Close() return nil, errors.WithStack(err) } if err := stream.SetWriteDeadline(time.Time{}); err != nil { stream.Close() return nil, errors.WithStack(err) } return stream, nil } func (c *RemoteClient) acceptSession(ctx context.Context, conn *kcp.UDPSession) (*smux.Session, error) { logger.Debug(ctx, "accepting client session") sess, err := smux.Server(conn, c.smuxConfig) if err != nil { return nil, errors.WithStack(err) } c.remoteAddr = conn.RemoteAddr() if c.onClientConnectHook != nil { if err := c.onClientConnectHook.OnClientConnect(ctx, c); err != nil { return nil, errors.WithStack(err) } } return sess, nil } func (c *RemoteClient) authenticate(ctx context.Context, stream *smux.Stream) error { start := time.Now() readDeadline := time.Now().Add(c.authenticationTimeout) logger.Debug(ctx, "waiting for auth request", logger.F("deadline", readDeadline)) if err := stream.SetReadDeadline(readDeadline); err != nil { return errors.WithStack(err) } decoder := json.NewDecoder(stream) authReq := &authRequest{} if err := decoder.Decode(authReq); err != nil { return errors.WithStack(err) } var ( success bool err error ) logger.Debug(ctx, "received client credentials", logger.F("credentials", authReq.Credentials)) if c.onClientAuthHook != nil { success, err = c.onClientAuthHook.OnClientAuth(ctx, c, authReq.Credentials) if err != nil { return errors.WithStack(err) } } authRes := &authResponse{ Success: success, } encoder := json.NewEncoder(stream) writeDeadline := time.Now().Add(c.authenticationTimeout - time.Since(start)) logger.Debug(ctx, "sending auth response", logger.F("deadline", writeDeadline)) if err := stream.SetWriteDeadline(writeDeadline); err != nil { return errors.WithStack(err) } if err := encoder.Encode(authRes); err != nil { return errors.WithStack(err) } if !success { return errors.WithStack(ErrAuthenticationFailed) } return nil } func NewRemoteClient(smuxConfig *smux.Config, authenticationTimeout, proxyRequestTimeout time.Duration) *RemoteClient { return &RemoteClient{ smuxConfig: smuxConfig, authenticationTimeout: authenticationTimeout, proxyRequestTimeout: proxyRequestTimeout, } }