feat: collect and display usage stats
This commit is contained in:
parent
bf14a70efe
commit
6b1637d1d8
|
@ -9,3 +9,4 @@
|
|||
tools/
|
||||
/CHANGELOG.md
|
||||
/.chglog
|
||||
/stats.json
|
|
@ -9,7 +9,7 @@ builds:
|
|||
ldflags:
|
||||
- -s
|
||||
- -w
|
||||
- -X 'main.Version=${MKT_PROJECT_VERSION}'
|
||||
- -X "main.Version={{ .Env.GORELEASER_CURRENT_TAG }}"
|
||||
gcflags:
|
||||
- -trimpath="${PWD}"
|
||||
asmflags:
|
||||
|
|
|
@ -11,6 +11,8 @@ import (
|
|||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var Version string = "unknown"
|
||||
|
||||
func main() {
|
||||
opts := rebound.DefaultOptions()
|
||||
|
||||
|
@ -18,8 +20,12 @@ func main() {
|
|||
log.Fatalf("[ERROR] %+v", errors.WithStack(err))
|
||||
}
|
||||
|
||||
opts.HTTP.TemplateData.Version = Version
|
||||
|
||||
server := rebound.NewServer(
|
||||
rebound.WithAddress(opts.Address),
|
||||
rebound.WithStatsFile(opts.StatsFile),
|
||||
rebound.WithStatsFileSaveInterval(opts.StatsFileSaveInterval),
|
||||
rebound.WithSSHOption(
|
||||
ssh.WithSockDir(opts.SSH.SockDir),
|
||||
ssh.WithPublicHost(opts.SSH.PublicHost),
|
||||
|
|
|
@ -2,16 +2,20 @@ package http
|
|||
|
||||
import (
|
||||
"log"
|
||||
|
||||
"forge.cadoles.com/wpetit/rebound/stat"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
Logger func(message string, args ...any)
|
||||
CustomDir string `env:"CUSTOM_DIR"`
|
||||
TemplateData *TemplateData `envPrefix:"TEMPLATE_DATA_"`
|
||||
Stats *stat.Store
|
||||
}
|
||||
|
||||
type TemplateData struct {
|
||||
Title string `env:"TITLE"`
|
||||
Version string
|
||||
SSHPublicHost string `env:"SSH_PUBLIC_HOST"`
|
||||
SSHPublicPort int `env:"SSH_PUBLIC_PORT"`
|
||||
}
|
||||
|
@ -27,6 +31,7 @@ func DefaultOptions() *Options {
|
|||
SSHPublicHost: "127.0.0.1",
|
||||
SSHPublicPort: 2222,
|
||||
},
|
||||
Stats: stat.NewStore(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -47,3 +52,9 @@ func WithTemplateData(templateData *TemplateData) func(*Options) {
|
|||
opts.TemplateData = templateData
|
||||
}
|
||||
}
|
||||
|
||||
func WithStats(stats *stat.Store) func(*Options) {
|
||||
return func(opts *Options) {
|
||||
opts.Stats = stats
|
||||
}
|
||||
}
|
||||
|
|
|
@ -3,9 +3,11 @@ package http
|
|||
import (
|
||||
"bytes"
|
||||
"embed"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"io"
|
||||
"io/fs"
|
||||
"log/slog"
|
||||
"net"
|
||||
"net/http"
|
||||
"os"
|
||||
|
@ -29,8 +31,45 @@ type Server struct {
|
|||
templates template.Template
|
||||
}
|
||||
|
||||
var templateFuncs = template.FuncMap{
|
||||
"humanSize": func(b float64) string {
|
||||
const unit = 1000
|
||||
if b < unit {
|
||||
return fmt.Sprintf("%d B", int64(b))
|
||||
}
|
||||
|
||||
div, exp := int64(unit), 0
|
||||
|
||||
for n := b / unit; n >= unit; n /= unit {
|
||||
div *= unit
|
||||
exp++
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%.1f %cB",
|
||||
float64(b)/float64(div), "kMGTPE"[exp])
|
||||
},
|
||||
}
|
||||
|
||||
func (s *Server) serveHomepage(w http.ResponseWriter, r *http.Request) {
|
||||
s.renderTemplate(w, "index", s.opts.TemplateData)
|
||||
stats, err := s.opts.Stats.Snapshot()
|
||||
if err != nil {
|
||||
slog.Error("could not make stats snapshot", slog.Any("error", errors.WithStack(err)))
|
||||
http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
data := struct {
|
||||
TemplateData
|
||||
Stats map[string]float64
|
||||
}{
|
||||
TemplateData: *s.opts.TemplateData,
|
||||
Stats: stats,
|
||||
}
|
||||
|
||||
s.opts.Stats.Add(StatTotalPageView, 1, 0)
|
||||
|
||||
s.renderTemplate(w, "index", data)
|
||||
}
|
||||
|
||||
func (s *Server) Serve(l net.Listener) error {
|
||||
|
@ -67,7 +106,7 @@ func (s *Server) Serve(l net.Listener) error {
|
|||
}
|
||||
|
||||
func (s *Server) parseTemplates(fs fs.FS) error {
|
||||
templates, err := template.ParseFS(fs, "templates/*.html")
|
||||
templates, err := template.New("").Funcs(templateFuncs).ParseFS(fs, "templates/*.html")
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
|
|
@ -0,0 +1,5 @@
|
|||
package http
|
||||
|
||||
const (
|
||||
StatTotalPageView = "total_page_view"
|
||||
)
|
|
@ -2,7 +2,7 @@
|
|||
<footer class="footer">
|
||||
<div class="container">
|
||||
<div class="content has-text-centered">
|
||||
Ce service est propulsé par <a href="https://forge.cadoles.com/wpetit/rebound" title="Rebound repository">Rebound</a>, un logiciel libre diffusé sous licence <a href="https://www.gnu.org/licenses/agpl-3.0.en.html#license-text">AGPL-3.0</a>.
|
||||
Ce service est propulsé par <a href="https://forge.cadoles.com/wpetit/rebound" title="Rebound repository">rebound@{{ .Version }}</a>, un logiciel libre diffusé sous licence <a href="https://www.gnu.org/licenses/agpl-3.0.en.html#license-text">AGPL-3.0</a>.
|
||||
</div>
|
||||
</div>
|
||||
</footer>
|
||||
|
|
|
@ -13,14 +13,46 @@
|
|||
<p class="subtitle is-size-3">
|
||||
Bienvenue sur <strong>Rebound</strong>!
|
||||
</p>
|
||||
<div class="content">
|
||||
<p>Rebound est un serveur SSH permettant de créer des tunnels TCP/IP éphémères et privés entre 2 machines positionnées
|
||||
derrière un <abbr title="Network Address Traversal">NAT</abbr>.</p>
|
||||
<p>Pour l'utiliser <strong>un simple client SSH suffit !</strong></p>
|
||||
<pre class="has-background-dark has-text-white-ter is-family-monospace">ssh -R 0:127.0.0.1:<span class="has-text-info"><port></span> rebound@{{ .SSHPublicHost }} -p {{ .SSHPublicPort }}</pre>
|
||||
<p class="is-italic">Où <span class="has-text-info"><port></span> est à remplacer par le port du service
|
||||
s'exécutant sur votre machine en local.</span>
|
||||
<p>Une fois connecté, suivez les instructions. 😉</p>
|
||||
<div class="block">
|
||||
<div class="content">
|
||||
<p>Rebound est un serveur SSH permettant de créer des tunnels TCP/IP éphémères et privés entre 2 machines positionnées
|
||||
derrière un <abbr title="Network Address Traversal">NAT</abbr>.</p>
|
||||
<p>Pour l'utiliser <strong>un simple client SSH suffit !</strong></p>
|
||||
<pre class="has-background-dark has-text-white-ter is-family-monospace">ssh -R 0:127.0.0.1:<span class="has-text-info"><port></span> rebound@{{ .SSHPublicHost }} -p {{ .SSHPublicPort }}</pre>
|
||||
<p class="is-italic">Où <span class="has-text-info"><port></span> est à remplacer par le port du service
|
||||
s'exécutant sur votre machine en local.</span>
|
||||
<p>Une fois connecté, suivez les instructions. 😉</p>
|
||||
</div>
|
||||
</div>
|
||||
<hr />
|
||||
<div class="block">
|
||||
<div class="columns">
|
||||
<div class="column is-4">
|
||||
<h3 class="title is-size-4">En savoir plus</h3>
|
||||
<div class="content">
|
||||
À venir...
|
||||
</div>
|
||||
</div>
|
||||
<div class="column is-4">
|
||||
<h3 class="title is-size-4">Statistiques</h3>
|
||||
<table class="table is-bordered is-striped is-fullwidth">
|
||||
<tbody>
|
||||
<tr>
|
||||
<td><strong>Total tunnels ouverts</strong></td>
|
||||
<td>{{ index .Stats "total_opened_tunnels" }}</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>Total données entrantes</strong></td>
|
||||
<td>{{ humanSize ( index .Stats "total_rx_bytes" ) }}</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td><strong>Total données sortantes</strong></td>
|
||||
<td>{{ humanSize ( index .Stats "total_tx_bytes" ) }}</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</section>
|
||||
|
|
|
@ -4,6 +4,7 @@ export REBOUND_SSH_PUBLIC_HOST=rebound
|
|||
export REBOUND_SSH_PUBLIC_PORT=2222
|
||||
export REBOUND_SSH_SOCK_DIR=/var/lib/rebound/socks
|
||||
export REBOUND_SSH_HOST_KEY=/etc/rebound/host.key
|
||||
export REBOUND_STATS_FILE=/var/lib/rebound/stats.json
|
||||
export REBOUND_HTTP_TEMPLATE_DATA_TITLE=Rebound
|
||||
export REBOUND_HTTP_TEMPLATE_DATA_SSH_PUBLIC_HOST=127.0.0.1
|
||||
export REBOUND_HTTP_TEMPLATE_DATA_SSH_PUBLIC_PORT=8080
|
|
@ -4,6 +4,7 @@ REBOUND_SSH_PUBLIC_HOST=rebound
|
|||
REBOUND_SSH_PUBLIC_PORT=8080
|
||||
REBOUND_SSH_SOCK_DIR=/var/lib/rebound/socks
|
||||
REBOUND_SSH_HOST_KEY=/var/lib/rebound/host.key
|
||||
REBOUND_STATS_FILE=/var/lib/rebound/stats.json
|
||||
REBOUND_HTTP_TEMPLATE_DATA_TITLE=Rebound
|
||||
REBOUND_HTTP_TEMPLATE_DATA_SSH_PUBLIC_HOST=127.0.0.1
|
||||
REBOUND_HTTP_TEMPLATE_DATA_SSH_PUBLIC_PORT=8080
|
33
options.go
33
options.go
|
@ -2,6 +2,7 @@ package rebound
|
|||
|
||||
import (
|
||||
"log"
|
||||
"time"
|
||||
|
||||
"forge.cadoles.com/wpetit/rebound/http"
|
||||
"forge.cadoles.com/wpetit/rebound/ssh"
|
||||
|
@ -10,10 +11,12 @@ import (
|
|||
)
|
||||
|
||||
type Options struct {
|
||||
Address string `env:"REBOUND_ADDRESS"`
|
||||
Logger func(message string, args ...any)
|
||||
SSH *ssh.Options `envPrefix:"REBOUND_SSH_"`
|
||||
HTTP *http.Options `envPrefix:"REBOUND_HTTP_"`
|
||||
Address string `env:"REBOUND_ADDRESS"`
|
||||
StatsFile string `env:"REBOUND_STATS_FILE"`
|
||||
StatsFileSaveInterval time.Duration `env:"REBOUND_STATS_FILE_SAVE_INTERVAL"`
|
||||
Logger func(message string, args ...any)
|
||||
SSH *ssh.Options `envPrefix:"REBOUND_SSH_"`
|
||||
HTTP *http.Options `envPrefix:"REBOUND_HTTP_"`
|
||||
}
|
||||
|
||||
func (o *Options) ParseEnv() error {
|
||||
|
@ -28,10 +31,12 @@ type OptionFunc func(*Options)
|
|||
|
||||
func DefaultOptions() *Options {
|
||||
return &Options{
|
||||
Address: "127.0.0.1:2222",
|
||||
Logger: log.Printf,
|
||||
SSH: ssh.DefaultOptions(),
|
||||
HTTP: http.DefaultOptions(),
|
||||
Address: "127.0.0.1:2222",
|
||||
StatsFile: "stats.json",
|
||||
StatsFileSaveInterval: 30 * time.Second,
|
||||
Logger: log.Printf,
|
||||
SSH: ssh.DefaultOptions(),
|
||||
HTTP: http.DefaultOptions(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -49,6 +54,18 @@ func WithLogger(logger func(message string, args ...any)) func(*Options) {
|
|||
}
|
||||
}
|
||||
|
||||
func WithStatsFile(path string) func(*Options) {
|
||||
return func(o *Options) {
|
||||
o.StatsFile = path
|
||||
}
|
||||
}
|
||||
|
||||
func WithStatsFileSaveInterval(interval time.Duration) func(*Options) {
|
||||
return func(o *Options) {
|
||||
o.StatsFileSaveInterval = interval
|
||||
}
|
||||
}
|
||||
|
||||
func WithSSHOption(funcs ...ssh.OptionFunc) func(*Options) {
|
||||
return func(o *Options) {
|
||||
for _, fn := range funcs {
|
||||
|
|
35
server.go
35
server.go
|
@ -1,21 +1,34 @@
|
|||
package rebound
|
||||
|
||||
import (
|
||||
"log/slog"
|
||||
"net"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"forge.cadoles.com/wpetit/rebound/http"
|
||||
"forge.cadoles.com/wpetit/rebound/ssh"
|
||||
"forge.cadoles.com/wpetit/rebound/stat"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type Server struct {
|
||||
listener net.Listener
|
||||
opts *Options
|
||||
stats *stat.Store
|
||||
}
|
||||
|
||||
func (s *Server) Start() error {
|
||||
s.log("[INFO] listening on %s", s.opts.Address)
|
||||
|
||||
if err := s.stats.Load(s.opts.StatsFile); err != nil {
|
||||
if errors.Is(err, os.ErrNotExist) {
|
||||
s.log("[INFO] stats file does not exist. ignoring.")
|
||||
} else {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
}
|
||||
|
||||
listener, err := net.Listen("tcp", s.opts.Address)
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
|
@ -34,6 +47,7 @@ func (s *Server) Start() error {
|
|||
ssh.WithPublicPort(s.opts.SSH.PublicPort),
|
||||
ssh.WithSockDir(s.opts.SSH.SockDir),
|
||||
ssh.WithLogger(s.opts.SSH.Logger),
|
||||
ssh.WithStats(s.stats),
|
||||
)
|
||||
|
||||
if err := server.Serve(sshListener); err != nil {
|
||||
|
@ -49,6 +63,7 @@ func (s *Server) Start() error {
|
|||
http.WithCustomDir(s.opts.HTTP.CustomDir),
|
||||
http.WithTemplateData(s.opts.HTTP.TemplateData),
|
||||
http.WithLogger(s.opts.HTTP.Logger),
|
||||
http.WithStats(s.stats),
|
||||
)
|
||||
|
||||
if err := server.Serve(httpListener); err != nil {
|
||||
|
@ -57,6 +72,23 @@ func (s *Server) Start() error {
|
|||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer listener.Close()
|
||||
|
||||
ticker := time.NewTicker(s.opts.StatsFileSaveInterval)
|
||||
|
||||
for {
|
||||
<-ticker.C
|
||||
|
||||
slog.Info("saving stats", slog.String("file", s.opts.StatsFile), slog.Duration("interval", s.opts.StatsFileSaveInterval))
|
||||
if err := s.stats.Save(s.opts.StatsFile); err != nil {
|
||||
slog.Error("could not save stat file", slog.Any("error", errors.WithStack(err)))
|
||||
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -85,6 +117,7 @@ func NewServer(funcs ...OptionFunc) *Server {
|
|||
}
|
||||
|
||||
return &Server{
|
||||
opts: opts,
|
||||
opts: opts,
|
||||
stats: stat.NewStore(),
|
||||
}
|
||||
}
|
||||
|
|
|
@ -77,7 +77,13 @@ func (s *Server) handleDirectTCP(srv *ssh.Server, conn *gossh.ServerConn, newCha
|
|||
defer dconn.Close()
|
||||
defer ch.Close()
|
||||
|
||||
if _, err := io.Copy(ch, dconn); err != nil {
|
||||
reader := &instrumentedReader{
|
||||
internal: dconn,
|
||||
stats: s.opts.Stats,
|
||||
name: StatTotalRxBytes,
|
||||
}
|
||||
|
||||
if _, err := io.Copy(ch, reader); err != nil {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return
|
||||
}
|
||||
|
@ -90,7 +96,13 @@ func (s *Server) handleDirectTCP(srv *ssh.Server, conn *gossh.ServerConn, newCha
|
|||
defer dconn.Close()
|
||||
defer ch.Close()
|
||||
|
||||
if _, err := io.Copy(dconn, ch); err != nil {
|
||||
writer := &instrumentedWriter{
|
||||
internal: dconn,
|
||||
stats: s.opts.Stats,
|
||||
name: StatTotalTxBytes,
|
||||
}
|
||||
|
||||
if _, err := io.Copy(writer, ch); err != nil {
|
||||
s.log("[ERROR] %+v", errors.WithStack(err))
|
||||
}
|
||||
}()
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
package ssh
|
||||
|
||||
import "log"
|
||||
import (
|
||||
"log"
|
||||
|
||||
"forge.cadoles.com/wpetit/rebound/stat"
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
Logger func(message string, args ...any)
|
||||
|
@ -8,6 +12,7 @@ type Options struct {
|
|||
PublicPort uint `env:"PUBLIC_PORT"`
|
||||
PublicHost string `env:"PUBLIC_HOST"`
|
||||
HostKey string `env:"HOST_KEY"`
|
||||
Stats *stat.Store
|
||||
}
|
||||
|
||||
type OptionFunc func(*Options)
|
||||
|
@ -19,6 +24,7 @@ func DefaultOptions() *Options {
|
|||
PublicPort: 2222,
|
||||
PublicHost: "127.0.0.1",
|
||||
HostKey: "./host.key",
|
||||
Stats: stat.NewStore(),
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -51,3 +57,9 @@ func WithLogger(logger func(message string, args ...any)) func(*Options) {
|
|||
opts.Logger = logger
|
||||
}
|
||||
}
|
||||
|
||||
func WithStats(stats *stat.Store) func(*Options) {
|
||||
return func(opts *Options) {
|
||||
opts.Stats = stats
|
||||
}
|
||||
}
|
||||
|
|
|
@ -91,6 +91,8 @@ func (s *Server) handleRequest(ctx ssh.Context, srv *ssh.Server, req *gossh.Requ
|
|||
return false, []byte{}
|
||||
}
|
||||
|
||||
s.opts.Stats.Add(StatTotalOpenedTunnels, 1, 0)
|
||||
|
||||
destPort := 1
|
||||
|
||||
s.requestHandlerLock.Lock()
|
||||
|
|
|
@ -0,0 +1,43 @@
|
|||
package ssh
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"forge.cadoles.com/wpetit/rebound/stat"
|
||||
)
|
||||
|
||||
const (
|
||||
StatTotalOpenedTunnels = "total_opened_tunnels"
|
||||
StatTotalTxBytes = "total_tx_bytes"
|
||||
StatTotalRxBytes = "total_rx_bytes"
|
||||
)
|
||||
|
||||
type instrumentedWriter struct {
|
||||
name string
|
||||
stats *stat.Store
|
||||
internal io.Writer
|
||||
}
|
||||
|
||||
// Write implements io.Writer.
|
||||
func (w *instrumentedWriter) Write(p []byte) (n int, err error) {
|
||||
n, err = w.internal.Write(p)
|
||||
w.stats.Add(w.name, float64(n), 0)
|
||||
return n, err
|
||||
}
|
||||
|
||||
var _ io.Writer = &instrumentedWriter{}
|
||||
|
||||
type instrumentedReader struct {
|
||||
name string
|
||||
stats *stat.Store
|
||||
internal io.Reader
|
||||
}
|
||||
|
||||
// Read implements io.Reader.
|
||||
func (w *instrumentedReader) Read(p []byte) (n int, err error) {
|
||||
n, err = w.internal.Read(p)
|
||||
w.stats.Add(w.name, float64(n), 0)
|
||||
return n, err
|
||||
}
|
||||
|
||||
var _ io.Reader = &instrumentedReader{}
|
|
@ -0,0 +1,154 @@
|
|||
package stat
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sync"
|
||||
|
||||
"log/slog"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type Store struct {
|
||||
data sync.Map
|
||||
loadSaveLock sync.Mutex
|
||||
}
|
||||
|
||||
func (s *Store) Load(path string) error {
|
||||
s.loadSaveLock.Lock()
|
||||
defer s.loadSaveLock.Unlock()
|
||||
|
||||
file, err := os.OpenFile(path, os.O_RDONLY, os.ModePerm)
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
decoder := json.NewDecoder(file)
|
||||
data := map[string]any{}
|
||||
|
||||
if err := decoder.Decode(&data); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
s.data.Range(func(key, value any) bool {
|
||||
s.data.Delete(key)
|
||||
return true
|
||||
})
|
||||
|
||||
for k, v := range data {
|
||||
s.data.Store(k, v)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) Save(path string) error {
|
||||
s.loadSaveLock.Lock()
|
||||
defer s.loadSaveLock.Unlock()
|
||||
|
||||
data, err := s.Snapshot()
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
dir := filepath.Dir(path)
|
||||
filename := filepath.Base(path)
|
||||
|
||||
temp, err := os.CreateTemp(dir, filename+".new*")
|
||||
if err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := os.Remove(temp.Name()); err != nil && !errors.Is(err, os.ErrNotExist) {
|
||||
slog.Error("could not remove temporary file",
|
||||
slog.String("file", temp.Name()),
|
||||
slog.Any("error", errors.WithStack(err)),
|
||||
)
|
||||
}
|
||||
}()
|
||||
|
||||
encoder := json.NewEncoder(temp)
|
||||
if err := encoder.Encode(data); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
if err := os.Rename(temp.Name(), path); err != nil {
|
||||
return errors.WithStack(err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Store) Snapshot() (map[string]float64, error) {
|
||||
data := map[string]float64{}
|
||||
|
||||
var err error
|
||||
s.data.Range(func(rawKey, rawValue any) bool {
|
||||
key, ok := rawKey.(string)
|
||||
if !ok {
|
||||
err = errors.Errorf("unexpected stat key of '%v'", rawKey)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
value, ok := rawValue.(float64)
|
||||
if !ok {
|
||||
err = errors.Errorf("unexpected stat value of '%v'", rawValue)
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
data[key] = value
|
||||
|
||||
return true
|
||||
})
|
||||
if err != nil {
|
||||
return nil, errors.WithStack(err)
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
func (s *Store) Add(name string, added float64, defaultValue float64) float64 {
|
||||
for {
|
||||
value := s.Get(name, defaultValue)
|
||||
if value == defaultValue {
|
||||
s.data.Store(name, defaultValue)
|
||||
}
|
||||
|
||||
sum := value + added
|
||||
if s.data.CompareAndSwap(name, value, value+added) {
|
||||
return sum
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Store) Set(name string, value float64) float64 {
|
||||
s.data.Store(name, value)
|
||||
|
||||
return value
|
||||
}
|
||||
|
||||
func (s *Store) Get(name string, defaultValue float64) float64 {
|
||||
rawValue, ok := s.data.Load(name)
|
||||
if !ok {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
value, ok := rawValue.(float64)
|
||||
if !ok {
|
||||
return defaultValue
|
||||
}
|
||||
|
||||
return value
|
||||
}
|
||||
|
||||
func NewStore() *Store {
|
||||
return &Store{
|
||||
data: sync.Map{},
|
||||
loadSaveLock: sync.Mutex{},
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue