Spaces:
Sleeping
Sleeping
File size: 578 Bytes
4dc3e99 ed95f9b 4dc3e99 a1a82a6 81c09b8 ed95f9b ff76a8d ed95f9b 81c09b8 ed95f9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 |
import torch
from models.ram import RAM
from huggingface_hub import hf_hub_download
device = 'cuda' if torch.cuda.is_available() else 'cpu'
def get_model():
"""
Load the model.
:param str model_name: name of the model
:param str device: device
:param bool grayscale: if True, the model is trained on grayscale images
:param bool train: if True, the model is trained
:return: model
"""
model = RAM()
model.load_state_dict(torch.load(hf_hub_download(repo_id="mterris/ram", filename="ram.pth.tar"), map_location=device))
return model |