File size: 2,502 Bytes
82ea528 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 |
import itertools
from typing import Optional
class TaggedCache:
def __init__(self, tag_settings: Optional[dict]=None):
self._tag_settings = tag_settings or {} # tag cache size
self._data = {}
def __getitem__(self, key):
for tag_data in self._data.values():
if key in tag_data:
return tag_data[key]
raise KeyError(f'Key `{key}` does not exist')
def __setitem__(self, key, value: tuple):
# value: (tag: str, (islist: bool, data: *))
# if key already exists, pop old value
for tag_data in self._data.values():
if key in tag_data:
tag_data.pop(key, None)
break
tag = value[0]
if tag not in self._data:
try:
from cachetools import LRUCache
default_size = 20
if 'ckpt' in tag:
default_size = 5
elif tag in ['latent', 'image']:
default_size = 100
self._data[tag] = LRUCache(maxsize=self._tag_settings.get(tag, default_size))
except (ImportError, ModuleNotFoundError):
# TODO: implement a simple lru dict
self._data[tag] = {}
self._data[tag][key] = value
def __delitem__(self, key):
for tag_data in self._data.values():
if key in tag_data:
del tag_data[key]
return
raise KeyError(f'Key `{key}` does not exist')
def __contains__(self, key):
return any(key in tag_data for tag_data in self._data.values())
def items(self):
yield from itertools.chain(*map(lambda x :x.items(), self._data.values()))
def get(self, key, default=None):
"""D.get(k[,d]) -> D[k] if k in D, else d. d defaults to None."""
for tag_data in self._data.values():
if key in tag_data:
return tag_data[key]
return default
def clear(self):
# clear all cache
self._data = {}
cache_settings = {}
cache = TaggedCache(cache_settings)
cache_count = {}
def update_cache(k, tag, v):
cache[k] = (tag, v)
cnt = cache_count.get(k)
if cnt is None:
cnt = 0
cache_count[k] = cnt
else:
cache_count[k] += 1
def remove_cache(key):
global cache
if key == '*':
cache = TaggedCache(cache_settings)
elif key in cache:
del cache[key]
else:
print(f"invalid {key}") |