File size: 7,292 Bytes
6ffe23f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
import numpy as np
import torch
from monai.transforms import (
Activations, AsDiscreteD, AsDiscrete, Compose, ToTensorD,
GaussianSmoothD, LoadImageD, TransposeD, OrientationD, ScaleIntensityRangeD,
ToTensor, FillHoles, KeepLargestConnectedComponent, NormalizeIntensityD
)
from nrrd import read
from visualization import visualize_results
from data_preparation import get_patient_dictionaries
from monai.data import Dataset, DataLoader
import os
from data_transforms import ConvertMaskValues, MaskOutNonliver
from pipeline import build_model, evaluate
def run_sequential_inference(txt_file, config_liver, config_tumor, eval_metrics, output_dir, only_tumor=False, export=True):
def custom_collate_fn(batch):
num_samples_to_select = config_liver['BATCH_SIZE']
# Extract images and masks from the batch, ensure image and mask same size
images, masks, pred_liver = [], [], []
for sample in batch:
num_samples = min(sample["image"].shape[0], sample["mask"].shape[0])
random_indices = torch.randperm(num_samples)[:num_samples_to_select]
images.append(sample["image"][:,:512,:512,:])
masks.append(sample["mask"][:,:512,:512,:])
# Stack images and masks along the first dimension
try:
concatenated_images = torch.stack(images, dim=0)
concatenated_masks = torch.stack(masks, dim=0)
except Exception as e:
print("WARNING: not all images/masks are 512 by 512. Please check. ", images[0].shape, images[1].shape, masks[0].shape, masks[1].shape)
return None, None
# Return stacked images and masks as tensors
if "pred_liver" in sample.keys():
return {"image": concatenated_images, "mask": concatenated_masks, "pred_liver": sample["pred_liver"]}
else:
return {"image": concatenated_images, "mask": concatenated_masks}
### Model preparation
print("")
print("Loading models....")
liver_model = build_model(config_liver)
tumor_model = build_model(config_tumor)
#### Data preparation
print("")
print("Loading test data....")
test_data_dict = get_patient_dictionaries(txt_file=txt_file, data_dir=config_liver['DATA_DIR'])
print(" Number of test patients:", len(test_data_dict))
# assign output file names and paths
export_file_metadata = []
if not os.path.exists(output_dir): os.makedirs(output_dir)
for patient_dict in test_data_dict:
patient_folder = os.path.join(output_dir, patient_dict['patient_id'])
if not os.path.exists(patient_folder): os.makedirs(patient_folder)
patient_dict['pred_liver'] = os.path.join(patient_folder, "liver_segmentation.nrrd")
patient_dict['pred_tumor'] = os.path.join(patient_folder, "tumor_segmentation.nrrd")
export_file_metadata.append(read(patient_dict['image'])[1])
#### Liver segmentation
# define liver data loading and preprocessing
if not only_tumor:
print("")
print("Producing liver segmentations....")
liver_preprocessing = Compose([
LoadImageD(keys=["image", "mask"], reader="NrrdReader", ensure_channel_first=True),
OrientationD(keys=["image", "mask"], axcodes="PLI"),
ScaleIntensityRangeD(keys=["image"],
a_min=config_liver['HU_RANGE'][0],
a_max=config_liver['HU_RANGE'][1],
b_min=0.0, b_max=1.0, clip=True
),
ConvertMaskValues(keys=["mask"], keep_classes=["liver"]),
ToTensorD(keys=["image", "mask"])
])
liver_postprocessing = Compose([
Activations(sigmoid=True),
AsDiscrete(argmax=True, to_onehot=None),
KeepLargestConnectedComponent(applied_labels=[1]),
FillHoles(applied_labels=[1]),
ToTensor()
])
test_ds_liver = Dataset(test_data_dict, transform=liver_preprocessing)
test_ds_liver = DataLoader(test_ds_liver, batch_size=config_liver['BATCH_SIZE'], collate_fn=custom_collate_fn, shuffle=False, num_workers=config_liver['NUM_WORKERS'])
# produce liver model results
test_metrics_liver, sample_output_liver = evaluate(liver_model, test_ds_liver, eval_metrics, config_liver, postprocessing_transforms=liver_postprocessing, export_filenames = [p['pred_liver'] for p in test_data_dict], export_file_metadata=export_file_metadata)
print("")
print("==============================")
print("Liver segmentation test performance ....")
for key, value in test_metrics_liver.items():
print(f' {key.replace("_avg", "_liver")}: {value:.3f}')
print("==============================")
##### Tumor segmentation
print("")
print("Producing tumor segmentations....")
# define tumor loading and preprocessing
tumor_preprocessing = Compose([
LoadImageD(keys=["image", "mask", "pred_liver"], reader="NrrdReader", ensure_channel_first=True),
OrientationD(keys=["image", "mask"], axcodes="PLI"),
MaskOutNonliver(mask_key="pred_liver"), # note that liver's predicted segmentation is used to crop to the liver region
ScaleIntensityRangeD(keys=["image"],
a_min=config_tumor['HU_RANGE'][0],
a_max=config_tumor['HU_RANGE'][1],
b_min=0.0, b_max=1.0, clip=True
),
ConvertMaskValues(keys=["mask"], keep_classes=["liver", "tumor"]), # format mask for measuring test performance
AsDiscreteD(keys=["mask"], to_onehot=3), # format mask for measuring test performance
ToTensorD(keys=["image", "mask", "pred_liver"])
])
tumor_postprocessing = Compose([
Activations(sigmoid=True),
AsDiscrete(argmax=True, to_onehot=3),
ToTensor()
])
test_ds_tumor = Dataset(test_data_dict, transform=tumor_preprocessing)
test_ds_tumor = DataLoader(test_ds_tumor, batch_size=config_tumor['BATCH_SIZE'], collate_fn=custom_collate_fn, shuffle=False, num_workers=config_tumor['NUM_WORKERS'])
test_metrics_tumor, sample_output_tumor = evaluate(tumor_model, test_ds_tumor, eval_metrics, config_tumor, tumor_postprocessing, use_liver_seg = True, export_filenames = [p['pred_tumor'] for p in test_data_dict] if export else [], export_file_metadata=export_file_metadata)
print("")
print("==============================")
print("Tumor segmentation test performance ....")
for key, value in test_metrics_tumor.items():
if "class2" in key:
print(f' {key.replace("_class2", "_tumor")}: {value:.3f}')
print("==============================")
print("")
#### Visualization
# combine liver and tumor segmentations into one segmentation output
if not only_tumor: sample_output_tumor[2][0][1] = sample_output_liver[2][0][0]
# visualization
print("")
if not only_tumor:
visualize_results(sample_output_liver[0][0].cpu(), sample_output_tumor[1][0].cpu(), sample_output_tumor[2][0].cpu(), n_slices=5, title="")
else:
visualize_results(sample_output_tumor[0][0].cpu(), sample_output_tumor[1][0].cpu(), sample_output_tumor[2][0].cpu(), n_slices=5, title="")
return
|