package ssh import ( "crypto/rand" "crypto/rsa" "crypto/x509" "encoding/pem" "io/fs" "net" "os" "sync" "time" "github.com/gliderlabs/ssh" "github.com/pkg/errors" gossh "golang.org/x/crypto/ssh" ) type Server struct { ssh *ssh.Server opts *Options sessionManager *SessionManager forwards map[string]net.Listener requestHandlerLock sync.Mutex } func (s *Server) Serve(l net.Listener) error { if err := s.initSSHServer(); err != nil { return errors.WithStack(err) } if err := s.ssh.Serve(l); err != nil { s.log("[ERROR] %+v", errors.WithStack(err)) } return nil } func (s *Server) initSSHServer() error { signer, err := s.loadOrCreateSigner() if err != nil { return errors.WithStack(err) } server := &ssh.Server{ HostSigners: []ssh.Signer{signer}, Handler: ssh.Handler(s.handleSession), RequestHandlers: map[string]ssh.RequestHandler{ "tcpip-forward": s.handleRequest, "cancel-tcpip-forward": s.handleRequest, }, ChannelHandlers: map[string]ssh.ChannelHandler{ "direct-tcpip": s.handleDirectTCP, "session": ssh.DefaultSessionHandler, }, } s.ssh = server return nil } func (s *Server) loadOrCreateSigner() (ssh.Signer, error) { var ( signer gossh.Signer ) s.log("reading host key from '%s'", s.opts.HostKey) data, err := os.ReadFile(s.opts.HostKey) if err != nil && !errors.Is(err, os.ErrNotExist) { return nil, errors.WithStack(err) } if data == nil { s.log("host key cannot be found, generating one") key, err := rsa.GenerateKey(rand.Reader, 2048) if err != nil { return nil, errors.WithStack(err) } signer, err = gossh.NewSignerFromKey(key) if err != nil { return nil, errors.WithStack(err) } pem := pem.EncodeToMemory( &pem.Block{ Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(key), }, ) s.log("saving host key to '%s'", s.opts.HostKey) if err := os.WriteFile(s.opts.HostKey, pem, fs.FileMode(0640)); err != nil { return nil, errors.WithStack(err) } } else { signer, err = gossh.ParsePrivateKey(data) if err != nil { return nil, errors.WithStack(err) } } return signer, nil } func (s *Server) log(message string, args ...any) { s.opts.Logger(message, args...) } func NewServer(funcs ...OptionFunc) *Server { opts := DefaultOptions() for _, fn := range funcs { fn(opts) } return &Server{ opts: opts, sessionManager: NewSessionManager(30 * time.Second), } }