File size: 3,630 Bytes
a256709 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 |
from torch.utils.data import DataLoader
import PIL
from torch.utils.data import Dataset
import numpy as np
import pandas as pd
from PIL import Image
import torch
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
class SIIM_ACR_Dataset(Dataset):
def __init__(self, csv_path, is_train=True, percentage=0.01):
data_info = pd.read_csv(csv_path)
if is_train == True:
total_len = int(percentage * len(data_info))
choice_list = np.random.choice(
range(len(data_info)), size=total_len, replace=False
)
self.img_path_list = data_info["image_path"][choice_list].tolist()
else:
self.img_path_list = data_info["image_path"].tolist()
self.img_root = "SIIM-CLS/siim-acr-pneumothorax/png_images/"
self.seg_root = "SIIM-CLS/siim-acr-pneumothorax/png_masks/" # We have pre-processed the original SIIM_ACR data, you may change this to fix your data
if is_train:
self.aug = A.Compose(
[
A.RandomResizedCrop(
width=224,
height=224,
scale=(0.2, 1.0),
always_apply=True,
interpolation=Image.BICUBIC,
),
A.HorizontalFlip(p=0.5),
A.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
always_apply=True,
),
ToTensorV2(),
]
)
else:
self.aug = A.Compose(
[
A.Resize(width=224, height=224, always_apply=True),
A.Normalize(
mean=[0.485, 0.456, 0.406],
std=[0.229, 0.224, 0.225],
always_apply=True,
),
ToTensorV2(),
]
)
def __getitem__(self, index):
img_path = self.img_root + self.img_path_list[index].split("/")[-1] # + ".png"
seg_path = (
self.seg_root + self.img_path_list[index].split("/")[-1] # + ".png"
) # We have pre-processed the original SIIM_ACR data, you may change this to fix your data
img = np.array(PIL.Image.open(img_path).convert("RGB"))
seg_map = np.array(PIL.Image.open(seg_path))[:, :, np.newaxis] / 255
augmented = self.aug(image=img, mask=seg_map)
img, seg_map = augmented["image"], augmented["mask"]
seg_map = seg_map.permute(2, 0, 1)
class_label = np.array([int(torch.sum(seg_map) > 0)])
return {"image": img, "seg": seg_map, "label": class_label}
def __len__(self):
return len(self.img_path_list)
def create_loader_RSNA(
datasets, samplers, batch_size, num_workers, is_trains, collate_fns
):
loaders = []
for dataset, sampler, bs, n_worker, is_train, collate_fn in zip(
datasets, samplers, batch_size, num_workers, is_trains, collate_fns
):
if is_train:
shuffle = sampler is None
drop_last = True
else:
shuffle = False
drop_last = False
loader = DataLoader(
dataset,
batch_size=bs,
num_workers=n_worker,
pin_memory=True,
sampler=sampler,
shuffle=shuffle,
collate_fn=collate_fn,
drop_last=drop_last,
)
loaders.append(loader)
return loaders
|