File size: 578 Bytes
4dc3e99
ed95f9b
4dc3e99
a1a82a6
 
81c09b8
 
ed95f9b
ff76a8d
 
 
 
 
 
 
 
 
ed95f9b
81c09b8
ed95f9b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
import torch
from models.ram import RAM

from huggingface_hub import hf_hub_download

device = 'cuda' if torch.cuda.is_available() else 'cpu'

def get_model():
    """
    Load the model.

    :param str model_name: name of the model
    :param str device: device
    :param bool grayscale: if True, the model is trained on grayscale images
    :param bool train: if True, the model is trained
    :return: model
    """
    model = RAM()
    model.load_state_dict(torch.load(hf_hub_download(repo_id="mterris/ram", filename="ram.pth.tar"), map_location=device))
    return model