File size: 4,227 Bytes
a256709
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
from cmath import nan
import csv
import json
import logging
import os
import sys
import pydicom

from abc import abstractmethod
from itertools import islice
from typing import List, Tuple, Dict, Any
from torch.utils.data import DataLoader
import PIL
from torch.utils.data import Dataset
import numpy as np
import pandas as pd
from torchvision import transforms
from PIL import Image
from skimage import exposure
import torch
from torchvision.transforms import InterpolationMode
from dataset.randaugment import RandomAugment


class SIIM_ACR_Dataset(Dataset):
    def __init__(self, csv_path, is_train=True, percentage=0.01):
        data_info = pd.read_csv(csv_path)
        if is_train == True:
            total_len = int(percentage * len(data_info))
            choice_list = np.random.choice(
                range(len(data_info)), size=total_len, replace=False
            )
            self.img_path_list = data_info["image_path"][choice_list].tolist()
        else:
            self.img_path_list = data_info["image_path"].tolist()

        self.img_root = "SIIM-CLS/siim-acr-pneumothorax/png_images/"
        self.seg_root = "SIIM-CLS/siim-acr-pneumothorax/png_masks/"  # We have pre-processed the original SIIM_ACR data, you may change this to fix your data

        normalize = transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))

        if is_train:
            self.transform = transforms.Compose(
                [
                    transforms.RandomResizedCrop(
                        224, scale=(0.2, 1.0), interpolation=Image.BICUBIC
                    ),
                    transforms.RandomHorizontalFlip(),
                    RandomAugment(
                        2,
                        7,
                        isPIL=True,
                        augs=[
                            "Identity",
                            "AutoContrast",
                            "Equalize",
                            "Brightness",
                            "Sharpness",
                            "ShearX",
                            "ShearY",
                            "TranslateX",
                            "TranslateY",
                            "Rotate",
                        ],
                    ),
                    transforms.ToTensor(),
                    normalize,
                ]
            )
        else:
            self.transform = transforms.Compose(
                [transforms.Resize([224, 224]), transforms.ToTensor(), normalize,]
            )

        self.seg_transfrom = transforms.Compose(
            [
                transforms.ToTensor(),
                transforms.Resize([224, 224], interpolation=InterpolationMode.NEAREST),
            ]
        )

    def __getitem__(self, index):
        img_path = self.img_root + self.img_path_list[index].split("/")[-1]  # + ".png"
        seg_path = (
            self.seg_root + self.img_path_list[index].split("/")[-1]  # + ".png"
        )  # We have pre-processed the original SIIM_ACR data, you may change this to fix your data
        img = PIL.Image.open(img_path).convert("RGB")
        image = self.transform(img)

        seg_map = PIL.Image.open(seg_path)
        seg_map = self.seg_transfrom(seg_map)
        seg_map = (seg_map > 0).type(torch.int)
        class_label = np.array([int(torch.sum(seg_map) > 0)])
        return {"image": image, "label": class_label}

    def __len__(self):
        return len(self.img_path_list)


def create_loader_RSNA(
    datasets, samplers, batch_size, num_workers, is_trains, collate_fns
):
    loaders = []
    for dataset, sampler, bs, n_worker, is_train, collate_fn in zip(
        datasets, samplers, batch_size, num_workers, is_trains, collate_fns
    ):
        if is_train:
            shuffle = sampler is None
            drop_last = True
        else:
            shuffle = False
            drop_last = False
        loader = DataLoader(
            dataset,
            batch_size=bs,
            num_workers=n_worker,
            pin_memory=True,
            sampler=sampler,
            shuffle=shuffle,
            collate_fn=collate_fn,
            drop_last=drop_last,
        )
        loaders.append(loader)
    return loaders