Spaces:
Running
on
T4
Running
on
T4
import torch | |
from models.ram import RAM | |
from huggingface_hub import hf_hub_download | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
def get_model(): | |
""" | |
Load the model. | |
:param str model_name: name of the model | |
:param str device: device | |
:param bool grayscale: if True, the model is trained on grayscale images | |
:param bool train: if True, the model is trained | |
:return: model | |
""" | |
model = RAM() | |
model.load_state_dict(torch.load(hf_hub_download(repo_id="mterris/ram", filename="ram.pth.tar"), map_location=device)) | |
return model |