Spaces:
Runtime error
Runtime error
zetavg
improve speed of switching models by offloading unused ones to cpu ram instead if unloading
6148b7c
unverified
from collections import OrderedDict | |
import gc | |
import torch | |
from ..lib.get_device import get_device | |
device_type = get_device() | |
class ModelLRUCache: | |
def __init__(self, capacity=5): | |
self.cache = OrderedDict() | |
self.capacity = capacity | |
def get(self, key): | |
if key in self.cache: | |
# Move the accessed item to the end of the OrderedDict | |
self.cache.move_to_end(key) | |
models_did_move = False | |
for k, m in self.cache.items(): | |
if key != k and m.device.type != 'cpu': | |
models_did_move = True | |
self.cache[k] = m.to('cpu') | |
if models_did_move: | |
gc.collect() | |
# if not shared.args.cpu: # will not be running on CPUs anyway | |
with torch.no_grad(): | |
torch.cuda.empty_cache() | |
model = self.cache[key] | |
if (model.device.type != device_type or | |
hasattr(model, "model") and | |
model.model.device.type != device_type): | |
model = model.to(device_type) | |
return model | |
return None | |
def set(self, key, value): | |
if key in self.cache: | |
# If the key already exists, update its value | |
self.cache[key] = value | |
else: | |
# If the cache has reached its capacity, remove the least recently used item | |
if len(self.cache) >= self.capacity: | |
self.cache.popitem(last=False) | |
self.cache[key] = value | |
def clear(self): | |
self.cache.clear() | |
def prepare_to_set(self): | |
if len(self.cache) >= self.capacity: | |
self.cache.popitem(last=False) | |
models_did_move = False | |
for k, m in self.cache.items(): | |
if m.device.type != 'cpu': | |
models_did_move = True | |
self.cache[k] = m.to('cpu') | |
if models_did_move: | |
gc.collect() | |
# if not shared.args.cpu: # will not be running on CPUs anyway | |
with torch.no_grad(): | |
torch.cuda.empty_cache() | |