File size: 826 Bytes
a3a57c6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 |
import os.path
import os
import monai.networks.nets as nets
import torch
from huggingface_hub import hf_hub_download
from constants import ROOT_DIR, MODEL_FILENAME, HF_MODEL_REPO_NAME
def load_model():
"""
Load pretrained model
"""
model_path = os.path.join(ROOT_DIR, "model", MODEL_FILENAME)
# If model doesnt exist download from huggingface
if not os.path.exists(model_path):
hf_hub_download(HF_MODEL_REPO_NAME, MODEL_FILENAME, local_dir=os.path.join(ROOT_DIR, "model"))
model = nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=3)
if torch.cuda.is_available():
checkpoint = torch.load(model_path)
else:
checkpoint = torch.load(model_path, map_location=torch.device("cpu"))
model.load_state_dict(checkpoint)
model.eval()
return model |