feat: general protocol rewrite
This commit is contained in:
parent
536100da90
commit
96c1575a0b
4
Makefile
4
Makefile
@ -4,10 +4,10 @@ watch:
|
||||
build: build-server build-client
|
||||
|
||||
build-server:
|
||||
go build -o ./bin/server ./cmd/server
|
||||
CGO_ENABLED=0 go build -o ./bin/server ./cmd/server
|
||||
|
||||
build-client:
|
||||
go build -o ./bin/client ./cmd/client
|
||||
CGO_ENABLED=0 go build -o ./bin/client ./cmd/client
|
||||
|
||||
test:
|
||||
go test -v -race ./...
|
172
client.go
172
client.go
@ -2,30 +2,27 @@ package tunnel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"gitlab.com/wpetit/goweb/logger"
|
||||
|
||||
"forge.cadoles.com/wpetit/go-tunnel/control"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/xtaci/kcp-go/v5"
|
||||
"github.com/xtaci/smux"
|
||||
)
|
||||
|
||||
type Client struct {
|
||||
conf *ClientConfig
|
||||
conn *kcp.UDPSession
|
||||
sess *smux.Session
|
||||
control *control.Control
|
||||
http *http.Client
|
||||
openStreamMutex sync.Mutex
|
||||
conf *ClientConfig
|
||||
conn *kcp.UDPSession
|
||||
sess *smux.Session
|
||||
}
|
||||
|
||||
func (c *Client) Connect(ctx context.Context) error {
|
||||
logger.Debug(ctx, "connecting", logger.F("serverAddr", c.conf.ServerAddress))
|
||||
|
||||
conn, err := kcp.DialWithOptions(
|
||||
c.conf.ServerAddress, c.conf.BlockCrypt,
|
||||
c.conf.DataShards, c.conf.ParityShards,
|
||||
@ -40,34 +37,29 @@ func (c *Client) Connect(ctx context.Context) error {
|
||||
}
|
||||
}
|
||||
|
||||
config := smux.DefaultConfig()
|
||||
config.Version = 2
|
||||
config.KeepAliveInterval = 10 * time.Second
|
||||
config.KeepAliveTimeout = 2 * config.KeepAliveInterval
|
||||
|
||||
sess, err := smux.Client(conn, config)
|
||||
sess, err := smux.Client(conn, c.conf.SmuxConfig)
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
control := control.New()
|
||||
if err := control.Init(ctx, sess, false); err != nil {
|
||||
stream, err := sess.OpenStream()
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
logger.Debug(ctx, "sending auth request")
|
||||
defer stream.Close()
|
||||
|
||||
success, err := control.AuthRequest(c.conf.Credentials)
|
||||
success, err := c.authenticate(ctx, stream)
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
if !success {
|
||||
defer c.Close()
|
||||
return errors.WithStack(ErrAuthFailed)
|
||||
return errors.WithStack(ErrAuthenticationFailed)
|
||||
}
|
||||
|
||||
c.control = control
|
||||
logger.Debug(ctx, "authentication success")
|
||||
|
||||
c.conn = conn
|
||||
c.sess = sess
|
||||
|
||||
@ -75,87 +67,138 @@ func (c *Client) Connect(ctx context.Context) error {
|
||||
}
|
||||
|
||||
func (c *Client) Listen(ctx context.Context) error {
|
||||
logger.Debug(ctx, "listening for messages")
|
||||
logger.Debug(ctx, "listening for proxy requests")
|
||||
|
||||
ctx, cancel := context.WithCancel(ctx)
|
||||
defer cancel()
|
||||
for {
|
||||
stream, err := c.sess.AcceptStream()
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
err := c.control.Listen(ctx, control.Handlers{
|
||||
control.TypeProxyRequest: c.handleProxyRequest,
|
||||
})
|
||||
subCtx := logger.With(ctx,
|
||||
logger.F("remoteAddr", stream.RemoteAddr()),
|
||||
logger.F("localAddr", stream.LocalAddr()),
|
||||
)
|
||||
|
||||
if errors.Is(err, io.ErrClosedPipe) {
|
||||
logger.Debug(ctx, "client connection closed")
|
||||
readDeadline := time.Now().Add(c.conf.ProxyRequestTimeout)
|
||||
logger.Debug(subCtx, "waiting for proxy request", logger.F("deadline", readDeadline))
|
||||
|
||||
return errors.WithStack(ErrConnectionClosed)
|
||||
if err := stream.SetReadDeadline(readDeadline); err != nil {
|
||||
stream.Close()
|
||||
logger.Error(subCtx, "could not set read deadline", logger.E(errors.WithStack(err)))
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
decoder := json.NewDecoder(stream)
|
||||
proxyReq := &proxyRequest{}
|
||||
|
||||
if err := decoder.Decode(proxyReq); err != nil {
|
||||
stream.Close()
|
||||
logger.Error(subCtx, "could not decode proxy request", logger.E(errors.WithStack(err)))
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
if err := stream.SetReadDeadline(time.Time{}); err != nil {
|
||||
stream.Close()
|
||||
logger.Error(subCtx, "could not set read deadline", logger.E(errors.WithStack(err)))
|
||||
|
||||
continue
|
||||
}
|
||||
|
||||
go c.handleProxyStream(subCtx, stream, proxyReq.Network, proxyReq.Address)
|
||||
}
|
||||
|
||||
return err
|
||||
}
|
||||
|
||||
func (c *Client) Close() error {
|
||||
if c.conn == nil {
|
||||
return errors.WithStack(ErrNotConnected)
|
||||
if c.sess != nil && !c.sess.IsClosed() {
|
||||
if err := c.sess.Close(); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := c.conn.Close(); err != nil {
|
||||
return errors.WithStack(err)
|
||||
if c.conn != nil {
|
||||
if err := c.conn.Close(); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
}
|
||||
|
||||
c.conn = nil
|
||||
c.sess = nil
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Client) handleProxyRequest(ctx context.Context, m *control.Message) (*control.Message, error) {
|
||||
proxyReqPayload, ok := m.Payload.(*control.ProxyRequestPayload)
|
||||
if !ok {
|
||||
return nil, errors.WithStack(ErrUnexpectedMessage)
|
||||
func (c *Client) authenticate(ctx context.Context, stream *smux.Stream) (bool, error) {
|
||||
encoder := json.NewEncoder(stream)
|
||||
authReq := &authRequest{
|
||||
Credentials: c.conf.Credentials,
|
||||
}
|
||||
|
||||
ctx = logger.With(ctx,
|
||||
logger.F("network", proxyReqPayload.Network),
|
||||
logger.F("address", proxyReqPayload.Address),
|
||||
)
|
||||
start := time.Now()
|
||||
writeDeadline := start.Add(c.conf.AuthenticationTimeout)
|
||||
logger.Debug(ctx, "sending auth request", logger.F("deadline", writeDeadline))
|
||||
|
||||
logger.Debug(ctx, "handling proxy request")
|
||||
|
||||
out, err := net.Dial(proxyReqPayload.Network, proxyReqPayload.Address)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
if err := stream.SetWriteDeadline(writeDeadline); err != nil {
|
||||
return false, errors.WithStack(err)
|
||||
}
|
||||
|
||||
go c.handleProxyStream(ctx, out)
|
||||
if err := encoder.Encode(authReq); err != nil {
|
||||
return false, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
decoder := json.NewDecoder(stream)
|
||||
authRes := &authResponse{}
|
||||
|
||||
readDeadline := time.Now().Add(c.conf.AuthenticationTimeout - time.Now().Sub(start))
|
||||
logger.Debug(ctx, "waiting for auth response", logger.F("deadline", readDeadline))
|
||||
|
||||
if err := stream.SetReadDeadline(readDeadline); err != nil {
|
||||
return false, errors.WithStack(err)
|
||||
}
|
||||
|
||||
if err := decoder.Decode(authRes); err != nil && !errors.Is(err, io.EOF) {
|
||||
return false, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return authRes.Success, nil
|
||||
}
|
||||
|
||||
func (c *Client) handleProxyStream(ctx context.Context, out net.Conn) {
|
||||
c.openStreamMutex.Lock()
|
||||
func (c *Client) handleProxyStream(ctx context.Context, in *smux.Stream, network, address string) {
|
||||
defer func(start time.Time) {
|
||||
logger.Debug(ctx, "handleProxyStream duration", logger.F("duration", time.Since(start)))
|
||||
}(time.Now())
|
||||
|
||||
in, err := c.sess.OpenStream()
|
||||
defer in.Close()
|
||||
|
||||
logger.Debug(
|
||||
ctx, "proxying",
|
||||
logger.F("network", network),
|
||||
logger.F("address", address),
|
||||
)
|
||||
|
||||
out, err := net.Dial(network, address)
|
||||
if err != nil {
|
||||
c.openStreamMutex.Unlock()
|
||||
logger.Error(ctx, "error while accepting proxy stream", logger.E(err))
|
||||
logger.Error(ctx, "could not dial", logger.E(errors.WithStack(err)))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
c.openStreamMutex.Unlock()
|
||||
defer out.Close()
|
||||
|
||||
streamCopy := func(dst io.Writer, src io.ReadCloser) {
|
||||
if _, err := Copy(dst, src); err != nil {
|
||||
if errors.Is(err, smux.ErrInvalidProtocol) {
|
||||
logger.Error(ctx, "error while proxying", logger.E(errors.WithStack(err)))
|
||||
logger.Error(ctx, "could not proxy", logger.E(errors.WithStack(err)))
|
||||
}
|
||||
}
|
||||
|
||||
logger.Debug(ctx, "closing proxy stream")
|
||||
|
||||
in.Close()
|
||||
out.Close()
|
||||
}
|
||||
|
||||
go streamCopy(in, out)
|
||||
streamCopy(out, in)
|
||||
go streamCopy(out, in)
|
||||
streamCopy(in, out)
|
||||
}
|
||||
|
||||
func NewClient(funcs ...ClientConfigFunc) *Client {
|
||||
@ -167,6 +210,5 @@ func NewClient(funcs ...ClientConfigFunc) *Client {
|
||||
|
||||
return &Client{
|
||||
conf: conf,
|
||||
http: &http.Client{},
|
||||
}
|
||||
}
|
||||
|
@ -2,33 +2,49 @@ package tunnel
|
||||
|
||||
import (
|
||||
"crypto/sha1"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/xtaci/kcp-go/v5"
|
||||
"github.com/xtaci/smux"
|
||||
"golang.org/x/crypto/pbkdf2"
|
||||
)
|
||||
|
||||
type ClientConfig struct {
|
||||
ServerAddress string
|
||||
BlockCrypt kcp.BlockCrypt
|
||||
DataShards int
|
||||
ParityShards int
|
||||
Credentials interface{}
|
||||
ConfigureConn ConfigureConnFunc
|
||||
ServerAddress string
|
||||
BlockCrypt kcp.BlockCrypt
|
||||
DataShards int
|
||||
ParityShards int
|
||||
Credentials interface{}
|
||||
ConfigureConn ConfigureConnFunc
|
||||
AuthenticationTimeout time.Duration
|
||||
ProxyRequestTimeout time.Duration
|
||||
SmuxConfig *smux.Config
|
||||
}
|
||||
|
||||
// nolint: go-mnd
|
||||
func DefaultClientConfig() *ClientConfig {
|
||||
unencryptedBlock, err := kcp.NewNoneBlockCrypt(nil)
|
||||
if err != nil { // should never happen
|
||||
panic(errors.WithStack(err))
|
||||
}
|
||||
|
||||
smuxConfig := smux.DefaultConfig()
|
||||
smuxConfig.Version = 2
|
||||
smuxConfig.KeepAliveInterval = 10 * time.Second
|
||||
smuxConfig.MaxReceiveBuffer = 4194304
|
||||
smuxConfig.MaxStreamBuffer = 2097152
|
||||
|
||||
return &ClientConfig{
|
||||
ServerAddress: "127.0.0.1:36543",
|
||||
BlockCrypt: unencryptedBlock,
|
||||
DataShards: 3,
|
||||
ParityShards: 10,
|
||||
Credentials: nil,
|
||||
ServerAddress: "127.0.0.1:36543",
|
||||
BlockCrypt: unencryptedBlock,
|
||||
DataShards: 3,
|
||||
ParityShards: 10,
|
||||
Credentials: nil,
|
||||
ConfigureConn: DefaultClientConfigureConn,
|
||||
AuthenticationTimeout: 30 * time.Second,
|
||||
ProxyRequestTimeout: 5 * time.Second,
|
||||
SmuxConfig: smuxConfig,
|
||||
}
|
||||
}
|
||||
|
||||
@ -64,3 +80,34 @@ func WithClientConfigureConn(fn ConfigureConnFunc) ClientConfigFunc {
|
||||
conf.ConfigureConn = fn
|
||||
}
|
||||
}
|
||||
|
||||
func WithClientSmuxConfig(c *smux.Config) ClientConfigFunc {
|
||||
return func(conf *ClientConfig) {
|
||||
conf.SmuxConfig = c
|
||||
}
|
||||
}
|
||||
|
||||
// nolint: go-mnd
|
||||
func DefaultClientConfigureConn(conn *kcp.UDPSession) error {
|
||||
// Based on kcptun default configuration, mode 'fast3'
|
||||
conn.SetStreamMode(true)
|
||||
conn.SetWriteDelay(false)
|
||||
conn.SetNoDelay(1, 10, 2, 1)
|
||||
conn.SetWindowSize(128, 512)
|
||||
conn.SetMtu(1400)
|
||||
conn.SetACKNoDelay(true)
|
||||
|
||||
if err := conn.SetReadBuffer(16777217); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
if err := conn.SetWriteBuffer(16777217); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
if err := conn.SetDSCP(46); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -3,6 +3,7 @@ package main
|
||||
import (
|
||||
"context"
|
||||
"flag"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"time"
|
||||
|
||||
@ -11,15 +12,18 @@ import (
|
||||
"gitlab.com/wpetit/goweb/logger"
|
||||
)
|
||||
|
||||
const sharedKey = "go-tunnel"
|
||||
const salt = "go-tunnel"
|
||||
|
||||
func main() {
|
||||
var (
|
||||
clientID string
|
||||
clientID = fmt.Sprintf("client-%d", time.Now().Unix())
|
||||
serverAddr = "127.0.0.1:36543"
|
||||
sharedKey = "go-tunnel"
|
||||
)
|
||||
|
||||
flag.StringVar(&clientID, "id", "", "Client ID")
|
||||
flag.StringVar(&sharedKey, "shared-key", sharedKey, "shared key")
|
||||
flag.StringVar(&clientID, "id", clientID, "Client ID")
|
||||
flag.StringVar(&serverAddr, "server-addr", serverAddr, "server address")
|
||||
flag.Parse()
|
||||
|
||||
ctx := context.Background()
|
||||
@ -28,12 +32,12 @@ func main() {
|
||||
logger.SetLevel(slog.LevelDebug)
|
||||
|
||||
client := tunnel.NewClient(
|
||||
tunnel.WithClientServerAddress(serverAddr),
|
||||
tunnel.WithClientCredentials(clientID),
|
||||
tunnel.WithClientAESBlockCrypt(sharedKey, salt),
|
||||
)
|
||||
defer client.Close()
|
||||
|
||||
initialBackoff := time.Second * 10
|
||||
initialBackoff := time.Second * 2
|
||||
backoff := initialBackoff
|
||||
|
||||
sleep := func() {
|
||||
|
@ -2,12 +2,9 @@ package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"flag"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"cdr.dev/slog"
|
||||
"forge.cadoles.com/wpetit/go-tunnel"
|
||||
@ -15,20 +12,61 @@ import (
|
||||
"gitlab.com/wpetit/goweb/logger"
|
||||
)
|
||||
|
||||
const sharedKey = "go-tunnel"
|
||||
const salt = "go-tunnel"
|
||||
|
||||
var registry = NewRegistry()
|
||||
|
||||
func main() {
|
||||
var (
|
||||
serverAddr = ":36543"
|
||||
httpAddr = ":3003"
|
||||
sharedKey = "go-tunnel"
|
||||
targetURL = "https://arcad.games"
|
||||
)
|
||||
|
||||
flag.StringVar(&serverAddr, "server-addr", serverAddr, "server address")
|
||||
flag.StringVar(&targetURL, "target-url", targetURL, "target url")
|
||||
flag.StringVar(&httpAddr, "http-addr", httpAddr, "http server address")
|
||||
flag.StringVar(&sharedKey, "shared-key", sharedKey, "shared key")
|
||||
|
||||
flag.Parse()
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
logger.SetLevel(slog.LevelDebug)
|
||||
|
||||
server := tunnel.NewServer(
|
||||
tunnel.WithServerAESBlockCrypt(sharedKey, salt),
|
||||
tunnel.WithServerAddress(serverAddr),
|
||||
tunnel.WithServerOnClientAuth(registry.OnClientAuth),
|
||||
tunnel.WithServerOnClientDisconnect(registry.OnClientDisconnect),
|
||||
tunnel.WithServerOnClientAuth(func(ctx context.Context, remoteClient *tunnel.RemoteClient, credentials interface{}) (bool, error) {
|
||||
remoteAddr := remoteClient.RemoteAddr().String()
|
||||
|
||||
ctx = logger.With(ctx, logger.F("remoteAddr", remoteAddr))
|
||||
|
||||
logger.Debug(ctx, "new client auth")
|
||||
|
||||
clientID, ok := credentials.(string)
|
||||
if !ok {
|
||||
logger.Debug(ctx, "client auth failed")
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
registry.Add(clientID, remoteAddr, remoteClient)
|
||||
|
||||
logger.Debug(ctx, "client auth success")
|
||||
|
||||
return true, nil
|
||||
}),
|
||||
tunnel.WithServerOnClientDisconnect(func(ctx context.Context, remoteClient *tunnel.RemoteClient) error {
|
||||
remoteAddr := remoteClient.RemoteAddr().String()
|
||||
ctx = logger.With(ctx, logger.F("remoteAddr", remoteAddr))
|
||||
logger.Debug(ctx, "client disconnected")
|
||||
registry.RemoveByRemoteAddr(remoteAddr)
|
||||
|
||||
return nil
|
||||
}),
|
||||
)
|
||||
|
||||
go func() {
|
||||
@ -37,51 +75,28 @@ func main() {
|
||||
}
|
||||
}()
|
||||
|
||||
if err := http.ListenAndServe(":3003", http.HandlerFunc(handleRequest)); err != nil {
|
||||
handler, err := createProxyHandler(targetURL)
|
||||
if err != nil {
|
||||
logger.Fatal(ctx, "could not create proxy handler", logger.E(errors.WithStack(err)))
|
||||
}
|
||||
|
||||
if err := http.ListenAndServe(httpAddr, handler); err != nil {
|
||||
logger.Fatal(ctx, "error while listening", logger.E(err))
|
||||
}
|
||||
}
|
||||
|
||||
func handleRequest(w http.ResponseWriter, r *http.Request) {
|
||||
subdomains := strings.SplitN(r.Host, ".", 2)
|
||||
func createProxyHandler(targetURL string) (http.Handler, error) {
|
||||
return tunnel.ProxyHandler(targetURL, func(w http.ResponseWriter, r *http.Request) (*tunnel.RemoteClient, error) {
|
||||
subdomains := strings.SplitN(r.Host, ".", 2)
|
||||
|
||||
if len(subdomains) < 2 {
|
||||
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
|
||||
if len(subdomains) < 2 {
|
||||
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
|
||||
|
||||
return
|
||||
}
|
||||
return nil, tunnel.ErrAbortProxy
|
||||
}
|
||||
|
||||
clientID := subdomains[0]
|
||||
remoteClient := registry.Get(clientID)
|
||||
clientID := subdomains[0]
|
||||
|
||||
if remoteClient == nil {
|
||||
http.Error(w, http.StatusText(http.StatusBadGateway), http.StatusBadGateway)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
target, err := url.Parse("https://arcad.games")
|
||||
if err != nil {
|
||||
logger.Fatal(r.Context(), "could not parse url", logger.E(err))
|
||||
}
|
||||
|
||||
reverse := httputil.NewSingleHostReverseProxy(target)
|
||||
reverse.Transport = &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: func(ctx context.Context, network string, addr string) (net.Conn, error) {
|
||||
conn, err := remoteClient.Proxy(ctx, network, addr)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
},
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
|
||||
reverse.ServeHTTP(w, r)
|
||||
return registry.Get(clientID), nil
|
||||
})
|
||||
}
|
||||
|
@ -1,9 +0,0 @@
|
||||
package control
|
||||
|
||||
type AuthRequestPayload struct {
|
||||
Credentials interface{} `json:"c"`
|
||||
}
|
||||
|
||||
type AuthResponsePayload struct {
|
||||
Success bool `json:"s"`
|
||||
}
|
@ -1,194 +0,0 @@
|
||||
package control
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/xtaci/smux"
|
||||
"gitlab.com/wpetit/goweb/logger"
|
||||
)
|
||||
|
||||
type Control struct {
|
||||
encoder *json.Encoder
|
||||
decoder *json.Decoder
|
||||
stream *smux.Stream
|
||||
}
|
||||
|
||||
func (c *Control) Init(ctx context.Context, sess *smux.Session, serverMode bool) error {
|
||||
config := smux.DefaultConfig()
|
||||
config.Version = 2
|
||||
|
||||
logger.Debug(ctx, "creating control stream")
|
||||
|
||||
var (
|
||||
stream *smux.Stream
|
||||
err error
|
||||
)
|
||||
|
||||
if serverMode {
|
||||
stream, err = sess.AcceptStream()
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
} else {
|
||||
stream, err = sess.OpenStream()
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
}
|
||||
|
||||
c.stream = stream
|
||||
c.decoder = json.NewDecoder(stream)
|
||||
c.encoder = json.NewEncoder(stream)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Control) AuthRequest(credentials interface{}) (bool, error) {
|
||||
req := NewMessage(TypeAuthRequest, &AuthRequestPayload{
|
||||
Credentials: credentials,
|
||||
})
|
||||
|
||||
res := NewMessage(TypeAuthResponse, nil)
|
||||
|
||||
if err := c.reqRes(req, res); err != nil {
|
||||
return false, errors.WithStack(err)
|
||||
}
|
||||
|
||||
authResPayload, ok := res.Payload.(*AuthResponsePayload)
|
||||
if !ok {
|
||||
return false, errors.WithStack(ErrUnexpectedMessage)
|
||||
}
|
||||
|
||||
return authResPayload.Success, nil
|
||||
}
|
||||
|
||||
func (c *Control) ProxyReq(ctx context.Context, network, address string) error {
|
||||
req := NewMessage(TypeProxyRequest, &ProxyRequestPayload{
|
||||
Network: network,
|
||||
Address: address,
|
||||
})
|
||||
|
||||
if err := c.Write(req); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Control) Listen(ctx context.Context, handlers Handlers) error {
|
||||
errChan := make(chan error)
|
||||
msgChan := make(chan *Message)
|
||||
dieChan := c.stream.GetDieCh()
|
||||
|
||||
go func(msgChan chan *Message, errChan chan error) {
|
||||
for {
|
||||
logger.Debug(ctx, "reading next message")
|
||||
|
||||
msg, err := c.Read()
|
||||
if err != nil {
|
||||
errChan <- errors.WithStack(err)
|
||||
|
||||
close(errChan)
|
||||
close(msgChan)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
msgChan <- msg
|
||||
}
|
||||
}(msgChan, errChan)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil
|
||||
|
||||
case <-dieChan:
|
||||
return errors.WithStack(ErrStreamClosed)
|
||||
|
||||
case err := <-errChan:
|
||||
return errors.WithStack(err)
|
||||
|
||||
case msg := <-msgChan:
|
||||
go func() {
|
||||
subCtx := logger.With(ctx, logger.F("messageType", msg.Type))
|
||||
|
||||
handler, exists := handlers[msg.Type]
|
||||
if !exists {
|
||||
logger.Error(subCtx, "no message handler registered")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
res, err := handler(subCtx, msg)
|
||||
if err != nil {
|
||||
logger.Error(subCtx, "error while handling message", logger.E(err))
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if res == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if err := c.Write(res); err != nil {
|
||||
logger.Error(subCtx, "error while write message response", logger.E(err))
|
||||
|
||||
return
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Control) Read() (*Message, error) {
|
||||
message := &Message{}
|
||||
|
||||
if err := c.read(message); err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return message, nil
|
||||
}
|
||||
|
||||
func (c *Control) Write(m *Message) error {
|
||||
if err := c.write(m); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Control) reqRes(req *Message, res *Message) error {
|
||||
if err := c.write(req); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
if err := c.read(res); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Control) read(m *Message) error {
|
||||
if err := c.decoder.Decode(m); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Control) write(m *Message) error {
|
||||
if err := c.encoder.Encode(m); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func New() *Control {
|
||||
return &Control{}
|
||||
}
|
@ -1,8 +0,0 @@
|
||||
package control
|
||||
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrStreamClosed = errors.New("stream closed")
|
||||
ErrUnexpectedMessage = errors.New("unexpected message")
|
||||
)
|
@ -1,7 +0,0 @@
|
||||
package control
|
||||
|
||||
import "context"
|
||||
|
||||
type Handlers map[MessageType]MessageHandler
|
||||
|
||||
type MessageHandler func(ctx context.Context, m *Message) (*Message, error)
|
@ -1,76 +0,0 @@
|
||||
package control
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
TypeAuthRequest MessageType = "auth-req"
|
||||
TypeAuthResponse MessageType = "auth-res"
|
||||
TypeProxyRequest MessageType = "proxy-req"
|
||||
TypeCloseProxy MessageType = "close-proxy"
|
||||
)
|
||||
|
||||
type MessageType string
|
||||
|
||||
type BaseMessage struct {
|
||||
Type MessageType `json:"t"`
|
||||
RawPayload json.RawMessage `json:"p"`
|
||||
}
|
||||
|
||||
type Message struct {
|
||||
BaseMessage
|
||||
Payload interface{} `json:"p"`
|
||||
}
|
||||
|
||||
func (m *Message) UnmarshalJSON(data []byte) error {
|
||||
base := &BaseMessage{}
|
||||
|
||||
if err := json.Unmarshal(data, base); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
payload, err := unmarshalPayload(base.Type, base.RawPayload)
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
m.Type = base.Type
|
||||
m.Payload = payload
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewMessage(mType MessageType, payload interface{}) *Message {
|
||||
return &Message{
|
||||
BaseMessage: BaseMessage{
|
||||
Type: mType,
|
||||
},
|
||||
Payload: payload,
|
||||
}
|
||||
}
|
||||
|
||||
func unmarshalPayload(mType MessageType, data []byte) (interface{}, error) {
|
||||
var payload interface{}
|
||||
|
||||
switch mType {
|
||||
case TypeAuthRequest:
|
||||
payload = &AuthRequestPayload{}
|
||||
case TypeAuthResponse:
|
||||
payload = &AuthResponsePayload{}
|
||||
case TypeProxyRequest:
|
||||
payload = &ProxyRequestPayload{}
|
||||
case TypeCloseProxy:
|
||||
payload = &CloseProxyPayload{}
|
||||
default:
|
||||
return nil, errors.Wrapf(ErrUnexpectedMessage, "unexpected message type '%s'", mType)
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(data, payload); err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return payload, nil
|
||||
}
|
@ -1,10 +0,0 @@
|
||||
package control
|
||||
|
||||
type ProxyRequestPayload struct {
|
||||
Network string `json:"n"`
|
||||
Address string `json:"a"`
|
||||
}
|
||||
|
||||
type CloseProxyPayload struct {
|
||||
RequestID int64 `json:"i"`
|
||||
}
|
12
error.go
12
error.go
@ -3,10 +3,10 @@ package tunnel
|
||||
import "errors"
|
||||
|
||||
var (
|
||||
ErrNotConnected = errors.New("not connected")
|
||||
ErrCouldNotConnect = errors.New("could not connect")
|
||||
ErrConnectionClosed = errors.New("connection closed")
|
||||
ErrAuthFailed = errors.New("auth failed")
|
||||
ErrUnexpectedMessage = errors.New("unexpected message")
|
||||
ErrUnexpectedResponse = errors.New("unexpected response")
|
||||
ErrNotConnected = errors.New("not connected")
|
||||
ErrCouldNotConnect = errors.New("could not connect")
|
||||
ErrConnectionClosed = errors.New("connection closed")
|
||||
ErrAuthenticationFailed = errors.New("authentication failed")
|
||||
ErrUnexpectedMessage = errors.New("unexpected message")
|
||||
ErrUnexpectedResponse = errors.New("unexpected response")
|
||||
)
|
||||
|
1
go.mod
1
go.mod
@ -4,6 +4,7 @@ go 1.15
|
||||
|
||||
require (
|
||||
cdr.dev/slog v1.3.0
|
||||
github.com/davecgh/go-spew v1.1.1
|
||||
github.com/orcaman/concurrent-map v0.0.0-20190826125027-8c72a8bb44f6
|
||||
github.com/pkg/errors v0.9.1
|
||||
github.com/streamrail/concurrent-map v0.0.0-20160823150647-8bf1e9bacbf6
|
||||
|
102
http.go
Normal file
102
http.go
Normal file
@ -0,0 +1,102 @@
|
||||
package tunnel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"gitlab.com/wpetit/goweb/logger"
|
||||
)
|
||||
|
||||
type contextKey string
|
||||
|
||||
const remoteClientKey contextKey = "go-tunnel.remoteclient"
|
||||
|
||||
var (
|
||||
ErrAbortProxy = errors.New("proxy aborted")
|
||||
)
|
||||
|
||||
type MatchRequestFunc func(w http.ResponseWriter, r *http.Request) (*RemoteClient, error)
|
||||
|
||||
func ProxyHandler(targetURL string, match MatchRequestFunc, funcs ...ProxyConfigFunc) (http.Handler, error) {
|
||||
conf := DefaultProxyConfig()
|
||||
|
||||
for _, fn := range funcs {
|
||||
fn(conf)
|
||||
}
|
||||
|
||||
target, err := url.Parse(targetURL)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
reverse := createReverseProxy(target)
|
||||
|
||||
if conf.ConfigureReverseProxy != nil {
|
||||
if err := conf.ConfigureReverseProxy(reverse); err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
}
|
||||
|
||||
fn := func(w http.ResponseWriter, r *http.Request) {
|
||||
remoteClient, err := match(w, r)
|
||||
if errors.Is(err, ErrAbortProxy) {
|
||||
return
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
logger.Error(r.Context(), "could not match proxy request", logger.E(errors.WithStack(err)))
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
if remoteClient == nil {
|
||||
http.Error(w, http.StatusText(http.StatusBadGateway), http.StatusBadGateway)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
ctx := context.WithValue(r.Context(), remoteClientKey, remoteClient)
|
||||
r = r.WithContext(ctx)
|
||||
|
||||
reverse.ServeHTTP(w, r)
|
||||
}
|
||||
|
||||
return http.HandlerFunc(fn), nil
|
||||
}
|
||||
|
||||
func createReverseProxy(target *url.URL) *httputil.ReverseProxy {
|
||||
reverse := httputil.NewSingleHostReverseProxy(target)
|
||||
|
||||
// nolint: go-mnd
|
||||
reverse.Transport = &http.Transport{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
DialContext: func(ctx context.Context, network string, addr string) (net.Conn, error) {
|
||||
remoteClient, ok := ctx.Value(remoteClientKey).(*RemoteClient)
|
||||
if !ok {
|
||||
return nil, errors.New("could not retrieve remote client")
|
||||
}
|
||||
|
||||
conn, err := remoteClient.Proxy(ctx, network, addr)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
},
|
||||
ForceAttemptHTTP2: true,
|
||||
MaxIdleConns: 100,
|
||||
IdleConnTimeout: 90 * time.Second,
|
||||
TLSHandshakeTimeout: 10 * time.Second,
|
||||
ExpectContinueTimeout: 1 * time.Second,
|
||||
}
|
||||
|
||||
reverse.FlushInterval = 0
|
||||
|
||||
return reverse
|
||||
}
|
21
http_config.go
Normal file
21
http_config.go
Normal file
@ -0,0 +1,21 @@
|
||||
package tunnel
|
||||
|
||||
import "net/http/httputil"
|
||||
|
||||
type ConfigureReverseProxyFunc func(*httputil.ReverseProxy) error
|
||||
|
||||
type ProxyConfig struct {
|
||||
ConfigureReverseProxy ConfigureReverseProxyFunc
|
||||
}
|
||||
|
||||
func DefaultProxyConfig() *ProxyConfig {
|
||||
return &ProxyConfig{}
|
||||
}
|
||||
|
||||
type ProxyConfigFunc func(c *ProxyConfig)
|
||||
|
||||
func WithProxyConfigure(fn ConfigureReverseProxyFunc) ProxyConfigFunc {
|
||||
return func(c *ProxyConfig) {
|
||||
c.ConfigureReverseProxy = fn
|
||||
}
|
||||
}
|
@ -1,7 +1,7 @@
|
||||
**/*.go {
|
||||
**/*.go
|
||||
modd.conf {
|
||||
prep: make test
|
||||
prep: make build
|
||||
daemon: ./bin/server
|
||||
daemon: ./bin/server -target-url http://127.0.0.1:3000
|
||||
daemon: ./bin/client -id client1
|
||||
daemon: ./bin/client -id client2
|
||||
}
|
14
protocol.go
Normal file
14
protocol.go
Normal file
@ -0,0 +1,14 @@
|
||||
package tunnel
|
||||
|
||||
type authRequest struct {
|
||||
Credentials interface{} `json:"c"`
|
||||
}
|
||||
|
||||
type authResponse struct {
|
||||
Success bool `json:"b"`
|
||||
}
|
||||
|
||||
type proxyRequest struct {
|
||||
Network string `json:"n"`
|
||||
Address string `json:"a"`
|
||||
}
|
208
remote_client.go
208
remote_client.go
@ -2,13 +2,11 @@ package tunnel
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"forge.cadoles.com/wpetit/go-tunnel/control"
|
||||
|
||||
cmap "github.com/orcaman/concurrent-map"
|
||||
"github.com/pkg/errors"
|
||||
"github.com/xtaci/kcp-go/v5"
|
||||
"github.com/xtaci/smux"
|
||||
@ -19,61 +17,45 @@ type RemoteClient struct {
|
||||
onClientAuthHook OnClientAuthHook
|
||||
onClientConnectHook OnClientConnectHook
|
||||
onClientDisconnectHook OnClientDisconnectHook
|
||||
conn *kcp.UDPSession
|
||||
sess *smux.Session
|
||||
control *control.Control
|
||||
remoteAddr net.Addr
|
||||
proxies cmap.ConcurrentMap
|
||||
acceptStreamMutex sync.Mutex
|
||||
authenticationTimeout time.Duration
|
||||
proxyRequestTimeout time.Duration
|
||||
connMutex sync.RWMutex
|
||||
smuxConfig *smux.Config
|
||||
}
|
||||
|
||||
func (c *RemoteClient) Accept(ctx context.Context, conn *kcp.UDPSession) error {
|
||||
config := smux.DefaultConfig()
|
||||
config.Version = 2
|
||||
config.KeepAliveInterval = 10 * time.Second
|
||||
config.KeepAliveTimeout = 2 * config.KeepAliveInterval
|
||||
c.connMutex.Lock()
|
||||
defer c.connMutex.Unlock()
|
||||
|
||||
logger.Debug(ctx, "creating server session")
|
||||
if err := c.Close(); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
sess, err := smux.Server(conn, config)
|
||||
sess, err := c.acceptSession(ctx, conn)
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
ctrl := control.New()
|
||||
stream, err := sess.AcceptStream()
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
if err := ctrl.Init(ctx, sess, true); err != nil {
|
||||
defer stream.Close()
|
||||
|
||||
if err := c.authenticate(ctx, stream); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
c.sess = sess
|
||||
c.remoteAddr = conn.RemoteAddr()
|
||||
c.control = ctrl
|
||||
|
||||
if c.onClientConnectHook != nil {
|
||||
if err := c.onClientConnectHook.OnClientConnect(ctx, c); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
}
|
||||
c.conn = conn
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *RemoteClient) Listen(ctx context.Context) error {
|
||||
defer func() {
|
||||
if c.onClientDisconnectHook != nil {
|
||||
if err := c.onClientDisconnectHook.OnClientDisconnect(ctx, c); err != nil {
|
||||
logger.Error(ctx, "client disconnect hook error", logger.E(errors.WithStack(err)))
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
logger.Debug(ctx, "listening for messages")
|
||||
|
||||
return c.control.Listen(ctx, control.Handlers{
|
||||
control.TypeAuthRequest: c.handleAuthRequest,
|
||||
})
|
||||
}
|
||||
|
||||
func (c *RemoteClient) ConfigureHooks(hooks interface{}) {
|
||||
if hooks == nil {
|
||||
return
|
||||
@ -96,74 +78,164 @@ func (c *RemoteClient) RemoteAddr() net.Addr {
|
||||
return c.remoteAddr
|
||||
}
|
||||
|
||||
func (c *RemoteClient) Close() {
|
||||
func (c *RemoteClient) Close() error {
|
||||
if c.sess != nil {
|
||||
c.sess.Close()
|
||||
if err := c.sess.Close(); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
}
|
||||
|
||||
if c.conn != nil {
|
||||
if err := c.conn.Close(); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
}
|
||||
|
||||
c.sess = nil
|
||||
c.control = nil
|
||||
c.conn = nil
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *RemoteClient) SwitchConn(ctx context.Context, conn *kcp.UDPSession) error {
|
||||
c.connMutex.Lock()
|
||||
defer c.connMutex.Unlock()
|
||||
|
||||
if err := c.Close(); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
sess, err := c.acceptSession(ctx, conn)
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
c.sess = sess
|
||||
c.conn = conn
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *RemoteClient) Proxy(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
ctx = logger.With(ctx, logger.F("network", network), logger.F("address", address))
|
||||
c.connMutex.RLock()
|
||||
defer c.connMutex.RUnlock()
|
||||
|
||||
if err := c.control.ProxyReq(ctx, network, address); err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
ctx = logger.With(ctx, logger.F("network", network), logger.F("address", address))
|
||||
|
||||
logger.Debug(ctx, "opening proxy stream")
|
||||
|
||||
c.acceptStreamMutex.Lock()
|
||||
|
||||
stream, err := c.sess.AcceptStream()
|
||||
stream, err := c.sess.OpenStream()
|
||||
if err != nil {
|
||||
c.acceptStreamMutex.Unlock()
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
c.acceptStreamMutex.Unlock()
|
||||
proxyReq := &proxyRequest{
|
||||
Network: network,
|
||||
Address: address,
|
||||
}
|
||||
encoder := json.NewEncoder(stream)
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
logger.Debug(ctx, "closing proxy stream")
|
||||
writeDeadline := time.Now().Add(c.proxyRequestTimeout)
|
||||
logger.Debug(ctx, "sending proxy req", logger.F("deadline", writeDeadline))
|
||||
|
||||
if err := stream.SetWriteDeadline(writeDeadline); err != nil {
|
||||
stream.Close()
|
||||
}()
|
||||
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
if err := encoder.Encode(proxyReq); err != nil {
|
||||
stream.Close()
|
||||
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
if err := stream.SetWriteDeadline(time.Time{}); err != nil {
|
||||
stream.Close()
|
||||
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
func (c *RemoteClient) handleAuthRequest(ctx context.Context, m *control.Message) (*control.Message, error) {
|
||||
authReqPayload, ok := m.Payload.(*control.AuthRequestPayload)
|
||||
if !ok {
|
||||
return nil, errors.WithStack(ErrUnexpectedMessage)
|
||||
func (c *RemoteClient) acceptSession(ctx context.Context, conn *kcp.UDPSession) (*smux.Session, error) {
|
||||
logger.Debug(ctx, "accepting client session")
|
||||
|
||||
sess, err := smux.Server(conn, c.smuxConfig)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
logger.Debug(ctx, "handling auth request", logger.F("credentials", authReqPayload.Credentials))
|
||||
c.remoteAddr = conn.RemoteAddr()
|
||||
|
||||
if c.onClientConnectHook != nil {
|
||||
if err := c.onClientConnectHook.OnClientConnect(ctx, c); err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
}
|
||||
|
||||
return sess, nil
|
||||
}
|
||||
|
||||
func (c *RemoteClient) authenticate(ctx context.Context, stream *smux.Stream) error {
|
||||
start := time.Now()
|
||||
|
||||
readDeadline := time.Now().Add(c.authenticationTimeout)
|
||||
logger.Debug(ctx, "waiting for auth request", logger.F("deadline", readDeadline))
|
||||
|
||||
if err := stream.SetReadDeadline(readDeadline); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
decoder := json.NewDecoder(stream)
|
||||
authReq := &authRequest{}
|
||||
|
||||
if err := decoder.Decode(authReq); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
var (
|
||||
success bool
|
||||
err error
|
||||
)
|
||||
|
||||
logger.Debug(ctx, "received client credentials", logger.F("credentials", authReq.Credentials))
|
||||
|
||||
if c.onClientAuthHook != nil {
|
||||
success, err = c.onClientAuthHook.OnClientAuth(ctx, c, authReqPayload.Credentials)
|
||||
success, err = c.onClientAuthHook.OnClientAuth(ctx, c, authReq.Credentials)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Debug(ctx, "auth succeeded", logger.F("credentials", authReqPayload.Credentials))
|
||||
|
||||
res := control.NewMessage(control.TypeAuthResponse, &control.AuthResponsePayload{
|
||||
authRes := &authResponse{
|
||||
Success: success,
|
||||
})
|
||||
}
|
||||
encoder := json.NewEncoder(stream)
|
||||
|
||||
return res, nil
|
||||
writeDeadline := time.Now().Add(c.authenticationTimeout - time.Since(start))
|
||||
logger.Debug(ctx, "sending auth response", logger.F("deadline", writeDeadline))
|
||||
|
||||
if err := stream.SetWriteDeadline(writeDeadline); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
if err := encoder.Encode(authRes); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
if !success {
|
||||
return errors.WithStack(ErrAuthenticationFailed)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewRemoteClient() *RemoteClient {
|
||||
func NewRemoteClient(smuxConfig *smux.Config, authenticationTimeout, proxyRequestTimeout time.Duration) *RemoteClient {
|
||||
return &RemoteClient{
|
||||
proxies: cmap.New(),
|
||||
smuxConfig: smuxConfig,
|
||||
authenticationTimeout: authenticationTimeout,
|
||||
proxyRequestTimeout: proxyRequestTimeout,
|
||||
}
|
||||
}
|
||||
|
41
server.go
41
server.go
@ -3,8 +3,8 @@ package tunnel
|
||||
import (
|
||||
"context"
|
||||
|
||||
cmap "github.com/orcaman/concurrent-map"
|
||||
"github.com/pkg/errors"
|
||||
cmap "github.com/streamrail/concurrent-map"
|
||||
"github.com/xtaci/kcp-go/v5"
|
||||
"gitlab.com/wpetit/goweb/logger"
|
||||
)
|
||||
@ -23,6 +23,14 @@ func (s *Server) Listen(ctx context.Context) error {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
if s.conf.ConfigureListener != nil {
|
||||
if err := s.conf.ConfigureListener(listener); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
}
|
||||
|
||||
logger.Debug(ctx, "accepting connections", logger.F("address", s.conf.Address))
|
||||
|
||||
for {
|
||||
conn, err := listener.AcceptKCP()
|
||||
if err != nil {
|
||||
@ -34,12 +42,31 @@ func (s *Server) Listen(ctx context.Context) error {
|
||||
}
|
||||
|
||||
func (s *Server) handleNewConn(ctx context.Context, conn *kcp.UDPSession) {
|
||||
ctx = logger.With(ctx, logger.F("remoteAddr", conn.RemoteAddr().String()))
|
||||
var remoteClient *RemoteClient
|
||||
|
||||
remoteClient := NewRemoteClient()
|
||||
remoteAddr := conn.RemoteAddr().String()
|
||||
ctx = logger.With(ctx, logger.F("remoteAddr", remoteAddr))
|
||||
|
||||
defer remoteClient.Close()
|
||||
defer conn.Close()
|
||||
rawExistingClient, exists := s.clients.Get(remoteAddr)
|
||||
if exists {
|
||||
logger.Debug(ctx, "remote client already exists")
|
||||
|
||||
remoteClient, _ = rawExistingClient.(*RemoteClient)
|
||||
|
||||
if err := remoteClient.SwitchConn(ctx, conn); err != nil {
|
||||
logger.Error(ctx, "remote client error", logger.E(errors.WithStack(err)))
|
||||
|
||||
s.clients.Remove(remoteAddr)
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
remoteClient = NewRemoteClient(
|
||||
s.conf.SmuxConfig,
|
||||
s.conf.AuthenticationTimeout,
|
||||
s.conf.ProxyRequestTimeout,
|
||||
)
|
||||
|
||||
remoteClient.ConfigureHooks(s.conf.Hooks)
|
||||
|
||||
@ -49,9 +76,7 @@ func (s *Server) handleNewConn(ctx context.Context, conn *kcp.UDPSession) {
|
||||
return
|
||||
}
|
||||
|
||||
if err := remoteClient.Listen(ctx); err != nil {
|
||||
logger.Error(ctx, "remote client error", logger.E(errors.WithStack(err)))
|
||||
}
|
||||
s.clients.Set(remoteAddr, remoteClient)
|
||||
}
|
||||
|
||||
func NewServer(funcs ...ServerConfigFunc) *Server {
|
||||
|
@ -2,29 +2,43 @@ package tunnel
|
||||
|
||||
import (
|
||||
"crypto/sha1"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"github.com/xtaci/kcp-go/v5"
|
||||
"github.com/xtaci/smux"
|
||||
"golang.org/x/crypto/pbkdf2"
|
||||
)
|
||||
|
||||
type ConfigureConnFunc func(conn *kcp.UDPSession) error
|
||||
type ConfigureListenerFunc func(listener *kcp.Listener) error
|
||||
|
||||
type ServerConfig struct {
|
||||
Address string
|
||||
BlockCrypt kcp.BlockCrypt
|
||||
DataShards int
|
||||
ParityShards int
|
||||
Hooks *ServerHooks
|
||||
ConfigureConn ConfigureConnFunc
|
||||
Address string
|
||||
BlockCrypt kcp.BlockCrypt
|
||||
DataShards int
|
||||
ParityShards int
|
||||
Hooks *ServerHooks
|
||||
ConfigureConn ConfigureConnFunc
|
||||
ConfigureListener ConfigureListenerFunc
|
||||
AuthenticationTimeout time.Duration
|
||||
ProxyRequestTimeout time.Duration
|
||||
SmuxConfig *smux.Config
|
||||
}
|
||||
|
||||
// nolint: go-mnd
|
||||
func DefaultServerConfig() *ServerConfig {
|
||||
unencryptedBlock, err := kcp.NewNoneBlockCrypt(nil)
|
||||
if err != nil { // should never happen
|
||||
panic(errors.WithStack(err))
|
||||
}
|
||||
|
||||
smuxConfig := smux.DefaultConfig()
|
||||
smuxConfig.Version = 2
|
||||
smuxConfig.KeepAliveInterval = 10 * time.Second
|
||||
smuxConfig.MaxReceiveBuffer = 4194304
|
||||
smuxConfig.MaxStreamBuffer = 2097152
|
||||
|
||||
return &ServerConfig{
|
||||
Address: ":36543",
|
||||
BlockCrypt: unencryptedBlock,
|
||||
@ -35,6 +49,11 @@ func DefaultServerConfig() *ServerConfig {
|
||||
onClientDisconnect: DefaultOnClientDisconnect,
|
||||
onClientAuth: DefaultOnClientAuth,
|
||||
},
|
||||
ConfigureConn: DefaultServerConfigureConn,
|
||||
ConfigureListener: DefaultServerConfigureListener,
|
||||
AuthenticationTimeout: 30 * time.Second,
|
||||
ProxyRequestTimeout: 5 * time.Second,
|
||||
SmuxConfig: smuxConfig,
|
||||
}
|
||||
}
|
||||
|
||||
@ -82,3 +101,50 @@ func WithServerConfigureConn(fn ConfigureConnFunc) ServerConfigFunc {
|
||||
conf.ConfigureConn = fn
|
||||
}
|
||||
}
|
||||
|
||||
func WithServerConfigureListener(fn ConfigureListenerFunc) ServerConfigFunc {
|
||||
return func(conf *ServerConfig) {
|
||||
conf.ConfigureListener = fn
|
||||
}
|
||||
}
|
||||
|
||||
func WithServerSmuxConfig(c *smux.Config) ServerConfigFunc {
|
||||
return func(conf *ServerConfig) {
|
||||
conf.SmuxConfig = c
|
||||
}
|
||||
}
|
||||
|
||||
// nolint: go-mnd
|
||||
func DefaultServerConfigureConn(conn *kcp.UDPSession) error {
|
||||
// Based on kcptun default configuration, mode 'fast3'
|
||||
conn.SetStreamMode(true)
|
||||
conn.SetWriteDelay(false)
|
||||
conn.SetNoDelay(1, 10, 2, 1)
|
||||
conn.SetWindowSize(128, 512)
|
||||
conn.SetMtu(1400)
|
||||
conn.SetACKNoDelay(true)
|
||||
|
||||
if err := conn.SetDSCP(46); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// nolint: go-mnd
|
||||
func DefaultServerConfigureListener(listener *kcp.Listener) error {
|
||||
// Based on kcptun default configuration, mode 'fast3'
|
||||
if err := listener.SetReadBuffer(16777217); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
if err := listener.SetWriteBuffer(16777217); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
if err := listener.SetDSCP(46); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
Loading…
Reference in New Issue
Block a user