go-tunnel/remote_client.go

242 lines
5.2 KiB
Go

package tunnel
import (
"context"
"encoding/json"
"net"
"sync"
"time"
"github.com/pkg/errors"
"github.com/xtaci/kcp-go/v5"
"github.com/xtaci/smux"
"gitlab.com/wpetit/goweb/logger"
)
type RemoteClient struct {
onClientAuthHook OnClientAuthHook
onClientConnectHook OnClientConnectHook
onClientDisconnectHook OnClientDisconnectHook
conn *kcp.UDPSession
sess *smux.Session
remoteAddr net.Addr
authenticationTimeout time.Duration
proxyRequestTimeout time.Duration
connMutex sync.RWMutex
smuxConfig *smux.Config
}
func (c *RemoteClient) Accept(ctx context.Context, conn *kcp.UDPSession) error {
c.connMutex.Lock()
defer c.connMutex.Unlock()
if err := c.Close(); err != nil {
return errors.WithStack(err)
}
sess, err := c.acceptSession(ctx, conn)
if err != nil {
return errors.WithStack(err)
}
stream, err := sess.AcceptStream()
if err != nil {
return errors.WithStack(err)
}
defer stream.Close()
if err := c.authenticate(ctx, stream); err != nil {
return errors.WithStack(err)
}
c.sess = sess
c.conn = conn
return nil
}
func (c *RemoteClient) ConfigureHooks(hooks interface{}) {
if hooks == nil {
return
}
if onClientAuthHook, ok := hooks.(OnClientAuthHook); ok {
c.onClientAuthHook = onClientAuthHook
}
if OnClientConnectHook, ok := hooks.(OnClientConnectHook); ok {
c.onClientConnectHook = OnClientConnectHook
}
if OnClientDisconnectHook, ok := hooks.(OnClientDisconnectHook); ok {
c.onClientDisconnectHook = OnClientDisconnectHook
}
}
func (c *RemoteClient) RemoteAddr() net.Addr {
return c.remoteAddr
}
func (c *RemoteClient) Close() error {
if c.sess != nil {
if err := c.sess.Close(); err != nil {
return errors.WithStack(err)
}
}
if c.conn != nil {
if err := c.conn.Close(); err != nil {
return errors.WithStack(err)
}
}
c.sess = nil
c.conn = nil
return nil
}
func (c *RemoteClient) SwitchConn(ctx context.Context, conn *kcp.UDPSession) error {
c.connMutex.Lock()
defer c.connMutex.Unlock()
if err := c.Close(); err != nil {
return errors.WithStack(err)
}
sess, err := c.acceptSession(ctx, conn)
if err != nil {
return errors.WithStack(err)
}
c.sess = sess
c.conn = conn
return nil
}
func (c *RemoteClient) Proxy(ctx context.Context, network, address string) (net.Conn, error) {
c.connMutex.RLock()
defer c.connMutex.RUnlock()
ctx = logger.With(ctx, logger.F("network", network), logger.F("address", address))
logger.Debug(ctx, "opening proxy stream")
stream, err := c.sess.OpenStream()
if err != nil {
return nil, errors.WithStack(err)
}
proxyReq := &proxyRequest{
Network: network,
Address: address,
}
encoder := json.NewEncoder(stream)
writeDeadline := time.Now().Add(c.proxyRequestTimeout)
logger.Debug(ctx, "sending proxy req", logger.F("deadline", writeDeadline))
if err := stream.SetWriteDeadline(writeDeadline); err != nil {
stream.Close()
return nil, errors.WithStack(err)
}
if err := encoder.Encode(proxyReq); err != nil {
stream.Close()
return nil, errors.WithStack(err)
}
if err := stream.SetWriteDeadline(time.Time{}); err != nil {
stream.Close()
return nil, errors.WithStack(err)
}
return stream, nil
}
func (c *RemoteClient) acceptSession(ctx context.Context, conn *kcp.UDPSession) (*smux.Session, error) {
logger.Debug(ctx, "accepting client session")
sess, err := smux.Server(conn, c.smuxConfig)
if err != nil {
return nil, errors.WithStack(err)
}
c.remoteAddr = conn.RemoteAddr()
if c.onClientConnectHook != nil {
if err := c.onClientConnectHook.OnClientConnect(ctx, c); err != nil {
return nil, errors.WithStack(err)
}
}
return sess, nil
}
func (c *RemoteClient) authenticate(ctx context.Context, stream *smux.Stream) error {
start := time.Now()
readDeadline := time.Now().Add(c.authenticationTimeout)
logger.Debug(ctx, "waiting for auth request", logger.F("deadline", readDeadline))
if err := stream.SetReadDeadline(readDeadline); err != nil {
return errors.WithStack(err)
}
decoder := json.NewDecoder(stream)
authReq := &authRequest{}
if err := decoder.Decode(authReq); err != nil {
return errors.WithStack(err)
}
var (
success bool
err error
)
logger.Debug(ctx, "received client credentials", logger.F("credentials", authReq.Credentials))
if c.onClientAuthHook != nil {
success, err = c.onClientAuthHook.OnClientAuth(ctx, c, authReq.Credentials)
if err != nil {
return errors.WithStack(err)
}
}
authRes := &authResponse{
Success: success,
}
encoder := json.NewEncoder(stream)
writeDeadline := time.Now().Add(c.authenticationTimeout - time.Since(start))
logger.Debug(ctx, "sending auth response", logger.F("deadline", writeDeadline))
if err := stream.SetWriteDeadline(writeDeadline); err != nil {
return errors.WithStack(err)
}
if err := encoder.Encode(authRes); err != nil {
return errors.WithStack(err)
}
if !success {
return errors.WithStack(ErrAuthenticationFailed)
}
return nil
}
func NewRemoteClient(smuxConfig *smux.Config, authenticationTimeout, proxyRequestTimeout time.Duration) *RemoteClient {
return &RemoteClient{
smuxConfig: smuxConfig,
authenticationTimeout: authenticationTimeout,
proxyRequestTimeout: proxyRequestTimeout,
}
}