edge/pkg/storage/driver/cache/lfu/cache.go

314 lines
6.5 KiB
Go

package lfu
import (
"sync"
"sync/atomic"
"github.com/pkg/errors"
)
var (
ErrNotFound = errors.New("not found")
ErrSizeExceedCapacity = errors.New("size exceed capacity")
)
type Cache[K comparable, V any] struct {
index *Map[K, *CacheItem[K, V]]
freqs *List[*FrequencyItem[K, V]]
size atomic.Int32
capacity int
store Store[K, V]
getValueSize GetValueSizeFunc[V]
sync *Synchronizer[K]
evictMutex sync.Mutex
log LogFunc
}
type CacheItem[K any, V any] struct {
key K
size int
frequencyParent *Element[*FrequencyItem[K, V]]
}
func NewCacheItem[K any, V any](key K, size int) *CacheItem[K, V] {
item := &CacheItem[K, V]{
key: key,
size: size,
}
return item
}
type FrequencyItem[K any, V any] struct {
entries *Map[*CacheItem[K, V], struct{}]
freq int
}
func NewFrequencyItem[K any, V any]() *FrequencyItem[K, V] {
frequencyItem := &FrequencyItem[K, V]{}
frequencyItem.entries = NewMap[*CacheItem[K, V], struct{}]()
return frequencyItem
}
func (c *Cache[K, V]) Set(key K, value V) error {
newItemSize, err := c.getValueSize(value)
if err != nil {
return errors.WithStack(err)
}
c.log("setting '%v' (size: %d)", key, newItemSize)
if newItemSize > int(c.capacity) {
return errors.WithStack(ErrSizeExceedCapacity)
}
var sizeDelta int
err = c.sync.WriteTx(key, func() error {
if err := c.store.Set(key, value); err != nil {
return errors.WithStack(err)
}
item, ok := c.index.Get(key)
if ok {
oldItemSize := item.size
sizeDelta = -int(oldItemSize) + newItemSize
} else {
item = NewCacheItem[K, V](key, newItemSize)
c.index.Set(key, item)
sizeDelta = newItemSize
}
c.size.Add(int32(sizeDelta))
c.increment(item)
return nil
})
if err != nil {
return errors.WithStack(err)
}
// Eviction, if needed
if err := c.Evict(); err != nil {
return errors.WithStack(err)
}
return nil
}
func (c *Cache[K, V]) Get(key K) (V, error) {
var value V
err := c.sync.WriteTx(key, func() error {
c.log("getting '%v'", key)
e, ok := c.index.Get(key)
if !ok {
return errors.WithStack(ErrNotFound)
}
v, err := c.store.Get(key)
if err != nil {
return errors.WithStack(err)
}
c.increment(e)
value = v
return nil
})
if err != nil {
return *new(V), errors.WithStack(err)
}
return value, nil
}
func (c *Cache[K, V]) Delete(key K) error {
err := c.sync.WriteTx(key, func() error {
c.log("deleting '%v'", key)
item, exists := c.index.Get(key)
if !exists {
return errors.WithStack(ErrNotFound)
}
if err := c.store.Delete(key); err != nil {
return errors.WithStack(err)
}
c.size.Add(-int32(item.size))
c.remove(item.frequencyParent, item)
c.index.Delete(key)
return nil
})
if err != nil {
return errors.WithStack(err)
}
return nil
}
func (c *Cache[K, V]) Len() int {
return c.index.Len()
}
func (c *Cache[K, V]) Size() int {
return int(c.size.Load())
}
func (c *Cache[K, V]) Capacity() int {
return c.capacity
}
func (c *Cache[K, V]) increment(item *CacheItem[K, V]) {
currentFrequencyElement := item.frequencyParent
var nextFrequencyAmount int
var nextFrequencyElement *Element[*FrequencyItem[K, V]]
if currentFrequencyElement == nil {
nextFrequencyAmount = 1
nextFrequencyElement = c.freqs.First()
} else {
atomicFrequencyItem := c.freqs.Value(currentFrequencyElement)
nextFrequencyAmount = atomicFrequencyItem.freq + 1
nextFrequencyElement = c.freqs.Next(currentFrequencyElement)
}
var nextFrequency *FrequencyItem[K, V]
if nextFrequencyElement != nil {
nextFrequency = c.freqs.Value(nextFrequencyElement)
}
if nextFrequencyElement == nil || nextFrequency == nil || nextFrequency.freq != nextFrequencyAmount {
newFrequencyItem := NewFrequencyItem[K, V]()
newFrequencyItem.freq = nextFrequencyAmount
if currentFrequencyElement == nil {
nextFrequencyElement = c.freqs.PushFront(newFrequencyItem)
} else {
nextFrequencyElement = c.freqs.InsertValueAfter(newFrequencyItem, currentFrequencyElement)
}
}
item.frequencyParent = nextFrequencyElement
nextFrequency = c.freqs.Value(nextFrequencyElement)
nextFrequency.entries.Set(item, struct{}{})
if currentFrequencyElement != nil {
c.remove(currentFrequencyElement, item)
}
}
func (c *Cache[K, V]) remove(listItem *Element[*FrequencyItem[K, V]], item *CacheItem[K, V]) {
entries := c.freqs.Value(listItem).entries
entries.Delete(item)
// if entries.Len() == 0 {
// c.freqs.Remove(listItem)
// }
}
func (c *Cache[K, V]) atCapacity() (bool, int) {
size, capacity := c.Size(), c.Capacity()
c.log("cache stats: %d/%d", size, capacity)
return size >= capacity, size - capacity
}
func (c *Cache[K, V]) evict(total int) error {
if total == 0 {
return nil
}
frequencyElement := c.freqs.First()
if frequencyElement == nil {
c.log("no frequency element")
return nil
}
for evicted := 0; evicted < total; {
c.log("running eviction: [to_evict:%d, evicted: %d]", total, evicted)
c.log("first frequency element %p", frequencyElement)
frequencyItem := c.freqs.Value(frequencyElement)
if frequencyItem == nil {
return nil
}
entries := frequencyItem.entries
if entries.Len() == 0 {
c.log("no frequency entries")
frequencyElement = c.freqs.Next(frequencyElement)
continue
}
var rangeErr error
entries.Range(func(key, v any) bool {
if evicted >= total {
c.log("evicted enough (%d >= %d), stopping", evicted, total)
return false
}
entry, _ := key.(*CacheItem[K, V])
if err := c.Delete(entry.key); err != nil {
if errors.Is(err, ErrNotFound) {
c.log("key '%s' not found", entry.key)
// Cleanup obsolete frequency
c.remove(frequencyElement, entry)
return true
}
rangeErr = errors.WithStack(err)
return false
}
c.log("evicted key '%v' (size: %d)", entry.key, entry.size)
evicted += int(entry.size)
return true
})
if rangeErr != nil {
return errors.WithStack(rangeErr)
}
}
return nil
}
func (c *Cache[K, V]) Evict() error {
exceed, delta := c.atCapacity()
if exceed && delta > 0 {
if err := c.evict(delta); err != nil {
return errors.WithStack(err)
}
}
return nil
}
func New[K comparable, V any](store Store[K, V], funcs ...OptionsFunc[K, V]) *Cache[K, V] {
opts := DefaultOptions[K, V](funcs...)
cache := &Cache[K, V]{
index: NewMap[K, *CacheItem[K, V]](),
freqs: NewList[*FrequencyItem[K, V]](),
capacity: opts.Capacity,
store: store,
getValueSize: opts.GetValueSize,
sync: NewSynchronizer[K](),
log: opts.Log,
}
return cache
}