feat: general protocol rewrite

This commit is contained in:
wpetit 2020-10-26 19:42:07 +01:00
parent 536100da90
commit 96c1575a0b
20 changed files with 628 additions and 523 deletions

View File

@ -4,10 +4,10 @@ watch:
build: build-server build-client build: build-server build-client
build-server: build-server:
go build -o ./bin/server ./cmd/server CGO_ENABLED=0 go build -o ./bin/server ./cmd/server
build-client: build-client:
go build -o ./bin/client ./cmd/client CGO_ENABLED=0 go build -o ./bin/client ./cmd/client
test: test:
go test -v -race ./... go test -v -race ./...

176
client.go
View File

@ -2,15 +2,13 @@ package tunnel
import ( import (
"context" "context"
"encoding/json"
"io" "io"
"net" "net"
"net/http"
"sync"
"time" "time"
"gitlab.com/wpetit/goweb/logger" "gitlab.com/wpetit/goweb/logger"
"forge.cadoles.com/wpetit/go-tunnel/control"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/xtaci/kcp-go/v5" "github.com/xtaci/kcp-go/v5"
"github.com/xtaci/smux" "github.com/xtaci/smux"
@ -20,12 +18,11 @@ type Client struct {
conf *ClientConfig conf *ClientConfig
conn *kcp.UDPSession conn *kcp.UDPSession
sess *smux.Session sess *smux.Session
control *control.Control
http *http.Client
openStreamMutex sync.Mutex
} }
func (c *Client) Connect(ctx context.Context) error { func (c *Client) Connect(ctx context.Context) error {
logger.Debug(ctx, "connecting", logger.F("serverAddr", c.conf.ServerAddress))
conn, err := kcp.DialWithOptions( conn, err := kcp.DialWithOptions(
c.conf.ServerAddress, c.conf.BlockCrypt, c.conf.ServerAddress, c.conf.BlockCrypt,
c.conf.DataShards, c.conf.ParityShards, c.conf.DataShards, c.conf.ParityShards,
@ -40,34 +37,29 @@ func (c *Client) Connect(ctx context.Context) error {
} }
} }
config := smux.DefaultConfig() sess, err := smux.Client(conn, c.conf.SmuxConfig)
config.Version = 2
config.KeepAliveInterval = 10 * time.Second
config.KeepAliveTimeout = 2 * config.KeepAliveInterval
sess, err := smux.Client(conn, config)
if err != nil { if err != nil {
return errors.WithStack(err) return errors.WithStack(err)
} }
control := control.New() stream, err := sess.OpenStream()
if err := control.Init(ctx, sess, false); err != nil { if err != nil {
return errors.WithStack(err) return errors.WithStack(err)
} }
logger.Debug(ctx, "sending auth request") defer stream.Close()
success, err := control.AuthRequest(c.conf.Credentials) success, err := c.authenticate(ctx, stream)
if err != nil { if err != nil {
return errors.WithStack(err) return errors.WithStack(err)
} }
if !success { if !success {
defer c.Close() return errors.WithStack(ErrAuthenticationFailed)
return errors.WithStack(ErrAuthFailed)
} }
c.control = control logger.Debug(ctx, "authentication success")
c.conn = conn c.conn = conn
c.sess = sess c.sess = sess
@ -75,87 +67,138 @@ func (c *Client) Connect(ctx context.Context) error {
} }
func (c *Client) Listen(ctx context.Context) error { func (c *Client) Listen(ctx context.Context) error {
logger.Debug(ctx, "listening for messages") logger.Debug(ctx, "listening for proxy requests")
ctx, cancel := context.WithCancel(ctx) for {
defer cancel() stream, err := c.sess.AcceptStream()
if err != nil {
err := c.control.Listen(ctx, control.Handlers{ return errors.WithStack(err)
control.TypeProxyRequest: c.handleProxyRequest,
})
if errors.Is(err, io.ErrClosedPipe) {
logger.Debug(ctx, "client connection closed")
return errors.WithStack(ErrConnectionClosed)
} }
return err subCtx := logger.With(ctx,
logger.F("remoteAddr", stream.RemoteAddr()),
logger.F("localAddr", stream.LocalAddr()),
)
readDeadline := time.Now().Add(c.conf.ProxyRequestTimeout)
logger.Debug(subCtx, "waiting for proxy request", logger.F("deadline", readDeadline))
if err := stream.SetReadDeadline(readDeadline); err != nil {
stream.Close()
logger.Error(subCtx, "could not set read deadline", logger.E(errors.WithStack(err)))
continue
}
decoder := json.NewDecoder(stream)
proxyReq := &proxyRequest{}
if err := decoder.Decode(proxyReq); err != nil {
stream.Close()
logger.Error(subCtx, "could not decode proxy request", logger.E(errors.WithStack(err)))
continue
}
if err := stream.SetReadDeadline(time.Time{}); err != nil {
stream.Close()
logger.Error(subCtx, "could not set read deadline", logger.E(errors.WithStack(err)))
continue
}
go c.handleProxyStream(subCtx, stream, proxyReq.Network, proxyReq.Address)
}
} }
func (c *Client) Close() error { func (c *Client) Close() error {
if c.conn == nil { if c.sess != nil && !c.sess.IsClosed() {
return errors.WithStack(ErrNotConnected) if err := c.sess.Close(); err != nil {
return errors.WithStack(err)
}
} }
if c.conn != nil {
if err := c.conn.Close(); err != nil { if err := c.conn.Close(); err != nil {
return errors.WithStack(err) return errors.WithStack(err)
} }
}
c.conn = nil
c.sess = nil
return nil return nil
} }
func (c *Client) handleProxyRequest(ctx context.Context, m *control.Message) (*control.Message, error) { func (c *Client) authenticate(ctx context.Context, stream *smux.Stream) (bool, error) {
proxyReqPayload, ok := m.Payload.(*control.ProxyRequestPayload) encoder := json.NewEncoder(stream)
if !ok { authReq := &authRequest{
return nil, errors.WithStack(ErrUnexpectedMessage) Credentials: c.conf.Credentials,
} }
ctx = logger.With(ctx, start := time.Now()
logger.F("network", proxyReqPayload.Network), writeDeadline := start.Add(c.conf.AuthenticationTimeout)
logger.F("address", proxyReqPayload.Address), logger.Debug(ctx, "sending auth request", logger.F("deadline", writeDeadline))
if err := stream.SetWriteDeadline(writeDeadline); err != nil {
return false, errors.WithStack(err)
}
if err := encoder.Encode(authReq); err != nil {
return false, errors.WithStack(err)
}
decoder := json.NewDecoder(stream)
authRes := &authResponse{}
readDeadline := time.Now().Add(c.conf.AuthenticationTimeout - time.Now().Sub(start))
logger.Debug(ctx, "waiting for auth response", logger.F("deadline", readDeadline))
if err := stream.SetReadDeadline(readDeadline); err != nil {
return false, errors.WithStack(err)
}
if err := decoder.Decode(authRes); err != nil && !errors.Is(err, io.EOF) {
return false, errors.WithStack(err)
}
return authRes.Success, nil
}
func (c *Client) handleProxyStream(ctx context.Context, in *smux.Stream, network, address string) {
defer func(start time.Time) {
logger.Debug(ctx, "handleProxyStream duration", logger.F("duration", time.Since(start)))
}(time.Now())
defer in.Close()
logger.Debug(
ctx, "proxying",
logger.F("network", network),
logger.F("address", address),
) )
logger.Debug(ctx, "handling proxy request") out, err := net.Dial(network, address)
out, err := net.Dial(proxyReqPayload.Network, proxyReqPayload.Address)
if err != nil { if err != nil {
return nil, errors.WithStack(err) logger.Error(ctx, "could not dial", logger.E(errors.WithStack(err)))
}
go c.handleProxyStream(ctx, out)
return nil, nil
}
func (c *Client) handleProxyStream(ctx context.Context, out net.Conn) {
c.openStreamMutex.Lock()
in, err := c.sess.OpenStream()
if err != nil {
c.openStreamMutex.Unlock()
logger.Error(ctx, "error while accepting proxy stream", logger.E(err))
return return
} }
defer out.Close()
c.openStreamMutex.Unlock()
streamCopy := func(dst io.Writer, src io.ReadCloser) { streamCopy := func(dst io.Writer, src io.ReadCloser) {
if _, err := Copy(dst, src); err != nil { if _, err := Copy(dst, src); err != nil {
if errors.Is(err, smux.ErrInvalidProtocol) { if errors.Is(err, smux.ErrInvalidProtocol) {
logger.Error(ctx, "error while proxying", logger.E(errors.WithStack(err))) logger.Error(ctx, "could not proxy", logger.E(errors.WithStack(err)))
} }
} }
logger.Debug(ctx, "closing proxy stream")
in.Close() in.Close()
out.Close() out.Close()
} }
go streamCopy(in, out) go streamCopy(out, in)
streamCopy(out, in) streamCopy(in, out)
} }
func NewClient(funcs ...ClientConfigFunc) *Client { func NewClient(funcs ...ClientConfigFunc) *Client {
@ -167,6 +210,5 @@ func NewClient(funcs ...ClientConfigFunc) *Client {
return &Client{ return &Client{
conf: conf, conf: conf,
http: &http.Client{},
} }
} }

View File

@ -2,9 +2,11 @@ package tunnel
import ( import (
"crypto/sha1" "crypto/sha1"
"time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/xtaci/kcp-go/v5" "github.com/xtaci/kcp-go/v5"
"github.com/xtaci/smux"
"golang.org/x/crypto/pbkdf2" "golang.org/x/crypto/pbkdf2"
) )
@ -15,20 +17,34 @@ type ClientConfig struct {
ParityShards int ParityShards int
Credentials interface{} Credentials interface{}
ConfigureConn ConfigureConnFunc ConfigureConn ConfigureConnFunc
AuthenticationTimeout time.Duration
ProxyRequestTimeout time.Duration
SmuxConfig *smux.Config
} }
// nolint: go-mnd
func DefaultClientConfig() *ClientConfig { func DefaultClientConfig() *ClientConfig {
unencryptedBlock, err := kcp.NewNoneBlockCrypt(nil) unencryptedBlock, err := kcp.NewNoneBlockCrypt(nil)
if err != nil { // should never happen if err != nil { // should never happen
panic(errors.WithStack(err)) panic(errors.WithStack(err))
} }
smuxConfig := smux.DefaultConfig()
smuxConfig.Version = 2
smuxConfig.KeepAliveInterval = 10 * time.Second
smuxConfig.MaxReceiveBuffer = 4194304
smuxConfig.MaxStreamBuffer = 2097152
return &ClientConfig{ return &ClientConfig{
ServerAddress: "127.0.0.1:36543", ServerAddress: "127.0.0.1:36543",
BlockCrypt: unencryptedBlock, BlockCrypt: unencryptedBlock,
DataShards: 3, DataShards: 3,
ParityShards: 10, ParityShards: 10,
Credentials: nil, Credentials: nil,
ConfigureConn: DefaultClientConfigureConn,
AuthenticationTimeout: 30 * time.Second,
ProxyRequestTimeout: 5 * time.Second,
SmuxConfig: smuxConfig,
} }
} }
@ -64,3 +80,34 @@ func WithClientConfigureConn(fn ConfigureConnFunc) ClientConfigFunc {
conf.ConfigureConn = fn conf.ConfigureConn = fn
} }
} }
func WithClientSmuxConfig(c *smux.Config) ClientConfigFunc {
return func(conf *ClientConfig) {
conf.SmuxConfig = c
}
}
// nolint: go-mnd
func DefaultClientConfigureConn(conn *kcp.UDPSession) error {
// Based on kcptun default configuration, mode 'fast3'
conn.SetStreamMode(true)
conn.SetWriteDelay(false)
conn.SetNoDelay(1, 10, 2, 1)
conn.SetWindowSize(128, 512)
conn.SetMtu(1400)
conn.SetACKNoDelay(true)
if err := conn.SetReadBuffer(16777217); err != nil {
return errors.WithStack(err)
}
if err := conn.SetWriteBuffer(16777217); err != nil {
return errors.WithStack(err)
}
if err := conn.SetDSCP(46); err != nil {
return errors.WithStack(err)
}
return nil
}

View File

@ -3,6 +3,7 @@ package main
import ( import (
"context" "context"
"flag" "flag"
"fmt"
"math/rand" "math/rand"
"time" "time"
@ -11,15 +12,18 @@ import (
"gitlab.com/wpetit/goweb/logger" "gitlab.com/wpetit/goweb/logger"
) )
const sharedKey = "go-tunnel"
const salt = "go-tunnel" const salt = "go-tunnel"
func main() { func main() {
var ( var (
clientID string clientID = fmt.Sprintf("client-%d", time.Now().Unix())
serverAddr = "127.0.0.1:36543"
sharedKey = "go-tunnel"
) )
flag.StringVar(&clientID, "id", "", "Client ID") flag.StringVar(&sharedKey, "shared-key", sharedKey, "shared key")
flag.StringVar(&clientID, "id", clientID, "Client ID")
flag.StringVar(&serverAddr, "server-addr", serverAddr, "server address")
flag.Parse() flag.Parse()
ctx := context.Background() ctx := context.Background()
@ -28,12 +32,12 @@ func main() {
logger.SetLevel(slog.LevelDebug) logger.SetLevel(slog.LevelDebug)
client := tunnel.NewClient( client := tunnel.NewClient(
tunnel.WithClientServerAddress(serverAddr),
tunnel.WithClientCredentials(clientID), tunnel.WithClientCredentials(clientID),
tunnel.WithClientAESBlockCrypt(sharedKey, salt),
) )
defer client.Close() defer client.Close()
initialBackoff := time.Second * 10 initialBackoff := time.Second * 2
backoff := initialBackoff backoff := initialBackoff
sleep := func() { sleep := func() {

View File

@ -2,12 +2,9 @@ package main
import ( import (
"context" "context"
"net" "flag"
"net/http" "net/http"
"net/http/httputil"
"net/url"
"strings" "strings"
"time"
"cdr.dev/slog" "cdr.dev/slog"
"forge.cadoles.com/wpetit/go-tunnel" "forge.cadoles.com/wpetit/go-tunnel"
@ -15,20 +12,61 @@ import (
"gitlab.com/wpetit/goweb/logger" "gitlab.com/wpetit/goweb/logger"
) )
const sharedKey = "go-tunnel"
const salt = "go-tunnel" const salt = "go-tunnel"
var registry = NewRegistry() var registry = NewRegistry()
func main() { func main() {
var (
serverAddr = ":36543"
httpAddr = ":3003"
sharedKey = "go-tunnel"
targetURL = "https://arcad.games"
)
flag.StringVar(&serverAddr, "server-addr", serverAddr, "server address")
flag.StringVar(&targetURL, "target-url", targetURL, "target url")
flag.StringVar(&httpAddr, "http-addr", httpAddr, "http server address")
flag.StringVar(&sharedKey, "shared-key", sharedKey, "shared key")
flag.Parse()
ctx := context.Background() ctx := context.Background()
logger.SetLevel(slog.LevelDebug) logger.SetLevel(slog.LevelDebug)
server := tunnel.NewServer( server := tunnel.NewServer(
tunnel.WithServerAESBlockCrypt(sharedKey, salt), tunnel.WithServerAddress(serverAddr),
tunnel.WithServerOnClientAuth(registry.OnClientAuth), tunnel.WithServerOnClientAuth(registry.OnClientAuth),
tunnel.WithServerOnClientDisconnect(registry.OnClientDisconnect), tunnel.WithServerOnClientDisconnect(registry.OnClientDisconnect),
tunnel.WithServerOnClientAuth(func(ctx context.Context, remoteClient *tunnel.RemoteClient, credentials interface{}) (bool, error) {
remoteAddr := remoteClient.RemoteAddr().String()
ctx = logger.With(ctx, logger.F("remoteAddr", remoteAddr))
logger.Debug(ctx, "new client auth")
clientID, ok := credentials.(string)
if !ok {
logger.Debug(ctx, "client auth failed")
return false, nil
}
registry.Add(clientID, remoteAddr, remoteClient)
logger.Debug(ctx, "client auth success")
return true, nil
}),
tunnel.WithServerOnClientDisconnect(func(ctx context.Context, remoteClient *tunnel.RemoteClient) error {
remoteAddr := remoteClient.RemoteAddr().String()
ctx = logger.With(ctx, logger.F("remoteAddr", remoteAddr))
logger.Debug(ctx, "client disconnected")
registry.RemoveByRemoteAddr(remoteAddr)
return nil
}),
) )
go func() { go func() {
@ -37,51 +75,28 @@ func main() {
} }
}() }()
if err := http.ListenAndServe(":3003", http.HandlerFunc(handleRequest)); err != nil { handler, err := createProxyHandler(targetURL)
if err != nil {
logger.Fatal(ctx, "could not create proxy handler", logger.E(errors.WithStack(err)))
}
if err := http.ListenAndServe(httpAddr, handler); err != nil {
logger.Fatal(ctx, "error while listening", logger.E(err)) logger.Fatal(ctx, "error while listening", logger.E(err))
} }
} }
func handleRequest(w http.ResponseWriter, r *http.Request) { func createProxyHandler(targetURL string) (http.Handler, error) {
return tunnel.ProxyHandler(targetURL, func(w http.ResponseWriter, r *http.Request) (*tunnel.RemoteClient, error) {
subdomains := strings.SplitN(r.Host, ".", 2) subdomains := strings.SplitN(r.Host, ".", 2)
if len(subdomains) < 2 { if len(subdomains) < 2 {
http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound) http.Error(w, http.StatusText(http.StatusNotFound), http.StatusNotFound)
return return nil, tunnel.ErrAbortProxy
} }
clientID := subdomains[0] clientID := subdomains[0]
remoteClient := registry.Get(clientID)
if remoteClient == nil { return registry.Get(clientID), nil
http.Error(w, http.StatusText(http.StatusBadGateway), http.StatusBadGateway) })
return
}
target, err := url.Parse("https://arcad.games")
if err != nil {
logger.Fatal(r.Context(), "could not parse url", logger.E(err))
}
reverse := httputil.NewSingleHostReverseProxy(target)
reverse.Transport = &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: func(ctx context.Context, network string, addr string) (net.Conn, error) {
conn, err := remoteClient.Proxy(ctx, network, addr)
if err != nil {
return nil, errors.WithStack(err)
}
return conn, nil
},
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
reverse.ServeHTTP(w, r)
} }

View File

@ -1,9 +0,0 @@
package control
type AuthRequestPayload struct {
Credentials interface{} `json:"c"`
}
type AuthResponsePayload struct {
Success bool `json:"s"`
}

View File

@ -1,194 +0,0 @@
package control
import (
"context"
"encoding/json"
"github.com/pkg/errors"
"github.com/xtaci/smux"
"gitlab.com/wpetit/goweb/logger"
)
type Control struct {
encoder *json.Encoder
decoder *json.Decoder
stream *smux.Stream
}
func (c *Control) Init(ctx context.Context, sess *smux.Session, serverMode bool) error {
config := smux.DefaultConfig()
config.Version = 2
logger.Debug(ctx, "creating control stream")
var (
stream *smux.Stream
err error
)
if serverMode {
stream, err = sess.AcceptStream()
if err != nil {
return errors.WithStack(err)
}
} else {
stream, err = sess.OpenStream()
if err != nil {
return errors.WithStack(err)
}
}
c.stream = stream
c.decoder = json.NewDecoder(stream)
c.encoder = json.NewEncoder(stream)
return nil
}
func (c *Control) AuthRequest(credentials interface{}) (bool, error) {
req := NewMessage(TypeAuthRequest, &AuthRequestPayload{
Credentials: credentials,
})
res := NewMessage(TypeAuthResponse, nil)
if err := c.reqRes(req, res); err != nil {
return false, errors.WithStack(err)
}
authResPayload, ok := res.Payload.(*AuthResponsePayload)
if !ok {
return false, errors.WithStack(ErrUnexpectedMessage)
}
return authResPayload.Success, nil
}
func (c *Control) ProxyReq(ctx context.Context, network, address string) error {
req := NewMessage(TypeProxyRequest, &ProxyRequestPayload{
Network: network,
Address: address,
})
if err := c.Write(req); err != nil {
return errors.WithStack(err)
}
return nil
}
func (c *Control) Listen(ctx context.Context, handlers Handlers) error {
errChan := make(chan error)
msgChan := make(chan *Message)
dieChan := c.stream.GetDieCh()
go func(msgChan chan *Message, errChan chan error) {
for {
logger.Debug(ctx, "reading next message")
msg, err := c.Read()
if err != nil {
errChan <- errors.WithStack(err)
close(errChan)
close(msgChan)
return
}
msgChan <- msg
}
}(msgChan, errChan)
for {
select {
case <-ctx.Done():
return nil
case <-dieChan:
return errors.WithStack(ErrStreamClosed)
case err := <-errChan:
return errors.WithStack(err)
case msg := <-msgChan:
go func() {
subCtx := logger.With(ctx, logger.F("messageType", msg.Type))
handler, exists := handlers[msg.Type]
if !exists {
logger.Error(subCtx, "no message handler registered")
return
}
res, err := handler(subCtx, msg)
if err != nil {
logger.Error(subCtx, "error while handling message", logger.E(err))
return
}
if res == nil {
return
}
if err := c.Write(res); err != nil {
logger.Error(subCtx, "error while write message response", logger.E(err))
return
}
}()
}
}
}
func (c *Control) Read() (*Message, error) {
message := &Message{}
if err := c.read(message); err != nil {
return nil, errors.WithStack(err)
}
return message, nil
}
func (c *Control) Write(m *Message) error {
if err := c.write(m); err != nil {
return errors.WithStack(err)
}
return nil
}
func (c *Control) reqRes(req *Message, res *Message) error {
if err := c.write(req); err != nil {
return errors.WithStack(err)
}
if err := c.read(res); err != nil {
return errors.WithStack(err)
}
return nil
}
func (c *Control) read(m *Message) error {
if err := c.decoder.Decode(m); err != nil {
return errors.WithStack(err)
}
return nil
}
func (c *Control) write(m *Message) error {
if err := c.encoder.Encode(m); err != nil {
return errors.WithStack(err)
}
return nil
}
func New() *Control {
return &Control{}
}

View File

@ -1,8 +0,0 @@
package control
import "errors"
var (
ErrStreamClosed = errors.New("stream closed")
ErrUnexpectedMessage = errors.New("unexpected message")
)

View File

@ -1,7 +0,0 @@
package control
import "context"
type Handlers map[MessageType]MessageHandler
type MessageHandler func(ctx context.Context, m *Message) (*Message, error)

View File

@ -1,76 +0,0 @@
package control
import (
"encoding/json"
"github.com/pkg/errors"
)
const (
TypeAuthRequest MessageType = "auth-req"
TypeAuthResponse MessageType = "auth-res"
TypeProxyRequest MessageType = "proxy-req"
TypeCloseProxy MessageType = "close-proxy"
)
type MessageType string
type BaseMessage struct {
Type MessageType `json:"t"`
RawPayload json.RawMessage `json:"p"`
}
type Message struct {
BaseMessage
Payload interface{} `json:"p"`
}
func (m *Message) UnmarshalJSON(data []byte) error {
base := &BaseMessage{}
if err := json.Unmarshal(data, base); err != nil {
return errors.WithStack(err)
}
payload, err := unmarshalPayload(base.Type, base.RawPayload)
if err != nil {
return errors.WithStack(err)
}
m.Type = base.Type
m.Payload = payload
return nil
}
func NewMessage(mType MessageType, payload interface{}) *Message {
return &Message{
BaseMessage: BaseMessage{
Type: mType,
},
Payload: payload,
}
}
func unmarshalPayload(mType MessageType, data []byte) (interface{}, error) {
var payload interface{}
switch mType {
case TypeAuthRequest:
payload = &AuthRequestPayload{}
case TypeAuthResponse:
payload = &AuthResponsePayload{}
case TypeProxyRequest:
payload = &ProxyRequestPayload{}
case TypeCloseProxy:
payload = &CloseProxyPayload{}
default:
return nil, errors.Wrapf(ErrUnexpectedMessage, "unexpected message type '%s'", mType)
}
if err := json.Unmarshal(data, payload); err != nil {
return nil, errors.WithStack(err)
}
return payload, nil
}

View File

@ -1,10 +0,0 @@
package control
type ProxyRequestPayload struct {
Network string `json:"n"`
Address string `json:"a"`
}
type CloseProxyPayload struct {
RequestID int64 `json:"i"`
}

View File

@ -6,7 +6,7 @@ var (
ErrNotConnected = errors.New("not connected") ErrNotConnected = errors.New("not connected")
ErrCouldNotConnect = errors.New("could not connect") ErrCouldNotConnect = errors.New("could not connect")
ErrConnectionClosed = errors.New("connection closed") ErrConnectionClosed = errors.New("connection closed")
ErrAuthFailed = errors.New("auth failed") ErrAuthenticationFailed = errors.New("authentication failed")
ErrUnexpectedMessage = errors.New("unexpected message") ErrUnexpectedMessage = errors.New("unexpected message")
ErrUnexpectedResponse = errors.New("unexpected response") ErrUnexpectedResponse = errors.New("unexpected response")
) )

1
go.mod
View File

@ -4,6 +4,7 @@ go 1.15
require ( require (
cdr.dev/slog v1.3.0 cdr.dev/slog v1.3.0
github.com/davecgh/go-spew v1.1.1
github.com/orcaman/concurrent-map v0.0.0-20190826125027-8c72a8bb44f6 github.com/orcaman/concurrent-map v0.0.0-20190826125027-8c72a8bb44f6
github.com/pkg/errors v0.9.1 github.com/pkg/errors v0.9.1
github.com/streamrail/concurrent-map v0.0.0-20160823150647-8bf1e9bacbf6 github.com/streamrail/concurrent-map v0.0.0-20160823150647-8bf1e9bacbf6

102
http.go Normal file
View File

@ -0,0 +1,102 @@
package tunnel
import (
"context"
"net"
"net/http"
"net/http/httputil"
"net/url"
"time"
"github.com/pkg/errors"
"gitlab.com/wpetit/goweb/logger"
)
type contextKey string
const remoteClientKey contextKey = "go-tunnel.remoteclient"
var (
ErrAbortProxy = errors.New("proxy aborted")
)
type MatchRequestFunc func(w http.ResponseWriter, r *http.Request) (*RemoteClient, error)
func ProxyHandler(targetURL string, match MatchRequestFunc, funcs ...ProxyConfigFunc) (http.Handler, error) {
conf := DefaultProxyConfig()
for _, fn := range funcs {
fn(conf)
}
target, err := url.Parse(targetURL)
if err != nil {
return nil, errors.WithStack(err)
}
reverse := createReverseProxy(target)
if conf.ConfigureReverseProxy != nil {
if err := conf.ConfigureReverseProxy(reverse); err != nil {
return nil, errors.WithStack(err)
}
}
fn := func(w http.ResponseWriter, r *http.Request) {
remoteClient, err := match(w, r)
if errors.Is(err, ErrAbortProxy) {
return
}
if err != nil {
logger.Error(r.Context(), "could not match proxy request", logger.E(errors.WithStack(err)))
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
return
}
if remoteClient == nil {
http.Error(w, http.StatusText(http.StatusBadGateway), http.StatusBadGateway)
return
}
ctx := context.WithValue(r.Context(), remoteClientKey, remoteClient)
r = r.WithContext(ctx)
reverse.ServeHTTP(w, r)
}
return http.HandlerFunc(fn), nil
}
func createReverseProxy(target *url.URL) *httputil.ReverseProxy {
reverse := httputil.NewSingleHostReverseProxy(target)
// nolint: go-mnd
reverse.Transport = &http.Transport{
Proxy: http.ProxyFromEnvironment,
DialContext: func(ctx context.Context, network string, addr string) (net.Conn, error) {
remoteClient, ok := ctx.Value(remoteClientKey).(*RemoteClient)
if !ok {
return nil, errors.New("could not retrieve remote client")
}
conn, err := remoteClient.Proxy(ctx, network, addr)
if err != nil {
return nil, errors.WithStack(err)
}
return conn, nil
},
ForceAttemptHTTP2: true,
MaxIdleConns: 100,
IdleConnTimeout: 90 * time.Second,
TLSHandshakeTimeout: 10 * time.Second,
ExpectContinueTimeout: 1 * time.Second,
}
reverse.FlushInterval = 0
return reverse
}

21
http_config.go Normal file
View File

@ -0,0 +1,21 @@
package tunnel
import "net/http/httputil"
type ConfigureReverseProxyFunc func(*httputil.ReverseProxy) error
type ProxyConfig struct {
ConfigureReverseProxy ConfigureReverseProxyFunc
}
func DefaultProxyConfig() *ProxyConfig {
return &ProxyConfig{}
}
type ProxyConfigFunc func(c *ProxyConfig)
func WithProxyConfigure(fn ConfigureReverseProxyFunc) ProxyConfigFunc {
return func(c *ProxyConfig) {
c.ConfigureReverseProxy = fn
}
}

View File

@ -1,7 +1,7 @@
**/*.go { **/*.go
modd.conf {
prep: make test prep: make test
prep: make build prep: make build
daemon: ./bin/server daemon: ./bin/server -target-url http://127.0.0.1:3000
daemon: ./bin/client -id client1 daemon: ./bin/client -id client1
daemon: ./bin/client -id client2
} }

14
protocol.go Normal file
View File

@ -0,0 +1,14 @@
package tunnel
type authRequest struct {
Credentials interface{} `json:"c"`
}
type authResponse struct {
Success bool `json:"b"`
}
type proxyRequest struct {
Network string `json:"n"`
Address string `json:"a"`
}

View File

@ -2,13 +2,11 @@ package tunnel
import ( import (
"context" "context"
"encoding/json"
"net" "net"
"sync" "sync"
"time" "time"
"forge.cadoles.com/wpetit/go-tunnel/control"
cmap "github.com/orcaman/concurrent-map"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/xtaci/kcp-go/v5" "github.com/xtaci/kcp-go/v5"
"github.com/xtaci/smux" "github.com/xtaci/smux"
@ -19,61 +17,45 @@ type RemoteClient struct {
onClientAuthHook OnClientAuthHook onClientAuthHook OnClientAuthHook
onClientConnectHook OnClientConnectHook onClientConnectHook OnClientConnectHook
onClientDisconnectHook OnClientDisconnectHook onClientDisconnectHook OnClientDisconnectHook
conn *kcp.UDPSession
sess *smux.Session sess *smux.Session
control *control.Control
remoteAddr net.Addr remoteAddr net.Addr
proxies cmap.ConcurrentMap authenticationTimeout time.Duration
acceptStreamMutex sync.Mutex proxyRequestTimeout time.Duration
connMutex sync.RWMutex
smuxConfig *smux.Config
} }
func (c *RemoteClient) Accept(ctx context.Context, conn *kcp.UDPSession) error { func (c *RemoteClient) Accept(ctx context.Context, conn *kcp.UDPSession) error {
config := smux.DefaultConfig() c.connMutex.Lock()
config.Version = 2 defer c.connMutex.Unlock()
config.KeepAliveInterval = 10 * time.Second
config.KeepAliveTimeout = 2 * config.KeepAliveInterval
logger.Debug(ctx, "creating server session") if err := c.Close(); err != nil {
return errors.WithStack(err)
}
sess, err := smux.Server(conn, config) sess, err := c.acceptSession(ctx, conn)
if err != nil { if err != nil {
return errors.WithStack(err) return errors.WithStack(err)
} }
ctrl := control.New() stream, err := sess.AcceptStream()
if err != nil {
return errors.WithStack(err)
}
if err := ctrl.Init(ctx, sess, true); err != nil { defer stream.Close()
if err := c.authenticate(ctx, stream); err != nil {
return errors.WithStack(err) return errors.WithStack(err)
} }
c.sess = sess c.sess = sess
c.remoteAddr = conn.RemoteAddr() c.conn = conn
c.control = ctrl
if c.onClientConnectHook != nil {
if err := c.onClientConnectHook.OnClientConnect(ctx, c); err != nil {
return errors.WithStack(err)
}
}
return nil return nil
} }
func (c *RemoteClient) Listen(ctx context.Context) error {
defer func() {
if c.onClientDisconnectHook != nil {
if err := c.onClientDisconnectHook.OnClientDisconnect(ctx, c); err != nil {
logger.Error(ctx, "client disconnect hook error", logger.E(errors.WithStack(err)))
}
}
}()
logger.Debug(ctx, "listening for messages")
return c.control.Listen(ctx, control.Handlers{
control.TypeAuthRequest: c.handleAuthRequest,
})
}
func (c *RemoteClient) ConfigureHooks(hooks interface{}) { func (c *RemoteClient) ConfigureHooks(hooks interface{}) {
if hooks == nil { if hooks == nil {
return return
@ -96,74 +78,164 @@ func (c *RemoteClient) RemoteAddr() net.Addr {
return c.remoteAddr return c.remoteAddr
} }
func (c *RemoteClient) Close() { func (c *RemoteClient) Close() error {
if c.sess != nil { if c.sess != nil {
c.sess.Close() if err := c.sess.Close(); err != nil {
return errors.WithStack(err)
}
}
if c.conn != nil {
if err := c.conn.Close(); err != nil {
return errors.WithStack(err)
}
} }
c.sess = nil c.sess = nil
c.control = nil c.conn = nil
return nil
}
func (c *RemoteClient) SwitchConn(ctx context.Context, conn *kcp.UDPSession) error {
c.connMutex.Lock()
defer c.connMutex.Unlock()
if err := c.Close(); err != nil {
return errors.WithStack(err)
}
sess, err := c.acceptSession(ctx, conn)
if err != nil {
return errors.WithStack(err)
}
c.sess = sess
c.conn = conn
return nil
} }
func (c *RemoteClient) Proxy(ctx context.Context, network, address string) (net.Conn, error) { func (c *RemoteClient) Proxy(ctx context.Context, network, address string) (net.Conn, error) {
ctx = logger.With(ctx, logger.F("network", network), logger.F("address", address)) c.connMutex.RLock()
defer c.connMutex.RUnlock()
if err := c.control.ProxyReq(ctx, network, address); err != nil { ctx = logger.With(ctx, logger.F("network", network), logger.F("address", address))
return nil, errors.WithStack(err)
}
logger.Debug(ctx, "opening proxy stream") logger.Debug(ctx, "opening proxy stream")
c.acceptStreamMutex.Lock() stream, err := c.sess.OpenStream()
stream, err := c.sess.AcceptStream()
if err != nil { if err != nil {
c.acceptStreamMutex.Unlock()
return nil, errors.WithStack(err) return nil, errors.WithStack(err)
} }
c.acceptStreamMutex.Unlock() proxyReq := &proxyRequest{
Network: network,
Address: address,
}
encoder := json.NewEncoder(stream)
go func() { writeDeadline := time.Now().Add(c.proxyRequestTimeout)
<-ctx.Done() logger.Debug(ctx, "sending proxy req", logger.F("deadline", writeDeadline))
logger.Debug(ctx, "closing proxy stream")
if err := stream.SetWriteDeadline(writeDeadline); err != nil {
stream.Close() stream.Close()
}()
return nil, errors.WithStack(err)
}
if err := encoder.Encode(proxyReq); err != nil {
stream.Close()
return nil, errors.WithStack(err)
}
if err := stream.SetWriteDeadline(time.Time{}); err != nil {
stream.Close()
return nil, errors.WithStack(err)
}
return stream, nil return stream, nil
} }
func (c *RemoteClient) handleAuthRequest(ctx context.Context, m *control.Message) (*control.Message, error) { func (c *RemoteClient) acceptSession(ctx context.Context, conn *kcp.UDPSession) (*smux.Session, error) {
authReqPayload, ok := m.Payload.(*control.AuthRequestPayload) logger.Debug(ctx, "accepting client session")
if !ok {
return nil, errors.WithStack(ErrUnexpectedMessage) sess, err := smux.Server(conn, c.smuxConfig)
if err != nil {
return nil, errors.WithStack(err)
} }
logger.Debug(ctx, "handling auth request", logger.F("credentials", authReqPayload.Credentials)) c.remoteAddr = conn.RemoteAddr()
if c.onClientConnectHook != nil {
if err := c.onClientConnectHook.OnClientConnect(ctx, c); err != nil {
return nil, errors.WithStack(err)
}
}
return sess, nil
}
func (c *RemoteClient) authenticate(ctx context.Context, stream *smux.Stream) error {
start := time.Now()
readDeadline := time.Now().Add(c.authenticationTimeout)
logger.Debug(ctx, "waiting for auth request", logger.F("deadline", readDeadline))
if err := stream.SetReadDeadline(readDeadline); err != nil {
return errors.WithStack(err)
}
decoder := json.NewDecoder(stream)
authReq := &authRequest{}
if err := decoder.Decode(authReq); err != nil {
return errors.WithStack(err)
}
var ( var (
success bool success bool
err error err error
) )
logger.Debug(ctx, "received client credentials", logger.F("credentials", authReq.Credentials))
if c.onClientAuthHook != nil { if c.onClientAuthHook != nil {
success, err = c.onClientAuthHook.OnClientAuth(ctx, c, authReqPayload.Credentials) success, err = c.onClientAuthHook.OnClientAuth(ctx, c, authReq.Credentials)
if err != nil { if err != nil {
return nil, errors.WithStack(err) return errors.WithStack(err)
} }
} }
logger.Debug(ctx, "auth succeeded", logger.F("credentials", authReqPayload.Credentials)) authRes := &authResponse{
res := control.NewMessage(control.TypeAuthResponse, &control.AuthResponsePayload{
Success: success, Success: success,
}) }
encoder := json.NewEncoder(stream)
return res, nil writeDeadline := time.Now().Add(c.authenticationTimeout - time.Since(start))
logger.Debug(ctx, "sending auth response", logger.F("deadline", writeDeadline))
if err := stream.SetWriteDeadline(writeDeadline); err != nil {
return errors.WithStack(err)
} }
func NewRemoteClient() *RemoteClient { if err := encoder.Encode(authRes); err != nil {
return errors.WithStack(err)
}
if !success {
return errors.WithStack(ErrAuthenticationFailed)
}
return nil
}
func NewRemoteClient(smuxConfig *smux.Config, authenticationTimeout, proxyRequestTimeout time.Duration) *RemoteClient {
return &RemoteClient{ return &RemoteClient{
proxies: cmap.New(), smuxConfig: smuxConfig,
authenticationTimeout: authenticationTimeout,
proxyRequestTimeout: proxyRequestTimeout,
} }
} }

View File

@ -3,8 +3,8 @@ package tunnel
import ( import (
"context" "context"
cmap "github.com/orcaman/concurrent-map"
"github.com/pkg/errors" "github.com/pkg/errors"
cmap "github.com/streamrail/concurrent-map"
"github.com/xtaci/kcp-go/v5" "github.com/xtaci/kcp-go/v5"
"gitlab.com/wpetit/goweb/logger" "gitlab.com/wpetit/goweb/logger"
) )
@ -23,6 +23,14 @@ func (s *Server) Listen(ctx context.Context) error {
return errors.WithStack(err) return errors.WithStack(err)
} }
if s.conf.ConfigureListener != nil {
if err := s.conf.ConfigureListener(listener); err != nil {
return errors.WithStack(err)
}
}
logger.Debug(ctx, "accepting connections", logger.F("address", s.conf.Address))
for { for {
conn, err := listener.AcceptKCP() conn, err := listener.AcceptKCP()
if err != nil { if err != nil {
@ -34,12 +42,31 @@ func (s *Server) Listen(ctx context.Context) error {
} }
func (s *Server) handleNewConn(ctx context.Context, conn *kcp.UDPSession) { func (s *Server) handleNewConn(ctx context.Context, conn *kcp.UDPSession) {
ctx = logger.With(ctx, logger.F("remoteAddr", conn.RemoteAddr().String())) var remoteClient *RemoteClient
remoteClient := NewRemoteClient() remoteAddr := conn.RemoteAddr().String()
ctx = logger.With(ctx, logger.F("remoteAddr", remoteAddr))
defer remoteClient.Close() rawExistingClient, exists := s.clients.Get(remoteAddr)
defer conn.Close() if exists {
logger.Debug(ctx, "remote client already exists")
remoteClient, _ = rawExistingClient.(*RemoteClient)
if err := remoteClient.SwitchConn(ctx, conn); err != nil {
logger.Error(ctx, "remote client error", logger.E(errors.WithStack(err)))
s.clients.Remove(remoteAddr)
return
}
}
remoteClient = NewRemoteClient(
s.conf.SmuxConfig,
s.conf.AuthenticationTimeout,
s.conf.ProxyRequestTimeout,
)
remoteClient.ConfigureHooks(s.conf.Hooks) remoteClient.ConfigureHooks(s.conf.Hooks)
@ -49,9 +76,7 @@ func (s *Server) handleNewConn(ctx context.Context, conn *kcp.UDPSession) {
return return
} }
if err := remoteClient.Listen(ctx); err != nil { s.clients.Set(remoteAddr, remoteClient)
logger.Error(ctx, "remote client error", logger.E(errors.WithStack(err)))
}
} }
func NewServer(funcs ...ServerConfigFunc) *Server { func NewServer(funcs ...ServerConfigFunc) *Server {

View File

@ -2,13 +2,16 @@ package tunnel
import ( import (
"crypto/sha1" "crypto/sha1"
"time"
"github.com/pkg/errors" "github.com/pkg/errors"
"github.com/xtaci/kcp-go/v5" "github.com/xtaci/kcp-go/v5"
"github.com/xtaci/smux"
"golang.org/x/crypto/pbkdf2" "golang.org/x/crypto/pbkdf2"
) )
type ConfigureConnFunc func(conn *kcp.UDPSession) error type ConfigureConnFunc func(conn *kcp.UDPSession) error
type ConfigureListenerFunc func(listener *kcp.Listener) error
type ServerConfig struct { type ServerConfig struct {
Address string Address string
@ -17,14 +20,25 @@ type ServerConfig struct {
ParityShards int ParityShards int
Hooks *ServerHooks Hooks *ServerHooks
ConfigureConn ConfigureConnFunc ConfigureConn ConfigureConnFunc
ConfigureListener ConfigureListenerFunc
AuthenticationTimeout time.Duration
ProxyRequestTimeout time.Duration
SmuxConfig *smux.Config
} }
// nolint: go-mnd
func DefaultServerConfig() *ServerConfig { func DefaultServerConfig() *ServerConfig {
unencryptedBlock, err := kcp.NewNoneBlockCrypt(nil) unencryptedBlock, err := kcp.NewNoneBlockCrypt(nil)
if err != nil { // should never happen if err != nil { // should never happen
panic(errors.WithStack(err)) panic(errors.WithStack(err))
} }
smuxConfig := smux.DefaultConfig()
smuxConfig.Version = 2
smuxConfig.KeepAliveInterval = 10 * time.Second
smuxConfig.MaxReceiveBuffer = 4194304
smuxConfig.MaxStreamBuffer = 2097152
return &ServerConfig{ return &ServerConfig{
Address: ":36543", Address: ":36543",
BlockCrypt: unencryptedBlock, BlockCrypt: unencryptedBlock,
@ -35,6 +49,11 @@ func DefaultServerConfig() *ServerConfig {
onClientDisconnect: DefaultOnClientDisconnect, onClientDisconnect: DefaultOnClientDisconnect,
onClientAuth: DefaultOnClientAuth, onClientAuth: DefaultOnClientAuth,
}, },
ConfigureConn: DefaultServerConfigureConn,
ConfigureListener: DefaultServerConfigureListener,
AuthenticationTimeout: 30 * time.Second,
ProxyRequestTimeout: 5 * time.Second,
SmuxConfig: smuxConfig,
} }
} }
@ -82,3 +101,50 @@ func WithServerConfigureConn(fn ConfigureConnFunc) ServerConfigFunc {
conf.ConfigureConn = fn conf.ConfigureConn = fn
} }
} }
func WithServerConfigureListener(fn ConfigureListenerFunc) ServerConfigFunc {
return func(conf *ServerConfig) {
conf.ConfigureListener = fn
}
}
func WithServerSmuxConfig(c *smux.Config) ServerConfigFunc {
return func(conf *ServerConfig) {
conf.SmuxConfig = c
}
}
// nolint: go-mnd
func DefaultServerConfigureConn(conn *kcp.UDPSession) error {
// Based on kcptun default configuration, mode 'fast3'
conn.SetStreamMode(true)
conn.SetWriteDelay(false)
conn.SetNoDelay(1, 10, 2, 1)
conn.SetWindowSize(128, 512)
conn.SetMtu(1400)
conn.SetACKNoDelay(true)
if err := conn.SetDSCP(46); err != nil {
return errors.WithStack(err)
}
return nil
}
// nolint: go-mnd
func DefaultServerConfigureListener(listener *kcp.Listener) error {
// Based on kcptun default configuration, mode 'fast3'
if err := listener.SetReadBuffer(16777217); err != nil {
return errors.WithStack(err)
}
if err := listener.SetWriteBuffer(16777217); err != nil {
return errors.WithStack(err)
}
if err := listener.SetDSCP(46); err != nil {
return errors.WithStack(err)
}
return nil
}