package ssh import ( "fmt" "io" "net" "os" "github.com/gliderlabs/ssh" "github.com/pkg/errors" gossh "golang.org/x/crypto/ssh" ) // direct-tcpip data struct as specified in RFC4254, Section 7.2 type localForwardChannelData struct { DestAddr string DestPort uint32 OriginAddr string OriginPort uint32 } func (s *Server) handleDirectTCP(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) { d := localForwardChannelData{} if err := gossh.Unmarshal(newChan.ExtraData(), &d); err != nil { s.log("[ERROR] %+v", errors.WithStack(err)) newChan.Reject(gossh.ConnectionFailed, "error parsing forward data: "+err.Error()) return } sessionID := SessionID(ctx.SessionID()) s.log("(%s): opening direct tcp tunnel", sessionID) s.sessionManager.Set(sessionID, SessionData{ Type: TypeServiceConsumer, }) sessionToken := ctx.User() remoteSessionID := s.sessionManager.FindByToken(sessionToken) if remoteSessionID == "" { newChan.Reject(gossh.ConnectionFailed, fmt.Sprintf("could not find session associated with token '%s'", sessionToken)) conn.Close() return } addr := s.getSocketPath(remoteSessionID) s.log("(%s): using sock '%s'", sessionID, addr) if _, err := os.Stat(addr); err != nil { newChan.Reject(gossh.ConnectionFailed, fmt.Sprintf("could not find session associated with token '%s'", sessionToken)) conn.Close() return } var dialer net.Dialer dconn, err := dialer.DialContext(ctx, "unix", addr) if err != nil { s.log("[ERROR] %+v", errors.WithStack(err)) newChan.Reject(gossh.ConnectionFailed, err.Error()) return } ch, reqs, err := newChan.Accept() if err != nil { s.log("[ERROR] %+v", errors.WithStack(err)) dconn.Close() return } go gossh.DiscardRequests(reqs) go func() { defer dconn.Close() defer ch.Close() if _, err := io.Copy(ch, dconn); err != nil { if errors.Is(err, net.ErrClosed) { return } s.log("[ERROR] %+v", errors.WithStack(err)) } }() go func() { defer dconn.Close() defer ch.Close() if _, err := io.Copy(dconn, ch); err != nil { s.log("[ERROR] %+v", errors.WithStack(err)) } }() }