import itertools from typing import Optional class TaggedCache: def __init__(self, tag_settings: Optional[dict]=None): self._tag_settings = tag_settings or {} # tag cache size self._data = {} def __getitem__(self, key): for tag_data in self._data.values(): if key in tag_data: return tag_data[key] raise KeyError(f'Key `{key}` does not exist') def __setitem__(self, key, value: tuple): # value: (tag: str, (islist: bool, data: *)) # if key already exists, pop old value for tag_data in self._data.values(): if key in tag_data: tag_data.pop(key, None) break tag = value[0] if tag not in self._data: try: from cachetools import LRUCache default_size = 20 if 'ckpt' in tag: default_size = 5 elif tag in ['latent', 'image']: default_size = 100 self._data[tag] = LRUCache(maxsize=self._tag_settings.get(tag, default_size)) except (ImportError, ModuleNotFoundError): # TODO: implement a simple lru dict self._data[tag] = {} self._data[tag][key] = value def __delitem__(self, key): for tag_data in self._data.values(): if key in tag_data: del tag_data[key] return raise KeyError(f'Key `{key}` does not exist') def __contains__(self, key): return any(key in tag_data for tag_data in self._data.values()) def items(self): yield from itertools.chain(*map(lambda x :x.items(), self._data.values())) def get(self, key, default=None): """D.get(k[,d]) -> D[k] if k in D, else d. d defaults to None.""" for tag_data in self._data.values(): if key in tag_data: return tag_data[key] return default def clear(self): # clear all cache self._data = {} cache_settings = {} cache = TaggedCache(cache_settings) cache_count = {} def update_cache(k, tag, v): cache[k] = (tag, v) cnt = cache_count.get(k) if cnt is None: cnt = 0 cache_count[k] = cnt else: cache_count[k] += 1 def remove_cache(key): global cache if key == '*': cache = TaggedCache(cache_settings) elif key in cache: del cache[key] else: print(f"invalid {key}")