feat: initial commit
This commit is contained in:
commit
1bdad36787
|
@ -0,0 +1,5 @@
|
|||
/bin
|
||||
/.mktools
|
||||
/tools
|
||||
/.env
|
||||
/socks
|
|
@ -0,0 +1,6 @@
|
|||
/tools
|
||||
/.mktools
|
||||
/bin
|
||||
/.env
|
||||
/socks
|
||||
/host.key
|
|
@ -0,0 +1,27 @@
|
|||
FROM golang:1.21 AS build
|
||||
|
||||
RUN apt-get update && apt-get install -y build-essential git bash curl ca-certificates
|
||||
|
||||
COPY . /src
|
||||
|
||||
WORKDIR /src
|
||||
|
||||
RUN make build
|
||||
|
||||
FROM busybox
|
||||
|
||||
COPY --from=build /src/bin /app
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
RUN mkdir -p /app/socks /app/keys
|
||||
|
||||
EXPOSE 2222
|
||||
|
||||
ENV REBOUND_PUBLIC_HOST=127.0.0.1
|
||||
ENV REBOUND_PUBLIC_PORT=2222
|
||||
ENV REBOUND_HOST_KEY=/app/keys/host.key
|
||||
ENV REBOUND_ADDRESS=:2222
|
||||
ENV REBOUND_SOCK_DIR=/app/socks
|
||||
|
||||
CMD ["/app/server"]
|
|
@ -0,0 +1,49 @@
|
|||
SHELL := /bin/bash
|
||||
DOKKU_URL := dokku@dev.lookingfora.name:rebound
|
||||
|
||||
all: build
|
||||
|
||||
watch: tools/modd/bin/modd
|
||||
tools/modd/bin/modd
|
||||
|
||||
socks:
|
||||
mkdir -p ./socks
|
||||
|
||||
run: .env socks
|
||||
( set -o allexport && source .env && set +o allexport && bin/server )
|
||||
|
||||
build: .env
|
||||
CGO_ENABLED=0 go build -o ./bin/server ./cmd/server
|
||||
|
||||
.env:
|
||||
cp .env.dist .env
|
||||
|
||||
dokku-build:
|
||||
docker build \
|
||||
-t rebound-dokku:latest \
|
||||
.
|
||||
|
||||
dokku-run:
|
||||
docker run \
|
||||
-it --rm \
|
||||
-p 2222:2222 \
|
||||
--tmpfs /socks \
|
||||
rebound-dokku:latest
|
||||
|
||||
dokku-deploy:
|
||||
$(if $(shell git config remote.dokku.url),, git remote add dokku $(DOKKU_URL))
|
||||
git push -f dokku $(shell git rev-parse HEAD):refs/heads/master
|
||||
|
||||
.PHONY: mktools
|
||||
mktools:
|
||||
rm -rf .mktools
|
||||
curl -q https://forge.cadoles.com/Cadoles/mktools/raw/branch/master/install.sh | $(SHELL)
|
||||
|
||||
tools/modd/bin/modd:
|
||||
mkdir -p tools/modd/bin
|
||||
GOBIN=$(PWD)/tools/modd/bin go install github.com/cortesi/modd/cmd/modd@v0.8.1
|
||||
|
||||
.mktools:
|
||||
$(MAKE) mktools
|
||||
|
||||
-include .mktools/*.mk
|
|
@ -0,0 +1,19 @@
|
|||
# Rebound
|
||||
|
||||
Serveur utilisant le protocole SSH pour créer des tunnels TCP/IP.
|
||||
|
||||
## Usage
|
||||
|
||||
### Ouvrir un tunnel sur un port local
|
||||
|
||||
```shell
|
||||
ssh -R 0:<target_address>:<target_port> root@rebound.lookingfora.name -p 2222
|
||||
```
|
||||
|
||||
### Se connecter à un tunnel ouvert
|
||||
|
||||
```
|
||||
ssh -L <local_port>:0.0.0.0:1 <secret>@rebound.lookingfora.name -p 2222
|
||||
```
|
||||
|
||||
Vous pourrez ensuite accéder au service sur `127.0.0.1:<local_port>` comme si c'était `<target_address>:<target_port>`.
|
|
@ -0,0 +1,50 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"flag"
|
||||
"log"
|
||||
"os"
|
||||
"os/signal"
|
||||
|
||||
"forge.cadoles.com/wpetit/rebound"
|
||||
"github.com/caarlos0/env/v6"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func main() {
|
||||
opts := rebound.DefaultOptions()
|
||||
if err := env.Parse(opts); err != nil {
|
||||
log.Fatalf("[ERROR] %+v", errors.WithStack(err))
|
||||
}
|
||||
|
||||
address := flag.String("address", opts.Address, "server listening address")
|
||||
sockDir := flag.String("sock-dir", opts.SockDir, "sock directory")
|
||||
publicPort := flag.Uint("public-port", opts.PublicPort, "public port")
|
||||
publicHost := flag.String("public-host", opts.PublicHost, "public host")
|
||||
hostKey := flag.String("host-key", opts.HostKey, "host key")
|
||||
|
||||
flag.Parse()
|
||||
|
||||
server := rebound.NewServer(
|
||||
rebound.WithAddress(*address),
|
||||
rebound.WithSockDir(*sockDir),
|
||||
rebound.WithPublicPort(*publicPort),
|
||||
rebound.WithPublicHost(*publicHost),
|
||||
rebound.WithHostKey(*hostKey),
|
||||
)
|
||||
|
||||
go func() {
|
||||
if err := server.Run(); err != nil {
|
||||
log.Fatalf("[FATAL] %+v", errors.WithStack(err))
|
||||
}
|
||||
}()
|
||||
|
||||
c := make(chan os.Signal, 1)
|
||||
signal.Notify(c, os.Interrupt)
|
||||
|
||||
<-c
|
||||
|
||||
if err := server.Stop(); err != nil {
|
||||
log.Fatalf("[FATAL] %+v", errors.WithStack(err))
|
||||
}
|
||||
}
|
|
@ -0,0 +1,97 @@
|
|||
package rebound
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
"github.com/pkg/errors"
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// direct-tcpip data struct as specified in RFC4254, Section 7.2
|
||||
type localForwardChannelData struct {
|
||||
DestAddr string
|
||||
DestPort uint32
|
||||
|
||||
OriginAddr string
|
||||
OriginPort uint32
|
||||
}
|
||||
|
||||
func (s *Server) handleDirectTCP(srv *ssh.Server, conn *gossh.ServerConn, newChan gossh.NewChannel, ctx ssh.Context) {
|
||||
d := localForwardChannelData{}
|
||||
if err := gossh.Unmarshal(newChan.ExtraData(), &d); err != nil {
|
||||
s.log("[ERROR] %+v", errors.WithStack(err))
|
||||
newChan.Reject(gossh.ConnectionFailed, "error parsing forward data: "+err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
sessionID := SessionID(ctx.SessionID())
|
||||
|
||||
s.log("(%s): opening direct tcp tunnel", sessionID)
|
||||
|
||||
s.sessionManager.Set(sessionID, SessionData{
|
||||
Type: TypeServiceConsumer,
|
||||
})
|
||||
|
||||
sessionToken := ctx.User()
|
||||
|
||||
remoteSessionID := s.sessionManager.FindByToken(sessionToken)
|
||||
if remoteSessionID == "" {
|
||||
newChan.Reject(gossh.ConnectionFailed, fmt.Sprintf("could not find session associated with token '%s'", sessionToken))
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
addr := s.getSocketPath(remoteSessionID)
|
||||
|
||||
s.log("(%s): using sock '%s'", sessionID, addr)
|
||||
|
||||
if _, err := os.Stat(addr); err != nil {
|
||||
newChan.Reject(gossh.ConnectionFailed, fmt.Sprintf("could not find session associated with token '%s'", sessionToken))
|
||||
conn.Close()
|
||||
return
|
||||
}
|
||||
|
||||
var dialer net.Dialer
|
||||
dconn, err := dialer.DialContext(ctx, "unix", addr)
|
||||
if err != nil {
|
||||
s.log("[ERROR] %+v", errors.WithStack(err))
|
||||
newChan.Reject(gossh.ConnectionFailed, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
ch, reqs, err := newChan.Accept()
|
||||
if err != nil {
|
||||
s.log("[ERROR] %+v", errors.WithStack(err))
|
||||
dconn.Close()
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
go gossh.DiscardRequests(reqs)
|
||||
|
||||
go func() {
|
||||
defer dconn.Close()
|
||||
defer ch.Close()
|
||||
|
||||
if _, err := io.Copy(ch, dconn); err != nil {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return
|
||||
}
|
||||
|
||||
s.log("[ERROR] %+v", errors.WithStack(err))
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer dconn.Close()
|
||||
defer ch.Close()
|
||||
|
||||
if _, err := io.Copy(dconn, ch); err != nil {
|
||||
s.log("[ERROR] %+v", errors.WithStack(err))
|
||||
}
|
||||
}()
|
||||
}
|
|
@ -0,0 +1,15 @@
|
|||
module forge.cadoles.com/wpetit/rebound
|
||||
|
||||
go 1.21.0
|
||||
|
||||
require (
|
||||
github.com/caarlos0/env/v6 v6.10.1
|
||||
github.com/gliderlabs/ssh v0.3.5
|
||||
github.com/pkg/errors v0.9.1
|
||||
golang.org/x/crypto v0.0.0-20220826181053-bd7e27e6170d
|
||||
)
|
||||
|
||||
require (
|
||||
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect
|
||||
golang.org/x/sys v0.0.0-20220825204002-c680a09ffe64 // indirect
|
||||
)
|
|
@ -0,0 +1,25 @@
|
|||
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be h1:9AeTilPcZAjCFIImctFaOjnTIavg87rW78vTPkQqLI8=
|
||||
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be/go.mod h1:ySMOLuWl6zY27l47sB3qLNK6tF2fkHG55UZxx8oIVo4=
|
||||
github.com/caarlos0/env/v6 v6.10.1 h1:t1mPSxNpei6M5yAeu1qtRdPAK29Nbcf/n3G7x+b3/II=
|
||||
github.com/caarlos0/env/v6 v6.10.1/go.mod h1:hvp/ryKXKipEkcuYjs9mI4bBCg+UI0Yhgm5Zu0ddvwc=
|
||||
github.com/gliderlabs/ssh v0.3.5 h1:OcaySEmAQJgyYcArR+gGGTHCyE7nvhEMTlYY+Dp8CpY=
|
||||
github.com/gliderlabs/ssh v0.3.5/go.mod h1:8XB4KraRrX39qHhT6yxPsHedjA08I/uBVwj4xC+/+z4=
|
||||
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
|
||||
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
|
||||
golang.org/x/crypto v0.0.0-20220826181053-bd7e27e6170d h1:3qF+Z8Hkrw9sOhrFHti9TlB1Hkac1x+DNRkv0XQiFjo=
|
||||
golang.org/x/crypto v0.0.0-20220826181053-bd7e27e6170d/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
|
||||
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
|
||||
golang.org/x/net v0.0.0-20220826154423-83b083e8dc8b/go.mod h1:YDH+HFinaLZZlnHAfSS6ZXJJ9M9t4Dl22yv3iI2vPwk=
|
||||
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220728004956-3c1f35247d10/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/sys v0.0.0-20220825204002-c680a09ffe64 h1:UiNENfZ8gDvpiWw7IpOMQ27spWmThO1RwwdQVbJahJM=
|
||||
golang.org/x/sys v0.0.0-20220825204002-c680a09ffe64/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
|
||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/term v0.0.0-20220722155259-a9ba230a4035 h1:Q5284mrmYTpACcm+eAKjKJH48BBwSyfJqmmGDTtT8Vc=
|
||||
golang.org/x/term v0.0.0-20220722155259-a9ba230a4035/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||
golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
|
|
@ -0,0 +1,6 @@
|
|||
**/*.go
|
||||
Makefile {
|
||||
prep: make build
|
||||
daemon: make run
|
||||
daemon: npx http-server .
|
||||
}
|
|
@ -0,0 +1,55 @@
|
|||
package rebound
|
||||
|
||||
import "log"
|
||||
|
||||
type Options struct {
|
||||
Address string `env:"REBOUND_ADDRESS"`
|
||||
Logger func(message string, args ...any)
|
||||
SockDir string `env:"REBOUND_SOCK_DIR"`
|
||||
PublicPort uint `env:"REBOUND_PUBLIC_PORT"`
|
||||
PublicHost string `env:"REBOUND_PUBLIC_HOST"`
|
||||
HostKey string `env:"REBOUND_HOST_KEY"`
|
||||
}
|
||||
|
||||
type OptionFunc func(*Options)
|
||||
|
||||
func DefaultOptions() *Options {
|
||||
return &Options{
|
||||
Address: "127.0.0.1:2222",
|
||||
SockDir: "./socks",
|
||||
Logger: log.Printf,
|
||||
PublicPort: 2222,
|
||||
PublicHost: "127.0.0.1",
|
||||
HostKey: "./host.key",
|
||||
}
|
||||
}
|
||||
|
||||
func WithAddress(addr string) func(*Options) {
|
||||
return func(opts *Options) {
|
||||
opts.Address = addr
|
||||
}
|
||||
}
|
||||
|
||||
func WithSockDir(addr string) func(*Options) {
|
||||
return func(opts *Options) {
|
||||
opts.SockDir = addr
|
||||
}
|
||||
}
|
||||
|
||||
func WithPublicHost(host string) func(*Options) {
|
||||
return func(opts *Options) {
|
||||
opts.PublicHost = host
|
||||
}
|
||||
}
|
||||
|
||||
func WithPublicPort(port uint) func(*Options) {
|
||||
return func(opts *Options) {
|
||||
opts.PublicPort = port
|
||||
}
|
||||
}
|
||||
|
||||
func WithHostKey(key string) func(*Options) {
|
||||
return func(opts *Options) {
|
||||
opts.HostKey = key
|
||||
}
|
||||
}
|
|
@ -0,0 +1,226 @@
|
|||
package rebound
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"io"
|
||||
"log"
|
||||
"math/big"
|
||||
"net"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
"github.com/pkg/errors"
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
const (
|
||||
forwardedTCPChannelType = "forwarded-tcpip"
|
||||
)
|
||||
|
||||
// direct-tcpip data struct as specified in RFC4254, Section 7.2
|
||||
type remoteForwardRequest struct {
|
||||
BindAddr string
|
||||
BindPort uint32
|
||||
}
|
||||
|
||||
type remoteForwardSuccess struct {
|
||||
BindPort uint32
|
||||
}
|
||||
|
||||
type remoteForwardCancelRequest struct {
|
||||
BindAddr string
|
||||
BindPort uint32
|
||||
}
|
||||
|
||||
type remoteForwardChannelData struct {
|
||||
DestAddr string
|
||||
DestPort uint32
|
||||
OriginAddr string
|
||||
OriginPort uint32
|
||||
}
|
||||
|
||||
func (s *Server) handleRequest(ctx ssh.Context, srv *ssh.Server, req *gossh.Request) (bool, []byte) {
|
||||
s.requestHandlerLock.Lock()
|
||||
if s.forwards == nil {
|
||||
s.forwards = make(map[string]net.Listener)
|
||||
}
|
||||
s.requestHandlerLock.Unlock()
|
||||
conn := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn)
|
||||
switch req.Type {
|
||||
case "tcpip-forward":
|
||||
var reqPayload remoteForwardRequest
|
||||
|
||||
if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil {
|
||||
log.Printf("[ERROR] %+v", errors.WithStack(err))
|
||||
return false, []byte{}
|
||||
}
|
||||
|
||||
if reqPayload.BindPort != 0 {
|
||||
return false, []byte("bind port must 0")
|
||||
}
|
||||
|
||||
s.log("(%s): opening reverse tunnel", ctx.SessionID())
|
||||
|
||||
token, err := generateToken(16)
|
||||
if err != nil {
|
||||
s.log("[ERROR] %+v", errors.WithStack(err))
|
||||
return false, []byte("could not generate secret token")
|
||||
}
|
||||
|
||||
sessionID := SessionID(ctx.SessionID())
|
||||
|
||||
s.sessionManager.Set(sessionID, SessionData{
|
||||
Type: TypeServiceProvider,
|
||||
Token: token,
|
||||
})
|
||||
|
||||
addr := s.getSocketPath(sessionID)
|
||||
|
||||
ln, err := net.Listen("unix", addr)
|
||||
if err != nil {
|
||||
s.log("[ERROR] %+v", errors.WithStack(err))
|
||||
return false, []byte{}
|
||||
}
|
||||
|
||||
destPort := 1
|
||||
|
||||
s.requestHandlerLock.Lock()
|
||||
s.forwards[addr] = ln
|
||||
s.requestHandlerLock.Unlock()
|
||||
|
||||
cleanup := func() {
|
||||
s.log("(%s): cleaning up session", sessionID)
|
||||
|
||||
s.sessionManager.Remove(sessionID)
|
||||
|
||||
if err := os.Remove(addr); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
s.log("[ERROR] %+v", errors.WithStack(err))
|
||||
}
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer cleanup()
|
||||
|
||||
<-ctx.Done()
|
||||
s.requestHandlerLock.Lock()
|
||||
ln, ok := s.forwards[addr]
|
||||
s.requestHandlerLock.Unlock()
|
||||
if ok {
|
||||
ln.Close()
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
for {
|
||||
c, err := ln.Accept()
|
||||
if err != nil {
|
||||
if !errors.Is(err, net.ErrClosed) {
|
||||
s.log("[ERROR] %+v", errors.WithStack(err))
|
||||
}
|
||||
|
||||
break
|
||||
}
|
||||
|
||||
originAddr, orignPortStr, _ := net.SplitHostPort(c.RemoteAddr().String())
|
||||
originPort, _ := strconv.Atoi(orignPortStr)
|
||||
payload := gossh.Marshal(&remoteForwardChannelData{
|
||||
DestAddr: reqPayload.BindAddr,
|
||||
DestPort: uint32(destPort),
|
||||
OriginAddr: originAddr,
|
||||
OriginPort: uint32(originPort),
|
||||
})
|
||||
|
||||
go func() {
|
||||
ch, reqs, err := conn.OpenChannel(forwardedTCPChannelType, payload)
|
||||
if err != nil {
|
||||
s.log("[ERROR] %+v", errors.WithStack(err))
|
||||
c.Close()
|
||||
return
|
||||
}
|
||||
|
||||
go gossh.DiscardRequests(reqs)
|
||||
|
||||
go func() {
|
||||
defer ch.Close()
|
||||
defer c.Close()
|
||||
|
||||
io.Copy(ch, c)
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer ch.Close()
|
||||
defer c.Close()
|
||||
|
||||
io.Copy(c, ch)
|
||||
}()
|
||||
}()
|
||||
}
|
||||
|
||||
s.requestHandlerLock.Lock()
|
||||
delete(s.forwards, addr)
|
||||
s.requestHandlerLock.Unlock()
|
||||
}()
|
||||
|
||||
return true, gossh.Marshal(&remoteForwardSuccess{uint32(destPort)})
|
||||
|
||||
case "cancel-tcpip-forward":
|
||||
var reqPayload remoteForwardCancelRequest
|
||||
if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil {
|
||||
s.log("[ERROR] %+v", errors.WithStack(err))
|
||||
return false, []byte{}
|
||||
}
|
||||
|
||||
sessionID := SessionID(ctx.SessionID())
|
||||
|
||||
addr := s.getSocketPath(sessionID)
|
||||
|
||||
s.log("(%s): closing sock '%s'", sessionID, addr)
|
||||
|
||||
s.requestHandlerLock.Lock()
|
||||
ln, ok := s.forwards[addr]
|
||||
s.requestHandlerLock.Unlock()
|
||||
if ok {
|
||||
ln.Close()
|
||||
if err := os.Remove(addr); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
s.log("[ERROR] %+v", errors.WithStack(err))
|
||||
}
|
||||
}
|
||||
return true, nil
|
||||
|
||||
default:
|
||||
return false, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) getSocketPath(sessionID SessionID) string {
|
||||
return filepath.Join(s.opts.SockDir, fmt.Sprintf("%s.sock", sessionID))
|
||||
}
|
||||
|
||||
func generateToken(length int) (string, error) {
|
||||
chars := []rune(
|
||||
"ABCDEFGHIJKLMNOPQRSTUVWXYZ" +
|
||||
"abcdefghijklmnopqrstuvwxyz" +
|
||||
"0123456789",
|
||||
)
|
||||
var b strings.Builder
|
||||
|
||||
charsLength := big.NewInt(int64(len(chars)))
|
||||
|
||||
for i := 0; i < length; i++ {
|
||||
idx, err := rand.Int(rand.Reader, charsLength)
|
||||
if err != nil {
|
||||
return "", errors.WithStack(err)
|
||||
}
|
||||
|
||||
c := chars[idx.Int64()]
|
||||
if _, err := b.WriteRune(c); err != nil {
|
||||
return "", errors.WithStack(err)
|
||||
}
|
||||
}
|
||||
|
||||
return b.String(), nil
|
||||
}
|
|
@ -0,0 +1,141 @@
|
|||
package rebound
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"io/fs"
|
||||
"net"
|
||||
"os"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
ssh *ssh.Server
|
||||
opts *Options
|
||||
|
||||
sessionManager *SessionManager
|
||||
forwards map[string]net.Listener
|
||||
requestHandlerLock sync.Mutex
|
||||
}
|
||||
|
||||
func (s *Server) Run() error {
|
||||
if err := s.initSSHServer(); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
s.log("listening on %s", s.opts.Address)
|
||||
|
||||
if err := s.ssh.ListenAndServe(); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) Stop() error {
|
||||
if s.ssh == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.log("stopping on '%s'", s.opts.Address)
|
||||
|
||||
if err := s.ssh.Close(); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) initSSHServer() error {
|
||||
signer, err := s.loadOrCreateSigner()
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
server := &ssh.Server{
|
||||
Addr: s.opts.Address,
|
||||
HostSigners: []ssh.Signer{signer},
|
||||
Handler: ssh.Handler(s.handleSession),
|
||||
RequestHandlers: map[string]ssh.RequestHandler{
|
||||
"tcpip-forward": s.handleRequest,
|
||||
"cancel-tcpip-forward": s.handleRequest,
|
||||
},
|
||||
ChannelHandlers: map[string]ssh.ChannelHandler{
|
||||
"direct-tcpip": s.handleDirectTCP,
|
||||
"session": ssh.DefaultSessionHandler,
|
||||
},
|
||||
}
|
||||
|
||||
s.ssh = server
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) loadOrCreateSigner() (ssh.Signer, error) {
|
||||
var (
|
||||
signer gossh.Signer
|
||||
)
|
||||
|
||||
s.log("reading host key from '%s'", s.opts.HostKey)
|
||||
|
||||
data, err := os.ReadFile(s.opts.HostKey)
|
||||
if err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
if data == nil {
|
||||
s.log("host key cannot be found, generating one")
|
||||
key, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
signer, err = gossh.NewSignerFromKey(key)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
pem := pem.EncodeToMemory(
|
||||
&pem.Block{
|
||||
Type: "RSA PRIVATE KEY",
|
||||
Bytes: x509.MarshalPKCS1PrivateKey(key),
|
||||
},
|
||||
)
|
||||
|
||||
s.log("saving host key to '%s'", s.opts.HostKey)
|
||||
if err := os.WriteFile(s.opts.HostKey, pem, fs.FileMode(0640)); err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
} else {
|
||||
signer, err = gossh.ParsePrivateKey(data)
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
}
|
||||
|
||||
return signer, nil
|
||||
}
|
||||
|
||||
func (s *Server) log(message string, args ...any) {
|
||||
s.opts.Logger(message, args...)
|
||||
}
|
||||
|
||||
func NewServer(funcs ...OptionFunc) *Server {
|
||||
opts := DefaultOptions()
|
||||
for _, fn := range funcs {
|
||||
fn(opts)
|
||||
}
|
||||
|
||||
return &Server{
|
||||
opts: opts,
|
||||
sessionManager: NewSessionManager(30 * time.Second),
|
||||
}
|
||||
}
|
|
@ -0,0 +1,133 @@
|
|||
package rebound
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func (s *Server) handleSession(sess ssh.Session) {
|
||||
ctx := sess.Context()
|
||||
sessionID := SessionID(ctx.SessionID())
|
||||
|
||||
s.log("(%s): session opened", sessionID)
|
||||
|
||||
message := `
|
||||
Welcome on Rebound !
|
||||
|
||||
Type Ctrl+C or Ctrl+D to exit.
|
||||
`
|
||||
|
||||
if _, err := sess.Write([]byte(message)); err != nil {
|
||||
s.log("[ERROR] %+v", errors.WithStack(err))
|
||||
}
|
||||
|
||||
go s.readClientInput(sess)
|
||||
|
||||
data := s.sessionManager.Get(sessionID, SessionData{
|
||||
Type: TypeServiceUnknown,
|
||||
})
|
||||
s.handleSessionData(sess, data)
|
||||
|
||||
onUpdate, close := s.sessionManager.OnUpdate(sessionID)
|
||||
defer close()
|
||||
|
||||
for {
|
||||
data, opened := <-onUpdate
|
||||
if !opened {
|
||||
return
|
||||
}
|
||||
|
||||
s.log("(%s): session data updated: %v", sessionID, data)
|
||||
|
||||
if err := s.handleSessionData(sess, data); err != nil {
|
||||
if errors.Is(err, io.EOF) {
|
||||
return
|
||||
}
|
||||
|
||||
s.log("[ERROR] %+v", errors.WithStack(err))
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handleSessionData(sess ssh.Session, data SessionData) error {
|
||||
switch data.Type {
|
||||
case TypeServiceConsumer:
|
||||
if err := s.writeConsumerMessage(sess, data); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
case TypeServiceProvider:
|
||||
if err := s.writeProviderMessage(sess, data); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) writeConsumerMessage(sess ssh.Session, data SessionData) error {
|
||||
message := ``
|
||||
|
||||
if _, err := sess.Write([]byte(message)); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) writeProviderMessage(sess ssh.Session, data SessionData) error {
|
||||
message := fmt.Sprintf(`
|
||||
You can connect to your tunnel by running in an other terminal:
|
||||
|
||||
ssh -L <local-port>:0.0.0.0:1 %s@%s -p %d
|
||||
|
||||
`, data.Token, s.opts.PublicHost, s.opts.PublicPort)
|
||||
|
||||
if _, err := sess.Write([]byte(message)); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
const (
|
||||
CtrlC = 3
|
||||
CtrlD = 4
|
||||
)
|
||||
|
||||
func (s *Server) readClientInput(sess ssh.Session) {
|
||||
sessionID := SessionID(sess.Context().SessionID())
|
||||
defer func() {
|
||||
s.sessionManager.Remove(sessionID)
|
||||
}()
|
||||
|
||||
buff := make([]byte, 1)
|
||||
|
||||
for {
|
||||
_, err := sess.Read(buff)
|
||||
if err != nil {
|
||||
if !errors.Is(err, io.EOF) {
|
||||
s.log("[ERROR] %+v", errors.WithStack(err))
|
||||
}
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
switch buff[0] {
|
||||
case CtrlC:
|
||||
fallthrough
|
||||
case CtrlD:
|
||||
sess.Exit(0)
|
||||
return
|
||||
|
||||
default:
|
||||
s.log("(%s) user input: %v", sessionID, buff)
|
||||
}
|
||||
|
||||
}
|
||||
}
|
|
@ -0,0 +1,184 @@
|
|||
package rebound
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
type SessionID string
|
||||
|
||||
type SessionType int
|
||||
|
||||
const (
|
||||
TypeServiceUnknown SessionType = iota
|
||||
TypeServiceProvider
|
||||
TypeServiceConsumer
|
||||
)
|
||||
|
||||
type SessionData struct {
|
||||
Type SessionType
|
||||
Token string
|
||||
}
|
||||
|
||||
type SessionManager struct {
|
||||
sessions map[SessionID]SessionData
|
||||
sessionsMutex sync.Mutex
|
||||
|
||||
tokenIndex map[string]SessionID
|
||||
|
||||
updates map[SessionID][]chan SessionData
|
||||
updatesMutex sync.Mutex
|
||||
|
||||
updateReadTimeout time.Duration
|
||||
}
|
||||
|
||||
func (m *SessionManager) Get(sessID SessionID, defaultValue SessionData) SessionData {
|
||||
log.Println("reading session", sessID)
|
||||
|
||||
m.sessionsMutex.Lock()
|
||||
defer m.sessionsMutex.Unlock()
|
||||
|
||||
session, exists := m.sessions[sessID]
|
||||
if !exists {
|
||||
session = defaultValue
|
||||
m.sessions[sessID] = session
|
||||
|
||||
m.updatesMutex.Lock()
|
||||
m.dispatchUpdate(sessID, session)
|
||||
m.updatesMutex.Unlock()
|
||||
}
|
||||
|
||||
return session
|
||||
}
|
||||
|
||||
func (m *SessionManager) FindByToken(token string) SessionID {
|
||||
m.sessionsMutex.Lock()
|
||||
defer m.sessionsMutex.Unlock()
|
||||
|
||||
sessID, exists := m.tokenIndex[token]
|
||||
if !exists {
|
||||
return ""
|
||||
}
|
||||
|
||||
return sessID
|
||||
}
|
||||
|
||||
func (m *SessionManager) Set(sessID SessionID, sess SessionData) {
|
||||
log.Println("updating session", sessID, sess)
|
||||
|
||||
m.sessionsMutex.Lock()
|
||||
oldSess, ok := m.sessions[sessID]
|
||||
if ok {
|
||||
m.updateTokenIndex(sessID, sess.Token, oldSess.Token)
|
||||
} else {
|
||||
m.updateTokenIndex(sessID, sess.Token, "")
|
||||
}
|
||||
m.sessions[sessID] = sess
|
||||
m.sessionsMutex.Unlock()
|
||||
|
||||
m.updatesMutex.Lock()
|
||||
m.dispatchUpdate(sessID, sess)
|
||||
m.updatesMutex.Unlock()
|
||||
}
|
||||
|
||||
func (m *SessionManager) Remove(sessID SessionID) {
|
||||
m.sessionsMutex.Lock()
|
||||
oldSess, ok := m.sessions[sessID]
|
||||
if ok {
|
||||
m.updateTokenIndex(sessID, "", oldSess.Token)
|
||||
}
|
||||
delete(m.sessions, sessID)
|
||||
m.sessionsMutex.Unlock()
|
||||
|
||||
m.updatesMutex.Lock()
|
||||
m.closeAllUpdates(sessID)
|
||||
m.updatesMutex.Unlock()
|
||||
}
|
||||
|
||||
func (m *SessionManager) OnUpdate(sessID SessionID) (<-chan SessionData, func()) {
|
||||
update := make(chan SessionData)
|
||||
|
||||
close := func() {
|
||||
m.updatesMutex.Lock()
|
||||
m.closeUpdate(sessID, update)
|
||||
m.updatesMutex.Unlock()
|
||||
}
|
||||
|
||||
m.updatesMutex.Lock()
|
||||
defer m.updatesMutex.Unlock()
|
||||
|
||||
channels, exists := m.updates[sessID]
|
||||
if !exists {
|
||||
channels = make([]chan SessionData, 0, 1)
|
||||
}
|
||||
|
||||
channels = append(channels, update)
|
||||
m.updates[sessID] = channels
|
||||
|
||||
return update, close
|
||||
}
|
||||
|
||||
func (m *SessionManager) closeAllUpdates(sessID SessionID) {
|
||||
channels, exists := m.updates[sessID]
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
for _, ch := range channels {
|
||||
m.closeUpdate(sessID, ch)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *SessionManager) closeUpdate(sessID SessionID, update chan SessionData) {
|
||||
channels, exists := m.updates[sessID]
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
for idx, ch := range channels {
|
||||
if ch != update {
|
||||
continue
|
||||
}
|
||||
|
||||
close(ch)
|
||||
m.updates[sessID] = append(channels[:idx], channels[idx+1:]...)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *SessionManager) dispatchUpdate(sessID SessionID, sess SessionData) {
|
||||
channels, exists := m.updates[sessID]
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
for _, ch := range channels {
|
||||
timeout := time.After(m.updateReadTimeout)
|
||||
select {
|
||||
case ch <- sess:
|
||||
case <-timeout:
|
||||
err := errors.New("session update read timed out")
|
||||
log.Printf("[ERROR] %+v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *SessionManager) updateTokenIndex(sessID SessionID, addedToken, deletedToken string) {
|
||||
if addedToken != "" {
|
||||
m.tokenIndex[addedToken] = sessID
|
||||
}
|
||||
|
||||
if deletedToken != "" {
|
||||
delete(m.tokenIndex, deletedToken)
|
||||
}
|
||||
}
|
||||
|
||||
func NewSessionManager(updateReadTimeout time.Duration) *SessionManager {
|
||||
return &SessionManager{
|
||||
sessions: make(map[SessionID]SessionData),
|
||||
tokenIndex: make(map[string]SessionID),
|
||||
updates: make(map[SessionID][]chan SessionData),
|
||||
updateReadTimeout: updateReadTimeout,
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue