feat: initial commit

This commit is contained in:
wpetit 2023-09-08 20:00:00 -06:00
commit 1bdad36787
16 changed files with 1038 additions and 0 deletions

5
.dockerignore Normal file
View File

@ -0,0 +1,5 @@
/bin
/.mktools
/tools
/.env
/socks

0
.env.dist Normal file
View File

6
.gitignore vendored Normal file
View File

@ -0,0 +1,6 @@
/tools
/.mktools
/bin
/.env
/socks
/host.key

27
Dockerfile Normal file
View File

@ -0,0 +1,27 @@
FROM golang:1.21 AS build
RUN apt-get update && apt-get install -y build-essential git bash curl ca-certificates
COPY . /src
WORKDIR /src
RUN make build
FROM busybox
COPY --from=build /src/bin /app
WORKDIR /app
RUN mkdir -p /app/socks /app/keys
EXPOSE 2222
ENV REBOUND_PUBLIC_HOST=127.0.0.1
ENV REBOUND_PUBLIC_PORT=2222
ENV REBOUND_HOST_KEY=/app/keys/host.key
ENV REBOUND_ADDRESS=:2222
ENV REBOUND_SOCK_DIR=/app/socks
CMD ["/app/server"]

49
Makefile Normal file
View File

@ -0,0 +1,49 @@
SHELL := /bin/bash
DOKKU_URL := dokku@dev.lookingfora.name:rebound
all: build
watch: tools/modd/bin/modd
tools/modd/bin/modd
socks:
mkdir -p ./socks
run: .env socks
( set -o allexport && source .env && set +o allexport && bin/server )
build: .env
CGO_ENABLED=0 go build -o ./bin/server ./cmd/server
.env:
cp .env.dist .env
dokku-build:
docker build \
-t rebound-dokku:latest \
.
dokku-run:
docker run \
-it --rm \
-p 2222:2222 \
--tmpfs /socks \
rebound-dokku:latest
dokku-deploy:
$(if $(shell git config remote.dokku.url),, git remote add dokku $(DOKKU_URL))
git push -f dokku $(shell git rev-parse HEAD):refs/heads/master
.PHONY: mktools
mktools:
rm -rf .mktools
curl -q https://forge.cadoles.com/Cadoles/mktools/raw/branch/master/install.sh | $(SHELL)
tools/modd/bin/modd:
mkdir -p tools/modd/bin
GOBIN=$(PWD)/tools/modd/bin go install github.com/cortesi/modd/cmd/modd@v0.8.1
.mktools:
$(MAKE) mktools
-include .mktools/*.mk

19
README.md Normal file
View File

@ -0,0 +1,19 @@
# Rebound
Serveur utilisant le protocole SSH pour créer des tunnels TCP/IP.
## Usage
### Ouvrir un tunnel sur un port local
```shell
ssh -R 0:<target_address>:<target_port> root@rebound.lookingfora.name -p 2222
```
### Se connecter à un tunnel ouvert
```
ssh -L <local_port>:0.0.0.0:1 <secret>@rebound.lookingfora.name -p 2222
```
Vous pourrez ensuite accéder au service sur `127.0.0.1:<local_port>` comme si c'était `<target_address>:<target_port>`.

50
cmd/server/main.go Normal file
View File

@ -0,0 +1,50 @@
package main
import (
"flag"
"log"
"os"
"os/signal"
"forge.cadoles.com/wpetit/rebound"
"github.com/caarlos0/env/v6"
"github.com/pkg/errors"
)
func main() {
opts := rebound.DefaultOptions()
if err := env.Parse(opts); err != nil {
log.Fatalf("[ERROR] %+v", errors.WithStack(err))
}
address := flag.String("address", opts.Address, "server listening address")
sockDir := flag.String("sock-dir", opts.SockDir, "sock directory")
publicPort := flag.Uint("public-port", opts.PublicPort, "public port")
publicHost := flag.String("public-host", opts.PublicHost, "public host")
hostKey := flag.String("host-key", opts.HostKey, "host key")
flag.Parse()
server := rebound.NewServer(
rebound.WithAddress(*address),
rebound.WithSockDir(*sockDir),
rebound.WithPublicPort(*publicPort),
rebound.WithPublicHost(*publicHost),
rebound.WithHostKey(*hostKey),
)
go func() {
if err := server.Run(); err != nil {
log.Fatalf("[FATAL] %+v", errors.WithStack(err))
}
}()
c := make(chan os.Signal, 1)
signal.Notify(c, os.Interrupt)
<-c
if err := server.Stop(); err != nil {
log.Fatalf("[FATAL] %+v", errors.WithStack(err))
}
}

97
direct_tcp_handler.go Normal file
View File

@ -0,0 +1,97 @@
package rebound
import (
"fmt"
"io"
"net"
"os"
"github.com/gliderlabs/ssh"
"github.com/pkg/errors"
gossh "golang.org/x/crypto/ssh"
)
// direct-tcpip data struct as specified in RFC4254, Section 7.2
type localForwardChannelData struct {
DestAddr string
DestPort uint32
OriginAddr string
OriginPort uint32
}
func (s *Server) handleDirectTCP(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) {
d := localForwardChannelData{}
if err := gossh.Unmarshal(newChan.ExtraData(), &d); err != nil {
s.log("[ERROR] %+v", errors.WithStack(err))
newChan.Reject(gossh.ConnectionFailed, "error parsing forward data: "+err.Error())
return
}
sessionID := SessionID(ctx.SessionID())
s.log("(%s): opening direct tcp tunnel", sessionID)
s.sessionManager.Set(sessionID, SessionData{
Type: TypeServiceConsumer,
})
sessionToken := ctx.User()
remoteSessionID := s.sessionManager.FindByToken(sessionToken)
if remoteSessionID == "" {
newChan.Reject(gossh.ConnectionFailed, fmt.Sprintf("could not find session associated with token '%s'", sessionToken))
conn.Close()
return
}
addr := s.getSocketPath(remoteSessionID)
s.log("(%s): using sock '%s'", sessionID, addr)
if _, err := os.Stat(addr); err != nil {
newChan.Reject(gossh.ConnectionFailed, fmt.Sprintf("could not find session associated with token '%s'", sessionToken))
conn.Close()
return
}
var dialer net.Dialer
dconn, err := dialer.DialContext(ctx, "unix", addr)
if err != nil {
s.log("[ERROR] %+v", errors.WithStack(err))
newChan.Reject(gossh.ConnectionFailed, err.Error())
return
}
ch, reqs, err := newChan.Accept()
if err != nil {
s.log("[ERROR] %+v", errors.WithStack(err))
dconn.Close()
return
}
go gossh.DiscardRequests(reqs)
go func() {
defer dconn.Close()
defer ch.Close()
if _, err := io.Copy(ch, dconn); err != nil {
if errors.Is(err, net.ErrClosed) {
return
}
s.log("[ERROR] %+v", errors.WithStack(err))
}
}()
go func() {
defer dconn.Close()
defer ch.Close()
if _, err := io.Copy(dconn, ch); err != nil {
s.log("[ERROR] %+v", errors.WithStack(err))
}
}()
}

15
go.mod Normal file
View File

@ -0,0 +1,15 @@
module forge.cadoles.com/wpetit/rebound
go 1.21.0
require (
github.com/caarlos0/env/v6 v6.10.1
github.com/gliderlabs/ssh v0.3.5
github.com/pkg/errors v0.9.1
golang.org/x/crypto v0.0.0-20220826181053-bd7e27e6170d
)
require (
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect
golang.org/x/sys v0.0.0-20220825204002-c680a09ffe64 // indirect
)

25
go.sum Normal file
View File

@ -0,0 +1,25 @@
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8=
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
github.com/caarlos0/env/v6 v6.10.1 h1:t1mPSxNpei6M5yAeu1qtRdPAK29Nbcf/n3G7x+b3/II=
github.com/caarlos0/env/v6 v6.10.1/go.mod h1:hvp/ryKXKipEkcuYjs9mI4bBCg+UI0Yhgm5Zu0ddvwc=
github.com/gliderlabs/ssh v0.3.5 h1:OcaySEmAQJgyYcArR+gGGTHCyE7nvhEMTlYY+Dp8CpY=
github.com/gliderlabs/ssh v0.3.5/go.mod h1:8XB4KraRrX39qHhT6yxPsHedjA08I/uBVwj4xC+/+z4=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
golang.org/x/crypto v0.0.0-20220826181053-bd7e27e6170d h1:3qF+Z8Hkrw9sOhrFHti9TlB1Hkac1x+DNRkv0XQiFjo=
golang.org/x/crypto v0.0.0-20220826181053-bd7e27e6170d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.0.0-20220825204002-c680a09ffe64 h1:UiNENfZ8gDvpiWw7IpOMQ27spWmThO1RwwdQVbJahJM=
golang.org/x/sys v0.0.0-20220825204002-c680a09ffe64/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.0.0-20220722155259-a9ba230a4035 h1:Q5284mrmYTpACcm+eAKjKJH48BBwSyfJqmmGDTtT8Vc=
golang.org/x/term v0.0.0-20220722155259-a9ba230a4035/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=

6
modd.conf Normal file
View File

@ -0,0 +1,6 @@
**/*.go
Makefile {
prep: make build
daemon: make run
daemon: npx http-server .
}

55
option.go Normal file
View File

@ -0,0 +1,55 @@
package rebound
import "log"
type Options struct {
Address string `env:"REBOUND_ADDRESS"`
Logger func(message string, args ...any)
SockDir string `env:"REBOUND_SOCK_DIR"`
PublicPort uint `env:"REBOUND_PUBLIC_PORT"`
PublicHost string `env:"REBOUND_PUBLIC_HOST"`
HostKey string `env:"REBOUND_HOST_KEY"`
}
type OptionFunc func(*Options)
func DefaultOptions() *Options {
return &Options{
Address: "127.0.0.1:2222",
SockDir: "./socks",
Logger: log.Printf,
PublicPort: 2222,
PublicHost: "127.0.0.1",
HostKey: "./host.key",
}
}
func WithAddress(addr string) func(*Options) {
return func(opts *Options) {
opts.Address = addr
}
}
func WithSockDir(addr string) func(*Options) {
return func(opts *Options) {
opts.SockDir = addr
}
}
func WithPublicHost(host string) func(*Options) {
return func(opts *Options) {
opts.PublicHost = host
}
}
func WithPublicPort(port uint) func(*Options) {
return func(opts *Options) {
opts.PublicPort = port
}
}
func WithHostKey(key string) func(*Options) {
return func(opts *Options) {
opts.HostKey = key
}
}

226
request_handler.go Normal file
View File

@ -0,0 +1,226 @@
package rebound
import (
"crypto/rand"
"fmt"
"io"
"log"
"math/big"
"net"
"os"
"path/filepath"
"strconv"
"strings"
"github.com/gliderlabs/ssh"
"github.com/pkg/errors"
gossh "golang.org/x/crypto/ssh"
)
const (
forwardedTCPChannelType = "forwarded-tcpip"
)
// direct-tcpip data struct as specified in RFC4254, Section 7.2
type remoteForwardRequest struct {
BindAddr string
BindPort uint32
}
type remoteForwardSuccess struct {
BindPort uint32
}
type remoteForwardCancelRequest struct {
BindAddr string
BindPort uint32
}
type remoteForwardChannelData struct {
DestAddr string
DestPort uint32
OriginAddr string
OriginPort uint32
}
func (s *Server) handleRequest(ctx ssh.Context, srv *ssh.Server, req *gossh.Request) (bool, []byte) {
s.requestHandlerLock.Lock()
if s.forwards == nil {
s.forwards = make(map[string]net.Listener)
}
s.requestHandlerLock.Unlock()
conn := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn)
switch req.Type {
case "tcpip-forward":
var reqPayload remoteForwardRequest
if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil {
log.Printf("[ERROR] %+v", errors.WithStack(err))
return false, []byte{}
}
if reqPayload.BindPort != 0 {
return false, []byte("bind port must 0")
}
s.log("(%s): opening reverse tunnel", ctx.SessionID())
token, err := generateToken(16)
if err != nil {
s.log("[ERROR] %+v", errors.WithStack(err))
return false, []byte("could not generate secret token")
}
sessionID := SessionID(ctx.SessionID())
s.sessionManager.Set(sessionID, SessionData{
Type: TypeServiceProvider,
Token: token,
})
addr := s.getSocketPath(sessionID)
ln, err := net.Listen("unix", addr)
if err != nil {
s.log("[ERROR] %+v", errors.WithStack(err))
return false, []byte{}
}
destPort := 1
s.requestHandlerLock.Lock()
s.forwards[addr] = ln
s.requestHandlerLock.Unlock()
cleanup := func() {
s.log("(%s): cleaning up session", sessionID)
s.sessionManager.Remove(sessionID)
if err := os.Remove(addr); err != nil && !errors.Is(err, os.ErrNotExist) {
s.log("[ERROR] %+v", errors.WithStack(err))
}
}
go func() {
defer cleanup()
<-ctx.Done()
s.requestHandlerLock.Lock()
ln, ok := s.forwards[addr]
s.requestHandlerLock.Unlock()
if ok {
ln.Close()
}
}()
go func() {
for {
c, err := ln.Accept()
if err != nil {
if !errors.Is(err, net.ErrClosed) {
s.log("[ERROR] %+v", errors.WithStack(err))
}
break
}
originAddr, orignPortStr, _ := net.SplitHostPort(c.RemoteAddr().String())
originPort, _ := strconv.Atoi(orignPortStr)
payload := gossh.Marshal(&remoteForwardChannelData{
DestAddr: reqPayload.BindAddr,
DestPort: uint32(destPort),
OriginAddr: originAddr,
OriginPort: uint32(originPort),
})
go func() {
ch, reqs, err := conn.OpenChannel(forwardedTCPChannelType, payload)
if err != nil {
s.log("[ERROR] %+v", errors.WithStack(err))
c.Close()
return
}
go gossh.DiscardRequests(reqs)
go func() {
defer ch.Close()
defer c.Close()
io.Copy(ch, c)
}()
go func() {
defer ch.Close()
defer c.Close()
io.Copy(c, ch)
}()
}()
}
s.requestHandlerLock.Lock()
delete(s.forwards, addr)
s.requestHandlerLock.Unlock()
}()
return true, gossh.Marshal(&remoteForwardSuccess{uint32(destPort)})
case "cancel-tcpip-forward":
var reqPayload remoteForwardCancelRequest
if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil {
s.log("[ERROR] %+v", errors.WithStack(err))
return false, []byte{}
}
sessionID := SessionID(ctx.SessionID())
addr := s.getSocketPath(sessionID)
s.log("(%s): closing sock '%s'", sessionID, addr)
s.requestHandlerLock.Lock()
ln, ok := s.forwards[addr]
s.requestHandlerLock.Unlock()
if ok {
ln.Close()
if err := os.Remove(addr); err != nil && !errors.Is(err, os.ErrNotExist) {
s.log("[ERROR] %+v", errors.WithStack(err))
}
}
return true, nil
default:
return false, nil
}
}
func (s *Server) getSocketPath(sessionID SessionID) string {
return filepath.Join(s.opts.SockDir, fmt.Sprintf("%s.sock", sessionID))
}
func generateToken(length int) (string, error) {
chars := []rune(
"ABCDEFGHIJKLMNOPQRSTUVWXYZ" +
"abcdefghijklmnopqrstuvwxyz" +
"0123456789",
)
var b strings.Builder
charsLength := big.NewInt(int64(len(chars)))
for i := 0; i < length; i++ {
idx, err := rand.Int(rand.Reader, charsLength)
if err != nil {
return "", errors.WithStack(err)
}
c := chars[idx.Int64()]
if _, err := b.WriteRune(c); err != nil {
return "", errors.WithStack(err)
}
}
return b.String(), nil
}

141
server.go Normal file
View File

@ -0,0 +1,141 @@
package rebound
import (
"crypto/rand"
"crypto/rsa"
"crypto/x509"
"encoding/pem"
"io/fs"
"net"
"os"
"sync"
"time"
"github.com/gliderlabs/ssh"
"github.com/pkg/errors"
gossh "golang.org/x/crypto/ssh"
)
type Server struct {
ssh *ssh.Server
opts *Options
sessionManager *SessionManager
forwards map[string]net.Listener
requestHandlerLock sync.Mutex
}
func (s *Server) Run() error {
if err := s.initSSHServer(); err != nil {
return errors.WithStack(err)
}
s.log("listening on %s", s.opts.Address)
if err := s.ssh.ListenAndServe(); err != nil {
return errors.WithStack(err)
}
return nil
}
func (s *Server) Stop() error {
if s.ssh == nil {
return nil
}
s.log("stopping on '%s'", s.opts.Address)
if err := s.ssh.Close(); err != nil {
return errors.WithStack(err)
}
return nil
}
func (s *Server) initSSHServer() error {
signer, err := s.loadOrCreateSigner()
if err != nil {
return errors.WithStack(err)
}
server := &ssh.Server{
Addr: s.opts.Address,
HostSigners: []ssh.Signer{signer},
Handler: ssh.Handler(s.handleSession),
RequestHandlers: map[string]ssh.RequestHandler{
"tcpip-forward": s.handleRequest,
"cancel-tcpip-forward": s.handleRequest,
},
ChannelHandlers: map[string]ssh.ChannelHandler{
"direct-tcpip": s.handleDirectTCP,
"session": ssh.DefaultSessionHandler,
},
}
s.ssh = server
return nil
}
func (s *Server) loadOrCreateSigner() (ssh.Signer, error) {
var (
signer gossh.Signer
)
s.log("reading host key from '%s'", s.opts.HostKey)
data, err := os.ReadFile(s.opts.HostKey)
if err != nil && !errors.Is(err, os.ErrNotExist) {
return nil, errors.WithStack(err)
}
if data == nil {
s.log("host key cannot be found, generating one")
key, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
return nil, errors.WithStack(err)
}
signer, err = gossh.NewSignerFromKey(key)
if err != nil {
return nil, errors.WithStack(err)
}
pem := pem.EncodeToMemory(
&pem.Block{
Type: "RSA PRIVATE KEY",
Bytes: x509.MarshalPKCS1PrivateKey(key),
},
)
s.log("saving host key to '%s'", s.opts.HostKey)
if err := os.WriteFile(s.opts.HostKey, pem, fs.FileMode(0640)); err != nil {
return nil, errors.WithStack(err)
}
} else {
signer, err = gossh.ParsePrivateKey(data)
if err != nil {
return nil, errors.WithStack(err)
}
}
return signer, nil
}
func (s *Server) log(message string, args ...any) {
s.opts.Logger(message, args...)
}
func NewServer(funcs ...OptionFunc) *Server {
opts := DefaultOptions()
for _, fn := range funcs {
fn(opts)
}
return &Server{
opts: opts,
sessionManager: NewSessionManager(30 * time.Second),
}
}

133
session_handler.go Normal file
View File

@ -0,0 +1,133 @@
package rebound
import (
"fmt"
"io"
"github.com/gliderlabs/ssh"
"github.com/pkg/errors"
)
func (s *Server) handleSession(sess ssh.Session) {
ctx := sess.Context()
sessionID := SessionID(ctx.SessionID())
s.log("(%s): session opened", sessionID)
message := `
Welcome on Rebound !
Type Ctrl+C or Ctrl+D to exit.
`
if _, err := sess.Write([]byte(message)); err != nil {
s.log("[ERROR] %+v", errors.WithStack(err))
}
go s.readClientInput(sess)
data := s.sessionManager.Get(sessionID, SessionData{
Type: TypeServiceUnknown,
})
s.handleSessionData(sess, data)
onUpdate, close := s.sessionManager.OnUpdate(sessionID)
defer close()
for {
data, opened := <-onUpdate
if !opened {
return
}
s.log("(%s): session data updated: %v", sessionID, data)
if err := s.handleSessionData(sess, data); err != nil {
if errors.Is(err, io.EOF) {
return
}
s.log("[ERROR] %+v", errors.WithStack(err))
return
}
}
}
func (s *Server) handleSessionData(sess ssh.Session, data SessionData) error {
switch data.Type {
case TypeServiceConsumer:
if err := s.writeConsumerMessage(sess, data); err != nil {
return errors.WithStack(err)
}
case TypeServiceProvider:
if err := s.writeProviderMessage(sess, data); err != nil {
return errors.WithStack(err)
}
}
return nil
}
func (s *Server) writeConsumerMessage(sess ssh.Session, data SessionData) error {
message := ``
if _, err := sess.Write([]byte(message)); err != nil {
return errors.WithStack(err)
}
return nil
}
func (s *Server) writeProviderMessage(sess ssh.Session, data SessionData) error {
message := fmt.Sprintf(`
You can connect to your tunnel by running in an other terminal:
ssh -L <local-port>:0.0.0.0:1 %s@%s -p %d
`, data.Token, s.opts.PublicHost, s.opts.PublicPort)
if _, err := sess.Write([]byte(message)); err != nil {
return errors.WithStack(err)
}
return nil
}
const (
CtrlC = 3
CtrlD = 4
)
func (s *Server) readClientInput(sess ssh.Session) {
sessionID := SessionID(sess.Context().SessionID())
defer func() {
s.sessionManager.Remove(sessionID)
}()
buff := make([]byte, 1)
for {
_, err := sess.Read(buff)
if err != nil {
if !errors.Is(err, io.EOF) {
s.log("[ERROR] %+v", errors.WithStack(err))
}
return
}
switch buff[0] {
case CtrlC:
fallthrough
case CtrlD:
sess.Exit(0)
return
default:
s.log("(%s) user input: %v", sessionID, buff)
}
}
}

184
session_manager.go Normal file
View File

@ -0,0 +1,184 @@
package rebound
import (
"errors"
"log"
"sync"
"time"
)
type SessionID string
type SessionType int
const (
TypeServiceUnknown SessionType = iota
TypeServiceProvider
TypeServiceConsumer
)
type SessionData struct {
Type SessionType
Token string
}
type SessionManager struct {
sessions map[SessionID]SessionData
sessionsMutex sync.Mutex
tokenIndex map[string]SessionID
updates map[SessionID][]chan SessionData
updatesMutex sync.Mutex
updateReadTimeout time.Duration
}
func (m *SessionManager) Get(sessID SessionID, defaultValue SessionData) SessionData {
log.Println("reading session", sessID)
m.sessionsMutex.Lock()
defer m.sessionsMutex.Unlock()
session, exists := m.sessions[sessID]
if !exists {
session = defaultValue
m.sessions[sessID] = session
m.updatesMutex.Lock()
m.dispatchUpdate(sessID, session)
m.updatesMutex.Unlock()
}
return session
}
func (m *SessionManager) FindByToken(token string) SessionID {
m.sessionsMutex.Lock()
defer m.sessionsMutex.Unlock()
sessID, exists := m.tokenIndex[token]
if !exists {
return ""
}
return sessID
}
func (m *SessionManager) Set(sessID SessionID, sess SessionData) {
log.Println("updating session", sessID, sess)
m.sessionsMutex.Lock()
oldSess, ok := m.sessions[sessID]
if ok {
m.updateTokenIndex(sessID, sess.Token, oldSess.Token)
} else {
m.updateTokenIndex(sessID, sess.Token, "")
}
m.sessions[sessID] = sess
m.sessionsMutex.Unlock()
m.updatesMutex.Lock()
m.dispatchUpdate(sessID, sess)
m.updatesMutex.Unlock()
}
func (m *SessionManager) Remove(sessID SessionID) {
m.sessionsMutex.Lock()
oldSess, ok := m.sessions[sessID]
if ok {
m.updateTokenIndex(sessID, "", oldSess.Token)
}
delete(m.sessions, sessID)
m.sessionsMutex.Unlock()
m.updatesMutex.Lock()
m.closeAllUpdates(sessID)
m.updatesMutex.Unlock()
}
func (m *SessionManager) OnUpdate(sessID SessionID) (<-chan SessionData, func()) {
update := make(chan SessionData)
close := func() {
m.updatesMutex.Lock()
m.closeUpdate(sessID, update)
m.updatesMutex.Unlock()
}
m.updatesMutex.Lock()
defer m.updatesMutex.Unlock()
channels, exists := m.updates[sessID]
if !exists {
channels = make([]chan SessionData, 0, 1)
}
channels = append(channels, update)
m.updates[sessID] = channels
return update, close
}
func (m *SessionManager) closeAllUpdates(sessID SessionID) {
channels, exists := m.updates[sessID]
if !exists {
return
}
for _, ch := range channels {
m.closeUpdate(sessID, ch)
}
}
func (m *SessionManager) closeUpdate(sessID SessionID, update chan SessionData) {
channels, exists := m.updates[sessID]
if !exists {
return
}
for idx, ch := range channels {
if ch != update {
continue
}
close(ch)
m.updates[sessID] = append(channels[:idx], channels[idx+1:]...)
}
}
func (m *SessionManager) dispatchUpdate(sessID SessionID, sess SessionData) {
channels, exists := m.updates[sessID]
if !exists {
return
}
for _, ch := range channels {
timeout := time.After(m.updateReadTimeout)
select {
case ch <- sess:
case <-timeout:
err := errors.New("session update read timed out")
log.Printf("[ERROR] %+v", err)
}
}
}
func (m *SessionManager) updateTokenIndex(sessID SessionID, addedToken, deletedToken string) {
if addedToken != "" {
m.tokenIndex[addedToken] = sessID
}
if deletedToken != "" {
delete(m.tokenIndex, deletedToken)
}
}
func NewSessionManager(updateReadTimeout time.Duration) *SessionManager {
return &SessionManager{
sessions: make(map[SessionID]SessionData),
tokenIndex: make(map[string]SessionID),
updates: make(map[SessionID][]chan SessionData),
updateReadTimeout: updateReadTimeout,
}
}