Spaces:
Runtime error
Runtime error
File size: 2,117 Bytes
6148b7c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 |
from collections import OrderedDict
import gc
import torch
from ..lib.get_device import get_device
device_type = get_device()
class ModelLRUCache:
def __init__(self, capacity=5):
self.cache = OrderedDict()
self.capacity = capacity
def get(self, key):
if key in self.cache:
# Move the accessed item to the end of the OrderedDict
self.cache.move_to_end(key)
models_did_move = False
for k, m in self.cache.items():
if key != k and m.device.type != 'cpu':
models_did_move = True
self.cache[k] = m.to('cpu')
if models_did_move:
gc.collect()
# if not shared.args.cpu: # will not be running on CPUs anyway
with torch.no_grad():
torch.cuda.empty_cache()
model = self.cache[key]
if (model.device.type != device_type or
hasattr(model, "model") and
model.model.device.type != device_type):
model = model.to(device_type)
return model
return None
def set(self, key, value):
if key in self.cache:
# If the key already exists, update its value
self.cache[key] = value
else:
# If the cache has reached its capacity, remove the least recently used item
if len(self.cache) >= self.capacity:
self.cache.popitem(last=False)
self.cache[key] = value
def clear(self):
self.cache.clear()
def prepare_to_set(self):
if len(self.cache) >= self.capacity:
self.cache.popitem(last=False)
models_did_move = False
for k, m in self.cache.items():
if m.device.type != 'cpu':
models_did_move = True
self.cache[k] = m.to('cpu')
if models_did_move:
gc.collect()
# if not shared.args.cpu: # will not be running on CPUs anyway
with torch.no_grad():
torch.cuda.empty_cache()
|