package rebound import ( "bufio" "net" "time" "github.com/pkg/errors" ) func (s *Server) muxListener(l net.Listener) (ssh net.Listener, other net.Listener) { sshListener, otherListener := newListener(l), newListener(l) go func() { for { conn, err := l.Accept() if err != nil { if errors.Is(err, net.ErrClosed) { return } s.log("[ERROR] %+v", errors.WithStack(err)) continue } if err := conn.SetReadDeadline(time.Now().Add(time.Second * 10)); err != nil { s.log("[ERROR] %+v", errors.WithStack(err)) continue } bconn := bufferedConn{conn, bufio.NewReaderSize(conn, 3)} p, err := bconn.Peek(3) if err != nil { s.log("[ERROR] %+v", errors.WithStack(err)) continue } if err := conn.SetReadDeadline(time.Time{}); err != nil { s.log("[ERROR] %+v", errors.WithStack(err)) continue } selectedListener := otherListener if prefix := string(p); prefix == "SSH" { s.log("[INFO] new ssh connection from '%s'", conn.RemoteAddr()) selectedListener = sshListener } else { s.log("[INFO] new http connection from '%s'", conn.RemoteAddr()) } if selectedListener.accept != nil { selectedListener.accept <- bconn } } }() return sshListener, otherListener } type listener struct { accept chan net.Conn net.Listener } func newListener(l net.Listener) *listener { return &listener{ make(chan net.Conn), l, } } func (l *listener) Accept() (net.Conn, error) { if l.accept == nil { return nil, errors.New("listener closed") } return <-l.accept, nil } func (l *listener) Close() error { close(l.accept) l.accept = nil return nil } type bufferedConn struct { net.Conn r *bufio.Reader } func (b bufferedConn) Peek(n int) ([]byte, error) { return b.r.Peek(n) } func (b bufferedConn) Read(p []byte) (int, error) { return b.r.Read(p) }