rebound/ssh/server.go

125 lines
2.4 KiB
Go

package ssh
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"io/fs"
"net"
"os"
"sync"
"time"
"github.com/gliderlabs/ssh"
"github.com/pkg/errors"
gossh "golang.org/x/crypto/ssh"
)
type Server struct {
ssh *ssh.Server
opts *Options
sessionManager *SessionManager
forwards map[string]net.Listener
requestHandlerLock sync.Mutex
}
func (s *Server) Serve(l net.Listener) error {
if err := s.initSSHServer(); err != nil {
return errors.WithStack(err)
}
if err := s.ssh.Serve(l); err != nil {
s.log("[ERROR] %+v", errors.WithStack(err))
}
return nil
}
func (s *Server) initSSHServer() error {
signer, err := s.loadOrCreateSigner()
if err != nil {
return errors.WithStack(err)
}
server := &ssh.Server{
HostSigners: []ssh.Signer{signer},
Handler: ssh.Handler(s.handleSession),
RequestHandlers: map[string]ssh.RequestHandler{
"tcpip-forward": s.handleRequest,
"cancel-tcpip-forward": s.handleRequest,
},
ChannelHandlers: map[string]ssh.ChannelHandler{
"direct-tcpip": s.handleDirectTCP,
"session": ssh.DefaultSessionHandler,
},
}
s.ssh = server
return nil
}
func (s *Server) loadOrCreateSigner() (ssh.Signer, error) {
var (
signer gossh.Signer
)
s.log("reading host key from '%s'", s.opts.HostKey)
data, err := os.ReadFile(s.opts.HostKey)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return nil, errors.WithStack(err)
}
if data == nil {
s.log("host key cannot be found, generating one")
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, errors.WithStack(err)
}
signer, err = gossh.NewSignerFromKey(key)
if err != nil {
return nil, errors.WithStack(err)
}
pem := pem.EncodeToMemory(
&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(key),
},
)
s.log("saving host key to '%s'", s.opts.HostKey)
if err := os.WriteFile(s.opts.HostKey, pem, fs.FileMode(0640)); err != nil {
return nil, errors.WithStack(err)
}
} else {
signer, err = gossh.ParsePrivateKey(data)
if err != nil {
return nil, errors.WithStack(err)
}
}
return signer, nil
}
func (s *Server) log(message string, args ...any) {
s.opts.Logger(message, args...)
}
func NewServer(funcs ...OptionFunc) *Server {
opts := DefaultOptions()
for _, fn := range funcs {
fn(opts)
}
return &Server{
opts: opts,
sessionManager: NewSessionManager(30 * time.Second),
}
}