feat: add http server with a page serving service informations

This commit is contained in:
2023-09-20 21:54:08 -06:00
parent 56d7174b96
commit fdaffca43f
23 changed files with 605 additions and 135 deletions

7
ssh/diagram.txt Normal file
View File

@ -0,0 +1,7 @@
NAT NAT
My machine | {{ .Hostname }} | Remote Machine
+----------+ | +----------+ | +----------+
| |<local_service_port> | {{.Pp}}| |{{.Pp}} | {{.Rp}}| |
| +<-----------------------+----------->+ +<-------------+------->+ |
+----------+ | +----------+ | +----------+
| |

97
ssh/direct_tcp_handler.go Normal file
View File

@ -0,0 +1,97 @@
package ssh
import (
"fmt"
"io"
"net"
"os"
"github.com/gliderlabs/ssh"
"github.com/pkg/errors"
gossh "golang.org/x/crypto/ssh"
)
// direct-tcpip data struct as specified in RFC4254, Section 7.2
type localForwardChannelData struct {
DestAddr string
DestPort uint32
OriginAddr string
OriginPort uint32
}
func (s *Server) handleDirectTCP(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) {
d := localForwardChannelData{}
if err := gossh.Unmarshal(newChan.ExtraData(), &d); err != nil {
s.log("[ERROR] %+v", errors.WithStack(err))
newChan.Reject(gossh.ConnectionFailed, "error parsing forward data: "+err.Error())
return
}
sessionID := SessionID(ctx.SessionID())
s.log("(%s): opening direct tcp tunnel", sessionID)
s.sessionManager.Set(sessionID, SessionData{
Type: TypeServiceConsumer,
})
sessionToken := ctx.User()
remoteSessionID := s.sessionManager.FindByToken(sessionToken)
if remoteSessionID == "" {
newChan.Reject(gossh.ConnectionFailed, fmt.Sprintf("could not find session associated with token '%s'", sessionToken))
conn.Close()
return
}
addr := s.getSocketPath(remoteSessionID)
s.log("(%s): using sock '%s'", sessionID, addr)
if _, err := os.Stat(addr); err != nil {
newChan.Reject(gossh.ConnectionFailed, fmt.Sprintf("could not find session associated with token '%s'", sessionToken))
conn.Close()
return
}
var dialer net.Dialer
dconn, err := dialer.DialContext(ctx, "unix", addr)
if err != nil {
s.log("[ERROR] %+v", errors.WithStack(err))
newChan.Reject(gossh.ConnectionFailed, err.Error())
return
}
ch, reqs, err := newChan.Accept()
if err != nil {
s.log("[ERROR] %+v", errors.WithStack(err))
dconn.Close()
return
}
go gossh.DiscardRequests(reqs)
go func() {
defer dconn.Close()
defer ch.Close()
if _, err := io.Copy(ch, dconn); err != nil {
if errors.Is(err, net.ErrClosed) {
return
}
s.log("[ERROR] %+v", errors.WithStack(err))
}
}()
go func() {
defer dconn.Close()
defer ch.Close()
if _, err := io.Copy(dconn, ch); err != nil {
s.log("[ERROR] %+v", errors.WithStack(err))
}
}()
}

53
ssh/options.go Normal file
View File

@ -0,0 +1,53 @@
package ssh
import "log"
type Options struct {
Logger func(message string, args ...any)
SockDir string `env:"SOCK_DIR"`
PublicPort uint `env:"PUBLIC_PORT"`
PublicHost string `env:"PUBLIC_HOST"`
HostKey string `env:"HOST_KEY"`
}
type OptionFunc func(*Options)
func DefaultOptions() *Options {
return &Options{
SockDir: "./socks",
Logger: log.Printf,
PublicPort: 2222,
PublicHost: "127.0.0.1",
HostKey: "./host.key",
}
}
func WithSockDir(addr string) func(*Options) {
return func(opts *Options) {
opts.SockDir = addr
}
}
func WithPublicHost(host string) func(*Options) {
return func(opts *Options) {
opts.PublicHost = host
}
}
func WithPublicPort(port uint) func(*Options) {
return func(opts *Options) {
opts.PublicPort = port
}
}
func WithHostKey(key string) func(*Options) {
return func(opts *Options) {
opts.HostKey = key
}
}
func WithLogger(logger func(message string, args ...any)) func(*Options) {
return func(opts *Options) {
opts.Logger = logger
}
}

226
ssh/request_handler.go Normal file
View File

@ -0,0 +1,226 @@
package ssh
import (
"crypto/rand"
"fmt"
"io"
"log"
"math/big"
"net"
"os"
"path/filepath"
"strconv"
"strings"
"github.com/gliderlabs/ssh"
"github.com/pkg/errors"
gossh "golang.org/x/crypto/ssh"
)
const (
forwardedTCPChannelType = "forwarded-tcpip"
)
// direct-tcpip data struct as specified in RFC4254, Section 7.2
type remoteForwardRequest struct {
BindAddr string
BindPort uint32
}
type remoteForwardSuccess struct {
BindPort uint32
}
type remoteForwardCancelRequest struct {
BindAddr string
BindPort uint32
}
type remoteForwardChannelData struct {
DestAddr string
DestPort uint32
OriginAddr string
OriginPort uint32
}
func (s *Server) handleRequest(ctx ssh.Context, srv *ssh.Server, req *gossh.Request) (bool, []byte) {
s.requestHandlerLock.Lock()
if s.forwards == nil {
s.forwards = make(map[string]net.Listener)
}
s.requestHandlerLock.Unlock()
conn := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn)
switch req.Type {
case "tcpip-forward":
var reqPayload remoteForwardRequest
if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil {
log.Printf("[ERROR] %+v", errors.WithStack(err))
return false, []byte{}
}
if reqPayload.BindPort != 0 {
return false, []byte("bind port must 0")
}
s.log("(%s): opening reverse tunnel", ctx.SessionID())
token, err := generateToken(16)
if err != nil {
s.log("[ERROR] %+v", errors.WithStack(err))
return false, []byte("could not generate secret token")
}
sessionID := SessionID(ctx.SessionID())
s.sessionManager.Set(sessionID, SessionData{
Type: TypeServiceProvider,
Token: token,
})
addr := s.getSocketPath(sessionID)
ln, err := net.Listen("unix", addr)
if err != nil {
s.log("[ERROR] %+v", errors.WithStack(err))
return false, []byte{}
}
destPort := 1
s.requestHandlerLock.Lock()
s.forwards[addr] = ln
s.requestHandlerLock.Unlock()
cleanup := func() {
s.log("(%s): cleaning up session", sessionID)
s.sessionManager.Remove(sessionID)
if err := os.Remove(addr); err != nil && !errors.Is(err, os.ErrNotExist) {
s.log("[ERROR] %+v", errors.WithStack(err))
}
}
go func() {
defer cleanup()
<-ctx.Done()
s.requestHandlerLock.Lock()
ln, ok := s.forwards[addr]
s.requestHandlerLock.Unlock()
if ok {
ln.Close()
}
}()
go func() {
for {
c, err := ln.Accept()
if err != nil {
if !errors.Is(err, net.ErrClosed) {
s.log("[ERROR] %+v", errors.WithStack(err))
}
break
}
originAddr, orignPortStr, _ := net.SplitHostPort(c.RemoteAddr().String())
originPort, _ := strconv.Atoi(orignPortStr)
payload := gossh.Marshal(&remoteForwardChannelData{
DestAddr: reqPayload.BindAddr,
DestPort: uint32(destPort),
OriginAddr: originAddr,
OriginPort: uint32(originPort),
})
go func() {
ch, reqs, err := conn.OpenChannel(forwardedTCPChannelType, payload)
if err != nil {
s.log("[ERROR] %+v", errors.WithStack(err))
c.Close()
return
}
go gossh.DiscardRequests(reqs)
go func() {
defer ch.Close()
defer c.Close()
io.Copy(ch, c)
}()
go func() {
defer ch.Close()
defer c.Close()
io.Copy(c, ch)
}()
}()
}
s.requestHandlerLock.Lock()
delete(s.forwards, addr)
s.requestHandlerLock.Unlock()
}()
return true, gossh.Marshal(&remoteForwardSuccess{uint32(destPort)})
case "cancel-tcpip-forward":
var reqPayload remoteForwardCancelRequest
if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil {
s.log("[ERROR] %+v", errors.WithStack(err))
return false, []byte{}
}
sessionID := SessionID(ctx.SessionID())
addr := s.getSocketPath(sessionID)
s.log("(%s): closing sock '%s'", sessionID, addr)
s.requestHandlerLock.Lock()
ln, ok := s.forwards[addr]
s.requestHandlerLock.Unlock()
if ok {
ln.Close()
if err := os.Remove(addr); err != nil && !errors.Is(err, os.ErrNotExist) {
s.log("[ERROR] %+v", errors.WithStack(err))
}
}
return true, nil
default:
return false, nil
}
}
func (s *Server) getSocketPath(sessionID SessionID) string {
return filepath.Join(s.opts.SockDir, fmt.Sprintf("%s.sock", sessionID))
}
func generateToken(length int) (string, error) {
chars := []rune(
"ABCDEFGHIJKLMNOPQRSTUVWXYZ" +
"abcdefghijklmnopqrstuvwxyz" +
"0123456789",
)
var b strings.Builder
charsLength := big.NewInt(int64(len(chars)))
for i := 0; i < length; i++ {
idx, err := rand.Int(rand.Reader, charsLength)
if err != nil {
return "", errors.WithStack(err)
}
c := chars[idx.Int64()]
if _, err := b.WriteRune(c); err != nil {
return "", errors.WithStack(err)
}
}
return b.String(), nil
}

124
ssh/server.go Normal file
View File

@ -0,0 +1,124 @@
package ssh
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"io/fs"
"net"
"os"
"sync"
"time"
"github.com/gliderlabs/ssh"
"github.com/pkg/errors"
gossh "golang.org/x/crypto/ssh"
)
type Server struct {
ssh *ssh.Server
opts *Options
sessionManager *SessionManager
forwards map[string]net.Listener
requestHandlerLock sync.Mutex
}
func (s *Server) Serve(l net.Listener) error {
if err := s.initSSHServer(); err != nil {
return errors.WithStack(err)
}
if err := s.ssh.Serve(l); err != nil {
s.log("[ERROR] %+v", errors.WithStack(err))
}
return nil
}
func (s *Server) initSSHServer() error {
signer, err := s.loadOrCreateSigner()
if err != nil {
return errors.WithStack(err)
}
server := &ssh.Server{
HostSigners: []ssh.Signer{signer},
Handler: ssh.Handler(s.handleSession),
RequestHandlers: map[string]ssh.RequestHandler{
"tcpip-forward": s.handleRequest,
"cancel-tcpip-forward": s.handleRequest,
},
ChannelHandlers: map[string]ssh.ChannelHandler{
"direct-tcpip": s.handleDirectTCP,
"session": ssh.DefaultSessionHandler,
},
}
s.ssh = server
return nil
}
func (s *Server) loadOrCreateSigner() (ssh.Signer, error) {
var (
signer gossh.Signer
)
s.log("reading host key from '%s'", s.opts.HostKey)
data, err := os.ReadFile(s.opts.HostKey)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return nil, errors.WithStack(err)
}
if data == nil {
s.log("host key cannot be found, generating one")
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, errors.WithStack(err)
}
signer, err = gossh.NewSignerFromKey(key)
if err != nil {
return nil, errors.WithStack(err)
}
pem := pem.EncodeToMemory(
&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(key),
},
)
s.log("saving host key to '%s'", s.opts.HostKey)
if err := os.WriteFile(s.opts.HostKey, pem, fs.FileMode(0640)); err != nil {
return nil, errors.WithStack(err)
}
} else {
signer, err = gossh.ParsePrivateKey(data)
if err != nil {
return nil, errors.WithStack(err)
}
}
return signer, nil
}
func (s *Server) log(message string, args ...any) {
s.opts.Logger(message, args...)
}
func NewServer(funcs ...OptionFunc) *Server {
opts := DefaultOptions()
for _, fn := range funcs {
fn(opts)
}
return &Server{
opts: opts,
sessionManager: NewSessionManager(30 * time.Second),
}
}

177
ssh/session_handler.go Normal file
View File

@ -0,0 +1,177 @@
package ssh
import (
"bytes"
"fmt"
"io"
"log"
"strings"
"text/template"
_ "embed"
"github.com/gliderlabs/ssh"
"github.com/pkg/errors"
)
//go:embed diagram.txt
var asciiDiagram string
var asciiDiagramTmpl = template.Must(template.New("").Parse(asciiDiagram))
func (s *Server) handleSession(sess ssh.Session) {
ctx := sess.Context()
sessionID := SessionID(ctx.SessionID())
s.log("(%s): session opened", sessionID)
message := `
Welcome on Rebound !
`
if _, err := sess.Write([]byte(message)); err != nil {
s.log("[ERROR] %+v", errors.WithStack(err))
}
go s.readClientInput(sess)
data := s.sessionManager.Get(sessionID, SessionData{
Type: TypeServiceUnknown,
})
s.handleSessionData(sess, data)
onUpdate, close := s.sessionManager.OnUpdate(sessionID)
defer close()
for {
data, opened := <-onUpdate
if !opened {
return
}
s.log("(%s): session data updated: %v", sessionID, data)
if err := s.handleSessionData(sess, data); err != nil {
if errors.Is(err, io.EOF) {
return
}
s.log("[ERROR] %+v", errors.WithStack(err))
return
}
}
}
func (s *Server) handleSessionData(sess ssh.Session, data SessionData) error {
switch data.Type {
case TypeServiceConsumer:
if err := s.writeConsumerMessage(sess, data); err != nil {
return errors.WithStack(err)
}
case TypeServiceProvider:
if err := s.writeProviderMessage(sess, data); err != nil {
return errors.WithStack(err)
}
}
return nil
}
func (s *Server) writeConsumerMessage(sess ssh.Session, data SessionData) error {
message := `
Type Ctrl+C or Ctrl+D to exit.
`
if _, err := sess.Write([]byte(message)); err != nil {
return errors.WithStack(err)
}
return nil
}
func (s *Server) writeProviderMessage(sess ssh.Session, data SessionData) error {
var diagramBuff bytes.Buffer
hostname := s.opts.PublicHost
if len(hostname) < 24 {
halfPadding := (24 - len(hostname)) / 2
leftPadding := strings.Repeat(" ", halfPadding)
rightPadding := strings.Repeat(" ", halfPadding)
hostname = fmt.Sprintf("%s%s%s", leftPadding, hostname, rightPadding)
if len(hostname) > 24 {
hostname = hostname[0:23]
}
} else if len(hostname) >= 24 {
hostname = hostname[0:20] + "..."
}
log.Printf("'%s'", hostname)
tmplData := struct {
Pp string
Rp string
Hostname string
}{
Pp: fmt.Sprintf("%04d", s.opts.PublicPort),
Rp: "<port>",
Hostname: hostname,
}
if err := asciiDiagramTmpl.Execute(&diagramBuff, tmplData); err != nil {
return errors.WithStack(err)
}
message := fmt.Sprintf(`
You can connect to your tunnel by running in an other terminal:
ssh -L <port>:127.0.0.1:1 %s@%s -p %d
%s
Type Ctrl+C or Ctrl+D to exit.
`, data.Token, s.opts.PublicHost, s.opts.PublicPort, diagramBuff.String())
if _, err := sess.Write([]byte(message)); err != nil {
return errors.WithStack(err)
}
return nil
}
const (
CtrlC = 3
CtrlD = 4
)
func (s *Server) readClientInput(sess ssh.Session) {
sessionID := SessionID(sess.Context().SessionID())
defer func() {
s.sessionManager.Remove(sessionID)
}()
buff := make([]byte, 1)
for {
_, err := sess.Read(buff)
if err != nil {
if !errors.Is(err, io.EOF) {
s.log("[ERROR] %+v", errors.WithStack(err))
}
return
}
switch buff[0] {
case CtrlC:
fallthrough
case CtrlD:
sess.Exit(0)
return
default:
s.log("(%s) user input: %v", sessionID, buff)
}
}
}

184
ssh/session_manager.go Normal file
View File

@ -0,0 +1,184 @@
package ssh
import (
"errors"
"log"
"sync"
"time"
)
type SessionID string
type SessionType int
const (
TypeServiceUnknown SessionType = iota
TypeServiceProvider
TypeServiceConsumer
)
type SessionData struct {
Type SessionType
Token string
}
type SessionManager struct {
sessions map[SessionID]SessionData
sessionsMutex sync.Mutex
tokenIndex map[string]SessionID
updates map[SessionID][]chan SessionData
updatesMutex sync.Mutex
updateReadTimeout time.Duration
}
func (m *SessionManager) Get(sessID SessionID, defaultValue SessionData) SessionData {
log.Println("reading session", sessID)
m.sessionsMutex.Lock()
defer m.sessionsMutex.Unlock()
session, exists := m.sessions[sessID]
if !exists {
session = defaultValue
m.sessions[sessID] = session
m.updatesMutex.Lock()
m.dispatchUpdate(sessID, session)
m.updatesMutex.Unlock()
}
return session
}
func (m *SessionManager) FindByToken(token string) SessionID {
m.sessionsMutex.Lock()
defer m.sessionsMutex.Unlock()
sessID, exists := m.tokenIndex[token]
if !exists {
return ""
}
return sessID
}
func (m *SessionManager) Set(sessID SessionID, sess SessionData) {
log.Println("updating session", sessID, sess)
m.sessionsMutex.Lock()
oldSess, ok := m.sessions[sessID]
if ok {
m.updateTokenIndex(sessID, sess.Token, oldSess.Token)
} else {
m.updateTokenIndex(sessID, sess.Token, "")
}
m.sessions[sessID] = sess
m.sessionsMutex.Unlock()
m.updatesMutex.Lock()
m.dispatchUpdate(sessID, sess)
m.updatesMutex.Unlock()
}
func (m *SessionManager) Remove(sessID SessionID) {
m.sessionsMutex.Lock()
oldSess, ok := m.sessions[sessID]
if ok {
m.updateTokenIndex(sessID, "", oldSess.Token)
}
delete(m.sessions, sessID)
m.sessionsMutex.Unlock()
m.updatesMutex.Lock()
m.closeAllUpdates(sessID)
m.updatesMutex.Unlock()
}
func (m *SessionManager) OnUpdate(sessID SessionID) (<-chan SessionData, func()) {
update := make(chan SessionData)
close := func() {
m.updatesMutex.Lock()
m.closeUpdate(sessID, update)
m.updatesMutex.Unlock()
}
m.updatesMutex.Lock()
defer m.updatesMutex.Unlock()
channels, exists := m.updates[sessID]
if !exists {
channels = make([]chan SessionData, 0, 1)
}
channels = append(channels, update)
m.updates[sessID] = channels
return update, close
}
func (m *SessionManager) closeAllUpdates(sessID SessionID) {
channels, exists := m.updates[sessID]
if !exists {
return
}
for _, ch := range channels {
m.closeUpdate(sessID, ch)
}
}
func (m *SessionManager) closeUpdate(sessID SessionID, update chan SessionData) {
channels, exists := m.updates[sessID]
if !exists {
return
}
for idx, ch := range channels {
if ch != update {
continue
}
close(ch)
m.updates[sessID] = append(channels[:idx], channels[idx+1:]...)
}
}
func (m *SessionManager) dispatchUpdate(sessID SessionID, sess SessionData) {
channels, exists := m.updates[sessID]
if !exists {
return
}
for _, ch := range channels {
timeout := time.After(m.updateReadTimeout)
select {
case ch <- sess:
case <-timeout:
err := errors.New("session update read timed out")
log.Printf("[ERROR] %+v", err)
}
}
}
func (m *SessionManager) updateTokenIndex(sessID SessionID, addedToken, deletedToken string) {
if addedToken != "" {
m.tokenIndex[addedToken] = sessID
}
if deletedToken != "" {
delete(m.tokenIndex, deletedToken)
}
}
func NewSessionManager(updateReadTimeout time.Duration) *SessionManager {
return &SessionManager{
sessions: make(map[SessionID]SessionData),
tokenIndex: make(map[string]SessionID),
updates: make(map[SessionID][]chan SessionData),
updateReadTimeout: updateReadTimeout,
}
}