Spaces:
Sleeping
Sleeping
import torch | |
from models.ram import RAM | |
def get_model(): | |
""" | |
Load the model. | |
:param str model_name: name of the model | |
:param str device: device | |
:param bool grayscale: if True, the model is trained on grayscale images | |
:param bool train: if True, the model is trained | |
:return: model | |
""" | |
model = RAM() | |
state_dict = torch.load('ckpt/ram.pth.tar') | |
model.load_state_dict(state_dict) | |
return model |