go-tunnel/remote_client.go

137 lines
3.1 KiB
Go

package tunnel
import (
"context"
"io"
"net"
"forge.cadoles.com/wpetit/go-tunnel/control"
"github.com/pkg/errors"
"github.com/xtaci/kcp-go/v5"
"github.com/xtaci/smux"
"gitlab.com/wpetit/goweb/logger"
)
type RemoteClient struct {
onClientAuthHook OnClientAuthHook
onClientConnectHook OnClientConnectHook
onClientDisconnectHook OnClientDisconnectHook
sess *smux.Session
control *control.Control
remoteAddr net.Addr
}
func (c *RemoteClient) Accept(ctx context.Context, conn *kcp.UDPSession) error {
c.remoteAddr = conn.RemoteAddr()
if c.onClientConnectHook != nil {
if err := c.onClientConnectHook.OnClientConnect(ctx, c); err != nil {
return errors.WithStack(err)
}
}
config := smux.DefaultConfig()
config.Version = 2
sess, err := smux.Server(conn, config)
if err != nil {
return errors.WithStack(err)
}
logger.Debug(ctx, "accepting control stream")
controlStream, err := sess.AcceptStream()
if err != nil {
return errors.WithStack(err)
}
c.sess = sess
c.control = control.New(sess, controlStream)
return nil
}
func (c *RemoteClient) Listen(ctx context.Context) error {
logger.Debug(ctx, "listening for messages")
err := c.control.Listen(ctx, control.Handlers{
control.TypeAuthRequest: c.handleAuthRequest,
})
if errors.Is(err, io.ErrClosedPipe) {
if c.onClientDisconnectHook != nil {
if err := c.onClientDisconnectHook.OnClientDisconnect(ctx, c); err != nil {
logger.Error(ctx, "client disconnect hook error", logger.E(err))
}
}
return errors.WithStack(ErrConnectionClosed)
}
return err
}
func (c *RemoteClient) handleAuthRequest(ctx context.Context, m *control.Message) (*control.Message, error) {
authReqPayload, ok := m.Payload.(*control.AuthRequestPayload)
if !ok {
return nil, errors.WithStack(ErrUnexpectedMessage)
}
logger.Debug(ctx, "handling auth request", logger.F("credentials", authReqPayload.Credentials))
var (
success bool
err error
)
if c.onClientAuthHook != nil {
success, err = c.onClientAuthHook.OnClientAuth(ctx, c, authReqPayload.Credentials)
if err != nil {
return nil, errors.WithStack(err)
}
}
res := control.NewMessage(control.TypeAuthResponse, &control.AuthResponsePayload{
Success: success,
})
return res, nil
}
func (c *RemoteClient) ConfigureHooks(hooks interface{}) {
if hooks == nil {
return
}
if onClientAuthHook, ok := hooks.(OnClientAuthHook); ok {
c.onClientAuthHook = onClientAuthHook
}
if OnClientConnectHook, ok := hooks.(OnClientConnectHook); ok {
c.onClientConnectHook = OnClientConnectHook
}
if OnClientDisconnectHook, ok := hooks.(OnClientDisconnectHook); ok {
c.onClientDisconnectHook = OnClientDisconnectHook
}
}
func (c *RemoteClient) RemoteAddr() net.Addr {
return c.remoteAddr
}
func (c *RemoteClient) Proxy(ctx context.Context, network, address string) (net.Conn, error) {
return c.control.Proxy(ctx, network, address)
}
func (c *RemoteClient) Close() {
if c.sess != nil && !c.sess.IsClosed() {
c.sess.Close()
}
}
func NewRemoteClient() *RemoteClient {
return &RemoteClient{}
}