File size: 4,203 Bytes
4a1f918
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import random
import os
import numpy as np
import torch
from PIL import Image
from torch.utils.data import Dataset
import pandas as pd
from data_transforms.atr_transform import ATR_Transform


class ATR_Dataset(Dataset):
    def __init__(self, config, is_train=False, shuffle_list = True, apply_norm=True, no_text_mode=False) -> None:
        super().__init__()
        self.root_path = config['data']['root_path']
        self.img_names = []
        self.img_path_list = []
        self.label_path_list = []
        self.label_list = []
        self.class_in_image = []
        self.is_train = is_train
        self.label_names = config['data']['label_names']
        self.num_classes = len(self.label_names)
        self.config = config
        self.apply_norm = apply_norm
        self.no_text_mode = no_text_mode
        if self.is_train:
            self.df = pd.read_csv(os.path.join(self.root_path, 'folds_masks', 'train0.csv'))
        else:
            self.df = pd.read_csv(os.path.join(self.root_path, 'folds_masks', 'val0.csv'))

        self.populate_lists()
        if shuffle_list:
            p = [x for x in range(len(self.img_path_list))]
            random.shuffle(p)
            self.img_path_list = [self.img_path_list[pi] for pi in p]
            self.img_names = [self.img_names[pi] for pi in p]
            self.label_path_list = [self.label_path_list[pi] for pi in p]
            self.label_list = [self.label_list[pi] for pi in p]
            self.class_in_image = [self.class_in_image[pi] for pi in p]

        #define data transform
        self.data_transform = ATR_Transform(config=config)

    def __len__(self):
        return len(self.img_path_list)

    def populate_lists(self):
        for i in range(len(self.df)):
            img = self.df['mask_path'][i][6:]
            img_path = os.path.join(self.root_path, 'imgs', img) 
            mask_path = os.path.join(self.root_path,self.df['mask_path'][i])
            # print(img)
            if (('jpg' not in img) and ('jpeg not in img') and ('png' not in img) and ('bmp' not in img)):
                continue
            if self.no_text_mode:
                self.img_names.append(img)
                self.img_path_list.append(img_path)
                self.label_path_list.append(mask_path)
                self.label_list.append('')
                self.class_in_image.append(self.df['tgt'][i])
            else:
                for label_name in self.label_names:
                    self.img_names.append(img)
                    self.img_path_list.append(img_path)
                    self.label_path_list.append(mask_path)
                    self.label_list.append(label_name)
                    self.class_in_image.append(self.df['tgt'][i])


    def __getitem__(self, index):
        img = torch.as_tensor(np.array(Image.open(self.img_path_list[index]).convert("RGB")))
        # print(img.shape)
        if self.config['data']['volume_channel']==2:
            img = img.permute(2,0,1)
            
        try:
            if self.num_classes>1:
                # print("classs in image: ", self.class_in_image[index])
                # print("label list: ", self.label_list[index])
                if self.class_in_image[index]+' Vehicle'==self.label_list[index]:
                    label = torch.Tensor(np.array(Image.open(self.label_path_list[index])))
                else:
                    label = torch.zeros(img.shape[1], img.shape[2])
            else:                       
                label = torch.Tensor(np.array(Image.open(self.label_path_list[index])))

            if len(label.shape)==3:
                label = label[:,:,0]
            # print(label.shape)
        except:
            1/0
            label = torch.zeros(img.shape[1], img.shape[2])
        
        label = label.unsqueeze(0)
        label = (label>0)+0
        label_of_interest = self.label_list[index]

        #convert all grayscale pixels due to resizing back to 0, 1
        img, label = self.data_transform(img, label, is_train=self.is_train, apply_norm=self.apply_norm)
        label = (label>=0.5)+0
        label = label[0]


        return img, label, self.img_path_list[index], label_of_interest