diff --git a/client.go b/client.go index b197a0d..a69ff68 100644 --- a/client.go +++ b/client.go @@ -35,7 +35,11 @@ func (c *Client) Connect(ctx context.Context) error { return errors.WithStack(err) } - conn.SetWriteDelay(false) + if c.conf.ConfigureConn != nil { + if err := c.conf.ConfigureConn(conn); err != nil { + return errors.WithStack(err) + } + } config := smux.DefaultConfig() config.Version = 2 diff --git a/client_config.go b/client_config.go index 6abc7ef..9f2eb00 100644 --- a/client_config.go +++ b/client_config.go @@ -14,6 +14,7 @@ type ClientConfig struct { DataShards int ParityShards int Credentials interface{} + ConfigureConn ConfigureConnFunc } func DefaultClientConfig() *ClientConfig { @@ -31,6 +32,8 @@ func DefaultClientConfig() *ClientConfig { } } +type ClientConfigFunc func(c *ClientConfig) + func WithClientServerAddress(addr string) ClientConfigFunc { return func(conf *ClientConfig) { conf.ServerAddress = addr @@ -43,17 +46,21 @@ func WithClientCredentials(credentials interface{}) ClientConfigFunc { } } -func WithClientAESBlockCrypt(pass, salt string) ClientConfigFunc { +func WithClientBlockCrypt(alg string, pass, salt string, iterations, keyLen int) ClientConfigFunc { return func(conf *ClientConfig) { - key := pbkdf2.Key([]byte(pass), []byte(salt), 1024, 32, sha1.New) + key := pbkdf2.Key([]byte(pass), []byte(salt), iterations, keyLen, sha1.New) - block, err := kcp.NewAESBlockCrypt(key) + block, err := createBlockCrypt(alg, key) if err != nil { - panic(errors.WithStack(err)) + panic(errors.Wrap(err, "could not create block crypt")) } conf.BlockCrypt = block } } -type ClientConfigFunc func(c *ClientConfig) +func WithClientConfigureConn(fn ConfigureConnFunc) ClientConfigFunc { + return func(conf *ClientConfig) { + conf.ConfigureConn = fn + } +} diff --git a/helper.go b/helper.go index 1ea5957..2e822da 100644 --- a/helper.go +++ b/helper.go @@ -6,6 +6,7 @@ import ( "net" "github.com/pkg/errors" + "github.com/xtaci/kcp-go/v5" "gitlab.com/wpetit/goweb/logger" ) @@ -47,3 +48,34 @@ func relay(src net.Conn, dst net.Conn, stop chan bool) (err error) { return } + +func createBlockCrypt(algorithm string, pass []byte) (kcp.BlockCrypt, error) { + switch algorithm { + case "sm4": + return kcp.NewSM4BlockCrypt(pass[:16]) + case "tea": + return kcp.NewTEABlockCrypt(pass[:16]) + case "xor": + return kcp.NewSimpleXORBlockCrypt(pass) + case "none": + return kcp.NewNoneBlockCrypt(pass) + case "aes-128": + return kcp.NewAESBlockCrypt(pass[:16]) + case "aes-192": + return kcp.NewAESBlockCrypt(pass[:24]) + case "blowfish": + return kcp.NewBlowfishBlockCrypt(pass) + case "twofish": + return kcp.NewTwofishBlockCrypt(pass) + case "cast5": + return kcp.NewCast5BlockCrypt(pass[:16]) + case "3des": + return kcp.NewTripleDESBlockCrypt(pass[:24]) + case "xtea": + return kcp.NewXTEABlockCrypt(pass[:16]) + case "salsa20": + return kcp.NewSalsa20BlockCrypt(pass) + default: + return nil, errors.Errorf("unknown algorithm '%s'", algorithm) + } +} diff --git a/remote_client.go b/remote_client.go index 0fb1b66..f8a4cff 100644 --- a/remote_client.go +++ b/remote_client.go @@ -4,7 +4,6 @@ import ( "context" "io" "net" - "time" "forge.cadoles.com/wpetit/go-tunnel/control" @@ -42,10 +41,6 @@ func (c *RemoteClient) Accept(ctx context.Context, conn *kcp.UDPSession) error { logger.Debug(ctx, "accepting control stream") - if err := sess.SetDeadline(time.Now().Add(30 * time.Second)); err != nil { - return errors.WithStack(err) - } - controlStream, err := sess.AcceptStream() if err != nil { return errors.WithStack(err) diff --git a/server_config.go b/server_config.go index f1267be..8cf3450 100644 --- a/server_config.go +++ b/server_config.go @@ -8,12 +8,15 @@ import ( "golang.org/x/crypto/pbkdf2" ) +type ConfigureConnFunc func(conn *kcp.UDPSession) error + type ServerConfig struct { - Address string - BlockCrypt kcp.BlockCrypt - DataShards int - ParityShards int - Hooks *ServerHooks + Address string + BlockCrypt kcp.BlockCrypt + DataShards int + ParityShards int + Hooks *ServerHooks + ConfigureConn ConfigureConnFunc } func DefaultServerConfig() *ServerConfig { @@ -43,13 +46,13 @@ func WithServerAddress(address string) ServerConfigFunc { } } -func WithServerAESBlockCrypt(pass, salt string) ServerConfigFunc { +func WithServerBlockCrypt(alg string, pass, salt string, iterations, keyLen int) ServerConfigFunc { return func(conf *ServerConfig) { - key := pbkdf2.Key([]byte(pass), []byte(salt), 1024, 32, sha1.New) + key := pbkdf2.Key([]byte(pass), []byte(salt), iterations, keyLen, sha1.New) - block, err := kcp.NewAESBlockCrypt(key) + block, err := createBlockCrypt(alg, key) if err != nil { - panic(errors.WithStack(err)) + panic(errors.Wrap(err, "could not create block crypt")) } conf.BlockCrypt = block @@ -73,3 +76,9 @@ func WithServerOnClientDisconnect(fn OnClientDisconnectFunc) ServerConfigFunc { conf.Hooks.onClientDisconnect = fn } } + +func WithServerConfigureConn(fn ConfigureConnFunc) ServerConfigFunc { + return func(conf *ServerConfig) { + conf.ConfigureConn = fn + } +}