go-tunnel/remote_client.go

170 lines
3.8 KiB
Go

package tunnel
import (
"context"
"net"
"sync"
"time"
"forge.cadoles.com/wpetit/go-tunnel/control"
cmap "github.com/orcaman/concurrent-map"
"github.com/pkg/errors"
"github.com/xtaci/kcp-go/v5"
"github.com/xtaci/smux"
"gitlab.com/wpetit/goweb/logger"
)
type RemoteClient struct {
onClientAuthHook OnClientAuthHook
onClientConnectHook OnClientConnectHook
onClientDisconnectHook OnClientDisconnectHook
sess *smux.Session
control *control.Control
remoteAddr net.Addr
proxies cmap.ConcurrentMap
acceptStreamMutex sync.Mutex
}
func (c *RemoteClient) Accept(ctx context.Context, conn *kcp.UDPSession) error {
config := smux.DefaultConfig()
config.Version = 2
config.KeepAliveInterval = 10 * time.Second
config.KeepAliveTimeout = 2 * config.KeepAliveInterval
logger.Debug(ctx, "creating server session")
sess, err := smux.Server(conn, config)
if err != nil {
return errors.WithStack(err)
}
ctrl := control.New()
if err := ctrl.Init(ctx, sess, true); err != nil {
return errors.WithStack(err)
}
c.sess = sess
c.remoteAddr = conn.RemoteAddr()
c.control = ctrl
if c.onClientConnectHook != nil {
if err := c.onClientConnectHook.OnClientConnect(ctx, c); err != nil {
return errors.WithStack(err)
}
}
return nil
}
func (c *RemoteClient) Listen(ctx context.Context) error {
defer func() {
if c.onClientDisconnectHook != nil {
if err := c.onClientDisconnectHook.OnClientDisconnect(ctx, c); err != nil {
logger.Error(ctx, "client disconnect hook error", logger.E(errors.WithStack(err)))
}
}
}()
logger.Debug(ctx, "listening for messages")
return c.control.Listen(ctx, control.Handlers{
control.TypeAuthRequest: c.handleAuthRequest,
})
}
func (c *RemoteClient) ConfigureHooks(hooks interface{}) {
if hooks == nil {
return
}
if onClientAuthHook, ok := hooks.(OnClientAuthHook); ok {
c.onClientAuthHook = onClientAuthHook
}
if OnClientConnectHook, ok := hooks.(OnClientConnectHook); ok {
c.onClientConnectHook = OnClientConnectHook
}
if OnClientDisconnectHook, ok := hooks.(OnClientDisconnectHook); ok {
c.onClientDisconnectHook = OnClientDisconnectHook
}
}
func (c *RemoteClient) RemoteAddr() net.Addr {
return c.remoteAddr
}
func (c *RemoteClient) Close() {
if c.sess != nil {
c.sess.Close()
}
c.sess = nil
c.control = nil
}
func (c *RemoteClient) Proxy(ctx context.Context, network, address string) (net.Conn, error) {
ctx = logger.With(ctx, logger.F("network", network), logger.F("address", address))
if err := c.control.ProxyReq(ctx, network, address); err != nil {
return nil, errors.WithStack(err)
}
logger.Debug(ctx, "opening proxy stream")
c.acceptStreamMutex.Lock()
stream, err := c.sess.AcceptStream()
if err != nil {
c.acceptStreamMutex.Unlock()
return nil, errors.WithStack(err)
}
c.acceptStreamMutex.Unlock()
go func() {
<-ctx.Done()
logger.Debug(ctx, "closing proxy stream")
stream.Close()
}()
return stream, nil
}
func (c *RemoteClient) handleAuthRequest(ctx context.Context, m *control.Message) (*control.Message, error) {
authReqPayload, ok := m.Payload.(*control.AuthRequestPayload)
if !ok {
return nil, errors.WithStack(ErrUnexpectedMessage)
}
logger.Debug(ctx, "handling auth request", logger.F("credentials", authReqPayload.Credentials))
var (
success bool
err error
)
if c.onClientAuthHook != nil {
success, err = c.onClientAuthHook.OnClientAuth(ctx, c, authReqPayload.Credentials)
if err != nil {
return nil, errors.WithStack(err)
}
}
logger.Debug(ctx, "auth succeeded", logger.F("credentials", authReqPayload.Credentials))
res := control.NewMessage(control.TypeAuthResponse, &control.AuthResponsePayload{
Success: success,
})
return res, nil
}
func NewRemoteClient() *RemoteClient {
return &RemoteClient{
proxies: cmap.New(),
}
}