medassist-liver-cancer / utils /data_preparation.py
lingchmao's picture
Upload 12 files
6ffe23f verified
raw
history blame
11.2 kB
import os
from sklearn.model_selection import train_test_split
import monai
from monai.data import Dataset, DataLoader
from data_transforms import define_transforms, define_transforms_loadonly
import torch
import numpy as np
from visualization import visualize_patient
from monai.data import list_data_collate
import pandas as pd
def prepare_clinical_data(data_file, predictors):
# read data file
info = pd.read_excel(data_file, sheet_name=0)
# convert to numerical
info['CPS'] = info['CPS'].map({'A': 1, 'B': 2, 'C': 3})
info['T_involvment'] = info['T_involvment'].map({'< or = 50%': 1, '>50%': 2})
info['CLIP_Score'] = info['CLIP_Score'].map({'Stage_0': 0, 'Stage_1': 1, 'Stage_2': 2, 'Stage_3': 3, 'Stage_4': 4, 'Stage_5': 5, 'Stage_6': 6})
info['Okuda'] = info['Okuda'].map({'Stage I': 1, 'Stage II': 2, 'Stage III': 3})
info['TNM'] = info['TNM'].map({'Stage-I': 1, 'Stage-II': 2, 'Stage-IIIA': 3, 'Stage-IIIB': 4, 'Stage-IIIC': 5, 'Stage-IVA': 6, 'Stage-IVB': 7})
info['BCLC'] = info['BCLC'].map({'0': 0, 'Stage-A': 1, 'Stage-B': 2, 'Stage-C': 3, 'Stage-D': 4})
# remove duplicates
info.groupby("TCIA_ID").first()
# select columns
info = info[['TCIA_ID'] + predictors].rename(columns={'TCIA_ID': "patient_id"})
return info
def preparare_train_test_txt(data_dir, test_patient_ratio=0.2, seed=1):
"""
From a list of patients, split them into train and test and export list to .txt files
"""
# split based on seed, write to txt files
patients = os.listdir(data_dir)
patients.remove("HCC-TACE-Seg_clinical_data-V2.xlsx")
patients = list(set(patients))
# remove one patient with wrong labels
try:
patients.remove("HCC_017")
print("The patient HCC_017 is removed due to label issues including necrosis.")
except Exception as e:
pass
print("Total patients:", len(patients))
patients_train, patients_test = train_test_split(patients, test_size=test_patient_ratio, random_state=seed)
print(" There are", len(patients_train), "patients in training")
print(" There are", len(patients_test), "patients in test")
# export a copy
if not os.path.exists('train-test-split-seed' + str(seed)):
os.makedirs('train-test-split-seed' + str(seed))
with open(r'train-test-split-seed' + str(seed) + '/train.txt', 'w') as f:
f.write(','.join(patient for patient in patients_train))
with open(r'train-test-split-seed' + str(seed) + '/test.txt', 'w') as f:
f.write(','.join(patient for patient in patients_test))
print("Files saved to", 'train-test-split-seed' + str(seed) + '/train.txt and train-test-split-seed' + str(seed) + '/test.txt')
return
def extract_file_path(patient_id, data_folder):
"""
Given one patient's ID, obtain the file path of the image and mask data.
If patient has multiple images, they are labeled as pre1, pre2, etc.
"""
path = os.path.join(data_folder, patient_id)
files = os.listdir(path)
patient_files = {}
count = 1
for file in files:
if "seg" in file or "Segmentation" in file:
patient_files["mask"] = os.path.join(path, file)
else:
patient_files["pre_" + str(count)] = os.path.join(path, file)
count += 1
return patient_files
def get_patient_dictionaries(txt_file, data_dir):
"""
From .txt file that stores list of patients, look through data folders and extract a dictionary of patient data
"""
assert os.path.isfile(txt_file), "The file " + txt_file + " was not found. Please check your file directory."
file = open(txt_file, "r")
patients = file.read().split(',')
data_dict = []
for patient_id in patients:
# get directories for mask and images
patient_files = extract_file_path(patient_id, data_dir)
# pair up each image with the mask
for key, value in patient_files.items():
if key != "mask":
data_dict.append(
{
"patient_id": patient_id,
"image": patient_files[key],
"mask": patient_files["mask"]
}
)
print(" There are", len(data_dict), "image-masks in this dataset.")
return data_dict
def build_dataset(config, get_clinical=False):
def custom_collate_fn(batch):
"""
Custom collate function to stack samples along the first dimension.
Args:
batch (list): List of dictionaries with keys "image" and "mask",
where values are tensors of shape (N, 1, 512, 512).
Returns:
tuple: Tuple containing two tensors:
- Stacked images of shape (B, 1, 512, 512)
- Stacked masks of shape (B, 1, 512, 512)
where B is the total number of samples in the batch.
"""
# torch.manual_seed(1)
num_samples_to_select = config['BATCH_SIZE']
# Extract images and masks from the batch
images, masks = [], []
for sample in batch:
num_samples = min(sample["image"].shape[0], sample["mask"].shape[0])
random_indices = torch.randperm(num_samples)[:num_samples_to_select]
if "3D" in config['MODEL_NAME']: # 3D image
images.append(sample["image"][:,:512,:512,:]) # ensure image and mask same size
masks.append(sample["mask"][:,:512,:512,:])
else:
images.append(sample["image"][random_indices,:,:512,:512]) # ensure image and mask same size
masks.append(sample["mask"][random_indices,:,:512,:512])
#images.append(sample["image"][:,:,:512,:512]) # ensure image and mask same size
#masks.append(sample["mask"][:,:,:512,:512])
# Stack images and masks along the first dimension
try:
if "3D" not in config['MODEL_NAME']: # 3D image
concatenated_images = torch.cat(images, dim=0)
concatenated_masks = torch.cat(masks, dim=0)
else:
concatenated_images = torch.stack(images, dim=0)
concatenated_masks = torch.stack(masks, dim=0)
except Exception as e:
print("WARNING: not all images/masks are 512 by 512. Please check. ", images[0].shape, images[1].shape, masks[0].shape, masks[1].shape)
return None, None
# Return stacked images and masks as tensors
return {"image": concatenated_images, "mask": concatenated_masks}
# get list of training and test patient files
train_data_dict = get_patient_dictionaries(config['TRAIN_PATIENTS_FILE'], config['DATA_DIR'])
test_data_dict = get_patient_dictionaries(config['TEST_PATIENTS_FILE'], config['DATA_DIR'])
if config['ONESAMPLETESTRUN']: train_data_dict = train_data_dict[:2]
ttrain_data_dict, valid_data_dict = train_test_split(train_data_dict, test_size=config['VALID_PATIENT_RATIO'], shuffle=False, random_state=1) # must be false to match with linical data
print(" Training patients:", len(ttrain_data_dict), " Validation patients:", len(valid_data_dict))
print(" Test patients:", len(test_data_dict))
# define data transformations
preprocessing_transforms_train, preprocessing_transforms_test, postprocessing_transforms = define_transforms(config)
# create data loaders
train_ds = Dataset(ttrain_data_dict, transform=preprocessing_transforms_train)
valid_ds = Dataset(valid_data_dict, transform=preprocessing_transforms_test)
test_ds = Dataset(test_data_dict, transform=preprocessing_transforms_test)
if "3D" in config['MODEL_NAME']:
train_loader = DataLoader(train_ds, batch_size=config['BATCH_SIZE'], collate_fn=custom_collate_fn, shuffle=False, num_workers=config['NUM_WORKERS'])
valid_loader = DataLoader(valid_ds, batch_size=config['BATCH_SIZE'], collate_fn=custom_collate_fn, shuffle=False, num_workers=config['NUM_WORKERS'])
test_loader = DataLoader(test_ds, batch_size=config['BATCH_SIZE'], collate_fn=custom_collate_fn, shuffle=False, num_workers=config['NUM_WORKERS'])
else:
train_loader = DataLoader(train_ds, batch_size=1, shuffle=False, collate_fn=custom_collate_fn, num_workers=config['NUM_WORKERS']) #, pin_memory=torch.cuda.is_available())
valid_loader = DataLoader(valid_ds, batch_size=1, shuffle=False, collate_fn=custom_collate_fn, num_workers=config['NUM_WORKERS']) #, pin_memory=torch.cuda.is_available())
test_loader = DataLoader(test_ds, batch_size=1, shuffle=False, collate_fn=custom_collate_fn, num_workers=config['NUM_WORKERS']) #, pin_memory=torch.cuda.is_available())
# get clinical data
df_clinical_train = pd.DataFrame()
if get_clinical:
# define transforms
simple_transforms = define_transforms_loadonly()
simple_train_ds = Dataset(train_data_dict, transform=simple_transforms)
simple_train_loader = DataLoader(simple_train_ds, batch_size=config['BATCH_SIZE'], collate_fn=list_data_collate, shuffle=False, num_workers=config['NUM_WORKERS']) #, pin_memory=torch.cuda.is_available())
# compute tumor ratio within liver
df_clinical_train['patient_id'] = [p["patient_id"] for p in train_data_dict]
ratios_train, ratios_test = [], []
for batch_data in simple_train_loader:
labels = batch_data["mask"]
ratio = torch.sum(labels == 2, dim=(1, 2, 3, 4)) / torch.sum(labels > 0, dim=(1, 2, 3, 4))
ratios_train.append(ratio.cpu().numpy()[0]) # [metatensor()]
df_clinical_train['tumor_ratio'] = ratios_train
# get clinical features
info = prepare_clinical_data(config['CLINICAL_DATA_FILE'], config['CLINICAL_PREDICTORS'])
df_clinical_train = pd.merge(df_clinical_train, info, on='patient_id', how="left")
df_clinical_train.fillna(df_clinical_train.median(), inplace=True)
df_clinical_train.set_index("patient_id", inplace=True)
# visualize the data loader for one image to ensure correct formatting
print("Example data transformations:")
while True:
sample = preprocessing_transforms_train(train_data_dict[0])
if isinstance(sample, list): # depending on preprocessing, one sample may be [sample] or sample
sample = sample[0]
if torch.sum(sample['mask'][-1]) == 0: continue
print(f" image shape: {sample['image'].shape}")
print(f" mask shape: {sample['mask'].shape}")
print(f" mask values: {np.unique(sample['mask'])}")
#print(f" image affine:\n{sample['image'].meta['affine']}")
print(f" image min max: {np.min(sample['image']), np.max(sample['image'])}")
visualize_patient(sample['image'], sample['mask'], n_slices=3, z_dim_last="3D" in config['MODEL_NAME'], mask_channel=-1)
break
temp = monai.utils.first(test_loader)
print("Test loader shapes:", temp['image'].shape, temp['mask'].shape)
return train_loader, valid_loader, test_loader, postprocessing_transforms, df_clinical_train