rebound/request_handler.go

227 lines
4.6 KiB
Go

package rebound
import (
"crypto/rand"
"fmt"
"io"
"log"
"math/big"
"net"
"os"
"path/filepath"
"strconv"
"strings"
"github.com/gliderlabs/ssh"
"github.com/pkg/errors"
gossh "golang.org/x/crypto/ssh"
)
const (
forwardedTCPChannelType = "forwarded-tcpip"
)
// direct-tcpip data struct as specified in RFC4254, Section 7.2
type remoteForwardRequest struct {
BindAddr string
BindPort uint32
}
type remoteForwardSuccess struct {
BindPort uint32
}
type remoteForwardCancelRequest struct {
BindAddr string
BindPort uint32
}
type remoteForwardChannelData struct {
DestAddr string
DestPort uint32
OriginAddr string
OriginPort uint32
}
func (s *Server) handleRequest(ctx ssh.Context, srv *ssh.Server, req *gossh.Request) (bool, []byte) {
s.requestHandlerLock.Lock()
if s.forwards == nil {
s.forwards = make(map[string]net.Listener)
}
s.requestHandlerLock.Unlock()
conn := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn)
switch req.Type {
case "tcpip-forward":
var reqPayload remoteForwardRequest
if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil {
log.Printf("[ERROR] %+v", errors.WithStack(err))
return false, []byte{}
}
if reqPayload.BindPort != 0 {
return false, []byte("bind port must 0")
}
s.log("(%s): opening reverse tunnel", ctx.SessionID())
token, err := generateToken(16)
if err != nil {
s.log("[ERROR] %+v", errors.WithStack(err))
return false, []byte("could not generate secret token")
}
sessionID := SessionID(ctx.SessionID())
s.sessionManager.Set(sessionID, SessionData{
Type: TypeServiceProvider,
Token: token,
})
addr := s.getSocketPath(sessionID)
ln, err := net.Listen("unix", addr)
if err != nil {
s.log("[ERROR] %+v", errors.WithStack(err))
return false, []byte{}
}
destPort := 1
s.requestHandlerLock.Lock()
s.forwards[addr] = ln
s.requestHandlerLock.Unlock()
cleanup := func() {
s.log("(%s): cleaning up session", sessionID)
s.sessionManager.Remove(sessionID)
if err := os.Remove(addr); err != nil && !errors.Is(err, os.ErrNotExist) {
s.log("[ERROR] %+v", errors.WithStack(err))
}
}
go func() {
defer cleanup()
<-ctx.Done()
s.requestHandlerLock.Lock()
ln, ok := s.forwards[addr]
s.requestHandlerLock.Unlock()
if ok {
ln.Close()
}
}()
go func() {
for {
c, err := ln.Accept()
if err != nil {
if !errors.Is(err, net.ErrClosed) {
s.log("[ERROR] %+v", errors.WithStack(err))
}
break
}
originAddr, orignPortStr, _ := net.SplitHostPort(c.RemoteAddr().String())
originPort, _ := strconv.Atoi(orignPortStr)
payload := gossh.Marshal(&remoteForwardChannelData{
DestAddr: reqPayload.BindAddr,
DestPort: uint32(destPort),
OriginAddr: originAddr,
OriginPort: uint32(originPort),
})
go func() {
ch, reqs, err := conn.OpenChannel(forwardedTCPChannelType, payload)
if err != nil {
s.log("[ERROR] %+v", errors.WithStack(err))
c.Close()
return
}
go gossh.DiscardRequests(reqs)
go func() {
defer ch.Close()
defer c.Close()
io.Copy(ch, c)
}()
go func() {
defer ch.Close()
defer c.Close()
io.Copy(c, ch)
}()
}()
}
s.requestHandlerLock.Lock()
delete(s.forwards, addr)
s.requestHandlerLock.Unlock()
}()
return true, gossh.Marshal(&remoteForwardSuccess{uint32(destPort)})
case "cancel-tcpip-forward":
var reqPayload remoteForwardCancelRequest
if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil {
s.log("[ERROR] %+v", errors.WithStack(err))
return false, []byte{}
}
sessionID := SessionID(ctx.SessionID())
addr := s.getSocketPath(sessionID)
s.log("(%s): closing sock '%s'", sessionID, addr)
s.requestHandlerLock.Lock()
ln, ok := s.forwards[addr]
s.requestHandlerLock.Unlock()
if ok {
ln.Close()
if err := os.Remove(addr); err != nil && !errors.Is(err, os.ErrNotExist) {
s.log("[ERROR] %+v", errors.WithStack(err))
}
}
return true, nil
default:
return false, nil
}
}
func (s *Server) getSocketPath(sessionID SessionID) string {
return filepath.Join(s.opts.SockDir, fmt.Sprintf("%s.sock", sessionID))
}
func generateToken(length int) (string, error) {
chars := []rune(
"ABCDEFGHIJKLMNOPQRSTUVWXYZ" +
"abcdefghijklmnopqrstuvwxyz" +
"0123456789",
)
var b strings.Builder
charsLength := big.NewInt(int64(len(chars)))
for i := 0; i < length; i++ {
idx, err := rand.Int(rand.Reader, charsLength)
if err != nil {
return "", errors.WithStack(err)
}
c := chars[idx.Int64()]
if _, err := b.WriteRune(c); err != nil {
return "", errors.WithStack(err)
}
}
return b.String(), nil
}