feat: general protocol rewrite
This commit is contained in:
parent
536100da90
commit
96c1575a0b
4
Makefile
4
Makefile
|
@ -4,10 +4,10 @@ watch:
|
||||||
build: build-server build-client
|
build: build-server build-client
|
||||||
|
|
||||||
build-server:
|
build-server:
|
||||||
go build -o ./bin/server ./cmd/server
|
CGO_ENABLED=0 go build -o ./bin/server ./cmd/server
|
||||||
|
|
||||||
build-client:
|
build-client:
|
||||||
go build -o ./bin/client ./cmd/client
|
CGO_ENABLED=0 go build -o ./bin/client ./cmd/client
|
||||||
|
|
||||||
test:
|
test:
|
||||||
go test -v -race ./...
|
go test -v -race ./...
|
166
client.go
166
client.go
|
@ -2,15 +2,13 @@ package tunnel
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
|
||||||
"sync"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"gitlab.com/wpetit/goweb/logger"
|
"gitlab.com/wpetit/goweb/logger"
|
||||||
|
|
||||||
"forge.cadoles.com/wpetit/go-tunnel/control"
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/xtaci/kcp-go/v5"
|
"github.com/xtaci/kcp-go/v5"
|
||||||
"github.com/xtaci/smux"
|
"github.com/xtaci/smux"
|
||||||
|
@ -20,12 +18,11 @@ type Client struct {
|
||||||
conf *ClientConfig
|
conf *ClientConfig
|
||||||
conn *kcp.UDPSession
|
conn *kcp.UDPSession
|
||||||
sess *smux.Session
|
sess *smux.Session
|
||||||
control *control.Control
|
|
||||||
http *http.Client
|
|
||||||
openStreamMutex sync.Mutex
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Connect(ctx context.Context) error {
|
func (c *Client) Connect(ctx context.Context) error {
|
||||||
|
logger.Debug(ctx, "connecting", logger.F("serverAddr", c.conf.ServerAddress))
|
||||||
|
|
||||||
conn, err := kcp.DialWithOptions(
|
conn, err := kcp.DialWithOptions(
|
||||||
c.conf.ServerAddress, c.conf.BlockCrypt,
|
c.conf.ServerAddress, c.conf.BlockCrypt,
|
||||||
c.conf.DataShards, c.conf.ParityShards,
|
c.conf.DataShards, c.conf.ParityShards,
|
||||||
|
@ -40,34 +37,29 @@ func (c *Client) Connect(ctx context.Context) error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
config := smux.DefaultConfig()
|
sess, err := smux.Client(conn, c.conf.SmuxConfig)
|
||||||
config.Version = 2
|
|
||||||
config.KeepAliveInterval = 10 * time.Second
|
|
||||||
config.KeepAliveTimeout = 2 * config.KeepAliveInterval
|
|
||||||
|
|
||||||
sess, err := smux.Client(conn, config)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.WithStack(err)
|
return errors.WithStack(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
control := control.New()
|
stream, err := sess.OpenStream()
|
||||||
if err := control.Init(ctx, sess, false); err != nil {
|
if err != nil {
|
||||||
return errors.WithStack(err)
|
return errors.WithStack(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Debug(ctx, "sending auth request")
|
defer stream.Close()
|
||||||
|
|
||||||
success, err := control.AuthRequest(c.conf.Credentials)
|
success, err := c.authenticate(ctx, stream)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.WithStack(err)
|
return errors.WithStack(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if !success {
|
if !success {
|
||||||
defer c.Close()
|
return errors.WithStack(ErrAuthenticationFailed)
|
||||||
return errors.WithStack(ErrAuthFailed)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
c.control = control
|
logger.Debug(ctx, "authentication success")
|
||||||
|
|
||||||
c.conn = conn
|
c.conn = conn
|
||||||
c.sess = sess
|
c.sess = sess
|
||||||
|
|
||||||
|
@ -75,87 +67,138 @@ func (c *Client) Connect(ctx context.Context) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Listen(ctx context.Context) error {
|
func (c *Client) Listen(ctx context.Context) error {
|
||||||
logger.Debug(ctx, "listening for messages")
|
logger.Debug(ctx, "listening for proxy requests")
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(ctx)
|
for {
|
||||||
defer cancel()
|
stream, err := c.sess.AcceptStream()
|
||||||
|
if err != nil {
|
||||||
err := c.control.Listen(ctx, control.Handlers{
|
return errors.WithStack(err)
|
||||||
control.TypeProxyRequest: c.handleProxyRequest,
|
|
||||||
})
|
|
||||||
|
|
||||||
if errors.Is(err, io.ErrClosedPipe) {
|
|
||||||
logger.Debug(ctx, "client connection closed")
|
|
||||||
|
|
||||||
return errors.WithStack(ErrConnectionClosed)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return err
|
subCtx := logger.With(ctx,
|
||||||
|
logger.F("remoteAddr", stream.RemoteAddr()),
|
||||||
|
logger.F("localAddr", stream.LocalAddr()),
|
||||||
|
)
|
||||||
|
|
||||||
|
readDeadline := time.Now().Add(c.conf.ProxyRequestTimeout)
|
||||||
|
logger.Debug(subCtx, "waiting for proxy request", logger.F("deadline", readDeadline))
|
||||||
|
|
||||||
|
if err := stream.SetReadDeadline(readDeadline); err != nil {
|
||||||
|
stream.Close()
|
||||||
|
logger.Error(subCtx, "could not set read deadline", logger.E(errors.WithStack(err)))
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
decoder := json.NewDecoder(stream)
|
||||||
|
proxyReq := &proxyRequest{}
|
||||||
|
|
||||||
|
if err := decoder.Decode(proxyReq); err != nil {
|
||||||
|
stream.Close()
|
||||||
|
logger.Error(subCtx, "could not decode proxy request", logger.E(errors.WithStack(err)))
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := stream.SetReadDeadline(time.Time{}); err != nil {
|
||||||
|
stream.Close()
|
||||||
|
logger.Error(subCtx, "could not set read deadline", logger.E(errors.WithStack(err)))
|
||||||
|
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
go c.handleProxyStream(subCtx, stream, proxyReq.Network, proxyReq.Address)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) Close() error {
|
func (c *Client) Close() error {
|
||||||
if c.conn == nil {
|
if c.sess != nil && !c.sess.IsClosed() {
|
||||||
return errors.WithStack(ErrNotConnected)
|
if err := c.sess.Close(); err != nil {
|
||||||
|
return errors.WithStack(err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if c.conn != nil {
|
||||||
if err := c.conn.Close(); err != nil {
|
if err := c.conn.Close(); err != nil {
|
||||||
return errors.WithStack(err)
|
return errors.WithStack(err)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
c.conn = nil
|
||||||
|
c.sess = nil
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) handleProxyRequest(ctx context.Context, m *control.Message) (*control.Message, error) {
|
func (c *Client) authenticate(ctx context.Context, stream *smux.Stream) (bool, error) {
|
||||||
proxyReqPayload, ok := m.Payload.(*control.ProxyRequestPayload)
|
encoder := json.NewEncoder(stream)
|
||||||
if !ok {
|
authReq := &authRequest{
|
||||||
return nil, errors.WithStack(ErrUnexpectedMessage)
|
Credentials: c.conf.Credentials,
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx = logger.With(ctx,
|
start := time.Now()
|
||||||
logger.F("network", proxyReqPayload.Network),
|
writeDeadline := start.Add(c.conf.AuthenticationTimeout)
|
||||||
logger.F("address", proxyReqPayload.Address),
|
logger.Debug(ctx, "sending auth request", logger.F("deadline", writeDeadline))
|
||||||
)
|
|
||||||
|
|
||||||
logger.Debug(ctx, "handling proxy request")
|
if err := stream.SetWriteDeadline(writeDeadline); err != nil {
|
||||||
|
return false, errors.WithStack(err)
|
||||||
out, err := net.Dial(proxyReqPayload.Network, proxyReqPayload.Address)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.WithStack(err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
go c.handleProxyStream(ctx, out)
|
if err := encoder.Encode(authReq); err != nil {
|
||||||
|
return false, errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil, nil
|
decoder := json.NewDecoder(stream)
|
||||||
|
authRes := &authResponse{}
|
||||||
|
|
||||||
|
readDeadline := time.Now().Add(c.conf.AuthenticationTimeout - time.Now().Sub(start))
|
||||||
|
logger.Debug(ctx, "waiting for auth response", logger.F("deadline", readDeadline))
|
||||||
|
|
||||||
|
if err := stream.SetReadDeadline(readDeadline); err != nil {
|
||||||
|
return false, errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := decoder.Decode(authRes); err != nil && !errors.Is(err, io.EOF) {
|
||||||
|
return false, errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return authRes.Success, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) handleProxyStream(ctx context.Context, out net.Conn) {
|
func (c *Client) handleProxyStream(ctx context.Context, in *smux.Stream, network, address string) {
|
||||||
c.openStreamMutex.Lock()
|
defer func(start time.Time) {
|
||||||
|
logger.Debug(ctx, "handleProxyStream duration", logger.F("duration", time.Since(start)))
|
||||||
|
}(time.Now())
|
||||||
|
|
||||||
in, err := c.sess.OpenStream()
|
defer in.Close()
|
||||||
|
|
||||||
|
logger.Debug(
|
||||||
|
ctx, "proxying",
|
||||||
|
logger.F("network", network),
|
||||||
|
logger.F("address", address),
|
||||||
|
)
|
||||||
|
|
||||||
|
out, err := net.Dial(network, address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.openStreamMutex.Unlock()
|
logger.Error(ctx, "could not dial", logger.E(errors.WithStack(err)))
|
||||||
logger.Error(ctx, "error while accepting proxy stream", logger.E(err))
|
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
defer out.Close()
|
||||||
c.openStreamMutex.Unlock()
|
|
||||||
|
|
||||||
streamCopy := func(dst io.Writer, src io.ReadCloser) {
|
streamCopy := func(dst io.Writer, src io.ReadCloser) {
|
||||||
if _, err := Copy(dst, src); err != nil {
|
if _, err := Copy(dst, src); err != nil {
|
||||||
if errors.Is(err, smux.ErrInvalidProtocol) {
|
if errors.Is(err, smux.ErrInvalidProtocol) {
|
||||||
logger.Error(ctx, "error while proxying", logger.E(errors.WithStack(err)))
|
logger.Error(ctx, "could not proxy", logger.E(errors.WithStack(err)))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Debug(ctx, "closing proxy stream")
|
|
||||||
|
|
||||||
in.Close()
|
in.Close()
|
||||||
out.Close()
|
out.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
go streamCopy(in, out)
|
go streamCopy(out, in)
|
||||||
streamCopy(out, in)
|
streamCopy(in, out)
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewClient(funcs ...ClientConfigFunc) *Client {
|
func NewClient(funcs ...ClientConfigFunc) *Client {
|
||||||
|
@ -167,6 +210,5 @@ func NewClient(funcs ...ClientConfigFunc) *Client {
|
||||||
|
|
||||||
return &Client{
|
return &Client{
|
||||||
conf: conf,
|
conf: conf,
|
||||||
http: &http.Client{},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -2,9 +2,11 @@ package tunnel
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/sha1"
|
"crypto/sha1"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/xtaci/kcp-go/v5"
|
"github.com/xtaci/kcp-go/v5"
|
||||||
|
"github.com/xtaci/smux"
|
||||||
"golang.org/x/crypto/pbkdf2"
|
"golang.org/x/crypto/pbkdf2"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -15,20 +17,34 @@ type ClientConfig struct {
|
||||||
ParityShards int
|
ParityShards int
|
||||||
Credentials interface{}
|
Credentials interface{}
|
||||||
ConfigureConn ConfigureConnFunc
|
ConfigureConn ConfigureConnFunc
|
||||||
|
AuthenticationTimeout time.Duration
|
||||||
|
ProxyRequestTimeout time.Duration
|
||||||
|
SmuxConfig *smux.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// nolint: go-mnd
|
||||||
func DefaultClientConfig() *ClientConfig {
|
func DefaultClientConfig() *ClientConfig {
|
||||||
unencryptedBlock, err := kcp.NewNoneBlockCrypt(nil)
|
unencryptedBlock, err := kcp.NewNoneBlockCrypt(nil)
|
||||||
if err != nil { // should never happen
|
if err != nil { // should never happen
|
||||||
panic(errors.WithStack(err))
|
panic(errors.WithStack(err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
smuxConfig := smux.DefaultConfig()
|
||||||
|
smuxConfig.Version = 2
|
||||||
|
smuxConfig.KeepAliveInterval = 10 * time.Second
|
||||||
|
smuxConfig.MaxReceiveBuffer = 4194304
|
||||||
|
smuxConfig.MaxStreamBuffer = 2097152
|
||||||
|
|
||||||
return &ClientConfig{
|
return &ClientConfig{
|
||||||
ServerAddress: "127.0.0.1:36543",
|
ServerAddress: "127.0.0.1:36543",
|
||||||
BlockCrypt: unencryptedBlock,
|
BlockCrypt: unencryptedBlock,
|
||||||
DataShards: 3,
|
DataShards: 3,
|
||||||
ParityShards: 10,
|
ParityShards: 10,
|
||||||
Credentials: nil,
|
Credentials: nil,
|
||||||
|
ConfigureConn: DefaultClientConfigureConn,
|
||||||
|
AuthenticationTimeout: 30 * time.Second,
|
||||||
|
ProxyRequestTimeout: 5 * time.Second,
|
||||||
|
SmuxConfig: smuxConfig,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -64,3 +80,34 @@ func WithClientConfigureConn(fn ConfigureConnFunc) ClientConfigFunc {
|
||||||
conf.ConfigureConn = fn
|
conf.ConfigureConn = fn
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithClientSmuxConfig(c *smux.Config) ClientConfigFunc {
|
||||||
|
return func(conf *ClientConfig) {
|
||||||
|
conf.SmuxConfig = c
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// nolint: go-mnd
|
||||||
|
func DefaultClientConfigureConn(conn *kcp.UDPSession) error {
|
||||||
|
// Based on kcptun default configuration, mode 'fast3'
|
||||||
|
conn.SetStreamMode(true)
|
||||||
|
conn.SetWriteDelay(false)
|
||||||
|
conn.SetNoDelay(1, 10, 2, 1)
|
||||||
|
conn.SetWindowSize(128, 512)
|
||||||
|
conn.SetMtu(1400)
|
||||||
|
conn.SetACKNoDelay(true)
|
||||||
|
|
||||||
|
if err := conn.SetReadBuffer(16777217); err != nil {
|
||||||
|
return errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := conn.SetWriteBuffer(16777217); err != nil {
|
||||||
|
return errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := conn.SetDSCP(46); err != nil {
|
||||||
|
return errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
|
@ -3,6 +3,7 @@ package main
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"flag"
|
"flag"
|
||||||
|
"fmt"
|
||||||
"math/rand"
|
"math/rand"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
@ -11,15 +12,18 @@ import (
|
||||||
"gitlab.com/wpetit/goweb/logger"
|
"gitlab.com/wpetit/goweb/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
const sharedKey = "go-tunnel"
|
|
||||||
const salt = "go-tunnel"
|
const salt = "go-tunnel"
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
var (
|
var (
|
||||||
clientID string
|
clientID = fmt.Sprintf("client-%d", time.Now().Unix())
|
||||||
|
serverAddr = "127.0.0.1:36543"
|
||||||
|
sharedKey = "go-tunnel"
|
||||||
)
|
)
|
||||||
|
|
||||||
flag.StringVar(&clientID, "id", "", "Client ID")
|
flag.StringVar(&sharedKey, "shared-key", sharedKey, "shared key")
|
||||||
|
flag.StringVar(&clientID, "id", clientID, "Client ID")
|
||||||
|
flag.StringVar(&serverAddr, "server-addr", serverAddr, "server address")
|
||||||
flag.Parse()
|
flag.Parse()
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
@ -28,12 +32,12 @@ func main() {
|
||||||
logger.SetLevel(slog.LevelDebug)
|
logger.SetLevel(slog.LevelDebug)
|
||||||
|
|
||||||
client := tunnel.NewClient(
|
client := tunnel.NewClient(
|
||||||
|
tunnel.WithClientServerAddress(serverAddr),
|
||||||
tunnel.WithClientCredentials(clientID),
|
tunnel.WithClientCredentials(clientID),
|
||||||
tunnel.WithClientAESBlockCrypt(sharedKey, salt),
|
|
||||||
)
|
)
|
||||||
defer client.Close()
|
defer client.Close()
|
||||||
|
|
||||||
initialBackoff := time.Second * 10
|
initialBackoff := time.Second * 2
|
||||||
backoff := initialBackoff
|
backoff := initialBackoff
|
||||||
|
|
||||||
sleep := func() {
|
sleep := func() {
|
||||||
|
|
|
@ -2,12 +2,9 @@ package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"flag"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
|
||||||
"net/url"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"cdr.dev/slog"
|
"cdr.dev/slog"
|
||||||
"forge.cadoles.com/wpetit/go-tunnel"
|
"forge.cadoles.com/wpetit/go-tunnel"
|
||||||
|
@ -15,20 +12,61 @@ import (
|
||||||
"gitlab.com/wpetit/goweb/logger"
|
"gitlab.com/wpetit/goweb/logger"
|
||||||
)
|
)
|
||||||
|
|
||||||
const sharedKey = "go-tunnel"
|
|
||||||
const salt = "go-tunnel"
|
const salt = "go-tunnel"
|
||||||
|
|
||||||
var registry = NewRegistry()
|
var registry = NewRegistry()
|
||||||
|
|
||||||
func main() {
|
func main() {
|
||||||
|
var (
|
||||||
|
serverAddr = ":36543"
|
||||||
|
httpAddr = ":3003"
|
||||||
|
sharedKey = "go-tunnel"
|
||||||
|
targetURL = "https://arcad.games"
|
||||||
|
)
|
||||||
|
|
||||||
|
flag.StringVar(&serverAddr, "server-addr", serverAddr, "server address")
|
||||||
|
flag.StringVar(&targetURL, "target-url", targetURL, "target url")
|
||||||
|
flag.StringVar(&httpAddr, "http-addr", httpAddr, "http server address")
|
||||||
|
flag.StringVar(&sharedKey, "shared-key", sharedKey, "shared key")
|
||||||
|
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
||||||
logger.SetLevel(slog.LevelDebug)
|
logger.SetLevel(slog.LevelDebug)
|
||||||
|
|
||||||
server := tunnel.NewServer(
|
server := tunnel.NewServer(
|
||||||
tunnel.WithServerAESBlockCrypt(sharedKey, salt),
|
tunnel.WithServerAddress(serverAddr),
|
||||||
tunnel.WithServerOnClientAuth(registry.OnClientAuth),
|
tunnel.WithServerOnClientAuth(registry.OnClientAuth),
|
||||||
tunnel.WithServerOnClientDisconnect(registry.OnClientDisconnect),
|
tunnel.WithServerOnClientDisconnect(registry.OnClientDisconnect),
|
||||||
|
tunnel.WithServerOnClientAuth(func(ctx context.Context, remoteClient *tunnel.RemoteClient, credentials interface{}) (bool, error) {
|
||||||
|
remoteAddr := remoteClient.RemoteAddr().String()
|
||||||
|
|
||||||
|
ctx = logger.With(ctx, logger.F("remoteAddr", remoteAddr))
|
||||||
|
|
||||||
|
logger.Debug(ctx, "new client auth")
|
||||||
|
|
||||||
|
clientID, ok := credentials.(string)
|
||||||
|
if !ok {
|
||||||
|
logger.Debug(ctx, "client auth failed")
|
||||||
|
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
registry.Add(clientID, remoteAddr, remoteClient)
|
||||||
|
|
||||||
|
logger.Debug(ctx, "client auth success")
|
||||||
|
|
||||||
|
return true, nil
|
||||||
|
}),
|
||||||
|
tunnel.WithServerOnClientDisconnect(func(ctx context.Context, remoteClient *tunnel.RemoteClient) error {
|
||||||
|
remoteAddr := remoteClient.RemoteAddr().String()
|
||||||
|
ctx = logger.With(ctx, logger.F("remoteAddr", remoteAddr))
|
||||||
|
logger.Debug(ctx, "client disconnected")
|
||||||
|
registry.RemoveByRemoteAddr(remoteAddr)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}),
|
||||||
)
|
)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
|
@ -37,51 +75,28 @@ func main() {
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if err := http.ListenAndServe(":3003", http.HandlerFunc(handleRequest)); err != nil {
|
handler, err := createProxyHandler(targetURL)
|
||||||
|
if err != nil {
|
||||||
|
logger.Fatal(ctx, "could not create proxy handler", logger.E(errors.WithStack(err)))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := http.ListenAndServe(httpAddr, handler); err != nil {
|
||||||
logger.Fatal(ctx, "error while listening", logger.E(err))
|
logger.Fatal(ctx, "error while listening", logger.E(err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func handleRequest(w http.ResponseWriter, r *http.Request) {
|
func createProxyHandler(targetURL string) (http.Handler, error) {
|
||||||
|
return tunnel.ProxyHandler(targetURL, func(w http.ResponseWriter, r *http.Request) (*tunnel.RemoteClient, error) {
|
||||||
subdomains := strings.SplitN(r.Host, ".", 2)
|
subdomains := strings.SplitN(r.Host, ".", 2)
|
||||||
|
|
||||||
if len(subdomains) < 2 {
|
if len(subdomains) < 2 {
|
||||||
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
|
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
|
||||||
|
|
||||||
return
|
return nil, tunnel.ErrAbortProxy
|
||||||
}
|
}
|
||||||
|
|
||||||
clientID := subdomains[0]
|
clientID := subdomains[0]
|
||||||
remoteClient := registry.Get(clientID)
|
|
||||||
|
|
||||||
if remoteClient == nil {
|
return registry.Get(clientID), nil
|
||||||
http.Error(w, http.StatusText(http.StatusBadGateway), http.StatusBadGateway)
|
})
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
target, err := url.Parse("https://arcad.games")
|
|
||||||
if err != nil {
|
|
||||||
logger.Fatal(r.Context(), "could not parse url", logger.E(err))
|
|
||||||
}
|
|
||||||
|
|
||||||
reverse := httputil.NewSingleHostReverseProxy(target)
|
|
||||||
reverse.Transport = &http.Transport{
|
|
||||||
Proxy: http.ProxyFromEnvironment,
|
|
||||||
DialContext: func(ctx context.Context, network string, addr string) (net.Conn, error) {
|
|
||||||
conn, err := remoteClient.Proxy(ctx, network, addr)
|
|
||||||
if err != nil {
|
|
||||||
return nil, errors.WithStack(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return conn, nil
|
|
||||||
},
|
|
||||||
ForceAttemptHTTP2: true,
|
|
||||||
MaxIdleConns: 100,
|
|
||||||
IdleConnTimeout: 90 * time.Second,
|
|
||||||
TLSHandshakeTimeout: 10 * time.Second,
|
|
||||||
ExpectContinueTimeout: 1 * time.Second,
|
|
||||||
}
|
|
||||||
|
|
||||||
reverse.ServeHTTP(w, r)
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,9 +0,0 @@
|
||||||
package control
|
|
||||||
|
|
||||||
type AuthRequestPayload struct {
|
|
||||||
Credentials interface{} `json:"c"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type AuthResponsePayload struct {
|
|
||||||
Success bool `json:"s"`
|
|
||||||
}
|
|
|
@ -1,194 +0,0 @@
|
||||||
package control
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"encoding/json"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
"github.com/xtaci/smux"
|
|
||||||
"gitlab.com/wpetit/goweb/logger"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Control struct {
|
|
||||||
encoder *json.Encoder
|
|
||||||
decoder *json.Decoder
|
|
||||||
stream *smux.Stream
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Control) Init(ctx context.Context, sess *smux.Session, serverMode bool) error {
|
|
||||||
config := smux.DefaultConfig()
|
|
||||||
config.Version = 2
|
|
||||||
|
|
||||||
logger.Debug(ctx, "creating control stream")
|
|
||||||
|
|
||||||
var (
|
|
||||||
stream *smux.Stream
|
|
||||||
err error
|
|
||||||
)
|
|
||||||
|
|
||||||
if serverMode {
|
|
||||||
stream, err = sess.AcceptStream()
|
|
||||||
if err != nil {
|
|
||||||
return errors.WithStack(err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
stream, err = sess.OpenStream()
|
|
||||||
if err != nil {
|
|
||||||
return errors.WithStack(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
c.stream = stream
|
|
||||||
c.decoder = json.NewDecoder(stream)
|
|
||||||
c.encoder = json.NewEncoder(stream)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Control) AuthRequest(credentials interface{}) (bool, error) {
|
|
||||||
req := NewMessage(TypeAuthRequest, &AuthRequestPayload{
|
|
||||||
Credentials: credentials,
|
|
||||||
})
|
|
||||||
|
|
||||||
res := NewMessage(TypeAuthResponse, nil)
|
|
||||||
|
|
||||||
if err := c.reqRes(req, res); err != nil {
|
|
||||||
return false, errors.WithStack(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
authResPayload, ok := res.Payload.(*AuthResponsePayload)
|
|
||||||
if !ok {
|
|
||||||
return false, errors.WithStack(ErrUnexpectedMessage)
|
|
||||||
}
|
|
||||||
|
|
||||||
return authResPayload.Success, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Control) ProxyReq(ctx context.Context, network, address string) error {
|
|
||||||
req := NewMessage(TypeProxyRequest, &ProxyRequestPayload{
|
|
||||||
Network: network,
|
|
||||||
Address: address,
|
|
||||||
})
|
|
||||||
|
|
||||||
if err := c.Write(req); err != nil {
|
|
||||||
return errors.WithStack(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Control) Listen(ctx context.Context, handlers Handlers) error {
|
|
||||||
errChan := make(chan error)
|
|
||||||
msgChan := make(chan *Message)
|
|
||||||
dieChan := c.stream.GetDieCh()
|
|
||||||
|
|
||||||
go func(msgChan chan *Message, errChan chan error) {
|
|
||||||
for {
|
|
||||||
logger.Debug(ctx, "reading next message")
|
|
||||||
|
|
||||||
msg, err := c.Read()
|
|
||||||
if err != nil {
|
|
||||||
errChan <- errors.WithStack(err)
|
|
||||||
|
|
||||||
close(errChan)
|
|
||||||
close(msgChan)
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
msgChan <- msg
|
|
||||||
}
|
|
||||||
}(msgChan, errChan)
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
return nil
|
|
||||||
|
|
||||||
case <-dieChan:
|
|
||||||
return errors.WithStack(ErrStreamClosed)
|
|
||||||
|
|
||||||
case err := <-errChan:
|
|
||||||
return errors.WithStack(err)
|
|
||||||
|
|
||||||
case msg := <-msgChan:
|
|
||||||
go func() {
|
|
||||||
subCtx := logger.With(ctx, logger.F("messageType", msg.Type))
|
|
||||||
|
|
||||||
handler, exists := handlers[msg.Type]
|
|
||||||
if !exists {
|
|
||||||
logger.Error(subCtx, "no message handler registered")
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
res, err := handler(subCtx, msg)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error(subCtx, "error while handling message", logger.E(err))
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if res == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := c.Write(res); err != nil {
|
|
||||||
logger.Error(subCtx, "error while write message response", logger.E(err))
|
|
||||||
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Control) Read() (*Message, error) {
|
|
||||||
message := &Message{}
|
|
||||||
|
|
||||||
if err := c.read(message); err != nil {
|
|
||||||
return nil, errors.WithStack(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return message, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Control) Write(m *Message) error {
|
|
||||||
if err := c.write(m); err != nil {
|
|
||||||
return errors.WithStack(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Control) reqRes(req *Message, res *Message) error {
|
|
||||||
if err := c.write(req); err != nil {
|
|
||||||
return errors.WithStack(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := c.read(res); err != nil {
|
|
||||||
return errors.WithStack(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Control) read(m *Message) error {
|
|
||||||
if err := c.decoder.Decode(m); err != nil {
|
|
||||||
return errors.WithStack(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *Control) write(m *Message) error {
|
|
||||||
if err := c.encoder.Encode(m); err != nil {
|
|
||||||
return errors.WithStack(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func New() *Control {
|
|
||||||
return &Control{}
|
|
||||||
}
|
|
|
@ -1,8 +0,0 @@
|
||||||
package control
|
|
||||||
|
|
||||||
import "errors"
|
|
||||||
|
|
||||||
var (
|
|
||||||
ErrStreamClosed = errors.New("stream closed")
|
|
||||||
ErrUnexpectedMessage = errors.New("unexpected message")
|
|
||||||
)
|
|
|
@ -1,7 +0,0 @@
|
||||||
package control
|
|
||||||
|
|
||||||
import "context"
|
|
||||||
|
|
||||||
type Handlers map[MessageType]MessageHandler
|
|
||||||
|
|
||||||
type MessageHandler func(ctx context.Context, m *Message) (*Message, error)
|
|
|
@ -1,76 +0,0 @@
|
||||||
package control
|
|
||||||
|
|
||||||
import (
|
|
||||||
"encoding/json"
|
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
TypeAuthRequest MessageType = "auth-req"
|
|
||||||
TypeAuthResponse MessageType = "auth-res"
|
|
||||||
TypeProxyRequest MessageType = "proxy-req"
|
|
||||||
TypeCloseProxy MessageType = "close-proxy"
|
|
||||||
)
|
|
||||||
|
|
||||||
type MessageType string
|
|
||||||
|
|
||||||
type BaseMessage struct {
|
|
||||||
Type MessageType `json:"t"`
|
|
||||||
RawPayload json.RawMessage `json:"p"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type Message struct {
|
|
||||||
BaseMessage
|
|
||||||
Payload interface{} `json:"p"`
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Message) UnmarshalJSON(data []byte) error {
|
|
||||||
base := &BaseMessage{}
|
|
||||||
|
|
||||||
if err := json.Unmarshal(data, base); err != nil {
|
|
||||||
return errors.WithStack(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
payload, err := unmarshalPayload(base.Type, base.RawPayload)
|
|
||||||
if err != nil {
|
|
||||||
return errors.WithStack(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
m.Type = base.Type
|
|
||||||
m.Payload = payload
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func NewMessage(mType MessageType, payload interface{}) *Message {
|
|
||||||
return &Message{
|
|
||||||
BaseMessage: BaseMessage{
|
|
||||||
Type: mType,
|
|
||||||
},
|
|
||||||
Payload: payload,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func unmarshalPayload(mType MessageType, data []byte) (interface{}, error) {
|
|
||||||
var payload interface{}
|
|
||||||
|
|
||||||
switch mType {
|
|
||||||
case TypeAuthRequest:
|
|
||||||
payload = &AuthRequestPayload{}
|
|
||||||
case TypeAuthResponse:
|
|
||||||
payload = &AuthResponsePayload{}
|
|
||||||
case TypeProxyRequest:
|
|
||||||
payload = &ProxyRequestPayload{}
|
|
||||||
case TypeCloseProxy:
|
|
||||||
payload = &CloseProxyPayload{}
|
|
||||||
default:
|
|
||||||
return nil, errors.Wrapf(ErrUnexpectedMessage, "unexpected message type '%s'", mType)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := json.Unmarshal(data, payload); err != nil {
|
|
||||||
return nil, errors.WithStack(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return payload, nil
|
|
||||||
}
|
|
|
@ -1,10 +0,0 @@
|
||||||
package control
|
|
||||||
|
|
||||||
type ProxyRequestPayload struct {
|
|
||||||
Network string `json:"n"`
|
|
||||||
Address string `json:"a"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type CloseProxyPayload struct {
|
|
||||||
RequestID int64 `json:"i"`
|
|
||||||
}
|
|
2
error.go
2
error.go
|
@ -6,7 +6,7 @@ var (
|
||||||
ErrNotConnected = errors.New("not connected")
|
ErrNotConnected = errors.New("not connected")
|
||||||
ErrCouldNotConnect = errors.New("could not connect")
|
ErrCouldNotConnect = errors.New("could not connect")
|
||||||
ErrConnectionClosed = errors.New("connection closed")
|
ErrConnectionClosed = errors.New("connection closed")
|
||||||
ErrAuthFailed = errors.New("auth failed")
|
ErrAuthenticationFailed = errors.New("authentication failed")
|
||||||
ErrUnexpectedMessage = errors.New("unexpected message")
|
ErrUnexpectedMessage = errors.New("unexpected message")
|
||||||
ErrUnexpectedResponse = errors.New("unexpected response")
|
ErrUnexpectedResponse = errors.New("unexpected response")
|
||||||
)
|
)
|
||||||
|
|
1
go.mod
1
go.mod
|
@ -4,6 +4,7 @@ go 1.15
|
||||||
|
|
||||||
require (
|
require (
|
||||||
cdr.dev/slog v1.3.0
|
cdr.dev/slog v1.3.0
|
||||||
|
github.com/davecgh/go-spew v1.1.1
|
||||||
github.com/orcaman/concurrent-map v0.0.0-20190826125027-8c72a8bb44f6
|
github.com/orcaman/concurrent-map v0.0.0-20190826125027-8c72a8bb44f6
|
||||||
github.com/pkg/errors v0.9.1
|
github.com/pkg/errors v0.9.1
|
||||||
github.com/streamrail/concurrent-map v0.0.0-20160823150647-8bf1e9bacbf6
|
github.com/streamrail/concurrent-map v0.0.0-20160823150647-8bf1e9bacbf6
|
||||||
|
|
|
@ -0,0 +1,102 @@
|
||||||
|
package tunnel
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httputil"
|
||||||
|
"net/url"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
"gitlab.com/wpetit/goweb/logger"
|
||||||
|
)
|
||||||
|
|
||||||
|
type contextKey string
|
||||||
|
|
||||||
|
const remoteClientKey contextKey = "go-tunnel.remoteclient"
|
||||||
|
|
||||||
|
var (
|
||||||
|
ErrAbortProxy = errors.New("proxy aborted")
|
||||||
|
)
|
||||||
|
|
||||||
|
type MatchRequestFunc func(w http.ResponseWriter, r *http.Request) (*RemoteClient, error)
|
||||||
|
|
||||||
|
func ProxyHandler(targetURL string, match MatchRequestFunc, funcs ...ProxyConfigFunc) (http.Handler, error) {
|
||||||
|
conf := DefaultProxyConfig()
|
||||||
|
|
||||||
|
for _, fn := range funcs {
|
||||||
|
fn(conf)
|
||||||
|
}
|
||||||
|
|
||||||
|
target, err := url.Parse(targetURL)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
reverse := createReverseProxy(target)
|
||||||
|
|
||||||
|
if conf.ConfigureReverseProxy != nil {
|
||||||
|
if err := conf.ConfigureReverseProxy(reverse); err != nil {
|
||||||
|
return nil, errors.WithStack(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn := func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
remoteClient, err := match(w, r)
|
||||||
|
if errors.Is(err, ErrAbortProxy) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
logger.Error(r.Context(), "could not match proxy request", logger.E(errors.WithStack(err)))
|
||||||
|
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if remoteClient == nil {
|
||||||
|
http.Error(w, http.StatusText(http.StatusBadGateway), http.StatusBadGateway)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := context.WithValue(r.Context(), remoteClientKey, remoteClient)
|
||||||
|
r = r.WithContext(ctx)
|
||||||
|
|
||||||
|
reverse.ServeHTTP(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
return http.HandlerFunc(fn), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func createReverseProxy(target *url.URL) *httputil.ReverseProxy {
|
||||||
|
reverse := httputil.NewSingleHostReverseProxy(target)
|
||||||
|
|
||||||
|
// nolint: go-mnd
|
||||||
|
reverse.Transport = &http.Transport{
|
||||||
|
Proxy: http.ProxyFromEnvironment,
|
||||||
|
DialContext: func(ctx context.Context, network string, addr string) (net.Conn, error) {
|
||||||
|
remoteClient, ok := ctx.Value(remoteClientKey).(*RemoteClient)
|
||||||
|
if !ok {
|
||||||
|
return nil, errors.New("could not retrieve remote client")
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := remoteClient.Proxy(ctx, network, addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return conn, nil
|
||||||
|
},
|
||||||
|
ForceAttemptHTTP2: true,
|
||||||
|
MaxIdleConns: 100,
|
||||||
|
IdleConnTimeout: 90 * time.Second,
|
||||||
|
TLSHandshakeTimeout: 10 * time.Second,
|
||||||
|
ExpectContinueTimeout: 1 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
reverse.FlushInterval = 0
|
||||||
|
|
||||||
|
return reverse
|
||||||
|
}
|
|
@ -0,0 +1,21 @@
|
||||||
|
package tunnel
|
||||||
|
|
||||||
|
import "net/http/httputil"
|
||||||
|
|
||||||
|
type ConfigureReverseProxyFunc func(*httputil.ReverseProxy) error
|
||||||
|
|
||||||
|
type ProxyConfig struct {
|
||||||
|
ConfigureReverseProxy ConfigureReverseProxyFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func DefaultProxyConfig() *ProxyConfig {
|
||||||
|
return &ProxyConfig{}
|
||||||
|
}
|
||||||
|
|
||||||
|
type ProxyConfigFunc func(c *ProxyConfig)
|
||||||
|
|
||||||
|
func WithProxyConfigure(fn ConfigureReverseProxyFunc) ProxyConfigFunc {
|
||||||
|
return func(c *ProxyConfig) {
|
||||||
|
c.ConfigureReverseProxy = fn
|
||||||
|
}
|
||||||
|
}
|
|
@ -1,7 +1,7 @@
|
||||||
**/*.go {
|
**/*.go
|
||||||
|
modd.conf {
|
||||||
prep: make test
|
prep: make test
|
||||||
prep: make build
|
prep: make build
|
||||||
daemon: ./bin/server
|
daemon: ./bin/server -target-url http://127.0.0.1:3000
|
||||||
daemon: ./bin/client -id client1
|
daemon: ./bin/client -id client1
|
||||||
daemon: ./bin/client -id client2
|
|
||||||
}
|
}
|
|
@ -0,0 +1,14 @@
|
||||||
|
package tunnel
|
||||||
|
|
||||||
|
type authRequest struct {
|
||||||
|
Credentials interface{} `json:"c"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type authResponse struct {
|
||||||
|
Success bool `json:"b"`
|
||||||
|
}
|
||||||
|
|
||||||
|
type proxyRequest struct {
|
||||||
|
Network string `json:"n"`
|
||||||
|
Address string `json:"a"`
|
||||||
|
}
|
208
remote_client.go
208
remote_client.go
|
@ -2,13 +2,11 @@ package tunnel
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/json"
|
||||||
"net"
|
"net"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"forge.cadoles.com/wpetit/go-tunnel/control"
|
|
||||||
|
|
||||||
cmap "github.com/orcaman/concurrent-map"
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/xtaci/kcp-go/v5"
|
"github.com/xtaci/kcp-go/v5"
|
||||||
"github.com/xtaci/smux"
|
"github.com/xtaci/smux"
|
||||||
|
@ -19,61 +17,45 @@ type RemoteClient struct {
|
||||||
onClientAuthHook OnClientAuthHook
|
onClientAuthHook OnClientAuthHook
|
||||||
onClientConnectHook OnClientConnectHook
|
onClientConnectHook OnClientConnectHook
|
||||||
onClientDisconnectHook OnClientDisconnectHook
|
onClientDisconnectHook OnClientDisconnectHook
|
||||||
|
conn *kcp.UDPSession
|
||||||
sess *smux.Session
|
sess *smux.Session
|
||||||
control *control.Control
|
|
||||||
remoteAddr net.Addr
|
remoteAddr net.Addr
|
||||||
proxies cmap.ConcurrentMap
|
authenticationTimeout time.Duration
|
||||||
acceptStreamMutex sync.Mutex
|
proxyRequestTimeout time.Duration
|
||||||
|
connMutex sync.RWMutex
|
||||||
|
smuxConfig *smux.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *RemoteClient) Accept(ctx context.Context, conn *kcp.UDPSession) error {
|
func (c *RemoteClient) Accept(ctx context.Context, conn *kcp.UDPSession) error {
|
||||||
config := smux.DefaultConfig()
|
c.connMutex.Lock()
|
||||||
config.Version = 2
|
defer c.connMutex.Unlock()
|
||||||
config.KeepAliveInterval = 10 * time.Second
|
|
||||||
config.KeepAliveTimeout = 2 * config.KeepAliveInterval
|
|
||||||
|
|
||||||
logger.Debug(ctx, "creating server session")
|
if err := c.Close(); err != nil {
|
||||||
|
return errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
sess, err := smux.Server(conn, config)
|
sess, err := c.acceptSession(ctx, conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return errors.WithStack(err)
|
return errors.WithStack(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ctrl := control.New()
|
stream, err := sess.AcceptStream()
|
||||||
|
if err != nil {
|
||||||
|
return errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
if err := ctrl.Init(ctx, sess, true); err != nil {
|
defer stream.Close()
|
||||||
|
|
||||||
|
if err := c.authenticate(ctx, stream); err != nil {
|
||||||
return errors.WithStack(err)
|
return errors.WithStack(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.sess = sess
|
c.sess = sess
|
||||||
c.remoteAddr = conn.RemoteAddr()
|
c.conn = conn
|
||||||
c.control = ctrl
|
|
||||||
|
|
||||||
if c.onClientConnectHook != nil {
|
|
||||||
if err := c.onClientConnectHook.OnClientConnect(ctx, c); err != nil {
|
|
||||||
return errors.WithStack(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *RemoteClient) Listen(ctx context.Context) error {
|
|
||||||
defer func() {
|
|
||||||
if c.onClientDisconnectHook != nil {
|
|
||||||
if err := c.onClientDisconnectHook.OnClientDisconnect(ctx, c); err != nil {
|
|
||||||
logger.Error(ctx, "client disconnect hook error", logger.E(errors.WithStack(err)))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
logger.Debug(ctx, "listening for messages")
|
|
||||||
|
|
||||||
return c.control.Listen(ctx, control.Handlers{
|
|
||||||
control.TypeAuthRequest: c.handleAuthRequest,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *RemoteClient) ConfigureHooks(hooks interface{}) {
|
func (c *RemoteClient) ConfigureHooks(hooks interface{}) {
|
||||||
if hooks == nil {
|
if hooks == nil {
|
||||||
return
|
return
|
||||||
|
@ -96,74 +78,164 @@ func (c *RemoteClient) RemoteAddr() net.Addr {
|
||||||
return c.remoteAddr
|
return c.remoteAddr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *RemoteClient) Close() {
|
func (c *RemoteClient) Close() error {
|
||||||
if c.sess != nil {
|
if c.sess != nil {
|
||||||
c.sess.Close()
|
if err := c.sess.Close(); err != nil {
|
||||||
|
return errors.WithStack(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if c.conn != nil {
|
||||||
|
if err := c.conn.Close(); err != nil {
|
||||||
|
return errors.WithStack(err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
c.sess = nil
|
c.sess = nil
|
||||||
c.control = nil
|
c.conn = nil
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RemoteClient) SwitchConn(ctx context.Context, conn *kcp.UDPSession) error {
|
||||||
|
c.connMutex.Lock()
|
||||||
|
defer c.connMutex.Unlock()
|
||||||
|
|
||||||
|
if err := c.Close(); err != nil {
|
||||||
|
return errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sess, err := c.acceptSession(ctx, conn)
|
||||||
|
if err != nil {
|
||||||
|
return errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.sess = sess
|
||||||
|
c.conn = conn
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *RemoteClient) Proxy(ctx context.Context, network, address string) (net.Conn, error) {
|
func (c *RemoteClient) Proxy(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
ctx = logger.With(ctx, logger.F("network", network), logger.F("address", address))
|
c.connMutex.RLock()
|
||||||
|
defer c.connMutex.RUnlock()
|
||||||
|
|
||||||
if err := c.control.ProxyReq(ctx, network, address); err != nil {
|
ctx = logger.With(ctx, logger.F("network", network), logger.F("address", address))
|
||||||
return nil, errors.WithStack(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Debug(ctx, "opening proxy stream")
|
logger.Debug(ctx, "opening proxy stream")
|
||||||
|
|
||||||
c.acceptStreamMutex.Lock()
|
stream, err := c.sess.OpenStream()
|
||||||
|
|
||||||
stream, err := c.sess.AcceptStream()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
c.acceptStreamMutex.Unlock()
|
|
||||||
return nil, errors.WithStack(err)
|
return nil, errors.WithStack(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
c.acceptStreamMutex.Unlock()
|
proxyReq := &proxyRequest{
|
||||||
|
Network: network,
|
||||||
|
Address: address,
|
||||||
|
}
|
||||||
|
encoder := json.NewEncoder(stream)
|
||||||
|
|
||||||
go func() {
|
writeDeadline := time.Now().Add(c.proxyRequestTimeout)
|
||||||
<-ctx.Done()
|
logger.Debug(ctx, "sending proxy req", logger.F("deadline", writeDeadline))
|
||||||
logger.Debug(ctx, "closing proxy stream")
|
|
||||||
|
if err := stream.SetWriteDeadline(writeDeadline); err != nil {
|
||||||
stream.Close()
|
stream.Close()
|
||||||
}()
|
|
||||||
|
return nil, errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := encoder.Encode(proxyReq); err != nil {
|
||||||
|
stream.Close()
|
||||||
|
|
||||||
|
return nil, errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := stream.SetWriteDeadline(time.Time{}); err != nil {
|
||||||
|
stream.Close()
|
||||||
|
|
||||||
|
return nil, errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
return stream, nil
|
return stream, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *RemoteClient) handleAuthRequest(ctx context.Context, m *control.Message) (*control.Message, error) {
|
func (c *RemoteClient) acceptSession(ctx context.Context, conn *kcp.UDPSession) (*smux.Session, error) {
|
||||||
authReqPayload, ok := m.Payload.(*control.AuthRequestPayload)
|
logger.Debug(ctx, "accepting client session")
|
||||||
if !ok {
|
|
||||||
return nil, errors.WithStack(ErrUnexpectedMessage)
|
sess, err := smux.Server(conn, c.smuxConfig)
|
||||||
|
if err != nil {
|
||||||
|
return nil, errors.WithStack(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Debug(ctx, "handling auth request", logger.F("credentials", authReqPayload.Credentials))
|
c.remoteAddr = conn.RemoteAddr()
|
||||||
|
|
||||||
|
if c.onClientConnectHook != nil {
|
||||||
|
if err := c.onClientConnectHook.OnClientConnect(ctx, c); err != nil {
|
||||||
|
return nil, errors.WithStack(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return sess, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *RemoteClient) authenticate(ctx context.Context, stream *smux.Stream) error {
|
||||||
|
start := time.Now()
|
||||||
|
|
||||||
|
readDeadline := time.Now().Add(c.authenticationTimeout)
|
||||||
|
logger.Debug(ctx, "waiting for auth request", logger.F("deadline", readDeadline))
|
||||||
|
|
||||||
|
if err := stream.SetReadDeadline(readDeadline); err != nil {
|
||||||
|
return errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
decoder := json.NewDecoder(stream)
|
||||||
|
authReq := &authRequest{}
|
||||||
|
|
||||||
|
if err := decoder.Decode(authReq); err != nil {
|
||||||
|
return errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
var (
|
var (
|
||||||
success bool
|
success bool
|
||||||
err error
|
err error
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger.Debug(ctx, "received client credentials", logger.F("credentials", authReq.Credentials))
|
||||||
|
|
||||||
if c.onClientAuthHook != nil {
|
if c.onClientAuthHook != nil {
|
||||||
success, err = c.onClientAuthHook.OnClientAuth(ctx, c, authReqPayload.Credentials)
|
success, err = c.onClientAuthHook.OnClientAuth(ctx, c, authReq.Credentials)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, errors.WithStack(err)
|
return errors.WithStack(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.Debug(ctx, "auth succeeded", logger.F("credentials", authReqPayload.Credentials))
|
authRes := &authResponse{
|
||||||
|
|
||||||
res := control.NewMessage(control.TypeAuthResponse, &control.AuthResponsePayload{
|
|
||||||
Success: success,
|
Success: success,
|
||||||
})
|
}
|
||||||
|
encoder := json.NewEncoder(stream)
|
||||||
|
|
||||||
return res, nil
|
writeDeadline := time.Now().Add(c.authenticationTimeout - time.Since(start))
|
||||||
|
logger.Debug(ctx, "sending auth response", logger.F("deadline", writeDeadline))
|
||||||
|
|
||||||
|
if err := stream.SetWriteDeadline(writeDeadline); err != nil {
|
||||||
|
return errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := encoder.Encode(authRes); err != nil {
|
||||||
|
return errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !success {
|
||||||
|
return errors.WithStack(ErrAuthenticationFailed)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewRemoteClient() *RemoteClient {
|
func NewRemoteClient(smuxConfig *smux.Config, authenticationTimeout, proxyRequestTimeout time.Duration) *RemoteClient {
|
||||||
return &RemoteClient{
|
return &RemoteClient{
|
||||||
proxies: cmap.New(),
|
smuxConfig: smuxConfig,
|
||||||
|
authenticationTimeout: authenticationTimeout,
|
||||||
|
proxyRequestTimeout: proxyRequestTimeout,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
41
server.go
41
server.go
|
@ -3,8 +3,8 @@ package tunnel
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
|
||||||
cmap "github.com/orcaman/concurrent-map"
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
|
cmap "github.com/streamrail/concurrent-map"
|
||||||
"github.com/xtaci/kcp-go/v5"
|
"github.com/xtaci/kcp-go/v5"
|
||||||
"gitlab.com/wpetit/goweb/logger"
|
"gitlab.com/wpetit/goweb/logger"
|
||||||
)
|
)
|
||||||
|
@ -23,6 +23,14 @@ func (s *Server) Listen(ctx context.Context) error {
|
||||||
return errors.WithStack(err)
|
return errors.WithStack(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if s.conf.ConfigureListener != nil {
|
||||||
|
if err := s.conf.ConfigureListener(listener); err != nil {
|
||||||
|
return errors.WithStack(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Debug(ctx, "accepting connections", logger.F("address", s.conf.Address))
|
||||||
|
|
||||||
for {
|
for {
|
||||||
conn, err := listener.AcceptKCP()
|
conn, err := listener.AcceptKCP()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
@ -34,12 +42,31 @@ func (s *Server) Listen(ctx context.Context) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) handleNewConn(ctx context.Context, conn *kcp.UDPSession) {
|
func (s *Server) handleNewConn(ctx context.Context, conn *kcp.UDPSession) {
|
||||||
ctx = logger.With(ctx, logger.F("remoteAddr", conn.RemoteAddr().String()))
|
var remoteClient *RemoteClient
|
||||||
|
|
||||||
remoteClient := NewRemoteClient()
|
remoteAddr := conn.RemoteAddr().String()
|
||||||
|
ctx = logger.With(ctx, logger.F("remoteAddr", remoteAddr))
|
||||||
|
|
||||||
defer remoteClient.Close()
|
rawExistingClient, exists := s.clients.Get(remoteAddr)
|
||||||
defer conn.Close()
|
if exists {
|
||||||
|
logger.Debug(ctx, "remote client already exists")
|
||||||
|
|
||||||
|
remoteClient, _ = rawExistingClient.(*RemoteClient)
|
||||||
|
|
||||||
|
if err := remoteClient.SwitchConn(ctx, conn); err != nil {
|
||||||
|
logger.Error(ctx, "remote client error", logger.E(errors.WithStack(err)))
|
||||||
|
|
||||||
|
s.clients.Remove(remoteAddr)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
remoteClient = NewRemoteClient(
|
||||||
|
s.conf.SmuxConfig,
|
||||||
|
s.conf.AuthenticationTimeout,
|
||||||
|
s.conf.ProxyRequestTimeout,
|
||||||
|
)
|
||||||
|
|
||||||
remoteClient.ConfigureHooks(s.conf.Hooks)
|
remoteClient.ConfigureHooks(s.conf.Hooks)
|
||||||
|
|
||||||
|
@ -49,9 +76,7 @@ func (s *Server) handleNewConn(ctx context.Context, conn *kcp.UDPSession) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := remoteClient.Listen(ctx); err != nil {
|
s.clients.Set(remoteAddr, remoteClient)
|
||||||
logger.Error(ctx, "remote client error", logger.E(errors.WithStack(err)))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServer(funcs ...ServerConfigFunc) *Server {
|
func NewServer(funcs ...ServerConfigFunc) *Server {
|
||||||
|
|
|
@ -2,13 +2,16 @@ package tunnel
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/sha1"
|
"crypto/sha1"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"github.com/xtaci/kcp-go/v5"
|
"github.com/xtaci/kcp-go/v5"
|
||||||
|
"github.com/xtaci/smux"
|
||||||
"golang.org/x/crypto/pbkdf2"
|
"golang.org/x/crypto/pbkdf2"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ConfigureConnFunc func(conn *kcp.UDPSession) error
|
type ConfigureConnFunc func(conn *kcp.UDPSession) error
|
||||||
|
type ConfigureListenerFunc func(listener *kcp.Listener) error
|
||||||
|
|
||||||
type ServerConfig struct {
|
type ServerConfig struct {
|
||||||
Address string
|
Address string
|
||||||
|
@ -17,14 +20,25 @@ type ServerConfig struct {
|
||||||
ParityShards int
|
ParityShards int
|
||||||
Hooks *ServerHooks
|
Hooks *ServerHooks
|
||||||
ConfigureConn ConfigureConnFunc
|
ConfigureConn ConfigureConnFunc
|
||||||
|
ConfigureListener ConfigureListenerFunc
|
||||||
|
AuthenticationTimeout time.Duration
|
||||||
|
ProxyRequestTimeout time.Duration
|
||||||
|
SmuxConfig *smux.Config
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// nolint: go-mnd
|
||||||
func DefaultServerConfig() *ServerConfig {
|
func DefaultServerConfig() *ServerConfig {
|
||||||
unencryptedBlock, err := kcp.NewNoneBlockCrypt(nil)
|
unencryptedBlock, err := kcp.NewNoneBlockCrypt(nil)
|
||||||
if err != nil { // should never happen
|
if err != nil { // should never happen
|
||||||
panic(errors.WithStack(err))
|
panic(errors.WithStack(err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
smuxConfig := smux.DefaultConfig()
|
||||||
|
smuxConfig.Version = 2
|
||||||
|
smuxConfig.KeepAliveInterval = 10 * time.Second
|
||||||
|
smuxConfig.MaxReceiveBuffer = 4194304
|
||||||
|
smuxConfig.MaxStreamBuffer = 2097152
|
||||||
|
|
||||||
return &ServerConfig{
|
return &ServerConfig{
|
||||||
Address: ":36543",
|
Address: ":36543",
|
||||||
BlockCrypt: unencryptedBlock,
|
BlockCrypt: unencryptedBlock,
|
||||||
|
@ -35,6 +49,11 @@ func DefaultServerConfig() *ServerConfig {
|
||||||
onClientDisconnect: DefaultOnClientDisconnect,
|
onClientDisconnect: DefaultOnClientDisconnect,
|
||||||
onClientAuth: DefaultOnClientAuth,
|
onClientAuth: DefaultOnClientAuth,
|
||||||
},
|
},
|
||||||
|
ConfigureConn: DefaultServerConfigureConn,
|
||||||
|
ConfigureListener: DefaultServerConfigureListener,
|
||||||
|
AuthenticationTimeout: 30 * time.Second,
|
||||||
|
ProxyRequestTimeout: 5 * time.Second,
|
||||||
|
SmuxConfig: smuxConfig,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -82,3 +101,50 @@ func WithServerConfigureConn(fn ConfigureConnFunc) ServerConfigFunc {
|
||||||
conf.ConfigureConn = fn
|
conf.ConfigureConn = fn
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func WithServerConfigureListener(fn ConfigureListenerFunc) ServerConfigFunc {
|
||||||
|
return func(conf *ServerConfig) {
|
||||||
|
conf.ConfigureListener = fn
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func WithServerSmuxConfig(c *smux.Config) ServerConfigFunc {
|
||||||
|
return func(conf *ServerConfig) {
|
||||||
|
conf.SmuxConfig = c
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// nolint: go-mnd
|
||||||
|
func DefaultServerConfigureConn(conn *kcp.UDPSession) error {
|
||||||
|
// Based on kcptun default configuration, mode 'fast3'
|
||||||
|
conn.SetStreamMode(true)
|
||||||
|
conn.SetWriteDelay(false)
|
||||||
|
conn.SetNoDelay(1, 10, 2, 1)
|
||||||
|
conn.SetWindowSize(128, 512)
|
||||||
|
conn.SetMtu(1400)
|
||||||
|
conn.SetACKNoDelay(true)
|
||||||
|
|
||||||
|
if err := conn.SetDSCP(46); err != nil {
|
||||||
|
return errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// nolint: go-mnd
|
||||||
|
func DefaultServerConfigureListener(listener *kcp.Listener) error {
|
||||||
|
// Based on kcptun default configuration, mode 'fast3'
|
||||||
|
if err := listener.SetReadBuffer(16777217); err != nil {
|
||||||
|
return errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := listener.SetWriteBuffer(16777217); err != nil {
|
||||||
|
return errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := listener.SetDSCP(46); err != nil {
|
||||||
|
return errors.WithStack(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
Loading…
Reference in New Issue