File size: 826 Bytes
a3a57c6
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
import os.path
import os

import monai.networks.nets as nets
import torch

from huggingface_hub import hf_hub_download

from constants import ROOT_DIR, MODEL_FILENAME, HF_MODEL_REPO_NAME

def load_model():
    """
    Load pretrained model
    """

    model_path = os.path.join(ROOT_DIR, "model", MODEL_FILENAME)

    # If model doesnt exist download from huggingface
    if not os.path.exists(model_path):
        hf_hub_download(HF_MODEL_REPO_NAME, MODEL_FILENAME, local_dir=os.path.join(ROOT_DIR, "model"))

    model = nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=3)
    if torch.cuda.is_available():
        checkpoint = torch.load(model_path)
    else:
        checkpoint = torch.load(model_path, map_location=torch.device("cpu"))
    model.load_state_dict(checkpoint)
    model.eval()

    return model