edge/pkg/storage/driver/cache/lfu/cache.go

263 lines
5.8 KiB
Go

package lfu
import (
"log"
"sync/atomic"
"github.com/pkg/errors"
)
var (
ErrNotFound = errors.New("not found")
ErrSizeExceedCapacity = errors.New("size exceed capacity")
)
type Cache[K comparable, V any] struct {
index *Map[K, *CacheItem[K, V]]
freqs *List[*FrequencyItem[K, V]]
size atomic.Int32
capacity int
store Store[K, V]
getValueSize GetValueSizeFunc[V]
}
type CacheItem[K any, V any] struct {
key K
size atomic.Int32
frequencyParent atomic.Pointer[Element[*FrequencyItem[K, V]]]
}
func NewCacheItem[K any, V any](key K, size int32) *CacheItem[K, V] {
item := &CacheItem[K, V]{
key: key,
}
item.size.Store(size)
return item
}
type FrequencyItem[K any, V any] struct {
entries atomic.Pointer[Map[*CacheItem[K, V], struct{}]]
freq atomic.Int32
}
func NewFrequencyItem[K any, V any]() *FrequencyItem[K, V] {
frequencyItem := &FrequencyItem[K, V]{}
frequencyItem.entries.Store(NewMap[*CacheItem[K, V], struct{}]())
return frequencyItem
}
func (c *Cache[K, V]) Set(key K, value V) error {
newItemSize, err := c.getValueSize(value)
if err != nil {
return errors.WithStack(err)
}
log.Printf("setting '%v' (size: %d)", key, newItemSize)
if newItemSize > int(c.capacity) {
return errors.WithStack(ErrSizeExceedCapacity)
}
if err := c.store.Set(key, value); err != nil {
return errors.WithStack(err)
}
if item, ok := c.index.Get(key); ok {
// oldItemSize := item.size.Swap(int32(newItemSize))
// delta := -int(oldItemSize) + newItemSize
// // Eviction, if needed
// if c.atCapacity(delta) {
// c.evict(delta)
// }
// c.size.Add(int32(delta))
c.increment(item)
} else {
item := NewCacheItem[K, V](key, int32(newItemSize))
c.index.Set(key, item)
for c.atCapacity(newItemSize) {
c.evict(newItemSize)
}
c.size.Add(int32(newItemSize))
c.increment(item)
}
return nil
}
func (c *Cache[K, V]) Get(key K) (V, error) {
log.Printf("getting '%v'", key)
if e, ok := c.index.Get(key); ok {
c.increment(e)
value, err := c.store.Get(key)
if err != nil {
return *new(V), errors.WithStack(err)
}
return value, nil
}
return *new(V), errors.WithStack(ErrNotFound)
}
func (c *Cache[K, V]) Delete(key K) error {
log.Printf("deleting '%v'", key)
item, exists := c.index.Get(key)
if !exists {
return errors.WithStack(ErrNotFound)
}
if err := c.store.Delete(key); err != nil {
return errors.WithStack(err)
}
c.size.Add(-int32(item.size.Load()))
c.remove(item.frequencyParent.Load(), item)
c.index.Delete(key)
return nil
}
func (c *Cache[K, V]) Len() int {
return c.index.Len()
}
func (c *Cache[K, V]) Size() int {
return int(c.size.Load())
}
func (c *Cache[K, V]) Capacity() int {
return c.capacity
}
func (c *Cache[K, V]) increment(item *CacheItem[K, V]) {
currentFrequencyElement := item.frequencyParent.Load()
var nextFrequencyAmount int
var nextFrequencyElement *Element[*FrequencyItem[K, V]]
if currentFrequencyElement == nil {
nextFrequencyAmount = 1
nextFrequencyElement = c.freqs.First()
} else {
atomicFrequencyItem := currentFrequencyElement.Value()
nextFrequencyAmount = int(atomicFrequencyItem.freq.Load()) + 1
nextFrequencyElement = currentFrequencyElement.Next()
}
var nextFrequency *FrequencyItem[K, V]
if nextFrequencyElement != nil {
nextFrequency = nextFrequencyElement.Value()
}
if nextFrequencyElement == nil || nextFrequency == nil || int(nextFrequency.freq.Load()) != nextFrequencyAmount {
newFrequencyItem := NewFrequencyItem[K, V]()
newFrequencyItem.freq.Store(int32(nextFrequencyAmount))
if currentFrequencyElement == nil {
nextFrequencyElement = c.freqs.PushFront(newFrequencyItem)
} else {
nextFrequencyElement = c.freqs.InsertValueAfter(newFrequencyItem, currentFrequencyElement)
}
}
item.frequencyParent.Store(nextFrequencyElement)
nextFrequency = nextFrequencyElement.Value()
nextFrequency.entries.Load().Set(item, struct{}{})
if currentFrequencyElement != nil {
c.remove(item.frequencyParent.Load(), item)
}
}
func (c *Cache[K, V]) remove(listItem *Element[*FrequencyItem[K, V]], item *CacheItem[K, V]) bool {
if listItem == nil {
return false
}
entries := listItem.Value().entries.Load()
if entries == nil {
return false
}
entries.Delete(item)
if entries.Len() == 0 {
c.freqs.Remove(listItem)
}
return true
}
func (c *Cache[K, V]) atCapacity(delta int) bool {
size, capacity := c.Size(), c.Capacity()
log.Printf("at capacity: %d/%d", size, capacity)
return size+delta >= capacity
}
func (c *Cache[K, V]) evict(total int) error {
log.Printf("evicting for %d", total)
for evicted := 0; evicted < total; {
frequencyElement := c.freqs.First()
if frequencyElement == nil {
return nil
}
frequencyItem := frequencyElement.Value()
if frequencyItem == nil {
return nil
}
entries := frequencyItem.entries.Load()
if entries == nil {
return nil
}
var rangeErr error
entries.Range(func(key, v any) bool {
if evicted >= total {
return false
}
entry, _ := key.(*CacheItem[K, V])
entrySize := entry.size.Load()
log.Printf("evicting key '%v' (size: %d)", entry.key, entrySize)
if err := c.store.Delete(entry.key); err != nil {
rangeErr = errors.WithStack(err)
return false
}
c.index.Delete(entry.key)
c.size.Add(-entrySize)
c.remove(frequencyElement, entry)
evicted += int(entrySize)
return true
})
if rangeErr != nil {
return errors.WithStack(rangeErr)
}
}
return nil
}
func New[K comparable, V any](store Store[K, V], funcs ...OptionsFunc[K, V]) *Cache[K, V] {
opts := DefaultOptions[K, V](funcs...)
cache := &Cache[K, V]{
index: NewMap[K, *CacheItem[K, V]](),
freqs: NewList[*FrequencyItem[K, V]](),
capacity: opts.Capacity,
store: store,
getValueSize: opts.GetValueSize,
}
return cache
}