File size: 1,204 Bytes
6e70c4a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
import torch
import torch.utils.data as data
import cv2
import numpy as np
from os.path import join


class HalftoneVOC2012(data.Dataset):
    # data range is [-1,1], color image is in BGR format
    def __init__(self, data_list):
        super(HalftoneVOC2012, self).__init__()
        self.inputs = [join('Data', x) for x in data_list['inputs']]
        self.labels = [join('Data', x) for x in data_list['labels']]

    @staticmethod
    def load_input(name):
        img = cv2.imread(name, flags=cv2.IMREAD_COLOR)
        # transpose data
        img = img.transpose((2, 0, 1))
        # to Tensor
        img = torch.from_numpy(img.astype(np.float32) / 127.5 - 1.0)
        return img

    @staticmethod
    def load_label(name):
        img = cv2.imread(name, flags=cv2.IMREAD_GRAYSCALE)
        # transpose data
        img = img[np.newaxis, :, :]
        # to Tensor
        img = torch.from_numpy(img.astype(np.float32) / 127.5 - 1.0)
        return img

    def __getitem__(self, index):
        input_data = self.load_input(self.inputs[index])
        label_data = self.load_label(self.labels[index])
        return input_data, label_data

    def __len__(self):
        return len(self.inputs)