File size: 1,637 Bytes
b84549f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
from torch.utils.data import Dataset
import os
from torchvision.datasets.folder import default_loader
import torchvision.transforms as T
import torch
import numpy as np
from PIL import Image


class CommonDataset(Dataset):
    def __init__(self, images_path, labels_path, x_transform, y_transform):
        self.imgs_path = images_path
        self.labels_path = labels_path
        
        # for p in os.listdir(os.path.join(image_dir)):
        #     p = os.path.join(dataset_project_dir, 'images', p)
        #     if not p.endswith('png'):
        #         continue
        #     self.imgs_path += [p]
        #     self.labels_path += [p.replace('images', 'labels_gt')]
        
        # self.x_transform = T.Compose(
        #     [
        #         T.Resize((224, 224)),
        #         T.ToTensor(),
        #         T.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
        #     ]
        # )
        # self.y_transform = T.Compose(
        #     [
        #         T.Resize((224, 224)),
        #         T.Lambda(lambda x: torch.from_numpy(np.array(x)).long())
        #     ]
        # )
        self.x_transform = x_transform
        self.y_transform = y_transform
        
        
    def __len__(self):
        return len(self.imgs_path)

    def __getitem__(self, idx):
        x_path = os.path.join(self.imgs_path[idx])
        y_path = os.path.join(self.labels_path[idx])
        
        x = default_loader(x_path)
        # y = default_loader(y_path)
        y = Image.open(y_path).convert('L')

        x = self.x_transform(x)
        y = self.y_transform(y)
        
        return x, y