|
from torch.utils.data import DataLoader |
|
import PIL |
|
from torch.utils.data import Dataset |
|
import numpy as np |
|
import pandas as pd |
|
from PIL import Image |
|
import torch |
|
import albumentations as A |
|
from albumentations.pytorch.transforms import ToTensorV2 |
|
|
|
|
|
class SIIM_ACR_Dataset(Dataset): |
|
def __init__(self, csv_path, is_train=True, percentage=0.01): |
|
data_info = pd.read_csv(csv_path) |
|
if is_train == True: |
|
total_len = int(percentage * len(data_info)) |
|
choice_list = np.random.choice( |
|
range(len(data_info)), size=total_len, replace=False |
|
) |
|
self.img_path_list = data_info["image_path"][choice_list].tolist() |
|
else: |
|
self.img_path_list = data_info["image_path"].tolist() |
|
|
|
self.img_root = "SIIM-CLS/siim-acr-pneumothorax/png_images/" |
|
self.seg_root = "SIIM-CLS/siim-acr-pneumothorax/png_masks/" |
|
|
|
if is_train: |
|
self.aug = A.Compose( |
|
[ |
|
A.RandomResizedCrop( |
|
width=224, |
|
height=224, |
|
scale=(0.2, 1.0), |
|
always_apply=True, |
|
interpolation=Image.BICUBIC, |
|
), |
|
A.HorizontalFlip(p=0.5), |
|
A.Normalize( |
|
mean=[0.485, 0.456, 0.406], |
|
std=[0.229, 0.224, 0.225], |
|
always_apply=True, |
|
), |
|
ToTensorV2(), |
|
] |
|
) |
|
else: |
|
self.aug = A.Compose( |
|
[ |
|
A.Resize(width=224, height=224, always_apply=True), |
|
A.Normalize( |
|
mean=[0.485, 0.456, 0.406], |
|
std=[0.229, 0.224, 0.225], |
|
always_apply=True, |
|
), |
|
ToTensorV2(), |
|
] |
|
) |
|
|
|
def __getitem__(self, index): |
|
img_path = self.img_root + self.img_path_list[index].split("/")[-1] |
|
seg_path = ( |
|
self.seg_root + self.img_path_list[index].split("/")[-1] |
|
) |
|
img = np.array(PIL.Image.open(img_path).convert("RGB")) |
|
seg_map = np.array(PIL.Image.open(seg_path))[:, :, np.newaxis] / 255 |
|
|
|
augmented = self.aug(image=img, mask=seg_map) |
|
img, seg_map = augmented["image"], augmented["mask"] |
|
seg_map = seg_map.permute(2, 0, 1) |
|
|
|
class_label = np.array([int(torch.sum(seg_map) > 0)]) |
|
return {"image": img, "seg": seg_map, "label": class_label} |
|
|
|
def __len__(self): |
|
return len(self.img_path_list) |
|
|
|
|
|
def create_loader_RSNA( |
|
datasets, samplers, batch_size, num_workers, is_trains, collate_fns |
|
): |
|
loaders = [] |
|
for dataset, sampler, bs, n_worker, is_train, collate_fn in zip( |
|
datasets, samplers, batch_size, num_workers, is_trains, collate_fns |
|
): |
|
if is_train: |
|
shuffle = sampler is None |
|
drop_last = True |
|
else: |
|
shuffle = False |
|
drop_last = False |
|
loader = DataLoader( |
|
dataset, |
|
batch_size=bs, |
|
num_workers=n_worker, |
|
pin_memory=True, |
|
sampler=sampler, |
|
shuffle=shuffle, |
|
collate_fn=collate_fn, |
|
drop_last=drop_last, |
|
) |
|
loaders.append(loader) |
|
return loaders |
|
|