142 lines
2.7 KiB
Go
142 lines
2.7 KiB
Go
|
package rebound
|
||
|
|
||
|
import (
|
||
|
"crypto/rand"
|
||
|
"crypto/rsa"
|
||
|
"crypto/x509"
|
||
|
"encoding/pem"
|
||
|
"io/fs"
|
||
|
"net"
|
||
|
"os"
|
||
|
"sync"
|
||
|
"time"
|
||
|
|
||
|
"github.com/gliderlabs/ssh"
|
||
|
"github.com/pkg/errors"
|
||
|
|
||
|
gossh "golang.org/x/crypto/ssh"
|
||
|
)
|
||
|
|
||
|
type Server struct {
|
||
|
ssh *ssh.Server
|
||
|
opts *Options
|
||
|
|
||
|
sessionManager *SessionManager
|
||
|
forwards map[string]net.Listener
|
||
|
requestHandlerLock sync.Mutex
|
||
|
}
|
||
|
|
||
|
func (s *Server) Run() error {
|
||
|
if err := s.initSSHServer(); err != nil {
|
||
|
return errors.WithStack(err)
|
||
|
}
|
||
|
|
||
|
s.log("listening on %s", s.opts.Address)
|
||
|
|
||
|
if err := s.ssh.ListenAndServe(); err != nil {
|
||
|
return errors.WithStack(err)
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (s *Server) Stop() error {
|
||
|
if s.ssh == nil {
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
s.log("stopping on '%s'", s.opts.Address)
|
||
|
|
||
|
if err := s.ssh.Close(); err != nil {
|
||
|
return errors.WithStack(err)
|
||
|
}
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (s *Server) initSSHServer() error {
|
||
|
signer, err := s.loadOrCreateSigner()
|
||
|
if err != nil {
|
||
|
return errors.WithStack(err)
|
||
|
}
|
||
|
|
||
|
server := &ssh.Server{
|
||
|
Addr: s.opts.Address,
|
||
|
HostSigners: []ssh.Signer{signer},
|
||
|
Handler: ssh.Handler(s.handleSession),
|
||
|
RequestHandlers: map[string]ssh.RequestHandler{
|
||
|
"tcpip-forward": s.handleRequest,
|
||
|
"cancel-tcpip-forward": s.handleRequest,
|
||
|
},
|
||
|
ChannelHandlers: map[string]ssh.ChannelHandler{
|
||
|
"direct-tcpip": s.handleDirectTCP,
|
||
|
"session": ssh.DefaultSessionHandler,
|
||
|
},
|
||
|
}
|
||
|
|
||
|
s.ssh = server
|
||
|
|
||
|
return nil
|
||
|
}
|
||
|
|
||
|
func (s *Server) loadOrCreateSigner() (ssh.Signer, error) {
|
||
|
var (
|
||
|
signer gossh.Signer
|
||
|
)
|
||
|
|
||
|
s.log("reading host key from '%s'", s.opts.HostKey)
|
||
|
|
||
|
data, err := os.ReadFile(s.opts.HostKey)
|
||
|
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||
|
return nil, errors.WithStack(err)
|
||
|
}
|
||
|
|
||
|
if data == nil {
|
||
|
s.log("host key cannot be found, generating one")
|
||
|
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||
|
if err != nil {
|
||
|
return nil, errors.WithStack(err)
|
||
|
}
|
||
|
|
||
|
signer, err = gossh.NewSignerFromKey(key)
|
||
|
if err != nil {
|
||
|
return nil, errors.WithStack(err)
|
||
|
}
|
||
|
|
||
|
pem := pem.EncodeToMemory(
|
||
|
&pem.Block{
|
||
|
Type: "RSA PRIVATE KEY",
|
||
|
Bytes: x509.MarshalPKCS1PrivateKey(key),
|
||
|
},
|
||
|
)
|
||
|
|
||
|
s.log("saving host key to '%s'", s.opts.HostKey)
|
||
|
if err := os.WriteFile(s.opts.HostKey, pem, fs.FileMode(0640)); err != nil {
|
||
|
return nil, errors.WithStack(err)
|
||
|
}
|
||
|
} else {
|
||
|
signer, err = gossh.ParsePrivateKey(data)
|
||
|
if err != nil {
|
||
|
return nil, errors.WithStack(err)
|
||
|
}
|
||
|
}
|
||
|
|
||
|
return signer, nil
|
||
|
}
|
||
|
|
||
|
func (s *Server) log(message string, args ...any) {
|
||
|
s.opts.Logger(message, args...)
|
||
|
}
|
||
|
|
||
|
func NewServer(funcs ...OptionFunc) *Server {
|
||
|
opts := DefaultOptions()
|
||
|
for _, fn := range funcs {
|
||
|
fn(opts)
|
||
|
}
|
||
|
|
||
|
return &Server{
|
||
|
opts: opts,
|
||
|
sessionManager: NewSessionManager(30 * time.Second),
|
||
|
}
|
||
|
}
|