import csv import json import logging import os import sys from abc import abstractmethod from itertools import islice from typing import List, Tuple, Dict, Any from torch.utils.data import DataLoader import PIL from torch.utils.data import Dataset import numpy as np import pandas as pd from torchvision import transforms from PIL import Image from dataset.randaugment import RandomAugment class Chestxray14_Dataset(Dataset): def __init__(self, csv_path, is_train=True): data_info = pd.read_csv(csv_path) self.img_path_list = np.asarray(data_info.iloc[:, 0]) self.class_list = np.asarray(data_info.iloc[:, 2:]) normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)) if is_train: self.transform = transforms.Compose( [ transforms.RandomResizedCrop( 224, scale=(0.2, 1.0), interpolation=Image.BICUBIC ), transforms.RandomHorizontalFlip(), RandomAugment( 2, 7, isPIL=True, augs=[ "Identity", "AutoContrast", "Equalize", "Brightness", "Sharpness", "ShearX", "ShearY", "TranslateX", "TranslateY", "Rotate", ], ), transforms.ToTensor(), normalize, ] ) else: self.transform = transforms.Compose( [transforms.Resize([224, 224]), transforms.ToTensor(), normalize,] ) def __getitem__(self, index): img_path = self.img_path_list[index] class_label = self.class_list[index] img = PIL.Image.open(img_path).convert("RGB") image = self.transform(img) return {"image": image, "label": class_label} def __len__(self): return len(self.img_path_list) def create_loader(datasets, samplers, batch_size, num_workers, is_trains, collate_fns): loaders = [] for dataset, sampler, bs, n_worker, is_train, collate_fn in zip( datasets, samplers, batch_size, num_workers, is_trains, collate_fns ): if is_train: shuffle = sampler is None drop_last = True else: shuffle = False drop_last = False loader = DataLoader( dataset, batch_size=bs, num_workers=n_worker, pin_memory=True, sampler=sampler, shuffle=shuffle, collate_fn=collate_fn, drop_last=drop_last, ) loaders.append(loader) return loaders