package lfu import ( "sync" "sync/atomic" "github.com/pkg/errors" ) var ( ErrNotFound = errors.New("not found") ErrSizeExceedCapacity = errors.New("size exceed capacity") ) type Cache[K comparable, V any] struct { index *Map[K, *CacheItem[K, V]] freqs *List[*FrequencyItem[K, V]] size atomic.Int32 capacity int store Store[K, V] getValueSize GetValueSizeFunc[V] sync *Synchronizer[K] evictMutex sync.Mutex log LogFunc } type CacheItem[K any, V any] struct { key K size int frequencyParent *Element[*FrequencyItem[K, V]] } func NewCacheItem[K any, V any](key K, size int) *CacheItem[K, V] { item := &CacheItem[K, V]{ key: key, size: size, } return item } type FrequencyItem[K any, V any] struct { entries *Map[*CacheItem[K, V], struct{}] freq int } func NewFrequencyItem[K any, V any]() *FrequencyItem[K, V] { frequencyItem := &FrequencyItem[K, V]{} frequencyItem.entries = NewMap[*CacheItem[K, V], struct{}]() return frequencyItem } func (c *Cache[K, V]) Set(key K, value V) error { newItemSize, err := c.getValueSize(value) if err != nil { return errors.WithStack(err) } c.log("setting '%v' (size: %d)", key, newItemSize) if newItemSize > int(c.capacity) { return errors.WithStack(ErrSizeExceedCapacity) } var sizeDelta int err = c.sync.WriteTx(key, func() error { if err := c.store.Set(key, value); err != nil { return errors.WithStack(err) } item, ok := c.index.Get(key) if ok { oldItemSize := item.size sizeDelta = -int(oldItemSize) + newItemSize } else { item = NewCacheItem[K, V](key, newItemSize) c.index.Set(key, item) sizeDelta = newItemSize } c.size.Add(int32(sizeDelta)) c.increment(item) return nil }) if err != nil { return errors.WithStack(err) } // Eviction, if needed if err := c.Evict(); err != nil { return errors.WithStack(err) } return nil } func (c *Cache[K, V]) Get(key K) (V, error) { var value V err := c.sync.WriteTx(key, func() error { c.log("getting '%v'", key) e, ok := c.index.Get(key) if !ok { return errors.WithStack(ErrNotFound) } v, err := c.store.Get(key) if err != nil { return errors.WithStack(err) } c.increment(e) value = v return nil }) if err != nil { return *new(V), errors.WithStack(err) } return value, nil } func (c *Cache[K, V]) Delete(key K) error { err := c.sync.WriteTx(key, func() error { c.log("deleting '%v'", key) item, exists := c.index.Get(key) if !exists { return errors.WithStack(ErrNotFound) } if err := c.store.Delete(key); err != nil { return errors.WithStack(err) } c.size.Add(-int32(item.size)) c.remove(item.frequencyParent, item) c.index.Delete(key) return nil }) if err != nil { return errors.WithStack(err) } return nil } func (c *Cache[K, V]) Len() int { return c.index.Len() } func (c *Cache[K, V]) Size() int { return int(c.size.Load()) } func (c *Cache[K, V]) Capacity() int { return c.capacity } func (c *Cache[K, V]) increment(item *CacheItem[K, V]) { currentFrequencyElement := item.frequencyParent var nextFrequencyAmount int var nextFrequencyElement *Element[*FrequencyItem[K, V]] if currentFrequencyElement == nil { nextFrequencyAmount = 1 nextFrequencyElement = c.freqs.First() } else { atomicFrequencyItem := c.freqs.Value(currentFrequencyElement) nextFrequencyAmount = atomicFrequencyItem.freq + 1 nextFrequencyElement = c.freqs.Next(currentFrequencyElement) } var nextFrequency *FrequencyItem[K, V] if nextFrequencyElement != nil { nextFrequency = c.freqs.Value(nextFrequencyElement) } if nextFrequencyElement == nil || nextFrequency == nil || nextFrequency.freq != nextFrequencyAmount { newFrequencyItem := NewFrequencyItem[K, V]() newFrequencyItem.freq = nextFrequencyAmount if currentFrequencyElement == nil { nextFrequencyElement = c.freqs.PushFront(newFrequencyItem) } else { nextFrequencyElement = c.freqs.InsertValueAfter(newFrequencyItem, currentFrequencyElement) } } item.frequencyParent = nextFrequencyElement nextFrequency = c.freqs.Value(nextFrequencyElement) nextFrequency.entries.Set(item, struct{}{}) if currentFrequencyElement != nil { c.remove(currentFrequencyElement, item) } } func (c *Cache[K, V]) remove(listItem *Element[*FrequencyItem[K, V]], item *CacheItem[K, V]) { entries := c.freqs.Value(listItem).entries entries.Delete(item) // if entries.Len() == 0 { // c.freqs.Remove(listItem) // } } func (c *Cache[K, V]) atCapacity() (bool, int) { size, capacity := c.Size(), c.Capacity() c.log("cache stats: %d/%d", size, capacity) return size >= capacity, size - capacity } func (c *Cache[K, V]) evict(total int) error { if total == 0 { return nil } frequencyElement := c.freqs.First() if frequencyElement == nil { c.log("no frequency element") return nil } for evicted := 0; evicted < total; { c.log("running eviction: [to_evict:%d, evicted: %d]", total, evicted) c.log("first frequency element %p", frequencyElement) frequencyItem := c.freqs.Value(frequencyElement) if frequencyItem == nil { return nil } entries := frequencyItem.entries if entries.Len() == 0 { c.log("no frequency entries") frequencyElement = c.freqs.Next(frequencyElement) continue } var rangeErr error entries.Range(func(key, v any) bool { if evicted >= total { c.log("evicted enough (%d >= %d), stopping", evicted, total) return false } entry, _ := key.(*CacheItem[K, V]) if err := c.Delete(entry.key); err != nil { if errors.Is(err, ErrNotFound) { c.log("key '%s' not found", entry.key) // Cleanup obsolete frequency c.remove(frequencyElement, entry) return true } rangeErr = errors.WithStack(err) return false } c.log("evicted key '%v' (size: %d)", entry.key, entry.size) evicted += int(entry.size) return true }) if rangeErr != nil { return errors.WithStack(rangeErr) } } return nil } func (c *Cache[K, V]) Evict() error { exceed, delta := c.atCapacity() if exceed && delta > 0 { if err := c.evict(delta); err != nil { return errors.WithStack(err) } } return nil } func New[K comparable, V any](store Store[K, V], funcs ...OptionsFunc[K, V]) *Cache[K, V] { opts := DefaultOptions[K, V](funcs...) cache := &Cache[K, V]{ index: NewMap[K, *CacheItem[K, V]](), freqs: NewList[*FrequencyItem[K, V]](), capacity: opts.Capacity, store: store, getValueSize: opts.GetValueSize, sync: NewSynchronizer[K](), log: opts.Log, } return cache }