File size: 1,588 Bytes
6709fc9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
# python3.7
"""Contains the class of dataset."""

import os
from PIL import Image
from .process_image import ImageProcessor
from torch.utils.data import Dataset

class InferenceDataset(Dataset):
    
    def __init__(self,
                 root_dir,
                 resolution=256,
                 aligner_path=None
                 ):
        """Initializes the dataset.

        Args:
            root_dir: Root directory containing the dataset.
            resolution: The resolution of the returned image.
            transform: The transform function for pre-processing.
                (default: `datasets.transforms.normalize_image()`)
        """

        self.root_dir = root_dir
        self.resolution = resolution
        self.image_paths = sorted(os.listdir(self.root_dir))
        self.num_samples = len(self.image_paths)
        self.processor = ImageProcessor(aligner_path)

    def __len__(self):
        return self.num_samples

    def __getitem__(self, idx):
        data = dict()

        image_path = self.image_paths[idx]
        image = Image.open(os.path.join(self.root_dir, image_path))
        image = self.processor.align_face(image)
        image = self.processor.preprocess_image(image)
        # image = image.resize( (self.resolution, self.resolution))
        # image = np.asarray(image).transpose(2, 0, 1).astype(np.float32) # C,H,W -> H,W,C
        # image = torch.FloatTensor(image.copy())
        # image = (image - 127.5) / 127.5     # Normalize

        data.update({'image': image})
        data.update({'name': image_path})
        return data