File size: 11,218 Bytes
6ffe23f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 |
import os
from sklearn.model_selection import train_test_split
import monai
from monai.data import Dataset, DataLoader
from data_transforms import define_transforms, define_transforms_loadonly
import torch
import numpy as np
from visualization import visualize_patient
from monai.data import list_data_collate
import pandas as pd
def prepare_clinical_data(data_file, predictors):
# read data file
info = pd.read_excel(data_file, sheet_name=0)
# convert to numerical
info['CPS'] = info['CPS'].map({'A': 1, 'B': 2, 'C': 3})
info['T_involvment'] = info['T_involvment'].map({'< or = 50%': 1, '>50%': 2})
info['CLIP_Score'] = info['CLIP_Score'].map({'Stage_0': 0, 'Stage_1': 1, 'Stage_2': 2, 'Stage_3': 3, 'Stage_4': 4, 'Stage_5': 5, 'Stage_6': 6})
info['Okuda'] = info['Okuda'].map({'Stage I': 1, 'Stage II': 2, 'Stage III': 3})
info['TNM'] = info['TNM'].map({'Stage-I': 1, 'Stage-II': 2, 'Stage-IIIA': 3, 'Stage-IIIB': 4, 'Stage-IIIC': 5, 'Stage-IVA': 6, 'Stage-IVB': 7})
info['BCLC'] = info['BCLC'].map({'0': 0, 'Stage-A': 1, 'Stage-B': 2, 'Stage-C': 3, 'Stage-D': 4})
# remove duplicates
info.groupby("TCIA_ID").first()
# select columns
info = info[['TCIA_ID'] + predictors].rename(columns={'TCIA_ID': "patient_id"})
return info
def preparare_train_test_txt(data_dir, test_patient_ratio=0.2, seed=1):
"""
From a list of patients, split them into train and test and export list to .txt files
"""
# split based on seed, write to txt files
patients = os.listdir(data_dir)
patients.remove("HCC-TACE-Seg_clinical_data-V2.xlsx")
patients = list(set(patients))
# remove one patient with wrong labels
try:
patients.remove("HCC_017")
print("The patient HCC_017 is removed due to label issues including necrosis.")
except Exception as e:
pass
print("Total patients:", len(patients))
patients_train, patients_test = train_test_split(patients, test_size=test_patient_ratio, random_state=seed)
print(" There are", len(patients_train), "patients in training")
print(" There are", len(patients_test), "patients in test")
# export a copy
if not os.path.exists('train-test-split-seed' + str(seed)):
os.makedirs('train-test-split-seed' + str(seed))
with open(r'train-test-split-seed' + str(seed) + '/train.txt', 'w') as f:
f.write(','.join(patient for patient in patients_train))
with open(r'train-test-split-seed' + str(seed) + '/test.txt', 'w') as f:
f.write(','.join(patient for patient in patients_test))
print("Files saved to", 'train-test-split-seed' + str(seed) + '/train.txt and train-test-split-seed' + str(seed) + '/test.txt')
return
def extract_file_path(patient_id, data_folder):
"""
Given one patient's ID, obtain the file path of the image and mask data.
If patient has multiple images, they are labeled as pre1, pre2, etc.
"""
path = os.path.join(data_folder, patient_id)
files = os.listdir(path)
patient_files = {}
count = 1
for file in files:
if "seg" in file or "Segmentation" in file:
patient_files["mask"] = os.path.join(path, file)
else:
patient_files["pre_" + str(count)] = os.path.join(path, file)
count += 1
return patient_files
def get_patient_dictionaries(txt_file, data_dir):
"""
From .txt file that stores list of patients, look through data folders and extract a dictionary of patient data
"""
assert os.path.isfile(txt_file), "The file " + txt_file + " was not found. Please check your file directory."
file = open(txt_file, "r")
patients = file.read().split(',')
data_dict = []
for patient_id in patients:
# get directories for mask and images
patient_files = extract_file_path(patient_id, data_dir)
# pair up each image with the mask
for key, value in patient_files.items():
if key != "mask":
data_dict.append(
{
"patient_id": patient_id,
"image": patient_files[key],
"mask": patient_files["mask"]
}
)
print(" There are", len(data_dict), "image-masks in this dataset.")
return data_dict
def build_dataset(config, get_clinical=False):
def custom_collate_fn(batch):
"""
Custom collate function to stack samples along the first dimension.
Args:
batch (list): List of dictionaries with keys "image" and "mask",
where values are tensors of shape (N, 1, 512, 512).
Returns:
tuple: Tuple containing two tensors:
- Stacked images of shape (B, 1, 512, 512)
- Stacked masks of shape (B, 1, 512, 512)
where B is the total number of samples in the batch.
"""
# torch.manual_seed(1)
num_samples_to_select = config['BATCH_SIZE']
# Extract images and masks from the batch
images, masks = [], []
for sample in batch:
num_samples = min(sample["image"].shape[0], sample["mask"].shape[0])
random_indices = torch.randperm(num_samples)[:num_samples_to_select]
if "3D" in config['MODEL_NAME']: # 3D image
images.append(sample["image"][:,:512,:512,:]) # ensure image and mask same size
masks.append(sample["mask"][:,:512,:512,:])
else:
images.append(sample["image"][random_indices,:,:512,:512]) # ensure image and mask same size
masks.append(sample["mask"][random_indices,:,:512,:512])
#images.append(sample["image"][:,:,:512,:512]) # ensure image and mask same size
#masks.append(sample["mask"][:,:,:512,:512])
# Stack images and masks along the first dimension
try:
if "3D" not in config['MODEL_NAME']: # 3D image
concatenated_images = torch.cat(images, dim=0)
concatenated_masks = torch.cat(masks, dim=0)
else:
concatenated_images = torch.stack(images, dim=0)
concatenated_masks = torch.stack(masks, dim=0)
except Exception as e:
print("WARNING: not all images/masks are 512 by 512. Please check. ", images[0].shape, images[1].shape, masks[0].shape, masks[1].shape)
return None, None
# Return stacked images and masks as tensors
return {"image": concatenated_images, "mask": concatenated_masks}
# get list of training and test patient files
train_data_dict = get_patient_dictionaries(config['TRAIN_PATIENTS_FILE'], config['DATA_DIR'])
test_data_dict = get_patient_dictionaries(config['TEST_PATIENTS_FILE'], config['DATA_DIR'])
if config['ONESAMPLETESTRUN']: train_data_dict = train_data_dict[:2]
ttrain_data_dict, valid_data_dict = train_test_split(train_data_dict, test_size=config['VALID_PATIENT_RATIO'], shuffle=False, random_state=1) # must be false to match with linical data
print(" Training patients:", len(ttrain_data_dict), " Validation patients:", len(valid_data_dict))
print(" Test patients:", len(test_data_dict))
# define data transformations
preprocessing_transforms_train, preprocessing_transforms_test, postprocessing_transforms = define_transforms(config)
# create data loaders
train_ds = Dataset(ttrain_data_dict, transform=preprocessing_transforms_train)
valid_ds = Dataset(valid_data_dict, transform=preprocessing_transforms_test)
test_ds = Dataset(test_data_dict, transform=preprocessing_transforms_test)
if "3D" in config['MODEL_NAME']:
train_loader = DataLoader(train_ds, batch_size=config['BATCH_SIZE'], collate_fn=custom_collate_fn, shuffle=False, num_workers=config['NUM_WORKERS'])
valid_loader = DataLoader(valid_ds, batch_size=config['BATCH_SIZE'], collate_fn=custom_collate_fn, shuffle=False, num_workers=config['NUM_WORKERS'])
test_loader = DataLoader(test_ds, batch_size=config['BATCH_SIZE'], collate_fn=custom_collate_fn, shuffle=False, num_workers=config['NUM_WORKERS'])
else:
train_loader = DataLoader(train_ds, batch_size=1, shuffle=False, collate_fn=custom_collate_fn, num_workers=config['NUM_WORKERS']) #, pin_memory=torch.cuda.is_available())
valid_loader = DataLoader(valid_ds, batch_size=1, shuffle=False, collate_fn=custom_collate_fn, num_workers=config['NUM_WORKERS']) #, pin_memory=torch.cuda.is_available())
test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, collate_fn=custom_collate_fn, num_workers=config['NUM_WORKERS']) #, pin_memory=torch.cuda.is_available())
# get clinical data
df_clinical_train = pd.DataFrame()
if get_clinical:
# define transforms
simple_transforms = define_transforms_loadonly()
simple_train_ds = Dataset(train_data_dict, transform=simple_transforms)
simple_train_loader = DataLoader(simple_train_ds, batch_size=config['BATCH_SIZE'], collate_fn=list_data_collate, shuffle=False, num_workers=config['NUM_WORKERS']) #, pin_memory=torch.cuda.is_available())
# compute tumor ratio within liver
df_clinical_train['patient_id'] = [p["patient_id"] for p in train_data_dict]
ratios_train, ratios_test = [], []
for batch_data in simple_train_loader:
labels = batch_data["mask"]
ratio = torch.sum(labels == 2, dim=(1, 2, 3, 4)) / torch.sum(labels > 0, dim=(1, 2, 3, 4))
ratios_train.append(ratio.cpu().numpy()[0]) # [metatensor()]
df_clinical_train['tumor_ratio'] = ratios_train
# get clinical features
info = prepare_clinical_data(config['CLINICAL_DATA_FILE'], config['CLINICAL_PREDICTORS'])
df_clinical_train = pd.merge(df_clinical_train, info, on='patient_id', how="left")
df_clinical_train.fillna(df_clinical_train.median(), inplace=True)
df_clinical_train.set_index("patient_id", inplace=True)
# visualize the data loader for one image to ensure correct formatting
print("Example data transformations:")
while True:
sample = preprocessing_transforms_train(train_data_dict[0])
if isinstance(sample, list): # depending on preprocessing, one sample may be [sample] or sample
sample = sample[0]
if torch.sum(sample['mask'][-1]) == 0: continue
print(f" image shape: {sample['image'].shape}")
print(f" mask shape: {sample['mask'].shape}")
print(f" mask values: {np.unique(sample['mask'])}")
#print(f" image affine:\n{sample['image'].meta['affine']}")
print(f" image min max: {np.min(sample['image']), np.max(sample['image'])}")
visualize_patient(sample['image'], sample['mask'], n_slices=3, z_dim_last="3D" in config['MODEL_NAME'], mask_channel=-1)
break
temp = monai.utils.first(test_loader)
print("Test loader shapes:", temp['image'].shape, temp['mask'].shape)
return train_loader, valid_loader, test_loader, postprocessing_transforms, df_clinical_train
|