98 lines
2.1 KiB
Go
98 lines
2.1 KiB
Go
package rebound
|
|
|
|
import (
|
|
"fmt"
|
|
"io"
|
|
"net"
|
|
"os"
|
|
|
|
"github.com/gliderlabs/ssh"
|
|
"github.com/pkg/errors"
|
|
gossh "golang.org/x/crypto/ssh"
|
|
)
|
|
|
|
// direct-tcpip data struct as specified in RFC4254, Section 7.2
|
|
type localForwardChannelData struct {
|
|
DestAddr string
|
|
DestPort uint32
|
|
|
|
OriginAddr string
|
|
OriginPort uint32
|
|
}
|
|
|
|
func (s *Server) handleDirectTCP(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) {
|
|
d := localForwardChannelData{}
|
|
if err := gossh.Unmarshal(newChan.ExtraData(), &d); err != nil {
|
|
s.log("[ERROR] %+v", errors.WithStack(err))
|
|
newChan.Reject(gossh.ConnectionFailed, "error parsing forward data: "+err.Error())
|
|
return
|
|
}
|
|
|
|
sessionID := SessionID(ctx.SessionID())
|
|
|
|
s.log("(%s): opening direct tcp tunnel", sessionID)
|
|
|
|
s.sessionManager.Set(sessionID, SessionData{
|
|
Type: TypeServiceConsumer,
|
|
})
|
|
|
|
sessionToken := ctx.User()
|
|
|
|
remoteSessionID := s.sessionManager.FindByToken(sessionToken)
|
|
if remoteSessionID == "" {
|
|
newChan.Reject(gossh.ConnectionFailed, fmt.Sprintf("could not find session associated with token '%s'", sessionToken))
|
|
conn.Close()
|
|
return
|
|
}
|
|
|
|
addr := s.getSocketPath(remoteSessionID)
|
|
|
|
s.log("(%s): using sock '%s'", sessionID, addr)
|
|
|
|
if _, err := os.Stat(addr); err != nil {
|
|
newChan.Reject(gossh.ConnectionFailed, fmt.Sprintf("could not find session associated with token '%s'", sessionToken))
|
|
conn.Close()
|
|
return
|
|
}
|
|
|
|
var dialer net.Dialer
|
|
dconn, err := dialer.DialContext(ctx, "unix", addr)
|
|
if err != nil {
|
|
s.log("[ERROR] %+v", errors.WithStack(err))
|
|
newChan.Reject(gossh.ConnectionFailed, err.Error())
|
|
return
|
|
}
|
|
|
|
ch, reqs, err := newChan.Accept()
|
|
if err != nil {
|
|
s.log("[ERROR] %+v", errors.WithStack(err))
|
|
dconn.Close()
|
|
|
|
return
|
|
}
|
|
|
|
go gossh.DiscardRequests(reqs)
|
|
|
|
go func() {
|
|
defer dconn.Close()
|
|
defer ch.Close()
|
|
|
|
if _, err := io.Copy(ch, dconn); err != nil {
|
|
if errors.Is(err, net.ErrClosed) {
|
|
return
|
|
}
|
|
|
|
s.log("[ERROR] %+v", errors.WithStack(err))
|
|
}
|
|
}()
|
|
|
|
go func() {
|
|
defer dconn.Close()
|
|
defer ch.Close()
|
|
|
|
if _, err := io.Copy(dconn, ch); err != nil {
|
|
s.log("[ERROR] %+v", errors.WithStack(err))
|
|
}
|
|
}()
|
|
}
|