Compare commits

..

1 Commits

Author SHA1 Message Date
wpetit 9186173322 fix: enhance proxy stability 2020-10-24 18:43:17 +02:00
3 changed files with 21 additions and 27 deletions

View File

@ -12,6 +12,7 @@ import (
type Control struct {
encoder *json.Encoder
decoder *json.Decoder
stream *smux.Stream
}
func (c *Control) Init(ctx context.Context, sess *smux.Session, serverMode bool) error {
@ -21,24 +22,25 @@ func (c *Control) Init(ctx context.Context, sess *smux.Session, serverMode bool)
logger.Debug(ctx, "creating control stream")
var (
controlStream *smux.Stream
err error
stream *smux.Stream
err error
)
if serverMode {
controlStream, err = sess.AcceptStream()
stream, err = sess.AcceptStream()
if err != nil {
return errors.WithStack(err)
}
} else {
controlStream, err = sess.OpenStream()
stream, err = sess.OpenStream()
if err != nil {
return errors.WithStack(err)
}
}
c.decoder = json.NewDecoder(controlStream)
c.encoder = json.NewEncoder(controlStream)
c.stream = stream
c.decoder = json.NewDecoder(stream)
c.encoder = json.NewEncoder(stream)
return nil
}
@ -78,6 +80,7 @@ func (c *Control) ProxyReq(ctx context.Context, network, address string) error {
func (c *Control) Listen(ctx context.Context, handlers Handlers) error {
errChan := make(chan error)
msgChan := make(chan *Message)
dieChan := c.stream.GetDieCh()
go func(msgChan chan *Message, errChan chan error) {
for {
@ -102,18 +105,13 @@ func (c *Control) Listen(ctx context.Context, handlers Handlers) error {
case <-ctx.Done():
return nil
case err, ok := <-errChan:
if !ok {
return nil
}
case <-dieChan:
return errors.WithStack(ErrStreamClosed)
return err
case msg, ok := <-msgChan:
if !ok {
return nil
}
case err := <-errChan:
return errors.WithStack(err)
case msg := <-msgChan:
go func() {
subCtx := logger.With(ctx, logger.F("messageType", msg.Type))

View File

@ -3,5 +3,6 @@ package control
import "errors"
var (
ErrStreamClosed = errors.New("stream closed")
ErrUnexpectedMessage = errors.New("unexpected message")
)

View File

@ -2,7 +2,6 @@ package tunnel
import (
"context"
"io"
"net"
"sync"
"time"
@ -60,23 +59,19 @@ func (c *RemoteClient) Accept(ctx context.Context, conn *kcp.UDPSession) error {
}
func (c *RemoteClient) Listen(ctx context.Context) error {
logger.Debug(ctx, "listening for messages")
err := c.control.Listen(ctx, control.Handlers{
control.TypeAuthRequest: c.handleAuthRequest,
})
if errors.Is(err, io.ErrClosedPipe) {
defer func() {
if c.onClientDisconnectHook != nil {
if err := c.onClientDisconnectHook.OnClientDisconnect(ctx, c); err != nil {
logger.Error(ctx, "client disconnect hook error", logger.E(errors.WithStack(err)))
}
}
}()
return errors.WithStack(ErrConnectionClosed)
}
logger.Debug(ctx, "listening for messages")
return errors.WithStack(err)
return c.control.Listen(ctx, control.Handlers{
control.TypeAuthRequest: c.handleAuthRequest,
})
}
func (c *RemoteClient) ConfigureHooks(hooks interface{}) {