215 lines
4.6 KiB
Go
215 lines
4.6 KiB
Go
package tunnel
|
|
|
|
import (
|
|
"context"
|
|
"encoding/json"
|
|
"io"
|
|
"net"
|
|
"time"
|
|
|
|
"gitlab.com/wpetit/goweb/logger"
|
|
|
|
"github.com/pkg/errors"
|
|
"github.com/xtaci/kcp-go/v5"
|
|
"github.com/xtaci/smux"
|
|
)
|
|
|
|
type Client struct {
|
|
conf *ClientConfig
|
|
conn *kcp.UDPSession
|
|
sess *smux.Session
|
|
}
|
|
|
|
func (c *Client) Connect(ctx context.Context) error {
|
|
logger.Debug(ctx, "connecting", logger.F("serverAddr", c.conf.ServerAddress))
|
|
|
|
conn, err := kcp.DialWithOptions(
|
|
c.conf.ServerAddress, c.conf.BlockCrypt,
|
|
c.conf.DataShards, c.conf.ParityShards,
|
|
)
|
|
if err != nil {
|
|
return errors.WithStack(err)
|
|
}
|
|
|
|
if c.conf.ConfigureConn != nil {
|
|
if err := c.conf.ConfigureConn(conn); err != nil {
|
|
return errors.WithStack(err)
|
|
}
|
|
}
|
|
|
|
sess, err := smux.Client(conn, c.conf.SmuxConfig)
|
|
if err != nil {
|
|
return errors.WithStack(err)
|
|
}
|
|
|
|
stream, err := sess.OpenStream()
|
|
if err != nil {
|
|
return errors.WithStack(err)
|
|
}
|
|
|
|
defer stream.Close()
|
|
|
|
success, err := c.authenticate(ctx, stream)
|
|
if err != nil {
|
|
return errors.WithStack(err)
|
|
}
|
|
|
|
if !success {
|
|
return errors.WithStack(ErrAuthenticationFailed)
|
|
}
|
|
|
|
logger.Debug(ctx, "authentication success")
|
|
|
|
c.conn = conn
|
|
c.sess = sess
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *Client) Listen(ctx context.Context) error {
|
|
logger.Debug(ctx, "listening for proxy requests")
|
|
|
|
for {
|
|
stream, err := c.sess.AcceptStream()
|
|
if err != nil {
|
|
return errors.WithStack(err)
|
|
}
|
|
|
|
subCtx := logger.With(ctx,
|
|
logger.F("remoteAddr", stream.RemoteAddr()),
|
|
logger.F("localAddr", stream.LocalAddr()),
|
|
)
|
|
|
|
readDeadline := time.Now().Add(c.conf.ProxyRequestTimeout)
|
|
logger.Debug(subCtx, "waiting for proxy request", logger.F("deadline", readDeadline))
|
|
|
|
if err := stream.SetReadDeadline(readDeadline); err != nil {
|
|
stream.Close()
|
|
logger.Error(subCtx, "could not set read deadline", logger.E(errors.WithStack(err)))
|
|
|
|
continue
|
|
}
|
|
|
|
decoder := json.NewDecoder(stream)
|
|
proxyReq := &proxyRequest{}
|
|
|
|
if err := decoder.Decode(proxyReq); err != nil {
|
|
stream.Close()
|
|
logger.Error(subCtx, "could not decode proxy request", logger.E(errors.WithStack(err)))
|
|
|
|
continue
|
|
}
|
|
|
|
if err := stream.SetReadDeadline(time.Time{}); err != nil {
|
|
stream.Close()
|
|
logger.Error(subCtx, "could not set read deadline", logger.E(errors.WithStack(err)))
|
|
|
|
continue
|
|
}
|
|
|
|
go c.handleProxyStream(subCtx, stream, proxyReq.Network, proxyReq.Address)
|
|
}
|
|
}
|
|
|
|
func (c *Client) Close() error {
|
|
if c.sess != nil && !c.sess.IsClosed() {
|
|
if err := c.sess.Close(); err != nil {
|
|
return errors.WithStack(err)
|
|
}
|
|
}
|
|
|
|
if c.conn != nil {
|
|
if err := c.conn.Close(); err != nil {
|
|
return errors.WithStack(err)
|
|
}
|
|
}
|
|
|
|
c.conn = nil
|
|
c.sess = nil
|
|
|
|
return nil
|
|
}
|
|
|
|
func (c *Client) authenticate(ctx context.Context, stream *smux.Stream) (bool, error) {
|
|
encoder := json.NewEncoder(stream)
|
|
authReq := &authRequest{
|
|
Credentials: c.conf.Credentials,
|
|
}
|
|
|
|
start := time.Now()
|
|
writeDeadline := start.Add(c.conf.AuthenticationTimeout)
|
|
logger.Debug(ctx, "sending auth request", logger.F("deadline", writeDeadline))
|
|
|
|
if err := stream.SetWriteDeadline(writeDeadline); err != nil {
|
|
return false, errors.WithStack(err)
|
|
}
|
|
|
|
if err := encoder.Encode(authReq); err != nil {
|
|
return false, errors.WithStack(err)
|
|
}
|
|
|
|
decoder := json.NewDecoder(stream)
|
|
authRes := &authResponse{}
|
|
|
|
readDeadline := time.Now().Add(c.conf.AuthenticationTimeout - time.Now().Sub(start))
|
|
logger.Debug(ctx, "waiting for auth response", logger.F("deadline", readDeadline))
|
|
|
|
if err := stream.SetReadDeadline(readDeadline); err != nil {
|
|
return false, errors.WithStack(err)
|
|
}
|
|
|
|
if err := decoder.Decode(authRes); err != nil && !errors.Is(err, io.EOF) {
|
|
return false, errors.WithStack(err)
|
|
}
|
|
|
|
return authRes.Success, nil
|
|
}
|
|
|
|
func (c *Client) handleProxyStream(ctx context.Context, in *smux.Stream, network, address string) {
|
|
defer func(start time.Time) {
|
|
logger.Debug(ctx, "handleProxyStream duration", logger.F("duration", time.Since(start)))
|
|
}(time.Now())
|
|
|
|
defer in.Close()
|
|
|
|
logger.Debug(
|
|
ctx, "proxying",
|
|
logger.F("network", network),
|
|
logger.F("address", address),
|
|
)
|
|
|
|
out, err := net.Dial(network, address)
|
|
if err != nil {
|
|
logger.Error(ctx, "could not dial", logger.E(errors.WithStack(err)))
|
|
|
|
return
|
|
}
|
|
defer out.Close()
|
|
|
|
streamCopy := func(dst io.Writer, src io.ReadCloser) {
|
|
if _, err := Copy(dst, src); err != nil {
|
|
if errors.Is(err, smux.ErrInvalidProtocol) {
|
|
logger.Error(ctx, "could not proxy", logger.E(errors.WithStack(err)))
|
|
}
|
|
}
|
|
|
|
in.Close()
|
|
out.Close()
|
|
}
|
|
|
|
go streamCopy(out, in)
|
|
streamCopy(in, out)
|
|
}
|
|
|
|
func NewClient(funcs ...ClientConfigFunc) *Client {
|
|
conf := DefaultClientConfig()
|
|
|
|
for _, fn := range funcs {
|
|
fn(conf)
|
|
}
|
|
|
|
return &Client{
|
|
conf: conf,
|
|
}
|
|
}
|