File size: 1,230 Bytes
59d4c6b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 |
import torch
from nnunetv2.inference.predict_from_raw_data import nnUNetPredictor
def predictNNUNet(model_dir, input_dir, output_dir, folds):
predictor = nnUNetPredictor(
tile_step_size=0.9, #0.5,
use_gaussian=True,
use_mirroring=False, # --disable_tta
# perform_everything_on_device=True,
device=torch.device('cpu', 0),
verbose=True,
verbose_preprocessing=False,
allow_tqdm=True,
)
predictor.initialize_from_trained_model_folder(
model_dir,
use_folds=folds, # None if autodetect folds
checkpoint_name='checkpoint_final.pth',
)
print("input_dir",input_dir)
predictor.predict_from_files(input_dir,
output_dir,
save_probabilities=False,
overwrite=True,
num_processes_preprocessing=2,
num_processes_segmentation_export=2,
folder_with_segs_from_prev_stage=None,
num_parts=1,
part_id=0
) |