Spaces:
Runtime error
Runtime error
File size: 612 Bytes
63775f2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 |
"""
Util function for Model Wrapper
---------------------------------------------------------------------
"""
import glob
import os
import torch
def load_cached_state_dict(model_folder_path):
# Take the first model matching the pattern *model.bin.
model_path_list = glob.glob(os.path.join(model_folder_path, "*model.bin"))
if not model_path_list:
raise FileNotFoundError(
f"model.bin not found in model folder {model_folder_path}."
)
model_path = model_path_list[0]
state_dict = torch.load(model_path, map_location=torch.device("cpu"))
return state_dict
|