101 lines
1.8 KiB
Go
101 lines
1.8 KiB
Go
package rebound
|
|
|
|
import (
|
|
"bufio"
|
|
"net"
|
|
"time"
|
|
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
func (s *Server) muxListener(l net.Listener) (ssh net.Listener, other net.Listener) {
|
|
sshListener, otherListener := newListener(l), newListener(l)
|
|
go func() {
|
|
for {
|
|
conn, err := l.Accept()
|
|
if err != nil {
|
|
if errors.Is(err, net.ErrClosed) {
|
|
return
|
|
}
|
|
|
|
s.log("[ERROR] %+v", errors.WithStack(err))
|
|
|
|
continue
|
|
}
|
|
|
|
if err := conn.SetReadDeadline(time.Now().Add(time.Second * 10)); err != nil {
|
|
s.log("[ERROR] %+v", errors.WithStack(err))
|
|
|
|
continue
|
|
}
|
|
|
|
bconn := bufferedConn{conn, bufio.NewReaderSize(conn, 3)}
|
|
|
|
p, err := bconn.Peek(3)
|
|
if err != nil {
|
|
s.log("[ERROR] %+v", errors.WithStack(err))
|
|
|
|
continue
|
|
}
|
|
|
|
if err := conn.SetReadDeadline(time.Time{}); err != nil {
|
|
s.log("[ERROR] %+v", errors.WithStack(err))
|
|
|
|
continue
|
|
}
|
|
|
|
selectedListener := otherListener
|
|
if prefix := string(p); prefix == "SSH" {
|
|
s.log("[INFO] new ssh connection from '%s'", conn.RemoteAddr())
|
|
selectedListener = sshListener
|
|
} else {
|
|
s.log("[INFO] new http connection from '%s'", conn.RemoteAddr())
|
|
}
|
|
|
|
if selectedListener.accept != nil {
|
|
selectedListener.accept <- bconn
|
|
}
|
|
}
|
|
}()
|
|
|
|
return sshListener, otherListener
|
|
}
|
|
|
|
type listener struct {
|
|
accept chan net.Conn
|
|
net.Listener
|
|
}
|
|
|
|
func newListener(l net.Listener) *listener {
|
|
return &listener{
|
|
make(chan net.Conn),
|
|
l,
|
|
}
|
|
}
|
|
|
|
func (l *listener) Accept() (net.Conn, error) {
|
|
if l.accept == nil {
|
|
return nil, errors.New("listener closed")
|
|
}
|
|
return <-l.accept, nil
|
|
}
|
|
|
|
func (l *listener) Close() error {
|
|
close(l.accept)
|
|
l.accept = nil
|
|
return nil
|
|
}
|
|
|
|
type bufferedConn struct {
|
|
net.Conn
|
|
r *bufio.Reader
|
|
}
|
|
|
|
func (b bufferedConn) Peek(n int) ([]byte, error) {
|
|
return b.r.Peek(n)
|
|
}
|
|
|
|
func (b bufferedConn) Read(p []byte) (int, error) {
|
|
return b.r.Read(p)
|
|
}
|