feat: add http server with a page serving service informations
This commit is contained in:
7
ssh/diagram.txt
Normal file
7
ssh/diagram.txt
Normal file
@ -0,0 +1,7 @@
|
||||
NAT NAT
|
||||
My machine | {{ .Hostname }} | Remote Machine
|
||||
+----------+ | +----------+ | +----------+
|
||||
| |<local_service_port> | {{.Pp}}| |{{.Pp}} | {{.Rp}}| |
|
||||
| +<-----------------------+----------->+ +<-------------+------->+ |
|
||||
+----------+ | +----------+ | +----------+
|
||||
| |
|
97
ssh/direct_tcp_handler.go
Normal file
97
ssh/direct_tcp_handler.go
Normal file
@ -0,0 +1,97 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
"github.com/pkg/errors"
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// direct-tcpip data struct as specified in RFC4254, Section 7.2
|
||||
type localForwardChannelData struct {
|
||||
DestAddr string
|
||||
DestPort uint32
|
||||
|
||||
OriginAddr string
|
||||
OriginPort uint32
|
||||
}
|
||||
|
||||
func (s *Server) handleDirectTCP(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) {
|
||||
d := localForwardChannelData{}
|
||||
if err := gossh.Unmarshal(newChan.ExtraData(), &d); err != nil {
|
||||
s.log("[ERROR] %+v", errors.WithStack(err))
|
||||
newChan.Reject(gossh.ConnectionFailed, "error parsing forward data: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
sessionID := SessionID(ctx.SessionID())
|
||||
|
||||
s.log("(%s): opening direct tcp tunnel", sessionID)
|
||||
|
||||
s.sessionManager.Set(sessionID, SessionData{
|
||||
Type: TypeServiceConsumer,
|
||||
})
|
||||
|
||||
sessionToken := ctx.User()
|
||||
|
||||
remoteSessionID := s.sessionManager.FindByToken(sessionToken)
|
||||
if remoteSessionID == "" {
|
||||
newChan.Reject(gossh.ConnectionFailed, fmt.Sprintf("could not find session associated with token '%s'", sessionToken))
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
addr := s.getSocketPath(remoteSessionID)
|
||||
|
||||
s.log("(%s): using sock '%s'", sessionID, addr)
|
||||
|
||||
if _, err := os.Stat(addr); err != nil {
|
||||
newChan.Reject(gossh.ConnectionFailed, fmt.Sprintf("could not find session associated with token '%s'", sessionToken))
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
var dialer net.Dialer
|
||||
dconn, err := dialer.DialContext(ctx, "unix", addr)
|
||||
if err != nil {
|
||||
s.log("[ERROR] %+v", errors.WithStack(err))
|
||||
newChan.Reject(gossh.ConnectionFailed, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
ch, reqs, err := newChan.Accept()
|
||||
if err != nil {
|
||||
s.log("[ERROR] %+v", errors.WithStack(err))
|
||||
dconn.Close()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
go gossh.DiscardRequests(reqs)
|
||||
|
||||
go func() {
|
||||
defer dconn.Close()
|
||||
defer ch.Close()
|
||||
|
||||
if _, err := io.Copy(ch, dconn); err != nil {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return
|
||||
}
|
||||
|
||||
s.log("[ERROR] %+v", errors.WithStack(err))
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer dconn.Close()
|
||||
defer ch.Close()
|
||||
|
||||
if _, err := io.Copy(dconn, ch); err != nil {
|
||||
s.log("[ERROR] %+v", errors.WithStack(err))
|
||||
}
|
||||
}()
|
||||
}
|
53
ssh/options.go
Normal file
53
ssh/options.go
Normal file
@ -0,0 +1,53 @@
|
||||
package ssh
|
||||
|
||||
import "log"
|
||||
|
||||
type Options struct {
|
||||
Logger func(message string, args ...any)
|
||||
SockDir string `env:"SOCK_DIR"`
|
||||
PublicPort uint `env:"PUBLIC_PORT"`
|
||||
PublicHost string `env:"PUBLIC_HOST"`
|
||||
HostKey string `env:"HOST_KEY"`
|
||||
}
|
||||
|
||||
type OptionFunc func(*Options)
|
||||
|
||||
func DefaultOptions() *Options {
|
||||
return &Options{
|
||||
SockDir: "./socks",
|
||||
Logger: log.Printf,
|
||||
PublicPort: 2222,
|
||||
PublicHost: "127.0.0.1",
|
||||
HostKey: "./host.key",
|
||||
}
|
||||
}
|
||||
|
||||
func WithSockDir(addr string) func(*Options) {
|
||||
return func(opts *Options) {
|
||||
opts.SockDir = addr
|
||||
}
|
||||
}
|
||||
|
||||
func WithPublicHost(host string) func(*Options) {
|
||||
return func(opts *Options) {
|
||||
opts.PublicHost = host
|
||||
}
|
||||
}
|
||||
|
||||
func WithPublicPort(port uint) func(*Options) {
|
||||
return func(opts *Options) {
|
||||
opts.PublicPort = port
|
||||
}
|
||||
}
|
||||
|
||||
func WithHostKey(key string) func(*Options) {
|
||||
return func(opts *Options) {
|
||||
opts.HostKey = key
|
||||
}
|
||||
}
|
||||
|
||||
func WithLogger(logger func(message string, args ...any)) func(*Options) {
|
||||
return func(opts *Options) {
|
||||
opts.Logger = logger
|
||||
}
|
||||
}
|
226
ssh/request_handler.go
Normal file
226
ssh/request_handler.go
Normal file
@ -0,0 +1,226 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"math/big"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
"github.com/pkg/errors"
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
const (
|
||||
forwardedTCPChannelType = "forwarded-tcpip"
|
||||
)
|
||||
|
||||
// direct-tcpip data struct as specified in RFC4254, Section 7.2
|
||||
type remoteForwardRequest struct {
|
||||
BindAddr string
|
||||
BindPort uint32
|
||||
}
|
||||
|
||||
type remoteForwardSuccess struct {
|
||||
BindPort uint32
|
||||
}
|
||||
|
||||
type remoteForwardCancelRequest struct {
|
||||
BindAddr string
|
||||
BindPort uint32
|
||||
}
|
||||
|
||||
type remoteForwardChannelData struct {
|
||||
DestAddr string
|
||||
DestPort uint32
|
||||
OriginAddr string
|
||||
OriginPort uint32
|
||||
}
|
||||
|
||||
func (s *Server) handleRequest(ctx ssh.Context, srv *ssh.Server, req *gossh.Request) (bool, []byte) {
|
||||
s.requestHandlerLock.Lock()
|
||||
if s.forwards == nil {
|
||||
s.forwards = make(map[string]net.Listener)
|
||||
}
|
||||
s.requestHandlerLock.Unlock()
|
||||
conn := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn)
|
||||
switch req.Type {
|
||||
case "tcpip-forward":
|
||||
var reqPayload remoteForwardRequest
|
||||
|
||||
if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil {
|
||||
log.Printf("[ERROR] %+v", errors.WithStack(err))
|
||||
return false, []byte{}
|
||||
}
|
||||
|
||||
if reqPayload.BindPort != 0 {
|
||||
return false, []byte("bind port must 0")
|
||||
}
|
||||
|
||||
s.log("(%s): opening reverse tunnel", ctx.SessionID())
|
||||
|
||||
token, err := generateToken(16)
|
||||
if err != nil {
|
||||
s.log("[ERROR] %+v", errors.WithStack(err))
|
||||
return false, []byte("could not generate secret token")
|
||||
}
|
||||
|
||||
sessionID := SessionID(ctx.SessionID())
|
||||
|
||||
s.sessionManager.Set(sessionID, SessionData{
|
||||
Type: TypeServiceProvider,
|
||||
Token: token,
|
||||
})
|
||||
|
||||
addr := s.getSocketPath(sessionID)
|
||||
|
||||
ln, err := net.Listen("unix", addr)
|
||||
if err != nil {
|
||||
s.log("[ERROR] %+v", errors.WithStack(err))
|
||||
return false, []byte{}
|
||||
}
|
||||
|
||||
destPort := 1
|
||||
|
||||
s.requestHandlerLock.Lock()
|
||||
s.forwards[addr] = ln
|
||||
s.requestHandlerLock.Unlock()
|
||||
|
||||
cleanup := func() {
|
||||
s.log("(%s): cleaning up session", sessionID)
|
||||
|
||||
s.sessionManager.Remove(sessionID)
|
||||
|
||||
if err := os.Remove(addr); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
s.log("[ERROR] %+v", errors.WithStack(err))
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer cleanup()
|
||||
|
||||
<-ctx.Done()
|
||||
s.requestHandlerLock.Lock()
|
||||
ln, ok := s.forwards[addr]
|
||||
s.requestHandlerLock.Unlock()
|
||||
if ok {
|
||||
ln.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
c, err := ln.Accept()
|
||||
if err != nil {
|
||||
if !errors.Is(err, net.ErrClosed) {
|
||||
s.log("[ERROR] %+v", errors.WithStack(err))
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
originAddr, orignPortStr, _ := net.SplitHostPort(c.RemoteAddr().String())
|
||||
originPort, _ := strconv.Atoi(orignPortStr)
|
||||
payload := gossh.Marshal(&remoteForwardChannelData{
|
||||
DestAddr: reqPayload.BindAddr,
|
||||
DestPort: uint32(destPort),
|
||||
OriginAddr: originAddr,
|
||||
OriginPort: uint32(originPort),
|
||||
})
|
||||
|
||||
go func() {
|
||||
ch, reqs, err := conn.OpenChannel(forwardedTCPChannelType, payload)
|
||||
if err != nil {
|
||||
s.log("[ERROR] %+v", errors.WithStack(err))
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
|
||||
go gossh.DiscardRequests(reqs)
|
||||
|
||||
go func() {
|
||||
defer ch.Close()
|
||||
defer c.Close()
|
||||
|
||||
io.Copy(ch, c)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer ch.Close()
|
||||
defer c.Close()
|
||||
|
||||
io.Copy(c, ch)
|
||||
}()
|
||||
}()
|
||||
}
|
||||
|
||||
s.requestHandlerLock.Lock()
|
||||
delete(s.forwards, addr)
|
||||
s.requestHandlerLock.Unlock()
|
||||
}()
|
||||
|
||||
return true, gossh.Marshal(&remoteForwardSuccess{uint32(destPort)})
|
||||
|
||||
case "cancel-tcpip-forward":
|
||||
var reqPayload remoteForwardCancelRequest
|
||||
if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil {
|
||||
s.log("[ERROR] %+v", errors.WithStack(err))
|
||||
return false, []byte{}
|
||||
}
|
||||
|
||||
sessionID := SessionID(ctx.SessionID())
|
||||
|
||||
addr := s.getSocketPath(sessionID)
|
||||
|
||||
s.log("(%s): closing sock '%s'", sessionID, addr)
|
||||
|
||||
s.requestHandlerLock.Lock()
|
||||
ln, ok := s.forwards[addr]
|
||||
s.requestHandlerLock.Unlock()
|
||||
if ok {
|
||||
ln.Close()
|
||||
if err := os.Remove(addr); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
s.log("[ERROR] %+v", errors.WithStack(err))
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
|
||||
default:
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) getSocketPath(sessionID SessionID) string {
|
||||
return filepath.Join(s.opts.SockDir, fmt.Sprintf("%s.sock", sessionID))
|
||||
}
|
||||
|
||||
func generateToken(length int) (string, error) {
|
||||
chars := []rune(
|
||||
"ABCDEFGHIJKLMNOPQRSTUVWXYZ" +
|
||||
"abcdefghijklmnopqrstuvwxyz" +
|
||||
"0123456789",
|
||||
)
|
||||
var b strings.Builder
|
||||
|
||||
charsLength := big.NewInt(int64(len(chars)))
|
||||
|
||||
for i := 0; i < length; i++ {
|
||||
idx, err := rand.Int(rand.Reader, charsLength)
|
||||
if err != nil {
|
||||
return "", errors.WithStack(err)
|
||||
}
|
||||
|
||||
c := chars[idx.Int64()]
|
||||
if _, err := b.WriteRune(c); err != nil {
|
||||
return "", errors.WithStack(err)
|
||||
}
|
||||
}
|
||||
|
||||
return b.String(), nil
|
||||
}
|
124
ssh/server.go
Normal file
124
ssh/server.go
Normal file
@ -0,0 +1,124 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"io/fs"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
ssh *ssh.Server
|
||||
opts *Options
|
||||
|
||||
sessionManager *SessionManager
|
||||
forwards map[string]net.Listener
|
||||
requestHandlerLock sync.Mutex
|
||||
}
|
||||
|
||||
func (s *Server) Serve(l net.Listener) error {
|
||||
if err := s.initSSHServer(); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
if err := s.ssh.Serve(l); err != nil {
|
||||
s.log("[ERROR] %+v", errors.WithStack(err))
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) initSSHServer() error {
|
||||
signer, err := s.loadOrCreateSigner()
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
server := &ssh.Server{
|
||||
HostSigners: []ssh.Signer{signer},
|
||||
Handler: ssh.Handler(s.handleSession),
|
||||
RequestHandlers: map[string]ssh.RequestHandler{
|
||||
"tcpip-forward": s.handleRequest,
|
||||
"cancel-tcpip-forward": s.handleRequest,
|
||||
},
|
||||
ChannelHandlers: map[string]ssh.ChannelHandler{
|
||||
"direct-tcpip": s.handleDirectTCP,
|
||||
"session": ssh.DefaultSessionHandler,
|
||||
},
|
||||
}
|
||||
|
||||
s.ssh = server
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) loadOrCreateSigner() (ssh.Signer, error) {
|
||||
var (
|
||||
signer gossh.Signer
|
||||
)
|
||||
|
||||
s.log("reading host key from '%s'", s.opts.HostKey)
|
||||
|
||||
data, err := os.ReadFile(s.opts.HostKey)
|
||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
if data == nil {
|
||||
s.log("host key cannot be found, generating one")
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
signer, err = gossh.NewSignerFromKey(key)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
pem := pem.EncodeToMemory(
|
||||
&pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(key),
|
||||
},
|
||||
)
|
||||
|
||||
s.log("saving host key to '%s'", s.opts.HostKey)
|
||||
if err := os.WriteFile(s.opts.HostKey, pem, fs.FileMode(0640)); err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
} else {
|
||||
signer, err = gossh.ParsePrivateKey(data)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
}
|
||||
|
||||
return signer, nil
|
||||
}
|
||||
|
||||
func (s *Server) log(message string, args ...any) {
|
||||
s.opts.Logger(message, args...)
|
||||
}
|
||||
|
||||
func NewServer(funcs ...OptionFunc) *Server {
|
||||
opts := DefaultOptions()
|
||||
for _, fn := range funcs {
|
||||
fn(opts)
|
||||
}
|
||||
|
||||
return &Server{
|
||||
opts: opts,
|
||||
sessionManager: NewSessionManager(30 * time.Second),
|
||||
}
|
||||
}
|
177
ssh/session_handler.go
Normal file
177
ssh/session_handler.go
Normal file
@ -0,0 +1,177 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"strings"
|
||||
"text/template"
|
||||
|
||||
_ "embed"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
//go:embed diagram.txt
|
||||
var asciiDiagram string
|
||||
var asciiDiagramTmpl = template.Must(template.New("").Parse(asciiDiagram))
|
||||
|
||||
func (s *Server) handleSession(sess ssh.Session) {
|
||||
ctx := sess.Context()
|
||||
sessionID := SessionID(ctx.SessionID())
|
||||
|
||||
s.log("(%s): session opened", sessionID)
|
||||
|
||||
message := `
|
||||
Welcome on Rebound !
|
||||
`
|
||||
|
||||
if _, err := sess.Write([]byte(message)); err != nil {
|
||||
s.log("[ERROR] %+v", errors.WithStack(err))
|
||||
}
|
||||
|
||||
go s.readClientInput(sess)
|
||||
|
||||
data := s.sessionManager.Get(sessionID, SessionData{
|
||||
Type: TypeServiceUnknown,
|
||||
})
|
||||
s.handleSessionData(sess, data)
|
||||
|
||||
onUpdate, close := s.sessionManager.OnUpdate(sessionID)
|
||||
defer close()
|
||||
|
||||
for {
|
||||
data, opened := <-onUpdate
|
||||
if !opened {
|
||||
return
|
||||
}
|
||||
|
||||
s.log("(%s): session data updated: %v", sessionID, data)
|
||||
|
||||
if err := s.handleSessionData(sess, data); err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return
|
||||
}
|
||||
|
||||
s.log("[ERROR] %+v", errors.WithStack(err))
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleSessionData(sess ssh.Session, data SessionData) error {
|
||||
switch data.Type {
|
||||
case TypeServiceConsumer:
|
||||
if err := s.writeConsumerMessage(sess, data); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
case TypeServiceProvider:
|
||||
if err := s.writeProviderMessage(sess, data); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) writeConsumerMessage(sess ssh.Session, data SessionData) error {
|
||||
message := `
|
||||
Type Ctrl+C or Ctrl+D to exit.
|
||||
`
|
||||
|
||||
if _, err := sess.Write([]byte(message)); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) writeProviderMessage(sess ssh.Session, data SessionData) error {
|
||||
var diagramBuff bytes.Buffer
|
||||
|
||||
hostname := s.opts.PublicHost
|
||||
if len(hostname) < 24 {
|
||||
halfPadding := (24 - len(hostname)) / 2
|
||||
leftPadding := strings.Repeat(" ", halfPadding)
|
||||
rightPadding := strings.Repeat(" ", halfPadding)
|
||||
hostname = fmt.Sprintf("%s%s%s", leftPadding, hostname, rightPadding)
|
||||
if len(hostname) > 24 {
|
||||
hostname = hostname[0:23]
|
||||
}
|
||||
} else if len(hostname) >= 24 {
|
||||
hostname = hostname[0:20] + "..."
|
||||
}
|
||||
|
||||
log.Printf("'%s'", hostname)
|
||||
|
||||
tmplData := struct {
|
||||
Pp string
|
||||
Rp string
|
||||
Hostname string
|
||||
}{
|
||||
Pp: fmt.Sprintf("%04d", s.opts.PublicPort),
|
||||
Rp: "<port>",
|
||||
Hostname: hostname,
|
||||
}
|
||||
|
||||
if err := asciiDiagramTmpl.Execute(&diagramBuff, tmplData); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
message := fmt.Sprintf(`
|
||||
You can connect to your tunnel by running in an other terminal:
|
||||
|
||||
ssh -L <port>:127.0.0.1:1 %s@%s -p %d
|
||||
|
||||
%s
|
||||
|
||||
Type Ctrl+C or Ctrl+D to exit.
|
||||
`, data.Token, s.opts.PublicHost, s.opts.PublicPort, diagramBuff.String())
|
||||
|
||||
if _, err := sess.Write([]byte(message)); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
CtrlC = 3
|
||||
CtrlD = 4
|
||||
)
|
||||
|
||||
func (s *Server) readClientInput(sess ssh.Session) {
|
||||
sessionID := SessionID(sess.Context().SessionID())
|
||||
defer func() {
|
||||
s.sessionManager.Remove(sessionID)
|
||||
}()
|
||||
|
||||
buff := make([]byte, 1)
|
||||
|
||||
for {
|
||||
_, err := sess.Read(buff)
|
||||
if err != nil {
|
||||
if !errors.Is(err, io.EOF) {
|
||||
s.log("[ERROR] %+v", errors.WithStack(err))
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
switch buff[0] {
|
||||
case CtrlC:
|
||||
fallthrough
|
||||
case CtrlD:
|
||||
sess.Exit(0)
|
||||
return
|
||||
|
||||
default:
|
||||
s.log("(%s) user input: %v", sessionID, buff)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
184
ssh/session_manager.go
Normal file
184
ssh/session_manager.go
Normal file
@ -0,0 +1,184 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type SessionID string
|
||||
|
||||
type SessionType int
|
||||
|
||||
const (
|
||||
TypeServiceUnknown SessionType = iota
|
||||
TypeServiceProvider
|
||||
TypeServiceConsumer
|
||||
)
|
||||
|
||||
type SessionData struct {
|
||||
Type SessionType
|
||||
Token string
|
||||
}
|
||||
|
||||
type SessionManager struct {
|
||||
sessions map[SessionID]SessionData
|
||||
sessionsMutex sync.Mutex
|
||||
|
||||
tokenIndex map[string]SessionID
|
||||
|
||||
updates map[SessionID][]chan SessionData
|
||||
updatesMutex sync.Mutex
|
||||
|
||||
updateReadTimeout time.Duration
|
||||
}
|
||||
|
||||
func (m *SessionManager) Get(sessID SessionID, defaultValue SessionData) SessionData {
|
||||
log.Println("reading session", sessID)
|
||||
|
||||
m.sessionsMutex.Lock()
|
||||
defer m.sessionsMutex.Unlock()
|
||||
|
||||
session, exists := m.sessions[sessID]
|
||||
if !exists {
|
||||
session = defaultValue
|
||||
m.sessions[sessID] = session
|
||||
|
||||
m.updatesMutex.Lock()
|
||||
m.dispatchUpdate(sessID, session)
|
||||
m.updatesMutex.Unlock()
|
||||
}
|
||||
|
||||
return session
|
||||
}
|
||||
|
||||
func (m *SessionManager) FindByToken(token string) SessionID {
|
||||
m.sessionsMutex.Lock()
|
||||
defer m.sessionsMutex.Unlock()
|
||||
|
||||
sessID, exists := m.tokenIndex[token]
|
||||
if !exists {
|
||||
return ""
|
||||
}
|
||||
|
||||
return sessID
|
||||
}
|
||||
|
||||
func (m *SessionManager) Set(sessID SessionID, sess SessionData) {
|
||||
log.Println("updating session", sessID, sess)
|
||||
|
||||
m.sessionsMutex.Lock()
|
||||
oldSess, ok := m.sessions[sessID]
|
||||
if ok {
|
||||
m.updateTokenIndex(sessID, sess.Token, oldSess.Token)
|
||||
} else {
|
||||
m.updateTokenIndex(sessID, sess.Token, "")
|
||||
}
|
||||
m.sessions[sessID] = sess
|
||||
m.sessionsMutex.Unlock()
|
||||
|
||||
m.updatesMutex.Lock()
|
||||
m.dispatchUpdate(sessID, sess)
|
||||
m.updatesMutex.Unlock()
|
||||
}
|
||||
|
||||
func (m *SessionManager) Remove(sessID SessionID) {
|
||||
m.sessionsMutex.Lock()
|
||||
oldSess, ok := m.sessions[sessID]
|
||||
if ok {
|
||||
m.updateTokenIndex(sessID, "", oldSess.Token)
|
||||
}
|
||||
delete(m.sessions, sessID)
|
||||
m.sessionsMutex.Unlock()
|
||||
|
||||
m.updatesMutex.Lock()
|
||||
m.closeAllUpdates(sessID)
|
||||
m.updatesMutex.Unlock()
|
||||
}
|
||||
|
||||
func (m *SessionManager) OnUpdate(sessID SessionID) (<-chan SessionData, func()) {
|
||||
update := make(chan SessionData)
|
||||
|
||||
close := func() {
|
||||
m.updatesMutex.Lock()
|
||||
m.closeUpdate(sessID, update)
|
||||
m.updatesMutex.Unlock()
|
||||
}
|
||||
|
||||
m.updatesMutex.Lock()
|
||||
defer m.updatesMutex.Unlock()
|
||||
|
||||
channels, exists := m.updates[sessID]
|
||||
if !exists {
|
||||
channels = make([]chan SessionData, 0, 1)
|
||||
}
|
||||
|
||||
channels = append(channels, update)
|
||||
m.updates[sessID] = channels
|
||||
|
||||
return update, close
|
||||
}
|
||||
|
||||
func (m *SessionManager) closeAllUpdates(sessID SessionID) {
|
||||
channels, exists := m.updates[sessID]
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
for _, ch := range channels {
|
||||
m.closeUpdate(sessID, ch)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *SessionManager) closeUpdate(sessID SessionID, update chan SessionData) {
|
||||
channels, exists := m.updates[sessID]
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
for idx, ch := range channels {
|
||||
if ch != update {
|
||||
continue
|
||||
}
|
||||
|
||||
close(ch)
|
||||
m.updates[sessID] = append(channels[:idx], channels[idx+1:]...)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *SessionManager) dispatchUpdate(sessID SessionID, sess SessionData) {
|
||||
channels, exists := m.updates[sessID]
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
for _, ch := range channels {
|
||||
timeout := time.After(m.updateReadTimeout)
|
||||
select {
|
||||
case ch <- sess:
|
||||
case <-timeout:
|
||||
err := errors.New("session update read timed out")
|
||||
log.Printf("[ERROR] %+v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *SessionManager) updateTokenIndex(sessID SessionID, addedToken, deletedToken string) {
|
||||
if addedToken != "" {
|
||||
m.tokenIndex[addedToken] = sessID
|
||||
}
|
||||
|
||||
if deletedToken != "" {
|
||||
delete(m.tokenIndex, deletedToken)
|
||||
}
|
||||
}
|
||||
|
||||
func NewSessionManager(updateReadTimeout time.Duration) *SessionManager {
|
||||
return &SessionManager{
|
||||
sessions: make(map[SessionID]SessionData),
|
||||
tokenIndex: make(map[string]SessionID),
|
||||
updates: make(map[SessionID][]chan SessionData),
|
||||
updateReadTimeout: updateReadTimeout,
|
||||
}
|
||||
}
|
Reference in New Issue
Block a user