nnUNet_Brain_EPTN / predict_nnunet.py
Margerie's picture
First upload
59d4c6b verified
import torch
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
def predictNNUNet(model_dir, input_dir, output_dir, folds):
predictor = nnUNetPredictor(
tile_step_size=0.9, #0.5,
use_gaussian=True,
use_mirroring=False, # --disable_tta
# perform_everything_on_device=True,
device=torch.device('cpu', 0),
verbose=True,
verbose_preprocessing=False,
allow_tqdm=True,
)
predictor.initialize_from_trained_model_folder(
model_dir,
use_folds=folds, # None if autodetect folds
checkpoint_name='checkpoint_final.pth',
)
print("input_dir",input_dir)
predictor.predict_from_files(input_dir,
output_dir,
save_probabilities=False,
overwrite=True,
num_processes_preprocessing=2,
num_processes_segmentation_export=2,
folder_with_segs_from_prev_stage=None,
num_parts=1,
part_id=0
)