import os.path import os import monai.networks.nets as nets import torch from huggingface_hub import hf_hub_download from constants import ROOT_DIR, MODEL_FILENAME, HF_MODEL_REPO_NAME def load_model(): """ Load pretrained model """ model_path = os.path.join(ROOT_DIR, "model", MODEL_FILENAME) # If model doesnt exist download from huggingface if not os.path.exists(model_path): hf_hub_download(HF_MODEL_REPO_NAME, MODEL_FILENAME, local_dir=os.path.join(ROOT_DIR, "model")) model = nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=3) if torch.cuda.is_available(): checkpoint = torch.load(model_path) else: checkpoint = torch.load(model_path, map_location=torch.device("cpu")) model.load_state_dict(checkpoint) model.eval() return model