package ssh import ( "crypto/rand" "fmt" "io" "log" "math/big" "net" "os" "path/filepath" "strconv" "strings" "github.com/gliderlabs/ssh" "github.com/pkg/errors" gossh "golang.org/x/crypto/ssh" ) const ( forwardedTCPChannelType = "forwarded-tcpip" ) // direct-tcpip data struct as specified in RFC4254, Section 7.2 type remoteForwardRequest struct { BindAddr string BindPort uint32 } type remoteForwardSuccess struct { BindPort uint32 } type remoteForwardCancelRequest struct { BindAddr string BindPort uint32 } type remoteForwardChannelData struct { DestAddr string DestPort uint32 OriginAddr string OriginPort uint32 } func (s *Server) handleRequest(ctx ssh.Context, srv *ssh.Server, req *gossh.Request) (bool, []byte) { s.requestHandlerLock.Lock() if s.forwards == nil { s.forwards = make(map[string]net.Listener) } s.requestHandlerLock.Unlock() conn := ctx.Value(ssh.ContextKeyConn).(*gossh.ServerConn) switch req.Type { case "tcpip-forward": var reqPayload remoteForwardRequest if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil { log.Printf("[ERROR] %+v", errors.WithStack(err)) return false, []byte{} } if reqPayload.BindPort != 0 { return false, []byte("bind port must 0") } s.log("(%s): opening reverse tunnel", ctx.SessionID()) token, err := generateToken(16) if err != nil { s.log("[ERROR] %+v", errors.WithStack(err)) return false, []byte("could not generate secret token") } sessionID := SessionID(ctx.SessionID()) s.sessionManager.Set(sessionID, SessionData{ Type: TypeServiceProvider, Token: token, }) addr := s.getSocketPath(sessionID) if err := s.ensureFileDir(addr); err != nil { s.log("[ERROR] %+v", errors.WithStack(err)) return false, []byte("internal server error") } ln, err := net.Listen("unix", addr) if err != nil { s.log("[ERROR] %+v", errors.WithStack(err)) return false, []byte{} } s.opts.Stats.Add(StatTotalOpenedTunnels, 1, 0) destPort := 1 s.requestHandlerLock.Lock() s.forwards[addr] = ln s.requestHandlerLock.Unlock() cleanup := func() { s.log("(%s): cleaning up session", sessionID) s.sessionManager.Remove(sessionID) if err := os.Remove(addr); err != nil && !errors.Is(err, os.ErrNotExist) { s.log("[ERROR] %+v", errors.WithStack(err)) } } go func() { defer cleanup() <-ctx.Done() s.requestHandlerLock.Lock() ln, ok := s.forwards[addr] s.requestHandlerLock.Unlock() if ok { ln.Close() } }() go func() { for { c, err := ln.Accept() if err != nil { if !errors.Is(err, net.ErrClosed) { s.log("[ERROR] %+v", errors.WithStack(err)) } break } originAddr, orignPortStr, _ := net.SplitHostPort(c.RemoteAddr().String()) originPort, _ := strconv.Atoi(orignPortStr) payload := gossh.Marshal(&remoteForwardChannelData{ DestAddr: reqPayload.BindAddr, DestPort: uint32(destPort), OriginAddr: originAddr, OriginPort: uint32(originPort), }) go func() { ch, reqs, err := conn.OpenChannel(forwardedTCPChannelType, payload) if err != nil { s.log("[ERROR] %+v", errors.WithStack(err)) c.Close() return } go gossh.DiscardRequests(reqs) go func() { defer ch.Close() defer c.Close() io.Copy(ch, c) }() go func() { defer ch.Close() defer c.Close() io.Copy(c, ch) }() }() } s.requestHandlerLock.Lock() delete(s.forwards, addr) s.requestHandlerLock.Unlock() }() return true, gossh.Marshal(&remoteForwardSuccess{uint32(destPort)}) case "cancel-tcpip-forward": var reqPayload remoteForwardCancelRequest if err := gossh.Unmarshal(req.Payload, &reqPayload); err != nil { s.log("[ERROR] %+v", errors.WithStack(err)) return false, []byte{} } sessionID := SessionID(ctx.SessionID()) addr := s.getSocketPath(sessionID) s.log("(%s): closing sock '%s'", sessionID, addr) s.requestHandlerLock.Lock() ln, ok := s.forwards[addr] s.requestHandlerLock.Unlock() if ok { ln.Close() if err := os.Remove(addr); err != nil && !errors.Is(err, os.ErrNotExist) { s.log("[ERROR] %+v", errors.WithStack(err)) } } return true, nil default: return false, nil } } func (s *Server) getSocketPath(sessionID SessionID) string { return filepath.Join(s.opts.SockDir, fmt.Sprintf("%s.sock", sessionID)) } func (s *Server) ensureFileDir(file string) error { dir := filepath.Dir(file) if err := os.MkdirAll(dir, os.FileMode(0750)); err != nil { return errors.WithStack(err) } return nil } func generateToken(length int) (string, error) { chars := []rune( "ABCDEFGHIJKLMNOPQRSTUVWXYZ" + "abcdefghijklmnopqrstuvwxyz" + "0123456789", ) var b strings.Builder charsLength := big.NewInt(int64(len(chars))) for i := 0; i < length; i++ { idx, err := rand.Int(rand.Reader, charsLength) if err != nil { return "", errors.WithStack(err) } c := chars[idx.Int64()] if _, err := b.WriteRune(c); err != nil { return "", errors.WithStack(err) } } return b.String(), nil }