go-tunnel/remote_client.go

242 lines
5.2 KiB
Go
Raw Permalink Normal View History

2020-10-21 18:00:15 +02:00
package tunnel
import (
"context"
2020-10-26 19:42:07 +01:00
"encoding/json"
2020-10-21 18:00:15 +02:00
"net"
2020-10-24 13:35:27 +02:00
"sync"
"time"
2020-10-21 18:00:15 +02:00
"github.com/pkg/errors"
"github.com/xtaci/kcp-go/v5"
"github.com/xtaci/smux"
"gitlab.com/wpetit/goweb/logger"
)
type RemoteClient struct {
onClientAuthHook OnClientAuthHook
onClientConnectHook OnClientConnectHook
onClientDisconnectHook OnClientDisconnectHook
2020-10-26 19:42:07 +01:00
conn *kcp.UDPSession
2020-10-21 18:00:15 +02:00
sess *smux.Session
remoteAddr net.Addr
2020-10-26 19:42:07 +01:00
authenticationTimeout time.Duration
proxyRequestTimeout time.Duration
connMutex sync.RWMutex
smuxConfig *smux.Config
2020-10-21 18:00:15 +02:00
}
func (c *RemoteClient) Accept(ctx context.Context, conn *kcp.UDPSession) error {
2020-10-26 19:42:07 +01:00
c.connMutex.Lock()
defer c.connMutex.Unlock()
2020-10-24 13:35:27 +02:00
2020-10-26 19:42:07 +01:00
if err := c.Close(); err != nil {
return errors.WithStack(err)
}
2020-10-21 18:00:15 +02:00
2020-10-26 19:42:07 +01:00
sess, err := c.acceptSession(ctx, conn)
2020-10-21 18:00:15 +02:00
if err != nil {
return errors.WithStack(err)
}
2020-10-26 19:42:07 +01:00
stream, err := sess.AcceptStream()
if err != nil {
2020-10-21 18:00:15 +02:00
return errors.WithStack(err)
}
2020-10-26 19:42:07 +01:00
defer stream.Close()
2020-10-24 13:35:27 +02:00
2020-10-26 19:42:07 +01:00
if err := c.authenticate(ctx, stream); err != nil {
return errors.WithStack(err)
2020-10-24 13:35:27 +02:00
}
2020-10-21 18:00:15 +02:00
2020-10-26 19:42:07 +01:00
c.sess = sess
c.conn = conn
2020-10-21 18:00:15 +02:00
2020-10-26 19:42:07 +01:00
return nil
2020-10-24 13:35:27 +02:00
}
2020-10-21 18:00:15 +02:00
2020-10-24 13:35:27 +02:00
func (c *RemoteClient) ConfigureHooks(hooks interface{}) {
if hooks == nil {
return
}
if onClientAuthHook, ok := hooks.(OnClientAuthHook); ok {
c.onClientAuthHook = onClientAuthHook
}
if OnClientConnectHook, ok := hooks.(OnClientConnectHook); ok {
c.onClientConnectHook = OnClientConnectHook
}
if OnClientDisconnectHook, ok := hooks.(OnClientDisconnectHook); ok {
c.onClientDisconnectHook = OnClientDisconnectHook
}
}
func (c *RemoteClient) RemoteAddr() net.Addr {
return c.remoteAddr
}
2020-10-26 19:42:07 +01:00
func (c *RemoteClient) Close() error {
2020-10-24 13:35:27 +02:00
if c.sess != nil {
2020-10-26 19:42:07 +01:00
if err := c.sess.Close(); err != nil {
return errors.WithStack(err)
}
}
if c.conn != nil {
if err := c.conn.Close(); err != nil {
return errors.WithStack(err)
}
2020-10-24 13:35:27 +02:00
}
c.sess = nil
2020-10-26 19:42:07 +01:00
c.conn = nil
return nil
}
func (c *RemoteClient) SwitchConn(ctx context.Context, conn *kcp.UDPSession) error {
c.connMutex.Lock()
defer c.connMutex.Unlock()
if err := c.Close(); err != nil {
return errors.WithStack(err)
}
sess, err := c.acceptSession(ctx, conn)
if err != nil {
return errors.WithStack(err)
}
c.sess = sess
c.conn = conn
return nil
2020-10-24 13:35:27 +02:00
}
func (c *RemoteClient) Proxy(ctx context.Context, network, address string) (net.Conn, error) {
2020-10-26 19:42:07 +01:00
c.connMutex.RLock()
defer c.connMutex.RUnlock()
2020-10-24 13:35:27 +02:00
ctx = logger.With(ctx, logger.F("network", network), logger.F("address", address))
2020-10-26 19:42:07 +01:00
logger.Debug(ctx, "opening proxy stream")
stream, err := c.sess.OpenStream()
if err != nil {
2020-10-24 13:35:27 +02:00
return nil, errors.WithStack(err)
}
2020-10-26 19:42:07 +01:00
proxyReq := &proxyRequest{
Network: network,
Address: address,
}
encoder := json.NewEncoder(stream)
2020-10-21 18:00:15 +02:00
2020-10-26 19:42:07 +01:00
writeDeadline := time.Now().Add(c.proxyRequestTimeout)
logger.Debug(ctx, "sending proxy req", logger.F("deadline", writeDeadline))
if err := stream.SetWriteDeadline(writeDeadline); err != nil {
stream.Close()
2020-10-24 13:35:27 +02:00
return nil, errors.WithStack(err)
2020-10-21 18:00:15 +02:00
}
2020-10-26 19:42:07 +01:00
if err := encoder.Encode(proxyReq); err != nil {
stream.Close()
return nil, errors.WithStack(err)
}
2020-10-24 13:35:27 +02:00
2020-10-26 19:42:07 +01:00
if err := stream.SetWriteDeadline(time.Time{}); err != nil {
2020-10-24 13:35:27 +02:00
stream.Close()
2020-10-26 19:42:07 +01:00
return nil, errors.WithStack(err)
}
2020-10-24 13:35:27 +02:00
return stream, nil
2020-10-21 18:00:15 +02:00
}
2020-10-26 19:42:07 +01:00
func (c *RemoteClient) acceptSession(ctx context.Context, conn *kcp.UDPSession) (*smux.Session, error) {
logger.Debug(ctx, "accepting client session")
sess, err := smux.Server(conn, c.smuxConfig)
if err != nil {
return nil, errors.WithStack(err)
}
c.remoteAddr = conn.RemoteAddr()
if c.onClientConnectHook != nil {
if err := c.onClientConnectHook.OnClientConnect(ctx, c); err != nil {
return nil, errors.WithStack(err)
}
}
return sess, nil
}
func (c *RemoteClient) authenticate(ctx context.Context, stream *smux.Stream) error {
start := time.Now()
readDeadline := time.Now().Add(c.authenticationTimeout)
logger.Debug(ctx, "waiting for auth request", logger.F("deadline", readDeadline))
if err := stream.SetReadDeadline(readDeadline); err != nil {
return errors.WithStack(err)
2020-10-21 18:00:15 +02:00
}
2020-10-26 19:42:07 +01:00
decoder := json.NewDecoder(stream)
authReq := &authRequest{}
if err := decoder.Decode(authReq); err != nil {
return errors.WithStack(err)
}
2020-10-21 18:00:15 +02:00
var (
success bool
err error
)
2020-10-26 19:42:07 +01:00
logger.Debug(ctx, "received client credentials", logger.F("credentials", authReq.Credentials))
2020-10-21 18:00:15 +02:00
if c.onClientAuthHook != nil {
2020-10-26 19:42:07 +01:00
success, err = c.onClientAuthHook.OnClientAuth(ctx, c, authReq.Credentials)
2020-10-21 18:00:15 +02:00
if err != nil {
2020-10-26 19:42:07 +01:00
return errors.WithStack(err)
2020-10-21 18:00:15 +02:00
}
}
2020-10-26 19:42:07 +01:00
authRes := &authResponse{
2020-10-21 18:00:15 +02:00
Success: success,
2020-10-26 19:42:07 +01:00
}
encoder := json.NewEncoder(stream)
writeDeadline := time.Now().Add(c.authenticationTimeout - time.Since(start))
logger.Debug(ctx, "sending auth response", logger.F("deadline", writeDeadline))
if err := stream.SetWriteDeadline(writeDeadline); err != nil {
return errors.WithStack(err)
}
2020-10-21 18:00:15 +02:00
2020-10-26 19:42:07 +01:00
if err := encoder.Encode(authRes); err != nil {
return errors.WithStack(err)
}
if !success {
return errors.WithStack(ErrAuthenticationFailed)
}
return nil
2020-10-21 18:00:15 +02:00
}
2020-10-26 19:42:07 +01:00
func NewRemoteClient(smuxConfig *smux.Config, authenticationTimeout, proxyRequestTimeout time.Duration) *RemoteClient {
2020-10-24 13:35:27 +02:00
return &RemoteClient{
2020-10-26 19:42:07 +01:00
smuxConfig: smuxConfig,
authenticationTimeout: authenticationTimeout,
proxyRequestTimeout: proxyRequestTimeout,
2020-10-24 13:35:27 +02:00
}
2020-10-21 18:00:15 +02:00
}