Compare commits
1 Commits
7c054fee58
...
9186173322
Author | SHA1 | Date |
---|---|---|
wpetit | 9186173322 |
|
@ -12,6 +12,7 @@ import (
|
|||
type Control struct {
|
||||
encoder *json.Encoder
|
||||
decoder *json.Decoder
|
||||
stream *smux.Stream
|
||||
}
|
||||
|
||||
func (c *Control) Init(ctx context.Context, sess *smux.Session, serverMode bool) error {
|
||||
|
@ -21,24 +22,25 @@ func (c *Control) Init(ctx context.Context, sess *smux.Session, serverMode bool)
|
|||
logger.Debug(ctx, "creating control stream")
|
||||
|
||||
var (
|
||||
controlStream *smux.Stream
|
||||
err error
|
||||
stream *smux.Stream
|
||||
err error
|
||||
)
|
||||
|
||||
if serverMode {
|
||||
controlStream, err = sess.AcceptStream()
|
||||
stream, err = sess.AcceptStream()
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
} else {
|
||||
controlStream, err = sess.OpenStream()
|
||||
stream, err = sess.OpenStream()
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
}
|
||||
|
||||
c.decoder = json.NewDecoder(controlStream)
|
||||
c.encoder = json.NewEncoder(controlStream)
|
||||
c.stream = stream
|
||||
c.decoder = json.NewDecoder(stream)
|
||||
c.encoder = json.NewEncoder(stream)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
@ -78,6 +80,7 @@ func (c *Control) ProxyReq(ctx context.Context, network, address string) error {
|
|||
func (c *Control) Listen(ctx context.Context, handlers Handlers) error {
|
||||
errChan := make(chan error)
|
||||
msgChan := make(chan *Message)
|
||||
dieChan := c.stream.GetDieCh()
|
||||
|
||||
go func(msgChan chan *Message, errChan chan error) {
|
||||
for {
|
||||
|
@ -102,18 +105,13 @@ func (c *Control) Listen(ctx context.Context, handlers Handlers) error {
|
|||
case <-ctx.Done():
|
||||
return nil
|
||||
|
||||
case err, ok := <-errChan:
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
case <-dieChan:
|
||||
return errors.WithStack(ErrStreamClosed)
|
||||
|
||||
return err
|
||||
|
||||
case msg, ok := <-msgChan:
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
case err := <-errChan:
|
||||
return errors.WithStack(err)
|
||||
|
||||
case msg := <-msgChan:
|
||||
go func() {
|
||||
subCtx := logger.With(ctx, logger.F("messageType", msg.Type))
|
||||
|
||||
|
|
|
@ -3,5 +3,6 @@ package control
|
|||
import "errors"
|
||||
|
||||
var (
|
||||
ErrStreamClosed = errors.New("stream closed")
|
||||
ErrUnexpectedMessage = errors.New("unexpected message")
|
||||
)
|
||||
|
|
|
@ -2,7 +2,6 @@ package tunnel
|
|||
|
||||
import (
|
||||
"context"
|
||||
"io"
|
||||
"net"
|
||||
"sync"
|
||||
"time"
|
||||
|
@ -60,23 +59,19 @@ func (c *RemoteClient) Accept(ctx context.Context, conn *kcp.UDPSession) error {
|
|||
}
|
||||
|
||||
func (c *RemoteClient) Listen(ctx context.Context) error {
|
||||
logger.Debug(ctx, "listening for messages")
|
||||
|
||||
err := c.control.Listen(ctx, control.Handlers{
|
||||
control.TypeAuthRequest: c.handleAuthRequest,
|
||||
})
|
||||
|
||||
if errors.Is(err, io.ErrClosedPipe) {
|
||||
defer func() {
|
||||
if c.onClientDisconnectHook != nil {
|
||||
if err := c.onClientDisconnectHook.OnClientDisconnect(ctx, c); err != nil {
|
||||
logger.Error(ctx, "client disconnect hook error", logger.E(errors.WithStack(err)))
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return errors.WithStack(ErrConnectionClosed)
|
||||
}
|
||||
logger.Debug(ctx, "listening for messages")
|
||||
|
||||
return errors.WithStack(err)
|
||||
return c.control.Listen(ctx, control.Handlers{
|
||||
control.TypeAuthRequest: c.handleAuthRequest,
|
||||
})
|
||||
}
|
||||
|
||||
func (c *RemoteClient) ConfigureHooks(hooks interface{}) {
|
||||
|
|
Loading…
Reference in New Issue