denoising / model_factory.py
mterris's picture
update
ed95f9b
raw
history blame
439 Bytes
import torch
from models.ram import RAM
def get_model():
"""
Load the model.
:param str model_name: name of the model
:param str device: device
:param bool grayscale: if True, the model is trained on grayscale images
:param bool train: if True, the model is trained
:return: model
"""
model = RAM()
state_dict = torch.load('ckpt/ram.pth.tar')
model.load_state_dict(state_dict)
return model