|
import itertools |
|
from typing import Optional |
|
|
|
class TaggedCache: |
|
def __init__(self, tag_settings: Optional[dict]=None): |
|
self._tag_settings = tag_settings or {} |
|
self._data = {} |
|
|
|
def __getitem__(self, key): |
|
for tag_data in self._data.values(): |
|
if key in tag_data: |
|
return tag_data[key] |
|
raise KeyError(f'Key `{key}` does not exist') |
|
|
|
def __setitem__(self, key, value: tuple): |
|
|
|
|
|
|
|
for tag_data in self._data.values(): |
|
if key in tag_data: |
|
tag_data.pop(key, None) |
|
break |
|
|
|
tag = value[0] |
|
if tag not in self._data: |
|
|
|
try: |
|
from cachetools import LRUCache |
|
|
|
default_size = 20 |
|
if 'ckpt' in tag: |
|
default_size = 5 |
|
elif tag in ['latent', 'image']: |
|
default_size = 100 |
|
|
|
self._data[tag] = LRUCache(maxsize=self._tag_settings.get(tag, default_size)) |
|
|
|
except (ImportError, ModuleNotFoundError): |
|
|
|
self._data[tag] = {} |
|
self._data[tag][key] = value |
|
|
|
def __delitem__(self, key): |
|
for tag_data in self._data.values(): |
|
if key in tag_data: |
|
del tag_data[key] |
|
return |
|
raise KeyError(f'Key `{key}` does not exist') |
|
|
|
def __contains__(self, key): |
|
return any(key in tag_data for tag_data in self._data.values()) |
|
|
|
def items(self): |
|
yield from itertools.chain(*map(lambda x :x.items(), self._data.values())) |
|
|
|
def get(self, key, default=None): |
|
"""D.get(k[,d]) -> D[k] if k in D, else d. d defaults to None.""" |
|
for tag_data in self._data.values(): |
|
if key in tag_data: |
|
return tag_data[key] |
|
return default |
|
|
|
def clear(self): |
|
|
|
self._data = {} |
|
|
|
cache_settings = {} |
|
cache = TaggedCache(cache_settings) |
|
cache_count = {} |
|
|
|
def update_cache(k, tag, v): |
|
cache[k] = (tag, v) |
|
cnt = cache_count.get(k) |
|
if cnt is None: |
|
cnt = 0 |
|
cache_count[k] = cnt |
|
else: |
|
cache_count[k] += 1 |
|
def remove_cache(key): |
|
global cache |
|
if key == '*': |
|
cache = TaggedCache(cache_settings) |
|
elif key in cache: |
|
del cache[key] |
|
else: |
|
print(f"invalid {key}") |