Spaces:
Runtime error
Runtime error
File size: 1,588 Bytes
6709fc9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 |
# python3.7
"""Contains the class of dataset."""
import os
from PIL import Image
from .process_image import ImageProcessor
from torch.utils.data import Dataset
class InferenceDataset(Dataset):
def __init__(self,
root_dir,
resolution=256,
aligner_path=None
):
"""Initializes the dataset.
Args:
root_dir: Root directory containing the dataset.
resolution: The resolution of the returned image.
transform: The transform function for pre-processing.
(default: `datasets.transforms.normalize_image()`)
"""
self.root_dir = root_dir
self.resolution = resolution
self.image_paths = sorted(os.listdir(self.root_dir))
self.num_samples = len(self.image_paths)
self.processor = ImageProcessor(aligner_path)
def __len__(self):
return self.num_samples
def __getitem__(self, idx):
data = dict()
image_path = self.image_paths[idx]
image = Image.open(os.path.join(self.root_dir, image_path))
image = self.processor.align_face(image)
image = self.processor.preprocess_image(image)
# image = image.resize( (self.resolution, self.resolution))
# image = np.asarray(image).transpose(2, 0, 1).astype(np.float32) # C,H,W -> H,W,C
# image = torch.FloatTensor(image.copy())
# image = (image - 127.5) / 127.5 # Normalize
data.update({'image': image})
data.update({'name': image_path})
return data
|