package lfu import ( "slices" "sync/atomic" "time" "github.com/pkg/errors" ) var ( ErrNotFound = errors.New("not found") ErrSizeExceedCapacity = errors.New("size exceed capacity") errExpired = errors.New("expired") ) type Cache[K comparable, V any] struct { index *Map[K, *cacheItem[K, V]] freqs *List[*frequencyItem[K, V]] size atomic.Int32 capacity int store Store[K, V] getValueSize GetValueSizeFunc[V] sync *Synchronizer[K] log LogFunc ttl time.Duration } type cacheItem[K any, V any] struct { key K size int time atomic.Int64 frequencyParent *Element[*frequencyItem[K, V]] } func (i *cacheItem[K, V]) Expired(ttl time.Duration) bool { if ttl == 0 { return false } itemTime := time.Unix(i.time.Load(), 0) // If item has expired, mark it as not found return itemTime.Add(ttl).Before(time.Now()) } func (i *cacheItem[K, V]) Refresh() { i.time.Store(time.Now().Unix()) } func newCacheItem[K any, V any](key K, size int) *cacheItem[K, V] { item := &cacheItem[K, V]{ key: key, size: size, } item.time.Store(time.Now().Unix()) return item } type frequencyItem[K any, V any] struct { entries *Map[*cacheItem[K, V], struct{}] freq int } func newFrequencyItem[K any, V any]() *frequencyItem[K, V] { frequencyItem := &frequencyItem[K, V]{} frequencyItem.entries = NewMap[*cacheItem[K, V], struct{}]() return frequencyItem } func (c *Cache[K, V]) Set(key K, value V) error { newItemSize, err := c.getValueSize(value) if err != nil { return errors.WithStack(err) } c.log("setting '%v' (size: %d)", key, newItemSize) if newItemSize > int(c.capacity) { return errors.Wrapf(ErrSizeExceedCapacity, "item size '%d' exceed cache total capacity of '%v'", newItemSize, c.capacity) } var sizeDelta int err = c.sync.WriteTx(key, func() error { if err := c.store.Set(key, value); err != nil { return errors.WithStack(err) } item, ok := c.index.Get(key) if ok { oldItemSize := item.size sizeDelta = -int(oldItemSize) + newItemSize item.Refresh() } else { item = newCacheItem[K, V](key, newItemSize) c.index.Set(key, item) sizeDelta = newItemSize } c.size.Add(int32(sizeDelta)) c.increment(item) return nil }) if err != nil { return errors.WithStack(err) } // Eviction, if needed if err := c.Evict(key); err != nil { return errors.WithStack(err) } return nil } func (c *Cache[K, V]) Get(key K) (V, error) { var value V err := c.sync.ReadTx(key, func(upgrade func(func())) error { c.log("getting '%v'", key) e, ok := c.index.Get(key) if !ok { return errors.WithStack(ErrNotFound) } if e.Expired(c.ttl) { return errors.WithStack(errExpired) } v, err := c.store.Get(key) if err != nil { return errors.WithStack(err) } upgrade(func() { c.increment(e) }) value = v return nil }) if err != nil { if errors.Is(err, errExpired) { if err := c.Delete(key); err != nil { return *new(V), errors.WithStack(err) } return *new(V), errors.WithStack(ErrNotFound) } return *new(V), errors.WithStack(err) } return value, nil } func (c *Cache[K, V]) Delete(key K) error { err := c.sync.WriteTx(key, func() error { c.log("deleting '%v'", key) item, exists := c.index.Get(key) if !exists { return errors.WithStack(ErrNotFound) } if err := c.store.Delete(key); err != nil { return errors.WithStack(err) } c.size.Add(-int32(item.size)) c.remove(item.frequencyParent, item) c.index.Delete(key) return nil }) if err != nil { return errors.WithStack(err) } return nil } func (c *Cache[K, V]) Evict(skipped ...K) error { exceed, delta := c.atCapacity() if exceed && delta > 0 { if err := c.evict(delta, skipped...); err != nil { return errors.WithStack(err) } } return nil } func (c *Cache[K, V]) Len() int { return c.index.Len() } func (c *Cache[K, V]) Size() int { return int(c.size.Load()) } func (c *Cache[K, V]) Capacity() int { return c.capacity } func (c *Cache[K, V]) increment(item *cacheItem[K, V]) { currentFrequencyElement := item.frequencyParent var nextFrequencyAmount int var nextFrequencyElement *Element[*frequencyItem[K, V]] if currentFrequencyElement == nil { nextFrequencyAmount = 1 nextFrequencyElement = c.freqs.First() } else { atomicFrequencyItem := c.freqs.Value(currentFrequencyElement) nextFrequencyAmount = atomicFrequencyItem.freq + 1 nextFrequencyElement = c.freqs.Next(currentFrequencyElement) } var nextFrequency *frequencyItem[K, V] if nextFrequencyElement != nil { nextFrequency = c.freqs.Value(nextFrequencyElement) } if nextFrequencyElement == nil || nextFrequency == nil || nextFrequency.freq != nextFrequencyAmount { newFrequencyItem := newFrequencyItem[K, V]() newFrequencyItem.freq = nextFrequencyAmount if currentFrequencyElement == nil { nextFrequencyElement = c.freqs.PushFront(newFrequencyItem) } else { nextFrequencyElement = c.freqs.InsertValueAfter(newFrequencyItem, currentFrequencyElement) } } item.frequencyParent = nextFrequencyElement nextFrequency = c.freqs.Value(nextFrequencyElement) nextFrequency.entries.Set(item, struct{}{}) if currentFrequencyElement != nil { c.remove(currentFrequencyElement, item) } } func (c *Cache[K, V]) remove(listItem *Element[*frequencyItem[K, V]], item *cacheItem[K, V]) { entries := c.freqs.Value(listItem).entries entries.Delete(item) } func (c *Cache[K, V]) atCapacity() (bool, int) { size, capacity := c.Size(), c.Capacity() c.log("cache stats: %d/%d", size, capacity) return size >= capacity, size - capacity } func (c *Cache[K, V]) evict(total int, skipped ...K) error { if total == 0 { return nil } frequencyElement := c.freqs.First() if frequencyElement == nil { c.log("no frequency element") return nil } for evicted := 0; evicted < total; { c.log("running eviction: [to_evict:%d, evicted: %d]", total, evicted) c.log("first frequency element %p", frequencyElement) frequencyItem := c.freqs.Value(frequencyElement) if frequencyItem == nil { return nil } entries := frequencyItem.entries if entries.Len() == 0 { c.log("no frequency entries") frequencyElement = c.freqs.Next(frequencyElement) continue } var rangeErr error entries.Range(func(key, v any) bool { if evicted >= total { c.log("evicted enough (%d >= %d), stopping", evicted, total) return false } entry, _ := key.(*cacheItem[K, V]) if slices.Contains(skipped, entry.key) { c.log("skipping key '%v'", entry.key) return true } if err := c.Delete(entry.key); err != nil { if errors.Is(err, ErrNotFound) { c.log("key '%s' not found", entry.key) // Cleanup obsolete frequency c.remove(frequencyElement, entry) return true } rangeErr = errors.WithStack(err) return false } c.log("evicted key '%v' (size: %d)", entry.key, entry.size) evicted += int(entry.size) return true }) if rangeErr != nil { return errors.WithStack(rangeErr) } } return nil } func NewCache[K comparable, V any](store Store[K, V], funcs ...OptionsFunc[K, V]) *Cache[K, V] { opts := DefaultOptions[K, V](funcs...) cache := &Cache[K, V]{ index: NewMap[K, *cacheItem[K, V]](), freqs: NewList[*frequencyItem[K, V]](), capacity: opts.Capacity, store: store, getValueSize: opts.GetValueSize, sync: NewSynchronizer[K](), log: opts.Log, ttl: opts.TTL, } return cache }