Alzheimer-Classifier-Demo / model /download_model.py
Thomas Lucchetta
Add files via upload
a3a57c6 unverified
raw
history blame
826 Bytes
import os.path
import os
import monai.networks.nets as nets
import torch
from huggingface_hub import hf_hub_download
from constants import ROOT_DIR, MODEL_FILENAME, HF_MODEL_REPO_NAME
def load_model():
"""
Load pretrained model
"""
model_path = os.path.join(ROOT_DIR, "model", MODEL_FILENAME)
# If model doesnt exist download from huggingface
if not os.path.exists(model_path):
hf_hub_download(HF_MODEL_REPO_NAME, MODEL_FILENAME, local_dir=os.path.join(ROOT_DIR, "model"))
model = nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=3)
if torch.cuda.is_available():
checkpoint = torch.load(model_path)
else:
checkpoint = torch.load(model_path, map_location=torch.device("cpu"))
model.load_state_dict(checkpoint)
model.eval()
return model