edge/pkg/storage/driver/cache/lfu/cache.go

350 lines
7.4 KiB
Go

package lfu
import (
"slices"
"sync/atomic"
"time"
"github.com/pkg/errors"
)
var (
ErrNotFound = errors.New("not found")
ErrSizeExceedCapacity = errors.New("size exceed capacity")
errExpired = errors.New("expired")
)
type Cache[K comparable, V any] struct {
index *Map[K, *cacheItem[K, V]]
freqs *List[*frequencyItem[K, V]]
size atomic.Int32
capacity int
store Store[K, V]
getValueSize GetValueSizeFunc[V]
sync *Synchronizer[K]
log LogFunc
ttl time.Duration
}
type cacheItem[K any, V any] struct {
key K
size int
time atomic.Int64
frequencyParent *Element[*frequencyItem[K, V]]
}
func (i *cacheItem[K, V]) Expired(ttl time.Duration) bool {
if ttl == 0 {
return false
}
itemTime := time.Unix(i.time.Load(), 0)
// If item has expired, mark it as not found
return itemTime.Add(ttl).Before(time.Now())
}
func (i *cacheItem[K, V]) Refresh() {
i.time.Store(time.Now().Unix())
}
func newCacheItem[K any, V any](key K, size int) *cacheItem[K, V] {
item := &cacheItem[K, V]{
key: key,
size: size,
}
item.time.Store(time.Now().Unix())
return item
}
type frequencyItem[K any, V any] struct {
entries *Map[*cacheItem[K, V], struct{}]
freq int
}
func newFrequencyItem[K any, V any]() *frequencyItem[K, V] {
frequencyItem := &frequencyItem[K, V]{}
frequencyItem.entries = NewMap[*cacheItem[K, V], struct{}]()
return frequencyItem
}
func (c *Cache[K, V]) Set(key K, value V) error {
newItemSize, err := c.getValueSize(value)
if err != nil {
return errors.WithStack(err)
}
c.log("setting '%v' (size: %d)", key, newItemSize)
if newItemSize > int(c.capacity) {
return errors.Wrapf(ErrSizeExceedCapacity, "item size '%d' exceed cache total capacity of '%v'", newItemSize, c.capacity)
}
var sizeDelta int
err = c.sync.WriteTx(key, func() error {
if err := c.store.Set(key, value); err != nil {
return errors.WithStack(err)
}
item, ok := c.index.Get(key)
if ok {
oldItemSize := item.size
sizeDelta = -int(oldItemSize) + newItemSize
item.Refresh()
} else {
item = newCacheItem[K, V](key, newItemSize)
c.index.Set(key, item)
sizeDelta = newItemSize
}
c.size.Add(int32(sizeDelta))
c.increment(item)
return nil
})
if err != nil {
return errors.WithStack(err)
}
// Eviction, if needed
if err := c.Evict(key); err != nil {
return errors.WithStack(err)
}
return nil
}
func (c *Cache[K, V]) Get(key K) (V, error) {
var value V
err := c.sync.ReadTx(key, func(upgrade func(func())) error {
c.log("getting '%v'", key)
e, ok := c.index.Get(key)
if !ok {
return errors.WithStack(ErrNotFound)
}
if e.Expired(c.ttl) {
return errors.WithStack(errExpired)
}
v, err := c.store.Get(key)
if err != nil {
return errors.WithStack(err)
}
upgrade(func() {
c.increment(e)
})
value = v
return nil
})
if err != nil {
if errors.Is(err, errExpired) {
if err := c.Delete(key); err != nil {
return *new(V), errors.WithStack(err)
}
return *new(V), errors.WithStack(ErrNotFound)
}
return *new(V), errors.WithStack(err)
}
return value, nil
}
func (c *Cache[K, V]) Delete(key K) error {
err := c.sync.WriteTx(key, func() error {
c.log("deleting '%v'", key)
item, exists := c.index.Get(key)
if !exists {
return errors.WithStack(ErrNotFound)
}
if err := c.store.Delete(key); err != nil {
return errors.WithStack(err)
}
c.size.Add(-int32(item.size))
c.remove(item.frequencyParent, item)
c.index.Delete(key)
return nil
})
if err != nil {
return errors.WithStack(err)
}
return nil
}
func (c *Cache[K, V]) Evict(skipped ...K) error {
exceed, delta := c.atCapacity()
if exceed && delta > 0 {
if err := c.evict(delta, skipped...); err != nil {
return errors.WithStack(err)
}
}
return nil
}
func (c *Cache[K, V]) Len() int {
return c.index.Len()
}
func (c *Cache[K, V]) Size() int {
return int(c.size.Load())
}
func (c *Cache[K, V]) Capacity() int {
return c.capacity
}
func (c *Cache[K, V]) increment(item *cacheItem[K, V]) {
currentFrequencyElement := item.frequencyParent
var nextFrequencyAmount int
var nextFrequencyElement *Element[*frequencyItem[K, V]]
if currentFrequencyElement == nil {
nextFrequencyAmount = 1
nextFrequencyElement = c.freqs.First()
} else {
atomicFrequencyItem := c.freqs.Value(currentFrequencyElement)
nextFrequencyAmount = atomicFrequencyItem.freq + 1
nextFrequencyElement = c.freqs.Next(currentFrequencyElement)
}
var nextFrequency *frequencyItem[K, V]
if nextFrequencyElement != nil {
nextFrequency = c.freqs.Value(nextFrequencyElement)
}
if nextFrequencyElement == nil || nextFrequency == nil || nextFrequency.freq != nextFrequencyAmount {
newFrequencyItem := newFrequencyItem[K, V]()
newFrequencyItem.freq = nextFrequencyAmount
if currentFrequencyElement == nil {
nextFrequencyElement = c.freqs.PushFront(newFrequencyItem)
} else {
nextFrequencyElement = c.freqs.InsertValueAfter(newFrequencyItem, currentFrequencyElement)
}
}
item.frequencyParent = nextFrequencyElement
nextFrequency = c.freqs.Value(nextFrequencyElement)
nextFrequency.entries.Set(item, struct{}{})
if currentFrequencyElement != nil {
c.remove(currentFrequencyElement, item)
}
}
func (c *Cache[K, V]) remove(listItem *Element[*frequencyItem[K, V]], item *cacheItem[K, V]) {
entries := c.freqs.Value(listItem).entries
entries.Delete(item)
}
func (c *Cache[K, V]) atCapacity() (bool, int) {
size, capacity := c.Size(), c.Capacity()
c.log("cache stats: %d/%d", size, capacity)
return size >= capacity, size - capacity
}
func (c *Cache[K, V]) evict(total int, skipped ...K) error {
if total == 0 {
return nil
}
frequencyElement := c.freqs.First()
if frequencyElement == nil {
c.log("no frequency element")
return nil
}
for evicted := 0; evicted < total; {
c.log("running eviction: [to_evict:%d, evicted: %d]", total, evicted)
c.log("first frequency element %p", frequencyElement)
frequencyItem := c.freqs.Value(frequencyElement)
if frequencyItem == nil {
return nil
}
entries := frequencyItem.entries
if entries.Len() == 0 {
c.log("no frequency entries")
frequencyElement = c.freqs.Next(frequencyElement)
continue
}
var rangeErr error
entries.Range(func(key, v any) bool {
if evicted >= total {
c.log("evicted enough (%d >= %d), stopping", evicted, total)
return false
}
entry, _ := key.(*cacheItem[K, V])
if slices.Contains(skipped, entry.key) {
c.log("skipping key '%v'", entry.key)
return true
}
if err := c.Delete(entry.key); err != nil {
if errors.Is(err, ErrNotFound) {
c.log("key '%s' not found", entry.key)
// Cleanup obsolete frequency
c.remove(frequencyElement, entry)
return true
}
rangeErr = errors.WithStack(err)
return false
}
c.log("evicted key '%v' (size: %d)", entry.key, entry.size)
evicted += int(entry.size)
return true
})
if rangeErr != nil {
return errors.WithStack(rangeErr)
}
}
return nil
}
func NewCache[K comparable, V any](store Store[K, V], funcs ...OptionsFunc[K, V]) *Cache[K, V] {
opts := DefaultOptions[K, V](funcs...)
cache := &Cache[K, V]{
index: NewMap[K, *cacheItem[K, V]](),
freqs: NewList[*frequencyItem[K, V]](),
capacity: opts.Capacity,
store: store,
getValueSize: opts.GetValueSize,
sync: NewSynchronizer[K](),
log: opts.Log,
ttl: opts.TTL,
}
return cache
}