File size: 410 Bytes
4dc3e99
ed95f9b
4dc3e99
81c09b8
 
ed95f9b
ff76a8d
 
 
 
 
 
 
 
 
ed95f9b
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch
from models.ram import RAM

device = 'cuda' if torch.cuda.is_available() else 'cpu'

def get_model():
    """
    Load the model.

    :param str model_name: name of the model
    :param str device: device
    :param bool grayscale: if True, the model is trained on grayscale images
    :param bool train: if True, the model is trained
    :return: model
    """
    model = RAM()
    return model