go-tunnel/remote_client.go

175 lines
3.9 KiB
Go
Raw Normal View History

2020-10-21 18:00:15 +02:00
package tunnel
import (
"context"
"net"
2020-10-24 13:35:27 +02:00
"sync"
"time"
2020-10-21 18:00:15 +02:00
"forge.cadoles.com/wpetit/go-tunnel/control"
2020-10-24 13:35:27 +02:00
cmap "github.com/orcaman/concurrent-map"
2020-10-21 18:00:15 +02:00
"github.com/pkg/errors"
"github.com/xtaci/kcp-go/v5"
"github.com/xtaci/smux"
"gitlab.com/wpetit/goweb/logger"
)
type RemoteClient struct {
onClientAuthHook OnClientAuthHook
onClientConnectHook OnClientConnectHook
onClientDisconnectHook OnClientDisconnectHook
2020-10-24 13:35:27 +02:00
conn net.Conn
2020-10-21 18:00:15 +02:00
sess *smux.Session
control *control.Control
remoteAddr net.Addr
2020-10-24 13:35:27 +02:00
proxies cmap.ConcurrentMap
acceptStreamMutex sync.Mutex
2020-10-21 18:00:15 +02:00
}
func (c *RemoteClient) Accept(ctx context.Context, conn *kcp.UDPSession) error {
config := smux.DefaultConfig()
config.Version = 2
2020-10-24 13:35:27 +02:00
config.KeepAliveInterval = 10 * time.Second
config.KeepAliveTimeout = 2 * config.KeepAliveInterval
2020-10-21 18:00:15 +02:00
sess, err := smux.Server(conn, config)
if err != nil {
return errors.WithStack(err)
}
2020-10-24 13:35:27 +02:00
control := control.New()
2020-10-21 18:00:15 +02:00
2020-10-24 13:35:27 +02:00
if err := control.Init(ctx, sess, true); err != nil {
2020-10-21 18:00:15 +02:00
return errors.WithStack(err)
}
c.sess = sess
2020-10-24 13:35:27 +02:00
c.remoteAddr = conn.RemoteAddr()
c.control = control
c.conn = conn
if c.onClientConnectHook != nil {
if err := c.onClientConnectHook.OnClientConnect(ctx, c); err != nil {
return errors.WithStack(err)
}
}
2020-10-21 18:00:15 +02:00
return nil
}
func (c *RemoteClient) Listen(ctx context.Context) error {
2020-10-24 13:35:27 +02:00
defer func() {
if c.onClientDisconnectHook != nil {
if err := c.onClientDisconnectHook.OnClientDisconnect(ctx, c); err != nil {
logger.Error(ctx, "client disconnect hook error", logger.E(errors.WithStack(err)))
}
}
}()
2020-10-21 18:00:15 +02:00
logger.Debug(ctx, "listening for messages")
2020-10-24 13:35:27 +02:00
return c.control.Listen(ctx, control.Handlers{
2020-10-21 18:00:15 +02:00
control.TypeAuthRequest: c.handleAuthRequest,
})
2020-10-24 13:35:27 +02:00
}
2020-10-21 18:00:15 +02:00
2020-10-24 13:35:27 +02:00
func (c *RemoteClient) ConfigureHooks(hooks interface{}) {
if hooks == nil {
return
}
if onClientAuthHook, ok := hooks.(OnClientAuthHook); ok {
c.onClientAuthHook = onClientAuthHook
}
if OnClientConnectHook, ok := hooks.(OnClientConnectHook); ok {
c.onClientConnectHook = OnClientConnectHook
}
2020-10-21 18:00:15 +02:00
2020-10-24 13:35:27 +02:00
if OnClientDisconnectHook, ok := hooks.(OnClientDisconnectHook); ok {
c.onClientDisconnectHook = OnClientDisconnectHook
2020-10-21 18:00:15 +02:00
}
2020-10-24 13:35:27 +02:00
}
2020-10-21 18:00:15 +02:00
2020-10-24 13:35:27 +02:00
func (c *RemoteClient) RemoteAddr() net.Addr {
return c.remoteAddr
}
func (c *RemoteClient) Close() {
if c.sess != nil {
c.sess.Close()
}
if c.conn != nil {
c.conn.Close()
}
c.sess = nil
c.conn = nil
c.control = nil
}
func (c *RemoteClient) Proxy(ctx context.Context, network, address string) (net.Conn, error) {
ctx = logger.With(ctx, logger.F("network", network), logger.F("address", address))
if err := c.control.ProxyReq(ctx, network, address); err != nil {
return nil, errors.WithStack(err)
}
logger.Debug(ctx, "opening proxy stream")
c.acceptStreamMutex.Lock()
stream, err := c.sess.AcceptStream()
if err != nil {
c.acceptStreamMutex.Unlock()
return nil, errors.WithStack(err)
}
c.acceptStreamMutex.Unlock()
go func() {
<-ctx.Done()
logger.Debug(ctx, "closing proxy stream")
stream.Close()
}()
return stream, nil
2020-10-21 18:00:15 +02:00
}
func (c *RemoteClient) handleAuthRequest(ctx context.Context, m *control.Message) (*control.Message, error) {
authReqPayload, ok := m.Payload.(*control.AuthRequestPayload)
if !ok {
return nil, errors.WithStack(ErrUnexpectedMessage)
}
logger.Debug(ctx, "handling auth request", logger.F("credentials", authReqPayload.Credentials))
var (
success bool
err error
)
if c.onClientAuthHook != nil {
success, err = c.onClientAuthHook.OnClientAuth(ctx, c, authReqPayload.Credentials)
if err != nil {
return nil, errors.WithStack(err)
}
}
2020-10-23 17:08:42 +02:00
logger.Debug(ctx, "auth succeeded", logger.F("credentials", authReqPayload.Credentials))
2020-10-21 18:00:15 +02:00
res := control.NewMessage(control.TypeAuthResponse, &control.AuthResponsePayload{
Success: success,
})
return res, nil
}
func NewRemoteClient() *RemoteClient {
2020-10-24 13:35:27 +02:00
return &RemoteClient{
proxies: cmap.New(),
}
2020-10-21 18:00:15 +02:00
}