rebound/stat/store.go

155 lines
2.7 KiB
Go
Raw Normal View History

2023-09-24 20:21:44 +02:00
package stat
import (
"encoding/json"
"os"
"path/filepath"
"sync"
"log/slog"
"github.com/pkg/errors"
)
type Store struct {
data sync.Map
loadSaveLock sync.Mutex
}
func (s *Store) Load(path string) error {
s.loadSaveLock.Lock()
defer s.loadSaveLock.Unlock()
file, err := os.OpenFile(path, os.O_RDONLY, os.ModePerm)
if err != nil {
return errors.WithStack(err)
}
decoder := json.NewDecoder(file)
data := map[string]any{}
if err := decoder.Decode(&data); err != nil {
return errors.WithStack(err)
}
s.data.Range(func(key, value any) bool {
s.data.Delete(key)
return true
})
for k, v := range data {
s.data.Store(k, v)
}
return nil
}
func (s *Store) Save(path string) error {
s.loadSaveLock.Lock()
defer s.loadSaveLock.Unlock()
data, err := s.Snapshot()
if err != nil {
return errors.WithStack(err)
}
dir := filepath.Dir(path)
filename := filepath.Base(path)
temp, err := os.CreateTemp(dir, filename+".new*")
if err != nil {
return errors.WithStack(err)
}
defer func() {
if err := os.Remove(temp.Name()); err != nil && !errors.Is(err, os.ErrNotExist) {
slog.Error("could not remove temporary file",
slog.String("file", temp.Name()),
slog.Any("error", errors.WithStack(err)),
)
}
}()
encoder := json.NewEncoder(temp)
if err := encoder.Encode(data); err != nil {
return errors.WithStack(err)
}
if err := os.Rename(temp.Name(), path); err != nil {
return errors.WithStack(err)
}
return nil
}
func (s *Store) Snapshot() (map[string]float64, error) {
data := map[string]float64{}
var err error
s.data.Range(func(rawKey, rawValue any) bool {
key, ok := rawKey.(string)
if !ok {
err = errors.Errorf("unexpected stat key of '%v'", rawKey)
return false
}
value, ok := rawValue.(float64)
if !ok {
err = errors.Errorf("unexpected stat value of '%v'", rawValue)
return false
}
data[key] = value
return true
})
if err != nil {
return nil, errors.WithStack(err)
}
return data, nil
}
func (s *Store) Add(name string, added float64, defaultValue float64) float64 {
for {
value := s.Get(name, defaultValue)
if value == defaultValue {
s.data.Store(name, defaultValue)
}
sum := value + added
if s.data.CompareAndSwap(name, value, value+added) {
return sum
}
}
}
func (s *Store) Set(name string, value float64) float64 {
s.data.Store(name, value)
return value
}
func (s *Store) Get(name string, defaultValue float64) float64 {
rawValue, ok := s.data.Load(name)
if !ok {
return defaultValue
}
value, ok := rawValue.(float64)
if !ok {
return defaultValue
}
return value
}
func NewStore() *Store {
return &Store{
data: sync.Map{},
loadSaveLock: sync.Mutex{},
}
}