|
import os |
|
from sklearn.model_selection import train_test_split |
|
import monai |
|
from monai.data import Dataset, DataLoader |
|
from data_transforms import define_transforms, define_transforms_loadonly |
|
import torch |
|
import numpy as np |
|
from visualization import visualize_patient |
|
from monai.data import list_data_collate |
|
import pandas as pd |
|
|
|
|
|
def prepare_clinical_data(data_file, predictors): |
|
|
|
|
|
info = pd.read_excel(data_file, sheet_name=0) |
|
|
|
|
|
info['CPS'] = info['CPS'].map({'A': 1, 'B': 2, 'C': 3}) |
|
info['T_involvment'] = info['T_involvment'].map({'< or = 50%': 1, '>50%': 2}) |
|
info['CLIP_Score'] = info['CLIP_Score'].map({'Stage_0': 0, 'Stage_1': 1, 'Stage_2': 2, 'Stage_3': 3, 'Stage_4': 4, 'Stage_5': 5, 'Stage_6': 6}) |
|
info['Okuda'] = info['Okuda'].map({'Stage I': 1, 'Stage II': 2, 'Stage III': 3}) |
|
info['TNM'] = info['TNM'].map({'Stage-I': 1, 'Stage-II': 2, 'Stage-IIIA': 3, 'Stage-IIIB': 4, 'Stage-IIIC': 5, 'Stage-IVA': 6, 'Stage-IVB': 7}) |
|
info['BCLC'] = info['BCLC'].map({'0': 0, 'Stage-A': 1, 'Stage-B': 2, 'Stage-C': 3, 'Stage-D': 4}) |
|
|
|
|
|
info.groupby("TCIA_ID").first() |
|
|
|
|
|
info = info[['TCIA_ID'] + predictors].rename(columns={'TCIA_ID': "patient_id"}) |
|
|
|
|
|
return info |
|
|
|
|
|
|
|
def preparare_train_test_txt(data_dir, test_patient_ratio=0.2, seed=1): |
|
""" |
|
From a list of patients, split them into train and test and export list to .txt files |
|
""" |
|
|
|
|
|
patients = os.listdir(data_dir) |
|
patients.remove("HCC-TACE-Seg_clinical_data-V2.xlsx") |
|
patients = list(set(patients)) |
|
|
|
|
|
try: |
|
patients.remove("HCC_017") |
|
print("The patient HCC_017 is removed due to label issues including necrosis.") |
|
except Exception as e: |
|
pass |
|
|
|
print("Total patients:", len(patients)) |
|
patients_train, patients_test = train_test_split(patients, test_size=test_patient_ratio, random_state=seed) |
|
print(" There are", len(patients_train), "patients in training") |
|
print(" There are", len(patients_test), "patients in test") |
|
|
|
|
|
if not os.path.exists('train-test-split-seed' + str(seed)): |
|
os.makedirs('train-test-split-seed' + str(seed)) |
|
with open(r'train-test-split-seed' + str(seed) + '/train.txt', 'w') as f: |
|
f.write(','.join(patient for patient in patients_train)) |
|
with open(r'train-test-split-seed' + str(seed) + '/test.txt', 'w') as f: |
|
f.write(','.join(patient for patient in patients_test)) |
|
|
|
print("Files saved to", 'train-test-split-seed' + str(seed) + '/train.txt and train-test-split-seed' + str(seed) + '/test.txt') |
|
return |
|
|
|
|
|
|
|
|
|
def extract_file_path(patient_id, data_folder): |
|
""" |
|
Given one patient's ID, obtain the file path of the image and mask data. |
|
If patient has multiple images, they are labeled as pre1, pre2, etc. |
|
""" |
|
path = os.path.join(data_folder, patient_id) |
|
files = os.listdir(path) |
|
patient_files = {} |
|
count = 1 |
|
for file in files: |
|
if "seg" in file or "Segmentation" in file: |
|
patient_files["mask"] = os.path.join(path, file) |
|
else: |
|
patient_files["pre_" + str(count)] = os.path.join(path, file) |
|
count += 1 |
|
return patient_files |
|
|
|
|
|
|
|
def get_patient_dictionaries(txt_file, data_dir): |
|
""" |
|
From .txt file that stores list of patients, look through data folders and extract a dictionary of patient data |
|
""" |
|
assert os.path.isfile(txt_file), "The file " + txt_file + " was not found. Please check your file directory." |
|
|
|
file = open(txt_file, "r") |
|
patients = file.read().split(',') |
|
|
|
data_dict = [] |
|
|
|
for patient_id in patients: |
|
|
|
|
|
patient_files = extract_file_path(patient_id, data_dir) |
|
|
|
|
|
for key, value in patient_files.items(): |
|
if key != "mask": |
|
data_dict.append( |
|
{ |
|
"patient_id": patient_id, |
|
"image": patient_files[key], |
|
"mask": patient_files["mask"] |
|
} |
|
) |
|
|
|
print(" There are", len(data_dict), "image-masks in this dataset.") |
|
return data_dict |
|
|
|
|
|
|
|
|
|
def build_dataset(config, get_clinical=False): |
|
|
|
def custom_collate_fn(batch): |
|
""" |
|
Custom collate function to stack samples along the first dimension. |
|
|
|
Args: |
|
batch (list): List of dictionaries with keys "image" and "mask", |
|
where values are tensors of shape (N, 1, 512, 512). |
|
|
|
Returns: |
|
tuple: Tuple containing two tensors: |
|
- Stacked images of shape (B, 1, 512, 512) |
|
- Stacked masks of shape (B, 1, 512, 512) |
|
where B is the total number of samples in the batch. |
|
""" |
|
|
|
num_samples_to_select = config['BATCH_SIZE'] |
|
|
|
|
|
images, masks = [], [] |
|
for sample in batch: |
|
num_samples = min(sample["image"].shape[0], sample["mask"].shape[0]) |
|
random_indices = torch.randperm(num_samples)[:num_samples_to_select] |
|
if "3D" in config['MODEL_NAME']: |
|
images.append(sample["image"][:,:512,:512,:]) |
|
masks.append(sample["mask"][:,:512,:512,:]) |
|
else: |
|
images.append(sample["image"][random_indices,:,:512,:512]) |
|
masks.append(sample["mask"][random_indices,:,:512,:512]) |
|
|
|
|
|
|
|
|
|
try: |
|
if "3D" not in config['MODEL_NAME']: |
|
concatenated_images = torch.cat(images, dim=0) |
|
concatenated_masks = torch.cat(masks, dim=0) |
|
else: |
|
concatenated_images = torch.stack(images, dim=0) |
|
concatenated_masks = torch.stack(masks, dim=0) |
|
except Exception as e: |
|
print("WARNING: not all images/masks are 512 by 512. Please check. ", images[0].shape, images[1].shape, masks[0].shape, masks[1].shape) |
|
return None, None |
|
|
|
|
|
return {"image": concatenated_images, "mask": concatenated_masks} |
|
|
|
|
|
train_data_dict = get_patient_dictionaries(config['TRAIN_PATIENTS_FILE'], config['DATA_DIR']) |
|
test_data_dict = get_patient_dictionaries(config['TEST_PATIENTS_FILE'], config['DATA_DIR']) |
|
if config['ONESAMPLETESTRUN']: train_data_dict = train_data_dict[:2] |
|
ttrain_data_dict, valid_data_dict = train_test_split(train_data_dict, test_size=config['VALID_PATIENT_RATIO'], shuffle=False, random_state=1) |
|
print(" Training patients:", len(ttrain_data_dict), " Validation patients:", len(valid_data_dict)) |
|
print(" Test patients:", len(test_data_dict)) |
|
|
|
|
|
preprocessing_transforms_train, preprocessing_transforms_test, postprocessing_transforms = define_transforms(config) |
|
|
|
|
|
train_ds = Dataset(ttrain_data_dict, transform=preprocessing_transforms_train) |
|
valid_ds = Dataset(valid_data_dict, transform=preprocessing_transforms_test) |
|
test_ds = Dataset(test_data_dict, transform=preprocessing_transforms_test) |
|
|
|
if "3D" in config['MODEL_NAME']: |
|
train_loader = DataLoader(train_ds, batch_size=config['BATCH_SIZE'], collate_fn=custom_collate_fn, shuffle=False, num_workers=config['NUM_WORKERS']) |
|
valid_loader = DataLoader(valid_ds, batch_size=config['BATCH_SIZE'], collate_fn=custom_collate_fn, shuffle=False, num_workers=config['NUM_WORKERS']) |
|
test_loader = DataLoader(test_ds, batch_size=config['BATCH_SIZE'], collate_fn=custom_collate_fn, shuffle=False, num_workers=config['NUM_WORKERS']) |
|
else: |
|
train_loader = DataLoader(train_ds, batch_size=1, shuffle=False, collate_fn=custom_collate_fn, num_workers=config['NUM_WORKERS']) |
|
valid_loader = DataLoader(valid_ds, batch_size=1, shuffle=False, collate_fn=custom_collate_fn, num_workers=config['NUM_WORKERS']) |
|
test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, collate_fn=custom_collate_fn, num_workers=config['NUM_WORKERS']) |
|
|
|
|
|
df_clinical_train = pd.DataFrame() |
|
if get_clinical: |
|
|
|
simple_transforms = define_transforms_loadonly() |
|
simple_train_ds = Dataset(train_data_dict, transform=simple_transforms) |
|
simple_train_loader = DataLoader(simple_train_ds, batch_size=config['BATCH_SIZE'], collate_fn=list_data_collate, shuffle=False, num_workers=config['NUM_WORKERS']) |
|
|
|
|
|
df_clinical_train['patient_id'] = [p["patient_id"] for p in train_data_dict] |
|
ratios_train, ratios_test = [], [] |
|
for batch_data in simple_train_loader: |
|
labels = batch_data["mask"] |
|
ratio = torch.sum(labels == 2, dim=(1, 2, 3, 4)) / torch.sum(labels > 0, dim=(1, 2, 3, 4)) |
|
ratios_train.append(ratio.cpu().numpy()[0]) |
|
df_clinical_train['tumor_ratio'] = ratios_train |
|
|
|
|
|
info = prepare_clinical_data(config['CLINICAL_DATA_FILE'], config['CLINICAL_PREDICTORS']) |
|
df_clinical_train = pd.merge(df_clinical_train, info, on='patient_id', how="left") |
|
df_clinical_train.fillna(df_clinical_train.median(), inplace=True) |
|
df_clinical_train.set_index("patient_id", inplace=True) |
|
|
|
|
|
print("Example data transformations:") |
|
while True: |
|
sample = preprocessing_transforms_train(train_data_dict[0]) |
|
if isinstance(sample, list): |
|
sample = sample[0] |
|
if torch.sum(sample['mask'][-1]) == 0: continue |
|
print(f" image shape: {sample['image'].shape}") |
|
print(f" mask shape: {sample['mask'].shape}") |
|
print(f" mask values: {np.unique(sample['mask'])}") |
|
|
|
print(f" image min max: {np.min(sample['image']), np.max(sample['image'])}") |
|
visualize_patient(sample['image'], sample['mask'], n_slices=3, z_dim_last="3D" in config['MODEL_NAME'], mask_channel=-1) |
|
break |
|
|
|
temp = monai.utils.first(test_loader) |
|
print("Test loader shapes:", temp['image'].shape, temp['mask'].shape) |
|
|
|
return train_loader, valid_loader, test_loader, postprocessing_transforms, df_clinical_train |
|
|
|
|