rebound/session_handler.go

134 lines
2.4 KiB
Go

package rebound
import (
"fmt"
"io"
"github.com/gliderlabs/ssh"
"github.com/pkg/errors"
)
func (s *Server) handleSession(sess ssh.Session) {
ctx := sess.Context()
sessionID := SessionID(ctx.SessionID())
s.log("(%s): session opened", sessionID)
message := `
Welcome on Rebound !
Type Ctrl+C or Ctrl+D to exit.
`
if _, err := sess.Write([]byte(message)); err != nil {
s.log("[ERROR] %+v", errors.WithStack(err))
}
go s.readClientInput(sess)
data := s.sessionManager.Get(sessionID, SessionData{
Type: TypeServiceUnknown,
})
s.handleSessionData(sess, data)
onUpdate, close := s.sessionManager.OnUpdate(sessionID)
defer close()
for {
data, opened := <-onUpdate
if !opened {
return
}
s.log("(%s): session data updated: %v", sessionID, data)
if err := s.handleSessionData(sess, data); err != nil {
if errors.Is(err, io.EOF) {
return
}
s.log("[ERROR] %+v", errors.WithStack(err))
return
}
}
}
func (s *Server) handleSessionData(sess ssh.Session, data SessionData) error {
switch data.Type {
case TypeServiceConsumer:
if err := s.writeConsumerMessage(sess, data); err != nil {
return errors.WithStack(err)
}
case TypeServiceProvider:
if err := s.writeProviderMessage(sess, data); err != nil {
return errors.WithStack(err)
}
}
return nil
}
func (s *Server) writeConsumerMessage(sess ssh.Session, data SessionData) error {
message := ``
if _, err := sess.Write([]byte(message)); err != nil {
return errors.WithStack(err)
}
return nil
}
func (s *Server) writeProviderMessage(sess ssh.Session, data SessionData) error {
message := fmt.Sprintf(`
You can connect to your tunnel by running in an other terminal:
ssh -L <local-port>:0.0.0.0:1 %s@%s -p %d
`, data.Token, s.opts.PublicHost, s.opts.PublicPort)
if _, err := sess.Write([]byte(message)); err != nil {
return errors.WithStack(err)
}
return nil
}
const (
CtrlC = 3
CtrlD = 4
)
func (s *Server) readClientInput(sess ssh.Session) {
sessionID := SessionID(sess.Context().SessionID())
defer func() {
s.sessionManager.Remove(sessionID)
}()
buff := make([]byte, 1)
for {
_, err := sess.Read(buff)
if err != nil {
if !errors.Is(err, io.EOF) {
s.log("[ERROR] %+v", errors.WithStack(err))
}
return
}
switch buff[0] {
case CtrlC:
fallthrough
case CtrlD:
sess.Exit(0)
return
default:
s.log("(%s) user input: %v", sessionID, buff)
}
}
}