commit 1bdad36787e902ca4748d6638c7517ffd82bb4d7 Author: William Petit Date: Fri Sep 8 20:00:00 2023 -0600 feat: initial commit diff --git a/.dockerignore b/.dockerignore new file mode 100644 index 0000000..d082abb --- /dev/null +++ b/.dockerignore @@ -0,0 +1,5 @@ +/bin +/.mktools +/tools +/.env +/socks \ No newline at end of file diff --git a/.env.dist b/.env.dist new file mode 100644 index 0000000..e69de29 diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..66554cc --- /dev/null +++ b/.gitignore @@ -0,0 +1,6 @@ +/tools +/.mktools +/bin +/.env +/socks +/host.key \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..6dd7e7d --- /dev/null +++ b/Dockerfile @@ -0,0 +1,27 @@ +FROM golang:1.21 AS build + +RUN apt-get update && apt-get install -y build-essential git bash curl ca-certificates + +COPY . /src + +WORKDIR /src + +RUN make build + +FROM busybox + +COPY --from=build /src/bin /app + +WORKDIR /app + +RUN mkdir -p /app/socks /app/keys + +EXPOSE 2222 + +ENV REBOUND_PUBLIC_HOST=127.0.0.1 +ENV REBOUND_PUBLIC_PORT=2222 +ENV REBOUND_HOST_KEY=/app/keys/host.key +ENV REBOUND_ADDRESS=:2222 +ENV REBOUND_SOCK_DIR=/app/socks + +CMD ["/app/server"] \ No newline at end of file diff --git a/Makefile b/Makefile new file mode 100644 index 0000000..94ef220 --- /dev/null +++ b/Makefile @@ -0,0 +1,49 @@ +SHELL := /bin/bash +DOKKU_URL := dokku@dev.lookingfora.name:rebound + +all: build + +watch: tools/modd/bin/modd + tools/modd/bin/modd + +socks: + mkdir -p ./socks + +run: .env socks + ( set -o allexport && source .env && set +o allexport && bin/server ) + +build: .env + CGO_ENABLED=0 go build -o ./bin/server ./cmd/server + +.env: + cp .env.dist .env + +dokku-build: + docker build \ + -t rebound-dokku:latest \ + . + +dokku-run: + docker run \ + -it --rm \ + -p 2222:2222 \ + --tmpfs /socks \ + rebound-dokku:latest + +dokku-deploy: + $(if $(shell git config remote.dokku.url),, git remote add dokku $(DOKKU_URL)) + git push -f dokku $(shell git rev-parse HEAD):refs/heads/master + +.PHONY: mktools +mktools: + rm -rf .mktools + curl -q https://forge.cadoles.com/Cadoles/mktools/raw/branch/master/install.sh | $(SHELL) + +tools/modd/bin/modd: + mkdir -p tools/modd/bin + GOBIN=$(PWD)/tools/modd/bin go install github.com/cortesi/modd/cmd/modd@v0.8.1 + +.mktools: + $(MAKE) mktools + +-include .mktools/*.mk \ No newline at end of file diff --git a/README.md b/README.md new file mode 100644 index 0000000..4d0a270 --- /dev/null +++ b/README.md @@ -0,0 +1,19 @@ +# Rebound + +Serveur utilisant le protocole SSH pour créer des tunnels TCP/IP. + +## Usage + +### Ouvrir un tunnel sur un port local + +```shell +ssh -R 0:: root@rebound.lookingfora.name -p 2222 +``` + +### Se connecter à un tunnel ouvert + +``` +ssh -L :0.0.0.0:1 @rebound.lookingfora.name -p 2222 +``` + +Vous pourrez ensuite accéder au service sur `127.0.0.1:` comme si c'était `:`. \ No newline at end of file diff --git a/cmd/server/main.go b/cmd/server/main.go new file mode 100644 index 0000000..9f272b6 --- /dev/null +++ b/cmd/server/main.go @@ -0,0 +1,50 @@ +package main + +import ( + "flag" + "log" + "os" + "os/signal" + + "forge.cadoles.com/wpetit/rebound" + "github.com/caarlos0/env/v6" + "github.com/pkg/errors" +) + +func main() { + opts := rebound.DefaultOptions() + if err := env.Parse(opts); err != nil { + log.Fatalf("[ERROR] %+v", errors.WithStack(err)) + } + + address := flag.String("address", opts.Address, "server listening address") + sockDir := flag.String("sock-dir", opts.SockDir, "sock directory") + publicPort := flag.Uint("public-port", opts.PublicPort, "public port") + publicHost := flag.String("public-host", opts.PublicHost, "public host") + hostKey := flag.String("host-key", opts.HostKey, "host key") + + flag.Parse() + + server := rebound.NewServer( + rebound.WithAddress(*address), + rebound.WithSockDir(*sockDir), + rebound.WithPublicPort(*publicPort), + rebound.WithPublicHost(*publicHost), + rebound.WithHostKey(*hostKey), + ) + + go func() { + if err := server.Run(); err != nil { + log.Fatalf("[FATAL] %+v", errors.WithStack(err)) + } + }() + + c := make(chan os.Signal, 1) + signal.Notify(c, os.Interrupt) + + <-c + + if err := server.Stop(); err != nil { + log.Fatalf("[FATAL] %+v", errors.WithStack(err)) + } +} diff --git a/direct_tcp_handler.go b/direct_tcp_handler.go new file mode 100644 index 0000000..18dbd15 --- /dev/null +++ b/direct_tcp_handler.go @@ -0,0 +1,97 @@ +package rebound + +import ( + "fmt" + "io" + "net" + "os" + + "github.com/gliderlabs/ssh" + "github.com/pkg/errors" + gossh "golang.org/x/crypto/ssh" +) + +// direct-tcpip data struct as specified in RFC4254, Section 7.2 +type localForwardChannelData struct { + DestAddr string + DestPort uint32 + + OriginAddr string + OriginPort uint32 +} + +func (s *Server) handleDirectTCP(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) { + d := localForwardChannelData{} + if err := gossh.Unmarshal(newChan.ExtraData(), &d); err != nil { + s.log("[ERROR] %+v", errors.WithStack(err)) + newChan.Reject(gossh.ConnectionFailed, "error parsing forward data: "+err.Error()) + return + } + + sessionID := SessionID(ctx.SessionID()) + + s.log("(%s): opening direct tcp tunnel", sessionID) + + s.sessionManager.Set(sessionID, SessionData{ + Type: TypeServiceConsumer, + }) + + sessionToken := ctx.User() + + remoteSessionID := s.sessionManager.FindByToken(sessionToken) + if remoteSessionID == "" { + newChan.Reject(gossh.ConnectionFailed, fmt.Sprintf("could not find session associated with token '%s'", sessionToken)) + conn.Close() + return + } + + addr := s.getSocketPath(remoteSessionID) + + s.log("(%s): using sock '%s'", sessionID, addr) + + if _, err := os.Stat(addr); err != nil { + newChan.Reject(gossh.ConnectionFailed, fmt.Sprintf("could not find session associated with token '%s'", sessionToken)) + conn.Close() + return + } + + var dialer net.Dialer + dconn, err := dialer.DialContext(ctx, "unix", addr) + if err != nil { + s.log("[ERROR] %+v", errors.WithStack(err)) + newChan.Reject(gossh.ConnectionFailed, err.Error()) + return + } + + ch, reqs, err := newChan.Accept() + if err != nil { + s.log("[ERROR] %+v", errors.WithStack(err)) + dconn.Close() + + return + } + + go gossh.DiscardRequests(reqs) + + go func() { + defer dconn.Close() + defer ch.Close() + + if _, err := io.Copy(ch, dconn); err != nil { + if errors.Is(err, net.ErrClosed) { + return + } + + s.log("[ERROR] %+v", errors.WithStack(err)) + } + }() + + go func() { + defer dconn.Close() + defer ch.Close() + + if _, err := io.Copy(dconn, ch); err != nil { + s.log("[ERROR] %+v", errors.WithStack(err)) + } + }() +} diff --git a/go.mod b/go.mod new file mode 100644 index 0000000..aecb25f --- /dev/null +++ b/go.mod @@ -0,0 +1,15 @@ +module forge.cadoles.com/wpetit/rebound + +go 1.21.0 + +require ( + github.com/caarlos0/env/v6 v6.10.1 + github.com/gliderlabs/ssh v0.3.5 + github.com/pkg/errors v0.9.1 + golang.org/x/crypto v0.0.0-20220826181053-bd7e27e6170d +) + +require ( + github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect + golang.org/x/sys v0.0.0-20220825204002-c680a09ffe64 // indirect +) diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..d006a3e --- /dev/null +++ b/go.sum @@ -0,0 +1,25 @@ +github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8= +github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4= +github.com/caarlos0/env/v6 v6.10.1 h1:t1mPSxNpei6M5yAeu1qtRdPAK29Nbcf/n3G7x+b3/II= +github.com/caarlos0/env/v6 v6.10.1/go.mod h1:hvp/ryKXKipEkcuYjs9mI4bBCg+UI0Yhgm5Zu0ddvwc= +github.com/gliderlabs/ssh v0.3.5 h1:OcaySEmAQJgyYcArR+gGGTHCyE7nvhEMTlYY+Dp8CpY= +github.com/gliderlabs/ssh v0.3.5/go.mod h1:8XB4KraRrX39qHhT6yxPsHedjA08I/uBVwj4xC+/+z4= +github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4= +github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= +golang.org/x/crypto v0.0.0-20220826181053-bd7e27e6170d h1:3qF+Z8Hkrw9sOhrFHti9TlB1Hkac1x+DNRkv0XQiFjo= +golang.org/x/crypto v0.0.0-20220826181053-bd7e27e6170d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= +golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= +golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk= +golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/sys v0.0.0-20220825204002-c680a09ffe64 h1:UiNENfZ8gDvpiWw7IpOMQ27spWmThO1RwwdQVbJahJM= +golang.org/x/sys v0.0.0-20220825204002-c680a09ffe64/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= +golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= +golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/term v0.0.0-20220722155259-a9ba230a4035 h1:Q5284mrmYTpACcm+eAKjKJH48BBwSyfJqmmGDTtT8Vc= +golang.org/x/term v0.0.0-20220722155259-a9ba230a4035/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= +golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= +golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= +golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ= diff --git a/modd.conf b/modd.conf new file mode 100644 index 0000000..649a319 --- /dev/null +++ b/modd.conf @@ -0,0 +1,6 @@ +**/*.go +Makefile { + prep: make build + daemon: make run + daemon: npx http-server . +} \ No newline at end of file diff --git a/option.go b/option.go new file mode 100644 index 0000000..2e8dbee --- /dev/null +++ b/option.go @@ -0,0 +1,55 @@ +package rebound + +import "log" + +type Options struct { + Address string `env:"REBOUND_ADDRESS"` + Logger func(message string, args ...any) + SockDir string `env:"REBOUND_SOCK_DIR"` + PublicPort uint `env:"REBOUND_PUBLIC_PORT"` + PublicHost string `env:"REBOUND_PUBLIC_HOST"` + HostKey string `env:"REBOUND_HOST_KEY"` +} + +type OptionFunc func(*Options) + +func DefaultOptions() *Options { + return &Options{ + Address: "127.0.0.1:2222", + SockDir: "./socks", + Logger: log.Printf, + PublicPort: 2222, + PublicHost: "127.0.0.1", + HostKey: "./host.key", + } +} + +func WithAddress(addr string) func(*Options) { + return func(opts *Options) { + opts.Address = addr + } +} + +func WithSockDir(addr string) func(*Options) { + return func(opts *Options) { + opts.SockDir = addr + } +} + +func WithPublicHost(host string) func(*Options) { + return func(opts *Options) { + opts.PublicHost = host + } +} + +func WithPublicPort(port uint) func(*Options) { + return func(opts *Options) { + opts.PublicPort = port + } +} + +func WithHostKey(key string) func(*Options) { + return func(opts *Options) { + opts.HostKey = key + } +} diff --git a/request_handler.go b/request_handler.go new file mode 100644 index 0000000..3ed7525 --- /dev/null +++ b/request_handler.go @@ -0,0 +1,226 @@ +package rebound + +import ( + "crypto/rand" + "fmt" + "io" + "log" + "math/big" + "net" + "os" + "path/filepath" + "strconv" + "strings" + + "github.com/gliderlabs/ssh" + "github.com/pkg/errors" + gossh "golang.org/x/crypto/ssh" +) + +const ( + forwardedTCPChannelType = "forwarded-tcpip" +) + +// direct-tcpip data struct as specified in RFC4254, Section 7.2 +type remoteForwardRequest struct { + BindAddr string + BindPort uint32 +} + +type remoteForwardSuccess struct { + BindPort uint32 +} + +type remoteForwardCancelRequest struct { + BindAddr string + BindPort uint32 +} + +type remoteForwardChannelData struct { + DestAddr string + DestPort uint32 + OriginAddr string + OriginPort uint32 +} + +func (s *Server) handleRequest(ctx ssh.Context, srv *ssh.Server, req *gossh.Request) (bool, []byte) { + s.requestHandlerLock.Lock() + if s.forwards == nil { + s.forwards = make(map[string]net.Listener) + } + s.requestHandlerLock.Unlock() + conn := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn) + switch req.Type { + case "tcpip-forward": + var reqPayload remoteForwardRequest + + if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil { + log.Printf("[ERROR] %+v", errors.WithStack(err)) + return false, []byte{} + } + + if reqPayload.BindPort != 0 { + return false, []byte("bind port must 0") + } + + s.log("(%s): opening reverse tunnel", ctx.SessionID()) + + token, err := generateToken(16) + if err != nil { + s.log("[ERROR] %+v", errors.WithStack(err)) + return false, []byte("could not generate secret token") + } + + sessionID := SessionID(ctx.SessionID()) + + s.sessionManager.Set(sessionID, SessionData{ + Type: TypeServiceProvider, + Token: token, + }) + + addr := s.getSocketPath(sessionID) + + ln, err := net.Listen("unix", addr) + if err != nil { + s.log("[ERROR] %+v", errors.WithStack(err)) + return false, []byte{} + } + + destPort := 1 + + s.requestHandlerLock.Lock() + s.forwards[addr] = ln + s.requestHandlerLock.Unlock() + + cleanup := func() { + s.log("(%s): cleaning up session", sessionID) + + s.sessionManager.Remove(sessionID) + + if err := os.Remove(addr); err != nil && !errors.Is(err, os.ErrNotExist) { + s.log("[ERROR] %+v", errors.WithStack(err)) + } + } + + go func() { + defer cleanup() + + <-ctx.Done() + s.requestHandlerLock.Lock() + ln, ok := s.forwards[addr] + s.requestHandlerLock.Unlock() + if ok { + ln.Close() + } + }() + + go func() { + for { + c, err := ln.Accept() + if err != nil { + if !errors.Is(err, net.ErrClosed) { + s.log("[ERROR] %+v", errors.WithStack(err)) + } + + break + } + + originAddr, orignPortStr, _ := net.SplitHostPort(c.RemoteAddr().String()) + originPort, _ := strconv.Atoi(orignPortStr) + payload := gossh.Marshal(&remoteForwardChannelData{ + DestAddr: reqPayload.BindAddr, + DestPort: uint32(destPort), + OriginAddr: originAddr, + OriginPort: uint32(originPort), + }) + + go func() { + ch, reqs, err := conn.OpenChannel(forwardedTCPChannelType, payload) + if err != nil { + s.log("[ERROR] %+v", errors.WithStack(err)) + c.Close() + return + } + + go gossh.DiscardRequests(reqs) + + go func() { + defer ch.Close() + defer c.Close() + + io.Copy(ch, c) + }() + + go func() { + defer ch.Close() + defer c.Close() + + io.Copy(c, ch) + }() + }() + } + + s.requestHandlerLock.Lock() + delete(s.forwards, addr) + s.requestHandlerLock.Unlock() + }() + + return true, gossh.Marshal(&remoteForwardSuccess{uint32(destPort)}) + + case "cancel-tcpip-forward": + var reqPayload remoteForwardCancelRequest + if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil { + s.log("[ERROR] %+v", errors.WithStack(err)) + return false, []byte{} + } + + sessionID := SessionID(ctx.SessionID()) + + addr := s.getSocketPath(sessionID) + + s.log("(%s): closing sock '%s'", sessionID, addr) + + s.requestHandlerLock.Lock() + ln, ok := s.forwards[addr] + s.requestHandlerLock.Unlock() + if ok { + ln.Close() + if err := os.Remove(addr); err != nil && !errors.Is(err, os.ErrNotExist) { + s.log("[ERROR] %+v", errors.WithStack(err)) + } + } + return true, nil + + default: + return false, nil + } +} + +func (s *Server) getSocketPath(sessionID SessionID) string { + return filepath.Join(s.opts.SockDir, fmt.Sprintf("%s.sock", sessionID)) +} + +func generateToken(length int) (string, error) { + chars := []rune( + "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + + "abcdefghijklmnopqrstuvwxyz" + + "0123456789", + ) + var b strings.Builder + + charsLength := big.NewInt(int64(len(chars))) + + for i := 0; i < length; i++ { + idx, err := rand.Int(rand.Reader, charsLength) + if err != nil { + return "", errors.WithStack(err) + } + + c := chars[idx.Int64()] + if _, err := b.WriteRune(c); err != nil { + return "", errors.WithStack(err) + } + } + + return b.String(), nil +} diff --git a/server.go b/server.go new file mode 100644 index 0000000..bdbea7b --- /dev/null +++ b/server.go @@ -0,0 +1,141 @@ +package rebound + +import ( + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "io/fs" + "net" + "os" + "sync" + "time" + + "github.com/gliderlabs/ssh" + "github.com/pkg/errors" + + gossh "golang.org/x/crypto/ssh" +) + +type Server struct { + ssh *ssh.Server + opts *Options + + sessionManager *SessionManager + forwards map[string]net.Listener + requestHandlerLock sync.Mutex +} + +func (s *Server) Run() error { + if err := s.initSSHServer(); err != nil { + return errors.WithStack(err) + } + + s.log("listening on %s", s.opts.Address) + + if err := s.ssh.ListenAndServe(); err != nil { + return errors.WithStack(err) + } + + return nil +} + +func (s *Server) Stop() error { + if s.ssh == nil { + return nil + } + + s.log("stopping on '%s'", s.opts.Address) + + if err := s.ssh.Close(); err != nil { + return errors.WithStack(err) + } + + return nil +} + +func (s *Server) initSSHServer() error { + signer, err := s.loadOrCreateSigner() + if err != nil { + return errors.WithStack(err) + } + + server := &ssh.Server{ + Addr: s.opts.Address, + HostSigners: []ssh.Signer{signer}, + Handler: ssh.Handler(s.handleSession), + RequestHandlers: map[string]ssh.RequestHandler{ + "tcpip-forward": s.handleRequest, + "cancel-tcpip-forward": s.handleRequest, + }, + ChannelHandlers: map[string]ssh.ChannelHandler{ + "direct-tcpip": s.handleDirectTCP, + "session": ssh.DefaultSessionHandler, + }, + } + + s.ssh = server + + return nil +} + +func (s *Server) loadOrCreateSigner() (ssh.Signer, error) { + var ( + signer gossh.Signer + ) + + s.log("reading host key from '%s'", s.opts.HostKey) + + data, err := os.ReadFile(s.opts.HostKey) + if err != nil && !errors.Is(err, os.ErrNotExist) { + return nil, errors.WithStack(err) + } + + if data == nil { + s.log("host key cannot be found, generating one") + key, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + return nil, errors.WithStack(err) + } + + signer, err = gossh.NewSignerFromKey(key) + if err != nil { + return nil, errors.WithStack(err) + } + + pem := pem.EncodeToMemory( + &pem.Block{ + Type: "RSA PRIVATE KEY", + Bytes: x509.MarshalPKCS1PrivateKey(key), + }, + ) + + s.log("saving host key to '%s'", s.opts.HostKey) + if err := os.WriteFile(s.opts.HostKey, pem, fs.FileMode(0640)); err != nil { + return nil, errors.WithStack(err) + } + } else { + signer, err = gossh.ParsePrivateKey(data) + if err != nil { + return nil, errors.WithStack(err) + } + } + + return signer, nil +} + +func (s *Server) log(message string, args ...any) { + s.opts.Logger(message, args...) +} + +func NewServer(funcs ...OptionFunc) *Server { + opts := DefaultOptions() + for _, fn := range funcs { + fn(opts) + } + + return &Server{ + opts: opts, + sessionManager: NewSessionManager(30 * time.Second), + } +} diff --git a/session_handler.go b/session_handler.go new file mode 100644 index 0000000..c9a31ab --- /dev/null +++ b/session_handler.go @@ -0,0 +1,133 @@ +package rebound + +import ( + "fmt" + "io" + + "github.com/gliderlabs/ssh" + "github.com/pkg/errors" +) + +func (s *Server) handleSession(sess ssh.Session) { + ctx := sess.Context() + sessionID := SessionID(ctx.SessionID()) + + s.log("(%s): session opened", sessionID) + + message := ` +Welcome on Rebound ! + +Type Ctrl+C or Ctrl+D to exit. +` + + if _, err := sess.Write([]byte(message)); err != nil { + s.log("[ERROR] %+v", errors.WithStack(err)) + } + + go s.readClientInput(sess) + + data := s.sessionManager.Get(sessionID, SessionData{ + Type: TypeServiceUnknown, + }) + s.handleSessionData(sess, data) + + onUpdate, close := s.sessionManager.OnUpdate(sessionID) + defer close() + + for { + data, opened := <-onUpdate + if !opened { + return + } + + s.log("(%s): session data updated: %v", sessionID, data) + + if err := s.handleSessionData(sess, data); err != nil { + if errors.Is(err, io.EOF) { + return + } + + s.log("[ERROR] %+v", errors.WithStack(err)) + + return + } + } +} + +func (s *Server) handleSessionData(sess ssh.Session, data SessionData) error { + switch data.Type { + case TypeServiceConsumer: + if err := s.writeConsumerMessage(sess, data); err != nil { + return errors.WithStack(err) + } + + case TypeServiceProvider: + if err := s.writeProviderMessage(sess, data); err != nil { + return errors.WithStack(err) + } + } + + return nil +} + +func (s *Server) writeConsumerMessage(sess ssh.Session, data SessionData) error { + message := `` + + if _, err := sess.Write([]byte(message)); err != nil { + return errors.WithStack(err) + } + + return nil +} + +func (s *Server) writeProviderMessage(sess ssh.Session, data SessionData) error { + message := fmt.Sprintf(` +You can connect to your tunnel by running in an other terminal: + +ssh -L :0.0.0.0:1 %s@%s -p %d + +`, data.Token, s.opts.PublicHost, s.opts.PublicPort) + + if _, err := sess.Write([]byte(message)); err != nil { + return errors.WithStack(err) + } + + return nil +} + +const ( + CtrlC = 3 + CtrlD = 4 +) + +func (s *Server) readClientInput(sess ssh.Session) { + sessionID := SessionID(sess.Context().SessionID()) + defer func() { + s.sessionManager.Remove(sessionID) + }() + + buff := make([]byte, 1) + + for { + _, err := sess.Read(buff) + if err != nil { + if !errors.Is(err, io.EOF) { + s.log("[ERROR] %+v", errors.WithStack(err)) + } + + return + } + + switch buff[0] { + case CtrlC: + fallthrough + case CtrlD: + sess.Exit(0) + return + + default: + s.log("(%s) user input: %v", sessionID, buff) + } + + } +} diff --git a/session_manager.go b/session_manager.go new file mode 100644 index 0000000..b9a55c5 --- /dev/null +++ b/session_manager.go @@ -0,0 +1,184 @@ +package rebound + +import ( + "errors" + "log" + "sync" + "time" +) + +type SessionID string + +type SessionType int + +const ( + TypeServiceUnknown SessionType = iota + TypeServiceProvider + TypeServiceConsumer +) + +type SessionData struct { + Type SessionType + Token string +} + +type SessionManager struct { + sessions map[SessionID]SessionData + sessionsMutex sync.Mutex + + tokenIndex map[string]SessionID + + updates map[SessionID][]chan SessionData + updatesMutex sync.Mutex + + updateReadTimeout time.Duration +} + +func (m *SessionManager) Get(sessID SessionID, defaultValue SessionData) SessionData { + log.Println("reading session", sessID) + + m.sessionsMutex.Lock() + defer m.sessionsMutex.Unlock() + + session, exists := m.sessions[sessID] + if !exists { + session = defaultValue + m.sessions[sessID] = session + + m.updatesMutex.Lock() + m.dispatchUpdate(sessID, session) + m.updatesMutex.Unlock() + } + + return session +} + +func (m *SessionManager) FindByToken(token string) SessionID { + m.sessionsMutex.Lock() + defer m.sessionsMutex.Unlock() + + sessID, exists := m.tokenIndex[token] + if !exists { + return "" + } + + return sessID +} + +func (m *SessionManager) Set(sessID SessionID, sess SessionData) { + log.Println("updating session", sessID, sess) + + m.sessionsMutex.Lock() + oldSess, ok := m.sessions[sessID] + if ok { + m.updateTokenIndex(sessID, sess.Token, oldSess.Token) + } else { + m.updateTokenIndex(sessID, sess.Token, "") + } + m.sessions[sessID] = sess + m.sessionsMutex.Unlock() + + m.updatesMutex.Lock() + m.dispatchUpdate(sessID, sess) + m.updatesMutex.Unlock() +} + +func (m *SessionManager) Remove(sessID SessionID) { + m.sessionsMutex.Lock() + oldSess, ok := m.sessions[sessID] + if ok { + m.updateTokenIndex(sessID, "", oldSess.Token) + } + delete(m.sessions, sessID) + m.sessionsMutex.Unlock() + + m.updatesMutex.Lock() + m.closeAllUpdates(sessID) + m.updatesMutex.Unlock() +} + +func (m *SessionManager) OnUpdate(sessID SessionID) (<-chan SessionData, func()) { + update := make(chan SessionData) + + close := func() { + m.updatesMutex.Lock() + m.closeUpdate(sessID, update) + m.updatesMutex.Unlock() + } + + m.updatesMutex.Lock() + defer m.updatesMutex.Unlock() + + channels, exists := m.updates[sessID] + if !exists { + channels = make([]chan SessionData, 0, 1) + } + + channels = append(channels, update) + m.updates[sessID] = channels + + return update, close +} + +func (m *SessionManager) closeAllUpdates(sessID SessionID) { + channels, exists := m.updates[sessID] + if !exists { + return + } + + for _, ch := range channels { + m.closeUpdate(sessID, ch) + } +} + +func (m *SessionManager) closeUpdate(sessID SessionID, update chan SessionData) { + channels, exists := m.updates[sessID] + if !exists { + return + } + + for idx, ch := range channels { + if ch != update { + continue + } + + close(ch) + m.updates[sessID] = append(channels[:idx], channels[idx+1:]...) + } +} + +func (m *SessionManager) dispatchUpdate(sessID SessionID, sess SessionData) { + channels, exists := m.updates[sessID] + if !exists { + return + } + + for _, ch := range channels { + timeout := time.After(m.updateReadTimeout) + select { + case ch <- sess: + case <-timeout: + err := errors.New("session update read timed out") + log.Printf("[ERROR] %+v", err) + } + } +} + +func (m *SessionManager) updateTokenIndex(sessID SessionID, addedToken, deletedToken string) { + if addedToken != "" { + m.tokenIndex[addedToken] = sessID + } + + if deletedToken != "" { + delete(m.tokenIndex, deletedToken) + } +} + +func NewSessionManager(updateReadTimeout time.Duration) *SessionManager { + return &SessionManager{ + sessions: make(map[SessionID]SessionData), + tokenIndex: make(map[string]SessionID), + updates: make(map[SessionID][]chan SessionData), + updateReadTimeout: updateReadTimeout, + } +}