File size: 439 Bytes
4dc3e99
ed95f9b
4dc3e99
ed95f9b
ff76a8d
 
 
 
 
 
 
 
 
ed95f9b
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch
from models.ram import RAM

def get_model():
    """
    Load the model.

    :param str model_name: name of the model
    :param str device: device
    :param bool grayscale: if True, the model is trained on grayscale images
    :param bool train: if True, the model is trained
    :return: model
    """
    model = RAM()
    state_dict = torch.load('ckpt/ram.pth.tar')
    model.load_state_dict(state_dict)
    return model