wenruifan's picture
Upload 115 files
a256709 verified
from torch.utils.data import DataLoader
import PIL
from torch.utils.data import Dataset
import numpy as np
import pandas as pd
from PIL import Image
import torch
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
class SIIM_ACR_Dataset(Dataset):
def __init__(self, csv_path, is_train=True, percentage=0.01):
data_info = pd.read_csv(csv_path)
if is_train == True:
total_len = int(percentage * len(data_info))
choice_list = np.random.choice(
range(len(data_info)), size=total_len, replace=False
)
self.img_path_list = data_info["image_path"][choice_list].tolist()
else:
self.img_path_list = data_info["image_path"].tolist()
self.img_root = "SIIM-CLS/siim-acr-pneumothorax/png_images/"
self.seg_root = "SIIM-CLS/siim-acr-pneumothorax/png_masks/" # We have pre-processed the original SIIM_ACR data, you may change this to fix your data
if is_train:
self.aug = A.Compose(
[
A.RandomResizedCrop(
width=224,
height=224,
scale=(0.2, 1.0),
always_apply=True,
interpolation=Image.BICUBIC,
),
A.HorizontalFlip(p=0.5),
A.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
always_apply=True,
),
ToTensorV2(),
]
)
else:
self.aug = A.Compose(
[
A.Resize(width=224, height=224, always_apply=True),
A.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
always_apply=True,
),
ToTensorV2(),
]
)
def __getitem__(self, index):
img_path = self.img_root + self.img_path_list[index].split("/")[-1] # + ".png"
seg_path = (
self.seg_root + self.img_path_list[index].split("/")[-1] # + ".png"
) # We have pre-processed the original SIIM_ACR data, you may change this to fix your data
img = np.array(PIL.Image.open(img_path).convert("RGB"))
seg_map = np.array(PIL.Image.open(seg_path))[:, :, np.newaxis] / 255
augmented = self.aug(image=img, mask=seg_map)
img, seg_map = augmented["image"], augmented["mask"]
seg_map = seg_map.permute(2, 0, 1)
class_label = np.array([int(torch.sum(seg_map) > 0)])
return {"image": img, "seg": seg_map, "label": class_label}
def __len__(self):
return len(self.img_path_list)
def create_loader_RSNA(
datasets, samplers, batch_size, num_workers, is_trains, collate_fns
):
loaders = []
for dataset, sampler, bs, n_worker, is_train, collate_fn in zip(
datasets, samplers, batch_size, num_workers, is_trains, collate_fns
):
if is_train:
shuffle = sampler is None
drop_last = True
else:
shuffle = False
drop_last = False
loader = DataLoader(
dataset,
batch_size=bs,
num_workers=n_worker,
pin_memory=True,
sampler=sampler,
shuffle=shuffle,
collate_fn=collate_fn,
drop_last=drop_last,
)
loaders.append(loader)
return loaders