mueller-franzes's picture
init
f85e212
raw
history blame
8.54 kB
import torch.utils.data as data
import torch
from torch import nn
from pathlib import Path
from torchvision import transforms as T
import pandas as pd
from PIL import Image
from medical_diffusion.data.augmentation.augmentations_2d import Normalize, ToTensor16bit
class SimpleDataset2D(data.Dataset):
def __init__(
self,
path_root,
item_pointers =[],
crawler_ext = 'tif', # other options are ['jpg', 'jpeg', 'png', 'tiff'],
transform = None,
image_resize = None,
augment_horizontal_flip = False,
augment_vertical_flip = False,
image_crop = None,
):
super().__init__()
self.path_root = Path(path_root)
self.crawler_ext = crawler_ext
if len(item_pointers):
self.item_pointers = item_pointers
else:
self.item_pointers = self.run_item_crawler(self.path_root, self.crawler_ext)
if transform is None:
self.transform = T.Compose([
T.Resize(image_resize) if image_resize is not None else nn.Identity(),
T.RandomHorizontalFlip() if augment_horizontal_flip else nn.Identity(),
T.RandomVerticalFlip() if augment_vertical_flip else nn.Identity(),
T.CenterCrop(image_crop) if image_crop is not None else nn.Identity(),
T.ToTensor(),
# T.Lambda(lambda x: torch.cat([x]*3) if x.shape[0]==1 else x),
# ToTensor16bit(),
# Normalize(), # [0, 1.0]
# T.ConvertImageDtype(torch.float),
T.Normalize(mean=0.5, std=0.5) # WARNING: mean and std are not the target values but rather the values to subtract and divide by: [0, 1] -> [0-0.5, 1-0.5]/0.5 -> [-1, 1]
])
else:
self.transform = transform
def __len__(self):
return len(self.item_pointers)
def __getitem__(self, index):
rel_path_item = self.item_pointers[index]
path_item = self.path_root/rel_path_item
# img = Image.open(path_item)
img = self.load_item(path_item)
return {'uid':rel_path_item.stem, 'source': self.transform(img)}
def load_item(self, path_item):
return Image.open(path_item).convert('RGB')
# return cv2.imread(str(path_item), cv2.IMREAD_UNCHANGED) # NOTE: Only CV2 supports 16bit RGB images
@classmethod
def run_item_crawler(cls, path_root, extension, **kwargs):
return [path.relative_to(path_root) for path in Path(path_root).rglob(f'*.{extension}')]
def get_weights(self):
"""Return list of class-weights for WeightedSampling"""
return None
class AIROGSDataset(SimpleDataset2D):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.labels = pd.read_csv(self.path_root.parent/'train_labels.csv', index_col='challenge_id')
def __len__(self):
return len(self.labels)
def __getitem__(self, index):
uid = self.labels.index[index]
path_item = self.path_root/f'{uid}.jpg'
img = self.load_item(path_item)
str_2_int = {'NRG':0, 'RG':1} # RG = 3270, NRG = 98172
target = str_2_int[self.labels.loc[uid, 'class']]
# return {'uid':uid, 'source': self.transform(img), 'target':target}
return {'source': self.transform(img), 'target':target}
def get_weights(self):
n_samples = len(self)
weight_per_class = 1/self.labels['class'].value_counts(normalize=True) # {'NRG': 1.03, 'RG': 31.02}
weights = [0] * n_samples
for index in range(n_samples):
target = self.labels.iloc[index]['class']
weights[index] = weight_per_class[target]
return weights
@classmethod
def run_item_crawler(cls, path_root, extension, **kwargs):
"""Overwrite to speed up as paths are determined by .csv file anyway"""
return []
class MSIvsMSS_Dataset(SimpleDataset2D):
# https://doi.org/10.5281/zenodo.2530835
def __getitem__(self, index):
rel_path_item = self.item_pointers[index]
path_item = self.path_root/rel_path_item
img = self.load_item(path_item)
uid = rel_path_item.stem
str_2_int = {'MSIMUT':0, 'MSS':1}
target = str_2_int[path_item.parent.name] #
return {'uid':uid, 'source': self.transform(img), 'target':target}
class MSIvsMSS_2_Dataset(SimpleDataset2D):
# https://doi.org/10.5281/zenodo.3832231
def __getitem__(self, index):
rel_path_item = self.item_pointers[index]
path_item = self.path_root/rel_path_item
img = self.load_item(path_item)
uid = rel_path_item.stem
str_2_int = {'MSIH':0, 'nonMSIH':1} # patients with MSI-H = MSIH; patients with MSI-L and MSS = NonMSIH)
target = str_2_int[path_item.parent.name]
# return {'uid':uid, 'source': self.transform(img), 'target':target}
return {'source': self.transform(img), 'target':target}
class CheXpert_Dataset(SimpleDataset2D):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
mode = self.path_root.name
labels = pd.read_csv(self.path_root.parent/f'{mode}.csv', index_col='Path')
self.labels = labels.loc[labels['Frontal/Lateral'] == 'Frontal'].copy()
self.labels.index = self.labels.index.str[20:]
self.labels.loc[self.labels['Sex'] == 'Unknown', 'Sex'] = 'Female' # Affects 1 case, must be "female" to match stats in publication
self.labels.fillna(2, inplace=True) # TODO: Find better solution,
str_2_int = {'Sex': {'Male':0, 'Female':1}, 'Frontal/Lateral':{'Frontal':0, 'Lateral':1}, 'AP/PA':{'AP':0, 'PA':1}}
self.labels.replace(str_2_int, inplace=True)
def __len__(self):
return len(self.labels)
def __getitem__(self, index):
rel_path_item = self.labels.index[index]
path_item = self.path_root/rel_path_item
img = self.load_item(path_item)
uid = str(rel_path_item)
target = torch.tensor(self.labels.loc[uid, 'Cardiomegaly']+1, dtype=torch.long) # Note Labels are -1=uncertain, 0=negative, 1=positive, NA=not reported -> Map to [0, 2], NA=3
return {'uid':uid, 'source': self.transform(img), 'target':target}
@classmethod
def run_item_crawler(cls, path_root, extension, **kwargs):
"""Overwrite to speed up as paths are determined by .csv file anyway"""
return []
class CheXpert_2_Dataset(SimpleDataset2D):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
labels = pd.read_csv(self.path_root/'labels/cheXPert_label.csv', index_col=['Path', 'Image Index']) # Note: 1 and -1 (uncertain) cases count as positives (1), 0 and NA count as negatives (0)
labels = labels.loc[labels['fold']=='train'].copy()
labels = labels.drop(labels='fold', axis=1)
labels2 = pd.read_csv(self.path_root/'labels/train.csv', index_col='Path')
labels2 = labels2.loc[labels2['Frontal/Lateral'] == 'Frontal'].copy()
labels2 = labels2[['Cardiomegaly',]].copy()
labels2[ (labels2 <0) | labels2.isna()] = 2 # 0 = Negative, 1 = Positive, 2 = Uncertain
labels = labels.join(labels2['Cardiomegaly'], on=["Path",], rsuffix='_true')
# labels = labels[labels['Cardiomegaly_true']!=2]
self.labels = labels
def __len__(self):
return len(self.labels)
def __getitem__(self, index):
path_index, image_index = self.labels.index[index]
path_item = self.path_root/'data'/f'{image_index:06}.png'
img = self.load_item(path_item)
uid = image_index
target = int(self.labels.loc[(path_index, image_index), 'Cardiomegaly'])
# return {'uid':uid, 'source': self.transform(img), 'target':target}
return {'source': self.transform(img), 'target':target}
@classmethod
def run_item_crawler(cls, path_root, extension, **kwargs):
"""Overwrite to speed up as paths are determined by .csv file anyway"""
return []
def get_weights(self):
n_samples = len(self)
weight_per_class = 1/self.labels['Cardiomegaly'].value_counts(normalize=True)
# weight_per_class = {2.0: 1.2, 1.0: 8.2, 0.0: 24.3}
weights = [0] * n_samples
for index in range(n_samples):
target = self.labels.loc[self.labels.index[index], 'Cardiomegaly']
weights[index] = weight_per_class[target]
return weights