131 lines
3.0 KiB
Go
131 lines
3.0 KiB
Go
package tunnel
|
|
|
|
import (
|
|
"context"
|
|
"io"
|
|
"net"
|
|
|
|
"forge.cadoles.com/wpetit/go-tunnel/control"
|
|
|
|
"github.com/pkg/errors"
|
|
"github.com/xtaci/kcp-go/v5"
|
|
"github.com/xtaci/smux"
|
|
"gitlab.com/wpetit/goweb/logger"
|
|
)
|
|
|
|
type RemoteClient struct {
|
|
onClientAuthHook OnClientAuthHook
|
|
onClientConnectHook OnClientConnectHook
|
|
onClientDisconnectHook OnClientDisconnectHook
|
|
sess *smux.Session
|
|
control *control.Control
|
|
remoteAddr net.Addr
|
|
}
|
|
|
|
func (c *RemoteClient) Accept(ctx context.Context, conn *kcp.UDPSession) error {
|
|
c.remoteAddr = conn.RemoteAddr()
|
|
|
|
if c.onClientConnectHook != nil {
|
|
if err := c.onClientConnectHook.OnClientConnect(ctx, c); err != nil {
|
|
return errors.WithStack(err)
|
|
}
|
|
}
|
|
|
|
config := smux.DefaultConfig()
|
|
config.Version = 2
|
|
|
|
sess, err := smux.Server(conn, config)
|
|
if err != nil {
|
|
return errors.WithStack(err)
|
|
}
|
|
|
|
logger.Debug(ctx, "accepting control stream")
|
|
|
|
controlStream, err := sess.AcceptStream()
|
|
if err != nil {
|
|
return errors.WithStack(err)
|
|
}
|
|
|
|
c.sess = sess
|
|
c.control = control.New(sess, controlStream)
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *RemoteClient) Listen(ctx context.Context) error {
|
|
logger.Debug(ctx, "listening for messages")
|
|
|
|
err := c.control.Listen(ctx, control.Handlers{
|
|
control.TypeAuthRequest: c.handleAuthRequest,
|
|
})
|
|
|
|
if errors.Is(err, io.ErrClosedPipe) {
|
|
if c.onClientDisconnectHook != nil {
|
|
if err := c.onClientDisconnectHook.OnClientDisconnect(ctx, c); err != nil {
|
|
logger.Error(ctx, "client disconnect hook error", logger.E(err))
|
|
}
|
|
}
|
|
|
|
return errors.WithStack(ErrConnectionClosed)
|
|
}
|
|
|
|
return err
|
|
}
|
|
|
|
func (c *RemoteClient) handleAuthRequest(ctx context.Context, m *control.Message) (*control.Message, error) {
|
|
authReqPayload, ok := m.Payload.(*control.AuthRequestPayload)
|
|
if !ok {
|
|
return nil, errors.WithStack(ErrUnexpectedMessage)
|
|
}
|
|
|
|
logger.Debug(ctx, "handling auth request", logger.F("credentials", authReqPayload.Credentials))
|
|
|
|
var (
|
|
success bool
|
|
err error
|
|
)
|
|
|
|
if c.onClientAuthHook != nil {
|
|
success, err = c.onClientAuthHook.OnClientAuth(ctx, c, authReqPayload.Credentials)
|
|
if err != nil {
|
|
return nil, errors.WithStack(err)
|
|
}
|
|
}
|
|
|
|
res := control.NewMessage(control.TypeAuthResponse, &control.AuthResponsePayload{
|
|
Success: success,
|
|
})
|
|
|
|
return res, nil
|
|
}
|
|
|
|
func (c *RemoteClient) ConfigureHooks(hooks interface{}) {
|
|
if hooks == nil {
|
|
return
|
|
}
|
|
|
|
if onClientAuthHook, ok := hooks.(OnClientAuthHook); ok {
|
|
c.onClientAuthHook = onClientAuthHook
|
|
}
|
|
|
|
if OnClientConnectHook, ok := hooks.(OnClientConnectHook); ok {
|
|
c.onClientConnectHook = OnClientConnectHook
|
|
}
|
|
|
|
if OnClientDisconnectHook, ok := hooks.(OnClientDisconnectHook); ok {
|
|
c.onClientDisconnectHook = OnClientDisconnectHook
|
|
}
|
|
}
|
|
|
|
func (c *RemoteClient) RemoteAddr() net.Addr {
|
|
return c.remoteAddr
|
|
}
|
|
|
|
func (c *RemoteClient) Proxy(ctx context.Context, network, address string) (net.Conn, error) {
|
|
return c.control.Proxy(ctx, network, address)
|
|
}
|
|
|
|
func NewRemoteClient() *RemoteClient {
|
|
return &RemoteClient{}
|
|
}
|