rebound/direct_tcp_handler.go

98 lines
2.1 KiB
Go
Raw Normal View History

2023-09-09 04:00:00 +02:00
package rebound
import (
"fmt"
"io"
"net"
"os"
"github.com/gliderlabs/ssh"
"github.com/pkg/errors"
gossh "golang.org/x/crypto/ssh"
)
// direct-tcpip data struct as specified in RFC4254, Section 7.2
type localForwardChannelData struct {
DestAddr string
DestPort uint32
OriginAddr string
OriginPort uint32
}
func (s *Server) handleDirectTCP(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) {
d := localForwardChannelData{}
if err := gossh.Unmarshal(newChan.ExtraData(), &d); err != nil {
s.log("[ERROR] %+v", errors.WithStack(err))
newChan.Reject(gossh.ConnectionFailed, "error parsing forward data: "+err.Error())
return
}
sessionID := SessionID(ctx.SessionID())
s.log("(%s): opening direct tcp tunnel", sessionID)
s.sessionManager.Set(sessionID, SessionData{
Type: TypeServiceConsumer,
})
sessionToken := ctx.User()
remoteSessionID := s.sessionManager.FindByToken(sessionToken)
if remoteSessionID == "" {
newChan.Reject(gossh.ConnectionFailed, fmt.Sprintf("could not find session associated with token '%s'", sessionToken))
conn.Close()
return
}
addr := s.getSocketPath(remoteSessionID)
s.log("(%s): using sock '%s'", sessionID, addr)
if _, err := os.Stat(addr); err != nil {
newChan.Reject(gossh.ConnectionFailed, fmt.Sprintf("could not find session associated with token '%s'", sessionToken))
conn.Close()
return
}
var dialer net.Dialer
dconn, err := dialer.DialContext(ctx, "unix", addr)
if err != nil {
s.log("[ERROR] %+v", errors.WithStack(err))
newChan.Reject(gossh.ConnectionFailed, err.Error())
return
}
ch, reqs, err := newChan.Accept()
if err != nil {
s.log("[ERROR] %+v", errors.WithStack(err))
dconn.Close()
return
}
go gossh.DiscardRequests(reqs)
go func() {
defer dconn.Close()
defer ch.Close()
if _, err := io.Copy(ch, dconn); err != nil {
if errors.Is(err, net.ErrClosed) {
return
}
s.log("[ERROR] %+v", errors.WithStack(err))
}
}()
go func() {
defer dconn.Close()
defer ch.Close()
if _, err := io.Copy(dconn, ch); err != nil {
s.log("[ERROR] %+v", errors.WithStack(err))
}
}()
}