rebound/mux_listener.go

101 lines
1.8 KiB
Go
Raw Normal View History

package rebound
import (
"bufio"
"net"
"time"
"github.com/pkg/errors"
)
func (s *Server) muxListener(l net.Listener) (ssh net.Listener, other net.Listener) {
sshListener, otherListener := newListener(l), newListener(l)
go func() {
for {
conn, err := l.Accept()
if err != nil {
if errors.Is(err, net.ErrClosed) {
return
}
s.log("[ERROR] %+v", errors.WithStack(err))
continue
}
if err := conn.SetReadDeadline(time.Now().Add(time.Second * 10)); err != nil {
s.log("[ERROR] %+v", errors.WithStack(err))
continue
}
bconn := bufferedConn{conn, bufio.NewReaderSize(conn, 3)}
p, err := bconn.Peek(3)
if err != nil {
s.log("[ERROR] %+v", errors.WithStack(err))
continue
}
if err := conn.SetReadDeadline(time.Time{}); err != nil {
s.log("[ERROR] %+v", errors.WithStack(err))
continue
}
selectedListener := otherListener
if prefix := string(p); prefix == "SSH" {
s.log("[INFO] new ssh connection from '%s'", conn.RemoteAddr())
selectedListener = sshListener
} else {
s.log("[INFO] new http connection from '%s'", conn.RemoteAddr())
}
if selectedListener.accept != nil {
selectedListener.accept <- bconn
}
}
}()
return sshListener, otherListener
}
type listener struct {
accept chan net.Conn
net.Listener
}
func newListener(l net.Listener) *listener {
return &listener{
make(chan net.Conn),
l,
}
}
func (l *listener) Accept() (net.Conn, error) {
if l.accept == nil {
return nil, errors.New("listener closed")
}
return <-l.accept, nil
}
func (l *listener) Close() error {
close(l.accept)
l.accept = nil
return nil
}
type bufferedConn struct {
net.Conn
r *bufio.Reader
}
func (b bufferedConn) Peek(n int) ([]byte, error) {
return b.r.Peek(n)
}
func (b bufferedConn) Read(p []byte) (int, error) {
return b.r.Read(p)
}