package ssh import ( "bytes" "fmt" "io" "log" "strings" "text/template" _ "embed" "github.com/gliderlabs/ssh" "github.com/pkg/errors" ) //go:embed diagram.txt var asciiDiagram string var asciiDiagramTmpl = template.Must(template.New("").Parse(asciiDiagram)) func (s *Server) handleSession(sess ssh.Session) { ctx := sess.Context() sessionID := SessionID(ctx.SessionID()) s.log("(%s): session opened", sessionID) message := ` Welcome on Rebound ! ` if _, err := sess.Write([]byte(message)); err != nil { s.log("[ERROR] %+v", errors.WithStack(err)) } go s.readClientInput(sess) data := s.sessionManager.Get(sessionID, SessionData{ Type: TypeServiceUnknown, }) s.handleSessionData(sess, data) onUpdate, close := s.sessionManager.OnUpdate(sessionID) defer close() for { data, opened := <-onUpdate if !opened { return } s.log("(%s): session data updated: %v", sessionID, data) if err := s.handleSessionData(sess, data); err != nil { if errors.Is(err, io.EOF) { return } s.log("[ERROR] %+v", errors.WithStack(err)) return } } } func (s *Server) handleSessionData(sess ssh.Session, data SessionData) error { switch data.Type { case TypeServiceConsumer: if err := s.writeConsumerMessage(sess, data); err != nil { return errors.WithStack(err) } case TypeServiceProvider: if err := s.writeProviderMessage(sess, data); err != nil { return errors.WithStack(err) } } return nil } func (s *Server) writeConsumerMessage(sess ssh.Session, data SessionData) error { message := ` Type Ctrl+C or Ctrl+D to exit. ` if _, err := sess.Write([]byte(message)); err != nil { return errors.WithStack(err) } return nil } func (s *Server) writeProviderMessage(sess ssh.Session, data SessionData) error { var diagramBuff bytes.Buffer hostname := s.opts.PublicHost if len(hostname) < 24 { halfPadding := (24 - len(hostname)) / 2 leftPadding := strings.Repeat(" ", halfPadding) rightPadding := strings.Repeat(" ", halfPadding) hostname = fmt.Sprintf("%s%s%s", leftPadding, hostname, rightPadding) if len(hostname) > 24 { hostname = hostname[0:23] } } else if len(hostname) >= 24 { hostname = hostname[0:20] + "..." } log.Printf("'%s'", hostname) tmplData := struct { Pp string Rp string Hostname string }{ Pp: fmt.Sprintf("%04d", s.opts.PublicPort), Rp: "", Hostname: hostname, } if err := asciiDiagramTmpl.Execute(&diagramBuff, tmplData); err != nil { return errors.WithStack(err) } message := fmt.Sprintf(` You can connect to your tunnel by running in an other terminal: ssh -L :127.0.0.1:1 %s@%s -p %d %s Type Ctrl+C or Ctrl+D to exit. `, data.Token, s.opts.PublicHost, s.opts.PublicPort, diagramBuff.String()) if _, err := sess.Write([]byte(message)); err != nil { return errors.WithStack(err) } return nil } const ( CtrlC = 3 CtrlD = 4 ) func (s *Server) readClientInput(sess ssh.Session) { sessionID := SessionID(sess.Context().SessionID()) defer func() { s.sessionManager.Remove(sessionID) }() buff := make([]byte, 1) for { _, err := sess.Read(buff) if err != nil { if !errors.Is(err, io.EOF) { s.log("[ERROR] %+v", errors.WithStack(err)) } return } switch buff[0] { case CtrlC: fallthrough case CtrlD: sess.Exit(0) return default: s.log("(%s) user input: %v", sessionID, buff) } } }