227 lines
4.6 KiB
Go
227 lines
4.6 KiB
Go
package rebound
|
|
|
|
import (
|
|
"crypto/rand"
|
|
"fmt"
|
|
"io"
|
|
"log"
|
|
"math/big"
|
|
"net"
|
|
"os"
|
|
"path/filepath"
|
|
"strconv"
|
|
"strings"
|
|
|
|
"github.com/gliderlabs/ssh"
|
|
"github.com/pkg/errors"
|
|
gossh "golang.org/x/crypto/ssh"
|
|
)
|
|
|
|
const (
|
|
forwardedTCPChannelType = "forwarded-tcpip"
|
|
)
|
|
|
|
// direct-tcpip data struct as specified in RFC4254, Section 7.2
|
|
type remoteForwardRequest struct {
|
|
BindAddr string
|
|
BindPort uint32
|
|
}
|
|
|
|
type remoteForwardSuccess struct {
|
|
BindPort uint32
|
|
}
|
|
|
|
type remoteForwardCancelRequest struct {
|
|
BindAddr string
|
|
BindPort uint32
|
|
}
|
|
|
|
type remoteForwardChannelData struct {
|
|
DestAddr string
|
|
DestPort uint32
|
|
OriginAddr string
|
|
OriginPort uint32
|
|
}
|
|
|
|
func (s *Server) handleRequest(ctx ssh.Context, srv *ssh.Server, req *gossh.Request) (bool, []byte) {
|
|
s.requestHandlerLock.Lock()
|
|
if s.forwards == nil {
|
|
s.forwards = make(map[string]net.Listener)
|
|
}
|
|
s.requestHandlerLock.Unlock()
|
|
conn := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn)
|
|
switch req.Type {
|
|
case "tcpip-forward":
|
|
var reqPayload remoteForwardRequest
|
|
|
|
if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil {
|
|
log.Printf("[ERROR] %+v", errors.WithStack(err))
|
|
return false, []byte{}
|
|
}
|
|
|
|
if reqPayload.BindPort != 0 {
|
|
return false, []byte("bind port must 0")
|
|
}
|
|
|
|
s.log("(%s): opening reverse tunnel", ctx.SessionID())
|
|
|
|
token, err := generateToken(16)
|
|
if err != nil {
|
|
s.log("[ERROR] %+v", errors.WithStack(err))
|
|
return false, []byte("could not generate secret token")
|
|
}
|
|
|
|
sessionID := SessionID(ctx.SessionID())
|
|
|
|
s.sessionManager.Set(sessionID, SessionData{
|
|
Type: TypeServiceProvider,
|
|
Token: token,
|
|
})
|
|
|
|
addr := s.getSocketPath(sessionID)
|
|
|
|
ln, err := net.Listen("unix", addr)
|
|
if err != nil {
|
|
s.log("[ERROR] %+v", errors.WithStack(err))
|
|
return false, []byte{}
|
|
}
|
|
|
|
destPort := 1
|
|
|
|
s.requestHandlerLock.Lock()
|
|
s.forwards[addr] = ln
|
|
s.requestHandlerLock.Unlock()
|
|
|
|
cleanup := func() {
|
|
s.log("(%s): cleaning up session", sessionID)
|
|
|
|
s.sessionManager.Remove(sessionID)
|
|
|
|
if err := os.Remove(addr); err != nil && !errors.Is(err, os.ErrNotExist) {
|
|
s.log("[ERROR] %+v", errors.WithStack(err))
|
|
}
|
|
}
|
|
|
|
go func() {
|
|
defer cleanup()
|
|
|
|
<-ctx.Done()
|
|
s.requestHandlerLock.Lock()
|
|
ln, ok := s.forwards[addr]
|
|
s.requestHandlerLock.Unlock()
|
|
if ok {
|
|
ln.Close()
|
|
}
|
|
}()
|
|
|
|
go func() {
|
|
for {
|
|
c, err := ln.Accept()
|
|
if err != nil {
|
|
if !errors.Is(err, net.ErrClosed) {
|
|
s.log("[ERROR] %+v", errors.WithStack(err))
|
|
}
|
|
|
|
break
|
|
}
|
|
|
|
originAddr, orignPortStr, _ := net.SplitHostPort(c.RemoteAddr().String())
|
|
originPort, _ := strconv.Atoi(orignPortStr)
|
|
payload := gossh.Marshal(&remoteForwardChannelData{
|
|
DestAddr: reqPayload.BindAddr,
|
|
DestPort: uint32(destPort),
|
|
OriginAddr: originAddr,
|
|
OriginPort: uint32(originPort),
|
|
})
|
|
|
|
go func() {
|
|
ch, reqs, err := conn.OpenChannel(forwardedTCPChannelType, payload)
|
|
if err != nil {
|
|
s.log("[ERROR] %+v", errors.WithStack(err))
|
|
c.Close()
|
|
return
|
|
}
|
|
|
|
go gossh.DiscardRequests(reqs)
|
|
|
|
go func() {
|
|
defer ch.Close()
|
|
defer c.Close()
|
|
|
|
io.Copy(ch, c)
|
|
}()
|
|
|
|
go func() {
|
|
defer ch.Close()
|
|
defer c.Close()
|
|
|
|
io.Copy(c, ch)
|
|
}()
|
|
}()
|
|
}
|
|
|
|
s.requestHandlerLock.Lock()
|
|
delete(s.forwards, addr)
|
|
s.requestHandlerLock.Unlock()
|
|
}()
|
|
|
|
return true, gossh.Marshal(&remoteForwardSuccess{uint32(destPort)})
|
|
|
|
case "cancel-tcpip-forward":
|
|
var reqPayload remoteForwardCancelRequest
|
|
if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil {
|
|
s.log("[ERROR] %+v", errors.WithStack(err))
|
|
return false, []byte{}
|
|
}
|
|
|
|
sessionID := SessionID(ctx.SessionID())
|
|
|
|
addr := s.getSocketPath(sessionID)
|
|
|
|
s.log("(%s): closing sock '%s'", sessionID, addr)
|
|
|
|
s.requestHandlerLock.Lock()
|
|
ln, ok := s.forwards[addr]
|
|
s.requestHandlerLock.Unlock()
|
|
if ok {
|
|
ln.Close()
|
|
if err := os.Remove(addr); err != nil && !errors.Is(err, os.ErrNotExist) {
|
|
s.log("[ERROR] %+v", errors.WithStack(err))
|
|
}
|
|
}
|
|
return true, nil
|
|
|
|
default:
|
|
return false, nil
|
|
}
|
|
}
|
|
|
|
func (s *Server) getSocketPath(sessionID SessionID) string {
|
|
return filepath.Join(s.opts.SockDir, fmt.Sprintf("%s.sock", sessionID))
|
|
}
|
|
|
|
func generateToken(length int) (string, error) {
|
|
chars := []rune(
|
|
"ABCDEFGHIJKLMNOPQRSTUVWXYZ" +
|
|
"abcdefghijklmnopqrstuvwxyz" +
|
|
"0123456789",
|
|
)
|
|
var b strings.Builder
|
|
|
|
charsLength := big.NewInt(int64(len(chars)))
|
|
|
|
for i := 0; i < length; i++ {
|
|
idx, err := rand.Int(rand.Reader, charsLength)
|
|
if err != nil {
|
|
return "", errors.WithStack(err)
|
|
}
|
|
|
|
c := chars[idx.Int64()]
|
|
if _, err := b.WriteRune(c); err != nil {
|
|
return "", errors.WithStack(err)
|
|
}
|
|
}
|
|
|
|
return b.String(), nil
|
|
}
|