178 lines
3.3 KiB
Go
178 lines
3.3 KiB
Go
package rebound
|
|
|
|
import (
|
|
"bytes"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"strings"
|
|
"text/template"
|
|
|
|
_ "embed"
|
|
|
|
"github.com/gliderlabs/ssh"
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
//go:embed diagram.txt
|
|
var asciiDiagram string
|
|
var asciiDiagramTmpl = template.Must(template.New("").Parse(asciiDiagram))
|
|
|
|
func (s *Server) handleSession(sess ssh.Session) {
|
|
ctx := sess.Context()
|
|
sessionID := SessionID(ctx.SessionID())
|
|
|
|
s.log("(%s): session opened", sessionID)
|
|
|
|
message := `
|
|
Welcome on Rebound !
|
|
`
|
|
|
|
if _, err := sess.Write([]byte(message)); err != nil {
|
|
s.log("[ERROR] %+v", errors.WithStack(err))
|
|
}
|
|
|
|
go s.readClientInput(sess)
|
|
|
|
data := s.sessionManager.Get(sessionID, SessionData{
|
|
Type: TypeServiceUnknown,
|
|
})
|
|
s.handleSessionData(sess, data)
|
|
|
|
onUpdate, close := s.sessionManager.OnUpdate(sessionID)
|
|
defer close()
|
|
|
|
for {
|
|
data, opened := <-onUpdate
|
|
if !opened {
|
|
return
|
|
}
|
|
|
|
s.log("(%s): session data updated: %v", sessionID, data)
|
|
|
|
if err := s.handleSessionData(sess, data); err != nil {
|
|
if errors.Is(err, io.EOF) {
|
|
return
|
|
}
|
|
|
|
s.log("[ERROR] %+v", errors.WithStack(err))
|
|
|
|
return
|
|
}
|
|
}
|
|
}
|
|
|
|
func (s *Server) handleSessionData(sess ssh.Session, data SessionData) error {
|
|
switch data.Type {
|
|
case TypeServiceConsumer:
|
|
if err := s.writeConsumerMessage(sess, data); err != nil {
|
|
return errors.WithStack(err)
|
|
}
|
|
|
|
case TypeServiceProvider:
|
|
if err := s.writeProviderMessage(sess, data); err != nil {
|
|
return errors.WithStack(err)
|
|
}
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *Server) writeConsumerMessage(sess ssh.Session, data SessionData) error {
|
|
message := `
|
|
Type Ctrl+C or Ctrl+D to exit.
|
|
`
|
|
|
|
if _, err := sess.Write([]byte(message)); err != nil {
|
|
return errors.WithStack(err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (s *Server) writeProviderMessage(sess ssh.Session, data SessionData) error {
|
|
var diagramBuff bytes.Buffer
|
|
|
|
hostname := s.opts.PublicHost
|
|
if len(hostname) < 24 {
|
|
halfPadding := (24 - len(hostname)) / 2
|
|
leftPadding := strings.Repeat(" ", halfPadding)
|
|
rightPadding := strings.Repeat(" ", halfPadding)
|
|
hostname = fmt.Sprintf("%s%s%s", leftPadding, hostname, rightPadding)
|
|
if len(hostname) > 24 {
|
|
hostname = hostname[0:23]
|
|
}
|
|
} else if len(hostname) >= 24 {
|
|
hostname = hostname[0:20] + "..."
|
|
}
|
|
|
|
log.Printf("'%s'", hostname)
|
|
|
|
tmplData := struct {
|
|
Pp string
|
|
Rp string
|
|
Hostname string
|
|
}{
|
|
Pp: fmt.Sprintf("%04d", s.opts.PublicPort),
|
|
Rp: "<port>",
|
|
Hostname: hostname,
|
|
}
|
|
|
|
if err := asciiDiagramTmpl.Execute(&diagramBuff, tmplData); err != nil {
|
|
return errors.WithStack(err)
|
|
}
|
|
|
|
message := fmt.Sprintf(`
|
|
You can connect to your tunnel by running in an other terminal:
|
|
|
|
ssh -L <port>:127.0.0.1:1 %s@%s -p %d
|
|
|
|
%s
|
|
|
|
Type Ctrl+C or Ctrl+D to exit.
|
|
`, data.Token, s.opts.PublicHost, s.opts.PublicPort, diagramBuff.String())
|
|
|
|
if _, err := sess.Write([]byte(message)); err != nil {
|
|
return errors.WithStack(err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
const (
|
|
CtrlC = 3
|
|
CtrlD = 4
|
|
)
|
|
|
|
func (s *Server) readClientInput(sess ssh.Session) {
|
|
sessionID := SessionID(sess.Context().SessionID())
|
|
defer func() {
|
|
s.sessionManager.Remove(sessionID)
|
|
}()
|
|
|
|
buff := make([]byte, 1)
|
|
|
|
for {
|
|
_, err := sess.Read(buff)
|
|
if err != nil {
|
|
if !errors.Is(err, io.EOF) {
|
|
s.log("[ERROR] %+v", errors.WithStack(err))
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
switch buff[0] {
|
|
case CtrlC:
|
|
fallthrough
|
|
case CtrlD:
|
|
sess.Exit(0)
|
|
return
|
|
|
|
default:
|
|
s.log("(%s) user input: %v", sessionID, buff)
|
|
}
|
|
|
|
}
|
|
}
|