Spaces:
Runtime error
Runtime error
File size: 8,544 Bytes
f85e212 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 |
import torch.utils.data as data
import torch
from torch import nn
from pathlib import Path
from torchvision import transforms as T
import pandas as pd
from PIL import Image
from medical_diffusion.data.augmentation.augmentations_2d import Normalize, ToTensor16bit
class SimpleDataset2D(data.Dataset):
def __init__(
self,
path_root,
item_pointers =[],
crawler_ext = 'tif', # other options are ['jpg', 'jpeg', 'png', 'tiff'],
transform = None,
image_resize = None,
augment_horizontal_flip = False,
augment_vertical_flip = False,
image_crop = None,
):
super().__init__()
self.path_root = Path(path_root)
self.crawler_ext = crawler_ext
if len(item_pointers):
self.item_pointers = item_pointers
else:
self.item_pointers = self.run_item_crawler(self.path_root, self.crawler_ext)
if transform is None:
self.transform = T.Compose([
T.Resize(image_resize) if image_resize is not None else nn.Identity(),
T.RandomHorizontalFlip() if augment_horizontal_flip else nn.Identity(),
T.RandomVerticalFlip() if augment_vertical_flip else nn.Identity(),
T.CenterCrop(image_crop) if image_crop is not None else nn.Identity(),
T.ToTensor(),
# T.Lambda(lambda x: torch.cat([x]*3) if x.shape[0]==1 else x),
# ToTensor16bit(),
# Normalize(), # [0, 1.0]
# T.ConvertImageDtype(torch.float),
T.Normalize(mean=0.5, std=0.5) # WARNING: mean and std are not the target values but rather the values to subtract and divide by: [0, 1] -> [0-0.5, 1-0.5]/0.5 -> [-1, 1]
])
else:
self.transform = transform
def __len__(self):
return len(self.item_pointers)
def __getitem__(self, index):
rel_path_item = self.item_pointers[index]
path_item = self.path_root/rel_path_item
# img = Image.open(path_item)
img = self.load_item(path_item)
return {'uid':rel_path_item.stem, 'source': self.transform(img)}
def load_item(self, path_item):
return Image.open(path_item).convert('RGB')
# return cv2.imread(str(path_item), cv2.IMREAD_UNCHANGED) # NOTE: Only CV2 supports 16bit RGB images
@classmethod
def run_item_crawler(cls, path_root, extension, **kwargs):
return [path.relative_to(path_root) for path in Path(path_root).rglob(f'*.{extension}')]
def get_weights(self):
"""Return list of class-weights for WeightedSampling"""
return None
class AIROGSDataset(SimpleDataset2D):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.labels = pd.read_csv(self.path_root.parent/'train_labels.csv', index_col='challenge_id')
def __len__(self):
return len(self.labels)
def __getitem__(self, index):
uid = self.labels.index[index]
path_item = self.path_root/f'{uid}.jpg'
img = self.load_item(path_item)
str_2_int = {'NRG':0, 'RG':1} # RG = 3270, NRG = 98172
target = str_2_int[self.labels.loc[uid, 'class']]
# return {'uid':uid, 'source': self.transform(img), 'target':target}
return {'source': self.transform(img), 'target':target}
def get_weights(self):
n_samples = len(self)
weight_per_class = 1/self.labels['class'].value_counts(normalize=True) # {'NRG': 1.03, 'RG': 31.02}
weights = [0] * n_samples
for index in range(n_samples):
target = self.labels.iloc[index]['class']
weights[index] = weight_per_class[target]
return weights
@classmethod
def run_item_crawler(cls, path_root, extension, **kwargs):
"""Overwrite to speed up as paths are determined by .csv file anyway"""
return []
class MSIvsMSS_Dataset(SimpleDataset2D):
# https://doi.org/10.5281/zenodo.2530835
def __getitem__(self, index):
rel_path_item = self.item_pointers[index]
path_item = self.path_root/rel_path_item
img = self.load_item(path_item)
uid = rel_path_item.stem
str_2_int = {'MSIMUT':0, 'MSS':1}
target = str_2_int[path_item.parent.name] #
return {'uid':uid, 'source': self.transform(img), 'target':target}
class MSIvsMSS_2_Dataset(SimpleDataset2D):
# https://doi.org/10.5281/zenodo.3832231
def __getitem__(self, index):
rel_path_item = self.item_pointers[index]
path_item = self.path_root/rel_path_item
img = self.load_item(path_item)
uid = rel_path_item.stem
str_2_int = {'MSIH':0, 'nonMSIH':1} # patients with MSI-H = MSIH; patients with MSI-L and MSS = NonMSIH)
target = str_2_int[path_item.parent.name]
# return {'uid':uid, 'source': self.transform(img), 'target':target}
return {'source': self.transform(img), 'target':target}
class CheXpert_Dataset(SimpleDataset2D):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
mode = self.path_root.name
labels = pd.read_csv(self.path_root.parent/f'{mode}.csv', index_col='Path')
self.labels = labels.loc[labels['Frontal/Lateral'] == 'Frontal'].copy()
self.labels.index = self.labels.index.str[20:]
self.labels.loc[self.labels['Sex'] == 'Unknown', 'Sex'] = 'Female' # Affects 1 case, must be "female" to match stats in publication
self.labels.fillna(2, inplace=True) # TODO: Find better solution,
str_2_int = {'Sex': {'Male':0, 'Female':1}, 'Frontal/Lateral':{'Frontal':0, 'Lateral':1}, 'AP/PA':{'AP':0, 'PA':1}}
self.labels.replace(str_2_int, inplace=True)
def __len__(self):
return len(self.labels)
def __getitem__(self, index):
rel_path_item = self.labels.index[index]
path_item = self.path_root/rel_path_item
img = self.load_item(path_item)
uid = str(rel_path_item)
target = torch.tensor(self.labels.loc[uid, 'Cardiomegaly']+1, dtype=torch.long) # Note Labels are -1=uncertain, 0=negative, 1=positive, NA=not reported -> Map to [0, 2], NA=3
return {'uid':uid, 'source': self.transform(img), 'target':target}
@classmethod
def run_item_crawler(cls, path_root, extension, **kwargs):
"""Overwrite to speed up as paths are determined by .csv file anyway"""
return []
class CheXpert_2_Dataset(SimpleDataset2D):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
labels = pd.read_csv(self.path_root/'labels/cheXPert_label.csv', index_col=['Path', 'Image Index']) # Note: 1 and -1 (uncertain) cases count as positives (1), 0 and NA count as negatives (0)
labels = labels.loc[labels['fold']=='train'].copy()
labels = labels.drop(labels='fold', axis=1)
labels2 = pd.read_csv(self.path_root/'labels/train.csv', index_col='Path')
labels2 = labels2.loc[labels2['Frontal/Lateral'] == 'Frontal'].copy()
labels2 = labels2[['Cardiomegaly',]].copy()
labels2[ (labels2 <0) | labels2.isna()] = 2 # 0 = Negative, 1 = Positive, 2 = Uncertain
labels = labels.join(labels2['Cardiomegaly'], on=["Path",], rsuffix='_true')
# labels = labels[labels['Cardiomegaly_true']!=2]
self.labels = labels
def __len__(self):
return len(self.labels)
def __getitem__(self, index):
path_index, image_index = self.labels.index[index]
path_item = self.path_root/'data'/f'{image_index:06}.png'
img = self.load_item(path_item)
uid = image_index
target = int(self.labels.loc[(path_index, image_index), 'Cardiomegaly'])
# return {'uid':uid, 'source': self.transform(img), 'target':target}
return {'source': self.transform(img), 'target':target}
@classmethod
def run_item_crawler(cls, path_root, extension, **kwargs):
"""Overwrite to speed up as paths are determined by .csv file anyway"""
return []
def get_weights(self):
n_samples = len(self)
weight_per_class = 1/self.labels['Cardiomegaly'].value_counts(normalize=True)
# weight_per_class = {2.0: 1.2, 1.0: 8.2, 0.0: 24.3}
weights = [0] * n_samples
for index in range(n_samples):
target = self.labels.loc[self.labels.index[index], 'Cardiomegaly']
weights[index] = weight_per_class[target]
return weights |