File size: 2,978 Bytes
a256709 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 |
import csv
import json
import logging
import os
import sys
from abc import abstractmethod
from itertools import islice
from typing import List, Tuple, Dict, Any
from torch.utils.data import DataLoader
import PIL
from torch.utils.data import Dataset
import numpy as np
import pandas as pd
from torchvision import transforms
from PIL import Image
from dataset.randaugment import RandomAugment
class Chestxray14_Dataset(Dataset):
def __init__(self, csv_path, is_train=True):
data_info = pd.read_csv(csv_path)
self.img_path_list = np.asarray(data_info.iloc[:, 0])
self.class_list = np.asarray(data_info.iloc[:, 2:])
normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
if is_train:
self.transform = transforms.Compose(
[
transforms.RandomResizedCrop(
224, scale=(0.2, 1.0), interpolation=Image.BICUBIC
),
transforms.RandomHorizontalFlip(),
RandomAugment(
2,
7,
isPIL=True,
augs=[
"Identity",
"AutoContrast",
"Equalize",
"Brightness",
"Sharpness",
"ShearX",
"ShearY",
"TranslateX",
"TranslateY",
"Rotate",
],
),
transforms.ToTensor(),
normalize,
]
)
else:
self.transform = transforms.Compose(
[transforms.Resize([224, 224]), transforms.ToTensor(), normalize,]
)
def __getitem__(self, index):
img_path = self.img_path_list[index]
class_label = self.class_list[index]
img = PIL.Image.open(img_path).convert("RGB")
image = self.transform(img)
return {"image": image, "label": class_label}
def __len__(self):
return len(self.img_path_list)
def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns):
loaders = []
for dataset, sampler, bs, n_worker, is_train, collate_fn in zip(
datasets, samplers, batch_size, num_workers, is_trains, collate_fns
):
if is_train:
shuffle = sampler is None
drop_last = True
else:
shuffle = False
drop_last = False
loader = DataLoader(
dataset,
batch_size=bs,
num_workers=n_worker,
pin_memory=True,
sampler=sampler,
shuffle=shuffle,
collate_fn=collate_fn,
drop_last=drop_last,
)
loaders.append(loader)
return loaders
|