|
import os.path |
|
import os |
|
|
|
import monai.networks.nets as nets |
|
import torch |
|
|
|
from huggingface_hub import hf_hub_download |
|
|
|
from constants import ROOT_DIR, MODEL_FILENAME, HF_MODEL_REPO_NAME |
|
|
|
def load_model(): |
|
""" |
|
Load pretrained model |
|
""" |
|
|
|
model_path = os.path.join(ROOT_DIR, "model", MODEL_FILENAME) |
|
|
|
|
|
if not os.path.exists(model_path): |
|
hf_hub_download(HF_MODEL_REPO_NAME, MODEL_FILENAME, local_dir=os.path.join(ROOT_DIR, "model")) |
|
|
|
model = nets.DenseNet121(spatial_dims=3, in_channels=1, out_channels=3) |
|
if torch.cuda.is_available(): |
|
checkpoint = torch.load(model_path) |
|
else: |
|
checkpoint = torch.load(model_path, map_location=torch.device("cpu")) |
|
model.load_state_dict(checkpoint) |
|
model.eval() |
|
|
|
return model |